├── Model ├── __init__.py └── module.py ├── figure ├── GiGCN.png └── ensemble.png ├── Monitor.py ├── dataInfo.ini ├── integrate.py ├── trainTestSplit.py ├── predict.py ├── utils.py ├── requirements.txt ├── Trainer.py ├── train.py └── README.md /Model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import module -------------------------------------------------------------------------------- /figure/GiGCN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuGuoJ/GiGCN/HEAD/figure/GiGCN.png -------------------------------------------------------------------------------- /figure/ensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuGuoJ/GiGCN/HEAD/figure/ensemble.png -------------------------------------------------------------------------------- /Monitor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monitoring the change of gradient 3 | """ 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class GradMonitor(object): 9 | def __init__(self): 10 | self.grads = [] 11 | 12 | def clear(self): 13 | self.grads.clear() 14 | return self 15 | 16 | def add(self, parameters_list: list, ord=1): 17 | grad = [] 18 | for parameters in parameters_list: 19 | grad_norm = [] 20 | for p in parameters: 21 | if p.requires_grad and p.grad is not None: 22 | norm = p.grad.norm(ord) 23 | grad_norm.append(norm) 24 | else: 25 | continue 26 | grad_norm = torch.tensor(grad_norm, dtype=torch.float) 27 | grad.append(grad_norm.norm(ord).item()) 28 | self.grads.append(grad) 29 | 30 | def get(self): 31 | if len(self.grads) == 0: 32 | return None 33 | else: 34 | # return float(np.mean(self.grads)) 35 | return np.mean(self.grads, axis=0).tolist() 36 | 37 | 38 | -------------------------------------------------------------------------------- /dataInfo.ini: -------------------------------------------------------------------------------- 1 | [PaviaU] 2 | data_key = paviaU 3 | gt_key = paviaU_gt 4 | band_begin = 430 5 | band_end = 860 6 | band = 103 7 | h = 610 8 | w = 340 9 | nc = 9 10 | 11 | [Salinas] 12 | data_key = salinas_corrected 13 | gt_key = salinas_gt 14 | band_begin = 200 15 | band_end = 2400 16 | band = 204 17 | h = 512 18 | w = 217 19 | nc = 16 20 | 21 | [KSC] 22 | data_key = KSC 23 | gt_key = KSC_gt 24 | band_begin = 400 25 | band_end = 2500 26 | band = 176 27 | h = 512 28 | w = 614 29 | nc = 13 30 | 31 | [gf5] 32 | data_key = gf5 33 | gt_key = gf5_gt 34 | band_begin = 400 35 | band_end = 2500 36 | band = 280 37 | h = 1400 38 | w = 1400 39 | nc = 20 40 | 41 | [Xiongan] 42 | data_key = xiongan 43 | gt_key = xiongan_gt 44 | band_begin = 400 45 | band_end = 1000 46 | band = 256 47 | h = 1580 48 | w = 3750 49 | nc = 20 50 | 51 | [Indian_pines] 52 | data_key = indian_pines_corrected 53 | gt_key = indian_pines_gt 54 | band_begin = 400 55 | band_end = 2500 56 | band = 200 57 | h = 145 58 | w = 145 59 | nc = 16 60 | 61 | [Yancheng] 62 | data_key = yancheng 63 | gt_key = yancheng_gt 64 | band_begin = 400 65 | band_end = 2500 66 | band = 266 67 | h = 585 68 | w = 1175 69 | nc = 20 70 | 71 | [Houston2018] 72 | data_key = Houston 73 | gt_key = Houston2018_gt 74 | band = 48 75 | h = 601 76 | w = 2384 77 | nc = 20 78 | 79 | -------------------------------------------------------------------------------- /integrate.py: -------------------------------------------------------------------------------- 1 | '''Ensemble learning''' 2 | from scipy.io import loadmat, savemat 3 | import argparse 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | from configparser import ConfigParser 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='INTEGRATED LEARNING') 11 | parser.add_argument('--name', type=str, default='PaviaU', 12 | help='DATASET NAME') 13 | parser.add_argument('--run', type=int, default=10, 14 | help='RUN TIMES') 15 | parser.add_argument('--spc', type=int, default=10, 16 | help='SAMPLE EACH CLASS') 17 | arg = parser.parse_args() 18 | config = ConfigParser() 19 | config.read('dataInfo.ini') 20 | block = [50, 100, 150, 200] 21 | save_root = 'prediction/{}/{}/integratedLearning_l1_clip'.format(arg.name, arg.spc) 22 | if not os.path.exists(save_root): 23 | os.makedirs(save_root) 24 | for r in tqdm(range(arg.run)): 25 | res = [] 26 | for b in block: 27 | path = 'prediction/{}/{}/{}_overall_skip_2_SGConv_l1_clip/{}.mat'.format(arg.name, arg.spc, b, r) 28 | res.append(loadmat(path)['pred']) 29 | # fr = np.where(res[0] == res[1], res[0], np.full(res[0].shape, -1)) 30 | # fr = np.where(res[1] == res[2], res[1], fr) 31 | # fr = np.where(res[0] == res[2], res[0], fr) 32 | # fr = np.where(fr == -1, res[1], fr) 33 | # res -> h x w x len(block) 34 | res = np.stack(res, axis=-1) 35 | h, w = res.shape[:2] 36 | nc = config.getint(arg.name, 'nc') 37 | res = res.reshape((h * w, -1)) 38 | fr = [np.bincount(x, minlength=nc) for x in res] 39 | fr = np.stack(fr, axis=0) 40 | fr = fr.reshape((h, w, -1)) 41 | fr = np.argmax(fr, axis=-1) 42 | savemat(os.path.join(save_root, '{}.mat'.format(r)), {'pred': fr}) 43 | print('*'*5 + 'FINISH' + '*'*5) 44 | -------------------------------------------------------------------------------- /trainTestSplit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | import os 4 | import random 5 | from scipy import io 6 | # get train test split 7 | keys = {'PaviaU':'paviaU_gt', 8 | 'Salinas':'salinas_gt', 9 | 'KSC':'KSC_gt', 10 | 'Houston':'Houston2018_gt', 11 | 'gf5': 'gf5_gt', 12 | 'Xiongan': 'xiongan_gt'} 13 | TRAIN_SIZE = [10] 14 | RUN = 10 15 | 16 | 17 | def sample_gt(gt, train_size, mode='fixed_withone'): 18 | indices = np.nonzero(gt) 19 | X = list(zip(*indices)) # x,y features 20 | y = gt[indices].ravel() # classes 21 | train_gt = np.zeros_like(gt) 22 | test_gt = np.zeros_like(gt) 23 | if train_size > 1: 24 | train_size = int(train_size) 25 | if mode == 'random': 26 | train_size = float(train_size) / 100 # dengbin:20181011 27 | 28 | if mode == 'random_withone': 29 | train_indices = [] 30 | test_gt = np.copy(gt) 31 | for c in np.unique(gt): 32 | if c == 0: 33 | continue 34 | indices = np.nonzero(gt == c) 35 | X = list(zip(*indices)) # x,y features 36 | train_len = int(np.ceil(train_size * len(X))) 37 | train_indices += random.sample(X, train_len) 38 | index = tuple(zip(*train_indices)) 39 | train_gt[index] = gt[index] 40 | test_gt[index] = 0 41 | 42 | elif mode == 'fixed_withone': 43 | train_indices = [] 44 | test_gt = np.copy(gt) 45 | for c in np.unique(gt): 46 | if c == 0: 47 | continue 48 | indices = np.nonzero(gt == c) 49 | X = list(zip(*indices)) # x,y features 50 | 51 | train_indices += random.sample(X, train_size) 52 | index = tuple(zip(*train_indices)) 53 | train_gt[index] = gt[index] 54 | test_gt[index] = 0 55 | else: 56 | raise ValueError("{} sampling is not implemented yet.".format(mode)) 57 | return train_gt, test_gt 58 | 59 | 60 | # 保存样本 61 | def save_sample(train_gt, test_gt, dataset_name, sample_size, run): 62 | sample_dir = './trainTestSplit/' + dataset_name + '/' 63 | if not os.path.isdir(sample_dir): 64 | os.makedirs(sample_dir) 65 | sample_file = sample_dir + 'sample' + str(sample_size) + '_run' + str(run) + '.mat' 66 | io.savemat(sample_file, {'train_gt':train_gt, 'test_gt':test_gt}) 67 | 68 | 69 | def load(dname): 70 | path = os.path.join(dname,'{}_gt.mat'.format(dname)) 71 | dataset = loadmat(path) 72 | key = keys[dname] 73 | gt = dataset[key] 74 | # # 采样背景像素点 75 | # gt += 1 76 | return gt 77 | 78 | 79 | def TrainTestSplit(datasetName): 80 | gt = load(datasetName) 81 | for size in TRAIN_SIZE: 82 | for r in range(RUN): 83 | train_gt, test_gt = sample_gt(gt, size) 84 | save_sample(train_gt, test_gt, datasetName, size, r) 85 | print('Finish split {}'.format(datasetName)) 86 | 87 | 88 | if __name__ == '__main__': 89 | dataseteName = ['Xiongan'] 90 | for name in dataseteName: 91 | TrainTestSplit(name) 92 | print('*'*8 + 'FINISH' + '*'*8) -------------------------------------------------------------------------------- /Model/module.py: -------------------------------------------------------------------------------- 1 | from torch_geometric import nn as gnn 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | # Internal graph convolution 8 | class SubGcn(nn.Module): 9 | def __init__(self, c_in, hidden_size, nc): 10 | super().__init__() 11 | self.gcn = gnn.SGConv(c_in, hidden_size, K=3) 12 | # 1 13 | self.classifier = nn.Sequential( 14 | nn.Dropout(), 15 | nn.Linear(hidden_size, hidden_size // 2), 16 | nn.ReLU(), 17 | # nn.Dropout(), 18 | nn.Linear(hidden_size // 2, nc) 19 | ) 20 | 21 | def forward(self, graph): 22 | h = F.relu(self.gcn(graph.x, graph.edge_index)) 23 | h_avg = gnn.global_mean_pool(h, graph.batch) 24 | logits = self.classifier(h_avg) 25 | return logits 26 | 27 | 28 | # Internal graph convolution feature module 29 | class SubGcnFeature(nn.Module): 30 | def __init__(self, c_in, hidden_size): 31 | super().__init__() 32 | self.gcn = gnn.SGConv(c_in, hidden_size, K=3) 33 | 34 | def forward(self, graph): 35 | h = F.relu(self.gcn(graph.x, graph.edge_index)) 36 | h_avg = gnn.global_mean_pool(h, graph.batch) 37 | return h_avg 38 | 39 | 40 | # External graph convolution 41 | class GraphNet(nn.Module): 42 | def __init__(self, c_in, hidden_size, nc): 43 | super().__init__() 44 | self.bn_0 = gnn.BatchNorm(c_in) 45 | self.gcn_1 = gnn.GCNConv(c_in, hidden_size) 46 | self.bn_1 = gnn.BatchNorm(hidden_size) 47 | self.gcn_2 = gnn.GraphConv(hidden_size, hidden_size) 48 | self.bn_2 = gnn.BatchNorm(hidden_size) 49 | # self.gcn_3 = gnn.GraphConv(hidden_size, hidden_size) 50 | # self.bn_3 = gnn.BatchNorm(hidden_size) 51 | self.classifier = nn.Sequential( 52 | nn.Dropout(), 53 | nn.Linear(hidden_size, hidden_size // 2), 54 | nn.ReLU(), 55 | # nn.Dropout(), 56 | nn.Linear(hidden_size // 2, nc) 57 | ) 58 | 59 | def forward(self, graph): 60 | # x_normalization = graph.x 61 | # h = F.relu(self.gcn_1(x_normalization, graph.edge_index)) 62 | # h = F.relu(self.gcn_2(h, graph.edge_index)) 63 | x_normalization = self.bn_0(graph.x) 64 | h = self.bn_1(F.relu(self.gcn_1(x_normalization, graph.edge_index))) 65 | h = self.bn_2(F.relu(self.gcn_2(h, graph.edge_index))) 66 | # h = self.bn_3(F.relu(self.gcn_3(h, graph.edge_index))) 67 | # h = F.relu(self.gcn_2(h, graph.edge_index)) 68 | logits = self.classifier(h + x_normalization) 69 | # logits = self.classifier(h) 70 | return logits 71 | 72 | 73 | # External graph convolution feature module 74 | class GraphNetFeature(nn.Module): 75 | def __init__(self, c_in, hidden_size): 76 | super().__init__() 77 | self.bn_0 = gnn.BatchNorm(c_in) 78 | self.gcn_1 = gnn.GCNConv(c_in, hidden_size) 79 | self.bn_1 = gnn.BatchNorm(hidden_size) 80 | self.gcn_2 = gnn.GCNConv(hidden_size, hidden_size) 81 | self.bn_2 = gnn.BatchNorm(hidden_size) 82 | 83 | def forward(self, graph): 84 | x_normalization = self.bn_0(graph.x) 85 | # x_normalization = graph.x 86 | h = self.bn_1(F.relu(self.gcn_1(x_normalization, graph.edge_index))) 87 | h = self.bn_2(F.relu(self.gcn_2(h, graph.edge_index))) 88 | return x_normalization + h 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | '''Predicting''' 2 | from scipy.io import loadmat, savemat 3 | import numpy as np 4 | import argparse 5 | import configparser 6 | import torch 7 | from torch import nn 8 | from torch_geometric.data import Data, Batch 9 | from skimage.segmentation import slic, mark_boundaries 10 | from sklearn.preprocessing import scale 11 | import os 12 | from PIL import Image 13 | from utils import get_graph_list, get_edge_index 14 | import math 15 | from Model.module import SubGcnFeature, GraphNet 16 | from Trainer import JointTrainer 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='TRAIN THE OVERALL') 20 | parser.add_argument('--name', type=str, default='PaviaU', 21 | help='DATASET NAME') 22 | parser.add_argument('--block', type=int, default=100, 23 | help='BLOCK SIZE') 24 | parser.add_argument('--gpu', type=int, default=-1, 25 | help='GPU ID') 26 | parser.add_argument('--comp', type=int, default=10, 27 | help='COMPACTNESS') 28 | parser.add_argument('--batchsz', type=int, default=64, 29 | help='BATCH SIZE') 30 | parser.add_argument('--run', type=int, default=10, 31 | help='EXPERIMENT AMOUNT') 32 | parser.add_argument('--spc', type=int, default=10, 33 | help='SAMPLE per CLASS') 34 | parser.add_argument('--hsz', type=int, default=128, 35 | help='HIDDEN SIZE') 36 | parser.add_argument('--lr', type=float, default=1e-3, 37 | help='LEARNING RATE') 38 | parser.add_argument('--wd', type=float, default=0., 39 | help='WEIGHT DECAY') 40 | arg = parser.parse_args() 41 | config = configparser.ConfigParser() 42 | config.read('dataInfo.ini') 43 | 44 | # Data processing 45 | # Reading hyperspectral image 46 | data_path = 'data/{0}/{0}.mat'.format(arg.name) 47 | m = loadmat(data_path) 48 | data = m[config.get(arg.name, 'data_key')] 49 | gt_path = 'data/{0}/{0}_gt.mat'.format(arg.name) 50 | m = loadmat(gt_path) 51 | gt = m[config.get(arg.name, 'gt_key')] 52 | # Normalizing data 53 | h, w, c = data.shape 54 | data = data.reshape((h * w, c)) 55 | data_normalization = scale(data).reshape((h, w, c)) 56 | 57 | # Superpixel segmentation 58 | seg_root = 'data/rgb' 59 | seg_path = os.path.join(seg_root, '{}_seg_{}.npy'.format(arg.name, arg.block)) 60 | if os.path.exists(seg_path): 61 | seg = np.load(seg_path) 62 | else: 63 | rgb_path = os.path.join(seg_root, '{}_rgb.jpg'.format(arg.name)) 64 | img = Image.open(rgb_path) 65 | img_array = np.array(img) 66 | # The number of superpixel 67 | n_superpixel = int(math.ceil((h * w) / arg.block)) 68 | seg = slic(img_array, n_superpixel, arg.comp) 69 | # Saving 70 | np.save(seg_path, seg) 71 | 72 | # Constructing graphs 73 | graph_path = 'data/{}/{}_graph.pkl'.format(arg.name, arg.block) 74 | if os.path.exists(graph_path): 75 | graph_list = torch.load(graph_path) 76 | else: 77 | graph_list = get_graph_list(data_normalization, seg) 78 | torch.save(graph_list, graph_path) 79 | subGraph = Batch.from_data_list(graph_list) 80 | 81 | # Constructing full graphs 82 | full_edge_index_path = 'data/{}/{}_edge_index.npy'.format(arg.name, arg.block) 83 | if os.path.exists(full_edge_index_path): 84 | edge_index = np.load(full_edge_index_path) 85 | else: 86 | edge_index, _ = get_edge_index(seg) 87 | np.save(full_edge_index_path, 88 | edge_index if isinstance(edge_index, np.ndarray) else edge_index.cpu().numpy()) 89 | fullGraph = Data(None, 90 | edge_index=torch.from_numpy(edge_index) if isinstance(edge_index, np.ndarray) else edge_index, 91 | seg=torch.from_numpy(seg) if isinstance(seg, np.ndarray) else seg) 92 | 93 | 94 | gcn1 = SubGcnFeature(config.getint(arg.name, 'band'), arg.hsz) 95 | gcn2 = GraphNet(arg.hsz, arg.hsz, config.getint(arg.name, 'nc')) 96 | 97 | device = torch.device('cuda:{}'.format(arg.gpu)) if arg.gpu != -1 else torch.device('cpu') 98 | 99 | for r in range(arg.run): 100 | # Loading pretraining parameters 101 | gcn1.load_state_dict( 102 | torch.load(f"models/{arg.name}/{arg.block}_overall_skip_2_SGConv_l1_clip/intNet_best_{arg.spc}_{r}.pkl")) 103 | gcn2.load_state_dict( 104 | torch.load(f"models/{arg.name}/{arg.block}_overall_skip_2_SGConv_l1_clip/extNet_best_{arg.spc}_{r}.pkl")) 105 | trainer = JointTrainer([gcn1, gcn2]) 106 | # predicting 107 | preds = trainer.predict(subGraph, fullGraph, device) 108 | seg_torch = torch.from_numpy(seg) 109 | map = preds[seg_torch] 110 | save_root = 'prediction/{}/{}/{}_overall_skip_2_SGConv_l1_clip'.format(arg.name, arg.spc, arg.block) 111 | if not os.path.exists(save_root): 112 | os.makedirs(save_root) 113 | save_path = os.path.join(save_root, '{}.mat'.format(r)) 114 | savemat(save_path, {'pred': map.cpu().numpy()}) 115 | print('*'*5 + 'FINISH' + '*'*5) 116 | 117 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Raw image -> Superpixel segmentation -> graph 3 | """ 4 | import numpy as np 5 | import torch 6 | import cv2 as cv 7 | from torch_scatter import scatter 8 | from torch_geometric.data import Data 9 | import copy 10 | from torch import nn 11 | 12 | 13 | # Getting adjacent relationship among nodes 14 | def get_edge_index(segment): 15 | if isinstance(segment, torch.Tensor): 16 | segment = segment.numpy() 17 | # 扩张 18 | img = segment.astype(np.uint8) 19 | kernel = np.ones((3,3), np.uint8) 20 | expansion = cv.dilate(img, kernel) 21 | mask = segment == expansion 22 | mask = np.invert(mask) 23 | # 构图 24 | h, w = segment.shape 25 | edge_index = set() 26 | directions = ((-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1)) 27 | indices = list(zip(*np.nonzero(mask))) 28 | for x, y in indices: 29 | for dx, dy in directions: 30 | adj_x, adj_y = x + dx, y + dy 31 | if -1 < adj_x < h and -1 < adj_y < w: 32 | source, target = segment[x, y], segment[adj_x, adj_y] 33 | if source != target: 34 | edge_index.add((source, target)) 35 | edge_index.add((target, source)) 36 | return torch.tensor(list(edge_index), dtype=torch.long).T, edge_index 37 | 38 | 39 | # Getting node features 40 | def get_node(x, segment, mode='mean'): 41 | assert x.ndim == 3 and segment.ndim == 2 42 | if isinstance(x, np.ndarray): 43 | x = torch.from_numpy(x) 44 | if isinstance(segment, np.ndarray): 45 | segment = torch.from_numpy(segment).to(torch.long) 46 | c = x.shape[2] 47 | x = x.reshape((-1, c)) 48 | mask = segment.flatten() 49 | nodes = scatter(x, mask, dim=0, reduce=mode) 50 | return nodes.to(torch.float32) 51 | 52 | 53 | # Constructing graphs by shifting 54 | def get_grid_adj(grid): 55 | edge_index = list() 56 | # 上偏移 57 | a = np.full_like(grid, -1, dtype=np.int32) 58 | a[:-1] = grid[1:] 59 | adj = np.stack([grid, a], axis=-1) 60 | mask = adj != -1 61 | mask = np.logical_and(mask[..., 0], mask[..., 1]) 62 | tmp = adj[mask] 63 | tmp = tmp.tolist() 64 | edge_index += tmp 65 | # 下偏移 66 | a = np.full_like(grid, -1, dtype=np.int32) 67 | a[1:] = grid[:-1] 68 | adj = np.stack([grid, a], axis=-1) 69 | mask = adj != -1 70 | mask = np.logical_and(mask[..., 0], mask[..., 1]) 71 | tmp = adj[mask] 72 | tmp = tmp.tolist() 73 | edge_index += tmp 74 | # 左偏移 75 | a = np.full_like(grid, -1, dtype=np.int32) 76 | a[:, :-1] = grid[:, 1:] 77 | adj = np.stack([grid, a], axis=-1) 78 | mask = adj != -1 79 | mask = np.logical_and(mask[..., 0], mask[..., 1]) 80 | tmp = adj[mask] 81 | tmp = tmp.tolist() 82 | edge_index += tmp 83 | # 右偏移 84 | a = np.full_like(grid, -1, dtype=np.int32) 85 | a[:, 1:] = grid[:, :-1] 86 | adj = np.stack([grid, a], axis=-1) 87 | mask = adj != -1 88 | mask = np.logical_and(mask[..., 0], mask[..., 1]) 89 | tmp = adj[mask] 90 | tmp = tmp.tolist() 91 | edge_index += tmp 92 | return edge_index 93 | 94 | 95 | # Getting graph list 96 | def get_graph_list(data, seg): 97 | graph_node_feature = [] 98 | graph_edge_index = [] 99 | for i in np.unique(seg): 100 | # 获取节点特征 101 | graph_node_feature.append(data[seg == i]) 102 | # 获取邻接信息 103 | x, y = np.nonzero(seg == i) 104 | n = len(x) 105 | x_min, x_max = x.min(), x.max() 106 | y_min, y_max = y.min(), y.max() 107 | grid = np.full((x_max - x_min + 1, y_max - y_min + 1), -1, dtype=np.int32) 108 | x_hat, y_hat = x - x_min, y - y_min 109 | grid[x_hat, y_hat] = np.arange(n) 110 | graph_edge_index.append(get_grid_adj(grid)) 111 | graph_list = [] 112 | # 数据变换 113 | for node, edge_index in zip(graph_node_feature, graph_edge_index): 114 | node = torch.tensor(node, dtype=torch.float) 115 | edge_index = torch.tensor(edge_index, dtype=torch.long).T 116 | graph_list.append(Data(node, edge_index=edge_index)) 117 | return graph_list 118 | 119 | 120 | def split(graph_list, gt, mask): 121 | indices = np.nonzero(gt) 122 | ans = [] 123 | number = mask[indices] 124 | gt = gt[indices] 125 | for i, n in enumerate(number): 126 | graph = copy.deepcopy(graph_list[n]) 127 | graph.y = torch.tensor([gt[i]], dtype=torch.long) 128 | ans.append(graph) 129 | return ans 130 | 131 | 132 | def summary(net: nn.Module): 133 | single_dotted_line = '-' * 40 134 | double_dotted_line = '=' * 40 135 | star_line = '*' * 40 136 | content = [] 137 | def backward(m: nn.Module, chain: list): 138 | children = m.children() 139 | params = 0 140 | chain.append(m._get_name()) 141 | try: 142 | child = next(children) 143 | params += backward(child, chain) 144 | for child in children: 145 | params += backward(child, chain) 146 | # print('*' * 40) 147 | # print('{:>25}{:>15,}'.format('->'.join(chain), params)) 148 | # print('*' * 40) 149 | if content[-1] is not star_line: 150 | content.append(star_line) 151 | content.append('{:>25}{:>15,}'.format('->'.join(chain), params)) 152 | content.append(star_line) 153 | except: 154 | for p in m.parameters(): 155 | if p.requires_grad: 156 | params += p.numel() 157 | # print('{:>25}{:>15,}'.format(chain[-1], params)) 158 | content.append('{:>25}{:>15,}'.format(chain[-1], params)) 159 | chain.pop() 160 | return params 161 | # print('-' * 40) 162 | # print('{:>25}{:>15}'.format('Layer(type)', 'Param')) 163 | # print('=' * 40) 164 | content.append(single_dotted_line) 165 | content.append('{:>25}{:>15}'.format('Layer(type)', 'Param')) 166 | content.append(double_dotted_line) 167 | params = backward(net, []) 168 | # print('=' * 40) 169 | # print('-' * 40) 170 | content.pop() 171 | content.append(single_dotted_line) 172 | print('\n'.join(content)) 173 | return params 174 | 175 | 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/ 12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/ 13 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 14 | - defaults 15 | dependencies: 16 | - _libgcc_mutex=0.1=main 17 | - _pytorch_select=0.1=cpu_0 18 | - attrs=20.3.0=pyhd3eb1b0_0 19 | - blas=1.0=mkl 20 | - blosc=1.20.1=hd408876_0 21 | - brotli=1.0.9=he6710b0_2 22 | - brotlipy=0.7.0=py37h27cfd23_1003 23 | - brunsli=0.1=h2531618_0 24 | - bzip2=1.0.8=h7b6447c_0 25 | - ca-certificates=2022.4.26=h06a4308_0 26 | - cairo=1.14.12=h8948797_3 27 | - certifi=2022.5.18.1=py37h06a4308_0 28 | - cffi=1.14.4=py37h261ae71_0 29 | - chardet=4.0.0=py37h06a4308_1003 30 | - charls=2.1.0=he6710b0_2 31 | - cloudpickle=1.6.0=py_0 32 | - cryptography=3.3.1=py37h3c74f83_0 33 | - cudatoolkit=10.2.89=hfd86e86_1 34 | - cycler=0.10.0=py37_0 35 | - cytoolz=0.11.0=py37h7b6447c_0 36 | - dask-core=2021.1.1=pyhd3eb1b0_0 37 | - dbus=1.13.18=hb2f20db_0 38 | - decorator=4.4.2=py_0 39 | - einops=0.3.0=py_0 40 | - expat=2.3.0=h2531618_2 41 | - ffmpeg=4.0=hcdf2ecd_0 42 | - fontconfig=2.13.0=h9420a91_0 43 | - freeglut=2.8.1=0 44 | - freetype=2.10.4=h5ab3b9f_0 45 | - future=0.18.2=py37_1 46 | - giflib=5.1.4=h14c3975_1 47 | - glib=2.66.1=h92f7085_0 48 | - graphite2=1.3.14=h23475e2_0 49 | - gst-plugins-base=1.14.0=h8213a91_2 50 | - gstreamer=1.14.0=h28cd5cc_2 51 | - harfbuzz=1.8.8=hffaf4a1_0 52 | - hdf5=1.10.2=hba1933b_1 53 | - hyperopt=0.1.1=py_0 54 | - icu=58.2=he6710b0_3 55 | - idna=2.10=pyhd3eb1b0_0 56 | - imagecodecs=2021.1.11=py37h581e88b_1 57 | - imageio=2.9.0=py_0 58 | - importlib-metadata=2.0.0=py_1 59 | - importlib_metadata=2.0.0=1 60 | - iniconfig=1.1.1=pyhd3eb1b0_0 61 | - intel-openmp=2020.2=254 62 | - jasper=2.0.14=h07fcdf6_0 63 | - jpeg=9b=0 64 | - jsonpatch=1.28=pyhd3eb1b0_0 65 | - jsonpath=0.82=py_0 66 | - jsonpointer=2.0=py_0 67 | - jxrlib=1.1=h7b6447c_2 68 | - kiwisolver=1.3.0=py37h2531618_0 69 | - lcms2=2.11=h396b838_0 70 | - ld_impl_linux-64=2.33.1=h53a641e_7 71 | - lerc=2.2.1=h2531618_0 72 | - libaec=1.0.4=he6710b0_1 73 | - libdeflate=1.7=h27cfd23_5 74 | - libedit=3.1.20191231=h14c3975_1 75 | - libffi=3.3=he6710b0_2 76 | - libgcc-ng=9.1.0=hdf63c60_0 77 | - libgfortran-ng=7.3.0=hdf63c60_0 78 | - libglu=9.0.0=hf484d3e_1 79 | - libopencv=3.4.2=hb342d67_1 80 | - libopus=1.3.1=h7b6447c_0 81 | - libpng=1.6.37=hbc83047_0 82 | - libsodium=1.0.18=h7b6447c_0 83 | - libstdcxx-ng=9.1.0=hdf63c60_0 84 | - libtiff=4.1.0=h2733197_1 85 | - libuuid=1.0.3=0 86 | - libuv=1.40.0=h7b6447c_0 87 | - libvpx=1.7.0=h439df22_0 88 | - libwebp=1.0.1=h8e7db2f_0 89 | - libxcb=1.14=h7b6447c_0 90 | - libxml2=2.9.10=hb55368b_3 91 | - libzopfli=1.0.3=he6710b0_0 92 | - lz4-c=1.9.3=h2531618_0 93 | - matplotlib=3.3.2=h06a4308_0 94 | - matplotlib-base=3.3.2=py37h817c723_0 95 | - mkl=2020.2=256 96 | - mkl-service=2.3.0=py37he8ac12f_0 97 | - mkl_fft=1.2.0=py37h23d657b_0 98 | - mkl_random=1.1.1=py37h0573a6f_0 99 | - more-itertools=8.6.0=pyhd3eb1b0_0 100 | - ncurses=6.2=he6710b0_1 101 | - networkx=1.11=py37_1 102 | - ninja=1.7.2=0 103 | - numpy=1.19.2=py37h54aff64_0 104 | - numpy-base=1.19.2=py37hfa32c7d_0 105 | - olefile=0.46=py_0 106 | - opencv=3.4.2=py37h6fd60c2_1 107 | - openjpeg=2.3.0=h05c96fa_1 108 | - openssl=1.1.1o=h7f8727e_0 109 | - packaging=20.8=pyhd3eb1b0_0 110 | - pcre=8.44=he6710b0_0 111 | - pillow=8.1.0=py37he98fc37_0 112 | - pip=20.3.3=py37h06a4308_0 113 | - pixman=0.34.0=0 114 | - pluggy=0.13.1=py37_0 115 | - py=1.10.0=pyhd3eb1b0_0 116 | - py-opencv=3.4.2=py37hb342d67_1 117 | - pycparser=2.20=py_2 118 | - pymongo=3.11.3=py37h2531618_0 119 | - pyopenssl=20.0.1=pyhd3eb1b0_1 120 | - pyparsing=2.4.7=pyhd3eb1b0_0 121 | - pyqt=5.9.2=py37h05f1152_2 122 | - pysocks=1.7.1=py37_1 123 | - pytest=6.2.2=py37h06a4308_1 124 | - python=3.7.9=h7579374_0 125 | - python-dateutil=2.8.1=py_0 126 | - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0 127 | - pywavelets=1.1.1=py37h7b6447c_2 128 | - pyyaml=5.4.1=py37h27cfd23_1 129 | - pyzmq=20.0.0=py37h2531618_1 130 | - qt=5.9.7=h5867ecd_1 131 | - readline=8.0=h7b6447c_0 132 | - requests=2.25.1=pyhd3eb1b0_0 133 | - scikit-image=0.15.0=py37he6710b0_0 134 | - scikit-learn=0.24.1=py37ha9443f7_0 135 | - scipy=1.6.2=py37h91f5cce_0 136 | - setuptools=52.0.0=py37h06a4308_0 137 | - sip=4.19.8=py37hf484d3e_0 138 | - six=1.15.0=py37h06a4308_0 139 | - snappy=1.1.8=he6710b0_0 140 | - sqlite=3.33.0=h62c20be_0 141 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 142 | - tifffile=2021.1.14=pyhd3eb1b0_1 143 | - timm=0.4.12=pyhd8ed1ab_0 144 | - tk=8.6.10=hbc83047_0 145 | - toml=0.10.1=py_0 146 | - toolz=0.11.1=pyhd3eb1b0_0 147 | - torchaudio=0.7.2=py37 148 | - torchfile=0.1.0=py_0 149 | - torchvision=0.8.2=cpu_py37ha229d99_0 150 | - tornado=6.1=py37h27cfd23_0 151 | - typing_extensions=3.7.4.3=pyh06a4308_0 152 | - urllib3=1.26.3=pyhd3eb1b0_0 153 | - visdom=0.1.8.9=0 154 | - websocket-client=0.57.0=py37_2 155 | - wheel=0.36.2=pyhd3eb1b0_0 156 | - xlrd=1.2.0=py37_0 157 | - xz=5.2.5=h7b6447c_0 158 | - yaml=0.2.5=h7b6447c_0 159 | - zeromq=4.3.3=he6710b0_3 160 | - zfp=0.5.5=h2531618_4 161 | - zipp=3.4.0=pyhd3eb1b0_0 162 | - zlib=1.2.11=0 163 | - zstd=1.4.5=h9ceee32_0 164 | - pip: 165 | - ase==3.21.1 166 | - cached-property==1.5.2 167 | - googledrivedownloader==0.4 168 | - h5py==3.1.0 169 | - isodate==0.6.0 170 | - jinja2==2.11.2 171 | - joblib==1.0.0 172 | - llvmlite==0.35.0 173 | - markupsafe==1.1.1 174 | - numba==0.52.0 175 | - openexr==1.3.8 176 | - pandas==1.2.1 177 | - python-louvain==0.15 178 | - pytz==2020.5 179 | - rdflib==5.0.0 180 | - torch-cluster==1.5.8 181 | - torch-geometric==1.6.3 182 | - torch-scatter==2.0.5 183 | - torch-sparse==0.6.8 184 | - torch-spline-conv==1.2.0 185 | - torchsummary==1.5.1 186 | - tqdm==4.56.0 187 | prefix: /home/jsg/anaconda3/envs/pytorch 188 | -------------------------------------------------------------------------------- /Trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, Batch 3 | from torch.optim import optimizer as optimizer_ 4 | from torch_geometric.utils import accuracy 5 | from torch_geometric.nn import DataParallel 6 | from torch.nn.utils import clip_grad_norm_ 7 | import time 8 | 9 | 10 | class JointTrainer(object): 11 | r'''Joint trainer''' 12 | def __init__(self, models: list): 13 | super().__init__() 14 | self.models = models 15 | 16 | def train(self, subGraph: Batch, fullGraph: Data, optimizer, criterion, device, monitor = None, is_l1=False, is_clip=False): 17 | intNet = DataParallel(self.models[0]) 18 | extNet = self.models[1] 19 | intNet.train() 20 | extNet.train() 21 | intNet.to(device) 22 | extNet.to(device) 23 | criterion.to(device) 24 | # Internal graph features 25 | # if subGraph.num_graphs < 5000: 26 | # subGraph = subGraph.to(device) 27 | # fe = intNet(subGraph) 28 | # else: 29 | # batchsz = 5000 30 | # fe = [] 31 | # n_iteration = subGraph.num_graphs // batchsz 32 | # graph_list = subGraph.to_data_list() 33 | # for n in range(n_iteration): 34 | # batch = Batch.from_data_list(graph_list[n * batchsz: (n + 1) * batchsz]) 35 | # batch = batch.to(device) 36 | # f = intNet(batch) 37 | # fe.append(f) 38 | # if subGraph.num_graphs % batchsz != 0: 39 | # batch = Batch.from_data_list(graph_list[(n + 1) * batchsz:]) 40 | # batch = batch.to(device) 41 | # f = intNet(batch) 42 | # fe.append(f) 43 | # fe = torch.cat(fe, dim=0) 44 | # subGraph = subGraph.to(device) 45 | fe = intNet(subGraph.to_data_list()) 46 | 47 | # External graph features 48 | fullGraph.x = fe 49 | fullGraph = fullGraph.to(device) 50 | logits = extNet(fullGraph) 51 | indices = torch.nonzero(fullGraph.tr_gt, as_tuple=True) 52 | y = fullGraph.tr_gt[indices].to(device) - 1 53 | node_number = fullGraph.seg[indices] 54 | pixel_logits = logits[node_number] 55 | loss = criterion(pixel_logits, y) 56 | # l1 norm 57 | if is_l1: 58 | l1 = 0 59 | for p in intNet.parameters(): 60 | l1 += p.norm(1) 61 | loss += 1e-4 * l1 62 | # Back propagation 63 | optimizer.zero_grad() 64 | loss.backward() 65 | # Clipping gradient 66 | if is_clip: 67 | # External gradient 68 | clip_grad_norm_(extNet.parameters(), max_norm=2., norm_type=2) 69 | # Internal gradient 70 | clip_grad_norm_(intNet.parameters(), max_norm=3., norm_type=2) 71 | optimizer.step() 72 | 73 | if monitor is not None: 74 | monitor.add([intNet.parameters(), extNet.parameters()], ord=2) 75 | return loss.item() 76 | 77 | def evaluate(self, subGraph, fullGraph, criterion, device): 78 | intNet = DataParallel(self.models[0]) 79 | extNet = self.models[1] 80 | intNet.eval() 81 | extNet.eval() 82 | intNet.to(device) 83 | extNet.to(device) 84 | criterion.to(device) 85 | with torch.no_grad(): 86 | # subGraph = subGraph.to(device) 87 | fe = intNet(subGraph.to_data_list()) 88 | # if subGraph.num_graphs < 5000: 89 | # subGraph = subGraph.to(device) 90 | # fe = intNet(subGraph) 91 | # else: 92 | # batchsz = 5000 93 | # fe = [] 94 | # n_iteration = subGraph.num_graphs // batchsz 95 | # graph_list = subGraph.to_data_list() 96 | # for n in range(n_iteration): 97 | # batch = Batch.from_data_list(graph_list[n * batchsz: (n + 1) * batchsz]) 98 | # batch = batch.to(device) 99 | # f = intNet(batch) 100 | # fe.append(f) 101 | # if subGraph.num_graphs % batchsz != 0: 102 | # batch = Batch.from_data_list(graph_list[(n + 1) * batchsz:]) 103 | # batch = batch.to(device) 104 | # f = intNet(batch) 105 | # fe.append(f) 106 | # fe = torch.cat(fe, dim=0) 107 | fullGraph.x = fe 108 | fullGraph = fullGraph.to(device) 109 | logits = extNet(fullGraph) 110 | pred = torch.argmax(logits, dim=-1) 111 | indices = torch.nonzero(fullGraph.te_gt, as_tuple=True) 112 | y = fullGraph.te_gt[indices].to(device) - 1 113 | node_number = fullGraph.seg[indices] 114 | pixel_pred = pred[node_number] 115 | pixel_logits = logits[node_number] 116 | loss = criterion(pixel_logits, y) 117 | return loss.item(), accuracy(pixel_pred, y) 118 | 119 | # Getting prediction results 120 | def predict(self, subGraph, fullGraph, device: torch.device): 121 | intNet = DataParallel(self.models[0]) 122 | extNet = self.models[1] 123 | intNet.eval() 124 | extNet.eval() 125 | intNet.to(device) 126 | extNet.to(device) 127 | # begin_time = time.time() 128 | with torch.no_grad(): 129 | # Internal graph features 130 | fe = intNet(subGraph.to_data_list()) 131 | 132 | # External graph features 133 | fullGraph.x = fe 134 | fullGraph = fullGraph.to(device) 135 | logits = extNet(fullGraph) 136 | pred = torch.argmax(logits, dim=-1) 137 | # end_time = time.time() 138 | # print(f"time: {end_time - begin_time}") 139 | # exit(0) 140 | # indices = torch.nonzero(fullGraph, as_tuple=True) 141 | # node_number = fullGraph.seg[indices] 142 | # pixel_pred = pred[node_number] 143 | 144 | return pred 145 | 146 | # Getting hidden features 147 | def getHiddenFeature(self, subGraph, fullGraph, device, gt = None, seg = None): 148 | intNet = DataParallel(self.models[0]) 149 | extNet = self.models[1] 150 | intNet.eval() 151 | extNet.eval() 152 | intNet.to(device) 153 | extNet.to(device) 154 | with torch.no_grad(): 155 | fe = intNet(subGraph.to_data_list()) 156 | fullGraph.x = fe 157 | fullGraph = fullGraph.to(device) 158 | fe = extNet(fullGraph) 159 | if gt is not None and seg is not None: 160 | indices = torch.nonzero(gt, as_tuple=True) 161 | gt = gt[indices] - 1 162 | node_number = seg[indices].to(device) 163 | fe = fe[node_number] 164 | return fe.cpu(), gt 165 | else: 166 | return fe.cpu() 167 | 168 | def get_parameters(self): 169 | return self.models[0].parameters(), self.models[1].parameters() 170 | 171 | def save(self, paths): 172 | torch.save(self.models[0].cpu().state_dict(), paths[0]) 173 | torch.save(self.models[1].cpu().state_dict(), paths[1]) 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | '''Training''' 2 | from scipy.io import loadmat 3 | import numpy as np 4 | import argparse 5 | import configparser 6 | import torch 7 | from torch import nn 8 | from skimage.segmentation import slic 9 | from torch_geometric.data import Data, Batch 10 | from sklearn.preprocessing import scale, minmax_scale 11 | import os 12 | from PIL import Image 13 | from utils import get_graph_list, split, get_edge_index 14 | import math 15 | from Model.module import SubGcnFeature, GraphNet 16 | from Trainer import JointTrainer 17 | from Monitor import GradMonitor 18 | from visdom import Visdom 19 | from tqdm import tqdm 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description='TRAIN SUBGRAPH') 23 | parser.add_argument('--name', type=str, default='PaviaU', 24 | help='DATASET NAME') 25 | parser.add_argument('--block', type=int, default=100, 26 | help='BLOCK SIZE') 27 | parser.add_argument('--epoch', type=int, default=1, 28 | help='ITERATION') 29 | parser.add_argument('--gpu', type=int, default=-1, 30 | help='GPU ID') 31 | parser.add_argument('--comp', type=int, default=10, 32 | help='COMPACTNESS') 33 | parser.add_argument('--batchsz', type=int, default=64, 34 | help='BATCH SIZE') 35 | parser.add_argument('--run', type=int, default=10, 36 | help='EXPERIMENT AMOUNT') 37 | parser.add_argument('--spc', type=int, default=10, 38 | help='SAMPLE per CLASS') 39 | parser.add_argument('--hsz', type=int, default=128, 40 | help='HIDDEN SIZE') 41 | parser.add_argument('--lr', type=float, default=1e-3, 42 | help='LEARNING RATE') 43 | parser.add_argument('--wd', type=float, default=0., 44 | help='WEIGHT DECAY') 45 | arg = parser.parse_args() 46 | config = configparser.ConfigParser() 47 | config.read('dataInfo.ini') 48 | viz = Visdom(port=17000) 49 | 50 | # Data processing 51 | # Reading hyperspectral image 52 | data_path = 'data/{0}/{0}.mat'.format(arg.name) 53 | m = loadmat(data_path) 54 | data = m[config.get(arg.name, 'data_key')] 55 | gt_path = 'data/{0}/{0}_gt.mat'.format(arg.name) 56 | m = loadmat(gt_path) 57 | gt = m[config.get(arg.name, 'gt_key')] 58 | # Normalizing data 59 | h, w, c = data.shape 60 | data = data.reshape((h * w, c)) 61 | data = data.astype(np.float) 62 | if arg.name == 'Xiongan': 63 | minmax_scale(data, copy=False) 64 | data_normalization = scale(data).reshape((h, w, c)) 65 | 66 | # Superpixel segmentation 67 | seg_root = 'data/rgb' 68 | seg_path = os.path.join(seg_root, '{}_seg_{}.npy'.format(arg.name, arg.block)) 69 | if os.path.exists(seg_path): 70 | seg = np.load(seg_path) 71 | else: 72 | rgb_path = os.path.join(seg_root, '{}_rgb.jpg'.format(arg.name)) 73 | img = Image.open(rgb_path) 74 | img_array = np.array(img) 75 | # The number of superpixel 76 | n_superpixel = int(math.ceil((h * w) / arg.block)) 77 | seg = slic(img_array, n_superpixel, arg.comp) 78 | # Saving 79 | np.save(seg_path, seg) 80 | 81 | # Constructing graphs 82 | graph_path = 'data/{}/{}_graph.pkl'.format(arg.name, arg.block) 83 | if os.path.exists(graph_path): 84 | graph_list = torch.load(graph_path) 85 | else: 86 | graph_list = get_graph_list(data_normalization, seg) 87 | torch.save(graph_list, graph_path) 88 | subGraph = Batch.from_data_list(graph_list) 89 | 90 | # Constructing full graphs 91 | full_edge_index_path = 'data/{}/{}_edge_index.npy'.format(arg.name, arg.block) 92 | if os.path.exists(full_edge_index_path): 93 | edge_index = np.load(full_edge_index_path) 94 | else: 95 | edge_index, _ = get_edge_index(seg) 96 | np.save(full_edge_index_path, edge_index if isinstance(edge_index, np.ndarray) else edge_index.cpu().numpy()) 97 | fullGraph = Data(None, 98 | edge_index=torch.from_numpy(edge_index) if isinstance(edge_index, np.ndarray) else edge_index, 99 | seg=torch.from_numpy(seg) if isinstance(seg, np.ndarray) else seg) 100 | 101 | for r in range(arg.run): 102 | print('*'*5 + 'Run {}'.format(r) + '*'*5) 103 | # Reading the training data set and testing data set 104 | m = loadmat('trainTestSplit/{}/sample{}_run{}.mat'.format(arg.name, arg.spc, r)) 105 | tr_gt, te_gt = m['train_gt'], m['test_gt'] 106 | tr_gt_torch, te_gt_torch = torch.from_numpy(tr_gt).long(), torch.from_numpy(te_gt).long() 107 | fullGraph.tr_gt, fullGraph.te_gt = tr_gt_torch, te_gt_torch 108 | 109 | gcn1 = SubGcnFeature(config.getint(arg.name, 'band'), arg.hsz) 110 | gcn2 = GraphNet(arg.hsz, arg.hsz, config.getint(arg.name, 'nc')) 111 | optimizer = torch.optim.Adam([{'params': gcn1.parameters()}, 112 | {'params': gcn2.parameters()}], 113 | weight_decay=arg.wd) 114 | criterion = nn.CrossEntropyLoss() 115 | trainer = JointTrainer([gcn1, gcn2]) 116 | monitor = GradMonitor() 117 | 118 | # Plotting a learning curve and gradient curve 119 | viz.line([[0., 0., 0.]], [0], win='{}_train_test_acc_{}'.format(arg.name, r), 120 | opts={'title': '{} train&test&acc {}'.format(arg.name, r), 121 | 'legend': ['train', 'test', 'acc']}) 122 | viz.line([[0., 0.]], [0], win='{}_grad_{}'.format(arg.name, r), opts={'title': '{} grad {}'.format(arg.name, r), 123 | 'legend': ['internal', 'external']}) 124 | 125 | device = torch.device('cuda:{}'.format(arg.gpu)) if arg.gpu != -1 else torch.device('cpu') 126 | max_acc = 0 127 | save_root = 'models/{}/{}/{}_overall_skip_2_SGConv_l1_clip'.format(arg.name, arg.spc, arg.block) 128 | pbar = tqdm(range(arg.epoch)) 129 | # Training 130 | for epoch in pbar: 131 | pbar.set_description_str('Epoch: {}'.format(epoch)) 132 | tr_loss = trainer.train(subGraph, fullGraph, optimizer, criterion, device, monitor.clear(), is_l1=True, is_clip=True) 133 | te_loss, acc = trainer.evaluate(subGraph, fullGraph, criterion, device) 134 | pbar.set_postfix_str('train loss: {} test loss:{} acc:{}'.format(tr_loss, te_loss, acc)) 135 | viz.line([[tr_loss, te_loss, acc]], [epoch], win='{}_train_test_acc_{}'.format(arg.name, r), update='append') 136 | viz.line([monitor.get()], [epoch], win='{}_grad_{}'.format(arg.name, r), update='append') 137 | 138 | if acc > max_acc: 139 | max_acc = acc 140 | if not os.path.exists(save_root): 141 | os.makedirs(save_root) 142 | trainer.save([os.path.join(save_root, 'intNet_best_{}_{}.pkl'.format(arg.spc, r)), 143 | os.path.join(save_root, 'extNet_best_{}_{}.pkl'.format(arg.spc, r))]) 144 | print('*'*5 + 'FINISH' + '*'*5) 145 | 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph-in-Graph Convolutional Network for Hyperspectral Image Classification (TNNLS 2022) 2 | 7 | 8 | [Sen Jia](https://scholar.google.com.hk/citations?hl=zh-CN&user=UxbDMKoAAAAJ), [Shuguo Jiang](https://scholar.google.com.hk/citations?hl=zh-CN&user=B1YTGUgAAAAJ), [Shuyu Zhang](https://scholar.google.com.hk/citations?hl=zh-CN&user=O48TQQ4AAAAJ), [Meng Xu](https://scholar.google.com.hk/citations?hl=zh-CN&user=Hw1TFzQAAAAJ), [Xiuping Jia]() 9 | 10 | 17 | 18 | 19 | 20 |
21 | 22 | > **Abstract:** *With the development of hyperspectral sensors, accessible hyperspectral images (HSIs) are increasing, and pixel-oriented classification has attracted much attention. Recently, graph convolutional networks (GCN) have been proposed to process graph-structured data in non-euclidean domains, and have been employed in HSI classification. But most methods based on GCN are hard to sufficiently exploit information of ground objects due to feature aggregation. To solve this issue, in this paper, we proposed a graph-in-graph (GiG) model and a related GiG convolutional network (GiGCN) for HSI classification from superpixel viewpoint. The graph-in-graph representation covers information inside and outside superpixels, respectively corresponding to the local and global characteristics of ground objects. Concretely, after segmenting HSI into disjoint superpixels, each one is converted to an internal graph. Meanwhile, an external graph is constructed according to the spatial adjacent relationships among superpixels. Significantly, each node in the external graph embeds a corresponding internal graph, forming the so-called graph-in-graph structure. Then, GiGCN composed of internal and external graph convolution is designed to extract hierarchical features and integrate them into multiple scales, improving the discriminability of GiGCN. Ensemble learning is incorporated to further boost the robustness of GiGCN. It is worth noting that we are the first to propose graph-in-graph framework from superpixel point and the GiGCN scheme for HSI classification. Experiment results on four benchmark data sets demonstrate that our proposed method is effective and feasible for HSI classification with limited labeled samples. For study replication, the code developed for this study is available at https://github.com/ShuGuoJ/GiGCN.git.* 23 |
24 | 25 | 26 | 27 | ## Network Architecture 28 | 29 |
30 | 31 | This is our ensemble flowchart. With the ensmeble learning, GiGCN is more robust in various scenes. 32 | 33 | 34 | 35 | 36 | ## Comparison with State-of-the-art Methods 37 | 38 | 39 | 40 | In comparative experiments, GiGCN is compared with other eight state-of-the-art methods for hyperspectral image classification. Ten samples for each class is chosen for training models and the others are used to test. To alleviate biases, the above operation is run ten times. 41 | 42 | 43 | 59 | 60 | 61 | 62 | 113 | 114 | 115 | 116 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 |
PaviaUSalinasGF5YC
OA (%)AA (%)kappaOA (%)AA (%)kappaOA (%)AA (%)kappaOA (%)AA (%)kappa
Two-CNN78.2076.020.7176.3785.250.7487.0882.370.8585.2081.180.82
3DVSCNN76.4375.470.6989.1794.070.8884.9979.020.8377.3375.140.73
HSGAN72.1774.660.6483.4489.670.8288.8481.500.8792.7488.460.91
SSLstm69.5972.770.6281.2087.040.7970.2958.550.6689.0178.200.87
MDGCN75.4479.750.6993.4995.600.9385.2375.600.8392.2888.090.91
S-DMM83.7790.980.7988.5394.530.8788.7284.200.8781.1782.850.78
3DCAE59.1471.580.5172.6475.580.7070.5860.980.6670.6581.470.68
MDL4OW76.5581.420.7082.4490.570.8187.4683.410.8693.6594.980.92
GiGCN93.5194.070.9297.3498.340.9792.5086.520.9197.5195.880.97
276 | 277 | Our GiGCN siginificantly outperforms other methods especially when labeled samples are a few. 278 | 279 | Note: access code for `Baidu Disk` is `mst1`. 280 | 281 | ## 1. Create Envirement: 282 | 283 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 284 | 285 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 286 | 287 | - Python packages: 288 | 289 | ```shell 290 | cd graph-in-graph 291 | pip install -r requirements.txt 292 | ``` 293 | 294 | ## 2. Data Preparation: 295 | - Download the data including raw `.mat` files and corresponding `.jpg` files used in superpixel segmentation from here (code: 4zyf) for a quick start and place them in `GiGCN/`. 296 | 297 | - Before trainig, every data set is split by runing `trainTestSplit.py`, shown as follow: 298 | 299 | ```shell 300 | python trainTestSplit.py --name PaviaU (data set name) 301 | ``` 302 | 303 | ## 3. Training 304 | 305 | To train a model, run 306 | 307 | ```shell 308 | # Training on PaviaU data set 309 | python train.py --name PaviaU --block 100 --gpu 0 310 | ``` 311 | Here, `--block` denots the number of superpixel, which lies in `[50, 100, 150, 200]` in our ensemble setup. 312 | 313 | The model with best accuracy will be saved. 314 | 315 | Note: The `scikit-image` package in our experimental configuaration is of version 0.15.0 whose parameter `start_label` defaults to 0. However, in the lastest version, it defaults to 1. So when encountering the problem that indexes are out of the bounder at `Line 54` in `Trainer.py`, you should set `start_label` as 0 explicitly. 316 | 317 | ## 4. Prediction: 318 | 319 | To test a trained model, run 320 | 321 | ```shell 322 | # Testing on PaviaU data set 323 | python predict.py --name PaviaU --block 100 --gpu 0 324 | ``` 325 | The code will load the best model in the last phase automatically. 326 | 327 | 328 | ## Citation 329 | If this repo helps you, please consider citing our works: 330 | 331 | 332 | ``` 333 | @ARTICLE{9801664, 334 | author={Jia, Sen and Jiang, Shuguo and Zhang, Shuyu and Xu, Meng and Jia, Xiuping}, 335 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 336 | title={Graph-in-Graph Convolutional Network for Hyperspectral Image Classification}, 337 | year={2022}, 338 | volume={}, 339 | number={}, 340 | pages={1-15}, 341 | } 342 | ``` 343 | --------------------------------------------------------------------------------