├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── binja_coverage.py ├── binja_frontend.py ├── client.py ├── comments.py ├── config.json.template ├── config.py ├── coverage.py ├── coverage_poll.py ├── example_idapythonrc.py ├── ida_frontend.py ├── plugin.json ├── redis ├── LICENSE ├── __init__.py ├── _compat.py ├── client.py ├── connection.py ├── exceptions.py ├── lock.py ├── sentinel.py └── utils.py └── viv_frontend.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | config.json 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 lunixbochs 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | revsync 2 | ======= 3 | 4 | Realtime sync plugin for IDA Pro, Binary Ninja and Vivisect 5 | 6 | Syncs: 7 | 8 | - Comments 9 | - Symbol names 10 | - Stack var names 11 | - Structs 12 | - Code coverage (how much time was spent looking at a block) 13 | 14 | IDA Pro Installation 15 | -------------------- 16 | 17 | First, clone to IDA Data Dir: 18 | 19 | - Windows: `%APPDATA%\Hex-Rays\IDA Pro` 20 | - Mac/Linux: `~/.idapro` 21 | 22 | Now: 23 | 24 | - Make a file in your dir above named _idapythonrc.py_ and append `import revsync`. 25 | - Copy _config.json.template_ to _config.json_ and fill out. 26 | - Restart IDA and look for revsync messages in the console. 27 | - In the Python console, typing `import revsync` should work without issue. 28 | 29 | Expected data directory layout is this (Mac/Linux): 30 | 31 | ``` 32 | ~/.idapro/idapythonrc.py 33 | ~/.idapro/revsync/ 34 | ``` 35 | 36 | Binary Ninja Installation 37 | ------------------------- 38 | 39 | - Install via the Plugin Manager (CMD/CTL-SHIFT-M) 40 | 41 | or: 42 | 43 | - Clone to [your plugin folder](https://github.com/Vector35/binaryninja-api/tree/dev/python/examples#loading-plugins). 44 | 45 | Then: 46 | 47 | - Restart if required. 48 | - Fill in config when prompted. 49 | - Load your binary, wait for analysis to finish 50 | - Use the Tools Menu, Right-Click or command-palette (CMD/CTL-P) to trigger revsync/Load 51 | -Done! 52 | 53 | 54 | Vivisect Installation 55 | --------------------- 56 | 57 | - Clone to [a plugin folder in your VIV_EXT_PATH (or ~/.viv/plugins/)](https://github.com/vivisect/vivisect/#extending-vivisect--vdb). 58 | 59 | Then: 60 | 61 | - Restart Vivisect 62 | - Fill in config when prompted. 63 | - Load your binary, wait for analysis to finish 64 | - Use the Plugins -> revsync -> Load option to trigger revsync/Load 65 | -Done! 66 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(__file__)) 5 | 6 | good = False 7 | 8 | def write_config(host, port, nick, password): 9 | path = os.path.dirname(os.path.abspath(__file__)) 10 | config = { 11 | "host": host, 12 | "port": port, 13 | "nick": nick, 14 | "password": password 15 | } 16 | with open(os.path.join(path, "config.json"), "w") as f: 17 | config = f.write(json.dumps(config)) 18 | 19 | try: 20 | import binaryninja 21 | # check if running in BinaryNinja: 22 | if binaryninja.core_ui_enabled(): 23 | import binaryninja.interaction as bi 24 | try: 25 | import config 26 | except ImportError: 27 | host_f = bi.TextLineField("host") 28 | port_f = bi.IntegerField("port") 29 | nick_f = bi.TextLineField("nick") 30 | password_f = bi.TextLineField("password") 31 | success = bi.get_form_input([None, host_f, port_f, nick_f, password_f], "Configure Revsync") 32 | if not success: 33 | binaryninja.interaction.show_message_box(title="Revsync error", text="Failed to configure revsync") 34 | raise 35 | 36 | write_config(host_f.result, port_f.result, nick_f.result, password_f.result) 37 | import config 38 | import binja_frontend 39 | #import binja_coverage 40 | good = True 41 | except ImportError: 42 | pass 43 | 44 | # check if running in Vivisect: 45 | if 'vw' in globals(): 46 | try: 47 | import vivisect 48 | try: 49 | import config 50 | except ImportError: 51 | import vqt.common as vcmn 52 | dynd = vcmn.DynamicDialog('RevSync Config') 53 | dynd.addTextField("host") 54 | dynd.addIntHexField("port", dflt=6379) 55 | dynd.addTextField("nick") 56 | dynd.addTextField("password") 57 | res = dynd.prompt() 58 | if not len(res): 59 | vcmn.warning("Revsync error", "Failed to configure revsync") 60 | raise 61 | 62 | write_config(res.get('host'), res.get('port'), res.get('nick'), res.get('password')) 63 | import config 64 | 65 | import viv_frontend 66 | good = True 67 | except ImportError: 68 | pass 69 | 70 | 71 | # if idaapi loads, go with it. 72 | try: 73 | import idaapi 74 | import ida_frontend 75 | good = True 76 | except ImportError: 77 | pass 78 | 79 | if not good: 80 | print('Warning: Could not find an appropriate plugin environment: IDA, Binary Ninja, and Vivisect plugin API imports failed') 81 | raise ImportError 82 | 83 | # Vivisect looks for this in a plugin 84 | def vivExtension(vw, vwgui): 85 | import viv_frontend 86 | viv_frontend.vivExtension(vw, vwgui) 87 | -------------------------------------------------------------------------------- /binja_coverage.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | from binaryninja import * 3 | from binaryninja.plugin import PluginCommand 4 | import logging 5 | import random 6 | 7 | COVERAGE_FIRST_LOAD = True 8 | SHOW_VISITS = True 9 | SHOW_LENGTH = True 10 | SHOW_VISITORS = False 11 | TRACK_COVERAGE = True 12 | IDLE_ASK = 250 13 | COLOUR_PERIOD = 20 14 | bb_coverage = {} 15 | 16 | 17 | def get_func_by_addr(bv, addr): 18 | bb = bv.get_basic_blocks_at(addr) 19 | if len(bb) > 0: 20 | return bb[0].function 21 | return None 22 | 23 | 24 | def get_bb_by_addr(bv, addr): 25 | bb = bv.get_basic_blocks_at(addr) 26 | if len(bb) > 0: 27 | return bb[0] 28 | return None 29 | 30 | 31 | def colour_blocks(blocks, max_visits, max_length, max_visitors): 32 | global SHOW_VISITS 33 | global SHOW_LENGTH 34 | global SHOW_VISITORS 35 | for bb in blocks: 36 | cov = blocks[bb] 37 | R, B, G = 0, 0, 0 38 | if SHOW_VISITS and cov["visits"] > 0: 39 | R = (cov["visits"] * 0x96) / max_visits 40 | if SHOW_LENGTH and cov["length"] > 0: 41 | B = (cov["length"] * 0x96) / max_length 42 | if SHOW_VISITORS and cov["visitors"] > 0: 43 | G = (cov["visitors"] * 0x96) / max_visitors 44 | if R == 0 and B == 0 and G == 0: 45 | bb.set_user_highlight(highlight.HighlightColor(red=74, blue=74, green=74)) 46 | else: 47 | bb.set_user_highlight(highlight.HighlightColor(red=R, blue=B, green=G)) 48 | 49 | 50 | def colour_coverage(cur_func, coverage): 51 | if cur_func is None: 52 | return 53 | blocks = {} 54 | max_visits = 0 55 | max_length = 0 56 | max_visitors = 0 57 | for bb in coverage: 58 | if coverage[bb]["visits"] > max_visits: 59 | max_visits = coverage[bb]["visits"] 60 | if coverage[bb]["length"] > max_length: 61 | max_length = coverage[bb]["length"] 62 | if coverage[bb]["visitors"] > max_visitors: 63 | max_visitors = coverage[bb]["visitors"] 64 | if bb.function == cur_func: 65 | blocks[bb] = coverage[bb] 66 | colour_blocks(blocks, max_visits, max_length, max_visitors) 67 | 68 | 69 | def watch_cur_func(bv): 70 | global TRACK_COVERAGE 71 | global bb_coverage 72 | 73 | def get_cur_func(): 74 | return get_func_by_addr(bv, bv.offset) 75 | 76 | def get_cur_bb(): 77 | return get_bb_by_addr(bv, bv.offset) 78 | 79 | last_func = None 80 | last_bb = None 81 | last_addr = None 82 | cur_func = None 83 | idle = 0 84 | colour = 0 85 | while True: 86 | if TRACK_COVERAGE: 87 | if idle > IDLE_ASK: 88 | res = get_choice_input("Continue coverage tracking?", "Idle Detection", ["Disable", "Continue"]) 89 | if res == 0: 90 | log_info('Coverage: Tracking Stopped') 91 | exit() 92 | else: 93 | idle = 0 94 | if last_addr == bv.offset: 95 | idle += 1 96 | if last_bb is not None: 97 | bb_coverage[last_bb]["length"] += 1 98 | sleep(0.50) 99 | else: 100 | cur_bb = get_cur_bb() 101 | cur_func = get_cur_func() 102 | last_addr = bv.offset 103 | idle = 0 104 | if cur_bb != last_bb: 105 | if cur_bb is not None: 106 | if cur_bb not in bb_coverage: 107 | bb_coverage[cur_bb] = {"visits": 0, "length": 0, "visitors": random.randint(1, 50)} 108 | bb_coverage[cur_bb]["visits"] += 1 109 | last_bb = cur_bb 110 | if cur_func != last_func: 111 | colour = COLOUR_PERIOD 112 | last_func = cur_func 113 | colour += 1 114 | if colour > COLOUR_PERIOD: 115 | colour_coverage(cur_func, bb_coverage) 116 | colour = 0 117 | else: 118 | idle = 0 119 | sleep(2) 120 | 121 | 122 | def coverage_load(bv): 123 | global COVERAGE_FIRST_LOAD 124 | global SHOW_VISITS 125 | global SHOW_LENGTH 126 | global SHOW_VISITORS 127 | log_info('Coverage: Tracking Started') 128 | if COVERAGE_FIRST_LOAD: 129 | opt_visit = ChoiceField("Visualize Visits (Red)", ["Yes", "No"]) 130 | opt_length = ChoiceField("Visualize Length (Blue)", ["Yes", "No"]) 131 | opt_visitors = ChoiceField("Visualize Visitors (Green)", ["No", "Yes"]) 132 | res = get_form_input(["Visualize by colouring backgrounds?", None, opt_visit, opt_length, opt_visitors], 133 | "Visualization Options") 134 | if res: 135 | log_info('Coverage: Visualization Options Set') 136 | if opt_visit.result > 0: 137 | SHOW_VISITS = not SHOW_VISITS 138 | if opt_length.result > 0: 139 | SHOW_LENGTH = not SHOW_LENGTH 140 | if opt_visitors.result > 0: 141 | SHOW_VISITORS = not SHOW_VISITORS 142 | COVERAGE_FIRST_LOAD = False 143 | t1 = threading.Thread(target=watch_cur_func, args=(bv,)) 144 | t1.daemon = True 145 | t1.start() 146 | 147 | 148 | def toggle_visits(bv): 149 | global SHOW_VISITS 150 | SHOW_VISITS = not SHOW_VISITS 151 | 152 | 153 | def toggle_length(bv): 154 | global SHOW_LENGTH 155 | SHOW_LENGTH = not SHOW_LENGTH 156 | 157 | 158 | def toggle_visitors(bv): 159 | global SHOW_VISITORS 160 | SHOW_VISITORS = not SHOW_VISITORS 161 | 162 | 163 | PluginCommand.register('Coverage: Start Tracking', 'Track Coverage', coverage_load) 164 | PluginCommand.register('Coverage: Toggle Visits (RED)', 'Toggle Red', toggle_visits) 165 | PluginCommand.register('Coverage: Toggle Length (BLUE)', 'Toggle Blue', toggle_length) 166 | PluginCommand.register('Coverage: Toggle Visitors (GREEN)', 'Toggle Green', toggle_visitors) -------------------------------------------------------------------------------- /binja_frontend.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import math 3 | import time 4 | from collections import defaultdict 5 | 6 | from binaryninja import * 7 | from binaryninja.plugin import PluginCommand 8 | 9 | from client import Client 10 | from config import config 11 | from comments import Comments, NoChange 12 | from coverage import Coverage 13 | from threading import Lock 14 | from collections import namedtuple 15 | 16 | Struct = namedtuple('Struct', 'name typedef') 17 | 18 | class State: 19 | @staticmethod 20 | def get(bv): 21 | return bv.session_data.get('revsync') 22 | 23 | show_visits = True 24 | show_time = True 25 | show_visitors = False 26 | track_coverage = True 27 | color_now = False 28 | running = False 29 | 30 | def __init__(self, bv): 31 | self.cov = Coverage() 32 | self.comments = Comments() 33 | self.fhash = get_fhash(bv.file.filename) 34 | self.running = True 35 | self.cmt_changes = {} 36 | self.cmt_lock = Lock() 37 | #self.stackvar_changes = {} 38 | #self.stackvar_lock = Lock() 39 | self.data_syms = get_syms(bv, SymbolType.DataSymbol) 40 | self.func_syms = get_syms(bv, SymbolType.FunctionSymbol) 41 | self.syms_lock = Lock() 42 | #self.structs = get_structs(bv) 43 | #self.structs_lock = Lock() 44 | 45 | def close(self): 46 | self.running = False 47 | 48 | MIN_COLOR = 0 49 | MAX_COLOR = 200 50 | 51 | IDLE_ASK = 250 52 | COLOUR_PERIOD = 20 53 | BB_REPORT = 50 54 | 55 | def get_fhash(fname): 56 | with open(fname, 'rb') as f: 57 | return hashlib.sha256(f.read()).hexdigest().upper() 58 | 59 | def get_can_addr(bv, addr): 60 | return addr - bv.start 61 | 62 | def get_ea(bv, addr): 63 | return addr + bv.start 64 | 65 | def get_func_by_addr(bv, addr): 66 | bb = bv.get_basic_blocks_at(addr) 67 | if len(bb) > 0: 68 | return bb[0].function 69 | return None 70 | 71 | def get_bb_by_addr(bv, addr): 72 | bb = bv.get_basic_blocks_at(addr) 73 | if len(bb) > 0: 74 | return bb[0] 75 | return None 76 | 77 | # in order to map IDA type sizes <-> binja types, 78 | # take the 'size' field and attempt to divine some 79 | # kind of type in binja that's close as possible. 80 | # right now, that means try to use uint8_t ... uint64_t, 81 | # anything bigger just make an array 82 | def get_type_by_size(bv, size): 83 | typedef = None 84 | if size <= 8: 85 | try: 86 | typedef, name = bv.parse_type_string('uint{}_t'.format(8*size)) 87 | except SyntaxError: 88 | pass 89 | else: 90 | try: 91 | typedef, name = bv.parse_type_string('char a[{}]'.format(8*size)) 92 | except SyntaxError: 93 | pass 94 | return typedef 95 | 96 | #def get_structs(bv): 97 | # d = dict() 98 | # for name, typedef in bv.types.items(): 99 | # if typedef.structure: 100 | # typeid = bv.get_type_id(name) 101 | # struct = Struct(name, typedef.structure) 102 | # d[typeid] = struct 103 | # return d 104 | 105 | def get_syms(bv, sym_type): 106 | # comes as list of Symbols 107 | syms = bv.get_symbols_of_type(sym_type) 108 | # turn our list into dict of addr => sym name 109 | syms_dict = dict() 110 | for sym in syms: 111 | # Sometimes there are duplicate symbols for a given address, and the 112 | # order they're returned from bv.get_symbols_of_type is 113 | # nondeterministic. 114 | # This created a bug where revsync would think the user renamed a ton of 115 | # symbols whenever that order happened to change. 116 | # This is probably slow, but I don't know what else to do about it 117 | if bv.get_symbol_at(sym.address) != sym: 118 | continue 119 | syms_dict[sym.address] = sym.name 120 | return syms_dict 121 | 122 | def stack_dict_from_list(stackvars): 123 | d = {} 124 | for var in stackvars: 125 | d[var.storage] = (var.name, var.type) 126 | return d 127 | 128 | def member_dict_from_list(members): 129 | d = {} 130 | for member in members: 131 | d[member.name] = member 132 | return d 133 | 134 | def rename_symbol(bv, addr, name): 135 | sym = bv.get_symbol_at(addr) 136 | if sym is not None: 137 | # symbol already exists for this address 138 | if sym.auto is True: 139 | bv.undefine_auto_symbol(sym) 140 | else: 141 | bv.undefine_user_symbol(sym) 142 | # is it a function? 143 | func = get_func_by_addr(bv, addr) 144 | if func is not None: 145 | # function 146 | sym = types.Symbol(SymbolType.FunctionSymbol, addr, name) 147 | else: 148 | # data 149 | sym = types.Symbol(SymbolType.DataSymbol, addr, name) 150 | bv.define_user_symbol(sym) 151 | 152 | #def rename_stackvar(bv, func_addr, offset, name): 153 | # func = get_func_by_addr(bv, func_addr) 154 | # if func is None: 155 | # log_info('revsync: bad func addr %#x during rename_stackvar' % func_addr) 156 | # return 157 | # # we need to figure out the variable type before renaming 158 | # stackvars = stack_dict_from_list(func.vars) 159 | # var = stackvars.get(offset) 160 | # if var is None: 161 | # log_info('revsync: could not locate stack var with offset %#x during rename_stackvar' % offset) 162 | # return 163 | # var_name, var_type = var 164 | # func.create_user_stack_var(offset, var_type, name) 165 | # return 166 | 167 | def publish(bv, data, **kwargs): 168 | state = State.get(bv) 169 | if state: 170 | client.publish(state.fhash, data, **kwargs) 171 | 172 | def push_cv(bv, data, **kwargs): 173 | state = State.get(bv) 174 | if state: 175 | client.push("%s_COVERAGE" % state.fhash, data, **kwargs) 176 | 177 | def onmsg(bv, key, data, replay): 178 | state = State.get(bv) 179 | if key != state.fhash: 180 | log_info('revsync: hash mismatch, dropping command') 181 | return 182 | cmd, user = data['cmd'], data['user'] 183 | ts = int(data.get('ts', 0)) 184 | if cmd == 'comment': 185 | with state.cmt_lock: 186 | log_info('revsync: <%s> %s %#x %s' % (user, cmd, data['addr'], data['text'])) 187 | addr = get_ea(bv, int(data['addr'])) 188 | func = get_func_by_addr(bv, addr) 189 | # binja does not support comments on data symbols??? IDA does. 190 | if func is not None: 191 | text = state.comments.set(addr, user, data['text'], ts) 192 | func.set_comment(addr, text) 193 | state.cmt_changes[addr] = text 194 | elif cmd == 'extra_comment': 195 | log_info('revsync: <%s> %s %#x %s' % (user, cmd, data['addr'], data['text'])) 196 | elif cmd == 'area_comment': 197 | log_info('revsync: <%s> %s %s %s' % (user, cmd, data['range'], data['text'])) 198 | elif cmd == 'rename': 199 | with state.syms_lock: 200 | log_info('revsync: <%s> %s %#x %s' % (user, cmd, data['addr'], data['text'])) 201 | addr = get_ea(bv, int(data['addr'])) 202 | rename_symbol(bv, addr, data['text']) 203 | state.data_syms = get_syms(bv, SymbolType.DataSymbol) 204 | state.func_syms = get_syms(bv, SymbolType.FunctionSymbol) 205 | # elif cmd == 'stackvar_renamed': 206 | # with state.stackvar_lock: 207 | # func_name = '???' 208 | # func = get_func_by_addr(bv, data['addr']) 209 | # if func: 210 | # func_name = func.name 211 | # log_info('revsync: <%s> %s %s %#x %s' % (user, cmd, func_name, data['offset'], data['name'])) 212 | # rename_stackvar(bv, data['addr'], data['offset'], data['name']) 213 | # # save stackvar changes using the tuple (func_addr, offset) as key 214 | # state.stackvar_changes[(data['addr'],data['offset'])] = data['name'] 215 | # elif cmd == 'struc_created': 216 | # with state.structs_lock: 217 | # # note: binja does not seem to appreciate the encoding of strings from redis 218 | # struct_name = data['struc_name'] 219 | # struct = bv.get_type_by_name(struct_name) 220 | # # if a struct with the same name already exists, undefine it 221 | # if struct: 222 | # bv.undefine_user_type(struct_name) 223 | # struct = Structure() 224 | # bv.define_user_type(struct_name, binaryninja.types.Type.structure_type(struct)) 225 | # state.structs = get_structs(bv) 226 | # log_info('revsync: <%s> %s %s' % (user, cmd, struct_name)) 227 | # elif cmd == 'struc_deleted': 228 | # with state.structs_lock: 229 | # struct_name = data['struc_name'] 230 | # struct = bv.get_type_by_name(struct_name) 231 | # # make sure the type is defined first 232 | # if struct is None: 233 | # log_info('revsync: unknown struct name %s during struc_deleted cmd' % struct_name) 234 | # return 235 | # bv.undefine_user_type(struct_name) 236 | # state.structs = get_structs(bv) 237 | # log_info('revsync: <%s> %s %s' % (user, cmd, struct_name)) 238 | # elif cmd == 'struc_renamed': 239 | # with state.structs_lock: 240 | # old_struct_name = data['old_name'] 241 | # new_struct_name = data['new_name'] 242 | # struct = bv.get_type_by_name(old_struct_name) 243 | # # make sure the type is defined first 244 | # if struct is None: 245 | # log_info('revsync: unknown struct name %s during struc_renamed cmd' % old_struct_name) 246 | # return 247 | # bv.rename_type(old_struct_name, new_struct_name) 248 | # state.structs = get_structs(bv) 249 | # log_info('revsync: <%s> %s %s %s' % (user, cmd, old_struct_name, new_struct_name)) 250 | # elif cmd == 'struc_member_created': 251 | # with state.structs_lock: 252 | # struct_name = data['struc_name'] 253 | # struct = bv.get_type_by_name(struct_name) 254 | # if struct is None: 255 | # log_info('revsync: unknown struct name %s during struc_member_created cmd' % struct_name) 256 | # return 257 | # member_name = data['member_name'] 258 | # struct_type = get_type_by_size(bv, data['size']) 259 | # if struct_type is None: 260 | # log_info('revsync: bad struct member size %d for member %s during struc_member_created cmd' % (data['size'], member_name)) 261 | # return 262 | # # need actual Structure class, not Type 263 | # struct = struct.structure.mutable_copy() 264 | # struct.insert(data['offset'], struct_type, member_name) 265 | # # we must redefine the type 266 | # bv.define_user_type(struct_name, binaryninja.types.Type.structure_type(struct)) 267 | # state.structs = get_structs(bv) 268 | # log_info('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, member_name)) 269 | # elif cmd == 'struc_member_deleted': 270 | # with state.structs_lock: 271 | # struct_name = data['struc_name'] 272 | # struct = bv.get_type_by_name(struct_name) 273 | # if struct is None: 274 | # log_info('revsync: unknown struct name %s during struc_member_deleted cmd' % struct_name) 275 | # return 276 | # offset = data['offset'] 277 | # # need actual Structure class, not Type 278 | # struct = struct.structure.mutable_copy() 279 | # # walk the list and find the index to delete (seriously, why by index binja and not offset?) 280 | # member_name = '???' 281 | # for i,m in enumerate(struct.members): 282 | # if m.offset == offset: 283 | # # found it 284 | # member_name = m.name 285 | # struct.remove(i) 286 | # # we must redefine the type 287 | # bv.define_user_type(struct_name, binaryninja.types.Type.structure_type(struct)) 288 | # state.structs = get_structs(bv) 289 | # log_info('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, member_name)) 290 | # elif cmd == 'struc_member_renamed': 291 | # with state.structs_lock: 292 | # struct_name = data['struc_name'] 293 | # member_name = data['member_name'] 294 | # struct = bv.get_type_by_name(struct_name) 295 | # if struct is None: 296 | # log_info('revsync: unknown struct name %s during struc_member_renamed cmd' % struct_name) 297 | # return 298 | # offset = data['offset'] 299 | # # need actual Structure class, not Type 300 | # struct = struct.structure.mutable_copy() 301 | # for i,m in enumerate(struct.members): 302 | # if m.offset == offset: 303 | # struct.replace(i, m.type, member_name) 304 | # bv.define_user_type(struct_name, binaryninja.types.Type.structure_type(struct)) 305 | # log_info('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, member_name)) 306 | # break 307 | # state.structs = get_structs(bv) 308 | # elif cmd == 'struc_member_changed': 309 | # with state.structs_lock: 310 | # struct_name = data['struc_name'] 311 | # struct = bv.get_type_by_name(struct_name) 312 | # if struct is None: 313 | # log_info('revsync: unknown struct name %s during struc_member_renamed cmd' % struct_name) 314 | # return 315 | # # need actual Structure class, not Type 316 | # struct = struct.structure.mutable_copy() 317 | # offset = data['offset'] 318 | # for i,m in enumerate(struct.members): 319 | # if m.offset == offset: 320 | # struct.replace(i, get_type_by_size(bv, data['size']), m.name) 321 | # bv.define_user_type(struct_name, binaryninja.types.Type.structure_type(struct)) 322 | # log_info('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, m.name)) 323 | # break 324 | # state.structs = get_structs(bv) 325 | elif cmd == 'join': 326 | log_info('revsync: <%s> joined' % (user)) 327 | elif cmd == 'coverage': 328 | log_info("Updating Global Coverage") 329 | state.cov.update(json.loads(data['blocks'])) 330 | state.color_now = True 331 | else: 332 | log_info('revsync: unknown cmd %s' % data) 333 | 334 | def revsync_callback(bv): 335 | def callback(key, data, replay=False): 336 | onmsg(bv, key, data, replay) 337 | return callback 338 | 339 | def revsync_comment(bv, addr): 340 | comment = interaction.get_text_line_input('Enter comment: ', 'revsync comment') 341 | publish(bv, {'cmd': 'comment', 'addr': get_can_addr(bv, addr), 'text': comment or ''}, send_uuid=False) 342 | get_func_by_addr(bv, addr).set_comment(addr, comment) 343 | 344 | def revsync_rename(bv, addr): 345 | name = interaction.get_text_line_input('Enter symbol name: ', 'revsync rename') 346 | publish(bv, {'cmd': 'rename', 'addr': get_can_addr(bv, addr), 'text': name}) 347 | rename_symbol(bv, addr, name) 348 | 349 | def map_color(x): 350 | n = x 351 | if x == 0: return 0 352 | # x = min(max(0, (x ** 2) / (2 * (x ** 2 - x) + 1)), 1) 353 | # if x == 0: return 0 354 | return int(math.ceil((MAX_COLOR - MIN_COLOR) * x + MIN_COLOR)) 355 | 356 | def convert_color(color): 357 | r, g, b = [map_color(x) for x in color] 358 | return highlight.HighlightColor(red=r, green=g, blue=b) 359 | 360 | def colour_coverage(bv, cur_func): 361 | state = State.get(bv) 362 | for bb in cur_func.basic_blocks: 363 | color = state.cov.color(get_can_addr(bv, bb.start), visits=state.show_visits, time=state.show_time, users=state.show_visitors) 364 | if color: 365 | bb.set_user_highlight(convert_color(color)) 366 | else: 367 | bb.set_user_highlight(highlight.HighlightColor(red=74, blue=74, green=74)) 368 | 369 | #def watch_structs(bv): 370 | # """ Check structs for changes and publish diffs""" 371 | # state = State.get(bv) 372 | # 373 | # while state.running: 374 | # with state.structs_lock: 375 | # structs = get_structs(bv) 376 | # if structs != state.structs: 377 | # for struct_id, struct in structs.items(): 378 | # last_struct = state.structs.get(struct_id) 379 | # struct_name = struct.name 380 | # if last_struct == None: 381 | # # new struct created, publish 382 | # log_info('revsync: user created struct %s' % struct_name) 383 | # # binja can't really handle unions at this time 384 | # publish(bv, {'cmd': 'struc_created', 'struc_name': str(struct_name), 'is_union': False}) 385 | # # if there are already members, publish them 386 | # members = member_dict_from_list(struct.typedef.members) 387 | # if members: 388 | # for member_name, member_def in members.items(): 389 | # publish(bv, {'cmd': 'struc_member_created', 'struc_name': str(struct_name), 'offset': member_def.offset, 'member_name': member_name, 'size': member_def.type.width, 'flag': None}) 390 | # continue 391 | # last_name = last_struct.name 392 | # if last_name != struct_name: 393 | # # struct renamed, publish 394 | # log_info('revsync: user renamed struct %s' % struct_name) 395 | # publish(bv, {'cmd': 'struc_renamed', 'old_name': str(last_name), 'new_name': str(struct_name)}) 396 | # 397 | # # check for member differences 398 | # members = member_dict_from_list(struct.typedef.members) 399 | # last_members = member_dict_from_list(last_struct.typedef.members) 400 | # 401 | # # first checks for deletions 402 | # removed_members = set(last_members.keys()) - set(members.keys()) 403 | # for member in removed_members: 404 | # log_info('revsync: user deleted struct member %s in struct %s' % (last_members[member].name, str(struct_name))) 405 | # publish(bv, {'cmd': 'struc_member_deleted', 'struc_name': str(struct_name), 'offset': last_members[member].offset}) 406 | # 407 | # # now check for additions 408 | # new_members = set(members.keys()) - set(last_members.keys()) 409 | # for member in new_members: 410 | # log_info('revsync: user added struct member %s in struct %s' % (members[member].name, str(struct_name))) 411 | # publish(bv, {'cmd': 'struc_member_created', 'struc_name': str(struct_name), 'offset': members[member].offset, 'member_name': str(member), 'size': members[member].type.width, 'flag': None}) 412 | # 413 | # # check for changes among intersection of members 414 | # intersec = set(members.keys()) & set(last_members.keys()) 415 | # for m in intersec: 416 | # if members[m].type.width != last_members[m].type.width: 417 | # # type (i.e., size) changed 418 | # log_info('revsync: user changed struct member %s in struct %s' % (members[m].name, str(struct_name))) 419 | # publish(bv, {'cmd': 'struc_member_changed', 'struc_name': str(struct_name), 'offset': members[m].offset, 'size': members[m].type.width}) 420 | # 421 | # for struct_id, struct_def in state.structs.items(): 422 | # if structs.get(struct_id) == None: 423 | # # struct deleted, publish 424 | # log_info('revsync: user deleted struct %s' % struct_def.name) 425 | # publish(bv, {'cmd': 'struc_deleted', 'struc_name': str(struct_def.name)}) 426 | # state.structs = get_structs(bv) 427 | # time.sleep(0.5) 428 | 429 | def watch_syms(bv, sym_type): 430 | """ Watch symbols of a given type (e.g. DataSymbol) for changes and publish diffs """ 431 | state = State.get(bv) 432 | 433 | while state.running: 434 | with state.syms_lock: 435 | # DataSymbol 436 | data_syms = get_syms(bv, SymbolType.DataSymbol) 437 | if data_syms != state.data_syms: 438 | for addr, name in data_syms.items(): 439 | if state.data_syms.get(addr) != name: 440 | # name changed, publish 441 | log_info('revsync: user renamed symbol at %#x: %s' % (addr, name)) 442 | publish(bv, {'cmd': 'rename', 'addr': get_can_addr(bv, addr), 'text': name}) 443 | 444 | # FunctionSymbol 445 | func_syms = get_syms(bv, SymbolType.FunctionSymbol) 446 | if func_syms != state.func_syms: 447 | for addr, name in func_syms.items(): 448 | if state.func_syms.get(addr) != name: 449 | # name changed, publish 450 | log_info('revsync: user renamed symbol at %#x: %s' % (addr, name)) 451 | publish(bv, {'cmd': 'rename', 'addr': get_can_addr(bv, addr), 'text': name}) 452 | 453 | state.data_syms = get_syms(bv, SymbolType.DataSymbol) 454 | state.func_syms = get_syms(bv, SymbolType.FunctionSymbol) 455 | time.sleep(0.5) 456 | 457 | def watch_cur_func(bv): 458 | """ Watch current function (if we're in code) for comment changes and publish diffs """ 459 | def get_cur_func(): 460 | return get_func_by_addr(bv, bv.offset) 461 | 462 | def get_cur_bb(): 463 | return get_bb_by_addr(bv, bv.offset) 464 | 465 | state = State.get(bv) 466 | last_func = get_cur_func() 467 | last_bb = get_cur_bb() 468 | last_time = time.time() 469 | last_bb_report = time.time() 470 | last_bb_addr = None 471 | last_addr = None 472 | while state.running: 473 | now = time.time() 474 | if state.track_coverage and now - last_bb_report >= BB_REPORT: 475 | last_bb_report = now 476 | push_cv(bv, {'b': state.cov.flush()}) 477 | 478 | if last_addr == bv.offset: 479 | time.sleep(0.25) 480 | continue 481 | else: 482 | # were we just in a function? 483 | if last_func: 484 | with state.cmt_lock: 485 | comments = last_func.comments 486 | # check for changed comments 487 | for cmt_addr, cmt in comments.items(): 488 | last_cmt = state.cmt_changes.get(cmt_addr) 489 | if last_cmt == None or last_cmt != cmt: 490 | # new/changed comment, publish 491 | try: 492 | addr = get_can_addr(bv, cmt_addr) 493 | changed = state.comments.parse_comment_update(addr, client.nick, cmt) 494 | log_info('revsync: user changed comment: %#x, %s' % (addr, changed)) 495 | publish(bv, {'cmd': 'comment', 'addr': addr, 'text': changed}) 496 | state.cmt_changes[cmt_addr] = changed 497 | except NoChange: 498 | pass 499 | continue 500 | 501 | # TODO: this needs to be fixed later 502 | """ 503 | # check for removed comments 504 | if last_comments: 505 | removed = set(last_comments.keys()) - set(comments.keys()) 506 | for addr in removed: 507 | addr = get_can_addr(bv, addr) 508 | log_info('revsync: user removed comment: %#x' % addr) 509 | publish(bv, {'cmd': 'comment', 'addr': addr, 'text': ''}) 510 | """ 511 | 512 | ## similar dance, but with stackvars 513 | #with state.stackvar_lock: 514 | # stackvars = stack_dict_from_list(last_func.vars) 515 | # for offset, data in stackvars.items(): 516 | # # stack variables are more difficult than comments to keep state on, since they 517 | # # exist from the beginning, and have a type. track each one. start by tracking the first 518 | # # time we see it. if there are changes after that, publish. 519 | # stackvar_name, stackvar_type = data 520 | # stackvar_val = state.stackvar_changes.get((last_func.start,offset)) 521 | # if stackvar_val == None: 522 | # # never seen before, start tracking 523 | # state.stackvar_changes[(last_func.start,offset)] = stackvar_name 524 | # elif stackvar_val != stackvar_name: 525 | # # stack var name changed, publish 526 | # log_info('revsync: user changed stackvar name at offset %#x to %s' % (offset, stackvar_name)) 527 | # publish(bv, {'cmd': 'stackvar_renamed', 'addr': last_func.start, 'offset': offset, 'name': stackvar_name}) 528 | # state.stackvar_changes[(last_func.start,offset)] = stackvar_name 529 | 530 | if state.track_coverage: 531 | cur_bb = get_cur_bb() 532 | if cur_bb != last_bb: 533 | state.color_now = True 534 | now = time.time() 535 | if last_bb_addr is not None: 536 | state.cov.visit_addr(last_bb_addr, elapsed=now - last_time, visits=1) 537 | last_time = now 538 | if cur_bb is None: 539 | last_bb_addr = None 540 | else: 541 | last_bb_addr = get_can_addr(bv, cur_bb.start) 542 | 543 | # update current function/addr info 544 | last_func = get_cur_func() 545 | last_bb = get_cur_bb() 546 | last_addr = bv.offset 547 | 548 | if state.color_now and last_func != None: 549 | colour_coverage(bv, last_func) 550 | state.color_now = False 551 | 552 | def do_analysis_and_wait(bv): 553 | log_info('revsync: running analysis update...') 554 | bv.update_analysis_and_wait() 555 | log_info('revsync: analysis finished.') 556 | return 557 | 558 | def revsync_load(bv): 559 | global client 560 | 561 | # lets ensure auto-analysis is finished by forcing another analysis 562 | t0 = threading.Thread(target=do_analysis_and_wait, args=(bv,)) 563 | t0.start() 564 | t0.join() 565 | 566 | try: 567 | client 568 | except: 569 | client = Client(**config) 570 | state = bv.session_data.get('revsync') 571 | if state: 572 | # close out the previous session 573 | client.leave(state.fhash) 574 | state.close() 575 | 576 | state = bv.session_data['revsync'] = State(bv) 577 | log_info('revsync: connecting with %s' % state.fhash) 578 | client.join(state.fhash, revsync_callback(bv)) 579 | log_info('revsync: connected!') 580 | t1 = threading.Thread(target=watch_cur_func, args=(bv,)) 581 | t2 = threading.Thread(target=watch_syms, args=(bv,SymbolType.DataSymbol)) 582 | t3 = threading.Thread(target=watch_syms, args=(bv,SymbolType.FunctionSymbol)) 583 | #t4 = threading.Thread(target=watch_structs, args=(bv,)) 584 | t1.daemon = True 585 | t2.daemon = True 586 | t3.daemon = True 587 | #t4.daemon = True 588 | t1.start() 589 | t2.start() 590 | t3.start() 591 | #t4.start() 592 | 593 | def toggle_visits(bv): 594 | state = State.get(bv) 595 | state.show_visits = not state.show_visits 596 | if state.show_visits: 597 | log_info("Visit Visualization Enabled (Red)") 598 | else: 599 | log_info("Visit Visualization Disabled (Red)") 600 | state.color_now = True 601 | 602 | def toggle_time(bv): 603 | state = State.get(bv) 604 | state.show_time = not state.show_time 605 | if state.show_time: 606 | log_info("Time Visualization Enabled (Blue)") 607 | else: 608 | log_info("Time Visualization Disabled (Blue)") 609 | state.color_now = True 610 | 611 | def toggle_visitors(bv): 612 | state = State.get(bv) 613 | state.show_visitors = not state.show_visitors 614 | if state.show_visitors: 615 | log_info("Visitor Visualization Enabled (Green)") 616 | else: 617 | log_info("Visitor Visualization Disabled (Green)") 618 | state.color_now = True 619 | 620 | def toggle_track(bv): 621 | state = State.get(bv) 622 | state.track_coverage = not state.track_coverage 623 | if state.track_coverage: 624 | log_info("Tracking Enabled") 625 | else: 626 | log_info("Tracking Disabled") 627 | 628 | PluginCommand.register('revsync\\Coverage: Toggle Tracking', 'Toggle Tracking', toggle_track) 629 | PluginCommand.register('revsync\\Coverage: Toggle Visits (RED)', 'Toggle Red', toggle_visits) 630 | PluginCommand.register('revsync\\Coverage: Toggle Time (BLUE)', 'Toggle Blue', toggle_time) 631 | PluginCommand.register('revsync\\Coverage: Toggle Visitors (GREEN)', 'Toggle Green', toggle_visitors) 632 | PluginCommand.register('revsync\\Load', 'load revsync!!!', revsync_load) 633 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import re 4 | import redis 5 | import threading 6 | import time 7 | import traceback 8 | import uuid 9 | import base64 10 | import binascii 11 | 12 | TTL = 2 13 | 14 | hash_keys = ('cmd', 'user') 15 | cmd_hash_keys = { 16 | 'comment': ('addr',), 17 | 'extra_comment': ('addr',), 18 | 'area_comment': ('addr',), 19 | 'rename': ('addr',), 20 | 'stackvar_renamed': ('addr', 'offset', 'name',), 21 | 'struc_created': ('struc_name', 'is_union',), 22 | 'struc_deleted': ('struc_name',), 23 | 'struc_renamed': ('old_name', 'new_name',), 24 | 'struc_member_created': ('struc_name', 'offset', 'member_name', 'size', 'flag',), 25 | 'struc_member_deleted': ('struc_name', 'offset',), 26 | 'struc_member_renamed': ('struc_name', 'offset', 'member_name',), 27 | 'struc_member_changed': ('struc_name', 'offset', 'size',), 28 | } 29 | key_dec = { 30 | 'c': 'cmd', 31 | 'a': 'addr', 32 | 'u': 'user', 33 | 't': 'text', 34 | 'i': 'uuid', 35 | 'b': 'blocks' 36 | } 37 | key_enc = dict((v, k) for k, v in key_dec.items()) 38 | nick_filter = re.compile(r'[^a-zA-Z0-9_\-]') 39 | 40 | def decode(data): 41 | d = json.loads(data) 42 | return dict((key_dec.get(k, k), v) for k, v in d.items()) 43 | 44 | def dtokey(d): 45 | return tuple(((k, v) for k, v in sorted(d.items()) if k not in ('user', 'ts', 'uuid'))) 46 | 47 | def remove_ttl(a): 48 | now = time.time() 49 | return [d for d in a if now - d[0] < TTL] 50 | 51 | class Client: 52 | def __init__(self, host, port, nick, password=None): 53 | self.r = redis.StrictRedis(host=host, port=port, password=password, socket_connect_timeout=5) 54 | self.r.info() 55 | self.nick = nick_filter.sub('_', nick) 56 | self.ps = {} 57 | self.nolock = threading.Lock() 58 | self.nosend = defaultdict(list) 59 | self.uuid = str(base64.b64encode(binascii.unhexlify(uuid.uuid4().hex)).decode('ascii')) 60 | 61 | def debounce(self, no, data): 62 | dkey = dtokey(data) 63 | now = time.time() 64 | with self.nolock: 65 | for data in no: 66 | ts = data[0] 67 | key = data[1:] 68 | if dkey == key and now - ts < TTL: 69 | no.remove(data) 70 | return True 71 | return False 72 | 73 | def _sub_thread(self, ps, cb, key): 74 | for item in ps.listen(): 75 | try: 76 | if item['type'] == 'message': 77 | data = decode(item['data']) 78 | if 'user' in data: 79 | data['user'] = nick_filter.sub('_', data['user']) 80 | # reject our own messages 81 | if data.get('uuid') == self.uuid: 82 | continue 83 | with self.nolock: 84 | self.nosend[key] = remove_ttl(self.nosend[key]) 85 | self.nosend[key].append((time.time(),) + dtokey(data)) 86 | cb(key, data) 87 | elif item['type'] == 'subscribe': 88 | decoded = [] 89 | for data in self.r.lrange(key, 0, -1): 90 | try: 91 | decoded.append(decode(data)) 92 | except Exception: 93 | print('error decoding history', data) 94 | traceback.print_exc() 95 | 96 | state = [] 97 | dedup = set() 98 | for data in reversed(decoded): 99 | cmd = data.get('cmd') 100 | if cmd: 101 | keys = hash_keys + cmd_hash_keys.get(cmd, ()) 102 | hashkey = tuple([str(data.get(k)) for k in keys]) 103 | if all(hashkey): 104 | if hashkey in dedup: 105 | continue 106 | dedup.add(hashkey) 107 | state.append(data) 108 | 109 | for data in reversed(state): 110 | try: 111 | with self.nolock: 112 | self.nosend[key].append((time.time(),) + dtokey(data)) 113 | cb(key, data, replay=True) 114 | except Exception: 115 | print('error replaying history', data) 116 | traceback.print_exc() 117 | else: 118 | print('unknown redis push', item) 119 | except Exception: 120 | print('error processing item', item) 121 | traceback.print_exc() 122 | 123 | def join(self, key, cb): 124 | ps = self.r.pubsub() 125 | ps.subscribe(key) 126 | t = threading.Thread(target=self._sub_thread, args=(ps, cb, key)) 127 | t.daemon = True 128 | t.start() 129 | 130 | self.ps[key] = ps 131 | self.publish(key, {'cmd': 'join'}, perm=False) 132 | 133 | def leave(self, key): 134 | ps = self.ps.pop(key, None) 135 | if ps: 136 | ps.unsubscribe(key) 137 | 138 | def publish(self, key, data, perm=True, send_uuid=True): 139 | if self.debounce(self.nosend[key], data): 140 | return 141 | 142 | data['user'] = self.nick 143 | data['ts'] = self.r.time()[0] 144 | if send_uuid: 145 | data['uuid'] = self.uuid 146 | data = dict((key_enc.get(k, k), v) for k, v in data.items()) 147 | data = json.dumps(data, separators=(',', ':'), sort_keys=True) 148 | if perm: 149 | self.r.rpush(key, data) 150 | self.r.publish(key, data) 151 | 152 | def push(self, key, data, send_uuid=True): 153 | if send_uuid: 154 | data['uuid'] = self.uuid 155 | data = dict((key_enc.get(k, k), v) for k, v in data.items()) 156 | data = json.dumps(data, separators=(',', ':'), sort_keys=True) 157 | self.r.lpush(key, data) 158 | -------------------------------------------------------------------------------- /comments.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from difflib import Differ 3 | 4 | def fmtuser(user): 5 | return '[{}] '.format(user) 6 | 7 | class NoChange(Exception): pass 8 | 9 | class Comments: 10 | def __init__(self): 11 | self.comments = defaultdict(dict) 12 | self.text = defaultdict(str) 13 | self.delimiter = '\x1f\n' 14 | 15 | def set(self, ea, user, cmt, timestamp): 16 | if cmt.strip(): 17 | self.comments[ea][user] = (timestamp, user, cmt) 18 | else: 19 | self.comments[ea].pop(user, None) 20 | result = str(self.delimiter.join( 21 | [''.join((fmtuser(user), cmt)) 22 | for _, user, cmt in 23 | sorted(self.comments[ea].values())])) 24 | self.text[ea] = result 25 | return result 26 | 27 | def get_comment_at_addr(self, ea): 28 | return self.text[ea] 29 | 30 | def parse_comment_update(self, ea, user, cmt): 31 | if not cmt: return '' 32 | if cmt == self.text[ea]: raise NoChange 33 | f = fmtuser(user) 34 | for cmt in cmt.split(self.delimiter): 35 | if cmt.startswith(f): 36 | new = cmt.split('] ', 1)[1] 37 | break 38 | else: 39 | # Assume new comments are always appended 40 | new = cmt.split(self.delimiter)[-1] 41 | old = self.comments[ea].get(user) 42 | if old: 43 | _, _, old = old 44 | if old.strip() == new.strip(): 45 | raise NoChange 46 | return new 47 | 48 | comments = Comments() 49 | comments_extra = Comments() 50 | 51 | if __name__ == '__main__': 52 | ts = 1 53 | def add(addr, user, comment): 54 | global ts 55 | ts += 1 56 | print('[+] {:#x} [{}] {}'.format(addr, user, comment)) 57 | comments.set(addr, user, comment, ts) 58 | print('Comment at {:#x}:\n{}'.format(addr, comments.get_comment_at_addr(addr))) 59 | print() 60 | 61 | ea = 0x1000 62 | add(ea, 'alice', 'hello from alice') 63 | add(ea, 'bob', 'hello from bob') 64 | add(ea, 'alice', 'update from alice') 65 | 66 | text = comments.get_comment_at_addr(ea) 67 | print('-'*40) 68 | split = text.split(comments.delimiter) 69 | for i, line in enumerate(split): 70 | if fmtuser('alice') in line: 71 | split[i] += ' added stuff' 72 | update = comments.delimiter.join(split) 73 | print('[-] update:\n{}'.format(update)) 74 | changed = comments.parse_comment_update(ea, 'alice', update) 75 | print('[-] changed text:\n{}'.format(changed)) 76 | print('[-] set:') 77 | add(ea, 'alice', changed) 78 | break 79 | 80 | print('-'*40) 81 | changed = comments.parse_comment_update(ea, 'alice', 'replaced all text') 82 | add(ea, 'alice', changed) 83 | 84 | print('-'*40) 85 | try: 86 | text = comments.get_comment_at_addr(ea) 87 | comments.parse_comment_update(ea, 'alice', text) 88 | print('[!] oh no, change detected!') 89 | except NoChange: 90 | print('[+] no change detected') 91 | 92 | print('-'*40) 93 | print('empty update:', repr(comments.parse_comment_update(ea, 'alice', ''))) 94 | 95 | print 96 | -------------------------------------------------------------------------------- /config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "host": "hostname", 3 | "port": 6379, 4 | "password": "password", 5 | "nick": "changeme" 6 | } 7 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import json 4 | import os 5 | 6 | try: 7 | # TODO: We can look in $HOME/.config or $HOME/.revsync or something 8 | path = os.path.dirname(os.path.abspath(__file__)) 9 | with open(os.path.join(path, "config.json"), "r") as f: 10 | config = json.loads(f.read()) 11 | except Exception: 12 | raise ImportError 13 | -------------------------------------------------------------------------------- /coverage.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import math 3 | 4 | MIN = 60 5 | HOUR = 60 * MIN 6 | DAY = 24 * HOUR 7 | 8 | # def scale_band(x): return (x ** 2) / (2 * (x ** 2 - x) + 1) 9 | scale_band = lambda x: x 10 | def log_band(n, scale=1): return math.log(n + 2) / scale 11 | 12 | bands = [scale_band(x) for x in [0, 0.1, 0.2, 0.3, 0.4, 0.5]] 13 | 14 | VISIT_SCALE = math.log(1000) 15 | def visit_band(n): 16 | if 1 <= n < 2: return bands[1] 17 | if 2 <= n < 5: return bands[2] 18 | if 5 <= n < 10: return bands[3] 19 | if 10 <= n < 20: return bands[4] 20 | if 20 <= n < 50: return bands[5] 21 | if n >= 50: return max(log_band(n, VISIT_SCALE), 0.5) 22 | return 0 23 | 24 | TIME_SCALE = math.log(48 * HOUR) 25 | def time_band(n): 26 | n /= float(MIN) 27 | if 0 <= n < 0.5: return bands[0] 28 | if 0.5 <= n < 5: return bands[1] 29 | if 5 <= n < 10: return bands[2] 30 | if 10 <= n < 20: return bands[3] 31 | if 20 <= n < 30: return bands[4] 32 | if 30 <= n < 60: return bands[5] 33 | if n >= 60: return max(log_band(n, TIME_SCALE), 0.5) 34 | return bands[0] 35 | 36 | def user_band(n): 37 | if 0 <= n < 2: return bands[0] 38 | if 2 <= n < 3: return bands[1] 39 | if 4 <= n < 6: return bands[2] 40 | if 6 <= n < 10: return bands[3] 41 | if 10 <= n < 15: return bands[4] 42 | if n >= 15: return bands[5] 43 | return bands[0] 44 | 45 | class Block: 46 | def __init__(self): 47 | self.time = 0 48 | self.visits = 0 49 | self.users = 1 50 | 51 | def dump(self): 52 | return {'l': self.time, 'v': self.visits, 'u': self.users} 53 | 54 | def add(self, block): 55 | self.time += block.time 56 | self.visits += block.visits 57 | self.users += block.users 58 | 59 | def update(self, b): 60 | self.time = int(b['l']) 61 | self.visits = int(b['v']) 62 | self.users = int(b['u']) 63 | 64 | def color(self, visits, time, users): 65 | r = g = b = 0 66 | if visits: 67 | r = visit_band(self.visits) 68 | if time: 69 | b = time_band(self.time) 70 | if users: 71 | g = user_band(self.users) 72 | if r == g == b == 0: 73 | return None 74 | 75 | # this semi-softmax hedges against the colors ending up too close together and making grey 76 | m = max((r, g, b)) 77 | r, g, b = r ** 2, g ** 2, b ** 2 78 | total = float(r + g + b) 79 | r, g, b = r / total * m, g / total * m, b / total * m 80 | return r, g, b 81 | 82 | class Blocks(defaultdict): 83 | def __init__(self): 84 | defaultdict.__init__(self, Block) 85 | 86 | def merge(self, blocks): 87 | for addr, block in blocks.items(): 88 | mine = self.get(addr, None) 89 | if mine: 90 | mine.add(block) 91 | else: 92 | self[addr] = block 93 | 94 | def update(self, blocks): 95 | for addr, data in blocks.items(): 96 | block = self[int(addr)] 97 | block.update(data) 98 | 99 | def visit(self, addr, elapsed=0, visits=0): 100 | block = self[addr] 101 | block.time += elapsed 102 | block.visits += visits 103 | 104 | class Coverage: 105 | def __init__(self): 106 | self.pending = Blocks() 107 | self.local = Blocks() 108 | self.shared = Blocks() 109 | 110 | def visit_addr(self, addr, elapsed=0, visits=0): 111 | self.pending.visit(addr, elapsed, visits) 112 | 113 | def color(self, addr, time=True, visits=True, users=True): 114 | # Sum blocks from global, local and pending for current colouring 115 | sblock = self.shared.get(addr, None) 116 | lblock = self.local.get(addr, None) 117 | pblock = self.pending.get(addr, None) 118 | if not sblock and not lblock and not pblock: 119 | return None 120 | block = Block() 121 | if sblock: block.add(sblock) 122 | if lblock: block.add(lblock) 123 | if pblock: block.add(pblock) 124 | return block.color(time=time, visits=visits, users=users) 125 | 126 | def update(self, blocks): 127 | # Update Global Coverage and Reset Local Coverage 128 | self.shared.update(blocks) 129 | self.local = Blocks() 130 | 131 | def flush(self): 132 | # Report pending coverage and merge with local 133 | pending = {addr: block.dump() for addr, block in self.pending.items()} 134 | self.local.merge(self.pending) 135 | self.pending = Blocks() 136 | return pending 137 | -------------------------------------------------------------------------------- /coverage_poll.py: -------------------------------------------------------------------------------- 1 | import redis 2 | from time import sleep 3 | from datetime import datetime, timedelta 4 | from config import config 5 | import json 6 | 7 | 8 | def rollup(rclient, key, coverage): 9 | while True: 10 | cov = rclient.rpop(key) 11 | if cov is None: 12 | return coverage 13 | cov = json.loads(cov) 14 | if "b" in cov: 15 | bbs = cov["b"] 16 | for bb in bbs.keys(): 17 | if bb not in coverage: 18 | coverage[bb] = {"v": 0, "l": 0, "u": []} 19 | coverage[bb]["v"] += bbs[bb]["v"] 20 | coverage[bb]["l"] += bbs[bb]["l"] 21 | if cov["i"] not in coverage[bb]["u"]: 22 | coverage[bb]["u"].append(cov["i"]) 23 | 24 | 25 | poll_interval = timedelta(minutes=1, seconds=30) 26 | r = redis.StrictRedis(host=config['host'], port=config['port'], password=config['password']) 27 | next_poll = datetime.now() 28 | coverage = {} 29 | while True: 30 | if datetime.now() > next_poll: 31 | print "Polling" 32 | next_poll = datetime.now() + poll_interval 33 | keys = r.keys(pattern="*_COVERAGE") 34 | for key in keys: 35 | k = key.split("_")[0] 36 | 37 | print "Retrieving Stored Results" 38 | cov = r.get("%s_STORE" % k) 39 | if cov is None: 40 | print "No Stored Results Found" 41 | cov = {} 42 | else: 43 | cov = json.loads(cov) 44 | 45 | print "Rolling Up: %s" % k 46 | cov = rollup(r, key, cov) 47 | print len(cov.keys()) 48 | 49 | print "Storing Results" 50 | r.set(name="%s_STORE" % k, value=json.dumps(cov)) 51 | 52 | print "Publish Results" 53 | for c in cov.keys(): 54 | cov[c]["u"] = len(cov[c]["u"]) 55 | data = {"c": "coverage", "u": "COV", "b": cov} 56 | data = json.dumps(data, separators=(',', ':'), sort_keys=True) 57 | r.publish(k, data) 58 | else: 59 | print "Sleep" 60 | sleep(5) 61 | -------------------------------------------------------------------------------- /example_idapythonrc.py: -------------------------------------------------------------------------------- 1 | import revsync 2 | -------------------------------------------------------------------------------- /ida_frontend.py: -------------------------------------------------------------------------------- 1 | import idaapi 2 | from idaapi import * 3 | from idc import * 4 | from idautils import * 5 | 6 | import hashlib 7 | import traceback 8 | 9 | from client import Client 10 | from config import config 11 | from comments import comments, comments_extra, NoChange 12 | 13 | 14 | ida_reserved_prefix = ( 15 | 'sub_', 'locret_', 'loc_', 'off_', 'seg_', 'asc_', 'byte_', 'word_', 16 | 'dword_', 'qword_', 'byte3_', 'xmmword_', 'ymmword_', 'packreal_', 17 | 'flt_', 'dbl_', 'tbyte_', 'stru_', 'custdata_', 'algn_', 'unk_', 18 | ) 19 | 20 | fhash = None 21 | auto_wait = False 22 | client = Client(**config) 23 | netnode = idaapi.netnode() 24 | NETNODE_NAME = '$ revsync-fhash' 25 | 26 | hook1 = hook2 = hook3 = None 27 | 28 | ### Helper Functions 29 | 30 | def cached_fhash(): 31 | return netnode.getblob(0, 'I').decode('ascii') 32 | 33 | def read_fhash(): 34 | filename = idaapi.get_root_filename() 35 | if filename is None: 36 | return None 37 | with open(filename, 'rb') as f: 38 | return hashlib.sha256(f.read()).hexdigest().upper() 39 | 40 | def get_can_addr(addr): 41 | """Convert an Effective Address to a canonical address.""" 42 | return addr - get_imagebase() 43 | 44 | def get_ea(addr): 45 | """Get Effective Address from a canonical address.""" 46 | return addr + get_imagebase() 47 | 48 | ### Redis Functions ### 49 | 50 | def onmsg_safe(key, data, replay=False): 51 | def tmp(): 52 | try: 53 | onmsg(key, data, replay=replay) 54 | except Exception as e: 55 | print('error during callback for %s: %s' % (data.get('cmd'), e)) 56 | traceback.print_exc() 57 | idaapi.execute_sync(tmp, MFF_WRITE) 58 | 59 | def onmsg(key, data, replay=False): 60 | if key != fhash or key != cached_fhash(): 61 | print('revsync: hash mismatch, dropping command') 62 | return 63 | 64 | if hook1: hook1.unhook() 65 | if hook2: hook2.unhook() 66 | if hook3: hook3.unhook() 67 | 68 | try: 69 | if 'addr' in data: 70 | ea = get_ea(data['addr']) 71 | ts = int(data.get('ts', 0)) 72 | cmd, user = data['cmd'], data['user'] 73 | if cmd == 'comment': 74 | print('revsync: <%s> %s %#x %s' % (user, cmd, ea, data['text'])) 75 | text = comments.set(ea, user, str(data['text']), ts) 76 | set_cmt(ea, text, 0) 77 | elif cmd == 'extra_comment': 78 | print('revsync: <%s> %s %#x %s' % (user, cmd, ea, data['text'])) 79 | text = comments_extra.set(ea, user, str(data['text']), ts) 80 | set_cmt(ea, text, 1) 81 | elif cmd == 'area_comment': 82 | print('revsync: <%s> %s %s %s' % (user, cmd, data['range'], data['text'])) 83 | elif cmd == 'rename': 84 | print('revsync: <%s> %s %#x %s' % (user, cmd, ea, data['text'])) 85 | set_name(ea, str(data['text']).replace(' ', '_')) 86 | elif cmd == 'join': 87 | print('revsync: <%s> joined' % (user)) 88 | elif cmd in ['stackvar_renamed', 'struc_created', 'struc_deleted', 89 | 'struc_renamed', 'struc_member_created', 'struc_member_deleted', 90 | 'struc_member_renamed', 'struc_member_changed', 'coverage']: 91 | if 'addr' in data: 92 | print('revsync: <%s> %s %#x (not supported in IDA revsync)' % (user, cmd, ea)) 93 | else: 94 | print('revsync: <%s> %s (not supported in IDA revsync)' % (user, cmd)) 95 | else: 96 | print('revsync: unknown cmd', data) 97 | finally: 98 | if hook1: hook1.hook() 99 | if hook2: hook2.hook() 100 | if hook3: hook3.hook() 101 | 102 | def publish(data, **kwargs): 103 | if not auto_is_ok(): 104 | return 105 | if fhash == netnode.getblob(0, 'I').decode('ascii'): 106 | client.publish(fhash, data, **kwargs) 107 | 108 | ### IDA Hook Classes ### 109 | 110 | def on_renamed(ea, new_name, local_name): 111 | if is_loaded(ea) and not new_name.startswith(ida_reserved_prefix): 112 | publish({'cmd': 'rename', 'addr': get_can_addr(ea), 'text': new_name}) 113 | 114 | def on_auto_empty_finally(): 115 | global auto_wait 116 | if auto_wait: 117 | auto_wait = False 118 | on_load() 119 | 120 | # These IDPHooks methods are for pre-IDA 7 121 | class IDPHooks(IDP_Hooks): 122 | def renamed(self, ea, new_name, local_name): 123 | on_renamed(ea, new_name, local_name) 124 | return IDP_Hooks.renamed(self, ea, new_name, local_name) 125 | 126 | # TODO: make sure this is on 6.1 127 | def auto_empty_finally(self): 128 | on_auto_empty_finally() 129 | return IDP_Hooks.auto_empty_finally(self) 130 | 131 | class IDBHooks(IDB_Hooks): 132 | def renamed(self, ea, new_name, local_name, old_name=None): 133 | on_renamed(ea, new_name, local_name) 134 | 135 | if (idaapi.IDA_SDK_VERSION >= 760): 136 | return IDB_Hooks.renamed(self, ea, new_name, local_name, old_name) 137 | 138 | return IDB_Hooks.renamed(self, ea, new_name, local_name) 139 | 140 | def auto_empty_finally(self): 141 | on_auto_empty_finally() 142 | return IDB_Hooks.auto_empty_finally(self) 143 | 144 | def cmt_changed(self, ea, repeatable): 145 | cmt = get_cmt(ea, repeatable) 146 | try: 147 | changed = comments.parse_comment_update(ea, client.nick, cmt) 148 | publish({'cmd': 'comment', 'addr': get_can_addr(ea), 'text': changed or ''}, send_uuid=False) 149 | except NoChange: 150 | pass 151 | return IDB_Hooks.cmt_changed(self, ea, repeatable) 152 | 153 | def extra_cmt_changed(self, ea, line_idx, repeatable): 154 | try: 155 | cmt = get_cmt(ea, repeatable) 156 | changed = comments_extra.parse_comment_update(ea, client.nick, cmt) 157 | publish({'cmd': 'extra_comment', 'addr': get_can_addr(ea), 'line': line_idx, 'text': changed or ''}, send_uuid=False) 158 | except NoChange: 159 | pass 160 | return IDB_Hooks.extra_cmt_changed(self, ea, line_idx, repeatable) 161 | 162 | def area_cmt_changed(self, cb, a, cmt, repeatable): 163 | publish({'cmd': 'area_comment', 'range': [get_can_addr(a.startEA), get_can_addr(a.endEA)], 'text': cmt or ''}, send_uuid=False) 164 | return IDB_Hooks.area_cmt_changed(self, cb, a, cmt, repeatable) 165 | 166 | class UIHooks(UI_Hooks): 167 | pass 168 | 169 | ### Setup Events ### 170 | 171 | def on_load(): 172 | global fhash 173 | if fhash: 174 | client.leave(fhash) 175 | fhash = cached_fhash() 176 | print('revsync: connecting with', fhash) 177 | client.join(fhash, onmsg_safe) 178 | 179 | def wait_for_analysis(): 180 | global auto_wait 181 | if auto_is_ok(): 182 | auto_wait = False 183 | on_load() 184 | return -1 185 | return 1000 186 | 187 | def on_open(): 188 | global auto_wait 189 | global fhash 190 | print('revsync: file opened:', idaapi.get_root_filename()) 191 | netnode.create(NETNODE_NAME) 192 | try: fhash = netnode.getblob(0, 'I').decode('ascii') 193 | except: fhash = None 194 | if not fhash: 195 | fhash = read_fhash() 196 | try: ret = netnode.setblob(fhash.encode('ascii'), 0, 'I') 197 | except: print('saving fhash failed, this will probably break revsync') 198 | 199 | if auto_is_ok(): 200 | on_load() 201 | auto_wait = False 202 | else: 203 | auto_wait = True 204 | print('revsync: waiting for auto analysis') 205 | if not hasattr(IDP_Hooks, 'auto_empty_finally'): 206 | idaapi.register_timer(1000, wait_for_analysis) 207 | 208 | def on_close(): 209 | global fhash 210 | if fhash: 211 | client.leave(fhash) 212 | fhash = None 213 | 214 | hook1 = IDPHooks() 215 | hook2 = IDBHooks() 216 | hook3 = UIHooks() 217 | 218 | def eventhook(event, old=0): 219 | if event == idaapi.NW_OPENIDB: 220 | on_open() 221 | elif event in (idaapi.NW_CLOSEIDB, idaapi.NW_TERMIDA): 222 | on_close() 223 | if event == idaapi.NW_TERMIDA: 224 | # remove hook on way out 225 | idaapi.notify_when(idaapi.NW_OPENIDB | idaapi.NW_CLOSEIDB | idaapi.NW_TERMIDA | idaapi.NW_REMOVE, eventhook) 226 | 227 | def setup(): 228 | if idaapi.get_root_filename(): 229 | on_open() 230 | else: 231 | idaapi.notify_when(idaapi.NW_OPENIDB | idaapi.NW_CLOSEIDB | idaapi.NW_TERMIDA, eventhook) 232 | return -1 233 | 234 | hook1.hook() 235 | hook2.hook() 236 | hook3.hook() 237 | idaapi.register_timer(1000, setup) 238 | print('revsync: starting setup timer') 239 | -------------------------------------------------------------------------------- /plugin.json: -------------------------------------------------------------------------------- 1 | { 2 | "pluginmetadataversion": 2, 3 | "name": "revsync", 4 | "type": ["ui"], 5 | "api": ["python2", "python3"], 6 | "description": "Realtime IDA Pro and Binary Ninja sync plugin", 7 | "longdescription": "revsync\n=======\n\nRealtime IDA Pro and Binary Ninja sync plugin\n\nSyncs:\n\n- Comments\n- Symbol names\n- Stack var names\n- Structs\n- Code coverage (how much time was spent looking at a block)\n\n\nBinary Ninja Installation\n-------------------------\n\n- Install via the Plugin Manager (CMD/CTL-SHIFT-M)\n\nor:\n\n- Clone to [your plugin folder](https://github.com/Vector35/binaryninja-api/tree/dev/python/examples#loading-plugins).\n\nThen:\n\n- Restart if required.\n- Fill in config when prompted.\n- Load your binary, wait for analysis to finish\n- Use the Tools Menu, Right-Click or command-palette (CMD/CTL-P) to trigger revsync/Load\n-Done!\n", 8 | "license": { 9 | "name": "MIT", 10 | "text": "Copyright (c) 2019 lunixbochs\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE." 11 | }, 12 | "platforms" : ["Darwin", "Linux", "Windows"], 13 | "installinstructions" : { 14 | "Darwin" : "", 15 | "Linux" : "", 16 | "Windows" : "" 17 | }, 18 | "dependencies": { 19 | }, 20 | "version": "1.0", 21 | "author": "lunixbochs", 22 | "minimumbinaryninjaversion": 1528 23 | } 24 | -------------------------------------------------------------------------------- /redis/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012 Andy McCurdy 2 | 3 | Permission is hereby granted, free of charge, to any person 4 | obtaining a copy of this software and associated documentation 5 | files (the "Software"), to deal in the Software without 6 | restriction, including without limitation the rights to use, 7 | copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following 10 | conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /redis/__init__.py: -------------------------------------------------------------------------------- 1 | from redis.client import Redis, StrictRedis 2 | from redis.connection import ( 3 | BlockingConnectionPool, 4 | ConnectionPool, 5 | Connection, 6 | SSLConnection, 7 | UnixDomainSocketConnection 8 | ) 9 | from redis.utils import from_url 10 | from redis.exceptions import ( 11 | AuthenticationError, 12 | BusyLoadingError, 13 | ConnectionError, 14 | DataError, 15 | InvalidResponse, 16 | PubSubError, 17 | ReadOnlyError, 18 | RedisError, 19 | ResponseError, 20 | TimeoutError, 21 | WatchError 22 | ) 23 | 24 | 25 | __version__ = '2.10.5' 26 | VERSION = tuple(map(int, __version__.split('.'))) 27 | 28 | __all__ = [ 29 | 'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool', 30 | 'Connection', 'SSLConnection', 'UnixDomainSocketConnection', 'from_url', 31 | 'AuthenticationError', 'BusyLoadingError', 'ConnectionError', 'DataError', 32 | 'InvalidResponse', 'PubSubError', 'ReadOnlyError', 'RedisError', 33 | 'ResponseError', 'TimeoutError', 'WatchError' 34 | ] 35 | -------------------------------------------------------------------------------- /redis/_compat.py: -------------------------------------------------------------------------------- 1 | """Internal module for Python 2 backwards compatibility.""" 2 | import errno 3 | import sys 4 | 5 | try: 6 | InterruptedError = InterruptedError 7 | except: 8 | InterruptedError = OSError 9 | 10 | # For Python older than 3.5, retry EINTR. 11 | if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and 12 | sys.version_info[1] < 5): 13 | # Adapted from https://bugs.python.org/review/23863/patch/14532/54418 14 | import socket 15 | import time 16 | import errno 17 | 18 | from select import select as _select 19 | 20 | def select(rlist, wlist, xlist, timeout): 21 | while True: 22 | try: 23 | return _select(rlist, wlist, xlist, timeout) 24 | except InterruptedError as e: 25 | # Python 2 does not define InterruptedError, instead 26 | # try to catch an OSError with errno == EINTR == 4. 27 | if getattr(e, 'errno', None) == getattr(errno, 'EINTR', 4): 28 | continue 29 | raise 30 | 31 | # Wrapper for handling interruptable system calls. 32 | def _retryable_call(s, func, *args, **kwargs): 33 | # Some modules (SSL) use the _fileobject wrapper directly and 34 | # implement a smaller portion of the socket interface, thus we 35 | # need to let them continue to do so. 36 | timeout, deadline = None, 0.0 37 | attempted = False 38 | try: 39 | timeout = s.gettimeout() 40 | except AttributeError: 41 | pass 42 | 43 | if timeout: 44 | deadline = time.time() + timeout 45 | 46 | try: 47 | while True: 48 | if attempted and timeout: 49 | now = time.time() 50 | if now >= deadline: 51 | raise socket.error(errno.EWOULDBLOCK, "timed out") 52 | else: 53 | # Overwrite the timeout on the socket object 54 | # to take into account elapsed time. 55 | s.settimeout(deadline - now) 56 | try: 57 | attempted = True 58 | return func(*args, **kwargs) 59 | except socket.error as e: 60 | if e.args[0] == errno.EINTR: 61 | continue 62 | raise 63 | finally: 64 | # Set the existing timeout back for future 65 | # calls. 66 | if timeout: 67 | s.settimeout(timeout) 68 | 69 | def recv(sock, *args, **kwargs): 70 | return _retryable_call(sock, sock.recv, *args, **kwargs) 71 | 72 | def recv_into(sock, *args, **kwargs): 73 | return _retryable_call(sock, sock.recv_into, *args, **kwargs) 74 | 75 | else: # Python 3.5 and above automatically retry EINTR 76 | from select import select 77 | 78 | def recv(sock, *args, **kwargs): 79 | return sock.recv(*args, **kwargs) 80 | 81 | def recv_into(sock, *args, **kwargs): 82 | return sock.recv_into(*args, **kwargs) 83 | 84 | if sys.version_info[0] < 3: 85 | from urllib import unquote 86 | from urlparse import parse_qs, urlparse 87 | from itertools import imap, izip 88 | from string import letters as ascii_letters 89 | from Queue import Queue 90 | try: 91 | from cStringIO import StringIO as BytesIO 92 | except ImportError: 93 | from StringIO import StringIO as BytesIO 94 | 95 | # special unicode handling for python2 to avoid UnicodeDecodeError 96 | def safe_unicode(obj, *args): 97 | """ return the unicode representation of obj """ 98 | try: 99 | return unicode(obj, *args) 100 | except UnicodeDecodeError: 101 | # obj is byte string 102 | ascii_text = str(obj).encode('string_escape') 103 | return unicode(ascii_text) 104 | 105 | def iteritems(x): 106 | return x.iteritems() 107 | 108 | def iterkeys(x): 109 | return x.iterkeys() 110 | 111 | def itervalues(x): 112 | return x.itervalues() 113 | 114 | def nativestr(x): 115 | return x if isinstance(x, str) else x.encode('utf-8', 'replace') 116 | 117 | def u(x): 118 | return x.decode() 119 | 120 | def b(x): 121 | return x 122 | 123 | def next(x): 124 | return x.next() 125 | 126 | def byte_to_chr(x): 127 | return x 128 | 129 | unichr = unichr 130 | xrange = xrange 131 | basestring = basestring 132 | unicode = unicode 133 | bytes = str 134 | long = long 135 | else: 136 | from urllib.parse import parse_qs, unquote, urlparse 137 | from io import BytesIO 138 | from string import ascii_letters 139 | from queue import Queue 140 | 141 | def iteritems(x): 142 | return iter(x.items()) 143 | 144 | def iterkeys(x): 145 | return iter(x.keys()) 146 | 147 | def itervalues(x): 148 | return iter(x.values()) 149 | 150 | def byte_to_chr(x): 151 | return chr(x) 152 | 153 | def nativestr(x): 154 | return x if isinstance(x, str) else x.decode('utf-8', 'replace') 155 | 156 | def u(x): 157 | return x 158 | 159 | def b(x): 160 | return x.encode('latin-1') if not isinstance(x, bytes) else x 161 | 162 | next = next 163 | unichr = chr 164 | imap = map 165 | izip = zip 166 | xrange = range 167 | basestring = str 168 | unicode = str 169 | safe_unicode = str 170 | bytes = bytes 171 | long = int 172 | 173 | try: # Python 3 174 | from queue import LifoQueue, Empty, Full 175 | except ImportError: 176 | from Queue import Empty, Full 177 | try: # Python 2.6 - 2.7 178 | from Queue import LifoQueue 179 | except ImportError: # Python 2.5 180 | from Queue import Queue 181 | # From the Python 2.7 lib. Python 2.5 already extracted the core 182 | # methods to aid implementating different queue organisations. 183 | 184 | class LifoQueue(Queue): 185 | "Override queue methods to implement a last-in first-out queue." 186 | 187 | def _init(self, maxsize): 188 | self.maxsize = maxsize 189 | self.queue = [] 190 | 191 | def _qsize(self, len=len): 192 | return len(self.queue) 193 | 194 | def _put(self, item): 195 | self.queue.append(item) 196 | 197 | def _get(self): 198 | return self.queue.pop() 199 | -------------------------------------------------------------------------------- /redis/connection.py: -------------------------------------------------------------------------------- 1 | from __future__ import with_statement 2 | from distutils.version import StrictVersion 3 | from itertools import chain 4 | import os 5 | import socket 6 | import sys 7 | import threading 8 | import warnings 9 | 10 | try: 11 | import ssl 12 | ssl_available = True 13 | except ImportError: 14 | ssl_available = False 15 | 16 | from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, 17 | BytesIO, nativestr, basestring, iteritems, 18 | LifoQueue, Empty, Full, urlparse, parse_qs, 19 | recv, recv_into, select, unquote) 20 | from redis.exceptions import ( 21 | RedisError, 22 | ConnectionError, 23 | TimeoutError, 24 | BusyLoadingError, 25 | ResponseError, 26 | InvalidResponse, 27 | AuthenticationError, 28 | NoScriptError, 29 | ExecAbortError, 30 | ReadOnlyError 31 | ) 32 | from redis.utils import HIREDIS_AVAILABLE 33 | if HIREDIS_AVAILABLE: 34 | import hiredis 35 | 36 | hiredis_version = StrictVersion(hiredis.__version__) 37 | HIREDIS_SUPPORTS_CALLABLE_ERRORS = \ 38 | hiredis_version >= StrictVersion('0.1.3') 39 | HIREDIS_SUPPORTS_BYTE_BUFFER = \ 40 | hiredis_version >= StrictVersion('0.1.4') 41 | 42 | if not HIREDIS_SUPPORTS_BYTE_BUFFER: 43 | msg = ("redis-py works best with hiredis >= 0.1.4. You're running " 44 | "hiredis %s. Please consider upgrading." % hiredis.__version__) 45 | warnings.warn(msg) 46 | 47 | HIREDIS_USE_BYTE_BUFFER = True 48 | # only use byte buffer if hiredis supports it and the Python version 49 | # is >= 2.7 50 | if not HIREDIS_SUPPORTS_BYTE_BUFFER or ( 51 | sys.version_info[0] == 2 and sys.version_info[1] < 7): 52 | HIREDIS_USE_BYTE_BUFFER = False 53 | 54 | SYM_STAR = b('*') 55 | SYM_DOLLAR = b('$') 56 | SYM_CRLF = b('\r\n') 57 | SYM_EMPTY = b('') 58 | 59 | SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." 60 | 61 | 62 | class Token(object): 63 | """ 64 | Literal strings in Redis commands, such as the command names and any 65 | hard-coded arguments are wrapped in this class so we know not to apply 66 | and encoding rules on them. 67 | """ 68 | 69 | _cache = {} 70 | 71 | @classmethod 72 | def get_token(cls, value): 73 | "Gets a cached token object or creates a new one if not already cached" 74 | 75 | # Use try/except because after running for a short time most tokens 76 | # should already be cached 77 | try: 78 | return cls._cache[value] 79 | except KeyError: 80 | token = Token(value) 81 | cls._cache[value] = token 82 | return token 83 | 84 | def __init__(self, value): 85 | if isinstance(value, Token): 86 | value = value.value 87 | self.value = value 88 | self.encoded_value = b(value) 89 | 90 | def __repr__(self): 91 | return self.value 92 | 93 | def __str__(self): 94 | return self.value 95 | 96 | 97 | class BaseParser(object): 98 | EXCEPTION_CLASSES = { 99 | 'ERR': { 100 | 'max number of clients reached': ConnectionError 101 | }, 102 | 'EXECABORT': ExecAbortError, 103 | 'LOADING': BusyLoadingError, 104 | 'NOSCRIPT': NoScriptError, 105 | 'READONLY': ReadOnlyError, 106 | } 107 | 108 | def parse_error(self, response): 109 | "Parse an error response" 110 | error_code = response.split(' ')[0] 111 | if error_code in self.EXCEPTION_CLASSES: 112 | response = response[len(error_code) + 1:] 113 | exception_class = self.EXCEPTION_CLASSES[error_code] 114 | if isinstance(exception_class, dict): 115 | exception_class = exception_class.get(response, ResponseError) 116 | return exception_class(response) 117 | return ResponseError(response) 118 | 119 | 120 | class SocketBuffer(object): 121 | def __init__(self, socket, socket_read_size): 122 | self._sock = socket 123 | self.socket_read_size = socket_read_size 124 | self._buffer = BytesIO() 125 | # number of bytes written to the buffer from the socket 126 | self.bytes_written = 0 127 | # number of bytes read from the buffer 128 | self.bytes_read = 0 129 | 130 | @property 131 | def length(self): 132 | return self.bytes_written - self.bytes_read 133 | 134 | def _read_from_socket(self, length=None): 135 | socket_read_size = self.socket_read_size 136 | buf = self._buffer 137 | buf.seek(self.bytes_written) 138 | marker = 0 139 | 140 | try: 141 | while True: 142 | data = recv(self._sock, socket_read_size) 143 | # an empty string indicates the server shutdown the socket 144 | if isinstance(data, bytes) and len(data) == 0: 145 | raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) 146 | buf.write(data) 147 | data_length = len(data) 148 | self.bytes_written += data_length 149 | marker += data_length 150 | 151 | if length is not None and length > marker: 152 | continue 153 | break 154 | except socket.timeout: 155 | raise TimeoutError("Timeout reading from socket") 156 | except socket.error: 157 | e = sys.exc_info()[1] 158 | raise ConnectionError("Error while reading from socket: %s" % 159 | (e.args,)) 160 | 161 | def read(self, length): 162 | length = length + 2 # make sure to read the \r\n terminator 163 | # make sure we've read enough data from the socket 164 | if length > self.length: 165 | self._read_from_socket(length - self.length) 166 | 167 | self._buffer.seek(self.bytes_read) 168 | data = self._buffer.read(length) 169 | self.bytes_read += len(data) 170 | 171 | # purge the buffer when we've consumed it all so it doesn't 172 | # grow forever 173 | if self.bytes_read == self.bytes_written: 174 | self.purge() 175 | 176 | return data[:-2] 177 | 178 | def readline(self): 179 | buf = self._buffer 180 | buf.seek(self.bytes_read) 181 | data = buf.readline() 182 | while not data.endswith(SYM_CRLF): 183 | # there's more data in the socket that we need 184 | self._read_from_socket() 185 | buf.seek(self.bytes_read) 186 | data = buf.readline() 187 | 188 | self.bytes_read += len(data) 189 | 190 | # purge the buffer when we've consumed it all so it doesn't 191 | # grow forever 192 | if self.bytes_read == self.bytes_written: 193 | self.purge() 194 | 195 | return data[:-2] 196 | 197 | def purge(self): 198 | self._buffer.seek(0) 199 | self._buffer.truncate() 200 | self.bytes_written = 0 201 | self.bytes_read = 0 202 | 203 | def close(self): 204 | try: 205 | self.purge() 206 | self._buffer.close() 207 | except: 208 | # issue #633 suggests the purge/close somehow raised a 209 | # BadFileDescriptor error. Perhaps the client ran out of 210 | # memory or something else? It's probably OK to ignore 211 | # any error being raised from purge/close since we're 212 | # removing the reference to the instance below. 213 | pass 214 | self._buffer = None 215 | self._sock = None 216 | 217 | 218 | class PythonParser(BaseParser): 219 | "Plain Python parsing class" 220 | encoding = None 221 | 222 | def __init__(self, socket_read_size): 223 | self.socket_read_size = socket_read_size 224 | self._sock = None 225 | self._buffer = None 226 | 227 | def __del__(self): 228 | try: 229 | self.on_disconnect() 230 | except Exception: 231 | pass 232 | 233 | def on_connect(self, connection): 234 | "Called when the socket connects" 235 | self._sock = connection._sock 236 | self._buffer = SocketBuffer(self._sock, self.socket_read_size) 237 | if connection.decode_responses: 238 | self.encoding = connection.encoding 239 | 240 | def on_disconnect(self): 241 | "Called when the socket disconnects" 242 | if self._sock is not None: 243 | self._sock.close() 244 | self._sock = None 245 | if self._buffer is not None: 246 | self._buffer.close() 247 | self._buffer = None 248 | self.encoding = None 249 | 250 | def can_read(self): 251 | return self._buffer and bool(self._buffer.length) 252 | 253 | def read_response(self): 254 | response = self._buffer.readline() 255 | if not response: 256 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 257 | 258 | byte, response = byte_to_chr(response[0]), response[1:] 259 | 260 | if byte not in ('-', '+', ':', '$', '*'): 261 | raise InvalidResponse("Protocol Error: %s, %s" % 262 | (str(byte), str(response))) 263 | 264 | # server returned an error 265 | if byte == '-': 266 | response = nativestr(response) 267 | error = self.parse_error(response) 268 | # if the error is a ConnectionError, raise immediately so the user 269 | # is notified 270 | if isinstance(error, ConnectionError): 271 | raise error 272 | # otherwise, we're dealing with a ResponseError that might belong 273 | # inside a pipeline response. the connection's read_response() 274 | # and/or the pipeline's execute() will raise this error if 275 | # necessary, so just return the exception instance here. 276 | return error 277 | # single value 278 | elif byte == '+': 279 | pass 280 | # int value 281 | elif byte == ':': 282 | response = long(response) 283 | # bulk response 284 | elif byte == '$': 285 | length = int(response) 286 | if length == -1: 287 | return None 288 | response = self._buffer.read(length) 289 | # multi-bulk response 290 | elif byte == '*': 291 | length = int(response) 292 | if length == -1: 293 | return None 294 | response = [self.read_response() for i in xrange(length)] 295 | if isinstance(response, bytes) and self.encoding: 296 | response = response.decode(self.encoding) 297 | return response 298 | 299 | 300 | class HiredisParser(BaseParser): 301 | "Parser class for connections using Hiredis" 302 | def __init__(self, socket_read_size): 303 | if not HIREDIS_AVAILABLE: 304 | raise RedisError("Hiredis is not installed") 305 | self.socket_read_size = socket_read_size 306 | 307 | if HIREDIS_USE_BYTE_BUFFER: 308 | self._buffer = bytearray(socket_read_size) 309 | 310 | def __del__(self): 311 | try: 312 | self.on_disconnect() 313 | except Exception: 314 | pass 315 | 316 | def on_connect(self, connection): 317 | self._sock = connection._sock 318 | kwargs = { 319 | 'protocolError': InvalidResponse, 320 | 'replyError': self.parse_error, 321 | } 322 | 323 | # hiredis < 0.1.3 doesn't support functions that create exceptions 324 | if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: 325 | kwargs['replyError'] = ResponseError 326 | 327 | if connection.decode_responses: 328 | kwargs['encoding'] = connection.encoding 329 | self._reader = hiredis.Reader(**kwargs) 330 | self._next_response = False 331 | 332 | def on_disconnect(self): 333 | self._sock = None 334 | self._reader = None 335 | self._next_response = False 336 | 337 | def can_read(self): 338 | if not self._reader: 339 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 340 | 341 | if self._next_response is False: 342 | self._next_response = self._reader.gets() 343 | return self._next_response is not False 344 | 345 | def read_response(self): 346 | if not self._reader: 347 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 348 | 349 | # _next_response might be cached from a can_read() call 350 | if self._next_response is not False: 351 | response = self._next_response 352 | self._next_response = False 353 | return response 354 | 355 | response = self._reader.gets() 356 | socket_read_size = self.socket_read_size 357 | while response is False: 358 | try: 359 | if HIREDIS_USE_BYTE_BUFFER: 360 | bufflen = recv_into(self._sock, self._buffer) 361 | if bufflen == 0: 362 | raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) 363 | else: 364 | buffer = recv(self._sock, socket_read_size) 365 | # an empty string indicates the server shutdown the socket 366 | if not isinstance(buffer, bytes) or len(buffer) == 0: 367 | raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) 368 | except socket.timeout: 369 | raise TimeoutError("Timeout reading from socket") 370 | except socket.error: 371 | e = sys.exc_info()[1] 372 | raise ConnectionError("Error while reading from socket: %s" % 373 | (e.args,)) 374 | if HIREDIS_USE_BYTE_BUFFER: 375 | self._reader.feed(self._buffer, 0, bufflen) 376 | else: 377 | self._reader.feed(buffer) 378 | response = self._reader.gets() 379 | # if an older version of hiredis is installed, we need to attempt 380 | # to convert ResponseErrors to their appropriate types. 381 | if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: 382 | if isinstance(response, ResponseError): 383 | response = self.parse_error(response.args[0]) 384 | elif isinstance(response, list) and response and \ 385 | isinstance(response[0], ResponseError): 386 | response[0] = self.parse_error(response[0].args[0]) 387 | # if the response is a ConnectionError or the response is a list and 388 | # the first item is a ConnectionError, raise it as something bad 389 | # happened 390 | if isinstance(response, ConnectionError): 391 | raise response 392 | elif isinstance(response, list) and response and \ 393 | isinstance(response[0], ConnectionError): 394 | raise response[0] 395 | return response 396 | 397 | if HIREDIS_AVAILABLE: 398 | DefaultParser = HiredisParser 399 | else: 400 | DefaultParser = PythonParser 401 | 402 | 403 | class Connection(object): 404 | "Manages TCP communication to and from a Redis server" 405 | description_format = "Connection" 406 | 407 | def __init__(self, host='localhost', port=6379, db=0, password=None, 408 | socket_timeout=None, socket_connect_timeout=None, 409 | socket_keepalive=False, socket_keepalive_options=None, 410 | retry_on_timeout=False, encoding='utf-8', 411 | encoding_errors='strict', decode_responses=False, 412 | parser_class=DefaultParser, socket_read_size=65536): 413 | self.pid = os.getpid() 414 | self.host = host 415 | self.port = int(port) 416 | self.db = db 417 | self.password = password 418 | self.socket_timeout = socket_timeout 419 | self.socket_connect_timeout = socket_connect_timeout or socket_timeout 420 | self.socket_keepalive = socket_keepalive 421 | self.socket_keepalive_options = socket_keepalive_options or {} 422 | self.retry_on_timeout = retry_on_timeout 423 | self.encoding = encoding 424 | self.encoding_errors = encoding_errors 425 | self.decode_responses = decode_responses 426 | self._sock = None 427 | self._parser = parser_class(socket_read_size=socket_read_size) 428 | self._description_args = { 429 | 'host': self.host, 430 | 'port': self.port, 431 | 'db': self.db, 432 | } 433 | self._connect_callbacks = [] 434 | 435 | def __repr__(self): 436 | return self.description_format % self._description_args 437 | 438 | def __del__(self): 439 | try: 440 | self.disconnect() 441 | except Exception: 442 | pass 443 | 444 | def register_connect_callback(self, callback): 445 | self._connect_callbacks.append(callback) 446 | 447 | def clear_connect_callbacks(self): 448 | self._connect_callbacks = [] 449 | 450 | def connect(self): 451 | "Connects to the Redis server if not already connected" 452 | if self._sock: 453 | return 454 | try: 455 | sock = self._connect() 456 | except socket.timeout: 457 | raise TimeoutError("Timeout connecting to server") 458 | except socket.error: 459 | e = sys.exc_info()[1] 460 | raise ConnectionError(self._error_message(e)) 461 | 462 | self._sock = sock 463 | try: 464 | self.on_connect() 465 | except RedisError: 466 | # clean up after any error in on_connect 467 | self.disconnect() 468 | raise 469 | 470 | # run any user callbacks. right now the only internal callback 471 | # is for pubsub channel/pattern resubscription 472 | for callback in self._connect_callbacks: 473 | callback(self) 474 | 475 | def _connect(self): 476 | "Create a TCP socket connection" 477 | # we want to mimic what socket.create_connection does to support 478 | # ipv4/ipv6, but we want to set options prior to calling 479 | # socket.connect() 480 | err = None 481 | for res in socket.getaddrinfo(self.host, self.port, 0, 482 | socket.SOCK_STREAM): 483 | family, socktype, proto, canonname, socket_address = res 484 | sock = None 485 | try: 486 | sock = socket.socket(family, socktype, proto) 487 | # TCP_NODELAY 488 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 489 | 490 | # TCP_KEEPALIVE 491 | if self.socket_keepalive: 492 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 493 | for k, v in iteritems(self.socket_keepalive_options): 494 | sock.setsockopt(socket.SOL_TCP, k, v) 495 | 496 | # set the socket_connect_timeout before we connect 497 | sock.settimeout(self.socket_connect_timeout) 498 | 499 | # connect 500 | sock.connect(socket_address) 501 | 502 | # set the socket_timeout now that we're connected 503 | sock.settimeout(self.socket_timeout) 504 | return sock 505 | 506 | except socket.error as _: 507 | err = _ 508 | if sock is not None: 509 | sock.close() 510 | 511 | if err is not None: 512 | raise err 513 | raise socket.error("socket.getaddrinfo returned an empty list") 514 | 515 | def _error_message(self, exception): 516 | # args for socket.error can either be (errno, "message") 517 | # or just "message" 518 | if len(exception.args) == 1: 519 | return "Error connecting to %s:%s. %s." % \ 520 | (self.host, self.port, exception.args[0]) 521 | else: 522 | return "Error %s connecting to %s:%s. %s." % \ 523 | (exception.args[0], self.host, self.port, exception.args[1]) 524 | 525 | def on_connect(self): 526 | "Initialize the connection, authenticate and select a database" 527 | self._parser.on_connect(self) 528 | 529 | # if a password is specified, authenticate 530 | if self.password: 531 | self.send_command('AUTH', self.password) 532 | if nativestr(self.read_response()) != 'OK': 533 | raise AuthenticationError('Invalid Password') 534 | 535 | # if a database is specified, switch to it 536 | if self.db: 537 | self.send_command('SELECT', self.db) 538 | if nativestr(self.read_response()) != 'OK': 539 | raise ConnectionError('Invalid Database') 540 | 541 | def disconnect(self): 542 | "Disconnects from the Redis server" 543 | self._parser.on_disconnect() 544 | if self._sock is None: 545 | return 546 | try: 547 | self._sock.shutdown(socket.SHUT_RDWR) 548 | self._sock.close() 549 | except socket.error: 550 | pass 551 | self._sock = None 552 | 553 | def send_packed_command(self, command): 554 | "Send an already packed command to the Redis server" 555 | if not self._sock: 556 | self.connect() 557 | try: 558 | if isinstance(command, str): 559 | command = [command] 560 | for item in command: 561 | self._sock.sendall(item) 562 | except socket.timeout: 563 | self.disconnect() 564 | raise TimeoutError("Timeout writing to socket") 565 | except socket.error: 566 | e = sys.exc_info()[1] 567 | self.disconnect() 568 | if len(e.args) == 1: 569 | errno, errmsg = 'UNKNOWN', e.args[0] 570 | else: 571 | errno = e.args[0] 572 | errmsg = e.args[1] 573 | raise ConnectionError("Error %s while writing to socket. %s." % 574 | (errno, errmsg)) 575 | except: 576 | self.disconnect() 577 | raise 578 | 579 | def send_command(self, *args): 580 | "Pack and send a command to the Redis server" 581 | self.send_packed_command(self.pack_command(*args)) 582 | 583 | def can_read(self, timeout=0): 584 | "Poll the socket to see if there's data that can be read." 585 | sock = self._sock 586 | if not sock: 587 | self.connect() 588 | sock = self._sock 589 | return self._parser.can_read() or \ 590 | bool(select([sock], [], [], timeout)[0]) 591 | 592 | def read_response(self): 593 | "Read the response from a previously sent command" 594 | try: 595 | response = self._parser.read_response() 596 | except: 597 | self.disconnect() 598 | raise 599 | if isinstance(response, ResponseError): 600 | raise response 601 | return response 602 | 603 | def encode(self, value): 604 | "Return a bytestring representation of the value" 605 | if isinstance(value, Token): 606 | return value.encoded_value 607 | elif isinstance(value, bytes): 608 | return value 609 | elif isinstance(value, (int, long)): 610 | value = b(str(value)) 611 | elif isinstance(value, float): 612 | value = b(repr(value)) 613 | elif not isinstance(value, basestring): 614 | value = unicode(value) 615 | if isinstance(value, unicode): 616 | value = value.encode(self.encoding, self.encoding_errors) 617 | return value 618 | 619 | def pack_command(self, *args): 620 | "Pack a series of arguments into the Redis protocol" 621 | output = [] 622 | # the client might have included 1 or more literal arguments in 623 | # the command name, e.g., 'CONFIG GET'. The Redis server expects these 624 | # arguments to be sent separately, so split the first argument 625 | # manually. All of these arguements get wrapped in the Token class 626 | # to prevent them from being encoded. 627 | command = args[0] 628 | if ' ' in command: 629 | args = tuple([Token.get_token(s) 630 | for s in command.split()]) + args[1:] 631 | else: 632 | args = (Token.get_token(command),) + args[1:] 633 | 634 | buff = SYM_EMPTY.join( 635 | (SYM_STAR, b(str(len(args))), SYM_CRLF)) 636 | 637 | for arg in imap(self.encode, args): 638 | # to avoid large string mallocs, chunk the command into the 639 | # output list if we're sending large values 640 | if len(buff) > 6000 or len(arg) > 6000: 641 | buff = SYM_EMPTY.join( 642 | (buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) 643 | output.append(buff) 644 | output.append(arg) 645 | buff = SYM_CRLF 646 | else: 647 | buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), 648 | SYM_CRLF, arg, SYM_CRLF)) 649 | output.append(buff) 650 | return output 651 | 652 | def pack_commands(self, commands): 653 | "Pack multiple commands into the Redis protocol" 654 | output = [] 655 | pieces = [] 656 | buffer_length = 0 657 | 658 | for cmd in commands: 659 | for chunk in self.pack_command(*cmd): 660 | pieces.append(chunk) 661 | buffer_length += len(chunk) 662 | 663 | if buffer_length > 6000: 664 | output.append(SYM_EMPTY.join(pieces)) 665 | buffer_length = 0 666 | pieces = [] 667 | 668 | if pieces: 669 | output.append(SYM_EMPTY.join(pieces)) 670 | return output 671 | 672 | 673 | class SSLConnection(Connection): 674 | description_format = "SSLConnection" 675 | 676 | def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, 677 | ssl_ca_certs=None, **kwargs): 678 | if not ssl_available: 679 | raise RedisError("Python wasn't built with SSL support") 680 | 681 | super(SSLConnection, self).__init__(**kwargs) 682 | 683 | self.keyfile = ssl_keyfile 684 | self.certfile = ssl_certfile 685 | if ssl_cert_reqs is None: 686 | ssl_cert_reqs = ssl.CERT_NONE 687 | elif isinstance(ssl_cert_reqs, basestring): 688 | CERT_REQS = { 689 | 'none': ssl.CERT_NONE, 690 | 'optional': ssl.CERT_OPTIONAL, 691 | 'required': ssl.CERT_REQUIRED 692 | } 693 | if ssl_cert_reqs not in CERT_REQS: 694 | raise RedisError( 695 | "Invalid SSL Certificate Requirements Flag: %s" % 696 | ssl_cert_reqs) 697 | ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 698 | self.cert_reqs = ssl_cert_reqs 699 | self.ca_certs = ssl_ca_certs 700 | 701 | def _connect(self): 702 | "Wrap the socket with SSL support" 703 | sock = super(SSLConnection, self)._connect() 704 | sock = ssl.wrap_socket(sock, 705 | cert_reqs=self.cert_reqs, 706 | keyfile=self.keyfile, 707 | certfile=self.certfile, 708 | ca_certs=self.ca_certs) 709 | return sock 710 | 711 | 712 | class UnixDomainSocketConnection(Connection): 713 | description_format = "UnixDomainSocketConnection" 714 | 715 | def __init__(self, path='', db=0, password=None, 716 | socket_timeout=None, encoding='utf-8', 717 | encoding_errors='strict', decode_responses=False, 718 | retry_on_timeout=False, 719 | parser_class=DefaultParser, socket_read_size=65536): 720 | self.pid = os.getpid() 721 | self.path = path 722 | self.db = db 723 | self.password = password 724 | self.socket_timeout = socket_timeout 725 | self.retry_on_timeout = retry_on_timeout 726 | self.encoding = encoding 727 | self.encoding_errors = encoding_errors 728 | self.decode_responses = decode_responses 729 | self._sock = None 730 | self._parser = parser_class(socket_read_size=socket_read_size) 731 | self._description_args = { 732 | 'path': self.path, 733 | 'db': self.db, 734 | } 735 | self._connect_callbacks = [] 736 | 737 | def _connect(self): 738 | "Create a Unix domain socket connection" 739 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 740 | sock.settimeout(self.socket_timeout) 741 | sock.connect(self.path) 742 | return sock 743 | 744 | def _error_message(self, exception): 745 | # args for socket.error can either be (errno, "message") 746 | # or just "message" 747 | if len(exception.args) == 1: 748 | return "Error connecting to unix socket: %s. %s." % \ 749 | (self.path, exception.args[0]) 750 | else: 751 | return "Error %s connecting to unix socket: %s. %s." % \ 752 | (exception.args[0], self.path, exception.args[1]) 753 | 754 | 755 | FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO') 756 | 757 | 758 | def to_bool(value): 759 | if value is None or value == '': 760 | return None 761 | if isinstance(value, basestring) and value.upper() in FALSE_STRINGS: 762 | return False 763 | return bool(value) 764 | 765 | 766 | URL_QUERY_ARGUMENT_PARSERS = { 767 | 'socket_timeout': float, 768 | 'socket_connect_timeout': float, 769 | 'socket_keepalive': to_bool, 770 | 'retry_on_timeout': to_bool 771 | } 772 | 773 | 774 | class ConnectionPool(object): 775 | "Generic connection pool" 776 | @classmethod 777 | def from_url(cls, url, db=None, decode_components=False, **kwargs): 778 | """ 779 | Return a connection pool configured from the given URL. 780 | 781 | For example:: 782 | 783 | redis://[:password]@localhost:6379/0 784 | rediss://[:password]@localhost:6379/0 785 | unix://[:password]@/path/to/socket.sock?db=0 786 | 787 | Three URL schemes are supported: 788 | 789 | - ```redis://`` 790 | `_ creates a 791 | normal TCP socket connection 792 | - ```rediss://`` 793 | `_ creates a 794 | SSL wrapped TCP socket connection 795 | - ``unix://`` creates a Unix Domain Socket connection 796 | 797 | There are several ways to specify a database number. The parse function 798 | will return the first specified option: 799 | 1. A ``db`` querystring option, e.g. redis://localhost?db=0 800 | 2. If using the redis:// scheme, the path argument of the url, e.g. 801 | redis://localhost/0 802 | 3. The ``db`` argument to this function. 803 | 804 | If none of these options are specified, db=0 is used. 805 | 806 | The ``decode_components`` argument allows this function to work with 807 | percent-encoded URLs. If this argument is set to ``True`` all ``%xx`` 808 | escapes will be replaced by their single-character equivalents after 809 | the URL has been parsed. This only applies to the ``hostname``, 810 | ``path``, and ``password`` components. 811 | 812 | Any additional querystring arguments and keyword arguments will be 813 | passed along to the ConnectionPool class's initializer. The querystring 814 | arguments ``socket_connect_timeout`` and ``socket_timeout`` if supplied 815 | are parsed as float values. The arguments ``socket_keepalive`` and 816 | ``retry_on_timeout`` are parsed to boolean values that accept 817 | True/False, Yes/No values to indicate state. Invalid types cause a 818 | ``UserWarning`` to be raised. In the case of conflicting arguments, 819 | querystring arguments always win. 820 | """ 821 | url_string = url 822 | url = urlparse(url) 823 | qs = '' 824 | 825 | # in python2.6, custom URL schemes don't recognize querystring values 826 | # they're left as part of the url.path. 827 | if '?' in url.path and not url.query: 828 | # chop the querystring including the ? off the end of the url 829 | # and reparse it. 830 | qs = url.path.split('?', 1)[1] 831 | url = urlparse(url_string[:-(len(qs) + 1)]) 832 | else: 833 | qs = url.query 834 | 835 | url_options = {} 836 | 837 | for name, value in iteritems(parse_qs(qs)): 838 | if value and len(value) > 0: 839 | parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 840 | if parser: 841 | try: 842 | url_options[name] = parser(value[0]) 843 | except (TypeError, ValueError): 844 | warnings.warn(UserWarning( 845 | "Invalid value for `%s` in connection URL." % name 846 | )) 847 | else: 848 | url_options[name] = value[0] 849 | 850 | if decode_components: 851 | password = unquote(url.password) if url.password else None 852 | path = unquote(url.path) if url.path else None 853 | hostname = unquote(url.hostname) if url.hostname else None 854 | else: 855 | password = url.password 856 | path = url.path 857 | hostname = url.hostname 858 | 859 | # We only support redis:// and unix:// schemes. 860 | if url.scheme == 'unix': 861 | url_options.update({ 862 | 'password': password, 863 | 'path': path, 864 | 'connection_class': UnixDomainSocketConnection, 865 | }) 866 | 867 | else: 868 | url_options.update({ 869 | 'host': hostname, 870 | 'port': int(url.port or 6379), 871 | 'password': password, 872 | }) 873 | 874 | # If there's a path argument, use it as the db argument if a 875 | # querystring value wasn't specified 876 | if 'db' not in url_options and path: 877 | try: 878 | url_options['db'] = int(path.replace('/', '')) 879 | except (AttributeError, ValueError): 880 | pass 881 | 882 | if url.scheme == 'rediss': 883 | url_options['connection_class'] = SSLConnection 884 | 885 | # last shot at the db value 886 | url_options['db'] = int(url_options.get('db', db or 0)) 887 | 888 | # update the arguments from the URL values 889 | kwargs.update(url_options) 890 | 891 | # backwards compatability 892 | if 'charset' in kwargs: 893 | warnings.warn(DeprecationWarning( 894 | '"charset" is deprecated. Use "encoding" instead')) 895 | kwargs['encoding'] = kwargs.pop('charset') 896 | if 'errors' in kwargs: 897 | warnings.warn(DeprecationWarning( 898 | '"errors" is deprecated. Use "encoding_errors" instead')) 899 | kwargs['encoding_errors'] = kwargs.pop('errors') 900 | 901 | return cls(**kwargs) 902 | 903 | def __init__(self, connection_class=Connection, max_connections=None, 904 | **connection_kwargs): 905 | """ 906 | Create a connection pool. If max_connections is set, then this 907 | object raises redis.ConnectionError when the pool's limit is reached. 908 | 909 | By default, TCP connections are created unless connection_class is 910 | specified. Use redis.UnixDomainSocketConnection for unix sockets. 911 | 912 | Any additional keyword arguments are passed to the constructor of 913 | connection_class. 914 | """ 915 | max_connections = max_connections or 2 ** 31 916 | if not isinstance(max_connections, (int, long)) or max_connections < 0: 917 | raise ValueError('"max_connections" must be a positive integer') 918 | 919 | self.connection_class = connection_class 920 | self.connection_kwargs = connection_kwargs 921 | self.max_connections = max_connections 922 | 923 | self.reset() 924 | 925 | def __repr__(self): 926 | return "%s<%s>" % ( 927 | type(self).__name__, 928 | self.connection_class.description_format % self.connection_kwargs, 929 | ) 930 | 931 | def reset(self): 932 | self.pid = os.getpid() 933 | self._created_connections = 0 934 | self._available_connections = [] 935 | self._in_use_connections = set() 936 | self._check_lock = threading.Lock() 937 | 938 | def _checkpid(self): 939 | if self.pid != os.getpid(): 940 | with self._check_lock: 941 | if self.pid == os.getpid(): 942 | # another thread already did the work while we waited 943 | # on the lock. 944 | return 945 | self.disconnect() 946 | self.reset() 947 | 948 | def get_connection(self, command_name, *keys, **options): 949 | "Get a connection from the pool" 950 | self._checkpid() 951 | try: 952 | connection = self._available_connections.pop() 953 | except IndexError: 954 | connection = self.make_connection() 955 | self._in_use_connections.add(connection) 956 | return connection 957 | 958 | def make_connection(self): 959 | "Create a new connection" 960 | if self._created_connections >= self.max_connections: 961 | raise ConnectionError("Too many connections") 962 | self._created_connections += 1 963 | return self.connection_class(**self.connection_kwargs) 964 | 965 | def release(self, connection): 966 | "Releases the connection back to the pool" 967 | self._checkpid() 968 | if connection.pid != self.pid: 969 | return 970 | self._in_use_connections.remove(connection) 971 | self._available_connections.append(connection) 972 | 973 | def disconnect(self): 974 | "Disconnects all connections in the pool" 975 | all_conns = chain(self._available_connections, 976 | self._in_use_connections) 977 | for connection in all_conns: 978 | connection.disconnect() 979 | 980 | 981 | class BlockingConnectionPool(ConnectionPool): 982 | """ 983 | Thread-safe blocking connection pool:: 984 | 985 | >>> from redis.client import Redis 986 | >>> client = Redis(connection_pool=BlockingConnectionPool()) 987 | 988 | It performs the same function as the default 989 | ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that, 990 | it maintains a pool of reusable connections that can be shared by 991 | multiple redis clients (safely across threads if required). 992 | 993 | The difference is that, in the event that a client tries to get a 994 | connection from the pool when all of connections are in use, rather than 995 | raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default 996 | ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it 997 | makes the client wait ("blocks") for a specified number of seconds until 998 | a connection becomes available. 999 | 1000 | Use ``max_connections`` to increase / decrease the pool size:: 1001 | 1002 | >>> pool = BlockingConnectionPool(max_connections=10) 1003 | 1004 | Use ``timeout`` to tell it either how many seconds to wait for a connection 1005 | to become available, or to block forever: 1006 | 1007 | # Block forever. 1008 | >>> pool = BlockingConnectionPool(timeout=None) 1009 | 1010 | # Raise a ``ConnectionError`` after five seconds if a connection is 1011 | # not available. 1012 | >>> pool = BlockingConnectionPool(timeout=5) 1013 | """ 1014 | def __init__(self, max_connections=50, timeout=20, 1015 | connection_class=Connection, queue_class=LifoQueue, 1016 | **connection_kwargs): 1017 | 1018 | self.queue_class = queue_class 1019 | self.timeout = timeout 1020 | super(BlockingConnectionPool, self).__init__( 1021 | connection_class=connection_class, 1022 | max_connections=max_connections, 1023 | **connection_kwargs) 1024 | 1025 | def reset(self): 1026 | self.pid = os.getpid() 1027 | self._check_lock = threading.Lock() 1028 | 1029 | # Create and fill up a thread safe queue with ``None`` values. 1030 | self.pool = self.queue_class(self.max_connections) 1031 | while True: 1032 | try: 1033 | self.pool.put_nowait(None) 1034 | except Full: 1035 | break 1036 | 1037 | # Keep a list of actual connection instances so that we can 1038 | # disconnect them later. 1039 | self._connections = [] 1040 | 1041 | def make_connection(self): 1042 | "Make a fresh connection." 1043 | connection = self.connection_class(**self.connection_kwargs) 1044 | self._connections.append(connection) 1045 | return connection 1046 | 1047 | def get_connection(self, command_name, *keys, **options): 1048 | """ 1049 | Get a connection, blocking for ``self.timeout`` until a connection 1050 | is available from the pool. 1051 | 1052 | If the connection returned is ``None`` then creates a new connection. 1053 | Because we use a last-in first-out queue, the existing connections 1054 | (having been returned to the pool after the initial ``None`` values 1055 | were added) will be returned before ``None`` values. This means we only 1056 | create new connections when we need to, i.e.: the actual number of 1057 | connections will only increase in response to demand. 1058 | """ 1059 | # Make sure we haven't changed process. 1060 | self._checkpid() 1061 | 1062 | # Try and get a connection from the pool. If one isn't available within 1063 | # self.timeout then raise a ``ConnectionError``. 1064 | connection = None 1065 | try: 1066 | connection = self.pool.get(block=True, timeout=self.timeout) 1067 | except Empty: 1068 | # Note that this is not caught by the redis client and will be 1069 | # raised unless handled by application code. If you want never to 1070 | raise ConnectionError("No connection available.") 1071 | 1072 | # If the ``connection`` is actually ``None`` then that's a cue to make 1073 | # a new connection to add to the pool. 1074 | if connection is None: 1075 | connection = self.make_connection() 1076 | 1077 | return connection 1078 | 1079 | def release(self, connection): 1080 | "Releases the connection back to the pool." 1081 | # Make sure we haven't changed process. 1082 | self._checkpid() 1083 | if connection.pid != self.pid: 1084 | return 1085 | 1086 | # Put the connection back into the pool. 1087 | try: 1088 | self.pool.put_nowait(connection) 1089 | except Full: 1090 | # perhaps the pool has been reset() after a fork? regardless, 1091 | # we don't want this connection 1092 | pass 1093 | 1094 | def disconnect(self): 1095 | "Disconnects all connections in the pool." 1096 | for connection in self._connections: 1097 | connection.disconnect() 1098 | -------------------------------------------------------------------------------- /redis/exceptions.py: -------------------------------------------------------------------------------- 1 | "Core exceptions raised by the Redis client" 2 | from redis._compat import unicode 3 | 4 | 5 | class RedisError(Exception): 6 | pass 7 | 8 | 9 | # python 2.5 doesn't implement Exception.__unicode__. Add it here to all 10 | # our exception types 11 | if not hasattr(RedisError, '__unicode__'): 12 | def __unicode__(self): 13 | if isinstance(self.args[0], unicode): 14 | return self.args[0] 15 | return unicode(self.args[0]) 16 | RedisError.__unicode__ = __unicode__ 17 | 18 | 19 | class AuthenticationError(RedisError): 20 | pass 21 | 22 | 23 | class ConnectionError(RedisError): 24 | pass 25 | 26 | 27 | class TimeoutError(RedisError): 28 | pass 29 | 30 | 31 | class BusyLoadingError(ConnectionError): 32 | pass 33 | 34 | 35 | class InvalidResponse(RedisError): 36 | pass 37 | 38 | 39 | class ResponseError(RedisError): 40 | pass 41 | 42 | 43 | class DataError(RedisError): 44 | pass 45 | 46 | 47 | class PubSubError(RedisError): 48 | pass 49 | 50 | 51 | class WatchError(RedisError): 52 | pass 53 | 54 | 55 | class NoScriptError(ResponseError): 56 | pass 57 | 58 | 59 | class ExecAbortError(ResponseError): 60 | pass 61 | 62 | 63 | class ReadOnlyError(ResponseError): 64 | pass 65 | 66 | 67 | class LockError(RedisError, ValueError): 68 | "Errors acquiring or releasing a lock" 69 | # NOTE: For backwards compatability, this class derives from ValueError. 70 | # This was originally chosen to behave like threading.Lock. 71 | pass 72 | -------------------------------------------------------------------------------- /redis/lock.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time as mod_time 3 | import uuid 4 | from redis.exceptions import LockError, WatchError 5 | from redis.utils import dummy 6 | from redis._compat import b 7 | 8 | 9 | class Lock(object): 10 | """ 11 | A shared, distributed Lock. Using Redis for locking allows the Lock 12 | to be shared across processes and/or machines. 13 | 14 | It's left to the user to resolve deadlock issues and make sure 15 | multiple clients play nicely together. 16 | """ 17 | def __init__(self, redis, name, timeout=None, sleep=0.1, 18 | blocking=True, blocking_timeout=None, thread_local=True): 19 | """ 20 | Create a new Lock instance named ``name`` using the Redis client 21 | supplied by ``redis``. 22 | 23 | ``timeout`` indicates a maximum life for the lock. 24 | By default, it will remain locked until release() is called. 25 | ``timeout`` can be specified as a float or integer, both representing 26 | the number of seconds to wait. 27 | 28 | ``sleep`` indicates the amount of time to sleep per loop iteration 29 | when the lock is in blocking mode and another client is currently 30 | holding the lock. 31 | 32 | ``blocking`` indicates whether calling ``acquire`` should block until 33 | the lock has been acquired or to fail immediately, causing ``acquire`` 34 | to return False and the lock not being acquired. Defaults to True. 35 | Note this value can be overridden by passing a ``blocking`` 36 | argument to ``acquire``. 37 | 38 | ``blocking_timeout`` indicates the maximum amount of time in seconds to 39 | spend trying to acquire the lock. A value of ``None`` indicates 40 | continue trying forever. ``blocking_timeout`` can be specified as a 41 | float or integer, both representing the number of seconds to wait. 42 | 43 | ``thread_local`` indicates whether the lock token is placed in 44 | thread-local storage. By default, the token is placed in thread local 45 | storage so that a thread only sees its token, not a token set by 46 | another thread. Consider the following timeline: 47 | 48 | time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 49 | thread-1 sets the token to "abc" 50 | time: 1, thread-2 blocks trying to acquire `my-lock` using the 51 | Lock instance. 52 | time: 5, thread-1 has not yet completed. redis expires the lock 53 | key. 54 | time: 5, thread-2 acquired `my-lock` now that it's available. 55 | thread-2 sets the token to "xyz" 56 | time: 6, thread-1 finishes its work and calls release(). if the 57 | token is *not* stored in thread local storage, then 58 | thread-1 would see the token value as "xyz" and would be 59 | able to successfully release the thread-2's lock. 60 | 61 | In some use cases it's necessary to disable thread local storage. For 62 | example, if you have code where one thread acquires a lock and passes 63 | that lock instance to a worker thread to release later. If thread 64 | local storage isn't disabled in this case, the worker thread won't see 65 | the token set by the thread that acquired the lock. Our assumption 66 | is that these cases aren't common and as such default to using 67 | thread local storage. 68 | """ 69 | self.redis = redis 70 | self.name = name 71 | self.timeout = timeout 72 | self.sleep = sleep 73 | self.blocking = blocking 74 | self.blocking_timeout = blocking_timeout 75 | self.thread_local = bool(thread_local) 76 | self.local = threading.local() if self.thread_local else dummy() 77 | self.local.token = None 78 | if self.timeout and self.sleep > self.timeout: 79 | raise LockError("'sleep' must be less than 'timeout'") 80 | 81 | def __enter__(self): 82 | # force blocking, as otherwise the user would have to check whether 83 | # the lock was actually acquired or not. 84 | self.acquire(blocking=True) 85 | return self 86 | 87 | def __exit__(self, exc_type, exc_value, traceback): 88 | self.release() 89 | 90 | def acquire(self, blocking=None, blocking_timeout=None): 91 | """ 92 | Use Redis to hold a shared, distributed lock named ``name``. 93 | Returns True once the lock is acquired. 94 | 95 | If ``blocking`` is False, always return immediately. If the lock 96 | was acquired, return True, otherwise return False. 97 | 98 | ``blocking_timeout`` specifies the maximum number of seconds to 99 | wait trying to acquire the lock. 100 | """ 101 | sleep = self.sleep 102 | token = b(uuid.uuid1().hex) 103 | if blocking is None: 104 | blocking = self.blocking 105 | if blocking_timeout is None: 106 | blocking_timeout = self.blocking_timeout 107 | stop_trying_at = None 108 | if blocking_timeout is not None: 109 | stop_trying_at = mod_time.time() + blocking_timeout 110 | while 1: 111 | if self.do_acquire(token): 112 | self.local.token = token 113 | return True 114 | if not blocking: 115 | return False 116 | if stop_trying_at is not None and mod_time.time() > stop_trying_at: 117 | return False 118 | mod_time.sleep(sleep) 119 | 120 | def do_acquire(self, token): 121 | if self.redis.setnx(self.name, token): 122 | if self.timeout: 123 | # convert to milliseconds 124 | timeout = int(self.timeout * 1000) 125 | self.redis.pexpire(self.name, timeout) 126 | return True 127 | return False 128 | 129 | def release(self): 130 | "Releases the already acquired lock" 131 | expected_token = self.local.token 132 | if expected_token is None: 133 | raise LockError("Cannot release an unlocked lock") 134 | self.local.token = None 135 | self.do_release(expected_token) 136 | 137 | def do_release(self, expected_token): 138 | name = self.name 139 | 140 | def execute_release(pipe): 141 | lock_value = pipe.get(name) 142 | if lock_value != expected_token: 143 | raise LockError("Cannot release a lock that's no longer owned") 144 | pipe.delete(name) 145 | 146 | self.redis.transaction(execute_release, name) 147 | 148 | def extend(self, additional_time): 149 | """ 150 | Adds more time to an already acquired lock. 151 | 152 | ``additional_time`` can be specified as an integer or a float, both 153 | representing the number of seconds to add. 154 | """ 155 | if self.local.token is None: 156 | raise LockError("Cannot extend an unlocked lock") 157 | if self.timeout is None: 158 | raise LockError("Cannot extend a lock with no timeout") 159 | return self.do_extend(additional_time) 160 | 161 | def do_extend(self, additional_time): 162 | pipe = self.redis.pipeline() 163 | pipe.watch(self.name) 164 | lock_value = pipe.get(self.name) 165 | if lock_value != self.local.token: 166 | raise LockError("Cannot extend a lock that's no longer owned") 167 | expiration = pipe.pttl(self.name) 168 | if expiration is None or expiration < 0: 169 | # Redis evicted the lock key between the previous get() and now 170 | # we'll handle this when we call pexpire() 171 | expiration = 0 172 | pipe.multi() 173 | pipe.pexpire(self.name, expiration + int(additional_time * 1000)) 174 | 175 | try: 176 | response = pipe.execute() 177 | except WatchError: 178 | # someone else acquired the lock 179 | raise LockError("Cannot extend a lock that's no longer owned") 180 | if not response[0]: 181 | # pexpire returns False if the key doesn't exist 182 | raise LockError("Cannot extend a lock that's no longer owned") 183 | return True 184 | 185 | 186 | class LuaLock(Lock): 187 | """ 188 | A lock implementation that uses Lua scripts rather than pipelines 189 | and watches. 190 | """ 191 | lua_acquire = None 192 | lua_release = None 193 | lua_extend = None 194 | 195 | # KEYS[1] - lock name 196 | # ARGV[1] - token 197 | # ARGV[2] - timeout in milliseconds 198 | # return 1 if lock was acquired, otherwise 0 199 | LUA_ACQUIRE_SCRIPT = """ 200 | if redis.call('setnx', KEYS[1], ARGV[1]) == 1 then 201 | if ARGV[2] ~= '' then 202 | redis.call('pexpire', KEYS[1], ARGV[2]) 203 | end 204 | return 1 205 | end 206 | return 0 207 | """ 208 | 209 | # KEYS[1] - lock name 210 | # ARGS[1] - token 211 | # return 1 if the lock was released, otherwise 0 212 | LUA_RELEASE_SCRIPT = """ 213 | local token = redis.call('get', KEYS[1]) 214 | if not token or token ~= ARGV[1] then 215 | return 0 216 | end 217 | redis.call('del', KEYS[1]) 218 | return 1 219 | """ 220 | 221 | # KEYS[1] - lock name 222 | # ARGS[1] - token 223 | # ARGS[2] - additional milliseconds 224 | # return 1 if the locks time was extended, otherwise 0 225 | LUA_EXTEND_SCRIPT = """ 226 | local token = redis.call('get', KEYS[1]) 227 | if not token or token ~= ARGV[1] then 228 | return 0 229 | end 230 | local expiration = redis.call('pttl', KEYS[1]) 231 | if not expiration then 232 | expiration = 0 233 | end 234 | if expiration < 0 then 235 | return 0 236 | end 237 | redis.call('pexpire', KEYS[1], expiration + ARGV[2]) 238 | return 1 239 | """ 240 | 241 | def __init__(self, *args, **kwargs): 242 | super(LuaLock, self).__init__(*args, **kwargs) 243 | LuaLock.register_scripts(self.redis) 244 | 245 | @classmethod 246 | def register_scripts(cls, redis): 247 | if cls.lua_acquire is None: 248 | cls.lua_acquire = redis.register_script(cls.LUA_ACQUIRE_SCRIPT) 249 | if cls.lua_release is None: 250 | cls.lua_release = redis.register_script(cls.LUA_RELEASE_SCRIPT) 251 | if cls.lua_extend is None: 252 | cls.lua_extend = redis.register_script(cls.LUA_EXTEND_SCRIPT) 253 | 254 | def do_acquire(self, token): 255 | timeout = self.timeout and int(self.timeout * 1000) or '' 256 | return bool(self.lua_acquire(keys=[self.name], 257 | args=[token, timeout], 258 | client=self.redis)) 259 | 260 | def do_release(self, expected_token): 261 | if not bool(self.lua_release(keys=[self.name], 262 | args=[expected_token], 263 | client=self.redis)): 264 | raise LockError("Cannot release a lock that's no longer owned") 265 | 266 | def do_extend(self, additional_time): 267 | additional_time = int(additional_time * 1000) 268 | if not bool(self.lua_extend(keys=[self.name], 269 | args=[self.local.token, additional_time], 270 | client=self.redis)): 271 | raise LockError("Cannot extend a lock that's no longer owned") 272 | return True 273 | -------------------------------------------------------------------------------- /redis/sentinel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import weakref 4 | 5 | from redis.client import StrictRedis 6 | from redis.connection import ConnectionPool, Connection 7 | from redis.exceptions import (ConnectionError, ResponseError, ReadOnlyError, 8 | TimeoutError) 9 | from redis._compat import iteritems, nativestr, xrange 10 | 11 | 12 | class MasterNotFoundError(ConnectionError): 13 | pass 14 | 15 | 16 | class SlaveNotFoundError(ConnectionError): 17 | pass 18 | 19 | 20 | class SentinelManagedConnection(Connection): 21 | def __init__(self, **kwargs): 22 | self.connection_pool = kwargs.pop('connection_pool') 23 | super(SentinelManagedConnection, self).__init__(**kwargs) 24 | 25 | def __repr__(self): 26 | pool = self.connection_pool 27 | s = '%s' % (type(self).__name__, pool.service_name) 28 | if self.host: 29 | host_info = ',host=%s,port=%s' % (self.host, self.port) 30 | s = s % host_info 31 | return s 32 | 33 | def connect_to(self, address): 34 | self.host, self.port = address 35 | super(SentinelManagedConnection, self).connect() 36 | if self.connection_pool.check_connection: 37 | self.send_command('PING') 38 | if nativestr(self.read_response()) != 'PONG': 39 | raise ConnectionError('PING failed') 40 | 41 | def connect(self): 42 | if self._sock: 43 | return # already connected 44 | if self.connection_pool.is_master: 45 | self.connect_to(self.connection_pool.get_master_address()) 46 | else: 47 | for slave in self.connection_pool.rotate_slaves(): 48 | try: 49 | return self.connect_to(slave) 50 | except ConnectionError: 51 | continue 52 | raise SlaveNotFoundError # Never be here 53 | 54 | def read_response(self): 55 | try: 56 | return super(SentinelManagedConnection, self).read_response() 57 | except ReadOnlyError: 58 | if self.connection_pool.is_master: 59 | # When talking to a master, a ReadOnlyError when likely 60 | # indicates that the previous master that we're still connected 61 | # to has been demoted to a slave and there's a new master. 62 | # calling disconnect will force the connection to re-query 63 | # sentinel during the next connect() attempt. 64 | self.disconnect() 65 | raise ConnectionError('The previous master is now a slave') 66 | raise 67 | 68 | 69 | class SentinelConnectionPool(ConnectionPool): 70 | """ 71 | Sentinel backed connection pool. 72 | 73 | If ``check_connection`` flag is set to True, SentinelManagedConnection 74 | sends a PING command right after establishing the connection. 75 | """ 76 | 77 | def __init__(self, service_name, sentinel_manager, **kwargs): 78 | kwargs['connection_class'] = kwargs.get( 79 | 'connection_class', SentinelManagedConnection) 80 | self.is_master = kwargs.pop('is_master', True) 81 | self.check_connection = kwargs.pop('check_connection', False) 82 | super(SentinelConnectionPool, self).__init__(**kwargs) 83 | self.connection_kwargs['connection_pool'] = weakref.proxy(self) 84 | self.service_name = service_name 85 | self.sentinel_manager = sentinel_manager 86 | 87 | def __repr__(self): 88 | return "%s>> from redis.sentinel import Sentinel 145 | >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) 146 | >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) 147 | >>> master.set('foo', 'bar') 148 | >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) 149 | >>> slave.get('foo') 150 | 'bar' 151 | 152 | ``sentinels`` is a list of sentinel nodes. Each node is represented by 153 | a pair (hostname, port). 154 | 155 | ``min_other_sentinels`` defined a minimum number of peers for a sentinel. 156 | When querying a sentinel, if it doesn't meet this threshold, responses 157 | from that sentinel won't be considered valid. 158 | 159 | ``sentinel_kwargs`` is a dictionary of connection arguments used when 160 | connecting to sentinel instances. Any argument that can be passed to 161 | a normal Redis connection can be specified here. If ``sentinel_kwargs`` is 162 | not specified, any socket_timeout and socket_keepalive options specified 163 | in ``connection_kwargs`` will be used. 164 | 165 | ``connection_kwargs`` are keyword arguments that will be used when 166 | establishing a connection to a Redis server. 167 | """ 168 | 169 | def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None, 170 | **connection_kwargs): 171 | # if sentinel_kwargs isn't defined, use the socket_* options from 172 | # connection_kwargs 173 | if sentinel_kwargs is None: 174 | sentinel_kwargs = dict([(k, v) 175 | for k, v in iteritems(connection_kwargs) 176 | if k.startswith('socket_') 177 | ]) 178 | self.sentinel_kwargs = sentinel_kwargs 179 | 180 | self.sentinels = [StrictRedis(hostname, port, **self.sentinel_kwargs) 181 | for hostname, port in sentinels] 182 | self.min_other_sentinels = min_other_sentinels 183 | self.connection_kwargs = connection_kwargs 184 | 185 | def __repr__(self): 186 | sentinel_addresses = [] 187 | for sentinel in self.sentinels: 188 | sentinel_addresses.append('%s:%s' % ( 189 | sentinel.connection_pool.connection_kwargs['host'], 190 | sentinel.connection_pool.connection_kwargs['port'], 191 | )) 192 | return '%s' % ( 193 | type(self).__name__, 194 | ','.join(sentinel_addresses)) 195 | 196 | def check_master_state(self, state, service_name): 197 | if not state['is_master'] or state['is_sdown'] or state['is_odown']: 198 | return False 199 | # Check if our sentinel doesn't see other nodes 200 | if state['num-other-sentinels'] < self.min_other_sentinels: 201 | return False 202 | return True 203 | 204 | def discover_master(self, service_name): 205 | """ 206 | Asks sentinel servers for the Redis master's address corresponding 207 | to the service labeled ``service_name``. 208 | 209 | Returns a pair (address, port) or raises MasterNotFoundError if no 210 | master is found. 211 | """ 212 | for sentinel_no, sentinel in enumerate(self.sentinels): 213 | try: 214 | masters = sentinel.sentinel_masters() 215 | except (ConnectionError, TimeoutError): 216 | continue 217 | state = masters.get(service_name) 218 | if state and self.check_master_state(state, service_name): 219 | # Put this sentinel at the top of the list 220 | self.sentinels[0], self.sentinels[sentinel_no] = ( 221 | sentinel, self.sentinels[0]) 222 | return state['ip'], state['port'] 223 | raise MasterNotFoundError("No master found for %r" % (service_name,)) 224 | 225 | def filter_slaves(self, slaves): 226 | "Remove slaves that are in an ODOWN or SDOWN state" 227 | slaves_alive = [] 228 | for slave in slaves: 229 | if slave['is_odown'] or slave['is_sdown']: 230 | continue 231 | slaves_alive.append((slave['ip'], slave['port'])) 232 | return slaves_alive 233 | 234 | def discover_slaves(self, service_name): 235 | "Returns a list of alive slaves for service ``service_name``" 236 | for sentinel in self.sentinels: 237 | try: 238 | slaves = sentinel.sentinel_slaves(service_name) 239 | except (ConnectionError, ResponseError, TimeoutError): 240 | continue 241 | slaves = self.filter_slaves(slaves) 242 | if slaves: 243 | return slaves 244 | return [] 245 | 246 | def master_for(self, service_name, redis_class=StrictRedis, 247 | connection_pool_class=SentinelConnectionPool, **kwargs): 248 | """ 249 | Returns a redis client instance for the ``service_name`` master. 250 | 251 | A SentinelConnectionPool class is used to retrive the master's 252 | address before establishing a new connection. 253 | 254 | NOTE: If the master's address has changed, any cached connections to 255 | the old master are closed. 256 | 257 | By default clients will be a redis.StrictRedis instance. Specify a 258 | different class to the ``redis_class`` argument if you desire 259 | something different. 260 | 261 | The ``connection_pool_class`` specifies the connection pool to use. 262 | The SentinelConnectionPool will be used by default. 263 | 264 | All other keyword arguments are merged with any connection_kwargs 265 | passed to this class and passed to the connection pool as keyword 266 | arguments to be used to initialize Redis connections. 267 | """ 268 | kwargs['is_master'] = True 269 | connection_kwargs = dict(self.connection_kwargs) 270 | connection_kwargs.update(kwargs) 271 | return redis_class(connection_pool=connection_pool_class( 272 | service_name, self, **connection_kwargs)) 273 | 274 | def slave_for(self, service_name, redis_class=StrictRedis, 275 | connection_pool_class=SentinelConnectionPool, **kwargs): 276 | """ 277 | Returns redis client instance for the ``service_name`` slave(s). 278 | 279 | A SentinelConnectionPool class is used to retrive the slave's 280 | address before establishing a new connection. 281 | 282 | By default clients will be a redis.StrictRedis instance. Specify a 283 | different class to the ``redis_class`` argument if you desire 284 | something different. 285 | 286 | The ``connection_pool_class`` specifies the connection pool to use. 287 | The SentinelConnectionPool will be used by default. 288 | 289 | All other keyword arguments are merged with any connection_kwargs 290 | passed to this class and passed to the connection pool as keyword 291 | arguments to be used to initialize Redis connections. 292 | """ 293 | kwargs['is_master'] = False 294 | connection_kwargs = dict(self.connection_kwargs) 295 | connection_kwargs.update(kwargs) 296 | return redis_class(connection_pool=connection_pool_class( 297 | service_name, self, **connection_kwargs)) 298 | -------------------------------------------------------------------------------- /redis/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | 4 | try: 5 | import hiredis 6 | HIREDIS_AVAILABLE = True 7 | except ImportError: 8 | HIREDIS_AVAILABLE = False 9 | 10 | 11 | def from_url(url, db=None, **kwargs): 12 | """ 13 | Returns an active Redis client generated from the given database URL. 14 | 15 | Will attempt to extract the database id from the path url fragment, if 16 | none is provided. 17 | """ 18 | from redis.client import Redis 19 | return Redis.from_url(url, db, **kwargs) 20 | 21 | 22 | @contextmanager 23 | def pipeline(redis_obj): 24 | p = redis_obj.pipeline() 25 | yield p 26 | p.execute() 27 | 28 | 29 | class dummy(object): 30 | """ 31 | Instances of this class can be used as an attribute container. 32 | """ 33 | pass 34 | -------------------------------------------------------------------------------- /viv_frontend.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import hashlib 4 | import logging 5 | from collections import defaultdict 6 | 7 | from vivisect import * 8 | 9 | from client import Client 10 | from config import config 11 | from comments import Comments, NoChange 12 | from coverage import Coverage 13 | from threading import Lock 14 | from collections import namedtuple 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | Struct = namedtuple('Struct', 'name typedef') 19 | 20 | class State: 21 | @staticmethod 22 | def get(vw): 23 | return vw.metadata.get('revsync') 24 | 25 | show_visits = True 26 | show_time = True 27 | show_visitors = False 28 | track_coverage = True 29 | color_now = False 30 | running = False 31 | 32 | def __init__(self, vw): 33 | self.cov = Coverage() 34 | self.comments = Comments() 35 | self.running = True 36 | self.cmt_changes = {} 37 | self.cmt_lock = Lock() 38 | self.stackvar_changes = {} 39 | self.stackvar_lock = Lock() 40 | self.syms = get_syms(vw) 41 | self.syms_lock = Lock() 42 | self.structs = {} # get_structs(vw) 43 | self.structs_lock = Lock() 44 | 45 | self.filedata_by_sha = {} 46 | self.filedata_by_fname = {} 47 | for fname in vw.getFiles(): 48 | sha256 = vw.getFileMeta(fname, 'sha256') 49 | mdict = dict(vw.getFileMetaDict(fname)) 50 | mdict['name'] = fname 51 | mdict['maps'] = [mmap for mmap in vw.getMemoryMaps() if mmap[MAP_FNAME] == fname] 52 | mdict['sha256'] = sha256 53 | self.filedata_by_sha[sha256] = mdict 54 | self.filedata_by_fname[fname] = mdict 55 | 56 | def close(self): 57 | # close out the previous session 58 | for shaval in self.genFhashes(): 59 | client.leave(shaval) 60 | 61 | self.running = False 62 | 63 | def getMetaBySha(self, key): 64 | return self.filedata_by_sha.get(key) 65 | 66 | def getMetaByFname(self, key): 67 | return self.filedata_by_fname.get(key) 68 | 69 | def getHashByAddr(self, addr): 70 | for mdict in self.filedata_by_fname.values(): 71 | for mmap in mdict.get('maps'): 72 | mva, msz, mperm, mfname = mmap 73 | if addr >= mva and addr < mva+msz: 74 | fhash = mdict.get('sha256') 75 | return fhash, mfname 76 | 77 | def genFhashes(self): 78 | for fdict in self.filedata_by_sha.values(): 79 | try: 80 | yield fdict['sha256'] 81 | except KeyError: 82 | vw.vprint("keyerror: no 'sha256' key in metadict: %r" % fdict) 83 | 84 | MIN_COLOR = 0 85 | MAX_COLOR = 200 86 | 87 | IDLE_ASK = 250 88 | COLOUR_PERIOD = 20 89 | BB_REPORT = 50 90 | 91 | 92 | def get_can_addr(vw, addr): 93 | ''' 94 | normalizing addresses 95 | ''' 96 | fname = vw.getFileByVa(addr) 97 | if fname is None: 98 | raise Exception("ARRG! get_can_addr(0x%x)" % addr) 99 | 100 | imagebase = vw.getFileMeta(fname, 'imagebase') 101 | return addr - imagebase 102 | 103 | def get_ea(vw, sha_key, addr): 104 | ''' 105 | normalizing addresses 106 | ''' 107 | imagebase = None 108 | for fname in vw.getFiles(): 109 | fmeta = vw.getFileMetaDict(fname) 110 | if fmeta.get('sha256') == sha_key: 111 | imagebase = fmeta.get('imagebase') 112 | break 113 | 114 | return addr + imagebase 115 | 116 | def get_func_by_addr(vw, addr): 117 | return vw.getFunction(addr) 118 | 119 | def get_bb_by_addr(vw, addr): 120 | return vw.getCodeBlock(addr) 121 | 122 | # in order to map IDA type sizes <-> viv types, 123 | # take the 'size' field and attempt to divine some 124 | # kind of type in viv that's close as possible. 125 | # right now, that means try to use uint8_t ... uint64_t, 126 | # anything bigger just make an array 127 | def get_type_by_size(vw, size): 128 | typedef = None 129 | if size <= 8: 130 | try: 131 | typedef, name = vw.parse_type_string('uint{}_t'.format(8*size)) 132 | except SyntaxError: 133 | pass 134 | else: 135 | try: 136 | typedef, name = vw.parse_type_string('char a[{}]'.format(8*size)) 137 | except SyntaxError: 138 | pass 139 | return typedef 140 | 141 | def get_syms(vw): 142 | syms = dict(vw.name_by_va) 143 | return syms 144 | 145 | def stack_dict_from_list(stackvars): 146 | d = {} 147 | for fva, offset, unknown, (vartype, varname) in stackvars: 148 | d[offset] = (varname, vartype) 149 | return d 150 | 151 | def rename_symbol(vw, addr, name): # done 152 | vw.makeName(addr, name) 153 | 154 | def rename_stackvar(vw, func_addr, offset, name): # TESTME 155 | func = vw.getFunction(func_addr) 156 | if func is None: 157 | vw.vprint('revsync: bad func addr %#x during rename_stackvar' % func_addr) 158 | return 159 | 160 | # we need to figure out the variable type before renaming 161 | stackvars = stack_dict_from_list(vw.getFunctionLocals(func_addr)) 162 | var = stackvars.get(offset) 163 | if var is None: 164 | vw.vprint('revsync: could not locate stack var with offset %#x during rename_stackvar' % offset) 165 | return 166 | var_name, var_type = var 167 | # CHECKME: does this set function args too? do i need to split them? 168 | vw.setFunctionLocal(func_addr, offset, var_type, name) 169 | #aidx = offset // vw.getPointerSize() 170 | #vw.setFunctionArg(func_addr, aidx, var_type, name) 171 | return 172 | 173 | def publish(vw, data, fhash, **kwargs): 174 | state = State.get(vw) 175 | if state: 176 | client.publish(fhash, data, **kwargs) 177 | 178 | def analyze(vw): 179 | vw.vprint('revsync: running analysis update...') 180 | vw.analyze() 181 | vw.vprint('revsync: analysis finished.') 182 | return 183 | 184 | 185 | 186 | 187 | 188 | ##### structs are not currently supported 189 | def get_structs(vw): 190 | d = dict() 191 | for name, typedef in vw.types.items(): 192 | if typedef.structure: 193 | typeid = vw.get_type_id(name) 194 | struct = Struct(name, typedef.structure) 195 | d[typeid] = struct 196 | return d 197 | 198 | def member_dict_from_list(members): 199 | d = {} 200 | for member in members: 201 | d[member.name] = member 202 | return d 203 | 204 | 205 | #### again, no struct just yet 206 | def watch_structs(vw): 207 | """ Check structs for changes and publish diffs""" 208 | state = State.get(vw) 209 | 210 | while state.running: 211 | state.structs_lock.acquire() 212 | structs = get_structs(vw) 213 | if structs != state.structs: 214 | for struct_id, struct in structs.items(): 215 | last_struct = state.structs.get(struct_id) 216 | struct_name = struct.name 217 | if last_struct == None: 218 | # new struct created, publish 219 | vw.vprint('revsync: user created struct %s' % struct_name) 220 | # binja can't really handle unions at this time 221 | publish(bv, {'cmd': 'struc_created', 'struc_name': str(struct_name), 'is_union': False}) 222 | # if there are already members, publish them 223 | members = member_dict_from_list(struct.typedef.members) 224 | if members: 225 | for member_name, member_def in members.items(): 226 | publish(bv, {'cmd': 'struc_member_created', 'struc_name': str(struct_name), 'offset': member_def.offset, 'member_name': member_name, 'size': member_def.type.width, 'flag': None}) 227 | continue 228 | last_name = last_struct.name 229 | if last_name != struct_name: 230 | # struct renamed, publish 231 | vw.vprint('revsync: user renamed struct %s' % struct_name) 232 | publish(bv, {'cmd': 'struc_renamed', 'old_name': str(last_name), 'new_name': str(struct_name)}) 233 | 234 | # check for member differences 235 | members = member_dict_from_list(struct.typedef.members) 236 | last_members = member_dict_from_list(last_struct.typedef.members) 237 | 238 | # first checks for deletions 239 | removed_members = set(last_members.keys()) - set(members.keys()) 240 | for member in removed_members: 241 | vw.vprint('revsync: user deleted struct member %s in struct %s' % (last_members[member].name, str(struct_name))) 242 | publish(bv, {'cmd': 'struc_member_deleted', 'struc_name': str(struct_name), 'offset': last_members[member].offset}) 243 | 244 | # now check for additions 245 | new_members = set(members.keys()) - set(last_members.keys()) 246 | for member in new_members: 247 | vw.vprint('revsync: user added struct member %s in struct %s' % (members[member].name, str(struct_name))) 248 | publish(bv, {'cmd': 'struc_member_created', 'struc_name': str(struct_name), 'offset': members[member].offset, 'member_name': str(member), 'size': members[member].type.width, 'flag': None}) 249 | 250 | # check for changes among intersection of members 251 | intersec = set(members.keys()) & set(last_members.keys()) 252 | for m in intersec: 253 | if members[m].type.width != last_members[m].type.width: 254 | # type (i.e., size) changed 255 | vw.vprint('revsync: user changed struct member %s in struct %s' % (members[m].name, str(struct_name))) 256 | publish(bv, {'cmd': 'struc_member_changed', 'struc_name': str(struct_name), 'offset': members[m].offset, 'size': members[m].type.width}) 257 | 258 | for struct_id, struct_def in state.structs.items(): 259 | if structs.get(struct_id) == None: 260 | # struct deleted, publish 261 | vw.vprint('revsync: user deleted struct %s' % struct_def.name) 262 | publish(bv, {'cmd': 'struc_deleted', 'struc_name': str(struct_def.name)}) 263 | state.structs = get_structs(bv) 264 | state.structs_lock.release() 265 | time.sleep(0.5) 266 | 267 | ###### Coverage not yet implemented 268 | def push_cv(vw, data, **kwargs): 269 | state = State.get(vw) 270 | if state: 271 | client.push("%s_COVERAGE" % state.fhash, data, **kwargs) 272 | 273 | def map_color(x): 274 | n = x 275 | if x == 0: return 0 276 | # x = min(max(0, (x ** 2) / (2 * (x ** 2 - x) + 1)), 1) 277 | # if x == 0: return 0 278 | return int(math.ceil((MAX_COLOR - MIN_COLOR) * x + MIN_COLOR)) 279 | 280 | def convert_color(color): 281 | r, g, b = [map_color(x) for x in color] 282 | return highlight.HighlightColor(red=r, green=g, blue=b) 283 | 284 | def colour_coverage(bv, cur_func): 285 | state = State.get(bv) 286 | for bb in cur_func.basic_blocks: 287 | color = state.cov.color(get_can_addr(bv, bb.start), visits=state.show_visits, time=state.show_time, users=state.show_visitors) 288 | if color: 289 | bb.set_user_highlight(convert_color(color)) 290 | else: 291 | bb.set_user_highlight(highlight.HighlightColor(red=74, blue=74, green=74)) 292 | 293 | 294 | 295 | 296 | ### handle remote events: 297 | def onmsg(vw, key, data, replay): 298 | logger.info("onmsg: %r : %r (%r)" % (key, data, replay)) 299 | try: 300 | state = State.get(vw) 301 | meta = state.getMetaBySha(key) 302 | if meta is None: 303 | vw.vprint('revsync: hash mismatch, dropping command') 304 | return 305 | 306 | cmd, user = data['cmd'], data['user'] 307 | ts = int(data.get('ts', 0)) 308 | if cmd == 'comment': 309 | state.cmt_lock.acquire() 310 | vw.vprint('revsync: <%s> %s %#x %s' % (user, cmd, data['addr'], data['text'])) 311 | addr = get_ea(vw, key, int(data['addr'])) 312 | #func = get_func_by_addr(vw, addr) 313 | ## binja does not support comments on data symbols??? IDA does. 314 | #if func is not None: 315 | text = state.comments.set(addr, user, data['text'], ts) 316 | vw.setComment(addr, text) 317 | state.cmt_changes[addr] = text 318 | state.cmt_lock.release() 319 | 320 | elif cmd == 'extra_comment': 321 | vw.vprint('revsync: <%s> %s %#x %s' % (user, cmd, data['addr'], data['text'])) 322 | 323 | elif cmd == 'area_comment': 324 | vw.vprint('revsync: <%s> %s %s %s' % (user, cmd, data['range'], data['text'])) 325 | 326 | elif cmd == 'rename': 327 | state.syms_lock.acquire() 328 | vw.vprint('revsync: <%s> %s %#x %s' % (user, cmd, data['addr'], data['text'])) 329 | addr = get_ea(vw, key, int(data['addr'])) 330 | rename_symbol(vw, addr, data['text']) 331 | state.syms = get_syms(vw) 332 | state.syms_lock.release() 333 | 334 | elif cmd == 'stackvar_renamed': 335 | state.stackvar_lock.acquire() 336 | func_name = '???' 337 | func = get_func_by_addr(vw, data['addr']) 338 | if func: 339 | func_name = vw.getName(func) 340 | vw.vprint('revsync: <%s> %s %s %#x %s' % (user, cmd, func_name, data['offset'], data['name'])) 341 | rename_stackvar(vw, data['addr'], data['offset'], data['name']) 342 | # save stackvar changes using the tuple (func_addr, offset) as key 343 | state.stackvar_changes[(data['addr'],data['offset'])] = data['name'] 344 | state.stackvar_lock.release() 345 | 346 | elif cmd == 'struc_created': 347 | # note: binja does not seem to appreciate the encoding of strings from redis 348 | struct_name = data['struc_name'].encode('ascii', 'ignore') 349 | vw.vprint('revsync: <%s> %s %s' % (user, cmd, struct_name)) 350 | 351 | elif cmd == 'struc_deleted': 352 | struct_name = data['struc_name'].encode('ascii', 'ignore') 353 | vw.vprint('revsync: <%s> %s %s' % (user, cmd, struct_name)) 354 | 355 | elif cmd == 'struc_renamed': 356 | old_struct_name = data['old_name'].encode('ascii', 'ignore') 357 | new_struct_name = data['new_name'].encode('ascii', 'ignore') 358 | vw.vprint('revsync: <%s> %s %s %s' % (user, cmd, old_struct_name, new_struct_name)) 359 | 360 | elif cmd == 'struc_member_created': 361 | struct_name = data['struc_name'].encode('ascii', 'ignore') 362 | vw.vprint('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, member_name)) 363 | 364 | elif cmd == 'struc_member_deleted': 365 | struct_name = data['struc_name'].encode('ascii', 'ignore') 366 | member_name = '???' 367 | vw.vprint('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, member_name)) 368 | 369 | elif cmd == 'struc_member_renamed': 370 | struct_name = data['struc_name'].encode('ascii', 'ignore') 371 | member_name = data['member_name'].encode('ascii', 'ignore') 372 | vw.vprint('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, member_name)) 373 | 374 | elif cmd == 'struc_member_changed': 375 | struct_name = data['struc_name'].encode('ascii', 'ignore') 376 | vw.vprint('revsync: <%s> %s %s->%s' % (user, cmd, struct_name, m.name)) 377 | 378 | elif cmd == 'join': 379 | vw.vprint('revsync: <%s> joined' % (user)) 380 | 381 | elif cmd == 'coverage': 382 | vw.vprint("Updating Global Coverage") 383 | state.cov.update(json.loads(data['blocks'])) 384 | state.color_now = True 385 | 386 | else: 387 | vw.vprint('revsync: unknown cmd %s' % data) 388 | 389 | except Exception as e: 390 | vw.vprint('onmsg error: %r' % e) 391 | 392 | def revsync_callback(vw): 393 | def callback(key, data, replay=False): 394 | onmsg(vw, key, data, replay) 395 | return callback 396 | 397 | 398 | ### handle local events and hand up to REDIS 399 | import vivisect.base as viv_base 400 | class VivEventClient(viv_base.VivEventCore): 401 | def __init__(self, vw): 402 | viv_base.VivEventCore.__init__(self, vw) 403 | self.mythread = self._ve_fireListener() 404 | self.state = vw.getMeta('revsync') 405 | 406 | # make sure all VA's are reduced to base-addr-offsets 407 | def VWE_COMMENT(self, vw, event, locinfo): 408 | logger.info("%r %r %r" % (vw, event, locinfo)) 409 | cmt_addr, cmt = locinfo 410 | # make sure something has changed (and that we're not repeating what we just received from revsync 411 | # publish comment to revsync 412 | last_cmt = self.state.cmt_changes.get(cmt_addr) 413 | if last_cmt is None or last_cmt != cmt: 414 | # new/changed comment, publish 415 | try: 416 | fhash, fname = self.state.getHashByAddr(cmt_addr) 417 | addr = get_can_addr(vw, cmt_addr) 418 | changed = self.state.comments.parse_comment_update(addr, client.nick, cmt) 419 | vw.vprint('revsync: user changed comment: %#x, %s' % (addr, changed)) 420 | publish(vw, {'cmd': 'comment', 'addr': addr, 'text': changed}, fhash) 421 | self.state.cmt_changes[cmt_addr] = changed 422 | except NoChange: 423 | pass 424 | 425 | def VWE_SETNAME(self, vw, event, locinfo): 426 | logger.info("%r %r %r" % (vw, event, locinfo)) 427 | name_addr, name = locinfo 428 | addr = get_can_addr(vw, name_addr) 429 | if self.state.syms.get(addr) != name: 430 | # name changed, publish 431 | fhash, fname = self.state.getHashByAddr(name_addr) 432 | vw.vprint('revsync: user renamed symbol at %#x: %s' % (addr, name)) 433 | publish(vw, {'cmd': 'rename', 'addr': addr, 'text': name}, fhash) 434 | self.state.syms[addr] = name 435 | 436 | def VWE_SETFUNCARGS(self, vw, event, loc): 437 | vw.vprint("%r %r %r" % (vw, event, locinfo)) 438 | 439 | def VWE_SETFUNCMETA(self, vw, event, loc): 440 | vw.vprint("%r %r %r" % (vw, event, locinfo)) 441 | 442 | 443 | def VWE_ADDLOCATION(self, vw, event, loc): 444 | # what kind of location? 445 | # * LOC_STRUCT 446 | # * LOC_STRING 447 | # * LOC_UNICODE 448 | # * LOC_POINTER 449 | # * LOC_NUMBER 450 | # * LOC_OP 451 | if loc[L_LTYPE] is LOC_OP: 452 | return 453 | vw.vprint("%r %r %r" % (vw, event, locinfo)) 454 | 455 | def VWE_DELLOCATION(self, vw, event, loc): 456 | #vw.vprint("%r %r %r" % (vw, event, locinfo)) 457 | pass 458 | 459 | def VWE_SETMETA(self, vw, event, loc): 460 | vw.vprint("%r %r %r" % (vw, event, locinfo)) 461 | 462 | def VWE_ADDFILE(self, vw, event, loc): 463 | #vw.vprint("%r %r %r" % (vw, event, locinfo)) 464 | pass 465 | 466 | def VWE_ADDFUNCTION(self, vw, event, loc): 467 | #vw.vprint("%r %r %r" % (vw, event, locinfo)) 468 | pass 469 | 470 | def VWE_DELFUNCTION(self, vw, event, loc): 471 | #vw.vprint("%r %r %r" % (vw, event, locinfo)) 472 | pass 473 | 474 | def VWE_ADDCOLOR(self, vw, event, loc): 475 | vw.vprint("%r %r %r" % (vw, event, locinfo)) 476 | 477 | def VWE_DELCOLOR(self, vw, event, loc): 478 | vw.vprint("%r %r %r" % (vw, event, locinfo)) 479 | 480 | def VWE_CHAT(self, vw, event, loc): 481 | #vw.vprint("%r %r %r" % (vw, event, locinfo)) 482 | pass 483 | 484 | 485 | client = None 486 | evtdist = None 487 | 488 | 489 | def revsync_load(vw): 490 | global client, evtdist 491 | vw.vprint('Connecting to RevSync Server') 492 | 493 | ### hook into the viv event stream 494 | 495 | # lets ensure auto-analysis is finished by forcing another analysis 496 | analyze(vw) 497 | 498 | if client is None: 499 | vw.vprint("creating a new revsync connection") 500 | client = Client(**config) 501 | 502 | state = vw.metadata.get('revsync') 503 | if state: 504 | vw.vprint("closing existing revsync state") 505 | state.close() 506 | 507 | vw.vprint('revsync: working...') 508 | state = vw.metadata['revsync'] = State(vw) 509 | vw.vprint('revsync: connecting with hashes: %s' % repr([x for x in state.genFhashes()])) 510 | 511 | for fhash in state.genFhashes(): 512 | client.join(fhash, revsync_callback(vw)) 513 | 514 | vw.vprint('revsync: connected!') 515 | 516 | if evtdist is None: 517 | evtdist = VivEventClient(vw) 518 | 519 | 520 | def toggle_visits(vw): 521 | state = State.get(vw) 522 | state.show_visits = not state.show_visits 523 | if state.show_visits: 524 | vw.vprint("Visit Visualization Enabled (Red)") 525 | else: 526 | vw.vprint("Visit Visualization Disabled (Red)") 527 | state.color_now = True 528 | 529 | def toggle_time(vw): 530 | state = State.get(vw) 531 | state.show_time = not state.show_time 532 | if state.show_time: 533 | vw.vprint("Time Visualization Enabled (Blue)") 534 | else: 535 | vw.vprint("Time Visualization Disabled (Blue)") 536 | state.color_now = True 537 | 538 | def toggle_visitors(vw): 539 | state = State.get(vw) 540 | state.show_visitors = not state.show_visitors 541 | if state.show_visitors: 542 | vw.vprint("Visitor Visualization Enabled (Green)") 543 | else: 544 | vw.vprint("Visitor Visualization Disabled (Green)") 545 | state.color_now = True 546 | 547 | def toggle_track(vw): 548 | state = State.get(vw) 549 | state.track_coverage = not state.track_coverage 550 | if state.track_coverage: 551 | vw.vprint("Tracking Enabled") 552 | else: 553 | vw.vprint("Tracking Disabled") 554 | 555 | 556 | ######### register the plugin ######### 557 | from vqt.main import idlethread 558 | 559 | @idlethread 560 | def vivExtension(vw, vwgui): 561 | vwgui.vqAddMenuField('&Plugins.&revsync.&Coverage: Toggle Tracking', toggle_track, args=(vw,)) 562 | vwgui.vqAddMenuField('&Plugins.&revsync.&Coverage: Toggle Visits (RED)', toggle_visits, args=(vw,)) 563 | vwgui.vqAddMenuField('&Plugins.&revsync.&Coverage: Toggle Time (BLUE)', toggle_time, args=(vw,)) 564 | vwgui.vqAddMenuField('&Plugins.&revsync.&Coverage: Toggle Visitors (GREEN)', toggle_visitors, args=(vw,)) 565 | vwgui.vqAddMenuField('&Plugins.&revsync.&load: Load revsync for binary(s) in this workspace', revsync_load, args=(vw,)) 566 | --------------------------------------------------------------------------------