├── Training_code ├── readme.md ├── modules │ ├── readme.md │ └── model.py ├── data_loader │ ├── readme.md │ ├── __init__.py │ ├── batch_graph.py │ └── dataset.py ├── mlp_readout.py ├── utils.py ├── models │ └── FFMPeg │ │ ├── valid_acc.txt │ │ ├── train_acc.txt │ │ ├── train_loss.txt │ │ └── valid_loss.txt ├── layers.py ├── trainer.py ├── main_fan.py ├── main.py ├── main_devign.py └── main_reveal.py ├── LICENSE └── README.md /Training_code/readme.md: -------------------------------------------------------------------------------- 1 | Training code 2 | -------------------------------------------------------------------------------- /Training_code/modules/readme.md: -------------------------------------------------------------------------------- 1 | GNN network 2 | -------------------------------------------------------------------------------- /Training_code/data_loader/readme.md: -------------------------------------------------------------------------------- 1 | Data_loader 2 | -------------------------------------------------------------------------------- /Training_code/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | n_identifier = 'features' 2 | g_identifier = 'structure' 3 | l_identifier = 'label' -------------------------------------------------------------------------------- /Training_code/mlp_readout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | MLP Layer used after graph vector representation 7 | """ 8 | 9 | 10 | class MLPReadout(nn.Module): 11 | 12 | def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers 13 | super().__init__() 14 | list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)] 15 | list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True)) 16 | self.FC_layers = nn.ModuleList(list_FC_layers) 17 | self.L = L 18 | 19 | def forward(self, x): 20 | y = x 21 | for l in range(self.L): 22 | y = self.FC_layers[l](y) 23 | y = F.relu(y) 24 | 25 | y = self.FC_layers[self.L](y) 26 | return y -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 xmwenxincheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Training_code/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from data_loader import n_identifier, g_identifier, l_identifier 4 | import inspect 5 | from datetime import datetime 6 | import logging 7 | 8 | def load_default_identifiers(n, g, l): 9 | if n is None: 10 | n = n_identifier 11 | if g is None: 12 | g = g_identifier 13 | if l is None: 14 | l = l_identifier 15 | return n, g, l 16 | 17 | 18 | def initialize_batch(entries, batch_size, shuffle=False): 19 | total = len(entries) 20 | #print(str(total)+'k'*35) 21 | indices = np.arange(0, total , 1) 22 | if shuffle: 23 | np.random.shuffle(indices) 24 | batch_indices = [] 25 | start = 0 26 | end = len(indices) 27 | curr = start 28 | while curr < end: 29 | c_end = curr + batch_size 30 | if c_end > end: 31 | c_end = end 32 | batch_indices.append(indices[curr:c_end]) 33 | curr = c_end 34 | return batch_indices[::-1] 35 | 36 | 37 | def tally_param(model): 38 | total = 0 39 | for param in model.parameters(): 40 | total += param.data.nelement() 41 | return total 42 | 43 | 44 | def debug(*msg, sep='\t'): 45 | caller = inspect.stack()[1] 46 | file_name = caller.filename 47 | ln = caller.lineno 48 | now = datetime.now() 49 | time = now.strftime("%m/%d/%Y - %H:%M:%S") 50 | print('[' + str(time) + '] File \"' + file_name + '\", line ' + str(ln) + ' ', end='\t') 51 | for m in msg: 52 | print(m, end=sep) 53 | print('') 54 | 55 | def set_logger(log_path): 56 | logger = logging.getLogger() 57 | logger.setLevel(logging.INFO) 58 | 59 | if not logger.handlers: 60 | file_handler = logging.FileHandler(log_path, mode="w", encoding='utf-8') 61 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 62 | logger.addHandler(file_handler) 63 | 64 | #logging to console 65 | stream_handler = logging.StreamHandler() 66 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 67 | logger.addHandler(stream_handler) 68 | -------------------------------------------------------------------------------- /Training_code/models/FFMPeg/valid_acc.txt: -------------------------------------------------------------------------------- 1 | 55.52455357142857 2 | 58.984375 3 | 61.27232142857143 4 | 62.94642857142857 5 | 62.27678571428571 6 | 61.60714285714286 7 | 63.44866071428571 8 | 63.44866071428571 9 | 63.671875 10 | 63.72767857142857 11 | 62.05357142857143 12 | 62.22098214285714 13 | 62.109375 14 | 62.55580357142857 15 | 61.99776785714286 16 | 61.04910714285714 17 | 62.5 18 | 62.16517857142857 19 | 61.328125 20 | 62.44419642857143 21 | 62.27678571428571 22 | 62.5 23 | 62.109375 24 | 60.21205357142857 25 | 59.98883928571429 26 | 59.98883928571429 27 | 60.9375 28 | 61.60714285714286 29 | 63.11383928571429 30 | 62.44419642857143 31 | 61.49553571428571 32 | 61.99776785714286 33 | 61.83035714285714 34 | 62.109375 35 | 62.61160714285714 36 | 61.99776785714286 37 | 62.05357142857143 38 | 60.546875 39 | 61.10491071428571 40 | 62.33258928571429 41 | 61.99776785714286 42 | 62.890625 43 | 62.55580357142857 44 | 62.109375 45 | 62.77901785714286 46 | 62.83482142857143 47 | 63.11383928571429 48 | 62.77901785714286 49 | 62.94642857142857 50 | 62.55580357142857 51 | 63.00223214285714 52 | 62.83482142857143 53 | 62.61160714285714 54 | 62.55580357142857 55 | 62.61160714285714 56 | 61.60714285714286 57 | 62.44419642857143 58 | 63.22544642857143 59 | 62.94642857142857 60 | 63.22544642857143 61 | 63.11383928571429 62 | 63.05803571428571 63 | 63.28125 64 | 62.94642857142857 65 | 63.61607142857143 66 | 63.22544642857143 67 | 63.56026785714286 68 | 63.44866071428571 69 | 63.39285714285714 70 | 63.39285714285714 71 | 63.33705357142857 72 | 63.16964285714286 73 | 63.11383928571429 74 | 62.94642857142857 75 | 62.5 76 | 63.50446428571429 77 | 62.38839285714286 78 | 62.66741071428571 79 | 62.890625 80 | 62.66741071428571 81 | 62.94642857142857 82 | 62.38839285714286 83 | 62.5 84 | 62.77901785714286 85 | 63.11383928571429 86 | 63.16964285714286 87 | 63.22544642857143 88 | 62.66741071428571 89 | 63.05803571428571 90 | 63.16964285714286 91 | 62.94642857142857 92 | 62.55580357142857 93 | 63.22544642857143 94 | 63.22544642857143 95 | 63.05803571428571 96 | 63.22544642857143 97 | 63.33705357142857 98 | 63.28125 99 | 63.28125 100 | 63.11383928571429 101 | -------------------------------------------------------------------------------- /Training_code/models/FFMPeg/train_acc.txt: -------------------------------------------------------------------------------- 1 | 56.88048245614035 2 | 60.69764254385965 3 | 63.60334429824561 4 | 67.9139254385965 5 | 71.08004385964912 6 | 74.35581140350878 7 | 77.7891995614035 8 | 81.3048245614035 9 | 84.14884868421053 10 | 86.10197368421053 11 | 85.75932017543859 12 | 78.29632675438597 13 | 85.07401315789474 14 | 88.25383771929825 15 | 83.0455043859649 16 | 79.5984100877193 17 | 89.24753289473685 18 | 86.6296600877193 19 | 84.99177631578947 20 | 89.34347587719299 21 | 89.70668859649122 22 | 92.82483552631578 23 | 93.1263706140351 24 | 83.8610197368421 25 | 79.44764254385966 26 | 91.4610745614035 27 | 87.47258771929825 28 | 83.41557017543859 29 | 90.81003289473685 30 | 88.46628289473685 31 | 85.98547149122807 32 | 92.86595394736842 33 | 95.77850877192982 34 | 96.5734649122807 35 | 95.76480263157895 36 | 93.79797149122807 37 | 94.22286184210526 38 | 78.0016447368421 39 | 83.81990131578947 40 | 85.69078947368422 41 | 96.42269736842105 42 | 96.83388157894737 43 | 97.40268640350878 44 | 96.9983552631579 45 | 97.84813596491229 46 | 97.89610745614034 47 | 97.96463815789474 48 | 97.9920504385965 49 | 98.00575657894737 50 | 98.02631578947368 51 | 98.04002192982456 52 | 98.05372807017544 53 | 98.06058114035088 54 | 98.07428728070175 55 | 98.08799342105263 56 | 98.1359649122807 57 | 98.1359649122807 58 | 98.14967105263158 59 | 98.14967105263158 60 | 98.14967105263158 61 | 98.14967105263158 62 | 98.14967105263158 63 | 98.14967105263158 64 | 98.15652412280701 65 | 98.15652412280701 66 | 98.16337719298247 67 | 98.16337719298247 68 | 98.16337719298247 69 | 98.16337719298247 70 | 98.16337719298247 71 | 98.16337719298247 72 | 98.16337719298247 73 | 98.1702302631579 74 | 98.18393640350878 75 | 98.19078947368422 76 | 98.19078947368422 77 | 98.19078947368422 78 | 98.19078947368422 79 | 98.19764254385966 80 | 98.18393640350878 81 | 98.18393640350878 82 | 98.17708333333334 83 | 98.17708333333334 84 | 98.17708333333334 85 | 98.17708333333334 86 | 98.17708333333334 87 | 98.1702302631579 88 | 98.18393640350878 89 | 98.19764254385966 90 | 98.2044956140351 91 | 98.21820175438597 92 | 98.21820175438597 93 | 98.21820175438597 94 | 98.21820175438597 95 | 98.21820175438597 96 | 98.21820175438597 97 | 98.21820175438597 98 | 98.21820175438597 99 | 98.21820175438597 100 | 98.21820175438597 101 | -------------------------------------------------------------------------------- /Training_code/models/FFMPeg/train_loss.txt: -------------------------------------------------------------------------------- 1 | 190.6135248552289 2 | 186.83300915099028 3 | 182.07138007983826 4 | 172.8624985008909 5 | 164.89226491827714 6 | 157.26750076026246 7 | 149.3726198966043 8 | 141.07457371761924 9 | 133.98771332857902 10 | 129.13849479273745 11 | 128.3739511590255 12 | 148.0837164092482 13 | 131.5058783815618 14 | 121.28719249524568 15 | 135.37261547958641 16 | 144.84769961708471 17 | 118.34907116806298 18 | 122.76924601772376 19 | 125.94456655937329 20 | 116.31134662293552 21 | 117.58514109829017 22 | 109.3832703975209 23 | 108.36965862073396 24 | 127.18655368738007 25 | 137.58479523240473 26 | 110.56509961579975 27 | 119.39040950306675 28 | 128.824333324767 29 | 114.60997102971663 30 | 121.8375253510057 31 | 127.71105167321991 32 | 108.89676559180545 33 | 100.2262555841814 34 | 97.71139312208744 35 | 99.47922355250309 36 | 104.06525983308491 37 | 105.14555051033958 38 | 149.13153169866195 39 | 133.55626450923452 40 | 128.77901525664748 41 | 97.68219877544202 42 | 97.06841104072437 43 | 94.95469491523609 44 | 95.93935755679482 45 | 93.89449979547868 46 | 93.42618815104167 47 | 93.32474477667557 48 | 93.18217066714638 49 | 93.09159436142235 50 | 93.03336253919099 51 | 92.99658778675816 52 | 92.95493182801364 53 | 92.91759571276214 54 | 92.87924863581071 55 | 92.81979316577576 56 | 92.80134689598752 57 | 92.70530794377913 58 | 92.66217362253289 59 | 92.65147814834327 60 | 92.64965097527755 61 | 92.64442443847656 62 | 92.64160865649842 63 | 92.63467715079324 64 | 92.62392612925747 65 | 92.61750124211898 66 | 92.60694390012507 67 | 92.60015869140625 68 | 92.59997973525734 69 | 92.59792515269497 70 | 92.59708351001404 71 | 92.59574421665124 72 | 92.59381023206208 73 | 92.57034355297424 74 | 92.55378522370991 75 | 92.53779762669613 76 | 92.53705181991845 77 | 92.52469179923074 78 | 92.52220167193497 79 | 92.50880418744003 80 | 92.5361611884937 81 | 92.53332452606736 82 | 92.55173947518333 83 | 92.5468166418243 84 | 92.5445417437637 85 | 92.54287920500103 86 | 92.54234153346012 87 | 92.56542808131168 88 | 92.56256357828777 89 | 92.49583796450966 90 | 92.48505267762302 91 | 92.43970302113316 92 | 92.44785268683182 93 | 92.4447270443565 94 | 92.44311978524192 95 | 92.4420425682737 96 | 92.44147785922937 97 | 92.44108072916667 98 | 92.44074998822128 99 | 92.44042339659573 100 | 92.43951656943874 101 | -------------------------------------------------------------------------------- /Training_code/models/FFMPeg/valid_loss.txt: -------------------------------------------------------------------------------- 1 | 191.06351797921317 2 | 188.51788548060827 3 | 186.54474312918526 4 | 182.59664699009485 5 | 180.86619785853796 6 | 180.48620387486048 7 | 180.26580592564173 8 | 180.93647112165178 9 | 183.1347198486328 10 | 184.07951136997767 11 | 185.65510777064733 12 | 194.4710235595703 13 | 191.34454781668526 14 | 185.3497554234096 15 | 191.63638523646765 16 | 196.07667105538505 17 | 186.45994785853796 18 | 185.02945382254464 19 | 186.290043422154 20 | 184.40681675502233 21 | 190.01944623674666 22 | 187.95960126604353 23 | 188.57422092982702 24 | 188.3047158377511 25 | 188.6261749267578 26 | 188.81195940290178 27 | 187.13639177594865 28 | 186.82730320521765 29 | 190.9355686732701 30 | 195.1559055873326 31 | 198.83881051199776 32 | 193.62359183175224 33 | 190.25687517438615 34 | 188.29340253557478 35 | 186.8502938406808 36 | 186.2771497453962 37 | 194.34925842285156 38 | 204.2268284388951 39 | 200.4159938267299 40 | 197.99866594587053 41 | 188.7580108642578 42 | 189.52193777901786 43 | 187.31442042759485 44 | 187.17162214006697 45 | 189.63689313616072 46 | 188.14918300083704 47 | 187.456059047154 48 | 187.86625017438615 49 | 188.08323887416296 50 | 187.93886021205358 51 | 188.19422694614954 52 | 188.41443089076452 53 | 188.5118887765067 54 | 188.56909397670202 55 | 188.77310398646765 56 | 189.7388153076172 57 | 189.01932634626115 58 | 188.84686497279577 59 | 188.4554683140346 60 | 188.62923104422433 61 | 188.62123543875558 62 | 188.69256373814173 63 | 188.72161865234375 64 | 188.89974757603235 65 | 188.8606916155134 66 | 189.32758004324776 67 | 188.8723863874163 68 | 188.90605817522322 69 | 188.8577423095703 70 | 188.92562430245536 71 | 188.922123500279 72 | 188.99642944335938 73 | 189.4260711669922 74 | 188.98912484305245 75 | 189.5516640799386 76 | 189.48152378627233 77 | 189.89412798200334 78 | 189.7627912248884 79 | 189.67321559361048 80 | 190.85924639020647 81 | 189.52085876464844 82 | 189.83466230119978 83 | 189.88701520647322 84 | 189.9333038330078 85 | 189.97372000558036 86 | 189.96190534319197 87 | 189.1787872314453 88 | 191.8114253452846 89 | 189.72842843191964 90 | 190.73050362723214 91 | 189.48109000069755 92 | 189.90345982142858 93 | 189.3923797607422 94 | 189.2856903076172 95 | 189.3286634172712 96 | 189.3504638671875 97 | 189.35826764787947 98 | 189.37970406668526 99 | 189.4085409981864 100 | 189.47732543945312 101 | -------------------------------------------------------------------------------- /Training_code/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import dgl 4 | from dgl.nn import GraphConv, AvgPooling, MaxPooling 5 | 6 | 7 | class SAGPool(torch.nn.Module): 8 | """The Self-Attention Pooling layer in paper 9 | `Self Attention Graph Pooling ` 10 | Args: 11 | in_dim (int): The dimension of node feature. 12 | ratio (float, optional): The pool ratio which determines the amount of nodes 13 | remain after pooling. (default: :obj:`0.5`) 14 | conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to 15 | compute scale for each node. (default: :obj:`dgl.nn.GraphConv`) 16 | non_linearity (Callable, optional): The non-linearity function, a pytorch function. 17 | (default: :obj:`torch.tanh`) 18 | """ 19 | 20 | def __init__(self, in_dim: int, ratio=0.5, conv_op=GraphConv, non_linearity=torch.tanh): 21 | super(SAGPool, self).__init__() 22 | self.in_dim = in_dim 23 | self.ratio = ratio 24 | self.score_layer = conv_op(in_dim, 1) 25 | self.non_linearity = non_linearity 26 | 27 | def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor): 28 | score = self.score_layer(graph, feature).squeeze() 29 | perm, next_batch_num_nodes = topk(score, self.ratio, get_batch_id(graph.batch_num_nodes()), 30 | graph.batch_num_nodes()) 31 | feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1) 32 | graph = dgl.node_subgraph(graph, perm) 33 | 34 | # node_subgraph currently does not support batch-graph, 35 | # the 'batch_num_nodes' of the result subgraph is None. 36 | # So we manually set the 'batch_num_nodes' here. 37 | # Since global pooling has nothing to do with 'batch_num_edges', 38 | # we can leave it to be None or unchanged. 39 | graph.set_batch_num_nodes(next_batch_num_nodes) 40 | 41 | return graph, feature, perm 42 | 43 | 44 | class ConvPoolBlock(torch.nn.Module): 45 | """A combination of GCN layer and SAGPool layer, 46 | followed by a concatenated (mean||sum) readout operation. 47 | """ 48 | 49 | def __init__(self, in_dim: int, out_dim: int, pool_ratio=0.8): 50 | super(ConvPoolBlock, self).__init__() 51 | self.conv = GraphConv(in_dim, out_dim) 52 | self.pool = SAGPool(out_dim, ratio=pool_ratio) 53 | self.avgpool = AvgPooling() 54 | self.maxpool = MaxPooling() 55 | 56 | def forward(self, graph, feature): 57 | out = F.relu(self.conv(graph, feature)) 58 | graph, out, _ = self.pool(graph, out) 59 | g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1) 60 | return graph, out, g_out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAGNET - Implementation 2 | ## Meta-Path Based Attentional Graph Learning Model for Vulnerability Detection 3 | 4 | ## Introduction 5 | Deep Learning (DL) methods based on the graph are widely used in code vulnerability detection. However, the existing studies focus on employing contextual value in the graph, which ignores the heterogeneous relations (i.e., the relations between different nodes and edges types). In addition, subject to a large number of nodes, current methods lack the ability to capture the long-range dependencies in the code graph (i.e., the relationship between distant nodes). These limitations may obstruct the learning 6 | of vulnerable code patterns. In this paper, we propose MAGNET, a Meta-path based Attentional Graph learning model for code vulNErability deTection. We design a multi-granularity meta-path to consider the heterogeneous relations between different node and edge types to better learn the structural information in the graph. Furthermore, we propose a multi-level attentional graph neural network called MHAGNN, which considers the heterogeneous relations and exploits the long-range dependencies between distant nodes. Comprehensive experimental results on three public benchmarks show that MAGNET achieves 6.32%, 21.50% and 25.40% improvement in F1 score compared to state-of-the-art methods. These results demonstrate that MAGNET can effectively capture structural information of graph and perform well on vulnerability detection. 7 | 8 | ## Dataset 9 | To investigate the effectiveness of MAGNET, we adopt three vulnerability datasets from these paper: 10 | * FFMPeg+Qemu [1]: https://drive.google.com/file/d/1x6hoF7G-tSYxg8AFybggypLZgMGDNHfF 11 | * Reveal [2]: https://drive.google.com/drive/folders/1KuIYgFcvWUXheDhT--cBALsfy1I4utOyF 12 | * Fan et al. [3]: 13 | * 14 | ## Requirement 15 | Our code is based on Python3 (>= 3.7). There are a few dependencies to run the code. The major libraries are listed as follows: 16 | * torch (==1.9.0) 17 | * dgl (==0.7.2) 18 | * numpy (==1.22.3) 19 | * sklearn (==0.0) 20 | * pandas (==1.4.1) 21 | * tqdm 22 | 23 | **Default settings in MAGNET**: 24 | * Training configs: 25 | * batch_size = 512 (FFMpeg+Qemu), 512 (Reveal), 256 (Fan et al.) 26 | * lr = 5e-4, epoch = 100, patience = 30 27 | * opt ='Adam', weight_decay=1.2e-6 28 | optim = Adam(model.parameters(), lr=5e-4, weight_decay=1.2e-6) 29 | 30 | ## Preprocessing 31 | We use the Reveal[2]'s Joern to generate the code structure graph [here](https://github.com/VulDetProject/ReVeal). It is worth noting that the structure of the generated diagrams differs significantly between versions of Joern due to the rapidity of the iterative versions. After Joern had generated the graph, we processed it into a meta-path graph. 32 | 33 | ## Training 34 | The model implementation code is under the ``` Training_code\``` folder. 35 | 36 | ## References 37 | 38 | [1] Yaqin Zhou, Shangqing Liu, Jingkai Siow, Xiaoning Du, and Yang Liu. 2019. Devign: Effective vulnerability identification by learning comprehensive program semantics via graph neural networks. In Advances in Neural Information Processing Systems. 10197–10207. 39 | 40 | [2] Saikat Chakraborty, Rahul Krishna, Yangruibo Ding, and Baishakhi Ray. 2020. Deep Learning based Vulnerability Detection: Are We There Yet? arXiv preprint arXiv:2009.07235 (2020). 41 | 42 | [3] Jiahao Fan, Yi Li, Shaohua Wang, and Tien Nguyen. 2020. A C/C++ Code Vulnerability Dataset with Code Changes and CVE Summaries. In The 2020 International Conference on Mining Software Repositories (MSR). IEEE. 43 | -------------------------------------------------------------------------------- /Training_code/data_loader/batch_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgl import DGLGraph 3 | 4 | import numpy 5 | import dgl 6 | import dgl.function as fn 7 | from torch import nn 8 | 9 | 10 | class BatchGraph: 11 | def __init__(self): 12 | #self.graph = dgl.heterograph()#DGLGraph() 13 | self.graph = dgl.heterograph({ 14 | #Expression 15 | ('Expression', 'Expression0Expression', 'Expression'): ((),()), 16 | ('Expression', 'Expression1Expression', 'Expression'): ((),()), 17 | #('Expression', 'Expression2Expression', 'Expression'): ((),()), 18 | ('Expression', 'Expression3Expression', 'Expression'): ((),()), 19 | 20 | ('Expression', 'Expression0Function', 'Function'): ((),()), 21 | ('Expression', 'Expression1Function', 'Function'): ((),()), 22 | ('Expression', 'Expression2Function', 'Function'): ((),()), 23 | ('Expression', 'Expression3Function', 'Function'): ((),()), 24 | 25 | ('Expression', 'Expression0Statement', 'Statement'): ((),()), 26 | ('Expression', 'Expression1Statement', 'Statement'): ((),()), 27 | #('Expression', 'Expression2Statement', 'Statement'): ((),()), 28 | ('Expression', 'Expression3Statement', 'Statement'): ((),()), 29 | 30 | #Function 31 | ('Function', 'Function0Expression', 'Expression'): ((),()), 32 | ('Function', 'Function1Expression', 'Expression'): ((),()), 33 | ('Function', 'Function2Expression', 'Expression'): ((),()), 34 | ('Function', 'Function3Expression', 'Expression'): ((),()), 35 | 36 | ('Function', 'Function0Function', 'Function'): ((),()), 37 | ('Function', 'Function1Function', 'Function'): ((),()), 38 | ('Function', 'Function2Function', 'Function'): ((),()), 39 | ('Function', 'Function3Function', 'Function'): ((),()), 40 | 41 | ('Function', 'Function0Statement', 'Statement'): ((),()), 42 | ('Function', 'Function1Statement', 'Statement'): ((),()), 43 | ('Function', 'Function2Statement', 'Statement'): ((),()), 44 | ('Function', 'Function3Statement', 'Statement'): ((),()), 45 | 46 | #Statement 47 | ('Statement', 'Statement0Expression', 'Expression'): ((),()), 48 | ('Statement', 'Statement1Expression', 'Expression'): ((),()), 49 | #('Statement', 'Statement2Expression', 'Expression'): ((),()), 50 | ('Statement', 'Statement3Expression', 'Expression'): ((),()), 51 | 52 | ('Statement', 'Statement0Function', 'Function'): ((),()), 53 | ('Statement', 'Statement1Function', 'Function'): ((),()), 54 | ('Statement', 'Statement2Function', 'Function'): ((),()), 55 | ('Statement', 'Statement3Function', 'Function'): ((),()), 56 | 57 | ('Statement', 'Statement0Statement', 'Statement'): ((),()), 58 | #('Statement', 'Statement1Statement', 'Statement'): ((),()), 59 | ('Statement', 'Statement2Statement', 'Statement'): ((),()), 60 | ('Statement', 'Statement3Statement', 'Statement'): ((),()), 61 | }) 62 | ''' 63 | #Expression 64 | ('Expression', '0', 'Expression'): ((),()), 65 | ('Expression', '1', 'Expression'): ((),()), 66 | ('Expression', '2', 'Expression'): ((),()), 67 | ('Expression', '3', 'Expression'): ((),()), 68 | 69 | ('Expression', '0', 'Function'): ((),()), 70 | ('Expression', '1', 'Function'): ((),()), 71 | ('Expression', '2', 'Function'): ((),()), 72 | ('Expression', '3', 'Function'): ((),()), 73 | 74 | ('Expression', '0', 'Statement'): ((),()), 75 | ('Expression', '1', 'Statement'): ((),()), 76 | ('Expression', '2', 'Statement'): ((),()), 77 | ('Expression', '3', 'Statement'): ((),()), 78 | 79 | #Function 80 | ('Function', '0', 'Expression'): ((),()), 81 | ('Function', '1', 'Expression'): ((),()), 82 | ('Function', '2', 'Expression'): ((),()), 83 | ('Function', '3', 'Expression'): ((),()), 84 | 85 | ('Function', '0', 'Function'): ((),()), 86 | ('Function', '1', 'Function'): ((),()), 87 | ('Function', '2', 'Function'): ((),()), 88 | ('Function', '3', 'Function'): ((),()), 89 | 90 | ('Function', '0', 'Statement'): ((),()), 91 | ('Function', '1', 'Statement'): ((),()), 92 | ('Function', '2', 'Statement'): ((),()), 93 | ('Function', '3', 'Statement'): ((),()), 94 | 95 | #Statement 96 | ('Statement', '0', 'Expression'): ((),()), 97 | ('Statement', '1', 'Expression'): ((),()), 98 | ('Statement', '2', 'Expression'): ((),()), 99 | ('Statement', '3', 'Expression'): ((),()), 100 | 101 | ('Statement', '0', 'Function'): ((),()), 102 | ('Statement', '1', 'Function'): ((),()), 103 | ('Statement', '2', 'Function'): ((),()), 104 | ('Statement', '3', 'Function'): ((),()) 105 | 106 | ('Statement', '0', 'Statement'): ((),()), 107 | ('Statement', '1', 'Statement'): ((),()), 108 | ('Statement', '2', 'Statement'): ((),()), 109 | ('Statement', '3', 'Statement'): ((),()), 110 | 111 | 112 | }) 113 | ''' 114 | self.number_of_nodes = 0 115 | self.graphid_to_nodeids = {} 116 | self.num_of_subgraphs = 0 117 | 118 | 119 | def add_subgraph(self, _g): 120 | assert isinstance(_g, DGLGraph) 121 | num_new_nodes = _g.number_of_nodes() 122 | num_edge_type = len(_g.canonical_etypes) 123 | self.graphid_to_nodeids[self.num_of_subgraphs] = torch.LongTensor( 124 | list(range(self.number_of_nodes, self.number_of_nodes + num_new_nodes))).to(torch.device('cuda:0')) 125 | #self.graph.add_nodes(num_new_nodes, data=_g.ndata) 126 | #self.graph.add_nodes(num_new_nodes, data={'features': self.features}) 127 | #print(self.graph) 128 | #print(self.number_of_nodes) 129 | ''' 130 | print(_g) 131 | 132 | for i in range (0,num_edge_type - 1): 133 | str_etype = str(i) 134 | sources, dests = _g.all_edges(etype = str_etype) 135 | self.graph.add_edge(sources, dests, etype= str_etype) 136 | 137 | 138 | sources += self.number_of_nodes 139 | dests += self.number_of_nodes 140 | #self.graph.add_edges(sources, dests, data=_g.edata) 141 | ''' 142 | if(self.number_of_nodes == 0): 143 | self.graph = _g 144 | else: 145 | self.graph = dgl.batch([self.graph, _g]) 146 | 147 | self.number_of_nodes += num_new_nodes 148 | self.num_of_subgraphs += 1 149 | 150 | 151 | 152 | def cuda(self, device='cuda:0'): 153 | for k in self.graphid_to_nodeids.keys(): 154 | self.graphid_to_nodeids[k] = self.graphid_to_nodeids[k].cuda(device=device) 155 | 156 | 157 | 158 | def de_batchify_graphs(self, features=None): 159 | print(self.graphid_to_nodeids.keys()) 160 | ''' 161 | assert isinstance(features, torch.Tensor) 162 | #print(features) 163 | #print(self.graphid_to_nodeids.keys()) 164 | vectors = [features.index_select(dim=0, index=self.graphid_to_nodeids[gid]) for gid in 165 | self.graphid_to_nodeids.keys()] 166 | #for i in self.graphid_to_nodeids.keys(): 167 | # vectors = features.index_select(dim=0, index=self.graphid_to_nodeids[gid]) 168 | 169 | lengths = [f.size(0) for f in vectors] 170 | max_len = max(lengths) 171 | for i, v in enumerate(vectors): 172 | #print(v.device) 173 | vectors[i] = torch.cat((v, torch.zeros(size=(max_len - v.size(0), *(v.shape[1:])), requires_grad=v.requires_grad, device=v.device)), dim=0) 174 | output_vectors = torch.stack(vectors).to(torch.device('cuda:0')) 175 | #lengths = torch.LongTensor(lengths).to(device=output_vectors.device) 176 | ''' 177 | return output_vectors#, lengths 178 | 179 | def get_network_inputs(self, cuda=False): 180 | raise NotImplementedError('Must be implemented by subclasses.') 181 | 182 | from scipy import sparse as sp 183 | 184 | 185 | 186 | class GGNNBatchGraph(BatchGraph): 187 | def __init__(self): 188 | super(GGNNBatchGraph, self).__init__() 189 | 190 | def get_network_inputs(self, cuda=False, device=None): 191 | #self.graph = dgl.add_self_loop(self.graph) 192 | #features = self.graph.ndata['features'] 193 | #features = self.graph.nodes.data['h'].to(torch.device('cuda:0')) 194 | #图结构信息 195 | #edge_types = self.graph.edata['etype'] 196 | if cuda: 197 | #self.cuda(device=device) 198 | return self.graph#, features#, edge_types#, _lap_pos_enc.cuda(device=device) 199 | else: 200 | return self.graph#, features, edge_types#, h_lap_pos_enc 201 | pass 202 | -------------------------------------------------------------------------------- /Training_code/trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from sys import stderr 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score 8 | from tqdm import tqdm 9 | 10 | from utils import debug 11 | 12 | 13 | def evaluate_loss(model, loss_function, num_batches, data_iter, cuda=False): 14 | model.eval() 15 | with torch.no_grad(): 16 | _loss = [] 17 | all_predictions, all_targets = [], [] 18 | for _ in range(num_batches): 19 | graph, targets = data_iter() 20 | targets = targets.cuda() 21 | predictions = model(graph, cuda=True) 22 | 23 | #print(predictions) 24 | batch_loss = loss_function(predictions, targets.long()) 25 | _loss.append(batch_loss.detach().cpu().item()) 26 | predictions = predictions.detach().cpu() 27 | if predictions.ndim == 2: 28 | all_predictions.extend(np.argmax(predictions.numpy(), axis=-1).tolist()) 29 | else: 30 | all_predictions.extend( 31 | predictions.ge(torch.ones(size=predictions.size()).fill_(0.5)).to( 32 | dtype=torch.int32).numpy().tolist() 33 | ) 34 | all_targets.extend(targets.detach().cpu().numpy().tolist()) 35 | model.train() 36 | return np.mean(_loss).item(), accuracy_score(all_targets, all_predictions) * 100 37 | pass 38 | 39 | def evaluate_metrics(model, loss_function, num_batches, data_iter): 40 | model.eval() 41 | with torch.no_grad(): 42 | _loss = [] 43 | all_predictions, all_targets = [], [] 44 | for _ in range(num_batches): 45 | graph, targets = data_iter() 46 | targets = targets.cuda() 47 | predictions = model(graph, cuda=True) 48 | 49 | #print(predictions) 50 | batch_loss = loss_function(predictions, targets.long()) 51 | _loss.append(batch_loss.detach().cpu().item()) 52 | predictions = predictions.detach().cpu() 53 | if predictions.ndim == 2: 54 | all_predictions.extend(np.argmax(predictions.numpy(), axis=-1).tolist()) 55 | else: 56 | all_predictions.extend( 57 | predictions.ge(torch.ones(size=predictions.size()).fill_(0.5)).to( 58 | dtype=torch.int32).numpy().tolist() 59 | ) 60 | all_targets.extend(targets.detach().cpu().numpy().tolist()) 61 | model.train() 62 | return np.mean(_loss).item(), \ 63 | accuracy_score(all_targets, all_predictions) * 100, \ 64 | precision_score(all_targets, all_predictions) * 100, \ 65 | recall_score(all_targets, all_predictions) * 100, \ 66 | f1_score(all_targets, all_predictions) * 100 67 | pass 68 | 69 | 70 | def train(model, dataset, epoches, dev_every, loss_function, optimizer, save_path, log_every=5, max_patience=5): 71 | debug('Start Training') 72 | debug(dev_every) 73 | logging.info('Start training!') 74 | train_losses = [] 75 | best_model = None 76 | patience_counter = 0 77 | best_f1 = 0 78 | log_flag = 0 79 | max_steps = epoches * dev_every 80 | all_train_acc = [] 81 | all_train_loss = [] 82 | all_valid_acc = [] 83 | all_valid_loss = [] 84 | try: 85 | for step_count in range(max_steps): 86 | #print("begin training") 87 | #print(step_count % dev_every) 88 | #if(step_count % dev_every==0): 89 | # continue 90 | #训练 91 | model.train() 92 | #模型的参数梯度设成0: 93 | model.zero_grad() 94 | graph, targets = dataset.get_next_train_batch() #first 95 | #print(graph) 96 | #print(dataset) 97 | #print(targets.size(0)) 98 | 99 | targets = targets.cuda() 100 | predictions = model(graph, cuda=True) 101 | 102 | #print(predictions) 103 | batch_loss = loss_function(predictions, targets.long()) 104 | ''' 105 | if log_every is not None and (step_count % log_every == log_every - 1): 106 | debug('Step %d\t\tTrain Loss %10.3f' % (step_count, batch_loss.detach().cpu().item())) 107 | logging.info('Step %d\t\tTrain Loss %10.3f' % (step_count, batch_loss.detach().cpu().item())) 108 | ''' 109 | #print(batch_loss.detach().cpu().item()) 110 | #train_losses.append(batch_loss.detach().cpu().item()) 111 | train_losses.append(batch_loss.detach().item()) 112 | batch_loss.backward() 113 | optimizer.step() 114 | 115 | if step_count % dev_every == (dev_every - 1): 116 | #print(step_count % dev_every) 117 | log_flag += 1 118 | debug('@@@' * 35) 119 | debug(step_count) 120 | debug(log_flag) 121 | train_loss, train_acc, train_pr, train_rc, train_f1 = evaluate_metrics(model, loss_function, dataset.initialize_train_batch(), dataset.get_next_train_batch) 122 | all_train_acc.append(train_acc) 123 | all_train_loss.append(train_loss) 124 | 125 | logging.info('-' * 100) 126 | logging.info('Epoch %d\t---Train--- Average Loss: %10.4f\t Patience %d\t Loss: %10.4f\tAccuracy: %0.4f\tPrecision: %0.4f\tRecall: %0.4f\tf1: %5.3f\t' % ( 127 | log_flag, np.mean(train_losses).item(), patience_counter, train_loss, train_acc, train_pr, train_rc, train_f1)) 128 | loss, acc, pr, rc, f1 = evaluate_metrics(model, loss_function, dataset.initialize_valid_batch(), dataset.get_next_valid_batch) 129 | logging.info('Epoch %d\t----Valid---- Loss: %0.4f\tAccuracy: %0.4f\tPrecision: %0.4f\tRecall: %0.4f\tF1: %0.4f' % (log_flag, loss, acc, pr, rc, f1)) 130 | all_valid_acc.append(acc) 131 | all_valid_loss.append(loss) 132 | if f1 > best_f1 or f1 > 28: 133 | patience_counter = 0 134 | best_f1 = f1 135 | best_model = copy.deepcopy(model.state_dict()) 136 | _save_file = open(save_path + str(log_flag) + '-model.bin', 'wb') 137 | torch.save(model.state_dict(), _save_file) 138 | _save_file.close() 139 | else: 140 | patience_counter += 1 141 | train_losses = [] 142 | loss, acc, pr, rc, f1 = evaluate_metrics(model, loss_function, dataset.initialize_test_batch(), dataset.get_next_test_batch) 143 | logging.info('Epoch %d\t----Tset---- Loss: %0.4f\tAccuracy: %0.4f\tPrecision: %0.4f\tRecall: %0.4f\tF1: %0.4f' % (log_flag, loss, acc, pr, rc, f1)) 144 | if patience_counter == max_patience: 145 | break 146 | except KeyboardInterrupt: 147 | debug('Training Interrupted by user!') 148 | logging.info('Training Interrupted by user!') 149 | logging.info('Finish training!') 150 | 151 | if best_model is not None: 152 | model.load_state_dict(best_model) 153 | _save_file = open(save_path + '-model.bin', 'wb') 154 | torch.save(model.state_dict(), _save_file) 155 | _save_file.close() 156 | 157 | 158 | #model.load_state_dict(torch.load('./Models/FFmpeg/'+'DevignModel187-model.bin')) 159 | #torch.no_grad() 160 | logging.info('#' * 100) 161 | logging.info("Test result") 162 | loss, acc, pr, rc, f1 = evaluate_metrics(model, loss_function, dataset.initialize_test_batch(), 163 | dataset.get_next_test_batch) 164 | debug('%s\tTest Accuracy: %0.2f\tPrecision: %0.2f\tRecall: %0.2f\tF1: %0.2f' % (save_path, acc, pr, rc, f1)) 165 | logging.info('%s\t----Test---- Loss: %0.4f\tAccuracy: %0.4f\tPrecision: %0.4f\tRecall: %0.4f\tF1: %0.4f' % (save_path, loss, acc, pr, rc, f1)) 166 | 167 | 168 | import os 169 | if not os.path.exists('models/FFmpeg/'): 170 | os.makedirs('models/FFmpeg/') 171 | with open('models/FFmpeg/train_acc.txt', 'w', encoding='utf-8') as f: 172 | for i in all_train_acc: 173 | f.writelines(str(i) + '\n') 174 | with open('models/FFmpeg/train_loss.txt', 'w', encoding='utf-8') as f: 175 | for i in all_train_loss: 176 | f.writelines(str(i) + '\n') 177 | with open('models/FFmpeg/valid_acc.txt', 'w', encoding='utf-8') as f: 178 | for i in all_valid_acc: 179 | f.writelines(str(i) + '\n') 180 | with open('models/FFmpeg/valid_loss.txt', 'w', encoding='utf-8') as f: 181 | for i in all_valid_loss: 182 | f.writelines(str(i) + '\n') 183 | 184 | -------------------------------------------------------------------------------- /Training_code/main_fan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import sys 6 | 7 | os.chdir(sys.path[0]) 8 | 9 | import numpy as np 10 | import torch 11 | from torch.nn import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss 12 | from torch.optim import Adam 13 | 14 | from data_loader.dataset import DataSet 15 | from modules.model import DevignModel, GGNNSum 16 | from trainer import train 17 | from utils import tally_param, debug, set_logger 18 | 19 | torch.backends.cudnn.enable =True 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | import math 24 | from torch.optim.optimizer import Optimizer, required 25 | 26 | # from tensorboardX import SummaryWriter 27 | # writer = SummaryWriter(logdir='/cps/gadam/n_cifa/') 28 | # iter_idx = 0 29 | 30 | # from ipdb import set_trace 31 | import torch.optim 32 | 33 | torch.backends.cudnn.enable =True 34 | torch.backends.cudnn.benchmark = True 35 | 36 | class RAdam(Optimizer): 37 | 38 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 39 | weight_decay=0): 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay) 42 | 43 | super(RAdam, self).__init__(params, defaults) 44 | 45 | def __setstate__(self, state): 46 | super(RAdam, self).__setstate__(state) 47 | 48 | def step(self, closure=None): 49 | loss = None 50 | beta2_t = None 51 | ratio = None 52 | N_sma_max = None 53 | N_sma = None 54 | 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data.float() 64 | if grad.is_sparse: 65 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 66 | 67 | p_data_fp32 = p.data.float() 68 | 69 | state = self.state[p] 70 | 71 | if len(state) == 0: 72 | state['step'] = 0 73 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 74 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 75 | else: 76 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 77 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 78 | 79 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 80 | beta1, beta2 = group['betas'] 81 | 82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 84 | 85 | state['step'] += 1 86 | if beta2_t is None: 87 | beta2_t = beta2 ** state['step'] 88 | N_sma_max = 2 / (1 - beta2) - 1 89 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 90 | beta1_t = 1 - beta1 ** state['step'] 91 | if N_sma >= 5: 92 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t 93 | 94 | if group['weight_decay'] != 0: 95 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 96 | 97 | # more conservative since it's an approximated value 98 | if N_sma >= 5: 99 | step_size = group['lr'] * ratio 100 | denom = exp_avg_sq.sqrt().add_(group['eps']) 101 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 102 | else: 103 | step_size = group['lr'] / beta1_t 104 | p_data_fp32.add_(-step_size, exp_avg) 105 | 106 | p.data.copy_(p_data_fp32) 107 | 108 | return loss 109 | 110 | if __name__ == '__main__': 111 | torch.manual_seed(10) 112 | 113 | np.random.seed(10) 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--model_type', type=str, help='Type of the model (devign/ggnn)', 116 | choices=['devign', 'ggnn'], default='devign') 117 | parser.add_argument('--dataset', type=str, help='Name of the dataset for experiment.', default='Fan') 118 | parser.add_argument('--input_dir', type=str, help='Input Directory of the parser', default='../fan_type1') 119 | parser.add_argument('--log_dir', default='Devign.log', type=str) 120 | parser.add_argument('--node_tag', type=str, help='Name of the node feature.', default='node_features') 121 | parser.add_argument('--graph_tag', type=str, help='Name of the graph feature.', default='graph') 122 | parser.add_argument('--label_tag', type=str, help='Name of the label feature.', default='targets') 123 | 124 | parser.add_argument('--feature_size', type=int, help='Size of feature vector for each node', default=100) 125 | parser.add_argument('--graph_embed_size', type=int, help='Size of the Graph Embedding', default=100) 126 | parser.add_argument('--num_steps', type=int, help='Number of steps in GGNN', default=6) 127 | parser.add_argument('--batch_size', type=int, help='Batch Size for training', default=1) 128 | args = parser.parse_args() 129 | 130 | model_dir = os.path.join('models', args.dataset) 131 | if not os.path.exists(model_dir): 132 | os.makedirs(model_dir) 133 | #设置日志输出 134 | log_dir = os.path.join(model_dir, args.log_dir) 135 | set_logger(log_dir) 136 | 137 | logging.info('Check up feature_size: %d', args.feature_size) 138 | if args.feature_size > args.graph_embed_size: 139 | print('Warning!!! Graph Embed dimension should be at least equal to the feature dimension.\n' 140 | 'Setting graph embedding size to feature size', file=sys.stderr) 141 | logging.info('Warning!!! Graph Embed dimension should be at least equal to the feature dimension') 142 | args.graph_embed_size = args.feature_size 143 | 144 | input_dir = args.input_dir 145 | processed_data_path = os.path.join(input_dir, 'hgt_c1_1_fan.bin') 146 | logging.info('#' * 100) 147 | if True and os.path.exists(processed_data_path): 148 | debug('Reading already processed data from %s!' % processed_data_path) 149 | dataset = pickle.load(open(processed_data_path, 'rb')) 150 | logging.info('Reading already processed data from %s!' % processed_data_path) 151 | else: 152 | logging.info('Loading the dataset from %s' % input_dir) 153 | dataset = DataSet(train_src=os.path.join(input_dir, './fan-train-v0.json'), 154 | valid_src=os.path.join(input_dir, './fan-valid-v0.json'), 155 | test_src=os.path.join(input_dir, './fan-test-v0.json'), 156 | batch_size=args.batch_size, n_ident=args.node_tag, g_ident=args.graph_tag, 157 | l_ident=args.label_tag) 158 | file = open(processed_data_path, 'wb') 159 | pickle.dump(dataset, file) #../dataset/FFmpeg_input/ 160 | file.close() 161 | logging.info('train_dataset: %d; valid_dataset: %d; test_dataset: %d', len(dataset.train_examples), len(dataset.valid_examples), len(dataset.test_examples)) 162 | logging.info("train_batch: %d, valid_batch: %d, test_batch: %d", len(dataset.train_batches), len(dataset.valid_batches), len(dataset.test_batches)) 163 | logging.info('#' * 100) 164 | ''' 165 | assert args.feature_size == dataset.feature_size, \ 166 | 'Dataset contains different feature vector than argument feature size. ' \ 167 | 'Either change the feature vector size in argument, or provide different dataset.' 168 | ''' 169 | logging.info('Check up model_type: ' + args.model_type) 170 | if args.model_type == 'ggnn': 171 | model = GGNNSum(input_dim=dataset.feature_size, output_dim=args.graph_embed_size, 172 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 173 | else: 174 | #使用DevignModel 175 | 176 | ''' 177 | dataset.feature_size : 100 178 | args.graph_embed_size : 200 179 | args.num_steps : 6 180 | dataset.max_edge_type : 4 181 | ''' 182 | 183 | 184 | model = DevignModel(input_dim= 100, output_dim=100, 185 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 186 | 187 | debug('Total Parameters : %d' % tally_param(model)) 188 | debug('#' * 100) 189 | logging.info('Total Parameters : %d' % tally_param(model)) 190 | logging.info('#' * 100) 191 | #device = torch.device("cuda:2") 192 | #model.to(device = device) 193 | model.cuda() 194 | #loss_function = BCELoss(reduction='sum') 195 | loss_function = CrossEntropyLoss(weight=torch.from_numpy(np.array([1,16.0])).float(),reduction='sum') 196 | loss_function.cuda() 197 | #loss_function = CrossEntropyLoss(reduction='sum') 198 | LR = 5e-4 199 | optim = Adam(model.parameters(), lr=1e-3, weight_decay=1.3e-6) 200 | #optim = RAdam(model.parameters(),lr=LR,weight_decay=1e-5) 201 | #logging.info('Start to train!') 202 | #开始训练模型 203 | train(model=model, dataset=dataset, epoches=100, dev_every=len(dataset.train_batches), 204 | loss_function=loss_function, optimizer=optim, 205 | save_path=model_dir + '/DevignModel', max_patience=100, log_every=5) #models/FFmpeg/GGNNSumModel.... 206 | 207 | -------------------------------------------------------------------------------- /Training_code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import sys 6 | 7 | os.chdir(sys.path[0]) 8 | 9 | import numpy as np 10 | import torch 11 | from torch.nn import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss 12 | from torch.optim import Adam 13 | 14 | from data_loader.dataset import DataSet 15 | from modules.model import DevignModel, GGNNSum 16 | from trainer import train 17 | from utils import tally_param, debug, set_logger 18 | 19 | torch.backends.cudnn.enable =True 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | import math 24 | from torch.optim.optimizer import Optimizer, required 25 | 26 | # from tensorboardX import SummaryWriter 27 | # writer = SummaryWriter(logdir='/cps/gadam/n_cifa/') 28 | # iter_idx = 0 29 | 30 | # from ipdb import set_trace 31 | import torch.optim 32 | 33 | torch.backends.cudnn.enable =True 34 | torch.backends.cudnn.benchmark = True 35 | 36 | class RAdam(Optimizer): 37 | 38 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 39 | weight_decay=0): 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay) 42 | 43 | super(RAdam, self).__init__(params, defaults) 44 | 45 | def __setstate__(self, state): 46 | super(RAdam, self).__setstate__(state) 47 | 48 | def step(self, closure=None): 49 | loss = None 50 | beta2_t = None 51 | ratio = None 52 | N_sma_max = None 53 | N_sma = None 54 | 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data.float() 64 | if grad.is_sparse: 65 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 66 | 67 | p_data_fp32 = p.data.float() 68 | 69 | state = self.state[p] 70 | 71 | if len(state) == 0: 72 | state['step'] = 0 73 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 74 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 75 | else: 76 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 77 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 78 | 79 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 80 | beta1, beta2 = group['betas'] 81 | 82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 84 | 85 | state['step'] += 1 86 | if beta2_t is None: 87 | beta2_t = beta2 ** state['step'] 88 | N_sma_max = 2 / (1 - beta2) - 1 89 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 90 | beta1_t = 1 - beta1 ** state['step'] 91 | if N_sma >= 5: 92 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t 93 | 94 | if group['weight_decay'] != 0: 95 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 96 | 97 | # more conservative since it's an approximated value 98 | if N_sma >= 5: 99 | step_size = group['lr'] * ratio 100 | denom = exp_avg_sq.sqrt().add_(group['eps']) 101 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 102 | else: 103 | step_size = group['lr'] / beta1_t 104 | p_data_fp32.add_(-step_size, exp_avg) 105 | 106 | p.data.copy_(p_data_fp32) 107 | 108 | return loss 109 | 110 | if __name__ == '__main__': 111 | torch.manual_seed(10) 112 | 113 | np.random.seed(10) 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--model_type', type=str, help='Type of the model (devign/ggnn)', 116 | choices=['devign', 'ggnn'], default='devign') 117 | parser.add_argument('--dataset', type=str, help='Name of the dataset for experiment.', default='Reveal') 118 | parser.add_argument('--input_dir', type=str, help='Input Directory of the parser', default='../reveal_type1') 119 | parser.add_argument('--log_dir', default='Devign.log', type=str) 120 | parser.add_argument('--node_tag', type=str, help='Name of the node feature.', default='node_features') 121 | parser.add_argument('--graph_tag', type=str, help='Name of the graph feature.', default='graph') 122 | parser.add_argument('--label_tag', type=str, help='Name of the label feature.', default='targets') 123 | 124 | parser.add_argument('--feature_size', type=int, help='Size of feature vector for each node', default=100) 125 | parser.add_argument('--graph_embed_size', type=int, help='Size of the Graph Embedding', default=100) 126 | parser.add_argument('--num_steps', type=int, help='Number of steps in GGNN', default=6) 127 | parser.add_argument('--batch_size', type=int, help='Batch Size for training', default=256) 128 | args = parser.parse_args() 129 | 130 | model_dir = os.path.join('models', args.dataset) 131 | if not os.path.exists(model_dir): 132 | os.makedirs(model_dir) 133 | #设置日志输出 134 | log_dir = os.path.join(model_dir, args.log_dir) 135 | set_logger(log_dir) 136 | 137 | logging.info('Check up feature_size: %d', args.feature_size) 138 | if args.feature_size > args.graph_embed_size: 139 | print('Warning!!! Graph Embed dimension should be at least equal to the feature dimension.\n' 140 | 'Setting graph embedding size to feature size', file=sys.stderr) 141 | logging.info('Warning!!! Graph Embed dimension should be at least equal to the feature dimension') 142 | args.graph_embed_size = args.feature_size 143 | 144 | input_dir = args.input_dir 145 | processed_data_path = os.path.join(input_dir, 'hgt_c4_256_reveal.bin') 146 | logging.info('#' * 100) 147 | if True and os.path.exists(processed_data_path): 148 | debug('Reading already processed data from %s!' % processed_data_path) 149 | dataset = pickle.load(open(processed_data_path, 'rb')) 150 | logging.info('Reading already processed data from %s!' % processed_data_path) 151 | else: 152 | logging.info('Loading the dataset from %s' % input_dir) 153 | dataset = DataSet(train_src=os.path.join(input_dir, './reveal-train-v0.json'), 154 | valid_src=os.path.join(input_dir, './reveal-valid-v0.json'), 155 | test_src=os.path.join(input_dir, './reveal-test-v0.json'), 156 | batch_size=args.batch_size, n_ident=args.node_tag, g_ident=args.graph_tag, 157 | l_ident=args.label_tag) 158 | file = open(processed_data_path, 'wb') 159 | pickle.dump(dataset, file) #../dataset/FFmpeg_input/ 160 | file.close() 161 | logging.info('train_dataset: %d; valid_dataset: %d; test_dataset: %d', len(dataset.train_examples), len(dataset.valid_examples), len(dataset.test_examples)) 162 | logging.info("train_batch: %d, valid_batch: %d, test_batch: %d", len(dataset.train_batches), len(dataset.valid_batches), len(dataset.test_batches)) 163 | logging.info('#' * 100) 164 | ''' 165 | assert args.feature_size == dataset.feature_size, \ 166 | 'Dataset contains different feature vector than argument feature size. ' \ 167 | 'Either change the feature vector size in argument, or provide different dataset.' 168 | ''' 169 | logging.info('Check up model_type: ' + args.model_type) 170 | if args.model_type == 'ggnn': 171 | model = GGNNSum(input_dim=dataset.feature_size, output_dim=args.graph_embed_size, 172 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 173 | else: 174 | #使用DevignModel 175 | 176 | ''' 177 | dataset.feature_size : 100 178 | args.graph_embed_size : 200 179 | args.num_steps : 6 180 | dataset.max_edge_type : 4 181 | ''' 182 | 183 | 184 | model = DevignModel(input_dim= 100, output_dim=100, 185 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 186 | 187 | debug('Total Parameters : %d' % tally_param(model)) 188 | debug('#' * 100) 189 | logging.info('Total Parameters : %d' % tally_param(model)) 190 | logging.info('#' * 100) 191 | #device = torch.device("cuda:2") 192 | #model.to(device = device) 193 | model.cuda() 194 | #loss_function = BCELoss(reduction='sum') 195 | loss_function = CrossEntropyLoss(weight=torch.from_numpy(np.array([1,9.9])).float(),reduction='sum') 196 | loss_function.cuda() 197 | #loss_function = CrossEntropyLoss(reduction='sum') 198 | LR = 5e-4 199 | optim = Adam(model.parameters(), lr=5e-4, weight_decay=1.3e-6) 200 | #optim = RAdam(model.parameters(),lr=LR,weight_decay=1e-5) 201 | #logging.info('Start to train!') 202 | #开始训练模型 203 | train(model=model, dataset=dataset, epoches=100, dev_every=len(dataset.train_batches), 204 | loss_function=loss_function, optimizer=optim, 205 | save_path=model_dir + '/DevignModel', max_patience=100, log_every=5) #models/FFmpeg/GGNNSumModel.... 206 | 207 | -------------------------------------------------------------------------------- /Training_code/main_devign.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import sys 6 | 7 | os.chdir(sys.path[0]) 8 | 9 | import numpy as np 10 | import torch 11 | from torch.nn import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss 12 | from torch.optim import Adam 13 | 14 | from data_loader.dataset import DataSet 15 | from modules.model import DevignModel, GGNNSum 16 | from trainer import train 17 | from utils import tally_param, debug, set_logger 18 | 19 | torch.backends.cudnn.enable =True 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | import math 24 | from torch.optim.optimizer import Optimizer, required 25 | 26 | # from tensorboardX import SummaryWriter 27 | # writer = SummaryWriter(logdir='/cps/gadam/n_cifa/') 28 | # iter_idx = 0 29 | 30 | # from ipdb import set_trace 31 | import torch.optim 32 | 33 | torch.backends.cudnn.enable =True 34 | torch.backends.cudnn.benchmark = True 35 | 36 | class RAdam(Optimizer): 37 | 38 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 39 | weight_decay=0): 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay) 42 | 43 | super(RAdam, self).__init__(params, defaults) 44 | 45 | def __setstate__(self, state): 46 | super(RAdam, self).__setstate__(state) 47 | 48 | def step(self, closure=None): 49 | loss = None 50 | beta2_t = None 51 | ratio = None 52 | N_sma_max = None 53 | N_sma = None 54 | 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data.float() 64 | if grad.is_sparse: 65 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 66 | 67 | p_data_fp32 = p.data.float() 68 | 69 | state = self.state[p] 70 | 71 | if len(state) == 0: 72 | state['step'] = 0 73 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 74 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 75 | else: 76 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 77 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 78 | 79 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 80 | beta1, beta2 = group['betas'] 81 | 82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 84 | 85 | state['step'] += 1 86 | if beta2_t is None: 87 | beta2_t = beta2 ** state['step'] 88 | N_sma_max = 2 / (1 - beta2) - 1 89 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 90 | beta1_t = 1 - beta1 ** state['step'] 91 | if N_sma >= 5: 92 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t 93 | 94 | if group['weight_decay'] != 0: 95 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 96 | 97 | # more conservative since it's an approximated value 98 | if N_sma >= 5: 99 | step_size = group['lr'] * ratio 100 | denom = exp_avg_sq.sqrt().add_(group['eps']) 101 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 102 | else: 103 | step_size = group['lr'] / beta1_t 104 | p_data_fp32.add_(-step_size, exp_avg) 105 | 106 | p.data.copy_(p_data_fp32) 107 | 108 | return loss 109 | 110 | if __name__ == '__main__': 111 | torch.manual_seed(10) 112 | 113 | np.random.seed(10) 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--model_type', type=str, help='Type of the model (devign/ggnn)', 116 | choices=['devign', 'ggnn'], default='devign') 117 | parser.add_argument('--dataset', type=str, help='Name of the dataset for experiment.', default='Devign') 118 | parser.add_argument('--input_dir', type=str, help='Input Directory of the parser', default='../devign_type1') 119 | parser.add_argument('--log_dir', default='Devign.log', type=str) 120 | parser.add_argument('--node_tag', type=str, help='Name of the node feature.', default='node_features') 121 | parser.add_argument('--graph_tag', type=str, help='Name of the graph feature.', default='graph') 122 | parser.add_argument('--label_tag', type=str, help='Name of the label feature.', default='targets') 123 | 124 | parser.add_argument('--feature_size', type=int, help='Size of feature vector for each node', default=100) 125 | parser.add_argument('--graph_embed_size', type=int, help='Size of the Graph Embedding', default=100) 126 | parser.add_argument('--num_steps', type=int, help='Number of steps in GGNN', default=6) 127 | parser.add_argument('--batch_size', type=int, help='Batch Size for training', default=512) 128 | args = parser.parse_args() 129 | 130 | model_dir = os.path.join('models', args.dataset) 131 | if not os.path.exists(model_dir): 132 | os.makedirs(model_dir) 133 | #设置日志输出 134 | log_dir = os.path.join(model_dir, args.log_dir) 135 | set_logger(log_dir) 136 | 137 | logging.info('Check up feature_size: %d', args.feature_size) 138 | if args.feature_size > args.graph_embed_size: 139 | print('Warning!!! Graph Embed dimension should be at least equal to the feature dimension.\n' 140 | 'Setting graph embedding size to feature size', file=sys.stderr) 141 | logging.info('Warning!!! Graph Embed dimension should be at least equal to the feature dimension') 142 | args.graph_embed_size = args.feature_size 143 | 144 | input_dir = args.input_dir 145 | processed_data_path = os.path.join(input_dir, 'hgt_c1_512.bin') 146 | logging.info('#' * 100) 147 | if True and os.path.exists(processed_data_path): 148 | debug('Reading already processed data from %s!' % processed_data_path) 149 | dataset = pickle.load(open(processed_data_path, 'rb')) 150 | logging.info('Reading already processed data from %s!' % processed_data_path) 151 | else: 152 | logging.info('Loading the dataset from %s' % input_dir) 153 | dataset = DataSet(train_src=os.path.join(input_dir, './devign-train-v0.json'), 154 | valid_src=os.path.join(input_dir, './devign-valid-v0.json'), 155 | test_src=os.path.join(input_dir, './devign-test-v0.json'), 156 | batch_size=args.batch_size, n_ident=args.node_tag, g_ident=args.graph_tag, 157 | l_ident=args.label_tag) 158 | file = open(processed_data_path, 'wb') 159 | pickle.dump(dataset, file) #../dataset/FFmpeg_input/ 160 | file.close() 161 | logging.info('train_dataset: %d; valid_dataset: %d; test_dataset: %d', len(dataset.train_examples), len(dataset.valid_examples), len(dataset.test_examples)) 162 | logging.info("train_batch: %d, valid_batch: %d, test_batch: %d", len(dataset.train_batches), len(dataset.valid_batches), len(dataset.test_batches)) 163 | logging.info('#' * 100) 164 | ''' 165 | assert args.feature_size == dataset.feature_size, \ 166 | 'Dataset contains different feature vector than argument feature size. ' \ 167 | 'Either change the feature vector size in argument, or provide different dataset.' 168 | ''' 169 | logging.info('Check up model_type: ' + args.model_type) 170 | if args.model_type == 'ggnn': 171 | model = GGNNSum(input_dim=dataset.feature_size, output_dim=args.graph_embed_size, 172 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 173 | else: 174 | #使用DevignModel 175 | 176 | ''' 177 | dataset.feature_size : 100 178 | args.graph_embed_size : 200 179 | args.num_steps : 6 180 | dataset.max_edge_type : 4 181 | ''' 182 | 183 | 184 | model = DevignModel(input_dim= 100, output_dim=100, 185 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 186 | 187 | debug('Total Parameters : %d' % tally_param(model)) 188 | debug('#' * 100) 189 | logging.info('Total Parameters : %d' % tally_param(model)) 190 | logging.info('#' * 100) 191 | #device = torch.device("cuda:2") 192 | #model.to(device = device) 193 | model.cuda() 194 | #loss_function = BCELoss(reduction='sum') 195 | loss_function = CrossEntropyLoss(weight=torch.from_numpy(np.array([1,1])).float(),reduction='sum') 196 | loss_function.cuda() 197 | #loss_function = CrossEntropyLoss(reduction='sum') 198 | LR = 5e-4 199 | optim = Adam(model.parameters(), lr=5e-4, weight_decay=1.2e-6) 200 | #optim = RAdam(model.parameters(),lr=LR,weight_decay=1e-5) 201 | #logging.info('Start to train!') 202 | #开始训练模型 203 | train(model=model, dataset=dataset, epoches=100, dev_every=len(dataset.train_batches), 204 | loss_function=loss_function, optimizer=optim, 205 | save_path=model_dir + '/DevignModel', max_patience=30, log_every=5) #models/FFmpeg/GGNNSumModel.... 206 | 207 | -------------------------------------------------------------------------------- /Training_code/main_reveal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import sys 6 | 7 | os.chdir(sys.path[0]) 8 | 9 | import numpy as np 10 | import torch 11 | from torch.nn import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss 12 | from torch.optim import Adam 13 | 14 | from data_loader.dataset import DataSet 15 | from modules.model import DevignModel, GGNNSum 16 | from trainer import train 17 | from utils import tally_param, debug, set_logger 18 | 19 | torch.backends.cudnn.enable =True 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | import math 24 | from torch.optim.optimizer import Optimizer, required 25 | 26 | # from tensorboardX import SummaryWriter 27 | # writer = SummaryWriter(logdir='/cps/gadam/n_cifa/') 28 | # iter_idx = 0 29 | 30 | # from ipdb import set_trace 31 | import torch.optim 32 | 33 | torch.backends.cudnn.enable =True 34 | torch.backends.cudnn.benchmark = True 35 | 36 | class RAdam(Optimizer): 37 | 38 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 39 | weight_decay=0): 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay) 42 | 43 | super(RAdam, self).__init__(params, defaults) 44 | 45 | def __setstate__(self, state): 46 | super(RAdam, self).__setstate__(state) 47 | 48 | def step(self, closure=None): 49 | loss = None 50 | beta2_t = None 51 | ratio = None 52 | N_sma_max = None 53 | N_sma = None 54 | 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data.float() 64 | if grad.is_sparse: 65 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 66 | 67 | p_data_fp32 = p.data.float() 68 | 69 | state = self.state[p] 70 | 71 | if len(state) == 0: 72 | state['step'] = 0 73 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 74 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 75 | else: 76 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 77 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 78 | 79 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 80 | beta1, beta2 = group['betas'] 81 | 82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 84 | 85 | state['step'] += 1 86 | if beta2_t is None: 87 | beta2_t = beta2 ** state['step'] 88 | N_sma_max = 2 / (1 - beta2) - 1 89 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 90 | beta1_t = 1 - beta1 ** state['step'] 91 | if N_sma >= 5: 92 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t 93 | 94 | if group['weight_decay'] != 0: 95 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 96 | 97 | # more conservative since it's an approximated value 98 | if N_sma >= 5: 99 | step_size = group['lr'] * ratio 100 | denom = exp_avg_sq.sqrt().add_(group['eps']) 101 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 102 | else: 103 | step_size = group['lr'] / beta1_t 104 | p_data_fp32.add_(-step_size, exp_avg) 105 | 106 | p.data.copy_(p_data_fp32) 107 | 108 | return loss 109 | 110 | if __name__ == '__main__': 111 | torch.manual_seed(10) 112 | 113 | np.random.seed(10) 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--model_type', type=str, help='Type of the model (devign/ggnn)', 116 | choices=['devign', 'ggnn'], default='devign') 117 | parser.add_argument('--dataset', type=str, help='Name of the dataset for experiment.', default='Reveal') 118 | parser.add_argument('--input_dir', type=str, help='Input Directory of the parser', default='../reveal_type1') 119 | parser.add_argument('--log_dir', default='Devign.log', type=str) 120 | parser.add_argument('--node_tag', type=str, help='Name of the node feature.', default='node_features') 121 | parser.add_argument('--graph_tag', type=str, help='Name of the graph feature.', default='graph') 122 | parser.add_argument('--label_tag', type=str, help='Name of the label feature.', default='targets') 123 | 124 | parser.add_argument('--feature_size', type=int, help='Size of feature vector for each node', default=100) 125 | parser.add_argument('--graph_embed_size', type=int, help='Size of the Graph Embedding', default=100) 126 | parser.add_argument('--num_steps', type=int, help='Number of steps in GGNN', default=6) 127 | parser.add_argument('--batch_size', type=int, help='Batch Size for training', default=512) 128 | args = parser.parse_args() 129 | 130 | model_dir = os.path.join('models', args.dataset) 131 | if not os.path.exists(model_dir): 132 | os.makedirs(model_dir) 133 | #设置日志输出 134 | log_dir = os.path.join(model_dir, args.log_dir) 135 | set_logger(log_dir) 136 | 137 | logging.info('Check up feature_size: %d', args.feature_size) 138 | if args.feature_size > args.graph_embed_size: 139 | print('Warning!!! Graph Embed dimension should be at least equal to the feature dimension.\n' 140 | 'Setting graph embedding size to feature size', file=sys.stderr) 141 | logging.info('Warning!!! Graph Embed dimension should be at least equal to the feature dimension') 142 | args.graph_embed_size = args.feature_size 143 | 144 | input_dir = args.input_dir 145 | processed_data_path = os.path.join(input_dir, 'hgt_c1_512_reveal.bin') 146 | logging.info('#' * 100) 147 | if True and os.path.exists(processed_data_path): 148 | debug('Reading already processed data from %s!' % processed_data_path) 149 | dataset = pickle.load(open(processed_data_path, 'rb')) 150 | logging.info('Reading already processed data from %s!' % processed_data_path) 151 | else: 152 | logging.info('Loading the dataset from %s' % input_dir) 153 | dataset = DataSet(train_src=os.path.join(input_dir, './reveal-train-v0.json'), 154 | valid_src=os.path.join(input_dir, './reveal-valid-v0.json'), 155 | test_src=os.path.join(input_dir, './reveal-test-v0.json'), 156 | batch_size=args.batch_size, n_ident=args.node_tag, g_ident=args.graph_tag, 157 | l_ident=args.label_tag) 158 | file = open(processed_data_path, 'wb') 159 | pickle.dump(dataset, file) #../dataset/FFmpeg_input/ 160 | file.close() 161 | logging.info('train_dataset: %d; valid_dataset: %d; test_dataset: %d', len(dataset.train_examples), len(dataset.valid_examples), len(dataset.test_examples)) 162 | logging.info("train_batch: %d, valid_batch: %d, test_batch: %d", len(dataset.train_batches), len(dataset.valid_batches), len(dataset.test_batches)) 163 | logging.info('#' * 100) 164 | ''' 165 | assert args.feature_size == dataset.feature_size, \ 166 | 'Dataset contains different feature vector than argument feature size. ' \ 167 | 'Either change the feature vector size in argument, or provide different dataset.' 168 | ''' 169 | logging.info('Check up model_type: ' + args.model_type) 170 | if args.model_type == 'ggnn': 171 | model = GGNNSum(input_dim=dataset.feature_size, output_dim=args.graph_embed_size, 172 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 173 | else: 174 | #使用DevignModel 175 | 176 | ''' 177 | dataset.feature_size : 100 178 | args.graph_embed_size : 200 179 | args.num_steps : 6 180 | dataset.max_edge_type : 4 181 | ''' 182 | 183 | 184 | model = DevignModel(input_dim= 100, output_dim=100, 185 | num_steps=args.num_steps, max_edge_types=dataset.max_edge_type) 186 | 187 | debug('Total Parameters : %d' % tally_param(model)) 188 | debug('#' * 100) 189 | logging.info('Total Parameters : %d' % tally_param(model)) 190 | logging.info('#' * 100) 191 | #device = torch.device("cuda:2") 192 | #model.to(device = device) 193 | model.cuda() 194 | #loss_function = BCELoss(reduction='sum') 195 | loss_function = CrossEntropyLoss(weight=torch.from_numpy(np.array([1,9.9])).float(),reduction='sum') 196 | loss_function.cuda() 197 | #loss_function = CrossEntropyLoss(reduction='sum') 198 | LR = 5e-4 199 | optim = Adam(model.parameters(), lr=5e-4, weight_decay=1.3e-6) 200 | #optim = RAdam(model.parameters(),lr=LR,weight_decay=1e-5) 201 | #logging.info('Start to train!') 202 | #开始训练模型 203 | train(model=model, dataset=dataset, epoches=100, dev_every=len(dataset.train_batches), 204 | loss_function=loss_function, optimizer=optim, 205 | save_path=model_dir + '/DevignModel', max_patience=20, log_every=5) #models/FFmpeg/GGNNSumModel.... 206 | 207 | -------------------------------------------------------------------------------- /Training_code/data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import sys 5 | import os 6 | os.chdir(sys.path[0]) 7 | import torch 8 | from dgl import DGLGraph 9 | import dgl 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from data_loader.batch_graph import GGNNBatchGraph 14 | from utils import load_default_identifiers, initialize_batch, debug 15 | 16 | type_0 = [1,7,9,12,16,17,21,22,27,33,34,35,37,40,47,48,55,56,59,63] 17 | type_1 = [5,13,18,19,25,30,31,46,49,51,54,60,62,64,66,67,69] 18 | node_type_0_set = set(type_0) 19 | node_type_1_set = set(type_1) 20 | ##for each function 21 | class DataEntry: 22 | def __init__(self, datset, num_nodes, features, edges, target): 23 | self.dataset = datset 24 | self.num_nodes = num_nodes 25 | self.target = target 26 | self.graph = dgl.heterograph({ 27 | 28 | #Expression 29 | ('Expression', 'Expression0Expression', 'Expression'): ((),()), 30 | ('Expression', 'Expression1Expression', 'Expression'): ((),()), 31 | ('Expression', 'Expression2Expression', 'Expression'): ((),()), 32 | ('Expression', 'Expression3Expression', 'Expression'): ((),()), 33 | 34 | ('Expression', 'Expression0Function', 'Function'): ((),()), 35 | ('Expression', 'Expression1Function', 'Function'): ((),()), 36 | ('Expression', 'Expression2Function', 'Function'): ((),()), 37 | ('Expression', 'Expression3Function', 'Function'): ((),()), 38 | 39 | ('Expression', 'Expression0Statement', 'Statement'): ((),()), 40 | ('Expression', 'Expression1Statement', 'Statement'): ((),()), 41 | ('Expression', 'Expression2Statement', 'Statement'): ((),()), 42 | ('Expression', 'Expression3Statement', 'Statement'): ((),()), 43 | 44 | #Function 45 | ('Function', 'Function0Expression', 'Expression'): ((),()), 46 | ('Function', 'Function1Expression', 'Expression'): ((),()), 47 | ('Function', 'Function2Expression', 'Expression'): ((),()), 48 | ('Function', 'Function3Expression', 'Expression'): ((),()), 49 | 50 | ('Function', 'Function0Function', 'Function'): ((),()), 51 | ('Function', 'Function1Function', 'Function'): ((),()), 52 | ('Function', 'Function2Function', 'Function'): ((),()), 53 | ('Function', 'Function3Function', 'Function'): ((),()), 54 | 55 | ('Function', 'Function0Statement', 'Statement'): ((),()), 56 | ('Function', 'Function1Statement', 'Statement'): ((),()), 57 | ('Function', 'Function2Statement', 'Statement'): ((),()), 58 | ('Function', 'Function3Statement', 'Statement'): ((),()), 59 | 60 | #Statement 61 | ('Statement', 'Statement0Expression', 'Expression'): ((),()), 62 | ('Statement', 'Statement1Expression', 'Expression'): ((),()), 63 | ('Statement', 'Statement2Expression', 'Expression'): ((),()), 64 | ('Statement', 'Statement3Expression', 'Expression'): ((),()), 65 | 66 | ('Statement', 'Statement0Function', 'Function'): ((),()), 67 | ('Statement', 'Statement1Function', 'Function'): ((),()), 68 | ('Statement', 'Statement2Function', 'Function'): ((),()), 69 | ('Statement', 'Statement3Function', 'Function'): ((),()), 70 | 71 | ('Statement', 'Statement0Statement', 'Statement'): ((),()), 72 | ('Statement', 'Statement1Statement', 'Statement'): ((),()), 73 | ('Statement', 'Statement2Statement', 'Statement'): ((),()), 74 | ('Statement', 'Statement3Statement', 'Statement'): ((),()), 75 | 76 | 77 | }) 78 | self.features = torch.FloatTensor(features) 79 | #self.graph.add_nodes(self.num_nodes, data={'features': self.features}) ## 80 | type_list = [] 81 | ex_list = [] 82 | st_list = [] 83 | fu_list = [] 84 | #print(self.features.shape) 85 | for i in range (0,self.features.shape[0]): 86 | features_new = self.features[i:i+1,:100] 87 | #print(features_new.shape) 88 | if(features[i][100] in node_type_0_set): 89 | type_new ='Expression' 90 | type_list.append(type_new) 91 | ex_list.append(i) 92 | elif(features[i][100] in node_type_1_set): 93 | type_new ='Statement' 94 | type_list.append(type_new) 95 | st_list.append(i) 96 | else: 97 | type_new='Function' 98 | type_list.append(type_new) 99 | fu_list.append(i) 100 | #print(features_type) 101 | self.graph = dgl.add_nodes(self.graph, num = 1, ntype = type_new,data = {'h':features_new})# 102 | #self.graph.nodes[features_type].data = torch.hstack((self.graph.nodes[features_type].data, features_new)) 103 | #print(ex_list) 104 | #self.graph = dgl.add_nodes(self.graph, self.num_nodes, ntype = 'features') ## 105 | #self.graph.nodes['features'].data['h'] = self.features 106 | #self.graph.add_nodes(self.num_nodes, data={self.features}, ntype = 'features') ## 107 | #for i in range (0,self.num_nodes): 108 | # self.graph.add_nodes(self.graph, 1, data=self.features[i], ntype='features') ## 109 | 110 | 111 | 112 | for s, _type, t in edges: 113 | etype_number = str(self.dataset.get_edge_type_number(_type)) 114 | #print(str(s)+" "+str(t)+" "+ str(etype_number)) 115 | if etype_number == '4': 116 | continue 117 | ''' 118 | if type_list[s]+etype_number+type_list[t] == 'Expression2Expression': 119 | continue 120 | if type_list[s]+etype_number+type_list[t] == 'Expression2Statement': 121 | continue 122 | if type_list[s]+etype_number+type_list[t] == 'Statement2Expression': 123 | continue 124 | if type_list[s]+etype_number+type_list[t] == 'Statement1Statement': 125 | continue 126 | ''' 127 | 128 | for number in range (len(ex_list)): 129 | if s == ex_list[number]: 130 | s_type = 'Expression' 131 | s_number = number 132 | if t == ex_list[number]: 133 | t_type = 'Expression' 134 | t_number = number 135 | 136 | for number in range (len(st_list)): 137 | if s == st_list[number]: 138 | s_type = 'Statement' 139 | s_number = number 140 | if t == st_list[number]: 141 | t_type = 'Statement' 142 | t_number = number 143 | 144 | for number in range (len(fu_list)): 145 | if s == fu_list[number]: 146 | s_type = 'Function' 147 | s_number = number 148 | if t == fu_list[number]: 149 | t_type = 'Function' 150 | t_number = number 151 | 152 | self.graph.add_edge(s_number, t_number, etype=(s_type, type_list[s]+etype_number+type_list[t], t_type)) 153 | #self.graph.add_edge(t, s, etype=etype_number) 154 | #self.graph.add_edge(s, t, data={'etype': torch.LongTensor([etype_number])}) ## 155 | #print(self.graph) 156 | 157 | 158 | class DataSet: 159 | def __init__(self, train_src, valid_src, test_src, batch_size, n_ident=None, g_ident=None, l_ident=None): 160 | self.train_examples = [] 161 | self.valid_examples = [] 162 | self.test_examples = [] 163 | self.train_batches = [] 164 | self.valid_batches = [] 165 | self.test_batches = [] 166 | self.batch_size = batch_size 167 | self.edge_types = {} 168 | self.max_etype = 0 169 | self.feature_size = 0 170 | self.n_ident, self.g_ident, self.l_ident= load_default_identifiers(n_ident, g_ident, l_ident) 171 | self.read_dataset(train_src, valid_src, test_src) 172 | self.initialize_dataset() 173 | 174 | def initialize_dataset(self): 175 | 176 | self.initialize_train_batch() 177 | self.initialize_valid_batch() 178 | self.initialize_test_batch() 179 | 180 | def read_dataset(self, train_src, valid_src, test_src): 181 | debug('Reading Train File!') 182 | #logging.info('train:' + train_src + '; valid:' + valid_src + '; test:' + test_src) 183 | 184 | with open(train_src, "r") as fp: 185 | train_data = [] 186 | #for i in fp.readlines(): 187 | # train_data.append(json.loads(i)) 188 | #for line in fp.readlines(): 189 | # train_data.append(json.loads(line)) 190 | train_data = json.load(fp) 191 | for entry in tqdm(train_data): 192 | example = DataEntry(datset=self, num_nodes=len(entry[self.n_ident]), features=entry[self.n_ident], 193 | edges=entry[self.g_ident], target=entry[self.l_ident][0][0]) 194 | 195 | if self.feature_size == 0: 196 | self.feature_size = example.features.size(1)-1 197 | debug('Feature Size %d' % self.feature_size) 198 | self.train_examples.append(example) 199 | ''' 200 | 201 | if valid_src is not None: 202 | debug('Reading Validation File!') 203 | with open(valid_src, "r") as fp: 204 | valid_data = [] 205 | #for i in fp.readlines(): 206 | # valid_data.append(json.loads(i)) 207 | valid_data = json.load(fp) 208 | for entry in tqdm(valid_data): 209 | 210 | example = DataEntry(datset=self, num_nodes=len(entry[self.n_ident]), 211 | features=entry[self.n_ident], 212 | edges=entry[self.g_ident], target=entry[self.l_ident][0][0]) 213 | self.valid_examples.append(example) 214 | 215 | if test_src is not None: 216 | debug('Reading Test File!') 217 | with open(test_src, "r") as fp: 218 | test_data = [] 219 | #for i in fp.readlines(): 220 | # test_data.append(json.loads(i)) 221 | test_data = json.load(fp) 222 | for entry in tqdm(test_data): 223 | 224 | example = DataEntry(datset=self, num_nodes=len(entry[self.n_ident]), 225 | features=entry[self.n_ident], 226 | edges=entry[self.g_ident], target=entry[self.l_ident][0][0]) 227 | self.test_examples.append(example) 228 | ''' 229 | 230 | def get_edge_type_number(self, _type): 231 | if _type not in self.edge_types: 232 | self.edge_types[_type] = self.max_etype 233 | self.max_etype += 1 234 | return self.edge_types[_type] 235 | 236 | @property 237 | def max_edge_type(self): 238 | return self.max_etype 239 | 240 | def initialize_train_batch(self, batch_size=-1): 241 | if batch_size == -1: 242 | batch_size = self.batch_size 243 | 244 | self.train_batches = initialize_batch(self.train_examples, batch_size, shuffle=False) 245 | 246 | 247 | return len(self.train_batches) 248 | pass 249 | 250 | def initialize_valid_batch(self, batch_size=-1): 251 | if batch_size == -1: 252 | batch_size = self.batch_size 253 | self.valid_batches = initialize_batch(self.valid_examples, batch_size, shuffle=False) 254 | return len(self.valid_batches) 255 | pass 256 | 257 | def initialize_test_batch(self, batch_size=-1): 258 | if batch_size == -1: 259 | batch_size = self.batch_size 260 | self.test_batches = initialize_batch(self.test_examples, batch_size, shuffle=False) 261 | return len(self.test_batches) 262 | pass 263 | 264 | def get_dataset_by_ids_for_GGNN(self, entries, ids): 265 | taken_entries = [entries[i] for i in ids] 266 | labels = [e.target for e in taken_entries] 267 | batch_graph = GGNNBatchGraph() 268 | for entry in taken_entries: 269 | batch_graph.add_subgraph(copy.deepcopy(entry.graph)) 270 | return batch_graph, torch.FloatTensor(labels) 271 | 272 | def get_next_train_batch(self): 273 | 274 | #print(len(self.train_batches)) 275 | if len(self.train_batches) == 0: 276 | #print('k'*40) 277 | self.initialize_train_batch() 278 | 279 | 280 | ids = self.train_batches.pop() 281 | if(len(self.train_batches) == 1): 282 | ids1 = self.train_batches.pop() 283 | 284 | return self.get_dataset_by_ids_for_GGNN(self.train_examples, ids) 285 | 286 | def get_next_valid_batch(self): 287 | if len(self.valid_batches) == 0: 288 | self.initialize_valid_batch() 289 | ids = self.valid_batches.pop() 290 | if (len(self.valid_batches) == 1): 291 | ids1 = self.valid_batches.pop() 292 | return self.get_dataset_by_ids_for_GGNN(self.valid_examples, ids) 293 | 294 | def get_next_test_batch(self): 295 | if len(self.test_batches) == 0: 296 | self.initialize_test_batch() 297 | ids = self.test_batches.pop() 298 | if (len(self.test_batches) == 1): 299 | ids1 = self.test_batches.pop() 300 | return self.get_dataset_by_ids_for_GGNN(self.test_examples, ids) 301 | -------------------------------------------------------------------------------- /Training_code/modules/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgl.nn import GatedGraphConv 3 | from torch import nn 4 | import torch.nn.functional as f 5 | 6 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 7 | from torch_geometric.nn import GCNConv 8 | from dgl.nn import GraphConv, AvgPooling, MaxPooling 9 | from torch.autograd import Variable 10 | import dgl 11 | import dgl.function as fn 12 | import math 13 | import numpy as np 14 | #from graph_transformer_edge_layer import GraphTransformerLayer 15 | 16 | from mlp_readout import MLPReadout 17 | 18 | from dgl.nn.pytorch import edge_softmax, GATConv 19 | from dgl.nn.pytorch import GraphConv, EdgeWeightNorm 20 | 21 | # graph transformer 22 | # HGT dgl version 23 | class HGTLayer(nn.Module): 24 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout=0.0, use_norm=False): 25 | super(HGTLayer, self).__init__() 26 | 27 | self.in_dim = in_dim 28 | self.out_dim = out_dim 29 | self.num_types = num_types 30 | self.num_relations = num_relations 31 | self.n_heads = n_heads 32 | self.d_k = out_dim // n_heads 33 | self.sqrt_dk = math.sqrt(self.d_k) 34 | 35 | self.k_linears = nn.ModuleList() 36 | self.q_linears = nn.ModuleList() 37 | self.v_linears = nn.ModuleList() 38 | self.a_linears = nn.ModuleList() 39 | self.norms = nn.ModuleList() 40 | self.use_norm = use_norm 41 | 42 | for t in range(num_types): 43 | self.k_linears.append(nn.Linear(in_dim, out_dim)) 44 | self.q_linears.append(nn.Linear(in_dim, out_dim)) 45 | self.v_linears.append(nn.Linear(in_dim, out_dim)) 46 | self.a_linears.append(nn.Linear(out_dim, out_dim)) 47 | if use_norm: 48 | self.norms.append(nn.LayerNorm(out_dim)) 49 | 50 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) 51 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 52 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 53 | 54 | 55 | self.node_type_att = nn.Parameter(torch.ones(num_types)) 56 | self.node_type_att1 = nn.Parameter(torch.ones(num_types)) 57 | 58 | 59 | self.skip = nn.Parameter(torch.ones(num_types)) 60 | 61 | self.weight = nn.Parameter(torch.ones(1)) 62 | 63 | self.drop = nn.Dropout(dropout) 64 | 65 | self.attn_fc = nn.Linear(2*self.d_k, 1, bias=False) 66 | 67 | nn.init.xavier_uniform_(self.relation_att) 68 | nn.init.xavier_uniform_(self.relation_msg) 69 | 70 | def edge_attention(self, edges): 71 | if(len(edges.data['id'])!= 0): 72 | 73 | etype = edges.data['id'][0] 74 | 75 | relation_att = self.relation_att[etype] 76 | relation_pri = self.relation_pri[etype] 77 | relation_msg = self.relation_msg[etype] 78 | key = torch.bmm(edges.src['k'].transpose(1, 0).to(torch.device('cuda:0')), relation_att).transpose(1, 0) 79 | att = ((edges.dst['q'].to(torch.device('cuda:0')) * key).sum(dim=-1) * relation_pri / self.sqrt_dk) 80 | val = torch.bmm(edges.src['v'].transpose(1, 0).to(torch.device('cuda:0')), relation_msg).transpose(1, 0) 81 | 82 | if etype <= 9: 83 | src_ntype = 0 84 | elif etype <= 21: 85 | src_ntype = 1 86 | else: src_ntype = 2 87 | 88 | if etype <= 2 or (etype >= 10 and etype <=13) or (etype >= 22 and etype <=24): 89 | dst_ntype = 0 90 | elif (etype >=3 and etype <= 6) or (etype >= 14 and etype <=17) or (etype >= 25 and etype <=28): 91 | dst_ntype = 1 92 | else: dst_ntype = 2 93 | att_src = self.node_type_att[src_ntype] 94 | att_dst = self.node_type_att1[dst_ntype] 95 | att2 = torch.cat([att_dst * edges.dst['q'], att_src * edges.src['k']], dim=2) 96 | att2 = self.attn_fc(att2) 97 | att2 = att2.sum(dim=-1) 98 | att2 = f.leaky_relu(att2) 99 | else: 100 | key = edges.src['k'].transpose(1, 0).to(torch.device('cuda:0')).transpose(1, 0) 101 | att = ((edges.dst['q'].to(torch.device('cuda:0')) * key).sum(dim=-1) / self.sqrt_dk) * 0 102 | val = (edges.src['v'].transpose(1, 0).to(torch.device('cuda:0'))).transpose(1, 0) 103 | att2 = torch.cat([edges.dst['q'], edges.src['k']], dim=2) 104 | att2 = self.attn_fc(att2) 105 | att2 = att2.sum(dim=-1) 106 | att2 = f.leaky_relu(att2) 107 | 108 | 109 | 110 | return {'a': att, 'na': att2, 'v': val} 111 | 112 | def message_func(self, edges): 113 | return {'v': edges.data['v'], 'a': edges.data['a'], 'na':edges.data['na']} 114 | 115 | def reduce_func(self, nodes): 116 | beta = torch.sigmoid(self.weight) 117 | att = f.softmax(nodes.mailbox['a'] + (beta) *nodes.mailbox['na'], dim=1) 118 | h = torch.sum(att.unsqueeze(dim=-1) * nodes.mailbox['v'], dim=1) 119 | 120 | 121 | return {'t': h.view(-1, self.out_dim)} 122 | 123 | def forward(self, G, inp_key, out_key): 124 | node_dict, edge_dict = G.node_dict, G.edge_dict 125 | for srctype, etype, dsttype in G.canonical_etypes: 126 | #print(srctype) 127 | k_linear = self.k_linears[node_dict[srctype]] 128 | v_linear = self.v_linears[node_dict[srctype]] 129 | q_linear = self.q_linears[node_dict[dsttype]] 130 | 131 | G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k) 132 | G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k) 133 | G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k) 134 | G.apply_edges(func=self.edge_attention, etype=etype) 135 | G.multi_update_all({etype: (self.message_func, self.reduce_func) \ 136 | for etype in edge_dict}, cross_reducer='mean') 137 | for ntype in G.ntypes: 138 | n_id = node_dict[ntype] 139 | alpha = torch.sigmoid(self.skip[n_id]) 140 | #print(G.nodes[ntype].data['t'].to(torch.device('cuda:0')).shape) 141 | h = G.nodes[ntype].data['t'].to(torch.device('cuda:0')).shape[0] 142 | 143 | t = G.nodes[ntype].data['t'].to(torch.device('cuda:0')) 144 | trans_out = self.a_linears[n_id](t) 145 | trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key].to(torch.device('cuda:0')) * (1 - alpha) 146 | if self.use_norm: 147 | G.nodes[ntype].data[out_key] = self.drop(self.norms[n_id](trans_out)) 148 | else: 149 | G.nodes[ntype].data[out_key] = self.drop(trans_out) 150 | 151 | def __repr__(self): 152 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format( 153 | self.__class__.__name__, self.in_dim, self.out_dim, 154 | self.num_types, self.num_relations) 155 | 156 | 157 | class DevignModel(nn.Module): 158 | def __init__(self, input_dim, output_dim, max_edge_types, num_steps=8): 159 | super(DevignModel, self).__init__() 160 | self.inp_dim = input_dim 161 | self.out_dim = output_dim 162 | self.max_edge_types = max_edge_types 163 | self.num_timesteps = num_steps 164 | 165 | self.gcs = nn.ModuleList() 166 | n_layers = 4 167 | n_heads = 4 168 | len_graph_ntypes = 3 169 | len_graph_etypes = 32 170 | self.n_layers = n_layers 171 | self.adapt_ws = nn.ModuleList() 172 | output_dim = 64 173 | for t in range(3): 174 | self.adapt_ws.append(nn.Linear(input_dim, output_dim)) 175 | for _ in range(n_layers): 176 | self.gcs.append(HGTLayer(output_dim, output_dim, len_graph_ntypes, len_graph_etypes, n_heads, use_norm=True)) 177 | 178 | 179 | self.hidden_dim = 64 180 | self.batch_size = 256 181 | self.num_layers = 1 182 | self.bigru1 = nn.GRU(output_dim, self.hidden_dim, num_layers=self.num_layers, bidirectional=True, 183 | batch_first=True) 184 | self.bigru2 = nn.GRU(output_dim, self.hidden_dim, num_layers=self.num_layers, bidirectional=True, 185 | batch_first=True) 186 | self.bigru3 = nn.GRU(output_dim, self.hidden_dim, num_layers=self.num_layers, bidirectional=True, 187 | batch_first=True) 188 | self.MPL_layer = MLPReadout(2*self.hidden_dim, 2) 189 | self.MPL_layer1 = MLPReadout(2*self.hidden_dim, 2) 190 | 191 | self.hidden1 = self.init_hidden() 192 | self.hidden2 = self.init_hidden() 193 | self.hidden3 = self.init_hidden() 194 | self.sigmoid = nn.Sigmoid() 195 | 196 | self.weight = Variable(torch.ones(len_graph_ntypes).cuda()) 197 | self.weight1 = Variable(torch.ones(len_graph_ntypes).cuda()) 198 | 199 | def init_hidden(self): 200 | if True: 201 | if isinstance(self.bigru1, nn.LSTM): 202 | h0 = Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim).cuda()) 203 | c0 = Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim).cuda()) 204 | return h0, c0 205 | if isinstance(self.bigru2, nn.LSTM): 206 | h0 = Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim).cuda()) 207 | c0 = Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim).cuda()) 208 | return h0, c0 209 | if isinstance(self.bigru3, nn.LSTM): 210 | h0 = Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim).cuda()) 211 | c0 = Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim).cuda()) 212 | return h0, c0 213 | return Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim)).cuda() 214 | else: 215 | return Variable(torch.zeros(self.num_layers * 2, self.batch_size, self.hidden_dim)) 216 | 217 | def get_zeros(self, num): 218 | zeros = Variable(torch.zeros(num, self.hidden_dim)) 219 | return zeros.cuda() 220 | 221 | def forward(self, batch, cuda=False): 222 | graph = batch.get_network_inputs(cuda=cuda) 223 | 224 | graph = graph.to(torch.device('cuda:0')) 225 | #print(graph) 226 | graph.node_dict = {} 227 | graph.edge_dict = {} 228 | for ntype in graph.ntypes: 229 | graph.node_dict[ntype] = len(graph.node_dict) 230 | graph.nodes[ntype].data['id'] = torch.ones(graph.number_of_nodes(ntype), dtype=torch.long, device=graph.device) * graph.node_dict[ntype] 231 | for etype in graph.etypes: 232 | graph.edge_dict[etype] = len(graph.edge_dict) 233 | graph.edges[etype].data['id'] = torch.ones(graph.number_of_edges(etype), dtype=torch.long, device=graph.device) * graph.edge_dict[etype] 234 | 235 | for ntype in graph.ntypes: 236 | n_id = graph.node_dict[ntype] 237 | #self.adapt_ws[n_id](graph.nodes[ntype].data['h'].to(torch.device('cuda:0'))) 238 | graph.nodes[ntype].data['new_h'] = torch.tanh(self.adapt_ws[n_id](graph.nodes[ntype].data['h'].to(torch.device('cuda:0')))) 239 | 240 | 241 | 242 | for i in range(self.n_layers): 243 | self.gcs[i](graph, 'new_h', 'new_h') 244 | #outputs = graph.nodes['Statement'].data['new_h'] 245 | 246 | #node_Statement1 = dgl.readout_nodes(graph, 'new_h', op = 'sum', ntype = 'Statement') 247 | #node_Expression1 = dgl.readout_nodes(graph, 'new_h', op = 'sum', ntype = 'Expression') 248 | #node_Function1 = dgl.readout_nodes(graph, 'new_h', op = 'sum', ntype = 'Function') 249 | 250 | statement = graph.nodes['Statement'].data['new_h'] 251 | expression = graph.nodes['Expression'].data['new_h'] 252 | function = graph.nodes['Function'].data['new_h'] 253 | st = graph.batch_num_nodes('Statement') 254 | ex = graph.batch_num_nodes('Expression') 255 | fu = graph.batch_num_nodes('Function') 256 | 257 | max_len_st = max(st) 258 | max_len_ex = max(ex) 259 | max_len_fu = max(fu) 260 | batch_size = len(st) 261 | st_seq, st_start, st_end = [], 0, 0 262 | ex_seq, ex_start, ex_end = [], 0, 0 263 | fu_seq, fu_start, fu_end = [], 0, 0 264 | for i in range (batch_size): 265 | st_end = st_start + st[i] 266 | ex_end = ex_start + ex[i] 267 | fu_end = fu_start + fu[i] 268 | if max_len_st - st[i]: 269 | st_seq.append(self.get_zeros(max_len_st-st[i])) 270 | if max_len_ex - ex[i]: 271 | ex_seq.append(self.get_zeros(max_len_ex-ex[i])) 272 | if max_len_fu - fu[i]: 273 | fu_seq.append(self.get_zeros(max_len_fu-fu[i])) 274 | st_seq.append(statement[st_start:st_end]) 275 | ex_seq.append(expression[ex_start:ex_end]) 276 | fu_seq.append(function[fu_start:fu_end]) 277 | 278 | st_start = st_end 279 | ex_start = ex_end 280 | fu_start = fu_end 281 | 282 | st = torch.cat(st_seq) 283 | ex = torch.cat(ex_seq) 284 | fu = torch.cat(fu_seq) 285 | st = st.view(batch_size, max_len_st, -1) 286 | ex = ex.view(batch_size, max_len_ex, -1) 287 | fu = fu.view(batch_size, max_len_fu, -1) 288 | 289 | st, hidden = self.bigru1(st, self.hidden1) 290 | st = torch.transpose(st, 1, 2) 291 | ex, hidden = self.bigru2(ex, self.hidden2) 292 | ex = torch.transpose(ex, 1, 2) 293 | fu, hidden = self.bigru3(fu, self.hidden3) 294 | fu = torch.transpose(fu, 1, 2) 295 | 296 | 297 | # pooling 298 | #print(st.shape) 299 | st1 = f.max_pool1d(st, st.size(2)).squeeze(2) 300 | ex1 = f.max_pool1d(ex, ex.size(2)).squeeze(2) 301 | fu1 = f.max_pool1d(fu, fu.size(2)).squeeze(2) 302 | 303 | st2 = f.avg_pool1d(st, st.size(2)).squeeze(2) 304 | ex2 = f.avg_pool1d(ex, ex.size(2)).squeeze(2) 305 | fu2 = f.avg_pool1d(fu, fu.size(2)).squeeze(2) 306 | #print(st.shape) 307 | ''' 308 | all = st + ex + fu 309 | 310 | st = st / all 311 | ex = ex / all 312 | fu = fu / all 313 | 314 | length = len(st) 315 | st = st.reshape(length,1) 316 | ex = ex.reshape(length,1) 317 | fu = fu.reshape(length,1) 318 | outputs = st*node_Statement1 + ex * node_Expression1 + fu * node_Function1 319 | ''' 320 | #print(outputs.shape) 321 | 322 | outputs = self.MPL_layer(self.weight[0] * st1+ self.weight[1] * ex1+self.weight[2] * fu1 + self.weight1[0] * st2+ self.weight1[1] * ex2+self.weight1[2] * fu2) 323 | 324 | 325 | 326 | outputs = nn.Softmax(dim=1)(outputs) 327 | #print(outputs) 328 | # outputs = avg.squeeze(dim = -1) 329 | return outputs 330 | 331 | 332 | 333 | 334 | 335 | 336 | class GGNNSum(nn.Module): 337 | def __init__(self, input_dim, output_dim, max_edge_types, num_steps=8): 338 | super(GGNNSum, self).__init__() 339 | self.inp_dim = input_dim 340 | self.out_dim = output_dim 341 | self.max_edge_types = max_edge_types 342 | self.num_timesteps = num_steps 343 | self.ggnn = GatedGraphConv(in_feats=input_dim, out_feats=output_dim, n_steps=num_steps, 344 | n_etypes=max_edge_types) 345 | self.classifier = nn.Linear(in_features=output_dim, out_features=1) 346 | self.sigmoid = nn.Sigmoid() 347 | 348 | #前向传播函数 349 | def forward(self, batch, cuda=False): 350 | graph, features, edge_types = batch.get_network_inputs(cuda=cuda) 351 | outputs = self.ggnn(graph, features, edge_types) 352 | h_i, _ = batch.de_batchify_graphs(outputs) 353 | ggnn_sum = self.classifier(h_i.sum(dim=1)) 354 | result = self.sigmoid(ggnn_sum).squeeze(dim=-1) 355 | return result --------------------------------------------------------------------------------