├── README.md
├── cfg.py
├── config.py
├── data.py
├── data
├── CodeNet_test.csv
├── CodeNet_train.csv
├── FixEval_complete_test.csv
├── FixEval_complete_train.csv
├── FixEval_incomplete_test.csv
└── FixEval_incomplete_train.csv
├── fuzz_testing.py
├── fuzz_testing_dataset
├── code_1.py
├── code_10.py
├── code_11.py
├── code_12.py
├── code_13.py
├── code_14.py
├── code_15.py
├── code_16.py
├── code_17.py
├── code_18.py
├── code_19.py
├── code_2.py
├── code_20.py
├── code_21.py
├── code_22.py
├── code_23.py
├── code_24.py
├── code_25.py
├── code_26.py
├── code_27.py
├── code_28.py
├── code_29.py
├── code_3.py
├── code_30.py
├── code_31.py
├── code_32.py
├── code_33.py
├── code_34.py
├── code_35.py
├── code_36.py
├── code_37.py
├── code_38.py
├── code_39.py
├── code_4.py
├── code_40.py
├── code_41.py
├── code_42.py
├── code_43.py
├── code_44.py
├── code_45.py
├── code_46.py
├── code_47.py
├── code_48.py
├── code_49.py
├── code_5.py
├── code_50.py
├── code_6.py
├── code_7.py
├── code_8.py
└── code_9.py
├── generate_dataset
├── cfg.py
├── generate_dataset.py
└── trace_execution.py
├── img
├── architecture.png
└── pipeline.png
├── main.py
├── model.py
├── requirements.txt
├── setup.sh
├── trace_execution.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # CodeFlow: Predicting Program Behavior with Dynamic Dependencies Learning
4 | [](https://arxiv.org/abs/2408.02816)
5 |
6 |
7 |
8 | ## Introduction
9 |
10 | We introduce **CodeFlow**, a novel machine learning-based approach designed to predict program behavior by learning both static and dynamic dependencies within the code. CodeFlow constructs control flow graphs (CFGs) to represent all possible execution paths and uses these graphs to predict code coverage and detect runtime errors. Our empirical evaluation demonstrates that CodeFlow significantly improves code coverage prediction accuracy and effectively localizes runtime errors, outperforming state-of-the-art models.
11 |
12 | ### Paper: [CodeFlow: Predicting Program Behavior with Dynamic Dependencies Learning](https://arxiv.org/abs/2408.02816)
13 |
14 | ## Installation
15 |
16 | To set up the environment and install the necessary libraries, run the following command:
17 |
18 | ```sh
19 | ./setup.sh
20 | ```
21 |
22 | ## Architecture
23 |
24 |
25 | CodeFlow consists of several key components:
26 | 1. **CFG Building**: Constructs CFGs from the source code.
27 | 2. **Source Code Representation Learning**: Learns vector representations of CFG nodes.
28 | 3. **Dynamic Dependencies Learning**: Captures dynamic dependencies among statements using execution traces.
29 | 4. **Code Coverage Prediction**: Classifies nodes for code coverage using learned embeddings.
30 | 5. **Runtime Error Detection and Localization**: Detects and localizes runtime errors by analyzing code coverage continuity within CFGs.
31 |
32 | ## Usage
33 |
34 | ### Running CodeFlow Model
35 |
36 | To run the CodeFlow model, use the following command:
37 |
38 | ```sh
39 | python main.py --data [--runtime_detection] [--bug_localization]
40 | ```
41 |
42 | #### Configuration Options
43 |
44 | - `--data`: Specify the dataset to be used for training. Options:
45 | - `CodeNet`: Train with only non-buggy Python code from the CodeNet dataset.
46 | - `FixEval_complete`: Train with both non-buggy and buggy code from the FixEval and CodeNet dataset.
47 | - `FixEval_incomplete`: Train with the incomplete version of the FixEval_complete dataset.
48 |
49 | - `--runtime_detection`: Validate the Runtime Error Detection.
50 |
51 | - `--bug_localization`: Validate the Bug Localization in buggy code.
52 |
53 | #### Example Usage
54 |
55 | 1. **Training with the CodeNet dataset(RQ1):**
56 | ```sh
57 | python main.py --data CodeNet
58 | ```
59 |
60 | 2. **Training with the complete FixEval dataset and validating Runtime Error Detection(RQ2):**
61 | ```sh
62 | python main.py --data FixEval_complete --runtime_detection
63 | ```
64 |
65 | 3. **Training with the complete and incomplete FixEval dataset and validating Bug Localization(RQ3):**
66 | ```sh
67 | python main.py --data FixEval_complete --bug_localization
68 | python main.py --data FixEval_incomplete --bug_localization
69 | ```
70 | ### Fuzz Testing with LLM Integration (RQ4)
71 |
72 | After training CodeFlow and saving the corresponding checkpoint, you can utilize it for fuzz testing by integrating it with a Large Language Model (LLM). Use the following command:
73 |
74 | ```sh
75 | python fuzz_testing.py --checkpoint --epoch --time --claude_api_key --model
76 | ```
77 | - `checkpoint`: The chosen checkpoint.
78 | - `epoch`: The chosen epoch of checkpoint.
79 | - `time`: Time in seconds to run fuzz testing for each code file.
80 | - `claude_api_key`: Your API key for Claude.
81 | - `model`: Model of Claude, default is claude-3-5-sonnet-20240620.
82 | #### Example
83 | ```sh
84 | python fuzz_testing.py --checkpoint 1 --epoch 600 --time 120 --claude_api_key YOUR_API_KEY --model claude-3-5-sonnet-20240620
85 | ```
86 | ### Generating Your Own Dataset
87 |
88 | To generate your own dataset, including CFG, forward and backward edges, and the true execution trace as ground truth for your Python code, follow these steps:
89 |
90 | 1. **Navigate to the `generate_dataset` folder**:
91 | ```sh
92 | cd generate_dataset
93 | ```
94 |
95 | 2. **Place your Python code files in the `dataset` folder**.
96 |
97 | 3. **Run the dataset generation script**:
98 | ```sh
99 | python generate_dataset.py
100 | ```
101 | To build and visualize CFG for a Python file, use this command:
102 | ```sh
103 | python cfg.py \directory_to_Python_file
104 | ```
105 |
106 | ## License
107 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
108 |
109 | ## Acknowledgements
110 | This codebase is adapted from:
111 | - [ConditionBugs](https://github.com/zhangj111/ConditionBugs)
112 | - [CFG-Generator](https://github.com/Tiankai-Jiang/CFG-Generator)
113 | - [trace_python](https://github.com/python/cpython/blob/3.12/Lib/trace.py)
114 |
115 | ## Citation Information
116 |
117 | If you're using CodeFlow, please cite using this BibTeX:
118 | ```bibtex
119 | @misc{le2024learningpredictprogramexecution,
120 | title={Learning to Predict Program Execution by Modeling Dynamic Dependency on Code Graphs},
121 | author={Cuong Chi Le and Hoang Nhat Phan and Huy Nhat Phan and Tien N. Nguyen and Nghi D. Q. Bui},
122 | year={2024},
123 | eprint={2408.02816},
124 | archivePrefix={arXiv},
125 | primaryClass={cs.SE},
126 | url={https://arxiv.org/abs/2408.02816},
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/cfg.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import ast, astor, autopep8, tokenize, io, sys
3 | import graphviz as gv
4 | from typing import Dict, List, Tuple, Set, Optional, Type
5 | import re
6 | import sys
7 | import trace_execution
8 | import os
9 | import io
10 | import linecache
11 | import random
12 |
13 | class SingletonMeta(type):
14 | _instance: Optional[BlockId] = None
15 |
16 | def __call__(self) -> BlockId:
17 | if self._instance is None:
18 | self._instance = super().__call__()
19 | return self._instance
20 |
21 |
22 | class BlockId(metaclass=SingletonMeta):
23 | counter: int = 0
24 |
25 | def gen(self) -> int:
26 | self.counter += 1
27 | return self.counter
28 |
29 |
30 | class BasicBlock:
31 |
32 | def __init__(self, bid: int):
33 | self.bid: int = bid
34 | self.stmts: List[Type[ast.AST]] = []
35 | self.calls: List[str] = []
36 | self.prev: List[int] = []
37 | self.next: List[int] = []
38 | self.condition = False
39 | self.for_loop = 0
40 | self.for_name = Type[ast.AST]
41 |
42 | def is_empty(self) -> bool:
43 | return len(self.stmts) == 0
44 |
45 | def has_next(self) -> bool:
46 | return len(self.next) != 0
47 |
48 | def has_previous(self) -> bool:
49 | return len(self.prev) != 0
50 |
51 | def remove_from_prev(self, prev_bid: int) -> None:
52 | if prev_bid in self.prev:
53 | self.prev.remove(prev_bid)
54 |
55 | def remove_from_next(self, next_bid: int) -> None:
56 | if next_bid in self.next:
57 | self.next.remove(next_bid)
58 |
59 | def stmts_to_code(self) -> str:
60 | code = ''
61 | for stmt in self.stmts:
62 | line = astor.to_source(stmt)
63 | code += line.split('\n')[0] + "\n" if type(stmt) in [ast.If, ast.For, ast.While, ast.FunctionDef,
64 | ast.AsyncFunctionDef] else line
65 | return code
66 |
67 | def calls_to_code(self) -> str:
68 | return '\n'.join(self.calls)
69 |
70 |
71 | class CFG:
72 |
73 | def __init__(self, name: str):
74 | self.name: str = name
75 | self.filename = name
76 |
77 | # I am sure that in original code variable asynchr is not used
78 | # And I think list finalblocks is also not used.
79 |
80 | self.start: Optional[BasicBlock] = None
81 | self.func_calls: Dict[str, CFG] = {}
82 | self.blocks: Dict[int, BasicBlock] = {}
83 | self.edges: Dict[Tuple[int, int], Type[ast.AST]] = {}
84 | self.back_edges: List[Tuple[int, int]] = []
85 | self.lst_temp : List[Tuple[int, int]] = []
86 | self.graph: Optional[gv.dot.Digraph] = None
87 | self.execution_path: List[int] = []
88 | self.path: List[int] = []
89 | self.revert: Dict[int, int] = {}
90 | self.store_True: List[int] = []
91 |
92 | def _traverse(self, block: BasicBlock, visited: Set[int] = set(), calls: bool = True) -> None:
93 | if block.bid not in visited:
94 | visited.add(block.bid)
95 | st = block.stmts_to_code()
96 | if st.startswith('if'):
97 | st = st[3:]
98 | elif st.startswith('while'):
99 | st = st[6:]
100 | if block.condition:
101 | st = 'T ' + st
102 | # Check if the block is in the path and highlight it
103 | node_attributes = {'shape': 'ellipse'}
104 | if block.bid in self.path:
105 | node_attributes['color'] = 'red'
106 | node_attributes['style'] = 'filled'
107 |
108 | self.graph.node(str(block.bid), label=st, _attributes=node_attributes)
109 |
110 | for next_bid in block.next:
111 | self._traverse(self.blocks[next_bid], visited, calls=calls)
112 | self.graph.edge(str(block.bid), str(next_bid))
113 |
114 | def get_revert(self):
115 | code = {}
116 | for_loop = {}
117 | for i in self.blocks:
118 | if self.blocks[i].for_loop != 0:
119 | if self.blocks[i].for_loop not in for_loop:
120 | for_loop[self.blocks[i].for_loop] = [i]
121 | else:
122 | for_loop[self.blocks[i].for_loop].append(i)
123 | first = []
124 | second = []
125 | for i in for_loop:
126 | first.append(for_loop[i][0]+1)
127 | second.append(for_loop[i][1])
128 | orin_node = []
129 | track = {}
130 | track_for = {}
131 | for i in self.blocks:
132 | if self.blocks[i].stmts_to_code():
133 | if int(i) == 1:
134 | st = 'BEGIN'
135 | elif int(i) == len(self.blocks):
136 | st = 'EXIT'
137 | else:
138 | if i in first:
139 | line = astor.to_source(self.blocks[i].for_name)
140 | st = line.split('\n')[0]
141 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
142 | else:
143 | st = self.blocks[i].stmts_to_code()
144 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
145 | orin_node.append([i, st, None])
146 | if st not in track:
147 | track[st] = [len(orin_node)-1]
148 | else:
149 | track[st].append(len(orin_node)-1)
150 | track_for[i] = len(orin_node)-1
151 | with open(self.filename, 'r') as file:
152 | lines = file.readlines()
153 | for i in range(1, len(lines)+1):
154 | line = lines[i-1]
155 | #delete \n at the end of each line and delete all spaces
156 | line = line.strip()
157 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "")
158 | if line.startswith('elif'):
159 | line = line[2:]
160 | if line in track:
161 | orin_node[track[line][0]][2] = i
162 | if orin_node[track[line][0]][0] in first:
163 | orin_node[track[line][0]-1][2] = i-0.4
164 | orin_node[track[line][0]+1][2] = i+0.4
165 | if len(track[line]) > 1:
166 | track[line].pop(0)
167 | for i in second:
168 | max_val = 0
169 | for edge in self.edges:
170 | if edge[0] == i:
171 | if orin_node[track_for[edge[1]]][2] > max_val:
172 | max_val = orin_node[track_for[edge[1]]][2]
173 | if edge[1] == i:
174 | if orin_node[track_for[edge[0]]][2] > max_val:
175 | max_val = orin_node[track_for[edge[0]]][2]
176 | orin_node[track_for[i]][2] = max_val + 0.5
177 | orin_node[0][2] = 0
178 | orin_node[-1][2] = len(lines)+1
179 | # sort orin_node by the third element
180 | orin_node.sort(key=lambda x: x[2])
181 | nodes = []
182 | for t in orin_node:
183 | i = t[0]
184 | if self.blocks[i].stmts_to_code():
185 | if int(i) == 1:
186 | nodes.append('BEGIN')
187 | elif int(i) == len(self.blocks):
188 | nodes.append('EXIT')
189 | else:
190 | st = self.blocks[i].stmts_to_code()
191 | st_no_space = re.sub(r"\s+", "", st)
192 | st_no_space = st_no_space.replace('"', "'")
193 | # if start with if or while, delete these keywords
194 | if st.startswith('if'):
195 | st = st[3:]
196 | elif st.startswith('while'):
197 | st = st[6:]
198 | if self.blocks[i].condition:
199 | st = 'T '+ st
200 | nodes.append(st)
201 | self.revert[len(nodes)] = i
202 |
203 | def track_execution(self):
204 | nodes = []
205 | blocks = []
206 | matching = {}
207 |
208 | # add
209 | code = {}
210 | for_loop = {}
211 | for i in self.blocks:
212 | if self.blocks[i].for_loop != 0:
213 | if self.blocks[i].for_loop not in for_loop:
214 | for_loop[self.blocks[i].for_loop] = [i]
215 | else:
216 | for_loop[self.blocks[i].for_loop].append(i)
217 | first = []
218 | second = []
219 | for i in for_loop:
220 | first.append(for_loop[i][0]+1)
221 | second.append(for_loop[i][1])
222 | orin_node = []
223 | track = {}
224 | track_for = {}
225 | for i in self.blocks:
226 | if self.blocks[i].stmts_to_code():
227 | if int(i) == 1:
228 | st = 'BEGIN'
229 | elif int(i) == len(self.blocks):
230 | st = 'EXIT'
231 | else:
232 | if i in first:
233 | line = astor.to_source(self.blocks[i].for_name)
234 | st = line.split('\n')[0]
235 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
236 | else:
237 | st = self.blocks[i].stmts_to_code()
238 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
239 | orin_node.append([i, st, None])
240 | if st not in track:
241 | track[st] = [len(orin_node)-1]
242 | else:
243 | track[st].append(len(orin_node)-1)
244 | track_for[i] = len(orin_node)-1
245 | with open(self.filename, 'r') as file:
246 | lines = file.readlines()
247 | for i in range(1, len(lines)+1):
248 | line = lines[i-1]
249 | #delete \n at the end of each line and delete all spaces
250 | line = line.strip()
251 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "")
252 | if line.startswith('elif'):
253 | line = line[2:]
254 | if line in track:
255 | orin_node[track[line][0]][2] = i
256 | if orin_node[track[line][0]][0] in first:
257 | orin_node[track[line][0]-1][2] = i-0.4
258 | orin_node[track[line][0]+1][2] = i+0.4
259 | if len(track[line]) > 1:
260 | track[line].pop(0)
261 | orin_node[0][2] = 0
262 | orin_node[-1][2] = len(lines)+1
263 | for i in second:
264 | max_val = 0
265 | for edge in self.edges:
266 | if edge[0] == i:
267 | if orin_node[track_for[edge[1]]][2] > max_val:
268 | max_val = orin_node[track_for[edge[1]]][2]
269 | if edge[1] == i:
270 | if orin_node[track_for[edge[0]]][2] > max_val:
271 | max_val = orin_node[track_for[edge[0]]][2]
272 | orin_node[track_for[i]][2] = max_val + 0.5
273 | # sort orin_node by the third element
274 | orin_node.sort(key=lambda x: x[2])
275 | #add
276 | for t in orin_node:
277 | i = t[0]
278 | if self.blocks[i].stmts_to_code():
279 | if int(i) == 1:
280 | nodes.append('BEGIN')
281 | blocks.append('BEGIN')
282 | elif int(i) == len(self.blocks):
283 | nodes.append('EXIT')
284 | blocks.append('EXIT')
285 | else:
286 | st = self.blocks[i].stmts_to_code()
287 | st_no_space = re.sub(r"\s+", "", st)
288 | st_no_space = st_no_space.replace('"', "'")
289 | blocks.append(st_no_space)
290 | # if start with if or while, delete these keywords
291 | if st.startswith('if'):
292 | st = st[3:]
293 | elif st.startswith('while'):
294 | st = st[6:]
295 | if self.blocks[i].condition:
296 | st = 'T '+ st
297 | nodes.append(st)
298 | matching[i] = len(nodes)
299 | self.revert[len(nodes)] = i
300 |
301 | edges = {}
302 | for edge in self.edges:
303 | if matching[edge[0]] not in edges:
304 | edges[matching[edge[0]]] = [matching[edge[1]]]
305 | else:
306 | edges[matching[edge[0]]].append(matching[edge[1]])
307 |
308 | for_loop = {}
309 | for i in self.blocks:
310 | if self.blocks[i].for_loop != 0:
311 | if self.blocks[i].for_loop not in for_loop:
312 | for_loop[self.blocks[i].for_loop] = [matching[i]]
313 | else:
314 | for_loop[self.blocks[i].for_loop].append(matching[i])
315 | store = {}
316 | last_loop = {}
317 | for i in for_loop:
318 | lst = for_loop[i]
319 | start = lst[0]
320 | end = lst[1]
321 | if self.blocks[self.revert[start]].condition:
322 | self.blocks[self.revert[start+1]].condition = True
323 | store[start+1] = [start]
324 | for t in matching:
325 | if matching[t] == start+1:
326 | line = astor.to_source(self.blocks[t].for_name)
327 | code = line.split('\n')[0]
328 | st_no_space = re.sub(r"\s+", "", code)
329 | st_no_space = st_no_space.replace('"', "'")
330 | blocks[start] = st_no_space
331 | edges.pop(start)
332 | for j in edges:
333 | if start in edges[j]:
334 | edges[j].remove(start)
335 | edges[j].append(start + 1)
336 | for a in edges:
337 | if end in edges[a]:
338 | edges[a].remove(end)
339 | edges[a].append(start+1)
340 | last_loop[a] = end
341 | edges.pop(end)
342 | if nodes[start+1].startswith('T'):
343 | store[start+1].append(start+2)
344 | for b in edges[start+2]:
345 | edges[start+1].append(b)
346 | edges.pop(start+2)
347 | edges[start+1].remove(start+2)
348 | else:
349 | store[start+1].append(start+3)
350 | for b in edges[start+3]:
351 | edges[start+1].append(b)
352 | edges.pop(start+3)
353 | edges[start+1].remove(start+3)
354 | t = trace_execution.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,],
355 | trace=0, count=1)
356 | arguments = []
357 | sys.argv = [self.filename, arguments]
358 | sys.path[0] = os.path.dirname(self.filename)
359 |
360 | with io.open_code(self.filename) as fp:
361 | code = compile(fp.read(), self.filename, 'exec')
362 | # try to emulate __main__ namespace as much as possible
363 | globs = {
364 | '__file__': self.filename,
365 | '__name__': '__main__',
366 | '__package__': None,
367 | '__cached__': None,
368 | }
369 | terminate = True
370 | try:
371 | t.runctx(code, globs, globs)
372 | except Exception as e:
373 | terminate = False
374 |
375 | source = linecache.getlines(self.filename)
376 | code_line = [element.lstrip().replace('\n', '') for element in source]
377 | execution_path = []
378 | for lineno in t.exe_path:
379 | no_spaces = re.sub(r"\s+", "", code_line[lineno-1])
380 | if no_spaces.startswith("if("):
381 | no_spaces = no_spaces.replace("(", "").replace(")", "")
382 | if no_spaces.startswith("elif"):
383 | no_spaces = no_spaces[2:]
384 | execution_path.append(no_spaces)
385 |
386 | check_True_condition = []
387 | for i in range(len(execution_path)-1):
388 | if execution_path[i].startswith('if') or execution_path[i].startswith('while') or execution_path[i].startswith('for'):
389 | if t.exe_path[i+1] == t.exe_path[i]+1:
390 | check_True_condition.append(i)
391 |
392 | current_node = 1
393 | path = [self.revert[current_node]]
394 | exit_flag = False
395 | for s in range(len(execution_path)):
396 | node = execution_path[s]
397 | if current_node == 1:
398 | current_node = edges[current_node][0]
399 | else:
400 | c = 0
401 | if node == "break" or node == "continue":
402 | continue
403 | if len(edges[current_node]) == 2:
404 | if blocks[edges[current_node][0]-1] == blocks[edges[current_node][1]-1]:
405 | if (s-1) in check_True_condition:
406 | if self.blocks[self.revert[edges[current_node][0]]].condition:
407 | current_node = edges[current_node][0]
408 | else:
409 | current_node = edges[current_node][1]
410 | else:
411 | if self.blocks[self.revert[edges[current_node][0]]].condition:
412 | current_node = edges[current_node][1]
413 | else:
414 | current_node = edges[current_node][0]
415 | if current_node in last_loop:
416 | if terminate:
417 | if self.revert[last_loop[current_node]] not in path:
418 | path.append(self.revert[last_loop[current_node]])
419 | else:
420 | if s != len(execution_path) - 1:
421 | if self.revert[last_loop[current_node]] not in path:
422 | path.append(self.revert[last_loop[current_node]])
423 | if current_node in store:
424 | for i in store[current_node]:
425 | if self.revert[i] not in path:
426 | path.append(self.revert[i])
427 | if self.revert[current_node] not in path:
428 | path.append(self.revert[current_node])
429 | continue
430 | for next_node in edges[current_node]:
431 | c += 1
432 | node = node.replace('"', "'")
433 | # print(node)
434 | # print(blocks[next_node-1])
435 | # print("___________")
436 | if blocks[next_node-1] == node:
437 | if next_node in last_loop:
438 | if terminate:
439 | if self.revert[last_loop[next_node]] not in path:
440 | path.append(self.revert[last_loop[next_node]])
441 | else:
442 | if s != len(execution_path) - 1:
443 | if self.revert[last_loop[next_node]] not in path:
444 | path.append(self.revert[last_loop[next_node]])
445 | if next_node in store:
446 | for i in store[next_node]:
447 | if self.revert[i] not in path:
448 | path.append(self.revert[i])
449 | current_node = next_node
450 | break
451 | if c == len(edges[current_node]):
452 | exit_flag = True
453 | raise Exception(f"Error: Cannot find the execution path in CFG in file {self.filename}")
454 |
455 | if exit_flag:
456 | break
457 | if self.revert[current_node] not in path:
458 | path.append(self.revert[current_node])
459 | if terminate:
460 | node_max = 0
461 | for i in self.blocks:
462 | if i > node_max:
463 | node_max = i
464 | path.append(node_max)
465 |
466 | self.path = path
467 |
468 | def clean(self):
469 | des_edges = {}
470 | for edge in self.edges:
471 | if edge[0] not in des_edges:
472 | des_edges[edge[0]] = [edge[1]]
473 | else:
474 | des_edges[edge[0]].append(edge[1])
475 | blank_node = []
476 | for node in des_edges:
477 | check = False
478 | if self.blocks[node].stmts_to_code():
479 | current = node
480 | for next in des_edges[node]:
481 | next_node = next
482 | while self.blocks[next_node].stmts_to_code() == '':
483 | if next_node not in blank_node:
484 | blank_node.append(next_node)
485 | current = next_node
486 | if current not in des_edges:
487 | break
488 | next_node = des_edges[current][0]
489 | if (current, next_node) in self.back_edges:
490 | check = True
491 | if (node, next_node) not in self.edges:
492 | self.edges[(node, next_node)] = None
493 | if check:
494 | self.back_edges.append((node, next_node))
495 | if next_node != node and next_node not in self.blocks[node].next:
496 | self.blocks[node].next.append(next_node)
497 | else:
498 | if node not in blank_node:
499 | blank_node.append(node)
500 | for edge in self.edges.copy():
501 | if edge[0] in blank_node or edge[1] in blank_node:
502 | self.edges.pop(edge)
503 | for i in self.blocks:
504 | for node in blank_node:
505 | if node in self.blocks[i].next:
506 | self.blocks[i].next.remove(node)
507 |
508 | def _show(self, fmt: str = 'pdf', calls: bool = True) -> gv.dot.Digraph:
509 | self.graph = gv.Digraph(name='cluster_' + self.name, format=fmt, graph_attr={'label': self.name})
510 | self._traverse(self.start, calls=calls)
511 | for k, v in self.func_calls.items():
512 | self.graph.subgraph(v._show(fmt, calls))
513 | return self.graph
514 |
515 |
516 | def show(self, filepath: str = './output', fmt: str = 'pdf', calls: bool = True, show: bool = True) -> None:
517 | self._show(fmt, calls)
518 | self.graph.render(filepath, view=show, cleanup=True)
519 |
520 |
521 | class CFGVisitor(ast.NodeVisitor):
522 |
523 | invertComparators: Dict[Type[ast.AST], Type[ast.AST]] = {ast.Eq: ast.NotEq, ast.NotEq: ast.Eq, ast.Lt: ast.GtE,
524 | ast.LtE: ast.Gt,
525 | ast.Gt: ast.LtE, ast.GtE: ast.Lt, ast.Is: ast.IsNot,
526 | ast.IsNot: ast.Is, ast.In: ast.NotIn, ast.NotIn: ast.In}
527 |
528 | def __init__(self):
529 | super().__init__()
530 | self.count_for_loop = 0
531 | self.loop_stack: List[BasicBlock] = []
532 | self.for_stack: List[int] = []
533 | self.continue_stack: List[BasicBlock] = []
534 | self.ifExp = False
535 | self.store_True: List[int] = []
536 |
537 | def build(self, name: str, tree: Type[ast.AST]) -> CFG:
538 | self.cfg = CFG(name)
539 | self.filename = name
540 | begin_block = self.new_block()
541 | begin_block.stmts = [ast.Expr(value=ast.Str(s='BEGIN'))]
542 | self.cfg.start = begin_block
543 | self.curr_block = self.new_block()
544 | self.add_edge(begin_block.bid, self.curr_block.bid)
545 |
546 | self.visit(tree)
547 | self.add_EXIT_node()
548 | self.remove_empty_blocks(self.cfg.start)
549 | for edge in self.cfg.lst_temp:
550 | if edge in self.cfg.back_edges:
551 | self.cfg.back_edges.remove(edge)
552 | return self.cfg
553 |
554 | def add_EXIT_node(self) -> None:
555 | exit_block = self.new_block()
556 | exit_block.stmts = [ast.Expr(value=ast.Str(s='EXIT'))]
557 | self.add_edge(self.curr_block.bid, exit_block.bid)
558 |
559 | def new_block(self) -> BasicBlock:
560 | bid: int = BlockId().gen()
561 | self.cfg.blocks[bid] = BasicBlock(bid)
562 | if bid in self.store_True:
563 | self.cfg.blocks[bid].condition = True
564 | self.store_True.remove(bid)
565 | return self.cfg.blocks[bid]
566 |
567 | def add_stmt(self, block: BasicBlock, stmt: Type[ast.AST]) -> None:
568 | # if block.stmts contain 1 stmt, then create the new node and add the stmt to the new node and add edge from block to new node
569 | if len(block.stmts) == 1:
570 | new_block = self.new_block()
571 | new_block.stmts.append(stmt)
572 | self.add_edge(block.bid, new_block.bid)
573 | self.curr_block = new_block
574 | else:
575 | block.stmts.append(stmt)
576 |
577 | def add_edge(self, frm_id: int, to_id: int, condition=None) -> BasicBlock:
578 | self.cfg.blocks[frm_id].next.append(to_id)
579 | self.cfg.blocks[to_id].prev.append(frm_id)
580 | self.cfg.edges[(frm_id, to_id)] = condition
581 | return self.cfg.blocks[to_id]
582 |
583 | def add_loop_block(self) -> BasicBlock:
584 | if self.curr_block.is_empty() and not self.curr_block.has_next():
585 | return self.curr_block
586 | else:
587 | loop_block = self.new_block()
588 | self.add_edge(self.curr_block.bid, loop_block.bid)
589 | return loop_block
590 |
591 | def add_subgraph(self, tree: Type[ast.AST]) -> None:
592 | self.cfg.func_calls[tree.name] = CFGVisitor().build(tree.name, ast.Module(body=tree.body))
593 |
594 | def add_condition(self, cond1: Optional[Type[ast.AST]], cond2: Optional[Type[ast.AST]]) -> Optional[Type[ast.AST]]:
595 | if cond1 and cond2:
596 | return ast.BoolOp(ast.And(), values=[cond1, cond2])
597 | else:
598 | return cond1 if cond1 else cond2
599 |
600 | # not tested
601 | def remove_empty_blocks(self, block: BasicBlock, visited: Set[int] = set()) -> None:
602 | if block.bid not in visited:
603 | visited.add(block.bid)
604 | if block.is_empty():
605 | for prev_bid in block.prev:
606 | prev_block = self.cfg.blocks[prev_bid]
607 | for next_bid in block.next:
608 | next_block = self.cfg.blocks[next_bid]
609 | self.add_edge(prev_bid, next_bid, self.add_condition(self.cfg.edges.get((prev_bid, block.bid)), self.cfg.edges.get((block.bid, next_bid))))
610 | self.cfg.edges.pop((block.bid, next_bid), None)
611 | if (block.bid, next_bid) in self.cfg.back_edges:
612 | self.cfg.back_edges.append((prev_bid, next_bid))
613 | self.cfg.lst_temp.append((block.bid, next_bid))
614 | next_block.remove_from_prev(block.bid)
615 | self.cfg.edges.pop((prev_bid, block.bid))
616 | if (prev_bid, block.bid) in self.cfg.back_edges:
617 | self.cfg.back_edges.append((prev_bid, next_bid))
618 | self.cfg.lst_temp.append((prev_bid, block.bid))
619 | prev_block.remove_from_next(block.bid)
620 | block.prev.clear()
621 | for next_bid in block.next:
622 | self.remove_empty_blocks(self.cfg.blocks[next_bid], visited)
623 | block.next.clear()
624 |
625 | else:
626 | for next_bid in block.next:
627 | self.remove_empty_blocks(self.cfg.blocks[next_bid], visited)
628 |
629 | def invert(self, node: Type[ast.AST]) -> Type[ast.AST]:
630 | if type(node) == ast.Compare:
631 | if len(node.ops) == 1:
632 | return ast.Compare(left=node.left, ops=[self.invertComparators[type(node.ops[0])]()], comparators=node.comparators)
633 | else:
634 | tmpNode = ast.BoolOp(op=ast.And(), values = [ast.Compare(left=node.left, ops=[node.ops[0]], comparators=[node.comparators[0]])])
635 | for i in range(0, len(node.ops) - 1):
636 | tmpNode.values.append(ast.Compare(left=node.comparators[i], ops=[node.ops[i+1]], comparators=[node.comparators[i+1]]))
637 | return self.invert(tmpNode)
638 | elif isinstance(node, ast.BinOp) and type(node.op) in self.invertComparators:
639 | return ast.BinOp(node.left, self.invertComparators[type(node.op)](), node.right)
640 | elif type(node) == ast.NameConstant and type(node.value) == bool:
641 | return ast.NameConstant(value=not node.value)
642 | elif type(node) == ast.BoolOp:
643 | return ast.BoolOp(values = [self.invert(x) for x in node.values], op = {ast.And: ast.Or(), ast.Or: ast.And()}.get(type(node.op)))
644 | elif type(node) == ast.UnaryOp:
645 | return self.UnaryopInvert(node)
646 | else:
647 | return ast.UnaryOp(op=ast.Not(), operand=node)
648 |
649 | def UnaryopInvert(self, node: Type[ast.AST]) -> Type[ast.AST]:
650 | if type(node.op) == ast.UAdd:
651 | return ast.UnaryOp(op=ast.USub(),operand = node.operand)
652 | elif type(node.op) == ast.USub:
653 | return ast.UnaryOp(op=ast.UAdd(),operand = node.operand)
654 | elif type(node.op) == ast.Invert:
655 | return ast.UnaryOp(op=ast.Not(), operand=node)
656 | else:
657 | return node.operand
658 |
659 | # def boolinvert(self, node:Type[ast.AST]) -> Type[ast.AST]:
660 | # value = []
661 | # for item in node.values:
662 | # value.append(self.invert(item))
663 | # if type(node.op) == ast.Or:
664 | # return ast.BoolOp(values = value, op = ast.And())
665 | # elif type(node.op) == ast.And:
666 | # return ast.BoolOp(values = value, op = ast.Or())
667 |
668 | def combine_conditions(self, node_list: List[Type[ast.AST]]) -> Type[ast.AST]:
669 | return node_list[0] if len(node_list) == 1 else ast.BoolOp(op=ast.And(), values = node_list)
670 |
671 | def generic_visit(self, node):
672 | if type(node) in [ast.Import, ast.ImportFrom]:
673 | self.add_stmt(self.curr_block, node)
674 | return
675 | if type(node) in [ast.FunctionDef, ast.AsyncFunctionDef]:
676 | self.add_stmt(self.curr_block, node)
677 | self.add_subgraph(node)
678 | return
679 | if type(node) in [ast.AnnAssign, ast.AugAssign]:
680 | self.add_stmt(self.curr_block, node)
681 | super().generic_visit(node)
682 |
683 | def get_function_name(self, node: Type[ast.AST]) -> str:
684 | if type(node) == ast.Name:
685 | return node.id
686 | elif type(node) == ast.Attribute:
687 | return self.get_function_name(node.value) + '.' + node.attr
688 | elif type(node) == ast.Str:
689 | return node.s
690 | elif type(node) == ast.Subscript:
691 | return node.value.id
692 | elif type(node) == ast.Lambda:
693 | return 'lambda function'
694 |
695 | def populate_body(self, body_list: List[Type[ast.AST]], to_bid: int, type: str) -> None:
696 | for child in body_list:
697 | self.visit(child)
698 | if not self.curr_block.next:
699 | self.add_edge(self.curr_block.bid, to_bid)
700 | if type == "While" or type == "Try":
701 | self.cfg.back_edges.append((self.curr_block.bid, to_bid))
702 |
703 | def populate_If_body(self, body_list: List[Type[ast.AST]], node_list: List[int]):
704 | for child in body_list:
705 | self.visit(child)
706 | if not self.curr_block.next:
707 | node_list.append(self.curr_block.bid)
708 | return node_list
709 |
710 | def populate_For_body(self, body_list: List[Type[ast.AST]]) -> None:
711 | for child in body_list:
712 | self.visit(child)
713 | new_node = self.new_block()
714 | if not self.curr_block.next:
715 | self.add_edge(self.curr_block.bid, new_node.bid)
716 | return new_node
717 |
718 | # assert type check
719 | def visit_Assert(self, node):
720 | self.add_stmt(self.curr_block, node)
721 | # If the assertion fails, the current flow ends, so the fail block is a
722 | # final block of the CFG.
723 | # self.cfg.finalblocks.append(self.add_edge(self.curr_block.bid, self.new_block().bid, self.invert(node.test)))
724 | # If the assertion is True, continue the flow of the program.
725 | # success block
726 | self.curr_block = self.add_edge(self.curr_block.bid, self.new_block().bid, node.test)
727 | self.generic_visit(node)
728 |
729 | def visit_Assign(self, node):
730 | if type(node.value) in [ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp, ast.Lambda] and len(node.targets) == 1 and type(node.targets[0]) == ast.Name: # is this entire statement necessary?
731 | if type(node.value) == ast.ListComp:
732 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.List(elts=[], ctx=ast.Load())))
733 | self.listCompReg = (node.targets[0].id, node.value)
734 | elif type(node.value) == ast.SetComp:
735 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.Call(func=ast.Name(id='set', ctx=ast.Load()), args=[], keywords=[])))
736 | self.setCompReg = (node.targets[0].id, node.value)
737 | elif type(node.value) == ast.DictComp:
738 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.Dict(keys=[], values=[])))
739 | self.dictCompReg = (node.targets[0].id, node.value)
740 | elif type(node.value) == ast.GeneratorExp:
741 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=node.targets[0].id, ctx=ast.Store())], value=ast.Call(func=ast.Name(id='__' + node.targets[0].id + 'Generator__', ctx=ast.Load()), args=[], keywords=[])))
742 | self.genExpReg = (node.targets[0].id, node.value)
743 | else:
744 | self.lambdaReg = (node.targets[0].id, node.value)
745 | else:
746 | self.add_stmt(self.curr_block, node)
747 | self.generic_visit(node)
748 |
749 | def visit_Await(self, node):
750 | afterawait_block = self.new_block()
751 | self.add_edge(self.curr_block.bid, afterawait_block.bid)
752 | self.generic_visit(node)
753 | self.curr_block = afterawait_block
754 |
755 | def visit_Break(self, node):
756 | assert len(self.loop_stack), "Found break not inside loop"
757 | self.add_edge(self.curr_block.bid, self.loop_stack[-1].bid, ast.Break())
758 |
759 | def visit_Continue(self, node):
760 | assert len(self.continue_stack), "Found continue not inside loop"
761 | self.add_edge(self.curr_block.bid, self.continue_stack[-1].bid, ast.Continue())
762 |
763 | def visit_Call(self, node):
764 | if type(node.func) == ast.Lambda:
765 | self.lambdaReg = ('Anonymous Function', node.func)
766 | self.generic_visit(node)
767 | # else:
768 | # self.curr_block.calls.append(self.get_function_name(node.func))
769 |
770 | def visit_DictComp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]:
771 | if not generators:
772 | if self.dictCompReg[0]: # bug if there is else statement in comprehension
773 | return [ast.Assign(targets=[ast.Subscript(value=ast.Name(id=self.dictCompReg[0], ctx=ast.Load()), slice=ast.Index(value=self.dictCompReg[1].key), ctx=ast.Store())], value=self.dictCompReg[1].value)]
774 | # else: # not supported yet
775 | # return [ast.Expr(value=self.dictCompReg[1].elt)]
776 | else:
777 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_DictComp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_DictComp_Rec(generators[:-1]), orelse=[])]
778 |
779 | def visit_DictComp(self, node):
780 | try: # try may change to checking if self.dictCompReg exists
781 | self.generic_visit(ast.Module(self.visit_DictComp_Rec(self.dictCompReg[1].generators)))
782 | except:
783 | pass
784 | finally:
785 | self.dictCompReg = None
786 |
787 | # ignore the case when using set or dict comprehension or generator expression but the result is not assigned to a variable
788 | def visit_Expr(self, node):
789 | if type(node.value) == ast.ListComp and type(node.value.elt) == ast.Call:
790 | self.listCompReg = (None, node.value)
791 | elif type(node.value) == ast.Lambda:
792 | self.lambdaReg = ('Anonymous Function', node.value)
793 | # elif type(node.value) == ast.Call and type(node.value.func) == ast.Lambda:
794 | # self.lambdaReg = ('Anonymous Function', node.value.func)
795 | else:
796 | self.add_stmt(self.curr_block, node)
797 | self.generic_visit(node)
798 |
799 | def visit_For(self, node):
800 | self.count_for_loop += 1
801 | assign_block = self.new_block()
802 | if self.curr_block.condition:
803 | assign_block.condition = True
804 | self.add_edge(self.curr_block.bid, assign_block.bid)
805 | self.curr_block = assign_block
806 | self.curr_block.for_loop = self.count_for_loop
807 | self.for_stack.append(self.count_for_loop)
808 | num = self.count_for_loop
809 | self.add_stmt(self.curr_block, ast.Assign(targets=[ast.Name(id=f'p{num}', ctx=ast.Store())], value=ast.Num(n=0)))
810 | if_node = ast.If(test=ast.Compare(left=ast.Name(id=f'p{num}', ctx=ast.Load()), ops=[ast.Lt()], comparators=[ast.Call(func=ast.Name(id='len', ctx=ast.Load()), args=[node.iter], keywords=[])]), body=[ast.Assign(targets=[ast.Name(id=node.target.id, ctx=ast.Store())], value=ast.Subscript(value=node.iter, slice=ast.Index(value=ast.Name(id=f'p{num}', ctx=ast.Load())), ctx=ast.Load()))] + node.body, orelse=[])
811 | loop_guard = self.add_loop_block()
812 | self.continue_stack.append(loop_guard)
813 | self.curr_block = loop_guard
814 | self.curr_block.for_name = node
815 | self.add_stmt(self.curr_block, if_node)
816 | # New block for the body of the for-loop.
817 | for_block = self.add_edge(self.curr_block.bid, self.new_block().bid)
818 | if not node.orelse:
819 | # Block of code after the for loop.
820 | afterfor_block = self.add_edge(self.curr_block.bid, self.new_block().bid)
821 | self.loop_stack.append(afterfor_block)
822 | self.curr_block = for_block
823 | self.curr_block.condition = True
824 | self.curr_block = self.populate_For_body(if_node.body)
825 | end_node = ast.AugAssign(target=ast.Name(id=f'p{num}', ctx=ast.Store()), op=ast.Add(), value=ast.Num(n=1))
826 | self.add_stmt(self.curr_block, end_node)
827 | self.curr_block.for_loop = self.for_stack[-1]
828 | self.add_edge(self.curr_block.bid, loop_guard.bid)
829 | self.cfg.back_edges.append((self.curr_block.bid, loop_guard.bid))
830 | else:
831 | # Block of code after the for loop.
832 | afterfor_block = self.new_block()
833 | orelse_block = self.add_edge(self.curr_block.bid, self.new_block().bid, ast.Name(id='else', ctx=ast.Load()))
834 | self.loop_stack.append(afterfor_block)
835 | self.curr_block = for_block
836 | self.curr_block.condition = True
837 | self.curr_block = self.populate_For_body(if_node.body)
838 | end_node = ast.AugAssign(target=ast.Name(id=f'p{num}', ctx=ast.Store()), op=ast.Add(), value=ast.Num(n=1))
839 | self.add_stmt(self.curr_block, end_node)
840 | self.curr_block.for_loop = self.for_stack[-1]
841 | self.add_edge(self.curr_block.bid, loop_guard.bid)
842 | self.cfg.back_edges.append((self.curr_block.bid, loop_guard.bid))
843 |
844 | self.curr_block = orelse_block
845 | for child in node.orelse:
846 | self.visit(child)
847 | self.add_edge(orelse_block.bid, afterfor_block.bid, "For")
848 |
849 | # Continue building the CFG in the after-for block.
850 | self.curr_block = afterfor_block
851 | self.continue_stack.pop()
852 | self.loop_stack.pop()
853 | self.for_stack.pop()
854 |
855 | def visit_GeneratorExp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]:
856 | if not generators:
857 | self.generic_visit(self.genExpReg[1].elt) # the location of the node may be wrong
858 | if self.genExpReg[0]: # bug if there is else statement in comprehension
859 | return [ast.Expr(value=ast.Yield(value=self.genExpReg[1].elt))]
860 | else:
861 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_GeneratorExp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_GeneratorExp_Rec(generators[:-1]), orelse=[])]
862 |
863 | def visit_GeneratorExp(self, node):
864 | try: # try may change to checking if self.genExpReg exists
865 | self.generic_visit(ast.FunctionDef(name='__' + self.genExpReg[0] + 'Generator__',
866 | args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
867 | body = self.visit_GeneratorExp_Rec(self.genExpReg[1].generators),
868 | decorator_list=[], returns=None))
869 | except:
870 | pass
871 | finally:
872 | self.genExpReg = None
873 |
874 | def visit_If(self, node):
875 | # Add the If statement at the end of the current block.
876 | self.add_stmt(self.curr_block, node)
877 |
878 | # Create a block for the code after the if-else.
879 | body_block = self.new_block()
880 | self.store_True.append(body_block.bid+1)
881 | if_block = self.add_edge(self.curr_block.bid, body_block.bid, node.test)
882 | node_list = []
883 | # New block for the body of the else if there is an else clause.
884 | store = self.curr_block.bid
885 | node_list = self.populate_If_body(node.body, node_list)
886 | if node.orelse:
887 | self.curr_block = self.add_edge(store, self.new_block().bid, self.invert(node.test))
888 |
889 | # Visit the children in the body of the else to populate the block.
890 | node_list = self.populate_If_body(node.orelse, node_list)
891 | else:
892 | node_list.append(store)
893 | # Visit children to populate the if block.
894 | self.curr_block = if_block
895 | afterif_block = self.new_block()
896 | for bid in node_list:
897 | self.add_edge(bid, afterif_block.bid)
898 | # Continue building the CFG in the after-if block.
899 | self.curr_block = afterif_block
900 |
901 | def visit_IfExp_Rec(self, node: Type[ast.AST]) -> List[Type[ast.AST]]:
902 | return [ast.If(test=node.test, body=[ast.Return(value=node.body)], orelse=self.visit_IfExp_Rec(node.orelse) if type(node.orelse) == ast.IfExp else [ast.Return(value=node.orelse)])]
903 |
904 | def visit_IfExp(self, node):
905 | if self.ifExp:
906 | self.generic_visit(ast.Module(self.visit_IfExp_Rec(node)))
907 |
908 | def visit_Lambda(self, node): # deprecated since there is autopep8
909 | self.add_subgraph(ast.FunctionDef(name=self.lambdaReg[0], args=node.args, body = [ast.Return(value=node.body)], decorator_list=[], returns=None))
910 | self.lambdaReg = None
911 |
912 | def visit_ListComp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]:
913 | if not generators:
914 | self.generic_visit(self.listCompReg[1].elt) # the location of the node may be wrong
915 | if self.listCompReg[0]: # bug if there is else statement in comprehension
916 | return [ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id=self.listCompReg[0], ctx=ast.Load()), attr='append', ctx=ast.Load()), args=[self.listCompReg[1].elt], keywords=[]))]
917 | else:
918 | return [ast.Expr(value=self.listCompReg[1].elt)]
919 | else:
920 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_ListComp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_ListComp_Rec(generators[:-1]), orelse=[])]
921 |
922 | def visit_ListComp(self, node):
923 | try: # try may change to checking if self.listCompReg exists
924 | self.generic_visit(ast.Module(self.visit_ListComp_Rec(self.listCompReg[1].generators)))
925 | except:
926 | pass
927 | finally:
928 | self.listCompReg = None
929 |
930 | def visit_Pass(self, node):
931 | self.add_stmt(self.curr_block, node)
932 |
933 | def visit_Raise(self, node):
934 | self.add_stmt(self.curr_block, node)
935 | self.curr_block = self.new_block()
936 |
937 | def visit_Return(self, node):
938 | if type(node.value) == ast.IfExp:
939 | self.ifExp = True
940 | self.generic_visit(node)
941 | self.ifExp = False
942 | else:
943 | self.add_stmt(self.curr_block, node)
944 | # self.cfg.finalblocks.append(self.curr_block)
945 | # Continue in a new block but without any jump to it -> all code after
946 | # the return statement will not be included in the CFG.
947 | self.curr_block = self.new_block()
948 |
949 | def visit_SetComp_Rec(self, generators: List[Type[ast.AST]]) -> List[Type[ast.AST]]:
950 | if not generators:
951 | self.generic_visit(self.setCompReg[1].elt) # the location of the node may be wrong
952 | if self.setCompReg[0]:
953 | return [ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id=self.setCompReg[0], ctx=ast.Load()), attr='add', ctx=ast.Load()), args=[self.setCompReg[1].elt], keywords=[]))]
954 | else: # not supported yet
955 | return [ast.Expr(value=self.setCompReg[1].elt)]
956 | else:
957 | return [ast.For(target=generators[-1].target, iter=generators[-1].iter, body=[ast.If(test=self.combine_conditions(generators[-1].ifs), body=self.visit_SetComp_Rec(generators[:-1]), orelse=[])] if generators[-1].ifs else self.visit_SetComp_Rec(generators[:-1]), orelse=[])]
958 |
959 | def visit_SetComp(self, node):
960 | try: # try may change to checking if self.setCompReg exists
961 | self.generic_visit(ast.Module(self.visit_SetComp_Rec(self.setCompReg[1].generators)))
962 | except:
963 | pass
964 | finally:
965 | self.setCompReg = None
966 |
967 | def visit_Try(self, node):
968 | loop_guard = self.add_loop_block()
969 | self.curr_block = loop_guard
970 | self.add_stmt(loop_guard, ast.Try(body=[], handlers=[], orelse=[], finalbody=[]))
971 |
972 | after_try_block = self.new_block()
973 | self.add_stmt(after_try_block, ast.Name(id='handle errors', ctx=ast.Load()))
974 | self.populate_body(node.body, after_try_block.bid, "Try")
975 |
976 | self.curr_block = after_try_block
977 |
978 | if node.handlers:
979 | for handler in node.handlers:
980 | before_handler_block = self.new_block()
981 | self.curr_block = before_handler_block
982 | self.add_edge(after_try_block.bid, before_handler_block.bid, handler.type if handler.type else ast.Name(id='Error', ctx=ast.Load()))
983 |
984 | after_handler_block = self.new_block()
985 | self.add_stmt(after_handler_block, ast.Name(id='end except', ctx=ast.Load()))
986 | self.populate_body(handler.body, after_handler_block.bid, "Try")
987 | self.add_edge(after_handler_block.bid, after_try_block.bid)
988 |
989 | if node.orelse:
990 | before_else_block = self.new_block()
991 | self.curr_block = before_else_block
992 | self.add_edge(after_try_block.bid, before_else_block.bid, ast.Name(id='No Error', ctx=ast.Load()))
993 |
994 | after_else_block = self.new_block()
995 | self.add_stmt(after_else_block, ast.Name(id='end no error', ctx=ast.Load()))
996 | self.populate_body(node.orelse, after_else_block.bid, "Try")
997 | self.add_edge(after_else_block.bid, after_try_block.bid)
998 |
999 | finally_block = self.new_block()
1000 | self.curr_block = finally_block
1001 |
1002 | if node.finalbody:
1003 | self.add_edge(after_try_block.bid, finally_block.bid, ast.Name(id='Finally', ctx=ast.Load()))
1004 | after_finally_block = self.new_block()
1005 | self.populate_body(node.finalbody, after_finally_block.bid, "Try")
1006 | self.curr_block = after_finally_block
1007 | else:
1008 | self.add_edge(after_try_block.bid, finally_block.bid)
1009 |
1010 | def visit_While(self, node):
1011 | loop_guard = self.add_loop_block()
1012 | self.continue_stack.append(loop_guard)
1013 | self.curr_block = loop_guard
1014 | self.add_stmt(loop_guard, node)
1015 | # New block for the case where the test in the while is False.
1016 | afterwhile_block = self.new_block()
1017 | self.loop_stack.append(afterwhile_block)
1018 | inverted_test = self.invert(node.test)
1019 |
1020 | if not node.orelse:
1021 | # Skip shortcut loop edge if while True:
1022 | if not (isinstance(inverted_test, ast.NameConstant) and inverted_test.value == False):
1023 | self.add_edge(self.curr_block.bid, afterwhile_block.bid, inverted_test)
1024 |
1025 | # New block for the case where the test in the while is True.
1026 | # Populate the while block.
1027 | body_block = self.new_block()
1028 | body_block.condition = True
1029 | self.curr_block = self.add_edge(self.curr_block.bid, body_block.bid, node.test)
1030 | self.populate_body(node.body, loop_guard.bid, "While")
1031 | else:
1032 | orelse_block = self.new_block()
1033 | if not (isinstance(inverted_test, ast.NameConstant) and inverted_test.value == False):
1034 | self.add_edge(self.curr_block.bid, orelse_block.bid, inverted_test)
1035 | self.curr_block = self.add_edge(self.curr_block.bid, self.new_block().bid, node.test)
1036 |
1037 | self.populate_body(node.body, loop_guard.bid, "While")
1038 | self.curr_block = orelse_block
1039 | for child in node.orelse:
1040 | self.visit(child)
1041 | self.add_edge(orelse_block.bid, afterwhile_block.bid)
1042 |
1043 | # Continue building the CFG in the after-while block.
1044 | self.curr_block = afterwhile_block
1045 | self.loop_stack.pop()
1046 | self.continue_stack.pop()
1047 |
1048 | def visit_Yield(self, node):
1049 | self.curr_block = self.add_edge(self.curr_block.bid, self.new_block().bid)
1050 |
1051 |
1052 | class PyParser:
1053 |
1054 | def __init__(self, script):
1055 | self.script = script
1056 |
1057 | def formatCode(self):
1058 | self.script = autopep8.fix_code(self.script)
1059 |
1060 | # https://github.com/liftoff/pyminifier/blob/master/pyminifier/minification.py
1061 | def removeCommentsAndDocstrings(self):
1062 | io_obj = io.StringIO(self.script) # ByteIO for Python2?
1063 | out = ""
1064 | prev_toktype = tokenize.INDENT
1065 | last_lineno = -1
1066 | last_col = 0
1067 | for tok in tokenize.generate_tokens(io_obj.readline):
1068 | token_type = tok[0]
1069 | token_string = tok[1]
1070 | start_line, start_col = tok[2]
1071 | end_line, end_col = tok[3]
1072 | if start_line > last_lineno:
1073 | last_col = 0
1074 | if start_col > last_col:
1075 | out += (" " * (start_col - last_col))
1076 | # Remove comments:
1077 | if token_type == tokenize.COMMENT:
1078 | pass
1079 | # This series of conditionals removes docstrings:
1080 | elif token_type == tokenize.STRING:
1081 | if prev_toktype != tokenize.INDENT:
1082 | # This is likely a docstring; double-check we're not inside an operator:
1083 | if prev_toktype != tokenize.NEWLINE:
1084 | # Note regarding NEWLINE vs NL: The tokenize module
1085 | # differentiates between newlines that start a new statement
1086 | # and newlines inside of operators such as parens, brackes,
1087 | # and curly braces. Newlines inside of operators are
1088 | # NEWLINE and newlines that start new code are NL.
1089 | # Catch whole-module docstrings:
1090 | if start_col > 0:
1091 | # Unlabelled indentation means we're inside an operator
1092 | out += token_string
1093 | # Note regarding the INDENT token: The tokenize module does
1094 | # not label indentation inside of an operator (parens,
1095 | # brackets, and curly braces) as actual indentation.
1096 | # For example:
1097 | # def foo():
1098 | # "The spaces before this docstring are tokenize.INDENT"
1099 | # test = [
1100 | # "The spaces before this string do not get a token"
1101 | # ]
1102 | else:
1103 | out += token_string
1104 | prev_toktype = token_type
1105 | last_col = end_col
1106 | last_lineno = end_line
1107 | self.script = out
1108 |
1109 | if __name__ == '__main__':
1110 | filename = sys.argv[1]
1111 | file = filename.split('/')[-1].split('.')[0]
1112 | filepath = f'./CFG/{file}_CFG'
1113 | try:
1114 | source = open(filename, 'r').read()
1115 | compile(source, filename, 'exec')
1116 | except:
1117 | print('Error in source code')
1118 | exit(1)
1119 |
1120 | parser = PyParser(source)
1121 | parser.removeCommentsAndDocstrings()
1122 | parser.formatCode()
1123 | cfg = CFGVisitor().build(filename, ast.parse(parser.script))
1124 | cfg.clean()
1125 | cfg.track_execution()
1126 | cfg.show(filepath)
1127 |
1128 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument(
6 | "-f", "--fff", help="a dummy argument to fool ipython", default="1"
7 | )
8 | parser.add_argument(
9 | "--batch_size", type=int, default=256, help="Input batch size for training"
10 | )
11 | parser.add_argument(
12 | "--hidden_dim", type=int, default=128, help="Dimension of hidden states"
13 | )
14 | parser.add_argument(
15 | "--vocab_size", type=int, default=100000, help="Vocab size for training"
16 | )
17 | parser.add_argument(
18 | "--max_node", type=int, default=100, help="Maximum number of nodes"
19 | )
20 | parser.add_argument(
21 | "--max_token", type=int, default=512, help="Maximum number of tokens"
22 | )
23 | parser.add_argument(
24 | "--learning_rate", type=float, default=0.001, help="Learning rate"
25 | )
26 | parser.add_argument("--num_epoch", type=int, default=600, help="Epochs for training")
27 | parser.add_argument(
28 | "--cpu", action="store_true", default=False, help="Disables CUDA training"
29 | )
30 | parser.add_argument("--gpu", type=int, default=0, help="GPU ID for CUDA training")
31 | parser.add_argument(
32 | "--save-model",
33 | action="store_true",
34 | default=False,
35 | help="For Saving the current Model",
36 | )
37 | parser.add_argument("--extra_aggregate", action="store_true", default=False)
38 | parser.add_argument("--delete_redundant_node", action="store_true", default=False)
39 | parser.add_argument("--output", type=str, default="output")
40 | parser.add_argument("--checkpoint", type=int, default=None)
41 | parser.add_argument("--dryrun", action="store_true", default=False)
42 | # add config epoch that is a list of epochs
43 | parser.add_argument(
44 | "--vis_epoch",
45 | nargs="+",
46 | type=int,
47 | default=[],
48 | help="Epochs for visualization",
49 | )
50 | #list of file code to visualize
51 | parser.add_argument(
52 | "--vis_code",
53 | nargs="+",
54 | type=int,
55 | default=[],
56 | help="Code to visualize",
57 | )
58 | parser.add_argument("--groundtruth", action="store_true", default=False)
59 | parser.add_argument("--name_exp", type=str, default=None)
60 | parser.add_argument("--cuda_num", type=int, default=None)
61 | parser.add_argument("--seed", type=int, default=300103)
62 | parser.add_argument("--alpha", type=float, default=0.5)
63 | parser.add_argument("--epoch", type=int, default=None)
64 | parser.add_argument("--time", type=int, default=60)
65 | parser.add_argument("--data", type=str, default="CodeNet")
66 | parser.add_argument("--runtime_detection", action="store_true", default=False)
67 | parser.add_argument("--bug_localization", action="store_true", default=False)
68 | parser.add_argument("--claude_api_key", type=str, default=None)
69 | parser.add_argument("--model", type=str, default='claude-3-5-sonnet-20240620')
70 | parser.add_argument("--folder_path", type=str, default='fuzz_testing_dataset')
71 | args = parser.parse_args()
72 | return args
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from torchtext import data
2 | from torchtext.data import Iterator
3 | import pandas as pd
4 | import config
5 | import torch
6 |
7 | def read_data(data_path, fields):
8 | csv_data = pd.read_csv(data_path, chunksize=100)
9 | all_examples = []
10 | for n, chunk in enumerate(csv_data):
11 |
12 | examples = chunk.apply(lambda r: data.Example.fromlist([eval(r['nodes']), eval(r['forward']), eval(r['backward']),
13 | eval(r['target'])], fields), axis=1)
14 | all_examples.extend(list(examples))
15 | return all_examples
16 |
17 | def get_iterators(args, device):
18 | TEXT = data.Field(tokenize=lambda x: x.split()[:args.max_token])
19 | NODE = data.NestedField(TEXT, preprocessing=lambda x: x[:args.max_node], include_lengths=True)
20 | ROW = data.Field(pad_token=1.0, use_vocab=False,
21 | preprocessing=lambda x: [1, 1] if any(i > args.max_node for i in x) else x)
22 | EDGE = data.NestedField(ROW)
23 | TARGET = data.Field(use_vocab=False, preprocessing=lambda x: x[:args.max_node], pad_token=0, batch_first=True)
24 |
25 | fields = [("nodes", NODE), ("forward", EDGE), ("backward", EDGE), ("target", TARGET)]
26 |
27 | print('Read data...')
28 | examples = read_data(f'data/{args.data}_train.csv', fields)
29 | train = data.Dataset(examples, fields)
30 | NODE.build_vocab(train, max_size=args.vocab_size)
31 |
32 | examples = read_data(f'data/{args.data}_test.csv', fields)
33 | test = data.Dataset(examples, fields)
34 | train_iter = Iterator(train,
35 | batch_size=args.batch_size,
36 | device=device,
37 | sort=False,
38 | sort_key=lambda x: len(x.nodes),
39 | sort_within_batch=False,
40 | repeat=False)
41 | test_iter = Iterator(test, batch_size=args.batch_size, device=device, train=False,
42 | sort=False, sort_key=lambda x: len(x.nodes), sort_within_batch=False, repeat=False, shuffle=False)
43 | print("Done")
44 | return train_iter, test_iter
45 |
--------------------------------------------------------------------------------
/fuzz_testing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import anthropic
3 | import ast, astor
4 | from cfg import *
5 | import re
6 | import sys
7 | import trace_execution
8 | import os
9 | import io
10 | import pandas as pd
11 | from torchtext import data
12 | from torchtext.data import Iterator
13 | import pandas as pd
14 | import torch
15 | import model
16 | import time
17 | import config
18 | import argparse
19 |
20 | def generate_prompt(method_code, feedback=""):
21 | prompt = f"""
22 | \n\nHuman: You are a terminal. Analyze the following Python code and generate likely inputs for all variables that might raise errors. Add these generated inputs at the beginning of the code snippet.
23 |
24 | Example:
25 | Python Method:
26 | if(S[0]=="A" and S[2,-1].count("C")==1):
27 | cnt=0
28 | for i in S:
29 | if(97<=ord(i) and ord(i)<=122):
30 | cnt+=1
31 | if(cnt==2):
32 | print("AC")
33 | else :
34 | print("WA")
35 | else :
36 | print("WA")
37 |
38 | Generated Input:
39 | S = 'AtCoder'
40 |
41 | Task:
42 | Given the following Python method, generate likely inputs for variables:
43 | {feedback}
44 |
45 | Python Method:
46 | {method_code}
47 |
48 | Generated Input:
49 | (No explanation needed, only one Generated Input:)
50 | \n\nAssistant:
51 | """
52 | return prompt
53 |
54 | def get_generated_inputs(claude_api_key, model, method_code, feedback=""):
55 | client = anthropic.Anthropic(api_key=claude_api_key)
56 | prompt = generate_prompt(method_code, feedback)
57 | response = client.messages.create(
58 | model=model,
59 | max_tokens=1024,
60 | messages=[
61 | {"role": "user", "content": prompt}
62 | ]
63 | )
64 | return response.content[0].text
65 |
66 | def add_generated_inputs_to_code(code, inputs):
67 | lines = code.split('\n')
68 | # Find the first non-import line
69 | insert_index = 0
70 | for i, line in enumerate(lines):
71 | if not line.startswith(('import', 'from', '\n')):
72 | insert_index = i
73 | break
74 |
75 | # Insert generated inputs at the found index
76 | for input_line in inputs.split('\n'):
77 | if input_line.startswith("Generated"):
78 | continue
79 | if input_line.strip():
80 | lines.insert(insert_index, input_line)
81 | insert_index += 1
82 |
83 | return '\n'.join(lines)
84 |
85 | def read_data(data_path, fields):
86 | csv_data = pd.read_csv(data_path, chunksize=100)
87 | all_examples = []
88 | for n, chunk in enumerate(csv_data):
89 | examples = chunk.apply(lambda r: data.Example.fromlist([eval(r['nodes']), eval(r['forward']), eval(r['backward']),
90 | eval(r['target'])], fields), axis=1)
91 | all_examples.extend(list(examples))
92 | return all_examples
93 |
94 | opt = config.parse()
95 | if opt.claude_api_key == None:
96 | raise Exception("Lack of CLAUDE api")
97 | if opt.cuda_num == None:
98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99 | else:
100 | device = torch.device(f"cuda:{opt.cuda_num}" if torch.cuda.is_available() else "cpu")
101 |
102 | TEXT = data.Field(tokenize=lambda x: x.split()[:512])
103 | NODE = data.NestedField(TEXT, preprocessing=lambda x: x[:100], include_lengths=True)
104 | ROW = data.Field(pad_token=1.0, use_vocab=False,
105 | preprocessing=lambda x: [1, 1] if any(i > 100 for i in x) else x)
106 | EDGE = data.NestedField(ROW)
107 | TARGET = data.Field(use_vocab=False, preprocessing=lambda x: x[:100], pad_token=0, batch_first=True)
108 |
109 | fields = [("nodes", NODE), ("forward", EDGE), ("backward", EDGE), ("target", TARGET)]
110 |
111 | print('Read data...')
112 | examples = read_data(f'data/FixEval_complete_train.csv', fields)
113 | train = data.Dataset(examples, fields)
114 | NODE.build_vocab(train, max_size=100000)
115 |
116 | orin_nodes = ['BEGIN', "_in = ['2', 3]", 'cont_str = _in[0] * _in[1]', 'cont_num = int(cont_str)', 'sqrt_flag = False', 'p1 = 0', 'p1 < len(range(4, 100))', 'T i = range(4, 100)[p1]', 'sqrt_flag', 'sqrt = i * i', 'cont_num == sqrt', 'T sqrt_flag = True', 'p1 += 1', "T print('Yes')", "print('No')", 'EXIT']
117 | orin_fwd_edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (7, 9), (8, 10), (10, 11), (11, 12), (12, 9), (9, 14), (9, 15), (11, 13), (14, 16), (15, 16)]
118 | orin_back_edges = [(13, 7)]
119 | orin_exe_path = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1]
120 |
121 | net = model.CodeFlow(opt).to(device)
122 | checkpoint_path = f"checkpoints/checkpoints_{opt.checkpoint}/epoch-{opt.epoch}.pt"
123 | net.load_state_dict(torch.load(checkpoint_path, map_location=device))
124 | net.eval()
125 |
126 | outpath = 'fuzz_testing_output'
127 | if not os.path.exists(outpath):
128 | os.makedirs(outpath)
129 |
130 | def extract_inputs(generated_text):
131 | # Use regular expression to match lines that are variable assignments
132 | input_lines = re.findall(r'^\s*\w+\s*=\s*.+$', generated_text, re.MULTILINE)
133 | return '\n'.join(input_lines)
134 |
135 | error_dict = {}
136 | locate = 0
137 | for root, _, files in os.walk(opt.folder_path):
138 | files = sorted(files, key=lambda x: int(x.split('.')[0][5:]))
139 | for file in files:
140 | print(f'Fuzz testing file {file}')
141 | feedback_list = []
142 | start_time = time.time()
143 | time_limit = opt.time # time limit in seconds
144 | repeat = True
145 | while repeat:
146 | if time.time() - start_time > time_limit:
147 | print(f'Time limit exceeded for file {file}')
148 | break
149 | feedback = f"\nThese inputs did not raise runtime errors, avoid to generate the same:\n{feedback_list}" if feedback_list else ""
150 | file_path = os.path.join(opt.folder_path, file)
151 | with open(file_path, 'r') as f:
152 | code = f.read()
153 | generated_inputs = get_generated_inputs(opt.claude_api_key, opt.model, code, feedback)
154 | generated_inputs = extract_inputs(generated_inputs)
155 | print(generated_inputs)
156 | # Add generated inputs to the original code
157 | modified_code = add_generated_inputs_to_code(code, generated_inputs)
158 | filename = os.path.join(outpath, file)
159 | with open(filename, 'w') as modified_file:
160 | modified_file.write(modified_code)
161 |
162 | BlockId().counter = 0
163 | try:
164 | source = open(filename, 'r').read()
165 | compile(source, filename, 'exec')
166 | except:
167 | print('Error in source code')
168 | exit(1)
169 | parser = PyParser(source)
170 | parser.removeCommentsAndDocstrings()
171 | parser.formatCode()
172 | try:
173 | cfg = CFGVisitor().build(filename, ast.parse(parser.script))
174 | except AttributeError:
175 | continue
176 | except IndentationError:
177 | continue
178 | except TypeError:
179 | continue
180 | except SyntaxError:
181 | continue
182 |
183 | cfg.clean()
184 | try:
185 | cfg.track_execution()
186 | except Exception:
187 | print("Generated input is not valid")
188 | continue
189 | code = {}
190 | for_loop = {}
191 | for i in cfg.blocks:
192 | if cfg.blocks[i].for_loop != 0:
193 | if cfg.blocks[i].for_loop not in for_loop:
194 | for_loop[cfg.blocks[i].for_loop] = [i]
195 | else:
196 | for_loop[cfg.blocks[i].for_loop].append(i)
197 | first = []
198 | second = []
199 | for i in for_loop:
200 | first.append(for_loop[i][0]+1)
201 | second.append(for_loop[i][1])
202 | orin_node = []
203 | track = {}
204 | track_for = {}
205 | for i in cfg.blocks:
206 | if cfg.blocks[i].stmts_to_code():
207 | if int(i) == 1:
208 | st = 'BEGIN'
209 | elif int(i) == len(cfg.blocks):
210 | st = 'EXIT'
211 | else:
212 | if i in first:
213 | line = astor.to_source(cfg.blocks[i].for_name)
214 | st = line.split('\n')[0]
215 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
216 | else:
217 | st = cfg.blocks[i].stmts_to_code()
218 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
219 | orin_node.append([i, st, None])
220 | if st not in track:
221 | track[st] = [len(orin_node)-1]
222 | else:
223 | track[st].append(len(orin_node)-1)
224 | track_for[i] = len(orin_node)-1
225 | with open(filename, 'r') as file_open:
226 | lines = file_open.readlines()
227 | for i in range(1, len(lines)+1):
228 | line = lines[i-1]
229 | #delete \n at the end of each line and delete all spaces
230 | line = line.strip()
231 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "")
232 | if line.startswith('elif'):
233 | line = line[2:]
234 | if line in track:
235 | orin_node[track[line][0]][2] = i
236 | if orin_node[track[line][0]][0] in first:
237 | orin_node[track[line][0]-1][2] = i-0.4
238 | orin_node[track[line][0]+1][2] = i+0.4
239 | if len(track[line]) > 1:
240 | track[line].pop(0)
241 | for i in second:
242 | max_val = 0
243 | for edge in cfg.edges:
244 | if edge[0] == i:
245 | if orin_node[track_for[edge[1]]][2] > max_val:
246 | max_val = orin_node[track_for[edge[1]]][2]
247 | if edge[1] == i:
248 | if orin_node[track_for[edge[0]]][2] > max_val:
249 | max_val = orin_node[track_for[edge[0]]][2]
250 | orin_node[track_for[i]][2] = max_val + 0.5
251 | orin_node[0][2] = 0
252 | orin_node[-1][2] = len(lines)+1
253 | # sort orin_node by the third element
254 | orin_node.sort(key=lambda x: x[2])
255 |
256 | nodes = []
257 | matching = {}
258 | for i in cfg.blocks:
259 | if cfg.blocks[i].stmts_to_code():
260 | if int(i) == 1:
261 | nodes.append('BEGIN')
262 | elif int(i) == len(cfg.blocks):
263 | nodes.append('EXIT')
264 | else:
265 | st = cfg.blocks[i].stmts_to_code()
266 | st_no_space = re.sub(r"\s+", "", st)
267 | # if start with if or while, delete these keywords
268 | if st.startswith('if'):
269 | st = st[3:]
270 | elif st.startswith('while'):
271 | st = st[6:]
272 | if cfg.blocks[i].condition:
273 | st = 'T '+ st
274 | if st.endswith('\n'):
275 | st = st[:-1]
276 | if st.endswith(":"):
277 | st = st[:-1]
278 | nodes.append(st)
279 | matching[i] = len(nodes)
280 |
281 | fwd_edges = []
282 | back_edges = []
283 | edges = {}
284 | for edge in cfg.edges:
285 | if edge not in cfg.back_edges:
286 | fwd_edges.append((matching[edge[0]], matching[edge[1]]))
287 | else:
288 | back_edges.append((matching[edge[0]], matching[edge[1]]))
289 | if matching[edge[0]] not in edges:
290 | edges[matching[edge[0]]] = [matching[edge[1]]]
291 | else:
292 | edges[matching[edge[0]]].append(matching[edge[1]])
293 | exe_path = [0 for i in range(len(nodes))]
294 | for i in range(len(cfg.path)):
295 | if cfg.path[i] == 1:
296 | exe_path[matching[i+1]-1] = 1
297 | out_nodes=[nodes, orin_nodes]
298 | out_fw_path=[fwd_edges, orin_fwd_edges]
299 | out_back_path=[back_edges, orin_back_edges]
300 | out_exe_path=[exe_path, orin_exe_path]
301 | data_example = {
302 | 'nodes': out_nodes,
303 | 'forward': out_fw_path,
304 | 'backward': out_back_path,
305 | 'target': out_exe_path,
306 | }
307 |
308 | df = pd.DataFrame(data_example)
309 | # Save to CSV
310 | df.to_csv(f'{outpath}/output.csv', index=False, quoting=1)
311 | examples = read_data(f'{outpath}/output.csv', fields)
312 | test = data.Dataset(examples, fields)
313 | test_iter = Iterator(test, batch_size=2, device=device, train=False,
314 | sort=False, sort_key=lambda x: len(x.nodes), sort_within_batch=False, repeat=False, shuffle=False)
315 | with torch.no_grad():
316 | for batch in test_iter:
317 | x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float()
318 | if isinstance(x, tuple):
319 | pred = net(x[0], edges, x[1], x[2])
320 | else:
321 | pred = net(x, edges)
322 | pred = pred[0].squeeze()
323 | pred = (pred > opt.beta).float()
324 | if pred[len(nodes)-1] == 1:
325 | print("No Runtime Error")
326 | feedback_list.append(generated_inputs)
327 | else:
328 | mask_pred = pred[:len(nodes)] == 1
329 | indices_pred = torch.nonzero(mask_pred).flatten()
330 | farthest_pred = indices_pred.max().item()
331 | error_line = nodes[farthest_pred]
332 | print(f"Runtime Error in line: {error_line}")
333 |
334 | mask_target = target[0][:len(nodes)] == 1
335 | indices_target = torch.nonzero(mask_target).flatten()
336 | farthest_target = indices_target.max().item()
337 | true_error_line = nodes[farthest_target]
338 | error_dict[file] = [error_line, true_error_line]
339 |
340 | if farthest_pred == farthest_target:
341 | locate += 1
342 | repeat = False
343 |
344 | locate_true = locate/len(error_dict)*100
345 | print(f'Fuzz testing within {opt.time}s')
346 | print(f'Sucessfully detect: {len(error_dict)}/{len(files)}')
347 | print(f'Bug Localization Acc: {locate_true:.2f}%')
348 | print(error_dict)
349 |
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_1.py:
--------------------------------------------------------------------------------
1 | ans=0
2 | cur=0
3 | ACGT=set("A","C","G","T")
4 | for ss in s:
5 | if ss in ACGT:
6 | cur+=1
7 | else:
8 | ans=max(cur,ans)
9 | cur=0
10 | print(max(ans,cur))
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_10.py:
--------------------------------------------------------------------------------
1 | S = 'nikoandsolstice'
2 | s = len(S)
3 | if (s <= K):
4 | print(S)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_11.py:
--------------------------------------------------------------------------------
1 | s = 'CSS'
2 | s=0
3 | for i in range(len(s)):
4 | if s[i]==t[i]:
5 | s+=1
6 | print(s)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_12.py:
--------------------------------------------------------------------------------
1 | col = N / 2
2 | if col == 0:
3 | print(col)
4 | else:
5 | col += 1
6 | print(col)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_13.py:
--------------------------------------------------------------------------------
1 | x=a+a*a+a*a*a
2 | print(a)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_14.py:
--------------------------------------------------------------------------------
1 | cont_str = _in[0] * _in[1]
2 | cont_num = int(cont_str)
3 | sqrt_flag = False
4 | for i in range(4, 100):
5 | sqrt = i * i
6 | if cont_num == sqrt:
7 | sqrt_flag = True
8 | break
9 | if sqrt_flag:
10 | print('Yes')
11 | else:
12 | print('No')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_15.py:
--------------------------------------------------------------------------------
1 | if s >= 3200:
2 | print(s)
3 | else:
4 | print("red")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_16.py:
--------------------------------------------------------------------------------
1 | if (s[:4] == 2019 and s[5] == 0 and s[6] <= 4) or (s[:4] < 2019):
2 | print("Heisei")
3 | else:
4 | print("TBD")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_17.py:
--------------------------------------------------------------------------------
1 | b = a / 3
2 | print(b)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_18.py:
--------------------------------------------------------------------------------
1 | import math
2 | m=N+1
3 | for j in range(2,math.sqrt(N)):
4 | if N%j==0:
5 | a=j
6 | b=N//j
7 | if a>b:
8 | break
9 | else:
10 | m=min(m,(a+b))
11 | print(m-2)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_19.py:
--------------------------------------------------------------------------------
1 | A = 2
2 | B = ['i', 'p', ' ', 'c', 'c']
3 | C=[]
4 | for i in range(A):
5 | C=A[i]
6 | C=B[i]
7 | print(C)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_2.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | for i in range(sqrt(n),0,-1):
3 | if n%i==0:
4 | j=n//i
5 | print(i+j-2)
6 | break
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_20.py:
--------------------------------------------------------------------------------
1 | a = [1, 1, 2, 2, 2]
2 | t=a[0]
3 | s=0
4 | Ans=0
5 | for i in range(1,n):
6 | if t==a[i]:
7 | s+=1
8 | else:
9 | Ans+=s//2
10 | s=0
11 | t=a[i]
12 | print(Ans)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_21.py:
--------------------------------------------------------------------------------
1 | if N % 1000 == 0:
2 | print(0)
3 | else:
4 | sen = N//1000
5 | sen2 = sen*1000
6 | atai = N-sen2
7 | print(atai)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_22.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | d = defaultdict(str)
3 | no_flag = False
4 | for i in range(len(S)):
5 | if S[i] != T[i]:
6 | if d[S[i]] == T[i]:
7 | S[i] = T[i]
8 | pass
9 | elif d[S[i]] == "" and d[T[i]] == "":
10 | d[S[i]] = T[i]
11 | d[T[i]] = S[i]
12 | else:
13 | pass
14 | S = sorted(S)
15 | T = sorted(T)
16 | if S == T:
17 | print("Yes")
18 | else:
19 | print("No")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_23.py:
--------------------------------------------------------------------------------
1 | if a == b & b == c:
2 | ans = 0
3 | else:
4 | for count in range(n):
5 | tmp = 0
6 | if a[count] == b[count]:
7 | if a[count] != c[count]:
8 | tmp += 1
9 | else:
10 | tmp += 1
11 | if (a[count] != c[count]) & (b[count] != c[count]):
12 | tmp += 1
13 | ans += tmp
14 | print(ans)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_24.py:
--------------------------------------------------------------------------------
1 | err = []
2 | for i in range(len(s) - 3):
3 | err.append(abs(int(s[i:i+3]), 753))
4 | print(min(err))
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_25.py:
--------------------------------------------------------------------------------
1 | if(s_input[1] == '+'):
2 | print(int(s_input[0]+int(s_input[2])))
3 | else:
4 | print(int(s_input[0]-int(s_input[2])))
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_26.py:
--------------------------------------------------------------------------------
1 | for i in range(0,len(b)):
2 | s+=int(b[i])
3 | if s-max(b)>max(b):
4 | print("Yes")
5 | else:
6 | print("No")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_27.py:
--------------------------------------------------------------------------------
1 | S[3]='7'
2 | print(S)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_28.py:
--------------------------------------------------------------------------------
1 | for i in range(len(s)):
2 | if s[i]=='1':
3 | s[i]='9'
4 | else:
5 | s[i]='1'
6 | print(s)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_29.py:
--------------------------------------------------------------------------------
1 | if N%2==0:
2 | print(N/2)
3 | else:
4 | print((N+1)/2)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_3.py:
--------------------------------------------------------------------------------
1 | s = 'abcabc'
2 | print(s[:n/2]==s[n/2:] and n%2==0)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_30.py:
--------------------------------------------------------------------------------
1 | size=len(S)
2 | print(S[0]+int(int(size)-2)+S[-1])
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_31.py:
--------------------------------------------------------------------------------
1 | if N % 2 == 1:
2 | print('No')
3 | exit()
4 | if S[:N/2] == S[N/2:]:
5 | print('Yes')
6 | else:
7 | print('No')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_32.py:
--------------------------------------------------------------------------------
1 | if a in 0:
2 | total = 0
3 | else:
4 | for i in a:
5 | total *= i
6 | if total > 10**18:
7 | total = -1
8 | print(total)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_33.py:
--------------------------------------------------------------------------------
1 | if n[1] == n[2] and n[0] == [1] or n[1] == n[2] and n[2] == n[3]:
2 | print('Yes')
3 | else:
4 | print('No')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_34.py:
--------------------------------------------------------------------------------
1 | for i in range(lst):
2 | if lst[i] == 0:
3 | ans = i+1
4 | print(ans)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_35.py:
--------------------------------------------------------------------------------
1 | print(a**3)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_36.py:
--------------------------------------------------------------------------------
1 | if a[3]!=8:
2 | a[3]=8
3 | print(a)
4 | else:
5 | print(a)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_37.py:
--------------------------------------------------------------------------------
1 | print(m % 1000)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_38.py:
--------------------------------------------------------------------------------
1 | res = a//2
2 | if a%2 >0:
3 | res += 1
4 | print(res)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_39.py:
--------------------------------------------------------------------------------
1 | for i in A:
2 | if A%2==0:
3 | if not A%3==0 or A%5==0:
4 | print("DENIED")
5 | x=1
6 | break
7 | else:
8 | continue
9 | if x !=0:
10 | print("APPROVED")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_4.py:
--------------------------------------------------------------------------------
1 | G = '2017'
2 | print(2*G-R)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_40.py:
--------------------------------------------------------------------------------
1 | a = ord(s)
2 | b = ord('ABC')
3 | r = ord('ARC')
4 | if a == b:
5 | print('ARC')
6 | elif a == b:
7 | print('ABC')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_41.py:
--------------------------------------------------------------------------------
1 | if S[0]!="1":
2 | print(S[0])
3 | else:
4 | print(S[1])
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_42.py:
--------------------------------------------------------------------------------
1 | n[3] = '8'
2 | print(str(n))
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_43.py:
--------------------------------------------------------------------------------
1 | k = '4'
2 | flag = True
3 | for i in range(k):
4 | if s[i] == '1':
5 | print(s[i])
6 | flag = False
7 | if flag:
8 | print(s[1])
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_44.py:
--------------------------------------------------------------------------------
1 | S[4] = '8'
2 | print(S)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_45.py:
--------------------------------------------------------------------------------
1 | S = 'abcabc'
2 | n = int(N/2)
3 | s1 = S[:n]
4 | s2 = S[n:]
5 | if s1 == s2:
6 | print('Yes')
7 | else:
8 | print('No')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_46.py:
--------------------------------------------------------------------------------
1 | d = a.split()
2 | b = d[0]*d[1]
3 | if(b%2==0):
4 | print('Even')
5 | else:
6 | print('Odd')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_47.py:
--------------------------------------------------------------------------------
1 | import sys
2 | input = sys.stdin.readline
3 | n = 5
4 | for i in a:
5 | print(a.count(str(i+1)))
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_48.py:
--------------------------------------------------------------------------------
1 | print("Christmas",["Eve "] * 25 - N)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_49.py:
--------------------------------------------------------------------------------
1 | a = [1, 2, 3, 4, 5, 6]
2 | memo = sum(a)
3 | a=a[0]
4 | b=memo-a[0]
5 | ans = abs(a-b)
6 | for i in range(1,n-1):
7 | a += a[i]
8 | b -= a[i]
9 | ans = min(ans,abs(a-b))
10 | print(ans)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_5.py:
--------------------------------------------------------------------------------
1 | a = r*r
2 | print(a)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_50.py:
--------------------------------------------------------------------------------
1 | if 7 in n:
2 | print("Yes")
3 | else:
4 | print("No")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_6.py:
--------------------------------------------------------------------------------
1 | ans = 10**12
2 | for k in range(1,(n+1)**0.5):
3 | if n%k == 0 :
4 | m = n//k + k - 2
5 | if ans > m:
6 | ans = m
7 | else:
8 | print(ans)
9 | sys.exit()
10 | print(ans)
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_7.py:
--------------------------------------------------------------------------------
1 | s = 'abcabc'
2 | if n%2==1:
3 | print('No')
4 | else:
5 | if s[:n/2]==s[n/2:]:
6 | print('Yes')
7 | else:
8 | print('No')
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_8.py:
--------------------------------------------------------------------------------
1 | num = "".join(num)
2 | if num % 4 == 0:
3 | print("YES")
4 | else:
5 | print("NO")
--------------------------------------------------------------------------------
/fuzz_testing_dataset/code_9.py:
--------------------------------------------------------------------------------
1 | s=str(n)
2 | array = list(map(int,s))
3 | if array%9==0:
4 | print("Yes")
5 | else:
6 | print("No")
--------------------------------------------------------------------------------
/generate_dataset/generate_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import ast, astor, autopep8, tokenize, io, sys
3 | import graphviz as gv
4 | from typing import Dict, List, Tuple, Set, Optional, Type
5 | from cfg import *
6 | import re
7 | import sys
8 | import trace_execution
9 | import os
10 | import io
11 | import linecache
12 | import pandas as pd
13 | from sklearn.model_selection import train_test_split
14 | import argparse
15 |
16 | output_dir = 'dataset'
17 | inference_dir = 'inference'
18 | os.makedirs(inference_dir, exist_ok=True)
19 | files = os.listdir(output_dir)
20 | files.sort()
21 | out_nodes = []
22 | out_fw_path = []
23 | out_back_path = []
24 | out_exe_path = []
25 | out_file_names = []
26 | max_node = 0
27 | max_edge = 0
28 |
29 | for file in files:
30 | BlockId().counter = 0
31 | filename = f'./{output_dir}/' + file
32 | try:
33 | source = open(filename, 'r').read()
34 | compile(source, filename, 'exec')
35 | except:
36 | print('Error in source code')
37 | continue
38 |
39 | parser = PyParser(source)
40 | parser.removeCommentsAndDocstrings()
41 | parser.formatCode()
42 | try:
43 | cfg = CFGVisitor().build(filename, ast.parse(parser.script))
44 | except AttributeError:
45 | continue
46 | except IndentationError:
47 | continue
48 | except TypeError:
49 | continue
50 | except SyntaxError:
51 | continue
52 | cfg.clean()
53 | try:
54 | cfg.track_execution()
55 | except Exception:
56 | continue
57 |
58 | code = {}
59 | for_loop = {}
60 | for i in cfg.blocks:
61 | if cfg.blocks[i].for_loop != 0:
62 | if cfg.blocks[i].for_loop not in for_loop:
63 | for_loop[cfg.blocks[i].for_loop] = [i]
64 | else:
65 | for_loop[cfg.blocks[i].for_loop].append(i)
66 | first = []
67 | second = []
68 | for i in for_loop:
69 | first.append(for_loop[i][0]+1)
70 | second.append(for_loop[i][1])
71 | orin_node = []
72 | track = {}
73 | track_for = {}
74 | for i in cfg.blocks:
75 | if cfg.blocks[i].stmts_to_code():
76 | if int(i) == 1:
77 | st = 'BEGIN'
78 | elif int(i) == len(cfg.blocks):
79 | st = 'EXIT'
80 | else:
81 | if i in first:
82 | line = astor.to_source(cfg.blocks[i].for_name)
83 | st = line.split('\n')[0]
84 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
85 | else:
86 | st = cfg.blocks[i].stmts_to_code()
87 | st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
88 | orin_node.append([i, st, None])
89 | if st not in track:
90 | track[st] = [len(orin_node)-1]
91 | else:
92 | track[st].append(len(orin_node)-1)
93 | track_for[i] = len(orin_node)-1
94 | with open(filename, 'r') as file_open:
95 | lines = file_open.readlines()
96 | for i in range(1, len(lines)+1):
97 | line = lines[i-1]
98 | #delete \n at the end of each line and delete all spaces
99 | line = line.strip()
100 | line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "")
101 | if line.startswith('elif'):
102 | line = line[2:]
103 | if line in track:
104 | orin_node[track[line][0]][2] = i
105 | if orin_node[track[line][0]][0] in first:
106 | orin_node[track[line][0]-1][2] = i-0.4
107 | orin_node[track[line][0]+1][2] = i+0.4
108 | if len(track[line]) > 1:
109 | track[line].pop(0)
110 | for i in second:
111 | max_val = 0
112 | for edge in cfg.edges:
113 | if edge[0] == i:
114 | if orin_node[track_for[edge[1]]][2] > max_val:
115 | max_val = orin_node[track_for[edge[1]]][2]
116 | if edge[1] == i:
117 | if orin_node[track_for[edge[0]]][2] > max_val:
118 | max_val = orin_node[track_for[edge[0]]][2]
119 | orin_node[track_for[i]][2] = max_val + 0.5
120 | orin_node[0][2] = 0
121 | orin_node[-1][2] = len(lines)+1
122 | # sort orin_node by the third element
123 | orin_node.sort(key=lambda x: x[2])
124 |
125 | nodes = []
126 | matching = {}
127 | for t in orin_node:
128 | i = t[0]
129 | if cfg.blocks[i].stmts_to_code():
130 | if int(i) == 1:
131 | nodes.append('BEGIN')
132 | elif int(i) == len(cfg.blocks):
133 | nodes.append('EXIT')
134 | else:
135 | st = cfg.blocks[i].stmts_to_code()
136 | st_no_space = re.sub(r"\s+", "", st)
137 | # if start with if or while, delete these keywords
138 | if st.startswith('if'):
139 | st = st[3:]
140 | elif st.startswith('while'):
141 | st = st[6:]
142 | if cfg.blocks[i].condition:
143 | st = 'T '+ st
144 | if st.endswith('\n'):
145 | st = st[:-1]
146 | if st.endswith(":"):
147 | st = st[:-1]
148 | nodes.append(st)
149 | matching[i] = len(nodes)
150 |
151 | fwd_edges = []
152 | back_edges = []
153 | edges = {}
154 | for edge in cfg.edges:
155 | if edge not in cfg.back_edges:
156 | fwd_edges.append((matching[edge[0]], matching[edge[1]]))
157 | else:
158 | back_edges.append((matching[edge[0]], matching[edge[1]]))
159 | if matching[edge[0]] not in edges:
160 | edges[matching[edge[0]]] = [matching[edge[1]]]
161 | else:
162 | edges[matching[edge[0]]].append(matching[edge[1]])
163 | exe_path = [0 for i in range(len(nodes))]
164 | for node in cfg.path:
165 | exe_path[matching[node]-1] = 1
166 | # check nodes, fwd_edges, back_edges, exe_path not exist in the list and then append
167 | print(f'Done in {file}')
168 | if nodes in out_nodes:
169 | # check the fwd_edges, back_edges, exe_path in the same index
170 | index = out_nodes.index(nodes)
171 | if fwd_edges != out_fw_path[index] or back_edges != out_back_path[index] or exe_path != out_exe_path[index]:
172 | out_nodes.append(nodes)
173 | out_fw_path.append(fwd_edges)
174 | out_back_path.append(back_edges)
175 | out_exe_path.append(exe_path)
176 | out_file_names.append(filename)
177 | else:
178 | out_nodes.append(nodes)
179 | out_fw_path.append(fwd_edges)
180 | out_back_path.append(back_edges)
181 | out_exe_path.append(exe_path)
182 | out_file_names.append(filename)
183 |
184 | data = {
185 | 'nodes': out_nodes,
186 | 'forward': out_fw_path,
187 | 'backward': out_back_path,
188 | 'target': out_exe_path,
189 | 'file_name': out_file_names
190 | }
191 |
192 | df = pd.DataFrame(data)
193 |
194 | # Split into training and testing sets
195 | train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
196 |
197 | # Save train set to CSV
198 | train_df.drop(columns=['file_name'], inplace=True)
199 | train_df.to_csv('train.csv', index=False, quoting=1)
200 | print("Train CSV file has been saved successfully.")
201 |
202 | # Save test set to CSV
203 | test_file_names = test_df['file_name'].tolist()
204 | test_df.drop(columns=['file_name'], inplace=True)
205 | test_df.to_csv('test.csv', index=False, quoting=1)
206 | print("Test CSV file has been saved successfully.")
207 |
208 | # Save each test file with the corresponding name in the "inference" folder
209 | for i, filename in enumerate(test_file_names):
210 | with open(filename, 'r') as f:
211 | source_code = f.read()
212 | new_filename = os.path.join(inference_dir, f"code_{i + 2}.py")
213 | with open(new_filename, 'w') as f:
214 | f.write(source_code)
215 | print("Test files have been saved successfully.")
--------------------------------------------------------------------------------
/generate_dataset/trace_execution.py:
--------------------------------------------------------------------------------
1 | """program/module to trace Python program or function execution
2 |
3 | Sample use, command line:
4 | trace.py -c -f counts --ignore-dir '$prefix' spam.py eggs
5 | trace.py -t --ignore-dir '$prefix' spam.py eggs
6 | trace.py --trackcalls spam.py eggs
7 |
8 | Sample use, programmatically
9 | import sys
10 |
11 | # create a Trace object, telling it what to ignore, and whether to
12 | # do tracing or line-counting or both.
13 | tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,],
14 | trace=0, count=1)
15 | # run the new command using the given tracer
16 | tracer.run('main()')
17 | # make a report, placing output in /tmp
18 | r = tracer.results()
19 | r.write_results(show_missing=True, coverdir="/tmp")
20 | """
21 | __all__ = ['Trace', 'CoverageResults']
22 |
23 | import io
24 | import linecache
25 | import os
26 | import sys
27 | import sysconfig
28 | import token
29 | import tokenize
30 | import inspect
31 | import gc
32 | import dis
33 | import pickle
34 | from time import monotonic as _time
35 |
36 | import threading
37 |
38 | PRAGMA_NOCOVER = "#pragma NO COVER"
39 |
40 | class _Ignore:
41 | def __init__(self, modules=None, dirs=None):
42 | self._mods = set() if not modules else set(modules)
43 | self._dirs = [] if not dirs else [os.path.normpath(d)
44 | for d in dirs]
45 | self._ignore = { '': 1 }
46 |
47 | def names(self, filename, modulename):
48 | if modulename in self._ignore:
49 | return self._ignore[modulename]
50 |
51 | # haven't seen this one before, so see if the module name is
52 | # on the ignore list.
53 | if modulename in self._mods: # Identical names, so ignore
54 | self._ignore[modulename] = 1
55 | return 1
56 |
57 | # check if the module is a proper submodule of something on
58 | # the ignore list
59 | for mod in self._mods:
60 | # Need to take some care since ignoring
61 | # "cmp" mustn't mean ignoring "cmpcache" but ignoring
62 | # "Spam" must also mean ignoring "Spam.Eggs".
63 | if modulename.startswith(mod + '.'):
64 | self._ignore[modulename] = 1
65 | return 1
66 |
67 | # Now check that filename isn't in one of the directories
68 | if filename is None:
69 | # must be a built-in, so we must ignore
70 | self._ignore[modulename] = 1
71 | return 1
72 |
73 | # Ignore a file when it contains one of the ignorable paths
74 | for d in self._dirs:
75 | # The '+ os.sep' is to ensure that d is a parent directory,
76 | # as compared to cases like:
77 | # d = "/usr/local"
78 | # filename = "/usr/local.py"
79 | # or
80 | # d = "/usr/local.py"
81 | # filename = "/usr/local.py"
82 | if filename.startswith(d + os.sep):
83 | self._ignore[modulename] = 1
84 | return 1
85 |
86 | # Tried the different ways, so we don't ignore this module
87 | self._ignore[modulename] = 0
88 | return 0
89 |
90 | def _modname(path):
91 | """Return a plausible module name for the path."""
92 |
93 | base = os.path.basename(path)
94 | filename, ext = os.path.splitext(base)
95 | return filename
96 |
97 | def _fullmodname(path):
98 | """Return a plausible module name for the path."""
99 |
100 | # If the file 'path' is part of a package, then the filename isn't
101 | # enough to uniquely identify it. Try to do the right thing by
102 | # looking in sys.path for the longest matching prefix. We'll
103 | # assume that the rest is the package name.
104 |
105 | comparepath = os.path.normcase(path)
106 | longest = ""
107 | for dir in sys.path:
108 | dir = os.path.normcase(dir)
109 | if comparepath.startswith(dir) and comparepath[len(dir)] == os.sep:
110 | if len(dir) > len(longest):
111 | longest = dir
112 |
113 | if longest:
114 | base = path[len(longest) + 1:]
115 | else:
116 | base = path
117 | # the drive letter is never part of the module name
118 | drive, base = os.path.splitdrive(base)
119 | base = base.replace(os.sep, ".")
120 | if os.altsep:
121 | base = base.replace(os.altsep, ".")
122 | filename, ext = os.path.splitext(base)
123 | return filename.lstrip(".")
124 |
125 | class CoverageResults:
126 | def __init__(self, counts=None, calledfuncs=None, infile=None,
127 | callers=None, outfile=None):
128 | self.counts = counts
129 | if self.counts is None:
130 | self.counts = {}
131 | self.counter = self.counts.copy() # map (filename, lineno) to count
132 | self.calledfuncs = calledfuncs
133 | if self.calledfuncs is None:
134 | self.calledfuncs = {}
135 | self.calledfuncs = self.calledfuncs.copy()
136 | self.callers = callers
137 | if self.callers is None:
138 | self.callers = {}
139 | self.callers = self.callers.copy()
140 | self.infile = infile
141 | self.outfile = outfile
142 | if self.infile:
143 | # Try to merge existing counts file.
144 | try:
145 | with open(self.infile, 'rb') as f:
146 | counts, calledfuncs, callers = pickle.load(f)
147 | self.update(self.__class__(counts, calledfuncs, callers=callers))
148 | except (OSError, EOFError, ValueError) as err:
149 | print(("Skipping counts file %r: %s"
150 | % (self.infile, err)), file=sys.stderr)
151 |
152 | def is_ignored_filename(self, filename):
153 | """Return True if the filename does not refer to a file
154 | we want to have reported.
155 | """
156 | return filename.startswith('<') and filename.endswith('>')
157 |
158 | def update(self, other):
159 | """Merge in the data from another CoverageResults"""
160 | counts = self.counts
161 | calledfuncs = self.calledfuncs
162 | callers = self.callers
163 | other_counts = other.counts
164 | other_calledfuncs = other.calledfuncs
165 | other_callers = other.callers
166 |
167 | for key in other_counts:
168 | counts[key] = counts.get(key, 0) + other_counts[key]
169 |
170 | for key in other_calledfuncs:
171 | calledfuncs[key] = 1
172 |
173 | for key in other_callers:
174 | callers[key] = 1
175 |
176 | def write_results(self, show_missing=True, summary=False, coverdir=None):
177 | """
178 | Write the coverage results.
179 |
180 | :param show_missing: Show lines that had no hits.
181 | :param summary: Include coverage summary per module.
182 | :param coverdir: If None, the results of each module are placed in its
183 | directory, otherwise it is included in the directory
184 | specified.
185 | """
186 | if self.calledfuncs:
187 | print()
188 | print("functions called:")
189 | calls = self.calledfuncs
190 | for filename, modulename, funcname in sorted(calls):
191 | print(("filename: %s, modulename: %s, funcname: %s"
192 | % (filename, modulename, funcname)))
193 |
194 | if self.callers:
195 | print()
196 | print("calling relationships:")
197 | lastfile = lastcfile = ""
198 | for ((pfile, pmod, pfunc), (cfile, cmod, cfunc)) \
199 | in sorted(self.callers):
200 | if pfile != lastfile:
201 | print()
202 | print("***", pfile, "***")
203 | lastfile = pfile
204 | lastcfile = ""
205 | if cfile != pfile and lastcfile != cfile:
206 | print(" -->", cfile)
207 | lastcfile = cfile
208 | print(" %s.%s -> %s.%s" % (pmod, pfunc, cmod, cfunc))
209 |
210 | # turn the counts data ("(filename, lineno) = count") into something
211 | # accessible on a per-file basis
212 | per_file = {}
213 | for filename, lineno in self.counts:
214 | lines_hit = per_file[filename] = per_file.get(filename, {})
215 | lines_hit[lineno] = self.counts[(filename, lineno)]
216 |
217 | # accumulate summary info, if needed
218 | sums = {}
219 |
220 | for filename, count in per_file.items():
221 | if self.is_ignored_filename(filename):
222 | continue
223 |
224 | if filename.endswith(".pyc"):
225 | filename = filename[:-1]
226 |
227 | if coverdir is None:
228 | dir = os.path.dirname(os.path.abspath(filename))
229 | modulename = _modname(filename)
230 | else:
231 | dir = coverdir
232 | os.makedirs(dir, exist_ok=True)
233 | modulename = _fullmodname(filename)
234 |
235 | # If desired, get a list of the line numbers which represent
236 | # executable content (returned as a dict for better lookup speed)
237 | if show_missing:
238 | lnotab = _find_executable_linenos(filename)
239 | else:
240 | lnotab = {}
241 | source = linecache.getlines(filename)
242 | coverpath = os.path.join(dir, modulename + ".cover")
243 | with open(filename, 'rb') as fp:
244 | encoding, _ = tokenize.detect_encoding(fp.readline)
245 | n_hits, n_lines = self.write_results_file(coverpath, source,
246 | lnotab, count, encoding)
247 | if summary and n_lines:
248 | percent = int(100 * n_hits / n_lines)
249 | sums[modulename] = n_lines, percent, modulename, filename
250 |
251 |
252 | if summary and sums:
253 | print("lines cov% module (path)")
254 | for m in sorted(sums):
255 | n_lines, percent, modulename, filename = sums[m]
256 | print("%5d %3d%% %s (%s)" % sums[m])
257 |
258 | if self.outfile:
259 | # try and store counts and module info into self.outfile
260 | try:
261 | with open(self.outfile, 'wb') as f:
262 | pickle.dump((self.counts, self.calledfuncs, self.callers),
263 | f, 1)
264 | except OSError as err:
265 | print("Can't save counts files because %s" % err, file=sys.stderr)
266 |
267 | def write_results_file(self, path, lines, lnotab, lines_hit, encoding=None):
268 | """Return a coverage results file in path."""
269 | # ``lnotab`` is a dict of executable lines, or a line number "table"
270 |
271 | try:
272 | outfile = open(path, "w", encoding=encoding)
273 | except OSError as err:
274 | print(("trace: Could not open %r for writing: %s "
275 | "- skipping" % (path, err)), file=sys.stderr)
276 | return 0, 0
277 |
278 | n_lines = 0
279 | n_hits = 0
280 | with outfile:
281 | for lineno, line in enumerate(lines, 1):
282 | # do the blank/comment match to try to mark more lines
283 | # (help the reader find stuff that hasn't been covered)
284 | if lineno in lines_hit:
285 | outfile.write("%5d: " % lines_hit[lineno])
286 | n_hits += 1
287 | n_lines += 1
288 | elif lineno in lnotab and not PRAGMA_NOCOVER in line:
289 | # Highlight never-executed lines, unless the line contains
290 | # #pragma: NO COVER
291 | outfile.write(">>>>>> ")
292 | n_lines += 1
293 | else:
294 | outfile.write(" ")
295 | outfile.write(line.expandtabs(8))
296 |
297 | return n_hits, n_lines
298 |
299 | def _find_lines_from_code(code, strs):
300 | """Return dict where keys are lines in the line number table."""
301 | linenos = {}
302 |
303 | for _, lineno in dis.findlinestarts(code):
304 | if lineno not in strs:
305 | linenos[lineno] = 1
306 |
307 | return linenos
308 |
309 | def _find_lines(code, strs):
310 | """Return lineno dict for all code objects reachable from code."""
311 | # get all of the lineno information from the code of this scope level
312 | linenos = _find_lines_from_code(code, strs)
313 |
314 | # and check the constants for references to other code objects
315 | for c in code.co_consts:
316 | if inspect.iscode(c):
317 | # find another code object, so recurse into it
318 | linenos.update(_find_lines(c, strs))
319 | return linenos
320 |
321 | def _find_strings(filename, encoding=None):
322 | """Return a dict of possible docstring positions.
323 |
324 | The dict maps line numbers to strings. There is an entry for
325 | line that contains only a string or a part of a triple-quoted
326 | string.
327 | """
328 | d = {}
329 | # If the first token is a string, then it's the module docstring.
330 | # Add this special case so that the test in the loop passes.
331 | prev_ttype = token.INDENT
332 | with open(filename, encoding=encoding) as f:
333 | tok = tokenize.generate_tokens(f.readline)
334 | for ttype, tstr, start, end, line in tok:
335 | if ttype == token.STRING:
336 | if prev_ttype == token.INDENT:
337 | sline, scol = start
338 | eline, ecol = end
339 | for i in range(sline, eline + 1):
340 | d[i] = 1
341 | prev_ttype = ttype
342 | return d
343 |
344 | def _find_executable_linenos(filename):
345 | """Return dict where keys are line numbers in the line number table."""
346 | try:
347 | with tokenize.open(filename) as f:
348 | prog = f.read()
349 | encoding = f.encoding
350 | except OSError as err:
351 | print(("Not printing coverage data for %r: %s"
352 | % (filename, err)), file=sys.stderr)
353 | return {}
354 | code = compile(prog, filename, "exec")
355 | strs = _find_strings(filename, encoding)
356 | return _find_lines(code, strs)
357 |
358 | class Trace:
359 | def __init__(self, count=1, trace=1, countfuncs=0, countcallers=0,
360 | ignoremods=(), ignoredirs=(), infile=None, outfile=None,
361 | timing=False):
362 | """
363 | @param count true iff it should count number of times each
364 | line is executed
365 | @param trace true iff it should print out each line that is
366 | being counted
367 | @param countfuncs true iff it should just output a list of
368 | (filename, modulename, funcname,) for functions
369 | that were called at least once; This overrides
370 | `count' and `trace'
371 | @param ignoremods a list of the names of modules to ignore
372 | @param ignoredirs a list of the names of directories to ignore
373 | all of the (recursive) contents of
374 | @param infile file from which to read stored counts to be
375 | added into the results
376 | @param outfile file in which to write the results
377 | @param timing true iff timing information be displayed
378 | """
379 | self.exe_path = []
380 | self.infile = infile
381 | self.outfile = outfile
382 | self.ignore = _Ignore(ignoremods, ignoredirs)
383 | self.counts = {} # keys are (filename, linenumber)
384 | self.pathtobasename = {} # for memoizing os.path.basename
385 | self.donothing = 0
386 | self.trace = trace
387 | self._calledfuncs = {}
388 | self._callers = {}
389 | self._caller_cache = {}
390 | self.start_time = None
391 | if timing:
392 | self.start_time = _time()
393 | if countcallers:
394 | self.globaltrace = self.globaltrace_trackcallers
395 | elif countfuncs:
396 | self.globaltrace = self.globaltrace_countfuncs
397 | elif trace and count:
398 | self.globaltrace = self.globaltrace_lt
399 | self.localtrace = self.localtrace_trace_and_count
400 | elif trace:
401 | self.globaltrace = self.globaltrace_lt
402 | self.localtrace = self.localtrace_trace
403 | elif count:
404 | self.globaltrace = self.globaltrace_lt
405 | self.localtrace = self.localtrace_count
406 | else:
407 | # Ahem -- do nothing? Okay.
408 | self.donothing = 1
409 |
410 | def run(self, cmd):
411 | import __main__
412 | dict = __main__.__dict__
413 | self.runctx(cmd, dict, dict)
414 |
415 | def runctx(self, cmd, globals=None, locals=None):
416 | if globals is None: globals = {}
417 | if locals is None: locals = {}
418 | if not self.donothing:
419 | threading.settrace(self.globaltrace)
420 | sys.settrace(self.globaltrace)
421 | try:
422 | exec(cmd, globals, locals)
423 | finally:
424 | if not self.donothing:
425 | sys.settrace(None)
426 | threading.settrace(None)
427 |
428 | def runfunc(self, func, /, *args, **kw):
429 | result = None
430 | if not self.donothing:
431 | sys.settrace(self.globaltrace)
432 | try:
433 | result = func(*args, **kw)
434 | finally:
435 | if not self.donothing:
436 | sys.settrace(None)
437 | return result
438 |
439 | def file_module_function_of(self, frame):
440 | code = frame.f_code
441 | filename = code.co_filename
442 | if filename:
443 | modulename = _modname(filename)
444 | else:
445 | modulename = None
446 |
447 | funcname = code.co_name
448 | clsname = None
449 | if code in self._caller_cache:
450 | if self._caller_cache[code] is not None:
451 | clsname = self._caller_cache[code]
452 | else:
453 | self._caller_cache[code] = None
454 | ## use of gc.get_referrers() was suggested by Michael Hudson
455 | # all functions which refer to this code object
456 | funcs = [f for f in gc.get_referrers(code)
457 | if inspect.isfunction(f)]
458 | # require len(func) == 1 to avoid ambiguity caused by calls to
459 | # new.function(): "In the face of ambiguity, refuse the
460 | # temptation to guess."
461 | if len(funcs) == 1:
462 | dicts = [d for d in gc.get_referrers(funcs[0])
463 | if isinstance(d, dict)]
464 | if len(dicts) == 1:
465 | classes = [c for c in gc.get_referrers(dicts[0])
466 | if hasattr(c, "__bases__")]
467 | if len(classes) == 1:
468 | # ditto for new.classobj()
469 | clsname = classes[0].__name__
470 | # cache the result - assumption is that new.* is
471 | # not called later to disturb this relationship
472 | # _caller_cache could be flushed if functions in
473 | # the new module get called.
474 | self._caller_cache[code] = clsname
475 | if clsname is not None:
476 | funcname = "%s.%s" % (clsname, funcname)
477 |
478 | return filename, modulename, funcname
479 |
480 | def globaltrace_trackcallers(self, frame, why, arg):
481 | """Handler for call events.
482 |
483 | Adds information about who called who to the self._callers dict.
484 | """
485 | if why == 'call':
486 | # XXX Should do a better job of identifying methods
487 | this_func = self.file_module_function_of(frame)
488 | parent_func = self.file_module_function_of(frame.f_back)
489 | self._callers[(parent_func, this_func)] = 1
490 |
491 | def globaltrace_countfuncs(self, frame, why, arg):
492 | """Handler for call events.
493 |
494 | Adds (filename, modulename, funcname) to the self._calledfuncs dict.
495 | """
496 | if why == 'call':
497 | this_func = self.file_module_function_of(frame)
498 | self._calledfuncs[this_func] = 1
499 |
500 | def globaltrace_lt(self, frame, why, arg):
501 | """Handler for call events.
502 |
503 | If the code block being entered is to be ignored, returns `None',
504 | else returns self.localtrace.
505 | """
506 | if why == 'call':
507 | code = frame.f_code
508 | filename = frame.f_globals.get('__file__', None)
509 | if filename:
510 | # XXX _modname() doesn't work right for packages, so
511 | # the ignore support won't work right for packages
512 | modulename = _modname(filename)
513 | if modulename is not None:
514 | ignore_it = self.ignore.names(filename, modulename)
515 | if not ignore_it:
516 | if self.trace:
517 | print((" --- modulename: %s, funcname: %s"
518 | % (modulename, code.co_name)))
519 | return self.localtrace
520 | else:
521 | return None
522 |
523 | def localtrace_trace_and_count(self, frame, why, arg):
524 | if why == "line":
525 | # record the file name and line number of every trace
526 | filename = frame.f_code.co_filename
527 | lineno = frame.f_lineno
528 | key = filename, lineno
529 | #Cuong
530 | print(key)
531 | self.counts[key] = self.counts.get(key, 0) + 1
532 |
533 | if self.start_time:
534 | print('%.2f' % (_time() - self.start_time), end=' ')
535 | bname = os.path.basename(filename)
536 | line = linecache.getline(filename, lineno)
537 | print("%s(%d)" % (bname, lineno), end='')
538 | if line:
539 | print(": ", line, end='')
540 | else:
541 | print()
542 | return self.localtrace
543 |
544 | def localtrace_trace(self, frame, why, arg):
545 | if why == "line":
546 | # record the file name and line number of every trace
547 | filename = frame.f_code.co_filename
548 | lineno = frame.f_lineno
549 |
550 | if self.start_time:
551 | print('%.2f' % (_time() - self.start_time), end=' ')
552 | bname = os.path.basename(filename)
553 | line = linecache.getline(filename, lineno)
554 | print("%s(%d)" % (bname, lineno), end='')
555 | if line:
556 | print(": ", line, end='')
557 | else:
558 | print()
559 | return self.localtrace
560 |
561 | def localtrace_count(self, frame, why, arg):
562 | if why == "line":
563 | filename = frame.f_code.co_filename
564 | lineno = frame.f_lineno
565 | key = filename, lineno
566 | self.exe_path.append(lineno)
567 | self.counts[key] = self.counts.get(key, 0) + 1
568 | return self.localtrace
569 |
570 | def results(self):
571 | return CoverageResults(self.counts, infile=self.infile,
572 | outfile=self.outfile,
573 | calledfuncs=self._calledfuncs,
574 | callers=self._callers)
575 |
576 | def main():
577 | import argparse
578 |
579 | parser = argparse.ArgumentParser()
580 | parser.add_argument('--version', action='version', version='trace 2.0')
581 |
582 | grp = parser.add_argument_group('Main options',
583 | 'One of these (or --report) must be given')
584 |
585 | grp.add_argument('-c', '--count', action='store_true',
586 | help='Count the number of times each line is executed and write '
587 | 'the counts to .cover for each module executed, in '
588 | 'the module\'s directory. See also --coverdir, --file, '
589 | '--no-report below.')
590 | grp.add_argument('-t', '--trace', action='store_true',
591 | help='Print each line to sys.stdout before it is executed')
592 | grp.add_argument('-l', '--listfuncs', action='store_true',
593 | help='Keep track of which functions are executed at least once '
594 | 'and write the results to sys.stdout after the program exits. '
595 | 'Cannot be specified alongside --trace or --count.')
596 | grp.add_argument('-T', '--trackcalls', action='store_true',
597 | help='Keep track of caller/called pairs and write the results to '
598 | 'sys.stdout after the program exits.')
599 |
600 | grp = parser.add_argument_group('Modifiers')
601 |
602 | _grp = grp.add_mutually_exclusive_group()
603 | _grp.add_argument('-r', '--report', action='store_true',
604 | help='Generate a report from a counts file; does not execute any '
605 | 'code. --file must specify the results file to read, which '
606 | 'must have been created in a previous run with --count '
607 | '--file=FILE')
608 | _grp.add_argument('-R', '--no-report', action='store_true',
609 | help='Do not generate the coverage report files. '
610 | 'Useful if you want to accumulate over several runs.')
611 |
612 | grp.add_argument('-f', '--file',
613 | help='File to accumulate counts over several runs')
614 | grp.add_argument('-C', '--coverdir',
615 | help='Directory where the report files go. The coverage report '
616 | 'for . will be written to file '
617 | '//.cover')
618 | grp.add_argument('-m', '--missing', action='store_true',
619 | help='Annotate executable lines that were not executed with '
620 | '">>>>>> "')
621 | grp.add_argument('-s', '--summary', action='store_true',
622 | help='Write a brief summary for each file to sys.stdout. '
623 | 'Can only be used with --count or --report')
624 | grp.add_argument('-g', '--timing', action='store_true',
625 | help='Prefix each line with the time since the program started. '
626 | 'Only used while tracing')
627 |
628 | grp = parser.add_argument_group('Filters',
629 | 'Can be specified multiple times')
630 | grp.add_argument('--ignore-module', action='append', default=[],
631 | help='Ignore the given module(s) and its submodules '
632 | '(if it is a package). Accepts comma separated list of '
633 | 'module names.')
634 | grp.add_argument('--ignore-dir', action='append', default=[],
635 | help='Ignore files in the given directory '
636 | '(multiple directories can be joined by os.pathsep).')
637 |
638 | parser.add_argument('--module', action='store_true', default=False,
639 | help='Trace a module. ')
640 | parser.add_argument('progname', nargs='?',
641 | help='file to run as main program')
642 | parser.add_argument('arguments', nargs=argparse.REMAINDER,
643 | help='arguments to the program')
644 |
645 | opts = parser.parse_args()
646 |
647 | if opts.ignore_dir:
648 | _prefix = sysconfig.get_path("stdlib")
649 | _exec_prefix = sysconfig.get_path("platstdlib")
650 |
651 | def parse_ignore_dir(s):
652 | s = os.path.expanduser(os.path.expandvars(s))
653 | s = s.replace('$prefix', _prefix).replace('$exec_prefix', _exec_prefix)
654 | return os.path.normpath(s)
655 |
656 | opts.ignore_module = [mod.strip()
657 | for i in opts.ignore_module for mod in i.split(',')]
658 | opts.ignore_dir = [parse_ignore_dir(s)
659 | for i in opts.ignore_dir for s in i.split(os.pathsep)]
660 |
661 | if opts.report:
662 | if not opts.file:
663 | parser.error('-r/--report requires -f/--file')
664 | results = CoverageResults(infile=opts.file, outfile=opts.file)
665 | return results.write_results(opts.missing, opts.summary, opts.coverdir)
666 |
667 | if not any([opts.trace, opts.count, opts.listfuncs, opts.trackcalls]):
668 | parser.error('must specify one of --trace, --count, --report, '
669 | '--listfuncs, or --trackcalls')
670 |
671 | if opts.listfuncs and (opts.count or opts.trace):
672 | parser.error('cannot specify both --listfuncs and (--trace or --count)')
673 |
674 | if opts.summary and not opts.count:
675 | parser.error('--summary can only be used with --count or --report')
676 |
677 | if opts.progname is None:
678 | parser.error('progname is missing: required with the main options')
679 |
680 | t = Trace(opts.count, opts.trace, countfuncs=opts.listfuncs,
681 | countcallers=opts.trackcalls, ignoremods=opts.ignore_module,
682 | ignoredirs=opts.ignore_dir, infile=opts.file,
683 | outfile=opts.file, timing=opts.timing)
684 | try:
685 | if opts.module:
686 | import runpy
687 | module_name = opts.progname
688 | mod_name, mod_spec, code = runpy._get_module_details(module_name)
689 | sys.argv = [code.co_filename, *opts.arguments]
690 | globs = {
691 | '__name__': '__main__',
692 | '__file__': code.co_filename,
693 | '__package__': mod_spec.parent,
694 | '__loader__': mod_spec.loader,
695 | '__spec__': mod_spec,
696 | '__cached__': None,
697 | }
698 | else:
699 | sys.argv = [opts.progname, *opts.arguments]
700 | sys.path[0] = os.path.dirname(opts.progname)
701 |
702 | with io.open_code(opts.progname) as fp:
703 | code = compile(fp.read(), opts.progname, 'exec')
704 | # try to emulate __main__ namespace as much as possible
705 | globs = {
706 | '__file__': opts.progname,
707 | '__name__': '__main__',
708 | '__package__': None,
709 | '__cached__': None,
710 | }
711 | t.runctx(code, globs, globs)
712 | except OSError as err:
713 | sys.exit("Cannot run file %r because: %s" % (sys.argv[0], err))
714 | except SystemExit:
715 | pass
716 |
717 | results = t.results()
718 |
719 | if not opts.no_report:
720 | results.write_results(opts.missing, opts.summary, opts.coverdir)
721 |
722 | if __name__=='__main__':
723 | main()
--------------------------------------------------------------------------------
/img/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/CodeFlow/ef145c4d7271ece6abbff9f6e8d4c2d630bdb5af/img/architecture.png
--------------------------------------------------------------------------------
/img/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/CodeFlow/ef145c4d7271ece6abbff9f6e8d4c2d630bdb5af/img/pipeline.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import data
2 | import config
3 | import os
4 | import model
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from tqdm import tqdm
9 | from sklearn.metrics import precision_recall_fscore_support
10 | import warnings
11 | import numpy as np
12 | from utils import write, pad_targets, accuracy_whole_list
13 | import numpy as np
14 | import random
15 |
16 | warnings.simplefilter("ignore")
17 |
18 | opt = config.parse()
19 | if torch.cuda.is_available():
20 | torch.cuda.manual_seed(opt.seed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 | random.seed(opt.seed)
24 | np.random.seed(opt. seed)
25 |
26 | def train(opt, train_iter, valid_iter, device):
27 | net = model.CodeFlow(opt).to(device)
28 | criterion = nn.BCELoss()
29 | optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate)
30 | write("Start training...", opt.output)
31 | for epoch in range(opt.num_epoch):
32 | net.train()
33 | total_loss, total_accuracy = 0, 0
34 | total_train = 0
35 | for batch in tqdm(train_iter):
36 | x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float()
37 |
38 | if isinstance(x, tuple):
39 | pred = net(x[0], edges, x[1], x[2])
40 | else:
41 | pred = net(x, edges)
42 | pred = pred.squeeze()
43 |
44 | loss = criterion(pred, target)
45 | optimizer.zero_grad()
46 | loss.backward()
47 | optimizer.step()
48 | total_loss += loss.item()
49 | pred = (pred > opt.alpha).float()
50 | accuracy = accuracy_whole_list(target.cpu().numpy(), pred.cpu().numpy(), x[1].cpu().numpy())
51 | total_train += target.shape[0]
52 | total_accuracy += accuracy
53 | avg_loss = total_loss / len(train_iter)
54 | avg_accuracy = total_accuracy / total_train
55 |
56 | net.eval()
57 | eval_loss, eval_accuracy, eval_error_accuracy = 0, 0, 0
58 | y_true, y_pred = [], []
59 | total_test = 0
60 | total_local = 0
61 | total_detect = 0
62 | locate_bug = 0
63 | detect_true = 0
64 | with torch.no_grad():
65 | for batch in valid_iter:
66 | x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float()
67 | num_nodes = x[1]
68 | if isinstance(x, tuple):
69 | pred = net(x[0], edges, x[1], x[2])
70 | else:
71 | pred = net(x, edges)
72 | pred = pred.squeeze()
73 |
74 | loss = criterion(pred, target)
75 | eval_loss += loss.item()
76 | pred = (pred > opt.alpha).float()
77 | if opt.runtime_detection:
78 | for i in range(len(x[1])):
79 | total_detect += 1
80 | if pred[i][x[1][i]-1] == target[i][x[1][i]-1]:
81 | detect_true += 1
82 | if opt.bug_localization:
83 | for i in range(len(x[1])):
84 | target_list = []
85 | pred_list = []
86 | num_nodes_list = []
87 | if target[i][x[1][i]-1] == 1:
88 | continue
89 | total_local += 1
90 | mask_pred = pred[i] == 1
91 | indices_pred = torch.nonzero(mask_pred).flatten()
92 | farthest_pred = indices_pred.max().item()
93 |
94 | mask_target = target[i] == 1
95 |
96 | indices_target = torch.nonzero(mask_target).flatten()
97 | farthest_target = indices_target.max().item()
98 | if farthest_pred == farthest_target:
99 | locate_bug += 1
100 | target_list.append(target[i].cpu().numpy())
101 | pred_list.append(pred[i].cpu().numpy())
102 | num_nodes_list.append(num_nodes[i].cpu().numpy())
103 | error_accuracy = accuracy_whole_list(target_list, pred_list, num_nodes_list)
104 | eval_error_accuracy += error_accuracy
105 | accuracy = accuracy_whole_list(target.cpu().numpy(), pred.cpu().numpy(), num_nodes.cpu().numpy())
106 | eval_accuracy += accuracy
107 | total_test += target.shape[0]
108 | # append target to y_true and pred to y_pred base on the number of node in num_nodes
109 | for i in range(len(num_nodes)):
110 | y_true.append(target[i, :num_nodes[i]].cpu().numpy())
111 | y_pred.append(pred[i, :num_nodes[i]].cpu().numpy())
112 | avg_eval_loss = eval_loss / len(valid_iter)
113 | avg_eval_accuracy = eval_accuracy / total_test
114 | # concatenate all the target and prediction
115 | y_true = np.concatenate(y_true)
116 | y_pred = np.concatenate(y_pred)
117 |
118 | precision, recall, fscore, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
119 | write(f"Epoch {epoch + 1}/{opt.num_epoch}", opt.output)
120 | write(f"Train Loss: {avg_loss:.4f}, Train Accuracy: {avg_accuracy:.4f}", opt.output)
121 | write(f"Validation Loss: {avg_eval_loss:.4f}, Validation Accuracy: {avg_eval_accuracy:.4f}", opt.output)
122 | if opt.runtime_detection:
123 | detect_acc = (detect_true / total_detect)*100
124 | write(f"Runtime Error Detection: {detect_acc:.4f}", opt.output)
125 | if opt.bug_localization:
126 | locate_acc = (locate_bug/total_local)*100
127 | write(f"BUG Localization: {locate_acc:.4f}", opt.output)
128 | write(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F-score: {fscore:.4f}", opt.output)
129 | write("_________________________________________________", opt.output)
130 |
131 | if (epoch+1) % 10 == 0:
132 | os.makedirs(f'checkpoints/checkpoints_{opt.checkpoint}', exist_ok=True)
133 | torch.save(net.state_dict(), f"checkpoints/checkpoints_{opt.checkpoint}/epoch-{epoch + 1}.pt")
134 |
135 | return net
136 |
137 | def main():
138 | if not os.path.exists('checkpoints'):
139 | os.makedirs('checkpoints')
140 | if opt.checkpoint == None:
141 | files = os.listdir("checkpoints")
142 | opt.checkpoint = len(files)+1
143 | if opt.name_exp == None:
144 | opt.output = f'{opt.output}/checkpoint_{opt.checkpoint}_{opt.seed}'
145 | else:
146 | opt.output = f'{opt.output}/checkpoint_{opt.checkpoint}_{opt.seed}_{opt.name_exp}'
147 | print(opt.output)
148 | os.makedirs(os.path.dirname(opt.output), exist_ok=True)
149 | open(opt.output, 'w').close()
150 | if opt.cuda_num == None:
151 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152 | else:
153 | device = torch.device(f"cuda:{opt.cuda_num}" if torch.cuda.is_available() else "cpu")
154 | if opt.seed != None:
155 | random.seed(opt.seed)
156 | print(f"Using device: {device}")
157 | train_iter, test_iter = data.get_iterators(opt, device)
158 | train(opt, train_iter, test_iter, device)
159 |
160 | if __name__ == "__main__":
161 | main()
162 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils import write
5 |
6 | class CodeFlow(nn.Module):
7 | def __init__(self, opt):
8 | super(CodeFlow, self).__init__()
9 | self.max_node = opt.max_node
10 | self.hidden_dim = opt.hidden_dim
11 | self.embedding = nn.Embedding(opt.vocab_size+2, opt.hidden_dim, padding_idx=1)
12 | self.node_lstm = nn.LSTM(self.hidden_dim, self.hidden_dim//2, bidirectional=True, batch_first=True)
13 | self.gate = Gate(self.hidden_dim, self.hidden_dim)
14 | self.back_gate = Gate(self.hidden_dim, self.hidden_dim)
15 | self.concat = nn.Linear(self.hidden_dim, self.hidden_dim)
16 | self.fc_output = nn.Linear(self.hidden_dim, 1) # Adjusted this layer to match the hidden dimensions
17 | self.opt = opt
18 |
19 | # @profile
20 | def forward(self, x, edges, node_lens=None, token_lens=None, target=None):
21 | f_edges, b_edges = edges
22 | batch_size, num_node, num_token = x.size(0), x.size(1), x.size(2)
23 |
24 | # token_ids [bs,num_node,num_token]
25 | # x: ([bs, 34, 12], [bs], [bs, 34])
26 | # num_node lengths of node (num_token)
27 | # f_edges, b_edges: ([bs, 38, 2], [bs, 6, 2]) -> max_edge_forward, max_edge_back
28 | # target: [bs, 34]
29 |
30 | x = self.embedding(x) # [B, N, L, H]
31 | if self.opt.extra_aggregate:
32 | neigbors = [{} for _ in range(len(f_edges))]
33 | for i in range(len(f_edges)):
34 | for (start, end) in f_edges[i]:
35 | start = start.item()
36 | end = end.item()
37 | if start == 1 and end == 1:
38 | continue
39 | if start not in neigbors[i]:
40 | neigbors[i][start] = [end]
41 | else:
42 | neigbors[i][start].append(end)
43 | if end not in neigbors[i]:
44 | neigbors[i][end] = [start]
45 | else:
46 | neigbors[i][end].append(start)
47 | for i in range(len(b_edges)):
48 | for (start, end) in b_edges[i]:
49 | start = start.item()
50 | end = end.item()
51 | if start == 1 and end == 1:
52 | continue
53 | if start not in neigbors[i]:
54 | neigbors[i][start] = [end]
55 | else:
56 | neigbors[i][start].append(end)
57 | if end not in neigbors[i]:
58 | neigbors[i][end] = [start]
59 | else:
60 | neigbors[i][end].append(start)
61 | max_node = max(node_lens)
62 | matrix = torch.zeros((batch_size, max_node, max_node), dtype=torch.float, device=x.device)
63 | for i in range(batch_size):
64 | if self.opt.delete_redundant_node:
65 | for node in neigbors[i]:
66 | num_neigbors = len(neigbors[i][node])
67 | for neighbor in neigbors[i][node]:
68 | matrix[i, node-1, neighbor-1] = (1-self.opt.alpha)/num_neigbors
69 | matrix[i, node-1, node-1] = self.opt.alpha
70 | else:
71 | for node in range(max_node):
72 | if node in neigbors[i].keys():
73 | num_neigbors = len(neigbors[i][node])
74 | for neighbor in neigbors[i][node]:
75 | matrix[i, node-1, neighbor-1] = (1-self.opt.alpha)/num_neigbors
76 | matrix[i, node-1, node-1] = self.opt.alpha
77 | else:
78 | matrix[i, node-1, node-1] = 1
79 |
80 | #! Node LSTM embedding, https://www.readcube.com/library/1771e2fb-bec1-4bc4-90b3-04c8786fe9dd:fd440d39-f13e-430c-b768-751878616cda, 2nd figure, Node Embedding part
81 | if token_lens is not None:
82 | x = x.view(batch_size*num_node, num_token, -1)
83 | h_n = torch.zeros((2, batch_size*num_node, self.hidden_dim//2)).to(x.device)
84 | c_n = torch.zeros((2, batch_size*num_node, self.hidden_dim//2)).to(x.device)
85 | x, _ = self.node_lstm(x, (h_n, c_n)) # [B*N, L, H]
86 | x = x.view(batch_size, num_node, num_token, -1)
87 | x = self.average_pooling(x, token_lens)
88 | else:
89 | x = torch.mean(x, dim=2) # [B, N, H]
90 |
91 | # ! Initialize hidden states to be zeros
92 | h_f = torch.zeros(x.size()).to(x.device)
93 | c_f = torch.zeros(x.size()).to(x.device)
94 |
95 | # ! Forward pass: including forward egde + backward edge, 1->K
96 | ori_f_matrix = self.convert_to_matrix(batch_size, num_node, f_edges)
97 | running_f_matrix = ori_f_matrix.clone()
98 | for i in range(num_node):
99 | f_i = running_f_matrix[:, i, :].unsqueeze(1)
100 | f_i = f_i.clone()
101 | x_cur = x[:, i, :].squeeze(1) # [B, hidden_dim]
102 | h_last, c_last = f_i.bmm(h_f), f_i.bmm(c_f) # h = [B, max_node, H]
103 | # h_last = [B, 1, H]
104 | # Stopping to check if the node is binary
105 | # [B, 1, max_node] * [B, max_node, hidden_dim] = [B, 1, hidden_dim]
106 | # h_last, c_last = [B, 1, hidden_dim]
107 | h_i, c_i = self.gate(x_cur, h_last.squeeze(1), c_last.squeeze(1))
108 | h_f[:, i, :], c_f[:, i, :] = h_i, c_i
109 | # make the f_matrix, the next nodes j, which connect to i->j. Change their jth row at ith entry
110 | h_i, c_i = h_i.squeeze(1), c_i.squeeze(1)
111 | # for sample_id in range(batch_size):
112 | # next_node_ids = []
113 | # for j in range(num_node):
114 | # if running_f_matrix[sample_id, j, i] == 1:
115 | # next_node_ids.append(j)
116 |
117 | # if len(next_node_ids) > 2:
118 | # print(sample_id)
119 | # print(torch.sum(running_f_matrix, dim=1))
120 | # # raise ValueError(f"Node {i+1} in sample_id: {sample_id} has more than 2 outward edges")
121 | # if len(next_node_ids) == 2:
122 | # if h_i[sample_id].sum() >= 0:
123 | # running_f_matrix[sample_id, next_node_ids[0], i] = 0
124 | # else:
125 | # running_f_matrix[sample_id, next_node_ids[1], i] = 0
126 |
127 |
128 | b_matrix = self.convert_to_matrix(batch_size, num_node, b_edges)
129 | for j in range(num_node):
130 | b_j = b_matrix[:, j, :].unsqueeze(1)
131 | h_temp = b_j.bmm(h_f)
132 | h_f[:, j, :] += h_temp.squeeze(1)
133 |
134 | # # ! Initialize hidden states to be zeros
135 | # h_b = torch.zeros(x.size()).to(x.device)
136 | # c_b = torch.zeros(x.size()).to(x.device)
137 |
138 | # # # ! Backward pass: transpose b_matrix, f_matrix, including forward egde + backward edge, K->1
139 | # b_matrix = self.convert_to_matrix(batch_size, num_node, f_edges)
140 | # b_matrix = b_matrix.transpose(1, 2)
141 | # for i in reversed(range(num_node)):
142 | # x_cur = x[:, i, :].squeeze(1)
143 | # b_i = b_matrix[:, i, :].unsqueeze(1)
144 | # h_hat, c_hat = b_i.bmm(h_b), b_i.bmm(c_b)
145 | # h_b[:, i, :], c_b[:, i, :] = self.back_gate(x_cur, h_hat.squeeze(), c_hat.squeeze())
146 |
147 | # f_matrix = self.convert_to_matrix(batch_size, num_node, b_edges)
148 | # f_matrix = f_matrix.transpose(1, 2)
149 | # for j in range(num_node):
150 | # f_j = f_matrix[:, j, :].unsqueeze(1)
151 | # h_temp = f_j.bmm(h_b)
152 | # h_b[:, j, :] += h_temp.squeeze(1)
153 |
154 | # ------------Prediction stage --------------#
155 |
156 | # h = torch.cat([h_f, h_b], dim=2)
157 | # output = torch.mean(h, dim=1) # take the mean over the nodes within a batch -> [B, H]
158 | # h = [B, max_node, H] -> each node is feeded into the fc_output
159 |
160 | # B, max_node, H -> B, max_node
161 | output = torch.sigmoid(self.fc_output(h_f)) #
162 | if self.opt.extra_aggregate:
163 | output = torch.bmm(matrix, output)
164 | return output
165 |
166 | @staticmethod
167 | def average_pooling(data, input_lens):
168 | B, N, T, H = data.size()
169 | idx = torch.arange(T, device=data.device).unsqueeze(0).expand(B, N, -1)
170 | idx = idx < input_lens.unsqueeze(2)
171 | idx = idx.unsqueeze(3).expand(-1, -1, -1, H)
172 | ret = (data.float() * idx.float()).sum(2) / (input_lens.unsqueeze(2).float()+10**-32)
173 | return ret
174 |
175 | @staticmethod
176 | def convert_to_matrix(batch_size, max_num, m):
177 | matrix = torch.zeros((batch_size, max_num, max_num), dtype=torch.float, device=m.device)
178 | m -= 1
179 | b_select = torch.arange(batch_size).unsqueeze(1).expand(batch_size, m.size(1)).contiguous().view(-1)
180 | matrix[b_select, m[:, :, 1].contiguous().view(-1), m[:, :, 0].contiguous().view(-1)] = 1
181 | matrix[:, 0, 0] = 0
182 | return matrix
183 |
184 |
185 | class Gate(nn.Module):
186 | def __init__(self, in_dim, mem_dim):
187 | super(Gate, self).__init__()
188 | self.in_dim = in_dim
189 | self.mem_dim = mem_dim
190 | self.ax = nn.Linear(self.in_dim, 3 * self.mem_dim)
191 | self.ah = nn.Linear(self.mem_dim, 3 * self.mem_dim)
192 | self.fx = nn.Linear(self.in_dim, self.mem_dim)
193 | self.fh = nn.Linear(self.mem_dim, self.mem_dim)
194 |
195 | def forward(self, inputs, last_h, pred_c):
196 | iou = self.ax(inputs) + self.ah(last_h)
197 | i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
198 | i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
199 | f = torch.sigmoid(self.fh(last_h) + self.fx(inputs))
200 | fc = torch.mul(f, pred_c)
201 | c = torch.mul(i, u) + fc
202 | h = torch.mul(o, torch.tanh(c))
203 | return h, c
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | autopep8
2 | graphviz
3 | pandas
4 | torch
5 | torchtext==0.4.0
6 | tqdm
7 | scikit-learn
8 | astor
9 | anthropic
10 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create a new virtual environment named codeflow with Python 3.10
4 | conda create -n codeflow python=3.10 -y
5 |
6 | # Activate the virtual environment
7 | source activate codeflow
8 |
9 | # Install the required Python libraries
10 | pip install -r requirements.txt
--------------------------------------------------------------------------------
/trace_execution.py:
--------------------------------------------------------------------------------
1 | """program/module to trace Python program or function execution
2 |
3 | Sample use, command line:
4 | trace.py -c -f counts --ignore-dir '$prefix' spam.py eggs
5 | trace.py -t --ignore-dir '$prefix' spam.py eggs
6 | trace.py --trackcalls spam.py eggs
7 |
8 | Sample use, programmatically
9 | import sys
10 |
11 | # create a Trace object, telling it what to ignore, and whether to
12 | # do tracing or line-counting or both.
13 | tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,],
14 | trace=0, count=1)
15 | # run the new command using the given tracer
16 | tracer.run('main()')
17 | # make a report, placing output in /tmp
18 | r = tracer.results()
19 | r.write_results(show_missing=True, coverdir="/tmp")
20 | """
21 | __all__ = ['Trace', 'CoverageResults']
22 |
23 | import io
24 | import linecache
25 | import os
26 | import sys
27 | import sysconfig
28 | import token
29 | import tokenize
30 | import inspect
31 | import gc
32 | import dis
33 | import pickle
34 | from time import monotonic as _time
35 |
36 | import threading
37 |
38 | PRAGMA_NOCOVER = "#pragma NO COVER"
39 |
40 | class _Ignore:
41 | def __init__(self, modules=None, dirs=None):
42 | self._mods = set() if not modules else set(modules)
43 | self._dirs = [] if not dirs else [os.path.normpath(d)
44 | for d in dirs]
45 | self._ignore = { '': 1 }
46 |
47 | def names(self, filename, modulename):
48 | if modulename in self._ignore:
49 | return self._ignore[modulename]
50 |
51 | # haven't seen this one before, so see if the module name is
52 | # on the ignore list.
53 | if modulename in self._mods: # Identical names, so ignore
54 | self._ignore[modulename] = 1
55 | return 1
56 |
57 | # check if the module is a proper submodule of something on
58 | # the ignore list
59 | for mod in self._mods:
60 | # Need to take some care since ignoring
61 | # "cmp" mustn't mean ignoring "cmpcache" but ignoring
62 | # "Spam" must also mean ignoring "Spam.Eggs".
63 | if modulename.startswith(mod + '.'):
64 | self._ignore[modulename] = 1
65 | return 1
66 |
67 | # Now check that filename isn't in one of the directories
68 | if filename is None:
69 | # must be a built-in, so we must ignore
70 | self._ignore[modulename] = 1
71 | return 1
72 |
73 | # Ignore a file when it contains one of the ignorable paths
74 | for d in self._dirs:
75 | # The '+ os.sep' is to ensure that d is a parent directory,
76 | # as compared to cases like:
77 | # d = "/usr/local"
78 | # filename = "/usr/local.py"
79 | # or
80 | # d = "/usr/local.py"
81 | # filename = "/usr/local.py"
82 | if filename.startswith(d + os.sep):
83 | self._ignore[modulename] = 1
84 | return 1
85 |
86 | # Tried the different ways, so we don't ignore this module
87 | self._ignore[modulename] = 0
88 | return 0
89 |
90 | def _modname(path):
91 | """Return a plausible module name for the path."""
92 |
93 | base = os.path.basename(path)
94 | filename, ext = os.path.splitext(base)
95 | return filename
96 |
97 | def _fullmodname(path):
98 | """Return a plausible module name for the path."""
99 |
100 | # If the file 'path' is part of a package, then the filename isn't
101 | # enough to uniquely identify it. Try to do the right thing by
102 | # looking in sys.path for the longest matching prefix. We'll
103 | # assume that the rest is the package name.
104 |
105 | comparepath = os.path.normcase(path)
106 | longest = ""
107 | for dir in sys.path:
108 | dir = os.path.normcase(dir)
109 | if comparepath.startswith(dir) and comparepath[len(dir)] == os.sep:
110 | if len(dir) > len(longest):
111 | longest = dir
112 |
113 | if longest:
114 | base = path[len(longest) + 1:]
115 | else:
116 | base = path
117 | # the drive letter is never part of the module name
118 | drive, base = os.path.splitdrive(base)
119 | base = base.replace(os.sep, ".")
120 | if os.altsep:
121 | base = base.replace(os.altsep, ".")
122 | filename, ext = os.path.splitext(base)
123 | return filename.lstrip(".")
124 |
125 | class CoverageResults:
126 | def __init__(self, counts=None, calledfuncs=None, infile=None,
127 | callers=None, outfile=None):
128 | self.counts = counts
129 | if self.counts is None:
130 | self.counts = {}
131 | self.counter = self.counts.copy() # map (filename, lineno) to count
132 | self.calledfuncs = calledfuncs
133 | if self.calledfuncs is None:
134 | self.calledfuncs = {}
135 | self.calledfuncs = self.calledfuncs.copy()
136 | self.callers = callers
137 | if self.callers is None:
138 | self.callers = {}
139 | self.callers = self.callers.copy()
140 | self.infile = infile
141 | self.outfile = outfile
142 | if self.infile:
143 | # Try to merge existing counts file.
144 | try:
145 | with open(self.infile, 'rb') as f:
146 | counts, calledfuncs, callers = pickle.load(f)
147 | self.update(self.__class__(counts, calledfuncs, callers=callers))
148 | except (OSError, EOFError, ValueError) as err:
149 | print(("Skipping counts file %r: %s"
150 | % (self.infile, err)), file=sys.stderr)
151 |
152 | def is_ignored_filename(self, filename):
153 | """Return True if the filename does not refer to a file
154 | we want to have reported.
155 | """
156 | return filename.startswith('<') and filename.endswith('>')
157 |
158 | def update(self, other):
159 | """Merge in the data from another CoverageResults"""
160 | counts = self.counts
161 | calledfuncs = self.calledfuncs
162 | callers = self.callers
163 | other_counts = other.counts
164 | other_calledfuncs = other.calledfuncs
165 | other_callers = other.callers
166 |
167 | for key in other_counts:
168 | counts[key] = counts.get(key, 0) + other_counts[key]
169 |
170 | for key in other_calledfuncs:
171 | calledfuncs[key] = 1
172 |
173 | for key in other_callers:
174 | callers[key] = 1
175 |
176 | def write_results(self, show_missing=True, summary=False, coverdir=None):
177 | """
178 | Write the coverage results.
179 |
180 | :param show_missing: Show lines that had no hits.
181 | :param summary: Include coverage summary per module.
182 | :param coverdir: If None, the results of each module are placed in its
183 | directory, otherwise it is included in the directory
184 | specified.
185 | """
186 | if self.calledfuncs:
187 | print()
188 | print("functions called:")
189 | calls = self.calledfuncs
190 | for filename, modulename, funcname in sorted(calls):
191 | print(("filename: %s, modulename: %s, funcname: %s"
192 | % (filename, modulename, funcname)))
193 |
194 | if self.callers:
195 | print()
196 | print("calling relationships:")
197 | lastfile = lastcfile = ""
198 | for ((pfile, pmod, pfunc), (cfile, cmod, cfunc)) \
199 | in sorted(self.callers):
200 | if pfile != lastfile:
201 | print()
202 | print("***", pfile, "***")
203 | lastfile = pfile
204 | lastcfile = ""
205 | if cfile != pfile and lastcfile != cfile:
206 | print(" -->", cfile)
207 | lastcfile = cfile
208 | print(" %s.%s -> %s.%s" % (pmod, pfunc, cmod, cfunc))
209 |
210 | # turn the counts data ("(filename, lineno) = count") into something
211 | # accessible on a per-file basis
212 | per_file = {}
213 | for filename, lineno in self.counts:
214 | lines_hit = per_file[filename] = per_file.get(filename, {})
215 | lines_hit[lineno] = self.counts[(filename, lineno)]
216 |
217 | # accumulate summary info, if needed
218 | sums = {}
219 |
220 | for filename, count in per_file.items():
221 | if self.is_ignored_filename(filename):
222 | continue
223 |
224 | if filename.endswith(".pyc"):
225 | filename = filename[:-1]
226 |
227 | if coverdir is None:
228 | dir = os.path.dirname(os.path.abspath(filename))
229 | modulename = _modname(filename)
230 | else:
231 | dir = coverdir
232 | os.makedirs(dir, exist_ok=True)
233 | modulename = _fullmodname(filename)
234 |
235 | # If desired, get a list of the line numbers which represent
236 | # executable content (returned as a dict for better lookup speed)
237 | if show_missing:
238 | lnotab = _find_executable_linenos(filename)
239 | else:
240 | lnotab = {}
241 | source = linecache.getlines(filename)
242 | coverpath = os.path.join(dir, modulename + ".cover")
243 | with open(filename, 'rb') as fp:
244 | encoding, _ = tokenize.detect_encoding(fp.readline)
245 | n_hits, n_lines = self.write_results_file(coverpath, source,
246 | lnotab, count, encoding)
247 | if summary and n_lines:
248 | percent = int(100 * n_hits / n_lines)
249 | sums[modulename] = n_lines, percent, modulename, filename
250 |
251 |
252 | if summary and sums:
253 | print("lines cov% module (path)")
254 | for m in sorted(sums):
255 | n_lines, percent, modulename, filename = sums[m]
256 | print("%5d %3d%% %s (%s)" % sums[m])
257 |
258 | if self.outfile:
259 | # try and store counts and module info into self.outfile
260 | try:
261 | with open(self.outfile, 'wb') as f:
262 | pickle.dump((self.counts, self.calledfuncs, self.callers),
263 | f, 1)
264 | except OSError as err:
265 | print("Can't save counts files because %s" % err, file=sys.stderr)
266 |
267 | def write_results_file(self, path, lines, lnotab, lines_hit, encoding=None):
268 | """Return a coverage results file in path."""
269 | # ``lnotab`` is a dict of executable lines, or a line number "table"
270 |
271 | try:
272 | outfile = open(path, "w", encoding=encoding)
273 | except OSError as err:
274 | print(("trace: Could not open %r for writing: %s "
275 | "- skipping" % (path, err)), file=sys.stderr)
276 | return 0, 0
277 |
278 | n_lines = 0
279 | n_hits = 0
280 | with outfile:
281 | for lineno, line in enumerate(lines, 1):
282 | # do the blank/comment match to try to mark more lines
283 | # (help the reader find stuff that hasn't been covered)
284 | if lineno in lines_hit:
285 | outfile.write("%5d: " % lines_hit[lineno])
286 | n_hits += 1
287 | n_lines += 1
288 | elif lineno in lnotab and not PRAGMA_NOCOVER in line:
289 | # Highlight never-executed lines, unless the line contains
290 | # #pragma: NO COVER
291 | outfile.write(">>>>>> ")
292 | n_lines += 1
293 | else:
294 | outfile.write(" ")
295 | outfile.write(line.expandtabs(8))
296 |
297 | return n_hits, n_lines
298 |
299 | def _find_lines_from_code(code, strs):
300 | """Return dict where keys are lines in the line number table."""
301 | linenos = {}
302 |
303 | for _, lineno in dis.findlinestarts(code):
304 | if lineno not in strs:
305 | linenos[lineno] = 1
306 |
307 | return linenos
308 |
309 | def _find_lines(code, strs):
310 | """Return lineno dict for all code objects reachable from code."""
311 | # get all of the lineno information from the code of this scope level
312 | linenos = _find_lines_from_code(code, strs)
313 |
314 | # and check the constants for references to other code objects
315 | for c in code.co_consts:
316 | if inspect.iscode(c):
317 | # find another code object, so recurse into it
318 | linenos.update(_find_lines(c, strs))
319 | return linenos
320 |
321 | def _find_strings(filename, encoding=None):
322 | """Return a dict of possible docstring positions.
323 |
324 | The dict maps line numbers to strings. There is an entry for
325 | line that contains only a string or a part of a triple-quoted
326 | string.
327 | """
328 | d = {}
329 | # If the first token is a string, then it's the module docstring.
330 | # Add this special case so that the test in the loop passes.
331 | prev_ttype = token.INDENT
332 | with open(filename, encoding=encoding) as f:
333 | tok = tokenize.generate_tokens(f.readline)
334 | for ttype, tstr, start, end, line in tok:
335 | if ttype == token.STRING:
336 | if prev_ttype == token.INDENT:
337 | sline, scol = start
338 | eline, ecol = end
339 | for i in range(sline, eline + 1):
340 | d[i] = 1
341 | prev_ttype = ttype
342 | return d
343 |
344 | def _find_executable_linenos(filename):
345 | """Return dict where keys are line numbers in the line number table."""
346 | try:
347 | with tokenize.open(filename) as f:
348 | prog = f.read()
349 | encoding = f.encoding
350 | except OSError as err:
351 | print(("Not printing coverage data for %r: %s"
352 | % (filename, err)), file=sys.stderr)
353 | return {}
354 | code = compile(prog, filename, "exec")
355 | strs = _find_strings(filename, encoding)
356 | return _find_lines(code, strs)
357 |
358 | class Trace:
359 | def __init__(self, count=1, trace=1, countfuncs=0, countcallers=0,
360 | ignoremods=(), ignoredirs=(), infile=None, outfile=None,
361 | timing=False):
362 | """
363 | @param count true iff it should count number of times each
364 | line is executed
365 | @param trace true iff it should print out each line that is
366 | being counted
367 | @param countfuncs true iff it should just output a list of
368 | (filename, modulename, funcname,) for functions
369 | that were called at least once; This overrides
370 | `count' and `trace'
371 | @param ignoremods a list of the names of modules to ignore
372 | @param ignoredirs a list of the names of directories to ignore
373 | all of the (recursive) contents of
374 | @param infile file from which to read stored counts to be
375 | added into the results
376 | @param outfile file in which to write the results
377 | @param timing true iff timing information be displayed
378 | """
379 | self.exe_path = []
380 | self.infile = infile
381 | self.outfile = outfile
382 | self.ignore = _Ignore(ignoremods, ignoredirs)
383 | self.counts = {} # keys are (filename, linenumber)
384 | self.pathtobasename = {} # for memoizing os.path.basename
385 | self.donothing = 0
386 | self.trace = trace
387 | self._calledfuncs = {}
388 | self._callers = {}
389 | self._caller_cache = {}
390 | self.start_time = None
391 | if timing:
392 | self.start_time = _time()
393 | if countcallers:
394 | self.globaltrace = self.globaltrace_trackcallers
395 | elif countfuncs:
396 | self.globaltrace = self.globaltrace_countfuncs
397 | elif trace and count:
398 | self.globaltrace = self.globaltrace_lt
399 | self.localtrace = self.localtrace_trace_and_count
400 | elif trace:
401 | self.globaltrace = self.globaltrace_lt
402 | self.localtrace = self.localtrace_trace
403 | elif count:
404 | self.globaltrace = self.globaltrace_lt
405 | self.localtrace = self.localtrace_count
406 | else:
407 | # Ahem -- do nothing? Okay.
408 | self.donothing = 1
409 |
410 | def run(self, cmd):
411 | import __main__
412 | dict = __main__.__dict__
413 | self.runctx(cmd, dict, dict)
414 |
415 | def runctx(self, cmd, globals=None, locals=None):
416 | if globals is None: globals = {}
417 | if locals is None: locals = {}
418 | if not self.donothing:
419 | threading.settrace(self.globaltrace)
420 | sys.settrace(self.globaltrace)
421 | try:
422 | exec(cmd, globals, locals)
423 | finally:
424 | if not self.donothing:
425 | sys.settrace(None)
426 | threading.settrace(None)
427 |
428 | def runfunc(self, func, /, *args, **kw):
429 | result = None
430 | if not self.donothing:
431 | sys.settrace(self.globaltrace)
432 | try:
433 | result = func(*args, **kw)
434 | finally:
435 | if not self.donothing:
436 | sys.settrace(None)
437 | return result
438 |
439 | def file_module_function_of(self, frame):
440 | code = frame.f_code
441 | filename = code.co_filename
442 | if filename:
443 | modulename = _modname(filename)
444 | else:
445 | modulename = None
446 |
447 | funcname = code.co_name
448 | clsname = None
449 | if code in self._caller_cache:
450 | if self._caller_cache[code] is not None:
451 | clsname = self._caller_cache[code]
452 | else:
453 | self._caller_cache[code] = None
454 | ## use of gc.get_referrers() was suggested by Michael Hudson
455 | # all functions which refer to this code object
456 | funcs = [f for f in gc.get_referrers(code)
457 | if inspect.isfunction(f)]
458 | # require len(func) == 1 to avoid ambiguity caused by calls to
459 | # new.function(): "In the face of ambiguity, refuse the
460 | # temptation to guess."
461 | if len(funcs) == 1:
462 | dicts = [d for d in gc.get_referrers(funcs[0])
463 | if isinstance(d, dict)]
464 | if len(dicts) == 1:
465 | classes = [c for c in gc.get_referrers(dicts[0])
466 | if hasattr(c, "__bases__")]
467 | if len(classes) == 1:
468 | # ditto for new.classobj()
469 | clsname = classes[0].__name__
470 | # cache the result - assumption is that new.* is
471 | # not called later to disturb this relationship
472 | # _caller_cache could be flushed if functions in
473 | # the new module get called.
474 | self._caller_cache[code] = clsname
475 | if clsname is not None:
476 | funcname = "%s.%s" % (clsname, funcname)
477 |
478 | return filename, modulename, funcname
479 |
480 | def globaltrace_trackcallers(self, frame, why, arg):
481 | """Handler for call events.
482 |
483 | Adds information about who called who to the self._callers dict.
484 | """
485 | if why == 'call':
486 | # XXX Should do a better job of identifying methods
487 | this_func = self.file_module_function_of(frame)
488 | parent_func = self.file_module_function_of(frame.f_back)
489 | self._callers[(parent_func, this_func)] = 1
490 |
491 | def globaltrace_countfuncs(self, frame, why, arg):
492 | """Handler for call events.
493 |
494 | Adds (filename, modulename, funcname) to the self._calledfuncs dict.
495 | """
496 | if why == 'call':
497 | this_func = self.file_module_function_of(frame)
498 | self._calledfuncs[this_func] = 1
499 |
500 | def globaltrace_lt(self, frame, why, arg):
501 | """Handler for call events.
502 |
503 | If the code block being entered is to be ignored, returns `None',
504 | else returns self.localtrace.
505 | """
506 | if why == 'call':
507 | code = frame.f_code
508 | filename = frame.f_globals.get('__file__', None)
509 | if filename:
510 | # XXX _modname() doesn't work right for packages, so
511 | # the ignore support won't work right for packages
512 | modulename = _modname(filename)
513 | if modulename is not None:
514 | ignore_it = self.ignore.names(filename, modulename)
515 | if not ignore_it:
516 | if self.trace:
517 | print((" --- modulename: %s, funcname: %s"
518 | % (modulename, code.co_name)))
519 | return self.localtrace
520 | else:
521 | return None
522 |
523 | def localtrace_trace_and_count(self, frame, why, arg):
524 | if why == "line":
525 | # record the file name and line number of every trace
526 | filename = frame.f_code.co_filename
527 | lineno = frame.f_lineno
528 | key = filename, lineno
529 | #Cuong
530 | print(key)
531 | self.counts[key] = self.counts.get(key, 0) + 1
532 |
533 | if self.start_time:
534 | print('%.2f' % (_time() - self.start_time), end=' ')
535 | bname = os.path.basename(filename)
536 | line = linecache.getline(filename, lineno)
537 | print("%s(%d)" % (bname, lineno), end='')
538 | if line:
539 | print(": ", line, end='')
540 | else:
541 | print()
542 | return self.localtrace
543 |
544 | def localtrace_trace(self, frame, why, arg):
545 | if why == "line":
546 | # record the file name and line number of every trace
547 | filename = frame.f_code.co_filename
548 | lineno = frame.f_lineno
549 |
550 | if self.start_time:
551 | print('%.2f' % (_time() - self.start_time), end=' ')
552 | bname = os.path.basename(filename)
553 | line = linecache.getline(filename, lineno)
554 | print("%s(%d)" % (bname, lineno), end='')
555 | if line:
556 | print(": ", line, end='')
557 | else:
558 | print()
559 | return self.localtrace
560 |
561 | def localtrace_count(self, frame, why, arg):
562 | if why == "line":
563 | filename = frame.f_code.co_filename
564 | lineno = frame.f_lineno
565 | key = filename, lineno
566 | self.exe_path.append(lineno)
567 | self.counts[key] = self.counts.get(key, 0) + 1
568 | return self.localtrace
569 |
570 | def results(self):
571 | return CoverageResults(self.counts, infile=self.infile,
572 | outfile=self.outfile,
573 | calledfuncs=self._calledfuncs,
574 | callers=self._callers)
575 |
576 | def main():
577 | import argparse
578 |
579 | parser = argparse.ArgumentParser()
580 | parser.add_argument('--version', action='version', version='trace 2.0')
581 |
582 | grp = parser.add_argument_group('Main options',
583 | 'One of these (or --report) must be given')
584 |
585 | grp.add_argument('-c', '--count', action='store_true',
586 | help='Count the number of times each line is executed and write '
587 | 'the counts to .cover for each module executed, in '
588 | 'the module\'s directory. See also --coverdir, --file, '
589 | '--no-report below.')
590 | grp.add_argument('-t', '--trace', action='store_true',
591 | help='Print each line to sys.stdout before it is executed')
592 | grp.add_argument('-l', '--listfuncs', action='store_true',
593 | help='Keep track of which functions are executed at least once '
594 | 'and write the results to sys.stdout after the program exits. '
595 | 'Cannot be specified alongside --trace or --count.')
596 | grp.add_argument('-T', '--trackcalls', action='store_true',
597 | help='Keep track of caller/called pairs and write the results to '
598 | 'sys.stdout after the program exits.')
599 |
600 | grp = parser.add_argument_group('Modifiers')
601 |
602 | _grp = grp.add_mutually_exclusive_group()
603 | _grp.add_argument('-r', '--report', action='store_true',
604 | help='Generate a report from a counts file; does not execute any '
605 | 'code. --file must specify the results file to read, which '
606 | 'must have been created in a previous run with --count '
607 | '--file=FILE')
608 | _grp.add_argument('-R', '--no-report', action='store_true',
609 | help='Do not generate the coverage report files. '
610 | 'Useful if you want to accumulate over several runs.')
611 |
612 | grp.add_argument('-f', '--file',
613 | help='File to accumulate counts over several runs')
614 | grp.add_argument('-C', '--coverdir',
615 | help='Directory where the report files go. The coverage report '
616 | 'for . will be written to file '
617 | '//.cover')
618 | grp.add_argument('-m', '--missing', action='store_true',
619 | help='Annotate executable lines that were not executed with '
620 | '">>>>>> "')
621 | grp.add_argument('-s', '--summary', action='store_true',
622 | help='Write a brief summary for each file to sys.stdout. '
623 | 'Can only be used with --count or --report')
624 | grp.add_argument('-g', '--timing', action='store_true',
625 | help='Prefix each line with the time since the program started. '
626 | 'Only used while tracing')
627 |
628 | grp = parser.add_argument_group('Filters',
629 | 'Can be specified multiple times')
630 | grp.add_argument('--ignore-module', action='append', default=[],
631 | help='Ignore the given module(s) and its submodules '
632 | '(if it is a package). Accepts comma separated list of '
633 | 'module names.')
634 | grp.add_argument('--ignore-dir', action='append', default=[],
635 | help='Ignore files in the given directory '
636 | '(multiple directories can be joined by os.pathsep).')
637 |
638 | parser.add_argument('--module', action='store_true', default=False,
639 | help='Trace a module. ')
640 | parser.add_argument('progname', nargs='?',
641 | help='file to run as main program')
642 | parser.add_argument('arguments', nargs=argparse.REMAINDER,
643 | help='arguments to the program')
644 |
645 | opts = parser.parse_args()
646 |
647 | if opts.ignore_dir:
648 | _prefix = sysconfig.get_path("stdlib")
649 | _exec_prefix = sysconfig.get_path("platstdlib")
650 |
651 | def parse_ignore_dir(s):
652 | s = os.path.expanduser(os.path.expandvars(s))
653 | s = s.replace('$prefix', _prefix).replace('$exec_prefix', _exec_prefix)
654 | return os.path.normpath(s)
655 |
656 | opts.ignore_module = [mod.strip()
657 | for i in opts.ignore_module for mod in i.split(',')]
658 | opts.ignore_dir = [parse_ignore_dir(s)
659 | for i in opts.ignore_dir for s in i.split(os.pathsep)]
660 |
661 | if opts.report:
662 | if not opts.file:
663 | parser.error('-r/--report requires -f/--file')
664 | results = CoverageResults(infile=opts.file, outfile=opts.file)
665 | return results.write_results(opts.missing, opts.summary, opts.coverdir)
666 |
667 | if not any([opts.trace, opts.count, opts.listfuncs, opts.trackcalls]):
668 | parser.error('must specify one of --trace, --count, --report, '
669 | '--listfuncs, or --trackcalls')
670 |
671 | if opts.listfuncs and (opts.count or opts.trace):
672 | parser.error('cannot specify both --listfuncs and (--trace or --count)')
673 |
674 | if opts.summary and not opts.count:
675 | parser.error('--summary can only be used with --count or --report')
676 |
677 | if opts.progname is None:
678 | parser.error('progname is missing: required with the main options')
679 |
680 | t = Trace(opts.count, opts.trace, countfuncs=opts.listfuncs,
681 | countcallers=opts.trackcalls, ignoremods=opts.ignore_module,
682 | ignoredirs=opts.ignore_dir, infile=opts.file,
683 | outfile=opts.file, timing=opts.timing)
684 | try:
685 | if opts.module:
686 | import runpy
687 | module_name = opts.progname
688 | mod_name, mod_spec, code = runpy._get_module_details(module_name)
689 | sys.argv = [code.co_filename, *opts.arguments]
690 | globs = {
691 | '__name__': '__main__',
692 | '__file__': code.co_filename,
693 | '__package__': mod_spec.parent,
694 | '__loader__': mod_spec.loader,
695 | '__spec__': mod_spec,
696 | '__cached__': None,
697 | }
698 | else:
699 | sys.argv = [opts.progname, *opts.arguments]
700 | sys.path[0] = os.path.dirname(opts.progname)
701 |
702 | with io.open_code(opts.progname) as fp:
703 | code = compile(fp.read(), opts.progname, 'exec')
704 | # try to emulate __main__ namespace as much as possible
705 | globs = {
706 | '__file__': opts.progname,
707 | '__name__': '__main__',
708 | '__package__': None,
709 | '__cached__': None,
710 | }
711 | t.runctx(code, globs, globs)
712 | except OSError as err:
713 | sys.exit("Cannot run file %r because: %s" % (sys.argv[0], err))
714 | except SystemExit:
715 | pass
716 |
717 | results = t.results()
718 |
719 | if not opts.no_report:
720 | results.write_results(opts.missing, opts.summary, opts.coverdir)
721 |
722 | if __name__=='__main__':
723 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn.metrics import precision_recall_fscore_support
3 | import warnings
4 | import numpy as np
5 |
6 | warnings.simplefilter("ignore")
7 |
8 | def write(content, file):
9 | with open(file, 'a') as f:
10 | f.write(content + '\n')
11 |
12 | def pad_targets(target, max_node):
13 | batch_size, current_max_node = target.shape
14 | padded_target = torch.zeros(batch_size, max_node, device=target.device)
15 | padded_target[:, :current_max_node] = target
16 | return padded_target
17 |
18 | def calculate_metrics(y_true, y_pred):
19 | precisions, recalls, fscores = [], [], []
20 | for i in range(y_true.shape[1]):
21 | precision, recall, fscore, _ = precision_recall_fscore_support(y_true[:, i], y_pred[:, i], average='binary')
22 | precisions.append(precision)
23 | recalls.append(recall)
24 | fscores.append(fscore)
25 | return sum(precisions) / len(precisions), sum(recalls) / len(recalls), sum(fscores) / len(fscores)
26 |
27 | # def accuracy_whole_list(y_true, y_pred):
28 | # correct = (y_true == y_pred).all(axis=1).sum().item()
29 | # return correct / y_true.shape[0]
30 |
31 | def accuracy_whole_list(y_true, y_pred, lengths):
32 | correct = 0
33 | total = 0
34 | for i in range(len(lengths)):
35 | length = lengths[i]
36 | if np.array_equal(y_true[i][:length], y_pred[i][:length]):
37 | correct += 1
38 | return correct
--------------------------------------------------------------------------------