├── .gitignore ├── README.md ├── Untitled.ipynb ├── analysis ├── fig2a.py ├── fig2b.py ├── fig2c.py ├── fig2d.py ├── fig3a.py ├── fig3b.py ├── fig3c.py ├── fig3d.py └── load_dataset.py ├── gnn.py ├── gnn_mini_batch.py ├── logger.py ├── models ├── __init__.py ├── gat.py ├── gat_neighsampler.py ├── gcn.py ├── mlp.py ├── rgcn.py ├── sage.py └── sage_neighsampler.py ├── run_tgat.py └── utils ├── __init__.py ├── dgraphfin.py ├── evaluator.py ├── tricks.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | model_results 3 | tmp.ipynb 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # pytype static type analyzer 146 | .pytype/ 147 | 148 | # Cython debug symbols 149 | cython_debug/ 150 | 151 | # PyCharm 152 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 153 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 154 | # and can be added to the global gitignore or merged into this file. For a more nuclear 155 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 156 | #.idea/ 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo provides a collection of baselines for DGraphFin dataset. Please download the dataset from the [DGraph](http://dgraph.xinye.com) web and place it under the folder './dataset/DGraphFin/raw'. 2 | 3 | ## Environments 4 | Implementing environment: 5 | - numpy = 1.21.2 6 | - pytorch = 1.6.0 7 | - torch_geometric = 1.7.2 8 | - torch_scatter = 2.0.8 9 | - torch_sparse = 0.6.9 10 | 11 | - GPU: Tesla V100 32G 12 | 13 | 14 | ## Training 15 | - **TGAT** 16 | ```bash 17 | python run_tgat.py 18 | ``` 19 | 20 | - **RGCN** 21 | ```bash 22 | python gnn.py --model rgcn --dataset DGraphFin --epochs 400 --runs 10 --device 1 --MV_trick=‘null’ --BN_trick='hetro' --BN_ratio 1.0 23 | ``` 24 | 25 | - **MLP** 26 | ```bash 27 | python gnn.py --model mlp --dataset DGraphFin --epochs 200 --runs 10 --device 0 28 | ``` 29 | 30 | - **GCN** 31 | ```bash 32 | python gnn.py --model gcn --dataset DGraphFin --epochs 200 --runs 10 --device 0 33 | ``` 34 | 35 | - **GraphSAGE** 36 | ```bash 37 | python gnn.py --model sage --dataset DGraphFin --epochs 200 --runs 10 --device 0 38 | ``` 39 | 40 | - **GraphSAGE (NeighborSampler)** 41 | ```bash 42 | python gnn_mini_batch.py --model sage_neighsampler --dataset DGraphFin --epochs 200 --runs 10 --device 0 43 | ``` 44 | 45 | - **GAT (NeighborSampler)** 46 | ```bash 47 | python gnn_mini_batch.py --model gat_neighsampler --dataset DGraphFin --epochs 200 --runs 10 --device 0 48 | ``` 49 | 50 | - **GATv2 (NeighborSampler)** 51 | ```bash 52 | python gnn_mini_batch.py --model gatv2_neighsampler --dataset DGraphFin --epochs 200 --runs 10 --device 0 53 | ``` 54 | 55 | 56 | ## Results: 57 | Performance on **DGraphFin**(10 runs): 58 | 59 | | Methods | Train AUC | Valid AUC | Test AUC | 60 | | :---- | ---- | ---- | ---- | 61 | | MLP | 0.7221 ± 0.0014 | 0.7135 ± 0.0010 | 0.7192 ± 0.0009 | 62 | | GCN | 0.7108 ± 0.0027 | 0.7078 ± 0.0027 | 0.7078 ± 0.0023 | 63 | | GraphSAGE| 0.7682 ± 0.0014 | 0.7548 ± 0.0013 | 0.7621 ± 0.0017 | 64 | | GraphSAGE (NeighborSampler) | 0.7845 ± 0.0013 | 0.7674 ± 0.0005 | **0.7761 ± 0.0018** | 65 | | GAT (NeighborSampler) | 0.7396 ± 0.0018 | 0.7233 ± 0.0012 | 0.7333 ± 0.0024 | 66 | | GATv2 (NeighborSampler) | 0.7698 ± 0.0083 | 0.7526 ± 0.0089 | 0.7624 ± 0.0081 | 67 | -------------------------------------------------------------------------------- /Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\n", 10 | "from utils import DGraphFin\n", 11 | "from utils.utils import prepare_folder\n", 12 | "from utils.evaluator import Evaluator\n", 13 | "from utils.investigate import Missvalues\n", 14 | "from utils.investigate import Background\n", 15 | "from models import MLP, MLPLinear, GCN, SAGE, GAT, GATv2\n", 16 | "from logger import Logger\n", 17 | "\n", 18 | "import argparse\n", 19 | "\n", 20 | "import torch\n", 21 | "import torch.nn.functional as F\n", 22 | "import torch.nn as nn\n", 23 | "\n", 24 | "import torch_geometric.transforms as T\n", 25 | "from torch_sparse import SparseTensor\n", 26 | "from torch_geometric.utils import to_undirected\n", 27 | "import pandas as pd" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Data(x=[3700550, 17], edge_attr=[4300999], y=[3700550, 1], train_mask=[857899], valid_mask=[183862], test_mask=[183840], adj_t=[3700550, 3700550, nnz=4300999])\n", 40 | "None\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "device = f'cuda:1' if torch.cuda.is_available() else 'cpu'\n", 46 | "device = torch.device(device)\n", 47 | "\n", 48 | "dataset = DGraphFin(root='./dataset/', name='DGraphFin', transform=T.ToSparseTensor())\n", 49 | "\n", 50 | "nlabels = dataset.num_classes\n", 51 | "\n", 52 | "data = dataset[0]\n", 53 | "print(data.coalesce())\n", 54 | "#data.adj_t = data.adj_t.to_symmetric()\n", 55 | "\n", 56 | "x = data.x\n", 57 | "x = (x-x.mean())/x.std()\n", 58 | "data.x = x\n", 59 | "data.y = data.y.squeeze(1) \n", 60 | "\n", 61 | "split_idx = {'train':data.train_mask, 'valid':data.valid_mask, 'test':data.test_mask}\n", 62 | "\n", 63 | "if split_idx['train'].dim()>1 and split_idx['train'].shape[1] >1:\n", 64 | " kfolds = True\n", 65 | " print('There are {} folds of splits'.format(split_idx['train'].shape[1]))\n", 66 | " split_idx['train'] = split_idx['train'][:, fold]\n", 67 | " split_idx['valid'] = split_idx['valid'][:, fold]\n", 68 | " split_idx['test'] = split_idx['test'][:, fold]\n", 69 | "else:\n", 70 | " kfolds = False\n", 71 | "\n", 72 | "missvalues = Missvalues('trickB')\n", 73 | "data = missvalues.process(data)\n", 74 | "\n", 75 | "print(data.edge_index)\n", 76 | "\n", 77 | "#BN = Background(args.BN_trick)\n", 78 | "#data = BN.process(data,args.BN_ratio)\n", 79 | "data.edge_index = data.adj_t\n", 80 | "data = data.to(device)\n", 81 | "train_idx = split_idx['train'].to(device)\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "Data(x=[3700550, 34], edge_attr=[4300999], y=[3700550], train_mask=[857899], valid_mask=[183862], test_mask=[183840], adj_t=[3700550, 3700550, nnz=4300999], edge_index=[3700550, 3700550, nnz=4300999])" 93 | ] 94 | }, 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "data" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "data.edge_index = torch.cat([data.edge_index.coo()[0].view(1,-1),data.edge_index.coo()[1].view(1,-1)],dim=0)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "tensor([[ 0, 1, 1, ..., 3700548, 3700549, 3700549],\n", 122 | " [ 146154, 1746169, 1752576, ..., 2626653, 3249617, 3464387]],\n", 123 | " device='cuda:1')" 124 | ] 125 | }, 126 | "execution_count": 5, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "data.edge_index " 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 6, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "import pandas as pd\n", 142 | "import numpy as np\n", 143 | "import matplotlib.pyplot as plt\n", 144 | "import numpy as np\n", 145 | "import seaborn as sns\n", 146 | "import torch\n", 147 | "import networkx as nx\n", 148 | "from sklearn.decomposition import PCA\n", 149 | "from sklearn.manifold import TSNE\n", 150 | "import torch_geometric as tg\n", 151 | "from torch_geometric.nn.conv import MessagePassing\n", 152 | "from typing import Callable, Optional, Union\n", 153 | "from torch import Tensor\n", 154 | "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", 155 | "from torch_sparse import SparseTensor, matmul\n", 156 | "from torch_geometric.utils import remove_self_loops, degree, add_self_loops\n", 157 | "from torch_scatter import scatter_add\n", 158 | "\n", 159 | "class MyConv(MessagePassing):\n", 160 | " \n", 161 | " def __init__(self,**kwargs):\n", 162 | " kwargs.setdefault('aggr', 'mean')\n", 163 | " super().__init__(**kwargs)\n", 164 | " \n", 165 | " def reset_parameters(self):\n", 166 | " reset(self.nn)\n", 167 | " self.eps.data.fill_(self.initial_eps)\n", 168 | " def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,\n", 169 | " size: Size = None) -> Tensor:\n", 170 | " if isinstance(x, Tensor):\n", 171 | " x: OptPairTensor = (x, x)\n", 172 | "\n", 173 | " # propagate_type: (x: OptPairTensor)\n", 174 | " out = self.propagate(edge_index, x=x, size=size)\n", 175 | " return out\n", 176 | "\n", 177 | "\n", 178 | " def message(self, x_j: Tensor) -> Tensor:\n", 179 | " return x_j\n", 180 | "\n", 181 | " def message_and_aggregate(self, adj_t: SparseTensor,\n", 182 | " x: OptPairTensor) -> Tensor:\n", 183 | " adj_t = adj_t.set_value(None, layout=None)\n", 184 | " return matmul(adj_t, x[0], reduce=self.aggr)\n", 185 | "\n", 186 | " def __repr__(self) -> str:\n", 187 | " return f'{self.__class__.__name__}(nn={self.nn})'\n", 188 | "class FeatureBlock(MessagePassing):\n", 189 | " def __init__(self, aggr='add', **kwargs):\n", 190 | " super(FeatureBlock, self).__init__(aggr=aggr, **kwargs)\n", 191 | "\n", 192 | " def forward(self, x, edge_index, edge_weight=None):\n", 193 | " h = x\n", 194 | " h = self.propagate(edge_index, x=h, edge_weight=edge_weight)\n", 195 | " return h\n", 196 | "\n", 197 | " def message(self, x_j, edge_weight):\n", 198 | " if edge_weight is not None:\n", 199 | " return edge_weight.view(-1, 1) * x_j\n", 200 | " return x_j\n", 201 | "\n", 202 | " def __repr__(self):\n", 203 | " return '{}({}, num_layers={})'.format(self.__class__.__name__,self.out_channels, self.num_layers)\n", 204 | "def label_norm(edge_index, num_nodes, edge_weight=None, improved=False,dtype=None):\n", 205 | " if edge_weight is None:\n", 206 | " edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,device=edge_index.device)\n", 207 | " edge_index,edge_weight= remove_self_loops(edge_index,edge_weight)\n", 208 | " row, col = edge_index\n", 209 | " out_deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)\n", 210 | " in_deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)\n", 211 | " in_deg = in_deg+1e-9\n", 212 | " out_deg = out_deg+1e-9\n", 213 | " deg_inv_sqrt =1/(in_deg.sqrt())\n", 214 | " deg_out_sqrt = 1/(out_deg.sqrt())\n", 215 | " return edge_index,deg_inv_sqrt[col]*deg_out_sqrt[row]\n", 216 | "def label_norm1(edge_index, num_nodes, edge_weight=None, improved=False,dtype=None):\n", 217 | " if edge_weight is None:\n", 218 | " edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,device=edge_index.device)\n", 219 | " edge_index,edge_weight= remove_self_loops(edge_index,edge_weight)\n", 220 | " row, col = edge_index\n", 221 | " out_deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)\n", 222 | " in_deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)\n", 223 | " print(col)\n", 224 | " in_deg = in_deg+1e-9\n", 225 | " out_deg = out_deg+1e-9\n", 226 | " deg_inv_sqrt = 1/(in_deg.sqrt())\n", 227 | " deg_out_sqrt = 1/(out_deg.sqrt())\n", 228 | " return edge_index,deg_inv_sqrt[col]*deg_out_sqrt[row],in_deg.view(-1,1)\n", 229 | "\n", 230 | "class Label_Extract(torch.nn.Module):\n", 231 | " def __init__(self):\n", 232 | " super(Label_Extract, self).__init__()\n", 233 | " self.conv1 = FeatureBlock('add')\n", 234 | " self.conv2 = FeatureBlock('add')\n", 235 | " def forward(self, data,is_direct):\n", 236 | " x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight\n", 237 | " \n", 238 | " label = data.label\n", 239 | " x = label\n", 240 | " edge_index, norm,in_degree = label_norm1(edge_index, x.shape[0],None, dtype=x.dtype)\n", 241 | " edge_index1, norm1,out_degree = label_norm1(edge_index[[1,0],:], x.shape[0],None, dtype=x.dtype)\n", 242 | " #norm_l = norm.view(-1,1)\n", 243 | " #norm_r = norm1.view(-1,1)\n", 244 | " h1 = self.conv1(x, edge_index[[1,0],:],norm1)\n", 245 | " h2 = self.conv1(x,edge_index[[0,1],:],norm)\n", 246 | " h3 = self.conv1(h2, edge_index[[1,0],:],norm1)\n", 247 | " norm2 = norm*norm1\n", 248 | " re = self.conv1(torch.ones((x.shape[0],1),dtype = torch.float), edge_index[[1,0],:],norm2)\n", 249 | " print(re)\n", 250 | " h3 -= label*re\n", 251 | " #h3[h3<0.000001]-=h3[h3<0.000001]\n", 252 | " \n", 253 | " h4 = self.conv1(h1, edge_index[[0,1],:],norm)\n", 254 | " norm2 = norm*norm1\n", 255 | " re = self.conv1(torch.ones((x.shape[0],1),dtype = torch.float), edge_index[[0,1],:],norm2)\n", 256 | " print(re.max())\n", 257 | " h4 -= label*re\n", 258 | " #h4[h4<0.000001]-=h4[h4<0.0000011]\n", 259 | " \n", 260 | " edge_index = tg.utils.to_undirected(edge_index)\n", 261 | " edge_index = tg.utils.remove_self_loops(edge_index)[0]\n", 262 | " #edge_weight = (edge_weight-edge_weight.min())/(edge_weight.max()-edge_weight.min()+1)\n", 263 | " edge_index, norm= label_norm(edge_index, x.shape[0],None, dtype=x.dtype)\n", 264 | " #edge_index2 = data.edge_index[[1,0],:]\n", 265 | " xx = None\n", 266 | " x1 = self.conv1(x, edge_index,norm)\n", 267 | " x2 = self.conv1(x1,edge_index,norm)\n", 268 | " re = self.conv1(torch.ones((x1.shape[0],1),dtype = torch.float), edge_index,norm*norm)\n", 269 | " #x1 = self.conv1(x, edge_index,norm*norm)\n", 270 | " x2 = x2 - label*re\n", 271 | " x2 = x2/(x2.sum(dim=1).view(-1,1)+1e-5)\n", 272 | " x1 = x1/(x1.sum(dim=1).view(-1,1)+1e-5)\n", 273 | " h1 = h1/(h1.sum(dim=1).view(-1,1)+1e-5)\n", 274 | " h2 = h2/(h2.sum(dim=1).view(-1,1)+1e-5)\n", 275 | " h3 = h3/(h3.sum(dim=1).view(-1,1)+1e-5)\n", 276 | " h4 = h4/(h4.sum(dim=1).view(-1,1)+1e-5)\n", 277 | " return torch.cat([x2,x1,h3,h2,h1,h4],dim=1)\n", 278 | "def label_feature(data):\n", 279 | " data = data.to('cpu')\n", 280 | " y = torch.zeros(data.x.shape[0],4)\n", 281 | " y[:,3]+=(data.y==3)\n", 282 | " y[:,2]+=(data.y==2)\n", 283 | " y[:,1]+=(data.y==1)\n", 284 | " y[:,0]+=(data.y==0)\n", 285 | " y[data.test_mask]-=y[data.test_mask]\n", 286 | " #y[data.valid_mask]-=y[data.valid_mask]\n", 287 | " conv=Label_Extract()\n", 288 | " #data.edge_index = torch.cat([data.edge_index.coo()[0].view(1,-1),data.edge_index.coo()[1].view(1,-1)],dim=0)\n", 289 | "\n", 290 | " data.label = y\n", 291 | " y1 = conv(data,is_direct=False)\n", 292 | " #x = torch.cat((y1,data.x),dim=1)\n", 293 | " return y1" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 7, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "tensor([ 146154, 1746169, 1752576, ..., 2626653, 3249617, 3464387])\n", 306 | "tensor([ 0, 1, 1, ..., 3700548, 3700549, 3700549])\n", 307 | "tensor([[0.2500],\n", 308 | " [0.2917],\n", 309 | " [0.3333],\n", 310 | " ...,\n", 311 | " [0.3333],\n", 312 | " [0.2917],\n", 313 | " [0.7500]])\n", 314 | "tensor(1.)\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "\n", 320 | "y1 = label_feature(data)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 8, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "conv = MyConv()" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 9, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "data": { 339 | "text/plain": [ 340 | "tensor([[ 0, 1, 1, ..., 3700548, 3700549, 3700549],\n", 341 | " [ 146154, 1746169, 1752576, ..., 2626653, 3249617, 3464387]])" 342 | ] 343 | }, 344 | "execution_count": 9, 345 | "metadata": {}, 346 | "output_type": "execute_result" 347 | } 348 | ], 349 | "source": [ 350 | "data.edge_index" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 10, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "#data.edge_index = tg.utils.to_undirected(data.edge_index)\n", 360 | "#h11 = conv(data.x,data.edge_index)\n", 361 | "#h12 = conv(data.x,data.edge_index[[1,0],:])" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 11, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "#h21 = conv(h11,data.edge_index[[1,0],:])" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 12, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "#h22 = conv(h12,data.edge_index)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 13, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "data = data.to('cpu')\n", 389 | "x = torch.cat([data.x,y1],dim=1)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 14, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "#data = label_feature(data)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 15, 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "data": { 408 | "text/plain": [ 409 | "SparseTensor(row=tensor([ 0, 1, 1, ..., 3700548, 3700549, 3700549]),\n", 410 | " col=tensor([ 146154, 1746169, 1752576, ..., 2626653, 3249617, 3464387]),\n", 411 | " size=(3700550, 3700550), nnz=4300999, density=0.00%)" 412 | ] 413 | }, 414 | "execution_count": 15, 415 | "metadata": {}, 416 | "output_type": "execute_result" 417 | } 418 | ], 419 | "source": [ 420 | "data.adj_t" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 16, 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "data": { 430 | "text/plain": [ 431 | "torch.Size([3700550, 24])" 432 | ] 433 | }, 434 | "execution_count": 16, 435 | "metadata": {}, 436 | "output_type": "execute_result" 437 | } 438 | ], 439 | "source": [ 440 | "y1.shape" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 17, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "data": { 450 | "text/plain": [ 451 | "Data(x=[3700550, 34], edge_attr=[4300999], y=[3700550], train_mask=[857899], valid_mask=[183862], test_mask=[183840], adj_t=[3700550, 3700550, nnz=4300999], edge_index=[2, 4300999], label=[3700550, 4])" 452 | ] 453 | }, 454 | "execution_count": 17, 455 | "metadata": {}, 456 | "output_type": "execute_result" 457 | } 458 | ], 459 | "source": [ 460 | "\n", 461 | "data.to(device)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 18, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "\n", 471 | "mlp_parameters = {'lr':0.01\n", 472 | " , 'num_layers':2\n", 473 | " , 'hidden_channels':64\n", 474 | " , 'dropout':0.0\n", 475 | " , 'batchnorm': False\n", 476 | " , 'l2':5e-7\n", 477 | " }" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 19, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "\n", 487 | "para_dict = mlp_parameters\n", 488 | "model_para = mlp_parameters.copy()\n", 489 | "model_para.pop('lr')\n", 490 | "model_para.pop('l2')\n", 491 | "model = MLP(in_channels = data.x.size(-1), out_channels = 2, **model_para).to(device)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 22, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "data = data.to('cpu')" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 23, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "X_train = x[list(data.train_mask.numpy())+list(data.valid_mask.numpy())].cpu().numpy()" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 24, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "y_train = data.y[list(data.train_mask.numpy())+list(data.valid_mask.numpy())].cpu().numpy()" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 25, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "X_test = x[data.test_mask].cpu().numpy()" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 26, 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [ 536 | "y_test = data.y[data.test_mask].cpu().numpy()" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 27, 542 | "metadata": {}, 543 | "outputs": [ 544 | { 545 | "data": { 546 | "text/plain": [ 547 | "torch.Size([3700550, 58])" 548 | ] 549 | }, 550 | "execution_count": 27, 551 | "metadata": {}, 552 | "output_type": "execute_result" 553 | } 554 | ], 555 | "source": [ 556 | "x.shape" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 77, 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "ename": "NameError", 566 | "evalue": "name 'StratifiedKFold' is not defined", 567 | "output_type": "error", 568 | "traceback": [ 569 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 570 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 571 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mxgboost\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mXGBClassifier\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mkfold\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mStratifiedKFold\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_splits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mn_estimators\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1000\u001b[0m \u001b[0;31m#数值大没关系,cv会自动返回合适的n_estimators\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m xgb1 = XGBClassifier(\n", 572 | "\u001b[0;31mNameError\u001b[0m: name 'StratifiedKFold' is not defined" 573 | ] 574 | } 575 | ], 576 | "source": [ 577 | "from xgboost import XGBClassifier\n", 578 | "import xgboost as xgb\n", 579 | " \n", 580 | "import pandas as pd \n", 581 | "import numpy as np\n", 582 | " \n", 583 | "from sklearn.model_selection import GridSearchCV\n", 584 | "from sklearn.model_selection import StratifiedKFold\n", 585 | " \n", 586 | "from sklearn.metrics import log_loss\n", 587 | "from xgboost import XGBClassifier\n", 588 | "\n", 589 | "kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=3)\n", 590 | " \n", 591 | "n_estimators = 1000 #数值大没关系,cv会自动返回合适的n_estimators\n", 592 | "xgb1 = XGBClassifier(\n", 593 | " learning_rate =0.1,\n", 594 | " n_estimators=n_estimators, \n", 595 | " max_depth=5,\n", 596 | " min_child_weight=1,\n", 597 | " gamma=0,\n", 598 | " subsample=0.3,\n", 599 | " colsample_bytree=0.8,\n", 600 | " colsample_bylevel=0.7,\n", 601 | " objective= 'multi:softprob',\n", 602 | " seed=3)\n", 603 | " \n", 604 | "xgtrain = xgb.DMatrix(x_train, label = y_train)\n", 605 | "xgb1.set_params(num_class = 3)\n", 606 | " \n", 607 | "cvresult = xgb.cv(xgb1.get_xgb_params(), xgtrain, num_boost_round=n_estimators, folds =kfold,metrics='mlogloss', early_stopping_rounds=10)\n", 608 | "n_estimators = cvresult.shape[0]\n", 609 | " \n", 610 | " # 采用交叉验证得到的最佳参数n_estimators,训练模型\n", 611 | "xgb1.set_params(n_estimators = n_estimators)\n", 612 | "xgb1.fit(X_train, y_train)" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 88, 618 | "metadata": {}, 619 | "outputs": [ 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "[LightGBM] [Warning] num_threads is set=16, n_jobs=-1 will be ignored. Current value: num_threads=16\n", 625 | "[LightGBM] [Warning] boosting is set=gbdt, boosting_type=gbdt will be ignored. Current value: boosting=gbdt\n", 626 | "[LightGBM] [Warning] bagging_freq is set=0, subsample_freq=0 will be ignored. Current value: bagging_freq=0\n", 627 | "(183840, 2)\n", 628 | "auc: 0.7940159371366042\n" 629 | ] 630 | } 631 | ], 632 | "source": [ 633 | "params = {\n", 634 | " #'objective': 'regression',\n", 635 | " 'objective':'binary',\n", 636 | " # 'objective': 'cross_entropy',\n", 637 | " #'num_class': 1,\n", 638 | " 'num_leaves':63,\n", 639 | " # 'max_depth': 6,\n", 640 | " 'num_threads': 16,\n", 641 | " 'n_estimators':100,\n", 642 | " #'min_data_in_leaf':1000,\n", 643 | " 'boosting':'gbdt',\n", 644 | " 'max_bin':10,\n", 645 | " # 'device_type': 'gpu',\n", 646 | " # 'gpu_device_id': 6,\n", 647 | " 'seed': 0,\n", 648 | " #'min_split_gain': 0.9,\n", 649 | " 'colsample_bytree': 0.8,\n", 650 | " 'subsample': 0.8,\n", 651 | " 'bagging_freq': 0,\n", 652 | " 'reg_lambda': 0.1,\n", 653 | " 'reg_alpha': 0.1,\n", 654 | " 'learning_rate': 0.05,\n", 655 | " 'metric':'auc',\n", 656 | " # 'min_child_samples': 1024,\n", 657 | " # 'num_leaves': 127,\n", 658 | " # 'learning_rate': 0.05,\n", 659 | " 'scale_pos_weight': 300\n", 660 | "}\n", 661 | "lgb_train = lgb.Dataset(X_train, y_train, free_raw_data=False)\n", 662 | "lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train,free_raw_data=False)\n", 663 | " \n", 664 | "model=lgb.LGBMClassifier(**params)\n", 665 | "model.fit(X_train,y_train)\n", 666 | "y_pre=model.predict_proba(X_test)\n", 667 | "print(y_pre.shape)\n", 668 | "y_pre = y_pre[:,1]\n", 669 | "y_pre = y_pre.reshape(len(y_pre),1)\n", 670 | "import sklearn\n", 671 | "print(\"auc:\",sklearn.metrics.roc_auc_score(y_test,y_pre))" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 89, 677 | "metadata": {}, 678 | "outputs": [], 679 | "source": [ 680 | "y_pre=model.predict_proba(X_test)\n" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 90, 686 | "metadata": {}, 687 | "outputs": [ 688 | { 689 | "data": { 690 | "text/plain": [ 691 | "0.4087848128807659" 692 | ] 693 | }, 694 | "execution_count": 90, 695 | "metadata": {}, 696 | "output_type": "execute_result" 697 | } 698 | ], 699 | "source": [ 700 | "model.score(X_test,y_test)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": null, 706 | "metadata": {}, 707 | "outputs": [], 708 | "source": [] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 30, 713 | "metadata": {}, 714 | "outputs": [ 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "设置参数\n", 720 | "交叉验证\n", 721 | "调参1:提高准确率\n" 722 | ] 723 | }, 724 | { 725 | "name": "stderr", 726 | "output_type": "stream", 727 | "text": [ 728 | "/home/xwhuang/.local/lib/python3.6/site-packages/lightgbm/engine.py:577: UserWarning: 'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. Pass 'early_stopping()' callback via 'callbacks' argument instead.\n", 729 | " _log_warning(\"'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. \"\n" 730 | ] 731 | }, 732 | { 733 | "ename": "KeyboardInterrupt", 734 | "evalue": "", 735 | "output_type": "error", 736 | "traceback": [ 737 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 738 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 739 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'auc'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mearly_stopping_rounds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0mverbose_eval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m )\n\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 740 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/lightgbm/engine.py\u001b[0m in \u001b[0;36mcv\u001b[0;34m(params, train_set, num_boost_round, folds, nfold, stratified, shuffle, metrics, fobj, feval, init_model, feature_name, categorical_feature, early_stopping_rounds, fpreproc, verbose_eval, show_stdv, seed, callbacks, eval_train_metric, return_cvbooster)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpreproc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfpreproc\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 607\u001b[0m \u001b[0mstratified\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstratified\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 608\u001b[0;31m eval_train_metric=eval_train_metric)\n\u001b[0m\u001b[1;32m 609\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;31m# setup callbacks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 741 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/lightgbm/engine.py\u001b[0m in \u001b[0;36m_make_n_folds\u001b[0;34m(full_data, folds, nfold, params, seed, fpreproc, stratified, shuffle, eval_train_metric)\u001b[0m\n\u001b[1;32m 359\u001b[0m shuffle=True, eval_train_metric=False):\n\u001b[1;32m 360\u001b[0m \u001b[0;34m\"\"\"Make a n-fold list of Booster from random indices.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 361\u001b[0;31m \u001b[0mfull_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfull_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconstruct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 362\u001b[0m \u001b[0mnum_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfull_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 363\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfolds\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 742 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36mconstruct\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1817\u001b[0m \u001b[0minit_score\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredictor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predictor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1818\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msilent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeature_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1819\u001b[0;31m categorical_feature=self.categorical_feature, params=self.params)\n\u001b[0m\u001b[1;32m 1820\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfree_raw_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1821\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 743 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m_lazy_init\u001b[0;34m(self, data, label, reference, weight, group, init_score, predictor, silent, feature_name, categorical_feature, params)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init_from_csc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mref_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1537\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1538\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init_from_np2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mref_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1539\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1540\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 744 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m__init_from_np2d\u001b[0;34m(self, mat, params_str, ref_dataset)\u001b[0m\n\u001b[1;32m 1665\u001b[0m \u001b[0mc_str\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1666\u001b[0m \u001b[0mref_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1667\u001b[0;31m ctypes.byref(self.handle)))\n\u001b[0m\u001b[1;32m 1668\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1669\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 745 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 746 | ] 747 | } 748 | ], 749 | "source": [ 750 | "import pandas as pd\n", 751 | "import lightgbm as lgb\n", 752 | "from sklearn.datasets import load_breast_cancer\n", 753 | "from sklearn.model_selection import train_test_split\n", 754 | "lgb_train = lgb.Dataset(X_train, y_train, free_raw_data=False)\n", 755 | "lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train,free_raw_data=False)\n", 756 | " \n", 757 | "### 设置初始参数--不含交叉验证参数\n", 758 | "print('设置参数')\n", 759 | "params = {\n", 760 | " 'boosting_type': 'gbdt',\n", 761 | " 'objective': 'binary',\n", 762 | " 'metric': 'auc',\n", 763 | " 'nthread':4,\n", 764 | " 'learning_rate':0.1\n", 765 | " }\n", 766 | " \n", 767 | "### 交叉验证(调参)\n", 768 | "print('交叉验证')\n", 769 | "max_auc = float('0')\n", 770 | "best_params = {}\n", 771 | " \n", 772 | "# 准确率\n", 773 | "print(\"调参1:提高准确率\")\n", 774 | "for num_leaves in range(5,100,5):\n", 775 | " for max_depth in range(3,8,1):\n", 776 | " params['num_leaves'] = num_leaves\n", 777 | " params['max_depth'] = max_depth\n", 778 | " \n", 779 | " cv_results = lgb.cv(\n", 780 | " params,\n", 781 | " lgb_train,\n", 782 | " seed=1,\n", 783 | " nfold=5,\n", 784 | " metrics=['auc'],\n", 785 | " early_stopping_rounds=10,\n", 786 | " verbose_eval=True\n", 787 | " )\n", 788 | " \n", 789 | " mean_auc = pd.Series(cv_results['auc-mean']).max()\n", 790 | " boost_rounds = pd.Series(cv_results['auc-mean']).idxmax()\n", 791 | " \n", 792 | " if mean_auc >= max_auc:\n", 793 | " max_auc = mean_auc\n", 794 | " best_params['num_leaves'] = num_leaves\n", 795 | " best_params['max_depth'] = max_depth\n", 796 | "if 'num_leaves' and 'max_depth' in best_params.keys(): \n", 797 | " params['num_leaves'] = best_params['num_leaves']\n", 798 | " params['max_depth'] = best_params['max_depth']\n", 799 | " \n", 800 | "print(\"调参2:降低过拟合\")\n", 801 | "for max_bin in range(5,256,10):\n", 802 | " for min_data_in_leaf in range(1,102,10):\n", 803 | " params['max_bin'] = max_bin\n", 804 | " params['min_data_in_leaf'] = min_data_in_leaf\n", 805 | " \n", 806 | " cv_results = lgb.cv(\n", 807 | " params,\n", 808 | " lgb_train,\n", 809 | " seed=1,\n", 810 | " nfold=5,\n", 811 | " metrics=['auc'],\n", 812 | " early_stopping_rounds=10,\n", 813 | " verbose_eval=True\n", 814 | " )\n", 815 | " \n", 816 | " mean_auc = pd.Series(cv_results['auc-mean']).max()\n", 817 | " boost_rounds = pd.Series(cv_results['auc-mean']).idxmax()\n", 818 | " \n", 819 | " if mean_auc >= max_auc:\n", 820 | " max_auc = mean_auc\n", 821 | " best_params['max_bin']= max_bin\n", 822 | " best_params['min_data_in_leaf'] = min_data_in_leaf\n", 823 | "if 'max_bin' and 'min_data_in_leaf' in best_params.keys():\n", 824 | " params['min_data_in_leaf'] = best_params['min_data_in_leaf']\n", 825 | " params['max_bin'] = best_params['max_bin']\n", 826 | " \n", 827 | "print(\"调参3:降低过拟合\")\n", 828 | "for feature_fraction in [0.6,0.7,0.8,0.9,1.0]:\n", 829 | " for bagging_fraction in [0.6,0.7,0.8,0.9,1.0]:\n", 830 | " for bagging_freq in range(0,50,5):\n", 831 | " params['feature_fraction'] = feature_fraction\n", 832 | " params['bagging_fraction'] = bagging_fraction\n", 833 | " params['bagging_freq'] = bagging_freq\n", 834 | " \n", 835 | " cv_results = lgb.cv(\n", 836 | " params,\n", 837 | " lgb_train,\n", 838 | " seed=1,\n", 839 | " nfold=5,\n", 840 | " metrics=['auc'],\n", 841 | " early_stopping_rounds=10,\n", 842 | " verbose_eval=True\n", 843 | " )\n", 844 | " \n", 845 | " mean_auc = pd.Series(cv_results['auc-mean']).max()\n", 846 | " boost_rounds = pd.Series(cv_results['auc-mean']).idxmax()\n", 847 | " \n", 848 | " if mean_auc >= max_auc:\n", 849 | " max_auc=mean_auc\n", 850 | " best_params['feature_fraction'] = feature_fraction\n", 851 | " best_params['bagging_fraction'] = bagging_fraction\n", 852 | " best_params['bagging_freq'] = bagging_freq\n", 853 | " \n", 854 | "if 'feature_fraction' and 'bagging_fraction' and 'bagging_freq' in best_params.keys():\n", 855 | " params['feature_fraction'] = best_params['feature_fraction']\n", 856 | " params['bagging_fraction'] = best_params['bagging_fraction']\n", 857 | " params['bagging_freq'] = best_params['bagging_freq']\n", 858 | " \n", 859 | " \n", 860 | "print(\"调参4:降低过拟合\")\n", 861 | "for lambda_l1 in [1e-5,1e-3,1e-1,0.0,0.1,0.3,0.5,0.7,0.9,1.0]:\n", 862 | " for lambda_l2 in [1e-5,1e-3,1e-1,0.0,0.1,0.4,0.6,0.7,0.9,1.0]:\n", 863 | " params['lambda_l1'] = lambda_l1\n", 864 | " params['lambda_l2'] = lambda_l2\n", 865 | " cv_results = lgb.cv(\n", 866 | " params,\n", 867 | " lgb_train,\n", 868 | " seed=1,\n", 869 | " nfold=5,\n", 870 | " metrics=['auc'],\n", 871 | " early_stopping_rounds=10,\n", 872 | " verbose_eval=True\n", 873 | " )\n", 874 | " \n", 875 | " mean_auc = pd.Series(cv_results['auc-mean']).max()\n", 876 | " boost_rounds = pd.Series(cv_results['auc-mean']).idxmax()\n", 877 | " \n", 878 | " if mean_auc >= max_auc:\n", 879 | " max_auc=mean_auc\n", 880 | " best_params['lambda_l1'] = lambda_l1\n", 881 | " best_params['lambda_l2'] = lambda_l2\n", 882 | "if 'lambda_l1' and 'lambda_l2' in best_params.keys():\n", 883 | " params['lambda_l1'] = best_params['lambda_l1']\n", 884 | " params['lambda_l2'] = best_params['lambda_l2']\n", 885 | " \n", 886 | "print(\"调参5:降低过拟合2\")\n", 887 | "for min_split_gain in [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]:\n", 888 | " params['min_split_gain'] = min_split_gain\n", 889 | " \n", 890 | " cv_results = lgb.cv(\n", 891 | " params,\n", 892 | " lgb_train,\n", 893 | " seed=1,\n", 894 | " nfold=5,\n", 895 | " metrics=['auc'],\n", 896 | " early_stopping_rounds=10,\n", 897 | " verbose_eval=True\n", 898 | " )\n", 899 | " \n", 900 | " mean_auc = pd.Series(cv_results['auc-mean']).max()\n", 901 | " boost_rounds = pd.Series(cv_results['auc-mean']).idxmax()\n", 902 | " \n", 903 | " if mean_auc >= max_auc:\n", 904 | " max_auc=mean_auc\n", 905 | " \n", 906 | " best_params['min_split_gain'] = min_split_gain\n", 907 | "if 'min_split_gain' in best_params.keys():\n", 908 | " params['min_split_gain'] = best_params['min_split_gain']\n", 909 | " \n", 910 | "# print(best_params)\n", 911 | "\n", 912 | "# {'bagging_fraction': 0.7,\n", 913 | "# 'bagging_freq': 30,\n", 914 | "# 'feature_fraction': 0.8,\n", 915 | "# 'lambda_l1': 0.1,\n", 916 | "# 'lambda_l2': 0.0,\n", 917 | "# 'max_bin': 255,\n", 918 | "# 'max_depth': 4,\n", 919 | "# 'min_data_in_leaf': 81,\n", 920 | "# 'min_split_gain': 0.1,\n", 921 | "# 'num_leaves': 10}\n", 922 | "\n", 923 | "model=lgb.LGBMClassifier(boosting_type='gbdt',objective='binary',metrics='auc',learning_rate=0.01, n_estimators=1000, max_depth=4, num_leaves=10,max_bin=255,min_data_in_leaf=81,bagging_fraction=0.7,bagging_freq= 30, feature_fraction= 0.8,\n", 924 | "lambda_l1=0.1,lambda_l2=0,min_split_gain=0.1)\n", 925 | "model.fit(X_train,y_train)\n", 926 | "y_pre=model.predict(X_test)\n", 927 | "print(\"acc:\",metrics.accuracy_score(y_test,y_pre))\n", 928 | "print(\"auc:\",metrics.roc_auc_score(y_test,y_pre))" 929 | ] 930 | }, 931 | { 932 | "cell_type": "code", 933 | "execution_count": 37, 934 | "metadata": {}, 935 | "outputs": [ 936 | { 937 | "ename": "AttributeError", 938 | "evalue": "'MLP' object has no attribute 'predict'", 939 | "output_type": "error", 940 | "traceback": [ 941 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 942 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 943 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my_pre\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 944 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1130\u001b[0m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0;32m-> 1131\u001b[0;31m type(self).__name__, name))\n\u001b[0m\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1133\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Module'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 945 | "\u001b[0;31mAttributeError\u001b[0m: 'MLP' object has no attribute 'predict'" 946 | ] 947 | } 948 | ], 949 | "source": [ 950 | "y_pre=model.predict(X_test)" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": 33, 956 | "metadata": {}, 957 | "outputs": [ 958 | { 959 | "name": "stdout", 960 | "output_type": "stream", 961 | "text": [ 962 | "3394\n", 963 | "Run: 01, Epoch: 10, Loss: 0.0696, Train AUC: 0.638 Train AP: 0.022 Valid AUC: 0.588 Valid AP: 0.015 Test AUC: 0.583 Test AP: 0.015\n", 964 | "Run: 01, Epoch: 20, Loss: 0.0884, Train AUC: 0.715 Train AP: 0.051 Valid AUC: 0.653 Valid AP: 0.019 Test AUC: 0.654 Test AP: 0.019\n", 965 | "Run: 01, Epoch: 30, Loss: 0.0722, Train AUC: 0.741 Train AP: 0.137 Valid AUC: 0.674 Valid AP: 0.021 Test AUC: 0.676 Test AP: 0.021\n", 966 | "Run: 01, Epoch: 40, Loss: 0.0626, Train AUC: 0.773 Train AP: 0.290 Valid AUC: 0.673 Valid AP: 0.021 Test AUC: 0.675 Test AP: 0.021\n", 967 | "Run: 01, Epoch: 50, Loss: 0.0568, Train AUC: 0.793 Train AP: 0.414 Valid AUC: 0.670 Valid AP: 0.023 Test AUC: 0.674 Test AP: 0.022\n", 968 | "Run: 01, Epoch: 60, Loss: 0.0536, Train AUC: 0.801 Train AP: 0.476 Valid AUC: 0.679 Valid AP: 0.023 Test AUC: 0.683 Test AP: 0.023\n", 969 | "Run: 01, Epoch: 70, Loss: 0.0486, Train AUC: 0.808 Train AP: 0.537 Valid AUC: 0.713 Valid AP: 0.029 Test AUC: 0.721 Test AP: 0.030\n", 970 | "Run: 01, Epoch: 80, Loss: 0.0452, Train AUC: 0.805 Train AP: 0.572 Valid AUC: 0.721 Valid AP: 0.030 Test AUC: 0.730 Test AP: 0.031\n", 971 | "Run: 01, Epoch: 90, Loss: 0.0430, Train AUC: 0.802 Train AP: 0.602 Valid AUC: 0.732 Valid AP: 0.032 Test AUC: 0.740 Test AP: 0.034\n", 972 | "Run: 01, Epoch: 100, Loss: 0.0417, Train AUC: 0.795 Train AP: 0.621 Valid AUC: 0.736 Valid AP: 0.034 Test AUC: 0.746 Test AP: 0.037\n", 973 | "Run: 01, Epoch: 110, Loss: 0.0409, Train AUC: 0.788 Train AP: 0.630 Valid AUC: 0.738 Valid AP: 0.035 Test AUC: 0.749 Test AP: 0.038\n", 974 | "Run: 01, Epoch: 120, Loss: 0.0405, Train AUC: 0.787 Train AP: 0.634 Valid AUC: 0.743 Valid AP: 0.036 Test AUC: 0.755 Test AP: 0.039\n", 975 | "Run: 01, Epoch: 130, Loss: 0.0402, Train AUC: 0.787 Train AP: 0.639 Valid AUC: 0.744 Valid AP: 0.037 Test AUC: 0.757 Test AP: 0.040\n", 976 | "Run: 01, Epoch: 140, Loss: 0.0400, Train AUC: 0.787 Train AP: 0.642 Valid AUC: 0.745 Valid AP: 0.037 Test AUC: 0.757 Test AP: 0.040\n", 977 | "Run: 01, Epoch: 150, Loss: 0.0398, Train AUC: 0.787 Train AP: 0.643 Valid AUC: 0.745 Valid AP: 0.037 Test AUC: 0.757 Test AP: 0.040\n", 978 | "Run: 01, Epoch: 160, Loss: 0.0395, Train AUC: 0.788 Train AP: 0.645 Valid AUC: 0.745 Valid AP: 0.038 Test AUC: 0.757 Test AP: 0.040\n", 979 | "Run: 01, Epoch: 170, Loss: 0.0393, Train AUC: 0.790 Train AP: 0.647 Valid AUC: 0.745 Valid AP: 0.038 Test AUC: 0.757 Test AP: 0.040\n" 980 | ] 981 | }, 982 | { 983 | "ename": "KeyboardInterrupt", 984 | "evalue": "", 985 | "output_type": "error", 986 | "traceback": [ 987 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 988 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 989 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mno_conv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0meval_results\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mno_conv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0mtrain_auc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_auc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_auc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'auc'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'auc'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'test'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'auc'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mtrain_ap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_ap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_ap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'ap'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'ap'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'test'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'ap'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 990 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 991 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtest\u001b[0;34m(model, data, split_idx, evaluator, no_conv)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mnode_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msplit_idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0meval_results\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0meval_results\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 992 | "\u001b[0;32m~/DGraph_Experiments/utils/evaluator.py\u001b[0m in \u001b[0;36meval\u001b[0;34m(self, y_true, y_pred)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval_metric\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'auc'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eval_rocauc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_eval_rocauc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 993 | "\u001b[0;32m~/DGraph_Experiments/utils/evaluator.py\u001b[0m in \u001b[0;36m_eval_rocauc\u001b[0;34m(self, y_true, y_pred)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mauc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mroc_auc_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maverage_precision_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 994 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# extra_args > 0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 995 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py\u001b[0m in \u001b[0;36mroc_auc_score\u001b[0;34m(y_true, y_score, average, sample_weight, max_fpr, multi_class, labels)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0my_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"binary\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munique\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabel_binarize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclasses\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m return _average_binary_score(partial(_binary_roc_auc_score,\n\u001b[1;32m 543\u001b[0m max_fpr=max_fpr),\n", 996 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# extra_args > 0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 997 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/sklearn/preprocessing/_label.py\u001b[0m in \u001b[0;36mlabel_binarize\u001b[0;34m(y, classes, neg_label, pos_label, sparse_output)\u001b[0m\n\u001b[1;32m 528\u001b[0m \u001b[0my_seen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0my_in_classes\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearchsorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msorted_class\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_seen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 530\u001b[0;31m \u001b[0mindptr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcumsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_in_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 531\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 532\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 998 | "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mcumsum\u001b[0;34m(*args, **kwargs)\u001b[0m\n", 999 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mcumsum\u001b[0;34m(a, axis, dtype, out)\u001b[0m\n\u001b[1;32m 2481\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2482\u001b[0m \"\"\"\n\u001b[0;32m-> 2483\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cumsum'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2484\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2485\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 1000 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mbound\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;31m# A TypeError occurs if the object does have such a method in its\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1001 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 1002 | ] 1003 | } 1004 | ], 1005 | "source": [ 1006 | "eval_metric = 'auc'\n", 1007 | "evaluator = Evaluator(eval_metric)\n", 1008 | "#logger = Logger(5, args)\n", 1009 | "no_conv = True \n", 1010 | "for run in range(5):\n", 1011 | " import gc\n", 1012 | " gc.collect()\n", 1013 | " print(sum(p.numel() for p in model.parameters()))\n", 1014 | "\n", 1015 | " model.reset_parameters()\n", 1016 | " optimizer = torch.optim.AdamW(model.parameters(), lr=para_dict['lr'], weight_decay=para_dict['l2'])\n", 1017 | " best_valid = 0\n", 1018 | " min_valid_loss = 1e8\n", 1019 | " best_out = None\n", 1020 | "\n", 1021 | " for epoch in range(1, 400):\n", 1022 | " loss = train(model, data, train_idx, optimizer, no_conv)\n", 1023 | " eval_results, losses, out = test(model, data, split_idx, evaluator, no_conv)\n", 1024 | " train_auc, valid_auc, test_auc = eval_results['train']['auc'], eval_results['valid']['auc'], eval_results['test']['auc']\n", 1025 | " train_ap, valid_ap, test_ap = eval_results['train']['ap'], eval_results['valid']['ap'], eval_results['test']['ap']\n", 1026 | "\n", 1027 | " train_loss, valid_loss, test_loss = losses['train'], losses['valid'], losses['test']\n", 1028 | " #print(eval_results['train'])\n", 1029 | "# if valid_eval > best_valid:\n", 1030 | "# best_valid = valid_result\n", 1031 | "# best_out = out.cpu().exp()\n", 1032 | " if valid_loss < min_valid_loss:\n", 1033 | " min_valid_loss = valid_loss\n", 1034 | " best_out = out.cpu()\n", 1035 | "\n", 1036 | " if epoch % 10 == 0:\n", 1037 | " print(f'Run: {run + 1:02d}, '\n", 1038 | " f'Epoch: {epoch:02d}, '\n", 1039 | " f'Loss: {loss:.4f}, '\n", 1040 | " f'Train AUC: {train_auc:.3f} '\n", 1041 | " f'Train AP: {train_ap:.3f} '\n", 1042 | " f'Valid AUC: {valid_auc:.3f} '\n", 1043 | " f'Valid AP: {valid_ap:.3f} '\n", 1044 | " f'Test AUC: { test_auc:.3f} '\n", 1045 | " f'Test AP: { test_ap:.3f}')\n", 1046 | " #logger.add_result(run, [train_auc, valid_auc, test_auc])\n", 1047 | "\n", 1048 | " #logger.print_statistics(run)\n", 1049 | "\n", 1050 | "final_results = logger.print_statistics()\n", 1051 | "print('final_results:', final_results)" 1052 | ] 1053 | }, 1054 | { 1055 | "cell_type": "code", 1056 | "execution_count": null, 1057 | "metadata": {}, 1058 | "outputs": [], 1059 | "source": [] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "execution_count": 32, 1064 | "metadata": {}, 1065 | "outputs": [], 1066 | "source": [ 1067 | "def train(model, data, train_idx, optimizer, no_conv=False):\n", 1068 | " # data.y is labels of shape (N, ) \n", 1069 | " model.train()\n", 1070 | "\n", 1071 | " optimizer.zero_grad()\n", 1072 | " if no_conv:\n", 1073 | " out = model(data.x[train_idx])\n", 1074 | " else:\n", 1075 | " out = model(data.x, data.edge_index)[train_idx]\n", 1076 | " loss = F.nll_loss(out, data.y[train_idx])\n", 1077 | " loss.backward()\n", 1078 | " optimizer.step()\n", 1079 | "\n", 1080 | " return loss.item()\n", 1081 | "\n", 1082 | "\n", 1083 | "@torch.no_grad()\n", 1084 | "def test(model, data, split_idx, evaluator, no_conv=False):\n", 1085 | " # data.y is labels of shape (N, )\n", 1086 | " model.eval()\n", 1087 | " \n", 1088 | " if no_conv:\n", 1089 | " out = model(data.x)\n", 1090 | " else:\n", 1091 | " out = model(data.x, data.edge_index)\n", 1092 | " \n", 1093 | " y_pred = out.exp() # (N,num_classes)\n", 1094 | " \n", 1095 | " losses, eval_results = dict(), dict()\n", 1096 | " for key in ['train', 'valid', 'test']:\n", 1097 | " node_id = split_idx[key]\n", 1098 | " losses[key] = F.nll_loss(out[node_id], data.y[node_id]).item()\n", 1099 | " eval_results[key] = evaluator.eval(data.y[node_id], y_pred[node_id])\n", 1100 | " \n", 1101 | " return eval_results, losses, y_pred" 1102 | ] 1103 | }, 1104 | { 1105 | "cell_type": "code", 1106 | "execution_count": null, 1107 | "metadata": {}, 1108 | "outputs": [], 1109 | "source": [] 1110 | } 1111 | ], 1112 | "metadata": { 1113 | "kernelspec": { 1114 | "display_name": "Python 3", 1115 | "language": "python", 1116 | "name": "python3" 1117 | }, 1118 | "language_info": { 1119 | "codemirror_mode": { 1120 | "name": "ipython", 1121 | "version": 3 1122 | }, 1123 | "file_extension": ".py", 1124 | "mimetype": "text/x-python", 1125 | "name": "python", 1126 | "nbconvert_exporter": "python", 1127 | "pygments_lexer": "ipython3", 1128 | "version": "3.6.9" 1129 | } 1130 | }, 1131 | "nbformat": 4, 1132 | "nbformat_minor": 4 1133 | } 1134 | -------------------------------------------------------------------------------- /analysis/fig2a.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import sys 4 | sys.path.append("..") 5 | from load_dataset import build_tg_data 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | def main(): 12 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 13 | origin_data = np.load(datapath) 14 | data = build_tg_data(is_undirected=False,) 15 | 16 | 17 | degree = pd.DataFrame(data.edge_index.T.numpy()).groupby(0).count().values 18 | ids = pd.DataFrame(data.edge_index.T.numpy()).groupby(0).count().index.values 19 | key = {} 20 | for i in range(data.x.shape[0]): 21 | key[i]=0 22 | for i in range(len(ids)): 23 | key[ids[i]]=degree[i][0] 24 | all_degree = np.array(list(key.values())) 25 | all_out_degree = all_degree 26 | 27 | 28 | degree = pd.DataFrame(data.edge_index.T.numpy()).groupby(1).count().values 29 | ids = pd.DataFrame(data.edge_index.T.numpy()).groupby(1).count().index.values 30 | key = {} 31 | for i in range(data.x.shape[0]): 32 | key[i]=0 33 | for i in range(len(ids)): 34 | key[ids[i]]=degree[i][0] 35 | all_degree = np.array(list(key.values())) 36 | all_in_degree = all_degree 37 | 38 | ab_in_d = all_in_degree[data.y==1] 39 | ab_out_d = all_out_degree[data.y==1] 40 | normal_in_d = all_in_degree[data.y==0] 41 | normal_out_d = all_out_degree[data.y==0] 42 | 43 | plot_data = pd.DataFrame() 44 | plot_data['y']=list(ab_in_d)+list(normal_in_d)+list(ab_out_d)+list(normal_out_d) 45 | plot_data['x']=['In deg.']*(len(ab_in_d)+len(normal_in_d))+['Out deg.']*(len(ab_out_d)+len(normal_out_d)) 46 | plot_data['label']=['Fraudsters']*len(ab_in_d)+['Normal users']*len(normal_in_d)+['Fraudsters']*len(ab_out_d)+['Normal users']*len(normal_out_d) 47 | 48 | #plt.rcParams['font.sans-serif'] = ['Times New Roman'] 49 | 50 | sns.set_color_codes("pastel") 51 | plt.rc('font', family='Times New Roman') 52 | 53 | pic_id = 2 54 | 55 | plt.figure(figsize=(10, 8)) 56 | plt.xticks(fontsize=45) 57 | plt.yticks([0,0.5,1.0,1.5,2.0,2.5],fontsize=45) 58 | ax = sns.barplot(x="x", y="y", hue="label", data=plot_data, palette=['r','b'],capsize=0.02,errwidth=1.5,linewidth=1.0,edgecolor=".2") 59 | #plt.ylim([0, 10]) 60 | plt.xlabel(' ',fontsize=50) 61 | plt.ylabel('Avg deg.',fontsize=50) 62 | 63 | plt.ylim(0,2.5) 64 | plt.legend(loc = 'best',fontsize=40,markerscale=2) 65 | plt.tight_layout() 66 | plt.savefig('./figure1.pdf',bbox_inches='tight', format='pdf') 67 | 68 | if __name__ == "__main__": 69 | main() 70 | 71 | #plt.show() -------------------------------------------------------------------------------- /analysis/fig2b.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import sys 5 | sys.path.append("..") 6 | from load_dataset import build_tg_data 7 | import pandas as pd 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import seaborn as sns 12 | 13 | from torch_geometric.nn.conv import MessagePassing 14 | from typing import Callable, Optional, Union 15 | from torch import Tensor 16 | from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size 17 | from torch_sparse import SparseTensor, matmul 18 | 19 | import torch 20 | 21 | 22 | def main(): 23 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 24 | origin_data = np.load(datapath) 25 | data = build_tg_data(is_undirected=False,) 26 | class MyConv(MessagePassing): 27 | 28 | def __init__(self,**kwargs): 29 | kwargs.setdefault('aggr', 'mean') 30 | super().__init__(**kwargs) 31 | 32 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 33 | size: Size = None) -> Tensor: 34 | out = self.propagate(edge_index, x=x, size=size) 35 | return out 36 | 37 | def message(self, x_i,x_j: Tensor) -> Tensor: 38 | p = (x_i*x_j).sum(dim=1).view(-1,1)/(torch.norm(x_i,dim=1).view(-1,1)*torch.norm(x_j,dim=1).view(-1,1)+1e-5) 39 | 40 | return p 41 | 42 | def __repr__(self) -> str: 43 | return f'{self.__class__.__name__}(nn={self.nn})' 44 | conv=MyConv() 45 | data = build_tg_data(is_undirected=False) 46 | edge_index = data.edge_index[[0,1],:] 47 | y = torch.zeros(data.x.shape[0],3) 48 | y[:,2]+=(data.y>1) 49 | y[:,1]+=(data.y==1) 50 | y[:,0]+=(data.y==0) 51 | y1 = conv(data.x,edge_index) 52 | edge_index = data.edge_index[[1,0],:] 53 | y = torch.zeros(data.x.shape[0],3) 54 | y[:,2]+=(data.y>1) 55 | y[:,1]+=(data.y==1) 56 | y[:,0]+=(data.y==0) 57 | y2 = conv(data.x,edge_index) 58 | 59 | 60 | 61 | plotdata = pd.DataFrame() 62 | plotdata['y'] = list(y1[data.y==1].view(-1).numpy())+list(y1[data.y==0].view(-1).numpy())+list(y2[data.y==1].view(-1).numpy())+list(y2[data.y==0].view(-1).numpy()) 63 | plotdata['label']=['Fraudsters']*len(y1[data.y==1])+['Normal users']*len(y1[data.y==0])+['Fraudsters']*len(y1[data.y==1])+['Normal users']*len(y1[data.y==0]) 64 | plotdata['x'] = ['In-neighbors']*(data.y<=1).sum()+['Out-neighbors']*(data.y<=1).sum() 65 | #plotdata = plotdata.sample(1000000) 66 | 67 | #plt.rcParams['font.sans-serif'] = ['Times New Roman'] 68 | 69 | sns.set_color_codes("pastel") 70 | plt.rc('font', family='Times New Roman') 71 | 72 | pic_id = 2 73 | 74 | plt.figure(figsize=(10, 8)) 75 | plt.xticks(fontsize=45) 76 | plt.yticks([0,0.2,0.4,0.6,0.9],fontsize=45) 77 | #ax = sns.barplot(data=plotdata,y= 78 | # 'Ratio of missing values',x='x',hue='label',palette=['r','b']) 79 | 80 | ax = sns.barplot(data=plotdata,y= 81 | 'y',x='x',hue='label', palette=['r','b'],capsize=0.02,errwidth=1.5,linewidth=1.0,edgecolor=".2") 82 | ax.legend(loc="upper right",fontsize=45) 83 | #plt.ylim([0, 10]) 84 | plt.xlabel(' ',fontsize=50) 85 | plt.ylabel('Avg cosine similarity',fontsize=50) 86 | plt.legend(loc = 'best',fontsize=40,markerscale=2) 87 | plt.ylim(0,0.6) 88 | #plt.xlim(-0.1,1) 89 | plt.tight_layout() 90 | plt.savefig('./figure2.pdf',bbox_inches='tight', format='pdf') 91 | if __name__ == "__main__": 92 | main() 93 | #plt.show() 94 | #ax.legend(loc="upper right",fontsize=32) -------------------------------------------------------------------------------- /analysis/fig2c.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import sys 4 | sys.path.append("..") 5 | from load_dataset import build_tg_data 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | import torch 12 | def main(): 13 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 14 | origin_data = np.load(datapath) 15 | data = build_tg_data(is_undirected=False,) 16 | 17 | 18 | MVR=(data.x==-1).sum(dim=1)*1.0 19 | MVR=MVR.float() 20 | plotdata = pd.DataFrame() 21 | plotdata['Ratio of missing values'] = list(MVR[data.y==1].numpy()/17)+list(MVR[data.y==0].numpy()/17) 22 | plotdata['label']=['Fraudsters']*len(MVR[data.y==1])+['Normal users']*len(MVR[data.y==0]) 23 | #plotdata = plotdata.sample(1000000) 24 | sns.set_color_codes("pastel") 25 | plt.rc('font', family='Times New Roman') 26 | 27 | pic_id = 2 28 | 29 | plt.figure(figsize=(10, 8)) 30 | plt.xticks(fontsize=45) 31 | plt.yticks(fontsize=45) 32 | ax = sns.kdeplot(data=plotdata[plotdata['label']=='Fraudsters'],x= 33 | 'Ratio of missing values',multiple="layer",common_norm=False,common_grid=False,fill=True,color='r',legend=False,label='Fraudsters') 34 | ax = sns.kdeplot(data=plotdata[plotdata['label']=='Normal users'],x= 35 | 'Ratio of missing values',multiple="layer",common_norm=False,common_grid=False,fill=True,color='b',legend=False,label='Normal users') 36 | 37 | #ax.legend(loc="upper right",fontsize=45) 38 | #plt.ylim([0, 10]) 39 | plt.xlabel('Ratio of missing values',fontsize=50) 40 | plt.ylabel('Density',fontsize=50) 41 | plt.legend(loc = 'best',fontsize=40,markerscale=2) 42 | plt.ylim(0,10) 43 | plt.xlim(-0.1,1) 44 | plt.tight_layout() 45 | plt.savefig('./figure3.pdf',bbox_inches='tight', format='pdf') 46 | if __name__ == "__main__": 47 | main() 48 | 49 | # 50 | 51 | #plt.show() -------------------------------------------------------------------------------- /analysis/fig2d.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import sys 4 | sys.path.append("..") 5 | from load_dataset import build_tg_data 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | import torch 12 | 13 | if __name__ == "__main__": 14 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 15 | origin_data = np.load(datapath) 16 | data = build_tg_data(is_undirected=False) 17 | 18 | x = (data.edge_time.max()-data.edge_time.min())/30 19 | edge_time = ((data.edge_time-data.edge_time.min())/x).long() 20 | 21 | du_data = pd.DataFrame(data.edge_index.T.numpy()) 22 | du_data['time']= edge_time.view(-1).numpy() 23 | #data = pd.DataFrame() 24 | x = [] 25 | hue = [] 26 | y = [] 27 | for i in range(2,6): 28 | ids = du_data.groupby(0).count().index.values 29 | degree = du_data.groupby(0).count().values 30 | ids2 = ids[degree[:,0]==i] 31 | values = du_data.groupby(0).max()['time'].values-du_data.groupby(0).min()['time'].values 32 | values = values[degree[:,0]==i] 33 | label = data.y[ids2].numpy() 34 | y = y+list(values[label==1]/(i-1))+list(values[label==0]/(i-1)) 35 | x = x+[i]*len(values[label==1])+[i]*len(values[label==0]) 36 | hue = hue+['Fraudsters']*len(values[label==1])+['Normal users']*len(values[label==0]) 37 | print(values[label==1].mean()/(i-1)) 38 | print(values[label==0].mean()/(i-1)) 39 | plot_data = pd.DataFrame() 40 | plot_data['x']=x 41 | plot_data['y']=y 42 | plot_data['label']=hue 43 | 44 | 45 | #plt.rcParams['font.sans-serif'] = ['Times New Roman'] 46 | 47 | sns.set_color_codes("pastel") 48 | plt.rc('font', family='Times New Roman') 49 | 50 | #pic_id = 2 51 | 52 | plt.figure(figsize=(10, 8)) 53 | plt.xticks([2,3,4,5],fontsize=45) 54 | plt.yticks([0,2,4,6],fontsize=45) 55 | ax = sns.lineplot(data=plot_data,x='x',y='y',hue='label',style="label",palette=['r','b'],sizes=[12,23],markersize='15',markers=['o']*2,markeredgecolor=None,legend='full') 56 | #for h in ax.legend_.legendHandles: 57 | # h.set_marker('o') 58 | #ax.legend(loc="upper right",fontsize=32,markerscale=32) 59 | 60 | #plt.ylim([0, 10]) 61 | plt.xlabel('Deg.',fontsize=50) 62 | plt.ylabel('Gap of each edges',fontsize=50) 63 | plt.legend(loc = 'best',fontsize=40,markerscale=2) 64 | plt.ylim(0,6) 65 | #plt.xlim(-0.1,1) 66 | plt.tight_layout() 67 | plt.savefig('./figure4.pdf',bbox_inches='tight', format='pdf') 68 | 69 | #plt.show() -------------------------------------------------------------------------------- /analysis/fig3a.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import sys 4 | sys.path.append("..") 5 | from load_dataset import build_tg_data 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | import torch 12 | 13 | from sklearn.decomposition import PCA 14 | from sklearn.manifold import TSNE 15 | if __name__ == "__main__": 16 | 17 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 18 | origin_data = np.load(datapath) 19 | data = build_tg_data(is_undirected=True) 20 | tsne = TSNE() 21 | Y = tsne.fit_transform(data.x[20000:120000]) 22 | # Y = np.load('./tsne.npy') 23 | label = data.y[20000:120000] 24 | 25 | sns.set_color_codes("pastel") 26 | plt.rc('font', family='Times New Roman') 27 | 28 | pic_id = 2 29 | 30 | 31 | 32 | plt.figure(figsize=(10, 8)) 33 | plt.xticks(fontsize=45) 34 | plt.yticks(fontsize=45) 35 | s1 = plt.scatter(x=Y[:,0][label<2][:5000],y=Y[:,1][label<2][:5000],color='b',marker='.',alpha=0.2,linewidths=5.2,edgecolors=None) 36 | 37 | s2 = plt.scatter(x=Y[:,0][label>1][:5000],y=Y[:,1][label>1][:5000],color='r',marker='.',alpha=0.2,linewidths=5.2,edgecolors=None) 38 | 39 | 40 | 41 | #plt.legend(fontsize=32) 42 | 43 | 44 | plt.xlabel('$x$',fontsize=50) 45 | plt.ylabel('$y$',fontsize=50) 46 | plt.legend((s1,s2),('Other nodes','Background nodes') ,loc = 'upper left',fontsize=40,markerscale=12,handletextpad=0) 47 | #plt.ylim(0,6) 48 | #plt.xlim(-0.1,1) 49 | plt.tight_layout() 50 | plt.savefig('./figure5.pdf',bbox_inches='tight', format='pdf') 51 | 52 | #plt.show() 53 | -------------------------------------------------------------------------------- /analysis/fig3b.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import sys 4 | sys.path.append("..") 5 | from load_dataset import build_tg_data 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | import torch 12 | import networkx as nx 13 | from sklearn.decomposition import PCA 14 | from sklearn.manifold import TSNE 15 | import torch_geometric as tg 16 | if __name__ == "__main__": 17 | 18 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 19 | origin_data = np.load(datapath) 20 | data = build_tg_data(is_undirected=True) 21 | #tsne = TSNE() 22 | #Y = tsne.fit_transform(data.x[20000:120000]) 23 | pure_idx = torch.arange(data.x.shape[0])[data.y<=1] 24 | 25 | pure_edge_index = tg.utils.subgraph(pure_idx,data.edge_index,relabel_nodes=True)[0] 26 | pure_data = tg.data.Data() 27 | pure_data.x = data.x[pure_idx] 28 | pure_data.edge_index = pure_edge_index 29 | 30 | netgraph = tg.utils.to_networkx(data,to_undirected=True) 31 | new_netgraph = tg.utils.to_networkx(pure_data,to_undirected=True) 32 | connected_components = [len(c) for c in sorted(nx.connected_components(new_netgraph), key=len, reverse=True)] 33 | connected_components = np.array(connected_components) 34 | print(connected_components.sum()) 35 | data_list = [] 36 | x = [] 37 | for i in range(1,557): 38 | num = (connected_components==i).sum() 39 | print(num) 40 | if(num>0): 41 | data_list.append(num) 42 | x.append(i) 43 | sns.set_color_codes("pastel") 44 | plt.rc('font', family='Times New Roman') 45 | 46 | pic_id = 2 47 | 48 | 49 | 50 | fig = plt.figure(figsize=(10, 8)) 51 | ax = fig.add_subplot(1,1,1) 52 | s1 = plt.scatter(y=data_list,x=x,color='w',marker='o',edgecolors='r',s=90,linewidths=3) 53 | s2 = plt.scatter(y=[1],x=[3700550],color='r',marker='*',edgecolors='r',s=120,linewidths=3) 54 | 55 | #s1 = plt.scatter(x=Y[:,0][label<2],y=Y[:,1][label<2],color='r',marker='o',alpha=0.9) 56 | 57 | ax.set_xscale("log") 58 | ax.set_yscale("log") 59 | #s2 = plt.scatter(x=Y[:,0][label>1],y=Y[:,1][label>1],color='b',marker='+',alpha=0.9) 60 | 61 | plt.xticks([1e0,1e3,1e6],fontsize=45) 62 | plt.yticks([1e0,1e2,1e4,1e6],fontsize=45) 63 | #plt.legend(fontsize=32) 64 | 65 | 66 | plt.legend((s1,s2),('w/o BN','original graph'),handletextpad=0 ,loc = 'best',fontsize=40,markerscale=2) 67 | plt.xlabel('Size of components',fontsize=50) 68 | plt.ylabel('Count',fontsize=50) 69 | plt.ylim(0,1e7) 70 | #plt.xlim(-0.1,1) 71 | plt.tight_layout() 72 | plt.savefig('./figure6.pdf',bbox_inches='tight', format='pdf') 73 | 74 | #plt.show() -------------------------------------------------------------------------------- /analysis/fig3c.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import sys 5 | sys.path.append("..") 6 | from load_dataset import build_tg_data 7 | import pandas as pd 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import seaborn as sns 12 | import torch 13 | import networkx as nx 14 | from sklearn.decomposition import PCA 15 | from sklearn.manifold import TSNE 16 | import torch_geometric as tg 17 | from torch_geometric.nn.conv import MessagePassing 18 | from typing import Callable, Optional, Union 19 | from torch import Tensor 20 | from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size 21 | from torch_sparse import SparseTensor, matmul 22 | if __name__ == "__main__": 23 | 24 | class MyConv(MessagePassing): 25 | 26 | def __init__(self,**kwargs): 27 | kwargs.setdefault('aggr', 'add') 28 | super().__init__(**kwargs) 29 | 30 | def reset_parameters(self): 31 | reset(self.nn) 32 | self.eps.data.fill_(self.initial_eps) 33 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 34 | size: Size = None) -> Tensor: 35 | if isinstance(x, Tensor): 36 | x: OptPairTensor = (x, x) 37 | 38 | # propagate_type: (x: OptPairTensor) 39 | out = self.propagate(edge_index, x=x, size=size) 40 | return out 41 | 42 | 43 | def message(self, x_j: Tensor) -> Tensor: 44 | return x_j 45 | 46 | def message_and_aggregate(self, adj_t: SparseTensor, 47 | x: OptPairTensor) -> Tensor: 48 | adj_t = adj_t.set_value(None, layout=None) 49 | return matmul(adj_t, x[0], reduce=self.aggr) 50 | 51 | def __repr__(self) -> str: 52 | return f'{self.__class__.__name__}(nn={self.nn})' 53 | 54 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 55 | origin_data = np.load(datapath) 56 | data = build_tg_data(is_undirected=False) 57 | 58 | 59 | conv=MyConv() 60 | 61 | flag = (data.y[data.edge_index[0,:]]<2)&(data.y[data.edge_index[1,:]]<2) 62 | #flag1 = (data.y[data.edge_index[0,:]]==0)&(data.y[data.edge_index[1,:]]<2) 63 | #lag = (flag|flag1) 64 | edge_index = data.edge_index[:,:] 65 | y = torch.zeros(data.x.shape[0],3) 66 | y[:,1]+=(data.y==1) 67 | y[:,0]+=(data.y==0) 68 | y[:,2]+=(data.y>1) 69 | y1 = conv(y,edge_index[:,:]) 70 | #y2 = conv(y1,edge_index[:,flag1]) 71 | #count_edge_homo(y.T.matmul(y1)) 72 | in_ratio=(y1[:,2]/(y1.sum(dim=1)+1e-4)) 73 | 74 | 75 | flag = (data.y[data.edge_index[0,:]]<2)&(data.y[data.edge_index[1,:]]<2) 76 | #flag1 = (data.y[data.edge_index[0,:]]==0)&(data.y[data.edge_index[1,:]]<2) 77 | #lag = (flag|flag1) 78 | edge_index = data.edge_index[:,:] 79 | y = torch.zeros(data.x.shape[0],3) 80 | y[:,1]+=(data.y==1) 81 | y[:,0]+=(data.y==0) 82 | y[:,2]+=(data.y>1) 83 | y1 = conv(y,edge_index[[1,0],:]) 84 | #y2 = conv(y1,edge_index[:,flag1]) 85 | #count_edge_homo(y.T.matmul(y1)) 86 | out_ratio=(y1[:,2]/(y1.sum(dim=1)+1e-4)) 87 | 88 | 89 | y = list(in_ratio[data.y==1].numpy())+list(in_ratio[data.y==0].numpy())+list(out_ratio[data.y==1].numpy())+list(out_ratio[data.y==0].numpy()) 90 | x = ['In-Neighbors']*((data.y<=1).sum())+['Out-Neighbors']*((data.y<=1).sum()) 91 | label = ['Fraudsters']*((data.y==1).sum())+['Normal users']*((data.y==0).sum())+['Fraudsters']*((data.y==1).sum())+['Normal users']*((data.y==0).sum()) 92 | plot_data=pd.DataFrame() 93 | plot_data['y']=y 94 | plot_data['x']=x 95 | plot_data['label']=label 96 | 97 | sns.set_color_codes("pastel") 98 | plt.rc('font', family='Times New Roman') 99 | 100 | pic_id = 2 101 | 102 | plt.figure(figsize=(10, 8)) 103 | plt.xticks(fontsize=45) 104 | plt.yticks([0,0.5,1.0,1.5,2.0,2.5],fontsize=45) 105 | 106 | ax = sns.barplot(x="x", y="y", hue="label", 107 | data=plot_data, palette=['r','b'],capsize=0.02,errwidth=1.5,linewidth=1.0,edgecolor=".2") 108 | 109 | #plt.ylim([0, 10]) 110 | plt.xlabel(' ',fontsize=50) 111 | plt.ylabel('Ratio of BN',fontsize=50) 112 | 113 | plt.ylim(0,0.8) 114 | plt.legend(loc = 'best',fontsize=40,markerscale=2) 115 | plt.tight_layout() 116 | plt.savefig('./figure7.pdf',bbox_inches='tight', format='pdf') 117 | 118 | #plt.show() -------------------------------------------------------------------------------- /analysis/fig3d.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import sys 5 | sys.path.append("..") 6 | from load_dataset import build_tg_data 7 | import pandas as pd 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import seaborn as sns 12 | import torch 13 | import networkx as nx 14 | from sklearn.decomposition import PCA 15 | from sklearn.manifold import TSNE 16 | import torch_geometric as tg 17 | from torch_geometric.nn.conv import MessagePassing 18 | from typing import Callable, Optional, Union 19 | from torch import Tensor 20 | from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size 21 | from torch_sparse import SparseTensor, matmul 22 | if __name__ == "__main__": 23 | 24 | class MyConv(MessagePassing): 25 | 26 | def __init__(self,**kwargs): 27 | kwargs.setdefault('aggr', 'add') 28 | super().__init__(**kwargs) 29 | 30 | def reset_parameters(self): 31 | reset(self.nn) 32 | self.eps.data.fill_(self.initial_eps) 33 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 34 | size: Size = None) -> Tensor: 35 | if isinstance(x, Tensor): 36 | x: OptPairTensor = (x, x) 37 | 38 | # propagate_type: (x: OptPairTensor) 39 | out = self.propagate(edge_index, x=x, size=size) 40 | return out 41 | 42 | 43 | def message(self, x_j: Tensor) -> Tensor: 44 | return x_j 45 | 46 | def message_and_aggregate(self, adj_t: SparseTensor, 47 | x: OptPairTensor) -> Tensor: 48 | adj_t = adj_t.set_value(None, layout=None) 49 | return matmul(adj_t, x[0], reduce=self.aggr) 50 | 51 | def __repr__(self) -> str: 52 | return f'{self.__class__.__name__}(nn={self.nn})' 53 | 54 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 55 | origin_data = np.load(datapath) 56 | data = build_tg_data(is_undirected=True) 57 | 58 | 59 | conv=MyConv() 60 | 61 | flag = (data.y[data.edge_index[0,:]]<5)&(data.y[data.edge_index[1,:]]>=2) 62 | flag1 = (data.y[data.edge_index[0,:]]>=2)&(data.y[data.edge_index[1,:]]<5) 63 | #lag = (flag|flag1) 64 | edge_index = data.edge_index[:,:] 65 | y = torch.zeros(data.x.shape[0],3) 66 | y[:,2]+=(data.y>1) 67 | y[:,1]+=(data.y==1) 68 | y[:,0]+=(data.y==0) 69 | y1 = conv(y,edge_index[:,flag]) 70 | y11 = conv(y,edge_index[:,flag][[1,0],:]) 71 | y2 = conv(y1,edge_index[:,flag1])-y*y11.sum(dim=1).view(-1,1) 72 | 73 | ans=0 74 | for i in range(3): 75 | ans += ((y2[data.y==i,i]).sum()*1.0/(y2[data.y==i].sum()+1e-5))-(data.y==i).sum()*1.0/len(data.y) 76 | print(ans) 77 | print('XBX h',ans/2) 78 | 79 | flag = (data.y[data.edge_index[0,:]]<5)&(data.y[data.edge_index[1,:]]<=1) 80 | flag1 = (data.y[data.edge_index[0,:]]<=1)&(data.y[data.edge_index[1,:]]<5) 81 | #lag = (flag|flag1) 82 | edge_index = data.edge_index[:,:] 83 | y = torch.zeros(data.x.shape[0],3) 84 | y[:,2]+=(data.y>1) 85 | y[:,1]+=(data.y==1) 86 | y[:,0]+=(data.y==0) 87 | y1 = conv(y,edge_index[:,flag]) 88 | y11 = conv(y,edge_index[:,flag][[1,0],:]) 89 | y2 = conv(y1,edge_index[:,flag1])-y*y11.sum(dim=1).view(-1,1) 90 | 91 | ans=0 92 | for i in range(3): 93 | ans += ((y2[data.y==i,i]).sum()*1.0/(y2[data.y==i].sum()+1e-5))-(data.y==i).sum()*1.0/len(data.y) 94 | print(ans) 95 | print('XTX h',ans/2) 96 | 97 | edge_index = data.edge_index[:,:] 98 | y = torch.zeros(data.x.shape[0],3) 99 | y[:,2]+=(data.y>1) 100 | y[:,1]+=(data.y==1) 101 | y[:,0]+=(data.y==0) 102 | y1 = conv(y,edge_index) 103 | y2=y1 104 | 105 | ans=0 106 | for i in range(3): 107 | ans += ((y2[data.y==i,i]).sum()*1.0/(y2[data.y==i].sum()+1e-5))-(data.y==i).sum()*1.0/len(data.y) 108 | print(ans) 109 | print('XX h',ans/2) 110 | 111 | 112 | sns.set_color_codes("pastel") 113 | plt.rc('font', family='Times New Roman') 114 | 115 | pic_id = 2 116 | 117 | plt.figure(figsize=(10, 8)) 118 | plt.xticks(fontsize=45) 119 | plt.yticks([0,0.2,0.4,0.6],fontsize=45) 120 | 121 | ax = plt.scatter(x=["$\\times$B$\\times$",'$\\times$T$\\times$','$\\times\\times$'], y=[0.2225,0.1460,0.1285], color='black',marker='+',edgecolors='black',s=500,linewidths=30) 122 | ax = plt.scatter(x=["$\\times$B$\\times$",'$\\times$T$\\timse$','$\\times\\times$'], y=[0.2225,0.1460,0.1285], color='black',marker='.',edgecolors='black',s=200,linewidths=None) 123 | 124 | ax = plt.scatter(x=['$\\times\\times$'], y=[0.011], color='r',marker='x',edgecolors='r',s=500,linewidths=30) 125 | ax = plt.scatter(x=['$\\times\\times$'], y=[0.011], color='r',marker='.',edgecolors='r',s=100,linewidths=1) 126 | 127 | ax = plt.scatter(x=['$\\times\\times$'], y=[0.416], color='r',marker='x',edgecolors='r',s=500,linewidths=30) 128 | ax = plt.scatter(x=['$\\times\\times$'], y=[0.416], color='r',marker='.',edgecolors='r',s=100,linewidths=1) 129 | 130 | plt.xlabel('Edge type',fontsize=50) 131 | plt.ylabel('$h$',fontsize=50) 132 | plt.ylim(0,0.5) 133 | plt.xlim(-0.5,2.5) 134 | plt.tight_layout() 135 | plt.savefig('./figure8.pdf',bbox_inches='tight', format='pdf') 136 | 137 | #plt.show() -------------------------------------------------------------------------------- /analysis/load_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch_geometric as tg 4 | 5 | datapath = '../dataset/DGraphFin/raw/dgraphfin.npz' 6 | 7 | def build_tg_data(is_undirected=True,datapath=None): 8 | origin_data = np.load(datapath) 9 | data = tg.data.Data() 10 | data.x = torch.tensor(origin_data['x']).float() 11 | data.y = torch.tensor(origin_data['y']).long() 12 | data.edge_index = torch.tensor(origin_data['edge_index']).long().T 13 | data.train_mask = torch.tensor(origin_data['train_mask']).long() 14 | data.val_mask = torch.tensor(origin_data['valid_mask']).long() 15 | data.test_mask = torch.tensor(origin_data['test_mask']).long() 16 | data.edge_time = torch.tensor(origin_data['edge_timestamp']).long() 17 | if(is_undirected): 18 | data.edge_index = tg.utils.to_undirected(data.edge_index) 19 | return data 20 | -------------------------------------------------------------------------------- /gnn.py: -------------------------------------------------------------------------------- 1 | # dataset name: DGraphFin 2 | 3 | from utils import DGraphFin 4 | from utils.utils import prepare_folder 5 | from utils.evaluator import Evaluator 6 | from utils.tricks import Missvalues 7 | from utils.tricks import Background 8 | from utils.tricks import Structure 9 | from models import MLP, MLPLinear, GCN, SAGE, GAT, GATv2,RGCN 10 | from logger import Logger 11 | 12 | import argparse 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.nn as nn 17 | 18 | import torch_geometric as tg 19 | import torch_geometric.transforms as T 20 | 21 | from torch_sparse import SparseTensor 22 | from torch_geometric.utils import to_undirected 23 | import pandas as pd 24 | import numpy as np 25 | import time 26 | eval_metric = 'auc' 27 | 28 | mlp_parameters = {'lr':0.01 29 | , 'num_layers':2 30 | , 'hidden_channels':64 31 | , 'dropout':0.0 32 | , 'batchnorm': False 33 | , 'l2':5e-7 34 | } 35 | 36 | gcn_parameters = {'lr':0.01 37 | , 'num_layers':2 38 | , 'hidden_channels':64 39 | , 'dropout':0.0 40 | , 'batchnorm': False 41 | , 'l2':5e-7 42 | } 43 | 44 | sage_parameters = {'lr':0.01 45 | , 'num_layers':2 46 | , 'hidden_channels':64 47 | , 'dropout':0 48 | , 'batchnorm': False 49 | , 'l2':5e-7 50 | } 51 | 52 | 53 | def train(model, data, train_idx, optimizer, weight=None, no_conv=False,is_rgcn=False): 54 | # data.y is labels of shape (N, ) 55 | model.train() 56 | 57 | optimizer.zero_grad() 58 | if no_conv: 59 | out = model(data.x[train_idx]) 60 | else: 61 | if(is_rgcn): 62 | out = model(data.x, data.edge_index, data.edge_type)[train_idx] 63 | else: 64 | out = model(data.x, data.edge_index)[train_idx] 65 | loss = F.nll_loss(out, data.y[train_idx],weight = weight) 66 | loss.backward() 67 | optimizer.step() 68 | 69 | return loss.item() 70 | 71 | 72 | @torch.no_grad() 73 | def test(model, data, split_idx, evaluator, no_conv=False,is_rgcn=True): 74 | # data.y is labels of shape (N, ) 75 | model.eval() 76 | 77 | if no_conv: 78 | out = model(data.x) 79 | else: 80 | if(is_rgcn): 81 | out = model(data.x, data.edge_index, data.edge_type) 82 | else: 83 | out = model(data.x, data.edge_index) 84 | 85 | y_pred = out.exp() # (N,num_classes) 86 | 87 | losses, eval_results = dict(), dict() 88 | for key in ['train', 'valid', 'test']: 89 | node_id = split_idx[key] 90 | losses[key] = F.nll_loss(out[node_id], data.y[node_id]).item() 91 | eval_results[key] = evaluator.eval(data.y[node_id], y_pred[node_id]) 92 | 93 | return eval_results, losses, y_pred 94 | 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser(description='gnn_models') 98 | parser.add_argument('--device', type=int, default=0) 99 | parser.add_argument('--dataset', type=str, default='DGraphFin') 100 | parser.add_argument('--log_steps', type=int, default=10) 101 | parser.add_argument('--model', type=str, default='mlp') 102 | parser.add_argument('--use_embeddings', action='store_true') 103 | parser.add_argument('--epochs', type=int, default=1000) 104 | parser.add_argument('--runs', type=int, default=5) 105 | parser.add_argument('--fold', type=int, default=0) 106 | parser.add_argument('--MV_trick', type=str, default='null') 107 | parser.add_argument('--BN_trick', type=str, default='null') 108 | parser.add_argument('--BN_ratio', type=float, default=0.1) 109 | parser.add_argument('--Structure', type=str, default='original') 110 | args = parser.parse_args() 111 | print(args) 112 | 113 | no_conv = False 114 | if args.model in ['mlp']: no_conv = True 115 | 116 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 117 | device = torch.device(device) 118 | # device = torch.device('cpu') 119 | dataset = DGraphFin(root='./dataset/', name=args.dataset, transform=T.ToSparseTensor()) 120 | 121 | nlabels = dataset.num_classes 122 | if args.dataset in ['DGraphFin']: nlabels = 2 123 | 124 | data = dataset[0] 125 | data.edge_index = data.adj_t 126 | data.adj_t = torch.cat([data.edge_index.coo()[0].view(1,-1),data.edge_index.coo()[1].view(1,-1)],dim=0) 127 | data.edge_index = data.adj_t 128 | structure = Structure(args.Structure) 129 | data = structure.process(data) 130 | data.adj_t = data.edge_index 131 | data.adj_t = tg.utils.to_undirected(data.adj_t) 132 | data.edge_index = data.adj_t 133 | if args.dataset in ['DGraphFin']: 134 | x = data.x 135 | x = (x-x.mean(0))/x.std(0) 136 | data.x = x 137 | if data.y.dim()==2: 138 | data.y = data.y.squeeze(1) 139 | 140 | 141 | print(data) 142 | 143 | 144 | missvalues = Missvalues(args.MV_trick) 145 | data = missvalues.process(data) 146 | 147 | #print(data.edge_index) 148 | 149 | data.edge_index = data.adj_t 150 | BN = Background(args.BN_trick) 151 | data = BN.process(data,args.BN_ratio) 152 | 153 | split_idx = {'train':data.train_mask, 'valid':data.valid_mask, 'test':data.test_mask} 154 | 155 | fold = args.fold 156 | if split_idx['train'].dim()>1 and split_idx['train'].shape[1] >1: 157 | kfolds = True 158 | print('There are {} folds of splits'.format(split_idx['train'].shape[1])) 159 | split_idx['train'] = split_idx['train'][:, fold] 160 | split_idx['valid'] = split_idx['valid'][:, fold] 161 | split_idx['test'] = split_idx['test'][:, fold] 162 | else: 163 | kfolds = False 164 | 165 | split_idx = {'train':data.train_mask, 'valid':data.valid_mask, 'test':data.test_mask} 166 | 167 | data = data.to(device) 168 | train_idx = split_idx['train'].to(device) 169 | 170 | result_dir = prepare_folder(args.dataset, args.model) 171 | print('result_dir:', result_dir) 172 | 173 | is_rgcn=False 174 | if args.model == 'mlp': 175 | para_dict = mlp_parameters 176 | model_para = mlp_parameters.copy() 177 | model_para.pop('lr') 178 | model_para.pop('l2') 179 | model = MLP(in_channels = data.x.size(-1), out_channels = nlabels, **model_para).to(device) 180 | if args.model == 'gcn': 181 | para_dict = gcn_parameters 182 | model_para = gcn_parameters.copy() 183 | model_para.pop('lr') 184 | model_para.pop('l2') 185 | model = GCN(in_channels = data.x.size(-1), out_channels = nlabels, **model_para).to(device) 186 | if args.model == 'sage': 187 | para_dict = sage_parameters 188 | model_para = sage_parameters.copy() 189 | model_para.pop('lr') 190 | model_para.pop('l2') 191 | model = SAGE(in_channels = data.x.size(-1), out_channels = nlabels, **model_para).to(device) 192 | if args.model == 'rgcn': 193 | para_dict = gcn_parameters 194 | model = RGCN(data.x.size(-1),16,2,4).to(device) 195 | is_rgcn=True 196 | print(f'Model {args.model} initialized') 197 | 198 | evaluator = Evaluator(eval_metric) 199 | logger = Logger(args.runs, args) 200 | weight = torch.tensor([1,50]).to(device).float() 201 | 202 | for run in range(args.runs): 203 | import gc 204 | gc.collect() 205 | print(sum(p.numel() for p in model.parameters())) 206 | 207 | model.reset_parameters() 208 | optimizer = torch.optim.Adam(model.parameters(), lr=para_dict['lr'], weight_decay=para_dict['l2']) 209 | best_valid = 0 210 | min_valid_loss = 1e8 211 | best_out = None 212 | 213 | time_ls = [] 214 | starttime = time.time() 215 | for epoch in range(1, args.epochs+1): 216 | starttime = time.time() 217 | loss = train(model, data, train_idx, optimizer, weight,no_conv,is_rgcn) 218 | 219 | endtime = time.time() 220 | time_ls.append(endtime-starttime) 221 | eval_results, losses, out = test(model, data, split_idx, evaluator,no_conv,is_rgcn) 222 | train_auc, valid_auc, test_auc = eval_results['train']['auc'], eval_results['valid']['auc'], eval_results['test']['auc'] 223 | train_ap, valid_ap, test_ap = eval_results['train']['ap'], eval_results['valid']['ap'], eval_results['test']['ap'] 224 | 225 | train_loss, valid_loss, test_loss = losses['train'], losses['valid'], losses['test'] 226 | #print(eval_results['train']) 227 | # if valid_eval > best_valid: 228 | # best_valid = valid_result 229 | # best_out = out.cpu().exp() 230 | 231 | if valid_loss < min_valid_loss: 232 | min_valid_loss = valid_loss 233 | best_out = out.cpu() 234 | 235 | 236 | if epoch % args.log_steps == 0: 237 | print(f'Run: {run + 1:02d}, ' 238 | f'Epoch: {epoch:02d}, ' 239 | f'Loss: {loss:.4f}, ' 240 | f'Train AUC: {train_auc:.3f} ' 241 | f'Train AP: {train_ap:.3f} ' 242 | f'Valid AUC: {valid_auc:.3f} ' 243 | f'Valid AP: {valid_ap:.3f} ' 244 | f'Test AUC: { test_auc:.3f} ' 245 | f'Test AP: { test_ap:.3f} ' 246 | f'Train time(s): {np.mean(time_ls):.3f}') 247 | 248 | 249 | time_ls = [] 250 | logger.add_result(run, [train_auc, valid_auc, test_auc]) 251 | 252 | logger.print_statistics(run) 253 | 254 | final_results = logger.print_statistics() 255 | print('final_results:', final_results) 256 | para_dict.update(final_results) 257 | pd.DataFrame(para_dict, index=[args.model]).to_csv(result_dir+'/results.csv') 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /gnn_mini_batch.py: -------------------------------------------------------------------------------- 1 | # dataset name: DGraphFin 2 | 3 | from utils import DGraphFin 4 | from utils.utils import prepare_folder 5 | from utils.evaluator import Evaluator 6 | from torch_geometric.data import NeighborSampler 7 | from models import SAGE_NeighSampler, GAT_NeighSampler, GATv2_NeighSampler 8 | from logger import Logger 9 | from tqdm import tqdm 10 | 11 | import argparse 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.nn as nn 16 | 17 | import torch_geometric.transforms as T 18 | from torch_sparse import SparseTensor 19 | from torch_geometric.utils import to_undirected 20 | import pandas as pd 21 | 22 | eval_metric = 'auc' 23 | 24 | sage_neighsampler_parameters = {'lr':0.003 25 | , 'num_layers':2 26 | , 'hidden_channels':128 27 | , 'dropout':0.0 28 | , 'batchnorm': False 29 | , 'l2':5e-7 30 | } 31 | 32 | gat_neighsampler_parameters = {'lr':0.003 33 | , 'num_layers':2 34 | , 'hidden_channels':128 35 | , 'dropout':0.0 36 | , 'batchnorm': False 37 | , 'l2':5e-7 38 | , 'layer_heads':[4,1] 39 | } 40 | 41 | gatv2_neighsampler_parameters = {'lr':0.003 42 | , 'num_layers':2 43 | , 'hidden_channels':128 44 | , 'dropout':0.0 45 | , 'batchnorm': False 46 | , 'l2':5e-6 47 | , 'layer_heads':[4,1] 48 | } 49 | 50 | 51 | def train(epoch, train_loader, model, data, train_idx, optimizer, device, no_conv=False): 52 | model.train() 53 | 54 | pbar = tqdm(total=train_idx.size(0), ncols=80) 55 | pbar.set_description(f'Epoch {epoch:02d}') 56 | 57 | total_loss = total_correct = 0 58 | for batch_size, n_id, adjs in train_loader: 59 | # `adjs` holds a list of `(edge_index, e_id, size)` tuples. 60 | adjs = [adj.to(device) for adj in adjs] 61 | 62 | optimizer.zero_grad() 63 | out = model(data.x[n_id], adjs) 64 | loss = F.nll_loss(out, data.y[n_id[:batch_size]]) 65 | loss.backward() 66 | optimizer.step() 67 | 68 | total_loss += float(loss) 69 | pbar.update(batch_size) 70 | 71 | pbar.close() 72 | loss = total_loss / len(train_loader) 73 | 74 | return loss 75 | 76 | 77 | @torch.no_grad() 78 | def test(layer_loader, model, data, split_idx, evaluator, device, no_conv=False): 79 | # data.y is labels of shape (N, ) 80 | model.eval() 81 | 82 | out = model.inference(data.x, layer_loader, device) 83 | # out = model.inference_all(data) 84 | y_pred = out.exp() # (N,num_classes) 85 | 86 | losses, eval_results = dict(), dict() 87 | for key in ['train', 'valid', 'test']: 88 | node_id = split_idx[key] 89 | node_id = node_id.to(device) 90 | losses[key] = F.nll_loss(out[node_id], data.y[node_id]).item() 91 | eval_results[key] = evaluator.eval(data.y[node_id], y_pred[node_id])[eval_metric] 92 | 93 | return eval_results, losses, y_pred 94 | 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser(description='minibatch_gnn_models') 98 | parser.add_argument('--device', type=int, default=0) 99 | parser.add_argument('--dataset', type=str, default='DGraphFin') 100 | parser.add_argument('--log_steps', type=int, default=10) 101 | parser.add_argument('--model', type=str, default='mlp') 102 | parser.add_argument('--use_embeddings', action='store_true') 103 | parser.add_argument('--epochs', type=int, default=100) 104 | parser.add_argument('--runs', type=int, default=10) 105 | parser.add_argument('--fold', type=int, default=0) 106 | 107 | args = parser.parse_args() 108 | print(args) 109 | 110 | no_conv = False 111 | if args.model in ['mlp']: no_conv = True 112 | 113 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 114 | device = torch.device(device) 115 | 116 | dataset = DGraphFin(root='./dataset/', name=args.dataset, transform=T.ToSparseTensor()) 117 | 118 | nlabels = dataset.num_classes 119 | if args.dataset =='DGraphFin': nlabels = 2 120 | 121 | data = dataset[0] 122 | data.adj_t = data.adj_t.to_symmetric() 123 | 124 | if args.dataset in ['DGraphFin']: 125 | x = data.x 126 | x = (x-x.mean(0))/x.std(0) 127 | data.x = x 128 | if data.y.dim()==2: 129 | data.y = data.y.squeeze(1) 130 | 131 | split_idx = {'train':data.train_mask, 'valid':data.valid_mask, 'test':data.test_mask} 132 | 133 | fold = args.fold 134 | if split_idx['train'].dim()>1 and split_idx['train'].shape[1] >1: 135 | kfolds = True 136 | print('There are {} folds of splits'.format(split_idx['train'].shape[1])) 137 | split_idx['train'] = split_idx['train'][:, fold] 138 | split_idx['valid'] = split_idx['valid'][:, fold] 139 | split_idx['test'] = split_idx['test'][:, fold] 140 | else: 141 | kfolds = False 142 | 143 | data = data.to(device) 144 | train_idx = split_idx['train'].to(device) 145 | 146 | result_dir = prepare_folder(args.dataset, args.model) 147 | print('result_dir:', result_dir) 148 | 149 | train_loader = NeighborSampler(data.adj_t, node_idx=train_idx, sizes=[10, 5], batch_size=1024, shuffle=True, num_workers=12) 150 | layer_loader = NeighborSampler(data.adj_t, node_idx=None, sizes=[-1], batch_size=4096, shuffle=False, num_workers=12) 151 | 152 | if args.model == 'sage_neighsampler': 153 | para_dict = sage_neighsampler_parameters 154 | model_para = sage_neighsampler_parameters.copy() 155 | model_para.pop('lr') 156 | model_para.pop('l2') 157 | model = SAGE_NeighSampler(in_channels = data.x.size(-1), out_channels = nlabels, **model_para).to(device) 158 | if args.model == 'gat_neighsampler': 159 | para_dict = gat_neighsampler_parameters 160 | model_para = gat_neighsampler_parameters.copy() 161 | model_para.pop('lr') 162 | model_para.pop('l2') 163 | model = GAT_NeighSampler(in_channels = data.x.size(-1), out_channels = nlabels, **model_para).to(device) 164 | if args.model == 'gatv2_neighsampler': 165 | para_dict = gatv2_neighsampler_parameters 166 | model_para = gatv2_neighsampler_parameters.copy() 167 | model_para.pop('lr') 168 | model_para.pop('l2') 169 | model = GATv2_NeighSampler(in_channels = data.x.size(-1), out_channels = nlabels, **model_para).to(device) 170 | 171 | print(f'Model {args.model} initialized') 172 | 173 | evaluator = Evaluator(eval_metric) 174 | logger = Logger(args.runs, args) 175 | 176 | for run in range(args.runs): 177 | import gc 178 | gc.collect() 179 | print(sum(p.numel() for p in model.parameters())) 180 | 181 | model.reset_parameters() 182 | optimizer = torch.optim.Adam(model.parameters(), lr=para_dict['lr'], weight_decay=para_dict['l2']) 183 | best_valid = 0 184 | min_valid_loss = 1e8 185 | best_out = None 186 | 187 | for epoch in range(1, args.epochs+1): 188 | loss = train(epoch, train_loader, model, data, train_idx, optimizer, device, no_conv) 189 | eval_results, losses, out = test(layer_loader, model, data, split_idx, evaluator, device, no_conv) 190 | train_eval, valid_eval, test_eval = eval_results['train'], eval_results['valid'], eval_results['test'] 191 | train_loss, valid_loss, test_loss = losses['train'], losses['valid'], losses['test'] 192 | 193 | # if valid_eval > best_valid: 194 | # best_valid = valid_result 195 | # best_out = out.cpu().exp() 196 | if valid_loss < min_valid_loss: 197 | min_valid_loss = valid_loss 198 | best_out = out.cpu() 199 | 200 | if epoch % args.log_steps == 0: 201 | print(f'Run: {run + 1:02d}, ' 202 | f'Epoch: {epoch:02d}, ' 203 | f'Loss: {loss:.4f}, ' 204 | f'Train: {100 * train_eval:.3f}%, ' 205 | f'Valid: {100 * valid_eval:.3f}% ' 206 | f'Test: {100 * test_eval:.3f}%') 207 | logger.add_result(run, [train_eval, valid_eval, test_eval]) 208 | 209 | logger.print_statistics(run) 210 | 211 | final_results = logger.print_statistics() 212 | print('final_results:', final_results) 213 | para_dict.update(final_results) 214 | for k, v in para_dict.items(): 215 | if type(v) is list: para_dict.update({k:str(v)}) 216 | pd.DataFrame(para_dict, index=[args.model]).to_csv(result_dir+'/results.csv') 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Logger(object): 4 | def __init__(self, runs, info=None): 5 | self.info = info 6 | self.results = [[] for _ in range(runs)] 7 | 8 | def add_result(self, run, result): 9 | assert len(result) == 3 10 | assert run >= 0 and run < len(self.results) 11 | self.results[run].append(result) 12 | 13 | def print_statistics(self, run=None): 14 | if run is not None: 15 | result = 100 * torch.tensor(self.results[run]) 16 | argmax = result[:, 1].argmax().item() 17 | print(f'Run {run + 1:02d}:') 18 | print(f'Highest Train: {result[:, 0].max():.2f}') 19 | print(f'Highest Valid: {result[:, 1].max():.2f}') 20 | print(f' Final Train: {result[argmax, 0]:.2f}') 21 | print(f' Final Test: {result[argmax, 2]:.2f}') 22 | else: 23 | result = 100 * torch.tensor(self.results) 24 | 25 | best_results = [] 26 | for r in result: 27 | train1 = r[:, 0].max().item() 28 | valid = r[:, 1].max().item() 29 | train2 = r[r[:, 1].argmax(), 0].item() 30 | test = r[r[:, 1].argmax(), 2].item() 31 | best_results.append((train1, valid, train2, test)) 32 | 33 | best_result = torch.tensor(best_results) 34 | 35 | print(f'All runs:') 36 | r = best_result[:, 0] 37 | highest_train, highest_train_std = r.mean().item(), r.std().item() 38 | print(f'Highest Train: {r.mean():.4f} ± {r.std():.4f}') 39 | r = best_result[:, 1] 40 | highest_valid, highest_valid_std = r.mean().item(), r.std().item() 41 | print(f'Highest Valid: {r.mean():.4f} ± {r.std():.4f}') 42 | r = best_result[:, 2] 43 | final_train, final_train_std = r.mean().item(), r.std().item() 44 | print(f' Final Train: {r.mean():.4f} ± {r.std():.4f}') 45 | r = best_result[:, 3] 46 | final_test, final_test_std = r.mean().item(), r.std().item() 47 | print(f' Final Test: {r.mean():.4f} ± {r.std():.4f}') 48 | 49 | return {'train': round(final_train, 4) 50 | , 'train_std': round(final_train_std, 4) 51 | , 'valid': round(highest_valid, 4) 52 | , 'valid_std': round(highest_valid_std, 4) 53 | , 'test': round(final_test, 4) 54 | , 'test_std': round(final_test_std, 4) 55 | } 56 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP, MLPLinear 2 | from .gcn import GCN 3 | from .sage import SAGE 4 | from .sage_neighsampler import SAGE_NeighSampler 5 | from .gat import GAT, GATv2 6 | from .gat_neighsampler import GAT_NeighSampler, GATv2_NeighSampler 7 | from .rgcn import RGCN -------------------------------------------------------------------------------- /models/gat.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch import Tensor 4 | from torch_sparse import SparseTensor 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import GATConv, GATv2Conv 8 | 9 | class GAT(torch.nn.Module): 10 | def __init__(self 11 | , in_channels 12 | , hidden_channels 13 | , out_channels 14 | , num_layers 15 | , dropout 16 | , layer_heads = [] 17 | , batchnorm=True): 18 | super(GAT, self).__init__() 19 | 20 | self.convs = torch.nn.ModuleList() 21 | self.convs.append(GATConv(in_channels, hidden_channels, heads=layer_heads[0], concat=True)) 22 | self.bns = torch.nn.ModuleList() 23 | self.batchnorm = batchnorm 24 | if self.batchnorm: 25 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[0])) 26 | for _ in range(num_layers - 2): 27 | self.convs.append(GATConv(hidden_channels*layer_heads[i-1], hidden_channels, heads=layer_heads[i], concat=True)) 28 | if self.batchnorm: 29 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[i-1])) 30 | self.convs.append(GATConv(hidden_channels*layer_heads[num_layers-2] 31 | , out_channels 32 | , heads=layer_heads[num_layers-1] 33 | , concat=False)) 34 | 35 | self.dropout = dropout 36 | 37 | def reset_parameters(self): 38 | for conv in self.convs: 39 | conv.reset_parameters() 40 | if self.batchnorm: 41 | for bn in self.bns: 42 | bn.reset_parameters() 43 | 44 | def forward(self, x, edge_index: Union[Tensor, SparseTensor]): 45 | for i, conv in enumerate(self.convs[:-1]): 46 | x = conv(x, edge_index) 47 | if self.batchnorm: 48 | x = self.bns[i](x) 49 | x = F.relu(x) 50 | x = F.dropout(x, p=self.dropout, training=self.training) 51 | x = self.convs[-1](x, edge_index) 52 | return x.log_softmax(dim=-1) 53 | 54 | 55 | 56 | 57 | class GATv2(torch.nn.Module): 58 | def __init__(self 59 | , in_channels 60 | , hidden_channels 61 | , out_channels 62 | , num_layers 63 | , dropout 64 | , layer_heads = [] 65 | , batchnorm=True): 66 | super(GATv2, self).__init__() 67 | 68 | self.convs = torch.nn.ModuleList() 69 | self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=layer_heads[0], concat=True)) 70 | self.bns = torch.nn.ModuleList() 71 | self.batchnorm = batchnorm 72 | if self.batchnorm: 73 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[0])) 74 | for _ in range(num_layers - 2): 75 | self.convs.append(GATv2Conv(hidden_channels*layer_heads[i-1], hidden_channels, heads=layer_heads[i], concat=True)) 76 | if self.batchnorm: 77 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[i-1])) 78 | self.convs.append(GATv2Conv(hidden_channels*layer_heads[num_layers-2] 79 | , out_channels 80 | , heads=layer_heads[num_layers-1] 81 | , concat=False)) 82 | 83 | self.dropout = dropout 84 | 85 | def reset_parameters(self): 86 | for conv in self.convs: 87 | conv.reset_parameters() 88 | if self.batchnorm: 89 | for bn in self.bns: 90 | bn.reset_parameters() 91 | 92 | def forward(self, x, edge_index: Union[Tensor, SparseTensor]): 93 | for i, conv in enumerate(self.convs[:-1]): 94 | x = conv(x, edge_index) 95 | if self.batchnorm: 96 | x = self.bns[i](x) 97 | x = F.relu(x) 98 | x = F.dropout(x, p=self.dropout, training=self.training) 99 | x = self.convs[-1](x, edge_index) 100 | return x.log_softmax(dim=-1) -------------------------------------------------------------------------------- /models/gat_neighsampler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch import Tensor 4 | from torch_sparse import SparseTensor 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import GATConv, GATv2Conv 8 | from tqdm import tqdm 9 | 10 | class GAT_NeighSampler(torch.nn.Module): 11 | def __init__(self 12 | , in_channels 13 | , hidden_channels 14 | , out_channels 15 | , num_layers 16 | , dropout 17 | , layer_heads = [] 18 | , batchnorm=True): 19 | super(GAT_NeighSampler, self).__init__() 20 | 21 | self.convs = torch.nn.ModuleList() 22 | self.batchnorm = batchnorm 23 | self.num_layers = num_layers 24 | 25 | if len(layer_heads)>1: 26 | self.convs.append(GATConv(in_channels, hidden_channels, heads=layer_heads[0], concat=True)) 27 | if self.batchnorm: 28 | self.bns = torch.nn.ModuleList() 29 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[0])) 30 | for i in range(num_layers - 2): 31 | self.convs.append(GATConv(hidden_channels*layer_heads[i-1], hidden_channels, heads=layer_heads[i], concat=True)) 32 | if self.batchnorm: 33 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[i-1])) 34 | self.convs.append(GATConv(hidden_channels*layer_heads[num_layers-2] 35 | , out_channels 36 | , heads=layer_heads[num_layers-1] 37 | , concat=False)) 38 | else: 39 | self.convs.append(GATConv(in_channels, out_channels, heads=layer_heads[0], concat=False)) 40 | 41 | self.dropout = dropout 42 | 43 | def reset_parameters(self): 44 | for conv in self.convs: 45 | conv.reset_parameters() 46 | if self.batchnorm: 47 | for bn in self.bns: 48 | bn.reset_parameters() 49 | 50 | 51 | def forward(self, x, adjs): 52 | for i, (edge_index, _, size) in enumerate(adjs): 53 | x_target = x[:size[1]] 54 | x = self.convs[i]((x, x_target), edge_index) 55 | if i != self.num_layers-1: 56 | if self.batchnorm: 57 | x = self.bns[i](x) 58 | x = F.relu(x) 59 | x = F.dropout(x, p=0.5, training=self.training) 60 | 61 | return x.log_softmax(dim=-1) 62 | 63 | ''' 64 | subgraph_loader: size = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], 65 | batch_size=**, shuffle=False, 66 | num_workers=12) 67 | You can also sample the complete k-hop neighborhood, but this is rather expensive (especially for Reddit). 68 | We apply here trick here to compute the node embeddings efficiently: 69 | Instead of sampling multiple layers for a mini-batch, we instead compute the node embeddings layer-wise. 70 | Doing this exactly k times mimics a k-layer GNN. 71 | ''' 72 | 73 | def inference_all(self, data): 74 | x, adj_t = data.x, data.adj_t 75 | for i, conv in enumerate(self.convs[:-1]): 76 | x = conv(x, adj_t) 77 | if self.batchnorm: 78 | x = self.bns[i](x) 79 | x = F.relu(x) 80 | x = F.dropout(x, p=self.dropout, training=self.training) 81 | x = self.convs[-1](x, adj_t) 82 | return x.log_softmax(dim=-1) 83 | 84 | def inference(self, x_all, layer_loader, device): 85 | pbar = tqdm(total=x_all.size(0) * self.num_layers, ncols=80) 86 | pbar.set_description('Evaluating') 87 | 88 | # Compute representations of nodes layer by layer, using *all* 89 | # available edges. This leads to faster computation in contrast to 90 | # immediately computing the final representations of each batch. 91 | for i in range(self.num_layers): 92 | xs = [] 93 | for batch_size, n_id, adj in layer_loader: 94 | edge_index, _, size = adj.to(device) 95 | x = x_all[n_id].to(device) 96 | x_target = x[:size[1]] 97 | x = self.convs[i]((x, x_target), edge_index) 98 | if i != self.num_layers - 1: 99 | x = F.relu(x) 100 | if self.batchnorm: 101 | x = self.bns[i](x) 102 | xs.append(x) 103 | 104 | pbar.update(batch_size) 105 | 106 | x_all = torch.cat(xs, dim=0) 107 | 108 | pbar.close() 109 | 110 | return x_all.log_softmax(dim=-1) 111 | 112 | 113 | 114 | class GATv2_NeighSampler(torch.nn.Module): 115 | def __init__(self 116 | , in_channels 117 | , hidden_channels 118 | , out_channels 119 | , num_layers 120 | , dropout 121 | , layer_heads = [] 122 | , batchnorm=True): 123 | super(GATv2_NeighSampler, self).__init__() 124 | 125 | self.convs = torch.nn.ModuleList() 126 | self.batchnorm = batchnorm 127 | self.num_layers = num_layers 128 | 129 | if len(layer_heads)>1: 130 | self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=layer_heads[0], concat=True)) 131 | if self.batchnorm: 132 | self.bns = torch.nn.ModuleList() 133 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[0])) 134 | for i in range(num_layers - 2): 135 | self.convs.append(GATv2Conv(hidden_channels*layer_heads[i-1], hidden_channels, heads=layer_heads[i], concat=True)) 136 | if self.batchnorm: 137 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels*layer_heads[i-1])) 138 | self.convs.append(GATv2Conv(hidden_channels*layer_heads[num_layers-2] 139 | , out_channels 140 | , heads=layer_heads[num_layers-1] 141 | , concat=False)) 142 | else: 143 | self.convs.append(GATv2Conv(in_channels, out_channels, heads=layer_heads[0], concat=False)) 144 | 145 | self.dropout = dropout 146 | 147 | def reset_parameters(self): 148 | for conv in self.convs: 149 | conv.reset_parameters() 150 | if self.batchnorm: 151 | for bn in self.bns: 152 | bn.reset_parameters() 153 | 154 | 155 | def forward(self, x, adjs): 156 | for i, (edge_index, _, size) in enumerate(adjs): 157 | x_target = x[:size[1]] 158 | x = self.convs[i]((x, x_target), edge_index) 159 | if i != self.num_layers-1: 160 | if self.batchnorm: 161 | x = self.bns[i](x) 162 | x = F.relu(x) 163 | x = F.dropout(x, p=0.5, training=self.training) 164 | 165 | return x.log_softmax(dim=-1) 166 | 167 | ''' 168 | subgraph_loader: size = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], 169 | batch_size=**, shuffle=False, 170 | num_workers=12) 171 | You can also sample the complete k-hop neighborhood, but this is rather expensive (especially for Reddit). 172 | We apply here trick here to compute the node embeddings efficiently: 173 | Instead of sampling multiple layers for a mini-batch, we instead compute the node embeddings layer-wise. 174 | Doing this exactly k times mimics a k-layer GNN. 175 | ''' 176 | 177 | def inference_all(self, data): 178 | x, adj_t = data.x, data.adj_t 179 | for i, conv in enumerate(self.convs[:-1]): 180 | x = conv(x, adj_t) 181 | if self.batchnorm: 182 | x = self.bns[i](x) 183 | x = F.relu(x) 184 | x = F.dropout(x, p=self.dropout, training=self.training) 185 | x = self.convs[-1](x, adj_t) 186 | return x.log_softmax(dim=-1) 187 | 188 | def inference(self, x_all, layer_loader, device): 189 | pbar = tqdm(total=x_all.size(0) * self.num_layers, ncols=80) 190 | pbar.set_description('Evaluating') 191 | 192 | # Compute representations of nodes layer by layer, using *all* 193 | # available edges. This leads to faster computation in contrast to 194 | # immediately computing the final representations of each batch. 195 | for i in range(self.num_layers): 196 | xs = [] 197 | for batch_size, n_id, adj in layer_loader: 198 | edge_index, _, size = adj.to(device) 199 | x = x_all[n_id].to(device) 200 | x_target = x[:size[1]] 201 | x = self.convs[i]((x, x_target), edge_index) 202 | if i != self.num_layers - 1: 203 | x = F.relu(x) 204 | if self.batchnorm: 205 | x = self.bns[i](x) 206 | xs.append(x) 207 | 208 | pbar.update(batch_size) 209 | 210 | x_all = torch.cat(xs, dim=0) 211 | 212 | pbar.close() 213 | 214 | return x_all.log_softmax(dim=-1) 215 | 216 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch import Tensor 4 | from torch_sparse import SparseTensor 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import GCNConv 8 | 9 | ''' 10 | 此模型邻居矩阵采用sparse tensor的形式,可以大大减少计算量, 11 | 如果不使用sparse tensor形式传递,将adj_t替换成edge_index 12 | ''' 13 | class GCN(torch.nn.Module): 14 | def __init__(self 15 | , in_channels 16 | , hidden_channels 17 | , out_channels 18 | , num_layers 19 | , dropout 20 | , batchnorm=True): 21 | super(GCN, self).__init__() 22 | 23 | self.convs = torch.nn.ModuleList() 24 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True)) 25 | self.batchnorm = batchnorm 26 | if self.batchnorm: 27 | self.bns = torch.nn.ModuleList() 28 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 29 | for _ in range(num_layers - 2): 30 | self.convs.append( 31 | GCNConv(hidden_channels, hidden_channels, cached=True)) 32 | if self.batchnorm: 33 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 34 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True)) 35 | 36 | self.dropout = dropout 37 | 38 | def reset_parameters(self): 39 | for conv in self.convs: 40 | conv.reset_parameters() 41 | if self.batchnorm: 42 | for bn in self.bns: 43 | bn.reset_parameters() 44 | 45 | def forward(self, x, edge_index: Union[Tensor, SparseTensor]): 46 | for i, conv in enumerate(self.convs[:-1]): 47 | x = conv(x, edge_index) 48 | if self.batchnorm: 49 | x = self.bns[i](x) 50 | x = F.relu(x) 51 | x = F.dropout(x, p=self.dropout, training=self.training) 52 | x = self.convs[-1](x, edge_index) 53 | return x.log_softmax(dim=-1) -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class MLP(torch.nn.Module): 5 | def __init__(self 6 | , in_channels 7 | , hidden_channels 8 | , out_channels 9 | , num_layers 10 | , dropout 11 | , batchnorm=True): 12 | super(MLP, self).__init__() 13 | self.lins = torch.nn.ModuleList() 14 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 15 | self.batchnorm = batchnorm 16 | if self.batchnorm: 17 | self.bns = torch.nn.ModuleList() 18 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 19 | for _ in range(num_layers - 2): 20 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 21 | if self.batchnorm: 22 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 23 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 24 | 25 | self.dropout = dropout 26 | 27 | def reset_parameters(self): 28 | for lin in self.lins: 29 | lin.reset_parameters() 30 | if self.batchnorm: 31 | for bn in self.bns: 32 | bn.reset_parameters() 33 | 34 | def forward(self, x): 35 | for i, lin in enumerate(self.lins[:-1]): 36 | x = lin(x) 37 | if self.batchnorm: 38 | x = self.bns[i](x) 39 | x = F.relu(x) 40 | x = F.dropout(x, p=self.dropout, training=self.training) 41 | x = self.lins[-1](x) 42 | return F.log_softmax(x, dim=-1) 43 | 44 | 45 | 46 | class MLPLinear(torch.nn.Module): 47 | def __init__(self, in_channels, out_channels): 48 | super(MLPLinear, self).__init__() 49 | self.lin = torch.nn.Linear(in_channels, out_channels) 50 | 51 | def reset_parameters(self): 52 | self.lin.reset_parameters() 53 | 54 | def forward(self, x): 55 | return F.log_softmax(self.lin(x), dim=-1) 56 | -------------------------------------------------------------------------------- /models/rgcn.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import RGCNConv 2 | from typing import Union 3 | 4 | from torch import Tensor 5 | from torch_sparse import SparseTensor 6 | import torch 7 | import torch.nn.functional as F 8 | from torch_geometric.nn import GCNConv 9 | 10 | class RGCN(torch.nn.Module): 11 | def __init__(self 12 | , in_channels 13 | , hidden_channels 14 | , out_channels 15 | , num_relations 16 | , batchnorm=True): 17 | super(RGCN, self).__init__() 18 | 19 | self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations) 20 | self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations) 21 | self.lin = torch.nn.Linear(hidden_channels, out_channels) 22 | 23 | def reset_parameters(self): 24 | self.conv1.reset_parameters() 25 | self.conv2.reset_parameters() 26 | self.lin.reset_parameters() 27 | 28 | def forward(self, x, edge_index, edge_type): 29 | x = self.conv1(x, edge_index, edge_type).relu() 30 | x = self.conv2(x, edge_index, edge_type).relu() 31 | x = self.lin(x) 32 | return F.log_softmax(x, dim=-1) 33 | -------------------------------------------------------------------------------- /models/sage.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch import Tensor 4 | from torch_sparse import SparseTensor 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import SAGEConv 8 | 9 | class SAGE(torch.nn.Module): 10 | def __init__(self 11 | , in_channels 12 | , hidden_channels 13 | , out_channels 14 | , num_layers 15 | , dropout 16 | , batchnorm=True): 17 | super(SAGE, self).__init__() 18 | 19 | self.convs = torch.nn.ModuleList() 20 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 21 | self.bns = torch.nn.ModuleList() 22 | self.batchnorm = batchnorm 23 | if self.batchnorm: 24 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 25 | for _ in range(num_layers - 2): 26 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 27 | if self.batchnorm: 28 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 29 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 30 | 31 | self.dropout = dropout 32 | 33 | def reset_parameters(self): 34 | for conv in self.convs: 35 | conv.reset_parameters() 36 | if self.batchnorm: 37 | for bn in self.bns: 38 | bn.reset_parameters() 39 | 40 | def forward(self, x, edge_index: Union[Tensor, SparseTensor]): 41 | for i, conv in enumerate(self.convs[:-1]): 42 | x = conv(x, edge_index) 43 | if self.batchnorm: 44 | x = self.bns[i](x) 45 | x = F.relu(x) 46 | x = F.dropout(x, p=self.dropout, training=self.training) 47 | x = self.convs[-1](x, edge_index) 48 | return x.log_softmax(dim=-1) 49 | -------------------------------------------------------------------------------- /models/sage_neighsampler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch import Tensor 4 | from torch_sparse import SparseTensor 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import SAGEConv 8 | from tqdm import tqdm 9 | 10 | class SAGE_NeighSampler(torch.nn.Module): 11 | def __init__(self 12 | , in_channels 13 | , hidden_channels 14 | , out_channels 15 | , num_layers 16 | , dropout 17 | , batchnorm=True): 18 | super(SAGE_NeighSampler, self).__init__() 19 | 20 | self.convs = torch.nn.ModuleList() 21 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 22 | self.bns = torch.nn.ModuleList() 23 | self.batchnorm = batchnorm 24 | self.num_layers = num_layers 25 | if self.batchnorm: 26 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 27 | for i in range(num_layers - 2): 28 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 29 | if self.batchnorm: 30 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 31 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 32 | 33 | self.dropout = dropout 34 | 35 | def reset_parameters(self): 36 | for conv in self.convs: 37 | conv.reset_parameters() 38 | if self.batchnorm: 39 | for bn in self.bns: 40 | bn.reset_parameters() 41 | 42 | 43 | def forward(self, x, adjs): 44 | for i, (edge_index, _, size) in enumerate(adjs): 45 | x_target = x[:size[1]] 46 | x = self.convs[i]((x, x_target), edge_index) 47 | if i != self.num_layers-1: 48 | if self.batchnorm: 49 | x = self.bns[i](x) 50 | x = F.relu(x) 51 | x = F.dropout(x, p=0.5, training=self.training) 52 | 53 | return x.log_softmax(dim=-1) 54 | 55 | ''' 56 | subgraph_loader: size = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], 57 | batch_size=**, shuffle=False, 58 | num_workers=12) 59 | You can also sample the complete k-hop neighborhood, but this is rather expensive (especially for Reddit). 60 | We apply here trick here to compute the node embeddings efficiently: 61 | Instead of sampling multiple layers for a mini-batch, we instead compute the node embeddings layer-wise. 62 | Doing this exactly k times mimics a k-layer GNN. 63 | ''' 64 | 65 | def inference_all(self, data): 66 | x, adj_t = data.x, data.adj_t 67 | for i, conv in enumerate(self.convs[:-1]): 68 | x = conv(x, adj_t) 69 | if self.batchnorm: 70 | x = self.bns[i](x) 71 | x = F.relu(x) 72 | x = F.dropout(x, p=self.dropout, training=self.training) 73 | x = self.convs[-1](x, adj_t) 74 | return x.log_softmax(dim=-1) 75 | 76 | def inference(self, x_all, layer_loader, device): 77 | pbar = tqdm(total=x_all.size(0) * self.num_layers, ncols=80) 78 | pbar.set_description('Evaluating') 79 | 80 | # Compute representations of nodes layer by layer, using *all* 81 | # available edges. This leads to faster computation in contrast to 82 | # immediately computing the final representations of each batch. 83 | for i in range(self.num_layers): 84 | xs = [] 85 | for batch_size, n_id, adj in layer_loader: 86 | edge_index, _, size = adj.to(device) 87 | x = x_all[n_id].to(device) 88 | x_target = x[:size[1]] 89 | x = self.convs[i]((x, x_target), edge_index) 90 | if i != self.num_layers - 1: 91 | x = F.relu(x) 92 | if self.batchnorm: 93 | x = self.bns[i](x) 94 | xs.append(x) 95 | 96 | pbar.update(batch_size) 97 | 98 | x_all = torch.cat(xs, dim=0) 99 | 100 | pbar.close() 101 | 102 | return x_all.log_softmax(dim=-1) 103 | -------------------------------------------------------------------------------- /run_tgat.py: -------------------------------------------------------------------------------- 1 | from analysis.load_dataset import build_tg_data 2 | import numpy as np 3 | import torch 4 | import pandas as pd 5 | 6 | import torch.nn.functional as F 7 | from sklearn import metrics 8 | from torch_geometric.nn import TransformerConv 9 | def evaluate(y_truth,y_pred): 10 | auc = metrics.roc_auc_score(y_truth, y_pred, multi_class='ovo',labels=[0,1],average='macro') 11 | ap = metrics.average_precision_score(y_truth, y_pred, average='macro', pos_label=1, sample_weight=None) 12 | return ap,auc 13 | def process_data(data,max_time_steps=32): 14 | data.edge_time = data.edge_time-data.edge_time.min() #process edge time 15 | data.edge_time = data.edge_time/data.edge_time.max() 16 | data.edge_time = (data.edge_time*max_time_steps).long() 17 | data.edge_time = data.edge_time.view(-1,1).float() 18 | 19 | 20 | edge = torch.cat([data.edge_index,data.edge_time.view(1,-1)],dim=0) #process node time 21 | degree = pd.DataFrame(edge.T.numpy()).groupby(0).min().values 22 | ids = pd.DataFrame(data.edge_index.T.numpy()).groupby(0).count().index.values 23 | key = {} 24 | for i in range(data.x.shape[0]): 25 | key[i]=0 26 | for i in range(len(ids)): 27 | key[ids[i]]=degree[i][1] 28 | node_time = np.array(list(key.values())) 29 | data.node_time=torch.tensor(node_time) 30 | 31 | # trans to undirected graph 32 | data.edge_index = torch.cat((data.edge_index,data.edge_index[[1,0],:]),dim=1) 33 | data.edge_time = torch.cat((data.edge_time,data.edge_time),dim=0) 34 | 35 | return data 36 | 37 | class TimeEncode(torch.nn.Module): 38 | # https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs 39 | def __init__(self, expand_dim, factor=5): 40 | super(TimeEncode, self).__init__() 41 | #init_len = np.array([1e8**(i/(time_dim-1)) for i in range(time_dim)]) 42 | 43 | time_dim = expand_dim 44 | self.factor = factor 45 | self.basis_freq = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, time_dim))).float()) 46 | self.phase = torch.nn.Parameter(torch.zeros(time_dim).float()) 47 | 48 | #self.dense = torch.nn.Linear(time_dim, expand_dim, bias=False)# 49 | 50 | #torch.nn.init.xavier_normal_(self.dense.weight) 51 | 52 | def forward(self, ts): 53 | # ts: [N, L] 54 | batch_size = ts.size(0) 55 | seq_len = ts.size(1) 56 | 57 | ts = ts.view(batch_size, seq_len, 1)# [N, L, 1] 58 | map_ts = ts * self.basis_freq.view(1, 1, -1) # [N, L, time_dim] 59 | map_ts += self.phase.view(1, 1, -1) 60 | 61 | harmonic = torch.cos(map_ts) 62 | 63 | return harmonic #self.dense(harmonic) 64 | 65 | class TGAT(torch.nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super().__init__() 68 | self.time_enc = TimeEncode(32) 69 | edge_dim =32 70 | self.lin = torch.nn.Linear(17,32) 71 | self.conv = TransformerConv(32, 32 // 2, heads=2, 72 | dropout=0.1, edge_dim=edge_dim) 73 | self.conv1 = TransformerConv(32, 32 // 2, heads=2, 74 | dropout=0.1, edge_dim=edge_dim) 75 | self.out = torch.nn.Linear(32,2) 76 | def forward(self, x, edge_index, t): 77 | rel_t = data.node_time[edge_index[0]].view(-1,1) - t 78 | rel_t_enc = self.time_enc(rel_t.to(x.dtype)) 79 | #edge_attr = torch.cat([rel_t_enc, msg], dim=-1) 80 | h1 = self.lin(x) 81 | h1 = F.relu(h1) 82 | #print(h1.shape) 83 | h1 = self.conv(h1, edge_index, rel_t_enc) 84 | #h1 = F.relu(h1) 85 | #h2 = self.conv1(h1, edge_index, rel_t_enc) 86 | out = self.out(h1) 87 | return F.log_softmax(out,dim=1) 88 | 89 | 90 | 91 | 92 | if __name__ == "__main__": 93 | datapath = './dataset/DGraphFin/raw/dgraphfin.npz' 94 | origin_data = np.load(datapath) 95 | data = build_tg_data(is_undirected=False,datapath=datapath) 96 | data = process_data(data) 97 | 98 | device = torch.device('cuda:1') 99 | #model = GCN(data.x.shape[1],2) 100 | model = TGAT(in_channels=17, out_channels=2) 101 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0e-4) 102 | data = data.to(device) 103 | model = model.to(device) 104 | lossf = torch.nn.CrossEntropyLoss()#.cuda(1) 105 | loss=None 106 | val_acc = 0 107 | test = 0 108 | weight = torch.tensor([1,50]).to(device).float() 109 | duration = 0 110 | y_valid = data.y[data.val_mask].cpu()#.numpy() 111 | for i in range(1000): 112 | model.train() 113 | optimizer.zero_grad() 114 | out = model(x = data.x,edge_index = data.edge_index,t = data.edge_time) 115 | loss = F.nll_loss(out[data.train_mask],data.y[data.train_mask],weight=weight) 116 | loss.backward() 117 | optimizer.step() 118 | 119 | model.eval() 120 | with torch.no_grad(): 121 | scores = model(x = data.x,edge_index = data.edge_index,t = data.edge_time) 122 | val_ap , val_auc = evaluate(y_valid.numpy(),scores[data.val_mask,1].cpu().numpy()) 123 | val=val_auc 124 | if(val>val_acc): 125 | y_true = data.y[data.test_mask].cpu().numpy() 126 | y_scores = scores[:,1][data.test_mask].cpu().numpy() 127 | ap,auc=evaluate(y_true,y_scores) 128 | print('best (epoch,val_ap,val_auc,test_ap,test_auc):',i,val_ap,val_auc,ap,auc) 129 | val_acc=val 130 | duration=0 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgraphfin import DGraphFin 2 | 3 | -------------------------------------------------------------------------------- /utils/dgraphfin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, List 2 | import os.path as osp 3 | 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import InMemoryDataset 7 | from torch_geometric.data import Data 8 | 9 | 10 | def read_dgraphfin(folder): 11 | print('read_dgraphfin') 12 | names = ['dgraphfin.npz'] 13 | items = [np.load(folder+'/'+name) for name in names] 14 | 15 | x = items[0]['x'] 16 | y = items[0]['y'].reshape(-1,1) 17 | edge_index = items[0]['edge_index'] 18 | edge_type = items[0]['edge_type'] 19 | train_mask = items[0]['train_mask'] 20 | valid_mask = items[0]['valid_mask'] 21 | test_mask = items[0]['test_mask'] 22 | 23 | x = torch.tensor(x, dtype=torch.float).contiguous() 24 | y = torch.tensor(y, dtype=torch.int64) 25 | edge_index_tensor = torch.tensor(edge_index).long().T 26 | edge_index = torch.tensor(edge_index.transpose(), dtype=torch.int64).contiguous() 27 | edge_type = torch.tensor(edge_type, dtype=torch.float) 28 | train_mask = torch.tensor(train_mask, dtype=torch.int64) 29 | valid_mask = torch.tensor(valid_mask, dtype=torch.int64) 30 | test_mask = torch.tensor(test_mask, dtype=torch.int64) 31 | 32 | data = Data(x=x, edge_index=edge_index_tensor, edge_attr=edge_type, y=y) 33 | data.train_mask = train_mask 34 | data.valid_mask = valid_mask 35 | data.test_mask = test_mask 36 | #data.edge_index = edge_index_tensor 37 | print(data.edge_index) 38 | return data 39 | 40 | class DGraphFin(InMemoryDataset): 41 | r""" 42 | Args: 43 | root (string): Root directory where the dataset should be saved. 44 | name (string): The name of the dataset (:obj:`"dgraphfin"`). 45 | transform (callable, optional): A function/transform that takes in an 46 | :obj:`torch_geometric.data.Data` object and returns a transformed 47 | version. The data object will be transformed before every access. 48 | (default: :obj:`None`) 49 | pre_transform (callable, optional): A function/transform that takes in 50 | an :obj:`torch_geometric.data.Data` object and returns a 51 | transformed version. The data object will be transformed before 52 | being saved to disk. (default: :obj:`None`) 53 | """ 54 | 55 | url = '' 56 | 57 | def __init__(self, root: str, name: str, 58 | transform: Optional[Callable] = None, 59 | pre_transform: Optional[Callable] = None): 60 | 61 | self.name = name 62 | super().__init__(root, transform, pre_transform) 63 | self.data, self.slices = torch.load(self.processed_paths[0]) 64 | 65 | @property 66 | def raw_dir(self) -> str: 67 | return osp.join(self.root, self.name, 'raw') 68 | 69 | @property 70 | def processed_dir(self) -> str: 71 | return osp.join(self.root, self.name, 'processed') 72 | 73 | @property 74 | def raw_file_names(self) -> List[str]: 75 | names = ['dgraphfin.npz'] 76 | return names 77 | 78 | @property 79 | def processed_file_names(self) -> str: 80 | return 'data.pt' 81 | 82 | def download(self): 83 | pass 84 | # for name in self.raw_file_names: 85 | # download_url('{}/{}'.format(self.url, name), self.raw_dir) 86 | 87 | def process(self): 88 | data = read_dgraphfin(self.raw_dir) 89 | data = data if self.pre_transform is None else self.pre_transform(data) 90 | torch.save(self.collate([data]), self.processed_paths[0]) 91 | 92 | def __repr__(self) -> str: 93 | return f'{self.name}()' -------------------------------------------------------------------------------- /utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn.metrics import roc_auc_score 5 | from sklearn.metrics import average_precision_score 6 | try: 7 | import torch 8 | except ImportError: 9 | torch = None 10 | 11 | ### Evaluator for node property prediction 12 | class Evaluator: 13 | def __init__(self, eval_metric): 14 | if eval_metric not in ['auc']: 15 | raise ValueError('eval_metric should be acc or auc') 16 | 17 | self.eval_metric = eval_metric 18 | 19 | def _check_input(self, y_true, y_pred): 20 | ''' 21 | y_true: numpy ndarray or torch tensor of shape (num_node) 22 | y_pred: numpy ndarray or torch tensor of shape (num_node, num_tasks) 23 | ''' 24 | 25 | # converting to torch.Tensor to numpy on cpu 26 | if torch is not None and isinstance(y_true, torch.Tensor): 27 | y_true = y_true.detach().cpu().numpy() 28 | 29 | if torch is not None and isinstance(y_pred, torch.Tensor): 30 | y_pred = y_pred.detach().cpu().numpy() 31 | 32 | ## check type 33 | if not (isinstance(y_true, np.ndarray) and isinstance(y_true, np.ndarray)): 34 | raise RuntimeError('Arguments to Evaluator need to be either numpy ndarray or torch tensor') 35 | 36 | if not y_pred.ndim == 2: 37 | raise RuntimeError('y_pred must to 2-dim arrray, {}-dim array given'.format(y_true.ndim)) 38 | 39 | return y_true, y_pred 40 | 41 | def eval(self, y_true, y_pred): 42 | if self.eval_metric == 'auc': 43 | y_true, y_pred = self._check_input(y_true, y_pred) 44 | return self._eval_rocauc(y_true, y_pred) 45 | 46 | def _eval_rocauc(self, y_true, y_pred): 47 | ''' 48 | compute ROC-AUC and AP score averaged across tasks 49 | ''' 50 | 51 | if y_pred.shape[1] ==2: 52 | auc = roc_auc_score(y_true, y_pred[:, 1]) 53 | ap = average_precision_score(y_true, y_pred[:, 1]) 54 | else: 55 | onehot_code = np.eye(y_pred.shape[1]) 56 | y_true_onehot = onehot_code[y_true] 57 | auc = roc_auc_score(y_true_onehot, y_pred) 58 | ap = average_precision_score(y_true_onehot, y_pred) 59 | 60 | return {'auc': auc, 'ap': ap} 61 | 62 | -------------------------------------------------------------------------------- /utils/tricks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric as tg 3 | import copy 4 | import numpy as np 5 | class Structure: 6 | def __init__(self,trick): 7 | if trick not in ['original', 'knn', 'random']: 8 | raise ValueError('trick should be original, knn, or random') 9 | self.trick = trick 10 | def process(self,data): 11 | if self.trick == 'original': 12 | return data 13 | if self.trick == 'knn': 14 | return self._KNNGraph(data) 15 | if self.trick == 'random': 16 | return self._RandomGraph(data) 17 | def _KNNGraph(self,data,num=4300999): 18 | num=4300999 19 | node_num = data.x.shape[0] 20 | L = (torch.rand(num*50)*node_num).long().view(-1,1) 21 | R = (torch.rand(num*50)*node_num).long().view(-1,1) 22 | flag = (L.view(-1)!=R.view(-1)) 23 | L = L[flag,:] 24 | R = R[flag,:] 25 | edge_index = torch.cat((L,R),dim=1).T 26 | x = data.x/data.x.norm(dim=1).view(-1,1) 27 | L = x[L.view(-1)] 28 | R = x[R.view(-1)] 29 | score = (L*R).sum(dim=1) 30 | score_a= score[score>0.9] 31 | edge_index = edge_index[:,score>0.9] 32 | score = score_a.numpy() 33 | index = np.argsort(-score) 34 | index = torch.tensor(index).long() 35 | edge_index = edge_index[:,index[:4300999]] 36 | data.edge_index = edge_index 37 | return data 38 | def _RandomGraph(self,data,num=4300999): 39 | node_num = data.x.shape[0] 40 | L = (torch.rand(num*2)*node_num).long().view(-1,1) 41 | R = (torch.rand(num*2)*node_num).long().view(-1,1) 42 | flag = (L.view(-1)!=R.view(-1)) 43 | L = L[flag,:] 44 | R = R[flag,:] 45 | L = L[:num,:] 46 | R = R[:num,:] 47 | edge_index = torch.cat((L,R),dim=1).T 48 | data.edge_index = edge_index 49 | return data 50 | 51 | class Missvalues: 52 | def __init__(self,trick): 53 | if trick not in ['null', 'default', 'trickA', 'trickB', 'trickC']: 54 | raise ValueError('trick should be null, default, trickA, trickB or trickC') 55 | self.trick = trick 56 | 57 | def process(self, data): 58 | if self.trick == 'null': 59 | return self._null(data) 60 | elif self.trick == 'default': 61 | return self._default(data) 62 | elif self.trick == 'trickA': 63 | return self._trickA(data) 64 | elif self.trick == 'trickB': 65 | return self._trickB(data) 66 | elif self.trick == 'trickC': 67 | return self._trickC(data) 68 | def _null(self, data): 69 | return data 70 | 71 | def _default(self, data): 72 | x = torch.cat([data.x,data.x],dim=1) 73 | data.x = x 74 | return data 75 | 76 | def _trickA(self, data): 77 | x = torch.cat([data.x,(data.x==-1).long()],dim=1) 78 | data.x = x 79 | return data 80 | 81 | def _trickB(self, data): 82 | x = data.x 83 | x[x==-1]+=1 84 | x = torch.cat([x,(data.x==-1).long()],dim=1) 85 | data.x = x 86 | return data 87 | 88 | def _trickC(self, data): 89 | x = data.x 90 | x[x==-1]+=1 91 | x = torch.cat([x,(data.x==-1).long()],dim=1) 92 | data.x = x 93 | return data 94 | 95 | class Background: 96 | def __init__(self, trick,): 97 | if trick not in ['null', 'remove', 'flag', 'hetro']: 98 | raise ValueError('trick should be null, remove, flag, hetro') 99 | self.trick = trick 100 | def process(self, data, ratio = 0.5): 101 | 102 | if self.trick == 'null': 103 | return self._null(data) 104 | 105 | elif self.trick == 'remove': 106 | return self._remove(data, ratio) 107 | 108 | elif self.trick == 'flag': 109 | return self._flag(data) 110 | 111 | elif self.trick == 'hetro': 112 | return self._trans2hetro(data) 113 | 114 | def _null(self, data): 115 | return data 116 | 117 | def _remove(self, data,ratio=0.5): 118 | def build_new_mask(tmask,pure_idx): 119 | mask = torch.zeros(3700550) 120 | mask[tmask.long()]=1 121 | mask = mask[pure_idx] 122 | ids = torch.arange(len(mask)) 123 | ids = ids[mask.bool()] 124 | return ids 125 | pure_idx = torch.arange(data.x.shape[0])[data.y<=1] 126 | pure_idx2 = torch.arange(data.x.shape[0])[data.y>1] 127 | randn = torch.rand(len(pure_idx2)) 128 | pure_idx2 = pure_idx2[randn<=ratio] 129 | pure_idx = torch.cat([pure_idx,pure_idx2],dim=0) 130 | print(pure_idx) 131 | print(data.edge_index) 132 | pure_edge_index = tg.utils.subgraph(pure_idx,data.edge_index,relabel_nodes=True)[0] 133 | pure_data = copy.deepcopy(data) 134 | pure_data.x = data.x[pure_idx] 135 | pure_data.edge_index = pure_edge_index 136 | pure_data.y = data.y[pure_idx] 137 | pure_data.train_mask = build_new_mask(data.train_mask,pure_idx) 138 | pure_data.valid_mask = build_new_mask(data.valid_mask,pure_idx) 139 | pure_data.test_mask = build_new_mask(data.test_mask,pure_idx) 140 | print(pure_data.train_mask.max()) 141 | return pure_data 142 | 143 | def _flag(self, data): 144 | flag=(data.y<=1).float().view(-1,1) 145 | flag1 = (data.y>1).float().view(-1,1) 146 | data.x = torch.cat([flag,flag1,data.x],dim=1) 147 | return data 148 | 149 | def _trans2hetro(self, data): 150 | l = data.y[data.edge_index[0,:]] 151 | r = data.y[data.edge_index[1,:]] 152 | edge_type = torch.zeros(data.edge_index.shape[1]) 153 | edge_type[(l<=1) & (r<=1)]=0 154 | edge_type[(l<=1) & (r>1)]=1 155 | edge_type[(l>1) & (r<=1)]=2 156 | edge_type[(l>1) & (r>1)]=3 157 | data.edge_type=edge_type.long() 158 | data.edge_type=edge_type.long() 159 | return data -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import shutil 5 | 6 | 7 | def prepare_folder(name, model_name): 8 | model_dir = f'./model_results/{name}/{model_name}' 9 | 10 | if os.path.exists(model_dir): 11 | shutil.rmtree(model_dir) 12 | os.makedirs(model_dir) 13 | return model_dir 14 | 15 | def prepare_tune_folder(name, model_name): 16 | str_time = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S') 17 | tune_model_dir = f'./tune_results/{name}/{model_name}/{str_time}/' 18 | 19 | if os.path.exists(tune_model_dir): 20 | print(f'rm tune_model_dir {tune_model_dir}') 21 | shutil.rmtree(tune_model_dir) 22 | os.makedirs(tune_model_dir) 23 | print(f'make tune_model_dir {tune_model_dir}') 24 | return tune_model_dir 25 | 26 | def save_preds_and_params(parameters, preds, model, file): 27 | save_dict = {'parameters':parameters, 'preds': preds, 'params': model.state_dict() 28 | , 'nparams': sum(p.numel() for p in model.parameters())} 29 | torch.save(save_dict, file) 30 | return 31 | 32 | 33 | 34 | 35 | --------------------------------------------------------------------------------