├── requirements.txt ├── .gitignore ├── PDG.py ├── README.md ├── CG.py ├── test.py ├── DDG.py ├── AST.py ├── CDG.py ├── utils.py ├── CFG.py └── File.py /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | tree-sitter=0.20.2 3 | inflection 4 | graphviz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pdf 2 | build 3 | __pycache__ 4 | test.c 5 | test 6 | .vscode 7 | *.gv 8 | pdf -------------------------------------------------------------------------------- /PDG.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from CFG import * 3 | from CDG import * 4 | from DDG import * 5 | from graphviz import Digraph 6 | 7 | class PDG(CFG): 8 | def construct_pdg(self, code): 9 | cdg = CDG(self.language) 10 | ddg = DDG(self.language) 11 | cdg.construct_cdg(code) 12 | ddg.construct_ddg(code) 13 | self.pdgs = [] 14 | for cdg, ddg in zip(cdg.cdgs, ddg.ddgs): 15 | pdg = cdg 16 | for node, edges in ddg.edges.items(): 17 | pdg.edges.setdefault(node, []) 18 | for edge in edges: 19 | pdg.edges[node].append(edge) 20 | self.pdgs.append(pdg) 21 | return self.pdgs 22 | 23 | def see_pdg(self, code, filename='PDG', pdf=True, view=False): 24 | self.construct_pdg(code) 25 | dot = Digraph(comment='PDG', strict=True) 26 | for pdg in self.pdgs: 27 | for v in pdg.nodes: 28 | if v < 0: 29 | continue 30 | node = pdg.id_to_nodes[v] 31 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>" 32 | if node.is_branch: 33 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong') 34 | elif node.type == 'function_definition': 35 | dot.node(str(node.id), label=label, fontname='fangsong') 36 | else: 37 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong') 38 | for u in pdg.edges[v]: 39 | dot.edge(str(u.id), str(v)) 40 | for v in pdg.edges: 41 | for u in pdg.edges[v]: 42 | if u.type == 'DDG': 43 | dot.edge(str(v), str(u.id), label=', '.join(u.token), style='dotted') 44 | else: 45 | dot.edge(str(u.id), str(v)) 46 | if pdf: 47 | dot.render(filename, view=view, cleanup=True) 48 | 49 | if __name__ == '__main__': 50 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read()) 51 | pdg = PDG('c') 52 | pdg.see_pdg(code, view=True) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 更新后的代码地址放在了:https://github.com/rebibabo/INVEST 2 | 3 | 这是一个基于tree-sitter编译工具进行静态程序分析的demo, 并使用可视化工具graphviz,能够生成抽象语法树AST、控制流图CFG、数据依赖DFG、数据依赖图CDG、程序依赖图PDG、函数调用图CG等。 4 | tree-sitter网址:https://tree-sitter.github.io/tree-sitter/ 5 | 6 | ## 环境配置 7 | 确保已经安装了graphviz,在windows上,官网https://www.graphviz.org/ 下载graphviz之后,配置环境变量为安装路径下的bin文件夹,例如D:\graphviz\bin\,注意末尾的'\\'不能省略,如果是linux上,运行下面命令安装: 8 | ``` 9 | sudo apt-get install graphviz graphviz-doc 10 | ``` 11 | 接着运行 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## 生成AST树 17 | AST.py能够生成AST树以及tokens,首先构造类,参数为代码语言,目前tree-sitter能够编译的语言都能够生成。 18 | ``` 19 | ast = AST('c') 20 | ``` 21 | 接着运行下面代码可以显示AST树 22 | ``` 23 | ast.see_tree(code, view=True) 24 | ``` 25 | ![AST](https://github.com/rebibabo/TSA/assets/80667434/6d1aae84-3c46-4978-844e-6006e8623718) 26 | 27 | 运行完成之后,会在当前目录下生成ast_tree.pdf,为可视化的ast树,可以通过设置参数view=False在生成pdf文件的同时不查看文件,pdf=False不生成可视化的pdf文件,设置参数filename="filename"来更改输出文件的名称。 28 | 获得代码的tokens可以运行下面的代码,返回值为token的列表。 29 | ``` 30 | ast.tokenize(code) 31 | #['int', 'main', '(', ')', '{', 'int', 'abc', '=', '1', ';', 'int', 'b', '=', '2', ';', 'int', 'c', '=', 'a', '+', 'b', ';', 'while', '(', 'i', '<', '10', ')', '{', 'i', '++', ';', '}', '}'] 32 | ``` 33 | 34 | ## 生成CFG 35 | CFG.py继承自AST类,能够生成控制流图,运行下面命令可以获得代码的CFG: 36 | ``` 37 | cfg = CFG('c') 38 | cfg.see_cfg(code, view=True) 39 | ``` 40 | 生成的CFG图样例: 41 | ![CFG](https://github.com/rebibabo/static_program_analysis_by_tree_sitter/assets/80667434/e9f6a213-a523-4a51-a6dd-849970e4d6fa) 42 | see_cfg的参数和see_tree的参数一样 43 | 44 | ## 生成CDG 45 | CDG.py继承自CFG类,能够生成控制依赖图,运行下面代码能够获得CDG图: 46 | ``` 47 | cdg = CDG('c') 48 | cdg.see_cdg(code, view=True) 49 | ``` 50 | 生成的CDG图样例: 51 | ![CDG](https://github.com/rebibabo/static_program_analysis_by_tree_sitter/assets/80667434/c8a3c611-f9e1-4953-afae-64a8684e92ea) 52 | 53 | ## 生成DDG 54 | DDG.py继承自CFG类,能够生成数据依赖图,运行下面代码能够获得DDG图: 55 | ``` 56 | ddg = DDG('c') 57 | ddg.see_ddg(code, view=True) 58 | ``` 59 | 生成的DDG图样例: 60 | ![DDG](https://github.com/rebibabo/static_program_analysis_by_tree_sitter/assets/80667434/5368b90a-a9e0-48e1-9f16-0b2add0e7f7a) 61 | 62 | ## 生成PDG 63 | PDG.py将CDG和DDG的节点和边结合在一起,运行下面代码获得PDG图: 64 | ``` 65 | pdg = PDG('c') 66 | pdg.see_pdg(code, view=True) 67 | ``` 68 | 生成的PDG图样例: 69 | ![PDG](https://github.com/rebibabo/static_program_analysis_by_tree_sitter/assets/80667434/5e9b495e-97f3-45bd-b2c8-cfe220ebaaf8) 70 | 71 | ## 生成CG 72 | File.py继承自AST.py,能够生成函数调用图,运行下面代码能够生成单个项目的CG图 73 | ``` 74 | file = File("path to project") 75 | file.see_cg(code, view=True) 76 | ``` 77 | 生成项目目录所有文件的CG图: 78 | ``` 79 | dir = Dir('path to project') 80 | ``` 81 | 生成CG图样例: 82 | ![捕获](https://github.com/rebibabo/static_program_analysis_by_tree_sitter/assets/80667434/b7dd8037-984e-4bea-920d-d3bdd1b4f8fe) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /CG.py: -------------------------------------------------------------------------------- 1 | from AST import * 2 | 3 | def get_identifiers(node): 4 | call_nodes = [] 5 | if node.child_count == 0: 6 | return call_nodes 7 | for child in node.children: 8 | if child.type == 'identifier': 9 | call_nodes.append(child) 10 | else: 11 | call_nodes.extend(get_identifiers(child)) 12 | return call_nodes 13 | 14 | class Function: 15 | def __init__(self, node): 16 | self.type = text(node.child_by_field_name('type')) 17 | while node.type != 'function_declarator': 18 | node = node.child_by_field_name('declarator') 19 | self.name = text(node.child_by_field_name('declarator')) 20 | paramters = node.child_by_field_name('parameters') 21 | self.signature = {'return': self.type, 'name': self.name, 'parameters': {}} 22 | self.id = node.start_point[0] 23 | for param in paramters.children: 24 | if text(param) not in [',', '(', ')']: 25 | type = text(param.child_by_field_name('type')) 26 | name = text(param.child_by_field_name('declarator')) 27 | self.signature['parameters'][name] = type 28 | 29 | def __eq__(self, other): 30 | return self.signature == other.signature 31 | 32 | def __str__(self): 33 | return str(self.signature) 34 | 35 | class CG(AST): 36 | def __init__(self, language): 37 | super().__init__(language) 38 | self.node_set = {} # 存放每一个节点的信息 39 | self.cg = [] # 存放每一个函数的CFG图 40 | self.funcs = {} 41 | 42 | def create_cg(self, root_node): 43 | func_def_nodes = {} 44 | for child in root_node.children: # 先找到所有的函数定义节点 45 | if child.type == 'function_definition': 46 | func_info = Function(child) 47 | func_def_nodes[func_info.name] = child 48 | self.funcs[func_info.name] = func_info 49 | for node in func_def_nodes: # 再依次遍历每一个函数中调用的函数 50 | func_node = func_def_nodes[node] 51 | ids = get_identifiers(func_node) 52 | cg_call_nodes = set() 53 | for id in ids: 54 | id_name = text(id) 55 | if id_name in func_def_nodes: 56 | cg_call_nodes.add(self.funcs[id_name].id) 57 | self.cg.append((Function(func_node), cg_call_nodes)) 58 | return self.cg 59 | 60 | def see_cg(self, code_path, filename='CG', pdf=True, view=False): 61 | code = r'{}'.format(open(code_path, 'r', encoding='utf-8').read()) 62 | tree = self.parser.parse(bytes(code, 'utf8')) 63 | root_node = tree.root_node 64 | CG = self.create_cg(root_node) 65 | dot = Digraph(comment=filename) 66 | for node in CG: 67 | dot.node(str(node[0].id), shape='rectangle', label=node[0].name, fontname='fangsong') 68 | for call_node in node[1]: 69 | if str(node[0].id) != str(call_node): 70 | dot.edge(str(node[0].id), str(call_node)) 71 | if pdf: 72 | dot.render(filename, view=view, cleanup=True) 73 | 74 | 75 | if __name__ == '__main__': 76 | cg = CG('c') 77 | cg.see_cg('test.c', view=True) 78 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | def detect_cycles(graph): 2 | visited = set() 3 | stack = [] 4 | cycles = [] 5 | 6 | def dfs(node): 7 | if node in stack: 8 | cycle = stack[stack.index(node):] 9 | cycles.append(cycle) 10 | return 11 | if node in visited: 12 | return 13 | 14 | visited.add(node) 15 | stack.append(node) 16 | 17 | for neighbor in graph.get(node, []): 18 | dfs(neighbor) 19 | 20 | stack.pop() 21 | 22 | for node in graph: 23 | dfs(node) 24 | 25 | return cycles 26 | 27 | # 字典表示的图,键是文件名,值是包含的文件列表 28 | file_graph = { 29 | '3rd/lua/lapi.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'], 30 | '3rd/lua/lapi.h': ['3rd/lua/lstate.h'], 31 | '3rd/lua/lcode.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstring.h', '3rd/lua/lvm.h'], 32 | '3rd/lua/ldebug.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'], 33 | '3rd/lua/ldebug.h': ['3rd/lua/lstate.h'], 34 | '3rd/lua/ldo.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'], 35 | '3rd/lua/ldo.h': ['3rd/lua/lstate.h'], 36 | '3rd/lua/ldump.c': ['3rd/lua/lstate.h'], 37 | '3rd/lua/lfunc.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'], 38 | '3rd/lua/lgc.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h'], 39 | '3rd/lua/lgc.h': ['3rd/lua/lstate.h'], 40 | '3rd/lua/llex.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'], 41 | '3rd/lua/lmem.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h'], 42 | '3rd/lua/lobject.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/lvm.h'], 43 | '3rd/lua/lparser.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'], 44 | '3rd/lua/lstate.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h'], 45 | '3rd/lua/lstate.h': ['3rd/lua/ltm.h'], 46 | '3rd/lua/lstring.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'], 47 | '3rd/lua/lstring.h': ['3rd/lua/lgc.h', '3rd/lua/lstate.h'], 48 | '3rd/lua/ltable.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/lvm.h'], 49 | '3rd/lua/ltests.c': ['3rd/lua/lapi.h', '3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h'], 50 | '3rd/lua/ltm.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'], 51 | '3rd/lua/ltm.h': ['3rd/lua/lstate.h'], 52 | '3rd/lua/luac.c': ['3rd/lua/ldebug.h', '3rd/lua/lstate.h'], 53 | '3rd/lua/lundump.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lstring.h'], 54 | '3rd/lua/lvm.c': ['3rd/lua/ldebug.h', '3rd/lua/ldo.h', '3rd/lua/lgc.h', '3rd/lua/lstate.h', '3rd/lua/lstring.h', '3rd/lua/ltm.h', '3rd/lua/lvm.h'], 55 | '3rd/lua/lvm.h': ['3rd/lua/ldo.h', '3rd/lua/ltm.h'], 56 | 'A': ['B'], 57 | 'B': ['C'], 58 | 'C': ['A'] 59 | } 60 | 61 | 62 | cycle = detect_cycles(file_graph) 63 | if cycle: 64 | print("存在环路:", cycle) 65 | else: 66 | print("不存在环路") 67 | -------------------------------------------------------------------------------- /DDG.py: -------------------------------------------------------------------------------- 1 | from CFG import * 2 | from utils import * 3 | from graphviz import Graph 4 | 5 | class DDG(CFG): 6 | def construct_ddg(self, code): 7 | # 参考https://home.cs.colorado.edu/~kena/classes/5828/s99/lectures/lecture25.pdf 中19页的算法 8 | if self.check_syntax(code): 9 | print('Syntax Error') 10 | exit(1) 11 | cfgs = self.see_cfg(code) 12 | self.ddgs = [] 13 | for cfg in cfgs: 14 | edge = {} 15 | ddg = {} 16 | defs = cfg.defs 17 | uses = cfg.uses 18 | # print(f"defs: {defs}") 19 | # print(f"uses: {uses}") 20 | # There is a data dependence from X to Y with respect to a variable v iff 21 | # there is a non-null path p from X to Y with no intervening definition of v either: 22 | # X contains a definition of v and Y a use of v; 23 | # X contains a use of v and Y a definition of v; or 24 | # X contains a definition of v and Y a definition of v 25 | 26 | # def X to use Y 27 | for X in defs: 28 | if X not in uses: 29 | continue 30 | def_ = defs[X] 31 | use_ = uses[X] 32 | for d in def_: 33 | for u in use_: 34 | paths = cfg.findAllPath(d, u) 35 | for path in paths: 36 | is_arrival = True 37 | for n in path[1:-1]: 38 | node = cfg.id_to_nodes[n] 39 | if X in node.defs: 40 | is_arrival = False 41 | break 42 | if not is_arrival: 43 | break 44 | edge.setdefault((d, u), set()) 45 | edge[(d, u)].add(X) 46 | # use X to def Y 47 | for X in uses: 48 | if X not in defs: 49 | continue 50 | use_ = uses[X] 51 | def_ = defs[X] 52 | for u in use_: 53 | for d in def_: 54 | paths = cfg.findAllPath(u, d) 55 | for path in paths: 56 | is_arrival = True 57 | for n in path[1:-1]: 58 | node = cfg.id_to_nodes[n] 59 | if X in node.defs: 60 | is_arrival = False 61 | break 62 | if not is_arrival: 63 | break 64 | edge.setdefault((u, d), set()) 65 | edge[(u, d)].add(X) 66 | # def X to def Y 67 | for X in defs: 68 | def_ = defs[X] 69 | for d1 in def_: 70 | for d2 in def_: 71 | paths = cfg.findAllPath(d1, d2) 72 | for path in paths: 73 | is_arrival = True 74 | for n in path[1:-1]: 75 | node = cfg.id_to_nodes[n] 76 | if X in node.defs: 77 | is_arrival = False 78 | break 79 | if not is_arrival: 80 | break 81 | edge.setdefault((d1, d2), set()) 82 | edge[(d1, d2)].add(X) 83 | for (u, v), Xs in edge.items(): 84 | ddg.setdefault(u, []) 85 | ddg[u].append(Edge(v, type='DDG', token=Xs)) 86 | cfg.edges = ddg 87 | self.ddgs.append(cfg) 88 | return self.ddgs 89 | 90 | def see_ddg(self, code, filename='DDG', pdf=True, view=False): 91 | self.construct_ddg(code) 92 | dot = Digraph(comment=filename, strict=True) 93 | for ddg in self.ddgs: 94 | for node in ddg.nodes: 95 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>" 96 | if node.is_branch: 97 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong') 98 | elif node.type == 'function_definition': 99 | dot.node(str(node.id), label=label, fontname='fangsong') 100 | else: 101 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong') 102 | for v in ddg.edges: 103 | for u in ddg.edges[v]: 104 | dot.edge(str(v), str(u.id), label=', '.join(u.token), style='dotted') 105 | if pdf: 106 | dot.render(filename, view=view, cleanup=True) 107 | 108 | if __name__ == '__main__': 109 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read()) 110 | ddg = DDG('c') 111 | ddg.see_ddg(code ,view=True) -------------------------------------------------------------------------------- /AST.py: -------------------------------------------------------------------------------- 1 | from tree_sitter import Parser, Language 2 | from graphviz import Digraph 3 | import os 4 | 5 | text = lambda node: node.text.decode('utf-8') 6 | 7 | class Node: 8 | '''用于存储AST树的节点信息,将tree-sitter.Node类型转换为Node类型,tree-sitter不能序列化''' 9 | def __init__(self, node, id): 10 | self.type = node.type 11 | self.start_byte = node.start_byte 12 | self.end_byte = node.end_byte 13 | self.start_point = node.start_point 14 | self.end_point = node.end_point 15 | self.text = text(node) 16 | self.id = id 17 | 18 | def __eq__(self, other): 19 | return self.id == other.id 20 | 21 | class TreeNode: 22 | '''用于存储AST树''' 23 | def __init__(self, root): 24 | '''输入: root,树的根节点,为tree-sitter.Node类型''' 25 | self.nodes = [Node(root, 0)] 26 | self.id_to_node = {0: root} 27 | self.edges = {} 28 | self.traverse_tree(root) 29 | 30 | def traverse_tree(self, node, pid=0): # python递归函数想要修改函数参数的值,只能通过列表等方式 31 | '''输入: node,当前节点,为tree-sitter.Node类型''' 32 | children = [] 33 | for child in node.children: 34 | id = len(self.nodes) # 唯一标记节点的id 35 | child_node = Node(child, id) 36 | self.nodes.append(child_node) 37 | self.id_to_node[id] = child_node 38 | children.append(id) 39 | self.traverse_tree(child, id) 40 | self.edges[pid] = children 41 | 42 | def print_tree(self): 43 | '''打印AST树''' 44 | def dfs(node, depth): 45 | if self.edges[node.id] == [] and node.type != node.text: 46 | print(' ' * depth + node.type + ': ' + node.text) 47 | else: 48 | print(' ' * depth + node.type) 49 | for child in self.edges[node.id]: 50 | dfs(self.nodes[child], depth + 1) 51 | dfs(self.nodes[0], 0) 52 | 53 | def get_node(self, id): 54 | return self.id_to_node[id] 55 | 56 | class AST: 57 | def __init__(self, language): 58 | self.language = language 59 | if not os.path.exists(f'./build/{language}-languages.so'): 60 | if not os.path.exists(f'./tree-sitter-{language}'): 61 | os.system(f'git clone https://github.com/tree-sitter/tree-sitter-{language}') 62 | Language.build_library( 63 | f'./build/{language}-languages.so', 64 | [ 65 | f'./tree-sitter-{language}', 66 | ] 67 | ) 68 | LANGUAGE = Language(f'./build/{language}-languages.so', language) 69 | parser = Parser() 70 | parser.set_language(LANGUAGE) 71 | self.parser = parser 72 | 73 | def see_tree(self, code, filename='ast_tree', pdf=True, view=False): 74 | ''' 75 | 生成AST树的可视化图 76 | code: 输入的代码 77 | filename: 生成的文件名 78 | pdf: 是否生成pdf文件 79 | view: 是否打开pdf文件 80 | ''' 81 | tree = self.parser.parse(bytes(code, 'utf8')) 82 | root_node = tree.root_node 83 | tree = TreeNode(root_node) 84 | dot = Digraph(comment='AST Tree', strict=True) 85 | for edge, children in tree.edges.items(): 86 | node = tree.get_node(edge) 87 | dot.node(str(edge), shape='rectangle', label=node.type, fontname='fangsong') 88 | dot.edges([(str(edge), str(child)) for child in children]) 89 | if children == []: 90 | dot.node(str(-edge), shape='ellipse', label=node.text, fontname='fangsong') 91 | dot.edges([(str(edge), str(-edge))]) 92 | if pdf: 93 | dot.render(filename, view=view, cleanup=True) 94 | 95 | def tokenize(self, code): 96 | '''输入代码code,返回token列表''' 97 | def tokenize_help(node, tokens): 98 | # 遍历整个AST树,返回符合func的节点列表results 99 | if not node.children: 100 | tokens.append(text(node)) 101 | return 102 | for n in node.children: 103 | tokenize_help(n, tokens) 104 | tree = self.parser.parse(bytes(code, 'utf8')) 105 | root_node = tree.root_node 106 | print(type(root_node)) 107 | tokens = [] 108 | tokenize_help(root_node, tokens) 109 | return tokens 110 | 111 | def check_syntax(self, code): 112 | '''检查代码是否有语法错误''' 113 | tree = self.parser.parse(bytes(code, 'utf8')) 114 | # 找出来Error的位置 115 | root_node = tree.root_node 116 | error_nodes = [] 117 | def find_error(node): 118 | if node.type == 'ERROR': 119 | error_nodes.append(node) 120 | for child in node.children: 121 | find_error(child) 122 | find_error(root_node) 123 | for i, node in enumerate(error_nodes): 124 | print(f"error {i:>3} : line {node.start_point[0]:>3} row {node.start_point[1]:>3} -to- line {node.end_point[0]:>3} row {node.end_point[1]:>3}") 125 | return tree.root_node.has_error 126 | 127 | if __name__ == '__main__': 128 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read()) 129 | ast = AST('cpp') 130 | print(ast.tokenize(code)) 131 | ast.see_tree(code, view=True) 132 | # node = Node(1, 0, 'A', 0, 0, 0, 0, 'A') 133 | # print(node()) -------------------------------------------------------------------------------- /CDG.py: -------------------------------------------------------------------------------- 1 | from CFG import * 2 | from utils import * 3 | from graphviz import Graph 4 | 5 | class Tree: 6 | def __init__(self, V, children, root): # 输入节点集合,children为字典,key为节点,value为子节点列表,根节点 7 | self.vertex = V 8 | self.children = children 9 | self.root = root 10 | self.parent = {} 11 | for node in children: # 初始化parent字典 12 | for each in children[node]: 13 | self.parent[each] = node 14 | self.parent[root] = root 15 | for v in V: 16 | if v not in self.children: 17 | self.children[v] = [] 18 | self.depth = self.get_nodes_depth(root, {root:0}) 19 | 20 | def get_nodes_depth(self, root, depth): 21 | # 递归计算每个节点的深度 22 | for child in self.children[root]: 23 | depth[child] = depth[root] + 1 24 | depth = self.get_nodes_depth(child, depth) 25 | return depth 26 | 27 | def get_lca(self, a, b): 28 | # 计算a,b的最近公共祖先 29 | if self.depth[a] > self.depth[b]: 30 | diff = self.depth[a] - self.depth[b] 31 | while diff > 0: 32 | a = self.parent[a] 33 | diff -= 1 34 | elif self.depth[a] < self.depth[b]: 35 | diff = self.depth[b] - self.depth[a] 36 | while diff > 0: 37 | b = self.parent[b] 38 | diff -= 1 39 | while a != b: 40 | a = self.parent[a] 41 | b = self.parent[b] 42 | return a 43 | 44 | def reset_by_parent(self): 45 | # 根据parent字典重置children字典 46 | self.children = {v:[] for v in self.vertex} 47 | for node in self.parent: 48 | if node != self.parent[node]: 49 | self.children[self.parent[node]].append(node) 50 | 51 | def see_tree(self): 52 | dot = Graph(comment='Tree') 53 | for node in self.vertex: 54 | dot.node(str(node), shape='rectangle', label=str(node), fontname='fangsong') 55 | for node in self.children: 56 | for child in self.children[node]: 57 | dot.edge(str(node), str(child)) 58 | dot.view() 59 | 60 | class CDG(CFG): 61 | def get_subTree(self, cfg): 62 | # 按照广度优先遍历,找出一个子树 63 | V, E, Exit = cfg.nodes, cfg.edges, cfg.Exit 64 | visited = {v:False for v in V} 65 | queue = [Exit] 66 | visited[Exit] = True 67 | subTree = {} 68 | while queue: 69 | node = queue.pop() 70 | if node not in E: 71 | continue 72 | for edge in E[node]: 73 | v = edge.id 74 | if not visited[v]: 75 | queue.append(v) 76 | visited[v] = True 77 | subTree.setdefault(node, []) 78 | subTree[node].append(v) 79 | return subTree 80 | 81 | def get_prev(self, cfgs): 82 | # 计算每个节点的前驱节点 83 | prev = {} 84 | for cfg in cfgs: 85 | for node, edges in cfg.edges.items(): 86 | prev.setdefault(node, []) 87 | for next_node in edges: 88 | id = next_node.id 89 | prev.setdefault(id, []) 90 | prev[id].append(node) 91 | return prev 92 | 93 | def post_dominator_tree(self, cfgs, prev): 94 | # 生成后支配树 95 | PDT = [] 96 | for cfg in cfgs: # 遍历每一个函数的CFG 97 | subTree = self.get_subTree(cfg) # 找出一个子树 98 | V, root = cfg.nodes, cfg.Exit 99 | tree = Tree(V, subTree, root) # 生成树 100 | changed = True 101 | while changed: 102 | changed = False 103 | for v in V: # dominator tree算法 104 | if v != root: 105 | for u in prev[v]: 106 | parent_v = tree.parent[v] 107 | if u not in tree.vertex: 108 | cfg.see_graph() 109 | if u != parent_v and parent_v != tree.get_lca(u, parent_v): 110 | tree.parent[v] = tree.get_lca(u, parent_v) 111 | changed = True 112 | tree.reset_by_parent() # 根据parent字典重置children字典 113 | PDT.append(tree) 114 | return PDT 115 | 116 | def dominance_frontier(self, code): 117 | # 输入代码,返回CFG和支配边界 118 | cfgs = self.see_cfg(code) 119 | reverse_cfgs = [cfg.reverse() for cfg in cfgs] # 计算逆向CFG 120 | prev = self.get_prev(reverse_cfgs) # 计算每个节点的前驱节点 121 | PDT = self.post_dominator_tree(reverse_cfgs, prev) # 输入逆向CFG,输出后支配树 122 | DF = [] 123 | for cfg, tree in zip(reverse_cfgs, PDT): 124 | V = cfg.nodes 125 | DF.append({v:[] for v in V}) 126 | for v in V: 127 | if len(prev[v]) > 1: 128 | for p in prev[v]: 129 | runner = p 130 | while runner != tree.parent[v]: 131 | DF[-1][runner].append(v) 132 | runner = tree.parent[runner] 133 | return cfgs, DF 134 | 135 | def construct_cdg(self, code): 136 | # 输入代码,返回CDG 137 | cfgs, DF = self.dominance_frontier(code) 138 | self.cdgs = [] 139 | for cfg, df in zip(cfgs, DF): 140 | for v in df: 141 | df[v] = [Edge(u, type='CDG') for u in df[v]] 142 | cfg.edges = df 143 | self.cdgs.append(cfg) 144 | return self.cdgs 145 | 146 | def see_cdg(self, code, filename='CDG', pdf=True, view=False): 147 | self.construct_cdg(code) 148 | # dot = Digraph(comment=filename, strict=True) 149 | dot = Digraph(comment=filename, strict=True, format='pdf', graph_attr={'rankdir':'LR'}) 150 | for cdg in self.cdgs: 151 | for n in cdg.edges: 152 | if n < 0: 153 | continue 154 | node = cdg.id_to_nodes[n] 155 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>" 156 | if node.is_branch: 157 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong') 158 | elif node.type == 'function_definition': 159 | dot.node(str(node.id), label=label, fontname='fangsong') 160 | else: 161 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong') 162 | for v in cdg.edges: 163 | for u in cdg.edges[v]: 164 | dot.edge(str(u.id), str(v)) 165 | if pdf: 166 | dot.render(filename, view=view, cleanup=True) 167 | 168 | if __name__ == '__main__': 169 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read()) 170 | cdg = CDG('c') 171 | # cdg.see_cfg(code, view=True) 172 | cdg.see_cdg(code, view=True) 173 | 174 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | text = lambda node: node.text.decode('utf-8') 2 | from graphviz import Digraph 3 | 4 | class Node: 5 | def __init__(self, node): 6 | self.line = node.start_point[0] + 1 7 | self.type = node.type 8 | self.id = hash((node.start_point, node.end_point)) % 1000000 9 | # self.id = node.start_point[0] + 1 10 | self.is_branch = False 11 | if node.type == 'function_definition': 12 | self.text = text(node.child_by_field_name('declarator').child_by_field_name('declarator')) # 函数名 13 | elif node.type in ['if_statement', 'while_statement', 'for_statement', 'switch_statement']: 14 | if node.type == 'if_statement': 15 | body = node.child_by_field_name('consequence') 16 | else: 17 | body = node.child_by_field_name('body') 18 | node_text = '' 19 | for child in node.children: 20 | if child == body: 21 | break 22 | node_text += text(child) 23 | self.text = node_text 24 | if node.type != 'switch_statement': 25 | self.is_branch = True 26 | elif node.type == 'do_statement': 27 | self.text = f'while{text(node.child_by_field_name("condition"))}' 28 | self.is_branch = True 29 | elif node.type == 'case_statement': 30 | node_text = '' 31 | for child in node.children: 32 | if child.type == ':': 33 | break 34 | node_text += ' ' + text(child) 35 | self.text = node_text 36 | self.is_branch = True 37 | else: 38 | self.text = text(node) 39 | defs, uses = self.get_def_use_info(node) 40 | self.defs = defs 41 | self.uses = uses 42 | 43 | def get_all_identifier(self, node): 44 | ids = [] 45 | def help(node): 46 | # 获取所有的变量名 47 | if node is None: 48 | return 49 | if node.type == 'identifier' and node.parent.type not in ['call_expression']: 50 | ids.append(text(node)) 51 | for child in node.children: 52 | help(child) 53 | help(node) 54 | return ids 55 | 56 | def get_def_id(self, node): 57 | update_ids, assignment_ids = [], [] 58 | def help(node): 59 | if node.type == 'update_expression': 60 | update_ids.append(text(node.child_by_field_name('argument'))) 61 | if node.type == 'assignment_expression': 62 | assignment_ids.append(text(node.child_by_field_name('left'))) 63 | for child in node.children: 64 | help(child) 65 | help(node) 66 | return update_ids, assignment_ids 67 | 68 | def get_node_def_use(self, node): 69 | uses = self.get_all_identifier(node) 70 | update_ids, assignment_ids = self.get_def_id(node) 71 | uses = list(set(uses) - set(assignment_ids) | set(update_ids)) 72 | return update_ids + assignment_ids, uses 73 | 74 | def get_def_use_info(self, node): 75 | # 获取变量的定义信息 76 | defi, uses = [], [] 77 | if node.type == 'function_definition': 78 | defi = self.get_all_identifier(node.child_by_field_name('declarator').child_by_field_name('parameters')) 79 | elif node.type == 'expression_statement': 80 | node = node.children[0] 81 | if node.type == 'call_expression': # scanf语句 82 | if text(node.child_by_field_name('function')) == 'scanf': 83 | arguments = node.child_by_field_name('arguments') 84 | defi = self.get_all_identifier(arguments) 85 | else: 86 | d, u = self.get_def_use_info(node) 87 | defi.extend(d) 88 | uses.extend(u) 89 | elif node.type == 'assignment_expression': # a += b 90 | d = text(node.child_by_field_name('left')) 91 | op = text(node.children[1]) 92 | u = self.get_all_identifier(node.child_by_field_name('right')) 93 | if op != '=': 94 | u.append(d) 95 | defi.append(d) 96 | uses = u 97 | elif node.type == 'update_expression': # a++; 98 | d = text(node.child_by_field_name('argument')) 99 | defi.append(d) 100 | uses.append(d) 101 | else: 102 | d, u = self.get_def_use_info(node) 103 | defi.extend(d) 104 | uses.extend(u) 105 | elif 'declaration' in node.type: 106 | defi.extend(self.get_all_identifier(node)) 107 | elif 'declarator' in node.type: 108 | d = text(node.child_by_field_name('declarator')) 109 | u = self.get_all_identifier(node) 110 | defi.append(d) 111 | uses = [i for i in u if i != d] 112 | elif node.type in ['if_statement', 'while_statement', 'do_statement', 'switch_statement']: 113 | condition = node.child_by_field_name('condition') 114 | d2, u2 = self.get_node_def_use(condition) 115 | defi.extend(d2) 116 | uses.extend(u2) 117 | elif node.type == 'for_statement': 118 | initializer = node.child_by_field_name('initializer') 119 | condition = node.child_by_field_name('condition') 120 | update = node.child_by_field_name('update') 121 | d1, u1 = self.get_node_def_use(initializer) 122 | d2, u2 = self.get_node_def_use(condition) 123 | d3, u3 = self.get_node_def_use(update) 124 | defi.extend(d1 + d2 + d3) 125 | uses.extend(u1 + u2 + u3) 126 | else: 127 | d, u = self.get_node_def_use(node) 128 | defi.extend(d) 129 | uses.extend(u) 130 | defi = list(set(defi)) 131 | uses = list(set(uses)) 132 | return defi, uses 133 | 134 | class Edge: 135 | def __init__(self, id, label='', type='', token=[]): 136 | self.id = id 137 | self.label = label 138 | self.type = type # CDG/DDG 139 | self.token = token # DDG的变量名 140 | 141 | class Graph: 142 | def __init__(self): 143 | self.nodes = set() 144 | self.edges = {} 145 | self.prev_nodes = {} 146 | self.id_to_nodes = {} 147 | self.r = self.Exit = None 148 | self.defs = {} # 存放每一个函数的变量定义信息 149 | self.uses = {} # 存放每一个函数的变量使用信息 150 | 151 | def add_edge(self, edge): 152 | node = edge[0] 153 | edges = edge[1] 154 | self.nodes.add(node) 155 | self.id_to_nodes[node.id] = node 156 | self.edges.setdefault(node.id, []) 157 | for prev_node, label in edges: 158 | self.edges.setdefault(prev_node, []) 159 | self.edges[prev_node].append(Edge(node.id, label)) 160 | 161 | def see_graph(self): 162 | dot = Digraph(comment='CFG', strict=True) 163 | for node, edges in self.edges.items(): 164 | for next_node in edges: 165 | left = 'Exit' if node == self.Exit else str(node) 166 | right = 'r' if next_node.id == self.r else str(next_node.id) 167 | dot.edge(left, right) 168 | dot.view() 169 | 170 | def reverse(self): 171 | # 返回反向的CFG图,添加了Exit节点 172 | E = {} 173 | V = set() 174 | for node, edges in self.edges.items(): 175 | V.add(node) 176 | if self.id_to_nodes[node].type == 'function_definition': # 函数入口 177 | self.r = node 178 | self.Exit = -node 179 | V.add(-node) # 增加一个Exit节点,该节点的id为函数节点的相反数 180 | E[self.Exit] = [Edge(self.r)] # 将Exit节点连接到函数节点 181 | for node, edges in self.edges.items(): 182 | if edges == [] or (len(edges) == 1 and edges[0].label == 'Y'): # 没有出节点或者出节点只有一个且是Y(没有N) 183 | E[self.Exit].append(Edge(node)) # 将Exit节点连接到没有出节点的节点 184 | for edge in edges: 185 | # node -> edge.id 变成 edge.id -> node 186 | E.setdefault(edge.id, []) 187 | E[edge.id].append(Edge(node)) 188 | self.nodes = V 189 | self.edges = E 190 | return self 191 | 192 | def Adj(self, node): 193 | adjs = [] 194 | if node in self.edges: 195 | for edge in self.edges[node]: 196 | adjs.append(edge.id) 197 | return adjs 198 | 199 | def get_def_use_info(self): 200 | for node in self.nodes: 201 | for each in node.defs: 202 | self.defs.setdefault(each, []) 203 | self.defs[each].append(node.id) 204 | for each in node.uses: 205 | self.uses.setdefault(each, []) 206 | self.uses[each].append(node.id) 207 | 208 | def findAllPath(self, start, end): 209 | # 算法参考:https://zhuanlan.zhihu.com/p/84437102 210 | # 输入两点,输出所有的路径列表 211 | paths, s1, s2 = [], [], [] # 存放所有路径,主栈,辅助栈 212 | s1.append(start) 213 | s2.append(self.Adj(start)) 214 | while s1: # 主栈不为空 215 | s2_top = s2[-1] 216 | if s2_top: # 邻接节点列表不为空 217 | s1.append(s2_top[0]) # 将邻接节点列表首个元素添加到主栈 218 | s2[-1] = s2_top[1:] # 将辅助栈的邻接节点列表首个元素删除 219 | temp = [] # 建栈,需要判断邻接节点是否在主栈中 220 | for each in self.Adj(s2_top[0]): 221 | if each not in s1: 222 | temp.append(each) 223 | s2.append(temp) 224 | else: # 削栈 225 | s1.pop() 226 | s2.pop() 227 | continue 228 | if s1[-1] == end: # 找到一条路径 229 | paths.append(s1.copy()) 230 | s1.pop() # 回溯 231 | s2.pop() 232 | return paths -------------------------------------------------------------------------------- /CFG.py: -------------------------------------------------------------------------------- 1 | from AST import * 2 | from utils import * 3 | import html 4 | import copy 5 | 6 | def get_break_continue_node(node): 7 | # 找到node节点循环中的所有break和continue节点并返回 8 | break_nodes, continue_nodes = [], [] 9 | for child in node.children: 10 | if child.type == 'break_statement': 11 | break_nodes.append(child) 12 | elif child.type == 'continue_statement': 13 | continue_nodes.append(child) 14 | elif child.type not in ['for_statement', 'while_statement']: 15 | b_node, c_nodes = get_break_continue_node(child) 16 | break_nodes.extend(b_node) 17 | continue_nodes.extend(c_nodes) 18 | return break_nodes, continue_nodes 19 | 20 | def get_edge(in_nodes): 21 | # 输入入节点,返回入边的列表,边为(parent_id, label) 22 | edge = [] 23 | for in_node in in_nodes: 24 | parent, label = in_node 25 | parent_id = parent.id 26 | edge.append((parent_id, label)) 27 | return edge 28 | 29 | class CFG(AST): 30 | def __init__(self, language): 31 | super().__init__(language) 32 | self.cfgs = [] # 存放每一个函数的CFG图 33 | 34 | def create_cfg(self, node, in_nodes=[()]): 35 | # 输入当前节点,以及入节点,入节点为(node_info, edge_label)的列表,node_info['id']唯一确定一个节点,edge_label为边的标签 36 | if node.child_count == 0 or in_nodes == []: # 如果in_nodes为空,说明没有入节点,跳过 37 | return [], in_nodes 38 | if node.type == 'function_definition': # 如果节点是函数,则创建函数节点,并且递归遍历函数的compound_statement 39 | body = node.child_by_field_name('body') 40 | node_info = Node(node) 41 | CFG, _ = self.create_cfg(body, [(node_info, '')]) 42 | return CFG + [(node_info, [])], [] 43 | elif node.type == 'compound_statement': # 如果是复合语句,则递归遍历复合语句的每一条statement 44 | CFG = [] 45 | for child in node.children: 46 | cfg, out_nodes = self.create_cfg(child, in_nodes) 47 | CFG.extend(cfg) 48 | in_nodes = out_nodes 49 | return CFG, in_nodes 50 | elif node.type not in ['if_statement', 'while_statement', 'for_statement', 'switch_statement', 'case_statement', 'translation_unit', 'do_statement']: # 如果是普通的语句 51 | edge = get_edge(in_nodes) 52 | node_info = Node(node) 53 | in_nodes = [(node_info, '')] 54 | if node.type in ['return_statement', 'break_statement', 'continue_statement']: # return,break,continue语句没有出节点 55 | return [(node_info, edge)], [] 56 | else: 57 | return [(node_info, edge)], in_nodes 58 | elif node.type == 'if_statement': # if语句 59 | CFG = [] 60 | edge = get_edge(in_nodes) 61 | node_info = Node(node) 62 | CFG.append((node_info, edge)) 63 | body = node.child_by_field_name('consequence') # 获取if的主体部分 64 | cfg, out_nodes = self.create_cfg(body, [(node_info, 'Y')]) 65 | CFG.extend(cfg) 66 | alternate = node.child_by_field_name('alternative') # 获取else的主体部分,可能是else,也可能是else if 67 | if alternate: # if else 或者 if else if 68 | body = alternate.children[1] 69 | cfg, al_out_nodes = self.create_cfg(body, [(node_info, 'N')]) 70 | CFG.extend(cfg) 71 | return CFG, out_nodes + al_out_nodes 72 | else: # 只有if 73 | return CFG, out_nodes + [(node_info, 'N')] 74 | elif node.type in ['for_statement', 'while_statement']: # for和while循环 75 | CFG = [] 76 | edge = get_edge(in_nodes) 77 | node_info = Node(node) 78 | CFG.append((node_info, edge)) 79 | body = node.child_by_field_name('body') # 获取循环主体 80 | cfg, out_nodes = self.create_cfg(body, [(node_info, 'Y')]) 81 | CFG.extend(cfg) 82 | for out_node in out_nodes: # 将循环主体的出节点与循环的开始节点相连 83 | parent, label = out_node 84 | parent_id = parent.id 85 | CFG.append((node_info, [(parent_id, label)])) 86 | break_nodes, continue_nodes = get_break_continue_node(node) # 求得循环内的break和continue节点 87 | out_nodes = [(node_info, 'N')] # 循环体的出节点开始节点,条件为N 88 | for break_node in break_nodes: 89 | out_nodes.append((Node(break_node), '')) # 将break节点添加到out_nodes中 90 | for continue_node in continue_nodes: 91 | CFG.append((node_info, [(Node(continue_node).id, '')])) # 将continue节点连接到循环的开始节点 92 | return CFG, out_nodes 93 | elif node.type == 'do_statement': # do while循环 94 | CFG = [] 95 | edge = get_edge(in_nodes) 96 | node_info = Node(node) 97 | body = node.child_by_field_name('body') # 获取循环主体 98 | cfg, out_nodes = self.create_cfg(body, [(node_info, '')]) 99 | first_node = cfg[0][0] 100 | CFG.append((first_node, edge)) 101 | CFG.extend(cfg) 102 | for out_node in out_nodes: # 将循环主体的出节点与条件节点相连 103 | parent, label = out_node 104 | parent_id = parent.id 105 | CFG.append((node_info, [(parent_id, label)])) 106 | CFG.append((first_node, [(node_info.id, 'Y')])) # 将条件节点连接到循环主体的开始节点 107 | out_nodes = [(node_info, 'N')] # 循环体的出节点开始节点,条件为N 108 | break_nodes, continue_nodes = get_break_continue_node(node) # 求得循环内的break和continue节点 109 | for break_node in break_nodes: 110 | out_nodes.append((Node(break_node), '')) 111 | for continue_node in continue_nodes: 112 | CFG.append((node_info, [(Node(continue_node).id, '')])) 113 | return CFG, out_nodes 114 | elif node.type == 'switch_statement': # switch语句 115 | CFG = [] 116 | edge = get_edge(in_nodes) 117 | node_info = Node(node) 118 | CFG.append((node_info, edge)) 119 | body = node.child_by_field_name('body') # 获取switch的主体部分 120 | cfg, out_nodes = self.create_cfg(body, [(node_info, '')]) # 递归遍历case语句 121 | CFG.extend(cfg) 122 | break_nodes, _ = get_break_continue_node(node) # 将break语句添加到out_nodes当中 123 | for break_node in break_nodes: 124 | out_nodes.append((Node(break_node), '')) 125 | return CFG, out_nodes 126 | elif node.type == 'case_statement': # case语句 127 | CFG = [] 128 | edge = get_edge(in_nodes) 129 | node_info = Node(node) 130 | CFG.append((node_info, edge)) 131 | if node.children[0].type == 'case': # 如果是case语句 132 | in_nodes = [(node_info, 'Y')] 133 | for child in node.children[3:]: 134 | cfg, out_nodes = self.create_cfg(child, in_nodes) 135 | CFG.extend(cfg) 136 | in_nodes = out_nodes 137 | return CFG, in_nodes + [(node_info, 'N')] 138 | else: # default 139 | in_nodes = [(node_info, '')] 140 | for child in node.children[2:]: 141 | cfg, out_nodes = self.create_cfg(child, in_nodes) 142 | CFG.extend(cfg) 143 | in_nodes = out_nodes 144 | return CFG, in_nodes 145 | else: 146 | CFGs = [] # 存放每一个函数的CFG图 147 | for child in node.children: 148 | if child.type == 'function_definition': # 获得每一个函数的CFG图 149 | CFG, out_nodes = self.create_cfg(child, in_nodes) 150 | CFGs.append(CFG) 151 | return CFGs, in_nodes 152 | 153 | def construct_cfg(self, code): 154 | tree = self.parser.parse(bytes(code, 'utf-8')) 155 | root_node = tree.root_node 156 | CFGs, _ = self.create_cfg(root_node) 157 | for func_cfg in CFGs: 158 | cfg = Graph() 159 | for each in func_cfg: 160 | cfg.add_edge(each) 161 | # 删除break节点和continue节点 162 | # cfg_edges = copy.deepcopy(cfg.edges) 163 | # for node, edges in cfg_edges.items(): 164 | # for i, edge in enumerate(edges): 165 | # if cfg.id_to_nodes[edge.id].type in ['break_statement', 'continue_statement']: 166 | # node_id = edge.id 167 | # next_node = cfg.edges[node_id][0].id 168 | # cfg.edges[node][i].id = next_node 169 | # del cfg.edges[node_id] 170 | # cfg.nodes.remove(cfg.id_to_nodes[node_id]) 171 | cfg.get_def_use_info() 172 | self.cfgs.append(cfg) 173 | 174 | def see_cfg(self, code, filename='CFG', pdf=True, view=False): 175 | self.construct_cfg(code) 176 | dot = Digraph(comment=filename, strict=True) 177 | for cfg in self.cfgs: 178 | # for n in cfg.id_to_nodes: 179 | # input((n, cfg.id_to_nodes[n].id)) 180 | for node in cfg.nodes: 181 | label = f"<({node.type}, {html.escape(node.text)}){node.line}>" 182 | if node.is_branch: 183 | dot.node(str(node.id), shape='diamond', label=label, fontname='fangsong') 184 | elif node.type == 'function_definition': 185 | dot.node(str(node.id), label=label, fontname='fangsong') 186 | else: 187 | dot.node(str(node.id), shape='rectangle', label=label, fontname='fangsong') 188 | for node, edges in cfg.edges.items(): 189 | for edge in edges: 190 | next_node, label = edge.id, edge.label 191 | dot.edge(str(node), str(next_node), label=label) 192 | if pdf: 193 | dot.render(filename, view=view, cleanup=True) 194 | return self.cfgs 195 | 196 | if __name__ == '__main__': 197 | code = r'{}'.format(open('test.c', 'r', encoding='utf-8').read()) 198 | cfg = CFG('c') 199 | # cfg.see_tree(code, view=True) 200 | cfg.see_cfg(code, view=True) 201 | -------------------------------------------------------------------------------- /File.py: -------------------------------------------------------------------------------- 1 | from AST import * 2 | import os 3 | import re 4 | 5 | MAX_LENGTH = 1000000 # 最大行数 6 | constant_type = ['number_literal', 'string_literal', 'character_literal', 'preproc_arg', 'true', 'false', 'null'] # 常量类型 7 | 8 | def Array(node, type): 9 | ''' 10 | 目的-为了获取int a[m][n];声明中的变量名a和类型int ** 11 | 输入-节点node和类型type,例如node: a[m][n], type: int 12 | 输出-变量名和类型,例如a, int** 13 | ''' 14 | dim = 0 15 | node_type = node.type 16 | while node and node.type == node_type: 17 | dim += 1 18 | node = node.children[0] 19 | name = text(node) 20 | type = f'{type}{"*"*dim}' 21 | return name, type 22 | 23 | def Pointer(node, type): 24 | ''' 25 | 目的-为了获取int *a[m];声明中的变量名a和类型int * 26 | 输入-节点node和类型type,例如node: *a[m], type: int 27 | 输出-变量名和类型,例如a, int** 28 | ''' 29 | dim = 0 30 | if node.type not in ['array_declarator', 'pointer_declarator']: 31 | return None, type 32 | node_type = node.type 33 | while node and node.type == node_type: 34 | dim += 1 35 | node = node.children[1] 36 | if node.type == 'array_declarator': # char* argv[] 37 | while node and node.type == 'array_declarator': 38 | dim += 1 39 | node = node.children[0] 40 | if node.type == 'function_declarator': 41 | name = text(node.child_by_field_name('declarator')) 42 | else: 43 | name = text(node) 44 | type = f'{type}{"*"*dim}' 45 | return name, type 46 | 47 | class Identifier: 48 | def __init__(self, type, name, domain, structure_, class_): 49 | self.type = type # 变量名类型 50 | self.name = name # 变量名 51 | self.domain = domain # 作用域:[start line, end line] 52 | self.structure_ = structure_ # 所属结构体名称 53 | self.class_ = class_ # 所属类名称 54 | 55 | def __str__(self): 56 | str = f'name: {self.name}\ntype: {self.type}\ndomain: {self.domain}\n' 57 | if self.structure_: 58 | str += f'structure: {self.structure_}\n' 59 | if self.class_: 60 | str += f'class: {self.class_}\n' 61 | return str 62 | 63 | class Declaration: 64 | def __init__(self, node): 65 | ''' 66 | 目的-获取int a[m], b=0;声明的变量名和类型: {a: int*, b: int} 67 | 输入-节点node,例如int a[m], b=0; 68 | 输出-变量名和类型,例如{a: int*, b: int} 69 | ''' 70 | self.identifiers = {} 71 | if node.type not in ['declaration', 'field_declaration', 'parameter_declaration', 'type_definition']: 72 | return 73 | type_node = node.child_by_field_name('type') 74 | if type_node.type == 'struct_specifier': # 结构体类型 75 | if not type_node.child_by_field_name('name'): 76 | return 77 | self.type = text(type_node.child_by_field_name('name')) 78 | else: # 基本类型 79 | self.type = text(type_node) 80 | self.identifiers = {} 81 | if node.type == 'parameter_declaration': # 获取函数参数的变量名和类型 82 | name, type = self.get_name_and_type(node.child_by_field_name('declarator')) 83 | if name and type: 84 | self.identifiers[name] = type 85 | else: 86 | for child in node.children[1: -1]: 87 | name, type = self.get_name_and_type(child) # 获取声明的变量名和类型 88 | if name and type: 89 | self.identifiers[name] = type 90 | 91 | def get_name_and_type(self, node): 92 | ''' 93 | 目的-获取以node为根节点的变量名和类型 94 | 输入-节点node,例如int a[m] 95 | 输出-变量名和类型,例如a, int* 96 | ''' 97 | if node is None: 98 | return None, None 99 | if node.type == 'array_declarator': # 如果是数组 100 | return Array(node, self.type) 101 | elif node.type == 'pointer_declarator': # 如果是指针 102 | return Pointer(node, self.type) 103 | elif node.type in ['init_declarator', 'parameter_declaration']: # 如果是初始化声明,还需要遍历declarator 104 | return self.get_name_and_type(node.child_by_field_name('declarator')) 105 | elif node.type in ['identifier', 'field_identifier', 'type_identifier']: # 一般的变量的类型就是self.type 106 | return text(node), self.type 107 | else: 108 | return None, None 109 | 110 | def __call__(self): 111 | '''return {identifier: type}''' 112 | return self.identifiers 113 | 114 | class Structure: 115 | def __init__(self, node): 116 | ''' 117 | 目的-获取结构体属性的类型field_vars,以及结构体声明的变量def_vars 118 | 输入-节点node,例如struct A{int a; int b;}a; 119 | 输出-结构体属性的类型field_vars={a:int, b:int},以及结构体声明的变量def_vars{a:A} 120 | ''' 121 | self.field_vars = {} 122 | self.def_vars = {} 123 | if node.type != 'struct_specifier': 124 | print('Error: not a struct_specifier') 125 | return 126 | if not node.child_by_field_name('name') and node.type == 'struct_specifier': 127 | self.name = text(node.parent.child_by_field_name('declarator')) 128 | else: 129 | self.name = text(node.child_by_field_name('name')) 130 | declaration = Declaration(node.parent) 131 | self.def_vars = declaration() 132 | body = node.child_by_field_name('body') 133 | if body: 134 | for field_declaration in body.children: 135 | if text(field_declaration) not in ['{', '}']: 136 | declaration = Declaration(field_declaration) 137 | self.field_vars.update(declaration()) 138 | 139 | class IdType: 140 | def __init__(self, file_path): 141 | self.file_path = file_path 142 | self.vars = {} # 存放变量名的定义信息{a: [Identifier, Identifier], b: [Identifier]} 143 | self.macro = {} # 存放宏定义的类型{ll: long long, ull: unsigned long long} 144 | 145 | def add_def_var(self, identifier, domain, structure_, class_): 146 | ''' 147 | 目的-将声明的变量加入到vars中 148 | 输入-声明的变量identifier,例如{a: int*, b: int},作用域domain=[start line, end line], structure_和class_ 149 | ''' 150 | for name, type in identifier.items(): 151 | self.vars.setdefault(name, []) 152 | self.vars[name].append(Identifier(type, name, domain, structure_, class_)) 153 | 154 | def add_macro(self, name, type): 155 | self.macro[name] = type 156 | 157 | def query_type(self, id, line): 158 | ''' 159 | 目的-查询变量id在line行的类型 160 | 输入-变量id和行号line 161 | 输出-变量id在line行的类型 162 | ''' 163 | if id not in self.vars: 164 | return 'unknown' 165 | match_info = [] # 存放符合作用域的变量信息 166 | for info in self.vars[id]: 167 | if info.domain[0] <= line <= info.domain[1]: 168 | match_info.append(info) 169 | if not match_info: 170 | return 'unknown' 171 | # 将match_info按照domain的长度排序,找出来最局部变量的类型 172 | match_info.sort(key=lambda x: x.domain[1] - x.domain[0]) 173 | type = match_info[0].type 174 | dim = type.count('*') 175 | new_type = type.replace('*', '') 176 | if new_type in self.macro: 177 | type = self.macro[new_type] + '*' * dim 178 | return type 179 | 180 | def __str__(self): 181 | str = [] 182 | for var, info in self.vars.items(): 183 | str.append(f"{var:=^40}") 184 | for i in info: 185 | str.append(i.__str__()) 186 | return '\n'.join(str) 187 | 188 | class Func: 189 | def __init__(self, node, file_path): 190 | ''' 191 | 目的-获取函数的返回类型,函数名,参数类型 192 | 输入-节点node,例如int main(int argc, char* argv[]), 文件路径file_path 193 | 输出-type = int, name = main, signature = {'return': int, 'name': main, 'parameters': [{'argc': int}, {'argv': char**}]} 194 | ''' 195 | self.func_node = node 196 | self.file_path = file_path 197 | type = text(node.child_by_field_name('type')) # 这里的type有问题 198 | _, self.type = Pointer(node.child_by_field_name('declarator'), type) 199 | self.line = node.start_point[0] + 1 200 | self.id = hash((node.start_byte, node.end_byte)) % 1000000007 201 | while node.type not in ['function_declarator', 'parenthesized_declarator']: # parenthesized_declarator: 例如main(){}, void省略了 202 | node = node.child_by_field_name('declarator') 203 | if not node: 204 | self.name = None 205 | self.type = None 206 | self.signature = {'parameters':[]} 207 | return 208 | if node.type == 'function_declarator': 209 | self.name = text(node.child_by_field_name('declarator')) 210 | parameters = node.child_by_field_name('parameters') 211 | param_list = [] 212 | for param in parameters.children: 213 | if text(param) not in [',', '(', ')']: 214 | declaration = Declaration(param) 215 | for name, type in declaration().items(): 216 | param_list.append((name, type)) 217 | else: 218 | self.name = self.type 219 | self.type = 'void' 220 | param_list = [] 221 | self.signature = {'return': self.type, 'name': self.name, 'parameters': param_list} 222 | 223 | def __eq__(self, other): 224 | return self.signature == other.signature 225 | 226 | def __str__(self): 227 | param_str = '' 228 | for param in self.signature['parameters']: 229 | param_str += f"({param[0]}: {param[1]}) " 230 | return f"function name: {self.name}\nreturn type: {self.type}\nparameters: {param_str}\nfile path: {self.file_path}\n\n" 231 | 232 | class Function: 233 | def __init__(self): 234 | self.funcs = {} 235 | self.id_to_func = {} 236 | 237 | def add_func(self, node, file_path): 238 | func = Func(node, file_path) 239 | if not func.name: 240 | return 241 | if func.name not in self.funcs: 242 | self.funcs[func.name] = [func] 243 | self.id_to_func[func.id] = func 244 | else: 245 | if func not in self.funcs[func.name]: 246 | self.funcs[func.name].append(func) 247 | return func 248 | 249 | def match_func(self, func_name, param_node, expression): 250 | '''根据函数名和参数匹配函数,返回函数信息''' 251 | if len(self.funcs[func_name]) == 1: # 没有重构函数 252 | return self.funcs[func_name][0] 253 | # 如果有重构函数,如果参数的个数唯一,则返回对应函数 254 | if param_node.type != 'argument_list': 255 | print("not a param list") 256 | return None 257 | param_types = [] 258 | for param in param_node.children[1: -1]: 259 | if param.type == ',': 260 | continue 261 | param_types.append(expression.traverse(param)) 262 | match_num = 0 263 | match_func = None 264 | for func in self.funcs[func_name]: 265 | if len(func.signature['parameters']) == len(param_types): 266 | match_func = func 267 | match_num += 1 268 | if match_num == 1: 269 | return match_func 270 | else: # 根据参数类型来匹配 271 | for func in self.funcs[func_name]: 272 | func_param_types = [x[1] for x in func.signature['parameters']] 273 | if func_param_types == param_types: 274 | return func 275 | return None 276 | 277 | def __getitem__(self, func_name): 278 | return self.funcs[func_name] 279 | 280 | def __call__(self): 281 | return self.funcs 282 | 283 | def __str__(self): 284 | str = '' 285 | for func_name in self.funcs: 286 | for f in self.funcs[func_name]: 287 | str += f.__str__() 288 | return str 289 | 290 | class Constant: 291 | def __init__(self, node): 292 | ''' 293 | 目的-获取常量的类型 294 | 输入-节点node,例如1, 1.0, 'a', "abc", TRUE, FALSE, NULL 295 | 输出-常量的类型,例如int, float, char, char*, bool, null 296 | ''' 297 | if not node or node.type not in constant_type: 298 | self.type = 'unknown' 299 | return 300 | string = text(node) 301 | if node.type == 'preproc_arg': # define a const中的const 302 | # 删除string中的注释"//"和"/* */" 303 | string = re.sub(r'//.*', '', string) 304 | string = re.sub(r'/\*.*\*/', '', string) 305 | if self.is_float(string): 306 | self.type = 'float' 307 | elif self.is_int(string): 308 | self.type = 'int' 309 | elif self.is_char(string): 310 | self.type = 'char' 311 | elif self.is_string(string): 312 | self.type = 'char*' 313 | elif self.is_bool(string): 314 | self.type = 'bool' 315 | elif self.is_null(string): 316 | self.type = 'null' 317 | else: 318 | self.type = 'unknown' 319 | self.value = string 320 | 321 | def is_float(self, string): 322 | return bool(re.match(r'^[-+]?[0-9]*\.[0-9]+$', string)) 323 | 324 | def is_int(self, string): 325 | hex_, oct_, dec_ = False, False, False 326 | hex_ = bool(re.match(r'^[-+]?0[xX][0-9a-fA-F]+$', string)) 327 | oct_ = bool(re.match(r'^[-+]?0[0-7]+$', string)) 328 | dec_ = bool(re.match(r'^[-+]?[0-9]+$', string)) 329 | return hex_ or oct_ or dec_ 330 | 331 | def is_char(self, string): 332 | return bool(re.match(r'^\'[^\']\'$', string)) 333 | 334 | def is_string(self, string): 335 | return bool(re.match(r'^\"[^\"]*\"$', string)) 336 | 337 | def is_bool(self, string): 338 | return string in ['TURE', 'FALSE'] 339 | 340 | def is_null(self, string): 341 | return string == 'NULL' 342 | 343 | class Expression: 344 | def __init__(self, idType, Structure, Function): 345 | self.idType = idType 346 | self.structure_ = Structure 347 | self.function = Function 348 | 349 | def type(self, node): 350 | type = self.traverse(node) 351 | if type == 'unknown': 352 | return 'unknown' 353 | dim = type.count('*') 354 | new_type = type.replace('*', '').replace(' ', '') 355 | if new_type in self.idType.macro: 356 | type = self.idType.macro[new_type] + '*' * dim 357 | return type 358 | 359 | def traverse(self, node): 360 | ''' 361 | 目的-从上往下遍历node表达式,获取表达式的类型 362 | 输入-idType, Structure, Function,分别表示变量名类型、结构体类型、函数类型 363 | 输出-表达式的类型 364 | ''' 365 | if node.type in constant_type: # 常量类型通过Constant查询 366 | return Constant(node).type 367 | elif node.type == 'identifier': # 变量名类型通过idType查询 368 | return self.idType.query_type(text(node), node.start_point[0] + 1) 369 | elif node.type == 'sizeof_expression': # sizeof表达式返回int 370 | return 'int' 371 | elif node.type in ['unary_expression', 'update_expression']: # 一元表达式返回本身的类型 372 | return self.traverse(node.child_by_field_name('argument')) 373 | elif node.type == 'conditional_expression': # 三元表达式返回两个表达式的类型 374 | return self.traverse(node.child_by_field_name('consequence')) 375 | elif node.type == 'subscript_expression': # 数组类型,例如a[i],a的类型是int*,那么a[i]的类型是int 376 | argument_type = self.traverse(node.child_by_field_name('argument')) 377 | if argument_type == 'unknown': 378 | return 'unknown' 379 | return argument_type[:-1] # 去掉最后一个*号 380 | elif node.type == 'field_expression': # 结构体类型,例如a.b,a的类型是A,那么a.b的类型是int 381 | argument_type = self.traverse(node.child_by_field_name('argument')) 382 | if argument_type == 'unknown': 383 | return 'unknown' 384 | field = text(node.child_by_field_name('field')) 385 | op = text(node.children[1]) 386 | if op == '->': # 例如a的类型是A*,那么a->b的时候,要删掉a类型的最后一个*号 387 | argument_type = argument_type[:-1] 388 | if argument_type not in self.structure_: 389 | return 'unknown' 390 | structure = self.structure_[argument_type] 391 | if field not in structure.field_vars: 392 | return 'unknown' 393 | return structure.field_vars[field] 394 | elif node.type == 'pointer_expression': # 指针类型,例如*a, &a 395 | op = text(node.children[0]) 396 | argument_type = self.traverse(node.child_by_field_name('argument')) 397 | if argument_type == 'unknown': 398 | return 'unknown' 399 | if op == '*': # *a删除a类型的最后一个*号 400 | return argument_type[:-1] 401 | else: # &添加一个*号 402 | return argument_type + '*' 403 | elif node.type == 'call_expression': # 函数类型,例如f(a, b),f的类型是int(int, int),那么f(a, b)的类型是int 404 | callee = text(node.child_by_field_name('function')) 405 | if callee in self.function(): 406 | parameters = node.child_by_field_name('arguments') 407 | match_func = self.function.match_func(callee, parameters, self) 408 | if match_func: 409 | return match_func.type 410 | return 'unknown' 411 | elif node.type == 'parenthesized_expression': # 括号表达式返回括号内的表达式类型 412 | return self.traverse(node.children[1]) 413 | elif node.type == 'binary_expression': # 二元表达式 414 | op = text(node.children[1]) 415 | if op in ['==', '!=', '>', '>=', '<', '<=', '&&', '||']: 416 | return 'bool' 417 | else: # 除了bool类型的二元表达式,其他的二元表达式返回强制类型转换后的类型 418 | left_type = self.traverse(node.child_by_field_name('left')) 419 | right_type = self.traverse(node.child_by_field_name('right')) 420 | return self.typecasting(left_type, right_type) 421 | elif node.type == 'assignment_expression': # 赋值表达式返回左边的表达式类型 422 | return self.traverse(node.child_by_field_name('left')) 423 | elif node.type == 'cast_expression': # 强制类型转换表达式返回强制类型转换后的类型 424 | return text(node.child_by_field_name('type')) 425 | else: 426 | return 'unknown' 427 | 428 | def typecasting(self, type1, type2): 429 | '''目的-将type1和type2进行强制类型转换''' 430 | def compare(type_list1, type_list2): 431 | if type1 in type_list1 and type2 in type_list2: 432 | return True 433 | if type1 in type_list2 and type2 in type_list1: 434 | return True 435 | return False 436 | def unsigned(types, copy=True): 437 | '''增加unsigned修饰''' 438 | if copy: 439 | new_types = types.copy() 440 | else: 441 | new_types = [] 442 | for type in types: 443 | if type == 'int': 444 | new_types.append('unsigned') 445 | new_types.append(f'unsigned {type}') 446 | new_types.append(f'{type} unsigned') 447 | return new_types 448 | if type1 == 'bool': # bool强制转换成int 449 | type1 = 'int' 450 | if type2 == 'bool': 451 | type2 = 'int' 452 | if compare(unsigned(['char', 'short']), unsigned(['char', 'short'])): # char和short强制转换成int 453 | return 'int' 454 | elif compare(['float'], ['float']): # float强制转换成double 455 | return 'double' 456 | elif type1 == type2: 457 | return type1 458 | elif compare(unsigned(['char', 'short']), ['int']): 459 | return 'int' 460 | if compare(unsigned(['char', 'short', 'int']), unsigned(['int'], copy=False)): # unsigned类型优先级高 461 | return 'unsigned int' 462 | elif compare(unsigned(['char', 'short', 'int']), ['long']): 463 | return 'long' 464 | elif compare(unsigned(['char', 'short', 'int', 'long']), unsigned(['long'], copy=False)): 465 | return 'unsigned long' 466 | elif compare(unsigned(['char', 'short', 'int', 'long']), ['long long']): 467 | return 'long long' 468 | elif compare(unsigned(['char', 'short', 'int', 'long', 'long long']), unsigned(['long long'], copy=False)): 469 | return 'unsigned long long' 470 | elif compare(unsigned(['char', 'short', 'int', 'long', 'long long']), ['float']): 471 | return 'float' 472 | elif compare(unsigned(['char', 'short', 'int', 'long', 'long long']) + ['float'], ['double']): 473 | return 'double' 474 | elif '*' in type1: # 如果type1是指针类型,返回type1 475 | return type1 476 | elif '*' in type2: 477 | return type2 478 | else: 479 | return 'unknown' 480 | 481 | class File(AST): 482 | def __init__(self, language, file_path): 483 | super().__init__(language) 484 | self.structure_ = {} # {structure_name: structure} 485 | self.function = Function() # {function_name: [function]} 可能重构 486 | self.file_path = os.path.join(os.path.abspath('.'), file_path) # 文件路径 487 | self.idType = IdType(file_path) # 变量名类型 488 | self.CG = {} # 存放函数调用图 489 | self.unknown_call = {} # 存放未知函数调用 490 | self.unknown_id = {} # 存放未知函数名id 491 | code = r'{}'.format(open(self.file_path, 'r', encoding='utf-8', errors='ignore').read()) 492 | tree = self.parser.parse(bytes(code, 'utf8')) 493 | self.root_node = tree.root_node 494 | 495 | def construct_file(self, node=None): 496 | ''' 497 | 构建文件的信息,包括全局信息:结构体、全局变量类型、函数、宏定义 498 | ''' 499 | if node is not None: 500 | root_node = node 501 | else: 502 | root_node = self.root_node 503 | for child in root_node.children: 504 | if child.type in ['preproc_ifdef', 'preproc_else', 'linkage_specification', 'declaration_list', 'ERROR']: # 如果是ifdef,则遍历ifdef内部的节点 505 | self.construct_file(child) 506 | elif child.type in ['declaration', 'struct_specifier']: # 结构体定义 507 | if child.type == 'struct_specifier': # struct A{int a; int b;}; 508 | type_node = child 509 | else: # struct A{int a; int b;} a; 在定义的时候多了声明 510 | self.function.add_func(child, self.file_path) 511 | type_node = child.child_by_field_name('type') 512 | if type_node.type == 'struct_specifier' and type_node.child_by_field_name('body'): # 结构体类型,不是函数指针 513 | structure = Structure(type_node) 514 | self.structure_[structure.name] = structure 515 | domain = [child.end_point[0] + 1, MAX_LENGTH] # 作用域全局 516 | self.idType.add_def_var(structure.def_vars, domain, structure.name, None) 517 | if type_node.type in ['primitive_type', 'type_identifier']: # 全局变量类型 518 | declaration = Declaration(child) 519 | domain = [child.start_point[0] + 1, MAX_LENGTH] 520 | self.idType.add_def_var(declaration(), domain, None, None) 521 | elif child.type == 'type_definition': # 宏定义typedef 522 | type_node = child.child_by_field_name('type') 523 | if type_node.type == 'struct_specifier': 524 | structure = Structure(type_node) 525 | self.structure_[structure.name] = structure 526 | domain = [child.end_point[0] + 1, MAX_LENGTH] 527 | self.idType.add_def_var(structure.def_vars, domain, structure.name, None) 528 | for name, type in structure.def_vars.items(): 529 | self.idType.add_macro(name, type) 530 | else: 531 | type = text(type_node) 532 | declaration = Declaration(child) 533 | for name, type in declaration().items(): 534 | self.idType.add_macro(name, type) 535 | elif child.type == 'preproc_def': # 宏定义#define,根据常量确定类型 536 | name = text(child.child_by_field_name('name')) 537 | const = Constant(child.child_by_field_name('value')) 538 | domain = [child.start_point[0] + 1, MAX_LENGTH] 539 | self.idType.add_def_var({name: const.type}, domain, None, None) 540 | elif child.type == 'function_definition': # 获取函数内部的变量类型 541 | func_node = child 542 | func_info = self.function.add_func(func_node, self.file_path) 543 | if not func_info: 544 | continue 545 | for name, type in func_info.signature['parameters']: 546 | domain = [func_node.start_point[0] + 1, func_node.end_point[0] + 1] 547 | self.idType.add_def_var({name: type}, domain, None, None) 548 | body = func_node.child_by_field_name('body') 549 | self.get_local_type(body) # 获取函数复合语句内部的变量类型 550 | # self.construct_call_graph() 551 | # self.query_type(self.root_node) 552 | # print(self.idType) 553 | # print(self.structure_) 554 | 555 | def construct_call_graph(self): 556 | '''构建函数调用图''' 557 | self.expression = Expression(self.idType, self.structure_, self.function) # 定义表达式类,用来求表达式的类型 558 | def query_callee(node, caller): 559 | '''目的-查询调用者caller调用的函数callee''' 560 | if node.type == 'call_expression': 561 | callee = text(node.child_by_field_name('function')) 562 | arguments = node.child_by_field_name('arguments') 563 | if callee in self.function(): 564 | match_func = self.function.match_func(callee, arguments, self.expression) 565 | if match_func: 566 | self.CG.setdefault(match_func.id, set()) 567 | self.CG[match_func.id].add(caller) 568 | else: # 不是自己定义的函数 569 | self.unknown_call.setdefault(callee, set()) 570 | self.unknown_call[callee].add(caller) 571 | if callee not in self.unknown_id: 572 | self.unknown_id[callee] = -len(self.unknown_id) - 1 573 | for child in node.children: 574 | query_callee(child, caller) 575 | for func_name in self.function(): # 变量文件内的所有函数,查询各个函数内部调用的函数 576 | for func in self.function[func_name]: 577 | query_callee(func.func_node, func.id) 578 | 579 | def get_local_type(self, node): 580 | '''获取函数内部变量名的类型''' 581 | if node.type not in ['compound_statement', 'else_clause']: # 如果不是复合语句或者else从句,直接pass 582 | return 583 | for statement in node.children: # 遍历复合语句(花括号)内的每一个语句 584 | if statement.type == 'compound_statement': 585 | self.get_local_type(statement) 586 | if statement.type == 'declaration': # 如果是声明语句,获取声明的变量名和类型 587 | declaration = Declaration(statement) 588 | domain = [statement.start_point[0] + 1, node.end_point[0] + 1] 589 | self.idType.add_def_var(declaration(), domain, None, None) 590 | elif statement.type in ['for_statement', 'while_statement', 'do_statement']: # 如果是循环语句,则循环语句内定义的变量的作用域是循环语句内 591 | if statement.type == 'for_statement': # for循环额外有一个初始化表达式 592 | initializer = statement.child_by_field_name('initializer') 593 | if initializer: 594 | declaration = Declaration(initializer) 595 | domain = [statement.start_point[0] + 1, statement.end_point[0] + 1] 596 | self.idType.add_def_var(declaration(), domain, None, None) 597 | self.get_local_type(statement.child_by_field_name('body')) 598 | elif statement.type == 'if_statement': # 如果是if语句,获取if语句内部的变量类型 599 | self.get_local_type(statement.child_by_field_name('consequence')) 600 | alternative = statement.child_by_field_name('alternative') 601 | if alternative: 602 | self.get_local_type(alternative) 603 | 604 | def query_type(self, root_node, id_types={}): 605 | '''以root_node节点为根节点,往下遍历AST树,获取变量名的类型''' 606 | self.expression = Expression(self.idType, self.structure_, self.function) 607 | for child in root_node.children: 608 | if 'expression' in child.type and child.type != 'expression_statement': 609 | exp = self.expression.type(child) 610 | print((text(child), exp)) 611 | self.query_type(child) 612 | return id_types 613 | 614 | def see_cg(self, unknown=True, pdf=True, view=True): 615 | '''可视化函数调用图''' 616 | self.construct_call_graph() 617 | if not self.CG: 618 | return 619 | # dot = Digraph(comment=self.file_path, graph_attr={'rankdir': 'LR'}) 620 | dot = Digraph(comment=self.file_path) 621 | 622 | nodes = set(self.CG.keys()) # 不画没有出边的节点 623 | for edges in self.CG.values(): 624 | nodes |= set(edges) 625 | for edges in self.unknown_call.values(): 626 | nodes |= set(edges) 627 | 628 | for funcs in self.function().values(): 629 | for f in funcs: 630 | if f.id not in nodes: 631 | continue 632 | file_path = f.file_path.replace('\\', '/') 633 | label = f'name: {f.name}\\ntype: {f.type}\\nparameters: {list(f.signature["parameters"])}\\nfile path: {file_path}' 634 | # label = f.name 635 | dot.node(str(f.id), shape='rectangle', label=label, fontname='fangsong') 636 | for caller, callees in self.CG.items(): 637 | for callee in callees: 638 | dot.edge(str(callee), str(caller)) 639 | if unknown: 640 | for func in self.unknown_id: 641 | dot.node(str(self.unknown_id[func]), shape='rectangle', label=func, fontname='fangsong') 642 | for caller, callees in self.unknown_call.items(): 643 | for callee in callees: 644 | dot.edge(str(callee), str(self.unknown_id[caller])) 645 | if pdf: 646 | cur_path = os.path.abspath('.') 647 | root_path = os.path.commonprefix([self.file_path, cur_path]) 648 | rel_path = os.path.relpath(self.file_path, root_path).replace('\\', '/') 649 | output_path = os.path.join(root_path, "pdf", rel_path) 650 | dot.render(output_path, view=view, cleanup=True) 651 | 652 | def merge(self, other): 653 | self.structure_.update(other.structure_) 654 | self.idType.vars.update(other.idType.vars) 655 | for func_name, funcs in other.function().items(): 656 | self.function().setdefault(func_name, []) 657 | for func in funcs: 658 | if func not in self.function()[func_name]: 659 | self.function()[func_name].append(func) 660 | self.function.id_to_func.update(other.function.id_to_func) 661 | 662 | class Dir(AST): 663 | def __init__(self, path, language='cpp'): 664 | super().__init__(language) 665 | self.path = os.path.abspath(path) 666 | self.filepaths = [] 667 | for root, _, files in os.walk(self.path): 668 | for file in files: 669 | if file.split('.')[-1] in ['c', 'cpp', 'h', 'hpp', 'cc']: 670 | filepath = os.path.abspath(os.path.join(root, file)) 671 | relpath = os.path.relpath(filepath, self.path).replace('\\', '/') 672 | self.filepaths.append(relpath) 673 | 674 | self.Include = {f: [] for f in self.filepaths} 675 | self.files = {f: None for f in self.filepaths} 676 | self.Included = {f: [] for f in self.filepaths} 677 | self.Indegree = {f: 0 for f in self.filepaths} # 按照拓扑排序,先把入度为0(没有引用文件或者全部引用完的)的文件的函数加载 678 | self.load = {f: False for f in self.filepaths} 679 | for f in self.filepaths: 680 | self.files[f] = File('c', os.path.join(self.path, f)) 681 | code = r'{}'.format(open(os.path.join(self.path, f), 'r', encoding='utf-8', errors='ignore').read()) 682 | tree = self.parser.parse(bytes(code, 'utf8')) 683 | root_node = tree.root_node 684 | self.find_include(root_node, f) 685 | 686 | cycles = self.detect_cycles() 687 | temp_include = {} # 为了解决循环引用的问题,先把循环引用的文件的第一个引用关系去掉,保存在temp_include中 688 | if cycles: 689 | print("There exists cycles") 690 | for cycle in cycles: 691 | print(cycle[-1], end=' ') 692 | for f in cycle: 693 | print('-> ' + f, end=' ') 694 | self.Include[cycle[0]].remove(cycle[1]) 695 | self.Included[cycle[1]].remove(cycle[0]) 696 | self.Indegree[cycle[0]] -= 1 697 | temp_include[cycle[0]] = cycle[1] 698 | print() 699 | 700 | for f in self.filepaths: 701 | self.files[f].construct_file() 702 | 703 | while True: 704 | finish_include = True 705 | for f in self.files: 706 | if self.Indegree[f] == 0 and not self.load[f]: 707 | finish_include = False 708 | for include_file in self.Include[f]: 709 | self.files[f].merge(self.files[include_file]) 710 | self.Include[f] = [] 711 | self.files[f].construct_file() 712 | self.files[f].construct_call_graph() 713 | for included_file in self.Included[f]: 714 | self.Indegree[included_file] -= 1 715 | self.load[f] = True 716 | break 717 | if finish_include: 718 | if temp_include: 719 | for f, include in temp_include.items(): 720 | self.files[f].merge(self.files[include]) 721 | self.files[f].construct_file() 722 | self.files[f].construct_call_graph() 723 | for f, indegree in self.Indegree.items(): 724 | print(f) 725 | self.files[f].see_cg(view=False, unknown=True) 726 | if indegree != 0: 727 | print(f"\tError: not finish yet!") 728 | break 729 | 730 | def detect_cycles(self): 731 | '''检测self.Include是否有环,例如A引用了B,B引用了C,C引用了A,返回这样的环的列表[[A,B,C],[...]]''' 732 | visited = set() 733 | stack = [] 734 | cycles = [] 735 | def dfs(node): 736 | if node in stack: 737 | cycle = stack[stack.index(node):] 738 | cycles.append(cycle) 739 | return 740 | if node in visited: 741 | return 742 | visited.add(node) 743 | stack.append(node) 744 | for neighbor in self.Include.get(node, []): 745 | dfs(neighbor) 746 | stack.pop() 747 | for node in self.Include: 748 | dfs(node) 749 | return cycles 750 | 751 | def find_include(self, node, f): 752 | for child in node.children: 753 | if child.type == 'preproc_include': 754 | path = child.child_by_field_name('path') 755 | if path.type in ['string_literal', 'system_lib_string']: 756 | # if path.type == 'string_literal': 757 | include = text(path)[1:-1] 758 | if include == 'config.h': # VS的配置头文件忽略 759 | continue 760 | include_path = os.path.normpath(os.path.join(os.path.dirname(f), include)).replace('\\', '/') 761 | if include_path in self.Include and include_path != f: 762 | self.Include[f].append(include_path) 763 | else: 764 | file = os.path.basename(include_path) 765 | is_find = False 766 | for f_ in self.filepaths: 767 | if file == os.path.basename(f_): 768 | include_path = f_ 769 | if f == include_path: 770 | continue 771 | self.Include[f].append(include_path) 772 | is_find = True 773 | break 774 | if not is_find: 775 | # print(f'In file {f}, {include} not in files') 776 | continue 777 | self.Included[include_path].append(f) 778 | self.Indegree[f] += 1 779 | elif child.type == 'preproc_ifdef': # ifdef 780 | self.find_include(child, f) 781 | 782 | if __name__ == '__main__': 783 | # file = File('c', 'test/5/littlefs-master/lfs_util.h') 784 | # file.construct_file() 785 | # file.see_cg(view=True) 786 | dir = Dir('./test/1') 787 | --------------------------------------------------------------------------------