├── .python-version ├── src ├── min_matrix.npy ├── train_ppo_mlp_cov.py ├── recorder.py ├── train_ppo_mlp.py ├── test_trained_cov_mlp.py ├── train_ppo_cnn.py ├── train_ppo_encoded.py ├── train_ppo_cnn_comm.py ├── train_dqn_cnn.py ├── train_ppo_multi.py ├── train_dqn_multi.py ├── train_descentralized_ppo_cnn.py ├── greedy_heuristic.py ├── test_ppo_cnn.py ├── train_ppo_cnn_cov.py ├── test_trained_search.py ├── test_trained_cov.py ├── a_star_coverage.py ├── train_ppo_cnn_lstm.py └── test_trained_cnn_lstm.py ├── LICENSE ├── README.md ├── requirements.txt ├── .gitignore └── imgs └── drone.svg /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.3 2 | -------------------------------------------------------------------------------- /src/min_matrix.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfeinsper/drone-swarm-search-algorithms/HEAD/src/min_matrix.npy -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 PFE Embraer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI Release 🚀](https://badge.fury.io/py/DSSE.svg)](https://badge.fury.io/py/DSSE) 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-brightgreen.svg?style=flat)](https://github.com/pfeinsper/drone-swarm-search/blob/main/LICENSE) 3 | [![PettingZoo version dependency](https://img.shields.io/badge/PettingZoo-v1.22.3-blue)]() 4 | ![GitHub stars](https://img.shields.io/github/stars/pfeinsper/drone-swarm-search-algorithms) 5 | 6 | # DSSE Icon Algorithms for Drone Swarm Search (DSSE) 7 | 8 | Welcome to the official GitHub repository for the Drone Swarm Search (DSSE) algorithms. These algorithms are specifically tailored for reinforcement learning environments aimed at optimizing drone swarm coordination and search efficiency. 9 | 10 | Explore a diverse range of implementations that leverage the latest advancements in machine learning to solve complex coordination tasks in dynamic and unpredictable environments. 11 | 12 | ## 📚 Documentation Links 13 | 14 | - **[Documentation Site](https://pfeinsper.github.io/drone-swarm-search/)**: Access detailed tutorials, usage examples, and comprehensive technical documentation. This resource is essential for understanding the DSSE framework and integrating these algorithms into your projects effectively. 15 | 16 | - **[DSSE Training Environment Repository](https://github.com/pfeinsper/drone-swarm-search)**: Visit the repository for the DSSE training environment, where you can access the core environment setups and configurations used for developing and testing the algorithms. 17 | 18 | - **[PyPI Repository](https://pypi.org/project/DSSE/)**: Download the latest release of DSSE, view the version history, and find installation instructions. Keep up with the latest updates and improvements to the algorithms. 19 | 20 | ## 🆘 Support and Community 21 | 22 | Run into a snag? Have a suggestion? Join our community on GitHub! Submit your queries, report bugs, or contribute to discussions by visiting our [issues page](https://github.com/pfeinsper/drone-swarm-search-algorithms/issues). Your input helps us improve and evolve. 23 | -------------------------------------------------------------------------------- /src/train_ppo_mlp_cov.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from DSSE import CoverageDroneSwarmSearch 3 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllFlattenWrapper 4 | import ray 5 | from ray import tune 6 | from ray.rllib.algorithms.ppo import PPOConfig 7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 8 | from ray.tune.registry import register_env 9 | from torch import nn 10 | import torch 11 | import numpy as np 12 | 13 | 14 | def env_creator(args): 15 | print("-------------------------- ENV CREATOR --------------------------") 16 | N_AGENTS = 2 17 | # 6 hours of simulation, 600 radius 18 | env = CoverageDroneSwarmSearch( 19 | timestep_limit=200, drone_amount=N_AGENTS, prob_matrix_path="min_matrix.npy" 20 | ) 21 | env = AllFlattenWrapper(env) 22 | grid_size = env.grid_size 23 | print("Grid size: ", grid_size) 24 | positions = [ 25 | (0, grid_size // 2), 26 | (grid_size - 1, grid_size // 2), 27 | ] 28 | env = RetainDronePosWrapper(env, positions) 29 | return env 30 | 31 | def position_on_diagonal(grid_size, drone_amount): 32 | positions = [] 33 | center = grid_size // 2 34 | for i in range(-drone_amount // 2, drone_amount // 2): 35 | positions.append((center + i, center + i)) 36 | return positions 37 | 38 | def position_on_circle(grid_size, drone_amount, radius): 39 | positions = [] 40 | center = grid_size // 2 41 | angle_increment = 2 * np.pi / drone_amount 42 | 43 | for i in range(drone_amount): 44 | angle = i * angle_increment 45 | x = center + int(radius * np.cos(angle)) 46 | y = center + int(radius * np.sin(angle)) 47 | positions.append((x, y)) 48 | 49 | return positions 50 | 51 | 52 | if __name__ == "__main__": 53 | ray.init() 54 | 55 | env_name = "DSSE_Coverage" 56 | 57 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 58 | 59 | config = ( 60 | PPOConfig() 61 | .environment(env=env_name) 62 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto", num_envs_per_worker=4) 63 | .training( 64 | train_batch_size=8192 * 3, 65 | lr=8e-6, 66 | gamma=0.9999999, 67 | lambda_=0.9, 68 | use_gae=True, 69 | entropy_coeff=0.01, 70 | vf_clip_param=100000, 71 | sgd_minibatch_size=300, 72 | num_sgd_iter=10, 73 | model={ 74 | "fcnet_hiddens": [512, 256], 75 | }, 76 | ) 77 | .experimental(_disable_preprocessor_api=True) 78 | .debugging(log_level="ERROR") 79 | .framework(framework="torch") 80 | .resources(num_gpus=1) 81 | ) 82 | 83 | curr_path = pathlib.Path().resolve() 84 | tune.run( 85 | "PPO", 86 | name="PPO_" + input("Exp name: "), 87 | # resume=True, 88 | stop={"timesteps_total": 40_000_000}, 89 | checkpoint_freq=25, 90 | storage_path=f"{curr_path}/ray_res/" + env_name, 91 | config=config.to_dict(), 92 | ) 93 | -------------------------------------------------------------------------------- /src/recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | PygameRecord - A utility for recording Pygame screens as GIFS. 3 | This module provides a class, PygameRecord, which can be used to record Pygame 4 | animations and save them as GIF files. It captures frames from the Pygame display 5 | and saves them as images, then combines them into a GIF file. 6 | Credits: 7 | - Author: Ricardo Ribeiro Rodrigues 8 | - Date: 21/03/2024 9 | - source: https://gist.github.com/RicardoRibeiroRodrigues/9c40f36909112950860a410a565de667 10 | Usage: 11 | 1. Initialize PygameRecord with a filename and desired frames per second (fps). 12 | 2. Enter a Pygame event loop. 13 | 3. Add frames to the recorder at desired intervals. 14 | 4. When done recording, exit the Pygame event loop. 15 | 5. The recorded GIF will be saved automatically. 16 | """ 17 | 18 | import pygame 19 | from PIL import Image 20 | import numpy as np 21 | 22 | 23 | class PygameRecord: 24 | def __init__(self, filename: str, fps: int): 25 | self.fps = fps 26 | self.filename = filename 27 | self.frames = [] 28 | 29 | def add_frame(self): 30 | curr_surface = pygame.display.get_surface() 31 | x3 = pygame.surfarray.array3d(curr_surface) 32 | x3 = np.moveaxis(x3, 0, 1) 33 | array = Image.fromarray(np.uint8(x3)) 34 | self.frames.append(array) 35 | 36 | def save(self): 37 | self.frames[0].save( 38 | self.filename, 39 | save_all=True, 40 | optimize=False, 41 | append_images=self.frames[1:], 42 | loop=0, 43 | duration=int(1000 / self.fps), 44 | ) 45 | 46 | def __enter__(self): 47 | return self 48 | 49 | def __exit__(self, exc_type, exc_value, traceback): 50 | if exc_type is not None: 51 | print(f"An exception of type {exc_type} occurred: {exc_value}") 52 | self.save() 53 | # Return False if you want exceptions to propagate, True to suppress them 54 | return False 55 | 56 | 57 | if __name__ == "__main__": 58 | # Example usage 59 | from random import randint 60 | 61 | FPS = 30 62 | # Init the recorder with the output file and the desired FPS 63 | with PygameRecord("output.gif", FPS) as recorder: 64 | pygame.init() 65 | screen = pygame.display.set_mode((400, 400)) 66 | running = True 67 | clock = pygame.time.Clock() 68 | n_frames = 90 69 | while running: 70 | for event in pygame.event.get(): 71 | if event.type == pygame.QUIT: 72 | running = False 73 | screen.fill((0, 0, 0)) 74 | pygame.draw.circle( 75 | screen, 76 | (randint(0, 255), randint(0, 255), randint(0, 255)), 77 | (200, 200), 78 | 50, 79 | ) 80 | recorder.add_frame() # Add frame to recorder 81 | pygame.display.flip() 82 | clock.tick(FPS) 83 | # Used here to limit the size of the GIF, not necessary for normal usage. 84 | n_frames -= 1 85 | if n_frames == 0: 86 | break 87 | recorder.save() 88 | pygame.quit() -------------------------------------------------------------------------------- /src/train_ppo_mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import TopNProbsWrapper 5 | import ray 6 | from ray import tune 7 | from ray.rllib.algorithms.ppo import PPOConfig 8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 9 | from ray.rllib.models import ModelCatalog 10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 11 | from ray.tune.registry import register_env 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class MLPModel(TorchModelV2, nn.Module): 17 | def __init__( 18 | self, 19 | obs_space, 20 | act_space, 21 | num_outputs, 22 | model_config, 23 | name, 24 | **kw, 25 | ): 26 | print("OBSSPACE: ", obs_space) 27 | TorchModelV2.__init__( 28 | self, obs_space, act_space, num_outputs, model_config, name, **kw 29 | ) 30 | nn.Module.__init__(self) 31 | 32 | self.model = nn.Sequential( 33 | nn.Linear(obs_space.shape[0], 512), 34 | nn.ReLU(), 35 | nn.Linear(512, 256), 36 | nn.ReLU(), 37 | ) 38 | self.policy_fn = nn.Linear(256, num_outputs) 39 | self.value_fn = nn.Linear(256, 1) 40 | 41 | def forward(self, input_dict, state, seq_lens): 42 | input_ = input_dict["obs"].float() 43 | value_input = self.model(input_) 44 | 45 | self._value_out = self.value_fn(value_input) 46 | return self.policy_fn(value_input), state 47 | 48 | def value_function(self): 49 | return self._value_out.flatten() 50 | 51 | 52 | def env_creator(args): 53 | env = DroneSwarmSearch( 54 | drone_amount=4, 55 | grid_size=20, 56 | dispersion_inc=0.1, 57 | person_initial_position=(10, 10), 58 | ) 59 | env = TopNProbsWrapper(env, 10) 60 | return env 61 | 62 | 63 | if __name__ == "__main__": 64 | ray.init() 65 | 66 | env_name = "DSSE" 67 | 68 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 69 | ModelCatalog.register_custom_model("MLPModel", MLPModel) 70 | 71 | config = ( 72 | PPOConfig() 73 | .environment(env=env_name) 74 | .rollouts(num_rollout_workers=5, rollout_fragment_length="auto") 75 | .training( 76 | train_batch_size=8192, 77 | lr=4e-5, 78 | gamma=0.99999, 79 | lambda_=0.9, 80 | use_gae=True, 81 | clip_param=0.4, 82 | grad_clip=None, 83 | entropy_coeff=0.1, 84 | vf_loss_coeff=0.25, 85 | vf_clip_param=4200, 86 | sgd_minibatch_size=1024, 87 | num_sgd_iter=10, 88 | model={ 89 | "custom_model": "MLPModel", 90 | "_disable_preprocessor_api": True, 91 | }, 92 | ) 93 | .experimental(_disable_preprocessor_api=True) 94 | .debugging(log_level="ERROR") 95 | .framework(framework="torch") 96 | .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1"))) 97 | ) 98 | 99 | curr_path = pathlib.Path().resolve() 100 | tune.run( 101 | "PPO", 102 | name="PPO", 103 | stop={"timesteps_total": 5000000 if not os.environ.get("CI") else 50000}, 104 | checkpoint_freq=10, 105 | storage_path=f"{curr_path}/ray_res/" + env_name, 106 | config=config.to_dict(), 107 | ) 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pandas 2 | # DSSE 3 | # torch 4 | # ray[rllib] 5 | aigyminsper 6 | 7 | # 8 | 9 | aiohttp==3.9.5 10 | aiosignal==1.3.1 11 | anyio==4.6.0 12 | argon2-cffi==23.1.0 13 | argon2-cffi-bindings==21.2.0 14 | arrow==1.3.0 15 | asciitree==0.3.3 16 | asttokens==2.4.1 17 | async-lru==2.0.4 18 | attrs==24.2.0 19 | babel==2.16.0 20 | beautifulsoup4==4.12.3 21 | bleach==6.1.0 22 | boto3==1.35.31 23 | botocore==1.35.31 24 | cachier==3.0.1 25 | Cartopy==0.23.0 26 | certifi==2024.8.30 27 | cf_xarray==0.9.5 28 | cffi==1.17.1 29 | cftime==1.6.4 30 | charset-normalizer==3.3.2 31 | click==8.1.7 32 | cloudpickle==3.0.0 33 | cmocean==4.0.3 34 | coloredlogs==15.0.1 35 | comm==0.2.2 36 | contourpy==1.3.0 37 | copernicusmarine==1.3.3 38 | cycler==0.12.1 39 | dask==2024.9.1 40 | debugpy==1.8.6 41 | decorator==5.1.1 42 | defusedxml==0.7.1 43 | DSSE==1.1.9 44 | executing==2.1.0 45 | Farama-Notifications==0.0.4 46 | fasteners==0.19 47 | fastjsonschema==2.20.0 48 | fonttools==4.54.1 49 | fqdn==1.5.1 50 | frozenlist==1.4.1 51 | fsspec==2024.9.0 52 | GDAL==3.4.1 53 | geojson==3.1.0 54 | gymnasium==0.29.1 55 | h11==0.14.0 56 | httpcore==1.0.6 57 | httpx==0.27.2 58 | humanfriendly==10.0 59 | idna==3.10 60 | ipykernel==6.29.5 61 | ipython==8.28.0 62 | isoduration==20.11.0 63 | jedi==0.19.1 64 | Jinja2==3.1.4 65 | jmespath==1.0.1 66 | json5==0.9.25 67 | jsonpointer==3.0.0 68 | jsonschema==4.23.0 69 | jsonschema-specifications==2023.12.1 70 | jupyter-events==0.10.0 71 | jupyter-lsp==2.2.5 72 | jupyter_client==8.6.3 73 | jupyter_core==5.7.2 74 | jupyter_server==2.14.2 75 | jupyter_server_terminals==0.5.3 76 | jupyterlab==4.2.5 77 | jupyterlab_pygments==0.3.0 78 | jupyterlab_server==2.27.3 79 | kiwisolver==1.4.7 80 | llvmlite==0.43.0 81 | locket==1.0.0 82 | lxml==5.3.0 83 | MarkupSafe==2.1.5 84 | matplotlib==3.8.4 85 | matplotlib-inline==0.1.7 86 | mistune==3.0.2 87 | multidict==6.1.0 88 | nbclient==0.10.0 89 | nbconvert==7.16.4 90 | nbformat==5.10.4 91 | nc-time-axis==1.4.1 92 | nest-asyncio==1.6.0 93 | netCDF4==1.7.1.post2 94 | notebook==7.2.2 95 | notebook_shim==0.2.4 96 | numba==0.60.0 97 | numcodecs==0.13.0 98 | numpy==1.26.4 99 | opendrift==1.11.13 100 | overrides==7.7.0 101 | packaging==24.1 102 | pandas==2.2.3 103 | pandocfilters==1.5.1 104 | parso==0.8.4 105 | partd==1.4.2 106 | pettingzoo==1.24.3 107 | pexpect==4.9.0 108 | pillow==10.4.0 109 | platformdirs==4.3.6 110 | portalocker==2.10.1 111 | prometheus_client==0.21.0 112 | prompt_toolkit==3.0.48 113 | psutil==6.0.0 114 | ptyprocess==0.7.0 115 | pure_eval==0.2.3 116 | pycparser==2.22 117 | pygame==2.6.1 118 | Pygments==2.18.0 119 | pykdtree==1.3.13 120 | pynucos==3.2.2 121 | pyparsing==3.1.4 122 | pyproj==3.7.0 123 | pyshp==2.3.1 124 | pystac==1.10.1 125 | python-dateutil==2.9.0.post0 126 | python-json-logger==2.0.7 127 | pytz==2024.2 128 | PyYAML==6.0.2 129 | pyzmq==26.2.0 130 | referencing==0.35.1 131 | requests==2.32.3 132 | rfc3339-validator==0.1.4 133 | rfc3986-validator==0.1.1 134 | roaring-landmask==0.9.1 135 | rpds-py==0.20.0 136 | s3transfer==0.10.2 137 | scipy==1.14.1 138 | semver==3.0.2 139 | Send2Trash==1.8.3 140 | setuptools==75.1.0 141 | shapely==2.0.6 142 | six==1.16.0 143 | sniffio==1.3.1 144 | soupsieve==2.6 145 | stack-data==0.6.3 146 | terminado==0.18.1 147 | tinycss2==1.3.0 148 | toolz==0.12.1 149 | tornado==6.4.1 150 | tqdm==4.66.5 151 | traitlets==5.14.3 152 | trajan==0.6.3 153 | types-python-dateutil==2.9.0.20240906 154 | typing_extensions==4.12.2 155 | tzdata==2024.2 156 | uri-template==1.3.0 157 | urllib3==2.2.3 158 | utm==0.7.0 159 | watchdog==5.0.3 160 | wcwidth==0.2.13 161 | webcolors==24.8.0 162 | webencodings==0.5.1 163 | websocket-client==1.8.0 164 | xarray==2024.9.0 165 | xhistogram==0.3.2 166 | yarl==1.13.1 167 | zarr==2.18.3 168 | rsutils -------------------------------------------------------------------------------- /src/test_trained_cov_mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from recorder import PygameRecord 4 | from DSSE import CoverageDroneSwarmSearch 5 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper, AllFlattenWrapper 6 | import ray 7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 8 | from ray.rllib.models import ModelCatalog 9 | from ray.tune.registry import register_env 10 | from ray.rllib.algorithms.ppo import PPO 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | import torch.nn as nn 13 | import argparse 14 | import torch 15 | 16 | 17 | 18 | argparser = argparse.ArgumentParser() 19 | argparser.add_argument("--checkpoint", type=str, required=True) 20 | argparser.add_argument("--see", action="store_true", default=False) 21 | args = argparser.parse_args() 22 | 23 | 24 | def env_creator(_): 25 | print("-------------------------- ENV CREATOR --------------------------") 26 | N_AGENTS = 2 27 | render_mode = "human" if args.see else "ansi" 28 | # 6 hours of simulation, 600 radius 29 | env = CoverageDroneSwarmSearch( 30 | timestep_limit=200, drone_amount=N_AGENTS, prob_matrix_path="min_matrix.npy", render_mode=render_mode 31 | ) 32 | env = AllFlattenWrapper(env) 33 | grid_size = env.grid_size 34 | print("Grid size: ", grid_size) 35 | positions = [ 36 | (0, grid_size // 2), 37 | (grid_size - 1, grid_size // 2), 38 | ] 39 | env = RetainDronePosWrapper(env, positions) 40 | return env 41 | def position_on_diagonal(grid_size, drone_amount): 42 | positions = [] 43 | center = grid_size // 2 44 | for i in range(-drone_amount // 2, drone_amount // 2): 45 | positions.append((center + i, center + i)) 46 | return positions 47 | 48 | env = env_creator(None) 49 | register_env("DSSE_Coverage", lambda config: ParallelPettingZooEnv(env_creator(config))) 50 | ray.init() 51 | 52 | def print_mean(values, name): 53 | print(f"Mean of {name}: ", sum(values) / len(values)) 54 | 55 | checkpoint_path = args.checkpoint 56 | PPOagent = PPO.from_checkpoint(checkpoint_path) 57 | 58 | reward_sum = 0 59 | i = 0 60 | 61 | if args.see: 62 | obs, info = env.reset() 63 | with PygameRecord("test_trained.gif", 5) as rec: 64 | while env.agents: 65 | actions = {} 66 | for k, v in obs.items(): 67 | actions[k] = PPOagent.compute_single_action(v, explore=False) 68 | # print(v) 69 | # action = PPOagent.compute_actions(obs) 70 | obs, rw, term, trunc, info = env.step(actions) 71 | reward_sum += sum(rw.values()) 72 | i += 1 73 | rec.add_frame() 74 | print(info) 75 | else: 76 | rewards = [] 77 | cov_rate = [] 78 | steps_needed = [] 79 | repeated_cov = [] 80 | 81 | N_EVALS = 500 82 | for _ in range(N_EVALS): 83 | i = 0 84 | obs, info = env.reset() 85 | reward_sum = 0 86 | while env.agents: 87 | actions = {} 88 | for k, v in obs.items(): 89 | actions[k] = PPOagent.compute_single_action(v, explore=False) 90 | obs, rw, term, trunc, info = env.step(actions) 91 | reward_sum += sum(rw.values()) 92 | i += 1 93 | rewards.append(reward_sum) 94 | steps_needed.append(i) 95 | cov_rate.append(info["drone0"]["coverage_rate"]) 96 | repeated_cov.append(info["drone0"]["repeated_coverage"]) 97 | 98 | print_mean(rewards, "rewards") 99 | print_mean(steps_needed, "steps needed") 100 | print_mean(cov_rate, "coverage rate") 101 | print_mean(repeated_cov, "repeated coverage") 102 | 103 | print("Total reward: ", reward_sum) 104 | print("Total steps: ", i) 105 | print("Found: ", info) 106 | env.close() 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /imgs/drone.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/train_ppo_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 5 | import ray 6 | from ray import tune 7 | from ray.rllib.algorithms.ppo import PPOConfig 8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 9 | from ray.rllib.models import ModelCatalog 10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 11 | from ray.tune.registry import register_env 12 | from torch import nn 13 | import torch 14 | 15 | 16 | class CNNModel(TorchModelV2, nn.Module): 17 | def __init__( 18 | self, 19 | obs_space, 20 | act_space, 21 | num_outputs, 22 | model_config, 23 | name, 24 | **kw, 25 | ): 26 | print("OBSSPACE: ", obs_space) 27 | TorchModelV2.__init__( 28 | self, obs_space, act_space, num_outputs, model_config, name, **kw 29 | ) 30 | nn.Module.__init__(self) 31 | 32 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3) 33 | self.cnn = nn.Sequential( 34 | nn.Conv2d( 35 | in_channels=1, 36 | out_channels=16, 37 | kernel_size=(8, 8), 38 | stride=(1, 1), 39 | ), 40 | nn.Tanh(), 41 | nn.Conv2d( 42 | in_channels=16, 43 | out_channels=32, 44 | kernel_size=(4, 4), 45 | stride=(1, 1), 46 | ), 47 | nn.Tanh(), 48 | nn.Flatten(), 49 | nn.Linear(flatten_size, 256), 50 | nn.Tanh(), 51 | ) 52 | 53 | self.linear = nn.Sequential( 54 | nn.Linear(obs_space[0].shape[0], 512), 55 | nn.Tanh(), 56 | nn.Linear(512, 256), 57 | nn.Tanh(), 58 | ) 59 | 60 | self.join = nn.Sequential( 61 | nn.Linear(256 * 2, 256), 62 | nn.Tanh(), 63 | ) 64 | 65 | self.policy_fn = nn.Linear(256, num_outputs) 66 | self.value_fn = nn.Linear(256, 1) 67 | 68 | def forward(self, input_dict, state, seq_lens): 69 | input_positions = input_dict["obs"][0].float() 70 | input_matrix = input_dict["obs"][1].float() 71 | 72 | input_matrix = input_matrix.unsqueeze(1) 73 | cnn_out = self.cnn(input_matrix) 74 | linear_out = self.linear(input_positions) 75 | 76 | value_input = torch.cat((cnn_out, linear_out), dim=1) 77 | value_input = self.join(value_input) 78 | 79 | self._value_out = self.value_fn(value_input) 80 | return self.policy_fn(value_input), state 81 | 82 | def value_function(self): 83 | return self._value_out.flatten() 84 | 85 | 86 | def env_creator(args): 87 | env = DroneSwarmSearch( 88 | drone_amount=4, 89 | grid_size=40, 90 | dispersion_inc=0.1, 91 | person_initial_position=(20, 20), 92 | ) 93 | positions = [ 94 | (20, 0), 95 | (20, 39), 96 | (0, 20), 97 | (39, 20), 98 | ] 99 | env = AllPositionsWrapper(env) 100 | env = RetainDronePosWrapper(env, positions) 101 | return env 102 | 103 | 104 | if __name__ == "__main__": 105 | ray.init() 106 | 107 | env_name = "DSSE" 108 | 109 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 110 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 111 | 112 | config = ( 113 | PPOConfig() 114 | .environment(env=env_name) 115 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto") 116 | .training( 117 | train_batch_size=8192, 118 | lr=1e-5, 119 | gamma=0.9999999, 120 | lambda_=0.9, 121 | use_gae=True, 122 | # clip_param=0.3, 123 | # grad_clip=None, 124 | entropy_coeff=0.01, 125 | # vf_loss_coeff=0.25, 126 | # vf_clip_param=10, 127 | sgd_minibatch_size=300, 128 | num_sgd_iter=10, 129 | model={ 130 | "custom_model": "CNNModel", 131 | "_disable_preprocessor_api": True, 132 | }, 133 | ) 134 | .experimental(_disable_preprocessor_api=True) 135 | .debugging(log_level="ERROR") 136 | .framework(framework="torch") 137 | .resources(num_gpus=1) 138 | ) 139 | 140 | curr_path = pathlib.Path().resolve() 141 | tune.run( 142 | "PPO", 143 | name="PPO", 144 | stop={"timesteps_total": 20_000_000 if not os.environ.get("CI") else 50000, "episode_reward_mean": 1.75}, 145 | checkpoint_freq=10, 146 | storage_path=f"{curr_path}/ray_res/" + env_name, 147 | config=config.to_dict(), 148 | ) 149 | -------------------------------------------------------------------------------- /src/train_ppo_encoded.py: -------------------------------------------------------------------------------- 1 | import os 2 | from DSSE import DroneSwarmSearch 3 | from DSSE.environment.wrappers.matrix_encode_wrapper import MatrixEncodeWrapper 4 | import ray 5 | from ray import tune 6 | from ray.rllib.algorithms.ppo import PPOConfig 7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 8 | from ray.rllib.models import ModelCatalog 9 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 10 | from ray.tune.registry import register_env 11 | import torch 12 | from torch import nn 13 | 14 | 15 | class MLPModel(TorchModelV2, nn.Module): 16 | def __init__( 17 | self, 18 | obs_space, 19 | act_space, 20 | num_outputs, 21 | model_config, 22 | name, 23 | **kw, 24 | ): 25 | print( 26 | "______________________________ MLPModel ________________________________" 27 | ) 28 | print("OBSSPACE: ", obs_space) 29 | print("ACTSPACE: ", act_space) 30 | print("NUMOUTPUTS: ", num_outputs) 31 | TorchModelV2.__init__( 32 | self, obs_space, act_space, num_outputs, model_config, name, **kw 33 | ) 34 | nn.Module.__init__(self) 35 | self.conv1 = nn.Sequential( 36 | nn.Conv2d( 37 | in_channels=1, 38 | out_channels=16, 39 | kernel_size=(8, 8), 40 | stride=(1, 1), 41 | ), 42 | nn.ReLU(), 43 | # nn.MaxPool2d(kernel_size=2), 44 | nn.Conv2d( 45 | in_channels=16, 46 | out_channels=32, 47 | kernel_size=(4, 4), 48 | stride=(1, 1) 49 | ), 50 | nn.ReLU(), 51 | nn.Conv2d( 52 | in_channels=32, 53 | out_channels=64, 54 | kernel_size=(3, 3), 55 | stride=(1, 1) 56 | ), 57 | # nn.MaxPool2d(kernel_size=2), 58 | nn.Flatten(), 59 | ) 60 | grid_size = obs_space.shape[0] 61 | # Fully connected layers 62 | # F: (((W - K + 2P)/S) + 1) 63 | first_conv_output_size = (grid_size - 8) + 1 64 | second_conv_output_size = (first_conv_output_size - 4) + 1 65 | third_conv_output_size = (second_conv_output_size - 3) + 1 66 | self.fc1_input_size = third_conv_output_size * third_conv_output_size * 64 67 | # Apply a DENSE layer to the flattened CNN2 output 68 | self.fc1 = nn.Linear(self.fc1_input_size, 512) 69 | self.policy_fn = nn.Linear(512, num_outputs) 70 | self.value_fn = nn.Linear(512, 1) 71 | 72 | def forward(self, input_dict, state, seq_lens): 73 | input_ = input_dict["obs"] 74 | # Convert dims 75 | input_ = input_.unsqueeze(1) 76 | model_out = self.conv1(input_) 77 | model_out = self.fc1(model_out) 78 | self._value_out = self.value_fn(model_out) 79 | return self.policy_fn(model_out), state 80 | 81 | def value_function(self): 82 | return self._value_out.flatten() 83 | 84 | 85 | def env_creator(args): 86 | env = DroneSwarmSearch( 87 | drone_amount=4, 88 | grid_size=20, 89 | dispersion_inc=0.1, 90 | person_initial_position=(10, 10), 91 | ) 92 | env = MatrixEncodeWrapper(env) 93 | return env 94 | 95 | 96 | if __name__ == "__main__": 97 | ray.init() 98 | 99 | env_name = "DSSE" 100 | 101 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 102 | ModelCatalog.register_custom_model("MLPModel", MLPModel) 103 | 104 | config = ( 105 | PPOConfig() 106 | .environment(env=env_name) 107 | .rollouts(num_rollout_workers=5, rollout_fragment_length='auto') 108 | .training( 109 | train_batch_size=512, 110 | lr=2e-5, 111 | gamma=0.99999, 112 | lambda_=0.9, 113 | use_gae=True, 114 | clip_param=0.3, 115 | grad_clip=None, 116 | entropy_coeff=0.1, 117 | vf_loss_coeff=0.25, 118 | vf_clip_param=420, 119 | sgd_minibatch_size=64, 120 | num_sgd_iter=10, 121 | model={ 122 | "custom_model": "MLPModel", 123 | "_disable_preprocessor_api": True, 124 | }, 125 | ) 126 | .debugging(log_level="ERROR") 127 | .framework(framework="torch") 128 | .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1"))) 129 | ) 130 | config["_disable_preprocessor_api"] = False 131 | 132 | tune.run( 133 | "PPO", 134 | name="PPO", 135 | stop={"timesteps_total": 5000000 if not os.environ.get("CI") else 50000}, 136 | checkpoint_freq=10, 137 | # storage_path="ray_res/" + env_name, 138 | config=config.to_dict(), 139 | ) 140 | -------------------------------------------------------------------------------- /src/train_ppo_cnn_comm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 5 | from DSSE.environment.wrappers.communication_wrapper import CommunicationWrapper 6 | import ray 7 | from ray import tune 8 | from ray.rllib.algorithms.ppo import PPOConfig 9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 10 | from ray.rllib.models import ModelCatalog 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | from ray.tune.registry import register_env 13 | from torch import nn 14 | import torch 15 | 16 | 17 | class CNNModel(TorchModelV2, nn.Module): 18 | def __init__( 19 | self, 20 | obs_space, 21 | act_space, 22 | num_outputs, 23 | model_config, 24 | name, 25 | **kw, 26 | ): 27 | print("OBSSPACE: ", obs_space) 28 | TorchModelV2.__init__( 29 | self, obs_space, act_space, num_outputs, model_config, name, **kw 30 | ) 31 | nn.Module.__init__(self) 32 | 33 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3) 34 | self.cnn = nn.Sequential( 35 | nn.Conv2d( 36 | in_channels=1, 37 | out_channels=16, 38 | kernel_size=(8, 8), 39 | stride=(1, 1), 40 | ), 41 | nn.Tanh(), 42 | nn.Conv2d( 43 | in_channels=16, 44 | out_channels=32, 45 | kernel_size=(4, 4), 46 | stride=(1, 1), 47 | ), 48 | nn.Tanh(), 49 | nn.Flatten(), 50 | nn.Linear(flatten_size, 256), 51 | nn.Tanh(), 52 | ) 53 | 54 | self.linear = nn.Sequential( 55 | nn.Linear(obs_space[0].shape[0], 512), 56 | nn.Tanh(), 57 | nn.Linear(512, 256), 58 | nn.Tanh(), 59 | ) 60 | 61 | self.join = nn.Sequential( 62 | nn.Linear(256 * 2, 256), 63 | nn.Tanh(), 64 | ) 65 | 66 | self.policy_fn = nn.Linear(256, num_outputs) 67 | self.value_fn = nn.Linear(256, 1) 68 | 69 | def forward(self, input_dict, state, seq_lens): 70 | input_positions = input_dict["obs"][0].float() 71 | input_matrix = input_dict["obs"][1].float() 72 | 73 | input_matrix = input_matrix.unsqueeze(1) 74 | cnn_out = self.cnn(input_matrix) 75 | linear_out = self.linear(input_positions) 76 | 77 | value_input = torch.cat((cnn_out, linear_out), dim=1) 78 | value_input = self.join(value_input) 79 | 80 | self._value_out = self.value_fn(value_input) 81 | return self.policy_fn(value_input), state 82 | 83 | def value_function(self): 84 | return self._value_out.flatten() 85 | 86 | 87 | def env_creator(args): 88 | env = DroneSwarmSearch( 89 | drone_amount=4, 90 | grid_size=40, 91 | dispersion_inc=0.1, 92 | person_initial_position=(20, 20), 93 | ) 94 | positions = [ 95 | (20, 0), 96 | (20, 39), 97 | (0, 20), 98 | (39, 20), 99 | ] 100 | env = AllPositionsWrapper(env) 101 | env = CommunicationWrapper(env, n_steps=12) 102 | env = RetainDronePosWrapper(env, positions) 103 | return env 104 | 105 | 106 | if __name__ == "__main__": 107 | ray.init() 108 | 109 | env_name = "DSSE" 110 | 111 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 112 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 113 | 114 | config = ( 115 | PPOConfig() 116 | .environment(env=env_name) 117 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto", num_envs_per_worker=2) 118 | .training( 119 | train_batch_size=8192, 120 | lr=1e-5, 121 | gamma=0.9999999, 122 | lambda_=0.9, 123 | use_gae=True, 124 | # clip_param=0.3, 125 | # grad_clip=None, 126 | entropy_coeff=0.01, 127 | # vf_loss_coeff=0.25, 128 | # vf_clip_param=10, 129 | sgd_minibatch_size=300, 130 | num_sgd_iter=10, 131 | model={ 132 | "custom_model": "CNNModel", 133 | "_disable_preprocessor_api": True, 134 | }, 135 | ) 136 | .experimental(_disable_preprocessor_api=True) 137 | .debugging(log_level="ERROR") 138 | .framework(framework="torch") 139 | .resources(num_gpus=1) 140 | ) 141 | 142 | curr_path = pathlib.Path().resolve() 143 | tune.run( 144 | "PPO", 145 | name="PPO_COMM_WRAPPER", 146 | stop={"timesteps_total": 20_000_000 if not os.environ.get("CI") else 50000, "episode_reward_mean": 1.75}, 147 | checkpoint_freq=10, 148 | storage_path=f"{curr_path}/ray_res/" + env_name, 149 | config=config.to_dict(), 150 | ) 151 | -------------------------------------------------------------------------------- /src/train_dqn_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from drone_swarm_search.DSSE import DroneSwarmSearch 4 | from drone_swarm_search.DSSE.environment.wrappers import AllPositionsWrapper, RetainDronePosWrapper, TopNProbsWrapper 5 | # from DSSE.environment.wrappers import AllPositionsWrapper 6 | import ray 7 | from ray import tune 8 | from ray.rllib.algorithms.dqn import DQNConfig 9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 10 | from ray.rllib.models import ModelCatalog 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | from ray.tune.registry import register_env 13 | import torch 14 | from torch import nn 15 | import random 16 | 17 | class CNNModel(TorchModelV2, nn.Module): 18 | def __init__( 19 | self, 20 | obs_space, 21 | act_space, 22 | num_outputs, 23 | model_config, 24 | name, 25 | **kw, 26 | ): 27 | print("OBSSPACE: ", obs_space) 28 | TorchModelV2.__init__( 29 | self, obs_space, act_space, num_outputs, model_config, name, **kw 30 | ) 31 | nn.Module.__init__(self) 32 | 33 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3) 34 | self.cnn = nn.Sequential( 35 | nn.Conv2d( 36 | in_channels=1, 37 | out_channels=16, 38 | kernel_size=(8, 8), 39 | stride=(1, 1), 40 | ), 41 | nn.Tanh(), 42 | nn.Conv2d( 43 | in_channels=16, 44 | out_channels=32, 45 | kernel_size=(4, 4), 46 | stride=(1, 1), 47 | ), 48 | nn.Tanh(), 49 | nn.Flatten(), 50 | nn.Linear(flatten_size, 256), 51 | nn.Tanh(), 52 | ) 53 | 54 | self.linear = nn.Sequential( 55 | nn.Linear(obs_space[0].shape[0], 512), 56 | nn.Tanh(), 57 | nn.Linear(512, 256), 58 | nn.Tanh(), 59 | ) 60 | 61 | self.join = nn.Sequential( 62 | nn.Linear(256 * 2, 256), 63 | nn.Tanh(), 64 | ) 65 | 66 | self.policy_fn = nn.Linear(256, num_outputs) 67 | self.value_fn = nn.Linear(256, 1) 68 | 69 | def forward(self, input_dict, state, seq_lens): 70 | input_positions = input_dict["obs"][0].float() 71 | input_matrix = input_dict["obs"][1].float() 72 | 73 | input_matrix = input_matrix.unsqueeze(1) 74 | cnn_out = self.cnn(input_matrix) 75 | linear_out = self.linear(input_positions) 76 | 77 | value_input = torch.cat((cnn_out, linear_out), dim=1) 78 | value_input = self.join(value_input) 79 | 80 | self._value_out = self.value_fn(value_input) 81 | return self.policy_fn(value_input), state 82 | 83 | def value_function(self): 84 | return self._value_out.flatten() 85 | 86 | def env_creator(args): 87 | env = DroneSwarmSearch( 88 | drone_amount=4, 89 | grid_size=40, 90 | dispersion_inc=0.1, 91 | person_initial_position=(20, 20), 92 | ) 93 | positions = [ 94 | (20, 0), 95 | (20, 39), 96 | (0, 20), 97 | (39, 20), 98 | ] 99 | env = AllPositionsWrapper(env) 100 | env = RetainDronePosWrapper(env, positions) 101 | return env 102 | 103 | 104 | if __name__ == "__main__": 105 | ray.init() 106 | 107 | env_name = "DSSE" 108 | 109 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 110 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 111 | 112 | natural_value = 512/(14*20*1) 113 | config = ( 114 | DQNConfig() 115 | .exploration( 116 | exploration_config={ 117 | "type": "EpsilonGreedy", 118 | "initial_epsilon": 1.0, 119 | "final_epsilon": 0.15, 120 | "epsilon_timesteps": 400_000, 121 | } 122 | ) 123 | .environment(env=env_name) 124 | .rollouts(num_rollout_workers=14, rollout_fragment_length=20) 125 | .framework("torch") 126 | .debugging(log_level="ERROR") 127 | .resources(num_gpus=1) 128 | .experimental(_disable_preprocessor_api=True) 129 | .training( 130 | lr=1e-4, 131 | gamma=0.9999999, 132 | tau=0.01, 133 | train_batch_size=512, 134 | model={ 135 | "custom_model": "CNNModel", 136 | "_disable_preprocessor_api": True, 137 | }, 138 | target_network_update_freq=500, 139 | double_q=False, 140 | training_intensity=natural_value, 141 | v_min=0, 142 | v_max=2, 143 | ) 144 | ) 145 | 146 | curr_path = pathlib.Path().resolve() 147 | tune.run( 148 | "DQN", 149 | name="DQN_DSSE", 150 | stop={"timesteps_total": 100_000_000}, 151 | checkpoint_freq=200, 152 | storage_path=f"{curr_path}/ray_res/" + env_name, 153 | config=config.to_dict(), 154 | ) 155 | 156 | # Finalize Ray to free up resources 157 | ray.shutdown() 158 | -------------------------------------------------------------------------------- /src/train_ppo_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import AllPositionsWrapper 5 | import ray 6 | from ray import tune 7 | from ray.rllib.algorithms.ppo import PPOConfig 8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 9 | from ray.rllib.models import ModelCatalog 10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 11 | from ray.tune.registry import register_env 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class MLPModel(TorchModelV2, nn.Module): 17 | def __init__( 18 | self, 19 | obs_space, 20 | act_space, 21 | num_outputs, 22 | model_config, 23 | name, 24 | **kw, 25 | ): 26 | print("OBSSPACE: ", obs_space) 27 | TorchModelV2.__init__( 28 | self, obs_space, act_space, num_outputs, model_config, name, **kw 29 | ) 30 | nn.Module.__init__(self) 31 | 32 | # F: (((W - K + 2P)/S) + 1) 33 | grid_size = obs_space[1].shape[0] 34 | first_conv_output_size = (grid_size - 8) + 1 35 | second_conv_output_size = (first_conv_output_size - 4) + 1 36 | third_conv_output_size = (second_conv_output_size - 3) + 1 37 | self.fc1_input_size = third_conv_output_size * third_conv_output_size * 64 38 | 39 | self.conv1 = nn.Sequential( 40 | nn.Conv2d( 41 | in_channels=1, 42 | out_channels=16, 43 | kernel_size=(8, 8), 44 | stride=(1, 1), 45 | ), 46 | nn.ReLU(), 47 | # nn.MaxPool2d(kernel_size=2), 48 | nn.Conv2d( 49 | in_channels=16, 50 | out_channels=32, 51 | kernel_size=(4, 4), 52 | stride=(1, 1) 53 | ), 54 | nn.ReLU(), 55 | nn.Conv2d( 56 | in_channels=32, 57 | out_channels=64, 58 | kernel_size=(3, 3), 59 | stride=(1, 1) 60 | ), 61 | # nn.MaxPool2d(kernel_size=2), 62 | nn.Flatten(), 63 | nn.Linear(self.fc1_input_size, 512), 64 | nn.ReLU(), 65 | nn.LayerNorm(512), 66 | ) 67 | 68 | self.fc_scalar = nn.Sequential( 69 | nn.Linear(obs_space[0].shape[0], 128), 70 | nn.ReLU(), 71 | nn.Linear(128, 512), 72 | nn.ReLU(), 73 | nn.LayerNorm(512), 74 | ) 75 | 76 | self.unifier = nn.Sequential( 77 | nn.Linear(512 * 2, 512), 78 | nn.ReLU(), 79 | nn.LayerNorm(512), 80 | ) 81 | self.policy_fn = nn.Linear(512, num_outputs) 82 | self.value_fn = nn.Linear(512, 1) 83 | 84 | def forward(self, input_dict, state, seq_lens): 85 | input_ = input_dict["obs"] 86 | # Convert dims 87 | input_cnn = input_[1].unsqueeze(1) 88 | model_out = self.conv1(input_cnn) 89 | 90 | scalar_input = input_[0].float() 91 | scalar_out = self.fc_scalar(scalar_input) 92 | 93 | value_input = torch.cat((model_out, scalar_out), -1) 94 | value_input = self.unifier(value_input) 95 | 96 | self._value_out = self.value_fn(value_input) 97 | return self.policy_fn(value_input), state 98 | 99 | def value_function(self): 100 | return self._value_out.flatten() 101 | 102 | 103 | def env_creator(args): 104 | env = DroneSwarmSearch( 105 | drone_amount=4, 106 | grid_size=20, 107 | dispersion_inc=0.1, 108 | person_initial_position=(10, 10), 109 | ) 110 | env = AllPositionsWrapper(env) 111 | return env 112 | 113 | 114 | if __name__ == "__main__": 115 | ray.init() 116 | 117 | env_name = "DSSE" 118 | 119 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 120 | ModelCatalog.register_custom_model("MLPModel", MLPModel) 121 | 122 | config = ( 123 | PPOConfig() 124 | .environment(env=env_name) 125 | .rollouts(num_rollout_workers=4, rollout_fragment_length=128) 126 | .training( 127 | train_batch_size=512, 128 | lr=4e-5, 129 | gamma=0.99999, 130 | lambda_=0.9, 131 | use_gae=True, 132 | clip_param=0.4, 133 | grad_clip=None, 134 | entropy_coeff=0.1, 135 | vf_loss_coeff=0.25, 136 | vf_clip_param=420, 137 | sgd_minibatch_size=64, 138 | num_sgd_iter=10, 139 | model={ 140 | "custom_model": "MLPModel", 141 | "_disable_preprocessor_api": True, 142 | }, 143 | ) 144 | .experimental(_disable_preprocessor_api=True) 145 | .debugging(log_level="ERROR") 146 | .framework(framework="torch") 147 | .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1"))) 148 | ) 149 | 150 | curr_path = pathlib.Path().resolve() 151 | tune.run( 152 | "PPO", 153 | name="PPO", 154 | stop={"timesteps_total": 5000000 if not os.environ.get("CI") else 50000}, 155 | checkpoint_freq=10, 156 | # local_dir="ray_results/" + env_name, 157 | storage_path=f"{curr_path}/ray_res/" + env_name, 158 | config=config.to_dict(), 159 | ) 160 | -------------------------------------------------------------------------------- /src/train_dqn_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import AllPositionsWrapper, RetainDronePosWrapper 5 | import ray 6 | from ray import air 7 | from ray import tune 8 | from ray.rllib.algorithms.dqn.dqn import DQNConfig 9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 10 | from ray.rllib.models import ModelCatalog 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | from ray.tune.registry import register_env 13 | import torch 14 | from torch import nn 15 | 16 | 17 | class MLPModel(TorchModelV2, nn.Module): 18 | def __init__( 19 | self, 20 | obs_space, 21 | act_space, 22 | num_outputs, 23 | model_config, 24 | name, 25 | **kw, 26 | ): 27 | print("OBSSPACE: ", obs_space) 28 | TorchModelV2.__init__( 29 | self, obs_space, act_space, num_outputs, model_config, name, **kw 30 | ) 31 | nn.Module.__init__(self) 32 | 33 | # F: (((W - K + 2P)/S) + 1) 34 | grid_size = obs_space[1].shape[0] 35 | first_conv_output_size = (grid_size - 8) + 1 36 | second_conv_output_size = (first_conv_output_size - 4) + 1 37 | third_conv_output_size = (second_conv_output_size - 3) + 1 38 | self.fc1_input_size = third_conv_output_size * third_conv_output_size * 64 39 | 40 | self.conv1 = nn.Sequential( 41 | nn.Conv2d( 42 | in_channels=1, 43 | out_channels=16, 44 | kernel_size=(8, 8), 45 | stride=(1, 1), 46 | ), 47 | nn.ReLU(), 48 | # nn.MaxPool2d(kernel_size=2), 49 | nn.Conv2d( 50 | in_channels=16, 51 | out_channels=32, 52 | kernel_size=(4, 4), 53 | stride=(1, 1) 54 | ), 55 | nn.ReLU(), 56 | nn.Conv2d( 57 | in_channels=32, 58 | out_channels=64, 59 | kernel_size=(3, 3), 60 | stride=(1, 1) 61 | ), 62 | # nn.MaxPool2d(kernel_size=2), 63 | nn.Flatten(), 64 | nn.Linear(self.fc1_input_size, 512), 65 | nn.ReLU(), 66 | nn.LayerNorm(512), 67 | ) 68 | 69 | self.fc_scalar = nn.Sequential( 70 | nn.Linear(obs_space[0].shape[0], 128), 71 | nn.ReLU(), 72 | nn.Linear(128, 512), 73 | nn.ReLU(), 74 | nn.LayerNorm(512), 75 | ) 76 | 77 | self.unifier = nn.Sequential( 78 | nn.Linear(512 * 2, 512), 79 | nn.ReLU(), 80 | nn.LayerNorm(512), 81 | ) 82 | self.policy_fn = nn.Linear(512, num_outputs) 83 | 84 | def forward(self, input_dict, state, seq_lens): 85 | input_ = input_dict["obs"] 86 | # Convert dims 87 | input_cnn = input_[1].unsqueeze(1) 88 | model_out = self.conv1(input_cnn) 89 | 90 | scalar_input = input_[0].float() 91 | scalar_out = self.fc_scalar(scalar_input) 92 | 93 | value_input = torch.cat((model_out, scalar_out), -1) 94 | value_input = self.unifier(value_input) 95 | 96 | return self.policy_fn(value_input), state 97 | 98 | 99 | def env_creator(args): 100 | env = DroneSwarmSearch( 101 | drone_amount=4, 102 | grid_size=20, 103 | dispersion_inc=0.1, 104 | person_initial_position=(10, 10), 105 | ) 106 | env = AllPositionsWrapper(env) 107 | env = RetainDronePosWrapper(env, [(0, 0), (0, 19), (19, 0), (19, 19)]) 108 | return env 109 | 110 | 111 | if __name__ == "__main__": 112 | ray.init() 113 | 114 | env_name = "DSSE" 115 | 116 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 117 | ModelCatalog.register_custom_model("MLPModel", MLPModel) 118 | 119 | config = DQNConfig() 120 | config = config.environment(env=env_name) 121 | config = config.rollouts(num_rollout_workers=6, rollout_fragment_length="auto", num_envs_per_worker=2) 122 | config = config.training( 123 | train_batch_size=512, 124 | grad_clip=None, 125 | target_network_update_freq=1, 126 | tau=0.005, 127 | gamma=0.99999, 128 | n_step=1, 129 | double_q=True, 130 | dueling=False, 131 | model={"custom_model": "MLPModel", "_disable_preprocessor_api": True}, 132 | v_min=-800, 133 | v_max=800, 134 | ) 135 | config = config.exploration( 136 | exploration_config={ 137 | "type": "EpsilonGreedy", 138 | "initial_epsilon": 1.0, 139 | "final_epsilon": 0.05, 140 | "epsilon_timesteps": 350_000, 141 | } 142 | ) 143 | config = config.debugging(log_level="ERROR") 144 | config = config.framework(framework="torch") 145 | config = config.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1"))) 146 | config = config.experimental(_disable_preprocessor_api=True) 147 | 148 | curr_path = pathlib.Path().resolve() 149 | run_config = air.RunConfig( 150 | stop={"timesteps_total": 10_000_000 if not os.environ.get("CI") else 50000}, 151 | storage_path=f"{curr_path}/ray_res/" + env_name, 152 | checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10), 153 | ) 154 | tune.Tuner( 155 | "DQN", 156 | run_config=run_config, 157 | param_space=config.to_dict() 158 | ).fit() 159 | -------------------------------------------------------------------------------- /src/train_descentralized_ppo_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 5 | import ray 6 | from ray import tune 7 | from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec 8 | from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec 9 | from ray.rllib.algorithms.ppo import PPOConfig 10 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 11 | from ray.rllib.models import ModelCatalog 12 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 13 | from ray.tune.registry import register_env 14 | import torch 15 | from torch import nn 16 | 17 | 18 | class CNNModel(TorchModelV2, nn.Module): 19 | def __init__( 20 | self, 21 | obs_space, 22 | act_space, 23 | num_outputs, 24 | model_config, 25 | name, 26 | **kw, 27 | ): 28 | print("OBSSPACE: ", obs_space) 29 | TorchModelV2.__init__( 30 | self, obs_space, act_space, num_outputs, model_config, name, **kw 31 | ) 32 | nn.Module.__init__(self) 33 | 34 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3) 35 | self.cnn = nn.Sequential( 36 | nn.Conv2d( 37 | in_channels=1, 38 | out_channels=16, 39 | kernel_size=(8, 8), 40 | stride=(1, 1), 41 | ), 42 | nn.Tanh(), 43 | nn.Conv2d( 44 | in_channels=16, 45 | out_channels=32, 46 | kernel_size=(4, 4), 47 | stride=(1, 1), 48 | ), 49 | nn.Tanh(), 50 | nn.Flatten(), 51 | nn.Linear(flatten_size, 256), 52 | nn.Tanh(), 53 | ) 54 | 55 | self.linear = nn.Sequential( 56 | nn.Linear(obs_space[0].shape[0], 512), 57 | nn.Tanh(), 58 | nn.Linear(512, 256), 59 | nn.Tanh(), 60 | ) 61 | 62 | self.join = nn.Sequential( 63 | nn.Linear(256 * 2, 256), 64 | nn.Tanh(), 65 | ) 66 | 67 | self.policy_fn = nn.Linear(256, num_outputs) 68 | self.value_fn = nn.Linear(256, 1) 69 | 70 | def forward(self, input_dict, state, seq_lens): 71 | input_positions = input_dict["obs"][0].float() 72 | input_matrix = input_dict["obs"][1].float() 73 | 74 | input_matrix = input_matrix.unsqueeze(1) 75 | cnn_out = self.cnn(input_matrix) 76 | linear_out = self.linear(input_positions) 77 | 78 | value_input = torch.cat((cnn_out, linear_out), dim=1) 79 | value_input = self.join(value_input) 80 | 81 | self._value_out = self.value_fn(value_input) 82 | return self.policy_fn(value_input), state 83 | 84 | def value_function(self): 85 | return self._value_out.flatten() 86 | 87 | 88 | 89 | def env_creator(args): 90 | env = DroneSwarmSearch( 91 | drone_amount=4, 92 | grid_size=40, 93 | dispersion_inc=0.1, 94 | person_initial_position=(20, 20), 95 | ) 96 | positions = [ 97 | (20, 0), 98 | (20, 39), 99 | (0, 20), 100 | (39, 20), 101 | ] 102 | env = AllPositionsWrapper(env) 103 | env = RetainDronePosWrapper(env, positions) 104 | return env 105 | 106 | 107 | if __name__ == "__main__": 108 | ray.init() 109 | env_name = "DSSE" 110 | 111 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 112 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 113 | 114 | # Policies are called just like the agents (exact 1:1 mapping). 115 | num_agents = 4 116 | policies = {f"drone{i}" for i in range(num_agents)} 117 | # policies = {f"drone{i}": (None, obs_space, act_space, {}) for i in range(num_agents)} 118 | 119 | config = ( 120 | PPOConfig() 121 | .environment(env=env_name) 122 | .rollouts(num_rollout_workers=4, rollout_fragment_length="auto") 123 | .multi_agent( 124 | policies=policies, 125 | # Exact 1:1 mapping from AgentID to ModuleID. 126 | policy_mapping_fn=(lambda aid, *args, **kwargs: aid), 127 | ) 128 | .training( 129 | train_batch_size=8192, 130 | lr=1e-5, 131 | gamma=0.9999999, 132 | lambda_=0.9, 133 | use_gae=True, 134 | entropy_coeff=0.01, 135 | sgd_minibatch_size=300, 136 | num_sgd_iter=10, 137 | model={ 138 | "custom_model": "CNNModel", 139 | "_disable_preprocessor_api": True, 140 | }, 141 | ) 142 | .rl_module( 143 | model_config_dict={"vf_share_layers": True}, 144 | rl_module_spec=MultiAgentRLModuleSpec( 145 | module_specs={p: SingleAgentRLModuleSpec() for p in policies}, 146 | ), 147 | ) 148 | .experimental(_disable_preprocessor_api=True) 149 | .debugging(log_level="ERROR") 150 | .framework(framework="torch") 151 | .resources(num_gpus=1) 152 | ) 153 | 154 | curr_path = pathlib.Path().resolve() 155 | tune.run( 156 | "PPO", 157 | name="PPO", 158 | stop={"timesteps_total": 5_000_000}, 159 | checkpoint_freq=30, 160 | keep_checkpoints_num=200, 161 | storage_path=f"{curr_path}/ray_res/" + env_name, 162 | config=config, 163 | ) 164 | -------------------------------------------------------------------------------- /src/greedy_heuristic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from DSSE import Actions 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import RetainDronePosWrapper 5 | from recorder import PygameRecord 6 | 7 | 8 | class GreedyAgent: 9 | def __call__(self, obs, agents): 10 | """ 11 | Greedy approach: Rush and search for the greatest prob. 12 | """ 13 | drone_actions = {} 14 | prob_matrix = obs["drone0"][1] 15 | n_drones = len(agents) 16 | 17 | drones_positions = {drone: obs[drone][0] for drone in agents} 18 | # Get n_drones greatest probabilities. 19 | greatest_probs = np.argsort(prob_matrix, axis=None)[-n_drones:] 20 | 21 | for index, drone in enumerate(agents): 22 | greatest_prob = np.unravel_index(greatest_probs[index], prob_matrix.shape) 23 | 24 | drone_obs = obs[drone] 25 | drone_action = self.choose_drone_action(drone_obs[0], greatest_prob) 26 | 27 | new_position = self.get_new_position(drone_obs[0], drone_action) 28 | 29 | # Avoid colision by waiting 1 timestep 30 | if self.drones_colide(drones_positions, new_position): 31 | drone_actions[drone] = Actions.SEARCH.value 32 | else: 33 | drone_actions[drone] = drone_action 34 | drones_positions[drone] = new_position 35 | return drone_actions 36 | 37 | def get_new_position(self, position: tuple, action: int) -> tuple: 38 | match action: 39 | case Actions.LEFT.value: 40 | new_position = (position[0] - 1, position[1]) 41 | case Actions.RIGHT.value: 42 | new_position = (position[0] + 1, position[1]) 43 | case Actions.UP.value: 44 | new_position = (position[0], position[1] - 1) 45 | case Actions.DOWN.value: 46 | new_position = (position[0], position[1] + 1) 47 | case Actions.UP_LEFT.value: 48 | new_position = (position[0] - 1, position[1] - 1) 49 | case Actions.UP_RIGHT.value: 50 | new_position = (position[0] + 1, position[1] - 1) 51 | case Actions.DOWN_LEFT.value: 52 | new_position = (position[0] - 1, position[1] + 1) 53 | case Actions.DOWN_RIGHT.value: 54 | new_position = (position[0] + 1, position[1] + 1) 55 | case _: 56 | new_position = position 57 | return new_position 58 | 59 | def choose_drone_action(self, drone_position: tuple, greatest_prob_position) -> int: 60 | greatest_prob_y, greatest_prob_x = greatest_prob_position 61 | drone_x, drone_y = drone_position 62 | is_x_greater = greatest_prob_x > drone_x 63 | is_y_greater = greatest_prob_y > drone_y 64 | is_x_lesser = greatest_prob_x < drone_x 65 | is_y_lesser = greatest_prob_y < drone_y 66 | is_x_equal = greatest_prob_x == drone_x 67 | is_y_equal = greatest_prob_y == drone_y 68 | 69 | if is_x_equal and is_y_equal: 70 | return Actions.SEARCH.value 71 | 72 | if is_x_equal: 73 | return Actions.DOWN.value if is_y_greater else Actions.UP.value 74 | elif is_y_equal: 75 | return Actions.LEFT.value if is_x_lesser else Actions.RIGHT.value 76 | 77 | # Movimento na diagonal 78 | if is_x_greater and is_y_greater: 79 | return Actions.DOWN_RIGHT.value 80 | elif is_x_greater and is_y_lesser: 81 | return Actions.UP_RIGHT.value 82 | elif is_x_lesser and is_y_greater: 83 | return Actions.DOWN_LEFT.value 84 | elif is_x_lesser and is_y_lesser: 85 | return Actions.UP_LEFT.value 86 | 87 | def drones_colide(self, drones_positions: dict, new_drone_position: tuple) -> bool: 88 | return new_drone_position in drones_positions.values() 89 | 90 | def __repr__(self) -> str: 91 | return "greedy" 92 | 93 | 94 | see = input("Want to see ?? (y/n): ") 95 | see = see.lower() == "y" 96 | def env_creator(_): 97 | env = DroneSwarmSearch( 98 | render_mode="human" if see else "ansi", 99 | render_grid=True, 100 | drone_amount=4, 101 | grid_size=40, 102 | dispersion_inc=0.1, 103 | person_initial_position=(20, 20), 104 | ) 105 | positions = [ 106 | (20, 0), 107 | (20, 39), 108 | (0, 20), 109 | (39, 20), 110 | ] 111 | env = RetainDronePosWrapper(env, positions) 112 | return env 113 | 114 | greedy_agent = GreedyAgent() 115 | 116 | env = env_creator(None) 117 | 118 | if see: 119 | with PygameRecord("greedy.gif", 5) as recorder: 120 | obs, info = env.reset() 121 | while env.agents: 122 | actions = greedy_agent(obs, env.agents) 123 | obs, rewards, terminations, truncations, infos = env.step(actions) 124 | recorder.add_frame() 125 | exit(1) 126 | 127 | rewards = [] 128 | actions = [] 129 | founds = 0 130 | N_EVALS = 5_000 131 | for _ in range(N_EVALS): 132 | i = 0 133 | obs, info = env.reset() 134 | reward_sum = 0 135 | while env.agents: 136 | action = greedy_agent(obs, env.agents) 137 | obs, rw, term, trunc, info = env.step(action) 138 | reward_sum += sum(rw.values()) 139 | i += 1 140 | rewards.append(reward_sum) 141 | actions.append(i) 142 | 143 | for _, v in info.items(): 144 | if v["Found"]: 145 | founds += 1 146 | break 147 | print("Average reward: ", sum(rewards) / N_EVALS) 148 | print("Mean actions: ", sum(actions) / N_EVALS) 149 | print("Found %: ", founds / N_EVALS) 150 | -------------------------------------------------------------------------------- /src/test_ppo_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from recorder import PygameRecord 4 | from DSSE import DroneSwarmSearch 5 | from DSSE.environment.wrappers import TopNProbsWrapper, RetainDronePosWrapper, AllFlattenWrapper, AllPositionsWrapper 6 | import ray 7 | from ray.rllib.algorithms.ppo import PPOConfig 8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 9 | from ray.rllib.models import ModelCatalog 10 | from ray.tune.registry import register_env 11 | from ray.rllib.algorithms.ppo import PPO 12 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 13 | import torch.nn as nn 14 | import argparse 15 | import torch 16 | 17 | 18 | 19 | argparser = argparse.ArgumentParser() 20 | argparser.add_argument("--checkpoint", type=str, required=True) 21 | argparser.add_argument("--see", action="store_true", default=False) 22 | args = argparser.parse_args() 23 | 24 | 25 | class CNNModel(TorchModelV2, nn.Module): 26 | def __init__( 27 | self, 28 | obs_space, 29 | act_space, 30 | num_outputs, 31 | model_config, 32 | name, 33 | **kw, 34 | ): 35 | print("OBSSPACE: ", obs_space) 36 | TorchModelV2.__init__( 37 | self, obs_space, act_space, num_outputs, model_config, name, **kw 38 | ) 39 | nn.Module.__init__(self) 40 | 41 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3) 42 | self.cnn = nn.Sequential( 43 | nn.Conv2d( 44 | in_channels=1, 45 | out_channels=16, 46 | kernel_size=(8, 8), 47 | stride=(1, 1), 48 | ), 49 | nn.Tanh(), 50 | nn.Conv2d( 51 | in_channels=16, 52 | out_channels=32, 53 | kernel_size=(4, 4), 54 | stride=(1, 1), 55 | ), 56 | nn.Tanh(), 57 | nn.Flatten(), 58 | nn.Linear(flatten_size, 256), 59 | nn.Tanh(), 60 | ) 61 | 62 | self.linear = nn.Sequential( 63 | nn.Linear(obs_space[0].shape[0], 512), 64 | nn.Tanh(), 65 | nn.Linear(512, 256), 66 | nn.Tanh(), 67 | ) 68 | 69 | self.join = nn.Sequential( 70 | nn.Linear(256 * 2, 256), 71 | nn.Tanh(), 72 | ) 73 | 74 | self.policy_fn = nn.Linear(256, num_outputs) 75 | self.value_fn = nn.Linear(256, 1) 76 | 77 | def forward(self, input_dict, state, seq_lens): 78 | input_positions = input_dict["obs"][0].float() 79 | input_matrix = input_dict["obs"][1].float() 80 | 81 | input_matrix = input_matrix.unsqueeze(1) 82 | cnn_out = self.cnn(input_matrix) 83 | linear_out = self.linear(input_positions) 84 | 85 | value_input = torch.cat((cnn_out, linear_out), dim=1) 86 | value_input = self.join(value_input) 87 | 88 | self._value_out = self.value_fn(value_input) 89 | return self.policy_fn(value_input), state 90 | 91 | def value_function(self): 92 | return self._value_out.flatten() 93 | 94 | 95 | 96 | # ModelCatalog.register_custom_model("MLPModel", MLPModel) 97 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 98 | 99 | # DEFINE HERE THE EXACT ENVIRONMENT YOU USED TO TRAIN THE AGENT 100 | def env_creator(_): 101 | render_mode = "human" if args.see else "ansi" 102 | env = DroneSwarmSearch( 103 | drone_amount=4, 104 | grid_size=40, 105 | render_mode=render_mode, 106 | render_grid=True, 107 | dispersion_inc=0.1, 108 | person_initial_position=(20, 20), 109 | ) 110 | positions = [ 111 | (20, 0), 112 | (20, 39), 113 | (0, 20), 114 | (39, 20), 115 | ] 116 | env = RetainDronePosWrapper(env, positions) 117 | env = AllPositionsWrapper(env) 118 | return env 119 | 120 | env = env_creator(None) 121 | register_env("DSSE", lambda config: ParallelPettingZooEnv(env_creator(config))) 122 | ray.init() 123 | 124 | 125 | checkpoint_path = args.checkpoint 126 | PPOagent = PPO.from_checkpoint(checkpoint_path) 127 | 128 | reward_sum = 0 129 | i = 0 130 | 131 | if args.see: 132 | obs, info = env.reset() 133 | with PygameRecord("test_trained.gif", 5) as rec: 134 | while env.agents: 135 | actions = {} 136 | for k, v in obs.items(): 137 | actions[k] = PPOagent.compute_single_action(v, explore=False) 138 | # print(v) 139 | # action = PPOagent.compute_actions(obs) 140 | obs, rw, term, trunc, info = env.step(actions) 141 | reward_sum += sum(rw.values()) 142 | i += 1 143 | rec.add_frame() 144 | else: 145 | rewards = [] 146 | founds = 0 147 | N_EVALS = 1000 148 | for _ in range(N_EVALS): 149 | obs, info = env.reset() 150 | reward_sum = 0 151 | while env.agents: 152 | actions = {} 153 | for k, v in obs.items(): 154 | actions[k] = PPOagent.compute_single_action(v, explore=False) 155 | # print(v) 156 | # action = PPOagent.compute_actions(obs, explore=False) 157 | obs, rw, term, trunc, info = env.step(actions) 158 | reward_sum += sum(rw.values()) 159 | i += 1 160 | rewards.append(reward_sum) 161 | for _, v in info.items(): 162 | if v["Found"]: 163 | founds += 1 164 | break 165 | print("Average reward: ", sum(rewards) / N_EVALS) 166 | print("Found %: ", founds / N_EVALS) 167 | 168 | print("Total reward: ", reward_sum) 169 | print("Total steps: ", i) 170 | print("Found: ", info) 171 | env.close() 172 | -------------------------------------------------------------------------------- /src/train_ppo_cnn_cov.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from DSSE import CoverageDroneSwarmSearch 3 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 4 | import ray 5 | from ray import tune 6 | from ray.rllib.algorithms.ppo import PPOConfig 7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 8 | from ray.rllib.models import ModelCatalog 9 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 10 | from ray.tune.registry import register_env 11 | from torch import nn 12 | import torch 13 | import numpy as np 14 | 15 | 16 | class CNNModel(TorchModelV2, nn.Module): 17 | def __init__( 18 | self, 19 | obs_space, 20 | act_space, 21 | num_outputs, 22 | model_config, 23 | name, 24 | **kw, 25 | ): 26 | print("OBSSPACE: ", obs_space) 27 | TorchModelV2.__init__( 28 | self, obs_space, act_space, num_outputs, model_config, name, **kw 29 | ) 30 | nn.Module.__init__(self) 31 | 32 | first_cnn_out = (obs_space[1].shape[0] - 3) + 1 33 | second_cnn_out = (first_cnn_out - 2) + 1 34 | flatten_size = 32 * second_cnn_out * second_cnn_out 35 | print("Cnn Dense layer input size: ", flatten_size) 36 | self.cnn = nn.Sequential( 37 | nn.Conv2d( 38 | in_channels=1, 39 | out_channels=16, 40 | kernel_size=(3, 3), 41 | ), 42 | nn.Tanh(), 43 | nn.Conv2d( 44 | in_channels=16, 45 | out_channels=32, 46 | kernel_size=(2, 2), 47 | ), 48 | nn.Tanh(), 49 | nn.Flatten(), 50 | nn.Linear(flatten_size, 256), 51 | nn.Tanh(), 52 | ) 53 | 54 | self.linear = nn.Sequential( 55 | nn.Linear(obs_space[0].shape[0], 512), 56 | nn.Tanh(), 57 | nn.Linear(512, 256), 58 | nn.Tanh(), 59 | ) 60 | 61 | self.join = nn.Sequential( 62 | nn.Linear(256 * 2, 256), 63 | nn.Tanh(), 64 | ) 65 | 66 | self.policy_fn = nn.Linear(256, num_outputs) 67 | self.value_fn = nn.Linear(256, 1) 68 | 69 | def forward(self, input_dict, state, seq_lens): 70 | input_positions = input_dict["obs"][0].float() 71 | input_matrix = input_dict["obs"][1].float() 72 | 73 | input_matrix = input_matrix.unsqueeze(1) 74 | cnn_out = self.cnn(input_matrix) 75 | linear_out = self.linear(input_positions) 76 | 77 | value_input = torch.cat((cnn_out, linear_out), dim=1) 78 | value_input = self.join(value_input) 79 | 80 | self._value_out = self.value_fn(value_input) 81 | return self.policy_fn(value_input), state 82 | 83 | def value_function(self): 84 | return self._value_out.flatten() 85 | 86 | def env_creator(args): 87 | print("-------------------------- ENV CREATOR --------------------------") 88 | N_AGENTS = 2 89 | # 6 hours of simulation, 600 radius 90 | env = CoverageDroneSwarmSearch( 91 | timestep_limit=200, drone_amount=N_AGENTS, prob_matrix_path="min_matrix.npy" 92 | ) 93 | env = AllPositionsWrapper(env) 94 | grid_size = env.grid_size 95 | # positions = position_on_diagonal(grid_size, N_AGENTS) 96 | # positions = position_on_circle(grid_size, N_AGENTS, 2) 97 | positions = [ 98 | (grid_size - 1, grid_size // 2), 99 | (0, grid_size // 2), 100 | ] 101 | env = RetainDronePosWrapper(env, positions) 102 | return env 103 | 104 | def position_on_diagonal(grid_size, drone_amount): 105 | positions = [] 106 | center = grid_size // 2 107 | for i in range(-drone_amount // 2, drone_amount // 2): 108 | positions.append((center + i, center + i)) 109 | return positions 110 | 111 | def position_on_circle(grid_size, drone_amount, radius): 112 | positions = [] 113 | center = grid_size // 2 114 | angle_increment = 2 * np.pi / drone_amount 115 | 116 | for i in range(drone_amount): 117 | angle = i * angle_increment 118 | x = center + int(radius * np.cos(angle)) 119 | y = center + int(radius * np.sin(angle)) 120 | positions.append((x, y)) 121 | 122 | return positions 123 | 124 | 125 | if __name__ == "__main__": 126 | ray.init() 127 | 128 | env_name = "DSSE_Coverage" 129 | 130 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 131 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 132 | 133 | config = ( 134 | PPOConfig() 135 | .environment(env=env_name) 136 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto") 137 | .training( 138 | train_batch_size=8192 * 5, 139 | lr=6e-6, 140 | gamma=0.9999999, 141 | lambda_=0.9, 142 | use_gae=True, 143 | entropy_coeff=0.01, 144 | vf_clip_param=100000, 145 | sgd_minibatch_size=300, 146 | num_sgd_iter=10, 147 | model={ 148 | "custom_model": "CNNModel", 149 | "_disable_preprocessor_api": True, 150 | }, 151 | ) 152 | .experimental(_disable_preprocessor_api=True) 153 | .debugging(log_level="ERROR") 154 | .framework(framework="torch") 155 | .resources(num_gpus=1) 156 | ) 157 | 158 | curr_path = pathlib.Path().resolve() 159 | tune.run( 160 | "PPO", 161 | name="PPO_" + input("Exp name: "), 162 | # resume=True, 163 | stop={"timesteps_total": 20_000_000}, 164 | checkpoint_freq=20, 165 | storage_path=f"{curr_path}/ray_res/" + env_name, 166 | config=config.to_dict(), 167 | ) 168 | -------------------------------------------------------------------------------- /src/test_trained_search.py: -------------------------------------------------------------------------------- 1 | from recorder import PygameRecord 2 | from DSSE import DroneSwarmSearch 3 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 4 | from DSSE.environment.wrappers.communication_wrapper import CommunicationWrapper 5 | import ray 6 | from ray.rllib.algorithms.ppo import PPOConfig 7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 8 | from ray.rllib.models import ModelCatalog 9 | from ray.tune.registry import register_env 10 | from ray.rllib.algorithms.ppo import PPO 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | import torch.nn as nn 13 | import argparse 14 | import torch 15 | import numpy as np 16 | 17 | 18 | 19 | argparser = argparse.ArgumentParser() 20 | argparser.add_argument("--checkpoint", type=str, required=True) 21 | argparser.add_argument("--see", action="store_true", default=False) 22 | args = argparser.parse_args() 23 | 24 | 25 | class CNNModel(TorchModelV2, nn.Module): 26 | def __init__( 27 | self, 28 | obs_space, 29 | act_space, 30 | num_outputs, 31 | model_config, 32 | name, 33 | **kw, 34 | ): 35 | print("OBSSPACE: ", obs_space) 36 | TorchModelV2.__init__( 37 | self, obs_space, act_space, num_outputs, model_config, name, **kw 38 | ) 39 | nn.Module.__init__(self) 40 | 41 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3) 42 | self.cnn = nn.Sequential( 43 | nn.Conv2d( 44 | in_channels=1, 45 | out_channels=16, 46 | kernel_size=(8, 8), 47 | stride=(1, 1), 48 | ), 49 | nn.Tanh(), 50 | nn.Conv2d( 51 | in_channels=16, 52 | out_channels=32, 53 | kernel_size=(4, 4), 54 | stride=(1, 1), 55 | ), 56 | nn.Tanh(), 57 | nn.Flatten(), 58 | nn.Linear(flatten_size, 256), 59 | nn.Tanh(), 60 | ) 61 | 62 | self.linear = nn.Sequential( 63 | nn.Linear(obs_space[0].shape[0], 512), 64 | nn.Tanh(), 65 | nn.Linear(512, 256), 66 | nn.Tanh(), 67 | ) 68 | 69 | self.join = nn.Sequential( 70 | nn.Linear(256 * 2, 256), 71 | nn.Tanh(), 72 | ) 73 | 74 | self.policy_fn = nn.Linear(256, num_outputs) 75 | self.value_fn = nn.Linear(256, 1) 76 | 77 | def forward(self, input_dict, state, seq_lens): 78 | input_positions = input_dict["obs"][0].float() 79 | input_matrix = input_dict["obs"][1].float() 80 | 81 | input_matrix = input_matrix.unsqueeze(1) 82 | cnn_out = self.cnn(input_matrix) 83 | linear_out = self.linear(input_positions) 84 | 85 | value_input = torch.cat((cnn_out, linear_out), dim=1) 86 | value_input = self.join(value_input) 87 | 88 | self._value_out = self.value_fn(value_input) 89 | return self.policy_fn(value_input), state 90 | 91 | def value_function(self): 92 | return self._value_out.flatten() 93 | 94 | 95 | 96 | # ModelCatalog.register_custom_model("MLPModel", MLPModel) 97 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 98 | 99 | # DEFINE HERE THE EXACT ENVIRONMENT YOU USED TO TRAIN THE AGENT 100 | def env_creator(args): 101 | env = DroneSwarmSearch( 102 | drone_amount=4, 103 | grid_size=40, 104 | dispersion_inc=0.1, 105 | person_initial_position=(20, 20), 106 | ) 107 | positions = [ 108 | (20, 0), 109 | (20, 39), 110 | (0, 20), 111 | (39, 20), 112 | ] 113 | env = AllPositionsWrapper(env) 114 | env = CommunicationWrapper(env, n_steps=12) 115 | env = RetainDronePosWrapper(env, positions) 116 | return env 117 | 118 | env = env_creator(None) 119 | register_env("DSSE", lambda config: ParallelPettingZooEnv(env_creator(config))) 120 | ray.init() 121 | 122 | 123 | checkpoint_path = args.checkpoint 124 | PPOagent = PPO.from_checkpoint(checkpoint_path) 125 | 126 | 127 | if args.see: 128 | i = 0 129 | reward_sum = 0 130 | obs, info = env.reset() 131 | with PygameRecord("test_trained.gif", 5) as rec: 132 | while env.agents: 133 | actions = {} 134 | for k, v in obs.items(): 135 | actions[k] = PPOagent.compute_single_action(v, explore=False) 136 | # print(v) 137 | # action = PPOagent.compute_actions(obs) 138 | obs, rw, term, trunc, info = env.step(actions) 139 | reward_sum += sum(rw.values()) 140 | i += 1 141 | rec.add_frame() 142 | print(info) 143 | print(reward_sum) 144 | else: 145 | rewards = [] 146 | actions_stat = [] 147 | founds = 0 148 | N_EVALS = 5_000 149 | for epoch in range(N_EVALS): 150 | print(epoch) 151 | obs, info = env.reset() 152 | i = 0 153 | reward_sum = 0 154 | while env.agents: 155 | actions = {} 156 | for k, v in obs.items(): 157 | actions[k] = PPOagent.compute_single_action(v, explore=False) 158 | obs, rw, term, trunc, info = env.step(actions) 159 | reward_sum += sum(rw.values()) 160 | i += 1 161 | rewards.append(reward_sum) 162 | actions_stat.append(i) 163 | 164 | for _, v in info.items(): 165 | if v["Found"]: 166 | founds += 1 167 | break 168 | print("Average reward: ", sum(rewards) / N_EVALS) 169 | print("Average Actions: ", sum(actions_stat) / N_EVALS) 170 | print("Median of actions: ", np.median(actions_stat)) 171 | print("Found %: ", founds / N_EVALS) 172 | 173 | env.close() 174 | -------------------------------------------------------------------------------- /src/test_trained_cov.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from recorder import PygameRecord 4 | from DSSE import CoverageDroneSwarmSearch 5 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 6 | import ray 7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 8 | from ray.rllib.models import ModelCatalog 9 | from ray.tune.registry import register_env 10 | from ray.rllib.algorithms.ppo import PPO 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | import torch.nn as nn 13 | import argparse 14 | import torch 15 | 16 | 17 | 18 | argparser = argparse.ArgumentParser() 19 | argparser.add_argument("--checkpoint", type=str, required=True) 20 | argparser.add_argument("--see", action="store_true", default=False) 21 | args = argparser.parse_args() 22 | 23 | 24 | class CNNModel(TorchModelV2, nn.Module): 25 | def __init__( 26 | self, 27 | obs_space, 28 | act_space, 29 | num_outputs, 30 | model_config, 31 | name, 32 | **kw, 33 | ): 34 | print("OBSSPACE: ", obs_space) 35 | TorchModelV2.__init__( 36 | self, obs_space, act_space, num_outputs, model_config, name, **kw 37 | ) 38 | nn.Module.__init__(self) 39 | 40 | first_cnn_out = (obs_space[1].shape[0] - 8) + 1 41 | second_cnn_out = (first_cnn_out - 4) + 1 42 | flatten_size = 32 * second_cnn_out * second_cnn_out 43 | print("Cnn Dense layer input size: ", flatten_size) 44 | self.cnn = nn.Sequential( 45 | nn.Conv2d( 46 | in_channels=1, 47 | out_channels=16, 48 | kernel_size=(8, 8), 49 | ), 50 | nn.Tanh(), 51 | nn.Conv2d( 52 | in_channels=16, 53 | out_channels=32, 54 | kernel_size=(4, 4), 55 | ), 56 | nn.Tanh(), 57 | nn.Flatten(), 58 | nn.Linear(flatten_size, 256), 59 | nn.Tanh(), 60 | ) 61 | 62 | self.linear = nn.Sequential( 63 | nn.Linear(obs_space[0].shape[0], 512), 64 | nn.Tanh(), 65 | nn.Linear(512, 256), 66 | nn.Tanh(), 67 | ) 68 | 69 | self.join = nn.Sequential( 70 | nn.Linear(256 * 2, 256), 71 | nn.Tanh(), 72 | ) 73 | 74 | self.policy_fn = nn.Linear(256, num_outputs) 75 | self.value_fn = nn.Linear(256, 1) 76 | 77 | def forward(self, input_dict, state, seq_lens): 78 | input_positions = input_dict["obs"][0].float() 79 | input_matrix = input_dict["obs"][1].float() 80 | 81 | input_matrix = input_matrix.unsqueeze(1) 82 | cnn_out = self.cnn(input_matrix) 83 | linear_out = self.linear(input_positions) 84 | 85 | value_input = torch.cat((cnn_out, linear_out), dim=1) 86 | value_input = self.join(value_input) 87 | 88 | self._value_out = self.value_fn(value_input) 89 | return self.policy_fn(value_input), state 90 | 91 | def value_function(self): 92 | return self._value_out.flatten() 93 | 94 | # Register the model 95 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 96 | 97 | 98 | def env_creator(args): 99 | print("-------------------------- ENV CREATOR --------------------------") 100 | N_AGENTS = 8 101 | # 6 hours of simulation, 600 radius 102 | env = CoverageDroneSwarmSearch( 103 | timestep_limit=180, drone_amount=N_AGENTS, prob_matrix_path="presim_20.npy", render_mode="human" 104 | ) 105 | env = AllPositionsWrapper(env) 106 | grid_size = env.grid_size 107 | positions = position_on_diagonal(grid_size, N_AGENTS) 108 | env = RetainDronePosWrapper(env, positions) 109 | return env 110 | 111 | def position_on_diagonal(grid_size, drone_amount): 112 | positions = [] 113 | center = grid_size // 2 114 | for i in range(-drone_amount // 2, drone_amount // 2): 115 | positions.append((center + i, center + i)) 116 | return positions 117 | 118 | env = env_creator(None) 119 | register_env("DSSE_Coverage", lambda config: ParallelPettingZooEnv(env_creator(config))) 120 | ray.init() 121 | 122 | 123 | checkpoint_path = args.checkpoint 124 | PPOagent = PPO.from_checkpoint(checkpoint_path) 125 | 126 | reward_sum = 0 127 | i = 0 128 | 129 | if args.see: 130 | obs, info = env.reset() 131 | with PygameRecord("test_trained.gif", 5) as rec: 132 | while env.agents: 133 | actions = {} 134 | for k, v in obs.items(): 135 | actions[k] = PPOagent.compute_single_action(v, explore=False) 136 | # print(v) 137 | # action = PPOagent.compute_actions(obs) 138 | obs, rw, term, trunc, info = env.step(actions) 139 | reward_sum += sum(rw.values()) 140 | i += 1 141 | rec.add_frame() 142 | else: 143 | rewards = [] 144 | founds = 0 145 | N_EVALS = 1000 146 | for _ in range(N_EVALS): 147 | obs, info = env.reset() 148 | reward_sum = 0 149 | while env.agents: 150 | actions = {} 151 | for k, v in obs.items(): 152 | actions[k] = PPOagent.compute_single_action(v, explore=False) 153 | # print(v) 154 | # action = PPOagent.compute_actions(obs, explore=False) 155 | obs, rw, term, trunc, info = env.step(actions) 156 | reward_sum += sum(rw.values()) 157 | i += 1 158 | rewards.append(reward_sum) 159 | for _, v in info.items(): 160 | if v["Found"]: 161 | founds += 1 162 | break 163 | print("Average reward: ", sum(rewards) / N_EVALS) 164 | print("Found %: ", founds / N_EVALS) 165 | 166 | print("Total reward: ", reward_sum) 167 | print("Total steps: ", i) 168 | print("Found: ", info) 169 | env.close() 170 | -------------------------------------------------------------------------------- /src/a_star_coverage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to implement the A* algorithm for coverage path planning. 3 | """ 4 | 5 | import argparse 6 | import numpy as np 7 | from DSSE import CoverageDroneSwarmSearch, Actions 8 | from aigyminsper.search.graph import State 9 | 10 | MOVEMENTS = { 11 | Actions.UP: (0, -1), 12 | Actions.DOWN: (0, 1), 13 | Actions.LEFT: (-1, 0), 14 | Actions.RIGHT: (1, 0), 15 | Actions.UP_LEFT: (-1, -1), 16 | Actions.UP_RIGHT: (1, -1), 17 | Actions.DOWN_LEFT: (-1, 1), 18 | Actions.DOWN_RIGHT: (1, 1), 19 | } 20 | 21 | 22 | class DroneState(State): 23 | def __init__(self, position: tuple, prob_matrix: np.ndarray, visited: set): 24 | self.position = position 25 | self.prob_matrix = prob_matrix 26 | self.visited = visited 27 | self.action = None 28 | 29 | def successors(self, allow_zeros: bool = False) -> list["DroneState"]: 30 | successors = [] 31 | x, y = self.position 32 | 33 | for action, (dx, dy) in MOVEMENTS.items(): 34 | new_x, new_y = x + dx, y + dy 35 | 36 | if 0 <= new_x < len(self.prob_matrix) and 0 <= new_y < len( 37 | self.prob_matrix[0] 38 | ): 39 | if allow_zeros or self.prob_matrix[new_y, new_x] > 0: 40 | new_state = DroneState( 41 | (new_x, new_y), 42 | self.prob_matrix, 43 | self.visited | {(new_x, new_y)}, 44 | ) 45 | new_state.set_action(action) 46 | successors.append(new_state) 47 | 48 | return successors 49 | 50 | def cost( 51 | self, 52 | prob_weight: int | float = 10, 53 | distance_weight: int | float = 0.5, 54 | revisit_penalty_value: int | float = 30, 55 | ) -> int | float: 56 | high_prob_indices = np.argwhere(self.prob_matrix > 0) 57 | number_of_high_prob = high_prob_indices.shape[0] 58 | 59 | if number_of_high_prob == 0: 60 | return float("inf") 61 | 62 | x, y = self.position 63 | distances = np.abs(high_prob_indices[:, 1] - x) + np.abs( 64 | high_prob_indices[:, 0] - y 65 | ) 66 | min_distance = np.min(distances) 67 | 68 | max_prob = np.max(self.prob_matrix) 69 | prob_value = self.prob_matrix[y, x] 70 | normalized_prob = prob_value / max_prob if max_prob > 0 else 0 71 | 72 | max_distance = len(self.prob_matrix) + len(self.prob_matrix[0]) 73 | normalized_distance = min_distance / max_distance 74 | 75 | revisit_penalty = 0 76 | if (x, y) in self.visited: 77 | revisit_penalty = revisit_penalty_value * ( 78 | len(self.visited) / (number_of_high_prob + 1) 79 | ) 80 | 81 | heuristic_value = ( 82 | -(normalized_prob * prob_weight) 83 | + (normalized_distance * distance_weight) 84 | + revisit_penalty 85 | ) 86 | 87 | return heuristic_value 88 | 89 | def description(self): 90 | return f"DroneState Position = {self.position}" 91 | 92 | def env(self): 93 | return self.position 94 | 95 | def is_goal(self): 96 | return False 97 | 98 | def set_action(self, action: Actions): 99 | self.action = action 100 | 101 | def get_action(self) -> Actions: 102 | return self.action if self.action else Actions.SEARCH 103 | 104 | def __eq__(self, other: "DroneState") -> bool: 105 | position_eq = self.position == other.position 106 | visited_eq = self.visited == other.visited 107 | action_eq = self.action == other.action 108 | return position_eq and visited_eq and action_eq 109 | 110 | 111 | def a_star( 112 | observations: dict, agents: list, prob_matrix: np.ndarray, visited: list[set] 113 | ) -> dict: 114 | actions = {} 115 | will_visit = [] 116 | for i, agent in enumerate(agents): 117 | current_x, current_y = observations[agent][0] 118 | 119 | drone_state = DroneState((current_x, current_y), prob_matrix, visited[i]) 120 | 121 | successors = drone_state.successors() 122 | 123 | if not successors: 124 | successors = drone_state.successors(allow_zeros=True) 125 | 126 | if not successors: 127 | continue 128 | 129 | next_state = min(successors, key=lambda state: state.cost()) 130 | while next_state.position in will_visit: 131 | successors.remove(next_state) 132 | next_state = min(successors, key=lambda state: state.cost()) 133 | 134 | will_visit.append(next_state.position) 135 | actions[agent] = next_state.get_action().value 136 | visited[i].add(next_state.position) 137 | prob_matrix[current_y, current_x] = 0 138 | 139 | return actions 140 | 141 | 142 | def main(num_drones: int): 143 | env = CoverageDroneSwarmSearch( 144 | drone_amount=num_drones, 145 | render_mode="human", 146 | timestep_limit=200, 147 | prob_matrix_path="src/min_matrix.npy", 148 | ) 149 | 150 | center = env.grid_size // 2 151 | positions = [(center, center)] 152 | 153 | for i in range(1, num_drones): 154 | offset = (i // 2) + 1 155 | x_offset = offset * (-1 if i % 2 == 0 else 1) 156 | y_offset = offset * (-1 if (i + 1) % 2 == 0 else 1) 157 | new_position = (center + x_offset, center + y_offset) 158 | positions.append(new_position) 159 | 160 | opt = {"drones_positions": positions} 161 | 162 | observations, _ = env.reset(options=opt) 163 | 164 | visited = [set([opt["drones_positions"][i]]) for i in range(num_drones)] 165 | 166 | prob_matrix = env.probability_matrix.get_matrix() 167 | while env.agents: 168 | actions = a_star(observations, env.agents, prob_matrix.copy(), visited) 169 | observations, *_ = env.step(actions) 170 | 171 | 172 | if __name__ == "__main__": 173 | argparser = argparse.ArgumentParser() 174 | argparser.add_argument("--num_drones", type=int, required=True) 175 | args = argparser.parse_args() 176 | main(num_drones=args.num_drones) 177 | -------------------------------------------------------------------------------- /src/train_ppo_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from DSSE import DroneSwarmSearch 4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 5 | from DSSE.environment.wrappers.communication_wrapper import CommunicationWrapper 6 | import ray 7 | from ray import tune 8 | from ray.rllib.algorithms.ppo import PPOConfig 9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 10 | from ray.rllib.models import ModelCatalog 11 | from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN 12 | from ray.rllib.models.modelv2 import ModelV2 13 | from ray.rllib.utils.annotations import override 14 | from ray.rllib.policy.rnn_sequencing import add_time_dimension 15 | from ray.tune.registry import register_env 16 | from torch import nn 17 | import torch 18 | 19 | 20 | class CNNModel(TorchRNN, nn.Module): 21 | def __init__( 22 | self, 23 | obs_space, 24 | act_space, 25 | num_outputs, 26 | model_config, 27 | name, 28 | **kw, 29 | ): 30 | print("OBSSPACE: ", obs_space) 31 | num_outputs = act_space.n 32 | nn.Module.__init__(self) 33 | super().__init__(obs_space, act_space, num_outputs, model_config, name, **kw) 34 | 35 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[0] - 7 - 3) 36 | self.cnn = nn.Sequential( 37 | nn.Conv2d( 38 | in_channels=1, 39 | out_channels=16, 40 | kernel_size=(8, 8), 41 | stride=(1, 1), 42 | ), 43 | nn.Tanh(), 44 | nn.Conv2d( 45 | in_channels=16, 46 | out_channels=32, 47 | kernel_size=(4, 4), 48 | stride=(1, 1), 49 | ), 50 | nn.Tanh(), 51 | nn.Flatten(), 52 | nn.Linear(flatten_size, 256), 53 | nn.Tanh(), 54 | ) 55 | 56 | self.linear = nn.Linear(obs_space[0].shape[0], 512) 57 | 58 | self.lstm_state_size = 256 59 | self.lstm = nn.LSTM(512, self.lstm_state_size, batch_first=True) 60 | 61 | self.join = nn.Sequential( 62 | nn.Linear(256 * 2, 256), 63 | nn.Tanh(), 64 | ) 65 | print("NUM OUTPUTS: ", num_outputs) 66 | self.policy_fn = nn.Linear(256, num_outputs) 67 | self.value_fn = nn.Linear(256, 1) 68 | # Holds the current "base" output (before logits layer). 69 | self._value_out = None 70 | 71 | @override(ModelV2) 72 | def get_initial_state(self): 73 | # TODO: (sven): Get rid of `get_initial_state` once Trajectory 74 | # View API is supported across all of RLlib. 75 | # Place hidden states on same device as model. 76 | h = [ 77 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0), 78 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0) 79 | ] 80 | return h 81 | 82 | def value_function(self): 83 | return self._value_out.flatten() 84 | 85 | @override(ModelV2) 86 | def forward( 87 | self, 88 | input_dict, 89 | state, 90 | seq_lens, 91 | ): 92 | """Adds time dimension to batch before sending inputs to forward_rnn(). 93 | 94 | You should implement forward_rnn() in your subclass.""" 95 | scalar_inputs = input_dict["obs"][0].float() 96 | input_matrix = input_dict["obs"][1].float() 97 | 98 | input_matrix = input_matrix.unsqueeze(1) 99 | cnn_out = self.cnn(input_matrix) 100 | 101 | flat_inputs = scalar_inputs.flatten(start_dim=1) 102 | # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() 103 | # as input_dict may have extra zero-padding beyond seq_lens.max(). 104 | # Use add_time_dimension to handle this 105 | self.time_major = self.model_config.get("_time_major", False) 106 | inputs = add_time_dimension( 107 | flat_inputs, 108 | seq_lens=seq_lens, 109 | framework="torch", 110 | time_major=self.time_major, 111 | ) 112 | lstm_out, new_state = self.forward_rnn(inputs, state, seq_lens) 113 | lstm_out = torch.reshape(lstm_out, [-1, self.lstm_state_size]) 114 | 115 | value_input = torch.cat((cnn_out, lstm_out), dim=1) 116 | value_input = self.join(value_input) 117 | 118 | self._value_out = self.value_fn(value_input) 119 | return self.policy_fn(value_input), new_state 120 | 121 | @override(TorchRNN) 122 | def forward_rnn(self, inputs, state, seq_lens): 123 | """Feeds `inputs` (B x T x ..) through the Gru Unit. 124 | 125 | Returns the resulting outputs as a sequence (B x T x ...). 126 | Values are stored in self._cur_value in simple (B) shape (where B 127 | contains both the B and T dims!). 128 | 129 | Returns: 130 | NN Outputs (B x T x ...) as sequence. 131 | The state batches as a List of two items (c- and h-states). 132 | """ 133 | linear_out = nn.functional.tanh(self.linear(inputs)) 134 | 135 | lstm_out, [h, c] = self.lstm(linear_out, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]) 136 | 137 | return lstm_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)] 138 | 139 | 140 | def env_creator(args): 141 | """ 142 | Petting Zoo environment for search of shipwrecked people. 143 | check it out at 144 | https://github.com/pfeinsper/drone-swarm-search 145 | or install with 146 | pip install DSSE 147 | """ 148 | env = DroneSwarmSearch( 149 | drone_amount=4, 150 | grid_size=40, 151 | dispersion_inc=0.1, 152 | person_initial_position=(20, 20), 153 | ) 154 | positions = [ 155 | (20, 0), 156 | (20, 39), 157 | (0, 20), 158 | (39, 20), 159 | ] 160 | env = AllPositionsWrapper(env) 161 | env = RetainDronePosWrapper(env, positions) 162 | return env 163 | 164 | 165 | if __name__ == "__main__": 166 | ray.init() 167 | 168 | env_name = "DSSE" 169 | 170 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) 171 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 172 | 173 | config = ( 174 | PPOConfig() 175 | .environment(env=env_name) 176 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto") 177 | .training( 178 | train_batch_size=4096, 179 | lr=1e-5, 180 | gamma=0.9999999, 181 | lambda_=0.9, 182 | use_gae=True, 183 | entropy_coeff=0.01, 184 | sgd_minibatch_size=300, 185 | num_sgd_iter=10, 186 | model={ 187 | "custom_model": "CNNModel", 188 | "use_lstm": False, 189 | "lstm_cell_size": 256, 190 | "_disable_preprocessor_api": True, 191 | }, 192 | ) 193 | .experimental(_disable_preprocessor_api=True) 194 | .debugging(log_level="ERROR") 195 | .framework(framework="torch") 196 | .resources(num_gpus=1) 197 | ) 198 | 199 | curr_path = pathlib.Path().resolve() 200 | tune.run( 201 | "PPO", 202 | name="PPO_LSTM_M", 203 | resume=True, 204 | stop={"timesteps_total": 20_000_000, "episode_reward_mean": 1.82}, 205 | checkpoint_freq=15, 206 | storage_path=f"{curr_path}/ray_res/" + env_name, 207 | config=config.to_dict(), 208 | ) 209 | -------------------------------------------------------------------------------- /src/test_trained_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from recorder import PygameRecord 4 | from DSSE import DroneSwarmSearch 5 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper 6 | import ray 7 | from ray.rllib.algorithms.ppo import PPOConfig 8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv 9 | from ray.rllib.models import ModelCatalog 10 | from ray.tune.registry import register_env 11 | from ray.rllib.utils.annotations import override 12 | from ray.rllib.policy.rnn_sequencing import add_time_dimension 13 | from ray.rllib.algorithms.ppo import PPO 14 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 15 | from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN 16 | from ray.rllib.models.modelv2 import ModelV2 17 | import torch.nn as nn 18 | import argparse 19 | import torch 20 | import numpy as np 21 | 22 | 23 | 24 | argparser = argparse.ArgumentParser() 25 | argparser.add_argument("--checkpoint", type=str, required=True) 26 | argparser.add_argument("--see", action="store_true", default=False) 27 | args = argparser.parse_args() 28 | 29 | class CNNModel(TorchRNN, nn.Module): 30 | def __init__( 31 | self, 32 | obs_space, 33 | act_space, 34 | num_outputs, 35 | model_config, 36 | name, 37 | **kw, 38 | ): 39 | print("OBSSPACE: ", obs_space) 40 | num_outputs = act_space.n 41 | nn.Module.__init__(self) 42 | super().__init__(obs_space, act_space, num_outputs, model_config, name, **kw) 43 | 44 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[0] - 7 - 3) 45 | self.cnn = nn.Sequential( 46 | nn.Conv2d( 47 | in_channels=1, 48 | out_channels=16, 49 | kernel_size=(8, 8), 50 | stride=(1, 1), 51 | ), 52 | nn.Tanh(), 53 | nn.Conv2d( 54 | in_channels=16, 55 | out_channels=32, 56 | kernel_size=(4, 4), 57 | stride=(1, 1), 58 | ), 59 | nn.Tanh(), 60 | nn.Flatten(), 61 | nn.Linear(flatten_size, 256), 62 | nn.Tanh(), 63 | ) 64 | 65 | self.linear = nn.Linear(obs_space[0].shape[0], 512) 66 | 67 | self.lstm_state_size = 256 68 | self.lstm = nn.LSTM(512, self.lstm_state_size, batch_first=True) 69 | 70 | self.join = nn.Sequential( 71 | nn.Linear(256 * 2, 256), 72 | nn.Tanh(), 73 | ) 74 | print("NUM OUTPUTS: ", num_outputs) 75 | self.policy_fn = nn.Linear(256, num_outputs) 76 | self.value_fn = nn.Linear(256, 1) 77 | # Holds the current "base" output (before logits layer). 78 | self._value_out = None 79 | 80 | @override(ModelV2) 81 | def get_initial_state(self): 82 | # TODO: (sven): Get rid of `get_initial_state` once Trajectory 83 | # View API is supported across all of RLlib. 84 | # Place hidden states on same device as model. 85 | h = [ 86 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0), 87 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0) 88 | ] 89 | return h 90 | 91 | def value_function(self): 92 | return self._value_out.flatten() 93 | 94 | @override(ModelV2) 95 | def forward( 96 | self, 97 | input_dict, 98 | state, 99 | seq_lens, 100 | ): 101 | """Adds time dimension to batch before sending inputs to forward_rnn(). 102 | 103 | You should implement forward_rnn() in your subclass.""" 104 | scalar_inputs = input_dict["obs"][0].float() 105 | input_matrix = input_dict["obs"][1].float() 106 | 107 | input_matrix = input_matrix.unsqueeze(1) 108 | cnn_out = self.cnn(input_matrix) 109 | 110 | flat_inputs = scalar_inputs.flatten(start_dim=1) 111 | # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() 112 | # as input_dict may have extra zero-padding beyond seq_lens.max(). 113 | # Use add_time_dimension to handle this 114 | self.time_major = self.model_config.get("_time_major", False) 115 | inputs = add_time_dimension( 116 | flat_inputs, 117 | seq_lens=seq_lens, 118 | framework="torch", 119 | time_major=self.time_major, 120 | ) 121 | lstm_out, new_state = self.forward_rnn(inputs, state, seq_lens) 122 | lstm_out = torch.reshape(lstm_out, [-1, self.lstm_state_size]) 123 | 124 | value_input = torch.cat((cnn_out, lstm_out), dim=1) 125 | value_input = self.join(value_input) 126 | 127 | self._value_out = self.value_fn(value_input) 128 | return self.policy_fn(value_input), new_state 129 | 130 | @override(TorchRNN) 131 | def forward_rnn(self, inputs, state, seq_lens): 132 | """Feeds `inputs` (B x T x ..) through the Gru Unit. 133 | 134 | Returns the resulting outputs as a sequence (B x T x ...). 135 | Values are stored in self._cur_value in simple (B) shape (where B 136 | contains both the B and T dims!). 137 | 138 | Returns: 139 | NN Outputs (B x T x ...) as sequence. 140 | The state batches as a List of two items (c- and h-states). 141 | """ 142 | linear_out = nn.functional.tanh(self.linear(inputs)) 143 | 144 | lstm_out, [h, c] = self.lstm(linear_out, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]) 145 | 146 | return lstm_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)] 147 | 148 | 149 | def env_creator(_): 150 | """ 151 | Petting Zoo environment for search of shipwrecked people. 152 | check it out at 153 | https://github.com/pfeinsper/drone-swarm-search 154 | or install with 155 | pip install DSSE 156 | """ 157 | render_mode = "human" if args.see else "ansi" 158 | env = DroneSwarmSearch( 159 | drone_amount=4, 160 | grid_size=40, 161 | dispersion_inc=0.1, 162 | person_initial_position=(20, 20), 163 | render_mode=render_mode, 164 | render_grid=True 165 | ) 166 | positions = [ 167 | (20, 0), 168 | (20, 39), 169 | (0, 20), 170 | (39, 20), 171 | ] 172 | env = AllPositionsWrapper(env) 173 | env = RetainDronePosWrapper(env, positions) 174 | return env 175 | 176 | env = env_creator(None) 177 | register_env("DSSE", lambda config: ParallelPettingZooEnv(env_creator(config))) 178 | ModelCatalog.register_custom_model("CNNModel", CNNModel) 179 | ray.init() 180 | 181 | 182 | checkpoint_path = args.checkpoint 183 | PPOagent = PPO.from_checkpoint(checkpoint_path) 184 | 185 | reward_sum = 0 186 | i = 0 187 | 188 | if args.see: 189 | sample_model = PPOagent.get_policy().model 190 | with PygameRecord("test_trained.gif", 5) as rec: 191 | obs, info = env.reset() 192 | reward_sum = 0 193 | state = sample_model.get_initial_state() 194 | # state = init_hidden() 195 | i = 0 196 | # done = False 197 | while env.agents: 198 | print(obs) 199 | actions = {} 200 | # for k, v in obs.items(): 201 | # action, state, _ = PPOagent.compute_single_action(v, state, explore=False) 202 | # actions[k] = action 203 | actions = PPOagent.compute_actions(obs, state, explore=False) 204 | obs, rw, term, trunc, info = env.step(actions) 205 | # done = any(term.values()) or any(trunc.values()) 206 | reward_sum += sum(rw.values()) 207 | i += 1 208 | rec.add_frame() 209 | else: 210 | rewards = [] 211 | actions_statics = [] 212 | founds = 0 213 | N_EVALS = 200 214 | sample_model = PPOagent.get_policy().model 215 | for _ in range(N_EVALS): 216 | print(_) 217 | obs, info = env.reset() 218 | reward_sum = 0 219 | # state = init_hidden() 220 | state = sample_model.get_initial_state() 221 | i = 0 222 | while env.agents: 223 | actions = {} 224 | # for k, v in obs.items(): 225 | # action, state, _ = PPOagent.compute_single_action(v, state, explore=False) 226 | # actions[k] = action 227 | actions = PPOagent.compute_actions(obs, state, explore=False) 228 | obs, rw, term, trunc, info = env.step(actions) 229 | reward_sum += sum(rw.values()) 230 | i += 1 231 | actions_statics.append(i) 232 | rewards.append(reward_sum) 233 | for _, v in info.items(): 234 | if v["Found"]: 235 | founds += 1 236 | break 237 | 238 | print("Average reward: ", sum(rewards) / N_EVALS) 239 | print("Found %: ", founds / N_EVALS) 240 | print("Mean steps: ", sum(actions_statics) / N_EVALS) 241 | print("Median steps: ", np.median(actions_statics)) 242 | print("Found: ", info) 243 | env.close() --------------------------------------------------------------------------------