├── requirements.txt
├── .gitignore
├── PDG.py
├── README.md
├── CG.py
├── test.py
├── DDG.py
├── AST.py
├── CDG.py
├── utils.py
├── CFG.py
└── File.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | tree-sitter=0.20.2
3 | inflection
4 | graphviz
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pdf
2 | build
3 | __pycache__
4 | test.c
5 | test
6 | .vscode
7 | *.gv
8 | pdf
--------------------------------------------------------------------------------
/PDG.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from CFG import *
3 | from CDG import *
4 | from DDG import *
5 | from graphviz import Digraph
6 |
7 | class PDG(CFG):
8 | def construct_pdg(self, code):
9 | cdg = CDG(self.language)
10 | ddg = DDG(self.language)
11 | cdg.construct_cdg(code)
12 | ddg.construct_ddg(code)
13 | self.pdgs = []
14 | for cdg, ddg in zip(cdg.cdgs, ddg.ddgs):
15 | pdg = cdg
16 | for node, edges in ddg.edges.items():
17 | pdg.edges.setdefault(node, [])
18 | for edge in edges:
19 | pdg.edges[node].append(edge)
20 | self.pdgs.append(pdg)
21 | return self.pdgs
22 |
23 | def see_pdg(self, code, filename='PDG', pdf=True, view=False):
24 | self.construct_pdg(code)
25 | dot = Digraph(comment='PDG', strict=True)
26 | for pdg in self.pdgs:
27 | for v in pdg.nodes:
28 | if v < 0:
29 | continue
30 | node = pdg.id_to_nodes[v]
31 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>"
32 | if node.is_branch:
33 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong')
34 | elif node.type == 'function_definition':
35 | dot.node(str(node.id), label=label, fontname='fangsong')
36 | else:
37 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong')
38 | for u in pdg.edges[v]:
39 | dot.edge(str(u.id), str(v))
40 | for v in pdg.edges:
41 | for u in pdg.edges[v]:
42 | if u.type == 'DDG':
43 | dot.edge(str(v), str(u.id), label=', '.join(u.token), style='dotted')
44 | else:
45 | dot.edge(str(u.id), str(v))
46 | if pdf:
47 | dot.render(filename, view=view, cleanup=True)
48 |
49 | if __name__ == '__main__':
50 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read())
51 | pdg = PDG('c')
52 | pdg.see_pdg(code, view=True)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 更新后的代码地址放在了:https://github.com/rebibabo/INVEST
2 |
3 | 这是一个基于tree-sitter编译工具进行静态程序分析的demo, 并使用可视化工具graphviz,能够生成抽象语法树AST、控制流图CFG、数据依赖DFG、数据依赖图CDG、程序依赖图PDG、函数调用图CG等。
4 | tree-sitter网址:https://tree-sitter.github.io/tree-sitter/
5 |
6 | ## 环境配置
7 | 确保已经安装了graphviz,在windows上,官网https://www.graphviz.org/ 下载graphviz之后,配置环境变量为安装路径下的bin文件夹,例如D:\graphviz\bin\,注意末尾的'\\'不能省略,如果是linux上,运行下面命令安装:
8 | ```
9 | sudo apt-get install graphviz graphviz-doc
10 | ```
11 | 接着运行
12 | ```
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## 生成AST树
17 | AST.py能够生成AST树以及tokens,首先构造类,参数为代码语言,目前tree-sitter能够编译的语言都能够生成。
18 | ```
19 | ast = AST('c')
20 | ```
21 | 接着运行下面代码可以显示AST树
22 | ```
23 | ast.see_tree(code, view=True)
24 | ```
25 | 
26 |
27 | 运行完成之后,会在当前目录下生成ast_tree.pdf,为可视化的ast树,可以通过设置参数view=False在生成pdf文件的同时不查看文件,pdf=False不生成可视化的pdf文件,设置参数filename="filename"来更改输出文件的名称。
28 | 获得代码的tokens可以运行下面的代码,返回值为token的列表。
29 | ```
30 | ast.tokenize(code)
31 | #['int', 'main', '(', ')', '{', 'int', 'abc', '=', '1', ';', 'int', 'b', '=', '2', ';', 'int', 'c', '=', 'a', '+', 'b', ';', 'while', '(', 'i', '<', '10', ')', '{', 'i', '++', ';', '}', '}']
32 | ```
33 |
34 | ## 生成CFG
35 | CFG.py继承自AST类,能够生成控制流图,运行下面命令可以获得代码的CFG:
36 | ```
37 | cfg = CFG('c')
38 | cfg.see_cfg(code, view=True)
39 | ```
40 | 生成的CFG图样例:
41 | 
42 | see_cfg的参数和see_tree的参数一样
43 |
44 | ## 生成CDG
45 | CDG.py继承自CFG类,能够生成控制依赖图,运行下面代码能够获得CDG图:
46 | ```
47 | cdg = CDG('c')
48 | cdg.see_cdg(code, view=True)
49 | ```
50 | 生成的CDG图样例:
51 | 
52 |
53 | ## 生成DDG
54 | DDG.py继承自CFG类,能够生成数据依赖图,运行下面代码能够获得DDG图:
55 | ```
56 | ddg = DDG('c')
57 | ddg.see_ddg(code, view=True)
58 | ```
59 | 生成的DDG图样例:
60 | 
61 |
62 | ## 生成PDG
63 | PDG.py将CDG和DDG的节点和边结合在一起,运行下面代码获得PDG图:
64 | ```
65 | pdg = PDG('c')
66 | pdg.see_pdg(code, view=True)
67 | ```
68 | 生成的PDG图样例:
69 | 
70 |
71 | ## 生成CG
72 | File.py继承自AST.py,能够生成函数调用图,运行下面代码能够生成单个项目的CG图
73 | ```
74 | file = File("path to project")
75 | file.see_cg(code, view=True)
76 | ```
77 | 生成项目目录所有文件的CG图:
78 | ```
79 | dir = Dir('path to project')
80 | ```
81 | 生成CG图样例:
82 | 
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/CG.py:
--------------------------------------------------------------------------------
1 | from AST import *
2 |
3 | def get_identifiers(node):
4 | call_nodes = []
5 | if node.child_count == 0:
6 | return call_nodes
7 | for child in node.children:
8 | if child.type == 'identifier':
9 | call_nodes.append(child)
10 | else:
11 | call_nodes.extend(get_identifiers(child))
12 | return call_nodes
13 |
14 | class Function:
15 | def __init__(self, node):
16 | self.type = text(node.child_by_field_name('type'))
17 | while node.type != 'function_declarator':
18 | node = node.child_by_field_name('declarator')
19 | self.name = text(node.child_by_field_name('declarator'))
20 | paramters = node.child_by_field_name('parameters')
21 | self.signature = {'return': self.type, 'name': self.name, 'parameters': {}}
22 | self.id = node.start_point[0]
23 | for param in paramters.children:
24 | if text(param) not in [',', '(', ')']:
25 | type = text(param.child_by_field_name('type'))
26 | name = text(param.child_by_field_name('declarator'))
27 | self.signature['parameters'][name] = type
28 |
29 | def __eq__(self, other):
30 | return self.signature == other.signature
31 |
32 | def __str__(self):
33 | return str(self.signature)
34 |
35 | class CG(AST):
36 | def __init__(self, language):
37 | super().__init__(language)
38 | self.node_set = {} # 存放每一个节点的信息
39 | self.cg = [] # 存放每一个函数的CFG图
40 | self.funcs = {}
41 |
42 | def create_cg(self, root_node):
43 | func_def_nodes = {}
44 | for child in root_node.children: # 先找到所有的函数定义节点
45 | if child.type == 'function_definition':
46 | func_info = Function(child)
47 | func_def_nodes[func_info.name] = child
48 | self.funcs[func_info.name] = func_info
49 | for node in func_def_nodes: # 再依次遍历每一个函数中调用的函数
50 | func_node = func_def_nodes[node]
51 | ids = get_identifiers(func_node)
52 | cg_call_nodes = set()
53 | for id in ids:
54 | id_name = text(id)
55 | if id_name in func_def_nodes:
56 | cg_call_nodes.add(self.funcs[id_name].id)
57 | self.cg.append((Function(func_node), cg_call_nodes))
58 | return self.cg
59 |
60 | def see_cg(self, code_path, filename='CG', pdf=True, view=False):
61 | code = r'{}'.format(open(code_path, 'r', encoding='utf-8').read())
62 | tree = self.parser.parse(bytes(code, 'utf8'))
63 | root_node = tree.root_node
64 | CG = self.create_cg(root_node)
65 | dot = Digraph(comment=filename)
66 | for node in CG:
67 | dot.node(str(node[0].id), shape='rectangle', label=node[0].name, fontname='fangsong')
68 | for call_node in node[1]:
69 | if str(node[0].id) != str(call_node):
70 | dot.edge(str(node[0].id), str(call_node))
71 | if pdf:
72 | dot.render(filename, view=view, cleanup=True)
73 |
74 |
75 | if __name__ == '__main__':
76 | cg = CG('c')
77 | cg.see_cg('test.c', view=True)
78 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | def detect_cycles(graph):
2 | visited = set()
3 | stack = []
4 | cycles = []
5 |
6 | def dfs(node):
7 | if node in stack:
8 | cycle = stack[stack.index(node):]
9 | cycles.append(cycle)
10 | return
11 | if node in visited:
12 | return
13 |
14 | visited.add(node)
15 | stack.append(node)
16 |
17 | for neighbor in graph.get(node, []):
18 | dfs(neighbor)
19 |
20 | stack.pop()
21 |
22 | for node in graph:
23 | dfs(node)
24 |
25 | return cycles
26 |
27 | # 字典表示的图,键是文件名,值是包含的文件列表
28 | file_graph = {
29 | '3rd/lua/lapi.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'],
30 | '3rd/lua/lapi.h': ['3rd/lua/lstate.h'],
31 | '3rd/lua/lcode.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstring.h', '3rd/lua/lvm.h'],
32 | '3rd/lua/ldebug.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'],
33 | '3rd/lua/ldebug.h': ['3rd/lua/lstate.h'],
34 | '3rd/lua/ldo.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'],
35 | '3rd/lua/ldo.h': ['3rd/lua/lstate.h'],
36 | '3rd/lua/ldump.c': ['3rd/lua/lstate.h'],
37 | '3rd/lua/lfunc.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'],
38 | '3rd/lua/lgc.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h'],
39 | '3rd/lua/lgc.h': ['3rd/lua/lstate.h'],
40 | '3rd/lua/llex.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'],
41 | '3rd/lua/lmem.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h'],
42 | '3rd/lua/lobject.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/lvm.h'],
43 | '3rd/lua/lparser.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'],
44 | '3rd/lua/lstate.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h'],
45 | '3rd/lua/lstate.h': ['3rd/lua/ltm.h'],
46 | '3rd/lua/lstring.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'],
47 | '3rd/lua/lstring.h': ['3rd/lua/lgc.h', '3rd/lua/lstate.h'],
48 | '3rd/lua/ltable.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/lvm.h'],
49 | '3rd/lua/ltests.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'],
50 | '3rd/lua/ltm.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'],
51 | '3rd/lua/ltm.h': ['3rd/lua/lstate.h'],
52 | '3rd/lua/luac.c': ['3rd/lua/ldebug.h', '3rd/lua/lstate.h'],
53 | '3rd/lua/lundump.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstring.h'],
54 | '3rd/lua/lvm.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'],
55 | '3rd/lua/lvm.h': ['3rd/lua/ldo.h', '3rd/lua/ltm.h'],
56 | 'A': ['B'],
57 | 'B': ['C'],
58 | 'C': ['A']
59 | }
60 |
61 |
62 | cycle = detect_cycles(file_graph)
63 | if cycle:
64 | print("存在环路:", cycle)
65 | else:
66 | print("不存在环路")
67 |
--------------------------------------------------------------------------------
/DDG.py:
--------------------------------------------------------------------------------
1 | from CFG import *
2 | from utils import *
3 | from graphviz import Graph
4 |
5 | class DDG(CFG):
6 | def construct_ddg(self, code):
7 | # 参考https://home.cs.colorado.edu/~kena/classes/5828/s99/lectures/lecture25.pdf 中19页的算法
8 | if self.check_syntax(code):
9 | print('Syntax Error')
10 | exit(1)
11 | cfgs = self.see_cfg(code)
12 | self.ddgs = []
13 | for cfg in cfgs:
14 | edge = {}
15 | ddg = {}
16 | defs = cfg.defs
17 | uses = cfg.uses
18 | # print(f"defs: {defs}")
19 | # print(f"uses: {uses}")
20 | # There is a data dependence from X to Y with respect to a variable v iff
21 | # there is a non-null path p from X to Y with no intervening definition of v either:
22 | # X contains a definition of v and Y a use of v;
23 | # X contains a use of v and Y a definition of v; or
24 | # X contains a definition of v and Y a definition of v
25 |
26 | # def X to use Y
27 | for X in defs:
28 | if X not in uses:
29 | continue
30 | def_ = defs[X]
31 | use_ = uses[X]
32 | for d in def_:
33 | for u in use_:
34 | paths = cfg.findAllPath(d, u)
35 | for path in paths:
36 | is_arrival = True
37 | for n in path[1:-1]:
38 | node = cfg.id_to_nodes[n]
39 | if X in node.defs:
40 | is_arrival = False
41 | break
42 | if not is_arrival:
43 | break
44 | edge.setdefault((d, u), set())
45 | edge[(d, u)].add(X)
46 | # use X to def Y
47 | for X in uses:
48 | if X not in defs:
49 | continue
50 | use_ = uses[X]
51 | def_ = defs[X]
52 | for u in use_:
53 | for d in def_:
54 | paths = cfg.findAllPath(u, d)
55 | for path in paths:
56 | is_arrival = True
57 | for n in path[1:-1]:
58 | node = cfg.id_to_nodes[n]
59 | if X in node.defs:
60 | is_arrival = False
61 | break
62 | if not is_arrival:
63 | break
64 | edge.setdefault((u, d), set())
65 | edge[(u, d)].add(X)
66 | # def X to def Y
67 | for X in defs:
68 | def_ = defs[X]
69 | for d1 in def_:
70 | for d2 in def_:
71 | paths = cfg.findAllPath(d1, d2)
72 | for path in paths:
73 | is_arrival = True
74 | for n in path[1:-1]:
75 | node = cfg.id_to_nodes[n]
76 | if X in node.defs:
77 | is_arrival = False
78 | break
79 | if not is_arrival:
80 | break
81 | edge.setdefault((d1, d2), set())
82 | edge[(d1, d2)].add(X)
83 | for (u, v), Xs in edge.items():
84 | ddg.setdefault(u, [])
85 | ddg[u].append(Edge(v, type='DDG', token=Xs))
86 | cfg.edges = ddg
87 | self.ddgs.append(cfg)
88 | return self.ddgs
89 |
90 | def see_ddg(self, code, filename='DDG', pdf=True, view=False):
91 | self.construct_ddg(code)
92 | dot = Digraph(comment=filename, strict=True)
93 | for ddg in self.ddgs:
94 | for node in ddg.nodes:
95 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>"
96 | if node.is_branch:
97 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong')
98 | elif node.type == 'function_definition':
99 | dot.node(str(node.id), label=label, fontname='fangsong')
100 | else:
101 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong')
102 | for v in ddg.edges:
103 | for u in ddg.edges[v]:
104 | dot.edge(str(v), str(u.id), label=', '.join(u.token), style='dotted')
105 | if pdf:
106 | dot.render(filename, view=view, cleanup=True)
107 |
108 | if __name__ == '__main__':
109 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read())
110 | ddg = DDG('c')
111 | ddg.see_ddg(code ,view=True)
--------------------------------------------------------------------------------
/AST.py:
--------------------------------------------------------------------------------
1 | from tree_sitter import Parser, Language
2 | from graphviz import Digraph
3 | import os
4 |
5 | text = lambda node: node.text.decode('utf-8')
6 |
7 | class Node:
8 | '''用于存储AST树的节点信息,将tree-sitter.Node类型转换为Node类型,tree-sitter不能序列化'''
9 | def __init__(self, node, id):
10 | self.type = node.type
11 | self.start_byte = node.start_byte
12 | self.end_byte = node.end_byte
13 | self.start_point = node.start_point
14 | self.end_point = node.end_point
15 | self.text = text(node)
16 | self.id = id
17 |
18 | def __eq__(self, other):
19 | return self.id == other.id
20 |
21 | class TreeNode:
22 | '''用于存储AST树'''
23 | def __init__(self, root):
24 | '''输入: root,树的根节点,为tree-sitter.Node类型'''
25 | self.nodes = [Node(root, 0)]
26 | self.id_to_node = {0: root}
27 | self.edges = {}
28 | self.traverse_tree(root)
29 |
30 | def traverse_tree(self, node, pid=0): # python递归函数想要修改函数参数的值,只能通过列表等方式
31 | '''输入: node,当前节点,为tree-sitter.Node类型'''
32 | children = []
33 | for child in node.children:
34 | id = len(self.nodes) # 唯一标记节点的id
35 | child_node = Node(child, id)
36 | self.nodes.append(child_node)
37 | self.id_to_node[id] = child_node
38 | children.append(id)
39 | self.traverse_tree(child, id)
40 | self.edges[pid] = children
41 |
42 | def print_tree(self):
43 | '''打印AST树'''
44 | def dfs(node, depth):
45 | if self.edges[node.id] == [] and node.type != node.text:
46 | print(' ' * depth + node.type + ': ' + node.text)
47 | else:
48 | print(' ' * depth + node.type)
49 | for child in self.edges[node.id]:
50 | dfs(self.nodes[child], depth + 1)
51 | dfs(self.nodes[0], 0)
52 |
53 | def get_node(self, id):
54 | return self.id_to_node[id]
55 |
56 | class AST:
57 | def __init__(self, language):
58 | self.language = language
59 | if not os.path.exists(f'./build/{language}-languages.so'):
60 | if not os.path.exists(f'./tree-sitter-{language}'):
61 | os.system(f'git clone https://github.com/tree-sitter/tree-sitter-{language}')
62 | Language.build_library(
63 | f'./build/{language}-languages.so',
64 | [
65 | f'./tree-sitter-{language}',
66 | ]
67 | )
68 | LANGUAGE = Language(f'./build/{language}-languages.so', language)
69 | parser = Parser()
70 | parser.set_language(LANGUAGE)
71 | self.parser = parser
72 |
73 | def see_tree(self, code, filename='ast_tree', pdf=True, view=False):
74 | '''
75 | 生成AST树的可视化图
76 | code: 输入的代码
77 | filename: 生成的文件名
78 | pdf: 是否生成pdf文件
79 | view: 是否打开pdf文件
80 | '''
81 | tree = self.parser.parse(bytes(code, 'utf8'))
82 | root_node = tree.root_node
83 | tree = TreeNode(root_node)
84 | dot = Digraph(comment='AST Tree', strict=True)
85 | for edge, children in tree.edges.items():
86 | node = tree.get_node(edge)
87 | dot.node(str(edge), shape='rectangle', label=node.type, fontname='fangsong')
88 | dot.edges([(str(edge), str(child)) for child in children])
89 | if children == []:
90 | dot.node(str(-edge), shape='ellipse', label=node.text, fontname='fangsong')
91 | dot.edges([(str(edge), str(-edge))])
92 | if pdf:
93 | dot.render(filename, view=view, cleanup=True)
94 |
95 | def tokenize(self, code):
96 | '''输入代码code,返回token列表'''
97 | def tokenize_help(node, tokens):
98 | # 遍历整个AST树,返回符合func的节点列表results
99 | if not node.children:
100 | tokens.append(text(node))
101 | return
102 | for n in node.children:
103 | tokenize_help(n, tokens)
104 | tree = self.parser.parse(bytes(code, 'utf8'))
105 | root_node = tree.root_node
106 | print(type(root_node))
107 | tokens = []
108 | tokenize_help(root_node, tokens)
109 | return tokens
110 |
111 | def check_syntax(self, code):
112 | '''检查代码是否有语法错误'''
113 | tree = self.parser.parse(bytes(code, 'utf8'))
114 | # 找出来Error的位置
115 | root_node = tree.root_node
116 | error_nodes = []
117 | def find_error(node):
118 | if node.type == 'ERROR':
119 | error_nodes.append(node)
120 | for child in node.children:
121 | find_error(child)
122 | find_error(root_node)
123 | for i, node in enumerate(error_nodes):
124 | print(f"error {i:>3} : line {node.start_point[0]:>3} row {node.start_point[1]:>3} -to- line {node.end_point[0]:>3} row {node.end_point[1]:>3}")
125 | return tree.root_node.has_error
126 |
127 | if __name__ == '__main__':
128 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read())
129 | ast = AST('cpp')
130 | print(ast.tokenize(code))
131 | ast.see_tree(code, view=True)
132 | # node = Node(1, 0, 'A', 0, 0, 0, 0, 'A')
133 | # print(node())
--------------------------------------------------------------------------------
/CDG.py:
--------------------------------------------------------------------------------
1 | from CFG import *
2 | from utils import *
3 | from graphviz import Graph
4 |
5 | class Tree:
6 | def __init__(self, V, children, root): # 输入节点集合,children为字典,key为节点,value为子节点列表,根节点
7 | self.vertex = V
8 | self.children = children
9 | self.root = root
10 | self.parent = {}
11 | for node in children: # 初始化parent字典
12 | for each in children[node]:
13 | self.parent[each] = node
14 | self.parent[root] = root
15 | for v in V:
16 | if v not in self.children:
17 | self.children[v] = []
18 | self.depth = self.get_nodes_depth(root, {root:0})
19 |
20 | def get_nodes_depth(self, root, depth):
21 | # 递归计算每个节点的深度
22 | for child in self.children[root]:
23 | depth[child] = depth[root] + 1
24 | depth = self.get_nodes_depth(child, depth)
25 | return depth
26 |
27 | def get_lca(self, a, b):
28 | # 计算a,b的最近公共祖先
29 | if self.depth[a] > self.depth[b]:
30 | diff = self.depth[a] - self.depth[b]
31 | while diff > 0:
32 | a = self.parent[a]
33 | diff -= 1
34 | elif self.depth[a] < self.depth[b]:
35 | diff = self.depth[b] - self.depth[a]
36 | while diff > 0:
37 | b = self.parent[b]
38 | diff -= 1
39 | while a != b:
40 | a = self.parent[a]
41 | b = self.parent[b]
42 | return a
43 |
44 | def reset_by_parent(self):
45 | # 根据parent字典重置children字典
46 | self.children = {v:[] for v in self.vertex}
47 | for node in self.parent:
48 | if node != self.parent[node]:
49 | self.children[self.parent[node]].append(node)
50 |
51 | def see_tree(self):
52 | dot = Graph(comment='Tree')
53 | for node in self.vertex:
54 | dot.node(str(node), shape='rectangle', label=str(node), fontname='fangsong')
55 | for node in self.children:
56 | for child in self.children[node]:
57 | dot.edge(str(node), str(child))
58 | dot.view()
59 |
60 | class CDG(CFG):
61 | def get_subTree(self, cfg):
62 | # 按照广度优先遍历,找出一个子树
63 | V, E, Exit = cfg.nodes, cfg.edges, cfg.Exit
64 | visited = {v:False for v in V}
65 | queue = [Exit]
66 | visited[Exit] = True
67 | subTree = {}
68 | while queue:
69 | node = queue.pop()
70 | if node not in E:
71 | continue
72 | for edge in E[node]:
73 | v = edge.id
74 | if not visited[v]:
75 | queue.append(v)
76 | visited[v] = True
77 | subTree.setdefault(node, [])
78 | subTree[node].append(v)
79 | return subTree
80 |
81 | def get_prev(self, cfgs):
82 | # 计算每个节点的前驱节点
83 | prev = {}
84 | for cfg in cfgs:
85 | for node, edges in cfg.edges.items():
86 | prev.setdefault(node, [])
87 | for next_node in edges:
88 | id = next_node.id
89 | prev.setdefault(id, [])
90 | prev[id].append(node)
91 | return prev
92 |
93 | def post_dominator_tree(self, cfgs, prev):
94 | # 生成后支配树
95 | PDT = []
96 | for cfg in cfgs: # 遍历每一个函数的CFG
97 | subTree = self.get_subTree(cfg) # 找出一个子树
98 | V, root = cfg.nodes, cfg.Exit
99 | tree = Tree(V, subTree, root) # 生成树
100 | changed = True
101 | while changed:
102 | changed = False
103 | for v in V: # dominator tree算法
104 | if v != root:
105 | for u in prev[v]:
106 | parent_v = tree.parent[v]
107 | if u not in tree.vertex:
108 | cfg.see_graph()
109 | if u != parent_v and parent_v != tree.get_lca(u, parent_v):
110 | tree.parent[v] = tree.get_lca(u, parent_v)
111 | changed = True
112 | tree.reset_by_parent() # 根据parent字典重置children字典
113 | PDT.append(tree)
114 | return PDT
115 |
116 | def dominance_frontier(self, code):
117 | # 输入代码,返回CFG和支配边界
118 | cfgs = self.see_cfg(code)
119 | reverse_cfgs = [cfg.reverse() for cfg in cfgs] # 计算逆向CFG
120 | prev = self.get_prev(reverse_cfgs) # 计算每个节点的前驱节点
121 | PDT = self.post_dominator_tree(reverse_cfgs, prev) # 输入逆向CFG,输出后支配树
122 | DF = []
123 | for cfg, tree in zip(reverse_cfgs, PDT):
124 | V = cfg.nodes
125 | DF.append({v:[] for v in V})
126 | for v in V:
127 | if len(prev[v]) > 1:
128 | for p in prev[v]:
129 | runner = p
130 | while runner != tree.parent[v]:
131 | DF[-1][runner].append(v)
132 | runner = tree.parent[runner]
133 | return cfgs, DF
134 |
135 | def construct_cdg(self, code):
136 | # 输入代码,返回CDG
137 | cfgs, DF = self.dominance_frontier(code)
138 | self.cdgs = []
139 | for cfg, df in zip(cfgs, DF):
140 | for v in df:
141 | df[v] = [Edge(u, type='CDG') for u in df[v]]
142 | cfg.edges = df
143 | self.cdgs.append(cfg)
144 | return self.cdgs
145 |
146 | def see_cdg(self, code, filename='CDG', pdf=True, view=False):
147 | self.construct_cdg(code)
148 | # dot = Digraph(comment=filename, strict=True)
149 | dot = Digraph(comment=filename, strict=True, format='pdf', graph_attr={'rankdir':'LR'})
150 | for cdg in self.cdgs:
151 | for n in cdg.edges:
152 | if n < 0:
153 | continue
154 | node = cdg.id_to_nodes[n]
155 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>"
156 | if node.is_branch:
157 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong')
158 | elif node.type == 'function_definition':
159 | dot.node(str(node.id), label=label, fontname='fangsong')
160 | else:
161 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong')
162 | for v in cdg.edges:
163 | for u in cdg.edges[v]:
164 | dot.edge(str(u.id), str(v))
165 | if pdf:
166 | dot.render(filename, view=view, cleanup=True)
167 |
168 | if __name__ == '__main__':
169 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read())
170 | cdg = CDG('c')
171 | # cdg.see_cfg(code, view=True)
172 | cdg.see_cdg(code, view=True)
173 |
174 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | text = lambda node: node.text.decode('utf-8')
2 | from graphviz import Digraph
3 |
4 | class Node:
5 | def __init__(self, node):
6 | self.line = node.start_point[0] + 1
7 | self.type = node.type
8 | self.id = hash((node.start_point, node.end_point)) % 1000000
9 | # self.id = node.start_point[0] + 1
10 | self.is_branch = False
11 | if node.type == 'function_definition':
12 | self.text = text(node.child_by_field_name('declarator').child_by_field_name('declarator')) # 函数名
13 | elif node.type in ['if_statement', 'while_statement', 'for_statement', 'switch_statement']:
14 | if node.type == 'if_statement':
15 | body = node.child_by_field_name('consequence')
16 | else:
17 | body = node.child_by_field_name('body')
18 | node_text = ''
19 | for child in node.children:
20 | if child == body:
21 | break
22 | node_text += text(child)
23 | self.text = node_text
24 | if node.type != 'switch_statement':
25 | self.is_branch = True
26 | elif node.type == 'do_statement':
27 | self.text = f'while{text(node.child_by_field_name("condition"))}'
28 | self.is_branch = True
29 | elif node.type == 'case_statement':
30 | node_text = ''
31 | for child in node.children:
32 | if child.type == ':':
33 | break
34 | node_text += ' ' + text(child)
35 | self.text = node_text
36 | self.is_branch = True
37 | else:
38 | self.text = text(node)
39 | defs, uses = self.get_def_use_info(node)
40 | self.defs = defs
41 | self.uses = uses
42 |
43 | def get_all_identifier(self, node):
44 | ids = []
45 | def help(node):
46 | # 获取所有的变量名
47 | if node is None:
48 | return
49 | if node.type == 'identifier' and node.parent.type not in ['call_expression']:
50 | ids.append(text(node))
51 | for child in node.children:
52 | help(child)
53 | help(node)
54 | return ids
55 |
56 | def get_def_id(self, node):
57 | update_ids, assignment_ids = [], []
58 | def help(node):
59 | if node.type == 'update_expression':
60 | update_ids.append(text(node.child_by_field_name('argument')))
61 | if node.type == 'assignment_expression':
62 | assignment_ids.append(text(node.child_by_field_name('left')))
63 | for child in node.children:
64 | help(child)
65 | help(node)
66 | return update_ids, assignment_ids
67 |
68 | def get_node_def_use(self, node):
69 | uses = self.get_all_identifier(node)
70 | update_ids, assignment_ids = self.get_def_id(node)
71 | uses = list(set(uses) - set(assignment_ids) | set(update_ids))
72 | return update_ids + assignment_ids, uses
73 |
74 | def get_def_use_info(self, node):
75 | # 获取变量的定义信息
76 | defi, uses = [], []
77 | if node.type == 'function_definition':
78 | defi = self.get_all_identifier(node.child_by_field_name('declarator').child_by_field_name('parameters'))
79 | elif node.type == 'expression_statement':
80 | node = node.children[0]
81 | if node.type == 'call_expression': # scanf语句
82 | if text(node.child_by_field_name('function')) == 'scanf':
83 | arguments = node.child_by_field_name('arguments')
84 | defi = self.get_all_identifier(arguments)
85 | else:
86 | d, u = self.get_def_use_info(node)
87 | defi.extend(d)
88 | uses.extend(u)
89 | elif node.type == 'assignment_expression': # a += b
90 | d = text(node.child_by_field_name('left'))
91 | op = text(node.children[1])
92 | u = self.get_all_identifier(node.child_by_field_name('right'))
93 | if op != '=':
94 | u.append(d)
95 | defi.append(d)
96 | uses = u
97 | elif node.type == 'update_expression': # a++;
98 | d = text(node.child_by_field_name('argument'))
99 | defi.append(d)
100 | uses.append(d)
101 | else:
102 | d, u = self.get_def_use_info(node)
103 | defi.extend(d)
104 | uses.extend(u)
105 | elif 'declaration' in node.type:
106 | defi.extend(self.get_all_identifier(node))
107 | elif 'declarator' in node.type:
108 | d = text(node.child_by_field_name('declarator'))
109 | u = self.get_all_identifier(node)
110 | defi.append(d)
111 | uses = [i for i in u if i != d]
112 | elif node.type in ['if_statement', 'while_statement', 'do_statement', 'switch_statement']:
113 | condition = node.child_by_field_name('condition')
114 | d2, u2 = self.get_node_def_use(condition)
115 | defi.extend(d2)
116 | uses.extend(u2)
117 | elif node.type == 'for_statement':
118 | initializer = node.child_by_field_name('initializer')
119 | condition = node.child_by_field_name('condition')
120 | update = node.child_by_field_name('update')
121 | d1, u1 = self.get_node_def_use(initializer)
122 | d2, u2 = self.get_node_def_use(condition)
123 | d3, u3 = self.get_node_def_use(update)
124 | defi.extend(d1 + d2 + d3)
125 | uses.extend(u1 + u2 + u3)
126 | else:
127 | d, u = self.get_node_def_use(node)
128 | defi.extend(d)
129 | uses.extend(u)
130 | defi = list(set(defi))
131 | uses = list(set(uses))
132 | return defi, uses
133 |
134 | class Edge:
135 | def __init__(self, id, label='', type='', token=[]):
136 | self.id = id
137 | self.label = label
138 | self.type = type # CDG/DDG
139 | self.token = token # DDG的变量名
140 |
141 | class Graph:
142 | def __init__(self):
143 | self.nodes = set()
144 | self.edges = {}
145 | self.prev_nodes = {}
146 | self.id_to_nodes = {}
147 | self.r = self.Exit = None
148 | self.defs = {} # 存放每一个函数的变量定义信息
149 | self.uses = {} # 存放每一个函数的变量使用信息
150 |
151 | def add_edge(self, edge):
152 | node = edge[0]
153 | edges = edge[1]
154 | self.nodes.add(node)
155 | self.id_to_nodes[node.id] = node
156 | self.edges.setdefault(node.id, [])
157 | for prev_node, label in edges:
158 | self.edges.setdefault(prev_node, [])
159 | self.edges[prev_node].append(Edge(node.id, label))
160 |
161 | def see_graph(self):
162 | dot = Digraph(comment='CFG', strict=True)
163 | for node, edges in self.edges.items():
164 | for next_node in edges:
165 | left = 'Exit' if node == self.Exit else str(node)
166 | right = 'r' if next_node.id == self.r else str(next_node.id)
167 | dot.edge(left, right)
168 | dot.view()
169 |
170 | def reverse(self):
171 | # 返回反向的CFG图,添加了Exit节点
172 | E = {}
173 | V = set()
174 | for node, edges in self.edges.items():
175 | V.add(node)
176 | if self.id_to_nodes[node].type == 'function_definition': # 函数入口
177 | self.r = node
178 | self.Exit = -node
179 | V.add(-node) # 增加一个Exit节点,该节点的id为函数节点的相反数
180 | E[self.Exit] = [Edge(self.r)] # 将Exit节点连接到函数节点
181 | for node, edges in self.edges.items():
182 | if edges == [] or (len(edges) == 1 and edges[0].label == 'Y'): # 没有出节点或者出节点只有一个且是Y(没有N)
183 | E[self.Exit].append(Edge(node)) # 将Exit节点连接到没有出节点的节点
184 | for edge in edges:
185 | # node -> edge.id 变成 edge.id -> node
186 | E.setdefault(edge.id, [])
187 | E[edge.id].append(Edge(node))
188 | self.nodes = V
189 | self.edges = E
190 | return self
191 |
192 | def Adj(self, node):
193 | adjs = []
194 | if node in self.edges:
195 | for edge in self.edges[node]:
196 | adjs.append(edge.id)
197 | return adjs
198 |
199 | def get_def_use_info(self):
200 | for node in self.nodes:
201 | for each in node.defs:
202 | self.defs.setdefault(each, [])
203 | self.defs[each].append(node.id)
204 | for each in node.uses:
205 | self.uses.setdefault(each, [])
206 | self.uses[each].append(node.id)
207 |
208 | def findAllPath(self, start, end):
209 | # 算法参考:https://zhuanlan.zhihu.com/p/84437102
210 | # 输入两点,输出所有的路径列表
211 | paths, s1, s2 = [], [], [] # 存放所有路径,主栈,辅助栈
212 | s1.append(start)
213 | s2.append(self.Adj(start))
214 | while s1: # 主栈不为空
215 | s2_top = s2[-1]
216 | if s2_top: # 邻接节点列表不为空
217 | s1.append(s2_top[0]) # 将邻接节点列表首个元素添加到主栈
218 | s2[-1] = s2_top[1:] # 将辅助栈的邻接节点列表首个元素删除
219 | temp = [] # 建栈,需要判断邻接节点是否在主栈中
220 | for each in self.Adj(s2_top[0]):
221 | if each not in s1:
222 | temp.append(each)
223 | s2.append(temp)
224 | else: # 削栈
225 | s1.pop()
226 | s2.pop()
227 | continue
228 | if s1[-1] == end: # 找到一条路径
229 | paths.append(s1.copy())
230 | s1.pop() # 回溯
231 | s2.pop()
232 | return paths
--------------------------------------------------------------------------------
/CFG.py:
--------------------------------------------------------------------------------
1 | from AST import *
2 | from utils import *
3 | import html
4 | import copy
5 |
6 | def get_break_continue_node(node):
7 | # 找到node节点循环中的所有break和continue节点并返回
8 | break_nodes, continue_nodes = [], []
9 | for child in node.children:
10 | if child.type == 'break_statement':
11 | break_nodes.append(child)
12 | elif child.type == 'continue_statement':
13 | continue_nodes.append(child)
14 | elif child.type not in ['for_statement', 'while_statement']:
15 | b_node, c_nodes = get_break_continue_node(child)
16 | break_nodes.extend(b_node)
17 | continue_nodes.extend(c_nodes)
18 | return break_nodes, continue_nodes
19 |
20 | def get_edge(in_nodes):
21 | # 输入入节点,返回入边的列表,边为(parent_id, label)
22 | edge = []
23 | for in_node in in_nodes:
24 | parent, label = in_node
25 | parent_id = parent.id
26 | edge.append((parent_id, label))
27 | return edge
28 |
29 | class CFG(AST):
30 | def __init__(self, language):
31 | super().__init__(language)
32 | self.cfgs = [] # 存放每一个函数的CFG图
33 |
34 | def create_cfg(self, node, in_nodes=[()]):
35 | # 输入当前节点,以及入节点,入节点为(node_info, edge_label)的列表,node_info['id']唯一确定一个节点,edge_label为边的标签
36 | if node.child_count == 0 or in_nodes == []: # 如果in_nodes为空,说明没有入节点,跳过
37 | return [], in_nodes
38 | if node.type == 'function_definition': # 如果节点是函数,则创建函数节点,并且递归遍历函数的compound_statement
39 | body = node.child_by_field_name('body')
40 | node_info = Node(node)
41 | CFG, _ = self.create_cfg(body, [(node_info, '')])
42 | return CFG + [(node_info, [])], []
43 | elif node.type == 'compound_statement': # 如果是复合语句,则递归遍历复合语句的每一条statement
44 | CFG = []
45 | for child in node.children:
46 | cfg, out_nodes = self.create_cfg(child, in_nodes)
47 | CFG.extend(cfg)
48 | in_nodes = out_nodes
49 | return CFG, in_nodes
50 | elif node.type not in ['if_statement', 'while_statement', 'for_statement', 'switch_statement', 'case_statement', 'translation_unit', 'do_statement']: # 如果是普通的语句
51 | edge = get_edge(in_nodes)
52 | node_info = Node(node)
53 | in_nodes = [(node_info, '')]
54 | if node.type in ['return_statement', 'break_statement', 'continue_statement']: # return,break,continue语句没有出节点
55 | return [(node_info, edge)], []
56 | else:
57 | return [(node_info, edge)], in_nodes
58 | elif node.type == 'if_statement': # if语句
59 | CFG = []
60 | edge = get_edge(in_nodes)
61 | node_info = Node(node)
62 | CFG.append((node_info, edge))
63 | body = node.child_by_field_name('consequence') # 获取if的主体部分
64 | cfg, out_nodes = self.create_cfg(body, [(node_info, 'Y')])
65 | CFG.extend(cfg)
66 | alternate = node.child_by_field_name('alternative') # 获取else的主体部分,可能是else,也可能是else if
67 | if alternate: # if else 或者 if else if
68 | body = alternate.children[1]
69 | cfg, al_out_nodes = self.create_cfg(body, [(node_info, 'N')])
70 | CFG.extend(cfg)
71 | return CFG, out_nodes + al_out_nodes
72 | else: # 只有if
73 | return CFG, out_nodes + [(node_info, 'N')]
74 | elif node.type in ['for_statement', 'while_statement']: # for和while循环
75 | CFG = []
76 | edge = get_edge(in_nodes)
77 | node_info = Node(node)
78 | CFG.append((node_info, edge))
79 | body = node.child_by_field_name('body') # 获取循环主体
80 | cfg, out_nodes = self.create_cfg(body, [(node_info, 'Y')])
81 | CFG.extend(cfg)
82 | for out_node in out_nodes: # 将循环主体的出节点与循环的开始节点相连
83 | parent, label = out_node
84 | parent_id = parent.id
85 | CFG.append((node_info, [(parent_id, label)]))
86 | break_nodes, continue_nodes = get_break_continue_node(node) # 求得循环内的break和continue节点
87 | out_nodes = [(node_info, 'N')] # 循环体的出节点开始节点,条件为N
88 | for break_node in break_nodes:
89 | out_nodes.append((Node(break_node), '')) # 将break节点添加到out_nodes中
90 | for continue_node in continue_nodes:
91 | CFG.append((node_info, [(Node(continue_node).id, '')])) # 将continue节点连接到循环的开始节点
92 | return CFG, out_nodes
93 | elif node.type == 'do_statement': # do while循环
94 | CFG = []
95 | edge = get_edge(in_nodes)
96 | node_info = Node(node)
97 | body = node.child_by_field_name('body') # 获取循环主体
98 | cfg, out_nodes = self.create_cfg(body, [(node_info, '')])
99 | first_node = cfg[0][0]
100 | CFG.append((first_node, edge))
101 | CFG.extend(cfg)
102 | for out_node in out_nodes: # 将循环主体的出节点与条件节点相连
103 | parent, label = out_node
104 | parent_id = parent.id
105 | CFG.append((node_info, [(parent_id, label)]))
106 | CFG.append((first_node, [(node_info.id, 'Y')])) # 将条件节点连接到循环主体的开始节点
107 | out_nodes = [(node_info, 'N')] # 循环体的出节点开始节点,条件为N
108 | break_nodes, continue_nodes = get_break_continue_node(node) # 求得循环内的break和continue节点
109 | for break_node in break_nodes:
110 | out_nodes.append((Node(break_node), ''))
111 | for continue_node in continue_nodes:
112 | CFG.append((node_info, [(Node(continue_node).id, '')]))
113 | return CFG, out_nodes
114 | elif node.type == 'switch_statement': # switch语句
115 | CFG = []
116 | edge = get_edge(in_nodes)
117 | node_info = Node(node)
118 | CFG.append((node_info, edge))
119 | body = node.child_by_field_name('body') # 获取switch的主体部分
120 | cfg, out_nodes = self.create_cfg(body, [(node_info, '')]) # 递归遍历case语句
121 | CFG.extend(cfg)
122 | break_nodes, _ = get_break_continue_node(node) # 将break语句添加到out_nodes当中
123 | for break_node in break_nodes:
124 | out_nodes.append((Node(break_node), ''))
125 | return CFG, out_nodes
126 | elif node.type == 'case_statement': # case语句
127 | CFG = []
128 | edge = get_edge(in_nodes)
129 | node_info = Node(node)
130 | CFG.append((node_info, edge))
131 | if node.children[0].type == 'case': # 如果是case语句
132 | in_nodes = [(node_info, 'Y')]
133 | for child in node.children[3:]:
134 | cfg, out_nodes = self.create_cfg(child, in_nodes)
135 | CFG.extend(cfg)
136 | in_nodes = out_nodes
137 | return CFG, in_nodes + [(node_info, 'N')]
138 | else: # default
139 | in_nodes = [(node_info, '')]
140 | for child in node.children[2:]:
141 | cfg, out_nodes = self.create_cfg(child, in_nodes)
142 | CFG.extend(cfg)
143 | in_nodes = out_nodes
144 | return CFG, in_nodes
145 | else:
146 | CFGs = [] # 存放每一个函数的CFG图
147 | for child in node.children:
148 | if child.type == 'function_definition': # 获得每一个函数的CFG图
149 | CFG, out_nodes = self.create_cfg(child, in_nodes)
150 | CFGs.append(CFG)
151 | return CFGs, in_nodes
152 |
153 | def construct_cfg(self, code):
154 | tree = self.parser.parse(bytes(code, 'utf-8'))
155 | root_node = tree.root_node
156 | CFGs, _ = self.create_cfg(root_node)
157 | for func_cfg in CFGs:
158 | cfg = Graph()
159 | for each in func_cfg:
160 | cfg.add_edge(each)
161 | # 删除break节点和continue节点
162 | # cfg_edges = copy.deepcopy(cfg.edges)
163 | # for node, edges in cfg_edges.items():
164 | # for i, edge in enumerate(edges):
165 | # if cfg.id_to_nodes[edge.id].type in ['break_statement', 'continue_statement']:
166 | # node_id = edge.id
167 | # next_node = cfg.edges[node_id][0].id
168 | # cfg.edges[node][i].id = next_node
169 | # del cfg.edges[node_id]
170 | # cfg.nodes.remove(cfg.id_to_nodes[node_id])
171 | cfg.get_def_use_info()
172 | self.cfgs.append(cfg)
173 |
174 | def see_cfg(self, code, filename='CFG', pdf=True, view=False):
175 | self.construct_cfg(code)
176 | dot = Digraph(comment=filename, strict=True)
177 | for cfg in self.cfgs:
178 | # for n in cfg.id_to_nodes:
179 | # input((n, cfg.id_to_nodes[n].id))
180 | for node in cfg.nodes:
181 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>"
182 | if node.is_branch:
183 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong')
184 | elif node.type == 'function_definition':
185 | dot.node(str(node.id), label=label, fontname='fangsong')
186 | else:
187 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong')
188 | for node, edges in cfg.edges.items():
189 | for edge in edges:
190 | next_node, label = edge.id, edge.label
191 | dot.edge(str(node), str(next_node), label=label)
192 | if pdf:
193 | dot.render(filename, view=view, cleanup=True)
194 | return self.cfgs
195 |
196 | if __name__ == '__main__':
197 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read())
198 | cfg = CFG('c')
199 | # cfg.see_tree(code, view=True)
200 | cfg.see_cfg(code, view=True)
201 |
--------------------------------------------------------------------------------
/File.py:
--------------------------------------------------------------------------------
1 | from AST import *
2 | import os
3 | import re
4 |
5 | MAX_LENGTH = 1000000 # 最大行数
6 | constant_type = ['number_literal', 'string_literal', 'character_literal', 'preproc_arg', 'true', 'false', 'null'] # 常量类型
7 |
8 | def Array(node, type):
9 | '''
10 | 目的-为了获取int a[m][n];声明中的变量名a和类型int **
11 | 输入-节点node和类型type,例如node: a[m][n], type: int
12 | 输出-变量名和类型,例如a, int**
13 | '''
14 | dim = 0
15 | node_type = node.type
16 | while node and node.type == node_type:
17 | dim += 1
18 | node = node.children[0]
19 | name = text(node)
20 | type = f'{type}{"*"*dim}'
21 | return name, type
22 |
23 | def Pointer(node, type):
24 | '''
25 | 目的-为了获取int *a[m];声明中的变量名a和类型int *
26 | 输入-节点node和类型type,例如node: *a[m], type: int
27 | 输出-变量名和类型,例如a, int**
28 | '''
29 | dim = 0
30 | if node.type not in ['array_declarator', 'pointer_declarator']:
31 | return None, type
32 | node_type = node.type
33 | while node and node.type == node_type:
34 | dim += 1
35 | node = node.children[1]
36 | if node.type == 'array_declarator': # char* argv[]
37 | while node and node.type == 'array_declarator':
38 | dim += 1
39 | node = node.children[0]
40 | if node.type == 'function_declarator':
41 | name = text(node.child_by_field_name('declarator'))
42 | else:
43 | name = text(node)
44 | type = f'{type}{"*"*dim}'
45 | return name, type
46 |
47 | class Identifier:
48 | def __init__(self, type, name, domain, structure_, class_):
49 | self.type = type # 变量名类型
50 | self.name = name # 变量名
51 | self.domain = domain # 作用域:[start line, end line]
52 | self.structure_ = structure_ # 所属结构体名称
53 | self.class_ = class_ # 所属类名称
54 |
55 | def __str__(self):
56 | str = f'name: {self.name}\ntype: {self.type}\ndomain: {self.domain}\n'
57 | if self.structure_:
58 | str += f'structure: {self.structure_}\n'
59 | if self.class_:
60 | str += f'class: {self.class_}\n'
61 | return str
62 |
63 | class Declaration:
64 | def __init__(self, node):
65 | '''
66 | 目的-获取int a[m], b=0;声明的变量名和类型: {a: int*, b: int}
67 | 输入-节点node,例如int a[m], b=0;
68 | 输出-变量名和类型,例如{a: int*, b: int}
69 | '''
70 | self.identifiers = {}
71 | if node.type not in ['declaration', 'field_declaration', 'parameter_declaration', 'type_definition']:
72 | return
73 | type_node = node.child_by_field_name('type')
74 | if type_node.type == 'struct_specifier': # 结构体类型
75 | if not type_node.child_by_field_name('name'):
76 | return
77 | self.type = text(type_node.child_by_field_name('name'))
78 | else: # 基本类型
79 | self.type = text(type_node)
80 | self.identifiers = {}
81 | if node.type == 'parameter_declaration': # 获取函数参数的变量名和类型
82 | name, type = self.get_name_and_type(node.child_by_field_name('declarator'))
83 | if name and type:
84 | self.identifiers[name] = type
85 | else:
86 | for child in node.children[1: -1]:
87 | name, type = self.get_name_and_type(child) # 获取声明的变量名和类型
88 | if name and type:
89 | self.identifiers[name] = type
90 |
91 | def get_name_and_type(self, node):
92 | '''
93 | 目的-获取以node为根节点的变量名和类型
94 | 输入-节点node,例如int a[m]
95 | 输出-变量名和类型,例如a, int*
96 | '''
97 | if node is None:
98 | return None, None
99 | if node.type == 'array_declarator': # 如果是数组
100 | return Array(node, self.type)
101 | elif node.type == 'pointer_declarator': # 如果是指针
102 | return Pointer(node, self.type)
103 | elif node.type in ['init_declarator', 'parameter_declaration']: # 如果是初始化声明,还需要遍历declarator
104 | return self.get_name_and_type(node.child_by_field_name('declarator'))
105 | elif node.type in ['identifier', 'field_identifier', 'type_identifier']: # 一般的变量的类型就是self.type
106 | return text(node), self.type
107 | else:
108 | return None, None
109 |
110 | def __call__(self):
111 | '''return {identifier: type}'''
112 | return self.identifiers
113 |
114 | class Structure:
115 | def __init__(self, node):
116 | '''
117 | 目的-获取结构体属性的类型field_vars,以及结构体声明的变量def_vars
118 | 输入-节点node,例如struct A{int a; int b;}a;
119 | 输出-结构体属性的类型field_vars={a:int, b:int},以及结构体声明的变量def_vars{a:A}
120 | '''
121 | self.field_vars = {}
122 | self.def_vars = {}
123 | if node.type != 'struct_specifier':
124 | print('Error: not a struct_specifier')
125 | return
126 | if not node.child_by_field_name('name') and node.type == 'struct_specifier':
127 | self.name = text(node.parent.child_by_field_name('declarator'))
128 | else:
129 | self.name = text(node.child_by_field_name('name'))
130 | declaration = Declaration(node.parent)
131 | self.def_vars = declaration()
132 | body = node.child_by_field_name('body')
133 | if body:
134 | for field_declaration in body.children:
135 | if text(field_declaration) not in ['{', '}']:
136 | declaration = Declaration(field_declaration)
137 | self.field_vars.update(declaration())
138 |
139 | class IdType:
140 | def __init__(self, file_path):
141 | self.file_path = file_path
142 | self.vars = {} # 存放变量名的定义信息{a: [Identifier, Identifier], b: [Identifier]}
143 | self.macro = {} # 存放宏定义的类型{ll: long long, ull: unsigned long long}
144 |
145 | def add_def_var(self, identifier, domain, structure_, class_):
146 | '''
147 | 目的-将声明的变量加入到vars中
148 | 输入-声明的变量identifier,例如{a: int*, b: int},作用域domain=[start line, end line], structure_和class_
149 | '''
150 | for name, type in identifier.items():
151 | self.vars.setdefault(name, [])
152 | self.vars[name].append(Identifier(type, name, domain, structure_, class_))
153 |
154 | def add_macro(self, name, type):
155 | self.macro[name] = type
156 |
157 | def query_type(self, id, line):
158 | '''
159 | 目的-查询变量id在line行的类型
160 | 输入-变量id和行号line
161 | 输出-变量id在line行的类型
162 | '''
163 | if id not in self.vars:
164 | return 'unknown'
165 | match_info = [] # 存放符合作用域的变量信息
166 | for info in self.vars[id]:
167 | if info.domain[0] <= line <= info.domain[1]:
168 | match_info.append(info)
169 | if not match_info:
170 | return 'unknown'
171 | # 将match_info按照domain的长度排序,找出来最局部变量的类型
172 | match_info.sort(key=lambda x: x.domain[1] - x.domain[0])
173 | type = match_info[0].type
174 | dim = type.count('*')
175 | new_type = type.replace('*', '')
176 | if new_type in self.macro:
177 | type = self.macro[new_type] + '*' * dim
178 | return type
179 |
180 | def __str__(self):
181 | str = []
182 | for var, info in self.vars.items():
183 | str.append(f"{var:=^40}")
184 | for i in info:
185 | str.append(i.__str__())
186 | return '\n'.join(str)
187 |
188 | class Func:
189 | def __init__(self, node, file_path):
190 | '''
191 | 目的-获取函数的返回类型,函数名,参数类型
192 | 输入-节点node,例如int main(int argc, char* argv[]), 文件路径file_path
193 | 输出-type = int, name = main, signature = {'return': int, 'name': main, 'parameters': [{'argc': int}, {'argv': char**}]}
194 | '''
195 | self.func_node = node
196 | self.file_path = file_path
197 | type = text(node.child_by_field_name('type')) # 这里的type有问题
198 | _, self.type = Pointer(node.child_by_field_name('declarator'), type)
199 | self.line = node.start_point[0] + 1
200 | self.id = hash((node.start_byte, node.end_byte)) % 1000000007
201 | while node.type not in ['function_declarator', 'parenthesized_declarator']: # parenthesized_declarator: 例如main(){}, void省略了
202 | node = node.child_by_field_name('declarator')
203 | if not node:
204 | self.name = None
205 | self.type = None
206 | self.signature = {'parameters':[]}
207 | return
208 | if node.type == 'function_declarator':
209 | self.name = text(node.child_by_field_name('declarator'))
210 | parameters = node.child_by_field_name('parameters')
211 | param_list = []
212 | for param in parameters.children:
213 | if text(param) not in [',', '(', ')']:
214 | declaration = Declaration(param)
215 | for name, type in declaration().items():
216 | param_list.append((name, type))
217 | else:
218 | self.name = self.type
219 | self.type = 'void'
220 | param_list = []
221 | self.signature = {'return': self.type, 'name': self.name, 'parameters': param_list}
222 |
223 | def __eq__(self, other):
224 | return self.signature == other.signature
225 |
226 | def __str__(self):
227 | param_str = ''
228 | for param in self.signature['parameters']:
229 | param_str += f"({param[0]}: {param[1]}) "
230 | return f"function name: {self.name}\nreturn type: {self.type}\nparameters: {param_str}\nfile path: {self.file_path}\n\n"
231 |
232 | class Function:
233 | def __init__(self):
234 | self.funcs = {}
235 | self.id_to_func = {}
236 |
237 | def add_func(self, node, file_path):
238 | func = Func(node, file_path)
239 | if not func.name:
240 | return
241 | if func.name not in self.funcs:
242 | self.funcs[func.name] = [func]
243 | self.id_to_func[func.id] = func
244 | else:
245 | if func not in self.funcs[func.name]:
246 | self.funcs[func.name].append(func)
247 | return func
248 |
249 | def match_func(self, func_name, param_node, expression):
250 | '''根据函数名和参数匹配函数,返回函数信息'''
251 | if len(self.funcs[func_name]) == 1: # 没有重构函数
252 | return self.funcs[func_name][0]
253 | # 如果有重构函数,如果参数的个数唯一,则返回对应函数
254 | if param_node.type != 'argument_list':
255 | print("not a param list")
256 | return None
257 | param_types = []
258 | for param in param_node.children[1: -1]:
259 | if param.type == ',':
260 | continue
261 | param_types.append(expression.traverse(param))
262 | match_num = 0
263 | match_func = None
264 | for func in self.funcs[func_name]:
265 | if len(func.signature['parameters']) == len(param_types):
266 | match_func = func
267 | match_num += 1
268 | if match_num == 1:
269 | return match_func
270 | else: # 根据参数类型来匹配
271 | for func in self.funcs[func_name]:
272 | func_param_types = [x[1] for x in func.signature['parameters']]
273 | if func_param_types == param_types:
274 | return func
275 | return None
276 |
277 | def __getitem__(self, func_name):
278 | return self.funcs[func_name]
279 |
280 | def __call__(self):
281 | return self.funcs
282 |
283 | def __str__(self):
284 | str = ''
285 | for func_name in self.funcs:
286 | for f in self.funcs[func_name]:
287 | str += f.__str__()
288 | return str
289 |
290 | class Constant:
291 | def __init__(self, node):
292 | '''
293 | 目的-获取常量的类型
294 | 输入-节点node,例如1, 1.0, 'a', "abc", TRUE, FALSE, NULL
295 | 输出-常量的类型,例如int, float, char, char*, bool, null
296 | '''
297 | if not node or node.type not in constant_type:
298 | self.type = 'unknown'
299 | return
300 | string = text(node)
301 | if node.type == 'preproc_arg': # define a const中的const
302 | # 删除string中的注释"//"和"/* */"
303 | string = re.sub(r'//.*', '', string)
304 | string = re.sub(r'/\*.*\*/', '', string)
305 | if self.is_float(string):
306 | self.type = 'float'
307 | elif self.is_int(string):
308 | self.type = 'int'
309 | elif self.is_char(string):
310 | self.type = 'char'
311 | elif self.is_string(string):
312 | self.type = 'char*'
313 | elif self.is_bool(string):
314 | self.type = 'bool'
315 | elif self.is_null(string):
316 | self.type = 'null'
317 | else:
318 | self.type = 'unknown'
319 | self.value = string
320 |
321 | def is_float(self, string):
322 | return bool(re.match(r'^[-+]?[0-9]*\.[0-9]+$', string))
323 |
324 | def is_int(self, string):
325 | hex_, oct_, dec_ = False, False, False
326 | hex_ = bool(re.match(r'^[-+]?0[xX][0-9a-fA-F]+$', string))
327 | oct_ = bool(re.match(r'^[-+]?0[0-7]+$', string))
328 | dec_ = bool(re.match(r'^[-+]?[0-9]+$', string))
329 | return hex_ or oct_ or dec_
330 |
331 | def is_char(self, string):
332 | return bool(re.match(r'^\'[^\']\'$', string))
333 |
334 | def is_string(self, string):
335 | return bool(re.match(r'^\"[^\"]*\"$', string))
336 |
337 | def is_bool(self, string):
338 | return string in ['TURE', 'FALSE']
339 |
340 | def is_null(self, string):
341 | return string == 'NULL'
342 |
343 | class Expression:
344 | def __init__(self, idType, Structure, Function):
345 | self.idType = idType
346 | self.structure_ = Structure
347 | self.function = Function
348 |
349 | def type(self, node):
350 | type = self.traverse(node)
351 | if type == 'unknown':
352 | return 'unknown'
353 | dim = type.count('*')
354 | new_type = type.replace('*', '').replace(' ', '')
355 | if new_type in self.idType.macro:
356 | type = self.idType.macro[new_type] + '*' * dim
357 | return type
358 |
359 | def traverse(self, node):
360 | '''
361 | 目的-从上往下遍历node表达式,获取表达式的类型
362 | 输入-idType, Structure, Function,分别表示变量名类型、结构体类型、函数类型
363 | 输出-表达式的类型
364 | '''
365 | if node.type in constant_type: # 常量类型通过Constant查询
366 | return Constant(node).type
367 | elif node.type == 'identifier': # 变量名类型通过idType查询
368 | return self.idType.query_type(text(node), node.start_point[0] + 1)
369 | elif node.type == 'sizeof_expression': # sizeof表达式返回int
370 | return 'int'
371 | elif node.type in ['unary_expression', 'update_expression']: # 一元表达式返回本身的类型
372 | return self.traverse(node.child_by_field_name('argument'))
373 | elif node.type == 'conditional_expression': # 三元表达式返回两个表达式的类型
374 | return self.traverse(node.child_by_field_name('consequence'))
375 | elif node.type == 'subscript_expression': # 数组类型,例如a[i],a的类型是int*,那么a[i]的类型是int
376 | argument_type = self.traverse(node.child_by_field_name('argument'))
377 | if argument_type == 'unknown':
378 | return 'unknown'
379 | return argument_type[:-1] # 去掉最后一个*号
380 | elif node.type == 'field_expression': # 结构体类型,例如a.b,a的类型是A,那么a.b的类型是int
381 | argument_type = self.traverse(node.child_by_field_name('argument'))
382 | if argument_type == 'unknown':
383 | return 'unknown'
384 | field = text(node.child_by_field_name('field'))
385 | op = text(node.children[1])
386 | if op == '->': # 例如a的类型是A*,那么a->b的时候,要删掉a类型的最后一个*号
387 | argument_type = argument_type[:-1]
388 | if argument_type not in self.structure_:
389 | return 'unknown'
390 | structure = self.structure_[argument_type]
391 | if field not in structure.field_vars:
392 | return 'unknown'
393 | return structure.field_vars[field]
394 | elif node.type == 'pointer_expression': # 指针类型,例如*a, &a
395 | op = text(node.children[0])
396 | argument_type = self.traverse(node.child_by_field_name('argument'))
397 | if argument_type == 'unknown':
398 | return 'unknown'
399 | if op == '*': # *a删除a类型的最后一个*号
400 | return argument_type[:-1]
401 | else: # &添加一个*号
402 | return argument_type + '*'
403 | elif node.type == 'call_expression': # 函数类型,例如f(a, b),f的类型是int(int, int),那么f(a, b)的类型是int
404 | callee = text(node.child_by_field_name('function'))
405 | if callee in self.function():
406 | parameters = node.child_by_field_name('arguments')
407 | match_func = self.function.match_func(callee, parameters, self)
408 | if match_func:
409 | return match_func.type
410 | return 'unknown'
411 | elif node.type == 'parenthesized_expression': # 括号表达式返回括号内的表达式类型
412 | return self.traverse(node.children[1])
413 | elif node.type == 'binary_expression': # 二元表达式
414 | op = text(node.children[1])
415 | if op in ['==', '!=', '>', '>=', '<', '<=', '&&', '||']:
416 | return 'bool'
417 | else: # 除了bool类型的二元表达式,其他的二元表达式返回强制类型转换后的类型
418 | left_type = self.traverse(node.child_by_field_name('left'))
419 | right_type = self.traverse(node.child_by_field_name('right'))
420 | return self.typecasting(left_type, right_type)
421 | elif node.type == 'assignment_expression': # 赋值表达式返回左边的表达式类型
422 | return self.traverse(node.child_by_field_name('left'))
423 | elif node.type == 'cast_expression': # 强制类型转换表达式返回强制类型转换后的类型
424 | return text(node.child_by_field_name('type'))
425 | else:
426 | return 'unknown'
427 |
428 | def typecasting(self, type1, type2):
429 | '''目的-将type1和type2进行强制类型转换'''
430 | def compare(type_list1, type_list2):
431 | if type1 in type_list1 and type2 in type_list2:
432 | return True
433 | if type1 in type_list2 and type2 in type_list1:
434 | return True
435 | return False
436 | def unsigned(types, copy=True):
437 | '''增加unsigned修饰'''
438 | if copy:
439 | new_types = types.copy()
440 | else:
441 | new_types = []
442 | for type in types:
443 | if type == 'int':
444 | new_types.append('unsigned')
445 | new_types.append(f'unsigned {type}')
446 | new_types.append(f'{type} unsigned')
447 | return new_types
448 | if type1 == 'bool': # bool强制转换成int
449 | type1 = 'int'
450 | if type2 == 'bool':
451 | type2 = 'int'
452 | if compare(unsigned(['char', 'short']), unsigned(['char', 'short'])): # char和short强制转换成int
453 | return 'int'
454 | elif compare(['float'], ['float']): # float强制转换成double
455 | return 'double'
456 | elif type1 == type2:
457 | return type1
458 | elif compare(unsigned(['char', 'short']), ['int']):
459 | return 'int'
460 | if compare(unsigned(['char', 'short', 'int']), unsigned(['int'], copy=False)): # unsigned类型优先级高
461 | return 'unsigned int'
462 | elif compare(unsigned(['char', 'short', 'int']), ['long']):
463 | return 'long'
464 | elif compare(unsigned(['char', 'short', 'int', 'long']), unsigned(['long'], copy=False)):
465 | return 'unsigned long'
466 | elif compare(unsigned(['char', 'short', 'int', 'long']), ['long long']):
467 | return 'long long'
468 | elif compare(unsigned(['char', 'short', 'int', 'long', 'long long']), unsigned(['long long'], copy=False)):
469 | return 'unsigned long long'
470 | elif compare(unsigned(['char', 'short', 'int', 'long', 'long long']), ['float']):
471 | return 'float'
472 | elif compare(unsigned(['char', 'short', 'int', 'long', 'long long']) + ['float'], ['double']):
473 | return 'double'
474 | elif '*' in type1: # 如果type1是指针类型,返回type1
475 | return type1
476 | elif '*' in type2:
477 | return type2
478 | else:
479 | return 'unknown'
480 |
481 | class File(AST):
482 | def __init__(self, language, file_path):
483 | super().__init__(language)
484 | self.structure_ = {} # {structure_name: structure}
485 | self.function = Function() # {function_name: [function]} 可能重构
486 | self.file_path = os.path.join(os.path.abspath('.'), file_path) # 文件路径
487 | self.idType = IdType(file_path) # 变量名类型
488 | self.CG = {} # 存放函数调用图
489 | self.unknown_call = {} # 存放未知函数调用
490 | self.unknown_id = {} # 存放未知函数名id
491 | code = r'{}'.format(open(self.file_path, 'r', encoding='utf-8', errors='ignore').read())
492 | tree = self.parser.parse(bytes(code, 'utf8'))
493 | self.root_node = tree.root_node
494 |
495 | def construct_file(self, node=None):
496 | '''
497 | 构建文件的信息,包括全局信息:结构体、全局变量类型、函数、宏定义
498 | '''
499 | if node is not None:
500 | root_node = node
501 | else:
502 | root_node = self.root_node
503 | for child in root_node.children:
504 | if child.type in ['preproc_ifdef', 'preproc_else', 'linkage_specification', 'declaration_list', 'ERROR']: # 如果是ifdef,则遍历ifdef内部的节点
505 | self.construct_file(child)
506 | elif child.type in ['declaration', 'struct_specifier']: # 结构体定义
507 | if child.type == 'struct_specifier': # struct A{int a; int b;};
508 | type_node = child
509 | else: # struct A{int a; int b;} a; 在定义的时候多了声明
510 | self.function.add_func(child, self.file_path)
511 | type_node = child.child_by_field_name('type')
512 | if type_node.type == 'struct_specifier' and type_node.child_by_field_name('body'): # 结构体类型,不是函数指针
513 | structure = Structure(type_node)
514 | self.structure_[structure.name] = structure
515 | domain = [child.end_point[0] + 1, MAX_LENGTH] # 作用域全局
516 | self.idType.add_def_var(structure.def_vars, domain, structure.name, None)
517 | if type_node.type in ['primitive_type', 'type_identifier']: # 全局变量类型
518 | declaration = Declaration(child)
519 | domain = [child.start_point[0] + 1, MAX_LENGTH]
520 | self.idType.add_def_var(declaration(), domain, None, None)
521 | elif child.type == 'type_definition': # 宏定义typedef
522 | type_node = child.child_by_field_name('type')
523 | if type_node.type == 'struct_specifier':
524 | structure = Structure(type_node)
525 | self.structure_[structure.name] = structure
526 | domain = [child.end_point[0] + 1, MAX_LENGTH]
527 | self.idType.add_def_var(structure.def_vars, domain, structure.name, None)
528 | for name, type in structure.def_vars.items():
529 | self.idType.add_macro(name, type)
530 | else:
531 | type = text(type_node)
532 | declaration = Declaration(child)
533 | for name, type in declaration().items():
534 | self.idType.add_macro(name, type)
535 | elif child.type == 'preproc_def': # 宏定义#define,根据常量确定类型
536 | name = text(child.child_by_field_name('name'))
537 | const = Constant(child.child_by_field_name('value'))
538 | domain = [child.start_point[0] + 1, MAX_LENGTH]
539 | self.idType.add_def_var({name: const.type}, domain, None, None)
540 | elif child.type == 'function_definition': # 获取函数内部的变量类型
541 | func_node = child
542 | func_info = self.function.add_func(func_node, self.file_path)
543 | if not func_info:
544 | continue
545 | for name, type in func_info.signature['parameters']:
546 | domain = [func_node.start_point[0] + 1, func_node.end_point[0] + 1]
547 | self.idType.add_def_var({name: type}, domain, None, None)
548 | body = func_node.child_by_field_name('body')
549 | self.get_local_type(body) # 获取函数复合语句内部的变量类型
550 | # self.construct_call_graph()
551 | # self.query_type(self.root_node)
552 | # print(self.idType)
553 | # print(self.structure_)
554 |
555 | def construct_call_graph(self):
556 | '''构建函数调用图'''
557 | self.expression = Expression(self.idType, self.structure_, self.function) # 定义表达式类,用来求表达式的类型
558 | def query_callee(node, caller):
559 | '''目的-查询调用者caller调用的函数callee'''
560 | if node.type == 'call_expression':
561 | callee = text(node.child_by_field_name('function'))
562 | arguments = node.child_by_field_name('arguments')
563 | if callee in self.function():
564 | match_func = self.function.match_func(callee, arguments, self.expression)
565 | if match_func:
566 | self.CG.setdefault(match_func.id, set())
567 | self.CG[match_func.id].add(caller)
568 | else: # 不是自己定义的函数
569 | self.unknown_call.setdefault(callee, set())
570 | self.unknown_call[callee].add(caller)
571 | if callee not in self.unknown_id:
572 | self.unknown_id[callee] = -len(self.unknown_id) - 1
573 | for child in node.children:
574 | query_callee(child, caller)
575 | for func_name in self.function(): # 变量文件内的所有函数,查询各个函数内部调用的函数
576 | for func in self.function[func_name]:
577 | query_callee(func.func_node, func.id)
578 |
579 | def get_local_type(self, node):
580 | '''获取函数内部变量名的类型'''
581 | if node.type not in ['compound_statement', 'else_clause']: # 如果不是复合语句或者else从句,直接pass
582 | return
583 | for statement in node.children: # 遍历复合语句(花括号)内的每一个语句
584 | if statement.type == 'compound_statement':
585 | self.get_local_type(statement)
586 | if statement.type == 'declaration': # 如果是声明语句,获取声明的变量名和类型
587 | declaration = Declaration(statement)
588 | domain = [statement.start_point[0] + 1, node.end_point[0] + 1]
589 | self.idType.add_def_var(declaration(), domain, None, None)
590 | elif statement.type in ['for_statement', 'while_statement', 'do_statement']: # 如果是循环语句,则循环语句内定义的变量的作用域是循环语句内
591 | if statement.type == 'for_statement': # for循环额外有一个初始化表达式
592 | initializer = statement.child_by_field_name('initializer')
593 | if initializer:
594 | declaration = Declaration(initializer)
595 | domain = [statement.start_point[0] + 1, statement.end_point[0] + 1]
596 | self.idType.add_def_var(declaration(), domain, None, None)
597 | self.get_local_type(statement.child_by_field_name('body'))
598 | elif statement.type == 'if_statement': # 如果是if语句,获取if语句内部的变量类型
599 | self.get_local_type(statement.child_by_field_name('consequence'))
600 | alternative = statement.child_by_field_name('alternative')
601 | if alternative:
602 | self.get_local_type(alternative)
603 |
604 | def query_type(self, root_node, id_types={}):
605 | '''以root_node节点为根节点,往下遍历AST树,获取变量名的类型'''
606 | self.expression = Expression(self.idType, self.structure_, self.function)
607 | for child in root_node.children:
608 | if 'expression' in child.type and child.type != 'expression_statement':
609 | exp = self.expression.type(child)
610 | print((text(child), exp))
611 | self.query_type(child)
612 | return id_types
613 |
614 | def see_cg(self, unknown=True, pdf=True, view=True):
615 | '''可视化函数调用图'''
616 | self.construct_call_graph()
617 | if not self.CG:
618 | return
619 | # dot = Digraph(comment=self.file_path, graph_attr={'rankdir': 'LR'})
620 | dot = Digraph(comment=self.file_path)
621 |
622 | nodes = set(self.CG.keys()) # 不画没有出边的节点
623 | for edges in self.CG.values():
624 | nodes |= set(edges)
625 | for edges in self.unknown_call.values():
626 | nodes |= set(edges)
627 |
628 | for funcs in self.function().values():
629 | for f in funcs:
630 | if f.id not in nodes:
631 | continue
632 | file_path = f.file_path.replace('\\', '/')
633 | label = f'name: {f.name}\\ntype: {f.type}\\nparameters: {list(f.signature["parameters"])}\\nfile path: {file_path}'
634 | # label = f.name
635 | dot.node(str(f.id), shape='rectangle', label=label, fontname='fangsong')
636 | for caller, callees in self.CG.items():
637 | for callee in callees:
638 | dot.edge(str(callee), str(caller))
639 | if unknown:
640 | for func in self.unknown_id:
641 | dot.node(str(self.unknown_id[func]), shape='rectangle', label=func, fontname='fangsong')
642 | for caller, callees in self.unknown_call.items():
643 | for callee in callees:
644 | dot.edge(str(callee), str(self.unknown_id[caller]))
645 | if pdf:
646 | cur_path = os.path.abspath('.')
647 | root_path = os.path.commonprefix([self.file_path, cur_path])
648 | rel_path = os.path.relpath(self.file_path, root_path).replace('\\', '/')
649 | output_path = os.path.join(root_path, "pdf", rel_path)
650 | dot.render(output_path, view=view, cleanup=True)
651 |
652 | def merge(self, other):
653 | self.structure_.update(other.structure_)
654 | self.idType.vars.update(other.idType.vars)
655 | for func_name, funcs in other.function().items():
656 | self.function().setdefault(func_name, [])
657 | for func in funcs:
658 | if func not in self.function()[func_name]:
659 | self.function()[func_name].append(func)
660 | self.function.id_to_func.update(other.function.id_to_func)
661 |
662 | class Dir(AST):
663 | def __init__(self, path, language='cpp'):
664 | super().__init__(language)
665 | self.path = os.path.abspath(path)
666 | self.filepaths = []
667 | for root, _, files in os.walk(self.path):
668 | for file in files:
669 | if file.split('.')[-1] in ['c', 'cpp', 'h', 'hpp', 'cc']:
670 | filepath = os.path.abspath(os.path.join(root, file))
671 | relpath = os.path.relpath(filepath, self.path).replace('\\', '/')
672 | self.filepaths.append(relpath)
673 |
674 | self.Include = {f: [] for f in self.filepaths}
675 | self.files = {f: None for f in self.filepaths}
676 | self.Included = {f: [] for f in self.filepaths}
677 | self.Indegree = {f: 0 for f in self.filepaths} # 按照拓扑排序,先把入度为0(没有引用文件或者全部引用完的)的文件的函数加载
678 | self.load = {f: False for f in self.filepaths}
679 | for f in self.filepaths:
680 | self.files[f] = File('c', os.path.join(self.path, f))
681 | code = r'{}'.format(open(os.path.join(self.path, f), 'r', encoding='utf-8', errors='ignore').read())
682 | tree = self.parser.parse(bytes(code, 'utf8'))
683 | root_node = tree.root_node
684 | self.find_include(root_node, f)
685 |
686 | cycles = self.detect_cycles()
687 | temp_include = {} # 为了解决循环引用的问题,先把循环引用的文件的第一个引用关系去掉,保存在temp_include中
688 | if cycles:
689 | print("There exists cycles")
690 | for cycle in cycles:
691 | print(cycle[-1], end=' ')
692 | for f in cycle:
693 | print('-> ' + f, end=' ')
694 | self.Include[cycle[0]].remove(cycle[1])
695 | self.Included[cycle[1]].remove(cycle[0])
696 | self.Indegree[cycle[0]] -= 1
697 | temp_include[cycle[0]] = cycle[1]
698 | print()
699 |
700 | for f in self.filepaths:
701 | self.files[f].construct_file()
702 |
703 | while True:
704 | finish_include = True
705 | for f in self.files:
706 | if self.Indegree[f] == 0 and not self.load[f]:
707 | finish_include = False
708 | for include_file in self.Include[f]:
709 | self.files[f].merge(self.files[include_file])
710 | self.Include[f] = []
711 | self.files[f].construct_file()
712 | self.files[f].construct_call_graph()
713 | for included_file in self.Included[f]:
714 | self.Indegree[included_file] -= 1
715 | self.load[f] = True
716 | break
717 | if finish_include:
718 | if temp_include:
719 | for f, include in temp_include.items():
720 | self.files[f].merge(self.files[include])
721 | self.files[f].construct_file()
722 | self.files[f].construct_call_graph()
723 | for f, indegree in self.Indegree.items():
724 | print(f)
725 | self.files[f].see_cg(view=False, unknown=True)
726 | if indegree != 0:
727 | print(f"\tError: not finish yet!")
728 | break
729 |
730 | def detect_cycles(self):
731 | '''检测self.Include是否有环,例如A引用了B,B引用了C,C引用了A,返回这样的环的列表[[A,B,C],[...]]'''
732 | visited = set()
733 | stack = []
734 | cycles = []
735 | def dfs(node):
736 | if node in stack:
737 | cycle = stack[stack.index(node):]
738 | cycles.append(cycle)
739 | return
740 | if node in visited:
741 | return
742 | visited.add(node)
743 | stack.append(node)
744 | for neighbor in self.Include.get(node, []):
745 | dfs(neighbor)
746 | stack.pop()
747 | for node in self.Include:
748 | dfs(node)
749 | return cycles
750 |
751 | def find_include(self, node, f):
752 | for child in node.children:
753 | if child.type == 'preproc_include':
754 | path = child.child_by_field_name('path')
755 | if path.type in ['string_literal', 'system_lib_string']:
756 | # if path.type == 'string_literal':
757 | include = text(path)[1:-1]
758 | if include == 'config.h': # VS的配置头文件忽略
759 | continue
760 | include_path = os.path.normpath(os.path.join(os.path.dirname(f), include)).replace('\\', '/')
761 | if include_path in self.Include and include_path != f:
762 | self.Include[f].append(include_path)
763 | else:
764 | file = os.path.basename(include_path)
765 | is_find = False
766 | for f_ in self.filepaths:
767 | if file == os.path.basename(f_):
768 | include_path = f_
769 | if f == include_path:
770 | continue
771 | self.Include[f].append(include_path)
772 | is_find = True
773 | break
774 | if not is_find:
775 | # print(f'In file {f}, {include} not in files')
776 | continue
777 | self.Included[include_path].append(f)
778 | self.Indegree[f] += 1
779 | elif child.type == 'preproc_ifdef': # ifdef
780 | self.find_include(child, f)
781 |
782 | if __name__ == '__main__':
783 | # file = File('c', 'test/5/littlefs-master/lfs_util.h')
784 | # file.construct_file()
785 | # file.see_cg(view=True)
786 | dir = Dir('./test/1')
787 |
--------------------------------------------------------------------------------