├── data ├── data.zip ├── config_graph_sim.ini ├── config_graph_sim_nokb.ini ├── config_graph_sim_nogcn.ini ├── README.md ├── requirements.txt ├── model.py ├── tensorboard_logger.py ├── Radm.py ├── model_batch.py ├── graph_sim_no_gcn_dej_X.py ├── graph_sim_dej_X.py └── graph_sim_no_kb_dej_X.py /data: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daixixiwang/KGroot/HEAD/data.zip -------------------------------------------------------------------------------- /config_graph_sim.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | # 图的最大点数,会影响池化核和输入数据的padding 3 | max_node_num=30 4 | # 输入w2v的每个事件的特征维度 5 | input_dim=100 6 | # gcn 输出每个结点向量维度 7 | gcn_hidden_dim=50 8 | # 多层感知机 中间单元数 9 | linear_hidden_dim=10 10 | 11 | num_bases=-1 12 | # 多少比例的参数不进行更新 13 | dropout=0. 14 | # 输入模型邻接矩阵的个数,本场景种恒为3 不需要变动 15 | support=3 16 | pool_step=10 17 | 18 | [data] 19 | DATASET = 1 20 | # 选择resplit会重新划分数据集 21 | resplit = True 22 | resplit_each_time = False 23 | batch_size = 100 24 | # 在resplit为true,且负样本数量更多时。1表示自适应repeat正样本数量,大于1表示repeat几次 25 | repeat_pos_data = 1 26 | # 数据集版本 "raw"表示未经优化, 27 | dataset_version=final_same 28 | 29 | [train] 30 | NB_EPOCH = 2000 31 | LR = 0.001 32 | l2norm = 0. 33 | # 已被弃用 34 | # cross_weight=0.06 35 | # 添加评论 会在runs文件夹目录名 后附上评论 36 | comment=D2_n30step10datasetraw 37 | ;[print_logging] 38 | ;level = "error" 39 | -------------------------------------------------------------------------------- /config_graph_sim_nokb.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | # 图的最大点数,会影响池化核和输入数据的padding 3 | max_node_num=100 4 | # 输入w2v的每个事件的特征维度 5 | input_dim=100 6 | # gcn 输出每个结点向量维度 7 | gcn_hidden_dim=64 8 | # 多层感知机 中间单元数 9 | linear_hidden_dim=32 10 | 11 | num_bases=-1 12 | # 多少比例的参数不进行更新 13 | dropout=0.1 14 | # 输入模型邻接矩阵的个数,本场景种恒为3 不需要变动 15 | support=3 16 | pool_step=20 17 | 18 | [data] 19 | DATASET = 0 20 | # 选择resplit会重新划分数据集 21 | resplit = False 22 | resplit_each_time = False 23 | # batch_size最好为 16 32 24 | batch_size = 16 25 | # 在resplit为true,且负样本数量更多时。1表示自适应repeat正样本数量,大于1表示repeat几次 26 | repeat_pos_data = 1 27 | # 数据集版本 "raw"表示未经优化, 28 | dataset_version=final_same 29 | 30 | [train] 31 | NB_EPOCH = 2000 32 | LR = 1e-3 33 | # 正则项 一般不适用目的在于处理过拟合。处理过拟合可以使用bn dropout 34 | l2norm = 0. 35 | # 已被弃用 36 | # cross_weight=0.06 37 | # 添加评论 会在runs文件夹目录名 后附上评论 38 | ;comment=D2_n30step10datasetrawspliteveryerrortype 39 | comment=D1_n100step20Nokbdatasetfinal_samebatch16drop01r-3split631 40 | ;[print_logging] 41 | ;level = "error" -------------------------------------------------------------------------------- /config_graph_sim_nogcn.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | # 图的最大点数,会影响池化核和输入数据的padding 3 | max_node_num=100 4 | # 输入w2v的每个事件的特征维度 5 | input_dim=100 6 | # gcn 输出每个结点向量维度 7 | gcn_hidden_dim=50 8 | # 多层感知机 中间单元数 9 | linear_hidden_dim=10 10 | 11 | num_bases=-1 12 | # 多少比例的参数不进行更新 13 | dropout=0.1 14 | # 输入模型邻接矩阵的个数,本场景种恒为3 不需要变动 15 | support=3 16 | pool_step=20 17 | 18 | [data] 19 | DATASET = 1 20 | # 选择resplit会重新划分数据集 21 | resplit = False 22 | resplit_each_time = False 23 | # batch_size最好为 16 32 24 | batch_size = 16 25 | # 在resplit为true,且负样本数量更多时。1表示自适应repeat正样本数量,大于1表示repeat几次 26 | repeat_pos_data = 1 27 | # 数据集版本 "raw"表示未经优化, 28 | dataset_version=final_same 29 | 30 | [train] 31 | NB_EPOCH = 2000 32 | LR = 1e-3 33 | # 正则项 一般不适用目的在于处理过拟合。处理过拟合可以使用bn dropout 34 | l2norm = 0. 35 | # 已被弃用 36 | # cross_weight=0.06 37 | # 添加评论 会在runs文件夹目录名 后附上评论 38 | ;comment=D2_n30step10datasetrawspliteveryerrortype 39 | comment=D1_n100step20Nogcndatasetfinal_samebatch16drop01r-3split631 40 | ;[print_logging] 41 | ;level = "error" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 1 Dataset 2 | raw data download: dataset C and dataset D https://www.dropbox.com/sh/ist4ojr03e2oeuw/AAD5NkpAFg1nOI2Ttug3h2qja?dl=0 3 | initial events:raw data sets are transferred to initial events,dataset C --> data/events_initial_A、dataset D -->data/events_initial_B 4 | 5 | 2 Dataprocess 6 | 7 | step1: cal_faults.py read raw data faults.csv from dataset C and dataset D, construct offline_data_set_info.json和online_data_set_info.json 8 | 9 | step2: runtime evn configuration, log4j 10 | 11 | step3: use the method proposed in this paper to preprocess event data for train & test 12 | 13 | step4: DataSetGraphSimGenerator.py split train & test dataset 14 | 15 | 16 | 3 Method 17 | 18 | #-----------fault KG constuction---------- 19 | 20 | KBConstruction.py 21 | 22 | #---------------KGroot-------------------- 23 | 24 | graph_sim_dej_X.py #X represents the normalization of Dataset A/Dataset B. 25 | 26 | 27 | #-------KGroot without GCN-------------- 28 | 29 | graph_sim_no_gcn_dej_X.py 30 | 31 | 32 | #---------KGroot without KG-------------- 33 | 34 | graph_sim_no_kb_dej_X.py 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | alabaster==0.7.12 3 | aliyun-log-cli==0.1.17 4 | aliyun-log-python-sdk==0.6.47.7 5 | aliyun-python-sdk-cms==7.0.4 6 | aliyun-python-sdk-core==2.13.5 7 | aliyun-python-sdk-core-v3==2.13.3 8 | aliyun-python-sdk-ecs==4.16.11 9 | aliyun-python-sdk-rds==2.3.9 10 | aliyun-python-sdk-slb==3.2.10 11 | aliyun-python-sdk-vpc==3.0.5 12 | anaconda-client==1.7.2 13 | anaconda-navigator==1.9.7 14 | anaconda-project==0.8.3 15 | aniso8601==8.0.0 16 | asn1crypto==0.24.0 17 | astor==0.8.0 18 | astroid==2.2.5 19 | astropy==3.2.1 20 | atomicwrites==1.3.0 21 | attrs==19.1.0 22 | Babel==2.7.0 23 | backcall==0.1.0 24 | backports.functools-lru-cache==1.5 25 | backports.os==0.1.1 26 | backports.shutil-get-terminal-size==1.0.0 27 | backports.tempfile==1.0 28 | backports.weakref==1.0.post1 29 | beautifulsoup4==4.7.1 30 | bitarray==0.9.3 31 | bkcharts==0.2 32 | bleach==3.1.0 33 | blinker==1.4 34 | bokeh==1.2.0 35 | boto==2.49.0 36 | boto3==1.12.12 37 | botocore==1.15.12 38 | Bottleneck==1.2.1 39 | bz2file==0.98 40 | cachetools==3.1.1 41 | certifi==2020.4.5.1 42 | cffi==1.12.3 43 | chardet==3.0.4 44 | Click==7.0 45 | cloudpickle==1.2.1 46 | clyent==1.2.2 47 | colorama==0.4.1 48 | coloredlogs==10.0 49 | comtypes==1.1.7 50 | conda==4.8.3 51 | conda-build==3.18.8 52 | conda-package-handling==1.3.11 53 | conda-verify==3.4.2 54 | contextlib2==0.5.5 55 | cryptography==2.7 56 | cvxopt==1.2.0 57 | cycler==0.10.0 58 | Cython==0.29.12 59 | cytoolz==0.10.0 60 | dask==2.1.0 61 | dateparser==0.7.2 62 | decorator==4.4.0 63 | defusedxml==0.6.0 64 | distributed==2.1.0 65 | docopt==0.6.2 66 | docutils==0.14 67 | dtw==1.3.3 68 | elasticsearch==6.4.0 69 | entrypoints==0.3 70 | et-xmlfile==1.0.1 71 | fastcache==1.1.0 72 | filelock==3.0.12 73 | Flask==1.1.1 74 | Flask-RESTful==0.3.7 75 | fp-growth==0.1.3 76 | future==0.17.1 77 | gast==0.2.2 78 | gensim==3.8.0 79 | gevent==1.4.0 80 | glob2==0.7 81 | google-auth==1.13.1 82 | google-auth-oauthlib==0.4.1 83 | google-pasta==0.2.0 84 | grakel==0.1.8 85 | graphviz==0.13 86 | greenlet==0.4.15 87 | grpcio==1.27.2 88 | h5py==2.9.0 89 | heapdict==1.0.0 90 | html5lib==1.0.1 91 | humanfriendly==4.18 92 | idna==2.8 93 | imageio==2.5.0 94 | imagesize==1.1.0 95 | importlib-metadata==0.17 96 | ipykernel==5.1.1 97 | ipython==7.6.1 98 | ipython-genutils==0.2.0 99 | ipywidgets==7.5.0 100 | isort==4.3.21 101 | itsdangerous==1.1.0 102 | jdcal==1.4.1 103 | jedi==0.13.3 104 | jieba==0.42.1 105 | Jinja2==2.10.1 106 | jmespath==0.9.4 107 | joblib==0.13.2 108 | JPype1==0.7.0 109 | json5==0.8.4 110 | jsonschema==3.0.1 111 | jupyter==1.0.0 112 | jupyter-client==5.3.1 113 | jupyter-console==6.0.0 114 | jupyter-core==4.5.0 115 | jupyterlab==1.0.2 116 | jupyterlab-server==1.0.0 117 | Keras-Applications==1.0.8 118 | Keras-Preprocessing==1.1.0 119 | keyring==18.0.0 120 | kiwisolver==1.1.0 121 | lazy-object-proxy==1.4.1 122 | libarchive-c==2.8 123 | llvmlite==0.29.0 124 | locket==0.2.0 125 | lxml==4.3.4 126 | Markdown==3.1.1 127 | MarkupSafe==1.1.1 128 | matplotlib==3.1.0 129 | mccabe==0.6.1 130 | menuinst==1.4.16 131 | mistune==0.8.4 132 | mkl-fft==1.0.12 133 | mkl-random==1.0.2 134 | mkl-service==2.0.2 135 | mock==3.0.5 136 | more-itertools==7.0.0 137 | MouseInfo==0.1.2 138 | mpmath==1.1.0 139 | msgpack==0.6.1 140 | multipledispatch==0.6.0 141 | navigator-updater==0.2.1 142 | nbconvert==5.5.0 143 | nbformat==4.4.0 144 | neo4j==1.7.4 145 | neobolt==1.7.13 146 | neotime==1.7.4 147 | networkx==2.3 148 | nltk==3.4.4 149 | nose==1.3.7 150 | notebook==6.0.0 151 | numba==0.44.1 152 | numexpr==2.6.9 153 | numpy==1.20.1 154 | numpydoc==0.9.1 155 | oauthlib==3.1.0 156 | olefile==0.46 157 | openpyxl==2.6.2 158 | opt-einsum==3.1.0 159 | packaging==19.0 160 | pandas==0.24.2 161 | pandocfilters==1.4.2 162 | parso==0.5.0 163 | partd==1.0.0 164 | path.py==12.0.1 165 | pathlib2==2.3.4 166 | patsy==0.5.1 167 | pep8==1.7.1 168 | pgmpy==0.1.7 169 | pickleshare==0.7.5 170 | Pillow==6.1.0 171 | pkgconfig==1.5.2 172 | pkginfo==1.5.0.1 173 | plotly==4.5.2 174 | pluggy==0.12.0 175 | ply==3.11 176 | prettytable==0.7.2 177 | prometheus-client==0.7.1 178 | prompt-toolkit==2.0.9 179 | protobuf==3.11.4 180 | psutil==5.6.3 181 | py==1.8.0 182 | py2neo==4.3.0 183 | pyasn1==0.4.8 184 | pyasn1-modules==0.2.7 185 | PyAutoGUI==0.9.48 186 | pycodestyle==2.5.0 187 | pycosat==0.6.3 188 | pycparser==2.19 189 | pycrypto==2.6.1 190 | pycurl==7.43.0.3 191 | pyflakes==2.1.1 192 | PyGetWindow==0.0.8 193 | Pygments==2.6.1 194 | PyJWT==1.7.1 195 | pylint==2.3.1 196 | PyMsgBox==1.0.7 197 | pyodbc==4.0.26 198 | pyOpenSSL==19.0.0 199 | pyparsing==2.4.0 200 | pyperclip==1.7.0 201 | pyreadline==2.1 202 | PyRect==0.1.4 203 | pyrsistent==0.14.11 204 | PyScreeze==0.1.26 205 | PySocks==1.7.0 206 | pytest==5.0.1 207 | pytest-arraydiff==0.3 208 | pytest-astropy==0.5.0 209 | pytest-doctestplus==0.3.0 210 | pytest-openfiles==0.3.2 211 | pytest-remotedata==0.3.1 212 | python-dateutil==2.8.0 213 | python-Levenshtein==0.12.2 214 | PyTweening==1.0.3 215 | pytz==2019.1 216 | PyWavelets==1.0.3 217 | pywin32==223 218 | pywinpty==0.5.5 219 | PyYAML==5.1.1 220 | pyzmq==18.0.0 221 | QtAwesome==0.5.7 222 | qtconsole==4.5.1 223 | QtPy==1.8.0 224 | regex==2019.8.19 225 | requests==2.22.0 226 | requests-oauthlib==1.3.0 227 | retrying==1.3.3 228 | rope==0.14.0 229 | rsa==4.0 230 | ruamel-yaml==0.15.46 231 | s3transfer==0.3.3 232 | scikit-image==0.15.0 233 | scikit-learn==0.21.2 234 | scipy==1.4.1 235 | seaborn==0.9.0 236 | selenium==3.141.0 237 | Send2Trash==1.5.0 238 | simplegeneric==0.8.1 239 | singledispatch==3.4.0.3 240 | six==1.12.0 241 | smart-open==1.9.0 242 | snowballstemmer==1.9.0 243 | sortedcollections==1.1.2 244 | sortedcontainers==2.1.0 245 | soupsieve==1.8 246 | Sphinx==2.1.2 247 | sphinxcontrib-applehelp==1.0.1 248 | sphinxcontrib-devhelp==1.0.1 249 | sphinxcontrib-htmlhelp==1.0.2 250 | sphinxcontrib-jsmath==1.0.1 251 | sphinxcontrib-qthelp==1.0.2 252 | sphinxcontrib-serializinghtml==1.1.3 253 | sphinxcontrib-websupport==1.1.2 254 | spyder==3.3.6 255 | spyder-kernels==0.5.1 256 | SQLAlchemy==1.3.5 257 | statsmodels==0.10.0 258 | style==1.1.0 259 | sympy==1.4 260 | tables==3.5.2 261 | tblib==1.4.0 262 | tensorboard==2.1.0 263 | tensorboardX==2.0 264 | tensorflow==2.1.0 265 | tensorflow-estimator==2.1.0 266 | termcolor==1.1.0 267 | terminado==0.8.2 268 | testpath==0.4.2 269 | toolz==0.10.0 270 | torch==1.4.0 271 | torchvision==0.5.0 272 | tornado==6.0.3 273 | tqdm==4.32.1 274 | traitlets==4.3.2 275 | tzlocal==2.0.0 276 | unicodecsv==0.14.1 277 | update==0.0.1 278 | urllib3==1.24.2 279 | uuid==1.30 280 | wcwidth==0.1.7 281 | webencodings==0.5.1 282 | Werkzeug==0.15.4 283 | widgetsnbextension==3.5.0 284 | win-inet-pton==1.1.0 285 | win-unicode-console==0.5 286 | wincertstore==0.2 287 | word2vec==0.9.4+2.g8204e5c 288 | wrapt==1.11.2 289 | xlrd==1.2.0 290 | XlsxWriter==1.1.8 291 | xlwings==0.15.8 292 | xlwt==1.3.0 293 | zict==1.0.0 294 | zipp==0.5.1 295 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from scipy import sparse 7 | 8 | 9 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | # device = torch.device("cpu") 13 | 14 | class GraphConvolution(nn.Module): 15 | def __init__(self, input_dim, output_dim, support=1, featureless=True, 16 | init='glorot_uniform', activation='linear', 17 | weights=None, W_regularizer=None, num_bases=-1, 18 | b_regularizer=None, bias=False, dropout=0., **kwargs): 19 | """ 20 | 21 | :param input_dim: 输入维度,对应A[0]的行数,即图的结点数 22 | :param output_dim: 超参数,隐藏层的units数目 23 | :param support: A 的长度 感觉应该是 边种类数 * 2 + 1 24 | :param featureless: 使用或者忽略输入的features 25 | :param init: 26 | :param activation: 激活函数 27 | :param weights: 28 | :param W_regularizer: 29 | :param num_bases: 使用的bases数量 (-1:all) 30 | :param b_regularizer: 31 | :param bias: 32 | :param dropout: 33 | :param kwargs: 34 | """ 35 | super(GraphConvolution, self).__init__() 36 | # self.init = initializations.get(init) 37 | # self.activation = activations.get(activation) 38 | if activation == "relu": 39 | self.activation = nn.ReLU() 40 | elif activation == "softmax": 41 | self.activation = nn.Softmax(dim=-1) 42 | else: 43 | self.activation = F.ReLU() 44 | self.input_dim = input_dim 45 | self.output_dim = output_dim # number of features per node 46 | self.support = support # filter support / number of weights 47 | self.featureless = featureless # use/ignore input features 48 | self.dropout = dropout 49 | self.w_regularizer = nn.L1Loss() 50 | 51 | assert support >= 1 52 | 53 | # TODO 54 | 55 | self.bias = bias 56 | self.initial_weights = weights 57 | self.num_bases = num_bases 58 | 59 | # these will be defined during build() 60 | # self.input_dim = None 61 | if self.num_bases > 0: 62 | # 使用的bases数大于0 63 | self.W = nn.Parameter( 64 | torch.empty(self.input_dim * self.num_bases, self.output_dim, dtype=torch.float32, device=device)) 65 | self.W_comp = nn.Parameter(torch.empty(self.support, self.num_bases, dtype=torch.float32, device=device)) 66 | nn.init.xavier_uniform_(self.W_comp) # 通过网络层时,输入和输出的方差相同 67 | else: 68 | self.W = nn.Parameter( 69 | torch.empty(self.input_dim * self.support, self.output_dim, dtype=torch.float32, device=device)) 70 | nn.init.xavier_uniform_(self.W) 71 | 72 | if self.bias: 73 | self.b = nn.Parameter(torch.empty(self.output_dim, dtype=torch.float32, device=device)) 74 | nn.init.xavier_uniform_(self.b) 75 | """ 76 | Dropout就是在不同的训练过程中随机扔掉一部分神经元。也就是让某个神经元的激活值以一定的概率p,让其停止工作, 77 | 这次训练过程中不更新权值,也不参加神经网络的计算。但是它的权重得保留下来(只是暂时不更新而已), 78 | 因为下次样本输入时它可能又得工作了 79 | """ 80 | self.dropout = nn.Dropout(dropout) 81 | 82 | def get_output_shape_for(self, input_shapes): 83 | features_shape = input_shapes[0] 84 | output_shape = (features_shape[0], self.output_dim) 85 | return output_shape # (batch_size, output_dim) 86 | 87 | def forward(self, inputs, mask=None): 88 | # inputs是 [x] + A; x是一个A[0]shape的稀疏矩阵,是一个特征矩阵 89 | features = torch.tensor(inputs[0], dtype=torch.float32, device=device) 90 | A = inputs[1:] # list of basis functions 91 | # 92 | # 寻找矩阵所有不为0的值 sparse.find(a)[-1] 93 | # a.nonzero() 返回 每个非0元素 所在行和列 94 | A = [torch.sparse.FloatTensor(torch.LongTensor(a.nonzero()) 95 | , torch.FloatTensor(sparse.find(a)[-1]) 96 | , torch.Size(a.shape)).to(device) 97 | if len(sparse.find(a)[-1]) > 0 else torch.sparse.FloatTensor(a.shape[0], a.shape[1]) 98 | for a in A] 99 | # convolve 100 | if not self.featureless: 101 | # 使用特征矩阵x featureless = False 102 | supports = list() 103 | for i in range(self.support): 104 | # 稀疏矩阵相乘 最终维度是 【(结点数, 特征数),,,,】 105 | # 相当于 依据连接关系,将结点周围结点的特征相加最为其特征 106 | # logging.info("A[i]:{} feature:{}".format(A[i].shape, features.shape)) 107 | supports.append(torch.spmm(A[i], features)) 108 | # 按照维度1 拼接 即成为了一个大矩阵 (结点数, self.support * 特证数) 特证数是 feature(x)的列数,默认为结点数 109 | supports = torch.cat(supports, dim=1) 110 | else: 111 | # 不适用特征x featureless = True 112 | values = torch.cat([i._values() for i in A], dim=-1) 113 | temp_list = list() 114 | for i, j in enumerate(A): 115 | j_index = j._indices() 116 | j_index_0 = j._indices()[0] 117 | j_index_re = j._indices()[0].reshape(1, -1) 118 | 119 | tt = [j._indices()[0].reshape(1, -1), 120 | (j._indices()[1] + (i * self.input_dim)).reshape(1, -1)] 121 | temp1 = torch.cat(tt) 122 | temp_list.append(temp1) 123 | indices = torch.cat(temp_list, dim=-1) 124 | # indices = torch.cat([torch.cat([j._indices()[0].reshape(1,-1), 125 | # (j._indices()[1] + (i*self.input_dim)).reshape(1,-1)]) 126 | # for i, j in enumerate(A)], dim=-1) 127 | print("featureless:{} indices:{} values:{}".format(self.featureless, indices.shape, values.shape)) 128 | # 没有特征的输入,就将A[:]拼接 为 (结点数, A长度*结点数)的矩阵 129 | supports = torch.sparse.FloatTensor(indices, values, torch.Size([A[0].shape[0], 130 | len(A) * self.input_dim])) 131 | self.num_nodes = supports.shape[0] 132 | if self.num_bases > 0: 133 | 134 | V = torch.matmul(self.W_comp, 135 | self.W.reshape(self.num_bases, self.input_dim, self.output_dim).permute(1, 0, 2)) 136 | V = torch.reshape(V, (self.support * self.input_dim, self.output_dim)) 137 | output = torch.spmm(supports, V) 138 | else: 139 | # (结点数, A长度* 特征数) * (输入维度*support, output_dim) 140 | output = torch.spmm(supports, self.W) 141 | 142 | # if featureless add dropout to output, by elementwise matmultiplying with column vector of ones, 143 | # with dropout applied to the vector of ones. 144 | if self.featureless: 145 | tmp = torch.ones(self.num_nodes, device=device) 146 | tmp_do = self.dropout(tmp) 147 | output = (output.transpose(1, 0) * tmp_do).transpose(1, 0) 148 | 149 | if self.bias: 150 | output += self.b 151 | return self.activation(output) 152 | 153 | 154 | class GraphSimilarity(nn.Module): 155 | 156 | def __init__(self, input_dim, gcn_hidden_dim, linear_hidden_dim, num_bases, dropout, support=3, max_node_num=100): 157 | """ 158 | 计算两个图的相似度,即二分类,输出为(2, ) 159 | :param input_dim: 初始每个结点的特征维度 160 | :param gcn_hidden_dim: gcn输出的每个结点的特征维度 161 | :param linear_hidden_dim: 多层感知机,中间隐层的神经单元数量 162 | :param num_bases: 默认为-1 163 | :param dropout: 舍弃dropout百分比的参数,不进行更新 164 | :param support: 采用多少个邻接矩阵,因为本项目中 因果事件图,只有3个邻接表 165 | :param max_node_num: 每个图的结点数目。不够的会进行padding 166 | """ 167 | super(GraphSimilarity, self).__init__() 168 | 169 | # input_dim 需要与 feature列数一致 170 | self.gcn_online = GraphConvolution(input_dim, gcn_hidden_dim, num_bases=num_bases, activation="relu", 171 | featureless=False, support=support) 172 | self.gcn_kb = GraphConvolution(input_dim, gcn_hidden_dim, num_bases=num_bases, activation="relu", 173 | featureless=False, support=support) 174 | kernel_size = (max_node_num, 1) 175 | self.pool_1 = torch.nn.MaxPool2d(kernel_size, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False) 176 | self.pool_2 = torch.nn.MaxPool2d(kernel_size, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False) 177 | 178 | # 池化层输出的 向量 拼接后,为(gcn_hidden_dim*2, ) 179 | self.linear_1 = torch.nn.Linear(in_features=gcn_hidden_dim*2, out_features=linear_hidden_dim, bias=True) 180 | self.linear_2 = torch.nn.Linear(in_features=linear_hidden_dim, out_features=2, bias=True) 181 | 182 | self.dropout = nn.Dropout(dropout) 183 | 184 | def forward(self, graph_1, graph_2, mask=None): 185 | # 两个图分别输入 RGCN 中,并最大池化 186 | gcn_info_1 = self.gcn_online(graph_1, mask=mask) 187 | gcn_info_1_drop = self.dropout(gcn_info_1) 188 | gcn_info_1_un = torch.unsqueeze(gcn_info_1_drop, 0) 189 | graph_info_1_pool = self.pool_2(gcn_info_1_un) 190 | graph_info_1_sq = torch.squeeze(graph_info_1_pool, 0) 191 | logging.info("gcn_info_1{} gcn_info_1_drop{} gcn_info_1_un{} graph_info_1_pool{} graph_info_1_sq{}".format(gcn_info_1.shape, gcn_info_1_drop.shape, gcn_info_1_un.shape, 192 | graph_info_1_pool.shape, graph_info_1_sq.shape)) 193 | 194 | # graph_info_1 = self.pool_1(self.dropout(gcn_info_1)) 195 | gcn_info_2 = self.gcn_kb(graph_2, mask=mask) 196 | gcn_info_2_drop = self.dropout(gcn_info_2) 197 | gcn_info_2_un = torch.unsqueeze(gcn_info_2_drop, 0) 198 | graph_info_2_pool = self.pool_2(gcn_info_2_un) 199 | graph_info_2_sq = torch.squeeze(graph_info_2_pool, 0) 200 | logging.info("gcn_info_2:{} gcn_info_2_drop:{} gcn_info_2_un:{} graph_info_2_pool:{} graph_info_2_sq:{}".format( 201 | gcn_info_2.shape, gcn_info_2_drop.shape, gcn_info_2_un.shape, 202 | graph_info_2_pool.shape, graph_info_2_sq.shape)) 203 | # concat 输入多层感知机 204 | 205 | cat_info = torch.cat([graph_info_1_sq, graph_info_2_sq], dim=1) 206 | 207 | cat_info = self.linear_1(cat_info) 208 | cat_info = F.relu(cat_info) 209 | 210 | output = self.linear_2(cat_info) 211 | output = F.relu(output) 212 | # print(output.shape) 213 | output = torch.transpose(output, 0, 1) # Shape: torch.Size([2, 71]) 214 | 215 | # Add batch dimension 216 | # output = output.unsqueeze(0) # Shape: torch.Size([1, 2, 71]) 217 | logging.info("output:{}".format(output.shape)) 218 | # 输出维度为 (2,)tensor 219 | return output -------------------------------------------------------------------------------- /tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | from tensorboardX import SummaryWriter 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | 12 | """ 13 | # tensorboard 在 tensorflow2.1里的 生成图有问题,需要trace计算图,但不知道怎么设置pytorch model 被trace 14 | log_dir = os.path.join(os.path.dirname(__file__), 'logs/%s' % datetime.now().strftime("%Y%m%d-%H%M%S")) 15 | log_dir = log_dir.replace("\\", os.sep) 16 | log_dir = log_dir.replace("/", os.sep) 17 | 18 | tb_logger = TensorBoardLogger(log_dir, 19 | trace_on=True) 20 | tb_logger.print_tensoroard_logs(model=model, step=epoch, loss=loss_all.item(), accuracy=accuary_all_num.item()/preds_all_num.item()) 21 | logging.error("epoch:{} loss:{} accuracy:{}/{}={}".format(epoch, loss_all, accuary_all_num, preds_all_num, 22 | int(accuary_all_num)/int(preds_all_num))) 23 | tb_logger.write_flush() 24 | 25 | """ 26 | # class TensorBoardLogger(object): 27 | # # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/logger.py 28 | # # https://tensorflow.google.cn/api_docs/python/tf/summary?hl=zh-CN 29 | # # tensorbard --logdir='./logs' 30 | # # https://www.cnblogs.com/rainydayfmb/p/7944224.html 31 | # def __init__(self, log_dir, trace_on=True): 32 | # """Create a summary writer logging to log_dir.""" 33 | # self.log_dir = log_dir 34 | # self.writer = tf.summary.create_file_writer(log_dir) 35 | # if trace_on: 36 | # self.trace_on() 37 | # 38 | # def scalar_summary(self, tag, value, step): 39 | # """Log a scalar variable.""" 40 | # with self.writer.as_default(): 41 | # tf.summary.scalar(name=tag, data=value, step=step) 42 | # 43 | # 44 | # def image_summary(self, tag, images, step): 45 | # """Log a list of images.""" 46 | # 47 | # img_summaries = [] 48 | # for i, img in enumerate(images): 49 | # # Write the image to a string 50 | # try: 51 | # s = StringIO() 52 | # except: 53 | # s = BytesIO() 54 | # scipy.misc.toimage(img).save(s, format="png") 55 | # 56 | # # Create an Image object 57 | # img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 58 | # height=img.shape[0], 59 | # width=img.shape[1]) 60 | # # Create a Summary value 61 | # img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 62 | # 63 | # # Create and write Summary 64 | # summary = tf.Summary(value=img_summaries) 65 | # self.writer.add_summary(summary, step) 66 | # 67 | # def histo_summary(self, tag, values, step): 68 | # with self.writer.as_default(): 69 | # tf.summary.histogram(name=tag, data=values, step=step) 70 | # 71 | # 72 | # def trace_on(self): 73 | # tf.summary.trace_on( 74 | # graph=True, profiler=True 75 | # ) 76 | # 77 | # def trace_export(self, tag, step, log_dir): 78 | # with self.writer.as_default(): 79 | # tf.summary.trace_export( 80 | # # tag, step=step, profiler_outdir=os.path.join(os.path.dirname(__file__), self.log_dir) 81 | # tag, step=step, profiler_outdir=log_dir 82 | # ) 83 | # 84 | # def write_flush(self): 85 | # self.writer.flush() 86 | # 87 | # def print_tensoroard_logs(self, model, step, loss, accuracy): 88 | # if step == 1: 89 | # self.trace_export(tag="graph", step=step, log_dir=self.log_dir) 90 | # 91 | # def to_np(x): 92 | # return x.cpu().data.numpy() 93 | # 94 | # info = { 95 | # 'loss': loss, 96 | # 'accuracy': accuracy 97 | # } 98 | # 99 | # for tag, value in info.items(): 100 | # self.scalar_summary(tag, value, step) 101 | # 102 | # # (2) Log values and gradients of the parameters (histogram) 103 | # for tag, value in model.named_parameters(): 104 | # tag = tag.replace('.', '/') 105 | # self.histo_summary(tag, to_np(value), step) 106 | # self.histo_summary(tag + '/grad', to_np(value.grad), step) 107 | # 108 | # # (3) Log the images 109 | # # info = { 110 | # # 'images': to_np(img.view(-1, 28, 28)[:10]) 111 | # # } 112 | # # 113 | # # for tag, images in info.items(): 114 | # # logger.image_summary(tag, images, step) 115 | 116 | # 从tensorboardx中引入的api 117 | class TensorBoardWritter(): 118 | # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/logger.py 119 | # https://tensorflow.google.cn/api_docs/python/tf/summary?hl=zh-CN 120 | # tensorbard --logdir='./logs' 121 | # https://www.cnblogs.com/rainydayfmb/p/7944224.html 122 | def __init__(self, log_dir=None, comment=''): 123 | """Create a summary writer logging to log_dir. comment会赋在文件夹名后""" 124 | self.log_dir = log_dir 125 | if log_dir: 126 | self.writer = SummaryWriter(log_dir=log_dir, comment=comment) 127 | else: 128 | self.writer = SummaryWriter(comment=comment) 129 | 130 | def print_tensoroard_logs(self, model, info_dict): 131 | sample_data = info_dict["sample_data"] 132 | step = info_dict["step"] 133 | loss = info_dict["loss"] 134 | loss_val = info_dict["loss_val"] 135 | loss_test = info_dict["loss_test"] 136 | 137 | accuracy = info_dict["accuracy"] 138 | accuary_val = info_dict["accuary_val"] 139 | accuary_test = info_dict["accuary_test"] 140 | recall_train = info_dict['recall_train'] 141 | recall_val = info_dict['recall_val'] 142 | recall_test = info_dict['recall_test'] 143 | precision_train = info_dict['precision_train'] 144 | precision_val = info_dict['precision_val'] 145 | precision_test = info_dict['precision_test'] 146 | F1_train = info_dict['F1_train'] 147 | F1_val = info_dict['F1_val'] 148 | F1_test = info_dict['F1_test'] 149 | outputs_all = info_dict["outputs_all"] 150 | 151 | train_pos_neg = info_dict["train_pos_neg"] 152 | val_pos_neg = info_dict["val_pos_neg"] 153 | test_pos_neg = info_dict["test_pos_neg"] 154 | cross_weight_auto = info_dict["cross_weight_auto"] 155 | class_ac_train1, class_ac_train3, class_ac_train2 = info_dict["class_ac_train"] 156 | class_ac_val1, class_ac_val3, class_ac_val2 = info_dict["class_ac_val"] 157 | class_ac_test1, class_ac_test3, class_ac_test2 = info_dict["class_ac_test"] 158 | 159 | if step == 0: 160 | self.writer.add_graph(model, sample_data) 161 | self.writer.flush() 162 | 163 | def to_np(x): 164 | return x.cpu().data.numpy() 165 | 166 | info = { 167 | 'loss/train': loss, 168 | 'loss/test': loss_test, 169 | 'loss/val': loss_val, 170 | 'accuracy/train': accuracy, 171 | "accuracy/val": accuary_val, 172 | "accuracy/test":accuary_test, 173 | "class_ac/class_ac1_train1":class_ac_train1, 174 | "class_ac/class_ac1_val":class_ac_val1, 175 | "class_ac/class_ac1_test":class_ac_test1, 176 | "class_ac/class_ac3_train": class_ac_train3, 177 | "class_ac/class_ac3_val": class_ac_val3, 178 | "class_ac/class_ac3_test": class_ac_test3, 179 | "class_ac/class_ac2_train": class_ac_train2, 180 | "class_ac/class_ac2_val": class_ac_val2, 181 | "class_ac/class_ac2_test": class_ac_test2, 182 | 183 | "assess/precision/train": precision_train, 184 | "assess/precision/val": precision_val, 185 | "assess/precision/test": precision_test, 186 | "assess/recall/train": recall_train, 187 | "assess/recall/val": recall_val, 188 | "assess/recall/test": recall_test, 189 | "assess/F1/train": F1_train, 190 | "assess/F1/val": F1_val, 191 | "assess/F1/test": F1_test, 192 | 193 | } 194 | 195 | for tag, value in info.items(): 196 | self.writer.add_scalar(tag, value, step) 197 | 198 | # (2) Log values and gradients of the parameters (histogram) 199 | for tag, value in model.named_parameters(): 200 | tag = tag.replace('.', '/') 201 | self.writer.add_histogram(tag, to_np(value), step) 202 | self.writer.add_histogram(tag + '/grad', to_np(value.grad), step) 203 | 204 | self.writer.add_histogram("outputs", to_np(outputs_all), step) 205 | self.writer.add_scalar("accuracy_base/train", train_pos_neg.max(), step) 206 | self.writer.add_scalar("accuracy_base/test", test_pos_neg.max(), step) 207 | self.writer.add_scalar("accuracy_base/val", val_pos_neg.max(), step) 208 | self.writer.add_scalar("accuracy_base/cross_weight_auto_0", cross_weight_auto[0], step) 209 | self.writer.add_scalar("accuracy_base/cross_weight_auto_1", cross_weight_auto[1], step) 210 | 211 | 212 | # (3) Log the images 213 | # info = { 214 | # 'images': to_np(img.view(-1, 28, 28)[:10]) 215 | # } 216 | # 217 | # for tag, images in info.items(): 218 | # logger.image_summary(tag, images, step) 219 | 220 | 221 | if __name__ == '__main__': 222 | """ 223 | 启动 tensorboard时 -log_dir=logs -log_dir=".//logs" 224 | 保存graph时 需要路径分割符号为 "\" 225 | """ 226 | pass 227 | # # writer = tf.summary.create_file_writer("./logs") 228 | # # 229 | # # 230 | # # @tf.function 231 | # # def my_func(step): 232 | # # with writer.as_default(): 233 | # # # other model code would go here 234 | # # tf.summary.scalar("my_metric", 0.5, step=step) 235 | # # 236 | # # 237 | # # for step in tf.range(100, dtype=tf.int64): 238 | # # my_func(step) 239 | # # writer.flush() 240 | # # The function to be traced. 241 | # @tf.function 242 | # def my_func(x, y): 243 | # # A simple hand-rolled layer. 244 | # return tf.nn.relu(tf.matmul(x, y)) 245 | # 246 | # 247 | # # Set up logging. 248 | # stamp = datetime.now().strftime("%Y%m%d-%H%M%S") 249 | # logdir = 'logs\%s' % stamp 250 | # writer = tf.summary.create_file_writer(logdir) 251 | # 252 | # # Sample data for your function. 253 | # x = tf.random.uniform((3, 3)) 254 | # y = tf.random.uniform((3, 3)) 255 | # 256 | # # Bracket the function call with 257 | # # tf.summary.trace_on() and tf.summary.trace_export(). 258 | # tf.summary.trace_on(graph=True, profiler=True) 259 | # # Call only one tf.function when tracing. 260 | # z = my_func(x, y) 261 | # with writer.as_default(): 262 | # tf.summary.trace_export( 263 | # name="my_func_trace", 264 | # step=0, 265 | # profiler_outdir=logdir) -------------------------------------------------------------------------------- /Radm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 9 | if not 0.0 <= lr: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if not 0.0 <= eps: 12 | raise ValueError("Invalid epsilon value: {}".format(eps)) 13 | if not 0.0 <= betas[0] < 1.0: 14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 15 | if not 0.0 <= betas[1] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 17 | 18 | self.degenerated_to_sgd = degenerated_to_sgd 19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 20 | for param in params: 21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 22 | param['buffer'] = [[None, None, None] for _ in range(10)] 23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 24 | buffer=[[None, None, None] for _ in range(10)]) 25 | super(RAdam, self).__init__(params, defaults) 26 | 27 | def __setstate__(self, state): 28 | super(RAdam, self).__setstate__(state) 29 | 30 | def step(self, closure=None): 31 | 32 | loss = None 33 | if closure is not None: 34 | loss = closure() 35 | 36 | for group in self.param_groups: 37 | 38 | for p in group['params']: 39 | if p.grad is None: 40 | continue 41 | grad = p.grad.data.float() 42 | if grad.is_sparse: 43 | raise RuntimeError('RAdam does not support sparse gradients') 44 | 45 | p_data_fp32 = p.data.float() 46 | 47 | state = self.state[p] 48 | 49 | if len(state) == 0: 50 | state['step'] = 0 51 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 52 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 53 | else: 54 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 55 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 56 | 57 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 58 | beta1, beta2 = group['betas'] 59 | 60 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 61 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 62 | 63 | state['step'] += 1 64 | buffered = group['buffer'][int(state['step'] % 10)] 65 | if state['step'] == buffered[0]: 66 | N_sma, step_size = buffered[1], buffered[2] 67 | else: 68 | buffered[0] = state['step'] 69 | beta2_t = beta2 ** state['step'] 70 | N_sma_max = 2 / (1 - beta2) - 1 71 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 72 | buffered[1] = N_sma 73 | 74 | # more conservative since it's an approximated value 75 | if N_sma >= 5: 76 | step_size = math.sqrt( 77 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 78 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 79 | elif self.degenerated_to_sgd: 80 | step_size = 1.0 / (1 - beta1 ** state['step']) 81 | else: 82 | step_size = -1 83 | buffered[2] = step_size 84 | 85 | # more conservative since it's an approximated value 86 | if N_sma >= 5: 87 | if group['weight_decay'] != 0: 88 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 89 | denom = exp_avg_sq.sqrt().add_(group['eps']) 90 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 91 | p.data.copy_(p_data_fp32) 92 | elif step_size > 0: 93 | if group['weight_decay'] != 0: 94 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 95 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 96 | p.data.copy_(p_data_fp32) 97 | 98 | return loss 99 | 100 | 101 | class PlainRAdam(Optimizer): 102 | 103 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 104 | if not 0.0 <= lr: 105 | raise ValueError("Invalid learning rate: {}".format(lr)) 106 | if not 0.0 <= eps: 107 | raise ValueError("Invalid epsilon value: {}".format(eps)) 108 | if not 0.0 <= betas[0] < 1.0: 109 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 110 | if not 0.0 <= betas[1] < 1.0: 111 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 112 | 113 | self.degenerated_to_sgd = degenerated_to_sgd 114 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 115 | 116 | super(PlainRAdam, self).__init__(params, defaults) 117 | 118 | def __setstate__(self, state): 119 | super(PlainRAdam, self).__setstate__(state) 120 | 121 | def step(self, closure=None): 122 | 123 | loss = None 124 | if closure is not None: 125 | loss = closure() 126 | 127 | for group in self.param_groups: 128 | 129 | for p in group['params']: 130 | if p.grad is None: 131 | continue 132 | grad = p.grad.data.float() 133 | if grad.is_sparse: 134 | raise RuntimeError('RAdam does not support sparse gradients') 135 | 136 | p_data_fp32 = p.data.float() 137 | 138 | state = self.state[p] 139 | 140 | if len(state) == 0: 141 | state['step'] = 0 142 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 143 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 144 | else: 145 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 146 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 147 | 148 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 149 | beta1, beta2 = group['betas'] 150 | 151 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 152 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 153 | 154 | state['step'] += 1 155 | beta2_t = beta2 ** state['step'] 156 | N_sma_max = 2 / (1 - beta2) - 1 157 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 158 | 159 | # more conservative since it's an approximated value 160 | if N_sma >= 5: 161 | if group['weight_decay'] != 0: 162 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 163 | step_size = group['lr'] * math.sqrt( 164 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 165 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 166 | denom = exp_avg_sq.sqrt().add_(group['eps']) 167 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 168 | p.data.copy_(p_data_fp32) 169 | elif self.degenerated_to_sgd: 170 | if group['weight_decay'] != 0: 171 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 172 | step_size = group['lr'] / (1 - beta1 ** state['step']) 173 | p_data_fp32.add_(-step_size, exp_avg) 174 | p.data.copy_(p_data_fp32) 175 | 176 | return loss 177 | 178 | 179 | class AdamW(Optimizer): 180 | 181 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): 182 | if not 0.0 <= lr: 183 | raise ValueError("Invalid learning rate: {}".format(lr)) 184 | if not 0.0 <= eps: 185 | raise ValueError("Invalid epsilon value: {}".format(eps)) 186 | if not 0.0 <= betas[0] < 1.0: 187 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 188 | if not 0.0 <= betas[1] < 1.0: 189 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 190 | 191 | defaults = dict(lr=lr, betas=betas, eps=eps, 192 | weight_decay=weight_decay, warmup=warmup) 193 | super(AdamW, self).__init__(params, defaults) 194 | 195 | def __setstate__(self, state): 196 | super(AdamW, self).__setstate__(state) 197 | 198 | def step(self, closure=None): 199 | loss = None 200 | if closure is not None: 201 | loss = closure() 202 | 203 | for group in self.param_groups: 204 | 205 | for p in group['params']: 206 | if p.grad is None: 207 | continue 208 | grad = p.grad.data.float() 209 | if grad.is_sparse: 210 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 211 | 212 | p_data_fp32 = p.data.float() 213 | 214 | state = self.state[p] 215 | 216 | if len(state) == 0: 217 | state['step'] = 0 218 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 219 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 220 | else: 221 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 222 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 223 | 224 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 225 | beta1, beta2 = group['betas'] 226 | 227 | state['step'] += 1 228 | 229 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 230 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 231 | 232 | denom = exp_avg_sq.sqrt().add_(group['eps']) 233 | bias_correction1 = 1 - beta1 ** state['step'] 234 | bias_correction2 = 1 - beta2 ** state['step'] 235 | 236 | if group['warmup'] > state['step']: 237 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 238 | else: 239 | scheduled_lr = group['lr'] 240 | 241 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 242 | 243 | if group['weight_decay'] != 0: 244 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 245 | 246 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 247 | 248 | p.data.copy_(p_data_fp32) 249 | 250 | return loss -------------------------------------------------------------------------------- /model_batch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from scipy import sparse 7 | import numpy as np 8 | 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | # device = torch.device("cpu") 13 | 14 | class GraphConvolution(nn.Module): 15 | def __init__(self, input_dim, output_dim, support=1, featureless=True, 16 | init='glorot_uniform', activation='linear', 17 | weights=None, W_regularizer=None, num_bases=-1, 18 | b_regularizer=None, bias=False, dropout=0., max_node_num=100, **kwargs) : 19 | """ 20 | 21 | :param input_dim: 输入维度,有特征矩阵时对应初始特征矩阵的词嵌入维度,无特征矩阵时,对应输入结点个数 22 | :param output_dim: 超参数,隐藏层的units数目 23 | :param support: A 的长度 感觉应该是 边种类数 * 2 + 1 24 | :param featureless: 使用或者忽略输入的features 25 | :param init: 26 | :param activation: 激活函数 27 | :param weights: 28 | :param W_regularizer: 29 | :param num_bases: 使用的bases数量 (-1:all) 30 | :param b_regularizer: 31 | :param bias: 32 | :param dropout: 33 | :param kwargs: 34 | """ 35 | super(GraphConvolution, self).__init__() 36 | # self.init = initializations.get(init) 37 | # self.activation = activations.get(activation) 38 | if activation == "relu": 39 | self.activation = nn.ReLU() 40 | elif activation == "softmax": 41 | self.activation = nn.Softmax(dim=-1) 42 | else: 43 | self.activation = F.ReLU() 44 | self.input_dim = input_dim 45 | self.output_dim = output_dim # number of features per node 46 | self.support = support # filter support / number of weights 47 | self.featureless = featureless # use/ignore input features 48 | self.dropout = dropout 49 | self.w_regularizer = nn.L1Loss() 50 | 51 | assert support >= 1 52 | 53 | # TODO 54 | 55 | self.bias = bias 56 | self.initial_weights = weights 57 | self.num_bases = num_bases 58 | 59 | # these will be defined during build() 60 | # self.input_dim = None 61 | if self.num_bases > 0: 62 | # 使用的bases数大于0 63 | self.W = nn.Parameter( 64 | torch.empty(self.input_dim * self.num_bases, self.output_dim, dtype=torch.float32, device=device)) 65 | self.W_comp = nn.Parameter(torch.empty(self.support, self.num_bases, dtype=torch.float32, device=device)) 66 | # nn.init.xavier_uniform_(self.W_comp) # 通过网络层时,输入和输出的方差相同 67 | torch.nn.init.kaiming_normal_(self.W_comp, a=0, mode='fan_in', nonlinearity='leaky_relu') 68 | else: 69 | self.W = nn.Parameter( 70 | torch.empty(self.input_dim * self.support, self.output_dim, dtype=torch.float32, device=device)) 71 | # nn.init.xavier_uniform_(self.W) 72 | torch.nn.init.kaiming_normal_(self.W, a=0, mode='fan_in', nonlinearity='leaky_relu') 73 | 74 | if self.bias: 75 | self.b = nn.Parameter(torch.empty(max_node_num, self.output_dim, dtype=torch.float32, device=device)) 76 | # nn.init.xavier_uniform_(self.b) 77 | torch.nn.init.kaiming_normal_(self.b, a=0, mode='fan_in', nonlinearity='leaky_relu') 78 | """ 79 | Dropout就是在不同的训练过程中随机扔掉一部分神经元。也就是让某个神经元的激活值以一定的概率p,让其停止工作, 80 | 这次训练过程中不更新权值,也不参加神经网络的计算。但是它的权重得保留下来(只是暂时不更新而已), 81 | 因为下次样本输入时它可能又得工作了 82 | """ 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def get_output_shape_for(self, input_shapes): 86 | features_shape = input_shapes[0] 87 | output_shape = (features_shape[0], self.output_dim) 88 | return output_shape # (batch_size, output_dim) 89 | 90 | def forward(self, inputs, mask=None): 91 | """ 92 | 输入多张图 93 | :param inputs: () 94 | :param mask: 95 | :return: 96 | """ 97 | features, A_list = inputs[0], inputs[1] 98 | batch_size = features.shape[0] 99 | node_num = features.shape[1] 100 | feature_dim = features.shape[2] 101 | # 相当于 cir 进行归一化 102 | A_list = F.normalize(A_list, p=1, dim=3) 103 | if not self.featureless: 104 | A_list = A_list.narrow(dim=1, start=0, length=self.support) 105 | features = features.unsqueeze(dim=1) 106 | supports = torch.matmul(A_list, features) 107 | supports = supports.transpose(1, 2) 108 | supports = supports.reshape([batch_size, 1, node_num, self.support*feature_dim]) 109 | else: 110 | supports = A_list.transpose(1, 2) 111 | supports = supports.reshape([batch_size, 1, node_num, A_list.shape[1] * node_num]) 112 | supports = supports.squeeze(dim=1) 113 | 114 | if self.num_bases > 0: 115 | 116 | V = torch.matmul(self.W_comp, 117 | self.W.reshape(self.num_bases, self.input_dim, self.output_dim).permute(1, 0, 2)) 118 | V = torch.reshape(V, (self.support * self.input_dim, self.output_dim)) 119 | output = torch.matmul(supports, V) 120 | else: 121 | # (结点数, A长度* 特征数) * (输入维度*support, output_dim) 122 | output = torch.matmul(supports, self.W) 123 | 124 | # if featureless add dropout to output, by elementwise matmultiplying with column vector of ones, 125 | # with dropout applied to the vector of ones. 126 | if self.featureless: 127 | tmp = torch.ones(self.num_nodes, device=device) 128 | tmp_do = self.dropout(tmp) 129 | tmp_do_stack = torch.stack([tmp_do for i in range(output.shape[0])], dim=0) 130 | output = (output.transpose(1, 0) * tmp_do_stack).transpose(1, 0) 131 | 132 | if self.bias: 133 | output += self.b 134 | return output 135 | 136 | 137 | 138 | 139 | class GraphSimilarity(nn.Module): 140 | 141 | def __init__(self, input_dim, gcn_hidden_dim, linear_hidden_dim, pool_step, num_bases, dropout, support=3, max_node_num=100): 142 | """ 143 | 计算两个图的相似度,即二分类,输出为(2, ) 144 | :param input_dim: 初始每个结点的特征维度 145 | :param gcn_hidden_dim: gcn输出的每个结点的特征维度 146 | :param linear_hidden_dim: 多层感知机,中间隐层的神经单元数量 147 | :param num_bases: 默认为-1 148 | :param dropout: 舍弃dropout百分比的参数,不进行更新 149 | :param support: 采用多少个邻接矩阵,因为本项目中 因果事件图,只有3个邻接表 150 | :param max_node_num: 每个图的结点数目。不够的会进行padding 151 | """ 152 | super(GraphSimilarity, self).__init__() 153 | 154 | # input_dim 需要与 feature列数一致 155 | self.gcn_online = GraphConvolution(input_dim, gcn_hidden_dim, num_bases=num_bases, activation="relu", 156 | featureless=False, support=support, bias=True, max_node_num=max_node_num) 157 | self.gcn_kb = GraphConvolution(input_dim, gcn_hidden_dim, num_bases=num_bases, activation="relu", 158 | featureless=False, support=support, bias=True, max_node_num=max_node_num) 159 | pool_step = pool_step 160 | kernel_size = (max_node_num//pool_step, 1) 161 | self.pool_1 = torch.nn.MaxPool2d(kernel_size, stride=(max_node_num//pool_step, 1), padding=0, dilation=1, return_indices=False, ceil_mode=False) 162 | self.pool_2 = torch.nn.MaxPool2d(kernel_size, stride=(max_node_num//pool_step, 1), padding=0, dilation=1, return_indices=False, ceil_mode=False) 163 | 164 | # 池化层输出的 向量 拼接后,为(gcn_hidden_dim*2, ) 165 | self.linear_1 = torch.nn.Linear(in_features=gcn_hidden_dim, out_features=linear_hidden_dim, bias=True) 166 | self.linear_2 = torch.nn.Linear(in_features=linear_hidden_dim, out_features=2, bias=True) 167 | 168 | self.door = torch.nn.Parameter(torch.empty(1, pool_step, dtype=torch.float32, device=device)) 169 | # nn.init.xavier_uniform_(self.door) # 通过网络层时,输入和输出的方差相同 170 | torch.nn.init.kaiming_normal_(self.door, a=0, mode='fan_in', nonlinearity='leaky_relu') 171 | 172 | self.node_choose_w = torch.nn.Parameter( 173 | torch.empty(1, pool_step, dtype=torch.float32, device=device)) 174 | # nn.init.xavier_uniform_(self.node_choose_w) # 通过网络层时,输入和输出的方差相同 175 | torch.nn.init.kaiming_normal_(self.node_choose_w, a=0, mode='fan_in', nonlinearity='leaky_relu') 176 | 177 | self.dropout = nn.Dropout(dropout) 178 | self.activation = F.relu 179 | 180 | def forward(self, graphs_1, graphs_2, mask=None): 181 | # 两个图分别输入 RGCN 中,并最大池化 182 | gcn_info_1 = self.gcn_online(graphs_1, mask=mask) 183 | gcn_info_1_drop = self.dropout(gcn_info_1) 184 | # gcn_info_1_drop = gcn_info_1 185 | gcn_info_1_ac = self.activation(gcn_info_1_drop) 186 | graph_info_1_pool = self.pool_1(gcn_info_1_ac) 187 | # logging.info("gcn_info_1{} gcn_info_1_drop{} graph_info_1_pool{} ".format( 188 | # np.array(gcn_info_1.shape), np.array(gcn_info_1_drop.shape), np.array(graph_info_1_pool.shape))) 189 | 190 | gcn_info_2 = self.gcn_kb(graphs_2, mask=mask) 191 | gcn_info_2_drop = self.dropout(gcn_info_2) 192 | # gcn_info_2_drop = gcn_info_2 193 | gcn_info_2_ac = self.activation(gcn_info_2_drop) 194 | graph_info_2_pool = self.pool_2(gcn_info_2_ac) 195 | # logging.info("gcn_info_2:{} gcn_info_2_drop:{} graph_info_2_pool:{} ".format( 196 | # np.array(gcn_info_2.shape), np.array(gcn_info_2_drop.shape), np.array(graph_info_2_pool.shape))) 197 | 198 | # concat 输入多层感知机 199 | # 方式1: 前后相连 200 | # cat_info = torch.cat([graph_info_1_pool, graph_info_2_pool], dim=2) 201 | # 方式2: 门控制 202 | cat_info = (self.door * graph_info_1_pool.transpose(1, 2) + (1.0-self.door) * graph_info_2_pool.transpose(1, 2) 203 | ).transpose(1, 2) 204 | 205 | # 方式3:zhuan两维以上向量 206 | # pass 207 | 208 | cat_info = torch.matmul(self.node_choose_w, cat_info).squeeze(dim=1) 209 | # cat_info = cat_info.squeeze(dim=1) 210 | 211 | cat_info = self.linear_1(cat_info) 212 | cat_info = self.activation(cat_info) 213 | 214 | output = self.linear_2(cat_info) 215 | # output = torch.nn.functional.softmax(output, dim=1) 216 | output = self.activation(output) 217 | # logging.info("output:{}".format(output.shape)) 218 | # 输出维度为 (2,)tensor 219 | # assert int(output.shape[0]) == int(graphs_1[0].shape[0]) == int(graphs_2[0].shape[0]) 220 | return output 221 | 222 | 223 | class GraphSimilarity_No_Gcn(nn.Module): 224 | """去除gcn""" 225 | def __init__(self, input_dim, gcn_hidden_dim, linear_hidden_dim, pool_step, num_bases, dropout, support=3, max_node_num=100): 226 | super(GraphSimilarity_No_Gcn, self).__init__() 227 | 228 | # input_dim 需要与 feature列数一致 229 | 230 | pool_step = pool_step 231 | kernel_size = (max_node_num // pool_step, 1) 232 | self.pool_1 = torch.nn.MaxPool2d(kernel_size, stride=(max_node_num // pool_step, 1), padding=0, dilation=1, 233 | return_indices=False, ceil_mode=False) 234 | self.pool_2 = torch.nn.MaxPool2d(kernel_size, stride=(max_node_num // pool_step, 1), padding=0, dilation=1, 235 | return_indices=False, ceil_mode=False) 236 | 237 | # 池化层输出的 向量 拼接后,为(gcn_hidden_dim*2, ) 238 | self.linear_1 = torch.nn.Linear(in_features=input_dim, out_features=linear_hidden_dim, bias=True) 239 | self.linear_2 = torch.nn.Linear(in_features=linear_hidden_dim, out_features=2, bias=True) 240 | 241 | self.door = torch.nn.Parameter(torch.empty(1, pool_step, dtype=torch.float32, device=device)) 242 | # nn.init.xavier_uniform_(self.door) # 通过网络层时,输入和输出的方差相同 243 | torch.nn.init.kaiming_normal_(self.door, a=0, mode='fan_in', nonlinearity='leaky_relu') 244 | 245 | self.node_choose_w = torch.nn.Parameter( 246 | torch.empty(1, pool_step, dtype=torch.float32, device=device)) 247 | # nn.init.xavier_uniform_(self.node_choose_w) # 通过网络层时,输入和输出的方差相同 248 | torch.nn.init.kaiming_normal_(self.node_choose_w, a=0, mode='fan_in', nonlinearity='leaky_relu') 249 | 250 | self.dropout = nn.Dropout(dropout) 251 | self.activation = F.relu 252 | pass 253 | 254 | def forward(self, graphs_1, graphs_2, mask=None): 255 | # 两个图分别输入 RGCN 中,并最大池化 256 | gcn_info_1 = graphs_1[0] 257 | gcn_info_2 = graphs_2[0] 258 | 259 | graph_info_1_pool = self.pool_1(gcn_info_1) 260 | graph_info_2_pool = self.pool_2(gcn_info_2) 261 | 262 | cat_info = (self.door * graph_info_1_pool.transpose(1, 2) + (1.0 - self.door) * graph_info_2_pool.transpose(1, 263 | 2) 264 | ).transpose(1, 2) 265 | 266 | cat_info = torch.matmul(self.node_choose_w, cat_info).squeeze(dim=1) 267 | # cat_info = cat_info.squeeze(dim=1) 268 | 269 | cat_info = self.linear_1(cat_info) 270 | cat_info = self.activation(cat_info) 271 | 272 | output = self.linear_2(cat_info) 273 | # output = torch.nn.functional.softmax(output, dim=1) 274 | output = self.activation(output) 275 | # logging.info("output:{}".format(output.shape)) 276 | # 输出维度为 (2,)tensor 277 | # assert int(output.shape[0]) == int(graphs_1[0].shape[0]) == int(graphs_2[0].shape[0]) 278 | return output 279 | pass 280 | 281 | 282 | class GraphSimilarity_No_KB(nn.Module): 283 | """去除KB 直接输入图之后 进行分类""" 284 | def __init__(self, input_dim, gcn_hidden_dim, linear_hidden_dim, out_dim, pool_step, num_bases, dropout, support=3, max_node_num=100): 285 | super(GraphSimilarity_No_KB, self).__init__() 286 | 287 | # input_dim 需要与 feature列数一致 288 | self.gcn_online = GraphConvolution(input_dim, gcn_hidden_dim, num_bases=num_bases, activation="relu", 289 | featureless=False, support=support, bias=True, max_node_num=max_node_num) 290 | 291 | pool_step = pool_step 292 | kernel_size = (max_node_num//pool_step, 1) 293 | self.pool_1 = torch.nn.MaxPool2d(kernel_size, stride=(max_node_num//pool_step, 1), padding=0, dilation=1, return_indices=False, ceil_mode=False) 294 | 295 | # 池化层输出的 向量 拼接后,为(gcn_hidden_dim*2, ) 296 | self.linear_1 = torch.nn.Linear(in_features=gcn_hidden_dim, out_features=linear_hidden_dim, bias=True) 297 | self.linear_2 = torch.nn.Linear(in_features=linear_hidden_dim, out_features=out_dim, bias=True) 298 | 299 | self.node_choose_w = torch.nn.Parameter( 300 | torch.empty(1, pool_step, dtype=torch.float32, device=device)) 301 | # nn.init.xavier_uniform_(self.node_choose_w) # 通过网络层时,输入和输出的方差相同 302 | torch.nn.init.kaiming_normal_(self.node_choose_w, a=0, mode='fan_in', nonlinearity='leaky_relu') 303 | 304 | self.dropout = nn.Dropout(dropout) 305 | self.activation = F.relu 306 | 307 | 308 | def forward(self, graphs_1, graphs_2, mask=None): 309 | gcn_info_1 = self.gcn_online(graphs_1, mask=mask) 310 | gcn_info_1_drop = self.dropout(gcn_info_1) 311 | gcn_info_1_ac = self.activation(gcn_info_1_drop) 312 | graph_info_1_pool = self.pool_1(gcn_info_1_ac) 313 | 314 | cat_info = torch.matmul(self.node_choose_w, graph_info_1_pool).squeeze(dim=1) 315 | 316 | cat_info = self.linear_1(cat_info) 317 | cat_info = self.activation(cat_info) 318 | 319 | output = self.linear_2(cat_info) 320 | # output = torch.nn.functional.softmax(output, dim=1) 321 | output = self.activation(output) 322 | return output 323 | 324 | 325 | if __name__ == '__main__': 326 | a = [ 327 | [[[1,1],[2,2]], 328 | [[3,3],[4,4]], 329 | [[5,5],[6,6]]], 330 | 331 | [[[7,7],[8,8]], 332 | [[9,9],[9,9]], 333 | [[10,10],[10,10]] 334 | ] 335 | ] 336 | b = [ 337 | [[1,1,1], 338 | [2,2,2]], 339 | [[3,3,3], 340 | [4,4,4]] 341 | ] 342 | 343 | c = [[1,1], 344 | [2,2], 345 | [3,3]] 346 | 347 | 348 | 349 | 350 | tensor_a = torch.as_tensor(a, dtype=torch.float32) 351 | # logging.info("narrow:{} after:\n{}".format(tensor_a, tensor_a.narrow( dim=1, start=0, length=1))) 352 | tensor_b = torch.as_tensor(b, dtype=torch.float32) 353 | tensor_c = torch.as_tensor(c, dtype=torch.float32) 354 | # tensor_b = tensor_b.unsqueeze(dim=1) 355 | tensor_mal = torch.matmul(tensor_b, tensor_c) 356 | print(tensor_a.shape) 357 | print(tensor_b.shape) 358 | print(tensor_c.shape) 359 | print(tensor_mal.shape) 360 | # logging.info("tensor_a:{} ".format(tensor_a)) 361 | print(tensor_b) 362 | print(tensor_c) 363 | print("tensor_mal:{}".format(tensor_mal)) 364 | 365 | tmp = torch.ones(2) 366 | drop = nn.Dropout(0.5) 367 | tmp_do = drop(tmp) 368 | print("tmp_do:{}".format(tmp_do)) 369 | tmp_do_stack = torch.stack([tmp_do for i in range(tensor_mal.shape[0])], dim=0) 370 | output1 = (tensor_mal[0].transpose(1, 0) * tmp_do).transpose(1, 0) 371 | tensor_mal_trans = tensor_mal.transpose(1,2) 372 | output2 = (tensor_mal_trans * tmp_do_stack).transpose(1, 2) 373 | print("output1:{}".format(output1)) 374 | print("output2:{}".format(output2)) 375 | 376 | bias = torch.as_tensor([1,2,3], dtype=torch.float32) 377 | print("b:{}".format(tensor_b)) 378 | print("bias:{}".format(bias)) 379 | print("sum:{}".format(tensor_b + bias)) 380 | 381 | b_2 = [ 382 | [[1,1,1], 383 | [2,2,2], 384 | [1,1,1], 385 | [2,2,2]], 386 | [[3,3,3], 387 | [4,4,4], 388 | [3,3,3], 389 | [4,4,4]] 390 | ] 391 | tensor_b_2 = torch.as_tensor(b_2, device=device, dtype=torch.float32) 392 | torch.rand((4,), dtype=torch.float32, device=device) 393 | tensor_b_2.mul() 394 | 395 | pool = torch.nn.MaxPool2d((2,1), stride=(2,1), padding=0, dilation=1, return_indices=False, 396 | ceil_mode=False) 397 | pool_r = pool.forward(tensor_b_2) 398 | print("tensor_b_2:{}".format(tensor_b_2)) 399 | print("pool_r:{}".format(pool_r)) 400 | rr = torch.cat([pool_r, pool_r], dim=2) 401 | cat = rr.squeeze(dim=1) 402 | print("cat:{}".format(cat)) 403 | linear = torch.nn.Linear(in_features= 6, out_features=2, bias=True) 404 | pre = linear.forward(cat) 405 | print("pre:{}".format(pre)) 406 | 407 | linear2 = torch.nn.Linear(in_features=3, out_features=2, bias=True) 408 | batch_2 = linear2.forward(tensor_b) 409 | print("batch_2:{}".format(batch_2)) 410 | 411 | # y = tensor_mal.transpose(1,2) 412 | # print(y) 413 | # y = y.reshape([2, 1, 2, 9]) 414 | # print(y) 415 | 416 | # 归一化 417 | m = [ 418 | [[[0,1],[1,1]], 419 | [[2,3],[2,4]], 420 | [[5,5],[4,6]]], 421 | 422 | [[[7,7],[0,8]], 423 | [[9,3],[9,9]], 424 | [[1,10],[10,10]] 425 | ] 426 | ] 427 | tensor_m = torch.as_tensor(m, dtype=torch.float32) 428 | m_nor = F.normalize(tensor_m, p=1, dim=3) 429 | print("normal:{}".format(m_nor)) -------------------------------------------------------------------------------- /graph_sim_no_gcn_dej_X.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import json 3 | import os 4 | import sys 5 | 6 | from collections import defaultdict 7 | from functools import partial 8 | from typing import Set, List, Any, Optional 9 | 10 | from datetime import datetime 11 | import socket 12 | import time 13 | import shutil 14 | import numpy as np 15 | import torch.optim as optim 16 | from torch.utils.data import DataLoader 17 | 18 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 19 | from Radm import RAdam 20 | 21 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 22 | import pandas as pd 23 | from DataSetGraphSimGenerator import DataSetGraphSimGenerator, CustomDataset 24 | from tensorboard_logger import TensorBoardWritter 25 | from model_batch import * 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | def get_rank(y_true: Set[Any], y_pred: List[Any], max_rank: Optional[int] = None) -> List[float]: 30 | rank_dict = defaultdict(lambda: len(y_pred) + 1 if max_rank is None else (max_rank + len(y_pred)) / 2) 31 | for idx, item in enumerate(y_pred, start=1): 32 | if item in y_true: 33 | rank_dict[item] = idx 34 | return [rank_dict[_] for _ in y_true] 35 | 36 | # noinspection PyPep8Naming 37 | def MAR(y_true: List[Set[Any]], y_pred: List[List[Any]], max_rank: Optional[int] = None): 38 | return np.mean([ 39 | np.mean(get_rank(a, b, max_rank)) 40 | for a, b in zip(y_true, y_pred) 41 | ]) 42 | class ModelInferenceNoGcn: 43 | def __init__(self): 44 | # 初始化配置 45 | logging.error("device:{}".format(device)) 46 | torch.set_printoptions(linewidth=120) 47 | torch.set_grad_enabled(True) 48 | np.random.seed(5) 49 | torch.manual_seed(0) 50 | 51 | # 超参数配置文件 52 | self.config = configparser.ConfigParser() 53 | self.config_file_path = os.path.join(os.path.dirname(__file__), "config_graph_sim_nogcn.ini") 54 | self.config.read(self.config_file_path, encoding='utf-8') 55 | 56 | # 超参数 57 | self.__load_super_paras() 58 | self.cross_weight_auto = None 59 | 60 | # 模型 61 | self.model = None 62 | self.model_saved_path = None 63 | self.model_saved_dir = None 64 | 65 | # tensor board类 66 | self.tb_comment = self.data_set_name 67 | self.tb_logger = None 68 | 69 | pass 70 | 71 | def __load_super_paras(self): 72 | self.data_set_id = self.config.getint("data", "DATASET") 73 | self.data_set_name = "train_ticket" if self.data_set_id==1 else "sock_shop" 74 | self.input_dim = self.config.getint("model", "input_dim") 75 | self.gcn_hidden_dim = self.config.getint("model", "gcn_hidden_dim") 76 | self.linear_hidden_dim = self.config.getint("model", "linear_hidden_dim") 77 | self.num_bases = self.config.getint("model", "num_bases") 78 | self.dropout = self.config.getfloat("model", "dropout") 79 | self.support = self.config.getint("model", "support") 80 | self.max_node_num = self.config.getint("model", "max_node_num") 81 | self.pool_step = self.config.getint("model", "pool_step") 82 | self.lr = self.config.getfloat("train", "LR") 83 | self.weight_decay = self.config.getfloat("train", "l2norm") 84 | self.resplit = self.config.getboolean("data", "resplit") 85 | self.batch_size = self.config.getint("data", "batch_size") 86 | self.resplit_each_time = self.config.getboolean("data", "resplit_each_time") 87 | self.repeat_pos_data = self.config.getint("data", "repeat_pos_data") 88 | self.dataset_version = self.config.get("data", "dataset_version") 89 | 90 | self.epoch = self.config.getint("train", "NB_EPOCH") 91 | self.user_comment = self.config.get("train", "comment") 92 | self.criterion = F.cross_entropy 93 | 94 | def __start_tb_logger(self, time_str): 95 | self.tb_log_dir = os.path.join(os.path.dirname(__file__), 'runs/%s' % time_str 96 | ).replace("\\", os.sep).replace("/", os.sep) 97 | self.tb_logger = TensorBoardWritter(log_dir="{}_{}{}".format(self.tb_log_dir, socket.gethostname(), self.tb_comment + self.user_comment), 98 | comment=self.tb_comment) 99 | 100 | def __stop_tb_logger(self): 101 | del self.tb_logger 102 | self.tb_logger = None 103 | 104 | def __print_paras(self, model): 105 | for name, param in model.named_parameters(): 106 | logging.warning("name:{} param:{}".format(name, param.requires_grad)) 107 | 108 | def generate_labeled_data(self): 109 | ds = DataSetGraphSimGenerator(data_set_id=self.data_set_id, dataset_version=self.dataset_version) 110 | ds.generate_dataset_pickle() 111 | del ds 112 | pass 113 | 114 | def __new_model_obj(self): 115 | return GraphSimilarity_No_Gcn(input_dim=self.input_dim, 116 | gcn_hidden_dim=self.gcn_hidden_dim, 117 | linear_hidden_dim=self.linear_hidden_dim, 118 | pool_step=self.pool_step, 119 | num_bases=self.num_bases, 120 | dropout=self.dropout, 121 | support=self.support, 122 | max_node_num=self.max_node_num) 123 | 124 | def __print_data_info(self): 125 | train_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 126 | repeat_pos_data=self.repeat_pos_data, resplit=False) 127 | test_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="test", 128 | repeat_pos_data=self.repeat_pos_data, resplit=False) 129 | val_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="val", 130 | repeat_pos_data=self.repeat_pos_data, resplit=False) 131 | train_data.print_data_set_info() 132 | test_data.print_data_set_info() 133 | val_data.print_data_set_info() 134 | for datas in [train_data, test_data, val_data]: 135 | for index, data in enumerate(datas): 136 | adj_1 = np.array(data["graph_online_adj"].cpu())[0] 137 | f_1 = np.array(data["graph_online_feature"].cpu()) 138 | adj_2 = np.array(data["graph_kb_adj"].cpu())[0] 139 | f_2 = np.array(data["graph_kb_feature"].cpu()) 140 | self.tb_logger.writer.add_histogram("graph_online/adj", adj_1, index) 141 | self.tb_logger.writer.add_histogram("graph_online/feature", f_1, index) 142 | self.tb_logger.writer.add_histogram("graph_kb/adj", adj_2, index) 143 | self.tb_logger.writer.add_histogram("graph_kb/feature", f_2, index) 144 | 145 | def crossentropy_loss(self, output, label, num_list): 146 | """ num_list 表示 从 0,1,2,3每种类别的数目 本处只有两个类别[不相似,相似]""" 147 | num_list.reverse() 148 | weight_ = torch.as_tensor(num_list, dtype=torch.float32, device=device) 149 | weight_ = weight_ / torch.sum(weight_) 150 | self.cross_weight_auto = np.array(weight_.cpu()) 151 | return self.criterion(output, label, weight=weight_) 152 | 153 | def train_model(self): 154 | """训练并记录参数和模型""" 155 | start_time_train = str(datetime.now().strftime("%Y%m%d-%H%M%S")) 156 | self.__start_tb_logger(time_str= start_time_train) 157 | # 模型 158 | self.model = self.__new_model_obj() 159 | self.__print_paras(self.model) 160 | self.model = self.model.to(device) 161 | 162 | # 交叉熵 163 | criterion = self.crossentropy_loss 164 | # 优化器 165 | optimizer = RAdam(self.model.parameters(), 166 | lr=self.lr, 167 | weight_decay=self.weight_decay) 168 | # 学习律 169 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, 170 | patience=8, threshold=1e-4, threshold_mode="rel", 171 | cooldown=0, min_lr=0, eps=1e-8) 172 | # 数据 173 | 174 | train_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 175 | repeat_pos_data=self.repeat_pos_data, resplit=self.resplit) 176 | self.__print_data_info() 177 | pos_train_num, neg_train_num = train_data.pos_neg_num() 178 | 179 | # 训练 180 | for epoch in range(self.epoch): 181 | if self.resplit_each_time: 182 | train_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 183 | repeat_pos_data=self.repeat_pos_data,resplit=self.resplit) 184 | train_loader = DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True) 185 | loss_all = 0 186 | accuary_all_num = 0 187 | preds_all_num = 0 188 | FN, FP, TN, TP = 0, 0, 0, 0 189 | batch_num = 0 190 | outputs_all = list() 191 | for batch in train_loader: 192 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 193 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 194 | labels = batch["label"] 195 | outputs = self.model(graphs_online, graphs_offline) 196 | outputs_all.append(outputs) 197 | loss = criterion(outputs, labels, num_list=[neg_train_num, pos_train_num]) 198 | 199 | preds = torch.argmax(outputs, dim=1) 200 | accuary_all_num += torch.sum(preds == labels) 201 | preds_all_num += torch.as_tensor(labels.shape[0]) 202 | FN += int(torch.sum(preds[labels == 1] == 0)) 203 | FP += int(torch.sum(preds[labels == 0] == 1)) 204 | TN += int(torch.sum(preds[labels == 0] == 0)) 205 | TP += int(torch.sum(preds[labels == 1] == 1)) 206 | 207 | batch_num += 1 208 | loss_all += loss 209 | optimizer.zero_grad() 210 | loss.backward() 211 | optimizer.step() 212 | scheduler.step(loss_all) 213 | sample_data = None 214 | if epoch == 0: 215 | batch = next(iter(train_loader)) 216 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 217 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 218 | sample_data = (graphs_online, graphs_offline) 219 | 220 | 221 | recall_train = TP / (TP + FN) if (TP + FN) else 0 222 | precision_train = TP / (TP + FP) if (TP + FP) else 0 223 | F1_train = ((2 * precision_train * recall_train) / (precision_train + recall_train)) if (precision_train and recall_train) else 0 224 | accuary_val, loss_all_val, base_ac_val, precision_val, recall_val, F1_val = self.test_val_model(mode="val") 225 | accuary_test, loss_all_test, base_ac_test, precision_test, recall_test, F1_test = self.test_val_model(mode="test") 226 | class_ac_train = self.judge_graph_class_ac(mode="train") 227 | class_ac_val = self.judge_graph_class_ac(mode="val") 228 | class_ac_test = self.judge_graph_class_ac(mode="test") 229 | accuary_train = accuary_all_num.item() / preds_all_num.item() 230 | info_dict = dict( 231 | sample_data=sample_data, step=epoch, loss=loss_all.item(), 232 | loss_val=loss_all_val.item(), 233 | loss_test=loss_all_test.item(), 234 | accuracy=accuary_train, 235 | accuary_val=accuary_val, 236 | accuary_test=accuary_test, 237 | outputs_all=torch.cat(outputs_all, dim=0), 238 | train_pos_neg=np.array([pos_train_num/len(train_data), neg_train_num/len(train_data)]), 239 | val_pos_neg=np.array(base_ac_val), 240 | test_pos_neg=np.array(base_ac_test), 241 | cross_weight_auto=self.cross_weight_auto, 242 | class_ac_train=class_ac_train, 243 | class_ac_val=class_ac_val, 244 | class_ac_test=class_ac_test, 245 | recall_train=recall_train, 246 | recall_val=recall_val, 247 | recall_test=recall_test, 248 | precision_train=precision_train, 249 | precision_val=precision_val, 250 | precision_test=precision_test, 251 | F1_train=F1_train, 252 | F1_val=F1_val, 253 | F1_test=F1_test 254 | ) 255 | self.tb_logger.print_tensoroard_logs(model=self.model, info_dict=info_dict) 256 | 257 | logging.error("epoch:{} loss:{} accuracy:{}/{}={}".format(epoch, loss_all, accuary_all_num, preds_all_num, 258 | int(accuary_all_num) / int(preds_all_num))) 259 | if accuary_train >= 0.77 or (epoch >= 100 and epoch % 10 == 0 and accuary_train >= 0.98): 260 | self.save_model(time_str=start_time_train) 261 | 262 | self.test_val_model(mode="test") 263 | self.save_model(time_str=start_time_train) 264 | self.__stop_tb_logger() 265 | 266 | @torch.no_grad() 267 | def test_val_model(self, mode): 268 | test_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode=mode, 269 | repeat_pos_data=self.repeat_pos_data, resplit=False) 270 | test_loader = DataLoader(dataset=test_data, batch_size=self.batch_size, shuffle=True) 271 | accuary_all_num = 0 272 | preds_all_num = 0 273 | FN, FP, TN, TP = 0, 0, 0, 0 274 | loss_all = 0.0 275 | pred_class = list() 276 | pred_class_t3 = list() 277 | pred_class_t2 = list() 278 | pred_class_t5 = list() 279 | label_class = list() 280 | time = 0 281 | pos_num, neg_num = test_data.pos_neg_num() 282 | for batch in test_loader: 283 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 284 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 285 | labels = batch["label"] 286 | start_time = time.perf_counter() 287 | outputs = self.model(graphs_online, graphs_offline) 288 | end_time = time.perf_counter() 289 | # logging.error("out{} label{} negnum{}".format(type(outputs), type(labels), type(neg_num))) 290 | loss = self.criterion(outputs, labels, weight=torch.as_tensor(self.cross_weight_auto, dtype=torch.float32, device=device)) 291 | loss_all += loss 292 | outputs_sm = torch.nn.functional.softmax(outputs, dim=1) 293 | # 记录最为相似的索引,即top1 294 | outputs_max_index = torch.argmax(outputs_sm.narrow(1, 1, 1)) 295 | # if save_error_excel and (outputs_max_index.item() != (torch.argmax(labels)).item()): 296 | # preds_index = outputs_max_index.item() 297 | # preds_name = online_data_info[3][preds_index] 298 | # one_error = list(online_data_info[0:3]) + [preds_index, preds_name] 299 | # labeled_error_data.append(one_error) 300 | # 记录top_3 301 | outputs_sq = torch.squeeze(outputs_sm.narrow(1, 1, 1)) 302 | top_3 = torch.topk(outputs_sq, k=3)[1] 303 | if torch.argmax(labels) in top_3: 304 | pred_class_t3.append(torch.argmax(labels)) 305 | else: 306 | pred_class_t3.append(top_3[0]) 307 | 308 | # 记录top2 309 | top_2 = torch.topk(outputs_sq, k=2)[1] 310 | if torch.argmax(labels) in top_2: 311 | pred_class_t2.append(torch.argmax(labels)) 312 | else: 313 | pred_class_t2.append(top_2[0]) 314 | # record top5 315 | top_5 = torch.topk(outputs_sq, k=5)[1] 316 | if torch.argmax(labels) in top_5: 317 | pred_class_t5.append(torch.argmax(labels)) 318 | else: 319 | pred_class_t5.append(top_5[0]) 320 | pred_class.append(outputs_max_index) 321 | label_class.append(torch.argmax(labels)) 322 | preds = torch.argmax(outputs, dim=1) 323 | accuary_all_num += torch.sum(preds == labels) 324 | preds_all_num += torch.as_tensor(labels.shape[0]) 325 | FN += int(torch.sum(preds[labels==1]==0)) 326 | FP += int(torch.sum(preds[labels==0]==1)) 327 | TN += int(torch.sum(preds[labels==0]==0)) 328 | TP += int(torch.sum(preds[labels==1]==1)) 329 | recall = TP / (TP + FN) if (TP + FN) else 0 330 | precision = TP / (TP + FP) if (TP + FP) else 0 331 | time = end_time - start_time 332 | F1 = ((2 * precision * recall) / (precision + recall)) if (precision and recall) else 0 333 | pred_class_s = torch.stack(pred_class) 334 | label_class_s = torch.stack(label_class) 335 | pred_class_t3_s = torch.stack(pred_class_t3) 336 | pred_class_t2_s = torch.stack(pred_class_t2) 337 | pred_class_t5_s = torch.stack(pred_class_t5) 338 | MAR_res = MAR(label_class_s, pred_class_s) 339 | top1_a = torch.sum(pred_class_s == label_class_s).item() / pred_class_s.size()[0] 340 | top3_a = torch.sum(pred_class_t3_s == label_class_s).item() / pred_class_t3_s.size()[0] 341 | top2_a = torch.sum(pred_class_t2_s == label_class_s).item() / pred_class_t2_s.size()[0] 342 | top5_a = torch.sum(pred_class_t5_s == label_class_s).item() / pred_class_t5_s.size()[0] 343 | logging.error("{}_data : accuracy1:{}/{}={} accuracy2:{} accuracy3:{} accuracy5:{} MAR: {} precision:{}/{}={} recall:{}/{}={} F1:{} time: {}".format( 344 | mode, accuary_all_num, preds_all_num, int(accuary_all_num) / int(preds_all_num),top2_a,top3_a,top5_a,MAR_res, 345 | TP, (TP + FP), precision, 346 | TP, (TP + FN), recall, 347 | F1,time 348 | )) 349 | pos, neg = test_data.pos_neg_num() 350 | base_ac = [pos/(pos+neg), neg/(pos+neg)] 351 | return int(accuary_all_num) / int(preds_all_num), loss_all, base_ac, precision, recall, F1 352 | pass 353 | 354 | @torch.no_grad() 355 | def judge_graph_class_ac(self, mode, save_error_excel=False, all_test=False): 356 | assert self.model 357 | data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode=mode, 358 | repeat_pos_data=self.repeat_pos_data, resplit=False) 359 | pred_class = list() 360 | pred_class_t3 = list() 361 | pred_class_t2 = list() 362 | label_class = list() 363 | 364 | labeled_error_data = list() 365 | for sample, online_data_info in data.graph_class_data(): 366 | """(online_data_path, e_name, e_index, error_name_list)""" 367 | graphs_online = (sample["graph_online_feature"], sample["graph_online_adj"]) 368 | graphs_offline = (sample["graph_kb_feature"], sample["graph_kb_adj"]) 369 | labels = sample["label"] 370 | outputs = self.model(graphs_online, graphs_offline) 371 | outputs_sm = torch.nn.functional.softmax(outputs, dim=1) 372 | # 记录最为相似的索引,即top1 373 | outputs_max_index = torch.argmax(outputs_sm.narrow(1, 1, 1)) 374 | if save_error_excel and (outputs_max_index.item() != (torch.argmax(labels)).item()): 375 | preds_index = outputs_max_index.item() 376 | preds_name = online_data_info[3][preds_index] 377 | one_error = list(online_data_info[0:3]) + [preds_index, preds_name] 378 | labeled_error_data.append(one_error) 379 | # 记录top_3 380 | outputs_sq = torch.squeeze(outputs_sm.narrow(1, 1, 1)) 381 | top_3 = torch.topk(outputs_sq, k=3)[1] 382 | if torch.argmax(labels) in top_3: 383 | pred_class_t3.append(torch.argmax(labels)) 384 | else: 385 | pred_class_t3.append(top_3[0]) 386 | 387 | # 记录top2 388 | top_2 = torch.topk(outputs_sq, k=2)[1] 389 | if torch.argmax(labels) in top_2: 390 | pred_class_t2.append(torch.argmax(labels)) 391 | else: 392 | pred_class_t2.append(top_2[0]) 393 | 394 | 395 | pred_class.append(outputs_max_index) 396 | label_class.append(torch.argmax(labels)) 397 | pred_class_s = torch.stack(pred_class) 398 | label_class_s = torch.stack(label_class) 399 | pred_class_t3_s = torch.stack(pred_class_t3) 400 | pred_class_t2_s = torch.stack(pred_class_t2) 401 | 402 | top1_a = torch.sum(pred_class_s == label_class_s).item() / pred_class_s.size()[0] 403 | top3_a = torch.sum(pred_class_t3_s == label_class_s).item() / pred_class_t3_s.size()[0] 404 | top2_a = torch.sum(pred_class_t2_s == label_class_s).item() / pred_class_t2_s.size()[0] 405 | if save_error_excel: 406 | excel_anme = os.path.join(os.path.dirname(__file__), "{}_error_label_{}.xls".format(self.data_set_name, mode)) 407 | df = pd.DataFrame(labeled_error_data, columns=["online_data_path", "e_name", "e_index", "preds_index", "preds_name"]) 408 | df.to_excel(excel_anme, index=False) 409 | if all_test: 410 | return df 411 | return top1_a, top3_a, top2_a 412 | 413 | 414 | 415 | 416 | def save_model(self, time_str): 417 | dir_name = "{}_{}".format(time_str, socket.gethostname()+self.user_comment) 418 | save_path_dir = os.path.join(os.path.dirname(__file__), "..", "data", "graph_sim_model_parameters", self.data_set_name, 419 | dir_name) 420 | os.makedirs(save_path_dir, exist_ok=True) 421 | self.model_saved_path = os.path.join(save_path_dir, "model.pth") 422 | self.model_saved_dir = save_path_dir 423 | torch.save(self.model.state_dict(), self.model_saved_path) 424 | shutil.copy(self.config_file_path, os.path.join(save_path_dir, "config_graph_sim.ini")) 425 | 426 | def load_model(self, model_saved_dir): 427 | "https://blog.csdn.net/dss_dssssd/article/details/89409183" 428 | if model_saved_dir: 429 | self.model_saved_dir = model_saved_dir 430 | self.model_saved_path = os.path.join(model_saved_dir, "model.pth") 431 | self.config_file_path = os.path.join(model_saved_dir, "config_graph_sim.ini") 432 | self.config.read(self.config_file_path, encoding='utf-8') 433 | self.__load_super_paras() 434 | 435 | self.model = self.__new_model_obj() 436 | self.model = self.model.to(device) 437 | self.model.load_state_dict(torch.load(self.model_saved_path, map_location=device)) 438 | self.model.eval() 439 | pass 440 | 441 | def get_error_summary(did = 1): 442 | model_paths = [ 443 | "20200423-095359_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdmweight46", 444 | "20200423-092330_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdmweight46", 445 | "20200423-084810_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdm", 446 | "20200423-081103_amaxD2_n100step100dataset70attetionjumpCosineAnnealingLRRAdm", 447 | "20200423-132639_amaxD2_n100step100dataset70allattetionjumpReduceLROnPlateauLRRAdm" 448 | ] 449 | minf = ModelInferenceNoGcn() 450 | dataset_name = "train_ticket" if did == 1 else "sock_shop" 451 | all_dir = os.path.join(os.path.dirname(__file__), "..", "data", "graph_sim_model_parameters", dataset_name) 452 | error_info = list() 453 | for root, dirs, files in os.walk(all_dir): 454 | for dir in dirs: 455 | if dir not in model_paths: 456 | continue 457 | model_dir = os.path.join(all_dir, dir) 458 | minf.load_model(model_saved_dir=model_dir) 459 | df_train = minf.judge_graph_class_ac(mode="train", save_error_excel=True, all_test=True) 460 | df_test = minf.judge_graph_class_ac(mode="test", save_error_excel=True, all_test=True) 461 | df_val = minf.judge_graph_class_ac(mode="val", save_error_excel=True, all_test=True) 462 | df_all = pd.concat([df_train, df_test, df_val]) 463 | online_data_path = list(set((df_all["online_data_path"]))) 464 | online_data_path.sort() 465 | e_name = list(set(df_all["e_name"])) 466 | e_name.sort() 467 | error_info.append(dict( 468 | model_name=dir, 469 | online_names=online_data_path, 470 | error_names=e_name 471 | )) 472 | online_name_sets = [set(info["online_names"]) for info in error_info] 473 | error_name_sets = [set(info["error_names"]) for info in error_info] 474 | online_final = online_name_sets[0] 475 | error_final = error_name_sets[0] 476 | for _ in range(1, len(online_name_sets)): 477 | online_final.intersection(online_name_sets[_]) 478 | for _ in range(1, len(error_name_sets)): 479 | error_final.intersection(error_name_sets[_]) 480 | error_info.insert(0, dict( 481 | online_names_jiaoset=sorted(list(online_final)), 482 | e_name_jiaoset=sorted(list(error_final)) 483 | )) 484 | 485 | save_json_data(os.path.join(os.path.dirname(__file__), "{}_error_info.json".format(dataset_name)), error_info) 486 | 487 | 488 | def save_json_data(save_path, pre_save_data): 489 | with open(save_path, 'w', encoding='utf-8') as file_writer: 490 | raw_data = json.dumps(pre_save_data, indent=4) 491 | file_writer.write(raw_data) 492 | 493 | 494 | 495 | if __name__ == '__main__': 496 | minf = ModelInferenceNoGcn() 497 | minf.train_model() 498 | -------------------------------------------------------------------------------- /graph_sim_dej_X.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import json 3 | import os 4 | import sys 5 | 6 | from collections import defaultdict 7 | from functools import partial 8 | from typing import Set, List, Any, Optional 9 | 10 | import time 11 | from datetime import datetime 12 | import socket 13 | import time 14 | import shutil 15 | import numpy as np 16 | import torch.optim as optim 17 | from torch.utils.data import DataLoader 18 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 19 | 20 | from Radm import RAdam 21 | 22 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 23 | import pandas as pd 24 | from DataSetGraphSimGenerator import DataSetGraphSimGenerator, CustomDataset 25 | from tensorboard_logger import TensorBoardWritter 26 | from model_batch import * 27 | 28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | 30 | def get_rank(y_true: Set[Any], y_pred: List[Any], max_rank: Optional[int] = None) -> List[float]: 31 | rank_dict = defaultdict(lambda: len(y_pred) + 1 if max_rank is None else (max_rank + len(y_pred)) / 2) 32 | for idx, item in enumerate(y_pred, start=1): 33 | if item in y_true: 34 | rank_dict[item] = idx 35 | return [rank_dict[_] for _ in y_true] 36 | 37 | # noinspection PyPep8Naming 38 | def MAR(y_true: List[Set[Any]], y_pred: List[List[Any]], max_rank: Optional[int] = None): 39 | return np.mean([ 40 | np.mean(get_rank(a, b, max_rank)) 41 | for a, b in zip(y_true, y_pred) 42 | ]) 43 | 44 | 45 | class ModelInference: 46 | def __init__(self): 47 | # 初始化配置 48 | logging.error("device:{}".format(device)) 49 | torch.set_printoptions(linewidth=120) 50 | torch.set_grad_enabled(True) 51 | np.random.seed(5) 52 | torch.manual_seed(0) 53 | 54 | # 超参数配置文件 55 | self.config = configparser.ConfigParser() 56 | self.config_file_path = os.path.join(os.path.dirname(__file__), "config_graph_sim.ini") 57 | self.config.read(self.config_file_path, encoding='utf-8') 58 | 59 | # 超参数 60 | self.__load_super_paras() 61 | self.cross_weight_auto = None 62 | 63 | # 模型 64 | self.model = None 65 | self.model_saved_path = None 66 | self.model_saved_dir = None 67 | 68 | # tensor board类 69 | self.tb_comment = self.data_set_name 70 | self.tb_logger = None 71 | 72 | # 控制台打印 73 | # coloredlogs.install( 74 | # level=self.logging_print_level, 75 | # fmt="[%(levelname)s] [%(asctime)s] [%(filename)s:%(lineno)d] %(message)s", 76 | # level_styles=LEVEL_STYLES, 77 | # field_styles=FIELD_STYLES, 78 | # logger=logger 79 | # ) 80 | 81 | pass 82 | 83 | def __load_super_paras(self): 84 | self.data_set_id = self.config.getint("data", "DATASET") 85 | self.data_set_name = "train_ticket" if self.data_set_id==1 else "sock_shop" 86 | self.input_dim = self.config.getint("model", "input_dim") 87 | self.gcn_hidden_dim = self.config.getint("model", "gcn_hidden_dim") 88 | self.linear_hidden_dim = self.config.getint("model", "linear_hidden_dim") 89 | self.num_bases = self.config.getint("model", "num_bases") 90 | self.dropout = self.config.getfloat("model", "dropout") 91 | self.support = self.config.getint("model", "support") 92 | self.max_node_num = self.config.getint("model", "max_node_num") 93 | self.pool_step = self.config.getint("model", "pool_step") 94 | self.lr = self.config.getfloat("train", "LR") 95 | self.weight_decay = self.config.getfloat("train", "l2norm") 96 | self.resplit = self.config.getboolean("data", "resplit") 97 | self.batch_size = self.config.getint("data", "batch_size") 98 | self.resplit_each_time = self.config.getboolean("data", "resplit_each_time") 99 | self.repeat_pos_data = self.config.getint("data", "repeat_pos_data") 100 | self.dataset_version = self.config.get("data", "dataset_version") 101 | 102 | self.epoch = self.config.getint("train", "NB_EPOCH") 103 | self.user_comment = self.config.get("train", "comment") 104 | # self.cross_weight = self.config.getfloat("train", "cross_weight") 105 | # self.logging_print_level = str(self.config.get("print_logging", "level")) 106 | self.criterion = F.cross_entropy 107 | 108 | def __start_tb_logger(self, time_str): 109 | # self.start_time = str(datetime.now().strftime("%Y%m%d-%H%M%S")) 110 | self.tb_log_dir = os.path.join(os.path.dirname(__file__), 'runs/%s' % time_str 111 | ).replace("\\", os.sep).replace("/", os.sep) 112 | self.tb_logger = TensorBoardWritter(log_dir="{}_{}{}".format(self.tb_log_dir, socket.gethostname(), self.tb_comment + self.user_comment), 113 | comment=self.tb_comment) 114 | 115 | def __stop_tb_logger(self): 116 | del self.tb_logger 117 | self.tb_logger = None 118 | 119 | def __print_paras(self, model): 120 | for name, param in model.named_parameters(): 121 | logging.warning("name:{} param:{}".format(name, param.requires_grad)) 122 | 123 | def generate_labeled_data(self): 124 | ds = DataSetGraphSimGenerator(data_set_id=self.data_set_id, dataset_version=self.dataset_version) 125 | ds.generate_dataset_pickle() 126 | del ds 127 | pass 128 | 129 | def __new_model_obj(self): 130 | return GraphSimilarity(input_dim=self.input_dim, 131 | gcn_hidden_dim=self.gcn_hidden_dim, 132 | linear_hidden_dim=self.linear_hidden_dim, 133 | pool_step=self.pool_step, 134 | num_bases=self.num_bases, 135 | dropout=self.dropout, 136 | support=self.support, 137 | max_node_num=self.max_node_num) 138 | 139 | def __print_data_info(self): 140 | train_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 141 | repeat_pos_data=self.repeat_pos_data, resplit=False) 142 | test_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="test", 143 | repeat_pos_data=self.repeat_pos_data, resplit=False) 144 | val_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="val", 145 | repeat_pos_data=self.repeat_pos_data, resplit=False) 146 | train_data.print_data_set_info() 147 | test_data.print_data_set_info() 148 | val_data.print_data_set_info() 149 | for datas in [train_data, test_data, val_data]: 150 | for index, data in enumerate(datas): 151 | adj_1 = np.array(data["graph_online_adj"].cpu())[0] 152 | f_1 = np.array(data["graph_online_feature"].cpu()) 153 | adj_2 = np.array(data["graph_kb_adj"].cpu())[0] 154 | f_2 = np.array(data["graph_kb_feature"].cpu()) 155 | self.tb_logger.writer.add_histogram("graph_online/adj", adj_1, index) 156 | self.tb_logger.writer.add_histogram("graph_online/feature", f_1, index) 157 | self.tb_logger.writer.add_histogram("graph_kb/adj", adj_2, index) 158 | self.tb_logger.writer.add_histogram("graph_kb/feature", f_2, index) 159 | 160 | def crossentropy_loss(self, output, label, num_list): 161 | """ num_list 表示 从 0,1,2,3每种类别的数目 本处只有两个类别[不相似,相似]""" 162 | # 方式1 直接翻转后除以总数 163 | num_list.reverse() 164 | weight_ = torch.as_tensor(num_list, dtype=torch.float32, device=device) 165 | weight_ = weight_ / torch.sum(weight_) 166 | # 方式2中值平均 167 | # weight_ = torch.as_tensor(num_list, dtype=torch.float32, device=device) 168 | # weight_ = torch.mean(weight_) * torch.rsqrt(weight_) 169 | self.cross_weight_auto = np.array(weight_.cpu()) 170 | # return self.criterion(output, label, weight=torch.as_tensor([0.4,0.6], dtype=torch.float32, device=device)) 171 | return self.criterion(output, label, weight=weight_) 172 | 173 | def train_model(self): 174 | """训练并记录参数和模型""" 175 | start_time_train = str(datetime.now().strftime("%Y%m%d-%H%M%S")) 176 | self.__start_tb_logger(time_str= start_time_train) 177 | # 模型 178 | self.model = self.__new_model_obj() 179 | self.__print_paras(self.model) 180 | self.model = self.model.to(device) 181 | 182 | # 交叉熵 183 | criterion = self.crossentropy_loss 184 | # 优化器 185 | # optimizer = optim.Adam(self.model.parameters(), 186 | # lr=self.lr, 187 | # weight_decay=self.weight_decay) 188 | optimizer = RAdam(self.model.parameters(), 189 | lr=self.lr, 190 | weight_decay=self.weight_decay) 191 | # 学习律 192 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, 193 | patience=8, threshold=1e-4, threshold_mode="rel", 194 | cooldown=0, min_lr=0, eps=1e-8) 195 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 32) 196 | # 数据 197 | 198 | train_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 199 | repeat_pos_data=self.repeat_pos_data, resplit=self.resplit) 200 | self.__print_data_info() 201 | pos_train_num, neg_train_num = train_data.pos_neg_num() 202 | 203 | # 训练 204 | for epoch in range(self.epoch): 205 | if self.resplit_each_time: 206 | train_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 207 | repeat_pos_data=self.repeat_pos_data,resplit=self.resplit) 208 | train_loader = DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True) 209 | loss_all = 0 210 | accuary_all_num = 0 211 | preds_all_num = 0 212 | FN, FP, TN, TP = 0, 0, 0, 0 213 | batch_num = 0 214 | outputs_all = list() 215 | for batch in train_loader: 216 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 217 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 218 | labels = batch["label"] 219 | outputs = self.model(graphs_online, graphs_offline) 220 | outputs_all.append(outputs) 221 | loss = criterion(outputs, labels, num_list=[neg_train_num, pos_train_num]) 222 | 223 | 224 | 225 | preds = torch.argmax(outputs, dim=1) 226 | accuary_all_num += torch.sum(preds == labels) 227 | preds_all_num += torch.as_tensor(labels.shape[0]) 228 | FN += int(torch.sum(preds[labels == 1] == 0)) 229 | FP += int(torch.sum(preds[labels == 0] == 1)) 230 | TN += int(torch.sum(preds[labels == 0] == 0)) 231 | TP += int(torch.sum(preds[labels == 1] == 1)) 232 | 233 | batch_num += 1 234 | loss_all += loss 235 | optimizer.zero_grad() 236 | loss.backward() 237 | optimizer.step() 238 | scheduler.step(loss_all) 239 | sample_data = None 240 | if epoch == 0: 241 | batch = next(iter(train_loader)) 242 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 243 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 244 | sample_data = (graphs_online, graphs_offline) 245 | 246 | 247 | recall_train = TP / (TP + FN) if (TP + FN) else 0 248 | precision_train = TP / (TP + FP) if (TP + FP) else 0 249 | F1_train = ((2 * precision_train * recall_train) / (precision_train + recall_train)) if (precision_train and recall_train) else 0 250 | accuary_val, loss_all_val, base_ac_val, precision_val, recall_val, F1_val = self.test_val_model(mode="val") 251 | accuary_test, loss_all_test, base_ac_test, precision_test, recall_test, F1_test = self.test_val_model(mode="test") 252 | class_ac_train = self.judge_graph_class_ac(mode="train") 253 | class_ac_val = self.judge_graph_class_ac(mode="val") 254 | class_ac_test = self.judge_graph_class_ac(mode="test") 255 | accuary_train = accuary_all_num.item() / preds_all_num.item() 256 | info_dict = dict( 257 | sample_data=sample_data, step=epoch, loss=loss_all.item(), 258 | loss_val=loss_all_val.item(), 259 | loss_test=loss_all_test.item(), 260 | accuracy=accuary_train, 261 | accuary_val=accuary_val, 262 | accuary_test=accuary_test, 263 | outputs_all=torch.cat(outputs_all, dim=0), 264 | train_pos_neg=np.array([pos_train_num/len(train_data), neg_train_num/len(train_data)]), 265 | val_pos_neg=np.array(base_ac_val), 266 | test_pos_neg=np.array(base_ac_test), 267 | cross_weight_auto=self.cross_weight_auto, 268 | class_ac_train=class_ac_train, 269 | class_ac_val=class_ac_val, 270 | class_ac_test=class_ac_test, 271 | recall_train=recall_train, 272 | recall_val=recall_val, 273 | recall_test=recall_test, 274 | precision_train=precision_train, 275 | precision_val=precision_val, 276 | precision_test=precision_test, 277 | F1_train=F1_train, 278 | F1_val=F1_val, 279 | F1_test=F1_test 280 | ) 281 | self.tb_logger.print_tensoroard_logs(model=self.model, info_dict=info_dict) 282 | 283 | logging.error("epoch:{} loss:{} accuracy:{}/{}={}".format(epoch, loss_all, accuary_all_num, preds_all_num, 284 | int(accuary_all_num) / int(preds_all_num))) 285 | if accuary_train >= 0.99 or (epoch >= 100 and epoch % 10 == 0 and accuary_train >= 0.98): 286 | self.save_model(time_str=start_time_train) 287 | 288 | self.test_val_model(mode="test") 289 | # 保存模型和超参数 290 | self.save_model(time_str=start_time_train) 291 | self.__stop_tb_logger() 292 | 293 | @torch.no_grad() 294 | def test_val_model(self, mode): 295 | test_data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode=mode, 296 | repeat_pos_data=self.repeat_pos_data, resplit=False) 297 | # test_data.print_data_set_info() 298 | test_loader = DataLoader(dataset=test_data, batch_size=self.batch_size, shuffle=True) 299 | accuary_all_num = 0 300 | accuary_all_num2 = 0 301 | accuary_all_num3 = 0 302 | preds_all_num = 0 303 | FN, FP, TN, TP = 0, 0, 0, 0 304 | loss_all = 0.0 305 | pred_class = list() 306 | pred_class_t3 = list() 307 | pred_class_t2 = list() 308 | pred_class_t5 = list() 309 | label_class = list() 310 | time = 0 311 | pos_num, neg_num = test_data.pos_neg_num() 312 | for batch in test_loader: 313 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 314 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 315 | labels = batch["label"] 316 | start_time = time.perf_counter() 317 | outputs = self.model(graphs_online, graphs_offline) 318 | end_time = time.perf_counter() 319 | # logging.error("out{} label{} negnum{}".format(type(outputs), type(labels), type(neg_num))) 320 | loss = self.criterion(outputs, labels, weight=torch.as_tensor(self.cross_weight_auto, dtype=torch.float32, device=device)) 321 | loss_all += loss 322 | outputs_sm = torch.nn.functional.softmax(outputs, dim=1) 323 | # 记录最为相似的索引,即top1 324 | outputs_max_index = torch.argmax(outputs_sm.narrow(1, 1, 1)) 325 | # if save_error_excel and (outputs_max_index.item() != (torch.argmax(labels)).item()): 326 | # preds_index = outputs_max_index.item() 327 | # preds_name = online_data_info[3][preds_index] 328 | # one_error = list(online_data_info[0:3]) + [preds_index, preds_name] 329 | # labeled_error_data.append(one_error) 330 | # 记录top_3 331 | outputs_sq = torch.squeeze(outputs_sm.narrow(1, 1, 1)) 332 | top_3 = torch.topk(outputs_sq, k=3)[1] 333 | if torch.argmax(labels) in top_3: 334 | pred_class_t3.append(torch.argmax(labels)) 335 | else: 336 | pred_class_t3.append(top_3[0]) 337 | 338 | # 记录top2 339 | top_2 = torch.topk(outputs_sq, k=2)[1] 340 | if torch.argmax(labels) in top_2: 341 | pred_class_t2.append(torch.argmax(labels)) 342 | else: 343 | pred_class_t2.append(top_2[0]) 344 | # record top5 345 | top_5 = torch.topk(outputs_sq, k=5)[1] 346 | if torch.argmax(labels) in top_5: 347 | pred_class_t5.append(torch.argmax(labels)) 348 | else: 349 | pred_class_t5.append(top_5[0]) 350 | 351 | 352 | pred_class.append(outputs_max_index) 353 | label_class.append(torch.argmax(labels)) 354 | 355 | 356 | # 从中拿到不同的top top-1 top-2 top-3 357 | 358 | # preds2 = torch.topk(outputs, k=2, dim=1) 359 | # accuary_all_num2 += torch.sum(preds2 == labels) 360 | 361 | # values2, indices2 = preds2 362 | 363 | # Convert boolean tensor to float tensor 364 | # accuary_all_num2 += torch.sum((indices2 == labels).float()) 365 | 366 | # Similarly, for preds3 367 | 368 | 369 | 370 | # preds3 = torch.topk(outputs, k=3, dim=1) 371 | # values3, indices3 = preds3 372 | # accuary_all_num3 += torch.sum((indices3 == labels).float()) 373 | 374 | 375 | preds = torch.argmax(outputs, dim=1) 376 | 377 | accuary_all_num += torch.sum(preds == labels) 378 | preds_all_num += torch.as_tensor(labels.shape[0]) 379 | FN += int(torch.sum(preds[labels==1]==0)) 380 | FP += int(torch.sum(preds[labels==0]==1)) 381 | TN += int(torch.sum(preds[labels==0]==0)) 382 | TP += int(torch.sum(preds[labels==1]==1)) 383 | recall = TP / (TP + FN) if (TP + FN) else 0 384 | time = (end_time - start_time)*1000 385 | precision = TP / (TP + FP) if (TP + FP) else 0 386 | F1 = ((2 * precision * recall) / (precision + recall)) if (precision and recall) else 0 387 | pred_class_s = torch.stack(pred_class) 388 | label_class_s = torch.stack(label_class) 389 | pred_class_t3_s = torch.stack(pred_class_t3) 390 | pred_class_t2_s = torch.stack(pred_class_t2) 391 | pred_class_t5_s = torch.stack(pred_class_t5) 392 | MAR_res = MAR(label_class_s, pred_class_s) 393 | top1_a = torch.sum(pred_class_s == label_class_s).item() / pred_class_s.size()[0] 394 | top3_a = torch.sum(pred_class_t3_s == label_class_s).item() / pred_class_t3_s.size()[0] 395 | top2_a = torch.sum(pred_class_t2_s == label_class_s).item() / pred_class_t2_s.size()[0] 396 | top5_a = torch.sum(pred_class_t5_s == label_class_s).item() / pred_class_t5_s.size()[0] 397 | logging.error("{}_data : accuracy2:{} accuracy3:{} accuracy5:{} accuracy1:{}/{}={} MAR={} precision:{}/{}={} recall:{}/{}={} F1:{} time:{}".format( 398 | mode, top2_a,top3_a,top5_a,accuary_all_num, preds_all_num, int(accuary_all_num) / int(preds_all_num),MAR_res, 399 | TP, (TP + FP), precision, 400 | TP, (TP + FN), recall, 401 | F1,time 402 | )) 403 | pos, neg = test_data.pos_neg_num() 404 | base_ac = [pos/(pos+neg), neg/(pos+neg)] 405 | return int(accuary_all_num) / int(preds_all_num), loss_all, base_ac, precision, recall, F1 406 | pass 407 | 408 | @torch.no_grad() 409 | def judge_graph_class_ac(self, mode, save_error_excel=False, all_test=False): 410 | assert self.model 411 | data = CustomDataset(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode=mode, 412 | repeat_pos_data=self.repeat_pos_data, resplit=False) 413 | pred_class = list() 414 | pred_class_t3 = list() 415 | pred_class_t2 = list() 416 | label_class = list() 417 | 418 | labeled_error_data = list() 419 | for sample, online_data_info in data.graph_class_data(): 420 | """(online_data_path, e_name, e_index, error_name_list)""" 421 | graphs_online = (sample["graph_online_feature"], sample["graph_online_adj"]) 422 | graphs_offline = (sample["graph_kb_feature"], sample["graph_kb_adj"]) 423 | labels = sample["label"] 424 | outputs = self.model(graphs_online, graphs_offline) 425 | outputs_sm = torch.nn.functional.softmax(outputs, dim=1) 426 | # 记录最为相似的索引,即top1 427 | outputs_max_index = torch.argmax(outputs_sm.narrow(1, 1, 1)) 428 | if save_error_excel and (outputs_max_index.item() != (torch.argmax(labels)).item()): 429 | preds_index = outputs_max_index.item() 430 | preds_name = online_data_info[3][preds_index] 431 | one_error = list(online_data_info[0:3]) + [preds_index, preds_name] 432 | labeled_error_data.append(one_error) 433 | # 记录top_3 434 | outputs_sq = torch.squeeze(outputs_sm.narrow(1, 1, 1)) 435 | top_3 = torch.topk(outputs_sq, k=3)[1] 436 | if torch.argmax(labels) in top_3: 437 | pred_class_t3.append(torch.argmax(labels)) 438 | else: 439 | pred_class_t3.append(top_3[0]) 440 | 441 | # 记录top2 442 | top_2 = torch.topk(outputs_sq, k=2)[1] 443 | if torch.argmax(labels) in top_2: 444 | pred_class_t2.append(torch.argmax(labels)) 445 | else: 446 | pred_class_t2.append(top_2[0]) 447 | 448 | 449 | pred_class.append(outputs_max_index) 450 | label_class.append(torch.argmax(labels)) 451 | pred_class_s = torch.stack(pred_class) 452 | label_class_s = torch.stack(label_class) 453 | pred_class_t3_s = torch.stack(pred_class_t3) 454 | pred_class_t2_s = torch.stack(pred_class_t2) 455 | 456 | top1_a = torch.sum(pred_class_s == label_class_s).item() / pred_class_s.size()[0] 457 | top3_a = torch.sum(pred_class_t3_s == label_class_s).item() / pred_class_t3_s.size()[0] 458 | top2_a = torch.sum(pred_class_t2_s == label_class_s).item() / pred_class_t2_s.size()[0] 459 | if save_error_excel: 460 | excel_anme = os.path.join(os.path.dirname(__file__), "{}_error_label_{}.xls".format(self.data_set_name, mode)) 461 | df = pd.DataFrame(labeled_error_data, columns=["online_data_path", "e_name", "e_index", "preds_index", "preds_name"]) 462 | df.to_excel(excel_anme, index=False) 463 | if all_test: 464 | return df 465 | # logging.error("mode:{} top1_a:{} top2_a:{} top3_a:{}".format(mode, top1_a, top2_a, top3_a)) 466 | return top1_a, top3_a, top2_a 467 | 468 | 469 | 470 | 471 | def save_model(self, time_str): 472 | dir_name = "{}_{}".format(time_str, socket.gethostname()+self.user_comment) 473 | save_path_dir = os.path.join(os.path.dirname(__file__), "..", "data", "graph_sim_model_parameters", self.data_set_name, 474 | dir_name) 475 | os.makedirs(save_path_dir, exist_ok=True) 476 | # torch.save(self.model, save_path) 477 | self.model_saved_path = os.path.join(save_path_dir, "model.pth") 478 | self.model_saved_dir = save_path_dir 479 | torch.save(self.model.state_dict(), self.model_saved_path) 480 | shutil.copy(self.config_file_path, os.path.join(save_path_dir, "config_graph_sim.ini")) 481 | 482 | def load_model(self, model_saved_dir): 483 | """https://blog.csdn.net/dss_dssssd/article/details/89409183""" 484 | if model_saved_dir: 485 | self.model_saved_dir = model_saved_dir 486 | self.model_saved_path = os.path.join(model_saved_dir, "model.pth") 487 | self.config_file_path = os.path.join(model_saved_dir, "config_graph_sim.ini") 488 | self.config.read(self.config_file_path, encoding='utf-8') 489 | self.__load_super_paras() 490 | 491 | self.model = self.__new_model_obj() 492 | self.model = self.model.to(device) 493 | self.model.load_state_dict(torch.load(self.model_saved_path, map_location=device)) 494 | self.model.eval() 495 | pass 496 | 497 | def get_error_summary(did = 1): 498 | model_paths = [ 499 | "20200423-095359_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdmweight46", 500 | "20200423-092330_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdmweight46", 501 | "20200423-084810_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdm", 502 | "20200423-081103_amaxD2_n100step100dataset70attetionjumpCosineAnnealingLRRAdm", 503 | "20200423-132639_amaxD2_n100step100dataset70allattetionjumpReduceLROnPlateauLRRAdm" 504 | ] 505 | minf = ModelInference() 506 | dataset_name = "train_ticket" if did == 1 else "sock_shop" 507 | all_dir = os.path.join(os.path.dirname(__file__), "..", "data", "graph_sim_model_parameters", dataset_name) 508 | error_info = list() 509 | for root, dirs, files in os.walk(all_dir): 510 | # if all_dir.find(root) == -1 : 511 | # break 512 | for dir in dirs: 513 | if dir not in model_paths: 514 | continue 515 | model_dir = os.path.join(all_dir, dir) 516 | minf.load_model(model_saved_dir=model_dir) 517 | df_train = minf.judge_graph_class_ac(mode="train", save_error_excel=True, all_test=True) 518 | df_test = minf.judge_graph_class_ac(mode="test", save_error_excel=True, all_test=True) 519 | df_val = minf.judge_graph_class_ac(mode="val", save_error_excel=True, all_test=True) 520 | df_all = pd.concat([df_train, df_test, df_val]) 521 | online_data_path = list(set((df_all["online_data_path"]))) 522 | online_data_path.sort() 523 | e_name = list(set(df_all["e_name"])) 524 | e_name.sort() 525 | error_info.append(dict( 526 | model_name=dir, 527 | online_names=online_data_path, 528 | error_names=e_name 529 | )) 530 | online_name_sets = [set(info["online_names"]) for info in error_info] 531 | error_name_sets = [set(info["error_names"]) for info in error_info] 532 | online_final = online_name_sets[0] 533 | error_final = error_name_sets[0] 534 | for _ in range(1, len(online_name_sets)): 535 | online_final.intersection(online_name_sets[_]) 536 | for _ in range(1, len(error_name_sets)): 537 | error_final.intersection(error_name_sets[_]) 538 | error_info.insert(0, dict( 539 | online_names_jiaoset=sorted(list(online_final)), 540 | e_name_jiaoset=sorted(list(error_final)) 541 | )) 542 | 543 | save_json_data(os.path.join(os.path.dirname(__file__), "{}_error_info.json".format(dataset_name)), error_info) 544 | 545 | 546 | def save_json_data(save_path, pre_save_data): 547 | with open(save_path, 'w', encoding='utf-8') as file_writer: 548 | raw_data = json.dumps(pre_save_data, indent=4) 549 | file_writer.write(raw_data) 550 | 551 | 552 | 553 | if __name__ == '__main__': 554 | # # get_error_summary(did=2) 555 | minf = ModelInference() 556 | # # # minf.generate_labeled_data() 557 | minf.train_model() 558 | # # path = "E:\code\python\\neo4j_aliyun\kb_algorithm\data\graph_sim_model_parameters\sock_shop\\20200422-082211_amaxD2_n30step10datasetraw" 559 | # path = "/data/sudong/code/graphsim_part/kb_algorithm/data/graph_sim_model_parameters/sock_shop/20200422-113035_amaxD2_n100step20datasetraw" 560 | # path = "/data/sudong/code/graphsim_part/kb_algorithm/data/graph_sim_model_parameters/train_ticket/20200416-131444_DESKTOP-8FNC2N9D1_n100step20repeatposauto" 561 | # minf.load_model(model_saved_dir=path) 562 | # minf.judge_graph_class_ac(mode="train", save_error_excel=False) 563 | # minf.judge_graph_class_ac(mode="test", save_error_excel=False) 564 | # minf.judge_graph_class_ac(mode="val", save_error_excel=False) 565 | # minf.test_val_model(mode="train") 566 | # minf.test_val_model(mode="test") 567 | # minf.test_val_model(mode="val") 568 | # # logging.info("done!") 569 | -------------------------------------------------------------------------------- /graph_sim_no_kb_dej_X.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import socket 4 | import sys 5 | from datetime import datetime 6 | 7 | 8 | from collections import defaultdict 9 | from functools import partial 10 | from typing import Set, List, Any, Optional 11 | 12 | from torch.utils.data import DataLoader 13 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 14 | 15 | from Radm import RAdam 16 | 17 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 18 | import pandas as pd 19 | from DataSetGraphSimGenerator import DataSetGraphSimGenerator 20 | from tensorboard_logger import TensorBoardWritter 21 | from model_batch import * 22 | 23 | 24 | import configparser 25 | import json 26 | import logging 27 | import os 28 | import random 29 | import sys 30 | from copy import deepcopy 31 | 32 | import numpy as np 33 | import scipy.sparse as sp 34 | import pickle 35 | 36 | import torch 37 | from torch.utils.data import Dataset 38 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 39 | 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | def get_rank(y_true: Set[Any], y_pred: List[Any], max_rank: Optional[int] = None) -> List[float]: 43 | rank_dict = defaultdict(lambda: len(y_pred) + 1 if max_rank is None else (max_rank + len(y_pred)) / 2) 44 | for idx, item in enumerate(y_pred, start=1): 45 | if item in y_true: 46 | rank_dict[item] = idx 47 | return [rank_dict[_] for _ in y_true] 48 | 49 | # noinspection PyPep8Naming 50 | def MAR(y_true: List[Set[Any]], y_pred: List[List[Any]], max_rank: Optional[int] = None): 51 | return np.mean([ 52 | np.mean(get_rank(a, b, max_rank)) 53 | for a, b in zip(y_true, y_pred) 54 | ]) 55 | 56 | def read_json_data(read_path): 57 | with open(read_path, 'r', encoding='utf-8') as file_reader: 58 | raw_data = file_reader.read() 59 | paths_list = json.loads(raw_data) 60 | return paths_list 61 | 62 | 63 | def read_pickle_data(read_path): 64 | with open(read_path, 'rb') as file_reader: 65 | return pickle.load(file_reader) 66 | 67 | 68 | def save_json_data(save_path, pre_save_data): 69 | with open(save_path, 'w', encoding='utf-8') as file_writer: 70 | raw_data = json.dumps(pre_save_data, indent=4) 71 | file_writer.write(raw_data) 72 | 73 | 74 | def save_pickle_data(save_path, pre_save_data): 75 | with open(save_path, 'wb') as f: 76 | pickle.dump(pre_save_data, f, pickle.HIGHEST_PROTOCOL) 77 | 78 | class CustomDatasetNoKB(Dataset):#需要继承data.Dataset 79 | #https://blog.csdn.net/liuweiyuxiang/article/details/84037973 80 | def __init__(self, data_set_id, dataset_version=None, mode="train", max_node_num=100, repeat_pos_data=0, resplit=False, 81 | random_expand_train_data_flag=False, dataset_dir=None): 82 | self.data_set_id = data_set_id 83 | self.dataset_name = "train_ticket" if data_set_id == 1 else "sock_shop" 84 | dataset_special = "{}_{}".format(self.dataset_name, dataset_version) if dataset_version else self.dataset_name 85 | self.dataset_dir = os.path.join(os.path.dirname(__file__), "..", "data", "data_for_graph_sim", 86 | dataset_special) 87 | 88 | if dataset_dir: 89 | self.dataset_dir = dataset_dir 90 | 91 | self.train_data_path = os.path.join(self.dataset_dir, "train_labeled_data.json") 92 | self.test_data_path = os.path.join(self.dataset_dir, "test_labeled_data.json") 93 | self.validation_data_path = os.path.join(self.dataset_dir, "validation_labeled_data.json") 94 | self.label_data_path = os.path.join(self.dataset_dir, "labeled_data.json") 95 | self.label_graph_class_path = os.path.join(self.dataset_dir, "labeled_graph_class_data.json") 96 | 97 | self.max_node_num = max_node_num 98 | self.repeat_pos_data = repeat_pos_data 99 | self.mode = mode 100 | if resplit or not os.path.exists(self.train_data_path): 101 | self.split_labeled_data() 102 | if random_expand_train_data_flag: 103 | self.random_expand_train_data(expand_num=9, delete_part=0.1) 104 | if self.repeat_pos_data: 105 | self.guocaiyang_train() 106 | if mode == "train": 107 | self.labeled_data = read_json_data(self.train_data_path) 108 | elif mode == "test": 109 | self.labeled_data = read_json_data(self.test_data_path) 110 | elif mode == "val": 111 | self.labeled_data = read_json_data(self.validation_data_path) 112 | else: 113 | self.labeled_data = read_json_data(self.label_data_path) 114 | 115 | # 在没有kb的场景中 只关心 正样本 116 | label_graph_class_data = read_json_data(self.label_graph_class_path) 117 | error_name_list = label_graph_class_data["error_name_list"] 118 | error_name_list.sort() 119 | 120 | df = pd.DataFrame(self.labeled_data, columns=["o_p", "f_p", "label"]) 121 | self.labeled_data = df[df["label"]==1].values.tolist() 122 | self.class_names = error_name_list 123 | 124 | def __getitem__(self, index): 125 | online_path, kb_path, label = self.labeled_data[index] 126 | online_path = (os.path.dirname(__file__) + online_path).replace("\\", os.sep).replace("/", os.sep) 127 | kb_path = (os.path.dirname(__file__) + kb_path).replace("\\", os.sep).replace("/", os.sep) 128 | # print("online_path") 129 | # print(online_path) 130 | online_path = online_path.split('/')[-1] 131 | online_real_path = os.path.join("xxxx/pickle_data",online_path) 132 | graph_online = read_pickle_data(online_real_path) 133 | # print("graph_online") 134 | # print(graph_online) 135 | kb_path = kb_path.split('/')[-1] 136 | kb_real_path = os.path.join("xxxx/pickle_data",kb_path) 137 | graph_kb = read_pickle_data(kb_real_path) 138 | graph_online_feature, graph_online_A_list = self.process_graph( 139 | (graph_online["fetures"], graph_online["adj"], graph_online["node_index_value"])) 140 | graph_kb_feature, graph_kb_A_list = self.process_graph( 141 | (graph_kb["fetures"], graph_kb["adj"], graph_kb["node_index_value"])) 142 | # 由类别序号作为label 143 | class_name = str(kb_path.split(os.sep)[-1].split("___")[0].replace("kb_","")) 144 | label = self.class_names.index(class_name) 145 | 146 | sample = { 147 | 'graph_online_adj': torch.as_tensor(graph_online_A_list, dtype=torch.float32, device=device), 148 | 'graph_online_feature': torch.as_tensor(graph_online_feature, dtype=torch.float32, device=device), 149 | 'graph_kb_adj': torch.as_tensor(graph_kb_A_list, dtype=torch.float32, device=device), 150 | 'graph_kb_feature': torch.as_tensor(graph_kb_feature, dtype=torch.float32, device=device), 151 | 'label': torch.as_tensor(label, dtype=torch.long, device=device) 152 | } 153 | return sample 154 | 155 | def __len__(self): 156 | # You should change 0 to the total size of your dataset. 157 | return len(self.labeled_data) 158 | 159 | def graph_class_data(self, use_best_path=False): 160 | """ 161 | 返回 一个online 图与所有 kb图 两两成对的数据 162 | :param use_best_path: 163 | :param graph_accuary: 小数 表示正确点数占比 164 | :param online_offline_proportion: 小数或整数 表示online:offline 's proportion 165 | :return: 166 | """ 167 | online_path_in_mode = list(set([o_p for o_p, kb_p, label in self.labeled_data])) 168 | label_graph_class_data = read_json_data(self.label_graph_class_path) 169 | online_info = label_graph_class_data["online_info"] 170 | kb_data = label_graph_class_data["kb_data"] 171 | error_name_list = label_graph_class_data["error_name_list"] 172 | error_name_list.sort() 173 | 174 | for online_data_path, e_name, e_index in online_info: 175 | # 只加载该mode状态里的数据 176 | if online_data_path not in online_path_in_mode: 177 | continue 178 | online_path = (os.path.dirname(__file__) + online_data_path).replace("\\", os.sep).replace("/", os.sep) 179 | if use_best_path: 180 | online_path = online_path.replace(self.dataset_name, "{}_best".format(self.dataset_name)) 181 | online_path = online_path.split('/')[-1] 182 | online_real_path = os.path.join("/home/mfm/experiment/kb_algorithm/graph_sim/data/data_for_graph_sim/train_ticket_final_same/pickle_data",online_path) 183 | 184 | graph_online = read_pickle_data(online_real_path) 185 | 186 | graph_online_feature, graph_online_A_list = self.process_graph( 187 | (graph_online["fetures"], graph_online["adj"], graph_online["node_index_value"])) 188 | graph_online_adj_list, graph_online_feature_list, graph_kb_adj_list = list(), list(), list() 189 | graph_kb_feature_list, label_list = list(), list() 190 | 191 | for e_kb_name in [e_name]: 192 | kb_data_path = kb_data[e_kb_name] 193 | label = error_name_list.index(e_name) 194 | kb_path = (os.path.dirname(__file__) + kb_data_path).replace("\\", os.sep).replace("/", os.sep) 195 | if use_best_path: 196 | kb_path = kb_path.replace(self.dataset_name, "{}_best".format(self.dataset_name)) 197 | kb_path = kb_path.split('/')[-1] 198 | kb_real_path = os.path.join("/home/mfm/experiment/kb_algorithm/graph_sim/data/data_for_graph_sim/train_ticket_final_same/pickle_data",kb_path) 199 | 200 | graph_kb = read_pickle_data(kb_real_path) 201 | 202 | 203 | graph_kb_feature, graph_kb_A_list = self.process_graph( 204 | (graph_kb["fetures"], graph_kb["adj"], graph_kb["node_index_value"])) 205 | graph_online_adj_list.append(torch.tensor(graph_online_A_list, dtype=torch.float32, device=device)) 206 | graph_online_feature_list.append(torch.tensor(graph_online_feature, dtype=torch.float32, device=device)) 207 | graph_kb_adj_list.append(torch.tensor(graph_kb_A_list, dtype=torch.float32, device=device)) 208 | graph_kb_feature_list.append(torch.tensor(graph_kb_feature, dtype=torch.float32, device=device)) 209 | label_list.append(torch.tensor(label, dtype=torch.long, device=device)) 210 | sample = { 211 | 'graph_online_adj': torch.stack(graph_online_adj_list), 212 | 'graph_online_feature': torch.stack(graph_online_feature_list), 213 | 'graph_kb_adj': torch.stack(graph_kb_adj_list), 214 | 'graph_kb_feature': torch.stack(graph_kb_feature_list), 215 | 'label': torch.stack(label_list), 216 | 'error_name_list': error_name_list, 217 | 'online_data_path': online_data_path, 218 | } 219 | 220 | yield sample, (online_data_path, e_name, e_index, error_name_list) 221 | 222 | def graph_class_data_bk(self): 223 | online_path_in_mode = list(set([o_p for o_p, kb_p, label in self.labeled_data])) 224 | label_graph_class_data = read_json_data(self.label_graph_class_path) 225 | online_info = label_graph_class_data["online_info"] 226 | kb_data = label_graph_class_data["kb_data"] 227 | error_name_list = label_graph_class_data["error_name_list"] 228 | error_name_list.sort() 229 | for online_data_path, e_name, e_index in online_info: 230 | # 只加载该mode状态里的数据 231 | if online_data_path not in online_path_in_mode: 232 | continue 233 | online_path = (os.path.dirname(__file__) + online_data_path).replace("\\", os.sep).replace("/", os.sep) 234 | graph_online = read_pickle_data(online_path) 235 | graph_online_feature, graph_online_A_list = self.process_graph( 236 | (graph_online["fetures"], graph_online["adj"], graph_online["node_index_value"])) 237 | graph_online_adj_list, graph_online_feature_list, graph_kb_adj_list = list(), list(), list() 238 | graph_kb_feature_list, label_list = list(), list() 239 | for e_kb_name in error_name_list: 240 | kb_data_path = kb_data[e_kb_name] 241 | label = 1.0 if e_name == e_kb_name else 0. 242 | kb_path = (os.path.dirname(__file__) + kb_data_path).replace("\\", os.sep).replace("/", os.sep) 243 | graph_kb = read_pickle_data(kb_path) 244 | graph_kb_feature, graph_kb_A_list = self.process_graph( 245 | (graph_kb["fetures"], graph_kb["adj"], graph_kb["node_index_value"])) 246 | graph_online_adj_list.append(torch.tensor(graph_online_A_list, dtype=torch.float32, device=device)) 247 | graph_online_feature_list.append(torch.tensor(graph_online_feature, dtype=torch.float32, device=device)) 248 | graph_kb_adj_list.append(torch.tensor(graph_kb_A_list, dtype=torch.float32, device=device)) 249 | graph_kb_feature_list.append(torch.tensor(graph_kb_feature, dtype=torch.float32, device=device)) 250 | label_list.append(torch.tensor(label, dtype=torch.long, device=device)) 251 | sample = { 252 | 'graph_online_adj': torch.stack(graph_online_adj_list), 253 | 'graph_online_feature': torch.stack(graph_online_feature_list), 254 | 'graph_kb_adj': torch.stack(graph_kb_adj_list), 255 | 'graph_kb_feature': torch.stack(graph_kb_feature_list), 256 | 'label': torch.stack(label_list), 257 | 'error_name_list': error_name_list, 258 | 'online_data_path': online_data_path, 259 | } 260 | assert torch.sum(sample["label"]).item() == 1.0 261 | 262 | yield sample, (online_data_path, e_name, e_index, error_name_list) 263 | 264 | def pos_neg_num(self): 265 | pos, neg = 0, 0 266 | for data in self.labeled_data: 267 | if data[2] == 1: 268 | pos += 1 269 | elif data[2] == 0: 270 | neg += 1 271 | return pos, neg 272 | 273 | def print_data_set_info(self): 274 | pos, neg = self.pos_neg_num() 275 | logging.error("dataset_id:{} mode:{} pos:{} neg:{}".format(self.data_set_id, self.mode, 276 | pos, neg)) 277 | 278 | def random_expand_train_data(self, expand_num, delete_part=0.1): 279 | 280 | # 获取训练数据 中的online pickle ptha 281 | train_labeled_data = read_json_data(self.train_data_path) 282 | train_online_timepieces = list(set([data[0] for data in train_labeled_data])) 283 | df_raw = pd.DataFrame(train_labeled_data, columns=["o_p", "f_p", "label"]) 284 | # 对于每个pickle 随机删除一些点和 边 285 | for piece in train_online_timepieces: 286 | # 匹配的 kb 和不匹配的kb 287 | match_kb = list(df_raw[(df_raw["o_p"]==piece) & (df_raw["label"]==1)]["f_p"])[0] 288 | no_match_kb = list(df_raw[(df_raw["o_p"]==piece) & (df_raw["label"]==0)]["f_p"]) 289 | 290 | online_path = (os.path.dirname(__file__) + piece).replace("\\", os.sep).replace("/", os.sep) 291 | graph_online = read_pickle_data(online_path) 292 | for index in range(expand_num): 293 | graph_online_copy = deepcopy(graph_online) 294 | new_pickle_path = online_path.replace(".pickle", "diy{}.pickle".format(index)) 295 | 296 | index_delete = random.sample(range(graph_online_copy["node_num"]), int(graph_online_copy["node_num"] * (1-delete_part))) 297 | id_delete = [graph_online_copy["node_index_id"][index] for index in index_delete] 298 | 299 | # 删除 结点关系字典 300 | for node in graph_online["nodes_dict"]: 301 | if node["id"] in id_delete: 302 | graph_online_copy["nodes_dict"].remove(node) 303 | 304 | for relation in graph_online["relations_dict"]: 305 | if (relation["_start_node_id"] in id_delete) or (relation["_end_node_id"] in id_delete): 306 | graph_online_copy["relations_dict"].remove(relation) 307 | graph_online_copy["relation_num"] = len(graph_online_copy["relations_dict"]) 308 | graph_online_copy["node_num"] = len(graph_online_copy["nodes_dict"]) 309 | 310 | for id in id_delete: 311 | del graph_online_copy["node_id_index"][id] 312 | del graph_online_copy["node_id_value"][id] 313 | 314 | graph_online_copy["node_index_id"] = np.delete(graph_online_copy["node_index_id"], np.array(index_delete), axis=0).tolist() 315 | graph_online_copy["node_index_value"] = np.delete(graph_online_copy["node_index_value"], np.array(index_delete), axis=0).tolist() 316 | 317 | graph_online_copy["fetures"] = np.delete(graph_online_copy["fetures"], np.array(index_delete), axis=0) 318 | 319 | graph_online_copy["adj"] = np.delete(graph_online_copy["adj"], np.array(index_delete), axis=0) 320 | graph_online_copy["adj"] = np.delete(graph_online_copy["adj"], np.array(index_delete), axis=1) 321 | 322 | loc = np.where(graph_online_copy["adj"]==1) 323 | graph_online_copy["adj_sparse"] = sp.csr_matrix((np.ones(loc[0].shape), (loc[0], loc[1])), shape=graph_online_copy["adj"].shape, 324 | dtype=np.int8) 325 | 326 | save_pickle_data(new_pickle_path, graph_online_copy) 327 | replace_part = str(os.path.dirname(__file__)) 328 | replace_part = replace_part.replace("\\", os.sep).replace("/", os.sep) 329 | save_pickle_path = new_pickle_path.replace(replace_part, "") 330 | train_labeled_data.append([save_pickle_path, match_kb, 1]) 331 | for kb in no_match_kb: 332 | train_labeled_data.append([save_pickle_path, kb, 0]) 333 | 334 | df_expand = pd.DataFrame(train_labeled_data, columns=["o_p", "f_p", "label"]) 335 | assert (len(df_raw) * (expand_num+1)) == len(df_expand) 336 | save_json_data(self.train_data_path, train_labeled_data) 337 | if self.mode == "train": 338 | self.labeled_data = read_json_data(self.train_data_path) 339 | logging.warning("expand train data done! raw:{}+{}={} new{}+{}={} expand_num:{}".format( 340 | len(df_raw[df_raw["label"]==0]), 341 | len(df_raw[df_raw["label"]==1]), 342 | len(df_raw), 343 | len(df_expand[df_expand["label"]==0]), 344 | len(df_expand[df_expand["label"]==1]), 345 | len(df_expand), 346 | expand_num 347 | )) 348 | 349 | def guocaiyang_train(self): 350 | """过采样训练集正样本""" 351 | train_data = read_json_data(self.train_data_path) 352 | df_train = pd.DataFrame(train_data, columns=["o_p", "kb_p", "label"]) 353 | pos_train = df_train[df_train["label"]==1] 354 | neg_train = df_train[df_train["label"]==0] 355 | 356 | pos_train = self.repeat_df(pos_train, self.repeat_pos_data, len(neg_train) // len(pos_train)) 357 | train = np.array(pd.concat([pos_train, neg_train]).sample(frac=1)).tolist() 358 | save_json_data(self.train_data_path, train) 359 | 360 | def repeat_df(self, df_data, repeat_pos_data_flag, neg_pos): 361 | if repeat_pos_data_flag >= 1: 362 | if repeat_pos_data_flag == 1: 363 | # 自适应 364 | if neg_pos > 1: 365 | df_data = pd.concat([df_data] * neg_pos) 366 | else: 367 | df_data = pd.concat([df_data] * neg_pos) 368 | return df_data 369 | 370 | def split_labeled_data(self): 371 | all_data = read_json_data(self.label_data_path) 372 | import pandas as pd 373 | 374 | def repeat_df(df_data, repeat_pos_data_flag, neg_pos): 375 | if repeat_pos_data_flag >= 1: 376 | if repeat_pos_data_flag == 1: 377 | # 自适应 378 | if neg_pos > 1: 379 | df_data = pd.concat([df_data] * neg_pos) 380 | else: 381 | df_data = pd.concat([df_data] * neg_pos) 382 | return df_data 383 | 384 | df = pd.DataFrame(all_data, columns=["o_p", "kb_p", "label"]) 385 | 386 | # # 方式1 将online时间段随机分为6;2:2 387 | # online_pickle_names = list(set(df["o_p"])) 388 | # random.shuffle(online_pickle_names) 389 | # occupy = [0.6, 0.2, 0.2] 390 | # train_num, test_num, val_num = int(occupy[0] * len(online_pickle_names)), int(occupy[1] * len(online_pickle_names)), int( 391 | # occupy[2] * len(online_pickle_names)) 392 | # train_online_names = online_pickle_names[:train_num] 393 | # test_online_names = online_pickle_names[train_num:train_num + test_num] 394 | # val_online_names = online_pickle_names[train_num + test_num:] 395 | # 方式2 需要确保每种故障类型 都能 包含训练\测试\验证数据 396 | train_online_names, test_online_names, val_online_names = list(), list(), list() 397 | kb_unique_names = list(set(df["kb_p"])) 398 | for kb_name in kb_unique_names: 399 | new_df = df[(df["kb_p"]==kb_name) & (df["label"]==1)] 400 | kb_online_names = list(set(new_df["o_p"])) 401 | random.shuffle(kb_online_names) 402 | occupy = [0.6, 0.3, 0.1] 403 | train_num, test_num, val_num = int(occupy[0] * len(kb_online_names)), int( 404 | occupy[1] * len(kb_online_names)), int( 405 | occupy[2] * len(kb_online_names)) 406 | train_l = kb_online_names[:train_num] 407 | test_l = kb_online_names[train_num:train_num + test_num] 408 | val_l = kb_online_names[train_num + test_num:] 409 | train_online_names.extend(train_l) 410 | test_online_names.extend(test_l) 411 | val_online_names.extend(val_l) 412 | logging.info("kb_name:{}\n online_times:{} train:{} test:{} val:{}".format(kb_name, len(new_df), len(train_l), len(test_l), len(val_l))) 413 | 414 | df_train = pd.concat([df.loc[df["o_p"]==name] for name in train_online_names]) 415 | df_test = pd.concat([df.loc[df["o_p"]==name] for name in test_online_names]) 416 | df_val = pd.concat([df.loc[df["o_p"]==name] for name in val_online_names]) 417 | # pos_df = repeat_df(pos_df, self.repeat_pos_data, len(neg_df) // len(pos_df)) 418 | pos_df = df[df["label"]==1] 419 | neg_df = df[df["label"]==0] 420 | 421 | pos_train = df_train[df_train["label"]==1] 422 | # pos_train = self.repeat_df(pos_train, self.repeat_pos_data, len(neg_df) // len(pos_df)) 423 | pos_test = df_test[df_test["label"]==1] 424 | # pos_test = repeat_df(pos_test, self.repeat_pos_data, len(neg_df) // len(pos_df)) 425 | pos_val = df_val[df_val["label"]==1] 426 | # pos_val = repeat_df(pos_val, self.repeat_pos_data, len(neg_df) // len(pos_df)) 427 | 428 | neg_train = df_train[df_train["label"]==0] 429 | neg_test = df_test[df_test["label"]==0] 430 | neg_val = df_val[df_val["label"]==0] 431 | 432 | train = np.array(pd.concat([pos_train, neg_train]).sample(frac=1)).tolist() 433 | test = np.array(pd.concat([pos_test, neg_test]).sample(frac=1)).tolist() 434 | val = np.array(pd.concat([pos_val, neg_val]).sample(frac=1)).tolist() 435 | 436 | save_json_data(self.train_data_path, train) 437 | save_json_data(self.test_data_path, test) 438 | save_json_data(self.validation_data_path, val) 439 | logging.info("split data done! all:{} train:{}({},{}) test:{}({},{}) val:{}({},{})".format(len(df), 440 | len(train),len(pos_train), len(neg_train), 441 | len(test), len(pos_test), len(neg_test), 442 | len(val), len(pos_val), len(neg_val))) 443 | 444 | def split_labeled_data_backup(self): 445 | all_data = read_json_data(self.label_data_path) 446 | import pandas as pd 447 | 448 | def repeat_df(df_data, repeat_pos_data_flag, neg_pos): 449 | if repeat_pos_data_flag >= 1: 450 | if repeat_pos_data_flag == 1: 451 | # 自适应 452 | if neg_pos > 1: 453 | df_data = pd.concat([df_data] * neg_pos) 454 | else: 455 | df_data = pd.concat([df_data] * neg_pos) 456 | return df_data 457 | 458 | df = pd.DataFrame(all_data, columns=["o_p", "kb_p", "label"]) 459 | pos_df = df[df["label"] == 1] 460 | neg_df = df[df["label"] == 0] 461 | # pos_df = repeat_df(pos_df, self.repeat_pos_data, len(neg_df) // len(pos_df)) 462 | 463 | pos_df = pos_df.sample(frac=1) 464 | neg_df = neg_df.sample(frac=1) 465 | 466 | occupy = [0.6, 0.2, 0.2] 467 | train_num, test_num, val_num = int(occupy[0] * len(pos_df)), int(occupy[1] * len(pos_df)), int(occupy[2] * len(pos_df)) 468 | pos_train = pos_df[:train_num] 469 | pos_train = repeat_df(pos_train, self.repeat_pos_data, len(neg_df) // len(pos_df)) 470 | pos_test = pos_df[train_num:train_num + test_num] 471 | pos_test = repeat_df(pos_test, self.repeat_pos_data, len(neg_df) // len(pos_df)) 472 | pos_val = pos_df[train_num + test_num:] 473 | pos_val = repeat_df(pos_val, self.repeat_pos_data, len(neg_df) // len(pos_df)) 474 | 475 | train_num, test_num, val_num = int(occupy[0] * len(neg_df)), int(occupy[1] * len(neg_df)), int(occupy[2] * len(neg_df)) 476 | neg_train = neg_df[:train_num] 477 | neg_test = neg_df[train_num:train_num + test_num] 478 | neg_val = neg_df[train_num + test_num:] 479 | 480 | train = np.array(pd.concat([pos_train, neg_train]).sample(frac=1)).tolist() 481 | test = np.array(pd.concat([pos_test, neg_test]).sample(frac=1)).tolist() 482 | val = np.array(pd.concat([pos_val, neg_val]).sample(frac=1)).tolist() 483 | 484 | save_json_data(self.train_data_path, train) 485 | save_json_data(self.test_data_path, test) 486 | save_json_data(self.validation_data_path, val) 487 | logging.info("split data done! all:{} train:{} test:{} val:{}".format(len(df), len(train), len(test), len(val))) 488 | 489 | def process_graph(self, graph): 490 | """ 491 | 将一个feature矩阵和邻接矩阵 padding 后 转为 [feature, r_r, r_reverse, r_self] 492 | :param graph: (feature矩阵:np.array, 邻接稀疏矩阵: sp.csr_matrix, index_关键性:list) 493 | :param max_node_num: 494 | :return: 495 | """ 496 | A_list = list() 497 | feature = graph[0] 498 | adj = graph[1] 499 | logging.info("process_graph_init: features:{} adj:{}".format(graph[0].shape, graph[1].shape)) 500 | 501 | max_node_num = self.max_node_num 502 | if max_node_num - feature.shape[0] > 0: 503 | """ 需要 padding""" 504 | padding_num = max_node_num - feature.shape[0] 505 | # 处理特征 506 | feature_new = np.concatenate([feature, np.zeros((padding_num, feature.shape[1]))], axis=0) 507 | adj_new = np.concatenate([adj, np.zeros((padding_num, adj.shape[1]))], axis=0) 508 | adj_r = np.concatenate([adj_new, np.zeros((adj_new.shape[0], padding_num))], axis=1) 509 | adj_l = np.transpose(adj_r) 510 | 511 | elif max_node_num - feature.shape[0] < 0: 512 | """ 需要删除一些 """ 513 | # 处理邻接矩阵 514 | num_delete = feature.shape[0] - max_node_num 515 | node_index_value = {_: graph[2][_] for _ in range(len(graph[2]))} 516 | for _ in graph[2]: 517 | # assert not isinstance(_, str) 518 | if isinstance(_, str): 519 | node_index_value = {_: 0 for _ in range(len(node_index_value))} 520 | logging.warning("some value is str!") 521 | break 522 | 523 | node_index_value_sort = sorted(node_index_value.items(), key=lambda x: x[1]) 524 | indexs_delete = np.array([node_index_value_sort[_][0] for _ in range(num_delete)]) 525 | # 处理特征矩阵 526 | feature_new = np.delete(feature, indexs_delete, axis=0) 527 | # 处理邻接矩阵 528 | adj_r_array = adj 529 | adj_r_array = np.delete(adj_r_array, indexs_delete, axis=0) 530 | adj_r_array = np.delete(adj_r_array, indexs_delete, axis=1) 531 | 532 | adj_r = adj_r_array 533 | adj_l = np.transpose(adj_r) 534 | else: 535 | """形状不需要改变""" 536 | feature_new = feature 537 | adj_r = adj 538 | adj_l = np.transpose(adj_r) 539 | 540 | A_list.append(adj_r) 541 | A_list.append(adj_l) 542 | A_list.append(np.identity(adj_r.shape[0])) 543 | logging.info( 544 | "process_graph_done: features:{} adj_r:{} adj_l:{} self:{}".format(feature_new.shape, adj_r.shape, 545 | adj_l.shape, A_list[2].shape)) 546 | return feature_new, np.array(A_list) 547 | 548 | 549 | class ModelInferenceNoKb: 550 | def __init__(self): 551 | # 初始化配置 552 | logging.error("device:{}".format(device)) 553 | torch.set_printoptions(linewidth=120) 554 | torch.set_grad_enabled(True) 555 | np.random.seed(5) 556 | torch.manual_seed(0) 557 | 558 | # 超参数配置文件 559 | self.config = configparser.ConfigParser() 560 | self.config_file_path = os.path.join(os.path.dirname(__file__), "config_graph_sim_nokb.ini") 561 | self.config.read(self.config_file_path, encoding='utf-8') 562 | 563 | # 超参数 564 | self.__load_super_paras() 565 | self.cross_weight_auto = None 566 | 567 | # 模型 568 | self.model = None 569 | self.model_saved_path = None 570 | self.model_saved_dir = None 571 | 572 | # tensor board类 573 | self.tb_comment = self.data_set_name 574 | self.tb_logger = None 575 | 576 | 577 | label_graph_class_path = os.path.join(os.path.dirname(__file__), "..", "data", "data_for_graph_sim", 578 | "{}_{}".format(self.data_set_name, self.dataset_version), "labeled_graph_class_data.json") 579 | label_graph_class_data = read_json_data(label_graph_class_path) 580 | error_name_list = label_graph_class_data["error_name_list"] 581 | error_name_list.sort() 582 | 583 | self.class_names = error_name_list 584 | 585 | # 控制台打印 586 | # coloredlogs.install( 587 | # level=self.logging_print_level, 588 | # fmt="[%(levelname)s] [%(asctime)s] [%(filename)s:%(lineno)d] %(message)s", 589 | # level_styles=LEVEL_STYLES, 590 | # field_styles=FIELD_STYLES, 591 | # logger=logger 592 | # ) 593 | 594 | pass 595 | 596 | def __load_super_paras(self): 597 | self.data_set_id = self.config.getint("data", "DATASET") 598 | self.data_set_name = "train_ticket" if self.data_set_id==1 else "sock_shop" 599 | self.input_dim = self.config.getint("model", "input_dim") 600 | self.gcn_hidden_dim = self.config.getint("model", "gcn_hidden_dim") 601 | self.linear_hidden_dim = self.config.getint("model", "linear_hidden_dim") 602 | self.num_bases = self.config.getint("model", "num_bases") 603 | self.dropout = self.config.getfloat("model", "dropout") 604 | self.support = self.config.getint("model", "support") 605 | self.max_node_num = self.config.getint("model", "max_node_num") 606 | self.pool_step = self.config.getint("model", "pool_step") 607 | self.lr = self.config.getfloat("train", "LR") 608 | self.weight_decay = self.config.getfloat("train", "l2norm") 609 | self.resplit = self.config.getboolean("data", "resplit") 610 | self.batch_size = self.config.getint("data", "batch_size") 611 | self.resplit_each_time = self.config.getboolean("data", "resplit_each_time") 612 | self.repeat_pos_data = self.config.getint("data", "repeat_pos_data") 613 | self.dataset_version = self.config.get("data", "dataset_version") 614 | 615 | self.epoch = self.config.getint("train", "NB_EPOCH") 616 | self.user_comment = self.config.get("train", "comment") 617 | # self.cross_weight = self.config.getfloat("train", "cross_weight") 618 | # self.logging_print_level = str(self.config.get("print_logging", "level")) 619 | self.criterion = F.cross_entropy 620 | 621 | label_graph_class_path = os.path.join(os.path.dirname(__file__), "..", "data", "data_for_graph_sim", 622 | "{}_{}".format(self.data_set_name, self.dataset_version), 623 | "labeled_graph_class_data.json") 624 | label_graph_class_data = read_json_data(label_graph_class_path) 625 | error_name_list = label_graph_class_data["error_name_list"] 626 | error_name_list.sort() 627 | 628 | self.class_names = error_name_list 629 | 630 | 631 | def __start_tb_logger(self, time_str): 632 | # self.start_time = str(datetime.now().strftime("%Y%m%d-%H%M%S")) 633 | self.tb_log_dir = os.path.join(os.path.dirname(__file__), 'runs/%s' % time_str 634 | ).replace("\\", os.sep).replace("/", os.sep) 635 | self.tb_logger = TensorBoardWritter(log_dir="{}_{}{}".format(self.tb_log_dir, socket.gethostname(), self.tb_comment + self.user_comment), 636 | comment=self.tb_comment) 637 | 638 | def __stop_tb_logger(self): 639 | del self.tb_logger 640 | self.tb_logger = None 641 | 642 | def __print_paras(self, model): 643 | for name, param in model.named_parameters(): 644 | logging.warning("name:{} param:{}".format(name, param.requires_grad)) 645 | 646 | def generate_labeled_data(self): 647 | ds = DataSetGraphSimGenerator(data_set_id=self.data_set_id, dataset_version=self.dataset_version) 648 | ds.generate_dataset_pickle() 649 | del ds 650 | pass 651 | 652 | def __new_model_obj(self): 653 | return GraphSimilarity_No_KB(input_dim=self.input_dim, 654 | gcn_hidden_dim=self.gcn_hidden_dim, 655 | linear_hidden_dim=self.linear_hidden_dim, 656 | out_dim=len(self.class_names), 657 | pool_step=self.pool_step, 658 | num_bases=self.num_bases, 659 | dropout=self.dropout, 660 | support=self.support, 661 | max_node_num=self.max_node_num) 662 | 663 | def __print_data_info(self): 664 | train_data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 665 | repeat_pos_data=self.repeat_pos_data, resplit=False) 666 | test_data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="test", 667 | repeat_pos_data=self.repeat_pos_data, resplit=False) 668 | val_data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="val", 669 | repeat_pos_data=self.repeat_pos_data, resplit=False) 670 | train_data.print_data_set_info() 671 | test_data.print_data_set_info() 672 | val_data.print_data_set_info() 673 | for datas in [train_data, test_data, val_data]: 674 | for index, data in enumerate(datas): 675 | adj_1 = np.array(data["graph_online_adj"].cpu())[0] 676 | f_1 = np.array(data["graph_online_feature"].cpu()) 677 | adj_2 = np.array(data["graph_kb_adj"].cpu())[0] 678 | f_2 = np.array(data["graph_kb_feature"].cpu()) 679 | self.tb_logger.writer.add_histogram("graph_online/adj", adj_1, index) 680 | self.tb_logger.writer.add_histogram("graph_online/feature", f_1, index) 681 | self.tb_logger.writer.add_histogram("graph_kb/adj", adj_2, index) 682 | self.tb_logger.writer.add_histogram("graph_kb/feature", f_2, index) 683 | 684 | def crossentropy_loss(self, output, label, num_list): 685 | """ num_list 表示 从 0,1,2,3每种类别的数目 本处只有两个类别[不相似,相似]""" 686 | # 方式1 直接翻转后除以总数 687 | num_list.reverse() 688 | weight_ = torch.as_tensor(num_list, dtype=torch.float32, device=device) 689 | weight_ = weight_ / torch.sum(weight_) 690 | # 方式2中值平均 691 | # weight_ = torch.as_tensor(num_list, dtype=torch.float32, device=device) 692 | # weight_ = torch.mean(weight_) * torch.rsqrt(weight_) 693 | self.cross_weight_auto = np.array(weight_.cpu()) 694 | # return self.criterion(output, label, weight=torch.as_tensor([0.4,0.6], dtype=torch.float32, device=device)) 695 | return self.criterion(output, label) 696 | 697 | def train_model(self): 698 | """训练并记录参数和模型""" 699 | start_time_train = str(datetime.now().strftime("%Y%m%d-%H%M%S")) 700 | self.__start_tb_logger(time_str= start_time_train) 701 | # 模型 702 | self.model = self.__new_model_obj() 703 | self.__print_paras(self.model) 704 | self.model = self.model.to(device) 705 | 706 | # 交叉熵 707 | criterion = self.crossentropy_loss 708 | optimizer = RAdam(self.model.parameters(), 709 | lr=self.lr, 710 | weight_decay=self.weight_decay) 711 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, 712 | patience=8, threshold=1e-4, threshold_mode="rel", 713 | cooldown=0, min_lr=0, eps=1e-8) 714 | 715 | train_data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 716 | repeat_pos_data=self.repeat_pos_data, resplit=self.resplit) 717 | self.__print_data_info() 718 | pos_train_num, neg_train_num = train_data.pos_neg_num() 719 | 720 | # 训练 721 | for epoch in range(self.epoch): 722 | if self.resplit_each_time: 723 | train_data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode="train", 724 | repeat_pos_data=self.repeat_pos_data, resplit=self.resplit) 725 | train_loader = DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True) 726 | loss_all = 0 727 | accuary_all_num = 0 728 | preds_all_num = 0 729 | FN, FP, TN, TP = 0, 0, 0, 0 730 | batch_num = 0 731 | outputs_all = list() 732 | for batch in train_loader: 733 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 734 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 735 | labels = batch["label"] 736 | outputs = self.model(graphs_online, graphs_offline) 737 | outputs_all.append(outputs) 738 | loss = criterion(outputs, labels, num_list=[neg_train_num, pos_train_num]) 739 | 740 | preds = torch.argmax(outputs, dim=1) 741 | accuary_all_num += torch.sum(preds == labels) 742 | preds_all_num += torch.as_tensor(labels.shape[0]) 743 | FN += int(torch.sum(preds[labels == 1] == 0)) 744 | FP += int(torch.sum(preds[labels == 0] == 1)) 745 | TN += int(torch.sum(preds[labels == 0] == 0)) 746 | TP += int(torch.sum(preds[labels == 1] == 1)) 747 | 748 | batch_num += 1 749 | loss_all += loss 750 | optimizer.zero_grad() 751 | loss.backward() 752 | optimizer.step() 753 | scheduler.step(loss_all) 754 | sample_data = None 755 | if epoch == 0: 756 | batch = next(iter(train_loader)) 757 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 758 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 759 | sample_data = (graphs_online, graphs_offline) 760 | 761 | 762 | recall_train = TP / (TP + FN) if (TP + FN) else 0 763 | precision_train = TP / (TP + FP) if (TP + FP) else 0 764 | F1_train = ((2 * precision_train * recall_train) / (precision_train + recall_train)) if (precision_train and recall_train) else 0 765 | accuary_val, loss_all_val, base_ac_val, precision_val, recall_val, F1_val = self.test_val_model(mode="val") 766 | accuary_test, loss_all_test, base_ac_test, precision_test, recall_test, F1_test = self.test_val_model(mode="test") 767 | class_ac_train = self.judge_graph_class_ac(mode="train") 768 | class_ac_val = self.judge_graph_class_ac(mode="val") 769 | class_ac_test = self.judge_graph_class_ac(mode="test") 770 | accuary_train = accuary_all_num.item() / preds_all_num.item() 771 | info_dict = dict( 772 | sample_data=sample_data, step=epoch, loss=loss_all.item(), 773 | loss_val=loss_all_val.item(), 774 | loss_test=loss_all_test.item(), 775 | accuracy=accuary_train, 776 | accuary_val=accuary_val, 777 | accuary_test=accuary_test, 778 | outputs_all=torch.cat(outputs_all, dim=0), 779 | train_pos_neg=np.array([pos_train_num/len(train_data), neg_train_num/len(train_data)]), 780 | val_pos_neg=np.array(base_ac_val), 781 | test_pos_neg=np.array(base_ac_test), 782 | cross_weight_auto=self.cross_weight_auto, 783 | class_ac_train=class_ac_train, 784 | class_ac_val=class_ac_val, 785 | class_ac_test=class_ac_test, 786 | recall_train=recall_train, 787 | recall_val=recall_val, 788 | recall_test=recall_test, 789 | precision_train=precision_train, 790 | precision_val=precision_val, 791 | precision_test=precision_test, 792 | F1_train=F1_train, 793 | F1_val=F1_val, 794 | F1_test=F1_test 795 | ) 796 | self.tb_logger.print_tensoroard_logs(model=self.model, info_dict=info_dict) 797 | 798 | logging.error("epoch:{} loss:{} accuracy:{}/{}={}".format(epoch, loss_all, accuary_all_num, preds_all_num, 799 | int(accuary_all_num) / int(preds_all_num))) 800 | if accuary_train >= 0.79 or (epoch >= 100 and epoch % 10 == 0 and accuary_train >= 0.7): 801 | self.save_model(time_str=start_time_train) 802 | 803 | self.test_val_model(mode="test") 804 | # 保存模型和超参数 805 | self.save_model(time_str=start_time_train) 806 | self.__stop_tb_logger() 807 | 808 | @torch.no_grad() 809 | def test_val_model(self, mode): 810 | test_data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode=mode, 811 | repeat_pos_data=self.repeat_pos_data, resplit=False) 812 | # test_data.print_data_set_info() 813 | test_loader = DataLoader(dataset=test_data, batch_size=self.batch_size, shuffle=True) 814 | accuary_all_num = 0 815 | preds_all_num = 0 816 | FN, FP, TN, TP = 0, 0, 0, 0 817 | loss_all = 0.0 818 | pred_class_t3 = list() 819 | pred_class_t2 = list() 820 | pred_class_t5 = list() 821 | pred_class = list() 822 | label_class = list() 823 | time = 0 824 | pos_num, neg_num = test_data.pos_neg_num() 825 | for batch in test_loader: 826 | graphs_online = (batch["graph_online_feature"], batch["graph_online_adj"]) 827 | graphs_offline = (batch["graph_kb_feature"], batch["graph_kb_adj"]) 828 | labels = batch["label"] 829 | start_time = time.perf_counter() 830 | outputs = self.model(graphs_online, graphs_offline) 831 | end_time = time.perf_counter() 832 | # logging.error("out{} label{} negnum{}".format(type(outputs), type(labels), type(neg_num))) 833 | loss = self.criterion(outputs, labels) 834 | loss_all += loss 835 | outputs_sm = torch.nn.functional.softmax(outputs, dim=1) 836 | # 记录最为相似的索引,即top1 837 | outputs_max_index = torch.argmax(outputs_sm.narrow(1, 1, 1)) 838 | # if save_error_excel and (outputs_max_index.item() != (torch.argmax(labels)).item()): 839 | # preds_index = outputs_max_index.item() 840 | # preds_name = online_data_info[3][preds_index] 841 | # one_error = list(online_data_info[0:3]) + [preds_index, preds_name] 842 | # labeled_error_data.append(one_error) 843 | # 记录top_3 844 | outputs_sq = torch.squeeze(outputs_sm.narrow(1, 1, 1)) 845 | top_3 = torch.topk(outputs_sq, k=3)[1] 846 | if torch.argmax(labels) in top_3: 847 | pred_class_t3.append(torch.argmax(labels)) 848 | else: 849 | pred_class_t3.append(top_3[0]) 850 | 851 | # 记录top2 852 | top_2 = torch.topk(outputs_sq, k=2)[1] 853 | if torch.argmax(labels) in top_2: 854 | pred_class_t2.append(torch.argmax(labels)) 855 | else: 856 | pred_class_t2.append(top_2[0]) 857 | # record top5 858 | top_5 = torch.topk(outputs_sq, k=5)[1] 859 | if torch.argmax(labels) in top_5: 860 | pred_class_t5.append(torch.argmax(labels)) 861 | else: 862 | pred_class_t5.append(top_5[0]) 863 | pred_class.append(outputs_max_index) 864 | label_class.append(torch.argmax(labels)) 865 | preds = torch.argmax(outputs, dim=1) 866 | accuary_all_num += torch.sum(preds == labels) 867 | preds_all_num += torch.as_tensor(labels.shape[0]) 868 | FN += int(torch.sum(preds[labels==1]==0)) 869 | FP += int(torch.sum(preds[labels==0]==1)) 870 | TN += int(torch.sum(preds[labels==0]==0)) 871 | TP += int(torch.sum(preds[labels==1]==1)) 872 | recall = TP / (TP + FN) if (TP + FN) else 0 873 | precision = TP / (TP + FP) if (TP + FP) else 0 874 | time = (end_time - start_time)*1000 875 | F1 = ((2 * precision * recall) / (precision + recall)) if (precision and recall) else 0 876 | pred_class_s = torch.stack(pred_class) 877 | label_class_s = torch.stack(label_class) 878 | pred_class_t3_s = torch.stack(pred_class_t3) 879 | pred_class_t2_s = torch.stack(pred_class_t2) 880 | pred_class_t5_s = torch.stack(pred_class_t5) 881 | MAR_res = MAR(label_class_s, pred_class_s) 882 | top1_a = torch.sum(pred_class_s == label_class_s).item() / pred_class_s.size()[0] 883 | top3_a = torch.sum(pred_class_t3_s == label_class_s).item() / pred_class_t3_s.size()[0] 884 | top2_a = torch.sum(pred_class_t2_s == label_class_s).item() / pred_class_t2_s.size()[0] 885 | top5_a = torch.sum(pred_class_t5_s == label_class_s).item() / pred_class_t5_s.size()[0] 886 | logging.error("{}_data : accuracy1:{}/{}={} accuracy2:{} accuracy3:{} accuracy5:{} MAR: {} precision:{}/{}={} recall:{}/{}={} F1:{} time: {}".format( 887 | mode, accuary_all_num, preds_all_num, int(accuary_all_num) / int(preds_all_num),top2_a,top3_a,top5_a,MAR_res, 888 | TP, (TP + FP), precision, 889 | TP, (TP + FN), recall, 890 | F1,time 891 | )) 892 | pos, neg = test_data.pos_neg_num() 893 | base_ac = [pos/(pos+neg), neg/(pos+neg)] 894 | return int(accuary_all_num) / int(preds_all_num), loss_all, base_ac, precision, recall, F1 895 | pass 896 | 897 | @torch.no_grad() 898 | def judge_graph_class_ac(self, mode, save_error_excel=False, all_test=False): 899 | assert self.model 900 | data = CustomDatasetNoKB(data_set_id=self.data_set_id, dataset_version=self.dataset_version, max_node_num=self.max_node_num, mode=mode, 901 | repeat_pos_data=self.repeat_pos_data, resplit=False) 902 | pred_class = list() 903 | pred_class_t3 = list() 904 | pred_class_t2 = list() 905 | label_class = list() 906 | 907 | labeled_error_data = list() 908 | for sample, online_data_info in data.graph_class_data(): 909 | """(online_data_path, e_name, e_index, error_name_list)""" 910 | graphs_online = (sample["graph_online_feature"], sample["graph_online_adj"]) 911 | graphs_offline = (sample["graph_kb_feature"], sample["graph_kb_adj"]) 912 | labels = sample["label"] 913 | outputs = self.model(graphs_online, graphs_offline) 914 | outputs_sm = torch.nn.functional.softmax(outputs, dim=1) 915 | outputs_sq = np.squeeze(outputs_sm) 916 | label_err_index = sample["label"][0] 917 | 918 | # 记录最为相似的索引,即top1 919 | outputs_max_index = torch.argmax(outputs_sq) 920 | if save_error_excel and (outputs_max_index.item() != label_err_index): 921 | preds_index = outputs_max_index.item() 922 | preds_name = online_data_info[3][preds_index] 923 | one_error = list(online_data_info[0:3]) + [preds_index, preds_name] 924 | labeled_error_data.append(one_error) 925 | # 记录top_3 926 | top_3 = torch.topk(outputs_sq, k=3)[1] 927 | if label_err_index in top_3: 928 | pred_class_t3.append(label_err_index) 929 | else: 930 | pred_class_t3.append(top_3[0]) 931 | 932 | # 记录top2 933 | top_2 = torch.topk(outputs_sq, k=2)[1] 934 | if label_err_index in top_2: 935 | pred_class_t2.append(label_err_index) 936 | else: 937 | pred_class_t2.append(top_2[0]) 938 | 939 | 940 | pred_class.append(outputs_max_index) 941 | label_class.append(label_err_index) 942 | pred_class_s = torch.stack(pred_class) 943 | label_class_s = torch.stack(label_class) 944 | pred_class_t3_s = torch.stack(pred_class_t3) 945 | pred_class_t2_s = torch.stack(pred_class_t2) 946 | 947 | top1_a = torch.sum(pred_class_s == label_class_s).item() / pred_class_s.size()[0] 948 | top3_a = torch.sum(pred_class_t3_s == label_class_s).item() / pred_class_t3_s.size()[0] 949 | top2_a = torch.sum(pred_class_t2_s == label_class_s).item() / pred_class_t2_s.size()[0] 950 | if save_error_excel: 951 | excel_anme = os.path.join(os.path.dirname(__file__), "{}_error_label_{}.xls".format(self.data_set_name, mode)) 952 | df = pd.DataFrame(labeled_error_data, columns=["online_data_path", "e_name", "e_index", "preds_index", "preds_name"]) 953 | df.to_excel(excel_anme, index=False) 954 | if all_test: 955 | return df 956 | return top1_a, top3_a, top2_a 957 | 958 | 959 | 960 | 961 | def save_model(self, time_str): 962 | dir_name = "{}_{}".format(time_str, socket.gethostname()+self.user_comment) 963 | save_path_dir = os.path.join(os.path.dirname(__file__), "..", "data", "graph_sim_model_parameters", self.data_set_name, 964 | dir_name) 965 | os.makedirs(save_path_dir, exist_ok=True) 966 | self.model_saved_path = os.path.join(save_path_dir, "model.pth") 967 | self.model_saved_dir = save_path_dir 968 | torch.save(self.model.state_dict(), self.model_saved_path) 969 | shutil.copy(self.config_file_path, os.path.join(save_path_dir, "config_graph_sim.ini")) 970 | 971 | def load_model(self, model_saved_dir): 972 | "https://blog.csdn.net/dss_dssssd/article/details/89409183" 973 | if model_saved_dir: 974 | self.model_saved_dir = model_saved_dir 975 | self.model_saved_path = os.path.join(model_saved_dir, "model.pth") 976 | self.config_file_path = os.path.join(model_saved_dir, "config_graph_sim.ini") 977 | self.config.read(self.config_file_path, encoding='utf-8') 978 | self.__load_super_paras() 979 | 980 | self.model = self.__new_model_obj() 981 | self.model = self.model.to(device) 982 | self.model.load_state_dict(torch.load(self.model_saved_path, map_location=device)) 983 | self.model.eval() 984 | pass 985 | 986 | def get_error_summary(did = 1): 987 | model_paths = [ 988 | "20200423-095359_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdmweight46", 989 | "20200423-092330_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdmweight46", 990 | "20200423-084810_amaxD2_n100step100dataset70attetionjumpReduceLROnPlateauLRRAdm", 991 | "20200423-081103_amaxD2_n100step100dataset70attetionjumpCosineAnnealingLRRAdm", 992 | "20200423-132639_amaxD2_n100step100dataset70allattetionjumpReduceLROnPlateauLRRAdm" 993 | ] 994 | minf = ModelInferenceNoKb() 995 | dataset_name = "train_ticket" if did == 1 else "sock_shop" 996 | all_dir = os.path.join(os.path.dirname(__file__), "..", "data", "graph_sim_model_parameters", dataset_name) 997 | error_info = list() 998 | for root, dirs, files in os.walk(all_dir): 999 | # if all_dir.find(root) == -1 : 1000 | # break 1001 | for dir in dirs: 1002 | if dir not in model_paths: 1003 | continue 1004 | model_dir = os.path.join(all_dir, dir) 1005 | minf.load_model(model_saved_dir=model_dir) 1006 | df_train = minf.judge_graph_class_ac(mode="train", save_error_excel=True, all_test=True) 1007 | df_test = minf.judge_graph_class_ac(mode="test", save_error_excel=True, all_test=True) 1008 | df_val = minf.judge_graph_class_ac(mode="val", save_error_excel=True, all_test=True) 1009 | df_all = pd.concat([df_train, df_test, df_val]) 1010 | online_data_path = list(set((df_all["online_data_path"]))) 1011 | online_data_path.sort() 1012 | e_name = list(set(df_all["e_name"])) 1013 | e_name.sort() 1014 | error_info.append(dict( 1015 | model_name=dir, 1016 | online_names=online_data_path, 1017 | error_names=e_name 1018 | )) 1019 | online_name_sets = [set(info["online_names"]) for info in error_info] 1020 | error_name_sets = [set(info["error_names"]) for info in error_info] 1021 | online_final = online_name_sets[0] 1022 | error_final = error_name_sets[0] 1023 | for _ in range(1, len(online_name_sets)): 1024 | online_final.intersection(online_name_sets[_]) 1025 | for _ in range(1, len(error_name_sets)): 1026 | error_final.intersection(error_name_sets[_]) 1027 | error_info.insert(0, dict( 1028 | online_names_jiaoset=sorted(list(online_final)), 1029 | e_name_jiaoset=sorted(list(error_final)) 1030 | )) 1031 | 1032 | save_json_data(os.path.join(os.path.dirname(__file__), "{}_error_info.json".format(dataset_name)), error_info) 1033 | 1034 | 1035 | def save_json_data(save_path, pre_save_data): 1036 | with open(save_path, 'w', encoding='utf-8') as file_writer: 1037 | raw_data = json.dumps(pre_save_data, indent=4) 1038 | file_writer.write(raw_data) 1039 | 1040 | 1041 | 1042 | if __name__ == '__main__': 1043 | minf = ModelInferenceNoKb() 1044 | minf.train_model() 1045 | --------------------------------------------------------------------------------