├── LICENSE ├── README.md ├── data ├── STRINGDB.graph.csv.zip └── expr.zip └── src ├── dataset.py ├── demo.ipynb ├── gen_data.py ├── model.py ├── scGraph.py └── scheduler.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Qijin Yin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scGraph 2 | ScGraph is a GNN-based automatic cell identification algorithm leveraging gene interaction relationships to enhance the performance of the cell type identification. 3 | 4 | # Requirements 5 | 6 | - python = 3.6.7 7 | - pytorch = 1.1.0 8 | - pytorch-geometric = 1.3.1 9 | - sklearn 10 | 11 | # Installation 12 | 13 | Download scGraph by 14 | 15 | ```shell 16 | git clone https://github.com/QijinYin/scGraph 17 | ``` 18 | 19 | Installation has been tested in a Linux platform with Python3.6. 20 | 21 | # Instructions 22 | 23 | There is a demo including data preprocessing and model training in ``src/demo.ipynb`` file. 24 | 25 | 26 | ## Preprocessing data for model training 27 | 28 | ```shell 29 | python gen_data.py -expr -label -net -out 30 | ``` 31 | ``` 32 | Arguments: 33 | expr_mat_file: scRNA-seq expression matrix with genes as rows and cells as columns (csv format) 34 | e.g. EntrezID,barocode1,barocode2,barocode3,barocode3 35 | 5685,1,0,0,0 36 | 5692,0,0,0,0 37 | 6193,0,0,0,1 38 | 39 | expr_label_file: cell types assignments (csv format) 40 | e.g. Barcodes ,label 41 | barocode1, celltype1 42 | barocode2, celltype1 43 | barocode3, celltype2 44 | barocode4, celltype3 45 | 46 | network_backbone_file: gene interactin network backbone (csv format) 47 | e.g. STRING database,1~3 cloumns indicate gene1, gene2 and combined_score respectively. Genes are in Entrez ID format. 48 | 23521,6193,999 49 | 5692,5685,999 50 | 5591,2547,999 51 | 6222,25873,999 52 | 53 | outputfile: preprocessed data for model training (npz format) 54 | 55 | Options: 56 | -q the top q quantile of network edges are used (default: 0.99 for STRING database) 57 | ``` 58 | 59 | ## Run scGraph model 60 | 61 | ```shell 62 | python scGraph.py -in -out-dir -bs 63 | ``` 64 | 65 | ``` 66 | Arguments: 67 | inputfile: preprocessed data for model training (npz format) 68 | outputfolder: the folder in which prediction results are saved 69 | batch_size : batch size for model training 70 | ``` 71 | 72 | # License 73 | 74 | This project is licensed under the MIT License - see the LICENSE.md file for details 75 | -------------------------------------------------------------------------------- /data/STRINGDB.graph.csv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QijinYin/scGraph/9d656f54001e299403fd3a71b6ff8e8a7723714a/data/STRINGDB.graph.csv.zip -------------------------------------------------------------------------------- /data/expr.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QijinYin/scGraph/9d656f54001e299403fd3a71b6ff8e8a7723714a/data/expr.zip -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import scipy 6 | from copy import deepcopy 7 | import numpy as np 8 | import pandas as pd 9 | import sys 10 | import math 11 | import pickle as pkl 12 | import math 13 | import torch as t 14 | from torch import Tensor 15 | import torch.nn.functional as F 16 | from torch.nn import Linear 17 | from sklearn.metrics import precision_score,f1_score 18 | from torch_geometric.utils import to_undirected,remove_self_loops 19 | from torch.nn.init import xavier_normal_,kaiming_normal_ 20 | from torch.nn.init import uniform_,kaiming_uniform_,constant 21 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 22 | from torch_geometric.data import Batch,Data 23 | from collections import Counter 24 | from torch.utils import data as tdata 25 | from sklearn.model_selection import StratifiedKFold 26 | 27 | 28 | 29 | def collate_func(batch): 30 | data0 = batch[0] 31 | if isinstance(data0,Data): 32 | tmp_x = [xx['x'] for xx in batch] 33 | tmp_y = [xx['y'] for xx in batch] 34 | # tmp_data = Data() 35 | # tmp_data['x']= t.stack(tmp_x,dim=1) 36 | # tmp_data['y']= t.cat(tmp_y) # 37 | # tmp_data['edge_index']=data0.edge_index 38 | # return Batch.from_data_list([tmp_data]) 39 | elif isinstance(data0,(list,tuple)): 40 | tmp_x = [xx[0] for xx in batch] 41 | tmp_y = [xx[1] for xx in batch] 42 | 43 | tmp_data = Data() 44 | tmp_data['x']= t.stack(tmp_x,dim=1) 45 | tmp_data['y']= t.cat(tmp_y) # 46 | tmp_data['edge_index']=data0.edge_index 47 | tmp_data['batch'] = t.zeros_like(tmp_data['y']) 48 | tmp_data['num_graphs'] = 1 49 | return tmp_data 50 | # return Batch.from_data_list([tmp_data]) 51 | 52 | 53 | class DataLoader(torch.utils.data.DataLoader): 54 | def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=[], 55 | **kwargs): 56 | if 'collate_fn' not in kwargs.keys(): 57 | raise 58 | 59 | super(DataLoader, 60 | self).__init__(dataset, batch_size, shuffle, **kwargs) 61 | 62 | 63 | class ExprDataset(tdata.Dataset):#需要继承data.Dataset 64 | def __init__(self,Expr,edge,y,device='cuda',): 65 | super(ExprDataset, self).__init__() 66 | 67 | print('processing...') 68 | self.gene_num = Expr.shape[1] 69 | if isinstance(edge,list): 70 | print('multi graphs:',len(edge)) 71 | self.edge_num = [x.shape[1] for x in edge] 72 | 73 | self.common_edge =[t.tensor(x).long().to(device) if not isinstance(x,t.Tensor) else x for x in edge] 74 | 75 | elif isinstance(edge,(np.ndarray,t.Tensor)): 76 | print('only has 1 graph.') 77 | self.edge_num = edge.shape[1] 78 | self.common_edge = edge if isinstance(edge,t.Tensor) else t.tensor(edge).long().to(device) 79 | 80 | 81 | self.Expr = Expr 82 | self.y = y 83 | self.num_sam = len(self.y) 84 | self.sample_mapping_list = np.arange(self.num_sam) 85 | 86 | if len(self.Expr.shape) ==2: 87 | self.num_expr_feaure = 1 88 | else: 89 | self.num_expr_feaure = self.Expr.shape[2] 90 | 91 | 92 | def duplicate_minor_types(self,dup_odds=50,random_seed=2240): 93 | 94 | counter = Counter(self.y) 95 | max_num_types = max(counter.values() ) 96 | impute_indexs = np.arange(self.num_sam).tolist() 97 | np.random.seed(2240) 98 | for lab in np.unique(self.y): 99 | # print('123,',max_num_types,np.sum(self.y==lab),dup_odds) 100 | if max_num_types/np.sum(self.y==lab) >dup_odds: 101 | impute_size = int(max_num_types/dup_odds) - np.sum(self.y==lab) 102 | print('duplicate #celltype %d with %d cells'%(lab,impute_size)) 103 | # print(impute_size) 104 | impute_idx = np.random.choice(np.where(self.y==lab)[0],size=impute_size,replace=True).tolist() 105 | impute_indexs += impute_idx 106 | impute_indexs = np.random.permutation(impute_indexs) 107 | print('org/imputed #cells:',self.num_sam,len(impute_indexs)) 108 | print('imputed amounts of each cell types',Counter(self.y[impute_indexs])) 109 | self.num_sam = len(impute_indexs) 110 | self.sample_mapping_list = impute_indexs 111 | 112 | 113 | def __getitem__(self, idx): 114 | if isinstance(idx, int): 115 | 116 | idx = self.sample_mapping_list[idx] 117 | data = self.get(idx) 118 | return data 119 | raise IndexError( 120 | 'Only integers are valid ' 121 | 'indices (got {}).'.format(type(idx).__name__)) 122 | pass 123 | 124 | def split(self,idx): 125 | return ExprDataset(self.Expr[idx,:],self.common_edge,self.y[idx],) 126 | 127 | def __len__(self): 128 | # You should change 0 to the total size of your dataset. 129 | return self.num_sam 130 | 131 | def get(self,index): 132 | data = Data() 133 | data['x']= t.tensor(self.Expr[index,:].reshape([-1,self.num_expr_feaure])).float() 134 | data['y']= t.tensor(self.y[index].reshape([1,1])).long() # 135 | data['edge_index']=self.common_edge 136 | 137 | # data.to(device) 138 | return data -------------------------------------------------------------------------------- /src/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6c86a11e", 6 | "metadata": {}, 7 | "source": [ 8 | "# data process" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "959d6d8b", 14 | "metadata": {}, 15 | "source": [ 16 | "## uncompress tiny data" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "id": "16379445", 23 | "metadata": { 24 | "ExecuteTime": { 25 | "end_time": "2021-09-22T14:54:51.087055Z", 26 | "start_time": "2021-09-22T14:54:50.968250Z" 27 | } 28 | }, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "expr.zip STRINGDB.graph.csv.zip\r\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "!ls ../data" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "id": "e8587e20", 46 | "metadata": { 47 | "ExecuteTime": { 48 | "end_time": "2021-09-22T14:57:36.848577Z", 49 | "start_time": "2021-09-22T14:57:35.497939Z" 50 | } 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "Archive: ../data/expr.zip\n", 58 | " inflating: ../data/expr.label.subset.csv \n", 59 | " inflating: ../data/expr.mat.subset.csv \n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "!unzip ../data/expr.zip -d ../data" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "id": "4cf91391", 71 | "metadata": { 72 | "ExecuteTime": { 73 | "end_time": "2021-09-22T14:57:57.003171Z", 74 | "start_time": "2021-09-22T14:57:53.328431Z" 75 | } 76 | }, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Archive: ../data/STRINGDB.graph.csv.zip\n", 83 | " inflating: ../data/STRINGDB.graph.csv \n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "!unzip ../data/STRINGDB.graph.csv.zip -d ../data" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "b24151e2", 94 | "metadata": {}, 95 | "source": [ 96 | "## construct dataset" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 8, 102 | "id": "f4d951bb", 103 | "metadata": { 104 | "ExecuteTime": { 105 | "end_time": "2021-09-22T15:21:44.841792Z", 106 | "start_time": "2021-09-22T15:21:13.564764Z" 107 | }, 108 | "scrolled": true 109 | }, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "args: Namespace(expr='../data/expr.mat.subset.csv', label='../data/expr.label.subset.csv', net='../../gen_data/sc_data/graph/STRINGDB.graph.csv', outfile='../data/dataset.npz', quantile=0.99)\n", 116 | "shape of expression matrix [#genes,#cells]: (23459, 1000)\n", 117 | "shape of cell labels: 1000\n", 118 | "number of cell types: 5\n", 119 | "shape of backbone network: (133373, 2)\n", 120 | "Finished.\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "! python gen_data.py -expr ../data/expr.mat.subset.csv -label ../data/expr.label.subset.csv \\\n", 126 | " -net ../../gen_data/sc_data/graph/STRINGDB.graph.csv -out ../data/dataset.npz" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "9e80e450", 132 | "metadata": {}, 133 | "source": [ 134 | "# model training" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 9, 140 | "id": "8a871115", 141 | "metadata": { 142 | "ExecuteTime": { 143 | "end_time": "2021-09-22T15:23:10.546152Z", 144 | "start_time": "2021-09-22T15:21:44.844759Z" 145 | } 146 | }, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "args: Namespace(batch_size=64, cuda=True, infile='../data/dataset.npz', outdir='../results')\n", 153 | "use wegithed cross entropy.... \n", 154 | "processing...\n", 155 | "only has 1 graph.\n", 156 | "processing...\n", 157 | "only has 1 graph.\n", 158 | "processing...\n", 159 | "only has 1 graph.\n", 160 | "org/imputed #cells: 800 800\n", 161 | "imputed amounts of each cell types Counter({4: 160, 0: 160, 1: 160, 3: 160, 2: 160})\n", 162 | "model dropout raito: 0.1\n", 163 | "/home/yinqijin/WorkSpace/9.gnn/scGraph/src/model.py:104: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", 164 | " nn.init.xavier_uniform(m.weight)\n", 165 | "scGraph(\n", 166 | " (conv1): SAGEConv(1, 8)\n", 167 | " (bn1): LayerNorm(torch.Size([23459, 8]), eps=1e-05, elementwise_affine=True)\n", 168 | " (act1): ReLU()\n", 169 | " (global_conv1): Conv2d(8, 12, kernel_size=[1, 1], stride=(1, 1))\n", 170 | " (global_bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 171 | " (global_act1): ReLU()\n", 172 | " (global_conv2): Conv2d(12, 4, kernel_size=[1, 1], stride=(1, 1))\n", 173 | " (global_bn2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 174 | " (global_act2): ReLU()\n", 175 | " (global_fc_nn): Sequential(\n", 176 | " (0): Linear(in_features=93836, out_features=256, bias=True)\n", 177 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 178 | " (2): Dropout(p=0.3)\n", 179 | " (3): ReLU()\n", 180 | " (4): Linear(in_features=256, out_features=64, bias=True)\n", 181 | " (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 182 | " (6): Dropout(p=0.3)\n", 183 | " (7): ReLU()\n", 184 | " )\n", 185 | " (fc1): Linear(in_features=64, out_features=5, bias=True)\n", 186 | ")\n", 187 | "/home/yinqijin/Software/anaconda3/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 188 | " _warn_prf(average, modifier, msg_start, len(result))\n", 189 | "epoch\t001,lr : 0.005657,loss: 1.433628,T-acc: 0.2420,T-f1: 0.1578\n", 190 | "/home/yinqijin/Software/anaconda3/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 191 | " _warn_prf(average, modifier, msg_start, len(result))\n", 192 | "epoch\t002,lr : 0.000053,loss: 1.224099,T-acc: 0.4267,T-f1: 0.2706\n", 193 | "epoch\t003,lr : 0.004381,loss: 1.132943,T-acc: 0.8060,T-f1: 0.5439\n", 194 | "epoch\t004,lr : 0.002668,loss: 1.031652,T-acc: 0.9695,T-f1: 0.9676\n", 195 | "epoch\t005,lr : 0.000860,loss: 0.982878,T-acc: 0.9860,T-f1: 0.9856\n", 196 | "epoch\t006,lr : 0.000015,loss: 0.960804,T-acc: 0.9922,T-f1: 0.9921\n", 197 | "epoch\t007,lr : 0.002420,loss: 0.952400,T-acc: 0.9974,T-f1: 0.9974\n", 198 | "epoch\t008,lr : 0.002164,loss: 0.940969,T-acc: 0.9987,T-f1: 0.9987\n", 199 | "epoch\t009,lr : 0.001769,loss: 0.930289,T-acc: 0.9987,T-f1: 0.9987\n", 200 | "epoch\t010,lr : 0.001296,loss: 0.929939,T-acc: 0.9987,T-f1: 0.9987\n", 201 | "epoch\t011,lr : 0.000816,loss: 0.924681,T-acc: 0.9987,T-f1: 0.9987\n", 202 | "epoch\t012,lr : 0.000404,loss: 0.916773,T-acc: 0.9987,T-f1: 0.9987\n", 203 | "epoch\t013,lr : 0.000121,loss: 0.918066,T-acc: 0.9987,T-f1: 0.9987\n", 204 | "epoch\t014,lr : 0.000011,loss: 0.919388,T-acc: 1.0000,T-f1: 1.0000\n", 205 | "epoch\t015,lr : 0.001240,loss: 0.918222,T-acc: 1.0000,T-f1: 1.0000\n", 206 | "stage 2 training...\n", 207 | "processing...\n", 208 | "only has 1 graph.\n", 209 | "processing...\n", 210 | "only has 1 graph.\n", 211 | "org/imputed #cells: 640 640\n", 212 | "imputed amounts of each cell types Counter({2: 128, 3: 128, 1: 128, 0: 128, 4: 128})\n", 213 | "stage2 initilize lr: 0.0012399845236747882\n", 214 | "epoch\t016,lr : 0.001240,loss: 0.920864,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 215 | "epoch\t017,lr : 0.001240,loss: 0.916894,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 216 | "epoch\t018,lr : 0.001240,loss: 0.915296,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 217 | "epoch\t019,lr : 0.001240,loss: 0.919547,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 218 | "Epoch 3: reducing learning rate of group 0 to 1.2400e-04.\n", 219 | "reset max_metric_count to 0 due to updating lr from 0.001240 to 0.000124\n", 220 | "epoch\t020,lr : 0.000124,loss: 0.912670,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 221 | "epoch\t021,lr : 0.000124,loss: 0.913889,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 222 | "epoch\t022,lr : 0.000124,loss: 0.911622,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 223 | "Epoch 6: reducing learning rate of group 0 to 1.2400e-05.\n", 224 | "reset max_metric_count to 0 due to updating lr from 0.000124 to 0.000012\n", 225 | "epoch\t023,lr : 0.000012,loss: 0.912611,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 226 | "epoch\t024,lr : 0.000012,loss: 0.912024,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 227 | "epoch\t025,lr : 0.000012,loss: 0.913034,T-acc: 1.0000,T-f1: 1.0000,acc: 1.0000,f1: 1.0000\n", 228 | "Epoch 9: reducing learning rate of group 0 to 1.0000e-05.\n", 229 | "F1: 0.930,Acc: 0.930\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "! CUDA_VISIBLE_DEVICES=5 python -u scGraph.py -in ../data/dataset.npz -out-dir ../results" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "9df74c50", 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [] 244 | } 245 | ], 246 | "metadata": { 247 | "hide_input": false, 248 | "kernelspec": { 249 | "display_name": "pytorch1.1", 250 | "language": "python", 251 | "name": "pytorch1.1" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.6.7" 264 | }, 265 | "toc": { 266 | "base_numbering": 1, 267 | "nav_menu": {}, 268 | "number_sections": true, 269 | "sideBar": true, 270 | "skip_h1_title": false, 271 | "title_cell": "Table of Contents", 272 | "title_sidebar": "Contents", 273 | "toc_cell": false, 274 | "toc_position": {}, 275 | "toc_section_display": true, 276 | "toc_window_display": false 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 5 281 | } 282 | -------------------------------------------------------------------------------- /src/gen_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | from scipy.io import mmread 5 | 6 | 7 | # + 8 | 9 | def get_parser(parser=None): 10 | if parser == None: 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-expr','--expr',type=str,) 13 | parser.add_argument('-label','--label',type=str,) 14 | parser.add_argument('-net','--net',type=str,) 15 | parser.add_argument('-out','--outfile',type=str,) 16 | parser.add_argument('-q','--quantile',type=float,default='0.99') 17 | return parser 18 | 19 | 20 | 21 | def add_remaining_self_loop_for_edge_df(edge_df, 22 | edge_weight_column = 'score', 23 | fill_value = 1., 24 | num_nodes= None): 25 | 26 | ''' 27 | edge_df : #num_edges x 2 28 | 29 | ''' 30 | assert 'node1' in edge_df.columns 31 | assert 'node2' in edge_df.columns 32 | edge_index = edge_df[['node1','node2']].T.values 33 | row, col = edge_index[0], edge_index[1] 34 | N = num_nodes if num_nodes is not None else np.max(edge_index)+1 35 | 36 | mask = row == col 37 | added_index = list(set( np.arange(0, N, dtype=int ))-set(row[mask])) 38 | 39 | new_df = pd.DataFrame() 40 | new_df['node1'] = added_index 41 | new_df['node2'] = added_index 42 | 43 | if edge_weight_column in edge_df.columns: 44 | new_df[edge_weight_column] = fill_value 45 | 46 | edge_df = edge_df.append(new_df, ignore_index=True) 47 | return edge_df 48 | 49 | 50 | def coding_edge_with_ref_gene_idx(converted_graph_gene,converted_expr_gene): 51 | ''' 52 | converted_graph_gene: #2 Dim 53 | converted_expr_gene: #1 Dim 54 | convert graph_gene to index which is the order of converted_expr_gene and drop nan 55 | ''' 56 | graph_edge_df = pd.DataFrame(converted_graph_gene,columns=['node1','node2']) 57 | gene_to_index_dict = {g:idx for idx,g in enumerate(converted_expr_gene)} 58 | 59 | mapfunc = lambda x: gene_to_index_dict.get(x,np.nan) 60 | graph_edge_df = graph_edge_df.applymap(mapfunc) 61 | graph_edge_df = graph_edge_df.dropna(axis='index') 62 | 63 | # print(graph_edge_df.shape) 64 | graph_edge_df = add_remaining_self_loop_for_edge_df(graph_edge_df,edge_weight_column='score',fill_value=1,num_nodes=len(converted_expr_gene)) 65 | # print(graph_edge_df.shape) 66 | return graph_edge_df.values 67 | 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = get_parser() 72 | args = parser.parse_args() 73 | print('args:',args) 74 | 75 | expr_file = args.expr 76 | label_file = args.label 77 | net_file = args.net 78 | thres = args.quantile 79 | save_file = args.outfile 80 | assert 0<=thres<=1,"quantile should be a float value in [0,1]." 81 | 82 | 83 | data_df = pd.read_csv(expr_file,header=0,index_col=0) 84 | label_df = pd.read_csv(label_file,header=0,index_col=0) 85 | 86 | graph_df = pd.read_csv(net_file,header=None,index_col=None,) 87 | graph_df.columns = ['node1', 'node2', 'score'] 88 | graph_df = graph_df.loc[graph_df.score.ge(graph_df.score.quantile(0.99)).values,['node1','node2']] # quantile 0.99 89 | 90 | 91 | # normalize + log1p transform for read counts 92 | data_df = data_df.apply(lambda x: 1e6* x/x.sum()+1e-5,axis=0) 93 | data_df = data_df.applymap(np.log1p) 94 | 95 | 96 | str_labels = np.unique(label_df.values).tolist() 97 | label = [str_labels.index(x) for x in label_df.values ] 98 | gene = data_df.index.values 99 | barcode = data_df.columns.values 100 | edge_index = coding_edge_with_ref_gene_idx(graph_df.values,gene) 101 | 102 | print('shape of expression matrix [#genes,#cells]:',data_df.shape) 103 | print('shape of cell labels:',len(label)) 104 | print('number of cell types:',len(str_labels)) 105 | print('shape of backbone network:',edge_index.shape) 106 | 107 | 108 | data_dict = {} 109 | data_dict['barcode'] = barcode 110 | data_dict['gene'] = gene 111 | data_dict['logExpr'] = data_df.values 112 | data_dict['str_labels'] = str_labels 113 | data_dict['label'] = label 114 | data_dict['edge_index'] = edge_index 115 | 116 | np.savez(save_file,**data_dict) 117 | 118 | print('Finished.') -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import scipy 6 | from copy import deepcopy 7 | import numpy as np 8 | import pandas as pd 9 | import sys 10 | import math 11 | import pickle as pkl 12 | import math 13 | import torch as t 14 | from torch import Tensor 15 | import torch.nn.functional as F 16 | from torch.nn import Linear 17 | from sklearn.metrics import precision_score,f1_score 18 | from torch_geometric.utils import to_undirected,remove_self_loops 19 | from torch.nn.init import xavier_normal_,kaiming_normal_ 20 | from torch.nn.init import uniform_,kaiming_uniform_,constant 21 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 22 | from torch_geometric.data import Batch,Data 23 | from collections import Counter 24 | from torch.utils import data as tdata 25 | from sklearn.model_selection import StratifiedKFold 26 | 27 | 28 | 29 | import torch 30 | import torch.nn.functional as F 31 | from torch.nn import Parameter 32 | from torch_geometric.nn.conv import MessagePassing 33 | from torch_geometric.utils import add_remaining_self_loops 34 | import math 35 | def uniform(size, tensor): 36 | bound = 1.0 / math.sqrt(size) 37 | if tensor is not None: 38 | tensor.data.uniform_(-bound, bound) 39 | 40 | class SAGEConv(MessagePassing): 41 | def __init__(self, in_channels, out_channels, normalize=False, bias=True,activate=False,alphas=[0,1],shared_weight=False,aggr = 'mean', 42 | **kwargs): 43 | super(SAGEConv, self).__init__(aggr=aggr, **kwargs) 44 | self.shared_weight = shared_weight 45 | self.activate = activate 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | self.normalize = normalize 49 | 50 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 51 | if self.shared_weight: 52 | self.self_weight = self.weight 53 | else: 54 | self.self_weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 55 | self.alphas = alphas #[self_alpha, pro_alpha] 56 | if bias: 57 | self.bias = Parameter(torch.Tensor(out_channels)) 58 | else: 59 | self.register_parameter('bias', None) 60 | 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | uniform(self.in_channels, self.weight) 65 | uniform(self.in_channels, self.bias) 66 | uniform(self.in_channels,self.self_weight) 67 | 68 | def forward(self, x, edge_index, edge_weight=None, size=None): 69 | 70 | out = torch.matmul(x,self.self_weight ) 71 | out2 = self.propagate(edge_index, size=size, x=x, 72 | edge_weight=edge_weight) 73 | return self.alphas[0]*out+ self.alphas[1]* out2 74 | 75 | def message(self, x_j, edge_weight): 76 | return x_j if edge_weight is None else edge_weight.view(-1, 1,1) * x_j 77 | 78 | def update(self, aggr_out): 79 | 80 | if self.activate: 81 | aggr_out = F.relu(aggr_out) 82 | 83 | if torch.is_tensor(aggr_out): 84 | aggr_out = torch.matmul(aggr_out,self.weight ) 85 | else: 86 | aggr_out = (None if aggr_out[0] is None else torch.matmul(aggr_out[0], self.weight), 87 | None if aggr_out[1] is None else torch.matmul(aggr_out[1], self.weight)) 88 | if self.bias is not None: 89 | aggr_out = aggr_out + self.bias 90 | if self.normalize: 91 | aggr_out = F.normalize(aggr_out, p=2, dim=-1) 92 | return aggr_out 93 | 94 | def __repr__(self): 95 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 96 | self.out_channels) 97 | 98 | 99 | 100 | from torch.nn.init import normal,uniform_ 101 | def init_weights(m): 102 | if type(m) == nn.Linear: 103 | # nn.init.kaiming_normal_(m.weight, mode='fan_out') 104 | nn.init.xavier_uniform(m.weight) 105 | m.bias.data.fill_(0.01) 106 | 107 | def help_bn(bn1,x): 108 | # dim1, dim2,dim3 = x.shape 109 | x = x.permute(1,0,2) # #samples x #nodes x #features 110 | x = bn1(x) 111 | x = x.permute(1,0,2) # #nodes x #samples x #features 112 | return x 113 | 114 | class scGraph(nn.Module): 115 | def __init__(self,in_channel=1,mid_channel=8,out_channel=2,num_nodes=2207,edge_num=151215, 116 | **args): 117 | super(scGraph,self).__init__() 118 | self.mid_channel = mid_channel 119 | self.dropout_ratio = args.get('dropout_ratio',0.3) 120 | print('model dropout raito:',self.dropout_ratio) 121 | n_out_nodes = num_nodes 122 | self.global_conv1_dim = 4*3 123 | self.global_conv2_dim = args.get('global_conv2_dim',4) 124 | self.conv1 = SAGEConv(in_channel, mid_channel, ) 125 | self.bn1 = torch.nn.LayerNorm((num_nodes,mid_channel)) 126 | self.act1 = nn.ReLU() 127 | 128 | self.global_conv1 = t.nn.Conv2d(mid_channel*1,self.global_conv1_dim,[1,1]) 129 | self.global_bn1 = torch.nn.BatchNorm2d(self.global_conv1_dim) 130 | self.global_act1 = nn.ReLU() 131 | self.global_conv2 = t.nn.Conv2d(self.global_conv1_dim,self.global_conv2_dim,[1,1]) 132 | self.global_bn2 = torch.nn.BatchNorm2d(self.global_conv2_dim) 133 | self.global_act2 = nn.ReLU() 134 | 135 | 136 | last_feature_node = 64 137 | channel_list = [self.global_conv2_dim*n_out_nodes,256,64] 138 | if args.get('channel_list',False): 139 | channel_list = [self.global_conv2_dim*n_out_nodes,128] 140 | last_feature_node = 128 141 | 142 | self.nn = [] 143 | for idx,num in enumerate(channel_list[:-1]): 144 | self.nn.append(nn.Linear(channel_list[idx],channel_list[idx+1])) 145 | self.nn.append(nn.BatchNorm1d(channel_list[idx+1])) 146 | if self.dropout_ratio >0: 147 | self.nn.append(nn.Dropout(0.3)) 148 | self.nn.append(nn.ReLU()) 149 | self.global_fc_nn =nn.Sequential(*self.nn) 150 | self.fc1 = nn.Linear(last_feature_node,out_channel) 151 | 152 | self.edge_num = edge_num 153 | self.weight_edge_flag = True 154 | # print('trainalbe edges :',self.weight_edge_flag) 155 | if self.weight_edge_flag: 156 | self.edge_weight = nn.Parameter(t.ones(edge_num).float()*0.01) 157 | # _=normal(self.edge_weight,mean=0,std=0.01) 158 | # _ = uniform_(self.edge_weight,a=-2,b=2) 159 | else: 160 | self.edge_weight = None 161 | 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self,): 165 | # uniform(self.mid_channel, self.global_conv1.weight 166 | # 167 | self.conv1.apply(init_weights) 168 | nn.init.kaiming_normal_(self.global_conv1.weight, mode='fan_out') 169 | uniform(self.mid_channel, self.global_conv1.bias) 170 | 171 | 172 | nn.init.kaiming_normal_(self.global_conv2.weight, mode='fan_out') 173 | uniform(self.global_conv1_dim, self.global_conv2.bias) 174 | 175 | self.global_fc_nn.apply(init_weights) 176 | self.fc1.apply(init_weights) 177 | # uniform(self.in_channels, self.bias) 178 | # uniform(self.in_channels,self.self_weight) 179 | pass 180 | 181 | def get_gcn_weight_penalty(self,mode='L2'): 182 | 183 | if mode == 'L1': 184 | func = lambda x: t.sum(t.abs(x)) 185 | elif mode =='L2': 186 | func = lambda x: t.sqrt(t.sum(x**2)) 187 | 188 | loss = 0 189 | 190 | tmp = getattr(self.conv1,'weight',None) 191 | if tmp is not None: 192 | loss += func(tmp) 193 | 194 | tmp = getattr(self.conv1,'self_weight',None) 195 | if tmp is not None: 196 | loss += 1* func(tmp) 197 | 198 | tmp = getattr(self.global_conv1,'weight',None) 199 | if tmp is not None: 200 | loss += func(tmp) 201 | tmp = getattr(self.global_conv2,'weight',None) 202 | if tmp is not None: 203 | loss += func(tmp) 204 | 205 | return loss 206 | 207 | 208 | def forward(self,data,get_latent_varaible=False): 209 | x, edge_index, batch = data.x, data.edge_index, data.batch 210 | 211 | if self.weight_edge_flag: 212 | one_graph_edge_weight=torch.sigmoid(self.edge_weight)#*self.edge_num 213 | edge_weight = one_graph_edge_weight 214 | else: 215 | edge_weight = None 216 | 217 | x = self.act1(self.conv1(x, edge_index,edge_weight=edge_weight)) 218 | x = help_bn(self.bn1,x) 219 | if self.dropout_ratio >0: x = F.dropout(x, p=0.1, training=self.training) 220 | 221 | 222 | x = x.permute(1,2,0) # #samples x #features x #nodes 223 | x = x.unsqueeze(dim=-1) # #samples x #features x #nodes x 1 224 | x = self.global_conv1(x) # #samples x #features x #nodes x 1 225 | x = self.global_act1(x) 226 | x = self.global_bn1(x) 227 | if self.dropout_ratio >0: x = F.dropout(x, p=0.3, training=self.training) 228 | x = self.global_conv2(x) 229 | x = self.global_act1(x) 230 | x = self.global_bn2(x) 231 | if self.dropout_ratio >0: x = F.dropout(x, p=0.3, training=self.training) 232 | x = x.squeeze(dim=-1) # #samples x #features x #nodes 233 | num_samples = x.shape[0] 234 | 235 | x = x .view(num_samples,-1) 236 | x = self.global_fc_nn(x) 237 | if get_latent_varaible: 238 | return x 239 | else: 240 | x = self.fc1(x) 241 | 242 | return F.softmax(x, dim=-1) 243 | 244 | 245 | 246 | def edge_transform_func(org_edge): 247 | edge = org_edge 248 | edge = t.tensor(edge.T) 249 | edge = remove_self_loops(edge)[0] 250 | edge = edge.numpy() 251 | return edge -------------------------------------------------------------------------------- /src/scGraph.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ ['CUDA_VISIBLE_DEVICES']='0' 3 | import torch as t 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import scipy 8 | from copy import deepcopy 9 | import numpy as np 10 | import pandas as pd 11 | import sys 12 | import math 13 | import pickle as pkl 14 | import math 15 | import torch as t 16 | from torch import Tensor 17 | import torch.nn.functional as F 18 | from torch.nn import Linear 19 | from sklearn.metrics import precision_score,f1_score 20 | from torch_geometric.utils import to_undirected,remove_self_loops 21 | from torch.nn.init import xavier_normal_,kaiming_normal_ 22 | from torch.nn.init import uniform_,kaiming_uniform_,constant 23 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 24 | from torch_geometric.data import Batch,Data 25 | from collections import Counter 26 | from torch.utils import data as tdata 27 | from sklearn.model_selection import StratifiedKFold 28 | import argparse 29 | 30 | 31 | from dataset import (collate_func,DataLoader,ExprDataset) 32 | from scheduler import CosineAnnealingWarmRestarts 33 | from model import (scGraph,edge_transform_func) 34 | 35 | pathjoin = os.path.join 36 | def get_parser(parser=None): 37 | if parser == None: 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-in','--infile',type=str,) 40 | parser.add_argument('-out-dir','--outdir',type=str,default='../results') 41 | parser.add_argument('-cuda','--cuda',type=bool,default=True) 42 | parser.add_argument('-bs','--batch_size',type=int,default=64) 43 | return parser 44 | 45 | def train2(model,optimizer,train_loader,epoch,device,loss_fn =None,scheduler =None,verbose=False ): 46 | model.train() 47 | 48 | loss_all = 0 49 | iters = len(train_loader) 50 | for idx,data in enumerate( train_loader): 51 | # print('epoch,idx',epoch,idx) 52 | # data.x = data.x + t.rand_like(data.x)*1e-5 53 | data = data.to(device) 54 | if verbose: 55 | print(data.y.shape,data.edge_index.shape) 56 | optimizer.zero_grad() 57 | output = model(data) 58 | if loss_fn is None: 59 | loss = F.cross_entropy(output, data.y.reshape(-1), weight=None,) 60 | else: 61 | loss = loss_fn(output, data.y.reshape(-1)) 62 | 63 | if model.edge_weight is not None: 64 | l2_loss = 0 65 | if isinstance(model.edge_weight,nn.Module): # nn.ParamterList 66 | for edge_weight in model.edge_weight : 67 | l2_loss += 0.1* t.mean((edge_weight)**2) 68 | elif isinstance(model.edge_weight,t.Tensor): 69 | l2_loss =0.1* t.mean((model.edge_weight)**2) 70 | # print(loss.cpu().detach().numpy(),l2_loss.cpu().detach().numpy()) 71 | loss+=l2_loss 72 | 73 | 74 | 75 | loss.backward() 76 | loss_all += loss.item() * data.num_graphs 77 | optimizer.step() 78 | 79 | if not (scheduler is None): 80 | scheduler.step( (epoch -1) + idx/iters) # let "epoch" begin from 0 81 | 82 | return loss_all / iters # len(train_dataset) 83 | 84 | def test2(model,loader,predicts=False): 85 | model.eval() 86 | 87 | correct = 0 88 | y_pred =[] 89 | y_true=[] 90 | y_output=[] 91 | for data in loader: 92 | data = data.to(device) 93 | # print(data.y.shape) 94 | output = model(data) 95 | pred = output.max(dim=1)[1].cpu().data.numpy() 96 | y = data.y.cpu().data.numpy() 97 | y_pred.extend(pred) 98 | y_true.extend(y) 99 | y_output.extend(output.cpu().data.numpy()) 100 | 101 | acc = precision_score(y_true,y_pred,average='macro') 102 | f1 = f1_score(y_true,y_pred,average='macro') 103 | if predicts: 104 | return acc,f1,y_true,np.array(y_pred),y_output 105 | else: 106 | return acc,f1 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | 112 | parser = get_parser() 113 | args = parser.parse_args() 114 | print('args:',args) 115 | 116 | cuda_flag = args.cuda 117 | npz_file = args.infile 118 | save_folder = args.outdir 119 | batch_size = args.batch_size 120 | 121 | device = torch.device('cuda' if torch.cuda.is_available() and cuda_flag else 'cpu') 122 | 123 | prob_file = pathjoin(save_folder,'predicted_probabilities.txt') 124 | pred_file = pathjoin(save_folder,'predicted_label.txt') 125 | true_file = pathjoin(save_folder,'true_label.txt') 126 | os.makedirs(pathjoin(save_folder,'models'),exist_ok=True) 127 | 128 | data= np.load(npz_file,allow_pickle=True) 129 | logExpr = data['logExpr'].T # logExpr: row-cell, column-gene 130 | label = data['label'] 131 | str_labels = data['str_labels'] 132 | used_edge = edge_transform_func(data['edge_index'],) 133 | 134 | num_samples = logExpr.shape[0] 135 | 136 | 137 | 138 | init_lr =0.01 139 | min_lr = 0.00001 140 | max_epoch= 16 141 | # batch_size = 64 142 | weight_decay = 1e-4 143 | dropout_ratio = 0.1 144 | 145 | print('use wegithed cross entropy.... ') 146 | label_type = np.unique(label.reshape(-1)) 147 | alpha = np.array([ np.sum(label == x) for x in label_type]) 148 | alpha = np.max(alpha) / alpha 149 | alpha = np.clip(alpha,1,50) 150 | alpha = alpha/ np.sum(alpha) 151 | loss_fn = t.nn.CrossEntropyLoss(weight = t.tensor(alpha).float()) 152 | loss_fn = loss_fn.to(device) 153 | 154 | 155 | dataset = ExprDataset(Expr=logExpr,edge=used_edge,y=label,device=device) 156 | gene_num = dataset.gene_num 157 | class_num = len(np.unique(label)) 158 | 159 | 160 | kf = StratifiedKFold(n_splits=5,shuffle=True) 161 | for tr, ts in kf.split(X=label,y=label): 162 | train_index = tr 163 | test_index = ts 164 | 165 | train_dataset = dataset.split(t.tensor(train_index).long()) 166 | test_dataset = dataset.split(t.tensor(test_index).long()) 167 | # add more samples for those small celltypes 168 | train_dataset.duplicate_minor_types(dup_odds=50) 169 | 170 | 171 | num_workers = 0 172 | assert num_workers == 0 173 | train_loader = DataLoader(train_dataset, batch_size=batch_size,num_workers=num_workers, shuffle=True,collate_fn = collate_func,drop_last=True) 174 | test_loader = DataLoader(test_dataset, batch_size=1,num_workers=num_workers,collate_fn = collate_func) 175 | 176 | model = scGraph(in_channel = dataset.num_expr_feaure , num_nodes=gene_num, 177 | out_channel=class_num,edge_num=dataset.edge_num, 178 | dropout_ratio = dropout_ratio, 179 | ).to(device) 180 | 181 | print(model) 182 | 183 | optimizer = torch.optim.Adam(model.parameters(),lr=init_lr ,weight_decay=weight_decay,) 184 | scheduler = CosineAnnealingWarmRestarts(optimizer,2, 2, eta_min=min_lr, lr_max_decay=0.5) 185 | max_metric = float(0) 186 | max_metric_count = 0 187 | weights_list = [] 188 | 189 | for epoch in range(1, max_epoch): 190 | train_loss = train2(model,optimizer,train_loader,epoch,device,loss_fn,scheduler =scheduler ) 191 | train_acc,train_f1= test2(model,train_loader,predicts=False) 192 | lr = optimizer.param_groups[0]['lr'] 193 | print('epoch\t%03d,lr : %.06f,loss: %.06f,T-acc: %.04f,T-f1: %.04f'%( 194 | epoch,lr,train_loss,train_acc,train_f1)) 195 | 196 | 197 | #stage two 198 | extend_epoch = 50 199 | print('stage 2 training...') 200 | from sklearn.model_selection import StratifiedShuffleSplit 201 | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0) 202 | for final_, valid_ in sss.split(dataset.y[train_index],dataset.y[train_index]): 203 | train_index2,valid_index2 =train_index[final_],train_index[valid_] 204 | valid_dataset = dataset.split(t.tensor(valid_index2).long()) 205 | train_dataset = dataset.split(t.tensor(train_index2).long()) 206 | if True: 207 | # add more samples for those small celltypes 208 | train_dataset.duplicate_minor_types(dup_odds=50) 209 | 210 | train_loader = DataLoader(train_dataset, batch_size=batch_size,num_workers=num_workers, shuffle=True,collate_fn = collate_func,drop_last=True) 211 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size,num_workers=num_workers,shuffle=True,collate_fn = collate_func) 212 | lr = optimizer.param_groups[0]['lr'] 213 | print('stage2 initilize lr:',lr) 214 | 215 | max_metric = float(0) 216 | max_metric_count = 0 217 | optimizer = torch.optim.Adam(model.parameters(),lr=lr ,weight_decay=weight_decay,) 218 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max',factor=0.1, patience=2, verbose=True,min_lr=0.00001) 219 | old_lr = lr 220 | for epoch_idx,epoch in enumerate(range(max_epoch,(max_epoch+extend_epoch))): 221 | if old_lr != lr: 222 | max_metric_count = 0 223 | print('reset max_metric_count to 0 due to updating lr from %f to %f'%(old_lr,lr)) 224 | old_lr = lr 225 | 226 | train_loss = train2(model,optimizer,train_loader,epoch,device,loss_fn,verbose=False ) 227 | train_acc,train_f1= test2(model,train_loader,predicts=False) 228 | valid_acc,valid_f1= test2(model,valid_loader,predicts=False) 229 | 230 | lr = optimizer.param_groups[0]['lr'] 231 | print('epoch\t%03d,lr : %.06f,loss: %.06f,T-acc: %.04f,T-f1: %.04f,acc: %.04f,f1: %.04f'%(epoch, 232 | lr,train_loss,train_acc,train_f1,valid_acc,valid_f1)) 233 | scheduler.step(valid_f1) 234 | lr = optimizer.param_groups[0]['lr'] 235 | 236 | if valid_f1 >max_metric: 237 | max_metric=valid_f1 238 | tmp_file = pathjoin(save_folder,'models','model.pth') 239 | weights_list.append(tmp_file) 240 | t.save(model,tmp_file) 241 | max_metric_count=0 242 | max_metric=valid_f1 243 | else: 244 | if epoch_idx >=2: #ignore first two epochs 245 | max_metric_count+=1 246 | if max_metric_count >3: 247 | print('break at epoch',epoch) 248 | break 249 | 250 | if lr <= 0.00001: 251 | break 252 | 253 | test_acc,test_f1,y_true,y_pred,y_output = test2(model,test_loader,predicts=True) 254 | print('F1: %.03f,Acc: %.03f'%(test_acc,test_f1)) 255 | 256 | np.savetxt(prob_file,y_output,) 257 | t.save(model,pathjoin(save_folder,'models','final_model.pth')) 258 | np.savetxt(pred_file,[str_labels[x] for x in np.reshape(y_pred,-1)],fmt="%s") 259 | np.savetxt(true_file,[str_labels[x] for x in np.reshape(y_true,-1)],fmt="%s") 260 | -------------------------------------------------------------------------------- /src/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | class CosineAnnealingWarmRestarts(_LRScheduler): 4 | r"""Set the learning rate of each parameter group using a cosine annealing 5 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 6 | is the number of epochs since the last restart and :math:`T_{i}` is the number 7 | of epochs between two warm restarts in SGDR: 8 | .. math:: 9 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 10 | \cos(\frac{T_{cur}}{T_{i}}\pi)) 11 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 12 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`. 13 | It has been proposed in 14 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 15 | Args: 16 | optimizer (Optimizer): Wrapped optimizer. 17 | T_0 (int): Number of iterations for the first restart. 18 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. 19 | eta_min (float, optional): Minimum learning rate. Default: 0. 20 | last_epoch (int, optional): The index of last epoch. Default: -1. 21 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 22 | https://arxiv.org/abs/1608.03983 23 | """ 24 | 25 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1,lr_max_decay=0.9): 26 | if T_0 <= 0 or not isinstance(T_0, int): 27 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 28 | if T_mult < 1 or not isinstance(T_mult, int): 29 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mul)) 30 | self.T_0 = T_0 31 | self.T_i = T_0 32 | self.T_mult = T_mult 33 | self.eta_min = eta_min 34 | self.lr_max_decay = lr_max_decay 35 | self.lr_max_cum_decay = 1 36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) 37 | self.T_cur = last_epoch 38 | self.cur_n = 0 39 | 40 | 41 | def get_lr(self): 42 | return [self.eta_min + (self.lr_max_cum_decay * base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 43 | for base_lr in self.base_lrs] 44 | 45 | def step(self, epoch=None): 46 | """Step could be called after every update, i.e. if one epoch has 10 iterations 47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc. 48 | This function can be called in an interleaved way. 49 | Example: 50 | >>> scheduler = SGDR(optimizer, T_0, T_mult) 51 | >>> for epoch in range(20): 52 | >>> scheduler.step() 53 | >>> scheduler.step(26) 54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 55 | """ 56 | if epoch is None: 57 | epoch = self.last_epoch + 1 58 | self.T_cur = self.T_cur + 1 59 | if self.T_cur >= self.T_i: 60 | self.T_cur = self.T_cur - self.T_i 61 | self.T_i = self.T_i * self.T_mult 62 | else: 63 | if epoch >= self.T_0: 64 | if self.T_mult == 1: 65 | self.T_cur = epoch % self.T_0 66 | else: 67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 68 | if n!= self.cur_n : 69 | #print('diff',self.cur_n,n) 70 | self.cur_n = n 71 | self.lr_max_cum_decay *= self.lr_max_decay 72 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) 73 | self.T_i = self.T_0 * self.T_mult ** (n) 74 | else: 75 | self.T_i = self.T_0 76 | self.T_cur = epoch 77 | self.last_epoch = math.floor(epoch) 78 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 79 | param_group['lr'] = lr --------------------------------------------------------------------------------