├── README.md ├── cfg.py ├── config.py ├── data.py ├── data ├── CodeNet_test.csv ├── CodeNet_train.csv ├── FixEval_complete_test.csv ├── FixEval_complete_train.csv ├── FixEval_incomplete_test.csv └── FixEval_incomplete_train.csv ├── fuzz_testing.py ├── fuzz_testing_dataset ├── code_1.py ├── code_10.py ├── code_11.py ├── code_12.py ├── code_13.py ├── code_14.py ├── code_15.py ├── code_16.py ├── code_17.py ├── code_18.py ├── code_19.py ├── code_2.py ├── code_20.py ├── code_21.py ├── code_22.py ├── code_23.py ├── code_24.py ├── code_25.py ├── code_26.py ├── code_27.py ├── code_28.py ├── code_29.py ├── code_3.py ├── code_30.py ├── code_31.py ├── code_32.py ├── code_33.py ├── code_34.py ├── code_35.py ├── code_36.py ├── code_37.py ├── code_38.py ├── code_39.py ├── code_4.py ├── code_40.py ├── code_41.py ├── code_42.py ├── code_43.py ├── code_44.py ├── code_45.py ├── code_46.py ├── code_47.py ├── code_48.py ├── code_49.py ├── code_5.py ├── code_50.py ├── code_6.py ├── code_7.py ├── code_8.py └── code_9.py ├── generate_dataset ├── cfg.py ├── generate_dataset.py └── trace_execution.py ├── img ├── architecture.png └── pipeline.png ├── main.py ├── model.py ├── requirements.txt ├── setup.sh ├── trace_execution.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # CodeFlow: Predicting Program Behavior with Dynamic Dependencies Learning 4 | [![arXiv](https://img.shields.io/badge/arXiv-2408.02816-b31b1b.svg)](https://arxiv.org/abs/2408.02816) 5 | 6 |
7 | 8 | ## Introduction 9 | 10 | We introduce **CodeFlow**, a novel machine learning-based approach designed to predict program behavior by learning both static and dynamic dependencies within the code. CodeFlow constructs control flow graphs (CFGs) to represent all possible execution paths and uses these graphs to predict code coverage and detect runtime errors. Our empirical evaluation demonstrates that CodeFlow significantly improves code coverage prediction accuracy and effectively localizes runtime errors, outperforming state-of-the-art models. 11 | 12 | ### Paper: [CodeFlow: Predicting Program Behavior with Dynamic Dependencies Learning](https://arxiv.org/abs/2408.02816) 13 | 14 | ## Installation 15 | 16 | To set up the environment and install the necessary libraries, run the following command: 17 | 18 | ```sh 19 | ./setup.sh 20 | ``` 21 | 22 | ## Architecture 23 | 24 | 25 | CodeFlow consists of several key components: 26 | 1. **CFG Building**: Constructs CFGs from the source code. 27 | 2. **Source Code Representation Learning**: Learns vector representations of CFG nodes. 28 | 3. **Dynamic Dependencies Learning**: Captures dynamic dependencies among statements using execution traces. 29 | 4. **Code Coverage Prediction**: Classifies nodes for code coverage using learned embeddings. 30 | 5. **Runtime Error Detection and Localization**: Detects and localizes runtime errors by analyzing code coverage continuity within CFGs. 31 | 32 | ## Usage 33 | 34 | ### Running CodeFlow Model 35 | 36 | To run the CodeFlow model, use the following command: 37 | 38 | ```sh 39 | python main.py --data [--runtime_detection] [--bug_localization] 40 | ``` 41 | 42 | #### Configuration Options 43 | 44 | - `--data`: Specify the dataset to be used for training. Options: 45 | - `CodeNet`: Train with only non-buggy Python code from the CodeNet dataset. 46 | - `FixEval_complete`: Train with both non-buggy and buggy code from the FixEval and CodeNet dataset. 47 | - `FixEval_incomplete`: Train with the incomplete version of the FixEval_complete dataset. 48 | 49 | - `--runtime_detection`: Validate the Runtime Error Detection. 50 | 51 | - `--bug_localization`: Validate the Bug Localization in buggy code. 52 | 53 | #### Example Usage 54 | 55 | 1. **Training with the CodeNet dataset(RQ1):** 56 | ```sh 57 | python main.py --data CodeNet 58 | ``` 59 | 60 | 2. **Training with the complete FixEval dataset and validating Runtime Error Detection(RQ2):** 61 | ```sh 62 | python main.py --data FixEval_complete --runtime_detection 63 | ``` 64 | 65 | 3. **Training with the complete and incomplete FixEval dataset and validating Bug Localization(RQ3):** 66 | ```sh 67 | python main.py --data FixEval_complete --bug_localization 68 | python main.py --data FixEval_incomplete --bug_localization 69 | ``` 70 | ### Fuzz Testing with LLM Integration (RQ4) 71 | 72 | After training CodeFlow and saving the corresponding checkpoint, you can utilize it for fuzz testing by integrating it with a Large Language Model (LLM). Use the following command: 73 | 74 | ```sh 75 | python fuzz_testing.py --checkpoint --epoch --time --claude_api_key --model 76 | ``` 77 | - `checkpoint`: The chosen checkpoint. 78 | - `epoch`: The chosen epoch of checkpoint. 79 | - `time`: Time in seconds to run fuzz testing for each code file. 80 | - `claude_api_key`: Your API key for Claude. 81 | - `model`: Model of Claude, default is claude-3-5-sonnet-20240620. 82 | #### Example 83 | ```sh 84 | python fuzz_testing.py --checkpoint 1 --epoch 600 --time 120 --claude_api_key YOUR_API_KEY --model claude-3-5-sonnet-20240620 85 | ``` 86 | ### Generating Your Own Dataset 87 | 88 | To generate your own dataset, including CFG, forward and backward edges, and the true execution trace as ground truth for your Python code, follow these steps: 89 | 90 | 1. **Navigate to the `generate_dataset` folder**: 91 | ```sh 92 | cd generate_dataset 93 | ``` 94 | 95 | 2. **Place your Python code files in the `dataset` folder**. 96 | 97 | 3. **Run the dataset generation script**: 98 | ```sh 99 | python generate_dataset.py 100 | ``` 101 | To build and visualize CFG for a Python file, use this command: 102 | ```sh 103 | python cfg.py \directory_to_Python_file 104 | ``` 105 | 106 | ## License 107 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 108 | 109 | ## Acknowledgements 110 | This codebase is adapted from: 111 | - [ConditionBugs](https://github.com/zhangj111/ConditionBugs) 112 | - [CFG-Generator](https://github.com/Tiankai-Jiang/CFG-Generator) 113 | - [trace_python](https://github.com/python/cpython/blob/3.12/Lib/trace.py) 114 | 115 | ## Citation Information 116 | 117 | If you're using CodeFlow, please cite using this BibTeX: 118 | ```bibtex 119 | @misc{le2024learningpredictprogramexecution, 120 | title={Learning to Predict Program Execution by Modeling Dynamic Dependency on Code Graphs}, 121 | author={Cuong Chi Le and Hoang Nhat Phan and Huy Nhat Phan and Tien N. Nguyen and Nghi D. Q. Bui}, 122 | year={2024}, 123 | eprint={2408.02816}, 124 | archivePrefix={arXiv}, 125 | primaryClass={cs.SE}, 126 | url={https://arxiv.org/abs/2408.02816}, 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import ast, astor, autopep8, tokenize, io, sys 3 | import graphviz as gv 4 | from typing import Dict, List, Tuple, Set, Optional, Type 5 | import re 6 | import sys 7 | import trace_execution 8 | import os 9 | import io 10 | import linecache 11 | import random 12 | 13 | class SingletonMeta(type): 14 | _instance: Optional[BlockId] = None 15 | 16 | def __call__(self) -> BlockId: 17 | if self._instance is None: 18 | self._instance = super().__call__() 19 | return self._instance 20 | 21 | 22 | class BlockId(metaclass=SingletonMeta): 23 | counter: int = 0 24 | 25 | def gen(self) -> int: 26 | self.counter += 1 27 | return self.counter 28 | 29 | 30 | class BasicBlock: 31 | 32 | def __init__(self, bid: int): 33 | self.bid: int = bid 34 | self.stmts: List[Type[ast.AST]] = [] 35 | self.calls: List[str] = [] 36 | self.prev: List[int] = [] 37 | self.next: List[int] = [] 38 | self.condition = False 39 | self.for_loop = 0 40 | self.for_name = Type[ast.AST] 41 | 42 | def is_empty(self) -> bool: 43 | return len(self.stmts) == 0 44 | 45 | def has_next(self) -> bool: 46 | return len(self.next) != 0 47 | 48 | def has_previous(self) -> bool: 49 | return len(self.prev) != 0 50 | 51 | def remove_from_prev(self, prev_bid: int) -> None: 52 | if prev_bid in self.prev: 53 | self.prev.remove(prev_bid) 54 | 55 | def remove_from_next(self, next_bid: int) -> None: 56 | if next_bid in self.next: 57 | self.next.remove(next_bid) 58 | 59 | def stmts_to_code(self) -> str: 60 | code = '' 61 | for stmt in self.stmts: 62 | line = astor.to_source(stmt) 63 | code += line.split('\n')[0] + "\n" if type(stmt) in [ast.If, ast.For, ast.While, ast.FunctionDef, 64 | ast.AsyncFunctionDef] else line 65 | return code 66 | 67 | def calls_to_code(self) -> str: 68 | return '\n'.join(self.calls) 69 | 70 | 71 | class CFG: 72 | 73 | def __init__(self, name: str): 74 | self.name: str = name 75 | self.filename = name 76 | 77 | # I am sure that in original code variable asynchr is not used 78 | # And I think list finalblocks is also not used. 79 | 80 | self.start: Optional[BasicBlock] = None 81 | self.func_calls: Dict[str, CFG] = {} 82 | self.blocks: Dict[int, BasicBlock] = {} 83 | self.edges: Dict[Tuple[int, int], Type[ast.AST]] = {} 84 | self.back_edges: List[Tuple[int, int]] = [] 85 | self.lst_temp : List[Tuple[int, int]] = [] 86 | self.graph: Optional[gv.dot.Digraph] = None 87 | self.execution_path: List[int] = [] 88 | self.path: List[int] = [] 89 | self.revert: Dict[int, int] = {} 90 | self.store_True: List[int] = [] 91 | 92 | def _traverse(self, block: BasicBlock, visited: Set[int] = set(), calls: bool = True) -> None: 93 | if block.bid not in visited: 94 | visited.add(block.bid) 95 | st = block.stmts_to_code() 96 | if st.startswith('if'): 97 | st = st[3:] 98 | elif st.startswith('while'): 99 | st = st[6:] 100 | if block.condition: 101 | st = 'T ' + st 102 | # Check if the block is in the path and highlight it 103 | node_attributes = {'shape': 'ellipse'} 104 | if block.bid in self.path: 105 | node_attributes['color'] = 'red' 106 | node_attributes['style'] = 'filled' 107 | 108 | self.graph.node(str(block.bid), label=st, _attributes=node_attributes) 109 | 110 | for next_bid in block.next: 111 | self._traverse(self.blocks[next_bid], visited, calls=calls) 112 | self.graph.edge(str(block.bid), str(next_bid)) 113 | 114 | def get_revert(self): 115 | code = {} 116 | for_loop = {} 117 | for i in self.blocks: 118 | if self.blocks[i].for_loop != 0: 119 | if self.blocks[i].for_loop not in for_loop: 120 | for_loop[self.blocks[i].for_loop] = [i] 121 | else: 122 | for_loop[self.blocks[i].for_loop].append(i) 123 | first = [] 124 | second = [] 125 | for i in for_loop: 126 | first.append(for_loop[i][0]+1) 127 | second.append(for_loop[i][1]) 128 | orin_node = [] 129 | track = {} 130 | track_for = {} 131 | for i in self.blocks: 132 | if self.blocks[i].stmts_to_code(): 133 | if int(i) == 1: 134 | st = 'BEGIN' 135 | elif int(i) == len(self.blocks): 136 | st = 'EXIT' 137 | else: 138 | if i in first: 139 | line = astor.to_source(self.blocks[i].for_name) 140 | st = line.split('\n')[0] 141 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 142 | else: 143 | st = self.blocks[i].stmts_to_code() 144 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 145 | orin_node.append([i, st, None]) 146 | if st not in track: 147 | track[st] = [len(orin_node)-1] 148 | else: 149 | track[st].append(len(orin_node)-1) 150 | track_for[i] = len(orin_node)-1 151 | with open(self.filename, 'r') as file: 152 | lines = file.readlines() 153 | for i in range(1, len(lines)+1): 154 | line = lines[i-1] 155 | #delete \n at the end of each line and delete all spaces 156 | line = line.strip() 157 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "") 158 | if line.startswith('elif'): 159 | line = line[2:] 160 | if line in track: 161 | orin_node[track[line][0]][2] = i 162 | if orin_node[track[line][0]][0] in first: 163 | orin_node[track[line][0]-1][2] = i-0.4 164 | orin_node[track[line][0]+1][2] = i+0.4 165 | if len(track[line]) > 1: 166 | track[line].pop(0) 167 | for i in second: 168 | max_val = 0 169 | for edge in self.edges: 170 | if edge[0] == i: 171 | if orin_node[track_for[edge[1]]][2] > max_val: 172 | max_val = orin_node[track_for[edge[1]]][2] 173 | if edge[1] == i: 174 | if orin_node[track_for[edge[0]]][2] > max_val: 175 | max_val = orin_node[track_for[edge[0]]][2] 176 | orin_node[track_for[i]][2] = max_val + 0.5 177 | orin_node[0][2] = 0 178 | orin_node[-1][2] = len(lines)+1 179 | # sort orin_node by the third element 180 | orin_node.sort(key=lambda x: x[2]) 181 | nodes = [] 182 | for t in orin_node: 183 | i = t[0] 184 | if self.blocks[i].stmts_to_code(): 185 | if int(i) == 1: 186 | nodes.append('BEGIN') 187 | elif int(i) == len(self.blocks): 188 | nodes.append('EXIT') 189 | else: 190 | st = self.blocks[i].stmts_to_code() 191 | st_no_space = re.sub(r"\s+", "", st) 192 | st_no_space = st_no_space.replace('"', "'") 193 | # if start with if or while, delete these keywords 194 | if st.startswith('if'): 195 | st = st[3:] 196 | elif st.startswith('while'): 197 | st = st[6:] 198 | if self.blocks[i].condition: 199 | st = 'T '+ st 200 | nodes.append(st) 201 | self.revert[len(nodes)] = i 202 | 203 | def track_execution(self): 204 | nodes = [] 205 | blocks = [] 206 | matching = {} 207 | 208 | # add 209 | code = {} 210 | for_loop = {} 211 | for i in self.blocks: 212 | if self.blocks[i].for_loop != 0: 213 | if self.blocks[i].for_loop not in for_loop: 214 | for_loop[self.blocks[i].for_loop] = [i] 215 | else: 216 | for_loop[self.blocks[i].for_loop].append(i) 217 | first = [] 218 | second = [] 219 | for i in for_loop: 220 | first.append(for_loop[i][0]+1) 221 | second.append(for_loop[i][1]) 222 | orin_node = [] 223 | track = {} 224 | track_for = {} 225 | for i in self.blocks: 226 | if self.blocks[i].stmts_to_code(): 227 | if int(i) == 1: 228 | st = 'BEGIN' 229 | elif int(i) == len(self.blocks): 230 | st = 'EXIT' 231 | else: 232 | if i in first: 233 | line = astor.to_source(self.blocks[i].for_name) 234 | st = line.split('\n')[0] 235 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 236 | else: 237 | st = self.blocks[i].stmts_to_code() 238 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 239 | orin_node.append([i, st, None]) 240 | if st not in track: 241 | track[st] = [len(orin_node)-1] 242 | else: 243 | track[st].append(len(orin_node)-1) 244 | track_for[i] = len(orin_node)-1 245 | with open(self.filename, 'r') as file: 246 | lines = file.readlines() 247 | for i in range(1, len(lines)+1): 248 | line = lines[i-1] 249 | #delete \n at the end of each line and delete all spaces 250 | line = line.strip() 251 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "") 252 | if line.startswith('elif'): 253 | line = line[2:] 254 | if line in track: 255 | orin_node[track[line][0]][2] = i 256 | if orin_node[track[line][0]][0] in first: 257 | orin_node[track[line][0]-1][2] = i-0.4 258 | orin_node[track[line][0]+1][2] = i+0.4 259 | if len(track[line]) > 1: 260 | track[line].pop(0) 261 | orin_node[0][2] = 0 262 | orin_node[-1][2] = len(lines)+1 263 | for i in second: 264 | max_val = 0 265 | for edge in self.edges: 266 | if edge[0] == i: 267 | if orin_node[track_for[edge[1]]][2] > max_val: 268 | max_val = orin_node[track_for[edge[1]]][2] 269 | if edge[1] == i: 270 | if orin_node[track_for[edge[0]]][2] > max_val: 271 | max_val = orin_node[track_for[edge[0]]][2] 272 | orin_node[track_for[i]][2] = max_val + 0.5 273 | # sort orin_node by the third element 274 | orin_node.sort(key=lambda x: x[2]) 275 | #add 276 | for t in orin_node: 277 | i = t[0] 278 | if self.blocks[i].stmts_to_code(): 279 | if int(i) == 1: 280 | nodes.append('BEGIN') 281 | blocks.append('BEGIN') 282 | elif int(i) == len(self.blocks): 283 | nodes.append('EXIT') 284 | blocks.append('EXIT') 285 | else: 286 | st = self.blocks[i].stmts_to_code() 287 | st_no_space = re.sub(r"\s+", "", st) 288 | st_no_space = st_no_space.replace('"', "'") 289 | blocks.append(st_no_space) 290 | # if start with if or while, delete these keywords 291 | if st.startswith('if'): 292 | st = st[3:] 293 | elif st.startswith('while'): 294 | st = st[6:] 295 | if self.blocks[i].condition: 296 | st = 'T '+ st 297 | nodes.append(st) 298 | matching[i] = len(nodes) 299 | self.revert[len(nodes)] = i 300 | 301 | edges = {} 302 | for edge in self.edges: 303 | if matching[edge[0]] not in edges: 304 | edges[matching[edge[0]]] = [matching[edge[1]]] 305 | else: 306 | edges[matching[edge[0]]].append(matching[edge[1]]) 307 | 308 | for_loop = {} 309 | for i in self.blocks: 310 | if self.blocks[i].for_loop != 0: 311 | if self.blocks[i].for_loop not in for_loop: 312 | for_loop[self.blocks[i].for_loop] = [matching[i]] 313 | else: 314 | for_loop[self.blocks[i].for_loop].append(matching[i]) 315 | store = {} 316 | last_loop = {} 317 | for i in for_loop: 318 | lst = for_loop[i] 319 | start = lst[0] 320 | end = lst[1] 321 | if self.blocks[self.revert[start]].condition: 322 | self.blocks[self.revert[start+1]].condition = True 323 | store[start+1] = [start] 324 | for t in matching: 325 | if matching[t] == start+1: 326 | line = astor.to_source(self.blocks[t].for_name) 327 | code = line.split('\n')[0] 328 | st_no_space = re.sub(r"\s+", "", code) 329 | st_no_space = st_no_space.replace('"', "'") 330 | blocks[start] = st_no_space 331 | edges.pop(start) 332 | for j in edges: 333 | if start in edges[j]: 334 | edges[j].remove(start) 335 | edges[j].append(start + 1) 336 | for a in edges: 337 | if end in edges[a]: 338 | edges[a].remove(end) 339 | edges[a].append(start+1) 340 | last_loop[a] = end 341 | edges.pop(end) 342 | if nodes[start+1].startswith('T'): 343 | store[start+1].append(start+2) 344 | for b in edges[start+2]: 345 | edges[start+1].append(b) 346 | edges.pop(start+2) 347 | edges[start+1].remove(start+2) 348 | else: 349 | store[start+1].append(start+3) 350 | for b in edges[start+3]: 351 | edges[start+1].append(b) 352 | edges.pop(start+3) 353 | edges[start+1].remove(start+3) 354 | t = trace_execution.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,], 355 | trace=0, count=1) 356 | arguments = [] 357 | sys.argv = [self.filename, arguments] 358 | sys.path[0] = os.path.dirname(self.filename) 359 | 360 | with io.open_code(self.filename) as fp: 361 | code = compile(fp.read(), self.filename, 'exec') 362 | # try to emulate __main__ namespace as much as possible 363 | globs = { 364 | '__file__': self.filename, 365 | '__name__': '__main__', 366 | '__package__': None, 367 | '__cached__': None, 368 | } 369 | terminate = True 370 | try: 371 | t.runctx(code, globs, globs) 372 | except Exception as e: 373 | terminate = False 374 | 375 | source = linecache.getlines(self.filename) 376 | code_line = [element.lstrip().replace('\n', '') for element in source] 377 | execution_path = [] 378 | for lineno in t.exe_path: 379 | no_spaces = re.sub(r"\s+", "", code_line[lineno-1]) 380 | if no_spaces.startswith("if("): 381 | no_spaces = no_spaces.replace("(", "").replace(")", "") 382 | if no_spaces.startswith("elif"): 383 | no_spaces = no_spaces[2:] 384 | execution_path.append(no_spaces) 385 | 386 | check_True_condition = [] 387 | for i in range(len(execution_path)-1): 388 | if execution_path[i].startswith('if') or execution_path[i].startswith('while') or execution_path[i].startswith('for'): 389 | if t.exe_path[i+1] == t.exe_path[i]+1: 390 | check_True_condition.append(i) 391 | 392 | current_node = 1 393 | path = [self.revert[current_node]] 394 | exit_flag = False 395 | for s in range(len(execution_path)): 396 | node = execution_path[s] 397 | if current_node == 1: 398 | current_node = edges[current_node][0] 399 | else: 400 | c = 0 401 | if node == "break" or node == "continue": 402 | continue 403 | if len(edges[current_node]) == 2: 404 | if blocks[edges[current_node][0]-1] == blocks[edges[current_node][1]-1]: 405 | if (s-1) in check_True_condition: 406 | if self.blocks[self.revert[edges[current_node][0]]].condition: 407 | current_node = edges[current_node][0] 408 | else: 409 | current_node = edges[current_node][1] 410 | else: 411 | if self.blocks[self.revert[edges[current_node][0]]].condition: 412 | current_node = edges[current_node][1] 413 | else: 414 | current_node = edges[current_node][0] 415 | if current_node in last_loop: 416 | if terminate: 417 | if self.revert[last_loop[current_node]] not in path: 418 | path.append(self.revert[last_loop[current_node]]) 419 | else: 420 | if s != len(execution_path) - 1: 421 | if self.revert[last_loop[current_node]] not in path: 422 | path.append(self.revert[last_loop[current_node]]) 423 | if current_node in store: 424 | for i in store[current_node]: 425 | if self.revert[i] not in path: 426 | path.append(self.revert[i]) 427 | if self.revert[current_node] not in path: 428 | path.append(self.revert[current_node]) 429 | continue 430 | for next_node in edges[current_node]: 431 | c += 1 432 | node = node.replace('"', "'") 433 | # print(node) 434 | # print(blocks[next_node-1]) 435 | # print("___________") 436 | if blocks[next_node-1] == node: 437 | if next_node in last_loop: 438 | if terminate: 439 | if self.revert[last_loop[next_node]] not in path: 440 | path.append(self.revert[last_loop[next_node]]) 441 | else: 442 | if s != len(execution_path) - 1: 443 | if self.revert[last_loop[next_node]] not in path: 444 | path.append(self.revert[last_loop[next_node]]) 445 | if next_node in store: 446 | for i in store[next_node]: 447 | if self.revert[i] not in path: 448 | path.append(self.revert[i]) 449 | current_node = next_node 450 | break 451 | if c == len(edges[current_node]): 452 | exit_flag = True 453 | raise Exception(f"Error: Cannot find the execution path in CFG in file {self.filename}") 454 | 455 | if exit_flag: 456 | break 457 | if self.revert[current_node] not in path: 458 | path.append(self.revert[current_node]) 459 | if terminate: 460 | node_max = 0 461 | for i in self.blocks: 462 | if i > node_max: 463 | node_max = i 464 | path.append(node_max) 465 | 466 | self.path = path 467 | 468 | def clean(self): 469 | des_edges = {} 470 | for edge in self.edges: 471 | if edge[0] not in des_edges: 472 | des_edges[edge[0]] = [edge[1]] 473 | else: 474 | des_edges[edge[0]].append(edge[1]) 475 | blank_node = [] 476 | for node in des_edges: 477 | check = False 478 | if self.blocks[node].stmts_to_code(): 479 | current = node 480 | for next in des_edges[node]: 481 | next_node = next 482 | while self.blocks[next_node].stmts_to_code() == '': 483 | if next_node not in blank_node: 484 | blank_node.append(next_node) 485 | current = next_node 486 | if current not in des_edges: 487 | break 488 | next_node = des_edges[current][0] 489 | if (current, next_node) in self.back_edges: 490 | check = True 491 | if (node, next_node) not in self.edges: 492 | self.edges[(node, next_node)] = None 493 | if check: 494 | self.back_edges.append((node, next_node)) 495 | if next_node != node and next_node not in self.blocks[node].next: 496 | self.blocks[node].next.append(next_node) 497 | else: 498 | if node not in blank_node: 499 | blank_node.append(node) 500 | for edge in self.edges.copy(): 501 | if edge[0] in blank_node or edge[1] in blank_node: 502 | self.edges.pop(edge) 503 | for i in self.blocks: 504 | for node in blank_node: 505 | if node in self.blocks[i].next: 506 | self.blocks[i].next.remove(node) 507 | 508 | def _show(self, fmt: str = 'pdf', calls: bool = True) -> gv.dot.Digraph: 509 | self.graph = gv.Digraph(name='cluster_' + self.name, format=fmt, graph_attr={'label': self.name}) 510 | self._traverse(self.start, calls=calls) 511 | for k, v in self.func_calls.items(): 512 | self.graph.subgraph(v._show(fmt, calls)) 513 | return self.graph 514 | 515 | 516 | def show(self, filepath: str = './output', fmt: str = 'pdf', calls: bool = True, show: bool = True) -> None: 517 | self._show(fmt, calls) 518 | self.graph.render(filepath, view=show, cleanup=True) 519 | 520 | 521 | class CFGVisitor(ast.NodeVisitor): 522 | 523 | invertComparators: Dict[Type[ast.AST], Type[ast.AST]] = {ast.Eq: ast.NotEq, ast.NotEq: ast.Eq, ast.Lt: ast.GtE, 524 | ast.LtE: ast.Gt, 525 | ast.Gt: ast.LtE, ast.GtE: ast.Lt, ast.Is: ast.IsNot, 526 | ast.IsNot: ast.Is, ast.In: ast.NotIn, ast.NotIn: ast.In} 527 | 528 | def __init__(self): 529 | super().__init__() 530 | self.count_for_loop = 0 531 | self.loop_stack: List[BasicBlock] = [] 532 | self.for_stack: List[int] = [] 533 | self.continue_stack: List[BasicBlock] = [] 534 | self.ifExp = False 535 | self.store_True: List[int] = [] 536 | 537 | def build(self, name: str, tree: Type[ast.AST]) -> CFG: 538 | self.cfg = CFG(name) 539 | self.filename = name 540 | begin_block = self.new_block() 541 | begin_block.stmts = [ast.Expr(value=ast.Str(s='BEGIN'))] 542 | self.cfg.start = begin_block 543 | self.curr_block = self.new_block() 544 | self.add_edge(begin_block.bid, self.curr_block.bid) 545 | 546 | self.visit(tree) 547 | self.add_EXIT_node() 548 | self.remove_empty_blocks(self.cfg.start) 549 | for edge in self.cfg.lst_temp: 550 | if edge in self.cfg.back_edges: 551 | self.cfg.back_edges.remove(edge) 552 | return self.cfg 553 | 554 | def add_EXIT_node(self) -> None: 555 | exit_block = self.new_block() 556 | exit_block.stmts = [ast.Expr(value=ast.Str(s='EXIT'))] 557 | self.add_edge(self.curr_block.bid, exit_block.bid) 558 | 559 | def new_block(self) -> BasicBlock: 560 | bid: int = BlockId().gen() 561 | self.cfg.blocks[bid] = BasicBlock(bid) 562 | if bid in self.store_True: 563 | self.cfg.blocks[bid].condition = True 564 | self.store_True.remove(bid) 565 | return self.cfg.blocks[bid] 566 | 567 | def add_stmt(self, block: BasicBlock, stmt: Type[ast.AST]) -> None: 568 | # if block.stmts contain 1 stmt, then create the new node and add the stmt to the new node and add edge from block to new node 569 | if len(block.stmts) == 1: 570 | new_block = self.new_block() 571 | new_block.stmts.append(stmt) 572 | self.add_edge(block.bid, new_block.bid) 573 | self.curr_block = new_block 574 | else: 575 | block.stmts.append(stmt) 576 | 577 | def add_edge(self, frm_id: int, to_id: int, condition=None) -> BasicBlock: 578 | self.cfg.blocks[frm_id].next.append(to_id) 579 | self.cfg.blocks[to_id].prev.append(frm_id) 580 | self.cfg.edges[(frm_id, to_id)] = condition 581 | return self.cfg.blocks[to_id] 582 | 583 | def add_loop_block(self) -> BasicBlock: 584 | if self.curr_block.is_empty() and not self.curr_block.has_next(): 585 | return self.curr_block 586 | else: 587 | loop_block = self.new_block() 588 | self.add_edge(self.curr_block.bid, loop_block.bid) 589 | return loop_block 590 | 591 | def add_subgraph(self, tree: Type[ast.AST]) -> None: 592 | self.cfg.func_calls[tree.name] = CFGVisitor().build(tree.name, ast.Module(body=tree.body)) 593 | 594 | def add_condition(self, cond1: Optional[Type[ast.AST]], cond2: Optional[Type[ast.AST]]) -> Optional[Type[ast.AST]]: 595 | if cond1 and cond2: 596 | return ast.BoolOp(ast.And(), values=[cond1, cond2]) 597 | else: 598 | return cond1 if cond1 else cond2 599 | 600 | # not tested 601 | def remove_empty_blocks(self, block: BasicBlock, visited: Set[int] = set()) -> None: 602 | if block.bid not in visited: 603 | visited.add(block.bid) 604 | if block.is_empty(): 605 | for prev_bid in block.prev: 606 | prev_block = self.cfg.blocks[prev_bid] 607 | for next_bid in block.next: 608 | next_block = self.cfg.blocks[next_bid] 609 | self.add_edge(prev_bid, next_bid, self.add_condition(self.cfg.edges.get((prev_bid, block.bid)), self.cfg.edges.get((block.bid, next_bid)))) 610 | self.cfg.edges.pop((block.bid, next_bid), None) 611 | if (block.bid, next_bid) in self.cfg.back_edges: 612 | self.cfg.back_edges.append((prev_bid, next_bid)) 613 | self.cfg.lst_temp.append((block.bid, next_bid)) 614 | next_block.remove_from_prev(block.bid) 615 | self.cfg.edges.pop((prev_bid, block.bid)) 616 | if (prev_bid, block.bid) in self.cfg.back_edges: 617 | self.cfg.back_edges.append((prev_bid, next_bid)) 618 | self.cfg.lst_temp.append((prev_bid, block.bid)) 619 | prev_block.remove_from_next(block.bid) 620 | block.prev.clear() 621 | for next_bid in block.next: 622 | self.remove_empty_blocks(self.cfg.blocks[next_bid], visited) 623 | block.next.clear() 624 | 625 | else: 626 | for next_bid in block.next: 627 | self.remove_empty_blocks(self.cfg.blocks[next_bid], visited) 628 | 629 | def invert(self, node: Type[ast.AST]) -> Type[ast.AST]: 630 | if type(node) == ast.Compare: 631 | if len(node.ops) == 1: 632 | return ast.Compare(left=node.left, ops=[self.invertComparators[type(node.ops[0])]()], comparators=node.comparators) 633 | else: 634 | tmpNode = ast.BoolOp(op=ast.And(), values = [ast.Compare(left=node.left, ops=[node.ops[0]], comparators=[node.comparators[0]])]) 635 | for i in range(0, len(node.ops) - 1): 636 | tmpNode.values.append(ast.Compare(left=node.comparators[i], ops=[node.ops[i+1]], comparators=[node.comparators[i+1]])) 637 | return self.invert(tmpNode) 638 | elif isinstance(node, ast.BinOp) and type(node.op) in self.invertComparators: 639 | return ast.BinOp(node.left, self.invertComparators[type(node.op)](), node.right) 640 | elif type(node) == ast.NameConstant and type(node.value) == bool: 641 | return ast.NameConstant(value=not node.value) 642 | elif type(node) == ast.BoolOp: 643 | return ast.BoolOp(values = [self.invert(x) for x in node.values], op = {ast.And: ast.Or(), ast.Or: ast.And()}.get(type(node.op))) 644 | elif type(node) == ast.UnaryOp: 645 | return self.UnaryopInvert(node) 646 | else: 647 | return ast.UnaryOp(op=ast.Not(), operand=node) 648 | 649 | def UnaryopInvert(self, node: Type[ast.AST]) -> Type[ast.AST]: 650 | if type(node.op) == ast.UAdd: 651 | return ast.UnaryOp(op=ast.USub(),operand = node.operand) 652 | elif type(node.op) == ast.USub: 653 | return ast.UnaryOp(op=ast.UAdd(),operand = node.operand) 654 | elif type(node.op) == ast.Invert: 655 | return ast.UnaryOp(op=ast.Not(), operand=node) 656 | else: 657 | return node.operand 658 | 659 | # def boolinvert(self, node:Type[ast.AST]) -> Type[ast.AST]: 660 | # value = [] 661 | # for item in node.values: 662 | # value.append(self.invert(item)) 663 | # if type(node.op) == ast.Or: 664 | # return ast.BoolOp(values = value, op = ast.And()) 665 | # elif type(node.op) == ast.And: 666 | # return ast.BoolOp(values = value, op = ast.Or()) 667 | 668 | def combine_conditions(self, node_list: List[Type[ast.AST]]) -> Type[ast.AST]: 669 | return node_list[0] if len(node_list) == 1 else ast.BoolOp(op=ast.And(), values = node_list) 670 | 671 | def generic_visit(self, node): 672 | if type(node) in [ast.Import, ast.ImportFrom]: 673 | self.add_stmt(self.curr_block, node) 674 | return 675 | if type(node) in [ast.FunctionDef, ast.AsyncFunctionDef]: 676 | self.add_stmt(self.curr_block, node) 677 | self.add_subgraph(node) 678 | return 679 | if type(node) in [ast.AnnAssign, ast.AugAssign]: 680 | self.add_stmt(self.curr_block, node) 681 | super().generic_visit(node) 682 | 683 | def get_function_name(self, node: Type[ast.AST]) -> str: 684 | if type(node) == ast.Name: 685 | return node.id 686 | elif type(node) == ast.Attribute: 687 | return self.get_function_name(node.value) + '.' + node.attr 688 | elif type(node) == ast.Str: 689 | return node.s 690 | elif type(node) == ast.Subscript: 691 | return node.value.id 692 | elif type(node) == ast.Lambda: 693 | return 'lambda function' 694 | 695 | def populate_body(self, body_list: List[Type[ast.AST]], to_bid: int, type: str) -> None: 696 | for child in body_list: 697 | self.visit(child) 698 | if not self.curr_block.next: 699 | self.add_edge(self.curr_block.bid, to_bid) 700 | if type == "While" or type == "Try": 701 | self.cfg.back_edges.append((self.curr_block.bid, to_bid)) 702 | 703 | def populate_If_body(self, body_list: List[Type[ast.AST]], node_list: List[int]): 704 | for child in body_list: 705 | self.visit(child) 706 | if not self.curr_block.next: 707 | node_list.append(self.curr_block.bid) 708 | return node_list 709 | 710 | def populate_For_body(self, body_list: List[Type[ast.AST]]) -> None: 711 | for child in body_list: 712 | self.visit(child) 713 | new_node = self.new_block() 714 | if not self.curr_block.next: 715 | self.add_edge(self.curr_block.bid, new_node.bid) 716 | return new_node 717 | 718 | # assert type check 719 | def visit_Assert(self, node): 720 | self.add_stmt(self.curr_block, node) 721 | # If the assertion fails, the current flow ends, so the fail block is a 722 | # final block of the CFG. 723 | # self.cfg.finalblocks.append(self.add_edge(self.curr_block.bid, self.new_block().bid, self.invert(node.test))) 724 | # If the assertion is True, continue the flow of the program. 725 | # success block 726 | self.curr_block = self.add_edge(self.curr_block.bid, self.new_block().bid, node.test) 727 | self.generic_visit(node) 728 | 729 | def visit_Assign(self, node): 730 | if type(node.value) in [ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp, ast.Lambda] and len(node.targets) == 1 and type(node.targets[0]) == ast.Name: # is this entire statement necessary? 731 | if type(node.value) == ast.ListComp: 732 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.List(elts=[], ctx=ast.Load()))) 733 | self.listCompReg = (node.targets[0].id, node.value) 734 | elif type(node.value) == ast.SetComp: 735 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.Call(func=ast.Name(id='set', ctx=ast.Load()), args=[], keywords=[]))) 736 | self.setCompReg = (node.targets[0].id, node.value) 737 | elif type(node.value) == ast.DictComp: 738 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.Dict(keys=[], values=[]))) 739 | self.dictCompReg = (node.targets[0].id, node.value) 740 | elif type(node.value) == ast.GeneratorExp: 741 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.Call(func=ast.Name(id='__' + node.targets[0].id + 'Generator__', ctx=ast.Load()), args=[], keywords=[]))) 742 | self.genExpReg = (node.targets[0].id, node.value) 743 | else: 744 | self.lambdaReg = (node.targets[0].id, node.value) 745 | else: 746 | self.add_stmt(self.curr_block, node) 747 | self.generic_visit(node) 748 | 749 | def visit_Await(self, node): 750 | afterawait_block = self.new_block() 751 | self.add_edge(self.curr_block.bid, afterawait_block.bid) 752 | self.generic_visit(node) 753 | self.curr_block = afterawait_block 754 | 755 | def visit_Break(self, node): 756 | assert len(self.loop_stack), "Found break not inside loop" 757 | self.add_edge(self.curr_block.bid, self.loop_stack[-1].bid, ast.Break()) 758 | 759 | def visit_Continue(self, node): 760 | assert len(self.continue_stack), "Found continue not inside loop" 761 | self.add_edge(self.curr_block.bid, self.continue_stack[-1].bid, ast.Continue()) 762 | 763 | def visit_Call(self, node): 764 | if type(node.func) == ast.Lambda: 765 | self.lambdaReg = ('Anonymous Function', node.func) 766 | self.generic_visit(node) 767 | # else: 768 | # self.curr_block.calls.append(self.get_function_name(node.func)) 769 | 770 | def visit_DictComp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]: 771 | if not generators: 772 | if self.dictCompReg[0]: # bug if there is else statement in comprehension 773 | return [ast.Assign(targets=[ast.Subscript(value=ast.Name(id=self.dictCompReg[0], ctx=ast.Load()), slice=ast.Index(value=self.dictCompReg[1].key), ctx=ast.Store())], value=self.dictCompReg[1].value)] 774 | # else: # not supported yet 775 | # return [ast.Expr(value=self.dictCompReg[1].elt)] 776 | else: 777 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_DictComp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_DictComp_Rec(generators[:-1]), orelse=[])] 778 | 779 | def visit_DictComp(self, node): 780 | try: # try may change to checking if self.dictCompReg exists 781 | self.generic_visit(ast.Module(self.visit_DictComp_Rec(self.dictCompReg[1].generators))) 782 | except: 783 | pass 784 | finally: 785 | self.dictCompReg = None 786 | 787 | # ignore the case when using set or dict comprehension or generator expression but the result is not assigned to a variable 788 | def visit_Expr(self, node): 789 | if type(node.value) == ast.ListComp and type(node.value.elt) == ast.Call: 790 | self.listCompReg = (None, node.value) 791 | elif type(node.value) == ast.Lambda: 792 | self.lambdaReg = ('Anonymous Function', node.value) 793 | # elif type(node.value) == ast.Call and type(node.value.func) == ast.Lambda: 794 | # self.lambdaReg = ('Anonymous Function', node.value.func) 795 | else: 796 | self.add_stmt(self.curr_block, node) 797 | self.generic_visit(node) 798 | 799 | def visit_For(self, node): 800 | self.count_for_loop += 1 801 | assign_block = self.new_block() 802 | if self.curr_block.condition: 803 | assign_block.condition = True 804 | self.add_edge(self.curr_block.bid, assign_block.bid) 805 | self.curr_block = assign_block 806 | self.curr_block.for_loop = self.count_for_loop 807 | self.for_stack.append(self.count_for_loop) 808 | num = self.count_for_loop 809 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=f'p{num}', ctx=ast.Store())], value=ast.Num(n=0))) 810 | if_node = ast.If(test=ast.Compare(left=ast.Name(id=f'p{num}', ctx=ast.Load()), ops=[ast.Lt()], comparators=[ast.Call(func=ast.Name(id='len', ctx=ast.Load()), args=[node.iter], keywords=[])]), body=[ast.Assign(targets=[ast.Name(id=node.target.id, ctx=ast.Store())], value=ast.Subscript(value=node.iter, slice=ast.Index(value=ast.Name(id=f'p{num}', ctx=ast.Load())), ctx=ast.Load()))] + node.body, orelse=[]) 811 | loop_guard = self.add_loop_block() 812 | self.continue_stack.append(loop_guard) 813 | self.curr_block = loop_guard 814 | self.curr_block.for_name = node 815 | self.add_stmt(self.curr_block, if_node) 816 | # New block for the body of the for-loop. 817 | for_block = self.add_edge(self.curr_block.bid, self.new_block().bid) 818 | if not node.orelse: 819 | # Block of code after the for loop. 820 | afterfor_block = self.add_edge(self.curr_block.bid, self.new_block().bid) 821 | self.loop_stack.append(afterfor_block) 822 | self.curr_block = for_block 823 | self.curr_block.condition = True 824 | self.curr_block = self.populate_For_body(if_node.body) 825 | end_node = ast.AugAssign(target=ast.Name(id=f'p{num}', ctx=ast.Store()), op=ast.Add(), value=ast.Num(n=1)) 826 | self.add_stmt(self.curr_block, end_node) 827 | self.curr_block.for_loop = self.for_stack[-1] 828 | self.add_edge(self.curr_block.bid, loop_guard.bid) 829 | self.cfg.back_edges.append((self.curr_block.bid, loop_guard.bid)) 830 | else: 831 | # Block of code after the for loop. 832 | afterfor_block = self.new_block() 833 | orelse_block = self.add_edge(self.curr_block.bid, self.new_block().bid, ast.Name(id='else', ctx=ast.Load())) 834 | self.loop_stack.append(afterfor_block) 835 | self.curr_block = for_block 836 | self.curr_block.condition = True 837 | self.curr_block = self.populate_For_body(if_node.body) 838 | end_node = ast.AugAssign(target=ast.Name(id=f'p{num}', ctx=ast.Store()), op=ast.Add(), value=ast.Num(n=1)) 839 | self.add_stmt(self.curr_block, end_node) 840 | self.curr_block.for_loop = self.for_stack[-1] 841 | self.add_edge(self.curr_block.bid, loop_guard.bid) 842 | self.cfg.back_edges.append((self.curr_block.bid, loop_guard.bid)) 843 | 844 | self.curr_block = orelse_block 845 | for child in node.orelse: 846 | self.visit(child) 847 | self.add_edge(orelse_block.bid, afterfor_block.bid, "For") 848 | 849 | # Continue building the CFG in the after-for block. 850 | self.curr_block = afterfor_block 851 | self.continue_stack.pop() 852 | self.loop_stack.pop() 853 | self.for_stack.pop() 854 | 855 | def visit_GeneratorExp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]: 856 | if not generators: 857 | self.generic_visit(self.genExpReg[1].elt) # the location of the node may be wrong 858 | if self.genExpReg[0]: # bug if there is else statement in comprehension 859 | return [ast.Expr(value=ast.Yield(value=self.genExpReg[1].elt))] 860 | else: 861 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_GeneratorExp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_GeneratorExp_Rec(generators[:-1]), orelse=[])] 862 | 863 | def visit_GeneratorExp(self, node): 864 | try: # try may change to checking if self.genExpReg exists 865 | self.generic_visit(ast.FunctionDef(name='__' + self.genExpReg[0] + 'Generator__', 866 | args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), 867 | body = self.visit_GeneratorExp_Rec(self.genExpReg[1].generators), 868 | decorator_list=[], returns=None)) 869 | except: 870 | pass 871 | finally: 872 | self.genExpReg = None 873 | 874 | def visit_If(self, node): 875 | # Add the If statement at the end of the current block. 876 | self.add_stmt(self.curr_block, node) 877 | 878 | # Create a block for the code after the if-else. 879 | body_block = self.new_block() 880 | self.store_True.append(body_block.bid+1) 881 | if_block = self.add_edge(self.curr_block.bid, body_block.bid, node.test) 882 | node_list = [] 883 | # New block for the body of the else if there is an else clause. 884 | store = self.curr_block.bid 885 | node_list = self.populate_If_body(node.body, node_list) 886 | if node.orelse: 887 | self.curr_block = self.add_edge(store, self.new_block().bid, self.invert(node.test)) 888 | 889 | # Visit the children in the body of the else to populate the block. 890 | node_list = self.populate_If_body(node.orelse, node_list) 891 | else: 892 | node_list.append(store) 893 | # Visit children to populate the if block. 894 | self.curr_block = if_block 895 | afterif_block = self.new_block() 896 | for bid in node_list: 897 | self.add_edge(bid, afterif_block.bid) 898 | # Continue building the CFG in the after-if block. 899 | self.curr_block = afterif_block 900 | 901 | def visit_IfExp_Rec(self, node: Type[ast.AST]) -> List[Type[ast.AST]]: 902 | return [ast.If(test=node.test, body=[ast.Return(value=node.body)], orelse=self.visit_IfExp_Rec(node.orelse) if type(node.orelse) == ast.IfExp else [ast.Return(value=node.orelse)])] 903 | 904 | def visit_IfExp(self, node): 905 | if self.ifExp: 906 | self.generic_visit(ast.Module(self.visit_IfExp_Rec(node))) 907 | 908 | def visit_Lambda(self, node): # deprecated since there is autopep8 909 | self.add_subgraph(ast.FunctionDef(name=self.lambdaReg[0], args=node.args, body = [ast.Return(value=node.body)], decorator_list=[], returns=None)) 910 | self.lambdaReg = None 911 | 912 | def visit_ListComp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]: 913 | if not generators: 914 | self.generic_visit(self.listCompReg[1].elt) # the location of the node may be wrong 915 | if self.listCompReg[0]: # bug if there is else statement in comprehension 916 | return [ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id=self.listCompReg[0], ctx=ast.Load()), attr='append', ctx=ast.Load()), args=[self.listCompReg[1].elt], keywords=[]))] 917 | else: 918 | return [ast.Expr(value=self.listCompReg[1].elt)] 919 | else: 920 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_ListComp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_ListComp_Rec(generators[:-1]), orelse=[])] 921 | 922 | def visit_ListComp(self, node): 923 | try: # try may change to checking if self.listCompReg exists 924 | self.generic_visit(ast.Module(self.visit_ListComp_Rec(self.listCompReg[1].generators))) 925 | except: 926 | pass 927 | finally: 928 | self.listCompReg = None 929 | 930 | def visit_Pass(self, node): 931 | self.add_stmt(self.curr_block, node) 932 | 933 | def visit_Raise(self, node): 934 | self.add_stmt(self.curr_block, node) 935 | self.curr_block = self.new_block() 936 | 937 | def visit_Return(self, node): 938 | if type(node.value) == ast.IfExp: 939 | self.ifExp = True 940 | self.generic_visit(node) 941 | self.ifExp = False 942 | else: 943 | self.add_stmt(self.curr_block, node) 944 | # self.cfg.finalblocks.append(self.curr_block) 945 | # Continue in a new block but without any jump to it -> all code after 946 | # the return statement will not be included in the CFG. 947 | self.curr_block = self.new_block() 948 | 949 | def visit_SetComp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]: 950 | if not generators: 951 | self.generic_visit(self.setCompReg[1].elt) # the location of the node may be wrong 952 | if self.setCompReg[0]: 953 | return [ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id=self.setCompReg[0], ctx=ast.Load()), attr='add', ctx=ast.Load()), args=[self.setCompReg[1].elt], keywords=[]))] 954 | else: # not supported yet 955 | return [ast.Expr(value=self.setCompReg[1].elt)] 956 | else: 957 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_SetComp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_SetComp_Rec(generators[:-1]), orelse=[])] 958 | 959 | def visit_SetComp(self, node): 960 | try: # try may change to checking if self.setCompReg exists 961 | self.generic_visit(ast.Module(self.visit_SetComp_Rec(self.setCompReg[1].generators))) 962 | except: 963 | pass 964 | finally: 965 | self.setCompReg = None 966 | 967 | def visit_Try(self, node): 968 | loop_guard = self.add_loop_block() 969 | self.curr_block = loop_guard 970 | self.add_stmt(loop_guard, ast.Try(body=[], handlers=[], orelse=[], finalbody=[])) 971 | 972 | after_try_block = self.new_block() 973 | self.add_stmt(after_try_block, ast.Name(id='handle errors', ctx=ast.Load())) 974 | self.populate_body(node.body, after_try_block.bid, "Try") 975 | 976 | self.curr_block = after_try_block 977 | 978 | if node.handlers: 979 | for handler in node.handlers: 980 | before_handler_block = self.new_block() 981 | self.curr_block = before_handler_block 982 | self.add_edge(after_try_block.bid, before_handler_block.bid, handler.type if handler.type else ast.Name(id='Error', ctx=ast.Load())) 983 | 984 | after_handler_block = self.new_block() 985 | self.add_stmt(after_handler_block, ast.Name(id='end except', ctx=ast.Load())) 986 | self.populate_body(handler.body, after_handler_block.bid, "Try") 987 | self.add_edge(after_handler_block.bid, after_try_block.bid) 988 | 989 | if node.orelse: 990 | before_else_block = self.new_block() 991 | self.curr_block = before_else_block 992 | self.add_edge(after_try_block.bid, before_else_block.bid, ast.Name(id='No Error', ctx=ast.Load())) 993 | 994 | after_else_block = self.new_block() 995 | self.add_stmt(after_else_block, ast.Name(id='end no error', ctx=ast.Load())) 996 | self.populate_body(node.orelse, after_else_block.bid, "Try") 997 | self.add_edge(after_else_block.bid, after_try_block.bid) 998 | 999 | finally_block = self.new_block() 1000 | self.curr_block = finally_block 1001 | 1002 | if node.finalbody: 1003 | self.add_edge(after_try_block.bid, finally_block.bid, ast.Name(id='Finally', ctx=ast.Load())) 1004 | after_finally_block = self.new_block() 1005 | self.populate_body(node.finalbody, after_finally_block.bid, "Try") 1006 | self.curr_block = after_finally_block 1007 | else: 1008 | self.add_edge(after_try_block.bid, finally_block.bid) 1009 | 1010 | def visit_While(self, node): 1011 | loop_guard = self.add_loop_block() 1012 | self.continue_stack.append(loop_guard) 1013 | self.curr_block = loop_guard 1014 | self.add_stmt(loop_guard, node) 1015 | # New block for the case where the test in the while is False. 1016 | afterwhile_block = self.new_block() 1017 | self.loop_stack.append(afterwhile_block) 1018 | inverted_test = self.invert(node.test) 1019 | 1020 | if not node.orelse: 1021 | # Skip shortcut loop edge if while True: 1022 | if not (isinstance(inverted_test, ast.NameConstant) and inverted_test.value == False): 1023 | self.add_edge(self.curr_block.bid, afterwhile_block.bid, inverted_test) 1024 | 1025 | # New block for the case where the test in the while is True. 1026 | # Populate the while block. 1027 | body_block = self.new_block() 1028 | body_block.condition = True 1029 | self.curr_block = self.add_edge(self.curr_block.bid, body_block.bid, node.test) 1030 | self.populate_body(node.body, loop_guard.bid, "While") 1031 | else: 1032 | orelse_block = self.new_block() 1033 | if not (isinstance(inverted_test, ast.NameConstant) and inverted_test.value == False): 1034 | self.add_edge(self.curr_block.bid, orelse_block.bid, inverted_test) 1035 | self.curr_block = self.add_edge(self.curr_block.bid, self.new_block().bid, node.test) 1036 | 1037 | self.populate_body(node.body, loop_guard.bid, "While") 1038 | self.curr_block = orelse_block 1039 | for child in node.orelse: 1040 | self.visit(child) 1041 | self.add_edge(orelse_block.bid, afterwhile_block.bid) 1042 | 1043 | # Continue building the CFG in the after-while block. 1044 | self.curr_block = afterwhile_block 1045 | self.loop_stack.pop() 1046 | self.continue_stack.pop() 1047 | 1048 | def visit_Yield(self, node): 1049 | self.curr_block = self.add_edge(self.curr_block.bid, self.new_block().bid) 1050 | 1051 | 1052 | class PyParser: 1053 | 1054 | def __init__(self, script): 1055 | self.script = script 1056 | 1057 | def formatCode(self): 1058 | self.script = autopep8.fix_code(self.script) 1059 | 1060 | # https://github.com/liftoff/pyminifier/blob/master/pyminifier/minification.py 1061 | def removeCommentsAndDocstrings(self): 1062 | io_obj = io.StringIO(self.script) # ByteIO for Python2? 1063 | out = "" 1064 | prev_toktype = tokenize.INDENT 1065 | last_lineno = -1 1066 | last_col = 0 1067 | for tok in tokenize.generate_tokens(io_obj.readline): 1068 | token_type = tok[0] 1069 | token_string = tok[1] 1070 | start_line, start_col = tok[2] 1071 | end_line, end_col = tok[3] 1072 | if start_line > last_lineno: 1073 | last_col = 0 1074 | if start_col > last_col: 1075 | out += (" " * (start_col - last_col)) 1076 | # Remove comments: 1077 | if token_type == tokenize.COMMENT: 1078 | pass 1079 | # This series of conditionals removes docstrings: 1080 | elif token_type == tokenize.STRING: 1081 | if prev_toktype != tokenize.INDENT: 1082 | # This is likely a docstring; double-check we're not inside an operator: 1083 | if prev_toktype != tokenize.NEWLINE: 1084 | # Note regarding NEWLINE vs NL: The tokenize module 1085 | # differentiates between newlines that start a new statement 1086 | # and newlines inside of operators such as parens, brackes, 1087 | # and curly braces. Newlines inside of operators are 1088 | # NEWLINE and newlines that start new code are NL. 1089 | # Catch whole-module docstrings: 1090 | if start_col > 0: 1091 | # Unlabelled indentation means we're inside an operator 1092 | out += token_string 1093 | # Note regarding the INDENT token: The tokenize module does 1094 | # not label indentation inside of an operator (parens, 1095 | # brackets, and curly braces) as actual indentation. 1096 | # For example: 1097 | # def foo(): 1098 | # "The spaces before this docstring are tokenize.INDENT" 1099 | # test = [ 1100 | # "The spaces before this string do not get a token" 1101 | # ] 1102 | else: 1103 | out += token_string 1104 | prev_toktype = token_type 1105 | last_col = end_col 1106 | last_lineno = end_line 1107 | self.script = out 1108 | 1109 | if __name__ == '__main__': 1110 | filename = sys.argv[1] 1111 | file = filename.split('/')[-1].split('.')[0] 1112 | filepath = f'./CFG/{file}_CFG' 1113 | try: 1114 | source = open(filename, 'r').read() 1115 | compile(source, filename, 'exec') 1116 | except: 1117 | print('Error in source code') 1118 | exit(1) 1119 | 1120 | parser = PyParser(source) 1121 | parser.removeCommentsAndDocstrings() 1122 | parser.formatCode() 1123 | cfg = CFGVisitor().build(filename, ast.parse(parser.script)) 1124 | cfg.clean() 1125 | cfg.track_execution() 1126 | cfg.show(filepath) 1127 | 1128 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument( 6 | "-f", "--fff", help="a dummy argument to fool ipython", default="1" 7 | ) 8 | parser.add_argument( 9 | "--batch_size", type=int, default=256, help="Input batch size for training" 10 | ) 11 | parser.add_argument( 12 | "--hidden_dim", type=int, default=128, help="Dimension of hidden states" 13 | ) 14 | parser.add_argument( 15 | "--vocab_size", type=int, default=100000, help="Vocab size for training" 16 | ) 17 | parser.add_argument( 18 | "--max_node", type=int, default=100, help="Maximum number of nodes" 19 | ) 20 | parser.add_argument( 21 | "--max_token", type=int, default=512, help="Maximum number of tokens" 22 | ) 23 | parser.add_argument( 24 | "--learning_rate", type=float, default=0.001, help="Learning rate" 25 | ) 26 | parser.add_argument("--num_epoch", type=int, default=600, help="Epochs for training") 27 | parser.add_argument( 28 | "--cpu", action="store_true", default=False, help="Disables CUDA training" 29 | ) 30 | parser.add_argument("--gpu", type=int, default=0, help="GPU ID for CUDA training") 31 | parser.add_argument( 32 | "--save-model", 33 | action="store_true", 34 | default=False, 35 | help="For Saving the current Model", 36 | ) 37 | parser.add_argument("--extra_aggregate", action="store_true", default=False) 38 | parser.add_argument("--delete_redundant_node", action="store_true", default=False) 39 | parser.add_argument("--output", type=str, default="output") 40 | parser.add_argument("--checkpoint", type=int, default=None) 41 | parser.add_argument("--dryrun", action="store_true", default=False) 42 | # add config epoch that is a list of epochs 43 | parser.add_argument( 44 | "--vis_epoch", 45 | nargs="+", 46 | type=int, 47 | default=[], 48 | help="Epochs for visualization", 49 | ) 50 | #list of file code to visualize 51 | parser.add_argument( 52 | "--vis_code", 53 | nargs="+", 54 | type=int, 55 | default=[], 56 | help="Code to visualize", 57 | ) 58 | parser.add_argument("--groundtruth", action="store_true", default=False) 59 | parser.add_argument("--name_exp", type=str, default=None) 60 | parser.add_argument("--cuda_num", type=int, default=None) 61 | parser.add_argument("--seed", type=int, default=300103) 62 | parser.add_argument("--alpha", type=float, default=0.5) 63 | parser.add_argument("--epoch", type=int, default=None) 64 | parser.add_argument("--time", type=int, default=60) 65 | parser.add_argument("--data", type=str, default="CodeNet") 66 | parser.add_argument("--runtime_detection", action="store_true", default=False) 67 | parser.add_argument("--bug_localization", action="store_true", default=False) 68 | parser.add_argument("--claude_api_key", type=str, default=None) 69 | parser.add_argument("--model", type=str, default='claude-3-5-sonnet-20240620') 70 | parser.add_argument("--folder_path", type=str, default='fuzz_testing_dataset') 71 | args = parser.parse_args() 72 | return args -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext.data import Iterator 3 | import pandas as pd 4 | import config 5 | import torch 6 | 7 | def read_data(data_path, fields): 8 | csv_data = pd.read_csv(data_path, chunksize=100) 9 | all_examples = [] 10 | for n, chunk in enumerate(csv_data): 11 | 12 | examples = chunk.apply(lambda r: data.Example.fromlist([eval(r['nodes']), eval(r['forward']), eval(r['backward']), 13 | eval(r['target'])], fields), axis=1) 14 | all_examples.extend(list(examples)) 15 | return all_examples 16 | 17 | def get_iterators(args, device): 18 | TEXT = data.Field(tokenize=lambda x: x.split()[:args.max_token]) 19 | NODE = data.NestedField(TEXT, preprocessing=lambda x: x[:args.max_node], include_lengths=True) 20 | ROW = data.Field(pad_token=1.0, use_vocab=False, 21 | preprocessing=lambda x: [1, 1] if any(i > args.max_node for i in x) else x) 22 | EDGE = data.NestedField(ROW) 23 | TARGET = data.Field(use_vocab=False, preprocessing=lambda x: x[:args.max_node], pad_token=0, batch_first=True) 24 | 25 | fields = [("nodes", NODE), ("forward", EDGE), ("backward", EDGE), ("target", TARGET)] 26 | 27 | print('Read data...') 28 | examples = read_data(f'data/{args.data}_train.csv', fields) 29 | train = data.Dataset(examples, fields) 30 | NODE.build_vocab(train, max_size=args.vocab_size) 31 | 32 | examples = read_data(f'data/{args.data}_test.csv', fields) 33 | test = data.Dataset(examples, fields) 34 | train_iter = Iterator(train, 35 | batch_size=args.batch_size, 36 | device=device, 37 | sort=False, 38 | sort_key=lambda x: len(x.nodes), 39 | sort_within_batch=False, 40 | repeat=False) 41 | test_iter = Iterator(test, batch_size=args.batch_size, device=device, train=False, 42 | sort=False, sort_key=lambda x: len(x.nodes), sort_within_batch=False, repeat=False, shuffle=False) 43 | print("Done") 44 | return train_iter, test_iter 45 | -------------------------------------------------------------------------------- /fuzz_testing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import anthropic 3 | import ast, astor 4 | from cfg import * 5 | import re 6 | import sys 7 | import trace_execution 8 | import os 9 | import io 10 | import pandas as pd 11 | from torchtext import data 12 | from torchtext.data import Iterator 13 | import pandas as pd 14 | import torch 15 | import model 16 | import time 17 | import config 18 | import argparse 19 | 20 | def generate_prompt(method_code, feedback=""): 21 | prompt = f""" 22 | \n\nHuman: You are a terminal. Analyze the following Python code and generate likely inputs for all variables that might raise errors. Add these generated inputs at the beginning of the code snippet. 23 | 24 | Example: 25 | Python Method: 26 | if(S[0]=="A" and S[2,-1].count("C")==1): 27 | cnt=0 28 | for i in S: 29 | if(97<=ord(i) and ord(i)<=122): 30 | cnt+=1 31 | if(cnt==2): 32 | print("AC") 33 | else : 34 | print("WA") 35 | else : 36 | print("WA") 37 | 38 | Generated Input: 39 | S = 'AtCoder' 40 | 41 | Task: 42 | Given the following Python method, generate likely inputs for variables: 43 | {feedback} 44 | 45 | Python Method: 46 | {method_code} 47 | 48 | Generated Input: 49 | (No explanation needed, only one Generated Input:) 50 | \n\nAssistant: 51 | """ 52 | return prompt 53 | 54 | def get_generated_inputs(claude_api_key, model, method_code, feedback=""): 55 | client = anthropic.Anthropic(api_key=claude_api_key) 56 | prompt = generate_prompt(method_code, feedback) 57 | response = client.messages.create( 58 | model=model, 59 | max_tokens=1024, 60 | messages=[ 61 | {"role": "user", "content": prompt} 62 | ] 63 | ) 64 | return response.content[0].text 65 | 66 | def add_generated_inputs_to_code(code, inputs): 67 | lines = code.split('\n') 68 | # Find the first non-import line 69 | insert_index = 0 70 | for i, line in enumerate(lines): 71 | if not line.startswith(('import', 'from', '\n')): 72 | insert_index = i 73 | break 74 | 75 | # Insert generated inputs at the found index 76 | for input_line in inputs.split('\n'): 77 | if input_line.startswith("Generated"): 78 | continue 79 | if input_line.strip(): 80 | lines.insert(insert_index, input_line) 81 | insert_index += 1 82 | 83 | return '\n'.join(lines) 84 | 85 | def read_data(data_path, fields): 86 | csv_data = pd.read_csv(data_path, chunksize=100) 87 | all_examples = [] 88 | for n, chunk in enumerate(csv_data): 89 | examples = chunk.apply(lambda r: data.Example.fromlist([eval(r['nodes']), eval(r['forward']), eval(r['backward']), 90 | eval(r['target'])], fields), axis=1) 91 | all_examples.extend(list(examples)) 92 | return all_examples 93 | 94 | opt = config.parse() 95 | if opt.claude_api_key == None: 96 | raise Exception("Lack of CLAUDE api") 97 | if opt.cuda_num == None: 98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 99 | else: 100 | device = torch.device(f"cuda:{opt.cuda_num}" if torch.cuda.is_available() else "cpu") 101 | 102 | TEXT = data.Field(tokenize=lambda x: x.split()[:512]) 103 | NODE = data.NestedField(TEXT, preprocessing=lambda x: x[:100], include_lengths=True) 104 | ROW = data.Field(pad_token=1.0, use_vocab=False, 105 | preprocessing=lambda x: [1, 1] if any(i > 100 for i in x) else x) 106 | EDGE = data.NestedField(ROW) 107 | TARGET = data.Field(use_vocab=False, preprocessing=lambda x: x[:100], pad_token=0, batch_first=True) 108 | 109 | fields = [("nodes", NODE), ("forward", EDGE), ("backward", EDGE), ("target", TARGET)] 110 | 111 | print('Read data...') 112 | examples = read_data(f'data/FixEval_complete_train.csv', fields) 113 | train = data.Dataset(examples, fields) 114 | NODE.build_vocab(train, max_size=100000) 115 | 116 | orin_nodes = ['BEGIN', "_in = ['2', 3]", 'cont_str = _in[0] * _in[1]', 'cont_num = int(cont_str)', 'sqrt_flag = False', 'p1 = 0', 'p1 < len(range(4, 100))', 'T i = range(4, 100)[p1]', 'sqrt_flag', 'sqrt = i * i', 'cont_num == sqrt', 'T sqrt_flag = True', 'p1 += 1', "T print('Yes')", "print('No')", 'EXIT'] 117 | orin_fwd_edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (7, 9), (8, 10), (10, 11), (11, 12), (12, 9), (9, 14), (9, 15), (11, 13), (14, 16), (15, 16)] 118 | orin_back_edges = [(13, 7)] 119 | orin_exe_path = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1] 120 | 121 | net = model.CodeFlow(opt).to(device) 122 | checkpoint_path = f"checkpoints/checkpoints_{opt.checkpoint}/epoch-{opt.epoch}.pt" 123 | net.load_state_dict(torch.load(checkpoint_path, map_location=device)) 124 | net.eval() 125 | 126 | outpath = 'fuzz_testing_output' 127 | if not os.path.exists(outpath): 128 | os.makedirs(outpath) 129 | 130 | def extract_inputs(generated_text): 131 | # Use regular expression to match lines that are variable assignments 132 | input_lines = re.findall(r'^\s*\w+\s*=\s*.+$', generated_text, re.MULTILINE) 133 | return '\n'.join(input_lines) 134 | 135 | error_dict = {} 136 | locate = 0 137 | for root, _, files in os.walk(opt.folder_path): 138 | files = sorted(files, key=lambda x: int(x.split('.')[0][5:])) 139 | for file in files: 140 | print(f'Fuzz testing file {file}') 141 | feedback_list = [] 142 | start_time = time.time() 143 | time_limit = opt.time # time limit in seconds 144 | repeat = True 145 | while repeat: 146 | if time.time() - start_time > time_limit: 147 | print(f'Time limit exceeded for file {file}') 148 | break 149 | feedback = f"\nThese inputs did not raise runtime errors, avoid to generate the same:\n{feedback_list}" if feedback_list else "" 150 | file_path = os.path.join(opt.folder_path, file) 151 | with open(file_path, 'r') as f: 152 | code = f.read() 153 | generated_inputs = get_generated_inputs(opt.claude_api_key, opt.model, code, feedback) 154 | generated_inputs = extract_inputs(generated_inputs) 155 | print(generated_inputs) 156 | # Add generated inputs to the original code 157 | modified_code = add_generated_inputs_to_code(code, generated_inputs) 158 | filename = os.path.join(outpath, file) 159 | with open(filename, 'w') as modified_file: 160 | modified_file.write(modified_code) 161 | 162 | BlockId().counter = 0 163 | try: 164 | source = open(filename, 'r').read() 165 | compile(source, filename, 'exec') 166 | except: 167 | print('Error in source code') 168 | exit(1) 169 | parser = PyParser(source) 170 | parser.removeCommentsAndDocstrings() 171 | parser.formatCode() 172 | try: 173 | cfg = CFGVisitor().build(filename, ast.parse(parser.script)) 174 | except AttributeError: 175 | continue 176 | except IndentationError: 177 | continue 178 | except TypeError: 179 | continue 180 | except SyntaxError: 181 | continue 182 | 183 | cfg.clean() 184 | try: 185 | cfg.track_execution() 186 | except Exception: 187 | print("Generated input is not valid") 188 | continue 189 | code = {} 190 | for_loop = {} 191 | for i in cfg.blocks: 192 | if cfg.blocks[i].for_loop != 0: 193 | if cfg.blocks[i].for_loop not in for_loop: 194 | for_loop[cfg.blocks[i].for_loop] = [i] 195 | else: 196 | for_loop[cfg.blocks[i].for_loop].append(i) 197 | first = [] 198 | second = [] 199 | for i in for_loop: 200 | first.append(for_loop[i][0]+1) 201 | second.append(for_loop[i][1]) 202 | orin_node = [] 203 | track = {} 204 | track_for = {} 205 | for i in cfg.blocks: 206 | if cfg.blocks[i].stmts_to_code(): 207 | if int(i) == 1: 208 | st = 'BEGIN' 209 | elif int(i) == len(cfg.blocks): 210 | st = 'EXIT' 211 | else: 212 | if i in first: 213 | line = astor.to_source(cfg.blocks[i].for_name) 214 | st = line.split('\n')[0] 215 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 216 | else: 217 | st = cfg.blocks[i].stmts_to_code() 218 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 219 | orin_node.append([i, st, None]) 220 | if st not in track: 221 | track[st] = [len(orin_node)-1] 222 | else: 223 | track[st].append(len(orin_node)-1) 224 | track_for[i] = len(orin_node)-1 225 | with open(filename, 'r') as file_open: 226 | lines = file_open.readlines() 227 | for i in range(1, len(lines)+1): 228 | line = lines[i-1] 229 | #delete \n at the end of each line and delete all spaces 230 | line = line.strip() 231 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "") 232 | if line.startswith('elif'): 233 | line = line[2:] 234 | if line in track: 235 | orin_node[track[line][0]][2] = i 236 | if orin_node[track[line][0]][0] in first: 237 | orin_node[track[line][0]-1][2] = i-0.4 238 | orin_node[track[line][0]+1][2] = i+0.4 239 | if len(track[line]) > 1: 240 | track[line].pop(0) 241 | for i in second: 242 | max_val = 0 243 | for edge in cfg.edges: 244 | if edge[0] == i: 245 | if orin_node[track_for[edge[1]]][2] > max_val: 246 | max_val = orin_node[track_for[edge[1]]][2] 247 | if edge[1] == i: 248 | if orin_node[track_for[edge[0]]][2] > max_val: 249 | max_val = orin_node[track_for[edge[0]]][2] 250 | orin_node[track_for[i]][2] = max_val + 0.5 251 | orin_node[0][2] = 0 252 | orin_node[-1][2] = len(lines)+1 253 | # sort orin_node by the third element 254 | orin_node.sort(key=lambda x: x[2]) 255 | 256 | nodes = [] 257 | matching = {} 258 | for i in cfg.blocks: 259 | if cfg.blocks[i].stmts_to_code(): 260 | if int(i) == 1: 261 | nodes.append('BEGIN') 262 | elif int(i) == len(cfg.blocks): 263 | nodes.append('EXIT') 264 | else: 265 | st = cfg.blocks[i].stmts_to_code() 266 | st_no_space = re.sub(r"\s+", "", st) 267 | # if start with if or while, delete these keywords 268 | if st.startswith('if'): 269 | st = st[3:] 270 | elif st.startswith('while'): 271 | st = st[6:] 272 | if cfg.blocks[i].condition: 273 | st = 'T '+ st 274 | if st.endswith('\n'): 275 | st = st[:-1] 276 | if st.endswith(":"): 277 | st = st[:-1] 278 | nodes.append(st) 279 | matching[i] = len(nodes) 280 | 281 | fwd_edges = [] 282 | back_edges = [] 283 | edges = {} 284 | for edge in cfg.edges: 285 | if edge not in cfg.back_edges: 286 | fwd_edges.append((matching[edge[0]], matching[edge[1]])) 287 | else: 288 | back_edges.append((matching[edge[0]], matching[edge[1]])) 289 | if matching[edge[0]] not in edges: 290 | edges[matching[edge[0]]] = [matching[edge[1]]] 291 | else: 292 | edges[matching[edge[0]]].append(matching[edge[1]]) 293 | exe_path = [0 for i in range(len(nodes))] 294 | for i in range(len(cfg.path)): 295 | if cfg.path[i] == 1: 296 | exe_path[matching[i+1]-1] = 1 297 | out_nodes=[nodes, orin_nodes] 298 | out_fw_path=[fwd_edges, orin_fwd_edges] 299 | out_back_path=[back_edges, orin_back_edges] 300 | out_exe_path=[exe_path, orin_exe_path] 301 | data_example = { 302 | 'nodes': out_nodes, 303 | 'forward': out_fw_path, 304 | 'backward': out_back_path, 305 | 'target': out_exe_path, 306 | } 307 | 308 | df = pd.DataFrame(data_example) 309 | # Save to CSV 310 | df.to_csv(f'{outpath}/output.csv', index=False, quoting=1) 311 | examples = read_data(f'{outpath}/output.csv', fields) 312 | test = data.Dataset(examples, fields) 313 | test_iter = Iterator(test, batch_size=2, device=device, train=False, 314 | sort=False, sort_key=lambda x: len(x.nodes), sort_within_batch=False, repeat=False, shuffle=False) 315 | with torch.no_grad(): 316 | for batch in test_iter: 317 | x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float() 318 | if isinstance(x, tuple): 319 | pred = net(x[0], edges, x[1], x[2]) 320 | else: 321 | pred = net(x, edges) 322 | pred = pred[0].squeeze() 323 | pred = (pred > opt.beta).float() 324 | if pred[len(nodes)-1] == 1: 325 | print("No Runtime Error") 326 | feedback_list.append(generated_inputs) 327 | else: 328 | mask_pred = pred[:len(nodes)] == 1 329 | indices_pred = torch.nonzero(mask_pred).flatten() 330 | farthest_pred = indices_pred.max().item() 331 | error_line = nodes[farthest_pred] 332 | print(f"Runtime Error in line: {error_line}") 333 | 334 | mask_target = target[0][:len(nodes)] == 1 335 | indices_target = torch.nonzero(mask_target).flatten() 336 | farthest_target = indices_target.max().item() 337 | true_error_line = nodes[farthest_target] 338 | error_dict[file] = [error_line, true_error_line] 339 | 340 | if farthest_pred == farthest_target: 341 | locate += 1 342 | repeat = False 343 | 344 | locate_true = locate/len(error_dict)*100 345 | print(f'Fuzz testing within {opt.time}s') 346 | print(f'Sucessfully detect: {len(error_dict)}/{len(files)}') 347 | print(f'Bug Localization Acc: {locate_true:.2f}%') 348 | print(error_dict) 349 | -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_1.py: -------------------------------------------------------------------------------- 1 | ans=0 2 | cur=0 3 | ACGT=set("A","C","G","T") 4 | for ss in s: 5 | if ss in ACGT: 6 | cur+=1 7 | else: 8 | ans=max(cur,ans) 9 | cur=0 10 | print(max(ans,cur)) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_10.py: -------------------------------------------------------------------------------- 1 | S = 'nikoandsolstice' 2 | s = len(S) 3 | if (s <= K): 4 | print(S) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_11.py: -------------------------------------------------------------------------------- 1 | s = 'CSS' 2 | s=0 3 | for i in range(len(s)): 4 | if s[i]==t[i]: 5 | s+=1 6 | print(s) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_12.py: -------------------------------------------------------------------------------- 1 | col = N / 2 2 | if col == 0: 3 | print(col) 4 | else: 5 | col += 1 6 | print(col) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_13.py: -------------------------------------------------------------------------------- 1 | x=a+a*a+a*a*a 2 | print(a) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_14.py: -------------------------------------------------------------------------------- 1 | cont_str = _in[0] * _in[1] 2 | cont_num = int(cont_str) 3 | sqrt_flag = False 4 | for i in range(4, 100): 5 | sqrt = i * i 6 | if cont_num == sqrt: 7 | sqrt_flag = True 8 | break 9 | if sqrt_flag: 10 | print('Yes') 11 | else: 12 | print('No') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_15.py: -------------------------------------------------------------------------------- 1 | if s >= 3200: 2 | print(s) 3 | else: 4 | print("red") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_16.py: -------------------------------------------------------------------------------- 1 | if (s[:4] == 2019 and s[5] == 0 and s[6] <= 4) or (s[:4] < 2019): 2 | print("Heisei") 3 | else: 4 | print("TBD") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_17.py: -------------------------------------------------------------------------------- 1 | b = a / 3 2 | print(b) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_18.py: -------------------------------------------------------------------------------- 1 | import math 2 | m=N+1 3 | for j in range(2,math.sqrt(N)): 4 | if N%j==0: 5 | a=j 6 | b=N//j 7 | if a>b: 8 | break 9 | else: 10 | m=min(m,(a+b)) 11 | print(m-2) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_19.py: -------------------------------------------------------------------------------- 1 | A = 2 2 | B = ['i', 'p', ' ', 'c', 'c'] 3 | C=[] 4 | for i in range(A): 5 | C=A[i] 6 | C=B[i] 7 | print(C) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_2.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | for i in range(sqrt(n),0,-1): 3 | if n%i==0: 4 | j=n//i 5 | print(i+j-2) 6 | break -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_20.py: -------------------------------------------------------------------------------- 1 | a = [1, 1, 2, 2, 2] 2 | t=a[0] 3 | s=0 4 | Ans=0 5 | for i in range(1,n): 6 | if t==a[i]: 7 | s+=1 8 | else: 9 | Ans+=s//2 10 | s=0 11 | t=a[i] 12 | print(Ans) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_21.py: -------------------------------------------------------------------------------- 1 | if N % 1000 == 0: 2 | print(0) 3 | else: 4 | sen = N//1000 5 | sen2 = sen*1000 6 | atai = N-sen2 7 | print(atai) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_22.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | d = defaultdict(str) 3 | no_flag = False 4 | for i in range(len(S)): 5 | if S[i] != T[i]: 6 | if d[S[i]] == T[i]: 7 | S[i] = T[i] 8 | pass 9 | elif d[S[i]] == "" and d[T[i]] == "": 10 | d[S[i]] = T[i] 11 | d[T[i]] = S[i] 12 | else: 13 | pass 14 | S = sorted(S) 15 | T = sorted(T) 16 | if S == T: 17 | print("Yes") 18 | else: 19 | print("No") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_23.py: -------------------------------------------------------------------------------- 1 | if a == b & b == c: 2 | ans = 0 3 | else: 4 | for count in range(n): 5 | tmp = 0 6 | if a[count] == b[count]: 7 | if a[count] != c[count]: 8 | tmp += 1 9 | else: 10 | tmp += 1 11 | if (a[count] != c[count]) & (b[count] != c[count]): 12 | tmp += 1 13 | ans += tmp 14 | print(ans) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_24.py: -------------------------------------------------------------------------------- 1 | err = [] 2 | for i in range(len(s) - 3): 3 | err.append(abs(int(s[i:i+3]), 753)) 4 | print(min(err)) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_25.py: -------------------------------------------------------------------------------- 1 | if(s_input[1] == '+'): 2 | print(int(s_input[0]+int(s_input[2]))) 3 | else: 4 | print(int(s_input[0]-int(s_input[2]))) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_26.py: -------------------------------------------------------------------------------- 1 | for i in range(0,len(b)): 2 | s+=int(b[i]) 3 | if s-max(b)>max(b): 4 | print("Yes") 5 | else: 6 | print("No") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_27.py: -------------------------------------------------------------------------------- 1 | S[3]='7' 2 | print(S) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_28.py: -------------------------------------------------------------------------------- 1 | for i in range(len(s)): 2 | if s[i]=='1': 3 | s[i]='9' 4 | else: 5 | s[i]='1' 6 | print(s) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_29.py: -------------------------------------------------------------------------------- 1 | if N%2==0: 2 | print(N/2) 3 | else: 4 | print((N+1)/2) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_3.py: -------------------------------------------------------------------------------- 1 | s = 'abcabc' 2 | print(s[:n/2]==s[n/2:] and n%2==0) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_30.py: -------------------------------------------------------------------------------- 1 | size=len(S) 2 | print(S[0]+int(int(size)-2)+S[-1]) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_31.py: -------------------------------------------------------------------------------- 1 | if N % 2 == 1: 2 | print('No') 3 | exit() 4 | if S[:N/2] == S[N/2:]: 5 | print('Yes') 6 | else: 7 | print('No') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_32.py: -------------------------------------------------------------------------------- 1 | if a in 0: 2 | total = 0 3 | else: 4 | for i in a: 5 | total *= i 6 | if total > 10**18: 7 | total = -1 8 | print(total) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_33.py: -------------------------------------------------------------------------------- 1 | if n[1] == n[2] and n[0] == [1] or n[1] == n[2] and n[2] == n[3]: 2 | print('Yes') 3 | else: 4 | print('No') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_34.py: -------------------------------------------------------------------------------- 1 | for i in range(lst): 2 | if lst[i] == 0: 3 | ans = i+1 4 | print(ans) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_35.py: -------------------------------------------------------------------------------- 1 | print(a**3) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_36.py: -------------------------------------------------------------------------------- 1 | if a[3]!=8: 2 | a[3]=8 3 | print(a) 4 | else: 5 | print(a) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_37.py: -------------------------------------------------------------------------------- 1 | print(m % 1000) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_38.py: -------------------------------------------------------------------------------- 1 | res = a//2 2 | if a%2 >0: 3 | res += 1 4 | print(res) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_39.py: -------------------------------------------------------------------------------- 1 | for i in A: 2 | if A%2==0: 3 | if not A%3==0 or A%5==0: 4 | print("DENIED") 5 | x=1 6 | break 7 | else: 8 | continue 9 | if x !=0: 10 | print("APPROVED") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_4.py: -------------------------------------------------------------------------------- 1 | G = '2017' 2 | print(2*G-R) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_40.py: -------------------------------------------------------------------------------- 1 | a = ord(s) 2 | b = ord('ABC') 3 | r = ord('ARC') 4 | if a == b: 5 | print('ARC') 6 | elif a == b: 7 | print('ABC') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_41.py: -------------------------------------------------------------------------------- 1 | if S[0]!="1": 2 | print(S[0]) 3 | else: 4 | print(S[1]) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_42.py: -------------------------------------------------------------------------------- 1 | n[3] = '8' 2 | print(str(n)) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_43.py: -------------------------------------------------------------------------------- 1 | k = '4' 2 | flag = True 3 | for i in range(k): 4 | if s[i] == '1': 5 | print(s[i]) 6 | flag = False 7 | if flag: 8 | print(s[1]) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_44.py: -------------------------------------------------------------------------------- 1 | S[4] = '8' 2 | print(S) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_45.py: -------------------------------------------------------------------------------- 1 | S = 'abcabc' 2 | n = int(N/2) 3 | s1 = S[:n] 4 | s2 = S[n:] 5 | if s1 == s2: 6 | print('Yes') 7 | else: 8 | print('No') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_46.py: -------------------------------------------------------------------------------- 1 | d = a.split() 2 | b = d[0]*d[1] 3 | if(b%2==0): 4 | print('Even') 5 | else: 6 | print('Odd') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_47.py: -------------------------------------------------------------------------------- 1 | import sys 2 | input = sys.stdin.readline 3 | n = 5 4 | for i in a: 5 | print(a.count(str(i+1))) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_48.py: -------------------------------------------------------------------------------- 1 | print("Christmas",["Eve "] * 25 - N) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_49.py: -------------------------------------------------------------------------------- 1 | a = [1, 2, 3, 4, 5, 6] 2 | memo = sum(a) 3 | a=a[0] 4 | b=memo-a[0] 5 | ans = abs(a-b) 6 | for i in range(1,n-1): 7 | a += a[i] 8 | b -= a[i] 9 | ans = min(ans,abs(a-b)) 10 | print(ans) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_5.py: -------------------------------------------------------------------------------- 1 | a = r*r 2 | print(a) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_50.py: -------------------------------------------------------------------------------- 1 | if 7 in n: 2 | print("Yes") 3 | else: 4 | print("No") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_6.py: -------------------------------------------------------------------------------- 1 | ans = 10**12 2 | for k in range(1,(n+1)**0.5): 3 | if n%k == 0 : 4 | m = n//k + k - 2 5 | if ans > m: 6 | ans = m 7 | else: 8 | print(ans) 9 | sys.exit() 10 | print(ans) -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_7.py: -------------------------------------------------------------------------------- 1 | s = 'abcabc' 2 | if n%2==1: 3 | print('No') 4 | else: 5 | if s[:n/2]==s[n/2:]: 6 | print('Yes') 7 | else: 8 | print('No') -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_8.py: -------------------------------------------------------------------------------- 1 | num = "".join(num) 2 | if num % 4 == 0: 3 | print("YES") 4 | else: 5 | print("NO") -------------------------------------------------------------------------------- /fuzz_testing_dataset/code_9.py: -------------------------------------------------------------------------------- 1 | s=str(n) 2 | array = list(map(int,s)) 3 | if array%9==0: 4 | print("Yes") 5 | else: 6 | print("No") -------------------------------------------------------------------------------- /generate_dataset/generate_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import ast, astor, autopep8, tokenize, io, sys 3 | import graphviz as gv 4 | from typing import Dict, List, Tuple, Set, Optional, Type 5 | from cfg import * 6 | import re 7 | import sys 8 | import trace_execution 9 | import os 10 | import io 11 | import linecache 12 | import pandas as pd 13 | from sklearn.model_selection import train_test_split 14 | import argparse 15 | 16 | output_dir = 'dataset' 17 | inference_dir = 'inference' 18 | os.makedirs(inference_dir, exist_ok=True) 19 | files = os.listdir(output_dir) 20 | files.sort() 21 | out_nodes = [] 22 | out_fw_path = [] 23 | out_back_path = [] 24 | out_exe_path = [] 25 | out_file_names = [] 26 | max_node = 0 27 | max_edge = 0 28 | 29 | for file in files: 30 | BlockId().counter = 0 31 | filename = f'./{output_dir}/' + file 32 | try: 33 | source = open(filename, 'r').read() 34 | compile(source, filename, 'exec') 35 | except: 36 | print('Error in source code') 37 | continue 38 | 39 | parser = PyParser(source) 40 | parser.removeCommentsAndDocstrings() 41 | parser.formatCode() 42 | try: 43 | cfg = CFGVisitor().build(filename, ast.parse(parser.script)) 44 | except AttributeError: 45 | continue 46 | except IndentationError: 47 | continue 48 | except TypeError: 49 | continue 50 | except SyntaxError: 51 | continue 52 | cfg.clean() 53 | try: 54 | cfg.track_execution() 55 | except Exception: 56 | continue 57 | 58 | code = {} 59 | for_loop = {} 60 | for i in cfg.blocks: 61 | if cfg.blocks[i].for_loop != 0: 62 | if cfg.blocks[i].for_loop not in for_loop: 63 | for_loop[cfg.blocks[i].for_loop] = [i] 64 | else: 65 | for_loop[cfg.blocks[i].for_loop].append(i) 66 | first = [] 67 | second = [] 68 | for i in for_loop: 69 | first.append(for_loop[i][0]+1) 70 | second.append(for_loop[i][1]) 71 | orin_node = [] 72 | track = {} 73 | track_for = {} 74 | for i in cfg.blocks: 75 | if cfg.blocks[i].stmts_to_code(): 76 | if int(i) == 1: 77 | st = 'BEGIN' 78 | elif int(i) == len(cfg.blocks): 79 | st = 'EXIT' 80 | else: 81 | if i in first: 82 | line = astor.to_source(cfg.blocks[i].for_name) 83 | st = line.split('\n')[0] 84 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 85 | else: 86 | st = cfg.blocks[i].stmts_to_code() 87 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") 88 | orin_node.append([i, st, None]) 89 | if st not in track: 90 | track[st] = [len(orin_node)-1] 91 | else: 92 | track[st].append(len(orin_node)-1) 93 | track_for[i] = len(orin_node)-1 94 | with open(filename, 'r') as file_open: 95 | lines = file_open.readlines() 96 | for i in range(1, len(lines)+1): 97 | line = lines[i-1] 98 | #delete \n at the end of each line and delete all spaces 99 | line = line.strip() 100 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "") 101 | if line.startswith('elif'): 102 | line = line[2:] 103 | if line in track: 104 | orin_node[track[line][0]][2] = i 105 | if orin_node[track[line][0]][0] in first: 106 | orin_node[track[line][0]-1][2] = i-0.4 107 | orin_node[track[line][0]+1][2] = i+0.4 108 | if len(track[line]) > 1: 109 | track[line].pop(0) 110 | for i in second: 111 | max_val = 0 112 | for edge in cfg.edges: 113 | if edge[0] == i: 114 | if orin_node[track_for[edge[1]]][2] > max_val: 115 | max_val = orin_node[track_for[edge[1]]][2] 116 | if edge[1] == i: 117 | if orin_node[track_for[edge[0]]][2] > max_val: 118 | max_val = orin_node[track_for[edge[0]]][2] 119 | orin_node[track_for[i]][2] = max_val + 0.5 120 | orin_node[0][2] = 0 121 | orin_node[-1][2] = len(lines)+1 122 | # sort orin_node by the third element 123 | orin_node.sort(key=lambda x: x[2]) 124 | 125 | nodes = [] 126 | matching = {} 127 | for t in orin_node: 128 | i = t[0] 129 | if cfg.blocks[i].stmts_to_code(): 130 | if int(i) == 1: 131 | nodes.append('BEGIN') 132 | elif int(i) == len(cfg.blocks): 133 | nodes.append('EXIT') 134 | else: 135 | st = cfg.blocks[i].stmts_to_code() 136 | st_no_space = re.sub(r"\s+", "", st) 137 | # if start with if or while, delete these keywords 138 | if st.startswith('if'): 139 | st = st[3:] 140 | elif st.startswith('while'): 141 | st = st[6:] 142 | if cfg.blocks[i].condition: 143 | st = 'T '+ st 144 | if st.endswith('\n'): 145 | st = st[:-1] 146 | if st.endswith(":"): 147 | st = st[:-1] 148 | nodes.append(st) 149 | matching[i] = len(nodes) 150 | 151 | fwd_edges = [] 152 | back_edges = [] 153 | edges = {} 154 | for edge in cfg.edges: 155 | if edge not in cfg.back_edges: 156 | fwd_edges.append((matching[edge[0]], matching[edge[1]])) 157 | else: 158 | back_edges.append((matching[edge[0]], matching[edge[1]])) 159 | if matching[edge[0]] not in edges: 160 | edges[matching[edge[0]]] = [matching[edge[1]]] 161 | else: 162 | edges[matching[edge[0]]].append(matching[edge[1]]) 163 | exe_path = [0 for i in range(len(nodes))] 164 | for node in cfg.path: 165 | exe_path[matching[node]-1] = 1 166 | # check nodes, fwd_edges, back_edges, exe_path not exist in the list and then append 167 | print(f'Done in {file}') 168 | if nodes in out_nodes: 169 | # check the fwd_edges, back_edges, exe_path in the same index 170 | index = out_nodes.index(nodes) 171 | if fwd_edges != out_fw_path[index] or back_edges != out_back_path[index] or exe_path != out_exe_path[index]: 172 | out_nodes.append(nodes) 173 | out_fw_path.append(fwd_edges) 174 | out_back_path.append(back_edges) 175 | out_exe_path.append(exe_path) 176 | out_file_names.append(filename) 177 | else: 178 | out_nodes.append(nodes) 179 | out_fw_path.append(fwd_edges) 180 | out_back_path.append(back_edges) 181 | out_exe_path.append(exe_path) 182 | out_file_names.append(filename) 183 | 184 | data = { 185 | 'nodes': out_nodes, 186 | 'forward': out_fw_path, 187 | 'backward': out_back_path, 188 | 'target': out_exe_path, 189 | 'file_name': out_file_names 190 | } 191 | 192 | df = pd.DataFrame(data) 193 | 194 | # Split into training and testing sets 195 | train_df, test_df = train_test_split(df, test_size=0.2, random_state=42) 196 | 197 | # Save train set to CSV 198 | train_df.drop(columns=['file_name'], inplace=True) 199 | train_df.to_csv('train.csv', index=False, quoting=1) 200 | print("Train CSV file has been saved successfully.") 201 | 202 | # Save test set to CSV 203 | test_file_names = test_df['file_name'].tolist() 204 | test_df.drop(columns=['file_name'], inplace=True) 205 | test_df.to_csv('test.csv', index=False, quoting=1) 206 | print("Test CSV file has been saved successfully.") 207 | 208 | # Save each test file with the corresponding name in the "inference" folder 209 | for i, filename in enumerate(test_file_names): 210 | with open(filename, 'r') as f: 211 | source_code = f.read() 212 | new_filename = os.path.join(inference_dir, f"code_{i + 2}.py") 213 | with open(new_filename, 'w') as f: 214 | f.write(source_code) 215 | print("Test files have been saved successfully.") -------------------------------------------------------------------------------- /generate_dataset/trace_execution.py: -------------------------------------------------------------------------------- 1 | """program/module to trace Python program or function execution 2 | 3 | Sample use, command line: 4 | trace.py -c -f counts --ignore-dir '$prefix' spam.py eggs 5 | trace.py -t --ignore-dir '$prefix' spam.py eggs 6 | trace.py --trackcalls spam.py eggs 7 | 8 | Sample use, programmatically 9 | import sys 10 | 11 | # create a Trace object, telling it what to ignore, and whether to 12 | # do tracing or line-counting or both. 13 | tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,], 14 | trace=0, count=1) 15 | # run the new command using the given tracer 16 | tracer.run('main()') 17 | # make a report, placing output in /tmp 18 | r = tracer.results() 19 | r.write_results(show_missing=True, coverdir="/tmp") 20 | """ 21 | __all__ = ['Trace', 'CoverageResults'] 22 | 23 | import io 24 | import linecache 25 | import os 26 | import sys 27 | import sysconfig 28 | import token 29 | import tokenize 30 | import inspect 31 | import gc 32 | import dis 33 | import pickle 34 | from time import monotonic as _time 35 | 36 | import threading 37 | 38 | PRAGMA_NOCOVER = "#pragma NO COVER" 39 | 40 | class _Ignore: 41 | def __init__(self, modules=None, dirs=None): 42 | self._mods = set() if not modules else set(modules) 43 | self._dirs = [] if not dirs else [os.path.normpath(d) 44 | for d in dirs] 45 | self._ignore = { '': 1 } 46 | 47 | def names(self, filename, modulename): 48 | if modulename in self._ignore: 49 | return self._ignore[modulename] 50 | 51 | # haven't seen this one before, so see if the module name is 52 | # on the ignore list. 53 | if modulename in self._mods: # Identical names, so ignore 54 | self._ignore[modulename] = 1 55 | return 1 56 | 57 | # check if the module is a proper submodule of something on 58 | # the ignore list 59 | for mod in self._mods: 60 | # Need to take some care since ignoring 61 | # "cmp" mustn't mean ignoring "cmpcache" but ignoring 62 | # "Spam" must also mean ignoring "Spam.Eggs". 63 | if modulename.startswith(mod + '.'): 64 | self._ignore[modulename] = 1 65 | return 1 66 | 67 | # Now check that filename isn't in one of the directories 68 | if filename is None: 69 | # must be a built-in, so we must ignore 70 | self._ignore[modulename] = 1 71 | return 1 72 | 73 | # Ignore a file when it contains one of the ignorable paths 74 | for d in self._dirs: 75 | # The '+ os.sep' is to ensure that d is a parent directory, 76 | # as compared to cases like: 77 | # d = "/usr/local" 78 | # filename = "/usr/local.py" 79 | # or 80 | # d = "/usr/local.py" 81 | # filename = "/usr/local.py" 82 | if filename.startswith(d + os.sep): 83 | self._ignore[modulename] = 1 84 | return 1 85 | 86 | # Tried the different ways, so we don't ignore this module 87 | self._ignore[modulename] = 0 88 | return 0 89 | 90 | def _modname(path): 91 | """Return a plausible module name for the path.""" 92 | 93 | base = os.path.basename(path) 94 | filename, ext = os.path.splitext(base) 95 | return filename 96 | 97 | def _fullmodname(path): 98 | """Return a plausible module name for the path.""" 99 | 100 | # If the file 'path' is part of a package, then the filename isn't 101 | # enough to uniquely identify it. Try to do the right thing by 102 | # looking in sys.path for the longest matching prefix. We'll 103 | # assume that the rest is the package name. 104 | 105 | comparepath = os.path.normcase(path) 106 | longest = "" 107 | for dir in sys.path: 108 | dir = os.path.normcase(dir) 109 | if comparepath.startswith(dir) and comparepath[len(dir)] == os.sep: 110 | if len(dir) > len(longest): 111 | longest = dir 112 | 113 | if longest: 114 | base = path[len(longest) + 1:] 115 | else: 116 | base = path 117 | # the drive letter is never part of the module name 118 | drive, base = os.path.splitdrive(base) 119 | base = base.replace(os.sep, ".") 120 | if os.altsep: 121 | base = base.replace(os.altsep, ".") 122 | filename, ext = os.path.splitext(base) 123 | return filename.lstrip(".") 124 | 125 | class CoverageResults: 126 | def __init__(self, counts=None, calledfuncs=None, infile=None, 127 | callers=None, outfile=None): 128 | self.counts = counts 129 | if self.counts is None: 130 | self.counts = {} 131 | self.counter = self.counts.copy() # map (filename, lineno) to count 132 | self.calledfuncs = calledfuncs 133 | if self.calledfuncs is None: 134 | self.calledfuncs = {} 135 | self.calledfuncs = self.calledfuncs.copy() 136 | self.callers = callers 137 | if self.callers is None: 138 | self.callers = {} 139 | self.callers = self.callers.copy() 140 | self.infile = infile 141 | self.outfile = outfile 142 | if self.infile: 143 | # Try to merge existing counts file. 144 | try: 145 | with open(self.infile, 'rb') as f: 146 | counts, calledfuncs, callers = pickle.load(f) 147 | self.update(self.__class__(counts, calledfuncs, callers=callers)) 148 | except (OSError, EOFError, ValueError) as err: 149 | print(("Skipping counts file %r: %s" 150 | % (self.infile, err)), file=sys.stderr) 151 | 152 | def is_ignored_filename(self, filename): 153 | """Return True if the filename does not refer to a file 154 | we want to have reported. 155 | """ 156 | return filename.startswith('<') and filename.endswith('>') 157 | 158 | def update(self, other): 159 | """Merge in the data from another CoverageResults""" 160 | counts = self.counts 161 | calledfuncs = self.calledfuncs 162 | callers = self.callers 163 | other_counts = other.counts 164 | other_calledfuncs = other.calledfuncs 165 | other_callers = other.callers 166 | 167 | for key in other_counts: 168 | counts[key] = counts.get(key, 0) + other_counts[key] 169 | 170 | for key in other_calledfuncs: 171 | calledfuncs[key] = 1 172 | 173 | for key in other_callers: 174 | callers[key] = 1 175 | 176 | def write_results(self, show_missing=True, summary=False, coverdir=None): 177 | """ 178 | Write the coverage results. 179 | 180 | :param show_missing: Show lines that had no hits. 181 | :param summary: Include coverage summary per module. 182 | :param coverdir: If None, the results of each module are placed in its 183 | directory, otherwise it is included in the directory 184 | specified. 185 | """ 186 | if self.calledfuncs: 187 | print() 188 | print("functions called:") 189 | calls = self.calledfuncs 190 | for filename, modulename, funcname in sorted(calls): 191 | print(("filename: %s, modulename: %s, funcname: %s" 192 | % (filename, modulename, funcname))) 193 | 194 | if self.callers: 195 | print() 196 | print("calling relationships:") 197 | lastfile = lastcfile = "" 198 | for ((pfile, pmod, pfunc), (cfile, cmod, cfunc)) \ 199 | in sorted(self.callers): 200 | if pfile != lastfile: 201 | print() 202 | print("***", pfile, "***") 203 | lastfile = pfile 204 | lastcfile = "" 205 | if cfile != pfile and lastcfile != cfile: 206 | print(" -->", cfile) 207 | lastcfile = cfile 208 | print(" %s.%s -> %s.%s" % (pmod, pfunc, cmod, cfunc)) 209 | 210 | # turn the counts data ("(filename, lineno) = count") into something 211 | # accessible on a per-file basis 212 | per_file = {} 213 | for filename, lineno in self.counts: 214 | lines_hit = per_file[filename] = per_file.get(filename, {}) 215 | lines_hit[lineno] = self.counts[(filename, lineno)] 216 | 217 | # accumulate summary info, if needed 218 | sums = {} 219 | 220 | for filename, count in per_file.items(): 221 | if self.is_ignored_filename(filename): 222 | continue 223 | 224 | if filename.endswith(".pyc"): 225 | filename = filename[:-1] 226 | 227 | if coverdir is None: 228 | dir = os.path.dirname(os.path.abspath(filename)) 229 | modulename = _modname(filename) 230 | else: 231 | dir = coverdir 232 | os.makedirs(dir, exist_ok=True) 233 | modulename = _fullmodname(filename) 234 | 235 | # If desired, get a list of the line numbers which represent 236 | # executable content (returned as a dict for better lookup speed) 237 | if show_missing: 238 | lnotab = _find_executable_linenos(filename) 239 | else: 240 | lnotab = {} 241 | source = linecache.getlines(filename) 242 | coverpath = os.path.join(dir, modulename + ".cover") 243 | with open(filename, 'rb') as fp: 244 | encoding, _ = tokenize.detect_encoding(fp.readline) 245 | n_hits, n_lines = self.write_results_file(coverpath, source, 246 | lnotab, count, encoding) 247 | if summary and n_lines: 248 | percent = int(100 * n_hits / n_lines) 249 | sums[modulename] = n_lines, percent, modulename, filename 250 | 251 | 252 | if summary and sums: 253 | print("lines cov% module (path)") 254 | for m in sorted(sums): 255 | n_lines, percent, modulename, filename = sums[m] 256 | print("%5d %3d%% %s (%s)" % sums[m]) 257 | 258 | if self.outfile: 259 | # try and store counts and module info into self.outfile 260 | try: 261 | with open(self.outfile, 'wb') as f: 262 | pickle.dump((self.counts, self.calledfuncs, self.callers), 263 | f, 1) 264 | except OSError as err: 265 | print("Can't save counts files because %s" % err, file=sys.stderr) 266 | 267 | def write_results_file(self, path, lines, lnotab, lines_hit, encoding=None): 268 | """Return a coverage results file in path.""" 269 | # ``lnotab`` is a dict of executable lines, or a line number "table" 270 | 271 | try: 272 | outfile = open(path, "w", encoding=encoding) 273 | except OSError as err: 274 | print(("trace: Could not open %r for writing: %s " 275 | "- skipping" % (path, err)), file=sys.stderr) 276 | return 0, 0 277 | 278 | n_lines = 0 279 | n_hits = 0 280 | with outfile: 281 | for lineno, line in enumerate(lines, 1): 282 | # do the blank/comment match to try to mark more lines 283 | # (help the reader find stuff that hasn't been covered) 284 | if lineno in lines_hit: 285 | outfile.write("%5d: " % lines_hit[lineno]) 286 | n_hits += 1 287 | n_lines += 1 288 | elif lineno in lnotab and not PRAGMA_NOCOVER in line: 289 | # Highlight never-executed lines, unless the line contains 290 | # #pragma: NO COVER 291 | outfile.write(">>>>>> ") 292 | n_lines += 1 293 | else: 294 | outfile.write(" ") 295 | outfile.write(line.expandtabs(8)) 296 | 297 | return n_hits, n_lines 298 | 299 | def _find_lines_from_code(code, strs): 300 | """Return dict where keys are lines in the line number table.""" 301 | linenos = {} 302 | 303 | for _, lineno in dis.findlinestarts(code): 304 | if lineno not in strs: 305 | linenos[lineno] = 1 306 | 307 | return linenos 308 | 309 | def _find_lines(code, strs): 310 | """Return lineno dict for all code objects reachable from code.""" 311 | # get all of the lineno information from the code of this scope level 312 | linenos = _find_lines_from_code(code, strs) 313 | 314 | # and check the constants for references to other code objects 315 | for c in code.co_consts: 316 | if inspect.iscode(c): 317 | # find another code object, so recurse into it 318 | linenos.update(_find_lines(c, strs)) 319 | return linenos 320 | 321 | def _find_strings(filename, encoding=None): 322 | """Return a dict of possible docstring positions. 323 | 324 | The dict maps line numbers to strings. There is an entry for 325 | line that contains only a string or a part of a triple-quoted 326 | string. 327 | """ 328 | d = {} 329 | # If the first token is a string, then it's the module docstring. 330 | # Add this special case so that the test in the loop passes. 331 | prev_ttype = token.INDENT 332 | with open(filename, encoding=encoding) as f: 333 | tok = tokenize.generate_tokens(f.readline) 334 | for ttype, tstr, start, end, line in tok: 335 | if ttype == token.STRING: 336 | if prev_ttype == token.INDENT: 337 | sline, scol = start 338 | eline, ecol = end 339 | for i in range(sline, eline + 1): 340 | d[i] = 1 341 | prev_ttype = ttype 342 | return d 343 | 344 | def _find_executable_linenos(filename): 345 | """Return dict where keys are line numbers in the line number table.""" 346 | try: 347 | with tokenize.open(filename) as f: 348 | prog = f.read() 349 | encoding = f.encoding 350 | except OSError as err: 351 | print(("Not printing coverage data for %r: %s" 352 | % (filename, err)), file=sys.stderr) 353 | return {} 354 | code = compile(prog, filename, "exec") 355 | strs = _find_strings(filename, encoding) 356 | return _find_lines(code, strs) 357 | 358 | class Trace: 359 | def __init__(self, count=1, trace=1, countfuncs=0, countcallers=0, 360 | ignoremods=(), ignoredirs=(), infile=None, outfile=None, 361 | timing=False): 362 | """ 363 | @param count true iff it should count number of times each 364 | line is executed 365 | @param trace true iff it should print out each line that is 366 | being counted 367 | @param countfuncs true iff it should just output a list of 368 | (filename, modulename, funcname,) for functions 369 | that were called at least once; This overrides 370 | `count' and `trace' 371 | @param ignoremods a list of the names of modules to ignore 372 | @param ignoredirs a list of the names of directories to ignore 373 | all of the (recursive) contents of 374 | @param infile file from which to read stored counts to be 375 | added into the results 376 | @param outfile file in which to write the results 377 | @param timing true iff timing information be displayed 378 | """ 379 | self.exe_path = [] 380 | self.infile = infile 381 | self.outfile = outfile 382 | self.ignore = _Ignore(ignoremods, ignoredirs) 383 | self.counts = {} # keys are (filename, linenumber) 384 | self.pathtobasename = {} # for memoizing os.path.basename 385 | self.donothing = 0 386 | self.trace = trace 387 | self._calledfuncs = {} 388 | self._callers = {} 389 | self._caller_cache = {} 390 | self.start_time = None 391 | if timing: 392 | self.start_time = _time() 393 | if countcallers: 394 | self.globaltrace = self.globaltrace_trackcallers 395 | elif countfuncs: 396 | self.globaltrace = self.globaltrace_countfuncs 397 | elif trace and count: 398 | self.globaltrace = self.globaltrace_lt 399 | self.localtrace = self.localtrace_trace_and_count 400 | elif trace: 401 | self.globaltrace = self.globaltrace_lt 402 | self.localtrace = self.localtrace_trace 403 | elif count: 404 | self.globaltrace = self.globaltrace_lt 405 | self.localtrace = self.localtrace_count 406 | else: 407 | # Ahem -- do nothing? Okay. 408 | self.donothing = 1 409 | 410 | def run(self, cmd): 411 | import __main__ 412 | dict = __main__.__dict__ 413 | self.runctx(cmd, dict, dict) 414 | 415 | def runctx(self, cmd, globals=None, locals=None): 416 | if globals is None: globals = {} 417 | if locals is None: locals = {} 418 | if not self.donothing: 419 | threading.settrace(self.globaltrace) 420 | sys.settrace(self.globaltrace) 421 | try: 422 | exec(cmd, globals, locals) 423 | finally: 424 | if not self.donothing: 425 | sys.settrace(None) 426 | threading.settrace(None) 427 | 428 | def runfunc(self, func, /, *args, **kw): 429 | result = None 430 | if not self.donothing: 431 | sys.settrace(self.globaltrace) 432 | try: 433 | result = func(*args, **kw) 434 | finally: 435 | if not self.donothing: 436 | sys.settrace(None) 437 | return result 438 | 439 | def file_module_function_of(self, frame): 440 | code = frame.f_code 441 | filename = code.co_filename 442 | if filename: 443 | modulename = _modname(filename) 444 | else: 445 | modulename = None 446 | 447 | funcname = code.co_name 448 | clsname = None 449 | if code in self._caller_cache: 450 | if self._caller_cache[code] is not None: 451 | clsname = self._caller_cache[code] 452 | else: 453 | self._caller_cache[code] = None 454 | ## use of gc.get_referrers() was suggested by Michael Hudson 455 | # all functions which refer to this code object 456 | funcs = [f for f in gc.get_referrers(code) 457 | if inspect.isfunction(f)] 458 | # require len(func) == 1 to avoid ambiguity caused by calls to 459 | # new.function(): "In the face of ambiguity, refuse the 460 | # temptation to guess." 461 | if len(funcs) == 1: 462 | dicts = [d for d in gc.get_referrers(funcs[0]) 463 | if isinstance(d, dict)] 464 | if len(dicts) == 1: 465 | classes = [c for c in gc.get_referrers(dicts[0]) 466 | if hasattr(c, "__bases__")] 467 | if len(classes) == 1: 468 | # ditto for new.classobj() 469 | clsname = classes[0].__name__ 470 | # cache the result - assumption is that new.* is 471 | # not called later to disturb this relationship 472 | # _caller_cache could be flushed if functions in 473 | # the new module get called. 474 | self._caller_cache[code] = clsname 475 | if clsname is not None: 476 | funcname = "%s.%s" % (clsname, funcname) 477 | 478 | return filename, modulename, funcname 479 | 480 | def globaltrace_trackcallers(self, frame, why, arg): 481 | """Handler for call events. 482 | 483 | Adds information about who called who to the self._callers dict. 484 | """ 485 | if why == 'call': 486 | # XXX Should do a better job of identifying methods 487 | this_func = self.file_module_function_of(frame) 488 | parent_func = self.file_module_function_of(frame.f_back) 489 | self._callers[(parent_func, this_func)] = 1 490 | 491 | def globaltrace_countfuncs(self, frame, why, arg): 492 | """Handler for call events. 493 | 494 | Adds (filename, modulename, funcname) to the self._calledfuncs dict. 495 | """ 496 | if why == 'call': 497 | this_func = self.file_module_function_of(frame) 498 | self._calledfuncs[this_func] = 1 499 | 500 | def globaltrace_lt(self, frame, why, arg): 501 | """Handler for call events. 502 | 503 | If the code block being entered is to be ignored, returns `None', 504 | else returns self.localtrace. 505 | """ 506 | if why == 'call': 507 | code = frame.f_code 508 | filename = frame.f_globals.get('__file__', None) 509 | if filename: 510 | # XXX _modname() doesn't work right for packages, so 511 | # the ignore support won't work right for packages 512 | modulename = _modname(filename) 513 | if modulename is not None: 514 | ignore_it = self.ignore.names(filename, modulename) 515 | if not ignore_it: 516 | if self.trace: 517 | print((" --- modulename: %s, funcname: %s" 518 | % (modulename, code.co_name))) 519 | return self.localtrace 520 | else: 521 | return None 522 | 523 | def localtrace_trace_and_count(self, frame, why, arg): 524 | if why == "line": 525 | # record the file name and line number of every trace 526 | filename = frame.f_code.co_filename 527 | lineno = frame.f_lineno 528 | key = filename, lineno 529 | #Cuong 530 | print(key) 531 | self.counts[key] = self.counts.get(key, 0) + 1 532 | 533 | if self.start_time: 534 | print('%.2f' % (_time() - self.start_time), end=' ') 535 | bname = os.path.basename(filename) 536 | line = linecache.getline(filename, lineno) 537 | print("%s(%d)" % (bname, lineno), end='') 538 | if line: 539 | print(": ", line, end='') 540 | else: 541 | print() 542 | return self.localtrace 543 | 544 | def localtrace_trace(self, frame, why, arg): 545 | if why == "line": 546 | # record the file name and line number of every trace 547 | filename = frame.f_code.co_filename 548 | lineno = frame.f_lineno 549 | 550 | if self.start_time: 551 | print('%.2f' % (_time() - self.start_time), end=' ') 552 | bname = os.path.basename(filename) 553 | line = linecache.getline(filename, lineno) 554 | print("%s(%d)" % (bname, lineno), end='') 555 | if line: 556 | print(": ", line, end='') 557 | else: 558 | print() 559 | return self.localtrace 560 | 561 | def localtrace_count(self, frame, why, arg): 562 | if why == "line": 563 | filename = frame.f_code.co_filename 564 | lineno = frame.f_lineno 565 | key = filename, lineno 566 | self.exe_path.append(lineno) 567 | self.counts[key] = self.counts.get(key, 0) + 1 568 | return self.localtrace 569 | 570 | def results(self): 571 | return CoverageResults(self.counts, infile=self.infile, 572 | outfile=self.outfile, 573 | calledfuncs=self._calledfuncs, 574 | callers=self._callers) 575 | 576 | def main(): 577 | import argparse 578 | 579 | parser = argparse.ArgumentParser() 580 | parser.add_argument('--version', action='version', version='trace 2.0') 581 | 582 | grp = parser.add_argument_group('Main options', 583 | 'One of these (or --report) must be given') 584 | 585 | grp.add_argument('-c', '--count', action='store_true', 586 | help='Count the number of times each line is executed and write ' 587 | 'the counts to .cover for each module executed, in ' 588 | 'the module\'s directory. See also --coverdir, --file, ' 589 | '--no-report below.') 590 | grp.add_argument('-t', '--trace', action='store_true', 591 | help='Print each line to sys.stdout before it is executed') 592 | grp.add_argument('-l', '--listfuncs', action='store_true', 593 | help='Keep track of which functions are executed at least once ' 594 | 'and write the results to sys.stdout after the program exits. ' 595 | 'Cannot be specified alongside --trace or --count.') 596 | grp.add_argument('-T', '--trackcalls', action='store_true', 597 | help='Keep track of caller/called pairs and write the results to ' 598 | 'sys.stdout after the program exits.') 599 | 600 | grp = parser.add_argument_group('Modifiers') 601 | 602 | _grp = grp.add_mutually_exclusive_group() 603 | _grp.add_argument('-r', '--report', action='store_true', 604 | help='Generate a report from a counts file; does not execute any ' 605 | 'code. --file must specify the results file to read, which ' 606 | 'must have been created in a previous run with --count ' 607 | '--file=FILE') 608 | _grp.add_argument('-R', '--no-report', action='store_true', 609 | help='Do not generate the coverage report files. ' 610 | 'Useful if you want to accumulate over several runs.') 611 | 612 | grp.add_argument('-f', '--file', 613 | help='File to accumulate counts over several runs') 614 | grp.add_argument('-C', '--coverdir', 615 | help='Directory where the report files go. The coverage report ' 616 | 'for . will be written to file ' 617 | '//.cover') 618 | grp.add_argument('-m', '--missing', action='store_true', 619 | help='Annotate executable lines that were not executed with ' 620 | '">>>>>> "') 621 | grp.add_argument('-s', '--summary', action='store_true', 622 | help='Write a brief summary for each file to sys.stdout. ' 623 | 'Can only be used with --count or --report') 624 | grp.add_argument('-g', '--timing', action='store_true', 625 | help='Prefix each line with the time since the program started. ' 626 | 'Only used while tracing') 627 | 628 | grp = parser.add_argument_group('Filters', 629 | 'Can be specified multiple times') 630 | grp.add_argument('--ignore-module', action='append', default=[], 631 | help='Ignore the given module(s) and its submodules ' 632 | '(if it is a package). Accepts comma separated list of ' 633 | 'module names.') 634 | grp.add_argument('--ignore-dir', action='append', default=[], 635 | help='Ignore files in the given directory ' 636 | '(multiple directories can be joined by os.pathsep).') 637 | 638 | parser.add_argument('--module', action='store_true', default=False, 639 | help='Trace a module. ') 640 | parser.add_argument('progname', nargs='?', 641 | help='file to run as main program') 642 | parser.add_argument('arguments', nargs=argparse.REMAINDER, 643 | help='arguments to the program') 644 | 645 | opts = parser.parse_args() 646 | 647 | if opts.ignore_dir: 648 | _prefix = sysconfig.get_path("stdlib") 649 | _exec_prefix = sysconfig.get_path("platstdlib") 650 | 651 | def parse_ignore_dir(s): 652 | s = os.path.expanduser(os.path.expandvars(s)) 653 | s = s.replace('$prefix', _prefix).replace('$exec_prefix', _exec_prefix) 654 | return os.path.normpath(s) 655 | 656 | opts.ignore_module = [mod.strip() 657 | for i in opts.ignore_module for mod in i.split(',')] 658 | opts.ignore_dir = [parse_ignore_dir(s) 659 | for i in opts.ignore_dir for s in i.split(os.pathsep)] 660 | 661 | if opts.report: 662 | if not opts.file: 663 | parser.error('-r/--report requires -f/--file') 664 | results = CoverageResults(infile=opts.file, outfile=opts.file) 665 | return results.write_results(opts.missing, opts.summary, opts.coverdir) 666 | 667 | if not any([opts.trace, opts.count, opts.listfuncs, opts.trackcalls]): 668 | parser.error('must specify one of --trace, --count, --report, ' 669 | '--listfuncs, or --trackcalls') 670 | 671 | if opts.listfuncs and (opts.count or opts.trace): 672 | parser.error('cannot specify both --listfuncs and (--trace or --count)') 673 | 674 | if opts.summary and not opts.count: 675 | parser.error('--summary can only be used with --count or --report') 676 | 677 | if opts.progname is None: 678 | parser.error('progname is missing: required with the main options') 679 | 680 | t = Trace(opts.count, opts.trace, countfuncs=opts.listfuncs, 681 | countcallers=opts.trackcalls, ignoremods=opts.ignore_module, 682 | ignoredirs=opts.ignore_dir, infile=opts.file, 683 | outfile=opts.file, timing=opts.timing) 684 | try: 685 | if opts.module: 686 | import runpy 687 | module_name = opts.progname 688 | mod_name, mod_spec, code = runpy._get_module_details(module_name) 689 | sys.argv = [code.co_filename, *opts.arguments] 690 | globs = { 691 | '__name__': '__main__', 692 | '__file__': code.co_filename, 693 | '__package__': mod_spec.parent, 694 | '__loader__': mod_spec.loader, 695 | '__spec__': mod_spec, 696 | '__cached__': None, 697 | } 698 | else: 699 | sys.argv = [opts.progname, *opts.arguments] 700 | sys.path[0] = os.path.dirname(opts.progname) 701 | 702 | with io.open_code(opts.progname) as fp: 703 | code = compile(fp.read(), opts.progname, 'exec') 704 | # try to emulate __main__ namespace as much as possible 705 | globs = { 706 | '__file__': opts.progname, 707 | '__name__': '__main__', 708 | '__package__': None, 709 | '__cached__': None, 710 | } 711 | t.runctx(code, globs, globs) 712 | except OSError as err: 713 | sys.exit("Cannot run file %r because: %s" % (sys.argv[0], err)) 714 | except SystemExit: 715 | pass 716 | 717 | results = t.results() 718 | 719 | if not opts.no_report: 720 | results.write_results(opts.missing, opts.summary, opts.coverdir) 721 | 722 | if __name__=='__main__': 723 | main() -------------------------------------------------------------------------------- /img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/CodeFlow/ef145c4d7271ece6abbff9f6e8d4c2d630bdb5af/img/architecture.png -------------------------------------------------------------------------------- /img/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/CodeFlow/ef145c4d7271ece6abbff9f6e8d4c2d630bdb5af/img/pipeline.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import data 2 | import config 3 | import os 4 | import model 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | from sklearn.metrics import precision_recall_fscore_support 10 | import warnings 11 | import numpy as np 12 | from utils import write, pad_targets, accuracy_whole_list 13 | import numpy as np 14 | import random 15 | 16 | warnings.simplefilter("ignore") 17 | 18 | opt = config.parse() 19 | if torch.cuda.is_available(): 20 | torch.cuda.manual_seed(opt.seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | random.seed(opt.seed) 24 | np.random.seed(opt. seed) 25 | 26 | def train(opt, train_iter, valid_iter, device): 27 | net = model.CodeFlow(opt).to(device) 28 | criterion = nn.BCELoss() 29 | optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate) 30 | write("Start training...", opt.output) 31 | for epoch in range(opt.num_epoch): 32 | net.train() 33 | total_loss, total_accuracy = 0, 0 34 | total_train = 0 35 | for batch in tqdm(train_iter): 36 | x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float() 37 | 38 | if isinstance(x, tuple): 39 | pred = net(x[0], edges, x[1], x[2]) 40 | else: 41 | pred = net(x, edges) 42 | pred = pred.squeeze() 43 | 44 | loss = criterion(pred, target) 45 | optimizer.zero_grad() 46 | loss.backward() 47 | optimizer.step() 48 | total_loss += loss.item() 49 | pred = (pred > opt.alpha).float() 50 | accuracy = accuracy_whole_list(target.cpu().numpy(), pred.cpu().numpy(), x[1].cpu().numpy()) 51 | total_train += target.shape[0] 52 | total_accuracy += accuracy 53 | avg_loss = total_loss / len(train_iter) 54 | avg_accuracy = total_accuracy / total_train 55 | 56 | net.eval() 57 | eval_loss, eval_accuracy, eval_error_accuracy = 0, 0, 0 58 | y_true, y_pred = [], [] 59 | total_test = 0 60 | total_local = 0 61 | total_detect = 0 62 | locate_bug = 0 63 | detect_true = 0 64 | with torch.no_grad(): 65 | for batch in valid_iter: 66 | x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float() 67 | num_nodes = x[1] 68 | if isinstance(x, tuple): 69 | pred = net(x[0], edges, x[1], x[2]) 70 | else: 71 | pred = net(x, edges) 72 | pred = pred.squeeze() 73 | 74 | loss = criterion(pred, target) 75 | eval_loss += loss.item() 76 | pred = (pred > opt.alpha).float() 77 | if opt.runtime_detection: 78 | for i in range(len(x[1])): 79 | total_detect += 1 80 | if pred[i][x[1][i]-1] == target[i][x[1][i]-1]: 81 | detect_true += 1 82 | if opt.bug_localization: 83 | for i in range(len(x[1])): 84 | target_list = [] 85 | pred_list = [] 86 | num_nodes_list = [] 87 | if target[i][x[1][i]-1] == 1: 88 | continue 89 | total_local += 1 90 | mask_pred = pred[i] == 1 91 | indices_pred = torch.nonzero(mask_pred).flatten() 92 | farthest_pred = indices_pred.max().item() 93 | 94 | mask_target = target[i] == 1 95 | 96 | indices_target = torch.nonzero(mask_target).flatten() 97 | farthest_target = indices_target.max().item() 98 | if farthest_pred == farthest_target: 99 | locate_bug += 1 100 | target_list.append(target[i].cpu().numpy()) 101 | pred_list.append(pred[i].cpu().numpy()) 102 | num_nodes_list.append(num_nodes[i].cpu().numpy()) 103 | error_accuracy = accuracy_whole_list(target_list, pred_list, num_nodes_list) 104 | eval_error_accuracy += error_accuracy 105 | accuracy = accuracy_whole_list(target.cpu().numpy(), pred.cpu().numpy(), num_nodes.cpu().numpy()) 106 | eval_accuracy += accuracy 107 | total_test += target.shape[0] 108 | # append target to y_true and pred to y_pred base on the number of node in num_nodes 109 | for i in range(len(num_nodes)): 110 | y_true.append(target[i, :num_nodes[i]].cpu().numpy()) 111 | y_pred.append(pred[i, :num_nodes[i]].cpu().numpy()) 112 | avg_eval_loss = eval_loss / len(valid_iter) 113 | avg_eval_accuracy = eval_accuracy / total_test 114 | # concatenate all the target and prediction 115 | y_true = np.concatenate(y_true) 116 | y_pred = np.concatenate(y_pred) 117 | 118 | precision, recall, fscore, _ = precision_recall_fscore_support(y_true, y_pred, average='binary') 119 | write(f"Epoch {epoch + 1}/{opt.num_epoch}", opt.output) 120 | write(f"Train Loss: {avg_loss:.4f}, Train Accuracy: {avg_accuracy:.4f}", opt.output) 121 | write(f"Validation Loss: {avg_eval_loss:.4f}, Validation Accuracy: {avg_eval_accuracy:.4f}", opt.output) 122 | if opt.runtime_detection: 123 | detect_acc = (detect_true / total_detect)*100 124 | write(f"Runtime Error Detection: {detect_acc:.4f}", opt.output) 125 | if opt.bug_localization: 126 | locate_acc = (locate_bug/total_local)*100 127 | write(f"BUG Localization: {locate_acc:.4f}", opt.output) 128 | write(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F-score: {fscore:.4f}", opt.output) 129 | write("_________________________________________________", opt.output) 130 | 131 | if (epoch+1) % 10 == 0: 132 | os.makedirs(f'checkpoints/checkpoints_{opt.checkpoint}', exist_ok=True) 133 | torch.save(net.state_dict(), f"checkpoints/checkpoints_{opt.checkpoint}/epoch-{epoch + 1}.pt") 134 | 135 | return net 136 | 137 | def main(): 138 | if not os.path.exists('checkpoints'): 139 | os.makedirs('checkpoints') 140 | if opt.checkpoint == None: 141 | files = os.listdir("checkpoints") 142 | opt.checkpoint = len(files)+1 143 | if opt.name_exp == None: 144 | opt.output = f'{opt.output}/checkpoint_{opt.checkpoint}_{opt.seed}' 145 | else: 146 | opt.output = f'{opt.output}/checkpoint_{opt.checkpoint}_{opt.seed}_{opt.name_exp}' 147 | print(opt.output) 148 | os.makedirs(os.path.dirname(opt.output), exist_ok=True) 149 | open(opt.output, 'w').close() 150 | if opt.cuda_num == None: 151 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 152 | else: 153 | device = torch.device(f"cuda:{opt.cuda_num}" if torch.cuda.is_available() else "cpu") 154 | if opt.seed != None: 155 | random.seed(opt.seed) 156 | print(f"Using device: {device}") 157 | train_iter, test_iter = data.get_iterators(opt, device) 158 | train(opt, train_iter, test_iter, device) 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import write 5 | 6 | class CodeFlow(nn.Module): 7 | def __init__(self, opt): 8 | super(CodeFlow, self).__init__() 9 | self.max_node = opt.max_node 10 | self.hidden_dim = opt.hidden_dim 11 | self.embedding = nn.Embedding(opt.vocab_size+2, opt.hidden_dim, padding_idx=1) 12 | self.node_lstm = nn.LSTM(self.hidden_dim, self.hidden_dim//2, bidirectional=True, batch_first=True) 13 | self.gate = Gate(self.hidden_dim, self.hidden_dim) 14 | self.back_gate = Gate(self.hidden_dim, self.hidden_dim) 15 | self.concat = nn.Linear(self.hidden_dim, self.hidden_dim) 16 | self.fc_output = nn.Linear(self.hidden_dim, 1) # Adjusted this layer to match the hidden dimensions 17 | self.opt = opt 18 | 19 | # @profile 20 | def forward(self, x, edges, node_lens=None, token_lens=None, target=None): 21 | f_edges, b_edges = edges 22 | batch_size, num_node, num_token = x.size(0), x.size(1), x.size(2) 23 | 24 | # token_ids [bs,num_node,num_token] 25 | # x: ([bs, 34, 12], [bs], [bs, 34]) 26 | # num_node lengths of node (num_token) 27 | # f_edges, b_edges: ([bs, 38, 2], [bs, 6, 2]) -> max_edge_forward, max_edge_back 28 | # target: [bs, 34] 29 | 30 | x = self.embedding(x) # [B, N, L, H] 31 | if self.opt.extra_aggregate: 32 | neigbors = [{} for _ in range(len(f_edges))] 33 | for i in range(len(f_edges)): 34 | for (start, end) in f_edges[i]: 35 | start = start.item() 36 | end = end.item() 37 | if start == 1 and end == 1: 38 | continue 39 | if start not in neigbors[i]: 40 | neigbors[i][start] = [end] 41 | else: 42 | neigbors[i][start].append(end) 43 | if end not in neigbors[i]: 44 | neigbors[i][end] = [start] 45 | else: 46 | neigbors[i][end].append(start) 47 | for i in range(len(b_edges)): 48 | for (start, end) in b_edges[i]: 49 | start = start.item() 50 | end = end.item() 51 | if start == 1 and end == 1: 52 | continue 53 | if start not in neigbors[i]: 54 | neigbors[i][start] = [end] 55 | else: 56 | neigbors[i][start].append(end) 57 | if end not in neigbors[i]: 58 | neigbors[i][end] = [start] 59 | else: 60 | neigbors[i][end].append(start) 61 | max_node = max(node_lens) 62 | matrix = torch.zeros((batch_size, max_node, max_node), dtype=torch.float, device=x.device) 63 | for i in range(batch_size): 64 | if self.opt.delete_redundant_node: 65 | for node in neigbors[i]: 66 | num_neigbors = len(neigbors[i][node]) 67 | for neighbor in neigbors[i][node]: 68 | matrix[i, node-1, neighbor-1] = (1-self.opt.alpha)/num_neigbors 69 | matrix[i, node-1, node-1] = self.opt.alpha 70 | else: 71 | for node in range(max_node): 72 | if node in neigbors[i].keys(): 73 | num_neigbors = len(neigbors[i][node]) 74 | for neighbor in neigbors[i][node]: 75 | matrix[i, node-1, neighbor-1] = (1-self.opt.alpha)/num_neigbors 76 | matrix[i, node-1, node-1] = self.opt.alpha 77 | else: 78 | matrix[i, node-1, node-1] = 1 79 | 80 | #! Node LSTM embedding, https://www.readcube.com/library/1771e2fb-bec1-4bc4-90b3-04c8786fe9dd:fd440d39-f13e-430c-b768-751878616cda, 2nd figure, Node Embedding part 81 | if token_lens is not None: 82 | x = x.view(batch_size*num_node, num_token, -1) 83 | h_n = torch.zeros((2, batch_size*num_node, self.hidden_dim//2)).to(x.device) 84 | c_n = torch.zeros((2, batch_size*num_node, self.hidden_dim//2)).to(x.device) 85 | x, _ = self.node_lstm(x, (h_n, c_n)) # [B*N, L, H] 86 | x = x.view(batch_size, num_node, num_token, -1) 87 | x = self.average_pooling(x, token_lens) 88 | else: 89 | x = torch.mean(x, dim=2) # [B, N, H] 90 | 91 | # ! Initialize hidden states to be zeros 92 | h_f = torch.zeros(x.size()).to(x.device) 93 | c_f = torch.zeros(x.size()).to(x.device) 94 | 95 | # ! Forward pass: including forward egde + backward edge, 1->K 96 | ori_f_matrix = self.convert_to_matrix(batch_size, num_node, f_edges) 97 | running_f_matrix = ori_f_matrix.clone() 98 | for i in range(num_node): 99 | f_i = running_f_matrix[:, i, :].unsqueeze(1) 100 | f_i = f_i.clone() 101 | x_cur = x[:, i, :].squeeze(1) # [B, hidden_dim] 102 | h_last, c_last = f_i.bmm(h_f), f_i.bmm(c_f) # h = [B, max_node, H] 103 | # h_last = [B, 1, H] 104 | # Stopping to check if the node is binary 105 | # [B, 1, max_node] * [B, max_node, hidden_dim] = [B, 1, hidden_dim] 106 | # h_last, c_last = [B, 1, hidden_dim] 107 | h_i, c_i = self.gate(x_cur, h_last.squeeze(1), c_last.squeeze(1)) 108 | h_f[:, i, :], c_f[:, i, :] = h_i, c_i 109 | # make the f_matrix, the next nodes j, which connect to i->j. Change their jth row at ith entry 110 | h_i, c_i = h_i.squeeze(1), c_i.squeeze(1) 111 | # for sample_id in range(batch_size): 112 | # next_node_ids = [] 113 | # for j in range(num_node): 114 | # if running_f_matrix[sample_id, j, i] == 1: 115 | # next_node_ids.append(j) 116 | 117 | # if len(next_node_ids) > 2: 118 | # print(sample_id) 119 | # print(torch.sum(running_f_matrix, dim=1)) 120 | # # raise ValueError(f"Node {i+1} in sample_id: {sample_id} has more than 2 outward edges") 121 | # if len(next_node_ids) == 2: 122 | # if h_i[sample_id].sum() >= 0: 123 | # running_f_matrix[sample_id, next_node_ids[0], i] = 0 124 | # else: 125 | # running_f_matrix[sample_id, next_node_ids[1], i] = 0 126 | 127 | 128 | b_matrix = self.convert_to_matrix(batch_size, num_node, b_edges) 129 | for j in range(num_node): 130 | b_j = b_matrix[:, j, :].unsqueeze(1) 131 | h_temp = b_j.bmm(h_f) 132 | h_f[:, j, :] += h_temp.squeeze(1) 133 | 134 | # # ! Initialize hidden states to be zeros 135 | # h_b = torch.zeros(x.size()).to(x.device) 136 | # c_b = torch.zeros(x.size()).to(x.device) 137 | 138 | # # # ! Backward pass: transpose b_matrix, f_matrix, including forward egde + backward edge, K->1 139 | # b_matrix = self.convert_to_matrix(batch_size, num_node, f_edges) 140 | # b_matrix = b_matrix.transpose(1, 2) 141 | # for i in reversed(range(num_node)): 142 | # x_cur = x[:, i, :].squeeze(1) 143 | # b_i = b_matrix[:, i, :].unsqueeze(1) 144 | # h_hat, c_hat = b_i.bmm(h_b), b_i.bmm(c_b) 145 | # h_b[:, i, :], c_b[:, i, :] = self.back_gate(x_cur, h_hat.squeeze(), c_hat.squeeze()) 146 | 147 | # f_matrix = self.convert_to_matrix(batch_size, num_node, b_edges) 148 | # f_matrix = f_matrix.transpose(1, 2) 149 | # for j in range(num_node): 150 | # f_j = f_matrix[:, j, :].unsqueeze(1) 151 | # h_temp = f_j.bmm(h_b) 152 | # h_b[:, j, :] += h_temp.squeeze(1) 153 | 154 | # ------------Prediction stage --------------# 155 | 156 | # h = torch.cat([h_f, h_b], dim=2) 157 | # output = torch.mean(h, dim=1) # take the mean over the nodes within a batch -> [B, H] 158 | # h = [B, max_node, H] -> each node is feeded into the fc_output 159 | 160 | # B, max_node, H -> B, max_node 161 | output = torch.sigmoid(self.fc_output(h_f)) # 162 | if self.opt.extra_aggregate: 163 | output = torch.bmm(matrix, output) 164 | return output 165 | 166 | @staticmethod 167 | def average_pooling(data, input_lens): 168 | B, N, T, H = data.size() 169 | idx = torch.arange(T, device=data.device).unsqueeze(0).expand(B, N, -1) 170 | idx = idx < input_lens.unsqueeze(2) 171 | idx = idx.unsqueeze(3).expand(-1, -1, -1, H) 172 | ret = (data.float() * idx.float()).sum(2) / (input_lens.unsqueeze(2).float()+10**-32) 173 | return ret 174 | 175 | @staticmethod 176 | def convert_to_matrix(batch_size, max_num, m): 177 | matrix = torch.zeros((batch_size, max_num, max_num), dtype=torch.float, device=m.device) 178 | m -= 1 179 | b_select = torch.arange(batch_size).unsqueeze(1).expand(batch_size, m.size(1)).contiguous().view(-1) 180 | matrix[b_select, m[:, :, 1].contiguous().view(-1), m[:, :, 0].contiguous().view(-1)] = 1 181 | matrix[:, 0, 0] = 0 182 | return matrix 183 | 184 | 185 | class Gate(nn.Module): 186 | def __init__(self, in_dim, mem_dim): 187 | super(Gate, self).__init__() 188 | self.in_dim = in_dim 189 | self.mem_dim = mem_dim 190 | self.ax = nn.Linear(self.in_dim, 3 * self.mem_dim) 191 | self.ah = nn.Linear(self.mem_dim, 3 * self.mem_dim) 192 | self.fx = nn.Linear(self.in_dim, self.mem_dim) 193 | self.fh = nn.Linear(self.mem_dim, self.mem_dim) 194 | 195 | def forward(self, inputs, last_h, pred_c): 196 | iou = self.ax(inputs) + self.ah(last_h) 197 | i, o, u = torch.split(iou, iou.size(1) // 3, dim=1) 198 | i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u) 199 | f = torch.sigmoid(self.fh(last_h) + self.fx(inputs)) 200 | fc = torch.mul(f, pred_c) 201 | c = torch.mul(i, u) + fc 202 | h = torch.mul(o, torch.tanh(c)) 203 | return h, c -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autopep8 2 | graphviz 3 | pandas 4 | torch 5 | torchtext==0.4.0 6 | tqdm 7 | scikit-learn 8 | astor 9 | anthropic 10 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a new virtual environment named codeflow with Python 3.10 4 | conda create -n codeflow python=3.10 -y 5 | 6 | # Activate the virtual environment 7 | source activate codeflow 8 | 9 | # Install the required Python libraries 10 | pip install -r requirements.txt -------------------------------------------------------------------------------- /trace_execution.py: -------------------------------------------------------------------------------- 1 | """program/module to trace Python program or function execution 2 | 3 | Sample use, command line: 4 | trace.py -c -f counts --ignore-dir '$prefix' spam.py eggs 5 | trace.py -t --ignore-dir '$prefix' spam.py eggs 6 | trace.py --trackcalls spam.py eggs 7 | 8 | Sample use, programmatically 9 | import sys 10 | 11 | # create a Trace object, telling it what to ignore, and whether to 12 | # do tracing or line-counting or both. 13 | tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,], 14 | trace=0, count=1) 15 | # run the new command using the given tracer 16 | tracer.run('main()') 17 | # make a report, placing output in /tmp 18 | r = tracer.results() 19 | r.write_results(show_missing=True, coverdir="/tmp") 20 | """ 21 | __all__ = ['Trace', 'CoverageResults'] 22 | 23 | import io 24 | import linecache 25 | import os 26 | import sys 27 | import sysconfig 28 | import token 29 | import tokenize 30 | import inspect 31 | import gc 32 | import dis 33 | import pickle 34 | from time import monotonic as _time 35 | 36 | import threading 37 | 38 | PRAGMA_NOCOVER = "#pragma NO COVER" 39 | 40 | class _Ignore: 41 | def __init__(self, modules=None, dirs=None): 42 | self._mods = set() if not modules else set(modules) 43 | self._dirs = [] if not dirs else [os.path.normpath(d) 44 | for d in dirs] 45 | self._ignore = { '': 1 } 46 | 47 | def names(self, filename, modulename): 48 | if modulename in self._ignore: 49 | return self._ignore[modulename] 50 | 51 | # haven't seen this one before, so see if the module name is 52 | # on the ignore list. 53 | if modulename in self._mods: # Identical names, so ignore 54 | self._ignore[modulename] = 1 55 | return 1 56 | 57 | # check if the module is a proper submodule of something on 58 | # the ignore list 59 | for mod in self._mods: 60 | # Need to take some care since ignoring 61 | # "cmp" mustn't mean ignoring "cmpcache" but ignoring 62 | # "Spam" must also mean ignoring "Spam.Eggs". 63 | if modulename.startswith(mod + '.'): 64 | self._ignore[modulename] = 1 65 | return 1 66 | 67 | # Now check that filename isn't in one of the directories 68 | if filename is None: 69 | # must be a built-in, so we must ignore 70 | self._ignore[modulename] = 1 71 | return 1 72 | 73 | # Ignore a file when it contains one of the ignorable paths 74 | for d in self._dirs: 75 | # The '+ os.sep' is to ensure that d is a parent directory, 76 | # as compared to cases like: 77 | # d = "/usr/local" 78 | # filename = "/usr/local.py" 79 | # or 80 | # d = "/usr/local.py" 81 | # filename = "/usr/local.py" 82 | if filename.startswith(d + os.sep): 83 | self._ignore[modulename] = 1 84 | return 1 85 | 86 | # Tried the different ways, so we don't ignore this module 87 | self._ignore[modulename] = 0 88 | return 0 89 | 90 | def _modname(path): 91 | """Return a plausible module name for the path.""" 92 | 93 | base = os.path.basename(path) 94 | filename, ext = os.path.splitext(base) 95 | return filename 96 | 97 | def _fullmodname(path): 98 | """Return a plausible module name for the path.""" 99 | 100 | # If the file 'path' is part of a package, then the filename isn't 101 | # enough to uniquely identify it. Try to do the right thing by 102 | # looking in sys.path for the longest matching prefix. We'll 103 | # assume that the rest is the package name. 104 | 105 | comparepath = os.path.normcase(path) 106 | longest = "" 107 | for dir in sys.path: 108 | dir = os.path.normcase(dir) 109 | if comparepath.startswith(dir) and comparepath[len(dir)] == os.sep: 110 | if len(dir) > len(longest): 111 | longest = dir 112 | 113 | if longest: 114 | base = path[len(longest) + 1:] 115 | else: 116 | base = path 117 | # the drive letter is never part of the module name 118 | drive, base = os.path.splitdrive(base) 119 | base = base.replace(os.sep, ".") 120 | if os.altsep: 121 | base = base.replace(os.altsep, ".") 122 | filename, ext = os.path.splitext(base) 123 | return filename.lstrip(".") 124 | 125 | class CoverageResults: 126 | def __init__(self, counts=None, calledfuncs=None, infile=None, 127 | callers=None, outfile=None): 128 | self.counts = counts 129 | if self.counts is None: 130 | self.counts = {} 131 | self.counter = self.counts.copy() # map (filename, lineno) to count 132 | self.calledfuncs = calledfuncs 133 | if self.calledfuncs is None: 134 | self.calledfuncs = {} 135 | self.calledfuncs = self.calledfuncs.copy() 136 | self.callers = callers 137 | if self.callers is None: 138 | self.callers = {} 139 | self.callers = self.callers.copy() 140 | self.infile = infile 141 | self.outfile = outfile 142 | if self.infile: 143 | # Try to merge existing counts file. 144 | try: 145 | with open(self.infile, 'rb') as f: 146 | counts, calledfuncs, callers = pickle.load(f) 147 | self.update(self.__class__(counts, calledfuncs, callers=callers)) 148 | except (OSError, EOFError, ValueError) as err: 149 | print(("Skipping counts file %r: %s" 150 | % (self.infile, err)), file=sys.stderr) 151 | 152 | def is_ignored_filename(self, filename): 153 | """Return True if the filename does not refer to a file 154 | we want to have reported. 155 | """ 156 | return filename.startswith('<') and filename.endswith('>') 157 | 158 | def update(self, other): 159 | """Merge in the data from another CoverageResults""" 160 | counts = self.counts 161 | calledfuncs = self.calledfuncs 162 | callers = self.callers 163 | other_counts = other.counts 164 | other_calledfuncs = other.calledfuncs 165 | other_callers = other.callers 166 | 167 | for key in other_counts: 168 | counts[key] = counts.get(key, 0) + other_counts[key] 169 | 170 | for key in other_calledfuncs: 171 | calledfuncs[key] = 1 172 | 173 | for key in other_callers: 174 | callers[key] = 1 175 | 176 | def write_results(self, show_missing=True, summary=False, coverdir=None): 177 | """ 178 | Write the coverage results. 179 | 180 | :param show_missing: Show lines that had no hits. 181 | :param summary: Include coverage summary per module. 182 | :param coverdir: If None, the results of each module are placed in its 183 | directory, otherwise it is included in the directory 184 | specified. 185 | """ 186 | if self.calledfuncs: 187 | print() 188 | print("functions called:") 189 | calls = self.calledfuncs 190 | for filename, modulename, funcname in sorted(calls): 191 | print(("filename: %s, modulename: %s, funcname: %s" 192 | % (filename, modulename, funcname))) 193 | 194 | if self.callers: 195 | print() 196 | print("calling relationships:") 197 | lastfile = lastcfile = "" 198 | for ((pfile, pmod, pfunc), (cfile, cmod, cfunc)) \ 199 | in sorted(self.callers): 200 | if pfile != lastfile: 201 | print() 202 | print("***", pfile, "***") 203 | lastfile = pfile 204 | lastcfile = "" 205 | if cfile != pfile and lastcfile != cfile: 206 | print(" -->", cfile) 207 | lastcfile = cfile 208 | print(" %s.%s -> %s.%s" % (pmod, pfunc, cmod, cfunc)) 209 | 210 | # turn the counts data ("(filename, lineno) = count") into something 211 | # accessible on a per-file basis 212 | per_file = {} 213 | for filename, lineno in self.counts: 214 | lines_hit = per_file[filename] = per_file.get(filename, {}) 215 | lines_hit[lineno] = self.counts[(filename, lineno)] 216 | 217 | # accumulate summary info, if needed 218 | sums = {} 219 | 220 | for filename, count in per_file.items(): 221 | if self.is_ignored_filename(filename): 222 | continue 223 | 224 | if filename.endswith(".pyc"): 225 | filename = filename[:-1] 226 | 227 | if coverdir is None: 228 | dir = os.path.dirname(os.path.abspath(filename)) 229 | modulename = _modname(filename) 230 | else: 231 | dir = coverdir 232 | os.makedirs(dir, exist_ok=True) 233 | modulename = _fullmodname(filename) 234 | 235 | # If desired, get a list of the line numbers which represent 236 | # executable content (returned as a dict for better lookup speed) 237 | if show_missing: 238 | lnotab = _find_executable_linenos(filename) 239 | else: 240 | lnotab = {} 241 | source = linecache.getlines(filename) 242 | coverpath = os.path.join(dir, modulename + ".cover") 243 | with open(filename, 'rb') as fp: 244 | encoding, _ = tokenize.detect_encoding(fp.readline) 245 | n_hits, n_lines = self.write_results_file(coverpath, source, 246 | lnotab, count, encoding) 247 | if summary and n_lines: 248 | percent = int(100 * n_hits / n_lines) 249 | sums[modulename] = n_lines, percent, modulename, filename 250 | 251 | 252 | if summary and sums: 253 | print("lines cov% module (path)") 254 | for m in sorted(sums): 255 | n_lines, percent, modulename, filename = sums[m] 256 | print("%5d %3d%% %s (%s)" % sums[m]) 257 | 258 | if self.outfile: 259 | # try and store counts and module info into self.outfile 260 | try: 261 | with open(self.outfile, 'wb') as f: 262 | pickle.dump((self.counts, self.calledfuncs, self.callers), 263 | f, 1) 264 | except OSError as err: 265 | print("Can't save counts files because %s" % err, file=sys.stderr) 266 | 267 | def write_results_file(self, path, lines, lnotab, lines_hit, encoding=None): 268 | """Return a coverage results file in path.""" 269 | # ``lnotab`` is a dict of executable lines, or a line number "table" 270 | 271 | try: 272 | outfile = open(path, "w", encoding=encoding) 273 | except OSError as err: 274 | print(("trace: Could not open %r for writing: %s " 275 | "- skipping" % (path, err)), file=sys.stderr) 276 | return 0, 0 277 | 278 | n_lines = 0 279 | n_hits = 0 280 | with outfile: 281 | for lineno, line in enumerate(lines, 1): 282 | # do the blank/comment match to try to mark more lines 283 | # (help the reader find stuff that hasn't been covered) 284 | if lineno in lines_hit: 285 | outfile.write("%5d: " % lines_hit[lineno]) 286 | n_hits += 1 287 | n_lines += 1 288 | elif lineno in lnotab and not PRAGMA_NOCOVER in line: 289 | # Highlight never-executed lines, unless the line contains 290 | # #pragma: NO COVER 291 | outfile.write(">>>>>> ") 292 | n_lines += 1 293 | else: 294 | outfile.write(" ") 295 | outfile.write(line.expandtabs(8)) 296 | 297 | return n_hits, n_lines 298 | 299 | def _find_lines_from_code(code, strs): 300 | """Return dict where keys are lines in the line number table.""" 301 | linenos = {} 302 | 303 | for _, lineno in dis.findlinestarts(code): 304 | if lineno not in strs: 305 | linenos[lineno] = 1 306 | 307 | return linenos 308 | 309 | def _find_lines(code, strs): 310 | """Return lineno dict for all code objects reachable from code.""" 311 | # get all of the lineno information from the code of this scope level 312 | linenos = _find_lines_from_code(code, strs) 313 | 314 | # and check the constants for references to other code objects 315 | for c in code.co_consts: 316 | if inspect.iscode(c): 317 | # find another code object, so recurse into it 318 | linenos.update(_find_lines(c, strs)) 319 | return linenos 320 | 321 | def _find_strings(filename, encoding=None): 322 | """Return a dict of possible docstring positions. 323 | 324 | The dict maps line numbers to strings. There is an entry for 325 | line that contains only a string or a part of a triple-quoted 326 | string. 327 | """ 328 | d = {} 329 | # If the first token is a string, then it's the module docstring. 330 | # Add this special case so that the test in the loop passes. 331 | prev_ttype = token.INDENT 332 | with open(filename, encoding=encoding) as f: 333 | tok = tokenize.generate_tokens(f.readline) 334 | for ttype, tstr, start, end, line in tok: 335 | if ttype == token.STRING: 336 | if prev_ttype == token.INDENT: 337 | sline, scol = start 338 | eline, ecol = end 339 | for i in range(sline, eline + 1): 340 | d[i] = 1 341 | prev_ttype = ttype 342 | return d 343 | 344 | def _find_executable_linenos(filename): 345 | """Return dict where keys are line numbers in the line number table.""" 346 | try: 347 | with tokenize.open(filename) as f: 348 | prog = f.read() 349 | encoding = f.encoding 350 | except OSError as err: 351 | print(("Not printing coverage data for %r: %s" 352 | % (filename, err)), file=sys.stderr) 353 | return {} 354 | code = compile(prog, filename, "exec") 355 | strs = _find_strings(filename, encoding) 356 | return _find_lines(code, strs) 357 | 358 | class Trace: 359 | def __init__(self, count=1, trace=1, countfuncs=0, countcallers=0, 360 | ignoremods=(), ignoredirs=(), infile=None, outfile=None, 361 | timing=False): 362 | """ 363 | @param count true iff it should count number of times each 364 | line is executed 365 | @param trace true iff it should print out each line that is 366 | being counted 367 | @param countfuncs true iff it should just output a list of 368 | (filename, modulename, funcname,) for functions 369 | that were called at least once; This overrides 370 | `count' and `trace' 371 | @param ignoremods a list of the names of modules to ignore 372 | @param ignoredirs a list of the names of directories to ignore 373 | all of the (recursive) contents of 374 | @param infile file from which to read stored counts to be 375 | added into the results 376 | @param outfile file in which to write the results 377 | @param timing true iff timing information be displayed 378 | """ 379 | self.exe_path = [] 380 | self.infile = infile 381 | self.outfile = outfile 382 | self.ignore = _Ignore(ignoremods, ignoredirs) 383 | self.counts = {} # keys are (filename, linenumber) 384 | self.pathtobasename = {} # for memoizing os.path.basename 385 | self.donothing = 0 386 | self.trace = trace 387 | self._calledfuncs = {} 388 | self._callers = {} 389 | self._caller_cache = {} 390 | self.start_time = None 391 | if timing: 392 | self.start_time = _time() 393 | if countcallers: 394 | self.globaltrace = self.globaltrace_trackcallers 395 | elif countfuncs: 396 | self.globaltrace = self.globaltrace_countfuncs 397 | elif trace and count: 398 | self.globaltrace = self.globaltrace_lt 399 | self.localtrace = self.localtrace_trace_and_count 400 | elif trace: 401 | self.globaltrace = self.globaltrace_lt 402 | self.localtrace = self.localtrace_trace 403 | elif count: 404 | self.globaltrace = self.globaltrace_lt 405 | self.localtrace = self.localtrace_count 406 | else: 407 | # Ahem -- do nothing? Okay. 408 | self.donothing = 1 409 | 410 | def run(self, cmd): 411 | import __main__ 412 | dict = __main__.__dict__ 413 | self.runctx(cmd, dict, dict) 414 | 415 | def runctx(self, cmd, globals=None, locals=None): 416 | if globals is None: globals = {} 417 | if locals is None: locals = {} 418 | if not self.donothing: 419 | threading.settrace(self.globaltrace) 420 | sys.settrace(self.globaltrace) 421 | try: 422 | exec(cmd, globals, locals) 423 | finally: 424 | if not self.donothing: 425 | sys.settrace(None) 426 | threading.settrace(None) 427 | 428 | def runfunc(self, func, /, *args, **kw): 429 | result = None 430 | if not self.donothing: 431 | sys.settrace(self.globaltrace) 432 | try: 433 | result = func(*args, **kw) 434 | finally: 435 | if not self.donothing: 436 | sys.settrace(None) 437 | return result 438 | 439 | def file_module_function_of(self, frame): 440 | code = frame.f_code 441 | filename = code.co_filename 442 | if filename: 443 | modulename = _modname(filename) 444 | else: 445 | modulename = None 446 | 447 | funcname = code.co_name 448 | clsname = None 449 | if code in self._caller_cache: 450 | if self._caller_cache[code] is not None: 451 | clsname = self._caller_cache[code] 452 | else: 453 | self._caller_cache[code] = None 454 | ## use of gc.get_referrers() was suggested by Michael Hudson 455 | # all functions which refer to this code object 456 | funcs = [f for f in gc.get_referrers(code) 457 | if inspect.isfunction(f)] 458 | # require len(func) == 1 to avoid ambiguity caused by calls to 459 | # new.function(): "In the face of ambiguity, refuse the 460 | # temptation to guess." 461 | if len(funcs) == 1: 462 | dicts = [d for d in gc.get_referrers(funcs[0]) 463 | if isinstance(d, dict)] 464 | if len(dicts) == 1: 465 | classes = [c for c in gc.get_referrers(dicts[0]) 466 | if hasattr(c, "__bases__")] 467 | if len(classes) == 1: 468 | # ditto for new.classobj() 469 | clsname = classes[0].__name__ 470 | # cache the result - assumption is that new.* is 471 | # not called later to disturb this relationship 472 | # _caller_cache could be flushed if functions in 473 | # the new module get called. 474 | self._caller_cache[code] = clsname 475 | if clsname is not None: 476 | funcname = "%s.%s" % (clsname, funcname) 477 | 478 | return filename, modulename, funcname 479 | 480 | def globaltrace_trackcallers(self, frame, why, arg): 481 | """Handler for call events. 482 | 483 | Adds information about who called who to the self._callers dict. 484 | """ 485 | if why == 'call': 486 | # XXX Should do a better job of identifying methods 487 | this_func = self.file_module_function_of(frame) 488 | parent_func = self.file_module_function_of(frame.f_back) 489 | self._callers[(parent_func, this_func)] = 1 490 | 491 | def globaltrace_countfuncs(self, frame, why, arg): 492 | """Handler for call events. 493 | 494 | Adds (filename, modulename, funcname) to the self._calledfuncs dict. 495 | """ 496 | if why == 'call': 497 | this_func = self.file_module_function_of(frame) 498 | self._calledfuncs[this_func] = 1 499 | 500 | def globaltrace_lt(self, frame, why, arg): 501 | """Handler for call events. 502 | 503 | If the code block being entered is to be ignored, returns `None', 504 | else returns self.localtrace. 505 | """ 506 | if why == 'call': 507 | code = frame.f_code 508 | filename = frame.f_globals.get('__file__', None) 509 | if filename: 510 | # XXX _modname() doesn't work right for packages, so 511 | # the ignore support won't work right for packages 512 | modulename = _modname(filename) 513 | if modulename is not None: 514 | ignore_it = self.ignore.names(filename, modulename) 515 | if not ignore_it: 516 | if self.trace: 517 | print((" --- modulename: %s, funcname: %s" 518 | % (modulename, code.co_name))) 519 | return self.localtrace 520 | else: 521 | return None 522 | 523 | def localtrace_trace_and_count(self, frame, why, arg): 524 | if why == "line": 525 | # record the file name and line number of every trace 526 | filename = frame.f_code.co_filename 527 | lineno = frame.f_lineno 528 | key = filename, lineno 529 | #Cuong 530 | print(key) 531 | self.counts[key] = self.counts.get(key, 0) + 1 532 | 533 | if self.start_time: 534 | print('%.2f' % (_time() - self.start_time), end=' ') 535 | bname = os.path.basename(filename) 536 | line = linecache.getline(filename, lineno) 537 | print("%s(%d)" % (bname, lineno), end='') 538 | if line: 539 | print(": ", line, end='') 540 | else: 541 | print() 542 | return self.localtrace 543 | 544 | def localtrace_trace(self, frame, why, arg): 545 | if why == "line": 546 | # record the file name and line number of every trace 547 | filename = frame.f_code.co_filename 548 | lineno = frame.f_lineno 549 | 550 | if self.start_time: 551 | print('%.2f' % (_time() - self.start_time), end=' ') 552 | bname = os.path.basename(filename) 553 | line = linecache.getline(filename, lineno) 554 | print("%s(%d)" % (bname, lineno), end='') 555 | if line: 556 | print(": ", line, end='') 557 | else: 558 | print() 559 | return self.localtrace 560 | 561 | def localtrace_count(self, frame, why, arg): 562 | if why == "line": 563 | filename = frame.f_code.co_filename 564 | lineno = frame.f_lineno 565 | key = filename, lineno 566 | self.exe_path.append(lineno) 567 | self.counts[key] = self.counts.get(key, 0) + 1 568 | return self.localtrace 569 | 570 | def results(self): 571 | return CoverageResults(self.counts, infile=self.infile, 572 | outfile=self.outfile, 573 | calledfuncs=self._calledfuncs, 574 | callers=self._callers) 575 | 576 | def main(): 577 | import argparse 578 | 579 | parser = argparse.ArgumentParser() 580 | parser.add_argument('--version', action='version', version='trace 2.0') 581 | 582 | grp = parser.add_argument_group('Main options', 583 | 'One of these (or --report) must be given') 584 | 585 | grp.add_argument('-c', '--count', action='store_true', 586 | help='Count the number of times each line is executed and write ' 587 | 'the counts to .cover for each module executed, in ' 588 | 'the module\'s directory. See also --coverdir, --file, ' 589 | '--no-report below.') 590 | grp.add_argument('-t', '--trace', action='store_true', 591 | help='Print each line to sys.stdout before it is executed') 592 | grp.add_argument('-l', '--listfuncs', action='store_true', 593 | help='Keep track of which functions are executed at least once ' 594 | 'and write the results to sys.stdout after the program exits. ' 595 | 'Cannot be specified alongside --trace or --count.') 596 | grp.add_argument('-T', '--trackcalls', action='store_true', 597 | help='Keep track of caller/called pairs and write the results to ' 598 | 'sys.stdout after the program exits.') 599 | 600 | grp = parser.add_argument_group('Modifiers') 601 | 602 | _grp = grp.add_mutually_exclusive_group() 603 | _grp.add_argument('-r', '--report', action='store_true', 604 | help='Generate a report from a counts file; does not execute any ' 605 | 'code. --file must specify the results file to read, which ' 606 | 'must have been created in a previous run with --count ' 607 | '--file=FILE') 608 | _grp.add_argument('-R', '--no-report', action='store_true', 609 | help='Do not generate the coverage report files. ' 610 | 'Useful if you want to accumulate over several runs.') 611 | 612 | grp.add_argument('-f', '--file', 613 | help='File to accumulate counts over several runs') 614 | grp.add_argument('-C', '--coverdir', 615 | help='Directory where the report files go. The coverage report ' 616 | 'for . will be written to file ' 617 | '//.cover') 618 | grp.add_argument('-m', '--missing', action='store_true', 619 | help='Annotate executable lines that were not executed with ' 620 | '">>>>>> "') 621 | grp.add_argument('-s', '--summary', action='store_true', 622 | help='Write a brief summary for each file to sys.stdout. ' 623 | 'Can only be used with --count or --report') 624 | grp.add_argument('-g', '--timing', action='store_true', 625 | help='Prefix each line with the time since the program started. ' 626 | 'Only used while tracing') 627 | 628 | grp = parser.add_argument_group('Filters', 629 | 'Can be specified multiple times') 630 | grp.add_argument('--ignore-module', action='append', default=[], 631 | help='Ignore the given module(s) and its submodules ' 632 | '(if it is a package). Accepts comma separated list of ' 633 | 'module names.') 634 | grp.add_argument('--ignore-dir', action='append', default=[], 635 | help='Ignore files in the given directory ' 636 | '(multiple directories can be joined by os.pathsep).') 637 | 638 | parser.add_argument('--module', action='store_true', default=False, 639 | help='Trace a module. ') 640 | parser.add_argument('progname', nargs='?', 641 | help='file to run as main program') 642 | parser.add_argument('arguments', nargs=argparse.REMAINDER, 643 | help='arguments to the program') 644 | 645 | opts = parser.parse_args() 646 | 647 | if opts.ignore_dir: 648 | _prefix = sysconfig.get_path("stdlib") 649 | _exec_prefix = sysconfig.get_path("platstdlib") 650 | 651 | def parse_ignore_dir(s): 652 | s = os.path.expanduser(os.path.expandvars(s)) 653 | s = s.replace('$prefix', _prefix).replace('$exec_prefix', _exec_prefix) 654 | return os.path.normpath(s) 655 | 656 | opts.ignore_module = [mod.strip() 657 | for i in opts.ignore_module for mod in i.split(',')] 658 | opts.ignore_dir = [parse_ignore_dir(s) 659 | for i in opts.ignore_dir for s in i.split(os.pathsep)] 660 | 661 | if opts.report: 662 | if not opts.file: 663 | parser.error('-r/--report requires -f/--file') 664 | results = CoverageResults(infile=opts.file, outfile=opts.file) 665 | return results.write_results(opts.missing, opts.summary, opts.coverdir) 666 | 667 | if not any([opts.trace, opts.count, opts.listfuncs, opts.trackcalls]): 668 | parser.error('must specify one of --trace, --count, --report, ' 669 | '--listfuncs, or --trackcalls') 670 | 671 | if opts.listfuncs and (opts.count or opts.trace): 672 | parser.error('cannot specify both --listfuncs and (--trace or --count)') 673 | 674 | if opts.summary and not opts.count: 675 | parser.error('--summary can only be used with --count or --report') 676 | 677 | if opts.progname is None: 678 | parser.error('progname is missing: required with the main options') 679 | 680 | t = Trace(opts.count, opts.trace, countfuncs=opts.listfuncs, 681 | countcallers=opts.trackcalls, ignoremods=opts.ignore_module, 682 | ignoredirs=opts.ignore_dir, infile=opts.file, 683 | outfile=opts.file, timing=opts.timing) 684 | try: 685 | if opts.module: 686 | import runpy 687 | module_name = opts.progname 688 | mod_name, mod_spec, code = runpy._get_module_details(module_name) 689 | sys.argv = [code.co_filename, *opts.arguments] 690 | globs = { 691 | '__name__': '__main__', 692 | '__file__': code.co_filename, 693 | '__package__': mod_spec.parent, 694 | '__loader__': mod_spec.loader, 695 | '__spec__': mod_spec, 696 | '__cached__': None, 697 | } 698 | else: 699 | sys.argv = [opts.progname, *opts.arguments] 700 | sys.path[0] = os.path.dirname(opts.progname) 701 | 702 | with io.open_code(opts.progname) as fp: 703 | code = compile(fp.read(), opts.progname, 'exec') 704 | # try to emulate __main__ namespace as much as possible 705 | globs = { 706 | '__file__': opts.progname, 707 | '__name__': '__main__', 708 | '__package__': None, 709 | '__cached__': None, 710 | } 711 | t.runctx(code, globs, globs) 712 | except OSError as err: 713 | sys.exit("Cannot run file %r because: %s" % (sys.argv[0], err)) 714 | except SystemExit: 715 | pass 716 | 717 | results = t.results() 718 | 719 | if not opts.no_report: 720 | results.write_results(opts.missing, opts.summary, opts.coverdir) 721 | 722 | if __name__=='__main__': 723 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import precision_recall_fscore_support 3 | import warnings 4 | import numpy as np 5 | 6 | warnings.simplefilter("ignore") 7 | 8 | def write(content, file): 9 | with open(file, 'a') as f: 10 | f.write(content + '\n') 11 | 12 | def pad_targets(target, max_node): 13 | batch_size, current_max_node = target.shape 14 | padded_target = torch.zeros(batch_size, max_node, device=target.device) 15 | padded_target[:, :current_max_node] = target 16 | return padded_target 17 | 18 | def calculate_metrics(y_true, y_pred): 19 | precisions, recalls, fscores = [], [], [] 20 | for i in range(y_true.shape[1]): 21 | precision, recall, fscore, _ = precision_recall_fscore_support(y_true[:, i], y_pred[:, i], average='binary') 22 | precisions.append(precision) 23 | recalls.append(recall) 24 | fscores.append(fscore) 25 | return sum(precisions) / len(precisions), sum(recalls) / len(recalls), sum(fscores) / len(fscores) 26 | 27 | # def accuracy_whole_list(y_true, y_pred): 28 | # correct = (y_true == y_pred).all(axis=1).sum().item() 29 | # return correct / y_true.shape[0] 30 | 31 | def accuracy_whole_list(y_true, y_pred, lengths): 32 | correct = 0 33 | total = 0 34 | for i in range(len(lengths)): 35 | length = lengths[i] 36 | if np.array_equal(y_true[i][:length], y_pred[i][:length]): 37 | correct += 1 38 | return correct --------------------------------------------------------------------------------