├── Figures ├── CE.png ├── Overview.png └── SupCon.png ├── LICENSE ├── README.md ├── models ├── __init__.py ├── detectors │ ├── DeepWuKong │ │ ├── __init__.py │ │ ├── configurations.py │ │ ├── model.py │ │ ├── pretrain.py │ │ └── train.py │ ├── Devign │ │ ├── __init__.py │ │ ├── configurations.py │ │ ├── model.py │ │ ├── train.py │ │ └── util.py │ ├── Reveal │ │ ├── __init__.py │ │ ├── configurations.py │ │ ├── model.py │ │ ├── preprocessing.py │ │ ├── pretrain.py │ │ ├── train.py │ │ └── util.py │ ├── __init__.py │ └── common_model.py ├── self_supervised │ ├── __init__.py │ ├── moco.py │ ├── simclr.py │ └── utils.py └── vulexplainer │ ├── __init__.py │ ├── explainer_models.py │ └── util.py └── preprocess ├── __init__.py ├── code2class_vocab.py ├── graphs_vocab.py ├── parser ├── languages.exp ├── languages.lib └── languages.so ├── sitter-libs ├── c │ ├── Cargo.toml │ ├── LICENSE │ ├── Package.swift │ ├── README.md │ ├── binding.gyp │ ├── bindings │ │ ├── node │ │ │ ├── binding.cc │ │ │ └── index.js │ │ ├── rust │ │ │ ├── README.md │ │ │ ├── build.rs │ │ │ └── lib.rs │ │ └── swift │ │ │ └── TreeSitterC │ │ │ └── c.h │ ├── examples │ │ ├── cluster.c │ │ ├── malloc.c │ │ └── parser.c │ ├── grammar.js │ ├── package.json │ ├── queries │ │ └── highlights.scm │ ├── src │ │ ├── grammar.json │ │ ├── node-types.json │ │ ├── parser.c │ │ └── tree_sitter │ │ │ └── parser.h │ └── test │ │ ├── corpus │ │ ├── ambiguities.txt │ │ ├── crlf.txt │ │ ├── declarations.txt │ │ ├── expressions.txt │ │ ├── microsoft.txt │ │ ├── preprocessor.txt │ │ ├── statements.txt │ │ └── types.txt │ │ └── highlight │ │ ├── keywords.c │ │ └── names.c └── cpp │ ├── Cargo.toml │ ├── LICENSE │ ├── Package.swift │ ├── README.md │ ├── binding.gyp │ ├── bindings │ ├── node │ │ ├── binding.cc │ │ └── index.js │ ├── rust │ │ ├── README.md │ │ ├── build.rs │ │ └── lib.rs │ └── swift │ │ └── TreeSitterCPP │ │ └── cpp.h │ ├── examples │ ├── marker-index.h │ └── rule.cc │ ├── grammar.js │ ├── package.json │ ├── queries │ ├── highlights.scm │ └── injections.scm │ ├── src │ ├── grammar.json │ ├── node-types.json │ ├── parser.c │ ├── scanner.cc │ └── tree_sitter │ │ ├── parser.h │ │ └── runtime.h │ └── test │ ├── corpus │ ├── ambiguities.txt │ ├── concepts.txt │ ├── declarations.txt │ ├── definitions.txt │ ├── expressions.txt │ ├── microsoft.txt │ ├── statements.txt │ └── types.txt │ └── highlight │ ├── keywords.cpp │ └── names.cpp ├── tokenize.py └── transformations ├── __init__.py ├── block_swap_transformations.py ├── confusion_remove.py ├── dead_code_inserter.py ├── demo_transformation.py ├── for_while_transformation.py ├── no_transform.py ├── operand_swap_transformations.py ├── syntactic_noising_transformation.py ├── transformation_base.py ├── transformation_main.py └── var_renaming_transformation.py /Figures/CE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/Figures/CE.png -------------------------------------------------------------------------------- /Figures/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/Figures/Overview.png -------------------------------------------------------------------------------- /Figures/SupCon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/Figures/SupCon.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Coca: Improving and Explaining Deep Learning-based Vulnerability Detection Systems 3 |

4 | 5 |

6 | Deep Learning (DL) models are increasingly integrated into vulnerability detection systems, and has achieved remarkable success. However, the lack of explainability poses a critical challenge to deploy black-box models in security-related domains. For this reason, several approaches have been proposed to explain the decision logic of the detection model by providing a set of crucial statements positively contributing to its predictions. Unfortunately, due to the weak robustness of DL models and failure to satisfy certain special requirements in security domains, existing explanation approaches are not directly applicable to DL-based vulnerability detection systems. 7 | 8 | In this paper, we propose Coca, a general framework aiming to 1) enhance the robustness of existing neural vulnerability detection models to avoid spurious explanations; and 2) provide both _concise_ and _effective_ explanations to reason about the detected vulnerabilities. Coca consists of two core parts referred to as _Trainer_ and _Explainer_. The former aims to train a detection model which is robust to random perturbation based on contrastive learning, while the latter builds an explainer to derive crucial statements that are most decisive to the detected vulnerability via dual-view causal inference. We apply _Trainer_ over three types of DL-based vulnerability detectors and provide the prototype implementation of _Explainer_ for GNN-based models. Experimental results show that Coca can effectively improve existing DL-based vulnerability detection systems, and provide high-quality explanations. 9 | 10 | 11 | ## Prerequisites 12 | 13 | Install the necessary dependencies before running the project: 14 | 15 | ### Environment Requirements 16 | ``` 17 | torch==1.9.0 18 | torchvision==0.10.0 19 | pytorch-lightning==1.4.2 20 | tqdm>=4.62.1 21 | wandb==0.12.0 22 | pytest>=6.2.4 23 | wget>=3.2 24 | split-folders==0.4.3 25 | omegaconf==2.1.1 26 | torchmetrics==0.5.0 27 | joblib>=1.0.1 28 | ``` 29 | 30 | 31 | ### Thrid Party Liraries 32 | 33 | - [Joern](https://github.com/joernio/joern) 34 | - [tree-sitter](https://github.com/tree-sitter/tree-sitter) 35 | 36 | ## Dataset 37 | The Dataset we used in the paper: 38 | 39 | Big-Vul [1]: https://drive.google.com/file/d/1-0VhnHBp9IGh90s2wCNjeCMuy70HPl8X/view?usp=sharing 40 | 41 | Reveal [2]: https://drive.google.com/drive/folders/1KuIYgFcvWUXheDhT--cBALsfy1I4utOy 42 | 43 | Devign [3]: https://drive.google.com/file/d/1x6hoF7G-tSYxg8AFybggypLZgMGDNHfF 44 | 45 | CrossVul [4]: https://zenodo.org/record/4734050 46 | 47 | CVEFixes [5]: https://zenodo.org/record/4476563 48 | 49 | ## t-SNE Visualization 50 | 51 | **(1) DeepWuKong (Standard Cross-Entropy)** 52 |

53 | 54 |

55 | 56 | **(2) DeepWuKong (Supervised Contrastive Learning)** 57 |

58 | 59 |

60 | 61 | 62 | ## Reference 63 | 64 | [1] Jiahao Fan, Yi Li, Shaohua Wang, and Tien Nguyen. A C/C++ Code Vulnerability Dataset with Code Changes and CVE Summaries. MSR 2020. 65 | 66 | [2] Saikat Chakraborty, Rahul Krishna, Yangruibo Ding, and Baishakhi Ray. Deep Learning based Vulnerability Detection: Are We There Yet? IEEE Transactions on Software Engineering, 2022. 67 | 68 | [3] Yaqin Zhou, Shangqing Liu, Jingkai Siow, Xiaoning Du, and Yang Liu. Devign: Effective vulnerability identification by learning comprehensive program semantics via graph neural networks. NeurIPS 2019. 69 | 70 | [4] Georgios Nikitopoulos, Konstantina Dritsa, Panos Louridas, and Dimitris Mitropoulos. CrossVul: A Cross-Language Vulnerability Dataset with Commit Data. ESEC/FSE 2021. 71 | 72 | [5] Guru Bhandari, Amara Naseer, and Leon Moonen. CVEfixes: Automated Collection of Vulnerabilities and Their Fixes from Open-Source Software. PROMISE 2021. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders import encoder_models 2 | from .self_supervised import MocoV2Model, BYOLModel, ssl_models, ssl_models_transforms 3 | 4 | __all__ = [ 5 | "MocoV2Model", 6 | "BYOLModel", 7 | "ssl_models", 8 | "ssl_models_transforms", 9 | "encoder_models" 10 | ] 11 | -------------------------------------------------------------------------------- /models/detectors/DeepWuKong/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # https://github.com/jumormt/DeepWukong -------------------------------------------------------------------------------- /models/detectors/DeepWuKong/configurations.py: -------------------------------------------------------------------------------- 1 | from tap import Tap 2 | import random 3 | import sys 4 | from global_defines import vul_types, cur_dir, device, num_classes, cur_vul_type_idx 5 | ### DeepWuKong Configuration 6 | 7 | type = vul_types[cur_vul_type_idx] 8 | 9 | 10 | class ModelParser(Tap): 11 | pretrain_word2vec_model: str = f"{cur_dir}/models/{type}/word/w2v_slice.model" 12 | vector_size: int = 128 # 图结点的向量维度 13 | hidden_size: int = 128 # GNN隐层向量维度 14 | layer_num: int = 3 # GNN层数 15 | rnn_layer_num: int = 1 # RNN层数 16 | num_classes: int = num_classes 17 | model_dir = f"{cur_dir}/models/{type}/model/" 18 | device = device 19 | model_name = 'gcn' 20 | detector = 'dwk' 21 | 22 | 23 | class DataParser(Tap): 24 | dataset_dir: str = f'{cur_dir}/datasets/{type}' 25 | 26 | shuffle_data: bool = True # 是否随机打乱数据集 27 | num_workers: int = 8 28 | 29 | random_split: bool = True 30 | batch_size: int = 64 31 | test_batch_size: int = 64 32 | 33 | device = device 34 | num_classes = 2 35 | 36 | 37 | class TrainParser(Tap): 38 | max_epochs: int = 100 39 | early_stopping: int = 5 40 | save_epoch: int = 5 41 | learning_rate: float = 0.002 42 | weight_decay: float = 1.3e-6 43 | 44 | 45 | random.seed(2) 46 | model_args = ModelParser().parse_args(known_only=True) 47 | data_args = DataParser().parse_args(known_only=True) 48 | train_args = TrainParser().parse_args(known_only=True) -------------------------------------------------------------------------------- /models/detectors/DeepWuKong/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_geometric.nn.conv import GCNConv, gcn_conv 5 | from torch_geometric.data import Batch, Data 6 | from detectors.DeepWuKong.configurations import model_args 7 | from torch_geometric.typing import Adj, OptTensor 8 | from torch import Tensor 9 | from torch_sparse import SparseTensor 10 | 11 | from detectors.common_model import GlobalMaxMeanPool 12 | 13 | class GCNConvGrad(GCNConv): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.edge_weight = None 17 | 18 | def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: 19 | """""" 20 | if self.normalize and edge_weight is None: 21 | if isinstance(edge_index, Tensor): 22 | cache = self._cached_edge_index 23 | if cache is None: 24 | edge_index, edge_weight = gcn_conv.gcn_norm( # yapf: disable 25 | edge_index, edge_weight, x.size(self.node_dim), 26 | self.improved, self.add_self_loops, dtype=x.dtype) 27 | if self.cached: 28 | self._cached_edge_index = (edge_index, edge_weight) 29 | else: 30 | edge_index, edge_weight = cache[0], cache[1] 31 | 32 | elif isinstance(edge_index, SparseTensor): 33 | cache = self._cached_adj_t 34 | if cache is None: 35 | edge_index = gcn_conv.gcn_norm( # yapf: disable 36 | edge_index, edge_weight, x.size(self.node_dim), 37 | self.improved, self.add_self_loops, dtype=x.dtype) 38 | if self.cached: 39 | self._cached_adj_t = edge_index 40 | else: 41 | edge_index = cache 42 | 43 | # --- add require_grad --- 44 | edge_weight.requires_grad_(True) 45 | x = torch.matmul(x, self.weight) 46 | # propagate_type: (x: Tensor, edge_weight: OptTensor) 47 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, 48 | size=None) 49 | if self.bias is not None: 50 | out += self.bias 51 | 52 | # --- My: record edge_weight --- 53 | self.edge_weight = edge_weight 54 | 55 | return out 56 | 57 | 58 | class DeepWuKongModel(nn.Module): 59 | def __init__(self, need_node_emb=False): 60 | super(DeepWuKongModel, self).__init__() 61 | self.need_node_emb = need_node_emb 62 | self.hidden_dim = model_args.hidden_size 63 | cons = [GCNConv(model_args.vector_size, self.hidden_dim)] 64 | cons.extend([ 65 | GCNConv(self.hidden_dim, self.hidden_dim) 66 | for _ in range(model_args.layer_num - 1) 67 | ]) 68 | self.convs = nn.ModuleList(cons) 69 | 70 | self.relus = nn.ModuleList( 71 | [ 72 | nn.ReLU() 73 | for _ in range(model_args.layer_num) 74 | ] 75 | ) 76 | self.readout = GlobalMaxMeanPool() 77 | self.ffn = nn.Sequential(*( 78 | [nn.Linear(self.hidden_dim * 2, self.hidden_dim)] + 79 | [nn.ReLU(), nn.Dropout()] 80 | )) 81 | 82 | self.final_layer = nn.Linear(self.hidden_dim, model_args.num_classes) 83 | self.dropout = nn.Dropout() 84 | 85 | # def generateGraphData(self, sequenceData: List[torch.FloatTensor], edge_index: torch.LongTensor, y: int) -> Data: 86 | # ''' 87 | # :param sequenceData: List of tensor size [seq_length, num_embedding] 88 | # :param edge_index: 89 | # :return: 90 | # ''' 91 | # feature = [] 92 | # for seq_data in sequenceData: 93 | # feature.append(self.gru(seq_data.unsqueeze(dim=0))[0][0, -1, :]) 94 | # feature = torch.stack(feature) 95 | # # feature = pack_sequence(sequenceData, enforce_sorted=False) 96 | # # feature, _ = self.gru(feature.float()) 97 | # # feature, out_len = pad_packed_sequence(feature, batch_first=True) # [node_num, max_seq_len, feature_size] 98 | # # feature = feature[:, -1:, :].squeeze() # [node_num, 2 * hidden_size] 99 | # return Data(x=feature, edge_index=edge_index, y=y) 100 | 101 | # support backwards-based explainers defined in DIG 102 | def arguments_read(self, *args, **kwargs): 103 | data: Batch = kwargs.get('data') or None 104 | 105 | if not data: 106 | if not args: 107 | assert 'x' in kwargs 108 | assert 'edge_index' in kwargs 109 | x, edge_index = kwargs['x'], kwargs['edge_index'], 110 | batch = kwargs.get('batch') 111 | if batch is None: 112 | batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=torch.device('cuda')) 113 | elif len(args) == 2: 114 | x, edge_index, batch = args[0], args[1], \ 115 | torch.zeros(args[0].shape[0], dtype=torch.int64, device=torch.device('cuda')) 116 | elif len(args) == 3: 117 | x, edge_index, batch = args[0], args[1], args[2] 118 | else: 119 | raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") 120 | else: 121 | x, edge_index, batch = data.x, data.edge_index, data.batch 122 | return x, edge_index, batch 123 | 124 | def forward(self, *args, **kwargs): 125 | """ 126 | :param Required[data]: Batch - input data 127 | :return: 128 | """ 129 | x, edge_index, batch = self.arguments_read(*args, **kwargs) 130 | post_conv = x 131 | for conv, relu in zip(self.convs, self.relus): 132 | post_conv = relu(conv(post_conv, edge_index)) 133 | out_readout = self.readout(post_conv, batch) # [batch_graph, hidden_size] 134 | out = self.ffn(out_readout) 135 | out = self.final_layer(out) 136 | 137 | if self.need_node_emb: 138 | return out, post_conv 139 | else: 140 | return out 141 | 142 | 143 | if __name__ == '__main__': 144 | x1 = torch.ones(size=(5, model_args.vector_size)) 145 | x2 = torch.ones(size=(6, model_args.vector_size)) 146 | edge_index = torch.LongTensor([0, 1]).reshape(2, -1) 147 | y = 1 148 | 149 | dwk_model = DeepWuKongModel() 150 | data: Data = dwk_model.generateGraphData([x1, x2], edge_index, y) 151 | result = dwk_model(data=Batch.from_data_list([data])) 152 | print(result) -------------------------------------------------------------------------------- /models/detectors/DeepWuKong/pretrain.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from gensim.models.word2vec import Word2Vec 4 | from gensim.models.callbacks import CallbackAny2Vec 5 | from detectors.DeepWuKong.configurations import data_args, model_args 6 | 7 | 8 | 9 | src_file = os.path.join(data_args.dataset_dir, 'all_xfg.json') 10 | pretrain_word2vec_model_path = model_args.pretrain_word2vec_model 11 | vector_size = model_args.vector_size 12 | window_size = 10 13 | 14 | 15 | class PrintStatus(CallbackAny2Vec): 16 | def __init__(self): 17 | super().__init__() 18 | self.epoch = 0 19 | self.batch = 0 20 | 21 | def on_epoch_begin(self, model): 22 | self.epoch += 1 23 | print(f"epoch {self.epoch} start") 24 | 25 | def on_epoch_end(self, model): 26 | self.batch = 0 27 | print(f"epoch {self.epoch} end") 28 | 29 | def on_batch_begin(self, model): 30 | self.batch += 1 31 | print(f"epoch {self.epoch} - batch {self.batch} start") 32 | 33 | 34 | class Sentences: 35 | def __init__(self): 36 | self.datas: list = json.load(open(src_file, 'r', encoding='utf-8')) 37 | print(len(self.datas)) 38 | 39 | def __iter__(self): 40 | for data in self.datas: 41 | for node_info in data["line-nodes"]: 42 | raw_data = json.loads(node_info) 43 | yield raw_data["contents"][0][1].split(" ") 44 | # for content in data["line-contents"]: 45 | # yield content.split(" ") 46 | 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | sentences = Sentences() 52 | model = Word2Vec(size=vector_size, window=window_size, hs=1, min_count=1, workers=4) 53 | model.build_vocab(sentences) 54 | model.train(sentences, epochs=20, total_examples=model.corpus_count) 55 | model.save(pretrain_word2vec_model_path) -------------------------------------------------------------------------------- /models/detectors/Devign/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | ## https://github.com/saikat107/Devign -------------------------------------------------------------------------------- /models/detectors/Devign/configurations.py: -------------------------------------------------------------------------------- 1 | from tap import Tap 2 | import random 3 | import sys 4 | from global_defines import vul_types, cur_dir, device, num_classes, cur_vul_type_idx 5 | 6 | 7 | ### Devign Configuration 8 | 9 | type = vul_types[cur_vul_type_idx] 10 | 11 | type_map = { 12 | 'AndExpression': 1, 'Sizeof': 2, 'Identifier': 3, 'ForInit': 4, 'ReturnStatement': 5, 'SizeofOperand': 6, 13 | 'InclusiveOrExpression': 7, 'PtrMemberAccess': 8, 'AssignmentExpr': 9, 'ParameterList': 10, 14 | 'IdentifierDeclType': 11, 'SizeofExpr': 12, 'SwitchStatement': 13, 'IncDec': 14, 'Function': 15, 15 | 'BitAndExpression': 16, 'UnaryOp': 17, 'DoStatement': 18, 'GotoStatement': 19, 'Callee': 20, 16 | 'OrExpression': 21, 'ShiftExpression': 22, 'Decl': 23, 'CFGErrorNode': 24, 'WhileStatement': 25, 17 | 'InfiniteForNode': 26, 'RelationalExpression': 27, 'CFGExitNode': 28, 'Condition': 29, 'BreakStatement': 30, 18 | 'CompoundStatement': 31, 'UnaryOperator': 32, 'CallExpression': 33, 'CastExpression': 34, 19 | 'ConditionalExpression': 35, 'ArrayIndexing': 36, 'PostfixExpression': 37, 'Label': 38, 20 | 'ArgumentList': 39, 'EqualityExpression': 40, 'ReturnType': 41, 'Parameter': 42, 'Argument': 43, 'Symbol': 44, 21 | 'ParameterType': 45, 'Statement': 46, 'AdditiveExpression': 47, 'PrimaryExpression': 48, 'DeclStmt': 49, 22 | 'CastTarget': 50, 'IdentifierDeclStatement': 51, 'IdentifierDecl': 52, 'CFGEntryNode': 53, 'TryStatement': 54, 23 | 'Expression': 55, 'ExclusiveOrExpression': 56, 'ClassDef': 57, 'ClassStaticIdentifier': 58, 'ForRangeInit': 59, 24 | 'ClassDefStatement': 60, 'FunctionDef': 61, 'IfStatement': 62, 'MultiplicativeExpression': 63, 25 | 'ContinueStatement': 64, 'MemberAccess': 65, 'ExpressionStatement': 66, 'ForStatement': 67, 'InitializerList': 68, 26 | 'ElseStatement': 69, 'ThrowExpression': 70, 'IncDecOp': 71, 'NewExpression': 72, 'DeleteExpression': 73, 27 | 'BoolExpression': 74, 28 | 'CharExpression': 75, 'DoubleExpression': 76, 'IntegerExpression': 77, 'PointerExpression': 78, 29 | 'StringExpression': 79, 30 | 'ExpressionHolderStatement': 80 31 | } 32 | 33 | 34 | class ModelParser(Tap): 35 | pretrain_word2vec_model: str = f"{cur_dir}/models/{type}/word/w2v.model" 36 | vector_size: int = 100 # 图结点的向量维度 37 | hidden_size: int = 256 # GNN隐层向量维度 38 | layer_num: int = 3 # GNN层数 39 | num_classes: int = num_classes 40 | model_dir = f"{cur_dir}/models/{type}/model/" 41 | device = device 42 | model_name = 'ggnn' 43 | detector = 'devign' 44 | 45 | 46 | class DataParser(Tap): 47 | dataset_dir: str = f'{cur_dir}/datasets/{type}' 48 | 49 | shuffle_data: bool = True # 是否随机打乱数据集 50 | num_workers: int = 8 51 | 52 | random_split: bool = True 53 | batch_size: int = 64 54 | test_batch_size: int = 64 55 | 56 | device = device 57 | num_classes = 2 58 | 59 | 60 | class TrainParser(Tap): 61 | max_epochs: int = 100 62 | early_stopping: int = 5 63 | save_epoch: int = 5 64 | learning_rate: float = 1e-4 65 | weight_decay: float = 1.3e-6 66 | 67 | 68 | random.seed(2) 69 | model_args = ModelParser().parse_args(known_only=True) 70 | data_args = DataParser().parse_args(known_only=True) 71 | train_args = TrainParser().parse_args(known_only=True) 72 | -------------------------------------------------------------------------------- /models/detectors/Devign/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from torch_geometric.nn.conv import GatedGraphConv 5 | from torch_geometric.data import Batch 6 | from detectors.Devign.configurations import model_args, type_map 7 | from detectors.common_model import GlobalMaxPool 8 | 9 | 10 | # Devign分类模型 11 | class DevignModel(nn.Module): 12 | def __init__(self, num_layers=1, MLP_hidden_dim=256, need_node_emb=False): 13 | super().__init__() 14 | MLP_internal_dim = int(MLP_hidden_dim / 2) 15 | input_dim = len(type_map) + model_args.vector_size 16 | self.hidden_dim = input_dim 17 | # GGNN层 18 | self.GGNN = GatedGraphConv(out_channels=input_dim, num_layers=5) 19 | self.readout = GlobalMaxPool() 20 | self.dropout_p = 0.2 21 | self.need_node_emb = need_node_emb 22 | 23 | # MLP层 24 | self.layer1 = nn.Sequential( 25 | nn.Linear(in_features=input_dim, out_features=MLP_hidden_dim, bias=True), 26 | nn.ReLU(), 27 | nn.Dropout(p=self.dropout_p) 28 | ) 29 | self.feature = nn.ModuleList([nn.Sequential( 30 | nn.Linear(in_features=MLP_hidden_dim, out_features=MLP_internal_dim, bias=True), 31 | nn.ReLU(), 32 | nn.Dropout(p=self.dropout_p), 33 | nn.Linear(in_features=MLP_internal_dim, out_features=MLP_hidden_dim, bias=True), 34 | nn.ReLU(), 35 | nn.Dropout(p=self.dropout_p), 36 | ) for _ in range(num_layers)]) 37 | 38 | self.classifier = nn.Sequential( 39 | nn.Linear(in_features=MLP_hidden_dim, out_features=2), 40 | ) 41 | 42 | def extract_feature(self, x): 43 | out = self.layer1(x) 44 | for layer in self.feature: 45 | out = layer(out) 46 | return out 47 | 48 | def embed_graph(self, x, edge_index, batch): 49 | node_emb = self.GGNN(x, edge_index) 50 | if batch is None: 51 | batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) 52 | graph_emb = self.readout(node_emb, batch) # [batch_size, embedding_dim] 53 | return graph_emb, node_emb 54 | 55 | def arguments_read(self, *args, **kwargs): 56 | data: Batch = kwargs.get('data') or None 57 | 58 | if not data: 59 | if not args: 60 | assert 'x' in kwargs 61 | assert 'edge_index' in kwargs 62 | x, edge_index = kwargs['x'], kwargs['edge_index'], 63 | batch = kwargs.get('batch') 64 | if batch is None: 65 | batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device) 66 | elif len(args) == 2: 67 | x, edge_index, batch = args[0], args[1], \ 68 | torch.zeros(args[0].shape[0], dtype=torch.int64, device=args[0].device) 69 | elif len(args) == 3: 70 | x, edge_index, batch = args[0], args[1], args[2] 71 | else: 72 | raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") 73 | else: 74 | x, edge_index, batch = data.x, data.edge_index, data.batch 75 | return x, edge_index, batch 76 | 77 | def forward(self, *args, **kwargs): 78 | x, edge_index, batch = self.arguments_read(*args, **kwargs) 79 | graph_emb, node_emb = self.embed_graph(x, edge_index, batch) 80 | feature_emb = self.extract_feature(graph_emb) 81 | probs = self.classifier(feature_emb) # [batch_size, 2] 82 | # 返回node_emb适应PGExplainer 83 | if self.need_node_emb: 84 | return probs, node_emb 85 | else: 86 | return probs 87 | -------------------------------------------------------------------------------- /models/detectors/Devign/util.py: -------------------------------------------------------------------------------- 1 | from gensim.models import Word2Vec 2 | import numpy as np 3 | import torch 4 | from torch_geometric.data import Data, Batch 5 | 6 | import json 7 | from typing import Dict, List, Tuple, Set 8 | 9 | from detectors.Devign.configurations import type_map, model_args, data_args 10 | from detectors.Devign.model import DevignModel 11 | 12 | 13 | class DevignUtil(object): 14 | def __init__(self, pretrain_model: Word2Vec, devign_model: DevignModel): 15 | self.pretrain_model = pretrain_model 16 | self.devign_model = devign_model 17 | self.arrays = np.eye(80) 18 | 19 | # 生成图中每个ASTNode的初始embedding,在训练阶段这个函数只执行一次 20 | # nodeContent[0] 为type, nodeContent[1] 为token sequence 21 | def generate_initial_astNode_embedding(self, nodeContent: List[str]) -> np.array: 22 | # type vector 23 | n_c = self.arrays[type_map[nodeContent[0]]] 24 | # token sequence 25 | token_seq: List[str] = nodeContent[1].split(' ') 26 | n_v = np.array([self.pretrain_model[word] if word in self.pretrain_model.wv.vocab else 27 | np.zeros(model_args.vector_size) for word in token_seq]).mean(axis=0) 28 | 29 | v = np.concatenate([n_c, n_v]) 30 | return v 31 | 32 | # 生成每个AST初始结点的信息 33 | # 与Reveal相比,它还需要在AST终端结点添加NCS边 34 | def generate_initial_node_info(self, ast: Dict) -> Data: 35 | astEmbedding: np.array = np.array( 36 | [self.generate_initial_astNode_embedding(node_info) for node_info in ast["contents"]]) 37 | x: torch.FloatTensor = torch.FloatTensor(astEmbedding) 38 | edges: List[List[int]] = [[edge[1], edge[0]] for edge in ast["edges"]] 39 | # 添加NCS边 40 | # 先找出图中所有的AST终端结点,按照索引顺序升序排序 41 | # 找出非终端结点的索引 42 | parent_idxs: Set[int] = set() 43 | terminal_idxs: List[int] = list() 44 | for edge in edges: 45 | parent_idxs.add(edge[1]) 46 | # 找出终端结点的索引 47 | for i in range(len(x)): 48 | if i not in parent_idxs: 49 | terminal_idxs.append(i) 50 | # 添加NCS边 51 | for i in range(1, len(terminal_idxs)): 52 | edges.append([terminal_idxs[i - 1], terminal_idxs[i]]) 53 | 54 | edge_index: torch.LongTensor = torch.LongTensor(edges).t() 55 | return Batch.from_data_list([Data(x=x, edge_index=edge_index)]).to(device=data_args.device) 56 | 57 | # 预处理训练数据, 训练过程只调用1次 58 | def generate_initial_training_datas(self, data: Dict) -> Tuple[int, List[Data], torch.LongTensor]: 59 | # label, List[ASTNode], edge 60 | label: int = data["target"] 61 | 62 | # edges 63 | cfgEdges = [json.loads(edge)[:2] for edge in data["cfgEdges"]] 64 | ddgEdges = [json.loads(edge)[:2] for edge in data["ddgEdges"]] 65 | edges = cfgEdges + ddgEdges 66 | edge_index: torch.LongTensor = torch.LongTensor(edges).t() 67 | 68 | # nodes 69 | nodes_info: List[Dict] = [json.loads(node_infos) for node_infos in data["nodes"]] 70 | graph_data_for_each_nodes: List[Data] = [self.generate_initial_node_info(node_info) for node_info in nodes_info] 71 | 72 | return (label, graph_data_for_each_nodes, edge_index) 73 | 74 | # 生成图初始向量, 每个epoch会调用1次 75 | def generate_initial_graph_embedding(self, graph_info: Tuple[int, List[Data], torch.LongTensor]) -> Data: 76 | # self.reveal_model.embed_graph(data)[0] return graph_embedding for initial CPG node 77 | initial_embeddings: List[torch.FloatTensor] = [ 78 | self.devign_model.embed_graph(data.x, data.edge_index, None)[0].reshape(-1, ) 79 | # 某些AST子树可能没有子结点,直接取其值作为node embedding 80 | if len(data.edge_index) > 0 else data.x[0] 81 | for data in graph_info[1]] 82 | X: torch.FloatTensor = torch.stack(initial_embeddings) 83 | return Data(x=X, edge_index=graph_info[2], y=torch.tensor([graph_info[0]], dtype=torch.long)) 84 | -------------------------------------------------------------------------------- /models/detectors/Reveal/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | ## https://github.com/VulDetProject/ReVeal -------------------------------------------------------------------------------- /models/detectors/Reveal/configurations.py: -------------------------------------------------------------------------------- 1 | from tap import Tap 2 | import random 3 | from global_defines import vul_types, cur_dir, device, num_classes, cur_vul_type_idx 4 | 5 | # Reveal Configuration 6 | 7 | type = vul_types[cur_vul_type_idx] 8 | 9 | type_map = { 10 | 'AndExpression': 1, 'Sizeof': 2, 'Identifier': 3, 'ForInit': 4, 'ReturnStatement': 5, 'SizeofOperand': 6, 11 | 'InclusiveOrExpression': 7, 'PtrMemberAccess': 8, 'AssignmentExpr': 9, 'ParameterList': 10, 12 | 'IdentifierDeclType': 11, 'SizeofExpr': 12, 'SwitchStatement': 13, 'IncDec': 14, 'Function': 15, 13 | 'BitAndExpression': 16, 'UnaryOp': 17, 'DoStatement': 18, 'GotoStatement': 19, 'Callee': 20, 14 | 'OrExpression': 21, 'ShiftExpression': 22, 'Decl': 23, 'CFGErrorNode': 24, 'WhileStatement': 25, 15 | 'InfiniteForNode': 26, 'RelationalExpression': 27, 'CFGExitNode': 28, 'Condition': 29, 'BreakStatement': 30, 16 | 'CompoundStatement': 31, 'UnaryOperator': 32, 'CallExpression': 33, 'CastExpression': 34, 17 | 'ConditionalExpression': 35, 'ArrayIndexing': 36, 'PostfixExpression': 37, 'Label': 38, 18 | 'ArgumentList': 39, 'EqualityExpression': 40, 'ReturnType': 41, 'Parameter': 42, 'Argument': 43, 'Symbol': 44, 19 | 'ParameterType': 45, 'Statement': 46, 'AdditiveExpression': 47, 'PrimaryExpression': 48, 'DeclStmt': 49, 20 | 'CastTarget': 50, 'IdentifierDeclStatement': 51, 'IdentifierDecl': 52, 'CFGEntryNode': 53, 'TryStatement': 54, 21 | 'Expression': 55, 'ExclusiveOrExpression': 56, 'ClassDef': 57, 'ClassStaticIdentifier': 58, 'ForRangeInit': 59, 22 | 'ClassDefStatement': 60, 'FunctionDef': 61, 'IfStatement': 62, 'MultiplicativeExpression': 63, 23 | 'ContinueStatement': 64, 'MemberAccess': 65, 'ExpressionStatement': 66, 'ForStatement': 67, 'InitializerList': 68, 24 | 'ElseStatement': 69, 'ThrowExpression': 70, 'IncDecOp': 71, 'NewExpression': 72, 'DeleteExpression': 73, 25 | 'BoolExpression': 74, 26 | 'CharExpression': 75, 'DoubleExpression': 76, 'IntegerExpression': 77, 'PointerExpression': 78, 27 | 'StringExpression': 79, 28 | 'ExpressionHolderStatement': 80 29 | } 30 | 31 | 32 | class ModelParser(Tap): 33 | pretrain_word2vec_model: str = f"{cur_dir}/models/{type}/word/w2v.model" 34 | vector_size: int = 100 # 图结点的向量维度 35 | hidden_size: int = 256 # GNN隐层向量维度 36 | layer_num: int = 3 # GNN层数 37 | num_classes: int = num_classes 38 | model_dir = f"{cur_dir}/models/{type}/model/" 39 | device = device 40 | model_name = 'ggnn' 41 | detector = 'reveal' 42 | 43 | 44 | class DataParser(Tap): 45 | dataset_dir: str = f'{cur_dir}/datasets/{type}' 46 | 47 | shuffle_data: bool = True # 是否随机打乱数据集 48 | num_workers: int = 8 49 | 50 | random_split: bool = True 51 | batch_size: int = 64 52 | test_batch_size: int = 64 53 | 54 | device = device 55 | num_classes = 2 56 | 57 | 58 | class TrainParser(Tap): 59 | max_epochs: int = 1 60 | early_stopping: int = 5 61 | save_epoch: int = 5 62 | learning_rate: float = 1e-4 63 | weight_decay: float = 1.3e-6 64 | 65 | 66 | random.seed(2) 67 | model_args = ModelParser().parse_args(known_only=True) 68 | data_args = DataParser().parse_args(known_only=True) 69 | train_args = TrainParser().parse_args(known_only=True) 70 | -------------------------------------------------------------------------------- /models/detectors/Reveal/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import sys 4 | import os 5 | 6 | from torch_geometric.nn.conv import GatedGraphConv 7 | from torch_geometric.data import Batch 8 | from detectors.Reveal.configurations import model_args, type_map 9 | 10 | from detectors.common_model import GlobalAddPool 11 | 12 | 13 | # Reveal分类模型 14 | class ClassifyModel(nn.Module): 15 | def __init__(self, num_layers=1, MLP_hidden_dim=256, need_node_emb=False): 16 | super().__init__() 17 | MLP_internal_dim = int(MLP_hidden_dim / 2) 18 | input_dim = len(type_map) + model_args.vector_size 19 | self.hidden_dim = input_dim 20 | # GGNN层 21 | self.GGNN = GatedGraphConv(out_channels=input_dim, num_layers=5) 22 | self.readout = GlobalAddPool() 23 | self.dropout_p = 0.2 24 | self.need_node_emb = need_node_emb 25 | 26 | # MLP层 27 | self.layer1 = nn.Sequential( 28 | nn.Linear(in_features=input_dim, out_features=MLP_hidden_dim, bias=True), 29 | nn.ReLU(), 30 | nn.Dropout(p=self.dropout_p) 31 | ) 32 | self.feature = nn.ModuleList([nn.Sequential( 33 | nn.Linear(in_features=MLP_hidden_dim, out_features=MLP_internal_dim, bias=True), 34 | nn.ReLU(), 35 | nn.Dropout(p=self.dropout_p), 36 | nn.Linear(in_features=MLP_internal_dim, out_features=MLP_hidden_dim, bias=True), 37 | nn.ReLU(), 38 | nn.Dropout(p=self.dropout_p), 39 | ) for _ in range(num_layers)]) 40 | 41 | self.classifier = nn.Sequential( 42 | nn.Linear(in_features=MLP_hidden_dim, out_features=2), 43 | ) 44 | 45 | def extract_feature(self, x): 46 | out = self.layer1(x) 47 | for layer in self.feature: 48 | out = layer(out) 49 | return out 50 | 51 | def embed_graph(self, x, edge_index, batch): 52 | node_emb = self.GGNN(x, edge_index) 53 | if batch is None: 54 | batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) 55 | graph_emb = self.readout(node_emb, batch) # [batch_size, embedding_dim] 56 | return graph_emb, node_emb 57 | 58 | def arguments_read(self, *args, **kwargs): 59 | data: Batch = kwargs.get('data') or None 60 | 61 | if not data: 62 | if not args: 63 | assert 'x' in kwargs 64 | assert 'edge_index' in kwargs 65 | x, edge_index = kwargs['x'], kwargs['edge_index'], 66 | batch = kwargs.get('batch') 67 | if batch is None: 68 | batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device) 69 | elif len(args) == 2: 70 | x, edge_index, batch = args[0], args[1], \ 71 | torch.zeros(args[0].shape[0], dtype=torch.int64, device=args[0].device) 72 | elif len(args) == 3: 73 | x, edge_index, batch = args[0], args[1], args[2] 74 | else: 75 | raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") 76 | else: 77 | x, edge_index, batch = data.x, data.edge_index, data.batch 78 | return x, edge_index, batch 79 | 80 | def forward(self, *args, **kwargs): 81 | x, edge_index, batch = self.arguments_read(*args, **kwargs) 82 | graph_emb, node_emb = self.embed_graph(x, edge_index, batch) 83 | feature_emb = self.extract_feature(graph_emb) 84 | probs = self.classifier(feature_emb) # [batch_size, 2] 85 | # 返回node_emb适应PGExplainer 86 | if self.need_node_emb: 87 | return probs, node_emb 88 | else: 89 | return feature_emb, probs 90 | -------------------------------------------------------------------------------- /models/detectors/Reveal/preprocessing.py: -------------------------------------------------------------------------------- 1 | from CppCodeAnalyzer.extraTools.vuldetect.deepwukong import * 2 | from CppCodeAnalyzer.mainTool.CPG import initialCalleeInfos, CFGToUDGConverter, ASTDefUseAnalyzer 3 | from time import time 4 | import sys 5 | import os 6 | import tqdm 7 | import json 8 | 9 | from global_defines import cur_dir 10 | 11 | 12 | def test_vul(): 13 | start = time() 14 | dataset = f"{cur_dir}/datasets/vulgen/test" 15 | configuration = f"{cur_dir}/detectors/Reveal/calleeInfos.json" 16 | 17 | calleeInfs = json.load(open(configuration, 'r', encoding='utf-8')) 18 | calleeInfos = initialCalleeInfos(calleeInfs) 19 | 20 | astAnalyzer: ASTDefUseAnalyzer = ASTDefUseAnalyzer() 21 | astAnalyzer.calleeInfos = calleeInfos 22 | converter: CFGToUDGConverter = CFGToUDGConverter() 23 | converter.astAnalyzer = astAnalyzer 24 | defUseConverter: CFGAndUDGToDefUseCFG = CFGAndUDGToDefUseCFG() 25 | ddgCreator: DDGCreator = DDGCreator() 26 | 27 | vul_json_content = [] 28 | nor_json_content = [] 29 | json_vul_path = f"{cur_dir}/datasets/test_vul.json" 30 | json_nor_path = f"{cur_dir}/datasets/test_nor.json" 31 | json_vul = open(json_vul_path, mode='w') 32 | json_nor = open(json_nor_path, mode='w') 33 | num_vul_file = len(os.listdir(dataset)) 34 | failure = 0 35 | 36 | with tqdm.tqdm(total=num_vul_file, leave=True, ncols=200, unit_scale=False) as bar: 37 | for root, dirs, files in os.walk(dataset): 38 | for file in files: 39 | if file.endswith("_vul.c"): 40 | bar.update(1) 41 | bar.set_postfix({"Current Processed File": file}) 42 | path = os.path.join(root, file) 43 | try: 44 | cpgs: List[CPG] = fileParse(path, converter, defUseConverter, ddgCreator) 45 | for cpg in cpgs: 46 | # json.dump(cpg.toSerializedJson(), json_file, indent=2) 47 | vul_json_content.append(cpg.toSerializedJson()) 48 | except: 49 | failure += 1 50 | break 51 | """ 52 | if file.endswith("_nonvul.c"): 53 | bar.update(1) 54 | bar.set_postfix({"Current Processed File": file}) 55 | path = os.path.join(root, file) 56 | try: 57 | cpgs: List[CPG] = fileParse(path, converter, defUseConverter, ddgCreator) 58 | for cpg in cpgs: 59 | # json.dump(cpg.toSerializedJson(), json_file, indent=2) 60 | nor_json_content.append(cpg.toSerializedJson()) 61 | except: 62 | failure += 1 63 | break 64 | """ 65 | 66 | json.dump(vul_json_content, json_vul, indent=2) 67 | # json.dump(nor_json_content, json_nor, indent=2) 68 | 69 | end = time() 70 | print(f"Successfully compile {num_vul_file - failure} samples") 71 | print(f"Fail to compile {failure} samples") 72 | print('Total time: {:.2f}s'.format(end - start)) 73 | return 74 | 75 | 76 | def test_nor(): 77 | start = time() 78 | dataset = f"{cur_dir}/datasets/vulgen/test" 79 | configuration = f"{cur_dir}/detectors/Reveal/calleeInfos.json" 80 | 81 | calleeInfs = json.load(open(configuration, 'r', encoding='utf-8')) 82 | calleeInfos = initialCalleeInfos(calleeInfs) 83 | 84 | astAnalyzer: ASTDefUseAnalyzer = ASTDefUseAnalyzer() 85 | astAnalyzer.calleeInfos = calleeInfos 86 | converter: CFGToUDGConverter = CFGToUDGConverter() 87 | converter.astAnalyzer = astAnalyzer 88 | defUseConverter: CFGAndUDGToDefUseCFG = CFGAndUDGToDefUseCFG() 89 | ddgCreator: DDGCreator = DDGCreator() 90 | 91 | save_json_content = [] 92 | json_nor_path = f"{cur_dir}/datasets/test_nor.json" 93 | json_nor = open(json_nor_path, mode='w') 94 | num_vul_file = len(os.listdir(dataset)) 95 | failure = 0 96 | 97 | with tqdm.tqdm(total=num_vul_file, leave=True, ncols=200, unit_scale=False) as bar: 98 | for root, dirs, files in os.walk(dataset): 99 | for file in files: 100 | if file.endswith("_nonvul.c"): 101 | bar.update(1) 102 | bar.set_postfix({"Current Processed File": file}) 103 | path = os.path.join(root, file) 104 | try: 105 | cpgs: List[CPG] = fileParse(path, converter, defUseConverter, ddgCreator) 106 | for cpg in cpgs: 107 | # json.dump(cpg.toSerializedJson(), json_file, indent=2) 108 | save_json_content.append(cpg.toSerializedJson()) 109 | except: 110 | failure += 1 111 | break 112 | 113 | json.dump(save_json_content, json_nor, indent=2) 114 | 115 | end = time() 116 | print(f"Successfully compile {failure} samples") 117 | print(f"Fail to compile {num_vul_file - failure} samples") 118 | print('Total time: {:.2f}s'.format(end - start)) 119 | return 120 | 121 | 122 | if __name__ == '__main__': 123 | test_vul() 124 | # test_nor() 125 | -------------------------------------------------------------------------------- /models/detectors/Reveal/pretrain.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from gensim.models.word2vec import Word2Vec 4 | from gensim.models.callbacks import CallbackAny2Vec 5 | from detectors.Reveal.configurations import data_args, model_args 6 | 7 | src_file = os.path.join(data_args.dataset_dir, 'pretrainCorpus.json') 8 | pretrain_word2vec_model_path = model_args.pretrain_word2vec_model 9 | vector_size = model_args.vector_size 10 | window_size = 10 11 | 12 | 13 | class PrintStatus(CallbackAny2Vec): 14 | def __init__(self): 15 | super().__init__() 16 | self.epoch = 0 17 | self.batch = 0 18 | 19 | def on_epoch_begin(self, model): 20 | self.epoch += 1 21 | print(f"epoch {self.epoch} start") 22 | 23 | def on_epoch_end(self, model): 24 | self.batch = 0 25 | print(f"epoch {self.epoch} end") 26 | 27 | def on_batch_begin(self, model): 28 | self.batch += 1 29 | print(f"epoch {self.epoch} - batch {self.batch} start") 30 | 31 | 32 | class Sentences: 33 | def __init__(self): 34 | self.datas: list = json.load(open(src_file, 'r', encoding='utf-8')) 35 | print(len(self.datas)) 36 | 37 | def __iter__(self): 38 | for data in self.datas: 39 | contents = data.split(' ') 40 | yield contents 41 | 42 | 43 | if __name__ == '__main__': 44 | sentences = Sentences() 45 | model = Word2Vec(size=vector_size, window=window_size, hs=1, min_count=1, workers=4) 46 | model.build_vocab(sentences) 47 | model.train(sentences, epochs=20, total_examples=model.corpus_count) 48 | model.save(pretrain_word2vec_model_path) 49 | # import numpy as np 50 | # arr = np.array([np.ones(100), np.zeros(100)]) 51 | # arr = np.concatenate([np.ones(100), np.zeros(20)]) 52 | # print(arr) 53 | # import torch 54 | # v = torch.FloatTensor(arr) 55 | # 56 | # data = torch.stack([torch.ones(size=(100,)), torch.ones(size=(100,))]) 57 | # print(data.size()) 58 | # 399 batch 59 | -------------------------------------------------------------------------------- /models/detectors/Reveal/util.py: -------------------------------------------------------------------------------- 1 | from gensim.models import Word2Vec 2 | import numpy as np 3 | import torch 4 | from torch_geometric.data import Data, Batch 5 | 6 | import json 7 | from typing import Dict, List, Tuple 8 | 9 | from detectors.Reveal.configurations import type_map, model_args, data_args 10 | from detectors.Reveal.model import ClassifyModel 11 | 12 | class RevealUtil(object): 13 | def __init__(self, pretrain_model: Word2Vec, reveal_model: ClassifyModel): 14 | self.pretrain_model = pretrain_model 15 | self.reveal_model = reveal_model 16 | self.arrays = np.eye(80) 17 | 18 | # 生成图中每个ASTNode的初始embedding,在训练阶段这个函数只执行一次 19 | # nodeContent[0] 为type, nodeContent[1] 为token sequence 20 | def generate_initial_astNode_embedding(self, nodeContent: List[str]) -> np.array: 21 | # type vector 22 | n_c = self.arrays[type_map[nodeContent[0]]] 23 | # token sequence 24 | token_seq: List[str] = nodeContent[1].split(' ') 25 | n_v = np.array([self.pretrain_model[word] if word in self.pretrain_model.wv.vocab else 26 | np.zeros(model_args.vector_size) for word in token_seq]).mean(axis=0) 27 | 28 | v = np.concatenate([n_c, n_v]) 29 | return v 30 | 31 | # 生成每个AST初始结点的信息 32 | def generate_initial_node_info(self, ast: Dict) -> Data: 33 | astEmbedding: np.array = np.array( 34 | [self.generate_initial_astNode_embedding(node_info) for node_info in ast["contents"]]) 35 | x: torch.FloatTensor = torch.FloatTensor(astEmbedding) 36 | edges: List[List[int]] = [[edge[1], edge[0]] for edge in ast["edges"]] 37 | edge_index: torch.LongTensor = torch.LongTensor(edges).t() 38 | return Batch.from_data_list([Data(x=x, edge_index=edge_index)]).to(device=data_args.device) 39 | 40 | # 预处理训练数据, 训练过程只调用1次 41 | def generate_initial_training_datas(self, data: Dict) -> Tuple[int, List[Data], torch.LongTensor]: 42 | # label, List[ASTNode], edge 43 | label: int = data["target"] 44 | 45 | # edges 46 | cfgEdges = [json.loads(edge)[:2] for edge in data["cfgEdges"]] 47 | cdgEdges = [json.loads(edge) for edge in data["cdgEdges"]] 48 | ddgEdges = [json.loads(edge)[:2] for edge in data["ddgEdges"]] 49 | edges = cfgEdges + cdgEdges + ddgEdges 50 | edge_index: torch.LongTensor = torch.LongTensor(edges).t() 51 | 52 | # nodes 53 | nodes_info: List[Dict] = [json.loads(node_infos) for node_infos in data["nodes"]] 54 | graph_data_for_each_nodes: List[Data] = [self.generate_initial_node_info(node_info) for node_info in nodes_info] 55 | 56 | return (label, graph_data_for_each_nodes, edge_index) 57 | 58 | # 生成图初始向量, 每个epoch会调用1次 59 | def generate_initial_graph_embedding(self, graph_info: Tuple[int, List[Data], torch.LongTensor]) -> Data: 60 | # self.reveal_model.embed_graph(data)[0] return graph_embedding for initial CPG node 61 | initial_embeddings: List[torch.FloatTensor] = [ 62 | self.reveal_model.embed_graph(data.x, data.edge_index, None)[0].reshape(-1, ) 63 | # 某些AST子树可能没有子结点,直接取其值作为node embedding 64 | if len(data.edge_index) > 0 else data.x[0] 65 | for data in graph_info[1]] 66 | X: torch.FloatTensor = torch.stack(initial_embeddings) 67 | return Data(x=X, edge_index=graph_info[2], y=torch.tensor([graph_info[0]], dtype=torch.long)) 68 | 69 | 70 | if __name__ == '__main__': 71 | pretrain_model = Word2Vec.load(model_args.pretrain_word2vec_model) 72 | reveal_model: ClassifyModel = ClassifyModel().to(model_args.device) 73 | reveal_util = RevealUtil(pretrain_model, reveal_model) 74 | -------------------------------------------------------------------------------- /models/detectors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/models/detectors/__init__.py -------------------------------------------------------------------------------- /models/detectors/common_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool 4 | 5 | 6 | # suit the API in DIG/xgraph 7 | class GNNPool(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | 12 | class GlobalAddPool(GNNPool): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, x, batch): 17 | return global_add_pool(x, batch) 18 | 19 | 20 | class GlobalMaxPool(GNNPool): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, batch): 25 | return global_max_pool(x, batch) 26 | 27 | 28 | class GlobalMaxMeanPool(GNNPool): 29 | def __init__(self): 30 | super().__init__() 31 | 32 | def forward(self, x, batch): 33 | return torch.cat((global_max_pool(x, batch), global_mean_pool(x, batch)), dim=1) 34 | -------------------------------------------------------------------------------- /models/self_supervised/__init__.py: -------------------------------------------------------------------------------- 1 | from .byol import BYOLModel, BYOLTransform 2 | from .moco import MocoV2Model 3 | from .simclr import SimCLRModel, SimCLRTransform 4 | from .swav import SwAVModel 5 | 6 | __all__ = [ 7 | "MocoV2Model", 8 | "BYOLModel", 9 | "SimCLRModel", 10 | "SwAVModel", 11 | "ssl_models", 12 | "ssl_models_transforms" 13 | ] 14 | 15 | ssl_models = { 16 | "MocoV2": MocoV2Model, 17 | "BYOL": BYOLModel, 18 | "SimCLR": SimCLRModel, 19 | "SwAV": SwAVModel 20 | } 21 | 22 | ssl_models_transforms = { 23 | "BYOL": BYOLTransform, 24 | "SimCLR": SimCLRTransform 25 | } 26 | -------------------------------------------------------------------------------- /models/self_supervised/moco.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import DictConfig 6 | from pl_bolts.models.self_supervised import Moco_v2 7 | 8 | from models.self_supervised.utils import ( 9 | validation_metrics, 10 | init_model, 11 | roc_auc, 12 | configure_optimizers, compute_num_samples 13 | ) 14 | 15 | 16 | class MocoV2Model(Moco_v2): 17 | def __init__( 18 | self, 19 | config: DictConfig, 20 | **kwargs 21 | ): 22 | self.save_hyperparameters() 23 | self.config = config 24 | 25 | super().__init__( 26 | base_encoder=config.name, 27 | emb_dim=config.num_classes, 28 | num_negatives=config.ssl.num_negatives, 29 | encoder_momentum=config.ssl.encoder_momentum, 30 | softmax_temperature=config.ssl.softmax_temperature, 31 | learning_rate=config.ssl.learning_rate, 32 | weight_decay=config.ssl.weight_decay, 33 | use_mlp=config.ssl.use_mlp, 34 | batch_size=config.hyper_parameters.batch_size, 35 | **kwargs 36 | ) 37 | 38 | # create the validation queue 39 | self.register_buffer("labels_queue", torch.zeros(config.ssl.num_negatives).long() - 1) 40 | 41 | train_data_path = join( 42 | config.data_folder, 43 | config.dataset.name, 44 | "raw", 45 | config.train_holdout 46 | ) 47 | 48 | self.train_iters_per_epoch = compute_num_samples(train_data_path) // self.config.hyper_parameters.batch_size 49 | 50 | def init_encoders(self, base_encoder: str): 51 | encoder_q = init_model(self.config) 52 | encoder_k = init_model(self.config) 53 | return encoder_q, encoder_k 54 | 55 | def forward(self, x): 56 | x = self.encoder_q(x) 57 | x = F.normalize(x, dim=1) 58 | return x 59 | 60 | @torch.no_grad() 61 | def _dequeue_and_enqueue(self, keys, labels): 62 | # gather keys before updating queue 63 | 64 | batch_size = keys.shape[0] 65 | 66 | ptr = int(self.queue_ptr) 67 | assert self.hparams.num_negatives % batch_size == 0 # for simplicity 68 | 69 | # replace the keys at ptr (dequeue and enqueue) 70 | self.queue[:, ptr:ptr + batch_size] = keys.T 71 | self.labels_queue[ptr:ptr + batch_size] = labels 72 | ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer 73 | 74 | self.queue_ptr[0] = ptr 75 | 76 | def representation(self, q, k): 77 | # compute query features 78 | q = self.encoder_q(q) # queries: NxC 79 | q = F.normalize(q, dim=1) 80 | 81 | # compute key features 82 | with torch.no_grad(): # no gradient to keys 83 | k = self.encoder_k(k) # keys: NxC 84 | k = F.normalize(k, dim=1) 85 | return q, k 86 | 87 | def uni_con(self, logits, target): 88 | sum_neg = ((1 - target) * torch.exp(logits)).sum(1) 89 | sum_pos = (target * torch.exp(-logits)).sum(1) 90 | loss = torch.log(1 + sum_neg * sum_pos) 91 | return torch.mean(loss) 92 | 93 | def _loss(self, q, k, labels, queue, labels_queue): 94 | # Einstein sum is more intuitive 95 | # positive logits: Nx1 96 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 97 | # negative logits: NxK 98 | l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()]) 99 | 100 | # logits: Nx(1+K) 101 | logits = torch.cat([l_pos, l_neg], dim=1) 102 | 103 | # apply temperature 104 | logits /= self.hparams.softmax_temperature 105 | 106 | batch_size, *_ = q.shape 107 | # positive label for the augmented version 108 | target_aug = torch.ones((batch_size, 1), device=q.device) 109 | # comparing the query label with l_que 110 | target_que = torch.eq(labels.reshape(-1, 1), labels_queue) 111 | target_que = target_que.float() 112 | # labels: Nx(1+K) 113 | target = torch.cat([target_aug, target_que], dim=1) 114 | # calculate the contrastive loss, Eqn.(7) 115 | loss = self.uni_con(logits=logits, target=target) 116 | return loss 117 | 118 | def training_step(self, batch, batch_idx): 119 | (q, k), labels = batch 120 | 121 | # update the key encoder 122 | self._momentum_update_key_encoder() 123 | 124 | queries, keys = self.representation(q=q, k=k) 125 | loss = self._loss( 126 | q=queries, 127 | k=keys, 128 | labels=labels, 129 | queue=self.queue, 130 | labels_queue=self.labels_queue 131 | ) 132 | 133 | # dequeue and enqueue 134 | self._dequeue_and_enqueue(keys, labels) 135 | 136 | roc_auc_ = roc_auc(queries=queries, keys=keys, labels=labels) 137 | self.log_dict({"train_loss": loss, "train_roc_auc": roc_auc_}) 138 | return loss 139 | 140 | def validation_step(self, batch, batch_idx): 141 | features, labels = batch 142 | features = self(features) 143 | labels = labels.contiguous().view(-1, 1) 144 | 145 | return {"features": features, "labels": labels} 146 | 147 | def validation_epoch_end(self, outputs): 148 | log = validation_metrics(outputs, task=self.config.dataset.name) 149 | self.log_dict(log) 150 | 151 | def test_step(self, batch, batch_idx): 152 | return self.validation_step(batch=batch, batch_idx=batch_idx) 153 | 154 | def test_epoch_end(self, outputs): 155 | self.validation_epoch_end(outputs=outputs) 156 | 157 | def configure_optimizers(self): 158 | return configure_optimizers( 159 | self, 160 | learning_rate=self.config.ssl.learning_rate, 161 | weight_decay=self.config.ssl.weight_decay, 162 | warmup_epochs=self.config.ssl.warmup_epochs, 163 | max_epochs=self.config.hyper_parameters.n_epochs, 164 | exclude_bn_bias=self.config.ssl.exclude_bn_bias, 165 | train_iters_per_epoch=self.train_iters_per_epoch 166 | ) 167 | -------------------------------------------------------------------------------- /models/self_supervised/simclr.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import DictConfig 6 | from pl_bolts.models.self_supervised import SimCLR 7 | 8 | from models.self_supervised.utils import ( 9 | validation_metrics, 10 | prepare_features, 11 | clone_classification_step, 12 | compute_num_samples, 13 | init_model, 14 | roc_auc, configure_optimizers 15 | ) 16 | 17 | 18 | class SimCLRModel(SimCLR): 19 | def __init__( 20 | self, 21 | config: DictConfig, 22 | **kwargs 23 | ): 24 | self.save_hyperparameters() 25 | self.config = config 26 | self.base_encoder = config.name 27 | train_data_path = join( 28 | config.data_folder, 29 | config.dataset.name, 30 | "raw", 31 | config.train_holdout 32 | ) 33 | 34 | num_samples = compute_num_samples(train_data_path) 35 | 36 | super().__init__( 37 | gpus=config.ssl.gpus, 38 | num_nodes=config.ssl.num_nodes, 39 | batch_size=config.hyper_parameters.batch_size, 40 | max_epochs=config.hyper_parameters.n_epochs, 41 | hidden_mlp=config.num_classes, 42 | feat_dim=config.num_classes, 43 | temperature=config.ssl.temperature, 44 | warmup_epochs=config.ssl.warmup_epochs, 45 | start_lr=config.ssl.start_lr, 46 | learning_rate=config.ssl.learning_rate, 47 | weight_decay=config.ssl.weight_decay, 48 | exclude_bn_bias=config.ssl.exclude_bn_bias, 49 | num_samples=num_samples, 50 | dataset="", 51 | **kwargs 52 | ) 53 | 54 | def init_model(self): 55 | encoder = init_model(self.config) 56 | return encoder 57 | 58 | def forward(self, x): 59 | x = self.encoder(x) 60 | x = F.normalize(x, dim=1) 61 | return x 62 | 63 | def _loss(self, logits, mask): 64 | batch_size = mask.shape[0] // 2 65 | 66 | # compute logits 67 | anchor_dot_contrast = logits / self.temperature 68 | 69 | # for numerical stability 70 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 71 | logits = anchor_dot_contrast - logits_max.detach() 72 | 73 | # mask-out self-contrast cases 74 | logits_mask = torch.scatter( 75 | torch.ones_like(mask, device=self.device), 76 | 1, 77 | torch.arange(2 * batch_size, device=self.device).view(-1, 1), 78 | 0 79 | ) 80 | mask_ = mask * logits_mask 81 | 82 | # compute log_prob 83 | exp_logits = torch.exp(logits) * logits_mask 84 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 85 | 86 | # compute mean of log-likelihood over positive 87 | mean_log_prob_pos = (mask_ * log_prob).sum(1) / mask_.sum(1) 88 | loss = -mean_log_prob_pos.view(2, batch_size).mean() 89 | 90 | return loss 91 | 92 | def training_step(self, batch, batch_idx): 93 | (q, k, _), labels = batch 94 | queries, keys = self(q), self(k) 95 | 96 | # get z representations 97 | z1 = self.projection(queries) 98 | z2 = self.projection(keys) 99 | 100 | embeddings, loss_labels = prepare_features(z1, z2, labels) 101 | loss_logits, loss_mask = clone_classification_step(embeddings, loss_labels) 102 | loss = self._loss(loss_logits, loss_mask) 103 | 104 | roc_auc_ = roc_auc(queries=queries, keys=keys, labels=labels) 105 | self.log_dict({"train_loss": loss, "train_roc_auc": roc_auc_}) 106 | return loss 107 | 108 | def validation_step(self, batch, batch_idx): 109 | features, labels = batch 110 | features = self(features) 111 | labels = labels.contiguous().view(-1, 1) 112 | 113 | return {"features": features, "labels": labels} 114 | 115 | def validation_epoch_end(self, outputs): 116 | log = validation_metrics(outputs, task=self.config.dataset.name) 117 | self.log_dict(log) 118 | 119 | def test_step(self, batch, batch_idx): 120 | return self.validation_step(batch=batch, batch_idx=batch_idx) 121 | 122 | def test_epoch_end(self, outputs): 123 | self.validation_epoch_end(outputs=outputs) 124 | 125 | def configure_optimizers(self): 126 | return configure_optimizers( 127 | self, 128 | learning_rate=self.config.ssl.learning_rate, 129 | weight_decay=self.config.ssl.weight_decay, 130 | warmup_epochs=self.config.ssl.warmup_epochs, 131 | max_epochs=self.config.hyper_parameters.n_epochs, 132 | exclude_bn_bias=self.config.ssl.exclude_bn_bias, 133 | train_iters_per_epoch=self.train_iters_per_epoch 134 | ) 135 | 136 | 137 | class SimCLRTransform: 138 | def __call__(self, batch): 139 | (x1, x2), y = batch 140 | return (x1, x2, None), y 141 | -------------------------------------------------------------------------------- /models/self_supervised/utils.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import isdir, join 3 | 4 | import torch 5 | from code2seq.data.vocabulary import Vocabulary 6 | from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay 7 | from torch.optim import Adam 8 | from torch_cluster import knn 9 | from torchmetrics.functional import auroc 10 | 11 | from models import encoder_models 12 | 13 | 14 | def exclude_from_wt_decay(named_params, weight_decay, skip_list=("bias", "bn")): 15 | params = [] 16 | excluded_params = [] 17 | 18 | for name, param in named_params: 19 | if not param.requires_grad: 20 | continue 21 | elif any(layer_name in name for layer_name in skip_list): 22 | excluded_params.append(param) 23 | else: 24 | params.append(param) 25 | 26 | return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}] 27 | 28 | 29 | def configure_optimizers( 30 | model, 31 | learning_rate: float, 32 | weight_decay: float, 33 | warmup_epochs: int, 34 | max_epochs: int, 35 | exclude_bn_bias: bool, 36 | train_iters_per_epoch: int 37 | ): 38 | if exclude_bn_bias: 39 | params = exclude_from_wt_decay(model.named_parameters(), weight_decay=weight_decay) 40 | else: 41 | params = model.parameters() 42 | 43 | optimizer = Adam(params, lr=learning_rate, weight_decay=weight_decay) 44 | 45 | warmup_steps = train_iters_per_epoch * warmup_epochs 46 | total_steps = train_iters_per_epoch * max_epochs 47 | 48 | scheduler = { 49 | "scheduler": torch.optim.lr_scheduler.LambdaLR( 50 | optimizer, 51 | linear_warmup_decay(warmup_steps, total_steps, cosine=True), 52 | ), 53 | "interval": "step", 54 | "frequency": 1, 55 | } 56 | return [optimizer], [scheduler] 57 | 58 | 59 | def init_model(config): 60 | if config.name in ["transformer", "gnn", "code-transformer"]: 61 | encoder = encoder_models[config.name](config) 62 | elif config.name == "code2class": 63 | _vocabulary = Vocabulary( 64 | join( 65 | config.data_folder, 66 | config.dataset.name, 67 | config.dataset.dir, 68 | config.vocabulary_name 69 | ), 70 | config.dataset.max_labels, 71 | config.dataset.max_tokens 72 | ) 73 | encoder = encoder_models[config.name](config=config, vocabulary=_vocabulary) 74 | else: 75 | raise ValueError(f"Unknown model: {config.name}") 76 | return encoder 77 | 78 | 79 | @torch.no_grad() 80 | def roc_auc(queries, keys, labels): 81 | features, labels = prepare_features(queries, keys, labels) 82 | logits, mask = clone_classification_step(features, labels) 83 | logits = scale(logits) 84 | logits = logits.reshape(-1) 85 | mask = mask.reshape(-1) 86 | 87 | return auroc(logits, mask) 88 | 89 | 90 | def compute_f1(conf_matrix): 91 | assert conf_matrix.shape == (2, 2) 92 | tn, fn, fp, tp = conf_matrix.reshape(-1).tolist() 93 | f1 = tp / (tp + 0.5 * (fp + fn)) 94 | return f1 95 | 96 | 97 | def compute_map_at_k(preds): 98 | avg_precisions = [] 99 | 100 | k = preds.shape[1] 101 | for pred in preds: 102 | positions = torch.arange(1, k + 1, device=preds.device)[pred > 0] 103 | if positions.shape[0]: 104 | avg = torch.arange(1, positions.shape[0] + 1, device=positions.device) / positions 105 | avg_precisions.append(avg.sum() / k) 106 | else: 107 | avg_precisions.append(torch.tensor(0.0, device=preds.device)) 108 | return torch.stack(avg_precisions).mean().item() 109 | 110 | 111 | def validation_metrics(outputs, task: str = "poj_104"): 112 | features = torch.cat([out["features"] for out in outputs]) 113 | _, hidden_size = features.shape 114 | 115 | labels = torch.cat([out["labels"] for out in outputs]).reshape(-1) 116 | 117 | if task == "poj_104": 118 | ks = [100, 200, 500] 119 | elif task == "codeforces": 120 | ks = [5, 10, 15] 121 | else: 122 | raise ValueError(f"Unknown task {task}") 123 | 124 | logs = {} 125 | for k in ks: 126 | if k < labels.shape[0]: 127 | top_ids = knn(x=features, y=features, k=k + 1) 128 | top_ids = top_ids[1, :].reshape(-1, k + 1) 129 | top_ids = top_ids[:, 1:] 130 | 131 | top_labels = labels[top_ids] 132 | preds = torch.eq(top_labels, labels.reshape(-1, 1)) 133 | logs[f"val_map@{k}"] = compute_map_at_k(preds) 134 | return logs 135 | 136 | 137 | def clone_classification_step(features, labels): 138 | logits = torch.matmul(features, features.T) 139 | mask = torch.eq(labels, labels.T) 140 | return logits, mask 141 | 142 | 143 | def prepare_features(queries, keys, labels): 144 | features = torch.cat([queries, keys], dim=0) 145 | labels = labels.contiguous().view(-1, 1) 146 | labels = labels.repeat(2, 1) 147 | return features, labels 148 | 149 | 150 | def scale(x): 151 | x = torch.clamp(x, min=-1, max=1) 152 | return (x + 1) / 2 153 | 154 | 155 | def compute_num_samples(train_data_path: str): 156 | num_samples = 0 157 | for class_ in listdir(train_data_path): 158 | class_path = join(train_data_path, class_) 159 | if isdir(class_path): 160 | num_files = len([_ for _ in listdir(class_path)]) 161 | num_samples += num_files * (num_files - 1) // 2 162 | return num_samples 163 | -------------------------------------------------------------------------------- /models/vulexplainer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | cur_detector = "ivdetect" 3 | slice_level: bool = (cur_detector == "deepwukong") 4 | 5 | data_path = { 6 | "reveal": "function/explain_reveal.json", 7 | "devign": "function/explain_devign.json", 8 | "ivdetect": "function/explain_ivdetect.json", 9 | "deepwukong": "slice/explain_deepwukong.json" 10 | } -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .code2class_vocab import build_code2seq_vocab, process_astminer_csv 2 | from .graphs_vocab import build_graphs_vocab 3 | from .joern import process_graphs 4 | from .tokenize import tokenize 5 | 6 | __all__ = [ 7 | "tokenize", 8 | "process_graphs", 9 | "process_astminer_csv", 10 | "build_graphs_vocab", 11 | "build_code2seq_vocab", 12 | ] 13 | -------------------------------------------------------------------------------- /preprocess/code2class_vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import Counter 3 | from os import remove 4 | from os.path import join, dirname, exists 5 | from random import shuffle 6 | from typing import Counter as CounterType, Type, Dict 7 | 8 | from code2seq.data.vocabulary import Vocabulary 9 | from commode_utils.filesystem import count_lines_in_file 10 | from commode_utils.vocabulary import BaseVocabulary 11 | from tqdm import tqdm 12 | 13 | 14 | def build_code2seq_vocab( 15 | train_data: str, 16 | test_data: str, 17 | val_data: str, 18 | vocabulary_cls: Type[BaseVocabulary] = Vocabulary 19 | ): 20 | counters: Dict[str, CounterType[str]] = { 21 | key: Counter() for key in [vocabulary_cls.LABEL, vocabulary_cls.TOKEN, vocabulary_cls.NODE] 22 | } 23 | with open(train_data, "r") as f_in: 24 | for raw_sample in tqdm(f_in, total=count_lines_in_file(train_data)): 25 | vocabulary_cls.process_raw_sample(raw_sample, counters) 26 | 27 | for data in [test_data, val_data]: 28 | with open(data, "r") as f_in: 29 | for raw_sample in tqdm(f_in, total=count_lines_in_file(data)): 30 | label, *_ = raw_sample.split(" ") 31 | counters[vocabulary_cls.LABEL].update(label.split(vocabulary_cls._separator)) 32 | 33 | for feature, counter in counters.items(): 34 | print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") 35 | 36 | dataset_dir = dirname(train_data) 37 | vocabulary_file = join(dataset_dir, vocabulary_cls.vocab_filename) 38 | with open(vocabulary_file, "wb") as f_out: 39 | pickle.dump(counters, f_out) 40 | 41 | 42 | def _get_id2value_from_csv(path_: str) -> Dict[str, str]: 43 | with open(path_, "r") as f: 44 | lines = f.read().strip().split("\n")[1:] 45 | lines = [line.split(",", maxsplit=1) for line in lines] 46 | return {k: v for k, v in lines} 47 | 48 | 49 | def process_astminer_csv(data_folder: str, dataset_name: str, holdout_name: str, is_shuffled: bool): 50 | """ 51 | Preprocessing for files tokens.csv, paths.csv, node_types.csv 52 | """ 53 | dataset_path = join(data_folder, dataset_name) 54 | id_to_token_data_path = join(dataset_path, f"tokens.{holdout_name}.csv") 55 | id_to_type_data_path = join(dataset_path, f"node_types.{holdout_name}.csv") 56 | id_to_paths_data_path = join(dataset_path, f"paths.{holdout_name}.csv") 57 | path_contexts_path = join(dataset_path, f"path_contexts.{holdout_name}.csv") 58 | output_c2s_path = join(dataset_path, f"{dataset_name}.{holdout_name}.c2s") 59 | 60 | id_to_paths_stored = _get_id2value_from_csv(id_to_paths_data_path) 61 | id_to_paths = {index: [n for n in nodes.split()] for index, nodes in id_to_paths_stored.items()} 62 | 63 | id_to_node_types = _get_id2value_from_csv(id_to_type_data_path) 64 | id_to_node_types = {index: node_type.rsplit(" ", maxsplit=1)[0] for index, node_type in 65 | id_to_node_types.items()} 66 | 67 | id_to_tokens = _get_id2value_from_csv(id_to_token_data_path) 68 | 69 | if exists(output_c2s_path): 70 | remove(output_c2s_path) 71 | with open(path_contexts_path, "r") as path_contexts_file, open(output_c2s_path, "a+") as c2s_output: 72 | output_lines = [] 73 | for line in tqdm(path_contexts_file, total=count_lines_in_file(path_contexts_path)): 74 | label, *path_contexts = line.split() 75 | parsed_line = [label] 76 | for path_context in path_contexts: 77 | from_token_id, path_types_id, to_token_id = path_context.split(",") 78 | from_token, to_token = id_to_tokens[from_token_id], id_to_tokens[to_token_id] 79 | if (" " in from_token) or (" " in to_token) or (): 80 | continue 81 | nodes = [id_to_node_types[p_] for p_ in id_to_paths[path_types_id]] 82 | for node in nodes: 83 | if " " in node: 84 | break 85 | parsed_line.append(",".join([from_token, "|".join(nodes), to_token])) 86 | output_lines.append(" ".join(parsed_line + ["\n"])) 87 | if is_shuffled: 88 | shuffle(output_lines) 89 | c2s_output.write("".join(output_lines)) 90 | -------------------------------------------------------------------------------- /preprocess/graphs_vocab.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import listdir 3 | from os.path import join, isdir, isfile 4 | 5 | from omegaconf import DictConfig 6 | from tqdm import tqdm 7 | 8 | 9 | def is_json_file(path: str): 10 | ext = path.rsplit(".", 1)[-1] 11 | return isfile(path) and (ext == "json") 12 | 13 | 14 | def build_graphs_vocab(config: DictConfig): 15 | graphs_storage = join(config.data_folder, config.dataset.name, config.dataset.dir) 16 | 17 | edges_types = set() 18 | vertexes_types = set() 19 | vertexes_names = set() 20 | 21 | holdout_path = join(graphs_storage, config.train_holdout) 22 | 23 | for class_ in tqdm(listdir(holdout_path)): 24 | class_path = join(holdout_path, class_) 25 | if isdir(class_path): 26 | paths = [ 27 | join(class_path, file) for file in listdir(class_path) if is_json_file(join(class_path, file)) 28 | ] 29 | 30 | for graph_path in tqdm(paths): 31 | with open(graph_path, "r") as f: 32 | graph = json.load(f) 33 | e = json.loads(graph["edges"]) 34 | v = json.loads(graph["vertexes"]) 35 | vertexes_types.update(set(v_["label"] for v_ in v)) 36 | vertexes_names.update(set(v_["name"] for v_ in v)) 37 | edges_types.update(set(e_["label"] for e_ in e)) 38 | 39 | vertexes_types.add("UNKNOWN") 40 | vertexes_types = sorted(list(vertexes_types)) 41 | edges_types.add("UNKNOWN") 42 | edges_types = sorted(list(edges_types)) 43 | vocab = { 44 | "v_type2id": {v_type: id_ for id_, v_type in enumerate(vertexes_types)}, 45 | "v_name2id": {v_name: id_ for id_, v_name in enumerate(vertexes_names)}, 46 | "e_type2id": {e_type: id_ for id_, e_type in enumerate(edges_types)} 47 | } 48 | 49 | with open(join(graphs_storage, config.dataset.vocab_file), "w") as vocab_f: 50 | json.dump(vocab, vocab_f) 51 | -------------------------------------------------------------------------------- /preprocess/parser/languages.exp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/preprocess/parser/languages.exp -------------------------------------------------------------------------------- /preprocess/parser/languages.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/preprocess/parser/languages.lib -------------------------------------------------------------------------------- /preprocess/parser/languages.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CocaVul/Coca/e4c6a6b54c2937fed7415dad3ebec9896f3c697b/preprocess/parser/languages.so -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tree-sitter-c" 3 | description = "C grammar for the tree-sitter parsing library" 4 | version = "0.20.2" 5 | authors = ["Max Brunsfeld "] 6 | license = "MIT" 7 | readme = "bindings/rust/README.md" 8 | keywords = ["incremental", "parsing", "c"] 9 | categories = ["parsing", "text-editors"] 10 | repository = "https://github.com/tree-sitter/tree-sitter-c" 11 | edition = "2018" 12 | 13 | build = "bindings/rust/build.rs" 14 | include = ["bindings/rust/*", "grammar.js", "queries/*", "src/*"] 15 | 16 | [lib] 17 | path = "bindings/rust/lib.rs" 18 | 19 | [dependencies] 20 | tree-sitter = "0.20" 21 | 22 | [build-dependencies] 23 | cc = "1.0" 24 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Max Brunsfeld 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version:5.3 2 | import PackageDescription 3 | 4 | let package = Package( 5 | name: "TreeSitterC", 6 | platforms: [.macOS(.v10_13), .iOS(.v11)], 7 | products: [ 8 | .library(name: "TreeSitterC", targets: ["TreeSitterC"]), 9 | ], 10 | dependencies: [], 11 | targets: [ 12 | .target(name: "TreeSitterC", 13 | path: ".", 14 | exclude: [ 15 | "binding.gyp", 16 | "bindings", 17 | "Cargo.toml", 18 | "examples", 19 | "grammar.js", 20 | "LICENSE", 21 | "Makefile", 22 | "package.json", 23 | "README.md", 24 | "src/grammar.json", 25 | "src/node-types.json", 26 | ], 27 | sources: [ 28 | "src/parser.c", 29 | ], 30 | resources: [ 31 | .copy("queries") 32 | ], 33 | publicHeadersPath: "bindings/swift", 34 | cSettings: [.headerSearchPath("src")]) 35 | ] 36 | ) -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/README.md: -------------------------------------------------------------------------------- 1 | tree-sitter-c 2 | ================== 3 | 4 | [![Build Status](https://travis-ci.org/tree-sitter/tree-sitter-c.svg?branch=master)](https://travis-ci.org/tree-sitter/tree-sitter-c) 5 | [![Build status](https://ci.appveyor.com/api/projects/status/7u0sy6ajmxro4wfh/branch/master?svg=true)](https://ci.appveyor.com/project/maxbrunsfeld/tree-sitter-c/branch/master) 6 | 7 | C grammar for [tree-sitter](https://github.com/tree-sitter/tree-sitter). Adapted from [this C99 grammar](http://slps.github.io/zoo/c/iso-9899-tc3.html). 8 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "tree_sitter_c_binding", 5 | "include_dirs": [ 6 | " 3 | #include "nan.h" 4 | 5 | using namespace v8; 6 | 7 | extern "C" TSLanguage * tree_sitter_c(); 8 | 9 | namespace { 10 | 11 | NAN_METHOD(New) {} 12 | 13 | void Init(Local exports, Local module) { 14 | Local tpl = Nan::New(New); 15 | tpl->SetClassName(Nan::New("Language").ToLocalChecked()); 16 | tpl->InstanceTemplate()->SetInternalFieldCount(1); 17 | 18 | Local constructor = Nan::GetFunction(tpl).ToLocalChecked(); 19 | Local instance = constructor->NewInstance(Nan::GetCurrentContext()).ToLocalChecked(); 20 | Nan::SetInternalFieldPointer(instance, 0, tree_sitter_c()); 21 | 22 | Nan::Set(instance, Nan::New("name").ToLocalChecked(), Nan::New("c").ToLocalChecked()); 23 | Nan::Set(module, Nan::New("exports").ToLocalChecked(), instance); 24 | } 25 | 26 | NODE_MODULE(tree_sitter_c_binding, Init) 27 | 28 | } // namespace 29 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/bindings/node/index.js: -------------------------------------------------------------------------------- 1 | try { 2 | module.exports = require("../../build/Release/tree_sitter_c_binding"); 3 | } catch (error1) { 4 | if (error1.code !== 'MODULE_NOT_FOUND') { 5 | throw error1; 6 | } 7 | try { 8 | module.exports = require("../../build/Debug/tree_sitter_c_binding"); 9 | } catch (error2) { 10 | if (error2.code !== 'MODULE_NOT_FOUND') { 11 | throw error2; 12 | } 13 | throw error1 14 | } 15 | } 16 | 17 | try { 18 | module.exports.nodeTypeInfo = require("../../src/node-types.json"); 19 | } catch (_) {} 20 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/bindings/rust/README.md: -------------------------------------------------------------------------------- 1 | # tree-sitter-c 2 | 3 | This crate provides a C grammar for the [tree-sitter][] parsing library. To 4 | use this crate, add it to the `[dependencies]` section of your `Cargo.toml` 5 | file. (Note that you will probably also need to depend on the 6 | [`tree-sitter`][tree-sitter crate] crate to use the parsed result in any useful 7 | way.) 8 | 9 | ``` toml 10 | [dependencies] 11 | tree-sitter = "0.17" 12 | tree-sitter-c = "0.16" 13 | ``` 14 | 15 | Typically, you will use the [language][language func] function to add this 16 | grammar to a tree-sitter [Parser][], and then use the parser to parse some code: 17 | 18 | ``` rust 19 | let code = r#" 20 | int double(int x) { 21 | return x * 2; 22 | } 23 | "#; 24 | let mut parser = Parser::new(); 25 | parser.set_language(tree_sitter_c::language()).expect("Error loading C grammar"); 26 | let parsed = parser.parse(code, None); 27 | ``` 28 | 29 | If you have any questions, please reach out to us in the [tree-sitter 30 | discussions] page. 31 | 32 | [Language]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Language.html 33 | [language func]: https://docs.rs/tree-sitter-c/*/tree_sitter_c/fn.language.html 34 | [Parser]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Parser.html 35 | [tree-sitter]: https://tree-sitter.github.io/ 36 | [tree-sitter crate]: https://crates.io/crates/tree-sitter 37 | [tree-sitter discussions]: https://github.com/tree-sitter/tree-sitter/discussions 38 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/bindings/rust/build.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | extern crate cc; 3 | 4 | fn main() { 5 | let src_dir = Path::new("src"); 6 | 7 | let mut c_config = cc::Build::new(); 8 | c_config.include(&src_dir); 9 | c_config 10 | .flag_if_supported("-Wno-unused-parameter") 11 | .flag_if_supported("-Wno-unused-but-set-variable") 12 | .flag_if_supported("-Wno-trigraphs"); 13 | let parser_path = src_dir.join("parser.c"); 14 | c_config.file(&parser_path); 15 | println!("cargo:rerun-if-changed={}", parser_path.to_str().unwrap()); 16 | c_config.compile("parser"); 17 | } 18 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/bindings/rust/lib.rs: -------------------------------------------------------------------------------- 1 | // -*- coding: utf-8 -*- 2 | // ------------------------------------------------------------------------------------------------ 3 | // Copyright © 2021, tree-sitter-c authors. 4 | // See the LICENSE file in this repo for license details. 5 | // ------------------------------------------------------------------------------------------------ 6 | 7 | //! This crate provides a C grammar for the [tree-sitter][] parsing library. 8 | //! 9 | //! Typically, you will use the [language][language func] function to add this grammar to a 10 | //! tree-sitter [Parser][], and then use the parser to parse some code: 11 | //! 12 | //! ``` 13 | //! use tree_sitter::Parser; 14 | //! 15 | //! let code = r#" 16 | //! int double(int x) { 17 | //! return x * 2; 18 | //! } 19 | //! "#; 20 | //! let mut parser = Parser::new(); 21 | //! parser.set_language(tree_sitter_c::language()).expect("Error loading C grammar"); 22 | //! let parsed = parser.parse(code, None); 23 | //! # let parsed = parsed.unwrap(); 24 | //! # let root = parsed.root_node(); 25 | //! # assert!(!root.has_error()); 26 | //! ``` 27 | //! 28 | //! [Language]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Language.html 29 | //! [language func]: fn.language.html 30 | //! [Parser]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Parser.html 31 | //! [tree-sitter]: https://tree-sitter.github.io/ 32 | 33 | use tree_sitter::Language; 34 | 35 | extern "C" { 36 | fn tree_sitter_c() -> Language; 37 | } 38 | 39 | /// Returns the tree-sitter [Language][] for this grammar. 40 | /// 41 | /// [Language]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Language.html 42 | pub fn language() -> Language { 43 | unsafe { tree_sitter_c() } 44 | } 45 | 46 | /// The source of the C tree-sitter grammar description. 47 | pub const GRAMMAR: &str = include_str!("../../grammar.js"); 48 | 49 | /// The syntax highlighting query for this language. 50 | pub const HIGHLIGHT_QUERY: &str = include_str!("../../queries/highlights.scm"); 51 | 52 | /// The content of the [`node-types.json`][] file for this grammar. 53 | /// 54 | /// [`node-types.json`]: https://tree-sitter.github.io/tree-sitter/using-parsers#static-node-types 55 | pub const NODE_TYPES: &str = include_str!("../../src/node-types.json"); 56 | 57 | #[cfg(test)] 58 | mod tests { 59 | #[test] 60 | fn can_load_grammar() { 61 | let mut parser = tree_sitter::Parser::new(); 62 | parser 63 | .set_language(super::language()) 64 | .expect("Error loading C grammar"); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/bindings/swift/TreeSitterC/c.h: -------------------------------------------------------------------------------- 1 | #ifndef TREE_SITTER_C_H_ 2 | #define TREE_SITTER_C_H_ 3 | 4 | typedef struct TSLanguage TSLanguage; 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | extern TSLanguage *tree_sitter_c(); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | #endif // TREE_SITTER_C_H_ -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tree-sitter-c", 3 | "version": "0.20.2", 4 | "description": "C grammar for node-tree-sitter", 5 | "main": "bindings/node", 6 | "keywords": [ 7 | "parser", 8 | "lexer" 9 | ], 10 | "repository": { 11 | "type": "git", 12 | "url": "https://github.com/tree-sitter/tree-sitter-c.git" 13 | }, 14 | "author": "Max Brunsfeld", 15 | "license": "MIT", 16 | "dependencies": { 17 | "nan": "^2.14.0" 18 | }, 19 | "devDependencies": { 20 | "tree-sitter-cli": "^0.20.0" 21 | }, 22 | "scripts": { 23 | "build": "tree-sitter generate && node-gyp build", 24 | "test": "tree-sitter test && tree-sitter parse examples/* --quiet --time", 25 | "test-windows": "tree-sitter test" 26 | }, 27 | "tree-sitter": [ 28 | { 29 | "scope": "source.c", 30 | "file-types": [ 31 | "c", 32 | "h" 33 | ] 34 | } 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/queries/highlights.scm: -------------------------------------------------------------------------------- 1 | "break" @keyword 2 | "case" @keyword 3 | "const" @keyword 4 | "continue" @keyword 5 | "default" @keyword 6 | "do" @keyword 7 | "else" @keyword 8 | "enum" @keyword 9 | "extern" @keyword 10 | "for" @keyword 11 | "if" @keyword 12 | "inline" @keyword 13 | "return" @keyword 14 | "sizeof" @keyword 15 | "static" @keyword 16 | "struct" @keyword 17 | "switch" @keyword 18 | "typedef" @keyword 19 | "union" @keyword 20 | "volatile" @keyword 21 | "while" @keyword 22 | 23 | "#define" @keyword 24 | "#elif" @keyword 25 | "#else" @keyword 26 | "#endif" @keyword 27 | "#if" @keyword 28 | "#ifdef" @keyword 29 | "#ifndef" @keyword 30 | "#include" @keyword 31 | (preproc_directive) @keyword 32 | 33 | "--" @operator 34 | "-" @operator 35 | "-=" @operator 36 | "->" @operator 37 | "=" @operator 38 | "!=" @operator 39 | "*" @operator 40 | "&" @operator 41 | "&&" @operator 42 | "+" @operator 43 | "++" @operator 44 | "+=" @operator 45 | "<" @operator 46 | "==" @operator 47 | ">" @operator 48 | "||" @operator 49 | 50 | "." @delimiter 51 | ";" @delimiter 52 | 53 | (string_literal) @string 54 | (system_lib_string) @string 55 | 56 | (null) @constant 57 | (number_literal) @number 58 | (char_literal) @number 59 | 60 | (call_expression 61 | function: (identifier) @function) 62 | (call_expression 63 | function: (field_expression 64 | field: (field_identifier) @function)) 65 | (function_declarator 66 | declarator: (identifier) @function) 67 | (preproc_function_def 68 | name: (identifier) @function.special) 69 | 70 | (field_identifier) @property 71 | (statement_identifier) @label 72 | (type_identifier) @type 73 | (primitive_type) @type 74 | (sized_type_specifier) @type 75 | 76 | ((identifier) @constant 77 | (#match? @constant "^[A-Z][A-Z\\d_]*$")) 78 | 79 | (identifier) @variable 80 | 81 | (comment) @comment 82 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/src/tree_sitter/parser.h: -------------------------------------------------------------------------------- 1 | #ifndef TREE_SITTER_PARSER_H_ 2 | #define TREE_SITTER_PARSER_H_ 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #define ts_builtin_sym_error ((TSSymbol)-1) 13 | #define ts_builtin_sym_end 0 14 | #define TREE_SITTER_SERIALIZATION_BUFFER_SIZE 1024 15 | 16 | typedef uint16_t TSStateId; 17 | 18 | #ifndef TREE_SITTER_API_H_ 19 | typedef uint16_t TSSymbol; 20 | typedef uint16_t TSFieldId; 21 | typedef struct TSLanguage TSLanguage; 22 | #endif 23 | 24 | typedef struct { 25 | TSFieldId field_id; 26 | uint8_t child_index; 27 | bool inherited; 28 | } TSFieldMapEntry; 29 | 30 | typedef struct { 31 | uint16_t index; 32 | uint16_t length; 33 | } TSFieldMapSlice; 34 | 35 | typedef struct { 36 | bool visible; 37 | bool named; 38 | bool supertype; 39 | } TSSymbolMetadata; 40 | 41 | typedef struct TSLexer TSLexer; 42 | 43 | struct TSLexer { 44 | int32_t lookahead; 45 | TSSymbol result_symbol; 46 | void (*advance)(TSLexer *, bool); 47 | void (*mark_end)(TSLexer *); 48 | uint32_t (*get_column)(TSLexer *); 49 | bool (*is_at_included_range_start)(const TSLexer *); 50 | bool (*eof)(const TSLexer *); 51 | }; 52 | 53 | typedef enum { 54 | TSParseActionTypeShift, 55 | TSParseActionTypeReduce, 56 | TSParseActionTypeAccept, 57 | TSParseActionTypeRecover, 58 | } TSParseActionType; 59 | 60 | typedef union { 61 | struct { 62 | uint8_t type; 63 | TSStateId state; 64 | bool extra; 65 | bool repetition; 66 | } shift; 67 | struct { 68 | uint8_t type; 69 | uint8_t child_count; 70 | TSSymbol symbol; 71 | int16_t dynamic_precedence; 72 | uint16_t production_id; 73 | } reduce; 74 | uint8_t type; 75 | } TSParseAction; 76 | 77 | typedef struct { 78 | uint16_t lex_state; 79 | uint16_t external_lex_state; 80 | } TSLexMode; 81 | 82 | typedef union { 83 | TSParseAction action; 84 | struct { 85 | uint8_t count; 86 | bool reusable; 87 | } entry; 88 | } TSParseActionEntry; 89 | 90 | struct TSLanguage { 91 | uint32_t version; 92 | uint32_t symbol_count; 93 | uint32_t alias_count; 94 | uint32_t token_count; 95 | uint32_t external_token_count; 96 | uint32_t state_count; 97 | uint32_t large_state_count; 98 | uint32_t production_id_count; 99 | uint32_t field_count; 100 | uint16_t max_alias_sequence_length; 101 | const uint16_t *parse_table; 102 | const uint16_t *small_parse_table; 103 | const uint32_t *small_parse_table_map; 104 | const TSParseActionEntry *parse_actions; 105 | const char * const *symbol_names; 106 | const char * const *field_names; 107 | const TSFieldMapSlice *field_map_slices; 108 | const TSFieldMapEntry *field_map_entries; 109 | const TSSymbolMetadata *symbol_metadata; 110 | const TSSymbol *public_symbol_map; 111 | const uint16_t *alias_map; 112 | const TSSymbol *alias_sequences; 113 | const TSLexMode *lex_modes; 114 | bool (*lex_fn)(TSLexer *, TSStateId); 115 | bool (*keyword_lex_fn)(TSLexer *, TSStateId); 116 | TSSymbol keyword_capture_token; 117 | struct { 118 | const bool *states; 119 | const TSSymbol *symbol_map; 120 | void *(*create)(void); 121 | void (*destroy)(void *); 122 | bool (*scan)(void *, TSLexer *, const bool *symbol_whitelist); 123 | unsigned (*serialize)(void *, char *); 124 | void (*deserialize)(void *, const char *, unsigned); 125 | } external_scanner; 126 | const TSStateId *primary_state_ids; 127 | }; 128 | 129 | /* 130 | * Lexer Macros 131 | */ 132 | 133 | #define START_LEXER() \ 134 | bool result = false; \ 135 | bool skip = false; \ 136 | bool eof = false; \ 137 | int32_t lookahead; \ 138 | goto start; \ 139 | next_state: \ 140 | lexer->advance(lexer, skip); \ 141 | start: \ 142 | skip = false; \ 143 | lookahead = lexer->lookahead; 144 | 145 | #define ADVANCE(state_value) \ 146 | { \ 147 | state = state_value; \ 148 | goto next_state; \ 149 | } 150 | 151 | #define SKIP(state_value) \ 152 | { \ 153 | skip = true; \ 154 | state = state_value; \ 155 | goto next_state; \ 156 | } 157 | 158 | #define ACCEPT_TOKEN(symbol_value) \ 159 | result = true; \ 160 | lexer->result_symbol = symbol_value; \ 161 | lexer->mark_end(lexer); 162 | 163 | #define END_STATE() return result; 164 | 165 | /* 166 | * Parse Table Macros 167 | */ 168 | 169 | #define SMALL_STATE(id) id - LARGE_STATE_COUNT 170 | 171 | #define STATE(id) id 172 | 173 | #define ACTIONS(id) id 174 | 175 | #define SHIFT(state_value) \ 176 | {{ \ 177 | .shift = { \ 178 | .type = TSParseActionTypeShift, \ 179 | .state = state_value \ 180 | } \ 181 | }} 182 | 183 | #define SHIFT_REPEAT(state_value) \ 184 | {{ \ 185 | .shift = { \ 186 | .type = TSParseActionTypeShift, \ 187 | .state = state_value, \ 188 | .repetition = true \ 189 | } \ 190 | }} 191 | 192 | #define SHIFT_EXTRA() \ 193 | {{ \ 194 | .shift = { \ 195 | .type = TSParseActionTypeShift, \ 196 | .extra = true \ 197 | } \ 198 | }} 199 | 200 | #define REDUCE(symbol_val, child_count_val, ...) \ 201 | {{ \ 202 | .reduce = { \ 203 | .type = TSParseActionTypeReduce, \ 204 | .symbol = symbol_val, \ 205 | .child_count = child_count_val, \ 206 | __VA_ARGS__ \ 207 | }, \ 208 | }} 209 | 210 | #define RECOVER() \ 211 | {{ \ 212 | .type = TSParseActionTypeRecover \ 213 | }} 214 | 215 | #define ACCEPT_INPUT() \ 216 | {{ \ 217 | .type = TSParseActionTypeAccept \ 218 | }} 219 | 220 | #ifdef __cplusplus 221 | } 222 | #endif 223 | 224 | #endif // TREE_SITTER_PARSER_H_ 225 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/corpus/ambiguities.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | pointer declarations vs multiplications 3 | ======================================================================== 4 | 5 | int main() { 6 | // declare a function pointer 7 | T1 * b(T2 a); 8 | 9 | // evaluate expressions 10 | c * d(5); 11 | e(f * g); 12 | } 13 | 14 | --- 15 | 16 | (translation_unit (function_definition 17 | (primitive_type) 18 | (function_declarator (identifier) (parameter_list)) 19 | (compound_statement 20 | (comment) 21 | (declaration 22 | (type_identifier) 23 | (pointer_declarator (function_declarator 24 | (identifier) 25 | (parameter_list (parameter_declaration (type_identifier) (identifier)))))) 26 | (comment) 27 | (expression_statement (binary_expression 28 | (identifier) 29 | (call_expression (identifier) (argument_list (number_literal))))) 30 | (expression_statement (call_expression 31 | (identifier) 32 | (argument_list (binary_expression (identifier) (identifier)))))))) 33 | 34 | ======================================================================== 35 | casts vs multiplications 36 | ======================================================================== 37 | 38 | /* 39 | * ambiguities 40 | */ 41 | 42 | int main() { 43 | // cast 44 | a((B *)c); 45 | 46 | // parenthesized product 47 | d((e * f)); 48 | } 49 | 50 | --- 51 | 52 | (translation_unit 53 | (comment) 54 | (function_definition 55 | (primitive_type) 56 | (function_declarator (identifier) (parameter_list)) 57 | (compound_statement 58 | (comment) 59 | (expression_statement (call_expression 60 | (identifier) 61 | (argument_list (cast_expression (type_descriptor (type_identifier) (abstract_pointer_declarator)) (identifier))))) 62 | (comment) 63 | (expression_statement (call_expression 64 | (identifier) 65 | (argument_list (parenthesized_expression (binary_expression (identifier) (identifier))))))))) 66 | 67 | ======================================================================== 68 | function-like type macros vs function calls 69 | ======================================================================== 70 | 71 | // this is a macro 72 | GIT_INLINE(int *) x = 5; 73 | 74 | --- 75 | 76 | (translation_unit 77 | (comment) 78 | (declaration 79 | (macro_type_specifier (identifier) (type_descriptor (primitive_type) (abstract_pointer_declarator))) 80 | (init_declarator (identifier) (number_literal)))) 81 | 82 | ======================================================================== 83 | function calls vs parenthesized declarators vs macro types 84 | ======================================================================== 85 | 86 | int main() { 87 | /* 88 | * Could be either: 89 | * - function call 90 | * - declaration w/ parenthesized declarator 91 | * - declaration w/ macro type, no declarator 92 | */ 93 | ABC(d); 94 | 95 | /* 96 | * Normal declaration 97 | */ 98 | efg hij; 99 | } 100 | 101 | --- 102 | 103 | (translation_unit 104 | (function_definition 105 | (primitive_type) 106 | (function_declarator (identifier) (parameter_list)) 107 | (compound_statement 108 | (comment) 109 | (expression_statement (call_expression (identifier) (argument_list (identifier)))) 110 | (comment) 111 | (declaration (type_identifier) (identifier))))) 112 | 113 | ======================================================================== 114 | Call expressions vs empty declarations w/ macros as types 115 | ======================================================================== 116 | 117 | int main() { 118 | int a = 1; 119 | b(a); 120 | A(A *); 121 | } 122 | 123 | --- 124 | 125 | (translation_unit 126 | (function_definition 127 | (primitive_type) 128 | (function_declarator (identifier) (parameter_list)) 129 | (compound_statement 130 | (declaration (primitive_type) (init_declarator (identifier) (number_literal))) 131 | (expression_statement (call_expression (identifier) (argument_list (identifier)))) 132 | (macro_type_specifier 133 | (identifier) 134 | (type_descriptor (type_identifier) (abstract_pointer_declarator)))))) 135 | 136 | ======================================================================== 137 | Comments after for loops with ambiguities 138 | ======================================================================== 139 | 140 | int main() { 141 | for (a *b = c; d; e) { 142 | aff; 143 | } 144 | 145 | // a-comment 146 | 147 | g; 148 | } 149 | 150 | --- 151 | 152 | (translation_unit (function_definition 153 | (primitive_type) 154 | (function_declarator (identifier) (parameter_list)) 155 | (compound_statement 156 | (for_statement 157 | (declaration (type_identifier) (init_declarator 158 | (pointer_declarator (identifier)) 159 | (identifier))) 160 | (identifier) 161 | (identifier) 162 | (compound_statement 163 | (expression_statement (identifier)))) 164 | (comment) 165 | (expression_statement (identifier))))) 166 | 167 | =============================================== 168 | Top-level macro invocations 169 | =============================================== 170 | 171 | DEFINE_SOMETHING(THING_A, "this is a thing a"); 172 | DEFINE_SOMETHING(THING_B, "this is a thing b", "thanks"); 173 | 174 | --- 175 | 176 | (translation_unit 177 | (expression_statement (call_expression (identifier) (argument_list (identifier) (string_literal)))) 178 | (expression_statement (call_expression (identifier) (argument_list (identifier) (string_literal) (string_literal))))) 179 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/corpus/crlf.txt: -------------------------------------------------------------------------------- 1 | ============================================ 2 | Line comments with escaped CRLF line endings 3 | ============================================ 4 | 5 | // hello \ 6 | this is still a comment 7 | this_is_not a_comment; 8 | 9 | --- 10 | 11 | (translation_unit 12 | (comment) 13 | (declaration (type_identifier) (identifier))) 14 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/corpus/microsoft.txt: -------------------------------------------------------------------------------- 1 | ================================ 2 | declaration specs 3 | ================================ 4 | 5 | struct __declspec(dllexport) s2 6 | { 7 | }; 8 | 9 | union __declspec(noinline) u2 { 10 | }; 11 | 12 | --- 13 | 14 | (translation_unit 15 | (struct_specifier 16 | (ms_declspec_modifier 17 | (identifier)) 18 | name: (type_identifier) 19 | body: (field_declaration_list)) 20 | (union_specifier 21 | (ms_declspec_modifier 22 | (identifier)) 23 | name: (type_identifier) 24 | body: (field_declaration_list))) 25 | 26 | ================================ 27 | pointers 28 | ================================ 29 | 30 | struct s2 31 | { 32 | int * __restrict x; 33 | int * __sptr psp; 34 | int * __uptr pup; 35 | int * __unaligned pup; 36 | }; 37 | 38 | void sum2(int n, int * __restrict a, int * __restrict b, 39 | int * c, int * d) { 40 | int i; 41 | for (i = 0; i < n; i++) { 42 | a[i] = b[i] + c[i]; 43 | c[i] = b[i] + d[i]; 44 | } 45 | } 46 | 47 | void MyFunction(char * __uptr myValue); 48 | 49 | --- 50 | 51 | (translation_unit 52 | (struct_specifier 53 | name: (type_identifier) 54 | body: (field_declaration_list 55 | (field_declaration 56 | type: (primitive_type) 57 | declarator: (pointer_declarator 58 | (ms_pointer_modifier 59 | (ms_restrict_modifier)) 60 | declarator: (field_identifier))) 61 | (field_declaration 62 | type: (primitive_type) 63 | declarator: (pointer_declarator 64 | (ms_pointer_modifier 65 | (ms_signed_ptr_modifier)) 66 | declarator: (field_identifier))) 67 | (field_declaration 68 | type: (primitive_type) 69 | declarator: (pointer_declarator 70 | (ms_pointer_modifier 71 | (ms_unsigned_ptr_modifier)) 72 | declarator: (field_identifier))) 73 | (field_declaration 74 | type: (primitive_type) 75 | declarator: (pointer_declarator 76 | (ms_pointer_modifier 77 | (ms_unaligned_ptr_modifier)) 78 | declarator: (field_identifier))))) 79 | (function_definition 80 | type: (primitive_type) 81 | declarator: (function_declarator 82 | declarator: (identifier) 83 | parameters: (parameter_list 84 | (parameter_declaration 85 | type: (primitive_type) 86 | declarator: (identifier)) 87 | (parameter_declaration 88 | type: (primitive_type) 89 | declarator: (pointer_declarator 90 | (ms_pointer_modifier 91 | (ms_restrict_modifier)) 92 | declarator: (identifier))) 93 | (parameter_declaration 94 | type: (primitive_type) 95 | declarator: (pointer_declarator 96 | (ms_pointer_modifier 97 | (ms_restrict_modifier)) 98 | declarator: (identifier))) 99 | (parameter_declaration 100 | type: (primitive_type) 101 | declarator: (pointer_declarator 102 | declarator: (identifier))) 103 | (parameter_declaration 104 | type: (primitive_type) 105 | declarator: (pointer_declarator 106 | declarator: (identifier))))) 107 | body: (compound_statement 108 | (declaration 109 | type: (primitive_type) 110 | declarator: (identifier)) 111 | (for_statement 112 | initializer: (assignment_expression 113 | left: (identifier) 114 | right: (number_literal)) 115 | condition: (binary_expression 116 | left: (identifier) 117 | right: (identifier)) 118 | update: (update_expression 119 | argument: (identifier)) 120 | body: (compound_statement 121 | (expression_statement 122 | (assignment_expression 123 | left: (subscript_expression 124 | argument: (identifier) 125 | index: (identifier)) 126 | right: (binary_expression 127 | left: (subscript_expression 128 | argument: (identifier) 129 | index: (identifier)) 130 | right: (subscript_expression 131 | argument: (identifier) 132 | index: (identifier))))) 133 | (expression_statement 134 | (assignment_expression 135 | left: (subscript_expression 136 | argument: (identifier) 137 | index: (identifier)) 138 | right: (binary_expression 139 | left: (subscript_expression 140 | argument: (identifier) 141 | index: (identifier)) 142 | right: (subscript_expression 143 | argument: (identifier) 144 | index: (identifier))))))))) 145 | (declaration 146 | type: (primitive_type) 147 | declarator: (function_declarator 148 | declarator: (identifier) 149 | parameters: (parameter_list 150 | (parameter_declaration 151 | type: (primitive_type) 152 | declarator: (pointer_declarator 153 | (ms_pointer_modifier 154 | (ms_unsigned_ptr_modifier)) 155 | declarator: (identifier))))))) 156 | 157 | ================================ 158 | call modifiers 159 | ================================ 160 | 161 | __cdecl void mymethod(){ 162 | return; 163 | } 164 | 165 | __fastcall void mymethod(){ 166 | return; 167 | } 168 | 169 | --- 170 | 171 | (translation_unit 172 | (function_definition 173 | (ms_call_modifier) 174 | type: (primitive_type) 175 | declarator: (function_declarator 176 | declarator: (identifier) 177 | parameters: (parameter_list)) 178 | body: (compound_statement 179 | (return_statement))) 180 | (function_definition 181 | (ms_call_modifier) 182 | type: (primitive_type) 183 | declarator: (function_declarator 184 | declarator: (identifier) 185 | parameters: (parameter_list)) 186 | body: (compound_statement 187 | (return_statement)))) 188 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/corpus/preprocessor.txt: -------------------------------------------------------------------------------- 1 | ============================================ 2 | Include directives 3 | ============================================ 4 | 5 | #include "some/path.h" 6 | #include 7 | #include MACRO 8 | #include MACRO(arg1, arg2) 9 | 10 | --- 11 | 12 | (translation_unit 13 | (preproc_include path: (string_literal)) 14 | (preproc_include path: (system_lib_string)) 15 | (preproc_include path: (identifier)) 16 | (preproc_include path: 17 | (call_expression 18 | function: (identifier) 19 | arguments: (argument_list (identifier) (identifier))))) 20 | 21 | ============================================ 22 | Object-like macro definitions 23 | ============================================ 24 | 25 | #define ONE 26 | #define TWO int a = b; 27 | #define THREE \ 28 | c == d ? \ 29 | e : \ 30 | f 31 | #define FOUR (mno * pq) 32 | #define FIVE(a,b) x \ 33 | + y 34 | #define SIX(a, \ 35 | b) x \ 36 | + y 37 | 38 | --- 39 | 40 | (translation_unit 41 | (preproc_def name: (identifier)) 42 | (preproc_def name: (identifier) value: (preproc_arg)) 43 | (preproc_def name: (identifier) value: (preproc_arg)) 44 | (preproc_def name: (identifier) value: (preproc_arg)) 45 | (preproc_function_def name: (identifier) parameters: (preproc_params (identifier) (identifier)) value: (preproc_arg)) 46 | (preproc_function_def name: (identifier) parameters: (preproc_params (identifier) (identifier)) value: (preproc_arg))) 47 | 48 | ============================================ 49 | Function-like macro definitions 50 | ============================================ 51 | 52 | #define ONE() a 53 | #define TWO(b) c 54 | #define THREE(d, e) f 55 | #define FOUR(...) g 56 | #define FIVE(h, i, ...) j 57 | 58 | --- 59 | 60 | (translation_unit 61 | (preproc_function_def 62 | name: (identifier) 63 | parameters: (preproc_params) 64 | value: (preproc_arg)) 65 | (preproc_function_def 66 | name: (identifier) 67 | parameters: (preproc_params (identifier)) 68 | value: (preproc_arg)) 69 | (preproc_function_def 70 | name: (identifier) 71 | parameters: (preproc_params (identifier) (identifier)) 72 | value: (preproc_arg)) 73 | (preproc_function_def 74 | name: (identifier) 75 | parameters: (preproc_params) 76 | value: (preproc_arg)) 77 | (preproc_function_def 78 | name: (identifier) 79 | parameters: (preproc_params (identifier) (identifier)) 80 | value: (preproc_arg))) 81 | 82 | ============================================ 83 | Ifdefs 84 | ============================================ 85 | 86 | #ifndef DEFINE1 87 | int j; 88 | #endif 89 | 90 | #ifdef DEFINE2 91 | ssize_t b; 92 | #define c 32 93 | #elif defined DEFINE3 94 | #else 95 | int b; 96 | #define c 16 97 | #endif 98 | 99 | #ifdef DEFINE2 100 | #else 101 | # ifdef DEFINE3 102 | # else 103 | # endif 104 | #endif 105 | 106 | --- 107 | 108 | (translation_unit 109 | (preproc_ifdef 110 | name: (identifier) 111 | (declaration 112 | type: (primitive_type) 113 | declarator: (identifier))) 114 | 115 | (preproc_ifdef 116 | name: (identifier) 117 | (declaration 118 | type: (primitive_type) 119 | declarator: (identifier)) 120 | (preproc_def 121 | name: (identifier) 122 | value: (preproc_arg)) 123 | alternative: (preproc_elif 124 | condition: (preproc_defined (identifier)) 125 | alternative: (preproc_else 126 | (declaration 127 | type: (primitive_type) 128 | declarator: (identifier)) 129 | (preproc_def 130 | name: (identifier) 131 | value: (preproc_arg))))) 132 | 133 | (preproc_ifdef 134 | name: (identifier) 135 | alternative: (preproc_else 136 | (preproc_ifdef 137 | name: (identifier) 138 | alternative: (preproc_else))))) 139 | 140 | =============================================================== 141 | General if blocks 142 | ========================================== 143 | 144 | #if defined(__GNUC__) && defined(__PIC__) 145 | #define inline inline __attribute__((always_inline)) 146 | #elif defined(_WIN32) 147 | #define something 148 | #elif !defined(SOMETHING_ELSE) 149 | #define SOMETHING_ELSE 150 | #else 151 | #include 152 | #endif 153 | 154 | --- 155 | 156 | (translation_unit 157 | (preproc_if 158 | condition: (binary_expression 159 | left: (preproc_defined (identifier)) 160 | right: (preproc_defined (identifier))) 161 | (preproc_def 162 | name: (identifier) 163 | value: (preproc_arg)) 164 | alternative: (preproc_elif 165 | condition: (preproc_defined (identifier)) 166 | (preproc_def 167 | name: (identifier)) 168 | alternative: (preproc_elif 169 | condition: (unary_expression 170 | argument: (preproc_defined (identifier))) 171 | (preproc_def 172 | name: (identifier)) 173 | alternative: (preproc_else 174 | (preproc_include path: (system_lib_string))))))) 175 | 176 | ============================================ 177 | Preprocessor conditionals in functions 178 | ============================================ 179 | 180 | int main() { 181 | #if d 182 | puts("1"); 183 | #else 184 | puts("2"); 185 | #endif 186 | 187 | #if a 188 | return 0; 189 | #elif b 190 | return 1; 191 | #elif c 192 | return 2; 193 | #else 194 | return 3; 195 | #endif 196 | } 197 | 198 | --- 199 | 200 | (translation_unit 201 | (function_definition 202 | (primitive_type) 203 | (function_declarator (identifier) (parameter_list)) 204 | (compound_statement 205 | (preproc_if 206 | (identifier) 207 | (expression_statement (call_expression (identifier) (argument_list (string_literal)))) 208 | (preproc_else 209 | (expression_statement (call_expression (identifier) (argument_list (string_literal)))))) 210 | (preproc_if 211 | (identifier) 212 | (return_statement (number_literal)) 213 | (preproc_elif 214 | (identifier) 215 | (return_statement (number_literal)) 216 | (preproc_elif 217 | (identifier) 218 | (return_statement (number_literal)) 219 | (preproc_else 220 | (return_statement (number_literal))))))))) 221 | 222 | ================================================= 223 | Preprocessor conditionals in struct/union bodies 224 | ================================================= 225 | 226 | struct S { 227 | #ifdef _WIN32 228 | LONG f2; 229 | #else 230 | uint32_t f2; 231 | #endif 232 | }; 233 | 234 | --- 235 | 236 | (translation_unit 237 | (struct_specifier (type_identifier) (field_declaration_list 238 | (preproc_ifdef (identifier) 239 | (field_declaration (type_identifier) (field_identifier)) 240 | (preproc_else 241 | (field_declaration (primitive_type) (field_identifier))))))) 242 | 243 | ==================================== 244 | Unknown preprocessor directives 245 | ==================================== 246 | 247 | #pragma mark - UIViewController 248 | 249 | --- 250 | 251 | (translation_unit (preproc_call 252 | directive: (preproc_directive) 253 | argument: (preproc_arg))) 254 | 255 | ====================================== 256 | Preprocessor expressions 257 | ====================================== 258 | 259 | #if A(B || C) && \ 260 | !D(F) 261 | 262 | uint32_t a; 263 | 264 | #endif 265 | 266 | --- 267 | 268 | (translation_unit 269 | (preproc_if 270 | (binary_expression 271 | (call_expression (identifier) (argument_list (binary_expression (identifier) (identifier)))) 272 | (unary_expression 273 | (call_expression (identifier) (argument_list (identifier))))) 274 | (declaration (primitive_type) (identifier)))) 275 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/corpus/types.txt: -------------------------------------------------------------------------------- 1 | ======================================== 2 | Primitive types 3 | ======================================== 4 | 5 | int a; 6 | uint8_t a; 7 | uint16_t a; 8 | uint32_t a; 9 | uint64_t a; 10 | uintptr_t a; 11 | 12 | int8_t a; 13 | int16_t a; 14 | int32_t a; 15 | int64_t a; 16 | intptr_t a; 17 | 18 | char16_t a; 19 | char32_t a; 20 | 21 | size_t a; 22 | ssize_t a; 23 | 24 | --- 25 | 26 | (translation_unit 27 | (declaration (primitive_type) (identifier)) 28 | (declaration (primitive_type) (identifier)) 29 | (declaration (primitive_type) (identifier)) 30 | (declaration (primitive_type) (identifier)) 31 | (declaration (primitive_type) (identifier)) 32 | (declaration (primitive_type) (identifier)) 33 | (declaration (primitive_type) (identifier)) 34 | (declaration (primitive_type) (identifier)) 35 | (declaration (primitive_type) (identifier)) 36 | (declaration (primitive_type) (identifier)) 37 | (declaration (primitive_type) (identifier)) 38 | (declaration (primitive_type) (identifier)) 39 | (declaration (primitive_type) (identifier)) 40 | (declaration (primitive_type) (identifier)) 41 | (declaration (primitive_type) (identifier))) 42 | 43 | ======================================== 44 | Type modifiers 45 | ======================================== 46 | 47 | void f(unsigned); 48 | void f(unsigned int); 49 | void f(signed long int); 50 | void f(unsigned v1); 51 | void f(unsigned long v2); 52 | 53 | --- 54 | 55 | (translation_unit 56 | (declaration 57 | (primitive_type) 58 | (function_declarator 59 | (identifier) 60 | (parameter_list (parameter_declaration (sized_type_specifier))))) 61 | (declaration 62 | (primitive_type) 63 | (function_declarator 64 | (identifier) 65 | (parameter_list (parameter_declaration (sized_type_specifier (primitive_type)))))) 66 | (declaration 67 | (primitive_type) 68 | (function_declarator 69 | (identifier) 70 | (parameter_list (parameter_declaration (sized_type_specifier (primitive_type)))))) 71 | (declaration 72 | (primitive_type) 73 | (function_declarator 74 | (identifier) 75 | (parameter_list (parameter_declaration (sized_type_specifier) (identifier))))) 76 | (declaration 77 | (primitive_type) 78 | (function_declarator 79 | (identifier) 80 | (parameter_list (parameter_declaration (sized_type_specifier) (identifier)))))) 81 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/highlight/keywords.c: -------------------------------------------------------------------------------- 1 | #include 2 | // ^ keyword 3 | // ^ string 4 | 5 | #include "something.h" 6 | // ^ string 7 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/c/test/highlight/names.c: -------------------------------------------------------------------------------- 1 | typedef struct { 2 | // ^ keyword 3 | // ^ keyword 4 | a_t b; 5 | // <- type 6 | // ^ property 7 | 8 | unsigned c_t (*d)[2]; 9 | // ^ type 10 | // ^ type 11 | // ^ property 12 | }, T, V; 13 | // ^ type 14 | // ^ type 15 | 16 | int main(const char string[SIZE]) { 17 | // <- type 18 | // ^ function 19 | // ^ keyword 20 | // ^ type 21 | // ^ variable 22 | // ^ constant 23 | 24 | return foo.bar + foo.baz(); 25 | // ^ keyword 26 | // ^ variable 27 | // ^ property 28 | // ^ function 29 | 30 | error: 31 | // <- label 32 | return 0; 33 | } 34 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tree-sitter-cpp" 3 | description = "Cpp grammar for the tree-sitter parsing library" 4 | version = "0.20.0" 5 | authors = ["Max Brunsfeld "] 6 | license = "MIT" 7 | readme = "bindings/rust/README.md" 8 | keywords = ["incremental", "parsing", "cpp"] 9 | categories = ["parsing", "text-editors"] 10 | repository = "https://github.com/tree-sitter/tree-sitter-cpp" 11 | edition = "2018" 12 | 13 | build = "bindings/rust/build.rs" 14 | include = [ 15 | "bindings/rust/*", 16 | "grammar.js", 17 | "queries/*", 18 | "src/*", 19 | ] 20 | 21 | [lib] 22 | path = "bindings/rust/lib.rs" 23 | 24 | [dependencies] 25 | tree-sitter = "0.20" 26 | 27 | [build-dependencies] 28 | cc = "1.0" 29 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Max Brunsfeld 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version:5.3 2 | import PackageDescription 3 | 4 | let package = Package( 5 | name: "TreeSitterCPP", 6 | platforms: [.macOS(.v10_13), .iOS(.v11)], 7 | products: [ 8 | .library(name: "TreeSitterCPP", targets: ["TreeSitterCPP"]), 9 | ], 10 | dependencies: [], 11 | targets: [ 12 | .target(name: "TreeSitterCPP", 13 | path: ".", 14 | exclude: [ 15 | "binding.gyp", 16 | "bindings", 17 | "Cargo.toml", 18 | "corpus", 19 | "examples", 20 | "grammar.js", 21 | "LICENSE", 22 | "Makefile", 23 | "package.json", 24 | "README.md", 25 | "src/grammar.json", 26 | "src/node-types.json", 27 | ], 28 | sources: [ 29 | "src/parser.c", 30 | "src/scanner.cc", 31 | ], 32 | resources: [ 33 | .copy("queries") 34 | ], 35 | publicHeadersPath: "bindings/swift", 36 | cSettings: [.headerSearchPath("src")]) 37 | ] 38 | ) -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/README.md: -------------------------------------------------------------------------------- 1 | tree-sitter-cpp 2 | ================== 3 | 4 | [![Build Status](https://travis-ci.org/tree-sitter/tree-sitter-cpp.svg?branch=master)](https://travis-ci.org/tree-sitter/tree-sitter-cpp) 5 | [![Build status](https://ci.appveyor.com/api/projects/status/fbj5gq4plxaiakiw/branch/master?svg=true)](https://ci.appveyor.com/project/maxbrunsfeld/tree-sitter-cpp/branch/master) 6 | 7 | C++ grammar for [tree-sitter](https://github.com/tree-sitter/tree-sitter). 8 | 9 | # References 10 | 11 | * [Hyperlinked C++ BNF Grammar](http://www.nongnu.org/hcb/) 12 | * [EBNF Syntax: C++](http://www.externsoft.ch/download/cpp-iso.html) 13 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "tree_sitter_cpp_binding", 5 | "include_dirs": [ 6 | " 3 | #include "nan.h" 4 | 5 | using namespace v8; 6 | 7 | extern "C" TSLanguage * tree_sitter_cpp(); 8 | 9 | namespace { 10 | 11 | NAN_METHOD(New) {} 12 | 13 | void Init(Local exports, Local module) { 14 | Local tpl = Nan::New(New); 15 | tpl->SetClassName(Nan::New("Language").ToLocalChecked()); 16 | tpl->InstanceTemplate()->SetInternalFieldCount(1); 17 | 18 | Local constructor = Nan::GetFunction(tpl).ToLocalChecked(); 19 | Local instance = constructor->NewInstance(Nan::GetCurrentContext()).ToLocalChecked(); 20 | Nan::SetInternalFieldPointer(instance, 0, tree_sitter_cpp()); 21 | 22 | Nan::Set(instance, Nan::New("name").ToLocalChecked(), Nan::New("cpp").ToLocalChecked()); 23 | Nan::Set(module, Nan::New("exports").ToLocalChecked(), instance); 24 | } 25 | 26 | NODE_MODULE(tree_sitter_cpp_binding, Init) 27 | 28 | } // namespace 29 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/bindings/node/index.js: -------------------------------------------------------------------------------- 1 | try { 2 | module.exports = require("../../build/Release/tree_sitter_cpp_binding"); 3 | } catch (error1) { 4 | if (error1.code !== 'MODULE_NOT_FOUND') { 5 | throw error1; 6 | } 7 | try { 8 | module.exports = require("../../build/Debug/tree_sitter_cpp_binding"); 9 | } catch (error2) { 10 | if (error2.code !== 'MODULE_NOT_FOUND') { 11 | throw error2; 12 | } 13 | throw error1 14 | } 15 | } 16 | 17 | try { 18 | module.exports.nodeTypeInfo = require("../../src/node-types.json"); 19 | } catch (_) {} 20 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/bindings/rust/README.md: -------------------------------------------------------------------------------- 1 | # tree-sitter-cpp 2 | 3 | This crate provides a CPP grammar for the [tree-sitter][] parsing library. To 4 | use this crate, add it to the `[dependencies]` section of your `Cargo.toml` 5 | file. (Note that you will probably also need to depend on the 6 | [`tree-sitter`][tree-sitter crate] crate to use the parsed result in any useful 7 | way.) 8 | 9 | ``` toml 10 | [dependencies] 11 | tree-sitter = "0.17" 12 | tree-sitter-cpp = "0.16" 13 | ``` 14 | 15 | Typically, you will use the [language][language func] function to add this 16 | grammar to a tree-sitter [Parser][], and then use the parser to parse some code: 17 | 18 | ``` rust 19 | let code = r#" 20 | int double(int x) { 21 | return x * 2; 22 | } 23 | "#; 24 | let mut parser = Parser::new(); 25 | parser.set_language(tree_sitter_cpp::language()).expect("Error loading CPP grammar"); 26 | let parsed = parser.parse(code, None); 27 | ``` 28 | 29 | If you have any questions, please reach out to us in the [tree-sitter 30 | discussions] page. 31 | 32 | [Language]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Language.html 33 | [language func]: https://docs.rs/tree-sitter-cpp/*/tree_sitter_cpp/fn.language.html 34 | [Parser]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Parser.html 35 | [tree-sitter]: https://tree-sitter.github.io/ 36 | [tree-sitter crate]: https://crates.io/crates/tree-sitter 37 | [tree-sitter discussions]: https://github.com/tree-sitter/tree-sitter/discussions 38 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/bindings/rust/build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | let src_dir = std::path::Path::new("src"); 3 | 4 | let mut c_config = cc::Build::new(); 5 | c_config.include(&src_dir); 6 | c_config 7 | .flag_if_supported("-Wno-unused-parameter") 8 | .flag_if_supported("-Wno-unused-but-set-variable") 9 | .flag_if_supported("-Wno-trigraphs"); 10 | let parser_path = src_dir.join("parser.c"); 11 | c_config.file(&parser_path); 12 | println!("cargo:rerun-if-changed={}", parser_path.to_str().unwrap()); 13 | c_config.compile("parser"); 14 | 15 | let mut cpp_config = cc::Build::new(); 16 | cpp_config.cpp(true); 17 | cpp_config.include(&src_dir); 18 | cpp_config 19 | .flag_if_supported("-Wno-unused-parameter") 20 | .flag_if_supported("-Wno-unused-but-set-variable"); 21 | let scanner_path = src_dir.join("scanner.cc"); 22 | cpp_config.file(&scanner_path); 23 | println!("cargo:rerun-if-changed={}", scanner_path.to_str().unwrap()); 24 | cpp_config.compile("scanner"); 25 | } 26 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/bindings/rust/lib.rs: -------------------------------------------------------------------------------- 1 | // -*- coding: utf-8 -*- 2 | // ------------------------------------------------------------------------------------------------ 3 | // Copyright © 2021, tree-sitter-cpp authors. 4 | // See the LICENSE file in this repo for license details. 5 | // ------------------------------------------------------------------------------------------------ 6 | 7 | //! This crate provides a Cpp grammar for the [tree-sitter][] parsing library. 8 | //! 9 | //! Typically, you will use the [language][language func] function to add this grammar to a 10 | //! tree-sitter [Parser][], and then use the parser to parse some code: 11 | //! 12 | //! ``` 13 | //! use tree_sitter::Parser; 14 | //! 15 | //! let code = r#" 16 | //! int double(int x) { 17 | //! return x * 2; 18 | //! } 19 | //! "#; 20 | //! let mut parser = Parser::new(); 21 | //! parser.set_language(tree_sitter_cpp::language()).expect("Error loading Cpp grammar"); 22 | //! let parsed = parser.parse(code, None); 23 | //! # let parsed = parsed.unwrap(); 24 | //! # let root = parsed.root_node(); 25 | //! # assert!(!root.has_error()); 26 | //! ``` 27 | //! 28 | //! [Language]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Language.html 29 | //! [language func]: fn.language.html 30 | //! [Parser]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Parser.html 31 | //! [tree-sitter]: https://tree-sitter.github.io/ 32 | 33 | use tree_sitter::Language; 34 | 35 | extern "C" { 36 | fn tree_sitter_cpp() -> Language; 37 | } 38 | 39 | /// Returns the tree-sitter [Language][] for this grammar. 40 | /// 41 | /// [Language]: https://docs.rs/tree-sitter/*/tree_sitter/struct.Language.html 42 | pub fn language() -> Language { 43 | unsafe { tree_sitter_cpp() } 44 | } 45 | 46 | /// The source of the Cpp tree-sitter grammar description. 47 | pub const GRAMMAR: &str = include_str!("../../grammar.js"); 48 | 49 | /// The syntax highlighting query for this language. 50 | pub const HIGHLIGHT_QUERY: &str = include_str!("../../queries/highlights.scm"); 51 | 52 | /// The content of the [`node-types.json`][] file for this grammar. 53 | /// 54 | /// [`node-types.json`]: https://tree-sitter.github.io/tree-sitter/using-parsers#static-node-types 55 | pub const NODE_TYPES: &str = include_str!("../../src/node-types.json"); 56 | 57 | #[cfg(test)] 58 | mod tests { 59 | #[test] 60 | fn can_load_grammar() { 61 | let mut parser = tree_sitter::Parser::new(); 62 | parser 63 | .set_language(super::language()) 64 | .expect("Error loading Cpp grammar"); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/bindings/swift/TreeSitterCPP/cpp.h: -------------------------------------------------------------------------------- 1 | #ifndef TREE_SITTER_CPP_H_ 2 | #define TREE_SITTER_CPP_H_ 3 | 4 | typedef struct TSLanguage TSLanguage; 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | extern TSLanguage *tree_sitter_cpp(); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | #endif // TREE_SITTER_CPP_H_ -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/examples/marker-index.h: -------------------------------------------------------------------------------- 1 | #ifndef MARKER_INDEX_H_ 2 | #define MARKER_INDEX_H_ 3 | 4 | #include 5 | #include 6 | #include "flat_set.h" 7 | #include "point.h" 8 | #include "range.h" 9 | 10 | class MarkerIndex { 11 | public: 12 | using MarkerId = unsigned; 13 | using MarkerIdSet = flat_set; 14 | 15 | struct SpliceResult { 16 | flat_set touch; 17 | flat_set inside; 18 | flat_set overlap; 19 | flat_set surround; 20 | }; 21 | 22 | struct Boundary { 23 | Point position; 24 | flat_set starting; 25 | flat_set ending; 26 | }; 27 | 28 | struct BoundaryQueryResult { 29 | std::vector containing_start; 30 | std::vector boundaries; 31 | }; 32 | 33 | MarkerIndex(unsigned seed = 0u); 34 | ~MarkerIndex(); 35 | int generate_random_number(); 36 | void insert(MarkerId id, Point start, Point end); 37 | void set_exclusive(MarkerId id, bool exclusive); 38 | void remove(MarkerId id); 39 | bool has(MarkerId id); 40 | SpliceResult splice(Point start, Point old_extent, Point new_extent); 41 | Point get_start(MarkerId id) const; 42 | Point get_end(MarkerId id) const; 43 | Range get_range(MarkerId id) const; 44 | 45 | int compare(MarkerId id1, MarkerId id2) const; 46 | flat_set find_intersecting(Point start, Point end); 47 | flat_set find_containing(Point start, Point end); 48 | flat_set find_contained_in(Point start, Point end); 49 | flat_set find_starting_in(Point start, Point end); 50 | flat_set find_starting_at(Point position); 51 | flat_set find_ending_in(Point start, Point end); 52 | flat_set find_ending_at(Point position); 53 | BoundaryQueryResult find_boundaries_after(Point start, size_t max_count); 54 | 55 | std::unordered_map dump(); 56 | 57 | private: 58 | friend class Iterator; 59 | 60 | struct Node { 61 | Node *parent; 62 | Node *left; 63 | Node *right; 64 | Point left_extent; 65 | flat_set left_marker_ids; 66 | flat_set right_marker_ids; 67 | flat_set start_marker_ids; 68 | flat_set end_marker_ids; 69 | int priority; 70 | 71 | Node(Node *parent, Point left_extent); 72 | bool is_marker_endpoint(); 73 | }; 74 | 75 | class Iterator { 76 | public: 77 | Iterator(MarkerIndex *marker_index); 78 | void reset(); 79 | Node* insert_marker_start(const MarkerId &id, const Point &start_position, const Point &end_position); 80 | Node* insert_marker_end(const MarkerId &id, const Point &start_position, const Point &end_position); 81 | Node* insert_splice_boundary(const Point &position, bool is_insertion_end); 82 | void find_intersecting(const Point &start, const Point &end, flat_set *result); 83 | void find_contained_in(const Point &start, const Point &end, flat_set *result); 84 | void find_starting_in(const Point &start, const Point &end, flat_set *result); 85 | void find_ending_in(const Point &start, const Point &end, flat_set *result); 86 | void find_boundaries_after(Point start, size_t max_count, BoundaryQueryResult *result); 87 | std::unordered_map dump(); 88 | 89 | private: 90 | void ascend(); 91 | void descend_left(); 92 | void descend_right(); 93 | void move_to_successor(); 94 | void seek_to_first_node_greater_than_or_equal_to(const Point &position); 95 | void mark_right(const MarkerId &id, const Point &start_position, const Point &end_position); 96 | void mark_left(const MarkerId &id, const Point &start_position, const Point &end_position); 97 | Node* insert_left_child(const Point &position); 98 | Node* insert_right_child(const Point &position); 99 | void check_intersection(const Point &start, const Point &end, flat_set *results); 100 | void cache_node_position() const; 101 | 102 | MarkerIndex *marker_index; 103 | Node *current_node; 104 | Point current_node_position; 105 | Point left_ancestor_position; 106 | Point right_ancestor_position; 107 | std::vector left_ancestor_position_stack; 108 | std::vector right_ancestor_position_stack; 109 | }; 110 | 111 | Point get_node_position(const Node *node) const; 112 | void delete_node(Node *node); 113 | void delete_subtree(Node *node); 114 | void bubble_node_up(Node *node); 115 | void bubble_node_down(Node *node); 116 | void rotate_node_left(Node *pivot); 117 | void rotate_node_right(Node *pivot); 118 | void get_starting_and_ending_markers_within_subtree(const Node *node, flat_set *starting, flat_set *ending); 119 | void populate_splice_invalidation_sets(SpliceResult *invalidated, const Node *start_node, const Node *end_node, const flat_set &starting_inside_splice, const flat_set &ending_inside_splice); 120 | 121 | std::default_random_engine random_engine; 122 | std::uniform_int_distribution random_distribution; 123 | Node *root; 124 | std::unordered_map start_nodes_by_id; 125 | std::unordered_map end_nodes_by_id; 126 | Iterator iterator; 127 | flat_set exclusive_marker_ids; 128 | mutable std::unordered_map node_position_cache; 129 | }; 130 | 131 | #endif // MARKER_INDEX_H_ 132 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tree-sitter-cpp", 3 | "version": "0.20.0", 4 | "description": "C++ grammar for tree-sitter", 5 | "main": "bindings/node", 6 | "keywords": [ 7 | "parser", 8 | "c++" 9 | ], 10 | "repository": { 11 | "type": "git", 12 | "url": "https://github.com/tree-sitter/tree-sitter-cpp.git" 13 | }, 14 | "author": "Max Brunsfeld", 15 | "license": "MIT", 16 | "dependencies": { 17 | "nan": "^2.14.0" 18 | }, 19 | "devDependencies": { 20 | "tree-sitter-c": "^0.20.2", 21 | "tree-sitter-cli": "^0.20.0" 22 | }, 23 | "scripts": { 24 | "test": "tree-sitter test && tree-sitter parse examples/* --quiet --time", 25 | "test-windows": "tree-sitter test" 26 | }, 27 | "tree-sitter": [ 28 | { 29 | "scope": "source.cpp", 30 | "file-types": [ 31 | "cc", 32 | "cpp", 33 | "hpp", 34 | "h" 35 | ], 36 | "highlights": [ 37 | "queries/highlights.scm", 38 | "node_modules/tree-sitter-c/queries/highlights.scm" 39 | ], 40 | "injections": "queries/injections.scm", 41 | "injection-regex": "^(cc|cpp)$" 42 | } 43 | ] 44 | } 45 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/queries/highlights.scm: -------------------------------------------------------------------------------- 1 | ; Functions 2 | 3 | (call_expression 4 | function: (qualified_identifier 5 | name: (identifier) @function)) 6 | 7 | (template_function 8 | name: (identifier) @function) 9 | 10 | (template_method 11 | name: (field_identifier) @function) 12 | 13 | (template_function 14 | name: (identifier) @function) 15 | 16 | (function_declarator 17 | declarator: (qualified_identifier 18 | name: (identifier) @function)) 19 | 20 | (function_declarator 21 | declarator: (qualified_identifier 22 | name: (identifier) @function)) 23 | 24 | (function_declarator 25 | declarator: (field_identifier) @function) 26 | 27 | ; Types 28 | 29 | ((namespace_identifier) @type 30 | (#match? @type "^[A-Z]")) 31 | 32 | (auto) @type 33 | 34 | ; Constants 35 | 36 | (this) @variable.builtin 37 | (nullptr) @constant 38 | 39 | ; Keywords 40 | 41 | [ 42 | "catch" 43 | "class" 44 | "co_await" 45 | "co_return" 46 | "co_yield" 47 | "constexpr" 48 | "constinit" 49 | "consteval" 50 | "delete" 51 | "explicit" 52 | "final" 53 | "friend" 54 | "mutable" 55 | "namespace" 56 | "noexcept" 57 | "new" 58 | "override" 59 | "private" 60 | "protected" 61 | "public" 62 | "template" 63 | "throw" 64 | "try" 65 | "typename" 66 | "using" 67 | "virtual" 68 | "concept" 69 | "requires" 70 | ] @keyword 71 | 72 | ; Strings 73 | 74 | (raw_string_literal) @string 75 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/queries/injections.scm: -------------------------------------------------------------------------------- 1 | (raw_string_literal 2 | delimiter: (raw_string_delimiter) @injection.language 3 | (raw_string_content) @injection.content) 4 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/src/scanner.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace { 8 | 9 | using std::wstring; 10 | using std::iswspace; 11 | 12 | enum TokenType { 13 | RAW_STRING_DELIMITER, 14 | RAW_STRING_CONTENT, 15 | }; 16 | 17 | // The spec limits delimiters to 16 chars, enforce this to bound serialization. 18 | const unsigned RAW_STRING_DELIMITER_MAX = 16; 19 | 20 | struct Scanner { 21 | // Last raw_string_delimiter, used to detect when raw_string_content ends. 22 | wstring delimiter; 23 | 24 | // Scan a raw string delimiter in R"delimiter(content)delimiter". 25 | bool scan_raw_string_delimiter(TSLexer *lexer) { 26 | if (!delimiter.empty()) { 27 | // Closing delimiter: must exactly match the opening delimiter. 28 | // We already checked this when scanning content, but this is how we know 29 | // when to stop. We can't stop at ", because R"""hello""" is valid. 30 | for (std::size_t i = 0; i < delimiter.size(); ++i) { 31 | if (lexer->lookahead != delimiter[i]) 32 | return false; 33 | lexer->advance(lexer, false); 34 | } 35 | delimiter.clear(); 36 | return true; 37 | } 38 | // Opening delimiter: record the d-char-sequence up to (. 39 | // d-char is any basic character except parens, backslashes, and spaces. 40 | for (;;) { 41 | if (delimiter.size() > RAW_STRING_DELIMITER_MAX || 42 | lexer->eof(lexer) || lexer->lookahead == '\\' || iswspace(lexer->lookahead)) { 43 | return false; 44 | } 45 | if (lexer->lookahead == '(') { 46 | // Rather than create a token for an empty delimiter, we fail and let 47 | // the grammar fall back to a delimiter-less rule. 48 | return !delimiter.empty(); 49 | } 50 | delimiter += lexer->lookahead; 51 | lexer->advance(lexer, false); 52 | } 53 | } 54 | 55 | // Scan the raw string content in R"delimiter(content)delimiter". 56 | bool scan_raw_string_content(TSLexer *lexer) { 57 | // The progress made through the delimiter since the last ')'. 58 | // The delimiter may not contain ')' so a single counter suffices. 59 | int delimiter_index = -1; 60 | for (;;) { 61 | // If we hit EOF, consider the content to terminate there. 62 | // This forms an incomplete raw_string_literal, and models the code well. 63 | if (lexer->eof(lexer)) { 64 | lexer->mark_end(lexer); 65 | return true; 66 | } 67 | 68 | if (delimiter_index >= 0) { 69 | if (static_cast(delimiter_index) == delimiter.size()) { 70 | if (lexer->lookahead == '"') { 71 | return true; 72 | } else { 73 | delimiter_index = -1; 74 | } 75 | } else { 76 | if (lexer->lookahead == delimiter[delimiter_index]) { 77 | delimiter_index++; 78 | } else { 79 | delimiter_index = -1; 80 | } 81 | } 82 | } 83 | 84 | if (delimiter_index == -1 && lexer->lookahead == ')') { 85 | // The content doesn't include the )delimiter" part. 86 | // We must still scan through it, but exclude it from the token. 87 | lexer->mark_end(lexer); 88 | delimiter_index = 0; 89 | } 90 | 91 | lexer->advance(lexer, false); 92 | } 93 | } 94 | 95 | bool scan(TSLexer *lexer, const bool *valid_symbols) { 96 | // No skipping leading whitespace: raw-string grammar is space-sensitive. 97 | 98 | if (valid_symbols[RAW_STRING_DELIMITER]) { 99 | lexer->result_symbol = RAW_STRING_DELIMITER; 100 | return scan_raw_string_delimiter(lexer); 101 | } 102 | 103 | if (valid_symbols[RAW_STRING_CONTENT]) { 104 | lexer->result_symbol = RAW_STRING_CONTENT; 105 | return scan_raw_string_content(lexer); 106 | } 107 | 108 | return false; 109 | } 110 | }; 111 | 112 | } 113 | 114 | extern "C" { 115 | 116 | void *tree_sitter_cpp_external_scanner_create() { 117 | return new Scanner(); 118 | } 119 | 120 | bool tree_sitter_cpp_external_scanner_scan(void *payload, TSLexer *lexer, 121 | const bool *valid_symbols) { 122 | Scanner *scanner = static_cast(payload); 123 | return scanner->scan(lexer, valid_symbols); 124 | } 125 | 126 | unsigned tree_sitter_cpp_external_scanner_serialize(void *payload, char *buffer) { 127 | #if __cpp_static_assert >= 200410L 128 | static_assert(RAW_STRING_DELIMITER_MAX * sizeof(wchar_t) < 129 | TREE_SITTER_SERIALIZATION_BUFFER_SIZE, 130 | "Raw string delimiters may not be serializable!"); 131 | #endif 132 | 133 | Scanner *scanner = static_cast(payload); 134 | size_t size = scanner->delimiter.size() * sizeof(wchar_t); 135 | memcpy(buffer, scanner->delimiter.data(), size); 136 | return size; 137 | } 138 | 139 | void tree_sitter_cpp_external_scanner_deserialize(void *payload, const char *buffer, unsigned length) { 140 | assert(length % sizeof(wchar_t) == 0 && "Can't decode serialized delimiter!"); 141 | 142 | Scanner *scanner = static_cast(payload); 143 | scanner->delimiter.resize(length/sizeof(wchar_t)); 144 | memcpy(&scanner->delimiter[0], buffer, length); 145 | } 146 | 147 | void tree_sitter_cpp_external_scanner_destroy(void *payload) { 148 | Scanner *scanner = static_cast(payload); 149 | delete scanner; 150 | } 151 | 152 | } 153 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/src/tree_sitter/parser.h: -------------------------------------------------------------------------------- 1 | #ifndef TREE_SITTER_PARSER_H_ 2 | #define TREE_SITTER_PARSER_H_ 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #define ts_builtin_sym_error ((TSSymbol)-1) 13 | #define ts_builtin_sym_end 0 14 | #define TREE_SITTER_SERIALIZATION_BUFFER_SIZE 1024 15 | 16 | typedef uint16_t TSStateId; 17 | 18 | #ifndef TREE_SITTER_API_H_ 19 | typedef uint16_t TSSymbol; 20 | typedef uint16_t TSFieldId; 21 | typedef struct TSLanguage TSLanguage; 22 | #endif 23 | 24 | typedef struct { 25 | TSFieldId field_id; 26 | uint8_t child_index; 27 | bool inherited; 28 | } TSFieldMapEntry; 29 | 30 | typedef struct { 31 | uint16_t index; 32 | uint16_t length; 33 | } TSFieldMapSlice; 34 | 35 | typedef struct { 36 | bool visible; 37 | bool named; 38 | bool supertype; 39 | } TSSymbolMetadata; 40 | 41 | typedef struct TSLexer TSLexer; 42 | 43 | struct TSLexer { 44 | int32_t lookahead; 45 | TSSymbol result_symbol; 46 | void (*advance)(TSLexer *, bool); 47 | void (*mark_end)(TSLexer *); 48 | uint32_t (*get_column)(TSLexer *); 49 | bool (*is_at_included_range_start)(const TSLexer *); 50 | bool (*eof)(const TSLexer *); 51 | }; 52 | 53 | typedef enum { 54 | TSParseActionTypeShift, 55 | TSParseActionTypeReduce, 56 | TSParseActionTypeAccept, 57 | TSParseActionTypeRecover, 58 | } TSParseActionType; 59 | 60 | typedef union { 61 | struct { 62 | uint8_t type; 63 | TSStateId state; 64 | bool extra; 65 | bool repetition; 66 | } shift; 67 | struct { 68 | uint8_t type; 69 | uint8_t child_count; 70 | TSSymbol symbol; 71 | int16_t dynamic_precedence; 72 | uint16_t production_id; 73 | } reduce; 74 | uint8_t type; 75 | } TSParseAction; 76 | 77 | typedef struct { 78 | uint16_t lex_state; 79 | uint16_t external_lex_state; 80 | } TSLexMode; 81 | 82 | typedef union { 83 | TSParseAction action; 84 | struct { 85 | uint8_t count; 86 | bool reusable; 87 | } entry; 88 | } TSParseActionEntry; 89 | 90 | struct TSLanguage { 91 | uint32_t version; 92 | uint32_t symbol_count; 93 | uint32_t alias_count; 94 | uint32_t token_count; 95 | uint32_t external_token_count; 96 | uint32_t state_count; 97 | uint32_t large_state_count; 98 | uint32_t production_id_count; 99 | uint32_t field_count; 100 | uint16_t max_alias_sequence_length; 101 | const uint16_t *parse_table; 102 | const uint16_t *small_parse_table; 103 | const uint32_t *small_parse_table_map; 104 | const TSParseActionEntry *parse_actions; 105 | const char * const *symbol_names; 106 | const char * const *field_names; 107 | const TSFieldMapSlice *field_map_slices; 108 | const TSFieldMapEntry *field_map_entries; 109 | const TSSymbolMetadata *symbol_metadata; 110 | const TSSymbol *public_symbol_map; 111 | const uint16_t *alias_map; 112 | const TSSymbol *alias_sequences; 113 | const TSLexMode *lex_modes; 114 | bool (*lex_fn)(TSLexer *, TSStateId); 115 | bool (*keyword_lex_fn)(TSLexer *, TSStateId); 116 | TSSymbol keyword_capture_token; 117 | struct { 118 | const bool *states; 119 | const TSSymbol *symbol_map; 120 | void *(*create)(void); 121 | void (*destroy)(void *); 122 | bool (*scan)(void *, TSLexer *, const bool *symbol_whitelist); 123 | unsigned (*serialize)(void *, char *); 124 | void (*deserialize)(void *, const char *, unsigned); 125 | } external_scanner; 126 | const TSStateId *primary_state_ids; 127 | }; 128 | 129 | /* 130 | * Lexer Macros 131 | */ 132 | 133 | #define START_LEXER() \ 134 | bool result = false; \ 135 | bool skip = false; \ 136 | bool eof = false; \ 137 | int32_t lookahead; \ 138 | goto start; \ 139 | next_state: \ 140 | lexer->advance(lexer, skip); \ 141 | start: \ 142 | skip = false; \ 143 | lookahead = lexer->lookahead; 144 | 145 | #define ADVANCE(state_value) \ 146 | { \ 147 | state = state_value; \ 148 | goto next_state; \ 149 | } 150 | 151 | #define SKIP(state_value) \ 152 | { \ 153 | skip = true; \ 154 | state = state_value; \ 155 | goto next_state; \ 156 | } 157 | 158 | #define ACCEPT_TOKEN(symbol_value) \ 159 | result = true; \ 160 | lexer->result_symbol = symbol_value; \ 161 | lexer->mark_end(lexer); 162 | 163 | #define END_STATE() return result; 164 | 165 | /* 166 | * Parse Table Macros 167 | */ 168 | 169 | #define SMALL_STATE(id) id - LARGE_STATE_COUNT 170 | 171 | #define STATE(id) id 172 | 173 | #define ACTIONS(id) id 174 | 175 | #define SHIFT(state_value) \ 176 | {{ \ 177 | .shift = { \ 178 | .type = TSParseActionTypeShift, \ 179 | .state = state_value \ 180 | } \ 181 | }} 182 | 183 | #define SHIFT_REPEAT(state_value) \ 184 | {{ \ 185 | .shift = { \ 186 | .type = TSParseActionTypeShift, \ 187 | .state = state_value, \ 188 | .repetition = true \ 189 | } \ 190 | }} 191 | 192 | #define SHIFT_EXTRA() \ 193 | {{ \ 194 | .shift = { \ 195 | .type = TSParseActionTypeShift, \ 196 | .extra = true \ 197 | } \ 198 | }} 199 | 200 | #define REDUCE(symbol_val, child_count_val, ...) \ 201 | {{ \ 202 | .reduce = { \ 203 | .type = TSParseActionTypeReduce, \ 204 | .symbol = symbol_val, \ 205 | .child_count = child_count_val, \ 206 | __VA_ARGS__ \ 207 | }, \ 208 | }} 209 | 210 | #define RECOVER() \ 211 | {{ \ 212 | .type = TSParseActionTypeRecover \ 213 | }} 214 | 215 | #define ACCEPT_INPUT() \ 216 | {{ \ 217 | .type = TSParseActionTypeAccept \ 218 | }} 219 | 220 | #ifdef __cplusplus 221 | } 222 | #endif 223 | 224 | #endif // TREE_SITTER_PARSER_H_ 225 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/src/tree_sitter/runtime.h: -------------------------------------------------------------------------------- 1 | #ifndef TREE_SITTER_RUNTIME_H_ 2 | #define TREE_SITTER_RUNTIME_H_ 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include 9 | #include 10 | 11 | typedef unsigned short TSSymbol; 12 | typedef struct TSLanguage TSLanguage; 13 | typedef struct TSDocument TSDocument; 14 | 15 | typedef enum { 16 | TSInputEncodingUTF8, 17 | TSInputEncodingUTF16, 18 | } TSInputEncoding; 19 | 20 | typedef struct { 21 | void *payload; 22 | const char *(*read_fn)(void *payload, size_t *bytes_read); 23 | int (*seek_fn)(void *payload, size_t character, size_t byte); 24 | TSInputEncoding encoding; 25 | } TSInput; 26 | 27 | typedef enum { 28 | TSDebugTypeParse, 29 | TSDebugTypeLex, 30 | } TSDebugType; 31 | 32 | typedef struct { 33 | void *payload; 34 | void (*debug_fn)(void *payload, TSDebugType, const char *); 35 | } TSDebugger; 36 | 37 | typedef struct { 38 | size_t position; 39 | size_t chars_inserted; 40 | size_t chars_removed; 41 | } TSInputEdit; 42 | 43 | typedef struct { 44 | size_t row; 45 | size_t column; 46 | } TSPoint; 47 | 48 | typedef struct { 49 | const void *data; 50 | size_t offset[3]; 51 | } TSNode; 52 | 53 | typedef struct { 54 | TSSymbol value; 55 | bool done; 56 | void *data; 57 | } TSSymbolIterator; 58 | 59 | size_t ts_node_start_char(TSNode); 60 | size_t ts_node_start_byte(TSNode); 61 | TSPoint ts_node_start_point(TSNode); 62 | size_t ts_node_end_char(TSNode); 63 | size_t ts_node_end_byte(TSNode); 64 | TSPoint ts_node_end_point(TSNode); 65 | TSSymbol ts_node_symbol(TSNode); 66 | TSSymbolIterator ts_node_symbols(TSNode); 67 | void ts_symbol_iterator_next(TSSymbolIterator *); 68 | const char *ts_node_name(TSNode, const TSDocument *); 69 | char *ts_node_string(TSNode, const TSDocument *); 70 | bool ts_node_eq(TSNode, TSNode); 71 | bool ts_node_is_named(TSNode); 72 | bool ts_node_has_changes(TSNode); 73 | TSNode ts_node_parent(TSNode); 74 | TSNode ts_node_child(TSNode, size_t); 75 | TSNode ts_node_named_child(TSNode, size_t); 76 | size_t ts_node_child_count(TSNode); 77 | size_t ts_node_named_child_count(TSNode); 78 | TSNode ts_node_next_sibling(TSNode); 79 | TSNode ts_node_next_named_sibling(TSNode); 80 | TSNode ts_node_prev_sibling(TSNode); 81 | TSNode ts_node_prev_named_sibling(TSNode); 82 | TSNode ts_node_descendant_for_range(TSNode, size_t, size_t); 83 | TSNode ts_node_named_descendant_for_range(TSNode, size_t, size_t); 84 | 85 | TSDocument *ts_document_make(); 86 | void ts_document_free(TSDocument *); 87 | const TSLanguage *ts_document_language(TSDocument *); 88 | void ts_document_set_language(TSDocument *, const TSLanguage *); 89 | TSInput ts_document_input(TSDocument *); 90 | void ts_document_set_input(TSDocument *, TSInput); 91 | void ts_document_set_input_string(TSDocument *, const char *); 92 | TSDebugger ts_document_debugger(const TSDocument *); 93 | void ts_document_set_debugger(TSDocument *, TSDebugger); 94 | void ts_document_print_debugging_graphs(TSDocument *, bool); 95 | void ts_document_edit(TSDocument *, TSInputEdit); 96 | int ts_document_parse(TSDocument *); 97 | void ts_document_invalidate(TSDocument *); 98 | TSNode ts_document_root_node(const TSDocument *); 99 | size_t ts_document_parse_count(const TSDocument *); 100 | 101 | size_t ts_language_symbol_count(const TSLanguage *); 102 | const char *ts_language_symbol_name(const TSLanguage *, TSSymbol); 103 | 104 | #define ts_builtin_sym_error ((TSSymbol)-1) 105 | #define ts_builtin_sym_end 0 106 | #define ts_builtin_sym_start 1 107 | 108 | #ifdef __cplusplus 109 | } 110 | #endif 111 | 112 | #endif // TREE_SITTER_RUNTIME_H_ 113 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/test/corpus/ambiguities.txt: -------------------------------------------------------------------------------- 1 | ================================================ 2 | template functions vs relational expressions 3 | ================================================ 4 | 5 | T1 a = b < c > d; 6 | T2 e = f(g); 7 | int a = std::get<0>(t); 8 | 9 | --- 10 | 11 | (translation_unit 12 | (declaration 13 | (type_identifier) 14 | (init_declarator 15 | (identifier) 16 | (binary_expression 17 | (binary_expression (identifier) (identifier)) 18 | (identifier)))) 19 | (declaration 20 | (type_identifier) 21 | (init_declarator 22 | (identifier) 23 | (call_expression 24 | (template_function (identifier) (template_argument_list 25 | (type_descriptor (type_identifier)))) 26 | (argument_list (identifier))))) 27 | (declaration 28 | (primitive_type) 29 | (init_declarator 30 | (identifier) 31 | (call_expression 32 | (qualified_identifier 33 | (namespace_identifier) 34 | (template_function 35 | (identifier) 36 | (template_argument_list (number_literal)))) 37 | (argument_list (identifier)))))) 38 | 39 | ================================================= 40 | function declarations vs variable initializations 41 | ================================================= 42 | 43 | // Function declarations 44 | T1 a(T2 *b); 45 | T3 c(T4 &d, T5 &&e); 46 | 47 | // Variable declarations with initializers 48 | T7 f(g.h); 49 | T6 i{j}; 50 | 51 | --- 52 | 53 | (translation_unit 54 | (comment) 55 | (declaration 56 | (type_identifier) 57 | (function_declarator 58 | (identifier) 59 | (parameter_list (parameter_declaration (type_identifier) (pointer_declarator (identifier)))))) 60 | (declaration 61 | (type_identifier) 62 | (function_declarator 63 | (identifier) 64 | (parameter_list 65 | (parameter_declaration (type_identifier) (reference_declarator (identifier))) 66 | (parameter_declaration (type_identifier) (reference_declarator (identifier)))))) 67 | 68 | (comment) 69 | (declaration 70 | (type_identifier) 71 | (init_declarator 72 | (identifier) 73 | (argument_list (field_expression (identifier) (field_identifier))))) 74 | (declaration 75 | (type_identifier) 76 | (init_declarator 77 | (identifier) 78 | (initializer_list (identifier))))) 79 | 80 | ================================================ 81 | template classes vs relational expressions 82 | ================================================ 83 | 84 | int main() { 85 | T1 v1; 86 | T1 v2 = v3; 87 | } 88 | 89 | --- 90 | 91 | (translation_unit (function_definition 92 | (primitive_type) 93 | (function_declarator (identifier) (parameter_list)) 94 | (compound_statement 95 | (declaration 96 | (template_type (type_identifier) 97 | (template_argument_list (type_descriptor (type_identifier)))) 98 | (identifier)) 99 | (declaration 100 | (template_type (type_identifier) 101 | (template_argument_list (type_descriptor (type_identifier)))) 102 | (init_declarator (identifier) (identifier)))))) 103 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/test/corpus/definitions.txt: -------------------------------------------------------------------------------- 1 | ===================================== 2 | Scoped function definitions 3 | ===================================== 4 | 5 | int T::foo() { return 1; } 6 | int T::foo() const { return 0; } 7 | 8 | --- 9 | 10 | (translation_unit 11 | (function_definition 12 | (primitive_type) 13 | (function_declarator 14 | (qualified_identifier (namespace_identifier) (identifier)) 15 | (parameter_list)) 16 | (compound_statement (return_statement (number_literal)))) 17 | (function_definition 18 | (primitive_type) 19 | (function_declarator 20 | (qualified_identifier (namespace_identifier) (identifier)) 21 | (parameter_list) 22 | (type_qualifier)) 23 | (compound_statement (return_statement (number_literal))))) 24 | 25 | ===================================== 26 | Constructor definitions 27 | ===================================== 28 | 29 | T::T() {} 30 | 31 | T::T() : f1(0), f2(1, 2) { 32 | puts("HI"); 33 | } 34 | 35 | T::T() : Base() {} 36 | 37 | T::T() try : f1(0) {} catch(...) {} 38 | 39 | --- 40 | 41 | (translation_unit 42 | (function_definition 43 | (function_declarator 44 | (qualified_identifier (namespace_identifier) (identifier)) 45 | (parameter_list)) 46 | (compound_statement)) 47 | (function_definition 48 | (function_declarator 49 | (qualified_identifier (namespace_identifier) (identifier)) 50 | (parameter_list)) 51 | (field_initializer_list 52 | (field_initializer (field_identifier) (argument_list (number_literal))) 53 | (field_initializer (field_identifier) (argument_list (number_literal) (number_literal)))) 54 | (compound_statement 55 | (expression_statement (call_expression (identifier) (argument_list (string_literal)))))) 56 | (function_definition 57 | (function_declarator 58 | (qualified_identifier (namespace_identifier) (identifier)) 59 | (parameter_list)) 60 | (field_initializer_list 61 | (field_initializer 62 | (template_method 63 | (field_identifier) 64 | (template_argument_list (type_descriptor (type_identifier)))) 65 | (argument_list))) 66 | (compound_statement)) 67 | (function_definition 68 | (function_declarator 69 | (qualified_identifier 70 | (namespace_identifier) 71 | (identifier)) 72 | (parameter_list)) 73 | (try_statement 74 | (field_initializer_list 75 | (field_initializer 76 | (field_identifier) 77 | (argument_list 78 | (number_literal)))) 79 | (compound_statement) 80 | (catch_clause 81 | (parameter_list) 82 | (compound_statement))))) 83 | 84 | ===================================== 85 | Explicit constructor definitions 86 | ===================================== 87 | 88 | class C { 89 | explicit C(int f) : f_(f) {} 90 | 91 | private: 92 | int f_; 93 | }; 94 | 95 | --- 96 | 97 | (translation_unit 98 | (class_specifier 99 | (type_identifier) 100 | (field_declaration_list 101 | (function_definition 102 | (explicit_function_specifier) 103 | (function_declarator 104 | (identifier) 105 | (parameter_list (parameter_declaration (primitive_type) (identifier)))) 106 | (field_initializer_list 107 | (field_initializer (field_identifier) (argument_list (identifier)))) 108 | (compound_statement)) 109 | (access_specifier) 110 | (field_declaration (primitive_type) (field_identifier))))) 111 | 112 | ===================================== 113 | Explicit constructor declaration 114 | ===================================== 115 | 116 | class C { 117 | explicit C(int f); 118 | explicit(true) C(long f); 119 | }; 120 | 121 | --- 122 | 123 | (translation_unit 124 | (class_specifier 125 | (type_identifier) 126 | (field_declaration_list 127 | (declaration 128 | (explicit_function_specifier) 129 | (function_declarator (identifier) (parameter_list (parameter_declaration (primitive_type) (identifier))))) 130 | (declaration 131 | (explicit_function_specifier (true)) 132 | (function_declarator (identifier) (parameter_list (parameter_declaration (sized_type_specifier) (identifier)))))))) 133 | 134 | ===================================== 135 | Default and deleted methods 136 | ===================================== 137 | 138 | class A : public B { 139 | A() = default; 140 | A(A &&) = delete; 141 | void f() = delete; 142 | A& operator=(const A&) = default; 143 | A& operator=(A&&) = delete; 144 | }; 145 | 146 | --- 147 | 148 | (translation_unit 149 | (class_specifier 150 | (type_identifier) 151 | (base_class_clause 152 | (access_specifier) 153 | (type_identifier)) 154 | (field_declaration_list 155 | (function_definition 156 | (function_declarator (identifier) (parameter_list)) 157 | (default_method_clause)) 158 | (function_definition 159 | (function_declarator 160 | (identifier) 161 | (parameter_list (parameter_declaration (type_identifier) (abstract_reference_declarator)))) 162 | (delete_method_clause)) 163 | (function_definition 164 | (primitive_type) 165 | (function_declarator (field_identifier) (parameter_list)) (delete_method_clause)) 166 | (function_definition 167 | (type_identifier) 168 | (reference_declarator 169 | (function_declarator 170 | (operator_name) 171 | (parameter_list (parameter_declaration (type_qualifier) (type_identifier) (abstract_reference_declarator))))) 172 | (default_method_clause)) 173 | (function_definition 174 | (type_identifier) 175 | (reference_declarator 176 | (function_declarator 177 | (operator_name) 178 | (parameter_list (parameter_declaration (type_identifier) (abstract_reference_declarator))))) 179 | (delete_method_clause))))) 180 | 181 | ===================================== 182 | Destructor definitions 183 | ===================================== 184 | 185 | ~T() {} 186 | T::~T() {} 187 | 188 | --- 189 | 190 | (translation_unit 191 | (function_definition 192 | (function_declarator (destructor_name (identifier)) (parameter_list)) 193 | (compound_statement)) 194 | (function_definition 195 | (function_declarator 196 | (qualified_identifier (namespace_identifier) (destructor_name (identifier))) (parameter_list)) 197 | (compound_statement))) 198 | 199 | ===================================== 200 | Function-try-block definitions 201 | ===================================== 202 | 203 | void foo() try {} catch(...) {} 204 | 205 | --- 206 | 207 | (translation_unit 208 | (function_definition 209 | (primitive_type) 210 | (function_declarator 211 | (identifier) 212 | (parameter_list)) 213 | (try_statement 214 | (compound_statement) 215 | (catch_clause 216 | (parameter_list) 217 | (compound_statement))))) 218 | 219 | 220 | ===================================== 221 | Conversion operator definitions 222 | ===================================== 223 | 224 | T::operator int() try { throw 1; } catch (...) { return 2; } 225 | 226 | --- 227 | 228 | (translation_unit 229 | (function_definition 230 | (qualified_identifier 231 | (namespace_identifier) 232 | (operator_cast 233 | (primitive_type) 234 | (abstract_function_declarator 235 | (parameter_list)))) 236 | (try_statement 237 | (compound_statement 238 | (throw_statement 239 | (number_literal))) 240 | (catch_clause 241 | (parameter_list) 242 | (compound_statement 243 | (return_statement 244 | (number_literal))))))) 245 | 246 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/test/corpus/microsoft.txt: -------------------------------------------------------------------------------- 1 | ================================ 2 | declaration specs 3 | ================================ 4 | 5 | struct __declspec(dllexport) s2 6 | { 7 | }; 8 | 9 | union __declspec(noinline) u2 { 10 | }; 11 | 12 | class __declspec(uuid) u2 { 13 | }; 14 | 15 | --- 16 | 17 | (translation_unit 18 | (struct_specifier 19 | (ms_declspec_modifier 20 | (identifier)) 21 | name: (type_identifier) 22 | body: (field_declaration_list)) 23 | (union_specifier 24 | (ms_declspec_modifier 25 | (identifier)) 26 | name: (type_identifier) 27 | body: (field_declaration_list)) 28 | (class_specifier 29 | (ms_declspec_modifier 30 | (identifier)) 31 | name: (type_identifier) 32 | body: (field_declaration_list))) 33 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/test/corpus/types.txt: -------------------------------------------------------------------------------- 1 | ========================================== 2 | The auto type 3 | ========================================== 4 | 5 | void foo() { 6 | auto x = 1; 7 | } 8 | 9 | --- 10 | 11 | (translation_unit 12 | (function_definition 13 | (primitive_type) 14 | (function_declarator (identifier) (parameter_list)) 15 | (compound_statement 16 | (declaration (placeholder_type_specifier (auto)) (init_declarator (identifier) (number_literal)))))) 17 | 18 | ========================================== 19 | Namespaced types 20 | ========================================== 21 | 22 | std::string my_string; 23 | std::vector::size_typ my_string; 24 | 25 | --- 26 | 27 | (translation_unit 28 | (declaration 29 | (qualified_identifier (namespace_identifier) (type_identifier)) 30 | (identifier)) 31 | (declaration 32 | (qualified_identifier 33 | (namespace_identifier) 34 | (qualified_identifier 35 | (template_type 36 | (type_identifier) 37 | (template_argument_list (type_descriptor (primitive_type)))) 38 | (type_identifier))) 39 | (identifier))) 40 | 41 | ========================================== 42 | Dependent type names 43 | ========================================== 44 | 45 | template 46 | struct X : B 47 | { 48 | typename T::A* pa; 49 | }; 50 | 51 | --- 52 | 53 | (translation_unit 54 | (template_declaration 55 | (template_parameter_list (type_parameter_declaration (type_identifier))) 56 | (struct_specifier 57 | (type_identifier) 58 | (base_class_clause 59 | (template_type (type_identifier) (template_argument_list (type_descriptor (type_identifier))))) 60 | (field_declaration_list 61 | (field_declaration 62 | (dependent_type (qualified_identifier (namespace_identifier) (type_identifier))) 63 | (pointer_declarator (field_identifier))))))) 64 | 65 | ========================================== 66 | Template types with empty argument lists 67 | ========================================== 68 | 69 | use_future_t<> use_future; 70 | 71 | --- 72 | 73 | (translation_unit 74 | (declaration (template_type (type_identifier) (template_argument_list)) (identifier))) 75 | 76 | ================================ 77 | Function types as template arguments 78 | ================================ 79 | 80 | typedef std::function MyFunc; 81 | typedef std::function b; 82 | 83 | --- 84 | 85 | (translation_unit 86 | (type_definition 87 | (qualified_identifier 88 | (namespace_identifier) 89 | (template_type 90 | (type_identifier) 91 | (template_argument_list 92 | (type_descriptor 93 | (type_identifier) 94 | (abstract_function_declarator (parameter_list 95 | (parameter_declaration (primitive_type)))))))) 96 | (type_identifier)) 97 | (type_definition 98 | (qualified_identifier 99 | (namespace_identifier) 100 | (template_type 101 | (type_identifier) 102 | (template_argument_list 103 | (type_descriptor 104 | (primitive_type) 105 | (abstract_function_declarator (parameter_list 106 | (parameter_declaration (primitive_type)))))))) 107 | (type_identifier))) 108 | 109 | ==================================================== 110 | Decltype 111 | ==================================================== 112 | 113 | decltype(A) x; 114 | decltype(B) foo(void x, decltype(C) y); 115 | template auto add(T t, U u) -> decltype(t + u); 116 | array arr; 117 | 118 | --- 119 | 120 | (translation_unit 121 | (declaration 122 | (decltype (identifier)) 123 | (identifier)) 124 | (declaration 125 | (decltype (identifier)) 126 | (function_declarator (identifier) 127 | (parameter_list 128 | (parameter_declaration (primitive_type) (identifier)) 129 | (parameter_declaration (decltype (identifier)) (identifier))))) 130 | (template_declaration 131 | (template_parameter_list 132 | (type_parameter_declaration (type_identifier)) (type_parameter_declaration (type_identifier))) 133 | (declaration 134 | (placeholder_type_specifier (auto)) 135 | (function_declarator 136 | (identifier) 137 | (parameter_list 138 | (parameter_declaration (type_identifier) (identifier)) 139 | (parameter_declaration (type_identifier) (identifier))) 140 | (trailing_return_type 141 | (type_descriptor 142 | (decltype (binary_expression (identifier) (identifier)))))))) 143 | (declaration 144 | (template_type 145 | (type_identifier) 146 | (template_argument_list 147 | (type_descriptor 148 | (qualified_identifier 149 | (decltype (identifier)) 150 | (type_identifier))) 151 | (number_literal))) 152 | (identifier))) 153 | 154 | ==================================================== 155 | Trailing return type 156 | ==================================================== 157 | 158 | auto a::foo() const -> const A& {} 159 | auto b::foo() const -> A const& {} 160 | 161 | --- 162 | 163 | (translation_unit 164 | (function_definition 165 | (placeholder_type_specifier (auto)) 166 | (function_declarator 167 | (qualified_identifier (namespace_identifier) (identifier)) 168 | (parameter_list) 169 | (type_qualifier) 170 | (trailing_return_type 171 | (type_descriptor 172 | (type_qualifier) 173 | (template_type (type_identifier) (template_argument_list (type_descriptor (type_identifier)))) 174 | (abstract_reference_declarator)))) 175 | (compound_statement)) 176 | (function_definition 177 | (placeholder_type_specifier (auto)) 178 | (function_declarator 179 | (qualified_identifier (namespace_identifier) (identifier)) 180 | (parameter_list) 181 | (type_qualifier) 182 | (trailing_return_type 183 | (type_descriptor 184 | (template_type (type_identifier) (template_argument_list (type_descriptor (type_identifier)))) 185 | (type_qualifier) 186 | (abstract_reference_declarator)))) 187 | (compound_statement)) 188 | ) 189 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/test/highlight/keywords.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | // ^ keyword 3 | 4 | namespace foo {} 5 | // ^ keyword 6 | 7 | template 8 | // ^ keyword 9 | // ^ keyword 10 | 11 | class A { 12 | // <- keyword 13 | 14 | public: 15 | // <- keyword 16 | private: 17 | // <- keyword 18 | protected: 19 | // <- keyword 20 | virtual ~A() = 0; 21 | // <- keyword 22 | }; 23 | 24 | int main() { 25 | throw new Error(); 26 | // ^ keyword 27 | // ^ keyword 28 | 29 | try { 30 | // <- keyword 31 | } catch (e) { 32 | // <- keyword 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /preprocess/sitter-libs/cpp/test/highlight/names.cpp: -------------------------------------------------------------------------------- 1 | int main() { 2 | a(); 3 | // <- function 4 | 5 | a::b(); 6 | // ^ function 7 | 8 | a::b(); 9 | // ^ function 10 | 11 | this->b(); 12 | // ^ function 13 | 14 | auto x = y; 15 | // <- type 16 | 17 | vector a; 18 | // <- type 19 | 20 | std::vector a; 21 | // ^ type 22 | } 23 | 24 | class C : D{ 25 | A(); 26 | // <- function 27 | 28 | void efg() { 29 | // ^ function 30 | } 31 | } 32 | 33 | void A::b() { 34 | // ^ function 35 | } 36 | -------------------------------------------------------------------------------- /preprocess/tokenize.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os import listdir, remove 3 | from os.path import splitext, join, isdir, exists 4 | 5 | import youtokentome as yttm 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | 9 | def tokenize(config: DictConfig): 10 | data_path = join(config.data_folder, config.dataset.name, config.dataset.dir) 11 | model_path = join(data_path, config.dataset.tokenizer_name) 12 | buffer_path = "text.yttm" 13 | if exists(buffer_path): 14 | remove(buffer_path) 15 | 16 | for file in listdir(join(data_path, config.train_holdout)): 17 | transformed_files_path = join(data_path, config.train_holdout, file) 18 | if isdir(transformed_files_path): 19 | for transformed_file in listdir(transformed_files_path): 20 | file_path = join(transformed_files_path, transformed_file) 21 | _, ext = splitext(file_path) 22 | if ext in [".cpp", ".c"]: 23 | with open(file_path, "r", encoding="utf8", errors='ignore') as file_: 24 | text = file_.read() + "\n" 25 | with open(buffer_path, "a") as buffer_: 26 | buffer_.write(text) 27 | 28 | _ = yttm.BPE.train( 29 | data="text.yttm", 30 | model=model_path, 31 | pad_id=config.dataset.pad_id, 32 | unk_id=config.dataset.unk_id, 33 | bos_id=config.dataset.bos_id, 34 | eos_id=config.dataset.eos_id, 35 | vocab_size=config.dataset.vocab_size, 36 | n_threads=config.num_workers 37 | ) 38 | 39 | remove("text.yttm") 40 | 41 | 42 | if __name__ == "__main__": 43 | arg_parser = ArgumentParser() 44 | arg_parser.add_argument("--config_path", type=str) 45 | args = arg_parser.parse_args() 46 | config_ = OmegaConf.load(args.config_path) 47 | tokenize(config=config_) 48 | -------------------------------------------------------------------------------- /preprocess/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_swap_transformations import * 2 | from .confusion_remove import * 3 | from .dead_code_inserter import * 4 | from .demo_transformation import * 5 | from .for_while_transformation import * 6 | from .operand_swap_transformations import * 7 | from .transformation_base import * 8 | from .no_transform import * 9 | from .syntactic_noising_transformation import * 10 | from .transformation_main import * 11 | from .var_renaming_transformation import * -------------------------------------------------------------------------------- /preprocess/transformations/block_swap_transformations.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import re 4 | from typing import Union, Tuple 5 | 6 | import numpy as np 7 | 8 | from src.data_preprocessors.language_processors import ( 9 | JavaAndCPPProcessor, 10 | CSharpProcessor, 11 | PythonProcessor, 12 | JavascriptProcessor, 13 | PhpProcessor, 14 | RubyProcessor, 15 | GoProcessor 16 | ) 17 | from src.data_preprocessors.transformations.transformation_base import TransformationBase 18 | 19 | processor_function = { 20 | "java": [JavaAndCPPProcessor.block_swap_java], 21 | "c": [JavaAndCPPProcessor.block_swap_c], 22 | "cpp": [JavaAndCPPProcessor.block_swap_c], 23 | "c_sharp": [CSharpProcessor.block_swap], 24 | "python": [PythonProcessor.block_swap], 25 | "javascript": [JavascriptProcessor.block_swap], 26 | "go": [GoProcessor.block_swap], 27 | "php": [PhpProcessor.block_swap], 28 | "ruby": [RubyProcessor.block_swap], 29 | } 30 | 31 | 32 | class BlockSwap(TransformationBase): 33 | """ 34 | Swapping if_else block 35 | """ 36 | 37 | def __init__(self, parser_path, language): 38 | super(BlockSwap, self).__init__(parser_path=parser_path, language=language) 39 | self.language = language 40 | self.transformations = processor_function[language] 41 | processor_map = { 42 | "java": self.get_tokens_with_node_type, 43 | "c": self.get_tokens_with_node_type, 44 | "cpp": self.get_tokens_with_node_type, 45 | "c_sharp": self.get_tokens_with_node_type, 46 | "javascript": JavascriptProcessor.get_tokens, 47 | "python": PythonProcessor.get_tokens, 48 | "php": PhpProcessor.get_tokens, 49 | "ruby": self.get_tokens_with_node_type, 50 | "go": self.get_tokens_with_node_type, 51 | } 52 | self.final_processor = processor_map[self.language] 53 | 54 | def transform_code( 55 | self, 56 | code: Union[str, bytes], 57 | ) -> Tuple[str, object]: 58 | success = False 59 | transform_functions = copy.deepcopy(self.transformations) 60 | while not success and len(transform_functions) > 0: 61 | function = np.random.choice(transform_functions) 62 | transform_functions.remove(function) 63 | modified_code, success = function(code, self) 64 | if success: 65 | code = modified_code 66 | root_node = self.parse_code( 67 | code=code 68 | ) 69 | return_values = self.final_processor( 70 | code=code.encode(), 71 | root=root_node 72 | ) 73 | if isinstance(return_values, tuple): 74 | tokens, types = return_values 75 | else: 76 | tokens, types = return_values, None 77 | return re.sub("[ \t\n]+", " ", " ".join(tokens)), \ 78 | { 79 | "types": types, 80 | "success": success 81 | } 82 | 83 | 84 | if __name__ == '__main__': 85 | java_code = """ 86 | void foo(){ 87 | int time = 20; 88 | if (time < 18) { 89 | time=10; 90 | } 91 | else { 92 | System.out.println("Good evening."); 93 | } 94 | } 95 | """ 96 | python_code = """ 97 | from typing import List 98 | 99 | def factorize(n: int) -> List[int]: 100 | import math 101 | fact = [] 102 | i = 2 103 | while i <= int(math.sqrt(n) + 1): 104 | if n % i == 0: 105 | fact.append(i) 106 | n //= i 107 | else: 108 | i += 1 109 | if n > 1: 110 | fact.append(n) 111 | return fact 112 | """ 113 | c_code = """ 114 | void foo(){ 115 | int time = 20; 116 | if (time < 18) { 117 | time=10; 118 | } 119 | else { 120 | System.out.println("Good evening."); 121 | } 122 | } 123 | """ 124 | cs_code = """ 125 | void foo(){ 126 | int time = 20; 127 | if (time < 18) { 128 | time=10; 129 | } 130 | else { 131 | System.out.println("Good evening."); 132 | } 133 | } 134 | """ 135 | js_code = """function foo(n) { 136 | if (time < 10) { 137 | greeting = "Good morning"; 138 | } 139 | else { 140 | greeting = "Good evening"; 141 | } 142 | } 143 | """ 144 | ruby_code = """ 145 | x = 1 146 | if x > 2 147 | puts "x is greater than 2" 148 | else 149 | puts "I can't guess the number" 150 | end 151 | """ 152 | go_code = """ 153 | func main() { 154 | /* local variable definition */ 155 | var a int = 100; 156 | 157 | /* check the boolean condition */ 158 | if( a < 20 ) { 159 | /* if condition is true then print the following */ 160 | fmt.Printf("a is less than 20\n" ); 161 | } else { 162 | /* if condition is false then print the following */ 163 | fmt.Printf("a is not less than 20\n" ); 164 | } 165 | fmt.Printf("value of a is : %d\n", a); 166 | } 167 | """ 168 | php_code = """ 169 | 177 | """ 178 | input_map = { 179 | "java": ("java", java_code), 180 | "c": ("c", c_code), 181 | "cpp": ("cpp", c_code), 182 | "cs": ("c_sharp", cs_code), 183 | "js": ("javascript", js_code), 184 | "python": ("python", python_code), 185 | "php": ("php", php_code), 186 | "ruby": ("ruby", ruby_code), 187 | "go": ("go", go_code), 188 | } 189 | code_directory = os.path.realpath(os.path.join(os.path.realpath(__file__), '../../../../')) 190 | parser_path = os.path.join(code_directory, "parser/languages.so") 191 | for lang in ["java", "python", "js", "c", "cpp", "php", "go", "ruby", "cs"]: 192 | lang, code = input_map[lang] 193 | no_transform = BlockSwap( 194 | "/home/saikatc/HDD_4TB/NatGen/parser/languages.so", lang 195 | ) 196 | print(lang) 197 | code, meta = no_transform.transform_code(code) 198 | code = re.sub("[ \t\n]+", " ", code) 199 | if lang == "python": 200 | code = PythonProcessor.beautify_python_code(code.split()) 201 | print(code) 202 | # print(re.sub("[ \t\n]+", " ", code)) 203 | print(meta) 204 | print("=" * 150) 205 | -------------------------------------------------------------------------------- /preprocess/transformations/confusion_remove.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Union, Tuple 3 | 4 | import numpy as np 5 | import os 6 | 7 | from src.data_preprocessors.language_processors import ( 8 | JavaAndCPPProcessor, 9 | PythonProcessor, 10 | JavascriptProcessor, 11 | PhpProcessor 12 | ) 13 | from src.data_preprocessors.transformations import TransformationBase 14 | 15 | processor_function = { 16 | "java": [JavaAndCPPProcessor.incre_decre_removal, JavaAndCPPProcessor.ternary_removal], 17 | "c": [JavaAndCPPProcessor.incre_decre_removal, JavaAndCPPProcessor.conditional_removal], 18 | "cpp": [JavaAndCPPProcessor.incre_decre_removal, JavaAndCPPProcessor.conditional_removal], 19 | "c_sharp": [JavaAndCPPProcessor.incre_decre_removal, JavaAndCPPProcessor.conditional_removal] 20 | } 21 | 22 | 23 | class ConfusionRemover(TransformationBase): 24 | """ 25 | Change the `for` loops with `while` loops and vice versa. 26 | """ 27 | 28 | def __init__(self, parser_path, language): 29 | super(ConfusionRemover, self).__init__(parser_path=parser_path, language=language) 30 | self.language = language 31 | if language in processor_function: 32 | self.transformations = processor_function[language] 33 | else: 34 | self.transformations = [] 35 | processor_map = { 36 | "java": self.get_tokens_with_node_type, # yes 37 | "c": self.get_tokens_with_node_type, # yes 38 | "cpp": self.get_tokens_with_node_type, # yes 39 | "c_sharp": self.get_tokens_with_node_type, # yes 40 | "javascript": JavascriptProcessor.get_tokens, # yes 41 | "python": PythonProcessor.get_tokens, # no 42 | "php": PhpProcessor.get_tokens, # yes 43 | "ruby": self.get_tokens_with_node_type, # yes 44 | "go": self.get_tokens_with_node_type, # no 45 | } 46 | self.final_processor = processor_map[self.language] 47 | 48 | def transform_code( 49 | self, 50 | code: Union[str, bytes], 51 | ) -> Tuple[str, object]: 52 | success = False 53 | transform_functions = copy.deepcopy(self.transformations) 54 | while not success and len(transform_functions) > 0: 55 | function = np.random.choice(transform_functions) 56 | transform_functions.remove(function) 57 | modified_root, modified_code, success = function(code, self) 58 | if success: 59 | code = modified_code 60 | root_node = self.parse_code( 61 | code=code 62 | ) 63 | return_values = self.final_processor( 64 | code=code.encode(), 65 | root=root_node 66 | ) 67 | if isinstance(return_values, tuple): 68 | tokens, types = return_values 69 | else: 70 | tokens, types = return_values, None 71 | return " ".join(tokens), \ 72 | { 73 | "types": types, 74 | "success": success 75 | } 76 | 77 | 78 | if __name__ == '__main__': 79 | java_code = """ 80 | class A{ 81 | int foo(int[] nums, int lower, upper){ 82 | for(int i = 0; i < n; i++) { 83 | static long start = i == 0 ? lower : (long)nums[i - 1] + 1; 84 | static long end = i == nums.length ? upper : (long)nums[i] - 1; 85 | start = (lower + nums[j] > upper) ? lower + nums[j] : upper; 86 | lower += 1; 87 | lower = upper++; 88 | lower = ++upper; 89 | } 90 | return i == end ? -1 : start; 91 | } 92 | } 93 | """ 94 | python_code = """def foo(n): 95 | res = 0 96 | for i in range(0, 19, 2): 97 | res += i 98 | i = 0 99 | while i in range(n): 100 | res += i 101 | i += 1 102 | return res 103 | """ 104 | c_code = """ 105 | int foo(int n){ 106 | int res; 107 | for(int i = 0; i < n; i++) { 108 | int j = 0; 109 | if (j == 0) { i = j; } 110 | j = (j == 0) ? (i + j) : i - j; 111 | int i = (i == 0) ? (i + j) : i - j; 112 | } 113 | i = j ++; 114 | j = i--; 115 | j = -- i; 116 | j = ++i; 117 | return i == 0 ? -1 : j; 118 | } 119 | 120 | """ 121 | cs_code = """int foo(int n){ 122 | x = n++; 123 | n = x--; 124 | x = ++n; 125 | n = ++x; 126 | return x != 0.0 ? Math.Sin(x) / x : 1.0; 127 | } 128 | """ 129 | js_code = """function foo(n) { 130 | let res = ''; 131 | for(let i = 0; i < 10; i++){ 132 | res += i.toString(); 133 | res += '
'; 134 | } 135 | while ( i < 10 ; ) { 136 | res += 'bk'; 137 | } 138 | return res; 139 | } 140 | """ 141 | ruby_code = """ 142 | for i in 0..5 143 | puts "Value of local variable is #{i}" 144 | end 145 | """ 146 | go_code = """ 147 | func main() { 148 | sum := 0; 149 | i := 0; 150 | for ; i < 10; { 151 | sum += i; 152 | } 153 | i++; 154 | fmt.Println(sum); 155 | } 156 | """ 157 | php_code = """ 158 | "; 161 | } 162 | $x = 0 ; 163 | while ( $x <= 10 ) { 164 | echo "The number is: $x
"; 165 | $x++; 166 | } 167 | ?> 168 | """ 169 | input_map = { 170 | "java": ("java", java_code), 171 | "c": ("c", c_code), 172 | "cpp": ("cpp", c_code), 173 | "cs": ("c_sharp", cs_code), 174 | "js": ("javascript", js_code), 175 | "python": ("python", python_code), 176 | "php": ("php", php_code), 177 | "ruby": ("ruby", ruby_code), 178 | "go": ("go", go_code), 179 | } 180 | code_directory = os.path.realpath(os.path.join(os.path.realpath(__file__), '../../../../')) 181 | parser_path = os.path.join(code_directory, "parser/languages.so") 182 | for lang in ["c", "cpp", "java", "cs", "python", "php", "go", "ruby"]: 183 | # lang = "php" 184 | lang, code = input_map[lang] 185 | confusion_remover = ConfusionRemover( 186 | parser_path, lang 187 | ) 188 | print(lang) 189 | code, types = confusion_remover.transform_code(code) 190 | print(types["success"]) 191 | print("=" * 100) 192 | -------------------------------------------------------------------------------- /preprocess/transformations/dead_code_inserter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Union, Tuple 3 | import os 4 | 5 | import numpy as np 6 | 7 | from src.data_preprocessors.language_processors import ( 8 | JavaAndCPPProcessor, 9 | CSharpProcessor, 10 | PythonProcessor, 11 | JavascriptProcessor, 12 | PhpProcessor, 13 | ) 14 | from src.data_preprocessors.language_processors.go_processor import GoProcessor 15 | from src.data_preprocessors.language_processors.ruby_processor import RubyProcessor 16 | from src.data_preprocessors.language_processors.utils import extract_statement_within_size, get_tokens, \ 17 | get_tokens_insert_before, count_nodes 18 | from src.data_preprocessors.transformations import TransformationBase 19 | 20 | processor_function = { 21 | "java": JavaAndCPPProcessor, 22 | "c": JavaAndCPPProcessor, 23 | "cpp": JavaAndCPPProcessor, 24 | "c_sharp": CSharpProcessor, 25 | "python": PythonProcessor, 26 | "javascript": JavascriptProcessor, 27 | "go": GoProcessor, 28 | "php": PhpProcessor, 29 | "ruby": RubyProcessor, 30 | } 31 | 32 | tokenizer_function = { 33 | "java": get_tokens, 34 | "c": get_tokens, 35 | "cpp": get_tokens, 36 | "c_sharp": get_tokens, 37 | "python": PythonProcessor.get_tokens, 38 | "javascript": JavascriptProcessor.get_tokens, 39 | "go": get_tokens, 40 | "php": PhpProcessor.get_tokens, 41 | "ruby": get_tokens, 42 | } 43 | 44 | insertion_function = { 45 | "java": get_tokens_insert_before, 46 | "c": get_tokens_insert_before, 47 | "cpp": get_tokens_insert_before, 48 | "c_sharp": get_tokens_insert_before, 49 | "python": PythonProcessor.get_tokens_insert_before, 50 | "javascript": JavascriptProcessor.get_tokens_insert_before, 51 | "go": get_tokens_insert_before, 52 | "php": PhpProcessor.get_tokens_insert_before, 53 | "ruby": get_tokens_insert_before, 54 | } 55 | 56 | 57 | class DeadCodeInserter(TransformationBase): 58 | def __init__( 59 | self, 60 | parser_path: str, 61 | language: str 62 | ): 63 | super(DeadCodeInserter, self).__init__( 64 | parser_path=parser_path, 65 | language=language, 66 | ) 67 | self.language = language 68 | self.processor = processor_function[self.language] 69 | self.tokenizer_function = tokenizer_function[self.language] 70 | self.insertion_function = insertion_function[self.language] 71 | 72 | def insert_random_dead_code(self, code_string, max_node_in_statement=-1): 73 | root = self.parse_code(code_string) 74 | original_node_count = count_nodes(root) 75 | if max_node_in_statement == -1: 76 | max_node_in_statement = int(original_node_count / 2) 77 | if self.language == "ruby": 78 | statement_markers = ["assignment", "until", "call", "if", "for", "while"] 79 | else: 80 | statement_markers = None 81 | statements = extract_statement_within_size( 82 | root, max_node_in_statement, statement_markers, 83 | code_string=code_string, tokenizer=self.tokenizer_function, 84 | ) 85 | original_code = " ".join(self.tokenizer_function(code_string, root)) 86 | try: 87 | while len(statements) > 0: 88 | random_stmt, insert_before = np.random.choice(statements, 2) 89 | statements.remove(random_stmt) 90 | dead_coed_body = " ".join(self.tokenizer_function(code_string, random_stmt)).strip() 91 | dead_code_function = np.random.choice( 92 | [ 93 | self.processor.create_dead_for_loop, 94 | self.processor.create_dead_while_loop, 95 | self.processor.create_dead_if 96 | ] 97 | ) 98 | dead_code = dead_code_function(dead_coed_body) 99 | modified_code = " ".join( 100 | self.insertion_function( 101 | code_str=code_string, root=root, insertion_code=dead_code, 102 | insert_before_node=insert_before 103 | ) 104 | ) 105 | if modified_code != original_code: 106 | modified_root = self.parse_code(" ".join(modified_code)) 107 | return modified_root, modified_code, True 108 | except: 109 | pass 110 | return root, original_code, False 111 | 112 | def transform_code( 113 | self, 114 | code: Union[str, bytes] 115 | ) -> Tuple[str, object]: 116 | root, code, success = self.insert_random_dead_code(code, -1) 117 | code = re.sub("[ \n\t]+", " ", code) 118 | return code, { 119 | "success": success 120 | } 121 | 122 | 123 | if __name__ == '__main__': 124 | java_code = """ 125 | class A{ 126 | int foo(int n){ 127 | int res = 0; 128 | for(int i = 0; i < n; i++) { 129 | int j = 0; 130 | while (j < i){ 131 | res += j; 132 | } 133 | } 134 | return res; 135 | } 136 | } 137 | """ 138 | python_code = """ 139 | def foo(n): 140 | res = 0 141 | for i in range(0, 19, 2): 142 | res += i 143 | i = 0 144 | while i in range(n): 145 | res += i 146 | i += 1 147 | return res 148 | """ 149 | c_code = """ 150 | int foo(int n){ 151 | int res = 0; 152 | for(int i = 0; i < n; i++) { 153 | int j = 0; 154 | while (j < i){ 155 | res += j; 156 | } 157 | } 158 | return res; 159 | } 160 | """ 161 | cs_code = """ 162 | int foo(int n){ 163 | int res = 0, i = 0; 164 | while(i < n) { 165 | int j = 0; 166 | while (j < i){ 167 | res += j; 168 | } 169 | } 170 | return res; 171 | } 172 | """ 173 | js_code = """function foo(n) { 174 | let res = ''; 175 | for(let i = 0; i < 10; i++){ 176 | res += i.toString(); 177 | res += '
'; 178 | } 179 | while ( i < 10 ; ) { 180 | res += 'bk'; 181 | } 182 | return res; 183 | } 184 | """ 185 | ruby_code = """ 186 | for i in 0..5 do 187 | puts "Value of local variable is #{i}" 188 | if false then 189 | puts "False printed" 190 | while i == 10 do 191 | print i; 192 | end 193 | i = u + 8 194 | end 195 | end 196 | """ 197 | go_code = """ 198 | func main() { 199 | sum := 0; 200 | i := 0; 201 | for ; i < 10; { 202 | sum += i; 203 | } 204 | i++; 205 | fmt.Println(sum); 206 | } 207 | """ 208 | php_code = """ 209 | "; 212 | } 213 | $x = 0 ; 214 | while ( $x <= 10 ) { 215 | echo "The number is: $x
"; 216 | $x++; 217 | } 218 | ?> 219 | """ 220 | input_map = { 221 | "java": ("java", java_code), 222 | "c": ("c", c_code), 223 | "cpp": ("cpp", c_code), 224 | "cs": ("c_sharp", cs_code), 225 | "js": ("javascript", js_code), 226 | "python": ("python", python_code), 227 | "php": ("php", php_code), 228 | "ruby": ("ruby", ruby_code), 229 | "go": ("go", go_code), 230 | } 231 | code_directory = os.path.realpath(os.path.join(os.path.realpath(__file__), '../../../../')) 232 | parser_path = os.path.join(code_directory, "parser/languages.so") 233 | for lang in ["c", "cpp", "java", "python", "php", "ruby", "js", "go", "cs"]: 234 | lang, code = input_map[lang] 235 | dead_code_inserter = DeadCodeInserter( 236 | parser_path, lang 237 | ) 238 | print(lang) 239 | code, meta = dead_code_inserter.transform_code(code) 240 | if lang == "python": 241 | code = PythonProcessor.beautify_python_code(code.split()) 242 | print(code) 243 | print(meta) 244 | print("=" * 150) 245 | -------------------------------------------------------------------------------- /preprocess/transformations/demo_transformation.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | from src.data_preprocessors.transformations import TransformationBase 4 | 5 | 6 | class DemoTransformation(TransformationBase): 7 | def __init__(self, parser, language): 8 | super(DemoTransformation, self).__init__( 9 | parser_path=parser, 10 | language=language, 11 | ) 12 | 13 | def transform_code( 14 | self, 15 | code: Union[str, bytes] 16 | ) -> Tuple[str, object]: 17 | root_node = self.parse_code( 18 | code=code 19 | ) 20 | tokens, types = self.get_tokens_with_node_type( 21 | code=code.encode(), 22 | root=root_node 23 | ) 24 | return " ".join(tokens), types 25 | -------------------------------------------------------------------------------- /preprocess/transformations/no_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from tree_sitter import Language, Parser 4 | from typing import Union, Tuple 5 | 6 | from src.data_preprocessors.language_processors import ( 7 | PythonProcessor, 8 | JavascriptProcessor, 9 | PhpProcessor 10 | ) 11 | from src.data_preprocessors.transformations import TransformationBase 12 | 13 | 14 | class NoTransformation(TransformationBase): 15 | def __init__(self, parser_path: str, language: str) -> object: 16 | super().__init__(parser_path, language) 17 | if not os.path.exists(parser_path): 18 | raise ValueError( 19 | f"Language parser does not exist at {parser_path}. Please run `setup.sh` to properly set the " 20 | f"environment!") 21 | self.lang_object = Language(parser_path, language) 22 | self.parser = Parser() 23 | self.parser.set_language(self.lang_object) 24 | processor_map = { 25 | "java": self.get_tokens_with_node_type, 26 | "c": self.get_tokens_with_node_type, 27 | "cpp": self.get_tokens_with_node_type, 28 | "c_sharp": self.get_tokens_with_node_type, 29 | "javascript": JavascriptProcessor.get_tokens, 30 | "python": PythonProcessor.get_tokens, 31 | "php": PhpProcessor.get_tokens, 32 | "ruby": self.get_tokens_with_node_type, 33 | "go": self.get_tokens_with_node_type, 34 | } 35 | self.processor = processor_map[language] 36 | 37 | def transform_code( 38 | self, 39 | code: Union[str, bytes] 40 | ) -> Tuple[str, object]: 41 | root_node = self.parse_code( 42 | code=code 43 | ) 44 | return_values = self.processor( 45 | code=code.encode(), 46 | root=root_node 47 | ) 48 | if isinstance(return_values, tuple): 49 | tokens, types = return_values 50 | else: 51 | tokens, types = return_values, None 52 | return re.sub("[ \t\n]+", " ", " ".join(tokens)), \ 53 | { 54 | "types": types, 55 | "success": False 56 | } 57 | -------------------------------------------------------------------------------- /preprocess/transformations/operand_swap_transformations.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import re 4 | from typing import Union, Tuple 5 | 6 | import numpy as np 7 | 8 | from src.data_preprocessors.language_processors import ( 9 | JavaAndCPPProcessor, 10 | CSharpProcessor, 11 | PythonProcessor, 12 | JavascriptProcessor, 13 | PhpProcessor, 14 | GoProcessor, 15 | RubyProcessor 16 | ) 17 | from src.data_preprocessors.transformations import TransformationBase 18 | 19 | processor_function = { 20 | "java": [JavaAndCPPProcessor.operand_swap], 21 | "c": [JavaAndCPPProcessor.operand_swap], 22 | "cpp": [JavaAndCPPProcessor.operand_swap], 23 | "c_sharp": [CSharpProcessor.operand_swap], 24 | "python": [PythonProcessor.operand_swap], 25 | "javascript": [JavascriptProcessor.operand_swap], 26 | "go": [GoProcessor.operand_swap], 27 | "php": [PhpProcessor.operand_swap], 28 | "ruby": [RubyProcessor.operand_swap], 29 | } 30 | 31 | 32 | class OperandSwap(TransformationBase): 33 | """ 34 | Swapping Operand "a>b" becomes "b Tuple[str, object]: 58 | success = False 59 | transform_functions = copy.deepcopy(self.transformations) 60 | while not success and len(transform_functions) > 0: 61 | function = np.random.choice(transform_functions) 62 | transform_functions.remove(function) 63 | modified_code, success = function(code, self) 64 | if success: 65 | code = modified_code 66 | root_node = self.parse_code( 67 | code=code 68 | ) 69 | return_values = self.final_processor( 70 | code=code.encode(), 71 | root=root_node 72 | ) 73 | if isinstance(return_values, tuple): 74 | tokens, types = return_values 75 | else: 76 | tokens, types = return_values, None 77 | return re.sub("[ \t\n]+", " ", " ".join(tokens)), \ 78 | { 79 | "types": types, 80 | "success": success 81 | } 82 | 83 | 84 | if __name__ == '__main__': 85 | java_code = """ 86 | void foo(){ 87 | int time = 20; 88 | if (time < 18) { 89 | time=10; 90 | } 91 | else { 92 | System.out.println("Good evening."); 93 | } 94 | } 95 | """ 96 | python_code = """ 97 | from typing import List 98 | 99 | def factorize(n: int) -> List[int]: 100 | import math 101 | fact = [] 102 | i = 2 103 | while i <= int(math.sqrt(n) + 1): 104 | if n % i == 0: 105 | fact.append(i) 106 | n //= i 107 | else: 108 | i += 1 109 | if n > 1: 110 | fact.append(n) 111 | return fact 112 | """ 113 | c_code = """ 114 | void foo(){ 115 | int time = 20; 116 | if (time < 18) { 117 | time=10; 118 | } 119 | else { 120 | System.out.println("Good evening."); 121 | } 122 | } 123 | """ 124 | cs_code = """ 125 | void foo(){ 126 | int time = 20; 127 | if (time < 18) { 128 | time=10; 129 | } 130 | else { 131 | System.out.println("Good evening."); 132 | } 133 | } 134 | """ 135 | js_code = """function foo(n) { 136 | if (time < 10) { 137 | greeting = "Good morning"; 138 | } 139 | else { 140 | greeting = "Good evening"; 141 | } 142 | } 143 | """ 144 | ruby_code = """ 145 | x = 1 146 | if x > 2 147 | puts "x is greater than 2" 148 | else 149 | puts "I can't guess the number" 150 | end 151 | """ 152 | go_code = """ 153 | func main() { 154 | /* local variable definition */ 155 | var a int = 100; 156 | 157 | /* check the boolean condition */ 158 | if( a < 20 ) { 159 | /* if condition is true then print the following */ 160 | fmt.Printf("a is less than 20\n" ); 161 | } else { 162 | /* if condition is false then print the following */ 163 | fmt.Printf("a is not less than 20\n" ); 164 | } 165 | fmt.Printf("value of a is : %d\n", a); 166 | } 167 | """ 168 | php_code = """ 169 | 177 | """ 178 | input_map = { 179 | "java": ("java", java_code), 180 | "c": ("c", c_code), 181 | "cpp": ("cpp", c_code), 182 | "cs": ("c_sharp", cs_code), 183 | "js": ("javascript", js_code), 184 | "python": ("python", python_code), 185 | "php": ("php", php_code), 186 | "ruby": ("ruby", ruby_code), 187 | "go": ("go", go_code), 188 | } 189 | code_directory = os.path.realpath(os.path.join(os.path.realpath(__file__), '../../../../')) 190 | parser_path = os.path.join(code_directory, "parser/languages.so") 191 | for lang in ["java", "python", "js", "c", "cpp", "php", "go", "ruby", 192 | "cs"]: # ["c", "cpp", "java", "cs", "python", 193 | # "php", "go", "ruby"]: 194 | # lang = "php" 195 | lang, code = input_map[lang] 196 | operandswap = OperandSwap( 197 | parser_path, lang 198 | ) 199 | print(lang) 200 | # print("-" * 150) 201 | # print(code) 202 | # print("-" * 150) 203 | code, meta = operandswap.transform_code(code) 204 | print(meta["success"]) 205 | print("=" * 150) 206 | -------------------------------------------------------------------------------- /preprocess/transformations/syntactic_noising_transformation.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | import nltk 4 | import numpy as np 5 | 6 | from src.data_preprocessors.transformations import TransformationBase, NoTransformation 7 | 8 | 9 | def masking(tokens, p): 10 | new_tokens = [] 11 | for t in tokens: 12 | if np.random.uniform() < p: 13 | new_tokens.append("") 14 | else: 15 | new_tokens.append(t) 16 | return " ".join(new_tokens) 17 | 18 | 19 | def deletion(tokens, p): 20 | new_tokens = [] 21 | for t in tokens: 22 | if np.random.uniform() >= p: 23 | new_tokens.append(t) 24 | return " ".join(new_tokens) 25 | 26 | 27 | def token_infilling(tokens, p): 28 | new_tokens = [] 29 | max_infilling_len = round(int(p * len(tokens)) / 2.) 30 | infilling_len = np.random.randint(1, max_infilling_len) 31 | start_index = np.random.uniform(high=(len(tokens) - infilling_len)) 32 | end_index = start_index + infilling_len 33 | for i, t in enumerate(tokens): 34 | if i < start_index or i > end_index: 35 | new_tokens.append(t) 36 | return " ".join(new_tokens) 37 | 38 | 39 | class SyntacticNoisingTransformation(TransformationBase): 40 | def __init__(self, parser_path: str, language: str, noise_ratio=0.15): 41 | # super().__init__(parser_path, language) 42 | self.language = language 43 | if self.language == "nl": 44 | self.tokenizer = nltk.word_tokenize 45 | else: 46 | self.tokenizer = NoTransformation(parser_path, language) 47 | self.noise_ratio = noise_ratio 48 | 49 | def transform_code( 50 | self, 51 | code: Union[str, bytes] 52 | ) -> Tuple[str, object]: 53 | if self.language == "nl": 54 | tokens = self.tokenizer(code) 55 | else: 56 | tokenized_code, _ = self.tokenizer.transform_code(code) 57 | tokens = tokenized_code.split() 58 | p = np.random.uniform() 59 | if p < 0.33: 60 | transformed_code = masking(tokens, self.noise_ratio) 61 | elif p < 0.66: 62 | transformed_code = deletion(tokens, self.noise_ratio) 63 | else: 64 | transformed_code = token_infilling(tokens, self.noise_ratio) 65 | return transformed_code, { 66 | "success": True 67 | } 68 | -------------------------------------------------------------------------------- /preprocess/transformations/transformation_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Tuple, List 3 | 4 | import tree_sitter 5 | from tree_sitter import Language, Parser 6 | 7 | 8 | def get_ancestor_type_chains( 9 | node: tree_sitter.Node 10 | ) -> List[str]: 11 | types = [str(node.type)] 12 | while node.parent is not None: 13 | node = node.parent 14 | types.append(str(node.type)) 15 | return types 16 | 17 | 18 | class TransformationBase: 19 | def __init__( 20 | self, 21 | parser_path: str, 22 | language: str 23 | ): 24 | if not os.path.exists(parser_path): 25 | raise ValueError( 26 | f"Language parser does not exist at {parser_path}. Please run `setup.sh` to properly set the " 27 | f"environment!") 28 | self.lang_object = Language(parser_path, language) 29 | self.parser = Parser() 30 | self.parser.set_language(self.lang_object) 31 | pass 32 | 33 | def parse_code( 34 | self, 35 | code: Union[str, bytes] 36 | ) -> tree_sitter.Node: 37 | """ 38 | This function parses a given code and return the root node. 39 | :param code: 40 | :return: tree_sitter.Node, the root node of the parsed tree. 41 | """ 42 | if isinstance(code, bytes): 43 | tree = self.parser.parse(code) 44 | elif isinstance(code, str): 45 | tree = self.parser.parse(code.encode()) 46 | else: 47 | raise ValueError("Code must be character string or bytes string") 48 | return tree.root_node 49 | 50 | def get_tokens( 51 | self, 52 | code: bytes, 53 | root: tree_sitter.Node 54 | ) -> List[str]: 55 | """ 56 | This function is for getting tokens recursively from a tree. 57 | :param code: the byte string corresponding to the code. 58 | :param root: the root node of the parsed tree 59 | :return: List of Tokens. 60 | """ 61 | tokens = [] 62 | if root.type == "comment": 63 | return tokens 64 | if "string" in str(root.type): 65 | parent = root.parent 66 | if "list" not in str(parent.type) and len(parent.children) == 1: 67 | return tokens 68 | else: 69 | return [code[root.start_byte:root.end_byte].decode()] 70 | if len(root.children) == 0: 71 | tokens.append(code[root.start_byte:root.end_byte].decode()) 72 | else: 73 | for child in root.children: 74 | tokens += self.get_tokens(code, child) 75 | return tokens 76 | 77 | def get_token_string( 78 | self, 79 | code: str, 80 | root: tree_sitter.Node 81 | ) -> str: 82 | """ 83 | This is a auxiliary function for just extracting the parsed token string. 84 | :param code: the byte string corresponding to the code. 85 | :param root: the root node of the parsed tree 86 | :return: str, the parsed code a string of tokens. 87 | """ 88 | tokens = self.get_tokens(code.encode(), root) 89 | return " ".join(tokens) 90 | 91 | def get_tokens_with_node_type( 92 | self, 93 | code: bytes, 94 | root: tree_sitter.Node 95 | ) -> Tuple[List[str], List[List[str]]]: 96 | """ 97 | This function extracts the tokens and types of the tokens. 98 | It returns a list of string as tokens, and a list of list of string as types. 99 | For every token, it extracts the sequence of ast node type starting from the token all the way to the root. 100 | :param code: the byte string corresponding to the code. 101 | :param root: the root node of the parsed tree 102 | :return: 103 | List[str]: The list of tokens. 104 | List[List[str]]: The AST node types corresponding to every token. 105 | """ 106 | tokens, types = [], [] 107 | if root.type == "comment": 108 | return tokens, types 109 | if "string" in str(root.type): 110 | return [code[root.start_byte:root.end_byte].decode()], [["string"]] 111 | if len(root.children) == 0: 112 | tokens.append(code[root.start_byte:root.end_byte].decode()) 113 | types.append(get_ancestor_type_chains(root)) 114 | else: 115 | for child in root.children: 116 | _tokens, _types = self.get_tokens_with_node_type(code, child) 117 | tokens += _tokens 118 | types += _types 119 | return tokens, types 120 | 121 | def transform_code( 122 | self, 123 | code: Union[str, bytes] 124 | ) -> Tuple[str, object]: 125 | """ 126 | Transforms a piece of code and returns the transformed version 127 | :param code: The code to be transformed either as a character string of bytes string. 128 | :return: 129 | A tuple, where the first member is the transformed code. 130 | The second member might be other metadata (e.g. nde types) of the transformed code. It can be None as well. 131 | """ 132 | pass 133 | -------------------------------------------------------------------------------- /preprocess/transformations/transformation_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, Callable 3 | 4 | from src.data_preprocessors.transformations import ( 5 | BlockSwap, ConfusionRemover, DeadCodeInserter, ForWhileTransformer, OperandSwap, SyntacticNoisingTransformation 6 | ) 7 | 8 | 9 | class SemanticPreservingTransformation: 10 | def __init__( 11 | self, 12 | parser_path: str, 13 | language: str, 14 | transform_functions: Dict[Callable, int] = None, 15 | ): 16 | self.language = language 17 | if transform_functions is not None: 18 | self.transform_functions = transform_functions 19 | else: 20 | self.transform_functions = { 21 | BlockSwap: 1, 22 | ConfusionRemover: 1, 23 | DeadCodeInserter: 1, 24 | ForWhileTransformer: 1, 25 | OperandSwap: 1, 26 | SyntacticNoisingTransformation: 1 27 | } 28 | self.transformations = [] 29 | if self.language == "nl": 30 | self.transformations.append(SyntacticNoisingTransformation(parser_path=parser_path, language="nl")) 31 | else: 32 | for t in self.transform_functions: 33 | for _ in range(self.transform_functions[t]): 34 | self.transformations.append(t(parser_path=parser_path, language=language)) 35 | 36 | def transform_code( 37 | self, 38 | code: str 39 | ): 40 | transformed_code, transformation_name = None, None 41 | indices = list(range(len(self.transformations))) 42 | np.random.shuffle(indices) 43 | success = False 44 | while not success and len(indices) > 0: 45 | si = np.random.choice(indices) 46 | indices.remove(si) 47 | t = self.transformations[si] 48 | transformed_code, metadata = t.transform_code(code) 49 | success = metadata["success"] 50 | if success: 51 | transformation_name = type(t).__name__ 52 | if not success: 53 | return code, None 54 | return transformed_code, transformation_name 55 | -------------------------------------------------------------------------------- /preprocess/transformations/var_renaming_transformation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import re 4 | from typing import Union, Tuple 5 | import os 6 | 7 | from src.data_preprocessors.language_processors import ( 8 | JavaAndCPPProcessor, 9 | CSharpProcessor, 10 | PythonProcessor, 11 | JavascriptProcessor, 12 | PhpProcessor, 13 | ) 14 | from src.data_preprocessors.language_processors.go_processor import GoProcessor 15 | from src.data_preprocessors.language_processors.ruby_processor import RubyProcessor 16 | from src.data_preprocessors.language_processors.utils import get_tokens 17 | from src.data_preprocessors.transformations import TransformationBase 18 | import os 19 | 20 | processor_function = { 21 | "java": JavaAndCPPProcessor, 22 | "c": JavaAndCPPProcessor, 23 | "cpp": JavaAndCPPProcessor, 24 | "c_sharp": CSharpProcessor, 25 | "python": PythonProcessor, 26 | "javascript": JavascriptProcessor, 27 | "go": GoProcessor, 28 | "php": PhpProcessor, 29 | "ruby": RubyProcessor, 30 | } 31 | 32 | tokenizer_function = { 33 | "java": get_tokens, 34 | "c": get_tokens, 35 | "cpp": get_tokens, 36 | "c_sharp": get_tokens, 37 | "python": PythonProcessor.get_tokens, 38 | "javascript": JavascriptProcessor.get_tokens, 39 | "go": get_tokens, 40 | "php": PhpProcessor.get_tokens, 41 | "ruby": get_tokens, 42 | } 43 | 44 | 45 | class VarRenamer(TransformationBase): 46 | def __init__( 47 | self, 48 | parser_path: str, 49 | language: str 50 | ): 51 | super(VarRenamer, self).__init__( 52 | parser_path=parser_path, 53 | language=language, 54 | ) 55 | self.language = language 56 | self.processor = processor_function[self.language] 57 | self.tokenizer_function = tokenizer_function[self.language] 58 | # C/CPP: function_declarator 59 | # Java: class_declaration, method_declaration 60 | # python: function_definition, call 61 | # js: function_declaration 62 | self.not_var_ptype = ["function_declarator", "class_declaration", "method_declaration", "function_definition", 63 | "function_declaration", "call", "local_function_statement"] 64 | 65 | def extract_var_names(self, root, code_string): 66 | var_names = [] 67 | queue = [root] 68 | 69 | while len(queue) > 0: 70 | current_node = queue[0] 71 | queue = queue[1:] 72 | if (current_node.type == "identifier" or current_node.type == "variable_name") and str( 73 | current_node.parent.type) not in self.not_var_ptype: 74 | var_names.append(self.tokenizer_function(code_string, current_node)[0]) 75 | for child in current_node.children: 76 | queue.append(child) 77 | return var_names 78 | 79 | def var_renaming(self, code_string): 80 | root = self.parse_code(code_string) 81 | original_code = self.tokenizer_function(code_string, root) 82 | # print(" ".join(original_code)) 83 | var_names = self.extract_var_names(root, code_string) 84 | var_names = list(set(var_names)) 85 | num_to_rename = math.ceil(0.2 * len(var_names)) 86 | random.shuffle(var_names) 87 | var_names = var_names[:num_to_rename] 88 | var_map = {} 89 | for idx, v in enumerate(var_names): 90 | var_map[v] = f"VAR_{idx}" 91 | modified_code = [] 92 | for t in original_code: 93 | if t in var_names: 94 | modified_code.append(var_map[t]) 95 | else: 96 | modified_code.append(t) 97 | 98 | modified_code_string = " ".join(modified_code) 99 | if modified_code != original_code: 100 | modified_root = self.parse_code(modified_code_string) 101 | return modified_root, modified_code_string, True 102 | else: 103 | return root, code_string, False 104 | 105 | def transform_code( 106 | self, 107 | code: Union[str, bytes] 108 | ) -> Tuple[str, object]: 109 | root, code, success = self.var_renaming(code) 110 | code = re.sub("[ \n\t]+", " ", code) 111 | return code, { 112 | "success": success 113 | } 114 | 115 | 116 | if __name__ == '__main__': 117 | java_code = """ 118 | class A{ 119 | int foo(int n){ 120 | int res = 0; 121 | for(int i = 0; i < n; i++) { 122 | int j = 0; 123 | while (j < i){ 124 | res += j; 125 | } 126 | } 127 | return res; 128 | } 129 | } 130 | """ 131 | python_code = """def foo(n): 132 | res = 0 133 | for i in range(0, 19, 2): 134 | res += i 135 | i = 0 136 | while i in range(n): 137 | res += i 138 | i += 1 139 | return res 140 | """ 141 | c_code = """ 142 | int foo(int n){ 143 | int res = 0; 144 | for(int i = 0; i < n; i++) { 145 | int j = 0; 146 | while (j < i){ 147 | res += j; 148 | } 149 | } 150 | return res; 151 | } 152 | """ 153 | cs_code = """ 154 | int foo(int n){ 155 | int res = 0, i = 0; 156 | while(i < n) { 157 | int j = 0; 158 | while (j < i){ 159 | res += j; 160 | } 161 | } 162 | return res; 163 | } 164 | """ 165 | js_code = """function foo(n) { 166 | let res = ''; 167 | for(let i = 0; i < 10; i++){ 168 | res += i.toString(); 169 | res += '
'; 170 | } 171 | while ( i < 10 ; ) { 172 | res += 'bk'; 173 | } 174 | return res; 175 | } 176 | """ 177 | ruby_code = """ 178 | for i in 0..5 do 179 | puts "Value of local variable is #{i}" 180 | if false then 181 | puts "False printed" 182 | while i == 10 do 183 | print i; 184 | end 185 | i = u + 8 186 | end 187 | end 188 | """ 189 | go_code = """ 190 | func main() { 191 | sum := 0; 192 | i := 0; 193 | for ; i < 10; { 194 | sum += i; 195 | } 196 | i++; 197 | fmt.Println(sum); 198 | } 199 | """ 200 | php_code = """ 201 | "; 204 | } 205 | $x = 0 ; 206 | while ( $x <= 10 ) { 207 | echo "The number is: $x
"; 208 | $x++; 209 | } 210 | ?> 211 | """ 212 | input_map = { 213 | "java": ("java", java_code), 214 | "c": ("c", c_code), 215 | "cpp": ("cpp", c_code), 216 | "cs": ("c_sharp", cs_code), 217 | "js": ("javascript", js_code), 218 | "python": ("python", python_code), 219 | "php": ("php", php_code), 220 | "ruby": ("ruby", ruby_code), 221 | "go": ("go", go_code), 222 | } 223 | code_directory = os.path.realpath(os.path.join(os.path.realpath(__file__), '../../../..')) 224 | parser_path = os.path.join(code_directory, "parser/languages.so") 225 | for lang in ["c", "cpp", "java", "python", "php", "ruby", "js", "go", "cs"]: 226 | lang, code = input_map[lang] 227 | var_renamer = VarRenamer( 228 | parser_path, lang 229 | ) 230 | print(lang) 231 | code, meta = var_renamer.transform_code(code) 232 | print(re.sub("[ \t\n]+", " ", code)) 233 | print(meta) 234 | print("=" * 150) 235 | --------------------------------------------------------------------------------