├── .gitignore ├── README.md ├── addr_search.py ├── archived_scripts ├── argument_analyzer.py ├── class_namespace_association.py ├── find_image_base.py ├── find_variable_args_for_func.py ├── function_finder.py ├── function_renamer.py ├── memory_enabled_constant_propogator.py ├── name_functions_from_string_param.py ├── rename_function_from_accessed_strings_guess.py └── vtable_finder.py ├── call_ref_utils.py ├── coverage_highlight.py ├── coverage_visualizer ├── afl_coverage_visualizer.py └── gather_qemu_coverage_data.sh ├── create_vtable.py ├── datatype_utils.py ├── decomp_utils.py ├── dfg_exporter.py ├── find_base_by_refs.py ├── find_ucmp_with_sub.py ├── find_unk_periphs.py ├── find_unknown_pointers.py ├── function_signature_utils.py ├── java_reflection_utils.py ├── loopfinder.py ├── pointer_utils.py ├── print_indexing_locations.py ├── prop_dt.py ├── register_utils.py ├── renamespace.py ├── tag_callback_registration.py ├── test ├── Makefile ├── include │ └── main.h └── src │ ├── int_under_overflow.c │ └── main.c ├── type_pointers_to_data.py └── type_propagator.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.swp 3 | *.o 4 | *.so 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ghidra Scripts 2 | 3 | 4 | ## API 5 | Ghidra's API is still pretty minimal, so a lot of these scripts just add another layer of API on-top of ghidra's existing `FlatProgramAPI` for functions that I have found to be useful but that aren't easily accessible or that require a bit more setup than I feel that they should. In general, I will try to keep a naming scheme of `*_utils` to mirror the strategy present in a lot of ghidra's code of making `static` `*Utilities` classes, which are generally the most usable and useful parts of ghidra's existing API. 6 | 7 | ### call_ref_utils.py 8 | Utilities for working will call references, mostly for following the call graph or finding callsites through thunks. 9 | 10 | 11 | ### datatype_utils.py 12 | Utilities for finding datatypes, finding datatypes that meet certain constraints, finding datatype usage within other datatypes, and finding field usage across the program as a whole. 13 | 14 | ### decomp_utils.py 15 | Utilities for interacting with ghidra's decompiler and `PCODE` as well as making associations between disassembled instructions, pcode operations, and decompiled pseudo-c. Also includes some utilities related to forward/backward slicing. 16 | 17 | ### java_reflection_utils.py 18 | Utilities for interacting with `java`'s reflection API through python 19 | 20 | ### loopfinder.py 21 | Utilities for interacting with loops. 22 | 23 | ### pointer_utils.py 24 | Utilities for searching for embedded addresses or address ranges 25 | 26 | ## addr_search.py 27 | Search for an embedded address. If ghidra doesn't find a reference to a function and you think it is getting called, run this script on the address to find potential references to it. 28 | 29 | ## A few notes about weirdness in scripts 30 | I try to write just about everything in `python` for these because it is quicker for me to write, but because ghidra uses `Jython` certain oddities are needed to improve the usability or functionality for things that would not necessarily be needed if I wrote these in `java`. 31 | 32 | ### The wierd import line 33 | I use the following line in almost all of the scripts, despite it being horrible practice for python: 34 | ```python 35 | from __main__ import * 36 | ``` 37 | 38 | This is a hack to make script development easier, as it allows you to do something like `from call_ref_utils import *` from the ghidra python interpreter and have the import work correctly, even if you utilize things that are default imports from `ghidra.program.flatapi.FlatProgramAPI`, like the `currentProgram` variable. I might change this in the future to make the scripts less cursed. 39 | 40 | ### Using Java's Reflection API 41 | Inheriting from `java` classes in python works, but it doesn't work for everything. As I understand it, inheriting from a class in `java` would allow you to access `protected` methods, constructors, and fields. Inheriting from a `java` class in `Jython` does not immediately give you access to `protected` fields, which makes `Overriding` `protected` methods inacessible, despite it being relatively acceptable behavior for a `java` class. To work around this (and to avoid having to write code in `java`), I have utilized java's reflection API to enable this behavior. I try to limit it, but I also don't intend to rewrite java classes from ghidra in python to adjust their behavior if I don't have to. 42 | 43 | -------------------------------------------------------------------------------- /addr_search.py: -------------------------------------------------------------------------------- 1 | # Search the address space of the current program for a pointer 2 | #@author Clifton Wolfe 3 | #@keybinding ctrl 0 4 | #@category Utils 5 | 6 | from ghidra.program.model.symbol import SourceType 7 | from pointer_utils import createPointerUtils 8 | import logging 9 | 10 | from __main__ import * 11 | 12 | log = logging.getLogger(__file__) 13 | log.addHandler(logging.StreamHandler()) 14 | log.setLevel(logging.INFO) 15 | 16 | 17 | selection = state.currentSelection 18 | if selection is None: 19 | log.debug("No selection detected, asking for address") 20 | addr = askAddress("Address to search for", 21 | "Enter address to search for") 22 | else: 23 | addr = selection.minAddress 24 | 25 | log.info("[+] Searching for %s", addr) 26 | 27 | ptr_util = createPointerUtils() 28 | 29 | match_addrs = ptr_util.search_for_pointer(addr) 30 | for addr in match_addrs: 31 | print("%s" % addr) 32 | -------------------------------------------------------------------------------- /archived_scripts/argument_analyzer.py: -------------------------------------------------------------------------------- 1 | # analyze arguments of function calls 2 | #@author Clifton Wolfe 3 | 4 | 5 | from ghidra.app.decompiler import DecompileOptions 6 | from ghidra.app.decompiler import DecompInterface 7 | from ghidra.util.task import ConsoleTaskMonitor 8 | from ghidra.program.model.pcode import PcodeOpAST 9 | from ghidra.program.flatapi import FlatProgramAPI 10 | from ghidra.python import PythonScript 11 | from ghidra.app.plugin.core.navigation.locationreferences import ReferenceUtils 12 | from ghidra.program.util import FunctionSignatureFieldLocation 13 | from ghidra.program.model.symbol import SourceType 14 | from ghidra.program.model.symbol import FlowType, RefType 15 | from ghidra.app.decompiler.component import DecompilerUtils 16 | from ghidra.program.model.symbol import SourceType 17 | from ghidra.program.model.listing.Function import FunctionUpdateType 18 | from ghidra.program.model.listing import ReturnParameterImpl 19 | from ghidra.program.model.listing import ParameterImpl 20 | from ghidra.program.model.data import PointerDataType 21 | from ghidra.program.database.data import StructureDB 22 | from ghidra.program.database.data import PointerDB 23 | # eventually would like to use this 24 | # from ghidra.app.cmd.function import ApplyFunctionSignatureCmd 25 | from collections import namedtuple, defaultdict 26 | import string 27 | import logging 28 | 29 | log = logging.getLogger(__file__) 30 | log.addHandler(logging.StreamHandler()) 31 | log.setLevel(logging.DEBUG) 32 | 33 | from __main__ import * 34 | 35 | 36 | IncomingCallNode = namedtuple("IncomingCallNode", ["function", "call_address"]) 37 | 38 | 39 | class FunctionArgumentAnalyzer: 40 | def __init__(self, currentProgram): 41 | self.fm = currentProgram.getFunctionManager() 42 | self.dtm = currentProgram.getDataTypeManager() 43 | self.addr_fact = currentProgram.getAddressFactory() 44 | self.addr_space = self.addr_fact.getDefaultAddressSpace() 45 | self.mem = currentProgram.getMemory() 46 | self.sym_tab = currentProgram.getSymbolTable() 47 | 48 | self._decomp_options = DecompileOptions() 49 | self._monitor = ConsoleTaskMonitor() 50 | self._ifc = DecompInterface() 51 | self._ifc.setOptions(self._decomp_options) 52 | self.refman = currentProgram.getReferenceManager() 53 | self.dropped_data_refs = [] 54 | self.dropped_callind_ops = [] 55 | 56 | def get_high_function(self, func, timeout=60): 57 | """ 58 | Get a HighFunction for a given function 59 | """ 60 | self._ifc.openProgram(func.getProgram()) 61 | res = self._ifc.decompileFunction(func, timeout, self._monitor) 62 | high_func = res.getHighFunction() 63 | return high_func 64 | 65 | def get_high_sym_for_param(self, func, param_num): 66 | """ 67 | Get the the high sym for param index 68 | """ 69 | high_func = self.get_high_function(func) 70 | prototype = high_func.getFunctionPrototype() 71 | num_params = prototype.getNumParams() 72 | if num_params == 0: 73 | return None 74 | high_sym_param = prototype.getParam(param_num) 75 | return high_sym_param 76 | 77 | def get_callsites_for_address(self, address): 78 | """ 79 | Iterate over all of the references to an address and pick out the ones 80 | that can be associated with a call to the provided address 81 | """ 82 | log.info("[+] Finding callsites for %s", str(address)) 83 | references = self.refman.getReferencesTo(address) 84 | incoming_calls = [] 85 | for ref in references: 86 | from_address = ref.getFromAddress() 87 | if ref.referenceType == RefType.DATA: 88 | self.dropped_data_refs.append(ref) 89 | log.warning("[-] Dropping a DATA ref at %s", str(from_address)) 90 | continue 91 | elif ref.referenceType == RefType.EXTERNAL_REF: 92 | continue 93 | elif ref.referenceType == FlowType.COMPUTED_CALL: 94 | # FIXME: this potentially introduces duplication of some work here 95 | incoming_calls.extend(self.get_callsites_for_address(from_address)) 96 | continue 97 | 98 | callerFunction = self.fm.getFunctionContaining(from_address) 99 | if callerFunction is None: 100 | log.warning("[-] Drop ref %s at %s" % (str(ref.referenceType), 101 | str(from_address))) 102 | continue 103 | incoming_calls.append(IncomingCallNode(callerFunction, from_address)) 104 | log.info("[+] Found %d callsites", len(incoming_calls)) 105 | return incoming_calls 106 | 107 | def get_pcode_ops_calling_func(self, func): 108 | """ 109 | Get all of the pcode ops that call the function @func 110 | """ 111 | incoming_calls = self.get_callsites_for_address(func.getEntryPoint()) 112 | additional_analysis_needed_funcs = set() 113 | incoming_functions = set([i.function for i in incoming_calls]) 114 | func_name = func.getName() 115 | 116 | call_ops = [] 117 | # iterate over functions that call the passed in function 118 | for calling_func_node in incoming_functions: 119 | current_function_name = calling_func_node.getName() 120 | log.info("[+] Identifying call ops in %s", str(current_function_name)) 121 | hf = self.get_high_function(calling_func_node) 122 | if hf is None: 123 | log.warning("[-] Failed to get a High function, unable to decompile") 124 | continue 125 | pcode_ops = list(hf.getPcodeOps()) 126 | func_address = func.getEntryPoint() 127 | 128 | for op in pcode_ops: 129 | if op.opcode == PcodeOpAST.CALLIND: 130 | # log.warning("[*] skipping a CALLIND at %s", str(op.seqnum.target)) 131 | self.dropped_callind_ops.append(op) 132 | continue 133 | if op.opcode != PcodeOpAST.CALL: 134 | continue 135 | 136 | # First input of CALL op is the address being called 137 | called_func_address = op.getInput(0).getAddress() 138 | called_func = getFunctionContaining(called_func_address) 139 | if called_func is None: 140 | log.warning("[-] A CALL op is calling into an undefined function (%s) from (%s)", 141 | str(called_func_address), str(op.seqnum.target)) 142 | continue 143 | 144 | # allow a little wiggle room for thunks by allowing a match by name too 145 | if called_func_address != func_address and called_func.getName() != func_name: 146 | continue 147 | call_ops.append(op) 148 | 149 | 150 | if len(call_ops) == 0: 151 | # if no call was found, it was an indirect reference 152 | log.warning("[-] No call found for %s" % current_function_name) 153 | continue 154 | return call_ops 155 | 156 | def get_pcode_calling_ops_by_func_name(self, name): 157 | """ 158 | Get all of the pcode ops that specify a call to functions named @name 159 | """ 160 | call_ops = [] 161 | for func in self.get_funcs_by_name(name): 162 | call_ops.extend(self.get_pcode_ops_calling_func(func)) 163 | return list(set(call_ops)) 164 | 165 | def get_backslice_ops_for_param_ind(self, call_op, param_ind): 166 | param_def = None 167 | param_varnode = call_op.getInput(param_index+1) 168 | backslice_ops = [] 169 | if param_varnode is None: 170 | return backslice_ops 171 | 172 | backslice_ops = DecompilerUtils.getBackwardSliceToPCodeOps(param_varnode) 173 | if backslice_ops is None: 174 | return [] 175 | return list(backslice_ops) 176 | 177 | 178 | def read_string_at(self, address, maxsize=256): 179 | """ 180 | Tries to extract strings from a binary 181 | """ 182 | while maxsize > 0: 183 | # This is supposed to handle the case of a string being very 184 | # close to the end of a memory region and the maxsize being larger 185 | # than the remainder 186 | try: 187 | string_bytearray = bytearray(getBytes(address, maxsize)) 188 | except: 189 | maxsize -= 1 190 | continue 191 | 192 | terminator_index = string_bytearray.find(b'\x00') 193 | extracted_string_bytes = string_bytearray[:terminator_index] 194 | try: 195 | decoded_extracted_string = extracted_string_bytes.decode() 196 | except: 197 | log.warning("Unable to decode as string") 198 | maxsize -= 1 199 | continue 200 | 201 | return decoded_extracted_string 202 | 203 | return "" 204 | 205 | def is_const_pcode_op(self, op): 206 | """ 207 | Check to see if all of the inputs for an operation are constants 208 | """ 209 | return any([vn for vn in op.getInputs() if not vn.isConstant()]) 210 | 211 | def resolve_pcode_call_parameter_varnode(self, call_op, param_index): 212 | raise NotImplementedError("Not fully implemented") 213 | param_varnode = call_op.getInput(param_index+1) 214 | if param_varnode is None: 215 | return 216 | backslice_ops = DecompilerUtils.getBackwardSliceToPCodeOps(param_varnode) 217 | if backslice_ops is None: 218 | backslice_ops = [] 219 | 220 | backslice_ops = list(backslice_ops) 221 | # check for empty list 222 | if not backslice_ops: 223 | # this means that there was a varnode created, it just wasn't 224 | # used in any ops. Happens with unmodified params 225 | # and const params (at least on i386) 226 | backslice = DecompilerUtils.getBackwardSlice(param_varnode) 227 | # FIXME: actually need to try to identify which varnode it is 228 | return backslice[0] 229 | 230 | 231 | # TODO: actually resolve the varnode 232 | return 233 | 234 | def get_pcode_for_function(self, func): 235 | """ 236 | Get Pcode ops for the function @func 237 | """ 238 | hf = self.get_high_function(func) 239 | return list(hf.getPcodeOps()) 240 | 241 | def get_data_accesses_from_function(self, func): 242 | pcode_ops = self.get_pcode_for_function(func) 243 | stackspace_id = self.addr_fact.getStackSpace().spaceID 244 | varnodes = set(sum([[op.getOutput()] + list(op.getInputs()) for op in pcode_ops], [])) 245 | # filter out the majority of nodes that are known to be out 246 | varnodes = [i for i in varnodes if i is not None and i.getSpace() != stackspace_id] 247 | # get all of the offsets that are within current addressSpace 248 | valid_data_addresses = [] 249 | for node in varnodes: 250 | addr = self.addr_space.getAddress(node.getOffset()) 251 | if self.mem.contains(addr): 252 | valid_data_addresses.append(addr) 253 | return valid_data_addresses 254 | 255 | def _get_pcode_op_copy_operand(self, pcode_ops, ptrsub_op): 256 | """ 257 | NOTE: Currently unused 258 | An initial attempt at a custom backslice that uses the stackspace of 259 | a function to identify sources and sinks for varnodes. 260 | """ 261 | if ptrsub_op.opcode == PcodeOpAST.COPY: 262 | return [self.addr_space.getAddress(ptrsub_op.getInput(0).getOffset())] 263 | 264 | non_register_varnode = [i for i in ptrsub_op.getInputs() if not i.isRegister()][0] 265 | stack_offset = non_register_varnode.offset 266 | stackspace_id = self.addr_fact.getStackSpace().spaceID 267 | copied_values = [] 268 | for op in pcode_ops: 269 | output = op.output 270 | if output is None: 271 | continue 272 | if output.offset != stack_offset: 273 | continue 274 | if output.getSpace() != stackspace_id: 275 | continue 276 | if op.opcode not in [PcodeOpAST.COPY]: 277 | continue 278 | # print("found one %s" % str(op)) 279 | string_address = self.addr_space.getAddress(op.getInput(0).getOffset()) 280 | copied_values.append(string_address) 281 | return copied_values 282 | 283 | def get_funcs_by_name(self, name): 284 | return [i for i in self.fm.getFunctions(1) if i.name == name] 285 | 286 | def has_complex_backslice(self, varnode): 287 | """ 288 | Try to determine whether or not the backslice for the given varnode 289 | is complex or not. Complex is arbitrarily defined as "whether or not 290 | this code can figure out all of the possible values for it" 291 | """ 292 | # a constant varnode should always be simple 293 | if varnode.isConstant(): 294 | return False 295 | 296 | backslice = DecompilerUtils.getBackwardSlice(varnode) 297 | if backslice is None: 298 | backslice = [] 299 | 300 | backslice = list(backslice) 301 | # check for empty list 302 | if not backslice: 303 | log.error("[!] There were no varnodes found for a backwards slice") 304 | return True 305 | 306 | 307 | for vn in backslice: 308 | # FIXME: This check is insufficient 309 | if vn.isRegister(): 310 | return True 311 | 312 | return False 313 | 314 | def filter_calls_with_simple_param(self, call_ops, param_index): 315 | """ 316 | Given a list of call_ops for a function, return the ones that 317 | have a sufficiently complex-enough backslice 318 | """ 319 | filtered_ops = [] 320 | for op in call_ops: 321 | varnode = op.getInput(param_index+1) 322 | if not self.has_complex_backslice(varnode): 323 | continue 324 | filtered_ops.append(op) 325 | return filtered_ops 326 | 327 | def get_descendant_called_funcs(self, func): 328 | """ 329 | Attempt to get every function that the specified function @func 330 | calls both directly and indirectly 331 | """ 332 | visited_functions = set() 333 | to_visit_stack = set([func]) 334 | call_ops_set = set([PcodeOpAST.CALLIND, PcodeOpAST.CALL]) 335 | while to_visit_stack: 336 | fun = to_visit_stack.pop() 337 | # immediately add it to visited so that it 338 | # can't be processed twice 339 | visited_functions.add(fun) 340 | pcode_ops = self.get_pcode_for_function(fun) 341 | call_ops = [i for i in pcode_ops if i.opcode in call_ops_set] 342 | for op in call_ops: 343 | if op.opcode == PcodeOpAST.CALLIND: 344 | # TODO: handle callinds 345 | continue 346 | if op.opcode != PcodeOpAST.CALL: 347 | raise NotImplementedError("Unhandled opcode encountered") 348 | 349 | # inp 0 is called addr 350 | called_addr = op.getInput(0).getAddress() 351 | called_func = getFunctionContaining(called_addr) 352 | if called_func is None: 353 | log.warning("Call to unknown function %s -> %s", 354 | op.seqnum.target, 355 | called_addr) 356 | continue 357 | if called_func not in visited_functions: 358 | to_visit_stack.add(called_func) 359 | return list(visited_functions) 360 | 361 | def get_called_funcs_with_param_as_argument(self, func, param_ind): 362 | """ 363 | Identify any function calls in the specified function @func that 364 | take the parameter at @param_ind (or an equivalent) as a parameter. 365 | """ 366 | 367 | high_func = self.get_high_function(func) 368 | proto = high_func.getFunctionPrototype() 369 | num_params = proto.getNumParams() 370 | param_high_sym = proto.getParam(param_ind) 371 | datatype = param_high_sym.getDataType() 372 | high_var = param_high_sym.getHighVariable() 373 | param_varnodes = set(high_var.getInstances()) 374 | pcode_ops = self.get_pcode_for_function(func) 375 | 376 | # TODO: this might be best as its own function 377 | # NOTE: a call to insertPtrsubZero in ActionSetCasts::castInput 378 | # NOTE: in Ghidra/Features/Decompiler/src/decompile/cpp/coreaction.cc 379 | # NOTE: can add an extra PTRSUB(vn, 0) op to the pcode when a struct 380 | # NOTE: ptr is passed to a function and the first field of the struct 381 | # NOTE: has the same "metatype" as the function's parameter, 382 | # NOTE: so long as the type of the struct doesn't match the type of 383 | # NOTE: the parameter, 384 | # NOTE: because in optimized code the two different actions look 385 | # NOTE: the same, and it clarifies something that the correct 386 | # NOTE: pcode op (a `CAST`) would not effectively communicate. 387 | # NOTE: Unfortunately, this also slightly breaks analysis here. 388 | 389 | # collect all of the varnodes that are the param or a direct 390 | # cast/copy of it 391 | added_varnode = True 392 | while added_varnode is True: 393 | added_varnode = False 394 | for op in pcode_ops: 395 | if op.opcode not in [PcodeOpAST.CAST, PcodeOpAST.COPY, 396 | PcodeOpAST.PTRSUB]: 397 | continue 398 | # Checking for a few things that would indicate that a 399 | # PTRSUB op was added (see note above). Should that 400 | # optimization be removed upstream then this can be cut 401 | if op.opcode == PcodeOpAST.PTRSUB: 402 | offset_vn = op.getInput(1) 403 | if not offset_vn.isConstant(): 404 | continue 405 | if offset_vn.getOffset() != 0: 406 | continue 407 | if not isinstance(datatype, PointerDataType): 408 | continue 409 | pointed_to_dt = datatype.getDataType() 410 | if not isinstance(pointed_to_dt, StructureDB): 411 | continue 412 | first_field = pointed_to_dt.getComponentAt(0) 413 | first_field_dt = first_field.dataType 414 | # TODO: this comparison should actually be against the 415 | # TODO: datatype of the parameter this value is being 416 | # TODO: `CAST`'d into, but that will get very expensive 417 | # TODO: very quickly, so focusing on the specific case 418 | # TODO: that this is meant to handle 419 | if not isinstance(first_field_dt, PointerDB): 420 | continue 421 | 422 | # only one input possible for COPY and CAST, PTRSUB should 423 | # always have the varnode that is variable as the first 424 | inp = op.getInput(0) 425 | outp = op.getOutput() 426 | if inp in param_varnodes and outp not in param_varnodes: 427 | param_varnodes.add(outp) 428 | added_varnodes = True 429 | # don't break here so that all of the ops before this 430 | # in the list don't have to be re-checked until 431 | # the next pass 432 | 433 | functions_taking_param_as_argument = set() 434 | for op in pcode_ops: 435 | # TODO: probably need to handle CALLIND here 436 | if op.opcode != PcodeOpAST.CALL: 437 | continue 438 | inputs_raw = list(op.getInputs()) 439 | called_addr_varnode = inputs_raw[0] 440 | # skip the first input because it is the call address 441 | inputs = inputs_raw[1:] 442 | 443 | # TODO: determine if this actually saves any time 444 | # filter out call ops that don't have a param varnode 445 | input_set = set(inputs) 446 | if not param_varnodes.intersection(input_set): 447 | continue 448 | 449 | for ind, param_inp in enumerate(inputs): 450 | if param_inp not in param_varnodes: 451 | continue 452 | 453 | called_func = getFunctionContaining(called_addr_varnode.getAddress()) 454 | if called_func is None: 455 | log.warning("Calling a function that isn't defined") 456 | continue 457 | functions_taking_param_as_argument.add((called_func, ind)) 458 | return list(functions_taking_param_as_argument) 459 | 460 | def realize_func_sig_from_op(self, call_op): 461 | """ 462 | Initial attempt 463 | @call_op PcodeOpAST CALL op 464 | compare the number of arguments between a decompiled 465 | function and the function before decompilation. 466 | This is kind of a discount ApplyFunctionSignatureCmd 467 | """ 468 | 469 | if call_op.opcode != PcodeOpAST.CALL: 470 | return 471 | 472 | func_addr = call_op.getInput(0).getAddress() 473 | func = getFunctionContaining(func_addr) 474 | if func is None: 475 | log.warning("Call to non-existant function %s -> %s", 476 | call_op.seqnum.target, 477 | func_addr) 478 | return 479 | log.info("Checking argument count for %s", func.getName()) 480 | 481 | # remove one input for called address 482 | op_arg_count = call_op.getNumInputs() - 1 483 | expected_param_count = func.getParameterCount() 484 | high_func = self.get_high_function(func) 485 | proto = high_func.getFunctionPrototype() 486 | proto_param_count = proto.getNumParams() 487 | # TODO: handle return type differences 488 | needs_return_fixup = False 489 | if func.returnType != proto.returnType: 490 | needs_return_fixup = True 491 | 492 | if expected_param_count >= proto_param_count: 493 | return 494 | log.info("%s expected param %d proto param %d", 495 | func.getName(), 496 | expected_param_count, 497 | proto_param_count) 498 | 499 | # TODO: handle additional arguments passed into the call 500 | # TODO: in the call_op 501 | num_params_to_use = proto_param_count 502 | 503 | params = [] 504 | for i in range(proto_param_count): 505 | high_sym = proto.getParam(i) 506 | param_def = ParameterImpl( 507 | high_sym.getName(), 508 | high_sym.getDataType(), 509 | currentProgram) 510 | params.append(param_def) 511 | 512 | # TODO: dynamically choose return type 513 | return_param = ReturnParameterImpl(proto.getReturnType(), 514 | currentProgram) 515 | func.updateFunction(func.callingConventionName, 516 | return_param, params, 517 | FunctionUpdateType.DYNAMIC_STORAGE_FORMAL_PARAMS, 518 | False, SourceType.USER_DEFINED) 519 | 520 | def _get_call_ops_for_descendant_funcs(self, func): 521 | desc_funcs = self.get_descendant_called_funcs(func) 522 | all_call_ops = [] 523 | for func in desc_funcs: 524 | try: 525 | pcode_ops = self.get_pcode_for_function(func) 526 | except: 527 | log.warning("Unable to get pcode for %s" % func.name) 528 | continue 529 | for op in pcode_ops: 530 | if op.opcode == PcodeOpAST.CALL: 531 | all_call_ops.append(op) 532 | return all_call_ops 533 | 534 | def realize_func_sig_for_descendant_funcs(self, func): 535 | call_ops = self._get_call_ops_for_descendant_funcs(func) 536 | 537 | grouped_ops = defaultdict(list) 538 | for op in call_ops: 539 | addr = op.getInput(0).getAddress() 540 | grouped_ops[addr].append(op) 541 | 542 | for ops_list in grouped_ops.values(): 543 | # only try to fix the arguments once for each function, 544 | # just in case one function is called many many times 545 | op = ops_list[0] 546 | self.realize_func_sig_from_op(op) 547 | 548 | 549 | def walk_pcode_until_handlable_op(varnode, maxcount=20): 550 | """ 551 | Naiive Backslice-like func that follows varnode definitions 552 | until a knows op is found 553 | """ 554 | param_def = varnode.getDef() 555 | # handling much more than a PTRSUB or COPY will likely require an actually intelligent traversal 556 | # of the pcode ast, if not emulation, as registers are assigned different types 557 | while param_def.opcode not in [PcodeOpAST.PTRSUB, PcodeOpAST.COPY] and maxcount > 0: 558 | if param_def.opcode == PcodeOpAST.CAST: 559 | varnode = param_def.getInput(0) 560 | else: 561 | varnode = param_def.getInput(1) 562 | param_def = varnode.getDef() 563 | maxcount -= 1 564 | 565 | return param_def 566 | 567 | 568 | -------------------------------------------------------------------------------- /archived_scripts/class_namespace_association.py: -------------------------------------------------------------------------------- 1 | # Namespace Association and class analysis to improve C++ analysis from 2 | # RecoverClassesFromRTTIScript.java and astrelsky/Ghidra-Cpp-Class-Analyzer. 3 | # The script currently requires one of those to have already been performed, 4 | # or for vtables to have either 'vtable' or 'vftable' in their label. 5 | # 6 | # The script is also a work in progress, so there will likely be improvements 7 | # made, especially when it comes to inheritance structures, call graph 8 | # analysis, and association of functions with a given class namespace. 9 | #@author Clifton Wolfe 10 | #@category C++ 11 | 12 | from collections import defaultdict 13 | from ghidra.app.decompiler import DecompileOptions 14 | from ghidra.app.decompiler import DecompInterface 15 | from ghidra.util.task import ConsoleTaskMonitor 16 | from ghidra.program.flatapi import FlatProgramAPI 17 | from ghidra.python import PythonScript 18 | 19 | # TODO: Try to autofill class structrues based on thisptr 20 | 21 | 22 | class ClassNamespaceAssociator: 23 | def __init__(self, currentProgram): 24 | self.fm = currentProgram.getFunctionManager() 25 | self.dtm = currentProgram.getDataTypeManager() 26 | self.namespace_manager = currentProgram.getNamespaceManager() 27 | self.addr_fact = currentProgram.getAddressFactory() 28 | self.addr_space = self.addr_fact.getDefaultAddressSpace() 29 | self.mem = currentProgram.getMemory() 30 | self.sym_tab = currentProgram.getSymbolTable() 31 | 32 | self.ptr_size = self.addr_space.getPointerSize() 33 | if self.ptr_size == 4: 34 | self._get_ptr_size = self.mem.getInt 35 | elif self.ptr_size == 8: 36 | self._get_ptr_size = self.mem.getLong 37 | 38 | self._null = self.addr_space.getAddress(0) 39 | self._global_ns = currentProgram.getGlobalNamespace() 40 | 41 | self._thiscall_str = u'__thiscall' 42 | self._vftable_str = u'vftable' 43 | self._vtable_str = u'vtable' 44 | self._pure_virtual_str = u'pure_virtual' 45 | 46 | self._decomp_options = DecompileOptions() 47 | self._monitor = ConsoleTaskMonitor() 48 | self._ifc = DecompInterface() 49 | self._ifc.setOptions(self._decomp_options) 50 | 51 | # lsm = high_func.getLocalSymbolMap() 52 | # symbols = lsm.getSymbols() 53 | self.func_associations = defaultdict(set) 54 | self.multiple_inheritance_functions = defaultdict(set) 55 | # store all found vtables/vftables, regardless of duplicates 56 | self.vftable_entries = {} 57 | # store all functions associated with each namespace, 58 | # removing duplicates 59 | self.namespace_functions = defaultdict(set) 60 | 61 | self.class_syms = defaultdict(list) 62 | self._populate_namespace_associated_symbols() 63 | self.analyze_function_associations() 64 | 65 | def get_high_function(self, func, timeout=60): 66 | """ 67 | Get a HighFunction for a given function 68 | """ 69 | self._ifc.openProgram(func.getProgram()) 70 | res = self._ifc.decompileFunction(func, timeout, self._monitor) 71 | high_func = res.getHighFunction() 72 | return high_func 73 | 74 | def _populate_namespace_associated_symbols(self): 75 | """ 76 | Populate a defaultdict(list) with symbols associated 77 | with each Class namespace that is currently accessible 78 | """ 79 | for namespace in self.sym_tab.getClassNamespaces(): 80 | for s in self.sym_tab.getChildren(namespace.getSymbol()): 81 | self.class_syms[s.getParentSymbol().getName()].append(s) 82 | 83 | def analyze_function_associations(self): 84 | """ 85 | Iterate each available class namespace to find labels/symbols 86 | called vtable/vftable, and collect data on which 87 | class namespace each function can be associated with. 88 | """ 89 | # clean up collected vftable entries to remove duplicates 90 | for namespace in self.sym_tab.getClassNamespaces(): 91 | for s in self.sym_tab.getChildren(namespace.getSymbol()): 92 | usable_vtable_symbol = False 93 | if s.name.find(self._vftable_str) != -1: 94 | usable_vtable_symbol = True 95 | 96 | if s.name.find(self._vtable_str) != -1: 97 | usable_vtable_symbol = True 98 | 99 | if not usable_vtable_symbol: 100 | continue 101 | 102 | vftable_entries = self.get_vftable_entries(s) 103 | parent_namespace = s.getParentNamespace() 104 | self.vftable_entries[s] = vftable_entries 105 | # if multiple vtable/vftable symbols are found for the 106 | # same namespace, the association should still be made 107 | # in the odd occurrance that there are multiple separate 108 | # vftables for the same class 109 | for e in vftable_entries: 110 | self.namespace_functions[parent_namespace].add(e) 111 | 112 | for func in vftable_entries: 113 | if func == self._null: 114 | continue 115 | 116 | self.func_associations[func].add(namespace) 117 | 118 | # Find vtable functions that could fit into multiple namespaces 119 | # (this likely means they are inherited) 120 | for func, namespaces in self.func_associations.items(): 121 | if len(namespaces) == 1: 122 | continue 123 | for n in namespaces: 124 | self.multiple_inheritance_functions[func].add(n) 125 | 126 | def set_function_associations(self, skip_thiscall_association=False): 127 | """ 128 | Associate functions that are only ever in a single 129 | class namespace with that namespace, as non-virtual functions. 130 | 131 | The calling convention for functions is also set to `__thiscall` 132 | unless disabled. 133 | """ 134 | # do the namespace association with each func 135 | for func, namespaces in self.func_associations.items(): 136 | if len(namespaces) == 1: 137 | if not self.is_external(func): 138 | self.set_parent_namespace_maybe_thunk(func, list(namespaces)[0]) 139 | else: 140 | # TODO: try to identify longer inheritance 141 | # structures to pick the 142 | # base class to associate the function with 143 | virtual_namespaces = [n for n in namespaces if self._is_class_namespace_virtual(n)] 144 | # pick off easier associations that are only inherited by 145 | # non-virtual classes 146 | if len(virtual_namespaces) == 1: 147 | if not self.is_external(func): 148 | self.set_parent_namespace_maybe_thunk(func, list(virtual_namespaces)[0]) 149 | 150 | if not skip_thiscall_association: 151 | self.set_calling_convention_maybe_thunk(func, self._thiscall_str) 152 | 153 | # search for and associate private functions with each class namespace 154 | for namespace in self.namespace_functions.keys(): 155 | priv_funcs = self._find_private_function_of_class_namespace(namespace) 156 | for func in priv_funcs: 157 | if not self.is_external(func): 158 | self.set_parent_namespace_maybe_thunk(func, namespace) 159 | 160 | if not skip_thiscall_association: 161 | self.set_calling_convention_maybe_thunk(func, self._thiscall_str) 162 | 163 | def is_external(self, func): 164 | if func == self._null: 165 | return False 166 | dethunked = func 167 | if func.thunk: 168 | dethunked = func.getThunkedFunction(True) 169 | 170 | return dethunked.external 171 | 172 | def set_calling_convention_maybe_thunk(self, func, calling_convention): 173 | while func.thunk: 174 | func.setCallingConvention(calling_convention) 175 | func = func.getThunkedFunction(False) 176 | func.setCallingConvention(calling_convention) 177 | 178 | def set_parent_namespace_maybe_thunk(self, func, namespace): 179 | while func.thunk: 180 | func.setParentNamespace(namespace) 181 | func = func.getThunkedFunction(False) 182 | 183 | if not func.external: # cant reparent external function 184 | func.setParentNamespace(namespace) 185 | 186 | def get_vftable_entries(self, vftable): 187 | """ 188 | Get a list of function pointer entries for a given vftable/vtable. 189 | The returned list may also include Address(0) (NULL) entries, 190 | as the vtable likely includes an uninitialized function poiner. 191 | """ 192 | vftable_addr = vftable.getAddress() 193 | if vftable.name.find(self._vtable_str) != -1: 194 | vftable_addr = vftable_addr.add(self.ptr_size*2) 195 | addr = vftable_addr 196 | funcs = [] 197 | while True: 198 | maybe_func_addr_val = self._get_ptr_size(addr) 199 | maybe_func_addr = self.addr_space.getAddress(maybe_func_addr_val) 200 | func = self.fm.getFunctionAt(maybe_func_addr) 201 | # if the ptr is a null ptr (uninitialized) or points 202 | # to a function, add it. Otherwise, assume the vtable is done 203 | if func is None: 204 | if maybe_func_addr_val == 0: 205 | func = self.addr_space.getAddress(0) 206 | else: 207 | break 208 | funcs.append(func) 209 | addr = addr.add(self.ptr_size) 210 | return funcs 211 | 212 | def get_datatype_of_thisptr(self, func): 213 | """ 214 | Get the datatype of a Function's `this` pointer. 215 | """ 216 | if func.getCallingConvention().name != self._thiscall_str: 217 | return None 218 | high_func = self.get_high_function(func) 219 | prot = high_func.getFunctionPrototype() 220 | num_params = prot.getNumParams() 221 | if num_params == 0: 222 | return None 223 | maybe_this = prot.getParam(0) 224 | return maybe_this.getDataType() 225 | 226 | def _is_class_namespace_virtual(self, namespace): 227 | """ 228 | Try to guess if a class namespace is virtual by looking for 229 | 'pure_virtual' in the function names associated with its vtable 230 | """ 231 | funcs = self.namespace_functions.get(namespace, []) 232 | for func in funcs: 233 | if func == self._null: 234 | continue 235 | 236 | if func.name.find(self._pure_virtual_str) != -1: 237 | return True 238 | 239 | return False 240 | 241 | def _find_private_function_of_class_namespace(self, namespace): 242 | """ 243 | Check each function that is called by the vtable functions of 244 | the given namespace to see if it is only called by functions 245 | in that vtable or their decendant functions. 246 | 247 | TODO: The graph traversal portion of this function is likely 248 | better implemented in java, so decendant call analysis is not 249 | implemented yet. Instead it is just a single call down for now. 250 | 251 | TODO: May need to check references to the called function as well 252 | 253 | TODO: handle thunks 254 | """ 255 | namespace_functions = self.namespace_functions.get(namespace, set()) 256 | found_private_functions = set() 257 | for func in namespace_functions: 258 | if func == self._null: 259 | continue 260 | called_functions = func.getCalledFunctions(self._monitor) 261 | for called_func in called_functions: 262 | calling_functions = set(called_func.getCallingFunctions(self._monitor)) 263 | if calling_functions.issubset(namespace_functions): 264 | found_private_functions.add(called_func) 265 | 266 | return found_private_functions 267 | 268 | 269 | # from class_mapper import ClassNamespaceAssociator 270 | # ca = ClassNamespaceAssociator(currentProgram) 271 | # ca.set_function_associations() 272 | 273 | if __name__ == '__main__': 274 | ca = ClassNamespaceAssociator(currentProgram) 275 | ca.set_function_associations() 276 | print("Done Running!") 277 | 278 | # funcs = ca.get_vftable_entries(cm.class_syms[u'ActiveLoggerImpl'][2]) 279 | # func = funcs[0] 280 | # datatype = ca.get_datatype_of_thisptr(func) 281 | # base_datatype_name = datatype.displayName.replace(' *', '') 282 | # [b] = [i for i in ca.dtm.getAllStructures() if i.getName() == base_datatype_name] 283 | -------------------------------------------------------------------------------- /archived_scripts/find_image_base.py: -------------------------------------------------------------------------------- 1 | # Idempotent script to try to find the base and regions 2 | # of firmware images 3 | # 4 | #@author Clifton Wolfe 5 | #@category C++ 6 | 7 | from collections import defaultdict 8 | from ghidra.app.decompiler import DecompileOptions 9 | from ghidra.app.decompiler import DecompInterface 10 | from ghidra.util.task import ConsoleTaskMonitor 11 | from ghidra.program.flatapi import FlatProgramAPI 12 | from ghidra.python import PythonScript 13 | from ghidra.app.util import MemoryBlockUtils 14 | from ghidra.program.model.address import AddressSet 15 | from ghidra.program.model.pcode import PcodeOpAST 16 | from ghidra.program.model.address import GenericAddress, Address 17 | from ghidra.program.model.symbol import FlowType 18 | import sys 19 | 20 | 21 | class DisassemblyHelper: 22 | def __init__(self, currentProgram): 23 | self.fm = currentProgram.getFunctionManager() 24 | self.dtm = currentProgram.getDataTypeManager() 25 | self.addr_fact = currentProgram.getAddressFactory() 26 | self.default_addr_space = self.addr_fact.getDefaultAddressSpace() 27 | self.mem = currentProgram.getMemory() 28 | self.sym_tab = currentProgram.getSymbolTable() 29 | 30 | self.ptr_size = self.default_addr_space.getPointerSize() 31 | if self.ptr_size == 4: 32 | self._get_ptr_size = self.mem.getInt 33 | elif self.ptr_size == 8: 34 | self._get_ptr_size = self.mem.getLong 35 | 36 | self._null = self.default_addr_space.getAddress(0) 37 | self._decomp_options = DecompileOptions() 38 | self._monitor = ConsoleTaskMonitor() 39 | self._ifc = DecompInterface() 40 | self._ifc.setOptions(self._decomp_options) 41 | 42 | def get_high_function(self, func, timeout=60): 43 | """ 44 | Get a HighFunction for a given function 45 | """ 46 | self._ifc.openProgram(func.getProgram()) 47 | res = self._ifc.decompileFunction(func, timeout, self._monitor) 48 | high_func = res.getHighFunction() 49 | return high_func 50 | 51 | def get_pcode_ops_for_func(self, func): 52 | """ 53 | Get a list of the PCODE ops for a function 54 | """ 55 | hf = self.get_high_function(func) 56 | if hf is None: 57 | return [] 58 | return hf.getPcodeOps() 59 | 60 | # TODO: RegionFinder 61 | # Track non-stack accesses/calls in rom to addresses that aren't mapped, 62 | # identify minimums and maximums. 63 | # group the accesses by page, 64 | # 65 | # 66 | # Identifying possible base address 67 | # While invalid memory accesses could indicate either the expected load 68 | # address of the binary or the address of mmio/ something else, 69 | # invalid call addresses indicate executable memory. Executable memory 70 | # cannot be uninitialized and work as intended, so one of the following 71 | # is likely expected to be in that location: 72 | # - this binary (or an in memory copy of it), at a different 73 | # base address. This might be the case if the binary 74 | # isn't PIE or has an entirely separate copy of itself 75 | # for a data segment 76 | # - a separate, uknown binary, like something placed into memory by 77 | # a different chip 78 | # - something like JIT code/self modifying code, placed there at 79 | # runtime 80 | # 81 | # This inference can be used as a data point and an attempt can be made 82 | # to identify if the called addresses actually correctly look like they 83 | # could be calls to this binary, just in a different location in memory 84 | # 85 | # 86 | 87 | 88 | class MemoryRegionFinder: 89 | def __init__(self, currentProgram, page_size=0x1000, 90 | error_thresh_pages=8): 91 | self.currentProgram = currentProgram 92 | self.dh = DisassemblyHelper(currentProgram) 93 | self.page_size = page_size 94 | # TODO: make sure jython doesn't mess with this like it 95 | # does with ctypes 96 | self._page_mask = sys.maxsize ^ (page_size - 1) 97 | self._page_addrs = [] 98 | self._invalid_calls = [] 99 | self._invalid_accesses = [] 100 | 101 | self._error_thres = self.page_size*error_thresh_pages 102 | self._find_invalid_accesses() 103 | self._consolidated_pages = self._get_addr_set_pages(self._invalid_accesses) 104 | 105 | def _find_invalid_accesses(self): 106 | """ 107 | Search for referenced addresses outside of the currently 108 | defined memory space 109 | https://gist.github.com/starfleetcadet75/cdc512db77d7f1fb7ef4611c2eda69a5 110 | """ 111 | listing = self.currentProgram.getListing() 112 | mem = self.dh.mem 113 | monitor = self.dh._monitor 114 | invalid_accesses = set() 115 | invalid_call_addrs = set() 116 | for instr in listing.getInstructions(1): 117 | if monitor.isCancelled(): 118 | break 119 | for ref in instr.getReferencesFrom(): 120 | to_addr = ref.getToAddress() 121 | if mem.contains(to_addr) or \ 122 | to_addr.isStackAddress() or \ 123 | to_addr.isRegisterAddress(): 124 | continue 125 | 126 | reftype = ref.getReferenceType() 127 | if reftype in [FlowType.UNCONDITIONAL_JUMP, 128 | FlowType.UNCONDITIONAL_CALL]: 129 | invalid_call_addrs.add(to_addr) 130 | else: 131 | invalid_accesses.add(to_addr) 132 | 133 | self._invalid_calls = list(invalid_call_addrs) 134 | self._invalid_calls.sort() 135 | self._invalid_accesses = list(invalid_accesses) 136 | self._invalid_accesses.sort() 137 | 138 | def _get_addr_set_pages(self, addrs): 139 | """ 140 | Consolidate the addresses in a list of addresses 141 | and return an address set that specifies pages instead 142 | of individual addresses 143 | Note: only uses minAddress, not maxAddress 144 | """ 145 | self._page_addrs = [i.getOffset() & self._page_mask 146 | for i in addrs] 147 | self._page_addrs = list(set(self._page_addrs)) 148 | self._page_addrs.sort() 149 | if len(self._page_addrs) == 0: 150 | return addrs 151 | 152 | # #speedhack 153 | create_new_addr_func = self.dh.default_addr_space.getAddress 154 | new_addr_set = AddressSet() 155 | region = [] 156 | last_addr = None 157 | for addr in self._page_addrs: 158 | if last_addr is None or \ 159 | (last_addr + self._error_thres) >= addr: 160 | region.append(addr) 161 | last_addr = addr 162 | continue 163 | 164 | min_addr = create_new_addr_func(region[0]) 165 | max_addr_val = region[-1] + self.page_size - 1 166 | max_addr = create_new_addr_func(max_addr_val) 167 | new_addr_set.addRange(min_addr, max_addr) 168 | region = [] 169 | last_addr = None 170 | 171 | if region: 172 | min_addr = create_new_addr_func(region[0]) 173 | max_addr_val = region[-1] + self.page_size - 1 174 | max_addr = create_new_addr_func(max_addr_val) 175 | new_addr_set.addRange(min_addr, max_addr) 176 | 177 | return new_addr_set 178 | 179 | 180 | # if __name__ == '__main__': 181 | # rbf = RomBaseFinder(currentProgram) 182 | # from find_image_base import * 183 | # dh = DisassemblyHelper(currentProgram) 184 | -------------------------------------------------------------------------------- /archived_scripts/find_variable_args_for_func.py: -------------------------------------------------------------------------------- 1 | 2 | from ghidra.program.model.pcode import PcodeOpAST 3 | from ghidra.program.flatapi import FlatProgramAPI 4 | from ghidra.python import PythonScript 5 | from argument_analyzer import * 6 | import logging 7 | 8 | from __main__ import * 9 | 10 | 11 | func_name = askString("Function Name", "Enter Function Name") 12 | param_ind = askInt("Parameter Index", "Enter parameter index (indexed from 0)") 13 | 14 | aa = FunctionArgumentAnalyzer(currentProgram) 15 | call_ops_for_target = aa.get_pcode_calling_ops_by_func_name(func_name) 16 | complex_call_ops = aa.filter_calls_with_simple_param(call_ops_for_target, param_ind) 17 | 18 | complex_call_ops.sort(key=lambda op: getFunctionContaining(op.seqnum.target).name) 19 | 20 | for call_op in complex_call_ops: 21 | caller_func = getFunctionContaining(call_op.seqnum.target) 22 | print("%s in %s" % (call_op.seqnum.target, caller_func.name)) 23 | -------------------------------------------------------------------------------- /archived_scripts/function_finder.py: -------------------------------------------------------------------------------- 1 | # Search for function prologues by providing a set of bytes to search 2 | # for 3 | #@author Clifton Wolfe 4 | 5 | import ghidra 6 | import re 7 | import binascii 8 | from __main__ import * 9 | 10 | def create_character_class_byte_range(start, end): 11 | """ 12 | Create a pre-escaped byte pattern that will work in re 13 | """ 14 | return b"\\%s-\\%s" % (bytearray([start]), bytearray([end])) 15 | 16 | 17 | def _gen_xtensa_entry_pattern(): 18 | """ 19 | Create a pattern for the xtensa ENTRY instruction 20 | """ 21 | character_class_inner = b'' 22 | start_add, end_add = 1, 0xf 23 | for i in range(0, 256, 0x10): 24 | new_range = create_character_class_byte_range(start_add+i, 25 | end_add+i) 26 | character_class_inner += new_range 27 | 28 | pattern = b'\x36[%s].' % character_class_inner 29 | return pattern 30 | 31 | 32 | def xtensa_entry_rexp_provider(): 33 | rexp = re.compile(_gen_xtensa_entry_pattern(), 34 | re.DOTALL | re.MULTILINE) 35 | return rexp 36 | 37 | 38 | def single_pattern_rexp_provider(): 39 | # using askBytes can result in an array containing a signed int, which 40 | # can't be processed correctly as a byte value 41 | byte_vals_from_user = askString("Enter bytes that mark a function entry", 42 | "search") 43 | byte_vals = binascii.unhexlify(byte_vals_from_user.replace(' ', '')) 44 | 45 | # python 2 requires bytearray to change to actual bytes 46 | byte_pattern = bytes(bytearray(list(byte_vals))) 47 | escaped_byte_pattern = re.escape(byte_pattern) 48 | 49 | byte_rexp = re.compile(escaped_byte_pattern, 50 | re.MULTILINE | re.DOTALL) 51 | 52 | 53 | def func_search(rexp_provider): 54 | byte_rexp = rexp_provider() 55 | 56 | memory_blocks = list(getMemoryBlocks()) 57 | 58 | # maybe add a filter here 59 | search_memory_blocks = memory_blocks 60 | 61 | for m_block in search_memory_blocks: 62 | if not m_block.isInitialized(): 63 | continue 64 | region_start = m_block.getStart() 65 | region_start_int = region_start.getOffset() 66 | search_bytes = getBytes(region_start, m_block.getSize()) 67 | iter_gen = re.finditer(byte_rexp, search_bytes) 68 | for m in iter_gen: 69 | addr = region_start.add(m.start()) 70 | func = getFunctionContaining(addr) 71 | if func is not None: 72 | continue 73 | disassemble(addr) 74 | createFunction(addr, "FUN_%s" % str(addr)) 75 | 76 | 77 | if __name__ == "__main__": 78 | func_search(single_pattern_rexp_provider) 79 | -------------------------------------------------------------------------------- /archived_scripts/function_renamer.py: -------------------------------------------------------------------------------- 1 | # Auto-rename functions across a file based on the string passed to a specific function. 2 | # It should be noted that the script only works for functions whose names start with `FUN_`, 3 | # to avoid overwriting user-named functions. 4 | # 5 | # It should also be noted that the script will only work if the parameter type has been 6 | # set correctly in the target function's signature. E.g. change `undefined8` to `char *`. 7 | # 8 | # The script Is meant to be a quick and easy solution, and it does not actually emulate or 9 | # interpret pcode in a meaningful way, it just tracks writes to register and stack locations 10 | # and relies on the assumption that in c and c++ a given space on the stack should only ever 11 | # be utilized for a single type E.g. a pointer on the stack that is used for a `char *` 12 | # should not ever be used to hold a `uint` unless there is a union containing the two types. 13 | # Keeping that in mind, the script can and will rename things incorrectly 14 | #@author Clifton Wolfe 15 | #@category C++ 16 | 17 | 18 | from ghidra.app.decompiler import DecompileOptions 19 | from ghidra.app.decompiler import DecompInterface 20 | from ghidra.util.task import ConsoleTaskMonitor 21 | from ghidra.program.model.pcode import PcodeOpAST 22 | from ghidra.program.flatapi import FlatProgramAPI 23 | from ghidra.python import PythonScript 24 | from ghidra.app.plugin.core.navigation.locationreferences import ReferenceUtils 25 | from ghidra.program.util import FunctionSignatureFieldLocation 26 | from ghidra.program.model.symbol import SourceType 27 | from collections import namedtuple 28 | import string 29 | import logging 30 | 31 | log = logging.getLogger(__file__) 32 | log.addHandler(logging.StreamHandler()) 33 | log.setLevel(logging.DEBUG) 34 | 35 | from __main__ import * 36 | 37 | 38 | IncomingCallNode = namedtuple("IncomingCallNode", ["function", "call_address"]) 39 | 40 | 41 | def get_location(func): 42 | return FunctionSignatureFieldLocation(func.getProgram(), 43 | func.getEntryPoint()) 44 | 45 | 46 | class FunctionRenamer: 47 | def __init__(self, currentProgram): 48 | self.fm = currentProgram.getFunctionManager() 49 | self.dtm = currentProgram.getDataTypeManager() 50 | self.namespace_manager = currentProgram.getNamespaceManager() 51 | self.addr_fact = currentProgram.getAddressFactory() 52 | self.addr_space = self.addr_fact.getDefaultAddressSpace() 53 | self.mem = currentProgram.getMemory() 54 | self.sym_tab = currentProgram.getSymbolTable() 55 | # NOTE: A better way to find this register needs to be found 56 | # if it is even still needed 57 | # self._stack_reg_offset = currentProgram.getRegister("sp").getOffset() 58 | 59 | self.ptr_size = self.addr_space.getPointerSize() 60 | if self.ptr_size == 4: 61 | self._get_ptr_size = self.mem.getInt 62 | elif self.ptr_size == 8: 63 | self._get_ptr_size = self.mem.getLong 64 | 65 | self._null = self.addr_space.getAddress(0) 66 | self._global_ns = currentProgram.getGlobalNamespace() 67 | self._decomp_options = DecompileOptions() 68 | self._monitor = ConsoleTaskMonitor() 69 | self._ifc = DecompInterface() 70 | self._ifc.setOptions(self._decomp_options) 71 | self.refman = currentProgram.getReferenceManager() 72 | 73 | def get_high_function(self, func, timeout=60): 74 | """ 75 | Get a HighFunction for a given function 76 | """ 77 | self._ifc.openProgram(func.getProgram()) 78 | res = self._ifc.decompileFunction(func, timeout, self._monitor) 79 | high_func = res.getHighFunction() 80 | return high_func 81 | 82 | def get_high_sym_for_param(self, func, param_num): 83 | """ 84 | Get the the high sym for param index 85 | """ 86 | high_func = self.get_high_function(func) 87 | prototype = high_func.getFunctionPrototype() 88 | num_params = prototype.getNumParams() 89 | if num_params == 0: 90 | return None 91 | high_sym_param = prototype.getParam(param_num) 92 | return high_sym_param 93 | 94 | def get_callsites_for_function(self, func): 95 | location = get_location(func) 96 | references = list(ReferenceUtils.getReferenceAddresses(location, self._monitor)) 97 | incoming_calls = [] 98 | for call_address in references: 99 | self._monitor.checkCanceled() 100 | callerFunction = self.fm.getFunctionContaining(call_address) 101 | if callerFunction is None: 102 | continue 103 | incoming_calls.append(IncomingCallNode(callerFunction, call_address)) 104 | return incoming_calls 105 | 106 | def get_previous_var_stack_offset_for_calling_function(self): 107 | pass 108 | 109 | def rename_functions_by_function_call(self, func, param_index, function_name_filter=None): 110 | incoming_calls = self.get_callsites_for_function(func) 111 | additional_analysis_needed_funcs = set() 112 | incoming_functions = set([i.function for i in incoming_calls]) 113 | for calling_func_node in incoming_functions: 114 | current_function_name = calling_func_node.getName() 115 | # calling_func_node = incoming_calls[1] 116 | hf = self.get_high_function(calling_func_node) 117 | pcode_ops = list(hf.getPcodeOps()) 118 | func_address = func.getEntryPoint() 119 | 120 | call_ops = [i for i in pcode_ops if i.opcode == PcodeOpAST.CALL and i.getInput(0).getAddress() == func_address] 121 | if len(call_ops) == 0: 122 | print("no call found for %s" % current_function_name) 123 | continue 124 | # call_op = call_ops[0] 125 | copied_values = [] 126 | param_def = None 127 | for call_op in call_ops: 128 | param_varnode = call_op.getInput(param_index+1) 129 | # check here if param is just the raw address. if not... 130 | try: 131 | param_def = walk_pcode_until_handlable_op(param_varnode) 132 | except Exception as err: 133 | # print(err) 134 | additional_analysis_needed_funcs.add(calling_func_node) 135 | continue 136 | copied_values += self.get_pcode_op_copy_operand(pcode_ops, param_def) 137 | 138 | if param_def is None: 139 | print("skipping %s" % current_function_name) 140 | continue 141 | # print("param def '%s'" % str(param_def)) 142 | # there is a weird roundabout way of looking stuff up here because there is a varnode being compared 143 | # with an arbitrary stackpointer offset 144 | # is_stackpointer_offset = any([i for i in param_def.getInputs() if i.isRegister() and i.getOffset() == self._stack_reg_offset]) 145 | # for whatever reason, the created varnode here gets put into unique space, not stack space, 146 | if len(copied_values) == 0: 147 | print("copied values for %s was empty" % current_function_name) 148 | possible_function_names = [self.read_string_at(i) for i in copied_values] 149 | if function_name_filter is not None: 150 | best_function_name = function_name_filter(possible_function_names) 151 | else: 152 | best_function_name = self.choose_best_function_name(possible_function_names) 153 | # print("best function name %s" % best_function_name) 154 | # TODO: identify whether the `SourceType` of a function's name can be accessed so that names don't get overwritten 155 | if best_function_name is not None and current_function_name != best_function_name and \ 156 | current_function_name.startswith("FUN_"): # so that other user defined function names don't get overwritten 157 | print("changing name from %s to %s" % (current_function_name, best_function_name)) 158 | calling_func_node.setName(best_function_name, SourceType.USER_DEFINED) 159 | 160 | for i in additional_analysis_needed_funcs: 161 | print("\n** %s likely requires manual analysis or decompilation fixups" % i.getName()) 162 | 163 | 164 | def read_string_at(self, address, maxsize=256): 165 | """ 166 | Tries to extract strings from a binary 167 | """ 168 | while maxsize > 0: 169 | # This is supposed to handle the case of a string being very 170 | # close to the end of a memory region and the maxsize being larger 171 | # than the remainder 172 | try: 173 | string_bytearray = bytearray(getBytes(address, maxsize)) 174 | except: 175 | maxsize -= 1 176 | continue 177 | 178 | terminator_index = string_bytearray.find(b'\x00') 179 | extracted_string_bytes = string_bytearray[:terminator_index] 180 | try: 181 | decoded_extracted_string = extracted_string_bytes.decode() 182 | except: 183 | log.warning("Unable to decode as string") 184 | break 185 | return decoded_extracted_string 186 | 187 | return "" 188 | 189 | def get_pcode_op_copy_operand(self, pcode_ops, ptrsub_op): 190 | """ 191 | Somewhat naive backslice 192 | """ 193 | if ptrsub_op.opcode == PcodeOpAST.COPY: 194 | return [self.addr_space.getAddress(ptrsub_op.getInput(0).getOffset())] 195 | 196 | non_register_varnode = [i for i in ptrsub_op.getInputs() if not i.isRegister()][0] 197 | stack_offset = non_register_varnode.offset 198 | stackspace_id = self.addr_fact.getStackSpace().spaceID 199 | copied_values = [] 200 | for op in pcode_ops: 201 | output = op.output 202 | if output is None: 203 | continue 204 | if output.offset != stack_offset: 205 | continue 206 | if output.getSpace() != stackspace_id: 207 | continue 208 | if op.opcode not in [PcodeOpAST.COPY]: 209 | continue 210 | # print("found one %s" % str(op)) 211 | string_address = self.addr_space.getAddress(op.getInput(0).getOffset()) 212 | copied_values.append(string_address) 213 | return copied_values 214 | 215 | def sort_function_name_candidates(self, function_name_candidates, allow_unprintable=False): 216 | """ 217 | Return a sorted list of function 218 | """ 219 | possible_function_names = [i for i in function_name_candidates if i is not None and i.find(" ") == -1] 220 | if len(possible_function_names) == 0: 221 | return None 222 | if len(possible_function_names) == 1: 223 | # list is already "sorted", only one option 224 | return possible_function_names 225 | 226 | printable_set = set(string.printable[:-5]) 227 | punctuation_set = set(['[', '\\', ']', '^', '`', '!', '"', '#', "'", '+', '-', '/', ';', '{', '|', '=', '}', ]) 228 | possible_function_names = list(set(possible_function_names)) 229 | if allow_unprintable is False: 230 | possible_function_names = [i for i in possible_function_names if set(i).issubset(printable_set)] 231 | possible_function_names = [i for i in possible_function_names if len(i) >= 3] 232 | # punctuation_set = set(string.punctuation) 233 | possible_function_names.sort(key=lambda a: ( 234 | not(set(a).issubset(printable_set)), # lower priority of strings that aren't printable by the most possible 235 | not(len(a) >= 3), # strings that are really short are lowered by a large amount 236 | len([i for i in a if i in punctuation_set]), # lower priority of strings with lots of punctuation 237 | a.find(' ') != -1, # spaces in the string are acceptable, but definitely aren't the 238 | # function name 239 | contains_path_markers(a), # lower priority of strings that are file paths, but use them as as 240 | # a last resort 241 | )) 242 | # DEBUG HACK 243 | # for i in range(10 if len(possible_function_names) > 10 else len(possible_function_names)): 244 | # print("%d: %s" % (i, possible_function_names[i])) 245 | return possible_function_names 246 | 247 | def choose_best_function_name(self, function_name_candidates): 248 | function_name_candidates = self.sort_function_name_candidates(function_name_candidates) 249 | return function_name_candidates[0] 250 | 251 | def get_pcode_for_function(self, func): 252 | if isinstance(func, str): 253 | func = [i for i in self.fm.getFunctions(1) if i.getName() == func][0] 254 | hf = self.get_high_function(func) 255 | return list(hf.getPcodeOps()) 256 | 257 | def get_data_accesses_from_function(self, func): 258 | pcode_ops = self.get_pcode_for_function(func) 259 | stackspace_id = self.addr_fact.getStackSpace().spaceID 260 | varnodes = set(sum([[op.getOutput()] + list(op.getInputs()) for op in pcode_ops], [])) 261 | # filter out the majority of nodes that are known to be out 262 | varnodes = [i for i in varnodes if i is not None and i.getSpace() != stackspace_id] 263 | # get all of the offsets that are within current addressSpace 264 | valid_data_addresses = [] 265 | for node in varnodes: 266 | addr = self.addr_space.getAddress(node.getOffset()) 267 | if self.mem.contains(addr): 268 | valid_data_addresses.append(addr) 269 | return valid_data_addresses 270 | 271 | def rename_function_from_accessed_strings_guess(self, func): 272 | valid_data_addresses = self.get_data_accesses_from_function(func) 273 | maybe_strings = [self.read_string_at(i) for i in valid_data_addresses] 274 | maybe_strings = [i for i in maybe_strings if i != ''] 275 | chosen_function_name = self.choose_best_function_name(maybe_strings) 276 | func.setName(chosen_function_name, SourceType.USER_DEFINED) 277 | 278 | 279 | 280 | def contains_path_markers(s): 281 | return s.find("\\") != -1 or s.find("/") != -1 282 | 283 | 284 | def walk_pcode_until_handlable_op(varnode, maxcount=20): 285 | param_def = varnode.getDef() 286 | # handling much more than a PTRSUB or COPY will likely require an actually intelligent traversal 287 | # of the pcode ast, if not emulation, as registers are assigned different types 288 | while param_def.opcode not in [PcodeOpAST.PTRSUB, PcodeOpAST.COPY] and maxcount > 0: 289 | if param_def.opcode == PcodeOpAST.CAST: 290 | varnode = param_def.getInput(0) 291 | else: 292 | varnode = param_def.getInput(1) 293 | param_def = varnode.getDef() 294 | maxcount -= 1 295 | 296 | return param_def 297 | 298 | 299 | # from function_renamer import * 300 | # fr = FunctionRenamer(currentProgram) 301 | -------------------------------------------------------------------------------- /archived_scripts/memory_enabled_constant_propogator.py: -------------------------------------------------------------------------------- 1 | # this is taken from one of the issues 2 | # https://github.com/NationalSecurityAgency/ghidra/issues/3581 3 | 4 | from __main__ import * 5 | 6 | from ghidra.app.plugin.core.analysis import ConstantPropagationContextEvaluator 7 | from ghidra.app.script import GhidraScript 8 | from ghidra.program.model.address import Address 9 | from ghidra.program.model.address import AddressSpace 10 | from ghidra.program.model.listing import Function 11 | from ghidra.program.model.listing import Program 12 | from ghidra.program.model.listing import ProgramContext 13 | from ghidra.program.model.listing import Variable 14 | from ghidra.program.model.pcode import Varnode 15 | from ghidra.program.util import ContextEvaluator 16 | from ghidra.program.util import SymbolicPropogator 17 | from ghidra.program.util import VarnodeContext 18 | 19 | from java_reflection_utils import get_accessible_java_field 20 | 21 | 22 | 23 | class MemoryEnabledVarnodeContext(VarnodeContext): 24 | 25 | def __init__(self, program, programContext, spaceProgramContext): 26 | super(MemoryEnabledVarnodeContext, self).__init__(program, programContext, 27 | spaceProgramContext) 28 | # accessing protected fields is a little bit painful 29 | memoryVals_field = get_accessible_java_field(VarnodeContext, "memoryVals") 30 | self.memoryVals = memoryVals_field.get(self) 31 | addrFactory_field = get_accessible_java_field(VarnodeContext, "addrFactory") 32 | self.addrFactory = addrFactory_field.get(self) 33 | 34 | def newGetMemoryValue(self, varnode): 35 | return self.getMemoryValue(varnode) 36 | 37 | 38 | def putMemoryValue(self, out, value): 39 | print("putMemoryValue called:") 40 | print("out: " + str(out) + " (" + str(out.getClass()) + ")") 41 | print("value: " + str(value) + " (" + str(value.getClass()) + ")") 42 | print("") 43 | super(MemoryEnabledVarnodeContext, self).putMemoryValue(out, value) 44 | 45 | def dumpMemory(self): 46 | for mem in self.memoryVals: 47 | for v in mem.keySet(): 48 | print("# memory entry") 49 | print(str(v) + ": " + mem.get(v).toString()) 50 | print("space id: " + str(v.getSpace())) 51 | print("offset: " + str(v.getOffset())) 52 | print("") 53 | 54 | def getAddressSpaceItself(self, name): 55 | return self.addrFactory.getAddressSpace(name) 56 | 57 | class MemoryEnabledSymbolicPropogator(SymbolicPropogator): 58 | def __init__(self, program=None): 59 | if program is None: 60 | program = currentProgram 61 | super(MemoryEnabledSymbolicPropogator, self).__init__(program) 62 | program_context_field = get_accessible_java_field(SymbolicPropogator, 63 | "programContext") 64 | program_context = program_context_field.get(self) 65 | space_context_field = get_accessible_java_field(SymbolicPropogator, 66 | "spaceContext") 67 | space_context = space_context_field.get(self) 68 | new_context = MemoryEnabledVarnodeContext(program, program_context, space_context) 69 | context_field = get_accessible_java_field(SymbolicPropogator, "context") 70 | context_field.set(self, new_context) 71 | self.context = new_context 72 | context_field.get(self).setDebug(True) 73 | 74 | def getMemoryValue(self, toAddr, memory): 75 | return None 76 | 77 | def getContext(self): 78 | return self.context 79 | 80 | 81 | func = currentProgram.getFunctionManager().getFunctionContaining(state.currentAddress) 82 | 83 | if func is None: 84 | print("there is no current function!") 85 | raise Exception("") 86 | 87 | start = func.getEntryPoint() 88 | 89 | evl = ConstantPropagationContextEvaluator(monitor, True) 90 | symEval = MemoryEnabledSymbolicPropogator(currentProgram) 91 | symEval.flowConstants(start, func.getBody(), evl, True, monitor) 92 | 93 | # // get the internal address space used by the propogator for ESP 94 | espSpace = symEval.getContext().getAddressSpaceItself("RSP") 95 | print("ESP address space: " + str(espSpace)) 96 | 97 | for v in func.getStackFrame().getLocals(): 98 | print("local variable: " + v.toString()) 99 | use = v.getFirstStorageVarnode() 100 | print("first use varnode: " + use.toString()) 101 | # // create the varnode the internal propogator would have used for this local 102 | translatedOffset = use.getOffset() # + 0x100000000 103 | contextVarnode = Varnode(espSpace.getTruncatedAddress(translatedOffset, True), use.getSize()) 104 | print("equivalent varnode: " + contextVarnode.toString()) 105 | 106 | # // search for it! 107 | result = symEval.getContext().newGetMemoryValue(contextVarnode) 108 | if result is None: 109 | print("no symbolic entry found") 110 | else: 111 | print("found symbolic entry: " + result.toString()) 112 | print("") 113 | -------------------------------------------------------------------------------- /archived_scripts/name_functions_from_string_param.py: -------------------------------------------------------------------------------- 1 | # Auto-rename functions across a file based on the string passed to a specific function. 2 | # It should be noted that the script only works for functions whose names start with `FUN_`, 3 | # to avoid overwriting user-named functions. 4 | # 5 | # It should also be noted that the script will only work if the parameter type has been 6 | # set correctly in the target function's signature. E.g. change `undefined8` to `char *`. 7 | # 8 | # The script Is meant to be a quick and easy solution, and it does not actually emulate or 9 | # interpret pcode in a meaningful way, it just tracks writes to register and stack locations 10 | # and relies on the assumption that in c and c++ a given space on the stack should only ever 11 | # be utilized for a single type E.g. a pointer on the stack that is used for a `char *` 12 | # should not ever be used to hold a `uint` unless there is a union containing the two types. 13 | # Keeping that in mind, the script can and will rename things incorrectly 14 | #@author Clifton Wolfe 15 | #@category C++ 16 | from ghidra.program.flatapi import FlatProgramAPI 17 | from ghidra.python import PythonScript 18 | 19 | from __main__ import * 20 | from function_renamer import FunctionRenamer 21 | 22 | 23 | def main(): 24 | fr = FunctionRenamer(currentProgram) 25 | funcname = askString("Which function's calls are you targeting? ", "") 26 | prompt = "Which parameter of %s? Please note that the argument types must be set correctly for this parameter" % funcname 27 | int1 = askInt(prompt, "enter parameter number") 28 | func = [i for i in fr.fm.getFunctions(1) if i.name == funcname][0] 29 | fr.rename_functions_by_function_call(func, int1) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() -------------------------------------------------------------------------------- /archived_scripts/rename_function_from_accessed_strings_guess.py: -------------------------------------------------------------------------------- 1 | # Auto-rename current function based on data accesses made within that function 2 | #@author Clifton Wolfe 3 | #@category C++ 4 | from ghidra.program.flatapi import FlatProgramAPI 5 | from ghidra.python import PythonScript 6 | 7 | from __main__ import * 8 | from function_renamer import FunctionRenamer 9 | 10 | 11 | def main(): 12 | fr = FunctionRenamer(currentProgram) 13 | func = fr.fm.getFunctionAt(currentLocation.getFunctionEntryPoint()) 14 | fr.rename_function_from_accessed_strings_guess(func) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() -------------------------------------------------------------------------------- /call_ref_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __main__ import * 3 | 4 | from collections import defaultdict 5 | from ghidra.program.model.symbol import FlowType, RefType 6 | 7 | 8 | def get_calling_addresses_to_address(address, program=None): 9 | """ 10 | get the addresses that call @address 11 | """ 12 | if program is None: 13 | program = currentProgram 14 | refman = program.getReferenceManager() 15 | calling_addrs = list() 16 | references = refman.getReferencesTo(address) 17 | for ref in references: 18 | ref_type = ref.getReferenceType() 19 | if ref_type.isCall() is False: 20 | continue 21 | calling_addrs.append(ref.fromAddress) 22 | return calling_addrs 23 | 24 | 25 | def get_called_addresses_from_address(address, program=None): 26 | """ 27 | get the addresses that call @address 28 | """ 29 | if program is None: 30 | program = currentProgram 31 | refman = program.getReferenceManager() 32 | called_addrs = list() 33 | references = refman.getReferencesFrom(address) 34 | for ref in references: 35 | ref_type = ref.getReferenceType() 36 | if ref_type.isCall() is False: 37 | continue 38 | called_addrs.append(ref.toAddress) 39 | return called_addrs 40 | 41 | 42 | def get_callsites_for_func_by_name(func_name, program=None): 43 | """ 44 | Return a dictionary of {Function: [call address, ..]} 45 | of functions that call @func_name 46 | """ 47 | if program is None: 48 | program = currentProgram 49 | 50 | # get all functions (including thunks) with the same name 51 | funcs = [i for i in program.getFunctionManager().getFunctions(1) \ 52 | if i.name == func_name] 53 | 54 | callsites = defaultdict(list) 55 | for func in funcs: 56 | entry = func.getEntryPoint() 57 | calling_addresses = get_calling_addresses_to_address(entry, program) 58 | for calling_addr in calling_addresses: 59 | calling_func = getFunctionContaining(calling_addr) 60 | # ignore thunks, they should already be in the list 61 | # so they will be processed 62 | if calling_func.name == func_name: 63 | continue 64 | callsites[calling_func].append(calling_addr) 65 | return dict(callsites) 66 | 67 | 68 | def function_calls_self(func, program=None): 69 | """ 70 | Check if a function calls itself 71 | """ 72 | if program is None: 73 | program = currentProgram 74 | entry = func.getEntryPoint() 75 | calling_addrs = get_calling_addresses_to_address(entry, program) 76 | return any([func.body.contains(a) for a in calling_addrs]) 77 | 78 | 79 | def get_all_functions_leading_to(func, program=None): 80 | """ 81 | Get a list of all functions that could call into @func and 82 | any functions that call those functions, etc. 83 | """ 84 | if program is None: 85 | program = currentProgram 86 | 87 | if func is None: 88 | return set() 89 | 90 | to_visit = set([func]) 91 | visited = set() 92 | while to_visit: 93 | curr_func = to_visit.pop() 94 | entry = curr_func.getEntryPoint() 95 | calling_addrs = get_calling_addresses_to_address(entry, program) 96 | for calling_addr in calling_addrs: 97 | calling_func = getFunctionContaining(calling_addr) 98 | if calling_func in visited: 99 | continue 100 | if calling_func in to_visit: 101 | continue 102 | if calling_func == curr_func: 103 | continue 104 | to_visit.add(calling_func) 105 | visited.add(curr_func) 106 | 107 | func_calls_self = function_calls_self(func, program) 108 | # check if func calls itself to determine if it needs to be removed 109 | if func_calls_self is False: 110 | visited.remove(func) 111 | return visited 112 | 113 | 114 | def get_all_functions_called_from(func, program=None): 115 | """ 116 | Get a list of all functions called by @func and 117 | any functions that are called by those functions, etc. 118 | """ 119 | if program is None: 120 | program = currentProgram 121 | 122 | if func is None: 123 | return set() 124 | 125 | to_visit = set([func]) 126 | visited = set() 127 | while to_visit: 128 | curr_func = to_visit.pop() 129 | called_addrs = [] 130 | for rang in curr_func.getBody(): 131 | for addr in rang: 132 | called_addrs += list(get_called_addresses_from_address(addr, program=program)) 133 | # called_addrs = curr_func.getCalledFunctions(monitor_inst) 134 | for called_addr in called_addrs: 135 | called_func = getFunctionContaining(called_addr) 136 | if called_func is None: 137 | continue 138 | if called_func in visited: 139 | continue 140 | if called_func in to_visit: 141 | continue 142 | if called_func == curr_func: 143 | continue 144 | to_visit.add(called_func) 145 | visited.add(curr_func) 146 | 147 | func_calls_self = function_calls_self(func, program) 148 | # check if func calls itself to determine if it needs to be removed 149 | if func_calls_self is False: 150 | visited.remove(func) 151 | return visited 152 | 153 | 154 | -------------------------------------------------------------------------------- /coverage_highlight.py: -------------------------------------------------------------------------------- 1 | # Visualize the coverage from a list of addresses 2 | #@author Clifton Wolfe 3 | 4 | import ghidra 5 | from ghidra.program.model.block import BasicBlockModel 6 | from ghidra.program.model.address import AddressSet 7 | from ghidra.program.model.address import AddressRangeImpl 8 | import os 9 | import re 10 | 11 | from java.awt import Color 12 | 13 | from __main__ import * 14 | 15 | COLOR_DEFAULT = Color(255, 255, 255) # white 16 | COLOR_VISITED = Color(137, 207, 240) # light blue 17 | COLOR_UNVISITED = Color(178, 34, 34) # dark red 18 | 19 | 20 | def get_instruction_addr_set_for_addresses(addresses): 21 | addr_set = AddressSet() 22 | listing = currentProgram.getListing() 23 | for addr in addresses: 24 | cu = listing.getCodeUnitAt(addr) 25 | if cu is None: 26 | continue 27 | addr_range = AddressRangeImpl(cu.minAddress, cu.maxAddress) 28 | addr_set.add(addr_range) 29 | return addr_set 30 | 31 | 32 | def get_basic_block_addr_set_for_addresses(addresses): 33 | addr_set = AddressSet() 34 | bbm = BasicBlockModel(currentProgram) 35 | code_blocks = list(bbm.getCodeBlocksContaining(addresses, 36 | monitor)) 37 | for block in code_blocks: 38 | for addr_range in block.addressRanges: 39 | addr_set.add(addr_range) 40 | return addr_set 41 | 42 | 43 | def entrypoint(): 44 | file = askFile("File", 45 | "Path to a file containing visited addresses") 46 | filepath = file.toString() 47 | with open(filepath, "r") as f: 48 | c = f.read() 49 | 50 | rexp = re.compile("(0x[a-fA-F0-9]+)") 51 | addrs = [int(m.groups()[0], 16) for m in re.finditer(rexp, c)] 52 | 53 | choices = askChoices("Coverage Highlighter", 54 | "Select how you would like to highlight coverage", 55 | ["instruction", 56 | "basic_block"], 57 | ["Instruction-level granularity", 58 | "Basic block-level granularity"]) 59 | if "instruction" in choices: 60 | addr_set = get_instruction_addr_set_for_addresses(addrs) 61 | elif "basic_block" in choices: 62 | addr_set = get_basic_block_addr_set_for_addresses(addrs) 63 | 64 | setBackgroundColor(addr_set, COLOR_VISITED) 65 | 66 | 67 | if __name__ == "__main__": 68 | entrypoint() 69 | -------------------------------------------------------------------------------- /coverage_visualizer/afl_coverage_visualizer.py: -------------------------------------------------------------------------------- 1 | # Visualize the coverage from an afl++ bitmap in ghidra. 2 | # Generate a showmap file with 3 | # `afl-showmap -C -i -o showmap -- ` 4 | # NOTE: it is possible for certain blocks at the start of your binary to be 5 | # missed if afl is running in persistent mode, so a red block at the start of 6 | # main is likely to be valida 7 | #@author Clifton Wolfe 8 | 9 | import ghidra 10 | from ghidra.program.model.block import BasicBlockModel 11 | from ghidra.program.model.address import AddressSet 12 | from ghidra.program.model.symbol import FlowType 13 | from ghidra.program.model.symbol import SymbolType 14 | from collections import namedtuple, defaultdict 15 | import os 16 | import re 17 | 18 | from java.awt import Color 19 | 20 | from __main__ import * 21 | 22 | # offset is used here instead of index because the the values stored 23 | # in each of the uint32_t's in the __sancov_guards section is an 24 | # index value, but these indexes do not start with zero so it would 25 | # get confusing to have two different index values in the object 26 | SanCovPCGuardRef = namedtuple("SanCovPCGuardRef", ["ref", 27 | "calling_function", 28 | "offset"]) 29 | 30 | SANCOV_PC_GUARD_VALUE_SIZE = 4 31 | # this might change in different versions or with different builds, 32 | # but so far it seems to start here 33 | SANCOV_PC_GUARD_START_INDEX = 6 34 | COLOR_DEFAULT = Color(255, 255, 255) # white 35 | COLOR_VISITED = Color(137, 207, 240) # light blue 36 | COLOR_UNVISITED = Color(178, 34, 34) # dark red 37 | 38 | 39 | def get_pc_guard_refs(): 40 | """ 41 | This function finds the __sancov_guards section in memory and 42 | returns a list of SanCovPCGuardRefs that are referred to by 43 | code in the currentProgram 44 | """ 45 | nsm = currentProgram.getNamespaceManager() 46 | global_namespace = nsm.getGlobalNamespace() 47 | 48 | start_sym = getSymbol("__start___sancov_guards", global_namespace) 49 | stop_sym = getSymbol("__stop___sancov_guards", global_namespace) 50 | 51 | if start_sym is None or stop_sym is None: 52 | print("Currently only binaries built with sancov pc guards are supported") 53 | exit(1) 54 | 55 | sancov_guards_addr = start_sym.getAddress() 56 | sancov_guards_size = stop_sym.getAddress().subtract(sancov_guards_addr) 57 | pc_guard_refs = [] 58 | for offset in range(0, sancov_guards_size, 59 | SANCOV_PC_GUARD_VALUE_SIZE): 60 | curr_addr = sancov_guards_addr.add(offset) 61 | # each location should really only be referenced once, 62 | # but the first few entries might have an extra 63 | valid_ref_found = False 64 | for ref in getReferencesTo(curr_addr): 65 | calling_function = getFunctionContaining(ref.fromAddress) 66 | # ignore references that are data references only 67 | if calling_function is None: 68 | continue 69 | # ignore things related to sancov, which will probably refer 70 | # to the first and last entry for initialization 71 | if calling_function.name.startswith("sancov"): 72 | continue 73 | 74 | valid_ref_found = True 75 | break 76 | 77 | if valid_ref_found is False: 78 | continue 79 | 80 | sancov_guard_ref = SanCovPCGuardRef(ref, calling_function, 81 | offset) 82 | 83 | pc_guard_refs.append(sancov_guard_ref) 84 | return pc_guard_refs 85 | 86 | 87 | def get_code_block(address): 88 | func = getFunctionContaining(address) 89 | if not func: 90 | return None 91 | block_model = BasicBlockModel(currentProgram) 92 | addresses = func.getBody() 93 | code_blocks = list(block_model.getCodeBlocksContaining(addresses, 94 | monitor)) 95 | for code_block in code_blocks: 96 | if code_block.contains(address): 97 | return code_block 98 | return None 99 | 100 | 101 | def parse_indices_from_showmap_output(filepath): 102 | """ 103 | return the index value from a file output by afl-showmap 104 | """ 105 | with open(filepath, "r") as f: 106 | content = f.read() 107 | 108 | rexp = re.compile(r"(\d+):\d+") 109 | indices = [] 110 | for m in re.finditer(rexp, content): 111 | index_str = m.groups()[0] 112 | index = int(index_str) 113 | indices.append(index) 114 | 115 | return indices 116 | 117 | 118 | def showmap_index_to_pc_guard_offset(index): 119 | return (index - SANCOV_PC_GUARD_START_INDEX) * SANCOV_PC_GUARD_VALUE_SIZE 120 | 121 | 122 | def pc_guard_offset_to_showmap_index(offset): 123 | return (offset // SANCOV_PC_GUARD_VALUE_SIZE) + SANCOV_PC_GUARD_START_INDEX 124 | 125 | 126 | def get_code_block_sources(code_block): 127 | """ 128 | block.SimpleSourceReferenceIterator is not iterable 129 | """ 130 | sources = [] 131 | source_iterator = code_block.getSources(monitor) 132 | while source_iterator.hasNext(): 133 | sources.append(source_iterator.next()) 134 | return sources 135 | 136 | 137 | def get_code_block_dests(code_block): 138 | """ 139 | block.SimpleSourceReferenceIterator is not iterable 140 | """ 141 | dests = [] 142 | dest_iterator = code_block.getDestinations(monitor) 143 | while dest_iterator.hasNext(): 144 | dests.append(dest_iterator.next()) 145 | return dests 146 | 147 | 148 | class BlockHighlighter: 149 | """ 150 | Class for hightlighting basic blocks 151 | """ 152 | def __init__(self): 153 | self.hightlighted_block_record = [] 154 | self.listing = currentProgram.getListing() 155 | 156 | def highlight_code_block_of_address(self, address, 157 | color=COLOR_VISITED): 158 | bb = get_code_block(address) 159 | self.highlight_code_block(bb) 160 | 161 | def highlight_code_block(self, code_block, color=COLOR_VISITED): 162 | for address_range in code_block.addressRanges: 163 | address_set = AddressSet(address_range) 164 | setBackgroundColor(address_set, color) 165 | # save record of this highlight so it can be 166 | # undone later 167 | self.hightlighted_block_record.append(code_block) 168 | 169 | 170 | class BlockFlowTracer: 171 | def __init__(self, visited_blocks): 172 | self._traced_blocks = set() 173 | self._visited_blocks = set(visited_blocks) 174 | 175 | def find_unconditional_source_blocks(self, code_block): 176 | curr_code_block = code_block 177 | unconditional_source_blocks = set() 178 | self._visited_blocks.add(code_block) 179 | while True: 180 | block_sources = get_code_block_sources(curr_code_block) 181 | if len(block_sources) != 1 or curr_code_block in unconditional_source_blocks: 182 | break 183 | src = block_sources[0] 184 | src_block = src.getSourceBlock() 185 | unconditional_source_blocks.add(src_block) 186 | curr_code_block = src_block 187 | self._visited_blocks.update(unconditional_source_blocks) 188 | return unconditional_source_blocks 189 | 190 | def find_unconditional_dest_blocks(self, code_block): 191 | curr_code_block = code_block 192 | unconditional_dest_blocks = set() 193 | self._visited_blocks.add(code_block) 194 | while True: 195 | block_dests = get_code_block_dests(curr_code_block) 196 | if len(block_dests) != 1 or curr_code_block in unconditional_dest_blocks: 197 | break 198 | dest = block_dests[0] 199 | dest_block = dest.getDestinationBlock() 200 | unconditional_dest_blocks.add(dest_block) 201 | curr_code_block = dest_block 202 | self._visited_blocks.update(unconditional_dest_blocks) 203 | return unconditional_dest_blocks 204 | 205 | def get_unconditionally_visited_blocks(self, code_block): 206 | all_visited_blocks = set() 207 | usb = self.find_unconditional_source_blocks(code_block) 208 | udb = self.find_unconditional_dest_blocks(code_block) 209 | all_visited_blocks.update(usb) 210 | all_visited_blocks.update(udb) 211 | all_visited_blocks.add(code_block) 212 | return all_visited_blocks 213 | 214 | def get_all_unconditionally_visited_blocks(self): 215 | # copy so that blocks don't get processed twice 216 | # during this process, as visited_blocks gets updated 217 | visited_blocks_copy = list(self._visited_blocks) 218 | all_visited_blocks = set(visited_blocks_copy) 219 | for visited_block in visited_blocks_copy: 220 | uvb = self.get_unconditionally_visited_blocks(visited_block) 221 | all_visited_blocks.update(uvb) 222 | return all_visited_blocks 223 | 224 | 225 | def highlight_visited_and_unvisited_blocks(all_visited_blocks, unreached): 226 | """ 227 | Generic function for highlighting all of the blocks that were visited 228 | and unvisited. Also does a little bit of "analysis" to try to improve 229 | results slightly 230 | """ 231 | bbh = BlockHighlighter() 232 | # Find blocks that must be unconditionally reached 233 | v_bft = BlockFlowTracer(all_visited_blocks) 234 | all_blocks_to_highlight = v_bft.get_all_unconditionally_visited_blocks() 235 | 236 | for block in unreached: 237 | bbh.highlight_code_block(block, COLOR_UNVISITED) 238 | 239 | for block in all_blocks_to_highlight: 240 | bbh.highlight_code_block(block, COLOR_VISITED) 241 | 242 | 243 | def hightlight_sancov_visited_code_blocks(): 244 | """ 245 | Highlight blocks based on sancov pc guards. Only works for binaries 246 | built with afl 247 | """ 248 | showmap_file = askFile("showmap File", 249 | "Path to a file output from afl-showmap") 250 | showmap_filepath = showmap_file.toString() 251 | 252 | showmap_indices = parse_indices_from_showmap_output(showmap_filepath) 253 | showmap_offsets = set([showmap_index_to_pc_guard_offset(i) for i in showmap_indices]) 254 | pc_guard_refs = get_pc_guard_refs() 255 | unreached_code_blocks = [] 256 | all_visited_blocks = [] 257 | # TODO: these can likely be optimized a bit 258 | for pc_guard_ref in pc_guard_refs: 259 | code_block = get_code_block(pc_guard_ref.ref.fromAddress) 260 | # pick out the ones that are not visited 261 | if pc_guard_ref.offset not in showmap_offsets: 262 | unreached_code_blocks.append(code_block) 263 | continue 264 | 265 | all_visited_blocks.append(code_block) 266 | 267 | all_blocks_to_highlight = all_visited_blocks 268 | highlight_visited_and_unvisited_blocks(all_blocks_to_highlight, 269 | unreached_code_blocks) 270 | 271 | 272 | def all_child_filepaths_gen(dir_path): 273 | """ 274 | Get all of the files under the dir path 275 | """ 276 | for dirpath, dirnames, filenames in os.walk(dir_path): 277 | for filename in filenames: 278 | filepath = os.path.join(dirpath, filename) 279 | yield filepath 280 | 281 | 282 | class QemuAsmAnalyzer: 283 | """ 284 | A class for Assisting with the analysis of qemu asm logs. 285 | Helps to determine if current program's base changes 286 | """ 287 | def __init__(self): 288 | self.symb_history = defaultdict(list) 289 | self.symb_rexp = re.compile('IN: ([^\n]+)\n(0x[a-f0-9]+)', 290 | re.MULTILINE | re.DOTALL) 291 | self.first_block_addr_rexp = re.compile('IN:[^\n]+\n(0x[a-f0-9]+):', 292 | re.MULTILINE | re.DOTALL) 293 | sm = currentProgram.getSymbolTable() 294 | self.useful_symbols = {i.name: i.getAddress().getOffset() 295 | for i in sm.getSymbolIterator() 296 | if i.isExternalEntryPoint() and 297 | i.symbolType == SymbolType.FUNCTION} 298 | 299 | def get_binary_base_from_symbols(self, qemu_in_asm_log): 300 | """ 301 | Get the base of the binary using the symbols that appear in 302 | qemu log 303 | """ 304 | current_base_int = currentProgram.getImageBase().getOffset() 305 | for m in re.finditer(self.symb_rexp, qemu_in_asm_log): 306 | log_symbol, address_str = m.groups() 307 | log_address = int(address_str, 16) 308 | known_addr = self.useful_symbols.get(log_symbol) 309 | # if the symbol doesn't appear in ghidra, it is likely from the 310 | # wrong binary, skip it 311 | if known_addr is None: 312 | continue 313 | 314 | # symbol matched, but is probably a dupicate symbol that exists 315 | # in multiple binaries. Only keep ones that appear to have come 316 | # from this binary 317 | if (known_addr & 0xfff) != (log_address & 0xfff): 318 | continue 319 | 320 | # ghidra will add a dummy base address if the executable is PIE, 321 | # so remove that to get the offset 322 | known_offset = known_addr - current_base_int 323 | binary_base = log_address - known_offset 324 | return binary_base 325 | 326 | return None 327 | 328 | def get_binary_base(self, qemu_in_asm_log): 329 | """ 330 | Try to get the base of the binary through a few different methods 331 | """ 332 | maybe_base = self.get_binary_base_from_symbols(qemu_in_asm_log) 333 | if maybe_base: 334 | return maybe_base 335 | 336 | def get_first_address_of_each_block(self, qemu_in_asm_log): 337 | """ 338 | yield the first address in each block. 339 | """ 340 | for m in re.finditer(self.first_block_addr_rexp, qemu_in_asm_log): 341 | yield int(m.groups()[0], 16) 342 | 343 | def parse_qemu_asm_visited_addresses(self, file_contents): 344 | """ 345 | More extensive, but on larger binaries the output could be messive 346 | """ 347 | rexp = re.compile("(0x[a-f0-9]+):") 348 | for m in re.finditer(rexp, file_contents): 349 | address_str = m.groups()[0] 350 | address = int(address_str, 16) 351 | yield address 352 | 353 | 354 | def highlight_qemu_visited_code_blocks(): 355 | dir_obj = askDirectory("Path to output from gather_qemu_coverage_data.sh", 356 | "select") 357 | 358 | # this is set as the base address in 359 | # ASSUMED_BASE_ADDRESS = 0x1800000 360 | dir_path = dir_obj.toString() 361 | all_visited_blocks = set() 362 | current_image_base_int = currentProgram.getImageBase().getOffset() 363 | qaa = QemuAsmAnalyzer() 364 | for path in all_child_filepaths_gen(dir_path): 365 | with open(path, "r") as f: 366 | file_contents = f.read() 367 | binary_base_address_in_log = qaa.get_binary_base(file_contents) 368 | for log_addr in qaa.get_first_address_of_each_block(file_contents): 369 | # adjust the address so that it matches up with what is in ghidra 370 | log_offset = (log_addr - binary_base_address_in_log) 371 | ghidra_addr = toAddr(log_offset + current_image_base_int) 372 | code_block = get_code_block(ghidra_addr) 373 | if code_block: 374 | all_visited_blocks.add(code_block) 375 | 376 | highlight_visited_and_unvisited_blocks(list(all_visited_blocks), []) 377 | 378 | 379 | if __name__ == "__main__": 380 | choices = askChoices("Coverage Highlighter", 381 | "Select how you would like to highlight coverage", 382 | ["qemu_based_binary_only", 383 | "afl_sancov_guard"], 384 | ["Qemu-based (works for binary only)", 385 | "AFL SanCov"]) 386 | for choice in choices: 387 | if choice.find("qemu_based_binary_only") != -1: 388 | highlight_qemu_visited_code_blocks() 389 | elif choice.find("afl_sancov_guard") != -1: 390 | hightlight_sancov_visited_code_blocks() 391 | -------------------------------------------------------------------------------- /coverage_visualizer/gather_qemu_coverage_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CAN_PROCEED=1 4 | if [[ ! -z "${OUTPUT_DIR}" ]]; then 5 | TMP_DIR="${OUTPUT_DIR}" 6 | mkdir -p $TMP_DIR 7 | else 8 | TMP_DIR=$(mktemp -d) 9 | fi 10 | 11 | # for debugging 12 | # mkdir -p $TMP_DIR 13 | 14 | if [[ ! -z "${AFL_PATH}" ]]; then 15 | AFL_QEMU_TRACE="${AFL_PATH}/afl-qemu-trace" 16 | elif [[ ! -z "$(command -v afl-qemu-trace)" ]]; then 17 | AFL_QEMU_TRACE=$(command -v afl-qemu-trace) 18 | else 19 | echo "ERR: AFL_PATH must be specified or afl-qemu-trace must be in PATH" 20 | CAN_PROCEED=0 21 | fi 22 | 23 | 24 | if [[ -z "${AFL_OUTPUT_DIR}" ]]; then 25 | echo "ERR: AFL_OUTPUT_DIR must be set" 26 | # maybe use AFL_CUSTOM_INFO_OUT 27 | CAN_PROCEED=0 28 | fi 29 | 30 | if [[ $# -lt 1 ]]; then 31 | echo "$0" 32 | echo "Usage: AFL_OUTPUT_DIR= $0 " 33 | echo "" 34 | echo " AFL_OUTPUT_DIR: output directory from afl++" 35 | echo " AFL_PATH: path to built afl++ project. not necessary if afl++ has been installed" 36 | echo " OUTPUT_DIR: path to output traces in" 37 | echo "" 38 | echo "Example:" 39 | echo " AFL_OUTPUT_DIR=/mnt/ramdisk/output/ OUTPUT_DIR=./output $0 ./fuzz_qemu" 40 | CAN_PROCEED=0 41 | fi 42 | 43 | # only exit after all of the other warnings and errors have been printed 44 | # so that the user doesn't have to waste time 45 | if [[ $CAN_PROCEED -eq 0 ]]; then 46 | rm -rf $TMP_DIR 47 | exit 1 48 | fi 49 | 50 | QUEUE_DIR=$(find $AFL_OUTPUT_DIR -type d -name 'queue') 51 | 52 | for i in $(find $QUEUE_DIR -type f); do 53 | TMP_FILE=$(mktemp -p $TMP_DIR) 54 | echo "$AFL_QEMU_TRACE -d in_asm -D $TMP_FILE $@ < $i" 55 | $AFL_QEMU_TRACE -d in_asm -D $TMP_FILE $@ < $i 56 | done 57 | 58 | echo "$TMP_DIR" 59 | -------------------------------------------------------------------------------- /create_vtable.py: -------------------------------------------------------------------------------- 1 | # Create and update vtables (vftables) 2 | # @author Clifton Wolfe 3 | # @category C++ 4 | # @keybinding ctrl 5 5 | # @menupath Tools.Automation.Create Vtable 6 | # @toolbar 7 | 8 | from __main__ import * 9 | from ghidra.program.model.address import AddressSet 10 | from ghidra.program.model.address import AddressRangeImpl 11 | from ghidra.program.model.symbol import DataRefType 12 | from ghidra.program.database.symbol import FunctionSymbol 13 | from ghidra.program.database.symbol import CodeSymbol 14 | from ghidra.program.database.code import DataDB 15 | from ghidra.program.database.code import InstructionDB 16 | from ghidra.program.model.data import StructureDataType 17 | from datatype_utils import getVoidPointerDatatype, applyDataTypeAtAddress 18 | import struct 19 | import logging 20 | 21 | log = logging.getLogger(__file__) 22 | log.addHandler(logging.StreamHandler()) 23 | log.setLevel(logging.DEBUG) 24 | 25 | 26 | def containsInstructions(address, program=None): 27 | """ 28 | Check if the address contains instructions. 29 | """ 30 | if program is None: 31 | program = currentProgram 32 | cu = program.getListing().getCodeUnitAt(address) 33 | return isinstance(cu, InstructionDB) 34 | 35 | 36 | def addressContainsUnownedAssembly(address, program=None): 37 | """ 38 | Check if the address contains valid instructions but is not 39 | within a function 40 | """ 41 | if program is None: 42 | program = currentProgram 43 | cu = program.getListing().getCodeUnitAt(address) 44 | if not isinstance(cu, InstructionDB): 45 | return False 46 | maybe_func = getFunctionContaining(address) 47 | return maybe_func is None 48 | 49 | 50 | def guessVtableByteSize(address, program=None, allow_null_ptrs=True): 51 | """ 52 | Make an educated guess at the valid size of a vtable 53 | """ 54 | if program is None: 55 | program = currentProgram 56 | refman = program.getReferenceManager() 57 | ptr_size = program.getDefaultPointerSize() 58 | mem = program.getMemory() 59 | ptr_pack_end = ">" if mem.isBigEndian() else "<" 60 | ptr_pack_sym = "I" if ptr_size == 4 else "Q" 61 | ptr_pack_code = ptr_pack_end + ptr_pack_sym 62 | listing = program.getListing() 63 | # skip the first entry because there are references to it 64 | curr_addr = address.add(ptr_size) 65 | last_valid_vtable_entry = address 66 | to_refs = [] 67 | addr_sym = None 68 | while len(to_refs) == 0 and addr_sym is None: 69 | to_refs = list(refman.getReferencesTo(curr_addr)) 70 | maybe_ptr_bytes = bytearray(getBytes(curr_addr, ptr_size)) 71 | maybe_func_addr_int = struct.unpack(ptr_pack_code, maybe_ptr_bytes)[0] 72 | maybe_func_addr = toAddr(maybe_func_addr_int) 73 | addr_sym = getSymbolAt(curr_addr) 74 | # check to see if it is actually valid code instead of a pointer to other data 75 | # Can't use refs for this because sometimes an address isn't identified as an address 76 | # and doesn't generate the reference 77 | is_valid_addr = mem.getRangeContaining(maybe_func_addr) is not None 78 | if is_valid_addr is True: 79 | # TODO: the structure of this loop should be untangled. for now just repeat loop exit critera 80 | if len(to_refs) == 0 and addr_sym is None: 81 | cu = listing.getCodeUnitAt(maybe_func_addr) 82 | if isinstance(cu, InstructionDB): 83 | # this means that it is a label or a func 84 | # save the last valid reference to a function or label as the end of the vtable 85 | last_valid_vtable_entry = curr_addr 86 | elif maybe_func_addr_int == 0 and allow_null_ptrs is True: 87 | # in some binaries it is entirely valid to have NULL pointers for vtable functions. 88 | last_valid_vtable_entry = curr_addr 89 | curr_addr = curr_addr.add(ptr_size) 90 | # add ptr size so that the bounds of the vtable include the last valid pointer 91 | vtable_guessed_end = last_valid_vtable_entry.add(ptr_size) 92 | vtable_size_guess = vtable_guessed_end.subtract(address) 93 | return vtable_size_guess 94 | 95 | 96 | def extractAddressTableEntries(address, table_size, program=None): 97 | """ 98 | Extracts the bytes from the specified address as an array of addresses 99 | """ 100 | if program is None: 101 | program = currentProgram 102 | ptr_size = program.getDefaultPointerSize() 103 | mem = program.getMemory() 104 | ptr_pack_end = ">" if mem.isBigEndian() else "<" 105 | ptr_pack_sym = "I" if ptr_size == 4 else "Q" 106 | vtable_bytes = bytearray(getBytes(address, table_size)) 107 | num_ptrs = (table_size // ptr_size) 108 | pack_code = "%s%d%s" % (ptr_pack_end, num_ptrs, ptr_pack_sym) 109 | table_addrs = [toAddr(i) for i in struct.unpack_from(pack_code, vtable_bytes)] 110 | return table_addrs 111 | 112 | 113 | def createFunctionsForVtableLabels(address, vtable_size, program=None): 114 | """ 115 | Iterate through the embedded addresses at the specified address and create functions for the addresses 116 | that are valid and contains instructions, but are not already functions 117 | """ 118 | if program is None: 119 | program = currentProgram 120 | mem = program.getMemory() 121 | vtable_addrs = extractAddressTableEntries(address, vtable_size, program=program) 122 | for addr in vtable_addrs: 123 | is_valid_address = mem.getRangeContaining(addr) is not None 124 | if is_valid_address is False: 125 | continue 126 | if addressContainsUnownedAssembly(addr, program=program) is True: 127 | createFunction(addr, None) 128 | 129 | 130 | def createStringForNamespace(curr_ns): 131 | ns_strs = [] 132 | while curr_ns: 133 | ns_strs.append(curr_ns.getName()) 134 | curr_ns = curr_ns.getParentNamespace() 135 | return "_".join(ns_strs[::-1]) 136 | 137 | 138 | 139 | def createNewVtableAtAddress(address, vtable_size=None, referring_func=None, program=None): 140 | """ 141 | Create a new vtable datatype based on the data at the specified address 142 | """ 143 | if program is None: 144 | program = currentProgram 145 | if referring_func is not None: 146 | namespace = referring_func.getParentNamespace() 147 | else: 148 | namespace = program.getGlobalNamespace() 149 | 150 | vtable_prefix = "" 151 | global_namespace = program.getGlobalNamespace() 152 | if namespace != global_namespace: 153 | # if the namespace for the function is not the global namespace, try to make an 154 | # appropriate name for the new vtable 155 | vtable_prefix = createStringForNamespace(namespace) + "_" 156 | 157 | if vtable_size is None: 158 | vtable_size = guessVtableByteSize(address, program=program) 159 | # Fix up the Label pointers and make them into functions 160 | createFunctionsForVtableLabels(address, vtable_size, program=program) 161 | # now create the actual struct 162 | dtm = program.getDataTypeManager() 163 | new_struct = StructureDataType("%svftable_%s" % (vtable_prefix, str(address)), vtable_size) 164 | ptr_size = program.getDefaultPointerSize() 165 | voidp_dt = getVoidPointerDatatype() 166 | table_addrs = extractAddressTableEntries(address, vtable_size, program=program) 167 | # update the new datatype by setting the field name to something recognizable 168 | # as a function pointer and setting the type to void* 169 | for ind, addr in enumerate(table_addrs): 170 | offset = ind*ptr_size 171 | func = getFunctionAt(addr) 172 | field_name = None 173 | if func is not None: 174 | field_name = "%s_%#x" % (func.name, offset) 175 | new_struct.replaceAtOffset(offset, voidp_dt, ptr_size, field_name, None) 176 | dtm.addDataType(new_struct, None) 177 | # actually apply the datatype 178 | applyDataTypeAtAddress(address, new_struct, vtable_size, program=program) 179 | return new_struct 180 | 181 | 182 | def updateVtableAtAddress(address, vtable_size=None, program=None): 183 | raise NotImplemented 184 | if program is None: 185 | program = currentProgram 186 | if vtable_size is None: 187 | vtable_size = guessVtableByteSize(address, program=program) 188 | data_db = getDataContaining(address) 189 | datatype = data_db.getDataType() 190 | table_addrs = extractAddressTableEntries(address, vtable_size, program=program) 191 | voidp_dt = getVoidPointerDatatype() 192 | ptr_size = program.getDefaultPointerSize() 193 | for ind, addr in enumerate(table_addrs): 194 | offset = ind*ptr_size 195 | func = getFunctionAt(addr) 196 | field_name = None 197 | if func is not None: 198 | field_name = "%s_%#x" % (func.name, offset) 199 | if field: 200 | pass 201 | datatype.replaceAtOffset(offset, voidp_dt, ptr_size, field_name, None) 202 | # dtm.addDataType(new_struct, None) 203 | 204 | 205 | 206 | def createOrUpdateVtableAtAddress(address, vtable_size=None, referring_func=None, program=None): 207 | if program is None: 208 | program = currentProgram 209 | maybe_dtdb = getDataContaining(address) 210 | if maybe_dtdb is not None and maybe_dtdb.isStructure(): 211 | vtable_dt = updateVtableAtAddress(address, vtable_size=vtable_size, program=program) 212 | else: 213 | vtable_dt = createNewVtableAtAddress(address, vtable_size=vtable_size, referring_func=referring_func, program=program) 214 | return vtable_dt 215 | 216 | def create_vtable_entrypoint(): 217 | selection = state.getCurrentSelection() 218 | currLoc = state.getCurrentLocation() 219 | addr_set = AddressSet() 220 | vtable_address = None 221 | referring_func = None 222 | vtable_size = None 223 | if selection is not None: 224 | pass 225 | elif currLoc is not None: 226 | if hasattr(currLoc, "getToken"): 227 | tok = currLoc.getToken() 228 | referring_func = getFunctionAt(currLoc.getFunctionEntryPoint()) 229 | # unless there is a better way to find where a token is referring to, 230 | # have to iterate over the addresses of the token to find where it is pointing 231 | refman = currentProgram.getReferenceManager() 232 | addr_range = AddressRangeImpl(tok.minAddress, tok.maxAddress) 233 | to_addrs = [] 234 | for addr in addr_range.iterator(): 235 | for ref in refman.getReferencesFrom(addr): 236 | # because this is only looking for vtables, drop all non-data types of 237 | # references 238 | if not isinstance(ref.referenceType, DataRefType): 239 | continue 240 | to_addrs.append(ref.toAddress) 241 | # TODO: find a better way to filter these out, if there are in fact 242 | # TODO: multiple references 243 | num_to_addrs = len(to_addrs) 244 | if num_to_addrs != 1: 245 | if num_to_addrs > 1: 246 | log.critical("A critical assumption of the script has been broken. There are more than two references from the same token") 247 | elif num_to_addrs < 1: 248 | log.error("No references from the token") 249 | return 250 | vtable_address = to_addrs[0] 251 | else: 252 | vtable_address = currLoc.getAddress() 253 | 254 | createOrUpdateVtableAtAddress(vtable_address, vtable_size=vtable_size, referring_func=referring_func, program=currentProgram) 255 | 256 | 257 | if __name__ == "__main__": 258 | create_vtable_entrypoint() -------------------------------------------------------------------------------- /datatype_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __main__ import * 3 | from ghidra.program.database.data import DataTypeUtilities 4 | from ghidra.program.model.data import PointerDataType 5 | from ghidra.program.model.data import MetaDataType 6 | from ghidra.program.model.data import EnumDataType 7 | 8 | 9 | def find_datatypes_using(datatype, check_full_chains=True): 10 | initial_dt = datatype 11 | visited = set() 12 | to_visit = set([datatype]) 13 | while to_visit: 14 | curr_dt = to_visit.pop() 15 | for parent_dt in curr_dt.getParents(): 16 | base = DataTypeUtilities.getBaseDataType(parent_dt) 17 | if base != initial_dt: 18 | # not a (series)? of pointer/array to the data type 19 | if check_full_chains is False: 20 | continue 21 | if parent_dt in visited: 22 | continue 23 | if parent_dt in to_visit: 24 | continue 25 | if parent_dt == curr_dt: 26 | continue 27 | to_visit.add(parent_dt) 28 | visited.add(curr_dt) 29 | return visited 30 | 31 | 32 | def getUndefinedRegisterSizeDatatype(program=None): 33 | """ 34 | Returns an "undefined*" datatype that is the appropriate 35 | size to hold a pointer. Useful if you don't know the real datatype 36 | and expect it to have to be changed later 37 | """ 38 | if program is None: 39 | program = currentProgram 40 | dtm = program.getDataTypeManager() 41 | default_ptr_size = program.getDefaultPointerSize() 42 | return dtm.getDataType("/undefined%d" % default_ptr_size) 43 | 44 | 45 | def getGenericPointerDatatype(): 46 | return PointerDataType() 47 | 48 | 49 | def getVoidPointerDatatype(program=None): 50 | if program is None: 51 | program = currentProgram 52 | dtm = program.getDataTypeManager() 53 | void_dt = dtm.getDataType("/void") 54 | return dtm.getPointer(void_dt) 55 | 56 | 57 | def areBaseDataTypesEquallyUnique(datatype_a, datatype_b): 58 | datatype_a = DataTypeUtilities.getBaseDataType(datatype_a) 59 | datatype_b = DataTypeUtilities.getBaseDataType(datatype_b) 60 | a_meta = MetaDataType.getMeta(datatype_a) 61 | b_meta = MetaDataType.getMeta(datatype_b) 62 | return a_meta.compareTo(b_meta) == 0 63 | 64 | def applyDataTypeAtAddress(address, datatype, size=None, program=None): 65 | if program is None: 66 | program = currentProgram 67 | if size is None: 68 | size = datatype.getLength() 69 | listing = program.getListing() 70 | listing.clearCodeUnits(address, address.add(size), False) 71 | listing.createData(address, datatype, size) 72 | 73 | -------------------------------------------------------------------------------- /decomp_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __main__ import * 3 | from ghidra.app.decompiler import DecompileOptions 4 | from ghidra.app.decompiler import DecompInterface 5 | from ghidra.util.task import ConsoleTaskMonitor 6 | from ghidra.program.model.pcode import PcodeOpAST 7 | from ghidra.app.decompiler.component import DecompilerUtils 8 | import logging 9 | 10 | log = logging.getLogger(__file__) 11 | log.addHandler(logging.StreamHandler()) 12 | log.setLevel(logging.WARNING) 13 | 14 | 15 | class DecompUtils: 16 | """ 17 | Utilities on top of the existing decompiler utils 18 | """ 19 | def __init__(self, program=None, monitor_inst=None, decomp_timeout=60): 20 | if program is not None: 21 | self.program = program 22 | else: 23 | self.program = currentProgram 24 | self.addr_fact = self.program.getAddressFactory() 25 | self.dtm = self.program.getDataTypeManager() 26 | self._decomp_options = DecompileOptions() 27 | if monitor_inst is None: 28 | self._monitor = monitor 29 | else: 30 | self._monitor = monitor_inst 31 | self._ifc = DecompInterface() 32 | self._ifc.setOptions(self._decomp_options) 33 | self.fm = self.program.getFunctionManager() 34 | self.decomp_timeout = decomp_timeout 35 | 36 | def get_funcs_by_name(self, name): 37 | """ 38 | Get all of the functions that match the name @name 39 | """ 40 | return [i for i in self.fm.getFunctions(1) if i.name == name] 41 | 42 | def get_high_function(self, func, timeout=None): 43 | """ 44 | Get a HighFunction for a given function 45 | """ 46 | res = self.get_decompiler_result(func, timeout) 47 | high_func = res.getHighFunction() 48 | return high_func 49 | 50 | def get_decompiler_result(self, func, timeout=None): 51 | """ 52 | Get decompiler results for a given function 53 | """ 54 | if timeout is None: 55 | timeout = self.decomp_timeout 56 | self._ifc.openProgram(func.getProgram()) 57 | res = self._ifc.decompileFunction(func, timeout, self._monitor) 58 | return res 59 | 60 | def get_function_prototype(self, func, **kwargs): 61 | """ 62 | Get the function prototype for function @func 63 | """ 64 | hf = self.get_high_function(func, **kwargs) 65 | if hf is None: 66 | return None 67 | return hf.getFunctionPrototype() 68 | 69 | def _get_high_sym_for_param_by_index(self, func, index, **kwargs): 70 | """ 71 | Get the HighSymbol for a function's parameter. Parameter is specified 72 | by @index, and indexes start at zero. @index will NOT match up 73 | with the parameter numbers from the decompiler window 74 | """ 75 | proto = self.get_function_prototype(func, **kwargs) 76 | if proto is None: 77 | log.warning("No prototype for %s" % func.name) 78 | return None 79 | num_params = proto.getNumParams() 80 | if num_params-1 < index: 81 | log.warning("Parameter index %d does not exist in %s which has %d parameters" % (index, func.name, num_params)) 82 | return None 83 | high_sym = proto.getParam(index) 84 | return high_sym 85 | 86 | def get_high_sym_for_param(self, func, param_num, **kwargs): 87 | """ 88 | Get the HighSymbol for a function. @param_num matches up with parameter 89 | numbers visible in the decompiler window 90 | """ 91 | return self._get_high_sym_for_param_by_index(func, param_num-1, **kwargs) 92 | 93 | 94 | def get_varnodes_for_param(self, func, param_num, **kwargs): 95 | """ 96 | Gets the varnodes for a function parameter. @param_num matches up with 97 | the parameter number from the decompiler window 98 | """ 99 | high_sym = self.get_high_sym_for_param(func, param_num, **kwargs) 100 | if high_sym is None: 101 | log.warning("No HighSymbol for %s param %d " % (func.name, param_num)) 102 | return None 103 | # it is actually legitimate for there to be no high variable or 104 | # varnodes for a high symbol, like in cases where the parameter is just 105 | # unused 106 | high_var = high_sym.getHighVariable() 107 | if high_var is None: 108 | # log.warning("No HighVariable for %s param %d " % (func.name, param_num)) 109 | return [] 110 | vn_arr = high_var.getInstances() 111 | if vn_arr is None: 112 | # log.warning("No Varnode instances for %s param %d " % (func.name, param_num)) 113 | return [] 114 | return list(vn_arr) 115 | 116 | def get_all_parameter_varnodes(self, func, **kwargs): 117 | """ 118 | Get a list of lists of varnodes for all paramters to @func 119 | """ 120 | proto = self.get_function_prototype(func, **kwargs) 121 | if proto is None: 122 | log.warning("No prototype for %s" % func.name) 123 | return None 124 | num_params = proto.getNumParams() 125 | if num_params == 0: 126 | return [] 127 | varnodes_lists = [] 128 | for param_num in range(1, num_params+1): 129 | vns = self.get_varnodes_for_param(func, param_num, **kwargs) 130 | if vns is None: 131 | log.error("unable to get varnodes for param %d" % param_num) 132 | continue 133 | varnodes_lists.append(vns) 134 | return varnodes_lists 135 | 136 | def get_pcode_for_function(self, func, **kwargs): 137 | """ 138 | Get an unsorted list of PcodeOps for the function @func 139 | """ 140 | hf = self.get_high_function(func, **kwargs) 141 | if hf is None: 142 | log.warning("couldn't get high function for %s" % func.name) 143 | return None 144 | return list(hf.getPcodeOps()) 145 | 146 | def get_pcode_blocks_for_function(self, func, **kwargs): 147 | """ 148 | Get an unsorted list of Pcode Basic Blocks for the 149 | function @func 150 | """ 151 | hf = self.get_high_function(func, **kwargs) 152 | if hf is None: 153 | log.warning("couldn't get high function for %s" % func.name) 154 | return None 155 | return list(hf.getBasicBlocks()) 156 | 157 | def varnode_is_direct_source_of(self, source_vn_cand, descendant_vn_cand): 158 | """ 159 | Check to see if the Varnode @source_vn_cand directly leads to 160 | @descendant_vn_cand 161 | """ 162 | if source_vn_cand == descendant_vn_cand: 163 | return True 164 | defining_op = descendant_vn_cand.getDef() 165 | # an op with no definition is likely a parameter, global, 166 | # uninitialized, or part of a composite struct on the stack or in ram 167 | # that hasn't been recovered 168 | if defining_op is None: 169 | return False 170 | fwd_slice_vns = list(DecompilerUtils.getForwardSlice(source_vn_cand)) 171 | # TODO: check to see if anything else weird could happen to make this 172 | # TODO: not handle all cases 173 | if descendant_vn_cand in fwd_slice_vns: 174 | return True 175 | return False 176 | 177 | def varnode_leads_to_definition_of(self, source_vn_cand, descendant_vn_cand): 178 | """ 179 | Check to see if the Varnode @source_vn_cand directly leads to 180 | inputs to the defining op of @descendant_vn_cand 181 | """ 182 | if source_vn_cand == descendant_vn_cand: 183 | return True 184 | defining_op = descendant_vn_cand.getDef() 185 | # an op with no definition is likely a parameter, global, 186 | # uninitialized, or part of a composite struct on the stack or in ram 187 | # that hasn't been recovered 188 | if defining_op is None: 189 | return False 190 | intersecting_vns = self.get_op_inputs_from_fwd_from_varnode(source_vn_cand, 191 | defining_op) 192 | if len(intersecting_vns) > 0: 193 | return True 194 | return False 195 | 196 | def get_op_inputs_fwd_from_varnode(self, varnode, op): 197 | if op is None: 198 | return set() 199 | op_inputs = list(op.getInputs()) 200 | op_inputs_set = set(op_inputs) 201 | fwd_slice_vns = list(DecompilerUtils.getForwardSlice(varnode)) 202 | fwd_slice_vns_set = set(fwd_slice_vns) 203 | intersecting_vns = fwd_slice_vns_set.intersection(op_inputs_set) 204 | return intersecting_vns 205 | 206 | 207 | def find_all_pcode_op_instances(opcodes, program=None, **kwargs): 208 | if not hasattr(opcodes, "__iter__"): 209 | opcodes = [opcodes] 210 | if any([i > PcodeOpAST.PCODE_MAX for i in opcodes]): 211 | raise Exception("Invalud Pcode op") 212 | if program is None: 213 | program = currentProgram 214 | du = DecompUtils(program, **kwargs) 215 | funcs_with_matching_op = {} 216 | for func in program.getFunctionManager().getFunctions(1): 217 | if func.isThunk(): 218 | continue 219 | pcode_ops = du.get_pcode_for_function(func) 220 | matching_ops = [i for i in pcode_ops if i.opcode in opcodes] 221 | if not matching_ops: 222 | continue 223 | funcs_with_matching_op[func] = [op.getSeqnum().getTarget() for op in matching_ops] 224 | return funcs_with_matching_op 225 | 226 | 227 | -------------------------------------------------------------------------------- /dfg_exporter.py: -------------------------------------------------------------------------------- 1 | from __main__ import * 2 | from decomp_utils import DecompUtils 3 | import json 4 | from collections import defaultdict 5 | from ghidra.program.model.block import BasicBlockModel 6 | from ghidra.program.model.address import AddressSet 7 | from ghidra.program.model.address import AddressRangeImpl 8 | from ghidra.program.model.symbol import FlowType 9 | from ghidra.program.model.symbol import SymbolType 10 | import os 11 | import re 12 | 13 | 14 | def get_function_call_graph_map(): 15 | call_map = defaultdict(set) 16 | for func in currentProgram.getFunctionManager().getFunctions(1): 17 | func_key = func.getEntryPoint() 18 | for called_func in func.getCalledFunctions(monitor): 19 | called_key = called_func.getEntryPoint() 20 | call_map[func_key].add(called_key) 21 | 22 | call_map = dict(call_map) 23 | serializable_call_map = {k: list(v) for k, v in call_map.items()} 24 | return serializable_call_map 25 | 26 | 27 | def get_code_block_sources(code_block): 28 | """ 29 | block.SimpleSourceReferenceIterator is not iterable 30 | """ 31 | sources = [] 32 | source_iterator = code_block.getSources(monitor) 33 | while source_iterator.hasNext(): 34 | sources.append(source_iterator.next()) 35 | return sources 36 | 37 | 38 | def get_code_block_dests(code_block): 39 | """ 40 | block.SimpleSourceReferenceIterator is not iterable 41 | """ 42 | dests = [] 43 | dest_iterator = code_block.getDestinations(monitor) 44 | while dest_iterator.hasNext(): 45 | dests.append(dest_iterator.next()) 46 | return dests 47 | 48 | 49 | def get_code_block_graph_map(): 50 | block_map = defaultdict(set) 51 | 52 | 53 | 54 | 55 | du = DecompUtils() 56 | -------------------------------------------------------------------------------- /find_base_by_refs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import ghidra 3 | 4 | from ghidra.program.model.symbol import FlowType, RefType 5 | from ghidra.program.model.address import AddressSet 6 | 7 | from __main__ import * 8 | from collections import defaultdict 9 | import logging 10 | 11 | log = logging.getLogger(__file__) 12 | log.addHandler(logging.StreamHandler()) 13 | log.setLevel(logging.WARNING) 14 | 15 | 16 | def offset_list_from_address_list(addr_list): 17 | """ 18 | Get a list of difference from each address in a list 19 | to the next entry in the list 20 | """ 21 | # sorted_addr_list = sorted(addr_list) 22 | offset_list = [] 23 | for ind in range(1, len(addr_list)): 24 | prev = addr_list[ind-1] 25 | curr = addr_list[ind] 26 | diff = curr.subtract(prev) 27 | offset_list.append(int(diff)) 28 | return offset_list 29 | 30 | 31 | addr_set = AddressSet() 32 | for m_block in getMemoryBlocks(): 33 | if not (m_block.isRead() or m_block.isWrite() or m_block.isExecute()): 34 | continue 35 | addr_set.add(m_block.getAddressRange()) 36 | 37 | 38 | refman = currentProgram.getReferenceManager() 39 | current_image_base = currentProgram.getImageBase() 40 | ref_iter = refman.getReferenceIterator(current_image_base) 41 | 42 | to_external = [] 43 | computed_call_refs = [] 44 | computed_jump_refs = [] 45 | conditional_jump_refs = [] 46 | external_data_refs = [] 47 | for ref in ref_iter: 48 | to_addr = ref.toAddress 49 | # ignore all references to locations that fit within the 50 | # currently established address space 51 | if addr_set.contains(to_addr): 52 | continue 53 | # stack references are just references to an address space that 54 | # will exist at runtime, ignore 55 | if ref.isStackReference(): 56 | continue 57 | to_external.append(ref) 58 | # a computed ref is expected to utilize an absolute address 59 | if ref.referenceType.isComputed(): 60 | if ref.referenceType.isCall(): 61 | # calls to function pointers 62 | computed_call_refs.append(ref) 63 | elif ref.referenceType.isJump(): 64 | # likely trampolines or switch-case statements 65 | if ref.referenceType.isConditional(): 66 | # expected to be switch statement 67 | conditional_jump_refs.append(ref) 68 | else: 69 | computed_jump_refs.append(ref) 70 | if ref.referenceType.isData(): 71 | external_data_refs.append(ref) 72 | 73 | 74 | computed_call_to_addrs = list(set([i.toAddress for i in computed_call_refs])) 75 | computed_call_to_addrs.sort() 76 | listing = currentProgram.getListing() 77 | 78 | established_function_entrypoints = [i.getEntryPoint() for i in currentProgram.getFunctionManager().getFunctions(1)] 79 | established_function_entrypoints.sort() 80 | 81 | 82 | established_func_offsets = offset_list_from_address_list(established_function_entrypoints) 83 | 84 | computed_addr_offsets = offset_list_from_address_list(computed_call_to_addrs) 85 | 86 | computed_addr_range = sum(computed_addr_offsets) 87 | 88 | 89 | conditional_jump_refs_by_func = defaultdict(list) 90 | for ref in conditional_jump_refs: 91 | referring_func = getFunctionContaining(ref.fromAddress) 92 | if referring_func is None: 93 | log.warning("No function for %s" % str(ref)) 94 | continue 95 | conditional_jump_refs_by_func[referring_func].append(ref) 96 | 97 | conditional_jump_refs_by_func = dict(conditional_jump_refs_by_func) 98 | conditional_jump_ref_offsets_by_func = {} 99 | for k, refs in conditional_jump_refs_by_func.items(): 100 | to_addrs = sorted([i.toAddress for i in refs]) 101 | conditional_jump_ref_offsets_by_func[k] = offset_list_from_address_list(to_addrs) 102 | # for func_addr in established_function_entrypoints: 103 | # curr_base = func_addr 104 | # for offset in computed_addr_offsets: 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /find_ucmp_with_sub.py: -------------------------------------------------------------------------------- 1 | from __main__ import * 2 | from decomp_utils import DecompUtils 3 | from ghidra.program.model.pcode import PcodeOpAST 4 | from ghidra.app.decompiler.component import DecompilerUtils 5 | from collections import defaultdict 6 | import logging 7 | 8 | log = logging.getLogger(__file__) 9 | log.addHandler(logging.StreamHandler()) 10 | log.setLevel(logging.WARNING) 11 | 12 | 13 | ucmp_opcodes = [PcodeOpAST.INT_LESS, PcodeOpAST.INT_LESSEQUAL] 14 | 15 | cmpsites = defaultdict(list) 16 | du = DecompUtils() 17 | 18 | for func in currentProgram.getFunctionManager().getFunctions(1): 19 | log.debug("looking at %s" % func.name) 20 | pcode_ops = du.get_pcode_for_function(func) 21 | ucmp_insts = [i for i in pcode_ops if i.opcode in ucmp_opcodes] 22 | for inst in ucmp_insts: 23 | # TODO: this might only matter if it is the second input 24 | for inp in inst.inputs: 25 | back_slice_ops = DecompilerUtils.getBackwardSliceToPCodeOps(inp) 26 | if PcodeOpAST.INT_SUB in [i.opcode for i in back_slice_ops]: 27 | cmpsites[func].append(inst.seqnum.getTarget()) 28 | # only care to find the cmp site, so skip the next one 29 | # on success 30 | break 31 | 32 | for func, addrs in cmpsites.items(): 33 | print(func) 34 | for addr in addrs: 35 | print(addr) 36 | print("") 37 | -------------------------------------------------------------------------------- /find_unk_periphs.py: -------------------------------------------------------------------------------- 1 | from __main__ import * 2 | from decomp_utils import DecompUtils 3 | from ghidra.program.model.pcode import PcodeOpAST 4 | from ghidra.app.decompiler.component import DecompilerUtils 5 | from ghidra.program.model.symbol import FlowType, RefType 6 | from ghidra.program.model.address import AddressSet 7 | from collections import defaultdict 8 | import logging 9 | 10 | log = logging.getLogger(__file__) 11 | log.addHandler(logging.StreamHandler()) 12 | log.setLevel(logging.DEBUG) 13 | 14 | 15 | def group_by_increment(iterable, group_incr, field_access=None, do_sort=True): 16 | """ 17 | Identify series of values that increment/decrement 18 | within a bounds @group_incr, grouping them into lists. 19 | The comparison to determine whether a value belongs in a group is 20 | if (prev_val + group_incr) <= curr_val: 21 | 22 | @iterable: iterable. This must be sorted for this function to work correctly. 23 | @group_incr: amount to be added to a value to determine 24 | @field_access: optional function to run on each element of the iterable to get 25 | a value to be compared. 26 | """ 27 | if field_access is None: 28 | field_access = lambda a: a 29 | if do_sort is True: 30 | iterable.sort() 31 | grouped = [] 32 | current = [iterable[0]] 33 | for i in range(1, len(iterable)): 34 | curr_val = field_access(iterable[i]) 35 | prev_val = field_access(current[-1]) 36 | if (prev_val + group_incr) >= curr_val: 37 | current.append(iterable[i]) 38 | else: 39 | grouped.append(current) 40 | current = [iterable[i]] 41 | if current: 42 | grouped.append(current) 43 | return grouped 44 | 45 | 46 | def bnot(n, numbits=None): 47 | if numbits is None: 48 | numbits = currentProgram.getDefaultPointerSize()*8 49 | return (1 << numbits) -1 -n 50 | 51 | 52 | def align(val, align_to, numbits=None): 53 | if numbits is None: 54 | numbits = currentProgram.getDefaultPointerSize()*8 55 | return val & bnot(align_to - 1, numbits) 56 | 57 | 58 | def align_up(val, align_to, numbits=None): 59 | if numbits is None: 60 | numbits = currentProgram.getDefaultPointerSize()*8 61 | aligned = align(val, align_to, numbits) 62 | if aligned < val: 63 | aligned += align_to 64 | return aligned 65 | 66 | 67 | class PseudoMemoryRegion: 68 | def __init__(self, values=None, ref_count=0, exec_count=0, save_values=True, start=0, end=0, align_to=0x1000, pad_end_by=0): 69 | if values is None: 70 | values = [] 71 | self.ref_count = ref_count 72 | self.exec_count = exec_count 73 | # python 2 list copy 74 | values = [i for i in values] 75 | values.sort() 76 | self.start = start 77 | self.end = end 78 | if values: 79 | self.start = values[0] 80 | self.end = values[-1] 81 | aligned_start = align(self.start, align_to) 82 | # resolves a bug with constants larger than pointer size 83 | if aligned_start != 0: 84 | self.start = aligned_start 85 | if save_values: 86 | self.values = values 87 | else: 88 | self.values = [] 89 | self.length = (self.end + pad_end_by) - self.start 90 | def __repr__(self): 91 | return "PseudoMemoryRegion(%#x-%#x, length=%#x)" % (self.start, self.end, self.length) 92 | 93 | 94 | class PeriphFinder: 95 | def __init__(self, group_incr=0x1000): 96 | self.group_incr = group_incr 97 | self.const_counts = defaultdict(lambda: 0) 98 | self.exec_const_counts = defaultdict(lambda: 0) 99 | self.const_set = set() 100 | self.pmrs = [] 101 | 102 | def find_code_accessed_consts(self): 103 | """ 104 | Find constants accessed by pcode ops 105 | """ 106 | du = DecompUtils() 107 | call_ops = [PcodeOpAST.CALL, PcodeOpAST.CALLIND] 108 | 109 | # establish groups of constants 110 | for func in currentProgram.getFunctionManager().getFunctions(1): 111 | log.debug("looking at %s" % func.name) 112 | pcode_ops = du.get_pcode_for_function(func) 113 | if pcode_ops is None: 114 | continue 115 | for op in pcode_ops: 116 | # vns = sum([list(i.inputs) for i in pcode_ops], []) 117 | vns = op.inputs 118 | for i, vn in enumerate(vns): 119 | if not vn.isConstant() and not vn.isAddress(): 120 | continue 121 | # TODO: maybe fix this is offset is negative 122 | off = vn.getOffset() 123 | self.const_counts[off] += 1 124 | self.const_set.add(off) 125 | # record if the ref was a call 126 | if op.opcode in call_ops and i == 0: 127 | self.exec_const_counts[off] += 1 128 | 129 | 130 | def find_defined_consts(self): 131 | """ 132 | Iterate through defined data and identify all consts 133 | """ 134 | listing = currentProgram.getListing() 135 | for dat in listing.getDefinedData(1): 136 | if dat.valueClass is None: 137 | val = dat.value 138 | if val is None: 139 | continue 140 | int_val = val.getUnsignedValue() 141 | self.const_set.add(int_val) 142 | self.const_counts[int_val] += 1 143 | if dat.isPointer(): 144 | val = dat.value 145 | if val is None: 146 | continue 147 | int_val = val.getOffsetAsBigInteger() 148 | self.const_set.add(int_val) 149 | self.const_counts[int_val] += 1 150 | # TODO: maybe handle structs, unions, and arrays 151 | 152 | def find_periphs(self): 153 | self.find_code_accessed_consts() 154 | self.find_defined_consts() 155 | sorted_consts = list(self.const_set) 156 | sorted_consts.sort() 157 | # group consts by how close they are to eachother 158 | const_groups = group_by_increment(sorted_consts, self.group_incr) 159 | pmrs = [] 160 | for g in const_groups: 161 | ref_count = sum([self.const_counts.get(i, 0) for i in g]) 162 | exec_count = sum([self.exec_const_counts.get(i, 0) for i in g]) 163 | pmrs.append(PseudoMemoryRegion(g, ref_count=ref_count, 164 | exec_count=exec_count)) 165 | ptr_size = currentProgram.getDefaultPointerSize() 166 | self.pmrs = [] 167 | for i in pmrs: 168 | if i.length == 0: 169 | continue 170 | if i.start < 0: 171 | continue 172 | if i.start.bit_length() > ptr_size*8: 173 | continue 174 | if i.end.bit_length() > ptr_size*8: 175 | continue 176 | self.pmrs.append(i) 177 | return self.pmrs 178 | 179 | 180 | def print_possible_periph_regions(): 181 | # get an address set for all current memory blocks 182 | existing_mem_addr_set = AddressSet() 183 | for m_block in getMemoryBlocks(): 184 | # cut sections that are unused 185 | if not (m_block.isRead() or m_block.isWrite() or m_block.isExecute()): 186 | continue 187 | existing_mem_addr_set.add(m_block.getAddressRange()) 188 | 189 | pf = PeriphFinder() 190 | valid_pmrs = pf.find_periphs() 191 | ptr_size = currentProgram.getDefaultPointerSize() 192 | hex_ptr_size = (ptr_size*2)+2 193 | print("start end length aligned length refcount exec") 194 | fmt = "%%#0%dx-%%#0%dx %%#010x %%#010x %%d exec=%%d" % (hex_ptr_size, hex_ptr_size) 195 | for pmr in valid_pmrs: 196 | print(fmt % (pmr.start, pmr.end, pmr.length, align_up(pmr.length, 0x1000), pmr.ref_count, pmr.exec_count)) 197 | 198 | 199 | if __name__ == "__main__": 200 | print_possible_periph_regions() 201 | -------------------------------------------------------------------------------- /find_unknown_pointers.py: -------------------------------------------------------------------------------- 1 | from __main__ import * 2 | from ghidra.program.model.address import AddressSet 3 | import struct 4 | import re 5 | import logging 6 | 7 | log = logging.getLogger(__file__) 8 | log.addHandler(logging.StreamHandler()) 9 | log.setLevel(logging.DEBUG) 10 | 11 | 12 | def applyDataTypeAtAddress(address, datatype, size=None, program=None): 13 | if program is None: 14 | program = currentProgram 15 | if size is None: 16 | size = datatype.getLength() 17 | listing = program.getListing() 18 | listing.clearCodeUnits(address, address.add(size), False) 19 | listing.createData(address, datatype, size) 20 | 21 | 22 | 23 | def gen_address_range_rexp(minimum_addr, maximum_addr, program=None): 24 | if program is None: 25 | program = currentProgram 26 | 27 | ptr_size = program.getDefaultPointerSize() 28 | mem = program.getMemory() 29 | is_big_endian = mem.isBigEndian() 30 | ptr_pack_sym = "" 31 | if ptr_size == 4: 32 | ptr_pack_sym = "I" 33 | elif ptr_size == 8: 34 | ptr_pack_sym = "Q" 35 | 36 | pack_endian = "" 37 | if is_big_endian is True: 38 | pack_endian = ">" 39 | else: 40 | pack_endian = "<" 41 | ptr_pack_code = pack_endian + ptr_pack_sym 42 | 43 | diff = maximum_addr - minimum_addr 44 | val = diff 45 | # calculate the changed number of bytes between the minimum_addr and the maximum_addr 46 | byte_count = 0 47 | while val > 0: 48 | val = val >> 8 49 | byte_count += 1 50 | 51 | # generate a sufficient wildcard character classes for all of the bytes that could fully c 52 | wildcard_bytes = byte_count - 1 53 | wildcard_pattern = "[\\x00-\\xff]" 54 | boundary_byte_upper = (maximum_addr >> (wildcard_bytes*8)) & 0xff 55 | boundary_byte_lower = (minimum_addr >> (wildcard_bytes*8)) & 0xff 56 | if boundary_byte_upper < boundary_byte_lower: 57 | boundary_byte_upper, boundary_byte_lower = boundary_byte_lower, boundary_byte_upper 58 | # create a character class that will match the largest changing byte 59 | # lower_byte = bytearray([boundary_byte_lower]) 60 | # upper_byte = bytearray([boundary_byte_upper]) 61 | boundary_byte_pattern = "[\\x%02x-\\x%02x]" % (boundary_byte_lower, boundary_byte_upper) 62 | address_pattern = '' 63 | single_address_pattern = '' 64 | if is_big_endian is False: 65 | packed_addr = struct.pack(ptr_pack_code, minimum_addr) 66 | single_address_pattern = ''.join([wildcard_pattern*wildcard_bytes, 67 | boundary_byte_pattern]) 68 | for i in packed_addr[byte_count:]: 69 | single_address_pattern += "\\x%02x" % ord(i) 70 | else: 71 | packed_addr = struct.pack(ptr_pack_code, minimum_addr) 72 | for i in packed_addr[:byte_count]: 73 | single_address_pattern += "\\x%02x" % ord(i) 74 | single_address_pattern = ''.join([boundary_byte_pattern, 75 | wildcard_pattern*wildcard_bytes]) 76 | address_pattern = "(%s)" % single_address_pattern 77 | return address_pattern 78 | 79 | 80 | def create_full_memory_rexp(program=None): 81 | if program is None: 82 | program = currentProgram 83 | patterns = [] 84 | # get an address set for all current memory blocks 85 | for m_block in getMemoryBlocks(): 86 | start = m_block.start.getOffsetAsBigInteger() 87 | end = m_block.end.getOffsetAsBigInteger() 88 | pat = gen_address_range_rexp(start, end) 89 | log.debug("adding pattern '%s'" % pat) 90 | patterns.append(pat) 91 | 92 | full_pat = '(%s)' % '|'.join(patterns) 93 | log.debug("full pattern '%s'" % full_pat) 94 | return full_pat 95 | 96 | 97 | def create_full_mem_addr_set(): 98 | existing_mem_addr_set = AddressSet() 99 | for m_block in getMemoryBlocks(): 100 | existing_mem_addr_set.add(m_block.getAddressRange()) 101 | return existing_mem_addr_set 102 | 103 | 104 | def find_full_mem_pointers(program=None, align_to=4): 105 | if program is None: 106 | program = currentProgram 107 | existing_mem_addr_set = create_full_mem_addr_set() 108 | full_pat = create_full_memory_rexp(program=program) 109 | for addr in findBytes(existing_mem_addr_set, full_pat, 100000, align_to, True): 110 | yield addr 111 | 112 | 113 | def identify_unknown_pointers(program=None, align_to=4): 114 | if program is None: 115 | program = currentProgram 116 | dtm = program.getDataTypeManager() 117 | ptr_dt = [i for i in dtm.getAllDataTypes() if i.name == 'pointer'][0] 118 | listing = program.getListing() 119 | for addr in find_full_mem_pointers(program=program, align_to=align_to): 120 | 121 | if addr.getOffsetAsBigInteger() % align_to != 0: 122 | continue 123 | def_code = listing.getCodeUnitContaining(addr) 124 | if def_code is not None: 125 | log.warning("match in code at %s" % addr) 126 | continue 127 | 128 | def_dat = listing.getDataContaining(addr) 129 | # skip defined data 130 | if def_dat is not None: 131 | continue 132 | log.info("found data at %s" % addr) 133 | applyDataTypeAtAddress(addr, ptr_dt) 134 | 135 | 136 | if __name__ == "__main__": 137 | identify_unknown_pointers() 138 | 139 | 140 | -------------------------------------------------------------------------------- /function_signature_utils.py: -------------------------------------------------------------------------------- 1 | from __main__ import * 2 | 3 | from ghidra.app.cmd.function import ApplyFunctionSignatureCmd 4 | from ghidra.program.model.data import Category 5 | from ghidra.program.model.data import CategoryPath 6 | from ghidra.program.model.data import DataType 7 | from ghidra.program.model.data import FunctionDefinitionDataType 8 | from ghidra.program.model.data import ProgramBasedDataTypeManager 9 | from ghidra.program.model.data import PointerDataType 10 | from ghidra.program.model.data import ParameterDefinition 11 | from ghidra.program.model.data import ParameterDefinitionImpl 12 | from ghidra.program.model.listing import Function 13 | from ghidra.program.model.listing import FunctionIterator 14 | from ghidra.program.model.listing import FunctionManager 15 | from ghidra.program.model.listing import FunctionSignature 16 | from ghidra.program.model.listing import Program 17 | from ghidra.program.model.symbol import SourceType 18 | from ghidra.program.database.data import DataTypeUtilities 19 | from ghidra.program.model.data import DefaultDataType 20 | from ghidra.program.model.data import MetaDataType 21 | 22 | from datatype_utils import getUndefinedRegisterSizeDatatype 23 | 24 | 25 | def getDataTypeForParam(func, param_num): 26 | sig = func.getSignature() 27 | param_ind = param_num - 1 28 | if param_ind < 0: 29 | raise Exception("param_num is too low to be valid") 30 | args = list(sig.getArguments()) 31 | if len(args) <= param_ind: 32 | return None 33 | param = args[param_ind] 34 | existing_datatype = param.getDataType() 35 | return existing_datatype 36 | 37 | 38 | def set_num_params(func, num_params, widen_undef_params=True, widen_undef_return=False, default_datatype=None, var_args=False, program=None): 39 | """ 40 | Set the number of parameters for a function. 41 | """ 42 | if program is None: 43 | program = currentProgram 44 | 45 | if default_datatype is None: 46 | default_datatype = getUndefinedRegisterSizeDatatype(program) 47 | 48 | existing_sig = func.getSignature() 49 | existing_args = list(existing_sig.getArguments()) 50 | existing_args_len = len(existing_args) 51 | # create a list of parameters 52 | params = [] 53 | for i in range(num_params): 54 | if i < existing_args_len: 55 | param = existing_args[i] 56 | if widen_undef_params is True: 57 | dt = param.getDataType() 58 | if isinstance(dt, DefaultDataType): 59 | param.setDataType(default_datatype) 60 | else: 61 | param_name = "param_%d" % (i+1) 62 | param_dt = default_datatype 63 | param_comment = "" 64 | param = ParameterDefinitionImpl(param_name, param_dt, param_comment) 65 | params.append(param) 66 | 67 | existing_sig.setArguments(params) 68 | if var_args is True: 69 | existing_args.setVarArgs(True) 70 | 71 | if widen_undef_return is True: 72 | return_type = existing_sig.getReturnType() 73 | if isinstance(return_type, DefaultDataType): 74 | existing_sig.setReturnType(default_datatype) 75 | # FunctionSignature newSignature = func_def 76 | cmd = ApplyFunctionSignatureCmd(func.getEntryPoint(), existing_sig, SourceType.USER_DEFINED) 77 | return runCommand(cmd) 78 | 79 | 80 | def set_param_datatype(func, param_num, datatype, program=None): 81 | """ 82 | Sets the Datatype for a parameter 83 | param_num is indexed from 1 and matches the param_* that can be seen in the decompiler. 84 | """ 85 | param_ind = param_num-1 86 | if param_ind < 0: 87 | raise Exception("parameter number is too low") 88 | if program is None: 89 | program = currentProgram 90 | default_datatype = getUndefinedRegisterSizeDatatype(program) 91 | existing_sig = func.getSignature() 92 | existing_args = list(existing_sig.getArguments()) 93 | existing_args_len = len(existing_args) 94 | # create a list of parameters 95 | params = [] 96 | for i in range(max(existing_args_len, param_ind)): 97 | if i >= existing_args_len: 98 | param_name = "param_%d" % (i+1) 99 | if param_ind == i: 100 | param_dt = datatype 101 | else: 102 | param_dt = default_datatype 103 | param_comment = "" 104 | param = ParameterDefinitionImpl(param_name, param_dt, param_comment) 105 | else: 106 | param = existing_args[i] 107 | if i == param_ind: 108 | param.setDataType(datatype) 109 | params.append(param) 110 | 111 | existing_sig.setArguments(params) 112 | cmd = ApplyFunctionSignatureCmd(func.getEntryPoint(), existing_sig, SourceType.USER_DEFINED) 113 | return runCommand(cmd) 114 | -------------------------------------------------------------------------------- /java_reflection_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __main__ import * 3 | import java 4 | 5 | """ 6 | Utilities for utilizing java's reflection capabilities from python 7 | """ 8 | 9 | 10 | def get_java_field(javaclass, field_name, check_super=True): 11 | curr_javaclass = javaclass 12 | while curr_javaclass is not None: 13 | for field in curr_javaclass.getDeclaredFields(): 14 | if field.name != field_name: 15 | continue 16 | return field 17 | if check_super is True: 18 | curr_javaclass = curr_javaclass.getSuperclass() 19 | else: 20 | break 21 | return None 22 | 23 | 24 | def get_accessible_java_field(javaclass, field_name, check_super=True): 25 | """ 26 | """ 27 | field = get_java_field(javaclass, field_name, check_super=check_super) 28 | if field is not None: 29 | field.setAccessible(True) 30 | return field 31 | 32 | 33 | def satisfies_parameter_constraints(method_or_constr, constraints): 34 | if len(constraints) > 0: 35 | max_constraint_ind = max(constraints.keys()) 36 | else: 37 | max_constraint_ind = -1 38 | # if there are constraints and the constraints 39 | # are outside of the bounds of the existing parameters, 40 | # this method can't be the one being looked for 41 | param_count = method_or_constr.getParameterCount() 42 | if max_constraint_ind > param_count-1: 43 | return False 44 | # check param constraints 45 | if param_count <= 0: 46 | param_types_arr = method_or_constr.getParameterTypes() 47 | for param_ind, expected_param_type in constraints.items(): 48 | # check each parameter constraint present 49 | if param_types_arr[param_ind] != expected_param_type: 50 | return False 51 | # if there are no parameter constraints then by default 52 | # the constraints are satisfied 53 | return True 54 | 55 | 56 | def get_java_method_by_param_constraints(javaclass, method_name, constraints=None, check_super=True): 57 | if constraints is None: 58 | constraints = {} 59 | 60 | # remove return constraint value if it exists 61 | return_constraint = constraints.get(-1) 62 | if return_constraint is not None: 63 | constraints.pop(-1) 64 | curr_javaclass = javaclass 65 | while curr_javaclass is not None: 66 | for method in curr_javaclass.getDeclaredMethods(): 67 | if method.name != method_name: 68 | continue 69 | if satisfies_parameter_constraints(method, constraints) is False: 70 | continue 71 | # check return constraints 72 | if return_constraint is not None: 73 | return_type = method.getReturnType() 74 | # TODO: confirm that there isn't a type in java like 75 | # TODO: AlwaysNull that could be returned 76 | if return_type is None: 77 | continue 78 | if return_type != return_constraint: 79 | continue 80 | # return the first method that satisfies all constraints 81 | return method 82 | if check_super is True: 83 | curr_javaclass = curr_javaclass.getSuperclass() 84 | else: 85 | break 86 | return None 87 | 88 | def get_accessible_java_method_by_param_constraints(javaclass, method_name, constraints=None, check_super=True): 89 | method = get_java_method_by_param_constraints(javaclass, 90 | method_name, 91 | constraints, 92 | check_super=check_super) 93 | if method is not None: 94 | method.setAccessible(True) 95 | return method 96 | 97 | 98 | def get_java_constructor_by_param_constraints(javaclass, constraints=None): 99 | if constraints is None: 100 | constraints = {} 101 | 102 | # There are no returns for constructors, so no return check needed 103 | for constructor in javaclass.getDeclaredConstructors(): 104 | if satisfies_parameter_constraints(constructor, constraints) is False: 105 | continue 106 | # return the first constructor that satisfies all constraints 107 | return constructor 108 | return None 109 | 110 | 111 | def get_accessible_java_constructor_by_param_constraints(javaclass, constraints=None): 112 | constructor = get_java_constructor_by_param_constraints(javaclass, 113 | constraints) 114 | if constructor is not None: 115 | constructor.setAccessible(True) 116 | return constructor 117 | 118 | 119 | def get_all_declared_fields(javaclass, ignore_object_fields=True): 120 | curr_javaclass = javaclass 121 | all_fields = [] 122 | while curr_javaclass is not None or (curr_javaclass is not None and ignore_object_fields and curr_javaclass != java.lang.Object): 123 | all_fields += list(curr_javaclass.getDeclaredFields()) 124 | curr_javaclass = curr_javaclass.getSuperclass() 125 | return all_fields 126 | 127 | 128 | def get_all_declared_methods(javaclass, ignore_object_fields=True): 129 | curr_javaclass = javaclass 130 | all_methods = [] 131 | while curr_javaclass is not None or (curr_javaclass is not None and ignore_object_fields and curr_javaclass != java.lang.Object): 132 | all_methods += list(curr_javaclass.getDeclaredMethods()) 133 | curr_javaclass = curr_javaclass.getSuperclass() 134 | return all_methods 135 | 136 | -------------------------------------------------------------------------------- /loopfinder.py: -------------------------------------------------------------------------------- 1 | 2 | from __main__ import * 3 | from ghidra.program.model.block import BasicBlockModel, CodeBlockIterator 4 | from ghidra.program.model.symbol import FlowType 5 | import ghidra.util.exception.CancelledException 6 | import ghidra.util.task.TaskMonitor 7 | from ghidra.program.model.block.graph import CodeBlockEdge, CodeBlockVertex 8 | from ghidra.graph import GDirectedGraph, GraphFactory, GraphAlgorithms 9 | import ghidra.graph.algo 10 | from ghidra.program.model.address import AddressSet 11 | from decomp_utils import DecompUtils 12 | import java 13 | 14 | 15 | def block_loops_to_self(block, monitor_inst=None): 16 | """ 17 | Check if a block jumps back to it self 18 | """ 19 | if monitor_inst is None: 20 | monitor_inst = monitor 21 | block_iter = block.getDestinations(monitor_inst) 22 | while block_iter.hasNext(): 23 | if monitor_inst.isCancelled(): 24 | break 25 | block_ref = block_iter.next() 26 | flow_type = block_ref.getFlowType() 27 | # TODO: indirection might be valid here, check 28 | if flow_type.isCall() is True or flow_type.isIndirect() is True: 29 | continue 30 | next_block = block_ref.getDestinationBlock() 31 | if next_block is None: 32 | continue 33 | if next_block == block: 34 | return True 35 | return False 36 | 37 | 38 | def is_addr_in_loop(addr, program=None, monitor_inst=None): 39 | """ 40 | Check if an address is in a basic loop within the current function. 41 | Untested 42 | """ 43 | if program is None: 44 | program = currentProgram 45 | if monitor_inst is None: 46 | monitor_inst = monitor 47 | bbm = BasicBlockModel(program) 48 | start_blocks = list(bbm.getCodeBlocksContaining(addr, monitor_inst)) 49 | # leave early if any of the first blocks just jump to themselves 50 | if any([block_loops_to_self(b) for b in start_blocks]) is True: 51 | return True 52 | 53 | # do a DFS to find all blocks that lead up to the starting blocks 54 | to_visit = set(start_blocks) 55 | visited = set() 56 | while to_visit: 57 | if monitor_inst.isCancelled(): 58 | break 59 | block = to_visit.pop() 60 | block_iter = block.getSources(monitor_inst) 61 | while block_iter.hasNext(): 62 | if monitor_inst.isCancelled(): 63 | break 64 | block_ref = block_iter.next() 65 | flow_type = block_ref.getFlowType() 66 | # TODO: indirection might be valid here, check 67 | if flow_type.isCall() is True or flow_type.isIndirect() is True: 68 | continue 69 | next_block = block_ref.getSourceBlock() 70 | if next_block is None: 71 | continue 72 | if next_block in visited: 73 | continue 74 | if next_block in to_visit: 75 | continue 76 | if next_block == block: 77 | continue 78 | to_visit.add(next_block) 79 | visited.add(block) 80 | 81 | back_blocks = visited 82 | 83 | # do a second DFS to get all of the blocks forward from the starting 84 | # blocks. Exits early if a block that leads to the start blocks is reached 85 | to_visit = set(start_blocks) 86 | visited = set() 87 | while to_visit: 88 | if monitor_inst.isCancelled(): 89 | break 90 | block = to_visit.pop() 91 | block_iter = block.getDestinations(monitor_inst) 92 | while block_iter.hasNext(): 93 | if monitor_inst.isCancelled(): 94 | break 95 | block_ref = block_iter.next() 96 | flow_type = block_ref.getFlowType() 97 | if flow_type.isCall() is True or flow_type.isIndirect() is True: 98 | continue 99 | next_block = block_ref.getDestinationBlock() 100 | if next_block is None: 101 | continue 102 | # extra check to exit early for found loops to reduce total 103 | # iterations 104 | if next_block in back_blocks: 105 | return True 106 | if next_block in visited: 107 | continue 108 | if next_block in to_visit: 109 | continue 110 | if next_block == block: 111 | continue 112 | to_visit.add(next_block) 113 | visited.add(block) 114 | # fwd_blocks = visited 115 | return False 116 | 117 | 118 | def getCodeBlockDestinations(block, monitor_inst=None): 119 | """ 120 | Get destination code blocks for a given code block 121 | """ 122 | if monitor_inst is None: 123 | monitor_inst = monitor 124 | all_dest_blocks = set() 125 | block_iter = block.getDestinations(monitor_inst) 126 | while block_iter.hasNext(): 127 | if monitor_inst.isCancelled(): 128 | break 129 | block_ref = block_iter.next() 130 | # flow_type = block_ref.getFlowType() 131 | # if flow_type.isCall() is True or flow_type.isIndirect() is True: 132 | # continue 133 | dest_block = block_ref.getDestinationBlock() 134 | if dest_block is None: 135 | continue 136 | all_dest_blocks.add(dest_block) 137 | return all_dest_blocks 138 | 139 | 140 | class Circuit(object): 141 | """ 142 | An object representing a single loop in a function 143 | """ 144 | def __init__(self, edges, cfg): 145 | self.edges = set(edges) 146 | self.vertices = set(sum([[i.getStart(), i.getEnd()] for i in self.edges], [])) 147 | self.cfg = cfg 148 | self.addr_set = self._getAddressSet() 149 | self.exit_edges = set() 150 | self.exit_vertices = set() 151 | self._findLoopExits() 152 | self.exit_target_addr_set = self._getExitTargetAddressSet() 153 | 154 | def _getAddressSet(self): 155 | """ 156 | Get an address set for all addresses within this loop 157 | """ 158 | addr_set = AddressSet() 159 | for v in self.vertices: 160 | for rang in v.getCodeBlock().getAddressRanges(): 161 | addr_set.add(rang) 162 | return addr_set 163 | 164 | def _getExitTargetAddressSet(self): 165 | """ 166 | Get an address set for all addresses that can be jumped to 167 | when exiting the loop 168 | """ 169 | addr_set = AddressSet() 170 | for e in self.exit_edges: 171 | end_v = e.getEnd() 172 | end_cb = end_v.getCodeBlock() 173 | for rang in end_cb.getAddressRanges(): 174 | addr_set.add(rang) 175 | return addr_set 176 | 177 | def _findLoopExits(self): 178 | """ 179 | Find Vertices and edges that can exit the circuit 180 | """ 181 | self.exit_vertices = set() 182 | self.exit_edges = set() 183 | for v in self.vertices: 184 | has_exit = False 185 | # TODO: does this handle verts returning out of the function? 186 | # TODO: if not, use getCodeBlockDestinations 187 | for e in self.cfg.getOutEdges(v): 188 | if e not in self.edges: 189 | self.exit_edges.add(e) 190 | has_exit = True 191 | if has_exit is True: 192 | self.exit_vertices.add(v) 193 | 194 | def get_function(self): 195 | """ 196 | Get the function that this loop is a part of 197 | """ 198 | # get Any Vertex from the cfg 199 | vert = list(self.cfg.getVertices())[0] 200 | return getFunctionContaining(vert.getCodeBlock().getStartAddresses()[0]) 201 | 202 | def get_loop_exiting_pcode_blocks(self): 203 | """ 204 | Returns a list of tuples containing pcode basic blocks and the 205 | exit type the loop (whether the condition in the loop has to be true or 206 | false to exit the loop) 207 | """ 208 | func = self.get_function() 209 | du = DecompUtils() 210 | pcode_blocks = du.get_pcode_blocks_for_function(func) 211 | loop_exiting_blocks = [] 212 | for exit_edge in self.exit_edges: 213 | exit_type = False 214 | exit_v = exit_edge.getStart() 215 | exit_jmp_target_v = exit_edge.getEnd() 216 | jmp_target_start_addrs = exit_jmp_target_v.getCodeBlock().getStartAddresses() 217 | v_start_addrs = exit_v.getCodeBlock().getStartAddresses() 218 | # should realistically only be one block, but these are the pcode 219 | # basic blocks that could exit the loop 220 | exit_pcode_blocks = [b for b in pcode_blocks if b.start in v_start_addrs] 221 | for b in exit_pcode_blocks: 222 | try: 223 | out = b.getFalseOut() 224 | except java.lang.IndexOutOfBoundsException: 225 | # FIXME: this is probably because a call or return 226 | # FIXME: to/from a function called recursively 227 | print("oob exception in %s" % func.name) 228 | out = None 229 | if out is not None and not self.addr_set.contains(out.start): 230 | exit_type = False 231 | loop_exiting_blocks.append((b, out, exit_type)) 232 | 233 | try: 234 | out = b.getTrueOut() 235 | except java.lang.IndexOutOfBoundsException: 236 | # FIXME: this is probably because a call or return 237 | # FIXME: to/from a function called recursively 238 | print("oob exception in %s" % func.name) 239 | out = None 240 | if out is not None and not self.addr_set.contains(out.start): 241 | exit_type = True 242 | loop_exiting_blocks.append((b, out, exit_type)) 243 | return loop_exiting_blocks 244 | 245 | 246 | class CircuitCollection(object): 247 | """ 248 | A simple class to hold loops and success status 249 | """ 250 | def __init__(self, cfg=None): 251 | # this is false when the circuit finding takes too long 252 | # private boolean complete 253 | # private Set allCircuits = new HashSet<>() 254 | # private Map> circuitsByVertex = new HashMap<>() 255 | self.complete = False 256 | # a series of sets, each set representing all of the edges in a loop 257 | self.allCircuits = set() 258 | self.circuitsByVertex = {} 259 | self.cfg = cfg 260 | # a set of Circuit objects 261 | self.circuitObjs = set() 262 | 263 | def clear(self): 264 | self.allCircuits = set() 265 | self.circuitsByVertex = {} 266 | 267 | def addCircuitEdges(self, edges): 268 | self.allCircuits.add(edges) 269 | circ = Circuit(edges, self.cfg) 270 | self.circuitObjs.add(circ) 271 | 272 | 273 | class LoopFinder(object): 274 | """ 275 | Based on AbstractModularizationCmd.java 276 | """ 277 | def __init__(self, program=None, monitor_inst=None): 278 | if program is None: 279 | program = currentProgram 280 | if monitor_inst is None: 281 | monitor_inst = monitor 282 | self.monitor = monitor_inst 283 | self.bbm = BasicBlockModel(program) 284 | 285 | def createCFGForFunc(self, func): 286 | """ 287 | returns GDirectedGraph 288 | """ 289 | return self.createCFGForAddressSet(func.body) 290 | 291 | def createCFGForAddressSet(self, address_set): 292 | """ 293 | returns GDirectedGraph 294 | """ 295 | self.validAddresses = address_set 296 | # Map instanceMap = new HashMap<>() 297 | instanceMap = {} 298 | # GDirectedGraph graph = GraphFactory.createDirectedGraph() 299 | graph = GraphFactory.createDirectedGraph() 300 | # CodeBlockIterator codeBlocks = getCallGraphBlocks() 301 | codeBlocks = self.bbm.getCodeBlocksContaining(self.validAddresses, self.monitor) 302 | while codeBlocks.hasNext(): 303 | block = codeBlocks.next() 304 | 305 | # CodeBlockVertex fromVertex = instanceMap.get(block) 306 | fromVertex = instanceMap.get(block) 307 | if fromVertex is None: 308 | fromVertex = CodeBlockVertex(block) 309 | instanceMap[block] = fromVertex 310 | graph.addVertex(fromVertex) 311 | 312 | # destinations section 313 | self.addEdgesForDestinations(graph, fromVertex, block, instanceMap) 314 | return graph 315 | 316 | def addEdgesForDestinations(self, graph, fromVertex, sourceBlock, instanceMap): 317 | """ 318 | GDirectedGraph graph, 319 | CodeBlockVertex fromVertex, 320 | CodeBlock sourceBlock, 321 | Map instanceMap 322 | """ 323 | 324 | # CodeBlockReferenceIterator iterator = sourceBlock.getDestinations(monitor) 325 | iterator = sourceBlock.getDestinations(self.monitor) 326 | while iterator.hasNext(): 327 | self.monitor.checkCancelled() 328 | 329 | # CodeBlockReference destination = iterator.next() 330 | destination = iterator.next() 331 | # CodeBlock targetBlock = getDestinationBlock(destination) 332 | targetBlock = self.getDestinationBlock(destination) 333 | if targetBlock is None: 334 | continue # # no block found 335 | 336 | # CodeBlockVertex targetVertex = instanceMap.get(targetBlock) 337 | targetVertex = instanceMap.get(targetBlock) 338 | if targetVertex is None: 339 | targetVertex = CodeBlockVertex(targetBlock) 340 | instanceMap[targetBlock] = targetVertex 341 | 342 | graph.addVertex(targetVertex) 343 | graph.addEdge(CodeBlockEdge(fromVertex, targetVertex)) 344 | 345 | def getDestinationBlock(self, destination): 346 | """ 347 | CodeBlockReference destination 348 | returns CodeBlock 349 | """ 350 | targetAddress = destination.getDestinationAddress() 351 | # CodeBlock targetBlock = self.bbm.getFirstCodeBlockContaining(targetAddress, monitor) 352 | targetBlock = self.bbm.getFirstCodeBlockContaining(targetAddress, self.monitor) 353 | if targetBlock is None: 354 | return None # # no code found for call external? 355 | 356 | blockAddress = targetBlock.getFirstStartAddress() 357 | if self.skipAddress(blockAddress): 358 | return None 359 | 360 | return targetBlock 361 | 362 | def skipAddress(self, address): 363 | """ 364 | Address address 365 | returns boolean 366 | """ 367 | # if (processEntireProgram): 368 | # return False 369 | return not self.validAddresses.contains(address) 370 | 371 | 372 | def getDominanceGraph(visualGraph, forward): 373 | """ 374 | VisualGraph visualGraph, 375 | boolean forward 376 | returns GDirectedGraph 377 | """ 378 | # Set sources 379 | sources = GraphAlgorithms.getSources(visualGraph) 380 | if len(sources) != 0: 381 | return visualGraph 382 | 383 | return None 384 | 385 | 386 | class GraphPathHelper(object): 387 | """ 388 | Based on VisualGraphPathHighlighter.java, but without swing. 389 | Helps with working with graphs for things like finding loops in 390 | a cfg 391 | """ 392 | def __init__(self, graph, program=None, monitor_inst=None): 393 | self.graph = graph 394 | if program is None: 395 | program = currentProgram 396 | if monitor_inst is None: 397 | monitor_inst = monitor 398 | self.program = program 399 | self.monitor = monitor_inst 400 | # Map> 401 | self.forwardScopedFlowEdgeCache = {} 402 | self.reverseScopedFlowEdgeCache = {} 403 | self.forwardFlowEdgeCache = {} 404 | 405 | def getForwardScopedFlowEdgesForVertexAsync(self, v): 406 | """ 407 | returns Set 408 | """ 409 | if v is None: 410 | return None 411 | 412 | # Set flowEdges = self.forwardScopedFlowEdgeCache.get(v) 413 | flowEdges = self.forwardScopedFlowEdgeCache.get(v) 414 | if flowEdges is None: 415 | flowEdges = self.findForwardScopedFlowAsync(v) 416 | self.forwardScopedFlowEdgeCache[v] = flowEdges 417 | # return Collections.unmodifiableSet(flowEdges) 418 | return flowEdges 419 | 420 | def getForwardFlowEdgesForVertexAsync(self, v): 421 | """ 422 | returns Set 423 | """ 424 | return self.getFlowEdgesForVertexAsync(True, self.forwardFlowEdgeCache, v) 425 | 426 | def getReverseFlowEdgesForVertexAsync(self, v): 427 | """ 428 | returns Set 429 | """ 430 | return self.getFlowEdgesForVertexAsync(False, self.reverseFlowEdgeCache, v) 431 | 432 | def getFlowEdgesForVertexAsync(self, isForward, cache, v): 433 | """ 434 | boolean isForward, Map> cache, V v 435 | returns Set 436 | """ 437 | 438 | if v is None: 439 | return None 440 | 441 | # Set flowEdges = cache.get(v) 442 | flowEdges = cache.get(v) 443 | if flowEdges is None: 444 | flowEdges = set() 445 | # Set pathsToVertex 446 | pathsToVertex = GraphAlgorithms.getEdgesFrom(self.graph, v, isForward) 447 | # flowEdges.addAll(pathsToVertex) 448 | flowEdges = flowEdges.union(pathsToVertex) 449 | cache[v] = flowEdges 450 | # return Collections.unmodifiableSet(flowEdges) 451 | return flowEdges 452 | 453 | def getAllCircuitFlowEdgesAsync(self): 454 | """ 455 | returns Set 456 | """ 457 | # # CompletableFuture future = lazyCreateCircuitFuture() 458 | # future = self.lazyCreateCircuitFuture() 459 | # # Circuits circuits = getAsync(future) # blocking operation 460 | # circuits = getAsync(future) # blocking operation 461 | circuits = self.calculateCircuits() 462 | if circuits is None: 463 | return set() # can happen during dispose 464 | # return Collections.unmodifiableSet(circuits.allCircuits) 465 | return set(circuits.allCircuits) 466 | 467 | def getReverseScopedFlowEdgesForVertexAsync(self, v): 468 | """ 469 | returns Set 470 | """ 471 | if v is None: 472 | return None 473 | 474 | # Set flowEdges = self.reverseScopedFlowEdgeCache.get(v) 475 | flowEdges = self.reverseScopedFlowEdgeCache.get(v) 476 | if flowEdges is None: 477 | flowEdges = self.findReverseScopedFlowAsync(v) 478 | self.reverseScopedFlowEdgeCache[v] = flowEdges 479 | # return Collections.unmodifiableSet(flowEdges) 480 | return set(flowEdges) 481 | 482 | def getCircuitEdgesAsync(self, v): 483 | """ 484 | returns Set 485 | """ 486 | 487 | if v is None: 488 | return None 489 | # # CompletableFuture future 490 | # future = self.lazyCreateCircuitFuture() 491 | # # Circuits circuits = getAsync(future) # blocking operation 492 | # circuits = getAsync(future) # blocking operation 493 | 494 | circuits = self.calculateCircuits() 495 | if circuits is None: 496 | return set() # can happen during dispose 497 | # Set 498 | circ_set = circuits.circuitsByVertex.get(v) 499 | if circ_set is None: 500 | return set() 501 | return set(circ_set) 502 | 503 | def calculateCircuits(self): 504 | """ 505 | TaskMonitor monitor 506 | returns CircuitCollection 507 | """ 508 | 509 | # Circuits result = new Circuits() 510 | result = CircuitCollection(self.graph) 511 | 512 | self.monitor.setMessage("Finding all loops") 513 | # Set> strongs 514 | strongs = GraphAlgorithms.getStronglyConnectedComponents(self.graph) 515 | # Set vertices 516 | for vertices in strongs: 517 | if self.monitor.isCancelled(): 518 | return result 519 | 520 | # removed to allow self-looping blocks 521 | # if len(vertices) == 1: 522 | # continue 523 | # GDirectedGraph subGraph 524 | subGraph = GraphAlgorithms.createSubGraph(self.graph, vertices) 525 | # Collection edges 526 | edges = subGraph.getEdges() 527 | if edges: 528 | result.addCircuitEdges(edges) 529 | # HashSet asSet 530 | asSet = set(edges) 531 | # Collection subVertices 532 | subVertices = subGraph.getVertices() 533 | # V v 534 | for v in subVertices: 535 | if self.monitor.isCancelled(): 536 | return result 537 | result.circuitsByVertex[v] = asSet 538 | 539 | result.complete = True 540 | return result 541 | 542 | def pathToEdgesAsync(self, path): 543 | """ 544 | List path 545 | returns List 546 | """ 547 | results = [] 548 | # Iterator it 549 | it = iter(path) 550 | from_v = it.next() 551 | while it.hasNext(): 552 | to = it.next() 553 | e = self.graph.findEdge(from_v, to) 554 | results.append(e) 555 | from_v = to 556 | return results 557 | 558 | ''' 559 | def findForwardScopedFlowAsync(self, v): 560 | """ 561 | V v 562 | returns Set 563 | """ 564 | 565 | # CompletableFuture> future = lazyCreateDominaceFuture() 566 | future = self.lazyCreateDominaceFuture() 567 | 568 | # GDirectedGraph dominanceGraph = getDominanceGraph(self.graph, True) 569 | dominanceGraph = getDominanceGraph(self.graph, True) 570 | 571 | try: 572 | # ChkDominanceAlgorithm dominanceAlgorithm = getAsync(future) 573 | dominanceAlgorithm = getAsync(future) 574 | 575 | if dominanceAlgorithm is not None: # null implies timeout 576 | # Set dominated 577 | dominated = dominanceAlgorithm.getDominated(v) 578 | return GraphAlgorithms.retainEdges(self.graph, dominated) 579 | 580 | except: 581 | pass 582 | # handled below 583 | 584 | # use the empty set so we do not repeatedly attempt to calculate these paths 585 | return set() 586 | ''' 587 | 588 | ''' 589 | def findReverseScopedFlowAsync(self, v): 590 | """ 591 | V v 592 | returns Set 593 | """ 594 | # CompletableFuture> future = lazyCreatePostDominanceFuture() 595 | future = self.lazyCreatePostDominanceFuture() 596 | 597 | try: 598 | # ChkDominanceAlgorithm postDominanceAlgorithm = getAsync(future) 599 | postDominanceAlgorithm = getAsync(future) 600 | 601 | if postDominanceAlgorithm is not None: # null implies timeout 602 | # Set dominated 603 | dominated = postDominanceAlgorithm.getDominated(v) 604 | return GraphAlgorithms.retainEdges(self.graph, dominated) 605 | except: 606 | pass 607 | # handled below 608 | 609 | # use the empty set so we do not repeatedly attempt to calculate these paths 610 | return set() 611 | ''' 612 | 613 | ''' 614 | def calculatePathsBetweenVerticesAsync(self, V v1, V v2) { 615 | """ 616 | V v1, V v2 617 | """ 618 | if v1.equals(v2): 619 | return 620 | 621 | # CallbackAccumulator> accumulator = new CallbackAccumulator<>(path -> { 622 | accumulator = new CallbackAccumulator<>(path -> { 623 | 624 | Collection edges = pathToEdgesAsync(path) 625 | SystemUtilities.runSwingLater(() -> setInHoverPathOnSwing(edges)) 626 | }) 627 | 628 | TaskMonitor timeoutMonitor = TimeoutTaskMonitor.timeoutIn(ALGORITHM_TIMEOUT, 629 | TimeUnit.SECONDS, new TaskMonitorAdapter(true)) 630 | 631 | try { 632 | GraphAlgorithms.findPaths(self.graph, v1, v2, accumulator, timeoutMonitor) 633 | } 634 | catch (ConcurrentModificationException e) { 635 | # TODO temp fix for 8.0. 636 | # This exception can happen when the current graph is being mutated off of the 637 | # Swing thread, such as when grouping and ungrouping. For now, squash the 638 | # problem, as it is only a UI feature. Post-"big graph branch merge", update 639 | # how we schedule this task in relation to background graph jobs (maybe just make 640 | # this task a job) 641 | } 642 | catch (CancelledException e) { 643 | SystemUtilities.runSwingLater( 644 | () -> setStatusTextSwing("Path computation halted by user or timeout.\n" + 645 | "Paths shown in graph are not complete!")) 646 | } 647 | 648 | } 649 | ''' 650 | 651 | ''' 652 | def lazyCreateDominaceFuture(self): 653 | """ 654 | returns CompletableFuture> 655 | """ 656 | 657 | # lazy-load 658 | if dominanceFuture is not None: 659 | return dominanceFuture 660 | 661 | # we use an executor to restrict thread usage by the Graph API 662 | # Executor executor = getGraphExecutor() 663 | executor = getGraphExecutor() 664 | dominanceFuture = CompletableFuture.supplyAsync(() -> { 665 | 666 | # this operation is fast enough that it shouldn't timeout, but just in case... 667 | TaskMonitor timeoutMonitor = TimeoutTaskMonitor.timeoutIn(ALGORITHM_TIMEOUT, 668 | TimeUnit.SECONDS, new TaskMonitorAdapter(true)) 669 | 670 | GDirectedGraph dominanceGraph = getDominanceGraph(self.graph, true) 671 | if (dominanceGraph is None: 672 | Msg.debug(this, "No sources found for graph cannot calculate dominance: " + 673 | self.graph.getClass().getSimpleName()) 674 | return null 675 | } 676 | 677 | try { 678 | # note: calling the constructor performs the work 679 | # return new ChkDominanceAlgorithm<>(dominanceGraph, timeoutMonitor) 680 | return ghidra.graph.algo.ChkDominanceAlgorithm(dominanceGraph, timeoutMonitor) 681 | } 682 | catch (CancelledException e) { 683 | # shouldn't happen 684 | Msg.debug(VisualGraphPathHighlighter.this, 685 | "Domiance calculation timed-out for " + self.graph.getClass().getSimpleName()) 686 | } 687 | return null 688 | }, executor) 689 | return dominanceFuture 690 | } 691 | ''' 692 | 693 | 694 | ''' 695 | private CompletableFuture> lazyCreatePostDominanceFuture() { 696 | 697 | # lazy-load 698 | if (postDominanceFuture is not None: 699 | return postDominanceFuture 700 | } 701 | 702 | Executor executor = getGraphExecutor() 703 | postDominanceFuture = CompletableFuture.supplyAsync(() -> { 704 | 705 | # this operation is fast enough that it shouldn't timeout, but just in case... 706 | TaskMonitor timeoutMonitor = TimeoutTaskMonitor.timeoutIn(ALGORITHM_TIMEOUT, 707 | TimeUnit.SECONDS, new TaskMonitorAdapter(true)) 708 | 709 | try { 710 | # note: calling the constructor performs the work 711 | return new ChkPostDominanceAlgorithm<>(self.graph, timeoutMonitor) 712 | } 713 | catch (CancelledException e) { 714 | # shouldn't happen 715 | Msg.debug(VisualGraphPathHighlighter.this, 716 | "Post-domiance calculation timed-out for " + self.graph.getClass().getSimpleName()) 717 | } 718 | return null 719 | }, executor) 720 | return postDominanceFuture 721 | } 722 | ''' 723 | -------------------------------------------------------------------------------- /pointer_utils.py: -------------------------------------------------------------------------------- 1 | # Utility Class for interacting with pointers. To get an instance of the class, 2 | # use createPointerUtils 3 | #@author Clifton Wolfe 4 | #@category Utils 5 | 6 | import ghidra 7 | from ghidra.program.model.symbol import SourceType 8 | from ghidra.program.model.address import GenericAddress, Address 9 | 10 | import string 11 | import re 12 | import struct 13 | # makes it easier for dev and testing 14 | from __main__ import * 15 | 16 | 17 | def compile_byte_rexp_pattern(pattern): 18 | """ 19 | Compile a pattern so that it can be searched in a series of bytes 20 | """ 21 | return re.compile(pattern, re.DOTALL | re.MULTILINE) 22 | 23 | 24 | def get_memory_bounds(excluded_memory_block_names=["tdb"]): 25 | """ 26 | Try to identify the bounds of memory that is currently mapped in. 27 | Some standard memory blocks (like `tdb` for microsoft binaries) 28 | are mapped in at ridiculous addresses (like 0xff00000000000000). 29 | If a memory block is mapped into this program 30 | """ 31 | minimum_addr = 0xffffffffffffffff 32 | maximum_addr = 0 33 | memory_blocks = list(getMemoryBlocks()) 34 | for m_block in memory_blocks: 35 | # tdb is placed at a very large address that is well outside 36 | # of the loaded range for most executables 37 | if m_block.name in excluded_memory_block_names: 38 | continue 39 | start = m_block.getStart().getOffset() 40 | end = m_block.getEnd().getOffset() 41 | if start < minimum_addr: 42 | minimum_addr = start 43 | if end > maximum_addr: 44 | maximum_addr = end 45 | return minimum_addr, maximum_addr 46 | 47 | 48 | def search_memory_for_rexp(rexp, save_match_objects=True): 49 | """ 50 | Given a regular expression, search through all of the program's 51 | memory blocks for it and return a list of addresses where it was found, 52 | as well as a list of the match objects. Set `save_match_objects` to 53 | False if you are searching for exceptionally large objects and 54 | don't want to keep the matches around 55 | """ 56 | memory_blocks = list(getMemoryBlocks()) 57 | search_memory_blocks = memory_blocks 58 | # TODO: maybe implement filters for which blocks get searched 59 | # filter out which memory blocks should actually be searched 60 | # search_memory_blocks = [i for i in search_memory_blocks 61 | # if i.getPermissions() == i.READ] 62 | # if additional_search_block_filter is not None: 63 | # search_memory_blocks = [i for i in search_memory_blocks if 64 | # additional_search_block_filter(i) is True] 65 | all_match_addrs = [] 66 | all_match_objects = [] 67 | for m_block in search_memory_blocks: 68 | if not m_block.isInitialized(): 69 | continue 70 | region_start = m_block.getStart() 71 | region_start_int = region_start.getOffset() 72 | search_bytes = getBytes(region_start, m_block.getSize()) 73 | iter_gen = re.finditer(rexp, search_bytes) 74 | match_count = 0 75 | # hacky loop over matches so that the recursion limit can be caught 76 | while True: 77 | try: 78 | m = next(iter_gen) 79 | except StopIteration: 80 | # this is where the loop is normally supposed to end 81 | break 82 | except RuntimeError: 83 | # this means that recursion went too deep 84 | print("match hit recursion limit on match %d" % match_count) 85 | break 86 | match_count += 1 87 | location_addr = region_start.add(m.start()) 88 | all_match_addrs.append(location_addr) 89 | if save_match_objects: 90 | all_match_objects.append(m) 91 | return all_match_addrs, all_match_objects 92 | 93 | 94 | def batch_pattern_memory_search(patterns, batchsize=100, save_match_objects=True): 95 | """ 96 | Works similar to search_memory_for_rexp, but supports running a list of patterns in batches 97 | so that python doesn't have to run a 500,000 character regular expression. 98 | """ 99 | def batch(it, sz): 100 | for i in range(0, len(it), sz): 101 | yield it[i:i+sz] 102 | 103 | all_match_addrs = [] 104 | all_match_objects = [] 105 | for pattern_batch in batch(patterns, batchsize): 106 | joined_pattern = b'(%s)' % b'|'.join(pattern_batch) 107 | rexp = compile_byte_rexp_pattern(joined_pattern) 108 | match_addrs, match_obj = search_memory_for_rexp(rexp, save_match_objects=save_match_objects) 109 | all_match_addrs.extend(match_addrs) 110 | all_match_objects.extend(match_obj) 111 | return all_match_addrs, all_match_objects 112 | 113 | 114 | class PointerUtils: 115 | def __init__(self, ptr_size=8, endian="little"): 116 | self.ptr_size = ptr_size 117 | if endian.lower() in ["big", "msb", "be"]: 118 | self.endian = "big" 119 | self.is_big_endian = True 120 | elif endian.lower() in ["little", "lsb", "le"]: 121 | self.endian = "little" 122 | self.is_big_endian = False 123 | 124 | self.ptr_pack_sym = "" 125 | if self.ptr_size == 4: 126 | self.ptr_pack_sym = "I" 127 | elif self.ptr_size == 8: 128 | self.ptr_pack_sym = "Q" 129 | 130 | self.pack_endian = "" 131 | if self.is_big_endian is True: 132 | self.pack_endian = ">" 133 | else: 134 | self.pack_endian = "<" 135 | self.ptr_pack_code = self.pack_endian + self.ptr_pack_sym 136 | 137 | def generate_address_range_pattern(self, minimum_addr, maximum_addr): 138 | """ 139 | Generate a regular expression pattern that can be used to match 140 | the bytes for an address between minimum_addr and maximum_addr 141 | (inclusive). This works best for small ranges, and will break 142 | somewhat if there are non-contiguous memory blocks, but it works 143 | well enough for most things 144 | """ 145 | diff = maximum_addr - minimum_addr 146 | val = diff 147 | # calculate the changed number of bytes between the minimum_addr and the maximum_addr 148 | byte_count = 0 149 | while val > 0: 150 | val = val >> 8 151 | byte_count += 1 152 | 153 | # generate a sufficient wildcard character classes for all of the bytes that could fully change 154 | wildcard_bytes = byte_count - 1 155 | wildcard_pattern = b"[\x00-\xff]" 156 | boundary_byte_upper = (maximum_addr >> (wildcard_bytes*8)) & 0xff 157 | boundary_byte_lower = (minimum_addr >> (wildcard_bytes*8)) & 0xff 158 | if boundary_byte_upper < boundary_byte_lower: 159 | boundary_byte_upper, boundary_byte_lower = boundary_byte_lower, boundary_byte_upper 160 | # create a character class that will match the largest changing byte 161 | lower_byte = bytearray([boundary_byte_lower]) 162 | upper_byte = bytearray([boundary_byte_upper]) 163 | # re.escape breaks depending on version of python, 164 | # converting bytes to strings. instead, manually escape 165 | # TODO: add a test case for this to make sure that python 166 | # TODO: isn't matching against the backslash for the end byte 167 | escaped_lower_byte = re.escape(lower_byte) 168 | escaped_lower_byte = bytearray(escaped_lower_byte) 169 | escaped_upper_byte = re.escape(upper_byte) 170 | escaped_upper_byte = bytearray(escaped_upper_byte) 171 | boundary_byte_pattern = b"[%s-%s]" % (escaped_lower_byte, 172 | escaped_upper_byte) 173 | address_pattern = b'' 174 | single_address_pattern = b'' 175 | if self.is_big_endian is False: 176 | packed_addr = struct.pack(self.ptr_pack_code, minimum_addr) 177 | single_address_pattern = b''.join([wildcard_pattern*wildcard_bytes, 178 | boundary_byte_pattern, 179 | packed_addr[byte_count:]]) 180 | else: 181 | packed_addr = struct.pack(self.ptr_pack_code, minimum_addr) 182 | single_address_pattern = b''.join([packed_addr[:byte_count], 183 | boundary_byte_pattern, 184 | wildcard_pattern*wildcard_bytes]) 185 | address_pattern = b"(%s)" % single_address_pattern 186 | return address_pattern 187 | 188 | def generate_address_range_rexp(self, minimum_addr, maximum_addr): 189 | """ 190 | Generate a regular expression that can match on any value between 191 | the provided minimum addr and maximum addr 192 | """ 193 | address_pattern = self.generate_address_range_pattern(minimum_addr, maximum_addr) 194 | address_rexp = compile_byte_rexp_pattern(address_pattern) 195 | return address_rexp 196 | 197 | def ptr_ints_from_bytearray(self, bytarr): 198 | """ 199 | Returns a tuple of poitner-sized ints unpacked from the provided 200 | bytearray 201 | """ 202 | bytarr = bytearray(bytarr) 203 | # truncate in case the bytarray isn't aligned to ptr size 204 | fit_len = len(bytarr) // self.ptr_size 205 | pack_code = "%s%d%s" % (self.pack_endian, fit_len, self.ptr_pack_sym) 206 | return struct.unpack_from(pack_code, bytarr) 207 | 208 | def gen_pattern_for_pointer(self, pointer): 209 | """ 210 | Generate a regular expression pattern for a pointer 211 | """ 212 | if isinstance(pointer, GenericAddress): 213 | pointer = pointer.getOffsetAsBigInteger() 214 | 215 | pointer_bytes = struct.pack(self.ptr_pack_code, pointer) 216 | pointer_pattern = re.escape(pointer_bytes) 217 | return pointer_pattern 218 | 219 | def search_for_pointer(self, pointer): 220 | """ 221 | Find all locations where a specific pointer is embedded in memory 222 | """ 223 | pointer_pattern = self.gen_pattern_for_pointer(pointer) 224 | address_rexp = compile_byte_rexp_pattern(pointer_pattern) 225 | match_addrs, _ = search_memory_for_rexp(address_rexp) 226 | return match_addrs 227 | 228 | 229 | def createPointerUtils(program=None, ptr_size=None, endian=None): 230 | if program is None: 231 | program = currentProgram 232 | if ptr_size is None: 233 | ptr_size = program.getDefaultPointerSize() 234 | if endian is None: 235 | mem = program.getMemory() 236 | if mem.isBigEndian(): 237 | endian = "big" 238 | else: 239 | endian = "little" 240 | pu = PointerUtils(ptr_size, endian) 241 | return pu 242 | -------------------------------------------------------------------------------- /print_indexing_locations.py: -------------------------------------------------------------------------------- 1 | # Print all of the locations in the binary where indexing can be detected 2 | # 3 | 4 | from __main__ import * 5 | from decomp_utils import find_all_pcode_op_instances 6 | from ghidra.program.model.pcode import PcodeOpAST 7 | 8 | funcs_with_ptradd = find_all_pcode_op_instances(PcodeOpAST.PTRADD) 9 | 10 | # for func, ptradd_addrs in funcs_with_ptradd.items(): 11 | # # print("%s" % func.name) 12 | # for addr in ptradd_addrs: 13 | # print("%s" % str(addr)) 14 | # print("") 15 | all_addrs = sum([v for v in funcs_with_ptradd.values()], []) 16 | all_addrs.sort() 17 | for addr in all_addrs: 18 | print("%s" % str(addr)) 19 | -------------------------------------------------------------------------------- /prop_dt.py: -------------------------------------------------------------------------------- 1 | # naively propagate a datatype forward to all function signatures it directly 2 | # flows to 3 | from __main__ import * 4 | 5 | from type_propagator import prop_datatype_from_func_param 6 | from decomp_utils import DecompUtils 7 | 8 | 9 | selection = currentSelection 10 | 11 | if selection: 12 | addr = selection.minAddress 13 | func = getFunctionContaining(addr) 14 | else: 15 | func_name = askString("enter function name", "enter function name") 16 | func = getFunction(func_name) 17 | 18 | param_num = askInt("Parameter (indexed from 1)", "Parameter (indexed from 1)") 19 | 20 | prop_datatype_from_func_param(func, param_num, program=currentProgram) 21 | -------------------------------------------------------------------------------- /register_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __main__ import * 3 | 4 | 5 | def getStackRegister(program=None): 6 | if program is None: 7 | program = currentProgram 8 | return program.getCompilerSpec().getStackPointer() 9 | 10 | 11 | def getAllContainedRegisters(reg): 12 | to_visit = set([reg]) 13 | visited = set() 14 | while to_visit: 15 | r = to_visit.pop() 16 | for child in r.getChildRegisters(): 17 | if child is None: 18 | continue 19 | if child in to_visit: 20 | continue 21 | if child in visited: 22 | continue 23 | if child == r: 24 | continue 25 | to_visit.add(child) 26 | visited.add(r) 27 | return visited 28 | 29 | 30 | def getGeneralPurposeRegsToParamMapForCallingConvention(cc, program=None): 31 | """ 32 | Create a map of general purpose registers to parameter number for 33 | the provided calling convention and program 34 | """ 35 | if program is None: 36 | program = currentProgram 37 | if cc is None: 38 | cc = program.getCompilerSpec().getDefaultCallingConvention() 39 | inp_storage_locs = cc.getPotentialInputRegisterStorage(program) 40 | gpr_storage_locs = [] 41 | # get all of the non-vector storage locations 42 | for stor in inp_storage_locs: 43 | storage_registers = stor.getRegisters() 44 | for reg in storage_registers: 45 | base_reg = reg.getBaseRegister() 46 | if base_reg.isVectorRegister() is True: 47 | continue 48 | gpr_storage_locs.append(base_reg) 49 | # map all of the general purpose registers to the parameter that they 50 | # should fit into 51 | reg_to_param = {} 52 | for ind, base_reg in enumerate(gpr_storage_locs): 53 | for r in getAllContainedRegisters(base_reg): 54 | reg_to_param[r] = ind+1 55 | return reg_to_param 56 | 57 | 58 | GPR_TO_PARAM_MAP_CACHE = {} 59 | 60 | 61 | def getRegToParamMapForFunc(func): 62 | """ 63 | Get a map of {general_purpose_register: int(param_num)} for the given 64 | function 65 | """ 66 | cc = func.getCallingConvention() 67 | program = func.getProgram() 68 | if cc is None: 69 | cc = program.getCompilerSpec().getDefaultCallingConvention() 70 | key = (cc, program) 71 | maybe_res = GPR_TO_PARAM_MAP_CACHE.get(key) 72 | if maybe_res is None: 73 | maybe_res = getGeneralPurposeRegsToParamMapForCallingConvention(cc, program=program) 74 | GPR_TO_PARAM_MAP_CACHE[key] = maybe_res 75 | return maybe_res 76 | -------------------------------------------------------------------------------- /renamespace.py: -------------------------------------------------------------------------------- 1 | # Change namespace of selected regions 2 | #@author Clifton Wolfe 3 | #@category C++ 4 | import ghidra 5 | from ghidra.program.model.symbol import SourceType 6 | import string 7 | import re 8 | import struct 9 | # makes it easier for dev and testing 10 | from __main__ import * 11 | 12 | 13 | class PointerUtils: 14 | def __init__(self): 15 | self.addr_fact = currentProgram.getAddressFactory() 16 | self.addr_space = self.addr_fact.getDefaultAddressSpace() 17 | self.ptr_size = self.addr_space.getPointerSize() 18 | self.mem = currentProgram.getMemory() 19 | self.ptr_pack_sym = "" 20 | if self.ptr_size == 4: 21 | self.ptr_pack_sym = "I" 22 | elif self.ptr_size == 8: 23 | self.ptr_pack_sym = "Q" 24 | 25 | self.pack_endian = "" 26 | if self.mem.isBigEndian(): 27 | self.pack_endian = ">" 28 | else: 29 | self.pack_endian = "<" 30 | 31 | def ptr_ints_from_bytearray(self, bytarr): 32 | bytarr = bytearray(bytarr) 33 | # truncate in case the bytarray isn't aligned to ptr size 34 | fit_len = len(bytarr) // self.ptr_size 35 | pack_code = "%s%d%s" % (self.pack_endian, fit_len, self.ptr_pack_sym) 36 | return struct.unpack_from(pack_code, bytarr) 37 | 38 | 39 | def get_or_create_namespace(name, parent=None): 40 | if parent is None: 41 | parent = currentProgram.getGlobalNamespace() 42 | sym_tab = currentProgram.getSymbolTable() 43 | maybe_ns = sym_tab.getNamespace(name, parent) 44 | if maybe_ns: 45 | return maybe_ns 46 | maybe_ns = sym_tab.createNameSpace(parent, name, 47 | SourceType.USER_DEFINED) 48 | return maybe_ns 49 | 50 | 51 | namespace_name = askString("Enter namespace", "Enter namespace") 52 | namespace = get_or_create_namespace(namespace_name) 53 | ptr_utils = PointerUtils() 54 | 55 | all_selected_ptrs = [] 56 | for addr_range in currentSelection: 57 | start_addr = addr_range.minAddress 58 | size = addr_range.maxAddress.subtract(start_addr)+1 59 | selected_bytes = bytearray(getBytes(start_addr, size)) 60 | selected_ptrs = [toAddr(i) for i in ptr_utils.ptr_ints_from_bytearray(selected_bytes)] 61 | all_selected_ptrs.extend(selected_ptrs) 62 | 63 | for ptr in all_selected_ptrs: 64 | func = getFunctionAt(ptr) 65 | # skip functions that aren't known for now 66 | if func is None: 67 | continue 68 | if func.getParentNamespace() != namespace: 69 | func.setParentNamespace(namespace) 70 | 71 | 72 | -------------------------------------------------------------------------------- /tag_callback_registration.py: -------------------------------------------------------------------------------- 1 | # tag functions that are registering callback functions 2 | #@author Clifton Wolfe 3 | #@category Analysis 4 | 5 | from __main__ import * 6 | from collections import defaultdict 7 | from ghidra.program.model.symbol import SourceType 8 | 9 | 10 | def get_function_data_refs_from_funcs(program=None): 11 | if program is None: 12 | program = currentProgram 13 | data_refs_from_funcs = defaultdict(list) 14 | refman = program.getReferenceManager() 15 | for func in program.getFunctionManager().getFunctions(1): 16 | refs = refman.getReferencesTo(func.getEntryPoint()) 17 | for ref in refs: 18 | if not ref.referenceType.isData(): 19 | continue 20 | referring_func = getFunctionContaining(ref.fromAddress) 21 | if referring_func is None: 22 | continue 23 | data_refs_from_funcs[referring_func].append(ref) 24 | return dict(data_refs_from_funcs) 25 | 26 | 27 | def generate_placeholder_function_name(func, prefix): 28 | entrypoint = func.getEntryPoint() 29 | return "%s_%s" % (prefix, str(entrypoint)) 30 | 31 | 32 | def tag_callback_registration(program=None, rename_unnamed_referring_funcs=True, rename_unnamed_callback_funcs=True): 33 | if program is None: 34 | program = currentProgram 35 | function_data_refs = get_function_data_refs_from_funcs(program) 36 | 37 | for referring_func, refs in function_data_refs.items(): 38 | referring_func.addTag("CALLBACK_REGISTRATION_FUNCTION") 39 | if rename_unnamed_referring_funcs is True and referring_func.name.startswith("FUN_"): 40 | generated_name = generate_placeholder_function_name(referring_func, "registerCallback") 41 | referring_func.setName(generated_name, SourceType.USER_DEFINED) 42 | 43 | for ref in refs: 44 | referenced_func = getFunctionContaining(ref.toAddress) 45 | if referenced_func is None: 46 | continue 47 | referenced_func.addTag("CALLBACK_FUNTION") 48 | if rename_unnamed_callback_funcs is True and referenced_func.name.startswith("FUN_"): 49 | generated_name = generate_placeholder_function_name(referenced_func, "callback") 50 | referenced_func.setName(generated_name, SourceType.USER_DEFINED) 51 | 52 | 53 | 54 | if __name__ == "__main__": 55 | tag_callback_registration(currentProgram) 56 | -------------------------------------------------------------------------------- /test/Makefile: -------------------------------------------------------------------------------- 1 | 2 | define allow-override 3 | $(if $(or $(findstring environment,$(origin $(1))),\ 4 | $(findstring command line,$(origin $(1)))),,\ 5 | $(eval $(1) = $(2))) 6 | endef 7 | 8 | TOOL_PREFIX= 9 | CC=$(TOOL_PREFIX)gcc 10 | LD=$(TOOL_PREFIX)ld 11 | 12 | ifdef TOOL_PREFIX 13 | $(call allow-override,CC,$(CC)) 14 | $(call allow-override,LD,$(LD)) 15 | endif 16 | 17 | CDEBUG=-g -O0 18 | CFLAGS=-Iinclude -Isrc -shared -fPIC 19 | RELEASE_FLAGS=-O2 20 | 21 | BINARY=testlib.so 22 | 23 | BUILD_DIR=build 24 | DEBUG_DIR=debug 25 | RELEASE_DIR=release 26 | OBJ=$(BUILD_DIR) 27 | 28 | OBJECTS += src/main.o 29 | OBJECTS += src/int_under_overflow.o 30 | 31 | RELEASE_OBJECT_FILES=$(addprefix $(OBJ)/release/, $(OBJECTS)) 32 | DBG_OBJECT_FILES=$(addprefix $(OBJ)/debug/, $(OBJECTS)) 33 | 34 | .PHONY: all clean tests debug release 35 | 36 | all: debug release 37 | 38 | $(OBJ)/$(DEBUG_DIR)/%.o: CFLAGS += $(CDEBUG) 39 | $(OBJ)/$(DEBUG_DIR)/%.o: %.c 40 | @mkdir -p $(@D) 41 | $(CC) $(CFLAGS) -c -o $@ $< 42 | 43 | $(OBJ)/$(RELEASE_DIR)/%.o: CFLAGS += $(RELEASE_FLAGS) 44 | $(OBJ)/$(RELEASE_DIR)/%.o: %.c 45 | @mkdir -p $(@D) 46 | $(CC) $(CFLAGS) -c -o $@ $< 47 | 48 | # currently unused target 49 | $(OBJ)/%.o: %.c 50 | @mkdir -p $(@D) 51 | $(CC) $(CFLAGS) -c -o $@ $< 52 | 53 | # old unused target 54 | $(BINARY): $(OBJECT_FILES) 55 | $(CC) $(CFLAGS) -o $@ $^ 56 | 57 | $(OBJ)/$(DEBUG_DIR)/$(BINARY): $(DBG_OBJECT_FILES) 58 | @mkdir -p $(@D) 59 | $(CC) $(CFLAGS) $(CDEBUG) -o $@ $^ 60 | 61 | debug: $(OBJ)/$(DEBUG_DIR)/$(BINARY) 62 | 63 | $(OBJ)/$(RELEASE_DIR)/$(BINARY): $(RELEASE_OBJECT_FILES) 64 | @mkdir -p $(@D) 65 | $(CC) $(CFLAGS) -o $@ $^ 66 | 67 | release: $(OBJ)/$(RELEASE_DIR)/$(BINARY) 68 | 69 | clean: 70 | rm -rf $(OBJ) 2>/dev/null 71 | rm -f $(BINARY) 2>/dev/null 72 | -------------------------------------------------------------------------------- /test/include/main.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flounderK/ghidra_scripts/56f63953b451c55e4daf4850657652edc8cc5ed3/test/include/main.h -------------------------------------------------------------------------------- /test/src/int_under_overflow.c: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | int __attribute__((optimize(0))) simple_ucmp_with_sub(uint16_t out_buf_len, void* out_buf, void* in_buf, uint16_t padding, uint32_t ea_block_size) { 8 | int ret = 0; 9 | 10 | // bad check because of unsigned val 11 | if (ea_block_size <= out_buf_len - padding) { 12 | memcpy(out_buf, in_buf, out_buf_len - padding); 13 | goto exit; 14 | } 15 | ret = -1; 16 | 17 | exit: 18 | return ret; 19 | } 20 | -------------------------------------------------------------------------------- /test/src/main.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int unconstrained_stack_memcpy_overflow(char* srcbuf, size_t count){ 11 | char dstbuf[32] = {}; 12 | memcpy(dstbuf, srcbuf, count); 13 | return 0; 14 | } 15 | 16 | int incorrect_size_checked_stack_memcpy_overflow(char* srcbuf, size_t count){ 17 | char dstbuf[32] = {}; 18 | if (count > 48) { 19 | return -1; 20 | } 21 | memcpy(dstbuf, srcbuf, count); 22 | return 0; 23 | } 24 | 25 | int unconstrained_stack_read_overflow(size_t count){ 26 | char dstbuf[32] = {}; 27 | read(0, dstbuf, count); 28 | return 0; 29 | } 30 | 31 | 32 | int variable_printf_arg_1(int choice) { 33 | const char* formatstring = NULL; 34 | switch (choice) { 35 | case 1: { 36 | formatstring = "One chosen\n"; 37 | break; 38 | } 39 | case 2: { 40 | formatstring = "Two chosen\n"; 41 | break; 42 | } 43 | case 3: { 44 | formatstring = "Three chosen\n"; 45 | break; 46 | } 47 | default: { 48 | formatstring = "DEFAULT\n"; 49 | break; 50 | } 51 | } 52 | 53 | printf(formatstring); 54 | return 0; 55 | } 56 | 57 | uint8_t g_GlobalArray[8]; 58 | uint8_t inner_global_oob_func(int ind, uint8_t val) { 59 | uint8_t oldval = g_GlobalArray[ind]; 60 | g_GlobalArray[ind] = val; 61 | 62 | return oldval; 63 | } 64 | 65 | int outer_global_oob_func(void) { 66 | inner_global_oob_func(16, 0); 67 | 68 | return 0; 69 | } 70 | 71 | char* heap_overflow_from_strlen(char* str) { 72 | int length; 73 | char* newstr; 74 | if (str == NULL) { 75 | return NULL; 76 | } 77 | length = strlen(str); 78 | newstr = malloc(length); 79 | memset(newstr, 0, length); 80 | 81 | strcat(newstr, str); 82 | // strncat(newstr, str, length); // also causes the issue 83 | return newstr; 84 | } 85 | 86 | 87 | int suspicious_malloc_size(void) { 88 | // anything that is malloc'd with a size of a single pointer or less is very suspicous, 89 | // as there is either enough space to store a single pointer with no additional data, or 90 | // enough space to store data that isn't a part of any data structure. Unless the data is for 91 | // an elasitcally sized value (like a char* that just happens to be very small) it could probably have 92 | // been a global. 93 | void* buf = malloc(sizeof(void*)); 94 | return 0; 95 | } 96 | 97 | bool dead_store_elim(char* instr) { 98 | char buf[255]; 99 | memset(buf, 0, sizeof(buf)); 100 | bool ret = false; 101 | 102 | snprintf(buf, sizeof(buf), "%s", "blah"); 103 | 104 | if (strcmp(buf, instr) == 0) { 105 | ret = true; 106 | } 107 | memset(buf, 0, sizeof(buf)); // gets optimized out because buf isn't used any more 108 | return ret; 109 | } 110 | 111 | 112 | -------------------------------------------------------------------------------- /type_pointers_to_data.py: -------------------------------------------------------------------------------- 1 | # This script identifies locations where data values embedded in a binary 2 | # can be represented as addresses, but the type of those embedded data objects 3 | # is either undefined or is defined as a generic pointer. It then appropriatly 4 | # types them 5 | from __main__ import * 6 | from ghidra.program.model.address import AddressSet 7 | from ghidra.program.database.data import PointerDB 8 | from ghidra.program.database.data import TypedefDB 9 | from ghidra.program.model.data import TypedefDataType 10 | from ghidra.program.model.data import ComponentOffsetSettingsDefinition 11 | from ghidra.program.model.scalar import Scalar 12 | import logging 13 | 14 | log = logging.getLogger(__file__) 15 | log.addHandler(logging.StreamHandler()) 16 | log.setLevel(logging.DEBUG) 17 | 18 | 19 | def applyDataTypeAtAddress(address, datatype, size=None, program=None): 20 | if program is None: 21 | program = currentProgram 22 | if size is None: 23 | size = datatype.getLength() 24 | listing = program.getListing() 25 | listing.clearCodeUnits(address, address.add(size), False) 26 | listing.createData(address, datatype, size) 27 | 28 | 29 | # get an address set for all current memory blocks 30 | existing_mem_addr_set = AddressSet() 31 | for m_block in getMemoryBlocks(): 32 | # cut sections that are unused 33 | if not (m_block.isRead() or m_block.isWrite() or m_block.isExecute()): 34 | continue 35 | existing_mem_addr_set.add(m_block.getAddressRange()) 36 | 37 | 38 | dtm = currentProgram.getDataTypeManager() 39 | useful_dat = [] 40 | listing = currentProgram.getListing() 41 | for dat in listing.getDefinedData(1): 42 | addr_repr = None 43 | # handle undefined datatype 44 | if dat.valueClass is None: 45 | val = dat.value 46 | if val is None: 47 | continue 48 | int_val = val.getUnsignedValue() 49 | addr_repr = toAddr(int_val) 50 | # handle all pointer datatypes 51 | if dat.isPointer(): 52 | val = dat.value 53 | if val is None: 54 | continue 55 | # TODO: detect type pointed to 56 | int_val = val.getOffsetAsBigInteger() 57 | addr_repr = val 58 | # handle other scalar values, (uint, int, etc.) 59 | if isinstance(dat.value, Scalar): 60 | addr_repr = toAddr(dat.value.value) 61 | if addr_repr is None: 62 | continue 63 | 64 | # find all of the defined data in the program that could 65 | # be a pointer to the current model of memory 66 | if not existing_mem_addr_set.contains(addr_repr): 67 | continue 68 | useful_dat.append(dat) 69 | curr_dt = dat.dataType 70 | dat_cont = listing.getDataContaining(addr_repr) 71 | if dat_cont is None: 72 | continue 73 | if dat_cont.isStructure() is False: 74 | continue 75 | 76 | # apply datatype if addresses matches to start of defined data 77 | if dat_cont.address == addr_repr: 78 | new_type = dtm.getPointer(dat_cont.dataType) 79 | if curr_dt != new_type: 80 | log.debug("setting type at %s" % dat.address) 81 | applyDataTypeAtAddress(dat.address, new_type) 82 | continue 83 | 84 | # if address didn't match, it means that a pointer offset typedef 85 | # needs to be used. 86 | # iterate through existing typedefs to see if one already exists 87 | # that will work correctly 88 | typedef_dts = [dt for dt in dtm.getAllDataTypes() if isinstance(dt, TypedefDB)] 89 | off = addr_repr.subtract(dat_cont.address) 90 | set_typedef_dt = None 91 | for dt in typedef_dts: 92 | pointed_dt = dt.dataType 93 | # only care about typedefs to pointers, so skip the rest 94 | if not isinstance(pointed_dt, PointerDB): 95 | continue 96 | if pointed_dt.dataType != dat_cont.dataType: 97 | continue 98 | comp_off = dt.defaultSettings.getValue("component_offset") 99 | if comp_off is None: 100 | continue 101 | if comp_off == off: 102 | log.debug("found matching type for %s" % dat_cont.dataType.name) 103 | set_typedef_dt = dt 104 | break 105 | 106 | log.debug("setting type to offser pointer at %s" % dat.address) 107 | if set_typedef_dt is None: 108 | log.debug("making a new datatype for %s %d" % (dat_cont.dataType.name, off)) 109 | # if there wasn't a match, make a new typedef 110 | new_typedef_name = "%s_ptr_%d" % (dat_cont.dataType.name, off) 111 | ptr_type = dtm.getPointer(dat_cont.dataType) 112 | new_typedef = TypedefDataType(new_typedef_name, ptr_type) 113 | set_typedef_dt = dtm.resolve(new_typedef, None) 114 | default_settings = set_typedef_dt.getDefaultSettings() 115 | ComponentOffsetSettingsDefinition.DEF.setValue(default_settings, off) 116 | applyDataTypeAtAddress(dat.address, set_typedef_dt) 117 | 118 | -------------------------------------------------------------------------------- /type_propagator.py: -------------------------------------------------------------------------------- 1 | from ghidra.program.model.pcode import PcodeOpAST, VarnodeAST 2 | from ghidra.program.util import FunctionSignatureFieldLocation 3 | from ghidra.program.model.symbol import FlowType, RefType, SourceType 4 | from ghidra.app.cmd.function import ApplyFunctionSignatureCmd 5 | from ghidra.program.model.data import MetaDataType 6 | from ghidra.program.model.data import DefaultDataType 7 | 8 | import logging 9 | from decomp_utils import DecompUtils 10 | from function_signature_utils import set_param_datatype, set_num_params, getDataTypeForParam 11 | from datatype_utils import getVoidPointerDatatype, areBaseDataTypesEquallyUnique, getUndefinedRegisterSizeDatatype 12 | from register_utils import getRegToParamMapForFunc 13 | 14 | log = logging.getLogger(__file__) 15 | log.addHandler(logging.StreamHandler()) 16 | log.setLevel(logging.DEBUG) 17 | 18 | from __main__ import * 19 | 20 | 21 | def guessNumParamsFromDecompilation(func): 22 | du = DecompUtils(program=func.getProgram()) 23 | hf = du.get_high_function(func) 24 | if hf is None: 25 | return 0 26 | reg_to_param_map = getRegToParamMapForFunc(func) 27 | sym_name_to_param_num_map = {("in_%s" % k): v for k, v in reg_to_param_map.items()} 28 | 29 | cur_num_arguments = len(func.getSignature().getArguments()) 30 | lsm = hf.getLocalSymbolMap() 31 | max_param_num = -1 32 | for sym in lsm.getSymbols(): 33 | maybe_param = sym_name_to_param_num_map.get(sym) 34 | if maybe_param is not None: 35 | max_param_num = max(max_param_num, maybe_param) 36 | 37 | return max(cur_num_arguments, max_param_num) 38 | 39 | 40 | def fix_underreported_num_params(func): 41 | """ 42 | Decompile the specified function and attempt to fix the number of 43 | parameters that the function takes based on the presence of 44 | 'in_' symbols in the local symbol map for the function. 45 | This will not work correctly if the function takes in vector registers 46 | as parameters 47 | """ 48 | cur_num_arguments = len(func.getSignature().getArguments()) 49 | max_param_num = guessNumParamsFromDecompilation(func) 50 | if max_param_num != -1 and max_param_num != cur_num_arguments: 51 | set_num_params(func, max_param_num) 52 | 53 | 54 | def dethunkIfNecessary(func): 55 | if func is None: 56 | return None 57 | ret = func 58 | if func.isThunk() is True: 59 | dethunked = func.getThunkedFunction(1) 60 | if dethunked is not None: 61 | ret = dethunked 62 | return ret 63 | 64 | 65 | class FunctionCallArgContext(object): 66 | """ 67 | A class representing the arguments for a call to a specific address 68 | """ 69 | def __init__(self, to_address=None): 70 | self.args = {} 71 | self.to_address = to_address 72 | self.called_to_func = None 73 | self._to_repr = "TO" 74 | if to_address is not None: 75 | self._to_repr = str(to_address) 76 | self.called_to_func = getFunctionContaining(to_address) 77 | self.called_to_func = dethunkIfNecessary(self.called_to_func) 78 | if self.called_to_func is not None: 79 | self._to_repr = self.called_to_func.name 80 | 81 | def add_arg(self, arg, arg_num): 82 | """ 83 | Arg num is indexed from 1 to preserve 84 | consistency with pcode call ops. 85 | arg_num is not equivalent to slot 86 | """ 87 | self.args[arg_num] = arg 88 | 89 | def merge(self, other): 90 | for k, v in other.args.items(): 91 | existing = self.args.get(k) 92 | if existing is None: 93 | self.args[k] = v 94 | 95 | def __repr__(self): 96 | return "%s(%s)" % (self._to_repr, 97 | ", ".join(["%d" % i for i in self.args.keys()])) 98 | 99 | 100 | class FunctionArgContextCollection(object): 101 | """ 102 | A class representing a related collection of FunctionCallArgContext. These 103 | FunctionCallArgContexts do not have to be from the same function. 104 | """ 105 | def __init__(self): 106 | self.function_call_args = {} 107 | 108 | def add(self, arg_ctx, merge=True): 109 | """ 110 | FunctionCallArgContext 111 | """ 112 | key = arg_ctx.to_address 113 | if merge is True: 114 | existing = self.function_call_args.get(key) 115 | if existing is not None: 116 | existing.merge(arg_ctx) 117 | arg_ctx = existing 118 | self.function_call_args[key] = arg_ctx 119 | 120 | def get(self, called_addr): 121 | """ 122 | Get a FunctionCallArgContext for the specified address 123 | """ 124 | key = called_addr 125 | maybe_v = self.function_call_args.get(key) 126 | if maybe_v is None: 127 | maybe_v = FunctionCallArgContext(called_addr) 128 | self.function_call_args[key] = maybe_v 129 | return maybe_v 130 | 131 | 132 | def trace_struct_fwd_to_call(varnodes, func_arg_holder=None, allow_ptrsub_zero=False): 133 | """ 134 | Within a single function, trace the specified varnodes forward to any call operations. 135 | Does not follow through dereferences 136 | """ 137 | def _add_to_list(curr_vn, vn_cand, to_visit, visited): 138 | if vn_cand is None: 139 | return 140 | if vn_cand in visited: 141 | return 142 | if vn_cand in to_visit: 143 | return 144 | if vn_cand == curr_vn: 145 | return 146 | to_visit.append(vn_cand) 147 | 148 | if func_arg_holder is None: 149 | func_arg_holder = FunctionArgContextCollection() 150 | 151 | refman = currentProgram.getReferenceManager() 152 | if not hasattr(varnodes, '__iter__'): 153 | varnodes = [varnodes] 154 | 155 | to_visit = list(varnodes) 156 | visited = set() 157 | while to_visit: 158 | vn = to_visit.pop() 159 | desc_ops = vn.getDescendants() 160 | for op in desc_ops: 161 | opcode = op.opcode 162 | if opcode in [PcodeOpAST.CALL, PcodeOpAST.CALLIND]: 163 | inputs = list(op.getInputs()) 164 | call_addr = op.getSeqnum().getTarget() 165 | for ref in refman.getReferencesFrom(call_addr): 166 | if ref.referenceType.isCall() is False: 167 | continue 168 | func_call_args = func_arg_holder.get(ref.toAddress) 169 | for i in range(1, len(inputs)): 170 | if inputs[i] != vn: 171 | continue 172 | # TODO: use data type here 173 | func_call_args.add_arg(i, i) 174 | 175 | elif opcode in [PcodeOpAST.COPY, PcodeOpAST.CAST, 176 | PcodeOpAST.MULTIEQUAL, PcodeOpAST.PIECE, 177 | PcodeOpAST.SUBPIECE]: 178 | # basically just ops where the decompiler could 179 | # change the type or size of the vn. decompiler 180 | # will generate these if a parameter for a call 181 | # is set to the incorrect type or size 182 | vn_cand = op.getOutput() 183 | _add_to_list(vn, vn_cand, to_visit, visited) 184 | elif opcode == PcodeOpAST.PTRSUB and allow_ptrsub_zero is True: 185 | # an optimization to make the decompiler output look 186 | # closer to C code can add in a dummy ptrsub vn, 0 187 | # to allow passing field0_0x0 into function calls 188 | offset_vn = op.getInput(1) 189 | if not offset_vn.isConstant(): 190 | log.error("PTRSUB offset was non-const %s %s" % 191 | (str(op.getSeqnum().getTarget()), 192 | str(op))) 193 | continue 194 | offset = int(offset_vn.getOffset()) 195 | # this is probably not a real access to a field, 196 | # the edge case this is looking for 197 | if offset == 0: 198 | vn_cand = op.getOutput() 199 | _add_to_list(vn, vn_cand, to_visit, visited) 200 | # TODO: maybe handle multi-level propagation 201 | elif opcode == PcodeOpAST.PTRADD: 202 | # TODO: confirm that PTRADD can not occur for passing 203 | # TODO: a pointer on [0] if a structure datatype is 204 | # TODO: confused with an array 205 | 206 | # TODO: maybe handle multi-level propagation 207 | pass 208 | elif opcode in [PcodeOpAST.INT_ADD, PcodeOpAST.INT_SUB]: 209 | # even though it doesn't make sense for correctly typed 210 | # things, an ADD or SUB would be seen in situations 211 | # where the type of a struct ptr is is incorrect 212 | vn_cand = op.getOutput() 213 | _add_to_list(vn, vn_cand, to_visit, visited) 214 | elif opcode in [PcodeOpAST.STORE, PcodeOpAST.LOAD]: 215 | # If load or store is handled then the traced 216 | # value would be a different value or type 217 | continue 218 | visited.add(vn) 219 | return func_arg_holder 220 | 221 | 222 | def trace_struct_forward(varnodes, allow_ptrsub_zero=False): 223 | """ 224 | Trace forward from a varnode or varnodes to all locations in this 225 | function and in functions called by this function and identify if the 226 | specified varnodes are passed directly into other functions 227 | returns a FunctionArgContextCollection. This effectively traces places where 228 | a struct pointer is passed as an argument 229 | """ 230 | if not hasattr(varnodes, '__iter__'): 231 | varnodes = [varnodes] 232 | du = DecompUtils() 233 | func_arg_holder = FunctionArgContextCollection() 234 | # get the top level of FunctionCallArgContext 235 | trace_struct_fwd_to_call(varnodes, func_arg_holder, allow_ptrsub_zero=allow_ptrsub_zero) 236 | to_visit = set([i for i in func_arg_holder.function_call_args.values()]) 237 | visited = set() 238 | visited_to_addrs = set() 239 | while to_visit: 240 | curr_arg_ctx = to_visit.pop() 241 | curr_func = curr_arg_ctx.called_to_func 242 | if curr_func is None: 243 | log.error("Have to skip %s because no to func could be identified" % str(curr_arg_ctx)) 244 | continue 245 | if curr_func.hasVarArgs() is True: 246 | continue 247 | in_vns = [] 248 | was_error = False 249 | for param_num, v in curr_arg_ctx.args.items(): 250 | maybe_vns = du.get_varnodes_for_param(curr_func, param_num) 251 | if maybe_vns is None: 252 | was_error = True 253 | break 254 | in_vns += maybe_vns 255 | if was_error is True: 256 | # without tracking to_addresses separately this search will 257 | # check many many call edges, even if the called function has 258 | # already been checked before 259 | log.error("couldn't get varnodes for %s" % (curr_func.name)) 260 | visited_to_addrs.add(curr_arg_ctx.to_address) 261 | continue 262 | tmp_arg_holder = trace_struct_fwd_to_call(in_vns, allow_ptrsub_zero=allow_ptrsub_zero) 263 | for cand_arg_ctx in tmp_arg_holder.function_call_args.values(): 264 | if cand_arg_ctx in visited: 265 | continue 266 | if cand_arg_ctx in to_visit: 267 | continue 268 | if cand_arg_ctx.to_address in visited_to_addrs: 269 | continue 270 | if cand_arg_ctx == curr_arg_ctx: 271 | continue 272 | to_visit.add(cand_arg_ctx) 273 | visited.add(curr_arg_ctx) 274 | visited_to_addrs.add(curr_arg_ctx.to_address) 275 | 276 | for arg_ctx in visited: 277 | func_arg_holder.add(arg_ctx) 278 | return func_arg_holder 279 | 280 | 281 | def propagate_datatype_forward_to_function_signatures(varnodes, datatype, program=None, skip_external=True, overwrite_voidp=False, prefer_existing=True, force_new=False, undef_only=True): 282 | """ 283 | Propagate a datatype forward to all function signatures the specified varnodes are passed to 284 | """ 285 | if program is None: 286 | program = currentProgram 287 | func_arg_ctx_collection = trace_struct_forward(varnodes) 288 | voidp_dt = getVoidPointerDatatype(program) 289 | 290 | for addr, arg_ctx in func_arg_ctx_collection.function_call_args.items(): 291 | if arg_ctx.to_address is None: 292 | log.info("skipping a reference to an unknown function call because there is no `to` address") 293 | continue 294 | if arg_ctx.called_to_func is None: 295 | log.warning("skipping a reference to %s because there is no function defined there" % str(arg_ctx.to_address)) 296 | continue 297 | # resolve thunk if necessary 298 | func = arg_ctx.called_to_func 299 | if func.isThunk(): 300 | func = func.getThunkedFunction(1) 301 | 302 | # variadic arguments are never safe to handle because they can make 303 | # ghidra start hallucinating varnodes 304 | if func.hasVarArgs() is True: 305 | continue 306 | # changing types on externals can cause some issues, so avoid it by default 307 | if func.isExternal() and skip_external is True: 308 | continue 309 | set_any_params = False 310 | for param_num, param_v in arg_ctx.args.items(): 311 | existing_datatype = getDataTypeForParam(func, param_num) 312 | if existing_datatype is None: 313 | log.warning("existing datatype was None for %s param %d" % (func.name, param_num)) 314 | existing_datatype = getUndefinedRegisterSizeDatatype() 315 | if existing_datatype == voidp_dt and overwrite_voidp is False: 316 | # don't overwrite void* because of functions like memcpy that actually use that type 317 | continue 318 | if existing_datatype == datatype: 319 | # if the parameter's datatype is already this datatype, we aren't making a change and shouldn't 320 | # do anything 321 | continue 322 | if undef_only is True and not existing_datatype.name.startswith("undefined"): 323 | continue 324 | if force_new is False: 325 | # getMostSpecificDataType prefers the first argument if the datatypes are equally specific 326 | if prefer_existing is True: 327 | preferred_dt, less_preferred_dt = existing_datatype, datatype 328 | else: 329 | preferred_dt, less_preferred_dt = datatype, existing_datatype 330 | chosen_datatype = MetaDataType.getMostSpecificDataType(preferred_dt, less_preferred_dt) 331 | if chosen_datatype != datatype: 332 | log.warning("skipping %s because the more specific datatype is not the provided one. to ignore this, run the function with force_new=True" % func.name) 333 | continue 334 | set_param_datatype(func, param_num, datatype, program) 335 | func.addTag("AUTO_PROPAGATED_PARAM_%d_DATATYPE" % param_num) 336 | set_any_params = True 337 | 338 | if set_any_params is True: 339 | # setting only one param can cause ghidra to assume that no parameters past the new one exist, 340 | # so make a best effort attempt to resolve that. 341 | fix_underreported_num_params(func) 342 | 343 | 344 | def prop_datatype_from_func_param(func, param_num, program=None): 345 | if program is None: 346 | program = currentProgram 347 | 348 | sig = func.getSignature() 349 | args = list(sig.getArguments()) 350 | dt = args[param_num-1].getDataType() 351 | du = DecompUtils(program) 352 | vns = du.get_varnodes_for_param(func, param_num) 353 | propagate_datatype_forward_to_function_signatures(vns, dt, program=program) 354 | --------------------------------------------------------------------------------