├── Workflow.png
├── cg_generation.png
├── LICENSE
├── README.md
├── Binary_data_preprocessing.ipynb
├── Ternary_data_preprocessing.ipynb
└── scdiag_gintopk_roc_ternary.ipynb
/Workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Docurdt/Cell-Graph_Signature/HEAD/Workflow.png
--------------------------------------------------------------------------------
/cg_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Docurdt/Cell-Graph_Signature/HEAD/cg_generation.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yanan Wang & Yuguang Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cell-Graph Signature (CGSignature)
2 | #### GNN-based cancer prognosis prediction using image-derived Cell-Graphs
3 | > Git repo for the manuscript of "Cell graph neural networks enable digital staging of tumour microenvironment and precisely predict patient survival in gastric cancer".
4 |
5 | ### Abstract
6 | Gastric cancer is one of the deadliest cancers worldwide. Accurate prognosis is essential for effective clinical assessment and treatment. Spatial patterns in the tumor microenvironment (TME) are conceptually indicative of the staging and progression of gastric cancer patients. Using spatial patterns of the TME by integrating and transforming the multiplexed immunohistochemistry (mIHC) images as Cell-Graphs, we propose a novel graph neural network-based approach, termed **Cell-Graph Signature** or **CGSignature**, powered by artificial intelligence, for digital staging of TME and precise prediction of patient survival in gastric cancer. In this study, patient survival prediction is formulated as either a binary (**short-term** and **long-term**) or ternary (**short-term**, **medium-term**, and **long-term**) classification task. Extensive benchmarking experiments demonstrate that the CGSignature achieves outstanding model performance, with Area Under the Receiver-Operating Characteristic curve (AUROC) of 0.960 ± 0.01, and 0.771 ± 0.024 to 0.904 ± 0.012 for the binary- and ternary-classification, respectively. Moreover, Kaplan-Meier survival analysis indicates that the 'digital-grade' cancer staging produced by CGSignature provides a remarkable capability in discriminating both binary and ternary classes with statistical significance (P-value < 0.0001), significantly outperforming the AJCC 8th edition Tumor-Node-Metastasis staging system. Using Cell-Graphs extracted from mIHC images, CGSignature improves the assessment of the link between the TME spatial patterns and patient prognosis. Our study suggests the feasibility and benefits of such artificial intelligence-powered digital staging system in diagnostic pathology and precision oncology.
7 |
8 | ### Workflow
9 | 
10 | **Figure 1. An overall workflow of graph neural network-based prognosis prediction using Cell-Graphs.**
11 | **(a)** Specimen processing: The tumor tissues were extracted from gastric cancer, and stained with seven different biomarkers including DAPI, Pan-CK, CD8, CD68, CD163/CD45, Foxp3, and PD-L1. **(b)** Image pre-processing: sub-sampling and cell-graph construction were conducted for image pre-processing. **(c)** An illustration for the cohort, 172 gastric cancer patients were collected. **(d)** Data split. The training, validation and testing datasets were split with the percentages of 64%, 16%, and 20%, respectively. **(f)** Data binning: overall survival time ranged from 0 to 88 months, and two data binning strategies were applied to generate binary- and ternary-class datasets. **(e)** Model construction: four different GNN model architectures, including GCNSag, GCNTopK, GINSag, and GINTopK, were constructed and compared. Multi-run model training, five-fold cross-validation, and independent test were conducted to evaluate the performance of the constructed GNN models. **(g)** Model architecture: The four models shared the same architecture but employed different types of convolutional unit and pooling layer, which consists of four consecutive convolutional layer and pooling layer blocks, followed by a summary layer and three fully-connected layers, prior to the generation of the final classification outcome. Architecture of the best-performing GINTopK model is illustrated herein, which outperformed the other three model architectures and also achieved the best performance on the test dataset. The corresponding number of hidden layers or feature dimensions are indicated at the bottom of each box. Here, FC stands for "fully connected layer".
12 |
13 | 
14 | **Figure 2. An overview of procedures from multiplexed staining to Cell-Graph generation.** Detailed information can be seen in the descriptions of the figure.
15 |
--------------------------------------------------------------------------------
/Binary_data_preprocessing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#!/usr/bin/env python3\n",
10 | "# -*- coding: utf-8 -*-\n",
11 | "\"\"\"\n",
12 | "Created on Thu Aug 8 01:49:00 2019\n",
13 | "\n",
14 | "@author: Yuguang Wang & Yanan Wang\n",
15 | "\"\"\"\n",
16 | "\n",
17 | "import pandas as pd\n",
18 | "import numpy as np\n",
19 | "import scipy.io as sio\n",
20 | "import os\n",
21 | "from scipy.sparse import csr_matrix"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "****** Processing data for data_file/test_data_surv.xlsx ******\n"
34 | ]
35 | },
36 | {
37 | "name": "stderr",
38 | "output_type": "stream",
39 | "text": [
40 | "/home/song-lab/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:123: RuntimeWarning: divide by zero encountered in double_scalars\n",
41 | "/home/song-lab/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:124: RuntimeWarning: divide by zero encountered in double_scalars\n"
42 | ]
43 | }
44 | ],
45 | "source": [
46 | "def create_samples(file_name):\n",
47 | " subdir = '_{}/'.format(file_name.split(\".\")[0])\n",
48 | " # critical distance 20 \\mu m, pix width 0.5 \\mu m\n",
49 | " critical = 20/0.5\n",
50 | " num_node = 100\n",
51 | " adj = list()\n",
52 | " feature = list()\n",
53 | " label = list()\n",
54 | " pid = list()\n",
55 | " pid_name = list()\n",
56 | " coor = list()\n",
57 | " edge_index = list()\n",
58 | " edge_coor = list()\n",
59 | " edge_attr = list()\n",
60 | " factor_flag = 'selected_new'\n",
61 | " #factor_flag = 'selected'\n",
62 | " #factor_flag = 'full'\n",
63 | " sv_dir = \"data/\" + factor_flag + subdir\n",
64 | " if not os.path.exists(sv_dir):\n",
65 | " os.makedirs(sv_dir)\n",
66 | " \n",
67 | " print('****** Processing data for %s ******' % file_name)\n",
68 | " df_diagnosis = pd.read_excel(file_name)\n",
69 | " id_participant = df_diagnosis['Id']\n",
70 | " num_id_1 = len(id_participant)\n",
71 | " survival_time_all = df_diagnosis['survival_time'].array\n",
72 | " \n",
73 | " for kid, row in df_diagnosis.iterrows():\n",
74 | " ld_csv = row['file_path']\n",
75 | " old_new_flag = ld_csv.split(\"/\")[2]\n",
76 | " #print(ld_csv)\n",
77 | " #print(old_new_flag)\n",
78 | " \n",
79 | " label1 = row['prognosis_label']\n",
80 | " df_csv = pd.read_csv(ld_csv)\n",
81 | " # extract the features\n",
82 | " if factor_flag == 'full': # from \"DAPI (DAPI) Nucleus Intensity\" to end\n",
83 | " lv = np.linspace(22,59,38,dtype=int)\n",
84 | " if factor_flag == 'selected_old':\n",
85 | " lv = np.array([11,12,18,35,36,37,38,39],dtype=np.int)\n",
86 | " if(old_new_flag == 'stomach_csv_1' and factor_flag == 'selected_new'):\n",
87 | " lv = np.concatenate((np.linspace(20,44,25,dtype=int),np.linspace(50,59,10,dtype=int)))\n",
88 | " elif(old_new_flag == 'stomach_csv_2' and factor_flag == 'selected_new'):\n",
89 | " lv = np.concatenate((np.linspace(32,56,25,dtype=int),np.linspace(62,71,10,dtype=int)))\n",
90 | " feature_all = df_csv.take(lv,axis=1).values\n",
91 | " #print(feature_all)\n",
92 | " feature_all = feature_all/(feature_all.max(axis=0)+0.00000000000001) # Add normalization\n",
93 | " #print(feature_all)\n",
94 | " \n",
95 | " \n",
96 | " # compute adjacency matrix\n",
97 | " xmin = np.array(df_csv['XMin'])\n",
98 | " xmax = np.array(df_csv['XMax'])\n",
99 | " ymin = np.array(df_csv['YMin'])\n",
100 | " ymax = np.array(df_csv['YMax'])\n",
101 | " num_cell = len(xmin)\n",
102 | " if np.mod(num_cell,num_node)==0:\n",
103 | " num_graph = int(num_cell/num_node)\n",
104 | " else:\n",
105 | " num_graph = int(num_cell/num_node)+1\n",
106 | " # compute the centre coordinates of cells\n",
107 | " xc = (xmin+xmax)/2\n",
108 | " yc = (ymin+ymax)/2\n",
109 | " #% compute the adjacency matrix for each graph\n",
110 | " # deal with the graphs except the last\n",
111 | " for i in range(num_graph-1):\n",
112 | " A = np.zeros([num_node,num_node])\n",
113 | " coor1 = list()\n",
114 | " coor1.append(xc[i*num_node:(i+1)*num_node])\n",
115 | " coor1.append(yc[i*num_node:(i+1)*num_node])\n",
116 | " coor1 = np.reshape(np.array(coor1),[num_node,2])\n",
117 | " edge_coor_1 = list()\n",
118 | " edge_index_1 =list()\n",
119 | " edge_attr_1 = list()\n",
120 | " for k in range(num_node):\n",
121 | " for j in range(k+1,num_node):\n",
122 | " # turn to global coordinates\n",
123 | " k1 = i*num_node + k\n",
124 | " j1 = i*num_node + j\n",
125 | " dist = np.sqrt((xc[k1]-xc[j1])**2+(yc[k1]-yc[j1])**2)\n",
126 | " if dist0:\n",
154 | " num_node_last = int(np.mod(num_cell,num_node))\n",
155 | " else:\n",
156 | " num_node_last = num_node\n",
157 | " coor1 = list()\n",
158 | " coor1.append(xc[(i+1)*num_node:])\n",
159 | " coor1.append(yc[(i+1)*num_node:])\n",
160 | " coor1 = np.reshape(np.array(coor1),[num_node_last,2])\n",
161 | " A = np.zeros([num_node_last,num_node_last])\n",
162 | " for k in range(num_node_last):\n",
163 | " for j in range(k,num_node_last):\n",
164 | " dist = np.sqrt((xc[k1]-xc[j1])**2+(yc[k1]-yc[j1])**2)\n",
165 | " k1 = i*num_node + k\n",
166 | " j1 = i*num_node + j\n",
167 | " if dist0:\n",
170 | " num_node_last = int(np.mod(num_cell,num_node))\n",
171 | " else:\n",
172 | " num_node_last = num_node\n",
173 | " coor1 = list()\n",
174 | " coor1.append(xc[(i+1)*num_node:])\n",
175 | " coor1.append(yc[(i+1)*num_node:])\n",
176 | " coor1 = np.reshape(np.array(coor1),[num_node_last,2])\n",
177 | " A = np.zeros([num_node_last,num_node_last])\n",
178 | " for k in range(num_node_last):\n",
179 | " for j in range(k,num_node_last):\n",
180 | " dist = np.sqrt((xc[k1]-xc[j1])**2+(yc[k1]-yc[j1])**2)\n",
181 | " k1 = i*num_node + k\n",
182 | " j1 = i*num_node + j\n",
183 | " if dist=4 else class_num\n",
59 | "\n",
60 | " for label_col in range(class_num):\n",
61 | " y_true_label = y_true[:, label_col]\n",
62 | " y_pred_label = y_pred[:, label_col]\n",
63 | "\n",
64 | " print(y_true_label)\n",
65 | " print(y_pred_label)\n",
66 | " conf_mat_dict[labels[label_col]] = confusion_matrix(y_pred=y_pred_label, y_true=y_true_label)\n",
67 | "\n",
68 | "\n",
69 | " fig, axes = plt.subplots(nrows=plot_rows, ncols=plot_cols, sharex=False, sharey=False,gridspec_kw = {'wspace':0.5, 'hspace':0.05},figsize=(10,10))\n",
70 | " axes = trim_axs(axes, class_num)\n",
71 | " for ii in range(len(labels)):\n",
72 | " _label = labels[ii]\n",
73 | " _matrix = conf_mat_dict[_label]\n",
74 | " axes[ii].imshow(_matrix,interpolation='nearest', cmap=plt.cm.Blues)\n",
75 | " axes[ii].set(xticks=np.arange(_matrix.shape[1]),\n",
76 | " yticks=np.arange(_matrix.shape[0]),\n",
77 | " # ... and label them with the respective list entries\n",
78 | " xticklabels=[\"Neg\",\"Pos\"], yticklabels=[\"Neg\",\"Pos\"],\n",
79 | " title=_label,\n",
80 | " ylabel='True label',\n",
81 | " xlabel='Predicted label')\n",
82 | " fmt = 'd'\n",
83 | " thresh = _matrix.max() / 2.\n",
84 | " for i in range(_matrix.shape[0]):\n",
85 | " for j in range(_matrix.shape[1]):\n",
86 | " axes[ii].text(j, i, format(_matrix[i, j], fmt),\n",
87 | " ha=\"center\", va=\"center\", fontsize=8,\n",
88 | " color=\"white\" if _matrix[i, j] > thresh else \"black\")\n",
89 | "\n",
90 | " plt.savefig(_save_path, dpi=100,pad_inches = 0.1,bbox_inches = 'tight')\n",
91 | "\n",
92 | "\n",
93 | "# In[ ]:\n",
94 | "\n",
95 | "def calculate_metrics(gts, ops, preds, class_num, labels, outputs, mode):\n",
96 | " if mode:\n",
97 | " gts = np.vstack([gts, labels.cpu()]) if gts.size else labels.cpu()\n",
98 | " y_pred = outputs.unsqueeze(1)\n",
99 | " y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)\n",
100 | " y_pred = torch.max(y_pred, dim=1)[1]\n",
101 | " # print(\"Predict is %s\"%y_pred)\n",
102 | " preds = np.vstack([preds, y_pred.cpu()]) if preds.size else y_pred.cpu()\n",
103 | " else:\n",
104 | " _labels = labels.cpu()\n",
105 | " tmp = torch.zeros(len(_labels), class_num)\n",
106 | " for idx, ele in enumerate(_labels):\n",
107 | " tmp[idx][ele] = 1\n",
108 | " gts = np.vstack([gts, tmp]) if gts.size else tmp\n",
109 | " view = outputs.view(-1, class_num)\n",
110 | " y_pred = (view == view.max(dim=1, keepdim=True)[0]).view_as(outputs).type(torch.ByteTensor)\n",
111 | " # y_pred = torch.max(outputs, 1)[1].view(labels.size())\n",
112 | " # y_pred = np.argmax(y_pred.cpu())\n",
113 | " # print(y_pred)\n",
114 | " preds = np.vstack([preds, y_pred.cpu()]) if preds.size else y_pred.cpu()\n",
115 | "\n",
116 | " acc_list = []\n",
117 | " auc_list = []\n",
118 | " f1 = f1_score(gts, preds, average=\"micro\")\n",
119 | " for j in range(0, class_num):\n",
120 | " gts_i = gts[:,j]\n",
121 | " preds_i = preds[:,j]\n",
122 | " ops_i = ops[:,j]\n",
123 | " fpr, tpr, thresholds = roc_curve(gts_i, ops_i)\n",
124 | " acc_score = accuracy_score(gts_i, preds_i)\n",
125 | " auc_score = auc(fpr, tpr)\n",
126 | " acc_list.append(acc_score)\n",
127 | " auc_list.append(auc_score)\n",
128 | " print(\"class_num: %d, acc_score: %f, auc_score: %f\"%(j, acc_score, auc_score))\n",
129 | " return acc_list, auc_list, f1, gts, ops, preds\n",
130 | "\n",
131 | "\n",
132 | "def plot_confusion_matrix(_model, y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues):\n",
133 | "\n",
134 | " plot_multi_label_confusion_matrix('/home/yuguang/cellstar/figures/%s_Confusion_matrix.png' % _model, y_true, y_pred, classes)\n",
135 | "\n",
136 | " \n",
137 | "def plot_roc_curve(pred_y, test_y, class_label, n_classes, fig_name=\"roc_auc.png\"):\n",
138 | " #pred_y = pred_y/pred_y.max(axis=0)\n",
139 | " colors = [\"#E69F00\", \"#56B4E9\", \"#009E73\", \"#F0E442\", \"#0072B2\", \"#D55E00\", \"#CC79A7\", \"#000000\", \"#66CC99\", \"#999999\"]\n",
140 | " plt.close('all')\n",
141 | " plt.style.use(\"ggplot\")\n",
142 | " matplotlib.rcParams['font.family'] = \"Arial\"\n",
143 | " plt.figure(figsize=(8, 8), dpi=400)\n",
144 | " for i in range(n_classes):\n",
145 | " _tmp_pred = pred_y\n",
146 | " _tmp_label = test_y\n",
147 | " #print(_tmp_label[:, 0], _tmp_pred[:, 0])\n",
148 | " _fpr, _tpr, _ = roc_curve(_tmp_label[:, i], _tmp_pred[:, i])\n",
149 | " _auc = auc(_fpr, _tpr)\n",
150 | " plt.plot(_fpr, _tpr, color=colors[i],\n",
151 | " label=r'%s ROC (AUC = %0.3f)' % (class_label[i], _auc), lw=2, alpha=.9)\n",
152 | " plt.plot([0, 1], [0, 1], 'k--', lw=2)\n",
153 | " plt.xlim([0.0, 1.01])\n",
154 | " plt.ylim([0.0, 1.01])\n",
155 | " plt.xlabel('False Positive Rate')\n",
156 | " plt.ylabel('True Positive Rate')\n",
157 | " #plt.title('ROC curve of')\n",
158 | " plt.legend(loc=\"lower right\")\n",
159 | " plt.savefig(fig_name, dpi=400)\n",
160 | " plt.close('all')\n",
161 | "\n",
162 | "##Define Model Class\n",
163 | "class GCNTopK(torch.nn.Module):\n",
164 | " def __init__(self, num_feature, num_class, nhid=256, pooling_ratio=0.75):\n",
165 | " super(GCNTopK, self).__init__()\n",
166 | " self.nhid = nhid\n",
167 | " self.pooling_ratio = pooling_ratio\n",
168 | " self.conv1 = GraphConv(int(num_feature), self.nhid)\n",
169 | " self.pool1 = TopKPooling(self.nhid, ratio = self.pooling_ratio) # edited by Ming with concern for further extension\n",
170 | " self.conv2 = GraphConv(self.nhid, self.nhid)\n",
171 | " self.pool2 = TopKPooling(self.nhid, ratio = self.pooling_ratio)\n",
172 | " self.conv3 = GraphConv(self.nhid, self.nhid)\n",
173 | " self.pool3 = TopKPooling(self.nhid, ratio = self.pooling_ratio)\n",
174 | " #add one more conv-pooling block, i.e., conv4 and pool4\n",
175 | " self.conv4 = GraphConv(self.nhid, self.nhid)\n",
176 | " self.pool4 = TopKPooling(self.nhid, ratio = self.pooling_ratio)\n",
177 | "\n",
178 | " self.lin1 = torch.nn.Linear(self.nhid*2, self.nhid) # edited by Ming with concern for further extension\n",
179 | " self.lin2 = torch.nn.Linear(self.nhid, self.nhid//2)\n",
180 | " self.lin3 = torch.nn.Linear(self.nhid//2, num_class) # edited by Ming with concern for further extension\n",
181 | "\n",
182 | " def forward(self, data):\n",
183 | " x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch\n",
184 | "\n",
185 | " x = F.relu(self.conv1(x, edge_index))\n",
186 | " x, edge_index, edge_attr, batch, _, _ = self.pool1(x, edge_index, edge_attr, batch)\n",
187 | " x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
188 | "\n",
189 | " x = F.relu(self.conv2(x, edge_index))\n",
190 | " x, edge_index, edge_attr, batch, _, _ = self.pool2(x, edge_index, edge_attr, batch)\n",
191 | " x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
192 | "\n",
193 | " x = F.relu(self.conv3(x, edge_index))\n",
194 | " x, edge_index, edge_attr, batch, _, _ = self.pool3(x, edge_index, edge_attr, batch)\n",
195 | " x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
196 | " #add one more conv-pooling block, corresponding to conv4 and pool4\n",
197 | " x = F.relu(self.conv4(x, edge_index))\n",
198 | " x, edge_index, edge_attr, batch, _, _ = self.pool4(x, edge_index, edge_attr, batch)\n",
199 | " x4 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
200 | "\n",
201 | " x = x1 + x2 + x3 + x4\n",
202 | "# x = x1 + x2 + x3\n",
203 | "\n",
204 | " x = F.relu(self.lin1(x))\n",
205 | " x = F.dropout(x, p=0.5, training=self.training)\n",
206 | " x = F.relu(self.lin2(x))\n",
207 | "# print('shape of x before log_softmax: ',x.shape)\n",
208 | " y1 = F.log_softmax(self.lin3(x), dim=-1)\n",
209 | "# print('shape of x after log_softmax: ',x.shape)\n",
210 | " y2 = torch.sigmoid(self.lin3(x))\n",
211 | "\n",
212 | " return y1, y2\n",
213 | " \n",
214 | "##GCNSag\n",
215 | "class GCNSag(torch.nn.Module):\n",
216 | " def __init__(self, num_feature, num_class, nhid=256, pooling_ratio=0.75):\n",
217 | " super(GCNSag, self).__init__()\n",
218 | " self.nhid = nhid\n",
219 | " self.pooling_ratio = pooling_ratio\n",
220 | " self.conv1 = GCNConv(int(num_feature), self.nhid)\n",
221 | " self.pool1 = SAGPooling(self.nhid, min_score=0.001, GNN=GCNConv) # edited by Ming with concern for further extension\n",
222 | " self.conv2 = GCNConv(self.nhid, self.nhid)\n",
223 | " self.pool2 = SAGPooling(self.nhid, min_score=0.001, GNN=GCNConv)\n",
224 | " self.conv3 = GCNConv(self.nhid, self.nhid)\n",
225 | " self.pool3 = SAGPooling(self.nhid, min_score=0.001, GNN=GCNConv)\n",
226 | " #add one more conv-pooling block, i.e., conv4 and pool4\n",
227 | " self.conv4 = GCNConv(self.nhid, self.nhid)\n",
228 | " self.pool4 = SAGPooling(self.nhid, min_score=0.001, GNN=GCNConv)\n",
229 | "\n",
230 | " self.lin1 = torch.nn.Linear(self.nhid*2, self.nhid) # edited by Ming with concern for further extension\n",
231 | " self.lin2 = torch.nn.Linear(self.nhid, self.nhid//2)\n",
232 | " self.lin3 = torch.nn.Linear(self.nhid//2, num_class) # edited by Ming with concern for further extension\n",
233 | "\n",
234 | " def forward(self, data):\n",
235 | " x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch\n",
236 | "\n",
237 | " x = F.relu(self.conv1(x, edge_index))\n",
238 | " x, edge_index, edge_attr, batch, _, _ = self.pool1(x, edge_index, edge_attr, batch)\n",
239 | " x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
240 | "\n",
241 | " x = F.relu(self.conv2(x, edge_index))\n",
242 | " x, edge_index, edge_attr, batch, _, _ = self.pool2(x, edge_index, edge_attr, batch)\n",
243 | " x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
244 | "\n",
245 | " x = F.relu(self.conv3(x, edge_index))\n",
246 | " x, edge_index, edge_attr, batch, _, _ = self.pool3(x, edge_index, edge_attr, batch)\n",
247 | " x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
248 | " #add one more conv-pooling block, corresponding to conv4 and pool4\n",
249 | " x = F.relu(self.conv4(x, edge_index))\n",
250 | " x, edge_index, edge_attr, batch, _, _ = self.pool4(x, edge_index, edge_attr, batch)\n",
251 | " x4 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
252 | "\n",
253 | " x = x1 + x2 + x3 + x4\n",
254 | "# x = x1 + x2 + x3\n",
255 | "\n",
256 | " x = F.relu(self.lin1(x))\n",
257 | " x = F.dropout(x, p=0.5, training=self.training)\n",
258 | " x = F.relu(self.lin2(x))\n",
259 | "# print('shape of x before log_softmax: ',x.shape)\n",
260 | " y1 = F.log_softmax(self.lin3(x), dim=-1)\n",
261 | "# print('shape of x after log_softmax: ',x.shape)\n",
262 | " y2 = torch.sigmoid(self.lin3(x))\n",
263 | "\n",
264 | " return y1, y2 \n",
265 | "\n",
266 | " \n",
267 | "##GINTopK\n",
268 | "class GINTopK(torch.nn.Module):\n",
269 | " def __init__(self, num_feature, num_class, nhid):\n",
270 | " super(GINTopK, self).__init__()\n",
271 | " self.conv1 = GINConv(Seq(Lin(num_feature, nhid), ReLU(), Lin(nhid, nhid)))\n",
272 | " self.pool1 = TopKPooling(nhid, ratio=0.8)\n",
273 | " self.conv2 = GINConv(Seq(Lin(nhid, nhid), ReLU(), Lin(nhid, nhid)))\n",
274 | " self.pool2 = TopKPooling(nhid, ratio=0.8)\n",
275 | " self.conv3 = GINConv(Seq(Lin(nhid, nhid), ReLU(), Lin(nhid, nhid)))\n",
276 | " self.pool3 = TopKPooling(nhid, ratio=0.8)\n",
277 | " self.conv4 = GINConv(Seq(Lin(nhid, nhid), ReLU(), Lin(nhid, nhid)))\n",
278 | " self.pool4 = TopKPooling(nhid, ratio=0.8)\n",
279 | "\n",
280 | " self.lin1 = torch.nn.Linear(2*nhid, nhid)\n",
281 | " self.lin2 = torch.nn.Linear(nhid, nhid//2)\n",
282 | " self.lin3 = torch.nn.Linear(nhid//2, num_class)\n",
283 | "\n",
284 | " def forward(self, data):\n",
285 | " x, edge_index, batch = data.x, data.edge_index, data.batch\n",
286 | "\n",
287 | " x = F.relu(self.conv1(x, edge_index))\n",
288 | " x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)\n",
289 | " x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
290 | "\n",
291 | " x = F.relu(self.conv2(x, edge_index))\n",
292 | " x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)\n",
293 | " x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
294 | "\n",
295 | " x = F.relu(self.conv3(x, edge_index))\n",
296 | " x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)\n",
297 | " x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
298 | " \n",
299 | " x = F.relu(self.conv4(x, edge_index))\n",
300 | " x, edge_index, _, batch, _, _ = self.pool4(x, edge_index, None, batch)\n",
301 | " x4 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
302 | "\n",
303 | " x = x1 + x2 + x3 + x4\n",
304 | "\n",
305 | " x = F.relu(self.lin1(x))\n",
306 | " x = F.dropout(x, p=0.5, training=self.training)\n",
307 | " x = F.relu(self.lin2(x))\n",
308 | " y1 = F.log_softmax(self.lin3(x), dim=-1)\n",
309 | " y2 = torch.sigmoid(self.lin3(x))\n",
310 | "\n",
311 | " return y1, y2\n",
312 | " \n",
313 | "\n",
314 | "##GINSAG\n",
315 | "class GINSAG(torch.nn.Module):\n",
316 | " def __init__(self, num_feature, num_class, nhid):\n",
317 | " super(Net, self).__init__()\n",
318 | " self.conv1 = GINConv(Seq(Lin(num_feature, nhid), ReLU(), Lin(nhid, nhid)))\n",
319 | " self.pool1 = SAGPooling(nhid, min_score=0.001, GNN=GCNConv)\n",
320 | " self.conv2 = GINConv(Seq(Lin(nhid, nhid), ReLU(), Lin(nhid, nhid)))\n",
321 | " self.pool2 = SAGPooling(nhid, min_score=0.001, GNN=GCNConv)\n",
322 | " self.conv3 = GINConv(Seq(Lin(nhid, nhid), ReLU(), Lin(nhid, nhid)))\n",
323 | " self.pool3 = SAGPooling(nhid, min_score=0.001, GNN=GCNConv)\n",
324 | " self.conv4 = GINConv(Seq(Lin(nhid, nhid), ReLU(), Lin(nhid, nhid)))\n",
325 | " self.pool4 = SAGPooling(nhid, min_score=0.001, GNN=GCNConv)\n",
326 | "\n",
327 | " self.lin1 = torch.nn.Linear(2*nhid, nhid)\n",
328 | " self.lin2 = torch.nn.Linear(nhid, nhid//2)\n",
329 | " self.lin3 = torch.nn.Linear(nhid//2, num_class)\n",
330 | "\n",
331 | " def forward(self, data):\n",
332 | " x, edge_index, batch = data.x, data.edge_index, data.batch\n",
333 | "\n",
334 | " x = F.relu(self.conv1(x, edge_index))\n",
335 | " x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)\n",
336 | " x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
337 | "\n",
338 | " x = F.relu(self.conv2(x, edge_index))\n",
339 | " x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)\n",
340 | " x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
341 | "\n",
342 | " x = F.relu(self.conv3(x, edge_index))\n",
343 | " x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)\n",
344 | " x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
345 | " \n",
346 | " x = F.relu(self.conv4(x, edge_index))\n",
347 | " x, edge_index, _, batch, _, _ = self.pool4(x, edge_index, None, batch)\n",
348 | " x4 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
349 | "\n",
350 | " x = x1 + x2 + x3 + x4\n",
351 | "\n",
352 | " x = F.relu(self.lin1(x))\n",
353 | " x = F.dropout(x, p=0.5, training=self.training)\n",
354 | " x = F.relu(self.lin2(x))\n",
355 | " y1 = F.log_softmax(self.lin3(x), dim=-1)\n",
356 | " y2 = torch.sigmoid(self.lin3(x))\n",
357 | "\n",
358 | " return y1, y2 \n",
359 | " \n",
360 | "\n",
361 | "def train(model,train_loader,device):\n",
362 | " model.train()\n",
363 | "\n",
364 | " loss_all = 0\n",
365 | " for data in train_loader:\n",
366 | " data = data.to(device)\n",
367 | " optimizer.zero_grad()\n",
368 | " output, _ = model(data)\n",
369 | " loss = F.nll_loss(output, data.y)\n",
370 | " loss.backward()\n",
371 | " loss_all += data.num_graphs * loss.item()\n",
372 | " optimizer.step()\n",
373 | " return loss_all / len(train_loader.dataset)\n",
374 | " \n",
375 | "def test(model,loader):\n",
376 | " model.eval()\n",
377 | " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
378 | " correct = 0.\n",
379 | " loss = 0. # edited by Ming with concern for further extension\n",
380 | " pred_1 = list()\n",
381 | " out_1 = np.array([])\n",
382 | " gt_l = np.array([])\n",
383 | " pred_bi = np.array([])\n",
384 | " label = np.array([])\n",
385 | " for data in loader:\n",
386 | " data = data.to(device)\n",
387 | " out, out2 = model(data)\n",
388 | "# print('out, out2 in test: ',out,out2)\n",
389 | " pred = out.max(dim=1)[1]\n",
390 | " correct += pred.eq(data.y).sum().item()\n",
391 | " loss += F.nll_loss(out, data.y,reduction='sum').item()\n",
392 | " \n",
393 | " pred_1.append(pred.cpu().detach().numpy())\n",
394 | " out_1 = np.vstack([out_1, out2.cpu().detach().numpy()]) if out_1.size else out2.cpu().detach().numpy()\n",
395 | " _tmp_label = data.y.cpu().detach().numpy()\n",
396 | " for _label in _tmp_label:\n",
397 | " if(_label == 0):\n",
398 | " _label_3d = np.array([1, 0, 0])\n",
399 | " elif(_label == 1):\n",
400 | " _label_3d = np.array([0, 1, 0])\n",
401 | " elif(_label == 2):\n",
402 | " _label_3d = np.array([0, 0, 1])\n",
403 | " gt_l = np.vstack([gt_l, _label_3d]) if gt_l.size else _label_3d\n",
404 | " for _pred in pred:\n",
405 | " if(_pred == 0):\n",
406 | " _pred_bi = np.array([1, 0, 0])\n",
407 | " if(_pred == 1):\n",
408 | " _pred_bi = np.array([0, 1, 0])\n",
409 | " if(_pred == 2):\n",
410 | " _pred_bi = np.array([0, 0, 1])\n",
411 | " pred_bi = np.vstack([pred_bi,_pred_bi]) if pred_bi.size else _pred_bi\n",
412 | " label = np.hstack([label,_tmp_label]) if label.size else _tmp_label\n",
413 | " # pred_1 = np.array(pred_1).reshape(pred_1)\n",
414 | " return correct *1.0 / len(loader.dataset), loss / len(loader.dataset), pred_1, out_1, gt_l, label, pred_bi"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": 6,
420 | "metadata": {},
421 | "outputs": [
422 | {
423 | "name": "stdout",
424 | "output_type": "stream",
425 | "text": [
426 | "Device: cuda:0\n"
427 | ]
428 | }
429 | ],
430 | "source": [
431 | "# import argparse\n",
432 | "#def hyperopt_train(batch_size=256, learning_rate=0.01, weight_decay=0.0005, nhid=256, pooling_ratio=0.75, epochs=200, runs=1):\n",
433 | " ## Parameter Setting\n",
434 | " #added by ming for future pooling extensions\n",
435 | " \n",
436 | "# parser = argparse.ArgumentParser()\n",
437 | "# parser.add_argument('--batch_size', type=int, default=256,\n",
438 | "# help='batch size')\n",
439 | "# parser.add_argument('--learning_rate', type=float, default=5e-4,\n",
440 | "# help='learning rate')\n",
441 | "# parser.add_argument('--weight_decay', type=float, default=1e-4,\n",
442 | "# help='weight decay')\n",
443 | "# parser.add_argument('--nhid', type=int, default=512,\n",
444 | "# help='hidden size')\n",
445 | "# parser.add_argument('--pooling_ratio', type=float, default=0.5,\n",
446 | "# help='pooling ratio')\n",
447 | "# parser.add_argument('--epochs', type=int, default=200,\n",
448 | "# help='maximum number of epochs')\n",
449 | "# # parser.add_argument('--early_stopping', type=int, default=100,\n",
450 | "# # help='patience for earlystopping')\n",
451 | "# parser.add_argument('--num_layers', type=int, default=4,\n",
452 | "# help='number of layers')\n",
453 | "# parser.add_argument('--runs', type=int, default=1,\n",
454 | "# help='number of runs')\n",
455 | "# args = parser.parse_args()\n",
456 | "\n",
457 | "# batch_size = args.batch_size\n",
458 | "# learning_rate = args.learning_rate\n",
459 | "# weight_decay = args.weight_decay\n",
460 | "# nhid = args.nhid\n",
461 | "# pooling_ratio = args.pooling_ratio\n",
462 | "# epochs = args.epochs\n",
463 | "# # early_stopping = args.early_stopping\n",
464 | "# num_layers = args.num_layers\n",
465 | "# runs = args.runs\n",
466 | "\n",
467 | "batch_size = 256\n",
468 | "learning_rate = 5e-4\n",
469 | "weight_decay = 1e-4\n",
470 | "nhid = 512\n",
471 | "pooling_ratio = 0.5\n",
472 | "epochs = 200\n",
473 | "# early_stopping = args.early_stopping\n",
474 | "num_layers = 4\n",
475 | "\n",
476 | "model_name = \"gintopk\"\n",
477 | "runs = 1\n",
478 | "fold = 4\n",
479 | "\n",
480 | "# early_stopping = epochs\n",
481 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
482 | "print('Device: {}'.format(device))"
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "execution_count": 7,
488 | "metadata": {},
489 | "outputs": [],
490 | "source": [
491 | "import os\n",
492 | "\n",
493 | "def load_dataset(dataset_path):\n",
494 | " ## load and preprocess data for stomach cancer\n",
495 | " ld_edge_index = \"\"\n",
496 | " ld_edge_attr = \"\"\n",
497 | " ld_feature = \"\"\n",
498 | " ld_label = \"\"\n",
499 | " ld_pid = \"\"\n",
500 | " for _root, _dirs, _files in os.walk(dataset_path):\n",
501 | " for _file in _files:\n",
502 | " #print(_file)\n",
503 | " if(\"weighted_edge_index\" in _file):\n",
504 | " ld_edge_index = os.path.join(_root, _file)\n",
505 | " elif(\"weighted_edge_attr\" in _file):\n",
506 | " ld_edge_attr = os.path.join(_root, _file)\n",
507 | " elif(\"weighted_feature\" in _file):\n",
508 | " ld_feature = os.path.join(_root, _file)\n",
509 | " elif(\"weighted_label\" in _file):\n",
510 | " ld_label = os.path.join(_root, _file)\n",
511 | " elif(\"weighted_pid.mat\" in _file):\n",
512 | " ld_pid = os.path.join(_root, _file)\n",
513 | "# print(ld_edge_index)\n",
514 | "# print(ld_edge_attr)\n",
515 | "# print(ld_feature)\n",
516 | "# print(ld_label)\n",
517 | "# print(ld_pid)\n",
518 | "\n",
519 | " edge_index = sio.loadmat(ld_edge_index)\n",
520 | " edge_index = edge_index['edge_index'][0]\n",
521 | " # load edge_attr\n",
522 | " edge_attr = sio.loadmat(ld_edge_attr)\n",
523 | " edge_attr = edge_attr['edge_attr'][0]\n",
524 | " # load feature\n",
525 | " feature = sio.loadmat(ld_feature)\n",
526 | " feature = feature['feature']\n",
527 | " #print(feature)\n",
528 | " # load label\n",
529 | " label = sio.loadmat(ld_label)\n",
530 | " label = label['label'][0]\n",
531 | " # load label_pid\n",
532 | " pid = sio.loadmat(ld_pid)\n",
533 | " pid = pid['pid'][0]\n",
534 | " \n",
535 | " stomach = list()\n",
536 | " num_edge = 0\n",
537 | " #num_feature = 0\n",
538 | " num_node = 0\n",
539 | " num_class = 3\n",
540 | " num_graph = edge_index.shape[0]\n",
541 | "\n",
542 | " for i in range(num_graph):\n",
543 | " # extract edge index, turn to tensor\n",
544 | " edge_index_1 = np.array(edge_index[i][:,0:2],dtype=np.int)\n",
545 | " edge_index_1 = torch.tensor(edge_index_1, dtype=torch.long).to(device)\n",
546 | " # number of edges\n",
547 | " num_edge = num_edge + edge_index_1.shape[0]\n",
548 | " # extract edge_attr, turn to tensor\n",
549 | " edge_attr_1 = np.array(edge_attr[i][:,0:1],dtype=np.int)\n",
550 | " edge_attr_1 = torch.tensor(edge_attr_1, dtype=torch.float).to(device)\n",
551 | " # extract feature, turn to tensor\n",
552 | " \n",
553 | " feature_1 = torch.tensor(feature[i], dtype=torch.float).to(device)\n",
554 | " #print(feature_1.shape)\n",
555 | " # number of nodes\n",
556 | " num_node = num_node + feature_1.shape[0]\n",
557 | " # number of features\n",
558 | " if i==0:\n",
559 | " num_feature = feature_1.shape[1]\n",
560 | " # extract label, turn to tensor\n",
561 | " label_1 = torch.tensor([label[i]],dtype=torch.long).to(device)\n",
562 | " # extract patient id, turn to tensor\n",
563 | " \n",
564 | " pid_1 = torch.tensor([pid[i]],dtype=torch.long).to(device)\n",
565 | " # put edge, feature, label together to form graph information in \"Data\" format\n",
566 | " data_1 = Data(x=feature_1, edge_index=edge_index_1.t().contiguous(), edge_attr=edge_attr_1, y=label_1, pid=pid_1)\n",
567 | " stomach.append(data_1)\n",
568 | " return(stomach, num_feature, num_edge, num_node)\n",
569 | " \n",
570 | "train_data_list, num_feature, num_edge, num_node = load_dataset(\"data/selected_new_data_file/train_data_fold_{}/\".format(fold))\n",
571 | "val_data_list, _, _, _ = load_dataset(\"data/selected_new_data_file/val_data_fold_{}/\".format(fold))\n",
572 | "test_data_list, _, _, _ = load_dataset(\"data/selected_new_data_file/test_data/\")\n",
573 | "\n",
574 | "train_val_list = train_data_list + val_data_list\n",
575 | "# generate training, validation and test data sets\n",
576 | "nv = np.random.permutation(len(train_val_list))\n",
577 | "stomach_1 = train_val_list\n",
578 | "stomach = list()\n",
579 | "for i in nv:\n",
580 | " stomach.append(stomach_1[nv[i]])\n",
581 | "num_train_val = len(stomach)\n",
582 | "num_train = int(num_train_val * 0.8)\n",
583 | "#num_val = num_train_val - num_train\n",
584 | "\n",
585 | "train_loader = DataLoader(stomach[0:num_train], batch_size=batch_size, shuffle = True)\n",
586 | "val_loader = DataLoader(stomach[num_train:-1], batch_size=batch_size, shuffle = True)\n",
587 | "test_loader = DataLoader(test_data_list, batch_size=1, shuffle = False)"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": 8,
593 | "metadata": {
594 | "scrolled": false
595 | },
596 | "outputs": [
597 | {
598 | "name": "stdout",
599 | "output_type": "stream",
600 | "text": [
601 | "**Data Set\n",
602 | "Ave.#Edge: 5553410.0, Ave.#Feature: 35.0, Ave.#Node: 1053500.0, #Classes: 3\n",
603 | "Train-val-test ratio: 7:1:2, Shuffle: True\n",
604 | "- number of training data: 42\n",
605 | "- number of validation data: 11\n",
606 | "- number of test data: 3660\n",
607 | "**Network Parameter Setting\n",
608 | "- batch size: 256\n",
609 | "- learning rate: 0.0005\n",
610 | "- weight decay: 0.0001\n",
611 | "- hidden size: 512\n",
612 | "- pooling_ratio: 0.5\n",
613 | "- maximum number of epochs: 200\n",
614 | "- graph convolution: GCNConv\n",
615 | "- number of graph convoluational layers: 1x4\n",
616 | "- graph pooling: TopKPooling\n",
617 | "- number of pooling layers: 4\n",
618 | "- number of fully connected layers: 4\n",
619 | "Run: 001, Epoch: 001, Val loss: 1.05696, Val acc: 0.42664\n",
620 | "Validation loss decreased (inf --> 1.056962). Saving model ...\n",
621 | "Run: 001, Epoch: 002, Val loss: 1.01926, Val acc: 0.45636\n",
622 | "Validation loss decreased (1.056962 --> 1.019258). Saving model ...\n",
623 | "Run: 001, Epoch: 003, Val loss: 0.96705, Val acc: 0.50489\n",
624 | "Validation loss decreased (1.019258 --> 0.967054). Saving model ...\n",
625 | "Run: 001, Epoch: 004, Val loss: 0.97122, Val acc: 0.50263\n",
626 | "EarlyStopping counter: 1 out of 20\n",
627 | "Run: 001, Epoch: 005, Val loss: 0.91673, Val acc: 0.54891\n",
628 | "Validation loss decreased (0.967054 --> 0.916735). Saving model ...\n",
629 | "Run: 001, Epoch: 006, Val loss: 0.85152, Val acc: 0.59707\n",
630 | "Validation loss decreased (0.916735 --> 0.851520). Saving model ...\n",
631 | "Run: 001, Epoch: 007, Val loss: 0.84072, Val acc: 0.61362\n",
632 | "Validation loss decreased (0.851520 --> 0.840718). Saving model ...\n",
633 | "Run: 001, Epoch: 008, Val loss: 0.78786, Val acc: 0.65124\n",
634 | "Validation loss decreased (0.840718 --> 0.787856). Saving model ...\n",
635 | "Run: 001, Epoch: 009, Val loss: 0.75665, Val acc: 0.66290\n",
636 | "Validation loss decreased (0.787856 --> 0.756647). Saving model ...\n",
637 | "Run: 001, Epoch: 010, Val loss: 0.68001, Val acc: 0.70542\n",
638 | "Validation loss decreased (0.756647 --> 0.680006). Saving model ...\n",
639 | "Run: 001, Epoch: 011, Val loss: 0.69914, Val acc: 0.68811\n",
640 | "EarlyStopping counter: 1 out of 20\n",
641 | "Run: 001, Epoch: 012, Val loss: 0.62928, Val acc: 0.72649\n",
642 | "Validation loss decreased (0.680006 --> 0.629284). Saving model ...\n",
643 | "Run: 001, Epoch: 013, Val loss: 0.62442, Val acc: 0.73890\n",
644 | "Validation loss decreased (0.629284 --> 0.624422). Saving model ...\n",
645 | "Run: 001, Epoch: 014, Val loss: 0.57847, Val acc: 0.75508\n",
646 | "Validation loss decreased (0.624422 --> 0.578474). Saving model ...\n",
647 | "Run: 001, Epoch: 015, Val loss: 0.55423, Val acc: 0.77163\n",
648 | "Validation loss decreased (0.578474 --> 0.554229). Saving model ...\n",
649 | "Run: 001, Epoch: 016, Val loss: 0.56123, Val acc: 0.77351\n",
650 | "EarlyStopping counter: 1 out of 20\n",
651 | "Run: 001, Epoch: 017, Val loss: 0.54422, Val acc: 0.76185\n",
652 | "Validation loss decreased (0.554229 --> 0.544223). Saving model ...\n",
653 | "Run: 001, Epoch: 018, Val loss: 0.48902, Val acc: 0.80023\n",
654 | "Validation loss decreased (0.544223 --> 0.489018). Saving model ...\n",
655 | "Run: 001, Epoch: 019, Val loss: 0.51223, Val acc: 0.79834\n",
656 | "EarlyStopping counter: 1 out of 20\n",
657 | "Run: 001, Epoch: 020, Val loss: 0.48600, Val acc: 0.80813\n",
658 | "Validation loss decreased (0.489018 --> 0.486000). Saving model ...\n",
659 | "Run: 001, Epoch: 021, Val loss: 0.45923, Val acc: 0.81264\n",
660 | "Validation loss decreased (0.486000 --> 0.459229). Saving model ...\n",
661 | "Run: 001, Epoch: 022, Val loss: 0.44748, Val acc: 0.82054\n",
662 | "Validation loss decreased (0.459229 --> 0.447480). Saving model ...\n",
663 | "Run: 001, Epoch: 023, Val loss: 0.46807, Val acc: 0.81716\n",
664 | "EarlyStopping counter: 1 out of 20\n",
665 | "Run: 001, Epoch: 024, Val loss: 0.41482, Val acc: 0.84274\n",
666 | "Validation loss decreased (0.447480 --> 0.414825). Saving model ...\n",
667 | "Run: 001, Epoch: 025, Val loss: 0.43815, Val acc: 0.83747\n",
668 | "EarlyStopping counter: 1 out of 20\n",
669 | "Run: 001, Epoch: 026, Val loss: 0.42474, Val acc: 0.84048\n",
670 | "EarlyStopping counter: 2 out of 20\n",
671 | "Run: 001, Epoch: 027, Val loss: 0.42213, Val acc: 0.84349\n",
672 | "EarlyStopping counter: 3 out of 20\n",
673 | "Run: 001, Epoch: 028, Val loss: 0.38146, Val acc: 0.86494\n",
674 | "Validation loss decreased (0.414825 --> 0.381457). Saving model ...\n",
675 | "Run: 001, Epoch: 029, Val loss: 0.43542, Val acc: 0.85403\n",
676 | "EarlyStopping counter: 1 out of 20\n",
677 | "Run: 001, Epoch: 030, Val loss: 0.38875, Val acc: 0.86305\n",
678 | "EarlyStopping counter: 2 out of 20\n",
679 | "Run: 001, Epoch: 031, Val loss: 0.44091, Val acc: 0.84462\n",
680 | "EarlyStopping counter: 3 out of 20\n",
681 | "Run: 001, Epoch: 032, Val loss: 0.40274, Val acc: 0.87020\n",
682 | "EarlyStopping counter: 4 out of 20\n",
683 | "Run: 001, Epoch: 033, Val loss: 0.44874, Val acc: 0.86117\n",
684 | "EarlyStopping counter: 5 out of 20\n",
685 | "Run: 001, Epoch: 034, Val loss: 0.39264, Val acc: 0.86268\n",
686 | "EarlyStopping counter: 6 out of 20\n",
687 | "Run: 001, Epoch: 035, Val loss: 0.41636, Val acc: 0.86719\n",
688 | "EarlyStopping counter: 7 out of 20\n",
689 | "Run: 001, Epoch: 036, Val loss: 0.39624, Val acc: 0.87359\n",
690 | "EarlyStopping counter: 8 out of 20\n",
691 | "Run: 001, Epoch: 037, Val loss: 0.46559, Val acc: 0.86230\n",
692 | "EarlyStopping counter: 9 out of 20\n",
693 | "Run: 001, Epoch: 038, Val loss: 0.42229, Val acc: 0.87547\n",
694 | "EarlyStopping counter: 10 out of 20\n",
695 | "Run: 001, Epoch: 039, Val loss: 0.45045, Val acc: 0.86531\n",
696 | "EarlyStopping counter: 11 out of 20\n",
697 | "Run: 001, Epoch: 040, Val loss: 0.38113, Val acc: 0.88751\n",
698 | "Validation loss decreased (0.381457 --> 0.381129). Saving model ...\n",
699 | "Run: 001, Epoch: 041, Val loss: 0.43642, Val acc: 0.87434\n",
700 | "EarlyStopping counter: 1 out of 20\n",
701 | "Run: 001, Epoch: 042, Val loss: 0.40980, Val acc: 0.88149\n",
702 | "EarlyStopping counter: 2 out of 20\n",
703 | "Run: 001, Epoch: 043, Val loss: 0.38917, Val acc: 0.88864\n",
704 | "EarlyStopping counter: 3 out of 20\n",
705 | "Run: 001, Epoch: 044, Val loss: 0.45189, Val acc: 0.87698\n",
706 | "EarlyStopping counter: 4 out of 20\n",
707 | "Run: 001, Epoch: 045, Val loss: 0.46373, Val acc: 0.87660\n",
708 | "EarlyStopping counter: 5 out of 20\n",
709 | "Run: 001, Epoch: 046, Val loss: 0.54833, Val acc: 0.86569\n",
710 | "EarlyStopping counter: 6 out of 20\n",
711 | "Run: 001, Epoch: 047, Val loss: 0.42891, Val acc: 0.87472\n",
712 | "EarlyStopping counter: 7 out of 20\n",
713 | "Run: 001, Epoch: 048, Val loss: 0.42434, Val acc: 0.89391\n",
714 | "EarlyStopping counter: 8 out of 20\n",
715 | "Run: 001, Epoch: 049, Val loss: 0.41616, Val acc: 0.89052\n",
716 | "EarlyStopping counter: 9 out of 20\n",
717 | "Run: 001, Epoch: 050, Val loss: 0.40190, Val acc: 0.89541\n",
718 | "EarlyStopping counter: 10 out of 20\n",
719 | "Run: 001, Epoch: 051, Val loss: 0.44651, Val acc: 0.89278\n",
720 | "EarlyStopping counter: 11 out of 20\n",
721 | "Run: 001, Epoch: 052, Val loss: 0.46134, Val acc: 0.89278\n",
722 | "EarlyStopping counter: 12 out of 20\n",
723 | "Run: 001, Epoch: 053, Val loss: 0.47696, Val acc: 0.87848\n",
724 | "EarlyStopping counter: 13 out of 20\n",
725 | "Run: 001, Epoch: 054, Val loss: 0.55726, Val acc: 0.88375\n",
726 | "EarlyStopping counter: 14 out of 20\n",
727 | "Run: 001, Epoch: 055, Val loss: 0.51403, Val acc: 0.86305\n",
728 | "EarlyStopping counter: 15 out of 20\n",
729 | "Run: 001, Epoch: 056, Val loss: 0.47854, Val acc: 0.89353\n",
730 | "EarlyStopping counter: 16 out of 20\n",
731 | "Run: 001, Epoch: 057, Val loss: 0.46898, Val acc: 0.88939\n",
732 | "EarlyStopping counter: 17 out of 20\n",
733 | "Run: 001, Epoch: 058, Val loss: 0.47342, Val acc: 0.89014\n",
734 | "EarlyStopping counter: 18 out of 20\n",
735 | "Run: 001, Epoch: 059, Val loss: 0.42647, Val acc: 0.89165\n",
736 | "EarlyStopping counter: 19 out of 20\n",
737 | "Run: 001, Epoch: 060, Val loss: 0.51179, Val acc: 0.87923\n",
738 | "EarlyStopping counter: 20 out of 20\n",
739 | "Early stopping\n",
740 | "** Run: 001, test loss: 1.96189, test acc: 0.74235\n",
741 | "Test accuarcy at patient level: 74.29\n",
742 | "** Model 1600507584.0485966, mean test acc (cell): 0.74235\n"
743 | ]
744 | }
745 | ],
746 | "source": [
747 | "# import EarlyStopping\n",
748 | "from pytorchtools import EarlyStopping\n",
749 | "\n",
750 | "\n",
751 | "sv_dat = '{}/test_data.pt'.format(model_name)\n",
752 | "torch.save(test_data_list, sv_dat)\n",
753 | "num_class = 3\n",
754 | "\n",
755 | "print('**Data Set')\n",
756 | "#print('Data name: {}, Data type: {}, #Graph: {}'.format('Stomach',data_type,num_graph))\n",
757 | "print('Ave.#Edge: {:.1f}, Ave.#Feature: {:.1f}, Ave.#Node: {:.1f}, #Classes: {:d}'.format(num_edge,num_feature,num_node,num_class))\n",
758 | "print('Train-val-test ratio: 7:1:2, Shuffle: True')\n",
759 | "print('- number of training data:',len(train_loader))\n",
760 | "print('- number of validation data:',len(val_loader))\n",
761 | "print('- number of test data:',len(test_loader))\n",
762 | "\n",
763 | "print('**Network Parameter Setting')\n",
764 | "print('- batch size: ',batch_size)\n",
765 | "print('- learning rate: ',learning_rate)\n",
766 | "print('- weight decay: ',weight_decay)\n",
767 | "print('- hidden size: ',nhid)\n",
768 | "print('- pooling_ratio: ',pooling_ratio)\n",
769 | "print('- maximum number of epochs: ',epochs)\n",
770 | "# print('- patience for earlystopping: ',early_stopping)\n",
771 | "print('- graph convolution: ','GCNConv')\n",
772 | "print('- number of graph convoluational layers: {}x{}'.format(1,num_layers))\n",
773 | "print('- graph pooling: ','TopKPooling')\n",
774 | "print('- number of pooling layers: ',num_layers)\n",
775 | "print('- number of fully connected layers: ',num_layers)\n",
776 | " \n",
777 | "###############################################################\n",
778 | "\n",
779 | "train_loss = np.zeros((runs,epochs),dtype=np.float)\n",
780 | "val_acc = np.zeros((runs,epochs))\n",
781 | "val_loss = np.zeros((runs,epochs))\n",
782 | "test_acc_c = np.zeros(runs)\n",
783 | "test_loss_c = np.zeros(runs)\n",
784 | "test_pred_c = np.zeros(runs)\n",
785 | "test_out_c = np.zeros((runs,num_class)) \n",
786 | "groud_truth_c = np.zeros((runs,num_class))\n",
787 | "test_acc_p = np.zeros(runs)\n",
788 | "min_loss = 1e10*np.ones(runs)\n",
789 | "# num_test_p = num_test\n",
790 | "# pid_test_p = np.zeros((runs,num_test_p))\n",
791 | "for run in range(runs):\n",
792 | "# print('\\n*** Training ***')\n",
793 | "# print('** Run {} of total {} runs ...'.format(run+1,runs))\n",
794 | "# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
795 | " model = GINTopK(num_feature=num_feature, num_class=num_class, nhid=nhid).to(device)\n",
796 | " optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)\n",
797 | " \n",
798 | " ## Training\n",
799 | "\n",
800 | " # initialize the early_stopping object\n",
801 | " patience = 20\n",
802 | " early_stopping = EarlyStopping(patience=patience, verbose=True, path=\"{}/model_{}_fold{}_run{}.pth\".format(model_name, model_name, fold, run))\n",
803 | "# val_acc_c = np.zeros((runs,epochs))\n",
804 | "# val_loss_c = np.zeros((runs,epochs))\n",
805 | "# test_acc_c = np.zeros(runs)\n",
806 | "# test_acc_p = np.zeros(runs)\n",
807 | " for epoch in range(epochs):\n",
808 | " model.train()\n",
809 | " loss_all = 0\n",
810 | " for i, data in enumerate(train_loader):\n",
811 | " data = data.to(device)\n",
812 | "# print('data in train: ',data)\n",
813 | " out, out2 = model(data)\n",
814 | " loss = F.nll_loss(out, data.y)\n",
815 | " #print(out, data.y)\n",
816 | " #writer.add_scalar('train/loss', loss, len(train_loader)*epoch+i)\n",
817 | " #print(\"Training loss: {:.5f}\".format(loss.item()))\n",
818 | " loss.backward()\n",
819 | " loss_all += data.num_graphs * loss.item()\n",
820 | " optimizer.step()\n",
821 | " optimizer.zero_grad()\n",
822 | " loss = loss_all / len(train_loader.dataset) \n",
823 | " train_loss[run,epoch] = loss\n",
824 | " val_acc[run,epoch], val_loss[run, epoch], _, _, _, _, _ = test(model, val_loader)\n",
825 | " print(\"Run: {:03d}, Epoch: {:03d}, Val loss: {:.5f}, Val acc: {:.5f}\".format(run+1,epoch+1,val_loss[run,epoch],val_acc[run,epoch]))\n",
826 | " \n",
827 | " # early_stopping needs the validation loss to check if it has decresed, \n",
828 | " # and if it has, it will make a checkpoint of the current model\n",
829 | " early_stopping(val_loss[run, epoch], model)\n",
830 | " \n",
831 | " if early_stopping.early_stop:\n",
832 | " print(\"Early stopping\")\n",
833 | " break\n",
834 | " \n",
835 | "# if val_loss[run,epoch] < min_loss[run]:\n",
836 | "# torch.save(model.state_dict(), 'model_gintopk.pth') # save the model and reuse later in test\n",
837 | "# #print(\"Model saved at epoch: {:03d}\".format(epoch))\n",
838 | "# min_loss[run] = val_loss[run,epoch]\n",
839 | " # model = GCNTopK(num_feature=num_feature, num_class=num_class, nhid=nhid, pooling_ratio=pooling_ratio).to(device)\n",
840 | " model = GINTopK(num_feature=num_feature, num_class=num_class, nhid=nhid).to(device)\n",
841 | " model.load_state_dict(torch.load(\"{}/model_{}_fold{}_run{}.pth\".format(model_name, model_name, fold, run)))\n",
842 | " test_acc_c[run], test_loss_c[run], test_pred_c, test_out_c, ground_truth_c, test_label_c, test_pred_bi_c = test(model,test_loader)\n",
843 | " print(\"** Run: {:03d}, test loss: {:.5f}, test acc: {:.5f}\".format(run+1,test_loss_c[run],test_acc_c[run]))\n",
844 | " pid_list = list()\n",
845 | " test_data = list([None] * len(test_loader))\n",
846 | " for i, data in enumerate(test_loader):\n",
847 | " pid_temp = data.pid.cpu().numpy()\n",
848 | " gt = data.y.cpu().numpy()\n",
849 | " test_data[i] = [pid_temp,gt,test_pred_c[i]]\n",
850 | " if not pid_temp in pid_list:\n",
851 | " pid_list.append(pid_temp)\n",
852 | " num_test_p = len(pid_list)\n",
853 | " test_pred_1 = np.zeros([num_class,num_test_p],dtype=np.int)\n",
854 | " pred_p = np.zeros(num_test_p,dtype=np.int)\n",
855 | " test_label_p = np.zeros(num_test_p,dtype=np.int)\n",
856 | " pid_test = np.array(pid_list)\n",
857 | " for j in range(num_test_p):\n",
858 | " pid_1 = pid_list[j]\n",
859 | " k = 0\n",
860 | " for i, data in enumerate(test_loader):\n",
861 | " if data.pid.cpu().numpy()==pid_1:\n",
862 | " if k==0:\n",
863 | " test_label_p[j] = data.y.cpu().numpy()\n",
864 | " k = 1\n",
865 | " test_pred_i = int(test_pred_c[i])\n",
866 | " test_pred_1[test_pred_i,j] = test_pred_1[test_pred_i,j] + 1\n",
867 | " pred_p[j] = np.argmax(test_pred_1[:,j])\n",
868 | " # print('j: {}, pred_p[j]: {}, test_pred_p[j]: {}'.format(j,pred_p[j],test_label_p[j]))\n",
869 | " test_acc_p[run] = (pred_p==test_label_p).sum()*1.0/num_test_p\n",
870 | " print(\"Test accuarcy at patient level: {:.2f}\".format(test_acc_p[run]*100))\n",
871 | " ## save data\n",
872 | " t1 = time.time()\n",
873 | " print(\"** Model {}, mean test acc (cell): {:.5f}\".format(t1,np.mean(test_acc_c)))\n",
874 | " sv = model_name + '/scdiag_' + model_name + '_fold' + str(fold) + '_runs' + str(runs) + '_run' + str(run) + '_epochs' + str(epochs) + '.mat'\n",
875 | " sio.savemat(sv,mdict={'val_loss':val_loss,'val_acc':val_acc,'test_loss_c':test_loss_c,'test_acc_c':test_acc_c,'train_loss':train_loss,'test_pred_c':test_pred_c,'test_out_c':test_out_c,'ground_truth_c':ground_truth_c,'test_label_c':test_label_c,'test_pred_bi_c':test_pred_bi_c,'test_acc_p':test_acc_p,'test_pred_p':pred_p,'pid_test':pid_test,'test_data':test_data})"
876 | ]
877 | },
878 | {
879 | "cell_type": "code",
880 | "execution_count": 9,
881 | "metadata": {},
882 | "outputs": [
883 | {
884 | "data": {
885 | "image/png": "\n",
886 | "text/plain": [
887 | ""
888 | ]
889 | },
890 | "metadata": {
891 | "needs_background": "light"
892 | },
893 | "output_type": "display_data"
894 | }
895 | ],
896 | "source": [
897 | "for run in range(runs):\n",
898 | " # visualize the loss as the network trained\n",
899 | " fig = plt.figure(figsize=(10,8))\n",
900 | " t_loss = train_loss[run][np.where(train_loss[run] > 0)]\n",
901 | " v_loss = val_loss[run][np.where(val_loss[run] > 0)]\n",
902 | " \n",
903 | " plt.plot(range(1,len(t_loss)+1),t_loss, label='Training Loss')\n",
904 | " plt.plot(range(1,len(v_loss)+1),v_loss,label='Validation Loss')\n",
905 | "\n",
906 | " # find position of lowest validation loss\n",
907 | " #print(np.where(v_loss == np.min(v_loss))[0][0])\n",
908 | " minposs = np.where(v_loss == np.min(v_loss))[0][0] + 1\n",
909 | " plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')\n",
910 | "\n",
911 | " plt.xlabel('epochs')\n",
912 | " plt.ylabel('loss')\n",
913 | " plt.ylim(0, 1) # consistent scale\n",
914 | " plt.xlim(0, len(v_loss)+1) # consistent scale\n",
915 | " plt.grid(True)\n",
916 | " plt.legend()\n",
917 | " plt.tight_layout()\n",
918 | " plt.show()\n",
919 | " fig.savefig('{}/loss_plot_fold{}_run{}.png'.format(model_name, fold, run), bbox_inches='tight',dpi=400)"
920 | ]
921 | },
922 | {
923 | "cell_type": "code",
924 | "execution_count": null,
925 | "metadata": {},
926 | "outputs": [],
927 | "source": []
928 | },
929 | {
930 | "cell_type": "code",
931 | "execution_count": null,
932 | "metadata": {},
933 | "outputs": [],
934 | "source": []
935 | },
936 | {
937 | "cell_type": "code",
938 | "execution_count": null,
939 | "metadata": {},
940 | "outputs": [],
941 | "source": []
942 | },
943 | {
944 | "cell_type": "code",
945 | "execution_count": null,
946 | "metadata": {},
947 | "outputs": [],
948 | "source": []
949 | },
950 | {
951 | "cell_type": "code",
952 | "execution_count": null,
953 | "metadata": {},
954 | "outputs": [],
955 | "source": []
956 | },
957 | {
958 | "cell_type": "code",
959 | "execution_count": null,
960 | "metadata": {},
961 | "outputs": [],
962 | "source": []
963 | },
964 | {
965 | "cell_type": "code",
966 | "execution_count": null,
967 | "metadata": {},
968 | "outputs": [],
969 | "source": []
970 | }
971 | ],
972 | "metadata": {
973 | "kernelspec": {
974 | "display_name": "CellStar",
975 | "language": "python",
976 | "name": "cellstar"
977 | },
978 | "language_info": {
979 | "codemirror_mode": {
980 | "name": "ipython",
981 | "version": 3
982 | },
983 | "file_extension": ".py",
984 | "mimetype": "text/x-python",
985 | "name": "python",
986 | "nbconvert_exporter": "python",
987 | "pygments_lexer": "ipython3",
988 | "version": "3.7.6"
989 | }
990 | },
991 | "nbformat": 4,
992 | "nbformat_minor": 4
993 | }
994 |
--------------------------------------------------------------------------------