├── .python-version
├── src
├── min_matrix.npy
├── train_ppo_mlp_cov.py
├── recorder.py
├── train_ppo_mlp.py
├── test_trained_cov_mlp.py
├── train_ppo_cnn.py
├── train_ppo_encoded.py
├── train_ppo_cnn_comm.py
├── train_dqn_cnn.py
├── train_ppo_multi.py
├── train_dqn_multi.py
├── train_descentralized_ppo_cnn.py
├── greedy_heuristic.py
├── test_ppo_cnn.py
├── train_ppo_cnn_cov.py
├── test_trained_search.py
├── test_trained_cov.py
├── a_star_coverage.py
├── train_ppo_cnn_lstm.py
└── test_trained_cnn_lstm.py
├── LICENSE
├── README.md
├── requirements.txt
├── .gitignore
└── imgs
└── drone.svg
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11.3
2 |
--------------------------------------------------------------------------------
/src/min_matrix.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfeinsper/drone-swarm-search-algorithms/HEAD/src/min_matrix.npy
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 PFE Embraer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://badge.fury.io/py/DSSE)
2 | [](https://github.com/pfeinsper/drone-swarm-search/blob/main/LICENSE)
3 | []()
4 | 
5 |
6 | #
Algorithms for Drone Swarm Search (DSSE)
7 |
8 | Welcome to the official GitHub repository for the Drone Swarm Search (DSSE) algorithms. These algorithms are specifically tailored for reinforcement learning environments aimed at optimizing drone swarm coordination and search efficiency.
9 |
10 | Explore a diverse range of implementations that leverage the latest advancements in machine learning to solve complex coordination tasks in dynamic and unpredictable environments.
11 |
12 | ## 📚 Documentation Links
13 |
14 | - **[Documentation Site](https://pfeinsper.github.io/drone-swarm-search/)**: Access detailed tutorials, usage examples, and comprehensive technical documentation. This resource is essential for understanding the DSSE framework and integrating these algorithms into your projects effectively.
15 |
16 | - **[DSSE Training Environment Repository](https://github.com/pfeinsper/drone-swarm-search)**: Visit the repository for the DSSE training environment, where you can access the core environment setups and configurations used for developing and testing the algorithms.
17 |
18 | - **[PyPI Repository](https://pypi.org/project/DSSE/)**: Download the latest release of DSSE, view the version history, and find installation instructions. Keep up with the latest updates and improvements to the algorithms.
19 |
20 | ## 🆘 Support and Community
21 |
22 | Run into a snag? Have a suggestion? Join our community on GitHub! Submit your queries, report bugs, or contribute to discussions by visiting our [issues page](https://github.com/pfeinsper/drone-swarm-search-algorithms/issues). Your input helps us improve and evolve.
23 |
--------------------------------------------------------------------------------
/src/train_ppo_mlp_cov.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from DSSE import CoverageDroneSwarmSearch
3 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllFlattenWrapper
4 | import ray
5 | from ray import tune
6 | from ray.rllib.algorithms.ppo import PPOConfig
7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
8 | from ray.tune.registry import register_env
9 | from torch import nn
10 | import torch
11 | import numpy as np
12 |
13 |
14 | def env_creator(args):
15 | print("-------------------------- ENV CREATOR --------------------------")
16 | N_AGENTS = 2
17 | # 6 hours of simulation, 600 radius
18 | env = CoverageDroneSwarmSearch(
19 | timestep_limit=200, drone_amount=N_AGENTS, prob_matrix_path="min_matrix.npy"
20 | )
21 | env = AllFlattenWrapper(env)
22 | grid_size = env.grid_size
23 | print("Grid size: ", grid_size)
24 | positions = [
25 | (0, grid_size // 2),
26 | (grid_size - 1, grid_size // 2),
27 | ]
28 | env = RetainDronePosWrapper(env, positions)
29 | return env
30 |
31 | def position_on_diagonal(grid_size, drone_amount):
32 | positions = []
33 | center = grid_size // 2
34 | for i in range(-drone_amount // 2, drone_amount // 2):
35 | positions.append((center + i, center + i))
36 | return positions
37 |
38 | def position_on_circle(grid_size, drone_amount, radius):
39 | positions = []
40 | center = grid_size // 2
41 | angle_increment = 2 * np.pi / drone_amount
42 |
43 | for i in range(drone_amount):
44 | angle = i * angle_increment
45 | x = center + int(radius * np.cos(angle))
46 | y = center + int(radius * np.sin(angle))
47 | positions.append((x, y))
48 |
49 | return positions
50 |
51 |
52 | if __name__ == "__main__":
53 | ray.init()
54 |
55 | env_name = "DSSE_Coverage"
56 |
57 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
58 |
59 | config = (
60 | PPOConfig()
61 | .environment(env=env_name)
62 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto", num_envs_per_worker=4)
63 | .training(
64 | train_batch_size=8192 * 3,
65 | lr=8e-6,
66 | gamma=0.9999999,
67 | lambda_=0.9,
68 | use_gae=True,
69 | entropy_coeff=0.01,
70 | vf_clip_param=100000,
71 | sgd_minibatch_size=300,
72 | num_sgd_iter=10,
73 | model={
74 | "fcnet_hiddens": [512, 256],
75 | },
76 | )
77 | .experimental(_disable_preprocessor_api=True)
78 | .debugging(log_level="ERROR")
79 | .framework(framework="torch")
80 | .resources(num_gpus=1)
81 | )
82 |
83 | curr_path = pathlib.Path().resolve()
84 | tune.run(
85 | "PPO",
86 | name="PPO_" + input("Exp name: "),
87 | # resume=True,
88 | stop={"timesteps_total": 40_000_000},
89 | checkpoint_freq=25,
90 | storage_path=f"{curr_path}/ray_res/" + env_name,
91 | config=config.to_dict(),
92 | )
93 |
--------------------------------------------------------------------------------
/src/recorder.py:
--------------------------------------------------------------------------------
1 | """
2 | PygameRecord - A utility for recording Pygame screens as GIFS.
3 | This module provides a class, PygameRecord, which can be used to record Pygame
4 | animations and save them as GIF files. It captures frames from the Pygame display
5 | and saves them as images, then combines them into a GIF file.
6 | Credits:
7 | - Author: Ricardo Ribeiro Rodrigues
8 | - Date: 21/03/2024
9 | - source: https://gist.github.com/RicardoRibeiroRodrigues/9c40f36909112950860a410a565de667
10 | Usage:
11 | 1. Initialize PygameRecord with a filename and desired frames per second (fps).
12 | 2. Enter a Pygame event loop.
13 | 3. Add frames to the recorder at desired intervals.
14 | 4. When done recording, exit the Pygame event loop.
15 | 5. The recorded GIF will be saved automatically.
16 | """
17 |
18 | import pygame
19 | from PIL import Image
20 | import numpy as np
21 |
22 |
23 | class PygameRecord:
24 | def __init__(self, filename: str, fps: int):
25 | self.fps = fps
26 | self.filename = filename
27 | self.frames = []
28 |
29 | def add_frame(self):
30 | curr_surface = pygame.display.get_surface()
31 | x3 = pygame.surfarray.array3d(curr_surface)
32 | x3 = np.moveaxis(x3, 0, 1)
33 | array = Image.fromarray(np.uint8(x3))
34 | self.frames.append(array)
35 |
36 | def save(self):
37 | self.frames[0].save(
38 | self.filename,
39 | save_all=True,
40 | optimize=False,
41 | append_images=self.frames[1:],
42 | loop=0,
43 | duration=int(1000 / self.fps),
44 | )
45 |
46 | def __enter__(self):
47 | return self
48 |
49 | def __exit__(self, exc_type, exc_value, traceback):
50 | if exc_type is not None:
51 | print(f"An exception of type {exc_type} occurred: {exc_value}")
52 | self.save()
53 | # Return False if you want exceptions to propagate, True to suppress them
54 | return False
55 |
56 |
57 | if __name__ == "__main__":
58 | # Example usage
59 | from random import randint
60 |
61 | FPS = 30
62 | # Init the recorder with the output file and the desired FPS
63 | with PygameRecord("output.gif", FPS) as recorder:
64 | pygame.init()
65 | screen = pygame.display.set_mode((400, 400))
66 | running = True
67 | clock = pygame.time.Clock()
68 | n_frames = 90
69 | while running:
70 | for event in pygame.event.get():
71 | if event.type == pygame.QUIT:
72 | running = False
73 | screen.fill((0, 0, 0))
74 | pygame.draw.circle(
75 | screen,
76 | (randint(0, 255), randint(0, 255), randint(0, 255)),
77 | (200, 200),
78 | 50,
79 | )
80 | recorder.add_frame() # Add frame to recorder
81 | pygame.display.flip()
82 | clock.tick(FPS)
83 | # Used here to limit the size of the GIF, not necessary for normal usage.
84 | n_frames -= 1
85 | if n_frames == 0:
86 | break
87 | recorder.save()
88 | pygame.quit()
--------------------------------------------------------------------------------
/src/train_ppo_mlp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import TopNProbsWrapper
5 | import ray
6 | from ray import tune
7 | from ray.rllib.algorithms.ppo import PPOConfig
8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
9 | from ray.rllib.models import ModelCatalog
10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
11 | from ray.tune.registry import register_env
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class MLPModel(TorchModelV2, nn.Module):
17 | def __init__(
18 | self,
19 | obs_space,
20 | act_space,
21 | num_outputs,
22 | model_config,
23 | name,
24 | **kw,
25 | ):
26 | print("OBSSPACE: ", obs_space)
27 | TorchModelV2.__init__(
28 | self, obs_space, act_space, num_outputs, model_config, name, **kw
29 | )
30 | nn.Module.__init__(self)
31 |
32 | self.model = nn.Sequential(
33 | nn.Linear(obs_space.shape[0], 512),
34 | nn.ReLU(),
35 | nn.Linear(512, 256),
36 | nn.ReLU(),
37 | )
38 | self.policy_fn = nn.Linear(256, num_outputs)
39 | self.value_fn = nn.Linear(256, 1)
40 |
41 | def forward(self, input_dict, state, seq_lens):
42 | input_ = input_dict["obs"].float()
43 | value_input = self.model(input_)
44 |
45 | self._value_out = self.value_fn(value_input)
46 | return self.policy_fn(value_input), state
47 |
48 | def value_function(self):
49 | return self._value_out.flatten()
50 |
51 |
52 | def env_creator(args):
53 | env = DroneSwarmSearch(
54 | drone_amount=4,
55 | grid_size=20,
56 | dispersion_inc=0.1,
57 | person_initial_position=(10, 10),
58 | )
59 | env = TopNProbsWrapper(env, 10)
60 | return env
61 |
62 |
63 | if __name__ == "__main__":
64 | ray.init()
65 |
66 | env_name = "DSSE"
67 |
68 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
69 | ModelCatalog.register_custom_model("MLPModel", MLPModel)
70 |
71 | config = (
72 | PPOConfig()
73 | .environment(env=env_name)
74 | .rollouts(num_rollout_workers=5, rollout_fragment_length="auto")
75 | .training(
76 | train_batch_size=8192,
77 | lr=4e-5,
78 | gamma=0.99999,
79 | lambda_=0.9,
80 | use_gae=True,
81 | clip_param=0.4,
82 | grad_clip=None,
83 | entropy_coeff=0.1,
84 | vf_loss_coeff=0.25,
85 | vf_clip_param=4200,
86 | sgd_minibatch_size=1024,
87 | num_sgd_iter=10,
88 | model={
89 | "custom_model": "MLPModel",
90 | "_disable_preprocessor_api": True,
91 | },
92 | )
93 | .experimental(_disable_preprocessor_api=True)
94 | .debugging(log_level="ERROR")
95 | .framework(framework="torch")
96 | .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1")))
97 | )
98 |
99 | curr_path = pathlib.Path().resolve()
100 | tune.run(
101 | "PPO",
102 | name="PPO",
103 | stop={"timesteps_total": 5000000 if not os.environ.get("CI") else 50000},
104 | checkpoint_freq=10,
105 | storage_path=f"{curr_path}/ray_res/" + env_name,
106 | config=config.to_dict(),
107 | )
108 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # pandas
2 | # DSSE
3 | # torch
4 | # ray[rllib]
5 | aigyminsper
6 |
7 | #
8 |
9 | aiohttp==3.9.5
10 | aiosignal==1.3.1
11 | anyio==4.6.0
12 | argon2-cffi==23.1.0
13 | argon2-cffi-bindings==21.2.0
14 | arrow==1.3.0
15 | asciitree==0.3.3
16 | asttokens==2.4.1
17 | async-lru==2.0.4
18 | attrs==24.2.0
19 | babel==2.16.0
20 | beautifulsoup4==4.12.3
21 | bleach==6.1.0
22 | boto3==1.35.31
23 | botocore==1.35.31
24 | cachier==3.0.1
25 | Cartopy==0.23.0
26 | certifi==2024.8.30
27 | cf_xarray==0.9.5
28 | cffi==1.17.1
29 | cftime==1.6.4
30 | charset-normalizer==3.3.2
31 | click==8.1.7
32 | cloudpickle==3.0.0
33 | cmocean==4.0.3
34 | coloredlogs==15.0.1
35 | comm==0.2.2
36 | contourpy==1.3.0
37 | copernicusmarine==1.3.3
38 | cycler==0.12.1
39 | dask==2024.9.1
40 | debugpy==1.8.6
41 | decorator==5.1.1
42 | defusedxml==0.7.1
43 | DSSE==1.1.9
44 | executing==2.1.0
45 | Farama-Notifications==0.0.4
46 | fasteners==0.19
47 | fastjsonschema==2.20.0
48 | fonttools==4.54.1
49 | fqdn==1.5.1
50 | frozenlist==1.4.1
51 | fsspec==2024.9.0
52 | GDAL==3.4.1
53 | geojson==3.1.0
54 | gymnasium==0.29.1
55 | h11==0.14.0
56 | httpcore==1.0.6
57 | httpx==0.27.2
58 | humanfriendly==10.0
59 | idna==3.10
60 | ipykernel==6.29.5
61 | ipython==8.28.0
62 | isoduration==20.11.0
63 | jedi==0.19.1
64 | Jinja2==3.1.4
65 | jmespath==1.0.1
66 | json5==0.9.25
67 | jsonpointer==3.0.0
68 | jsonschema==4.23.0
69 | jsonschema-specifications==2023.12.1
70 | jupyter-events==0.10.0
71 | jupyter-lsp==2.2.5
72 | jupyter_client==8.6.3
73 | jupyter_core==5.7.2
74 | jupyter_server==2.14.2
75 | jupyter_server_terminals==0.5.3
76 | jupyterlab==4.2.5
77 | jupyterlab_pygments==0.3.0
78 | jupyterlab_server==2.27.3
79 | kiwisolver==1.4.7
80 | llvmlite==0.43.0
81 | locket==1.0.0
82 | lxml==5.3.0
83 | MarkupSafe==2.1.5
84 | matplotlib==3.8.4
85 | matplotlib-inline==0.1.7
86 | mistune==3.0.2
87 | multidict==6.1.0
88 | nbclient==0.10.0
89 | nbconvert==7.16.4
90 | nbformat==5.10.4
91 | nc-time-axis==1.4.1
92 | nest-asyncio==1.6.0
93 | netCDF4==1.7.1.post2
94 | notebook==7.2.2
95 | notebook_shim==0.2.4
96 | numba==0.60.0
97 | numcodecs==0.13.0
98 | numpy==1.26.4
99 | opendrift==1.11.13
100 | overrides==7.7.0
101 | packaging==24.1
102 | pandas==2.2.3
103 | pandocfilters==1.5.1
104 | parso==0.8.4
105 | partd==1.4.2
106 | pettingzoo==1.24.3
107 | pexpect==4.9.0
108 | pillow==10.4.0
109 | platformdirs==4.3.6
110 | portalocker==2.10.1
111 | prometheus_client==0.21.0
112 | prompt_toolkit==3.0.48
113 | psutil==6.0.0
114 | ptyprocess==0.7.0
115 | pure_eval==0.2.3
116 | pycparser==2.22
117 | pygame==2.6.1
118 | Pygments==2.18.0
119 | pykdtree==1.3.13
120 | pynucos==3.2.2
121 | pyparsing==3.1.4
122 | pyproj==3.7.0
123 | pyshp==2.3.1
124 | pystac==1.10.1
125 | python-dateutil==2.9.0.post0
126 | python-json-logger==2.0.7
127 | pytz==2024.2
128 | PyYAML==6.0.2
129 | pyzmq==26.2.0
130 | referencing==0.35.1
131 | requests==2.32.3
132 | rfc3339-validator==0.1.4
133 | rfc3986-validator==0.1.1
134 | roaring-landmask==0.9.1
135 | rpds-py==0.20.0
136 | s3transfer==0.10.2
137 | scipy==1.14.1
138 | semver==3.0.2
139 | Send2Trash==1.8.3
140 | setuptools==75.1.0
141 | shapely==2.0.6
142 | six==1.16.0
143 | sniffio==1.3.1
144 | soupsieve==2.6
145 | stack-data==0.6.3
146 | terminado==0.18.1
147 | tinycss2==1.3.0
148 | toolz==0.12.1
149 | tornado==6.4.1
150 | tqdm==4.66.5
151 | traitlets==5.14.3
152 | trajan==0.6.3
153 | types-python-dateutil==2.9.0.20240906
154 | typing_extensions==4.12.2
155 | tzdata==2024.2
156 | uri-template==1.3.0
157 | urllib3==2.2.3
158 | utm==0.7.0
159 | watchdog==5.0.3
160 | wcwidth==0.2.13
161 | webcolors==24.8.0
162 | webencodings==0.5.1
163 | websocket-client==1.8.0
164 | xarray==2024.9.0
165 | xhistogram==0.3.2
166 | yarl==1.13.1
167 | zarr==2.18.3
168 | rsutils
--------------------------------------------------------------------------------
/src/test_trained_cov_mlp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from recorder import PygameRecord
4 | from DSSE import CoverageDroneSwarmSearch
5 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper, AllFlattenWrapper
6 | import ray
7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
8 | from ray.rllib.models import ModelCatalog
9 | from ray.tune.registry import register_env
10 | from ray.rllib.algorithms.ppo import PPO
11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
12 | import torch.nn as nn
13 | import argparse
14 | import torch
15 |
16 |
17 |
18 | argparser = argparse.ArgumentParser()
19 | argparser.add_argument("--checkpoint", type=str, required=True)
20 | argparser.add_argument("--see", action="store_true", default=False)
21 | args = argparser.parse_args()
22 |
23 |
24 | def env_creator(_):
25 | print("-------------------------- ENV CREATOR --------------------------")
26 | N_AGENTS = 2
27 | render_mode = "human" if args.see else "ansi"
28 | # 6 hours of simulation, 600 radius
29 | env = CoverageDroneSwarmSearch(
30 | timestep_limit=200, drone_amount=N_AGENTS, prob_matrix_path="min_matrix.npy", render_mode=render_mode
31 | )
32 | env = AllFlattenWrapper(env)
33 | grid_size = env.grid_size
34 | print("Grid size: ", grid_size)
35 | positions = [
36 | (0, grid_size // 2),
37 | (grid_size - 1, grid_size // 2),
38 | ]
39 | env = RetainDronePosWrapper(env, positions)
40 | return env
41 | def position_on_diagonal(grid_size, drone_amount):
42 | positions = []
43 | center = grid_size // 2
44 | for i in range(-drone_amount // 2, drone_amount // 2):
45 | positions.append((center + i, center + i))
46 | return positions
47 |
48 | env = env_creator(None)
49 | register_env("DSSE_Coverage", lambda config: ParallelPettingZooEnv(env_creator(config)))
50 | ray.init()
51 |
52 | def print_mean(values, name):
53 | print(f"Mean of {name}: ", sum(values) / len(values))
54 |
55 | checkpoint_path = args.checkpoint
56 | PPOagent = PPO.from_checkpoint(checkpoint_path)
57 |
58 | reward_sum = 0
59 | i = 0
60 |
61 | if args.see:
62 | obs, info = env.reset()
63 | with PygameRecord("test_trained.gif", 5) as rec:
64 | while env.agents:
65 | actions = {}
66 | for k, v in obs.items():
67 | actions[k] = PPOagent.compute_single_action(v, explore=False)
68 | # print(v)
69 | # action = PPOagent.compute_actions(obs)
70 | obs, rw, term, trunc, info = env.step(actions)
71 | reward_sum += sum(rw.values())
72 | i += 1
73 | rec.add_frame()
74 | print(info)
75 | else:
76 | rewards = []
77 | cov_rate = []
78 | steps_needed = []
79 | repeated_cov = []
80 |
81 | N_EVALS = 500
82 | for _ in range(N_EVALS):
83 | i = 0
84 | obs, info = env.reset()
85 | reward_sum = 0
86 | while env.agents:
87 | actions = {}
88 | for k, v in obs.items():
89 | actions[k] = PPOagent.compute_single_action(v, explore=False)
90 | obs, rw, term, trunc, info = env.step(actions)
91 | reward_sum += sum(rw.values())
92 | i += 1
93 | rewards.append(reward_sum)
94 | steps_needed.append(i)
95 | cov_rate.append(info["drone0"]["coverage_rate"])
96 | repeated_cov.append(info["drone0"]["repeated_coverage"])
97 |
98 | print_mean(rewards, "rewards")
99 | print_mean(steps_needed, "steps needed")
100 | print_mean(cov_rate, "coverage rate")
101 | print_mean(repeated_cov, "repeated coverage")
102 |
103 | print("Total reward: ", reward_sum)
104 | print("Total steps: ", i)
105 | print("Found: ", info)
106 | env.close()
107 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
--------------------------------------------------------------------------------
/imgs/drone.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/train_ppo_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
5 | import ray
6 | from ray import tune
7 | from ray.rllib.algorithms.ppo import PPOConfig
8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
9 | from ray.rllib.models import ModelCatalog
10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
11 | from ray.tune.registry import register_env
12 | from torch import nn
13 | import torch
14 |
15 |
16 | class CNNModel(TorchModelV2, nn.Module):
17 | def __init__(
18 | self,
19 | obs_space,
20 | act_space,
21 | num_outputs,
22 | model_config,
23 | name,
24 | **kw,
25 | ):
26 | print("OBSSPACE: ", obs_space)
27 | TorchModelV2.__init__(
28 | self, obs_space, act_space, num_outputs, model_config, name, **kw
29 | )
30 | nn.Module.__init__(self)
31 |
32 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
33 | self.cnn = nn.Sequential(
34 | nn.Conv2d(
35 | in_channels=1,
36 | out_channels=16,
37 | kernel_size=(8, 8),
38 | stride=(1, 1),
39 | ),
40 | nn.Tanh(),
41 | nn.Conv2d(
42 | in_channels=16,
43 | out_channels=32,
44 | kernel_size=(4, 4),
45 | stride=(1, 1),
46 | ),
47 | nn.Tanh(),
48 | nn.Flatten(),
49 | nn.Linear(flatten_size, 256),
50 | nn.Tanh(),
51 | )
52 |
53 | self.linear = nn.Sequential(
54 | nn.Linear(obs_space[0].shape[0], 512),
55 | nn.Tanh(),
56 | nn.Linear(512, 256),
57 | nn.Tanh(),
58 | )
59 |
60 | self.join = nn.Sequential(
61 | nn.Linear(256 * 2, 256),
62 | nn.Tanh(),
63 | )
64 |
65 | self.policy_fn = nn.Linear(256, num_outputs)
66 | self.value_fn = nn.Linear(256, 1)
67 |
68 | def forward(self, input_dict, state, seq_lens):
69 | input_positions = input_dict["obs"][0].float()
70 | input_matrix = input_dict["obs"][1].float()
71 |
72 | input_matrix = input_matrix.unsqueeze(1)
73 | cnn_out = self.cnn(input_matrix)
74 | linear_out = self.linear(input_positions)
75 |
76 | value_input = torch.cat((cnn_out, linear_out), dim=1)
77 | value_input = self.join(value_input)
78 |
79 | self._value_out = self.value_fn(value_input)
80 | return self.policy_fn(value_input), state
81 |
82 | def value_function(self):
83 | return self._value_out.flatten()
84 |
85 |
86 | def env_creator(args):
87 | env = DroneSwarmSearch(
88 | drone_amount=4,
89 | grid_size=40,
90 | dispersion_inc=0.1,
91 | person_initial_position=(20, 20),
92 | )
93 | positions = [
94 | (20, 0),
95 | (20, 39),
96 | (0, 20),
97 | (39, 20),
98 | ]
99 | env = AllPositionsWrapper(env)
100 | env = RetainDronePosWrapper(env, positions)
101 | return env
102 |
103 |
104 | if __name__ == "__main__":
105 | ray.init()
106 |
107 | env_name = "DSSE"
108 |
109 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
110 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
111 |
112 | config = (
113 | PPOConfig()
114 | .environment(env=env_name)
115 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto")
116 | .training(
117 | train_batch_size=8192,
118 | lr=1e-5,
119 | gamma=0.9999999,
120 | lambda_=0.9,
121 | use_gae=True,
122 | # clip_param=0.3,
123 | # grad_clip=None,
124 | entropy_coeff=0.01,
125 | # vf_loss_coeff=0.25,
126 | # vf_clip_param=10,
127 | sgd_minibatch_size=300,
128 | num_sgd_iter=10,
129 | model={
130 | "custom_model": "CNNModel",
131 | "_disable_preprocessor_api": True,
132 | },
133 | )
134 | .experimental(_disable_preprocessor_api=True)
135 | .debugging(log_level="ERROR")
136 | .framework(framework="torch")
137 | .resources(num_gpus=1)
138 | )
139 |
140 | curr_path = pathlib.Path().resolve()
141 | tune.run(
142 | "PPO",
143 | name="PPO",
144 | stop={"timesteps_total": 20_000_000 if not os.environ.get("CI") else 50000, "episode_reward_mean": 1.75},
145 | checkpoint_freq=10,
146 | storage_path=f"{curr_path}/ray_res/" + env_name,
147 | config=config.to_dict(),
148 | )
149 |
--------------------------------------------------------------------------------
/src/train_ppo_encoded.py:
--------------------------------------------------------------------------------
1 | import os
2 | from DSSE import DroneSwarmSearch
3 | from DSSE.environment.wrappers.matrix_encode_wrapper import MatrixEncodeWrapper
4 | import ray
5 | from ray import tune
6 | from ray.rllib.algorithms.ppo import PPOConfig
7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
8 | from ray.rllib.models import ModelCatalog
9 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
10 | from ray.tune.registry import register_env
11 | import torch
12 | from torch import nn
13 |
14 |
15 | class MLPModel(TorchModelV2, nn.Module):
16 | def __init__(
17 | self,
18 | obs_space,
19 | act_space,
20 | num_outputs,
21 | model_config,
22 | name,
23 | **kw,
24 | ):
25 | print(
26 | "______________________________ MLPModel ________________________________"
27 | )
28 | print("OBSSPACE: ", obs_space)
29 | print("ACTSPACE: ", act_space)
30 | print("NUMOUTPUTS: ", num_outputs)
31 | TorchModelV2.__init__(
32 | self, obs_space, act_space, num_outputs, model_config, name, **kw
33 | )
34 | nn.Module.__init__(self)
35 | self.conv1 = nn.Sequential(
36 | nn.Conv2d(
37 | in_channels=1,
38 | out_channels=16,
39 | kernel_size=(8, 8),
40 | stride=(1, 1),
41 | ),
42 | nn.ReLU(),
43 | # nn.MaxPool2d(kernel_size=2),
44 | nn.Conv2d(
45 | in_channels=16,
46 | out_channels=32,
47 | kernel_size=(4, 4),
48 | stride=(1, 1)
49 | ),
50 | nn.ReLU(),
51 | nn.Conv2d(
52 | in_channels=32,
53 | out_channels=64,
54 | kernel_size=(3, 3),
55 | stride=(1, 1)
56 | ),
57 | # nn.MaxPool2d(kernel_size=2),
58 | nn.Flatten(),
59 | )
60 | grid_size = obs_space.shape[0]
61 | # Fully connected layers
62 | # F: (((W - K + 2P)/S) + 1)
63 | first_conv_output_size = (grid_size - 8) + 1
64 | second_conv_output_size = (first_conv_output_size - 4) + 1
65 | third_conv_output_size = (second_conv_output_size - 3) + 1
66 | self.fc1_input_size = third_conv_output_size * third_conv_output_size * 64
67 | # Apply a DENSE layer to the flattened CNN2 output
68 | self.fc1 = nn.Linear(self.fc1_input_size, 512)
69 | self.policy_fn = nn.Linear(512, num_outputs)
70 | self.value_fn = nn.Linear(512, 1)
71 |
72 | def forward(self, input_dict, state, seq_lens):
73 | input_ = input_dict["obs"]
74 | # Convert dims
75 | input_ = input_.unsqueeze(1)
76 | model_out = self.conv1(input_)
77 | model_out = self.fc1(model_out)
78 | self._value_out = self.value_fn(model_out)
79 | return self.policy_fn(model_out), state
80 |
81 | def value_function(self):
82 | return self._value_out.flatten()
83 |
84 |
85 | def env_creator(args):
86 | env = DroneSwarmSearch(
87 | drone_amount=4,
88 | grid_size=20,
89 | dispersion_inc=0.1,
90 | person_initial_position=(10, 10),
91 | )
92 | env = MatrixEncodeWrapper(env)
93 | return env
94 |
95 |
96 | if __name__ == "__main__":
97 | ray.init()
98 |
99 | env_name = "DSSE"
100 |
101 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
102 | ModelCatalog.register_custom_model("MLPModel", MLPModel)
103 |
104 | config = (
105 | PPOConfig()
106 | .environment(env=env_name)
107 | .rollouts(num_rollout_workers=5, rollout_fragment_length='auto')
108 | .training(
109 | train_batch_size=512,
110 | lr=2e-5,
111 | gamma=0.99999,
112 | lambda_=0.9,
113 | use_gae=True,
114 | clip_param=0.3,
115 | grad_clip=None,
116 | entropy_coeff=0.1,
117 | vf_loss_coeff=0.25,
118 | vf_clip_param=420,
119 | sgd_minibatch_size=64,
120 | num_sgd_iter=10,
121 | model={
122 | "custom_model": "MLPModel",
123 | "_disable_preprocessor_api": True,
124 | },
125 | )
126 | .debugging(log_level="ERROR")
127 | .framework(framework="torch")
128 | .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1")))
129 | )
130 | config["_disable_preprocessor_api"] = False
131 |
132 | tune.run(
133 | "PPO",
134 | name="PPO",
135 | stop={"timesteps_total": 5000000 if not os.environ.get("CI") else 50000},
136 | checkpoint_freq=10,
137 | # storage_path="ray_res/" + env_name,
138 | config=config.to_dict(),
139 | )
140 |
--------------------------------------------------------------------------------
/src/train_ppo_cnn_comm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
5 | from DSSE.environment.wrappers.communication_wrapper import CommunicationWrapper
6 | import ray
7 | from ray import tune
8 | from ray.rllib.algorithms.ppo import PPOConfig
9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
10 | from ray.rllib.models import ModelCatalog
11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
12 | from ray.tune.registry import register_env
13 | from torch import nn
14 | import torch
15 |
16 |
17 | class CNNModel(TorchModelV2, nn.Module):
18 | def __init__(
19 | self,
20 | obs_space,
21 | act_space,
22 | num_outputs,
23 | model_config,
24 | name,
25 | **kw,
26 | ):
27 | print("OBSSPACE: ", obs_space)
28 | TorchModelV2.__init__(
29 | self, obs_space, act_space, num_outputs, model_config, name, **kw
30 | )
31 | nn.Module.__init__(self)
32 |
33 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
34 | self.cnn = nn.Sequential(
35 | nn.Conv2d(
36 | in_channels=1,
37 | out_channels=16,
38 | kernel_size=(8, 8),
39 | stride=(1, 1),
40 | ),
41 | nn.Tanh(),
42 | nn.Conv2d(
43 | in_channels=16,
44 | out_channels=32,
45 | kernel_size=(4, 4),
46 | stride=(1, 1),
47 | ),
48 | nn.Tanh(),
49 | nn.Flatten(),
50 | nn.Linear(flatten_size, 256),
51 | nn.Tanh(),
52 | )
53 |
54 | self.linear = nn.Sequential(
55 | nn.Linear(obs_space[0].shape[0], 512),
56 | nn.Tanh(),
57 | nn.Linear(512, 256),
58 | nn.Tanh(),
59 | )
60 |
61 | self.join = nn.Sequential(
62 | nn.Linear(256 * 2, 256),
63 | nn.Tanh(),
64 | )
65 |
66 | self.policy_fn = nn.Linear(256, num_outputs)
67 | self.value_fn = nn.Linear(256, 1)
68 |
69 | def forward(self, input_dict, state, seq_lens):
70 | input_positions = input_dict["obs"][0].float()
71 | input_matrix = input_dict["obs"][1].float()
72 |
73 | input_matrix = input_matrix.unsqueeze(1)
74 | cnn_out = self.cnn(input_matrix)
75 | linear_out = self.linear(input_positions)
76 |
77 | value_input = torch.cat((cnn_out, linear_out), dim=1)
78 | value_input = self.join(value_input)
79 |
80 | self._value_out = self.value_fn(value_input)
81 | return self.policy_fn(value_input), state
82 |
83 | def value_function(self):
84 | return self._value_out.flatten()
85 |
86 |
87 | def env_creator(args):
88 | env = DroneSwarmSearch(
89 | drone_amount=4,
90 | grid_size=40,
91 | dispersion_inc=0.1,
92 | person_initial_position=(20, 20),
93 | )
94 | positions = [
95 | (20, 0),
96 | (20, 39),
97 | (0, 20),
98 | (39, 20),
99 | ]
100 | env = AllPositionsWrapper(env)
101 | env = CommunicationWrapper(env, n_steps=12)
102 | env = RetainDronePosWrapper(env, positions)
103 | return env
104 |
105 |
106 | if __name__ == "__main__":
107 | ray.init()
108 |
109 | env_name = "DSSE"
110 |
111 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
112 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
113 |
114 | config = (
115 | PPOConfig()
116 | .environment(env=env_name)
117 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto", num_envs_per_worker=2)
118 | .training(
119 | train_batch_size=8192,
120 | lr=1e-5,
121 | gamma=0.9999999,
122 | lambda_=0.9,
123 | use_gae=True,
124 | # clip_param=0.3,
125 | # grad_clip=None,
126 | entropy_coeff=0.01,
127 | # vf_loss_coeff=0.25,
128 | # vf_clip_param=10,
129 | sgd_minibatch_size=300,
130 | num_sgd_iter=10,
131 | model={
132 | "custom_model": "CNNModel",
133 | "_disable_preprocessor_api": True,
134 | },
135 | )
136 | .experimental(_disable_preprocessor_api=True)
137 | .debugging(log_level="ERROR")
138 | .framework(framework="torch")
139 | .resources(num_gpus=1)
140 | )
141 |
142 | curr_path = pathlib.Path().resolve()
143 | tune.run(
144 | "PPO",
145 | name="PPO_COMM_WRAPPER",
146 | stop={"timesteps_total": 20_000_000 if not os.environ.get("CI") else 50000, "episode_reward_mean": 1.75},
147 | checkpoint_freq=10,
148 | storage_path=f"{curr_path}/ray_res/" + env_name,
149 | config=config.to_dict(),
150 | )
151 |
--------------------------------------------------------------------------------
/src/train_dqn_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from drone_swarm_search.DSSE import DroneSwarmSearch
4 | from drone_swarm_search.DSSE.environment.wrappers import AllPositionsWrapper, RetainDronePosWrapper, TopNProbsWrapper
5 | # from DSSE.environment.wrappers import AllPositionsWrapper
6 | import ray
7 | from ray import tune
8 | from ray.rllib.algorithms.dqn import DQNConfig
9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
10 | from ray.rllib.models import ModelCatalog
11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
12 | from ray.tune.registry import register_env
13 | import torch
14 | from torch import nn
15 | import random
16 |
17 | class CNNModel(TorchModelV2, nn.Module):
18 | def __init__(
19 | self,
20 | obs_space,
21 | act_space,
22 | num_outputs,
23 | model_config,
24 | name,
25 | **kw,
26 | ):
27 | print("OBSSPACE: ", obs_space)
28 | TorchModelV2.__init__(
29 | self, obs_space, act_space, num_outputs, model_config, name, **kw
30 | )
31 | nn.Module.__init__(self)
32 |
33 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
34 | self.cnn = nn.Sequential(
35 | nn.Conv2d(
36 | in_channels=1,
37 | out_channels=16,
38 | kernel_size=(8, 8),
39 | stride=(1, 1),
40 | ),
41 | nn.Tanh(),
42 | nn.Conv2d(
43 | in_channels=16,
44 | out_channels=32,
45 | kernel_size=(4, 4),
46 | stride=(1, 1),
47 | ),
48 | nn.Tanh(),
49 | nn.Flatten(),
50 | nn.Linear(flatten_size, 256),
51 | nn.Tanh(),
52 | )
53 |
54 | self.linear = nn.Sequential(
55 | nn.Linear(obs_space[0].shape[0], 512),
56 | nn.Tanh(),
57 | nn.Linear(512, 256),
58 | nn.Tanh(),
59 | )
60 |
61 | self.join = nn.Sequential(
62 | nn.Linear(256 * 2, 256),
63 | nn.Tanh(),
64 | )
65 |
66 | self.policy_fn = nn.Linear(256, num_outputs)
67 | self.value_fn = nn.Linear(256, 1)
68 |
69 | def forward(self, input_dict, state, seq_lens):
70 | input_positions = input_dict["obs"][0].float()
71 | input_matrix = input_dict["obs"][1].float()
72 |
73 | input_matrix = input_matrix.unsqueeze(1)
74 | cnn_out = self.cnn(input_matrix)
75 | linear_out = self.linear(input_positions)
76 |
77 | value_input = torch.cat((cnn_out, linear_out), dim=1)
78 | value_input = self.join(value_input)
79 |
80 | self._value_out = self.value_fn(value_input)
81 | return self.policy_fn(value_input), state
82 |
83 | def value_function(self):
84 | return self._value_out.flatten()
85 |
86 | def env_creator(args):
87 | env = DroneSwarmSearch(
88 | drone_amount=4,
89 | grid_size=40,
90 | dispersion_inc=0.1,
91 | person_initial_position=(20, 20),
92 | )
93 | positions = [
94 | (20, 0),
95 | (20, 39),
96 | (0, 20),
97 | (39, 20),
98 | ]
99 | env = AllPositionsWrapper(env)
100 | env = RetainDronePosWrapper(env, positions)
101 | return env
102 |
103 |
104 | if __name__ == "__main__":
105 | ray.init()
106 |
107 | env_name = "DSSE"
108 |
109 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
110 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
111 |
112 | natural_value = 512/(14*20*1)
113 | config = (
114 | DQNConfig()
115 | .exploration(
116 | exploration_config={
117 | "type": "EpsilonGreedy",
118 | "initial_epsilon": 1.0,
119 | "final_epsilon": 0.15,
120 | "epsilon_timesteps": 400_000,
121 | }
122 | )
123 | .environment(env=env_name)
124 | .rollouts(num_rollout_workers=14, rollout_fragment_length=20)
125 | .framework("torch")
126 | .debugging(log_level="ERROR")
127 | .resources(num_gpus=1)
128 | .experimental(_disable_preprocessor_api=True)
129 | .training(
130 | lr=1e-4,
131 | gamma=0.9999999,
132 | tau=0.01,
133 | train_batch_size=512,
134 | model={
135 | "custom_model": "CNNModel",
136 | "_disable_preprocessor_api": True,
137 | },
138 | target_network_update_freq=500,
139 | double_q=False,
140 | training_intensity=natural_value,
141 | v_min=0,
142 | v_max=2,
143 | )
144 | )
145 |
146 | curr_path = pathlib.Path().resolve()
147 | tune.run(
148 | "DQN",
149 | name="DQN_DSSE",
150 | stop={"timesteps_total": 100_000_000},
151 | checkpoint_freq=200,
152 | storage_path=f"{curr_path}/ray_res/" + env_name,
153 | config=config.to_dict(),
154 | )
155 |
156 | # Finalize Ray to free up resources
157 | ray.shutdown()
158 |
--------------------------------------------------------------------------------
/src/train_ppo_multi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import AllPositionsWrapper
5 | import ray
6 | from ray import tune
7 | from ray.rllib.algorithms.ppo import PPOConfig
8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
9 | from ray.rllib.models import ModelCatalog
10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
11 | from ray.tune.registry import register_env
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class MLPModel(TorchModelV2, nn.Module):
17 | def __init__(
18 | self,
19 | obs_space,
20 | act_space,
21 | num_outputs,
22 | model_config,
23 | name,
24 | **kw,
25 | ):
26 | print("OBSSPACE: ", obs_space)
27 | TorchModelV2.__init__(
28 | self, obs_space, act_space, num_outputs, model_config, name, **kw
29 | )
30 | nn.Module.__init__(self)
31 |
32 | # F: (((W - K + 2P)/S) + 1)
33 | grid_size = obs_space[1].shape[0]
34 | first_conv_output_size = (grid_size - 8) + 1
35 | second_conv_output_size = (first_conv_output_size - 4) + 1
36 | third_conv_output_size = (second_conv_output_size - 3) + 1
37 | self.fc1_input_size = third_conv_output_size * third_conv_output_size * 64
38 |
39 | self.conv1 = nn.Sequential(
40 | nn.Conv2d(
41 | in_channels=1,
42 | out_channels=16,
43 | kernel_size=(8, 8),
44 | stride=(1, 1),
45 | ),
46 | nn.ReLU(),
47 | # nn.MaxPool2d(kernel_size=2),
48 | nn.Conv2d(
49 | in_channels=16,
50 | out_channels=32,
51 | kernel_size=(4, 4),
52 | stride=(1, 1)
53 | ),
54 | nn.ReLU(),
55 | nn.Conv2d(
56 | in_channels=32,
57 | out_channels=64,
58 | kernel_size=(3, 3),
59 | stride=(1, 1)
60 | ),
61 | # nn.MaxPool2d(kernel_size=2),
62 | nn.Flatten(),
63 | nn.Linear(self.fc1_input_size, 512),
64 | nn.ReLU(),
65 | nn.LayerNorm(512),
66 | )
67 |
68 | self.fc_scalar = nn.Sequential(
69 | nn.Linear(obs_space[0].shape[0], 128),
70 | nn.ReLU(),
71 | nn.Linear(128, 512),
72 | nn.ReLU(),
73 | nn.LayerNorm(512),
74 | )
75 |
76 | self.unifier = nn.Sequential(
77 | nn.Linear(512 * 2, 512),
78 | nn.ReLU(),
79 | nn.LayerNorm(512),
80 | )
81 | self.policy_fn = nn.Linear(512, num_outputs)
82 | self.value_fn = nn.Linear(512, 1)
83 |
84 | def forward(self, input_dict, state, seq_lens):
85 | input_ = input_dict["obs"]
86 | # Convert dims
87 | input_cnn = input_[1].unsqueeze(1)
88 | model_out = self.conv1(input_cnn)
89 |
90 | scalar_input = input_[0].float()
91 | scalar_out = self.fc_scalar(scalar_input)
92 |
93 | value_input = torch.cat((model_out, scalar_out), -1)
94 | value_input = self.unifier(value_input)
95 |
96 | self._value_out = self.value_fn(value_input)
97 | return self.policy_fn(value_input), state
98 |
99 | def value_function(self):
100 | return self._value_out.flatten()
101 |
102 |
103 | def env_creator(args):
104 | env = DroneSwarmSearch(
105 | drone_amount=4,
106 | grid_size=20,
107 | dispersion_inc=0.1,
108 | person_initial_position=(10, 10),
109 | )
110 | env = AllPositionsWrapper(env)
111 | return env
112 |
113 |
114 | if __name__ == "__main__":
115 | ray.init()
116 |
117 | env_name = "DSSE"
118 |
119 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
120 | ModelCatalog.register_custom_model("MLPModel", MLPModel)
121 |
122 | config = (
123 | PPOConfig()
124 | .environment(env=env_name)
125 | .rollouts(num_rollout_workers=4, rollout_fragment_length=128)
126 | .training(
127 | train_batch_size=512,
128 | lr=4e-5,
129 | gamma=0.99999,
130 | lambda_=0.9,
131 | use_gae=True,
132 | clip_param=0.4,
133 | grad_clip=None,
134 | entropy_coeff=0.1,
135 | vf_loss_coeff=0.25,
136 | vf_clip_param=420,
137 | sgd_minibatch_size=64,
138 | num_sgd_iter=10,
139 | model={
140 | "custom_model": "MLPModel",
141 | "_disable_preprocessor_api": True,
142 | },
143 | )
144 | .experimental(_disable_preprocessor_api=True)
145 | .debugging(log_level="ERROR")
146 | .framework(framework="torch")
147 | .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1")))
148 | )
149 |
150 | curr_path = pathlib.Path().resolve()
151 | tune.run(
152 | "PPO",
153 | name="PPO",
154 | stop={"timesteps_total": 5000000 if not os.environ.get("CI") else 50000},
155 | checkpoint_freq=10,
156 | # local_dir="ray_results/" + env_name,
157 | storage_path=f"{curr_path}/ray_res/" + env_name,
158 | config=config.to_dict(),
159 | )
160 |
--------------------------------------------------------------------------------
/src/train_dqn_multi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import AllPositionsWrapper, RetainDronePosWrapper
5 | import ray
6 | from ray import air
7 | from ray import tune
8 | from ray.rllib.algorithms.dqn.dqn import DQNConfig
9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
10 | from ray.rllib.models import ModelCatalog
11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
12 | from ray.tune.registry import register_env
13 | import torch
14 | from torch import nn
15 |
16 |
17 | class MLPModel(TorchModelV2, nn.Module):
18 | def __init__(
19 | self,
20 | obs_space,
21 | act_space,
22 | num_outputs,
23 | model_config,
24 | name,
25 | **kw,
26 | ):
27 | print("OBSSPACE: ", obs_space)
28 | TorchModelV2.__init__(
29 | self, obs_space, act_space, num_outputs, model_config, name, **kw
30 | )
31 | nn.Module.__init__(self)
32 |
33 | # F: (((W - K + 2P)/S) + 1)
34 | grid_size = obs_space[1].shape[0]
35 | first_conv_output_size = (grid_size - 8) + 1
36 | second_conv_output_size = (first_conv_output_size - 4) + 1
37 | third_conv_output_size = (second_conv_output_size - 3) + 1
38 | self.fc1_input_size = third_conv_output_size * third_conv_output_size * 64
39 |
40 | self.conv1 = nn.Sequential(
41 | nn.Conv2d(
42 | in_channels=1,
43 | out_channels=16,
44 | kernel_size=(8, 8),
45 | stride=(1, 1),
46 | ),
47 | nn.ReLU(),
48 | # nn.MaxPool2d(kernel_size=2),
49 | nn.Conv2d(
50 | in_channels=16,
51 | out_channels=32,
52 | kernel_size=(4, 4),
53 | stride=(1, 1)
54 | ),
55 | nn.ReLU(),
56 | nn.Conv2d(
57 | in_channels=32,
58 | out_channels=64,
59 | kernel_size=(3, 3),
60 | stride=(1, 1)
61 | ),
62 | # nn.MaxPool2d(kernel_size=2),
63 | nn.Flatten(),
64 | nn.Linear(self.fc1_input_size, 512),
65 | nn.ReLU(),
66 | nn.LayerNorm(512),
67 | )
68 |
69 | self.fc_scalar = nn.Sequential(
70 | nn.Linear(obs_space[0].shape[0], 128),
71 | nn.ReLU(),
72 | nn.Linear(128, 512),
73 | nn.ReLU(),
74 | nn.LayerNorm(512),
75 | )
76 |
77 | self.unifier = nn.Sequential(
78 | nn.Linear(512 * 2, 512),
79 | nn.ReLU(),
80 | nn.LayerNorm(512),
81 | )
82 | self.policy_fn = nn.Linear(512, num_outputs)
83 |
84 | def forward(self, input_dict, state, seq_lens):
85 | input_ = input_dict["obs"]
86 | # Convert dims
87 | input_cnn = input_[1].unsqueeze(1)
88 | model_out = self.conv1(input_cnn)
89 |
90 | scalar_input = input_[0].float()
91 | scalar_out = self.fc_scalar(scalar_input)
92 |
93 | value_input = torch.cat((model_out, scalar_out), -1)
94 | value_input = self.unifier(value_input)
95 |
96 | return self.policy_fn(value_input), state
97 |
98 |
99 | def env_creator(args):
100 | env = DroneSwarmSearch(
101 | drone_amount=4,
102 | grid_size=20,
103 | dispersion_inc=0.1,
104 | person_initial_position=(10, 10),
105 | )
106 | env = AllPositionsWrapper(env)
107 | env = RetainDronePosWrapper(env, [(0, 0), (0, 19), (19, 0), (19, 19)])
108 | return env
109 |
110 |
111 | if __name__ == "__main__":
112 | ray.init()
113 |
114 | env_name = "DSSE"
115 |
116 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
117 | ModelCatalog.register_custom_model("MLPModel", MLPModel)
118 |
119 | config = DQNConfig()
120 | config = config.environment(env=env_name)
121 | config = config.rollouts(num_rollout_workers=6, rollout_fragment_length="auto", num_envs_per_worker=2)
122 | config = config.training(
123 | train_batch_size=512,
124 | grad_clip=None,
125 | target_network_update_freq=1,
126 | tau=0.005,
127 | gamma=0.99999,
128 | n_step=1,
129 | double_q=True,
130 | dueling=False,
131 | model={"custom_model": "MLPModel", "_disable_preprocessor_api": True},
132 | v_min=-800,
133 | v_max=800,
134 | )
135 | config = config.exploration(
136 | exploration_config={
137 | "type": "EpsilonGreedy",
138 | "initial_epsilon": 1.0,
139 | "final_epsilon": 0.05,
140 | "epsilon_timesteps": 350_000,
141 | }
142 | )
143 | config = config.debugging(log_level="ERROR")
144 | config = config.framework(framework="torch")
145 | config = config.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1")))
146 | config = config.experimental(_disable_preprocessor_api=True)
147 |
148 | curr_path = pathlib.Path().resolve()
149 | run_config = air.RunConfig(
150 | stop={"timesteps_total": 10_000_000 if not os.environ.get("CI") else 50000},
151 | storage_path=f"{curr_path}/ray_res/" + env_name,
152 | checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
153 | )
154 | tune.Tuner(
155 | "DQN",
156 | run_config=run_config,
157 | param_space=config.to_dict()
158 | ).fit()
159 |
--------------------------------------------------------------------------------
/src/train_descentralized_ppo_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
5 | import ray
6 | from ray import tune
7 | from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
8 | from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
9 | from ray.rllib.algorithms.ppo import PPOConfig
10 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
11 | from ray.rllib.models import ModelCatalog
12 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
13 | from ray.tune.registry import register_env
14 | import torch
15 | from torch import nn
16 |
17 |
18 | class CNNModel(TorchModelV2, nn.Module):
19 | def __init__(
20 | self,
21 | obs_space,
22 | act_space,
23 | num_outputs,
24 | model_config,
25 | name,
26 | **kw,
27 | ):
28 | print("OBSSPACE: ", obs_space)
29 | TorchModelV2.__init__(
30 | self, obs_space, act_space, num_outputs, model_config, name, **kw
31 | )
32 | nn.Module.__init__(self)
33 |
34 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
35 | self.cnn = nn.Sequential(
36 | nn.Conv2d(
37 | in_channels=1,
38 | out_channels=16,
39 | kernel_size=(8, 8),
40 | stride=(1, 1),
41 | ),
42 | nn.Tanh(),
43 | nn.Conv2d(
44 | in_channels=16,
45 | out_channels=32,
46 | kernel_size=(4, 4),
47 | stride=(1, 1),
48 | ),
49 | nn.Tanh(),
50 | nn.Flatten(),
51 | nn.Linear(flatten_size, 256),
52 | nn.Tanh(),
53 | )
54 |
55 | self.linear = nn.Sequential(
56 | nn.Linear(obs_space[0].shape[0], 512),
57 | nn.Tanh(),
58 | nn.Linear(512, 256),
59 | nn.Tanh(),
60 | )
61 |
62 | self.join = nn.Sequential(
63 | nn.Linear(256 * 2, 256),
64 | nn.Tanh(),
65 | )
66 |
67 | self.policy_fn = nn.Linear(256, num_outputs)
68 | self.value_fn = nn.Linear(256, 1)
69 |
70 | def forward(self, input_dict, state, seq_lens):
71 | input_positions = input_dict["obs"][0].float()
72 | input_matrix = input_dict["obs"][1].float()
73 |
74 | input_matrix = input_matrix.unsqueeze(1)
75 | cnn_out = self.cnn(input_matrix)
76 | linear_out = self.linear(input_positions)
77 |
78 | value_input = torch.cat((cnn_out, linear_out), dim=1)
79 | value_input = self.join(value_input)
80 |
81 | self._value_out = self.value_fn(value_input)
82 | return self.policy_fn(value_input), state
83 |
84 | def value_function(self):
85 | return self._value_out.flatten()
86 |
87 |
88 |
89 | def env_creator(args):
90 | env = DroneSwarmSearch(
91 | drone_amount=4,
92 | grid_size=40,
93 | dispersion_inc=0.1,
94 | person_initial_position=(20, 20),
95 | )
96 | positions = [
97 | (20, 0),
98 | (20, 39),
99 | (0, 20),
100 | (39, 20),
101 | ]
102 | env = AllPositionsWrapper(env)
103 | env = RetainDronePosWrapper(env, positions)
104 | return env
105 |
106 |
107 | if __name__ == "__main__":
108 | ray.init()
109 | env_name = "DSSE"
110 |
111 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
112 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
113 |
114 | # Policies are called just like the agents (exact 1:1 mapping).
115 | num_agents = 4
116 | policies = {f"drone{i}" for i in range(num_agents)}
117 | # policies = {f"drone{i}": (None, obs_space, act_space, {}) for i in range(num_agents)}
118 |
119 | config = (
120 | PPOConfig()
121 | .environment(env=env_name)
122 | .rollouts(num_rollout_workers=4, rollout_fragment_length="auto")
123 | .multi_agent(
124 | policies=policies,
125 | # Exact 1:1 mapping from AgentID to ModuleID.
126 | policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
127 | )
128 | .training(
129 | train_batch_size=8192,
130 | lr=1e-5,
131 | gamma=0.9999999,
132 | lambda_=0.9,
133 | use_gae=True,
134 | entropy_coeff=0.01,
135 | sgd_minibatch_size=300,
136 | num_sgd_iter=10,
137 | model={
138 | "custom_model": "CNNModel",
139 | "_disable_preprocessor_api": True,
140 | },
141 | )
142 | .rl_module(
143 | model_config_dict={"vf_share_layers": True},
144 | rl_module_spec=MultiAgentRLModuleSpec(
145 | module_specs={p: SingleAgentRLModuleSpec() for p in policies},
146 | ),
147 | )
148 | .experimental(_disable_preprocessor_api=True)
149 | .debugging(log_level="ERROR")
150 | .framework(framework="torch")
151 | .resources(num_gpus=1)
152 | )
153 |
154 | curr_path = pathlib.Path().resolve()
155 | tune.run(
156 | "PPO",
157 | name="PPO",
158 | stop={"timesteps_total": 5_000_000},
159 | checkpoint_freq=30,
160 | keep_checkpoints_num=200,
161 | storage_path=f"{curr_path}/ray_res/" + env_name,
162 | config=config,
163 | )
164 |
--------------------------------------------------------------------------------
/src/greedy_heuristic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from DSSE import Actions
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import RetainDronePosWrapper
5 | from recorder import PygameRecord
6 |
7 |
8 | class GreedyAgent:
9 | def __call__(self, obs, agents):
10 | """
11 | Greedy approach: Rush and search for the greatest prob.
12 | """
13 | drone_actions = {}
14 | prob_matrix = obs["drone0"][1]
15 | n_drones = len(agents)
16 |
17 | drones_positions = {drone: obs[drone][0] for drone in agents}
18 | # Get n_drones greatest probabilities.
19 | greatest_probs = np.argsort(prob_matrix, axis=None)[-n_drones:]
20 |
21 | for index, drone in enumerate(agents):
22 | greatest_prob = np.unravel_index(greatest_probs[index], prob_matrix.shape)
23 |
24 | drone_obs = obs[drone]
25 | drone_action = self.choose_drone_action(drone_obs[0], greatest_prob)
26 |
27 | new_position = self.get_new_position(drone_obs[0], drone_action)
28 |
29 | # Avoid colision by waiting 1 timestep
30 | if self.drones_colide(drones_positions, new_position):
31 | drone_actions[drone] = Actions.SEARCH.value
32 | else:
33 | drone_actions[drone] = drone_action
34 | drones_positions[drone] = new_position
35 | return drone_actions
36 |
37 | def get_new_position(self, position: tuple, action: int) -> tuple:
38 | match action:
39 | case Actions.LEFT.value:
40 | new_position = (position[0] - 1, position[1])
41 | case Actions.RIGHT.value:
42 | new_position = (position[0] + 1, position[1])
43 | case Actions.UP.value:
44 | new_position = (position[0], position[1] - 1)
45 | case Actions.DOWN.value:
46 | new_position = (position[0], position[1] + 1)
47 | case Actions.UP_LEFT.value:
48 | new_position = (position[0] - 1, position[1] - 1)
49 | case Actions.UP_RIGHT.value:
50 | new_position = (position[0] + 1, position[1] - 1)
51 | case Actions.DOWN_LEFT.value:
52 | new_position = (position[0] - 1, position[1] + 1)
53 | case Actions.DOWN_RIGHT.value:
54 | new_position = (position[0] + 1, position[1] + 1)
55 | case _:
56 | new_position = position
57 | return new_position
58 |
59 | def choose_drone_action(self, drone_position: tuple, greatest_prob_position) -> int:
60 | greatest_prob_y, greatest_prob_x = greatest_prob_position
61 | drone_x, drone_y = drone_position
62 | is_x_greater = greatest_prob_x > drone_x
63 | is_y_greater = greatest_prob_y > drone_y
64 | is_x_lesser = greatest_prob_x < drone_x
65 | is_y_lesser = greatest_prob_y < drone_y
66 | is_x_equal = greatest_prob_x == drone_x
67 | is_y_equal = greatest_prob_y == drone_y
68 |
69 | if is_x_equal and is_y_equal:
70 | return Actions.SEARCH.value
71 |
72 | if is_x_equal:
73 | return Actions.DOWN.value if is_y_greater else Actions.UP.value
74 | elif is_y_equal:
75 | return Actions.LEFT.value if is_x_lesser else Actions.RIGHT.value
76 |
77 | # Movimento na diagonal
78 | if is_x_greater and is_y_greater:
79 | return Actions.DOWN_RIGHT.value
80 | elif is_x_greater and is_y_lesser:
81 | return Actions.UP_RIGHT.value
82 | elif is_x_lesser and is_y_greater:
83 | return Actions.DOWN_LEFT.value
84 | elif is_x_lesser and is_y_lesser:
85 | return Actions.UP_LEFT.value
86 |
87 | def drones_colide(self, drones_positions: dict, new_drone_position: tuple) -> bool:
88 | return new_drone_position in drones_positions.values()
89 |
90 | def __repr__(self) -> str:
91 | return "greedy"
92 |
93 |
94 | see = input("Want to see ?? (y/n): ")
95 | see = see.lower() == "y"
96 | def env_creator(_):
97 | env = DroneSwarmSearch(
98 | render_mode="human" if see else "ansi",
99 | render_grid=True,
100 | drone_amount=4,
101 | grid_size=40,
102 | dispersion_inc=0.1,
103 | person_initial_position=(20, 20),
104 | )
105 | positions = [
106 | (20, 0),
107 | (20, 39),
108 | (0, 20),
109 | (39, 20),
110 | ]
111 | env = RetainDronePosWrapper(env, positions)
112 | return env
113 |
114 | greedy_agent = GreedyAgent()
115 |
116 | env = env_creator(None)
117 |
118 | if see:
119 | with PygameRecord("greedy.gif", 5) as recorder:
120 | obs, info = env.reset()
121 | while env.agents:
122 | actions = greedy_agent(obs, env.agents)
123 | obs, rewards, terminations, truncations, infos = env.step(actions)
124 | recorder.add_frame()
125 | exit(1)
126 |
127 | rewards = []
128 | actions = []
129 | founds = 0
130 | N_EVALS = 5_000
131 | for _ in range(N_EVALS):
132 | i = 0
133 | obs, info = env.reset()
134 | reward_sum = 0
135 | while env.agents:
136 | action = greedy_agent(obs, env.agents)
137 | obs, rw, term, trunc, info = env.step(action)
138 | reward_sum += sum(rw.values())
139 | i += 1
140 | rewards.append(reward_sum)
141 | actions.append(i)
142 |
143 | for _, v in info.items():
144 | if v["Found"]:
145 | founds += 1
146 | break
147 | print("Average reward: ", sum(rewards) / N_EVALS)
148 | print("Mean actions: ", sum(actions) / N_EVALS)
149 | print("Found %: ", founds / N_EVALS)
150 |
--------------------------------------------------------------------------------
/src/test_ppo_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from recorder import PygameRecord
4 | from DSSE import DroneSwarmSearch
5 | from DSSE.environment.wrappers import TopNProbsWrapper, RetainDronePosWrapper, AllFlattenWrapper, AllPositionsWrapper
6 | import ray
7 | from ray.rllib.algorithms.ppo import PPOConfig
8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
9 | from ray.rllib.models import ModelCatalog
10 | from ray.tune.registry import register_env
11 | from ray.rllib.algorithms.ppo import PPO
12 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
13 | import torch.nn as nn
14 | import argparse
15 | import torch
16 |
17 |
18 |
19 | argparser = argparse.ArgumentParser()
20 | argparser.add_argument("--checkpoint", type=str, required=True)
21 | argparser.add_argument("--see", action="store_true", default=False)
22 | args = argparser.parse_args()
23 |
24 |
25 | class CNNModel(TorchModelV2, nn.Module):
26 | def __init__(
27 | self,
28 | obs_space,
29 | act_space,
30 | num_outputs,
31 | model_config,
32 | name,
33 | **kw,
34 | ):
35 | print("OBSSPACE: ", obs_space)
36 | TorchModelV2.__init__(
37 | self, obs_space, act_space, num_outputs, model_config, name, **kw
38 | )
39 | nn.Module.__init__(self)
40 |
41 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
42 | self.cnn = nn.Sequential(
43 | nn.Conv2d(
44 | in_channels=1,
45 | out_channels=16,
46 | kernel_size=(8, 8),
47 | stride=(1, 1),
48 | ),
49 | nn.Tanh(),
50 | nn.Conv2d(
51 | in_channels=16,
52 | out_channels=32,
53 | kernel_size=(4, 4),
54 | stride=(1, 1),
55 | ),
56 | nn.Tanh(),
57 | nn.Flatten(),
58 | nn.Linear(flatten_size, 256),
59 | nn.Tanh(),
60 | )
61 |
62 | self.linear = nn.Sequential(
63 | nn.Linear(obs_space[0].shape[0], 512),
64 | nn.Tanh(),
65 | nn.Linear(512, 256),
66 | nn.Tanh(),
67 | )
68 |
69 | self.join = nn.Sequential(
70 | nn.Linear(256 * 2, 256),
71 | nn.Tanh(),
72 | )
73 |
74 | self.policy_fn = nn.Linear(256, num_outputs)
75 | self.value_fn = nn.Linear(256, 1)
76 |
77 | def forward(self, input_dict, state, seq_lens):
78 | input_positions = input_dict["obs"][0].float()
79 | input_matrix = input_dict["obs"][1].float()
80 |
81 | input_matrix = input_matrix.unsqueeze(1)
82 | cnn_out = self.cnn(input_matrix)
83 | linear_out = self.linear(input_positions)
84 |
85 | value_input = torch.cat((cnn_out, linear_out), dim=1)
86 | value_input = self.join(value_input)
87 |
88 | self._value_out = self.value_fn(value_input)
89 | return self.policy_fn(value_input), state
90 |
91 | def value_function(self):
92 | return self._value_out.flatten()
93 |
94 |
95 |
96 | # ModelCatalog.register_custom_model("MLPModel", MLPModel)
97 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
98 |
99 | # DEFINE HERE THE EXACT ENVIRONMENT YOU USED TO TRAIN THE AGENT
100 | def env_creator(_):
101 | render_mode = "human" if args.see else "ansi"
102 | env = DroneSwarmSearch(
103 | drone_amount=4,
104 | grid_size=40,
105 | render_mode=render_mode,
106 | render_grid=True,
107 | dispersion_inc=0.1,
108 | person_initial_position=(20, 20),
109 | )
110 | positions = [
111 | (20, 0),
112 | (20, 39),
113 | (0, 20),
114 | (39, 20),
115 | ]
116 | env = RetainDronePosWrapper(env, positions)
117 | env = AllPositionsWrapper(env)
118 | return env
119 |
120 | env = env_creator(None)
121 | register_env("DSSE", lambda config: ParallelPettingZooEnv(env_creator(config)))
122 | ray.init()
123 |
124 |
125 | checkpoint_path = args.checkpoint
126 | PPOagent = PPO.from_checkpoint(checkpoint_path)
127 |
128 | reward_sum = 0
129 | i = 0
130 |
131 | if args.see:
132 | obs, info = env.reset()
133 | with PygameRecord("test_trained.gif", 5) as rec:
134 | while env.agents:
135 | actions = {}
136 | for k, v in obs.items():
137 | actions[k] = PPOagent.compute_single_action(v, explore=False)
138 | # print(v)
139 | # action = PPOagent.compute_actions(obs)
140 | obs, rw, term, trunc, info = env.step(actions)
141 | reward_sum += sum(rw.values())
142 | i += 1
143 | rec.add_frame()
144 | else:
145 | rewards = []
146 | founds = 0
147 | N_EVALS = 1000
148 | for _ in range(N_EVALS):
149 | obs, info = env.reset()
150 | reward_sum = 0
151 | while env.agents:
152 | actions = {}
153 | for k, v in obs.items():
154 | actions[k] = PPOagent.compute_single_action(v, explore=False)
155 | # print(v)
156 | # action = PPOagent.compute_actions(obs, explore=False)
157 | obs, rw, term, trunc, info = env.step(actions)
158 | reward_sum += sum(rw.values())
159 | i += 1
160 | rewards.append(reward_sum)
161 | for _, v in info.items():
162 | if v["Found"]:
163 | founds += 1
164 | break
165 | print("Average reward: ", sum(rewards) / N_EVALS)
166 | print("Found %: ", founds / N_EVALS)
167 |
168 | print("Total reward: ", reward_sum)
169 | print("Total steps: ", i)
170 | print("Found: ", info)
171 | env.close()
172 |
--------------------------------------------------------------------------------
/src/train_ppo_cnn_cov.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from DSSE import CoverageDroneSwarmSearch
3 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
4 | import ray
5 | from ray import tune
6 | from ray.rllib.algorithms.ppo import PPOConfig
7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
8 | from ray.rllib.models import ModelCatalog
9 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
10 | from ray.tune.registry import register_env
11 | from torch import nn
12 | import torch
13 | import numpy as np
14 |
15 |
16 | class CNNModel(TorchModelV2, nn.Module):
17 | def __init__(
18 | self,
19 | obs_space,
20 | act_space,
21 | num_outputs,
22 | model_config,
23 | name,
24 | **kw,
25 | ):
26 | print("OBSSPACE: ", obs_space)
27 | TorchModelV2.__init__(
28 | self, obs_space, act_space, num_outputs, model_config, name, **kw
29 | )
30 | nn.Module.__init__(self)
31 |
32 | first_cnn_out = (obs_space[1].shape[0] - 3) + 1
33 | second_cnn_out = (first_cnn_out - 2) + 1
34 | flatten_size = 32 * second_cnn_out * second_cnn_out
35 | print("Cnn Dense layer input size: ", flatten_size)
36 | self.cnn = nn.Sequential(
37 | nn.Conv2d(
38 | in_channels=1,
39 | out_channels=16,
40 | kernel_size=(3, 3),
41 | ),
42 | nn.Tanh(),
43 | nn.Conv2d(
44 | in_channels=16,
45 | out_channels=32,
46 | kernel_size=(2, 2),
47 | ),
48 | nn.Tanh(),
49 | nn.Flatten(),
50 | nn.Linear(flatten_size, 256),
51 | nn.Tanh(),
52 | )
53 |
54 | self.linear = nn.Sequential(
55 | nn.Linear(obs_space[0].shape[0], 512),
56 | nn.Tanh(),
57 | nn.Linear(512, 256),
58 | nn.Tanh(),
59 | )
60 |
61 | self.join = nn.Sequential(
62 | nn.Linear(256 * 2, 256),
63 | nn.Tanh(),
64 | )
65 |
66 | self.policy_fn = nn.Linear(256, num_outputs)
67 | self.value_fn = nn.Linear(256, 1)
68 |
69 | def forward(self, input_dict, state, seq_lens):
70 | input_positions = input_dict["obs"][0].float()
71 | input_matrix = input_dict["obs"][1].float()
72 |
73 | input_matrix = input_matrix.unsqueeze(1)
74 | cnn_out = self.cnn(input_matrix)
75 | linear_out = self.linear(input_positions)
76 |
77 | value_input = torch.cat((cnn_out, linear_out), dim=1)
78 | value_input = self.join(value_input)
79 |
80 | self._value_out = self.value_fn(value_input)
81 | return self.policy_fn(value_input), state
82 |
83 | def value_function(self):
84 | return self._value_out.flatten()
85 |
86 | def env_creator(args):
87 | print("-------------------------- ENV CREATOR --------------------------")
88 | N_AGENTS = 2
89 | # 6 hours of simulation, 600 radius
90 | env = CoverageDroneSwarmSearch(
91 | timestep_limit=200, drone_amount=N_AGENTS, prob_matrix_path="min_matrix.npy"
92 | )
93 | env = AllPositionsWrapper(env)
94 | grid_size = env.grid_size
95 | # positions = position_on_diagonal(grid_size, N_AGENTS)
96 | # positions = position_on_circle(grid_size, N_AGENTS, 2)
97 | positions = [
98 | (grid_size - 1, grid_size // 2),
99 | (0, grid_size // 2),
100 | ]
101 | env = RetainDronePosWrapper(env, positions)
102 | return env
103 |
104 | def position_on_diagonal(grid_size, drone_amount):
105 | positions = []
106 | center = grid_size // 2
107 | for i in range(-drone_amount // 2, drone_amount // 2):
108 | positions.append((center + i, center + i))
109 | return positions
110 |
111 | def position_on_circle(grid_size, drone_amount, radius):
112 | positions = []
113 | center = grid_size // 2
114 | angle_increment = 2 * np.pi / drone_amount
115 |
116 | for i in range(drone_amount):
117 | angle = i * angle_increment
118 | x = center + int(radius * np.cos(angle))
119 | y = center + int(radius * np.sin(angle))
120 | positions.append((x, y))
121 |
122 | return positions
123 |
124 |
125 | if __name__ == "__main__":
126 | ray.init()
127 |
128 | env_name = "DSSE_Coverage"
129 |
130 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
131 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
132 |
133 | config = (
134 | PPOConfig()
135 | .environment(env=env_name)
136 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto")
137 | .training(
138 | train_batch_size=8192 * 5,
139 | lr=6e-6,
140 | gamma=0.9999999,
141 | lambda_=0.9,
142 | use_gae=True,
143 | entropy_coeff=0.01,
144 | vf_clip_param=100000,
145 | sgd_minibatch_size=300,
146 | num_sgd_iter=10,
147 | model={
148 | "custom_model": "CNNModel",
149 | "_disable_preprocessor_api": True,
150 | },
151 | )
152 | .experimental(_disable_preprocessor_api=True)
153 | .debugging(log_level="ERROR")
154 | .framework(framework="torch")
155 | .resources(num_gpus=1)
156 | )
157 |
158 | curr_path = pathlib.Path().resolve()
159 | tune.run(
160 | "PPO",
161 | name="PPO_" + input("Exp name: "),
162 | # resume=True,
163 | stop={"timesteps_total": 20_000_000},
164 | checkpoint_freq=20,
165 | storage_path=f"{curr_path}/ray_res/" + env_name,
166 | config=config.to_dict(),
167 | )
168 |
--------------------------------------------------------------------------------
/src/test_trained_search.py:
--------------------------------------------------------------------------------
1 | from recorder import PygameRecord
2 | from DSSE import DroneSwarmSearch
3 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
4 | from DSSE.environment.wrappers.communication_wrapper import CommunicationWrapper
5 | import ray
6 | from ray.rllib.algorithms.ppo import PPOConfig
7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
8 | from ray.rllib.models import ModelCatalog
9 | from ray.tune.registry import register_env
10 | from ray.rllib.algorithms.ppo import PPO
11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
12 | import torch.nn as nn
13 | import argparse
14 | import torch
15 | import numpy as np
16 |
17 |
18 |
19 | argparser = argparse.ArgumentParser()
20 | argparser.add_argument("--checkpoint", type=str, required=True)
21 | argparser.add_argument("--see", action="store_true", default=False)
22 | args = argparser.parse_args()
23 |
24 |
25 | class CNNModel(TorchModelV2, nn.Module):
26 | def __init__(
27 | self,
28 | obs_space,
29 | act_space,
30 | num_outputs,
31 | model_config,
32 | name,
33 | **kw,
34 | ):
35 | print("OBSSPACE: ", obs_space)
36 | TorchModelV2.__init__(
37 | self, obs_space, act_space, num_outputs, model_config, name, **kw
38 | )
39 | nn.Module.__init__(self)
40 |
41 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
42 | self.cnn = nn.Sequential(
43 | nn.Conv2d(
44 | in_channels=1,
45 | out_channels=16,
46 | kernel_size=(8, 8),
47 | stride=(1, 1),
48 | ),
49 | nn.Tanh(),
50 | nn.Conv2d(
51 | in_channels=16,
52 | out_channels=32,
53 | kernel_size=(4, 4),
54 | stride=(1, 1),
55 | ),
56 | nn.Tanh(),
57 | nn.Flatten(),
58 | nn.Linear(flatten_size, 256),
59 | nn.Tanh(),
60 | )
61 |
62 | self.linear = nn.Sequential(
63 | nn.Linear(obs_space[0].shape[0], 512),
64 | nn.Tanh(),
65 | nn.Linear(512, 256),
66 | nn.Tanh(),
67 | )
68 |
69 | self.join = nn.Sequential(
70 | nn.Linear(256 * 2, 256),
71 | nn.Tanh(),
72 | )
73 |
74 | self.policy_fn = nn.Linear(256, num_outputs)
75 | self.value_fn = nn.Linear(256, 1)
76 |
77 | def forward(self, input_dict, state, seq_lens):
78 | input_positions = input_dict["obs"][0].float()
79 | input_matrix = input_dict["obs"][1].float()
80 |
81 | input_matrix = input_matrix.unsqueeze(1)
82 | cnn_out = self.cnn(input_matrix)
83 | linear_out = self.linear(input_positions)
84 |
85 | value_input = torch.cat((cnn_out, linear_out), dim=1)
86 | value_input = self.join(value_input)
87 |
88 | self._value_out = self.value_fn(value_input)
89 | return self.policy_fn(value_input), state
90 |
91 | def value_function(self):
92 | return self._value_out.flatten()
93 |
94 |
95 |
96 | # ModelCatalog.register_custom_model("MLPModel", MLPModel)
97 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
98 |
99 | # DEFINE HERE THE EXACT ENVIRONMENT YOU USED TO TRAIN THE AGENT
100 | def env_creator(args):
101 | env = DroneSwarmSearch(
102 | drone_amount=4,
103 | grid_size=40,
104 | dispersion_inc=0.1,
105 | person_initial_position=(20, 20),
106 | )
107 | positions = [
108 | (20, 0),
109 | (20, 39),
110 | (0, 20),
111 | (39, 20),
112 | ]
113 | env = AllPositionsWrapper(env)
114 | env = CommunicationWrapper(env, n_steps=12)
115 | env = RetainDronePosWrapper(env, positions)
116 | return env
117 |
118 | env = env_creator(None)
119 | register_env("DSSE", lambda config: ParallelPettingZooEnv(env_creator(config)))
120 | ray.init()
121 |
122 |
123 | checkpoint_path = args.checkpoint
124 | PPOagent = PPO.from_checkpoint(checkpoint_path)
125 |
126 |
127 | if args.see:
128 | i = 0
129 | reward_sum = 0
130 | obs, info = env.reset()
131 | with PygameRecord("test_trained.gif", 5) as rec:
132 | while env.agents:
133 | actions = {}
134 | for k, v in obs.items():
135 | actions[k] = PPOagent.compute_single_action(v, explore=False)
136 | # print(v)
137 | # action = PPOagent.compute_actions(obs)
138 | obs, rw, term, trunc, info = env.step(actions)
139 | reward_sum += sum(rw.values())
140 | i += 1
141 | rec.add_frame()
142 | print(info)
143 | print(reward_sum)
144 | else:
145 | rewards = []
146 | actions_stat = []
147 | founds = 0
148 | N_EVALS = 5_000
149 | for epoch in range(N_EVALS):
150 | print(epoch)
151 | obs, info = env.reset()
152 | i = 0
153 | reward_sum = 0
154 | while env.agents:
155 | actions = {}
156 | for k, v in obs.items():
157 | actions[k] = PPOagent.compute_single_action(v, explore=False)
158 | obs, rw, term, trunc, info = env.step(actions)
159 | reward_sum += sum(rw.values())
160 | i += 1
161 | rewards.append(reward_sum)
162 | actions_stat.append(i)
163 |
164 | for _, v in info.items():
165 | if v["Found"]:
166 | founds += 1
167 | break
168 | print("Average reward: ", sum(rewards) / N_EVALS)
169 | print("Average Actions: ", sum(actions_stat) / N_EVALS)
170 | print("Median of actions: ", np.median(actions_stat))
171 | print("Found %: ", founds / N_EVALS)
172 |
173 | env.close()
174 |
--------------------------------------------------------------------------------
/src/test_trained_cov.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from recorder import PygameRecord
4 | from DSSE import CoverageDroneSwarmSearch
5 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
6 | import ray
7 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
8 | from ray.rllib.models import ModelCatalog
9 | from ray.tune.registry import register_env
10 | from ray.rllib.algorithms.ppo import PPO
11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
12 | import torch.nn as nn
13 | import argparse
14 | import torch
15 |
16 |
17 |
18 | argparser = argparse.ArgumentParser()
19 | argparser.add_argument("--checkpoint", type=str, required=True)
20 | argparser.add_argument("--see", action="store_true", default=False)
21 | args = argparser.parse_args()
22 |
23 |
24 | class CNNModel(TorchModelV2, nn.Module):
25 | def __init__(
26 | self,
27 | obs_space,
28 | act_space,
29 | num_outputs,
30 | model_config,
31 | name,
32 | **kw,
33 | ):
34 | print("OBSSPACE: ", obs_space)
35 | TorchModelV2.__init__(
36 | self, obs_space, act_space, num_outputs, model_config, name, **kw
37 | )
38 | nn.Module.__init__(self)
39 |
40 | first_cnn_out = (obs_space[1].shape[0] - 8) + 1
41 | second_cnn_out = (first_cnn_out - 4) + 1
42 | flatten_size = 32 * second_cnn_out * second_cnn_out
43 | print("Cnn Dense layer input size: ", flatten_size)
44 | self.cnn = nn.Sequential(
45 | nn.Conv2d(
46 | in_channels=1,
47 | out_channels=16,
48 | kernel_size=(8, 8),
49 | ),
50 | nn.Tanh(),
51 | nn.Conv2d(
52 | in_channels=16,
53 | out_channels=32,
54 | kernel_size=(4, 4),
55 | ),
56 | nn.Tanh(),
57 | nn.Flatten(),
58 | nn.Linear(flatten_size, 256),
59 | nn.Tanh(),
60 | )
61 |
62 | self.linear = nn.Sequential(
63 | nn.Linear(obs_space[0].shape[0], 512),
64 | nn.Tanh(),
65 | nn.Linear(512, 256),
66 | nn.Tanh(),
67 | )
68 |
69 | self.join = nn.Sequential(
70 | nn.Linear(256 * 2, 256),
71 | nn.Tanh(),
72 | )
73 |
74 | self.policy_fn = nn.Linear(256, num_outputs)
75 | self.value_fn = nn.Linear(256, 1)
76 |
77 | def forward(self, input_dict, state, seq_lens):
78 | input_positions = input_dict["obs"][0].float()
79 | input_matrix = input_dict["obs"][1].float()
80 |
81 | input_matrix = input_matrix.unsqueeze(1)
82 | cnn_out = self.cnn(input_matrix)
83 | linear_out = self.linear(input_positions)
84 |
85 | value_input = torch.cat((cnn_out, linear_out), dim=1)
86 | value_input = self.join(value_input)
87 |
88 | self._value_out = self.value_fn(value_input)
89 | return self.policy_fn(value_input), state
90 |
91 | def value_function(self):
92 | return self._value_out.flatten()
93 |
94 | # Register the model
95 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
96 |
97 |
98 | def env_creator(args):
99 | print("-------------------------- ENV CREATOR --------------------------")
100 | N_AGENTS = 8
101 | # 6 hours of simulation, 600 radius
102 | env = CoverageDroneSwarmSearch(
103 | timestep_limit=180, drone_amount=N_AGENTS, prob_matrix_path="presim_20.npy", render_mode="human"
104 | )
105 | env = AllPositionsWrapper(env)
106 | grid_size = env.grid_size
107 | positions = position_on_diagonal(grid_size, N_AGENTS)
108 | env = RetainDronePosWrapper(env, positions)
109 | return env
110 |
111 | def position_on_diagonal(grid_size, drone_amount):
112 | positions = []
113 | center = grid_size // 2
114 | for i in range(-drone_amount // 2, drone_amount // 2):
115 | positions.append((center + i, center + i))
116 | return positions
117 |
118 | env = env_creator(None)
119 | register_env("DSSE_Coverage", lambda config: ParallelPettingZooEnv(env_creator(config)))
120 | ray.init()
121 |
122 |
123 | checkpoint_path = args.checkpoint
124 | PPOagent = PPO.from_checkpoint(checkpoint_path)
125 |
126 | reward_sum = 0
127 | i = 0
128 |
129 | if args.see:
130 | obs, info = env.reset()
131 | with PygameRecord("test_trained.gif", 5) as rec:
132 | while env.agents:
133 | actions = {}
134 | for k, v in obs.items():
135 | actions[k] = PPOagent.compute_single_action(v, explore=False)
136 | # print(v)
137 | # action = PPOagent.compute_actions(obs)
138 | obs, rw, term, trunc, info = env.step(actions)
139 | reward_sum += sum(rw.values())
140 | i += 1
141 | rec.add_frame()
142 | else:
143 | rewards = []
144 | founds = 0
145 | N_EVALS = 1000
146 | for _ in range(N_EVALS):
147 | obs, info = env.reset()
148 | reward_sum = 0
149 | while env.agents:
150 | actions = {}
151 | for k, v in obs.items():
152 | actions[k] = PPOagent.compute_single_action(v, explore=False)
153 | # print(v)
154 | # action = PPOagent.compute_actions(obs, explore=False)
155 | obs, rw, term, trunc, info = env.step(actions)
156 | reward_sum += sum(rw.values())
157 | i += 1
158 | rewards.append(reward_sum)
159 | for _, v in info.items():
160 | if v["Found"]:
161 | founds += 1
162 | break
163 | print("Average reward: ", sum(rewards) / N_EVALS)
164 | print("Found %: ", founds / N_EVALS)
165 |
166 | print("Total reward: ", reward_sum)
167 | print("Total steps: ", i)
168 | print("Found: ", info)
169 | env.close()
170 |
--------------------------------------------------------------------------------
/src/a_star_coverage.py:
--------------------------------------------------------------------------------
1 | """
2 | Module to implement the A* algorithm for coverage path planning.
3 | """
4 |
5 | import argparse
6 | import numpy as np
7 | from DSSE import CoverageDroneSwarmSearch, Actions
8 | from aigyminsper.search.graph import State
9 |
10 | MOVEMENTS = {
11 | Actions.UP: (0, -1),
12 | Actions.DOWN: (0, 1),
13 | Actions.LEFT: (-1, 0),
14 | Actions.RIGHT: (1, 0),
15 | Actions.UP_LEFT: (-1, -1),
16 | Actions.UP_RIGHT: (1, -1),
17 | Actions.DOWN_LEFT: (-1, 1),
18 | Actions.DOWN_RIGHT: (1, 1),
19 | }
20 |
21 |
22 | class DroneState(State):
23 | def __init__(self, position: tuple, prob_matrix: np.ndarray, visited: set):
24 | self.position = position
25 | self.prob_matrix = prob_matrix
26 | self.visited = visited
27 | self.action = None
28 |
29 | def successors(self, allow_zeros: bool = False) -> list["DroneState"]:
30 | successors = []
31 | x, y = self.position
32 |
33 | for action, (dx, dy) in MOVEMENTS.items():
34 | new_x, new_y = x + dx, y + dy
35 |
36 | if 0 <= new_x < len(self.prob_matrix) and 0 <= new_y < len(
37 | self.prob_matrix[0]
38 | ):
39 | if allow_zeros or self.prob_matrix[new_y, new_x] > 0:
40 | new_state = DroneState(
41 | (new_x, new_y),
42 | self.prob_matrix,
43 | self.visited | {(new_x, new_y)},
44 | )
45 | new_state.set_action(action)
46 | successors.append(new_state)
47 |
48 | return successors
49 |
50 | def cost(
51 | self,
52 | prob_weight: int | float = 10,
53 | distance_weight: int | float = 0.5,
54 | revisit_penalty_value: int | float = 30,
55 | ) -> int | float:
56 | high_prob_indices = np.argwhere(self.prob_matrix > 0)
57 | number_of_high_prob = high_prob_indices.shape[0]
58 |
59 | if number_of_high_prob == 0:
60 | return float("inf")
61 |
62 | x, y = self.position
63 | distances = np.abs(high_prob_indices[:, 1] - x) + np.abs(
64 | high_prob_indices[:, 0] - y
65 | )
66 | min_distance = np.min(distances)
67 |
68 | max_prob = np.max(self.prob_matrix)
69 | prob_value = self.prob_matrix[y, x]
70 | normalized_prob = prob_value / max_prob if max_prob > 0 else 0
71 |
72 | max_distance = len(self.prob_matrix) + len(self.prob_matrix[0])
73 | normalized_distance = min_distance / max_distance
74 |
75 | revisit_penalty = 0
76 | if (x, y) in self.visited:
77 | revisit_penalty = revisit_penalty_value * (
78 | len(self.visited) / (number_of_high_prob + 1)
79 | )
80 |
81 | heuristic_value = (
82 | -(normalized_prob * prob_weight)
83 | + (normalized_distance * distance_weight)
84 | + revisit_penalty
85 | )
86 |
87 | return heuristic_value
88 |
89 | def description(self):
90 | return f"DroneState Position = {self.position}"
91 |
92 | def env(self):
93 | return self.position
94 |
95 | def is_goal(self):
96 | return False
97 |
98 | def set_action(self, action: Actions):
99 | self.action = action
100 |
101 | def get_action(self) -> Actions:
102 | return self.action if self.action else Actions.SEARCH
103 |
104 | def __eq__(self, other: "DroneState") -> bool:
105 | position_eq = self.position == other.position
106 | visited_eq = self.visited == other.visited
107 | action_eq = self.action == other.action
108 | return position_eq and visited_eq and action_eq
109 |
110 |
111 | def a_star(
112 | observations: dict, agents: list, prob_matrix: np.ndarray, visited: list[set]
113 | ) -> dict:
114 | actions = {}
115 | will_visit = []
116 | for i, agent in enumerate(agents):
117 | current_x, current_y = observations[agent][0]
118 |
119 | drone_state = DroneState((current_x, current_y), prob_matrix, visited[i])
120 |
121 | successors = drone_state.successors()
122 |
123 | if not successors:
124 | successors = drone_state.successors(allow_zeros=True)
125 |
126 | if not successors:
127 | continue
128 |
129 | next_state = min(successors, key=lambda state: state.cost())
130 | while next_state.position in will_visit:
131 | successors.remove(next_state)
132 | next_state = min(successors, key=lambda state: state.cost())
133 |
134 | will_visit.append(next_state.position)
135 | actions[agent] = next_state.get_action().value
136 | visited[i].add(next_state.position)
137 | prob_matrix[current_y, current_x] = 0
138 |
139 | return actions
140 |
141 |
142 | def main(num_drones: int):
143 | env = CoverageDroneSwarmSearch(
144 | drone_amount=num_drones,
145 | render_mode="human",
146 | timestep_limit=200,
147 | prob_matrix_path="src/min_matrix.npy",
148 | )
149 |
150 | center = env.grid_size // 2
151 | positions = [(center, center)]
152 |
153 | for i in range(1, num_drones):
154 | offset = (i // 2) + 1
155 | x_offset = offset * (-1 if i % 2 == 0 else 1)
156 | y_offset = offset * (-1 if (i + 1) % 2 == 0 else 1)
157 | new_position = (center + x_offset, center + y_offset)
158 | positions.append(new_position)
159 |
160 | opt = {"drones_positions": positions}
161 |
162 | observations, _ = env.reset(options=opt)
163 |
164 | visited = [set([opt["drones_positions"][i]]) for i in range(num_drones)]
165 |
166 | prob_matrix = env.probability_matrix.get_matrix()
167 | while env.agents:
168 | actions = a_star(observations, env.agents, prob_matrix.copy(), visited)
169 | observations, *_ = env.step(actions)
170 |
171 |
172 | if __name__ == "__main__":
173 | argparser = argparse.ArgumentParser()
174 | argparser.add_argument("--num_drones", type=int, required=True)
175 | args = argparser.parse_args()
176 | main(num_drones=args.num_drones)
177 |
--------------------------------------------------------------------------------
/src/train_ppo_cnn_lstm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from DSSE import DroneSwarmSearch
4 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
5 | from DSSE.environment.wrappers.communication_wrapper import CommunicationWrapper
6 | import ray
7 | from ray import tune
8 | from ray.rllib.algorithms.ppo import PPOConfig
9 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
10 | from ray.rllib.models import ModelCatalog
11 | from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
12 | from ray.rllib.models.modelv2 import ModelV2
13 | from ray.rllib.utils.annotations import override
14 | from ray.rllib.policy.rnn_sequencing import add_time_dimension
15 | from ray.tune.registry import register_env
16 | from torch import nn
17 | import torch
18 |
19 |
20 | class CNNModel(TorchRNN, nn.Module):
21 | def __init__(
22 | self,
23 | obs_space,
24 | act_space,
25 | num_outputs,
26 | model_config,
27 | name,
28 | **kw,
29 | ):
30 | print("OBSSPACE: ", obs_space)
31 | num_outputs = act_space.n
32 | nn.Module.__init__(self)
33 | super().__init__(obs_space, act_space, num_outputs, model_config, name, **kw)
34 |
35 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[0] - 7 - 3)
36 | self.cnn = nn.Sequential(
37 | nn.Conv2d(
38 | in_channels=1,
39 | out_channels=16,
40 | kernel_size=(8, 8),
41 | stride=(1, 1),
42 | ),
43 | nn.Tanh(),
44 | nn.Conv2d(
45 | in_channels=16,
46 | out_channels=32,
47 | kernel_size=(4, 4),
48 | stride=(1, 1),
49 | ),
50 | nn.Tanh(),
51 | nn.Flatten(),
52 | nn.Linear(flatten_size, 256),
53 | nn.Tanh(),
54 | )
55 |
56 | self.linear = nn.Linear(obs_space[0].shape[0], 512)
57 |
58 | self.lstm_state_size = 256
59 | self.lstm = nn.LSTM(512, self.lstm_state_size, batch_first=True)
60 |
61 | self.join = nn.Sequential(
62 | nn.Linear(256 * 2, 256),
63 | nn.Tanh(),
64 | )
65 | print("NUM OUTPUTS: ", num_outputs)
66 | self.policy_fn = nn.Linear(256, num_outputs)
67 | self.value_fn = nn.Linear(256, 1)
68 | # Holds the current "base" output (before logits layer).
69 | self._value_out = None
70 |
71 | @override(ModelV2)
72 | def get_initial_state(self):
73 | # TODO: (sven): Get rid of `get_initial_state` once Trajectory
74 | # View API is supported across all of RLlib.
75 | # Place hidden states on same device as model.
76 | h = [
77 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
78 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
79 | ]
80 | return h
81 |
82 | def value_function(self):
83 | return self._value_out.flatten()
84 |
85 | @override(ModelV2)
86 | def forward(
87 | self,
88 | input_dict,
89 | state,
90 | seq_lens,
91 | ):
92 | """Adds time dimension to batch before sending inputs to forward_rnn().
93 |
94 | You should implement forward_rnn() in your subclass."""
95 | scalar_inputs = input_dict["obs"][0].float()
96 | input_matrix = input_dict["obs"][1].float()
97 |
98 | input_matrix = input_matrix.unsqueeze(1)
99 | cnn_out = self.cnn(input_matrix)
100 |
101 | flat_inputs = scalar_inputs.flatten(start_dim=1)
102 | # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
103 | # as input_dict may have extra zero-padding beyond seq_lens.max().
104 | # Use add_time_dimension to handle this
105 | self.time_major = self.model_config.get("_time_major", False)
106 | inputs = add_time_dimension(
107 | flat_inputs,
108 | seq_lens=seq_lens,
109 | framework="torch",
110 | time_major=self.time_major,
111 | )
112 | lstm_out, new_state = self.forward_rnn(inputs, state, seq_lens)
113 | lstm_out = torch.reshape(lstm_out, [-1, self.lstm_state_size])
114 |
115 | value_input = torch.cat((cnn_out, lstm_out), dim=1)
116 | value_input = self.join(value_input)
117 |
118 | self._value_out = self.value_fn(value_input)
119 | return self.policy_fn(value_input), new_state
120 |
121 | @override(TorchRNN)
122 | def forward_rnn(self, inputs, state, seq_lens):
123 | """Feeds `inputs` (B x T x ..) through the Gru Unit.
124 |
125 | Returns the resulting outputs as a sequence (B x T x ...).
126 | Values are stored in self._cur_value in simple (B) shape (where B
127 | contains both the B and T dims!).
128 |
129 | Returns:
130 | NN Outputs (B x T x ...) as sequence.
131 | The state batches as a List of two items (c- and h-states).
132 | """
133 | linear_out = nn.functional.tanh(self.linear(inputs))
134 |
135 | lstm_out, [h, c] = self.lstm(linear_out, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)])
136 |
137 | return lstm_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
138 |
139 |
140 | def env_creator(args):
141 | """
142 | Petting Zoo environment for search of shipwrecked people.
143 | check it out at
144 | https://github.com/pfeinsper/drone-swarm-search
145 | or install with
146 | pip install DSSE
147 | """
148 | env = DroneSwarmSearch(
149 | drone_amount=4,
150 | grid_size=40,
151 | dispersion_inc=0.1,
152 | person_initial_position=(20, 20),
153 | )
154 | positions = [
155 | (20, 0),
156 | (20, 39),
157 | (0, 20),
158 | (39, 20),
159 | ]
160 | env = AllPositionsWrapper(env)
161 | env = RetainDronePosWrapper(env, positions)
162 | return env
163 |
164 |
165 | if __name__ == "__main__":
166 | ray.init()
167 |
168 | env_name = "DSSE"
169 |
170 | register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
171 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
172 |
173 | config = (
174 | PPOConfig()
175 | .environment(env=env_name)
176 | .rollouts(num_rollout_workers=6, rollout_fragment_length="auto")
177 | .training(
178 | train_batch_size=4096,
179 | lr=1e-5,
180 | gamma=0.9999999,
181 | lambda_=0.9,
182 | use_gae=True,
183 | entropy_coeff=0.01,
184 | sgd_minibatch_size=300,
185 | num_sgd_iter=10,
186 | model={
187 | "custom_model": "CNNModel",
188 | "use_lstm": False,
189 | "lstm_cell_size": 256,
190 | "_disable_preprocessor_api": True,
191 | },
192 | )
193 | .experimental(_disable_preprocessor_api=True)
194 | .debugging(log_level="ERROR")
195 | .framework(framework="torch")
196 | .resources(num_gpus=1)
197 | )
198 |
199 | curr_path = pathlib.Path().resolve()
200 | tune.run(
201 | "PPO",
202 | name="PPO_LSTM_M",
203 | resume=True,
204 | stop={"timesteps_total": 20_000_000, "episode_reward_mean": 1.82},
205 | checkpoint_freq=15,
206 | storage_path=f"{curr_path}/ray_res/" + env_name,
207 | config=config.to_dict(),
208 | )
209 |
--------------------------------------------------------------------------------
/src/test_trained_cnn_lstm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from recorder import PygameRecord
4 | from DSSE import DroneSwarmSearch
5 | from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
6 | import ray
7 | from ray.rllib.algorithms.ppo import PPOConfig
8 | from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
9 | from ray.rllib.models import ModelCatalog
10 | from ray.tune.registry import register_env
11 | from ray.rllib.utils.annotations import override
12 | from ray.rllib.policy.rnn_sequencing import add_time_dimension
13 | from ray.rllib.algorithms.ppo import PPO
14 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
15 | from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
16 | from ray.rllib.models.modelv2 import ModelV2
17 | import torch.nn as nn
18 | import argparse
19 | import torch
20 | import numpy as np
21 |
22 |
23 |
24 | argparser = argparse.ArgumentParser()
25 | argparser.add_argument("--checkpoint", type=str, required=True)
26 | argparser.add_argument("--see", action="store_true", default=False)
27 | args = argparser.parse_args()
28 |
29 | class CNNModel(TorchRNN, nn.Module):
30 | def __init__(
31 | self,
32 | obs_space,
33 | act_space,
34 | num_outputs,
35 | model_config,
36 | name,
37 | **kw,
38 | ):
39 | print("OBSSPACE: ", obs_space)
40 | num_outputs = act_space.n
41 | nn.Module.__init__(self)
42 | super().__init__(obs_space, act_space, num_outputs, model_config, name, **kw)
43 |
44 | flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[0] - 7 - 3)
45 | self.cnn = nn.Sequential(
46 | nn.Conv2d(
47 | in_channels=1,
48 | out_channels=16,
49 | kernel_size=(8, 8),
50 | stride=(1, 1),
51 | ),
52 | nn.Tanh(),
53 | nn.Conv2d(
54 | in_channels=16,
55 | out_channels=32,
56 | kernel_size=(4, 4),
57 | stride=(1, 1),
58 | ),
59 | nn.Tanh(),
60 | nn.Flatten(),
61 | nn.Linear(flatten_size, 256),
62 | nn.Tanh(),
63 | )
64 |
65 | self.linear = nn.Linear(obs_space[0].shape[0], 512)
66 |
67 | self.lstm_state_size = 256
68 | self.lstm = nn.LSTM(512, self.lstm_state_size, batch_first=True)
69 |
70 | self.join = nn.Sequential(
71 | nn.Linear(256 * 2, 256),
72 | nn.Tanh(),
73 | )
74 | print("NUM OUTPUTS: ", num_outputs)
75 | self.policy_fn = nn.Linear(256, num_outputs)
76 | self.value_fn = nn.Linear(256, 1)
77 | # Holds the current "base" output (before logits layer).
78 | self._value_out = None
79 |
80 | @override(ModelV2)
81 | def get_initial_state(self):
82 | # TODO: (sven): Get rid of `get_initial_state` once Trajectory
83 | # View API is supported across all of RLlib.
84 | # Place hidden states on same device as model.
85 | h = [
86 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
87 | self.linear.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
88 | ]
89 | return h
90 |
91 | def value_function(self):
92 | return self._value_out.flatten()
93 |
94 | @override(ModelV2)
95 | def forward(
96 | self,
97 | input_dict,
98 | state,
99 | seq_lens,
100 | ):
101 | """Adds time dimension to batch before sending inputs to forward_rnn().
102 |
103 | You should implement forward_rnn() in your subclass."""
104 | scalar_inputs = input_dict["obs"][0].float()
105 | input_matrix = input_dict["obs"][1].float()
106 |
107 | input_matrix = input_matrix.unsqueeze(1)
108 | cnn_out = self.cnn(input_matrix)
109 |
110 | flat_inputs = scalar_inputs.flatten(start_dim=1)
111 | # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
112 | # as input_dict may have extra zero-padding beyond seq_lens.max().
113 | # Use add_time_dimension to handle this
114 | self.time_major = self.model_config.get("_time_major", False)
115 | inputs = add_time_dimension(
116 | flat_inputs,
117 | seq_lens=seq_lens,
118 | framework="torch",
119 | time_major=self.time_major,
120 | )
121 | lstm_out, new_state = self.forward_rnn(inputs, state, seq_lens)
122 | lstm_out = torch.reshape(lstm_out, [-1, self.lstm_state_size])
123 |
124 | value_input = torch.cat((cnn_out, lstm_out), dim=1)
125 | value_input = self.join(value_input)
126 |
127 | self._value_out = self.value_fn(value_input)
128 | return self.policy_fn(value_input), new_state
129 |
130 | @override(TorchRNN)
131 | def forward_rnn(self, inputs, state, seq_lens):
132 | """Feeds `inputs` (B x T x ..) through the Gru Unit.
133 |
134 | Returns the resulting outputs as a sequence (B x T x ...).
135 | Values are stored in self._cur_value in simple (B) shape (where B
136 | contains both the B and T dims!).
137 |
138 | Returns:
139 | NN Outputs (B x T x ...) as sequence.
140 | The state batches as a List of two items (c- and h-states).
141 | """
142 | linear_out = nn.functional.tanh(self.linear(inputs))
143 |
144 | lstm_out, [h, c] = self.lstm(linear_out, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)])
145 |
146 | return lstm_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
147 |
148 |
149 | def env_creator(_):
150 | """
151 | Petting Zoo environment for search of shipwrecked people.
152 | check it out at
153 | https://github.com/pfeinsper/drone-swarm-search
154 | or install with
155 | pip install DSSE
156 | """
157 | render_mode = "human" if args.see else "ansi"
158 | env = DroneSwarmSearch(
159 | drone_amount=4,
160 | grid_size=40,
161 | dispersion_inc=0.1,
162 | person_initial_position=(20, 20),
163 | render_mode=render_mode,
164 | render_grid=True
165 | )
166 | positions = [
167 | (20, 0),
168 | (20, 39),
169 | (0, 20),
170 | (39, 20),
171 | ]
172 | env = AllPositionsWrapper(env)
173 | env = RetainDronePosWrapper(env, positions)
174 | return env
175 |
176 | env = env_creator(None)
177 | register_env("DSSE", lambda config: ParallelPettingZooEnv(env_creator(config)))
178 | ModelCatalog.register_custom_model("CNNModel", CNNModel)
179 | ray.init()
180 |
181 |
182 | checkpoint_path = args.checkpoint
183 | PPOagent = PPO.from_checkpoint(checkpoint_path)
184 |
185 | reward_sum = 0
186 | i = 0
187 |
188 | if args.see:
189 | sample_model = PPOagent.get_policy().model
190 | with PygameRecord("test_trained.gif", 5) as rec:
191 | obs, info = env.reset()
192 | reward_sum = 0
193 | state = sample_model.get_initial_state()
194 | # state = init_hidden()
195 | i = 0
196 | # done = False
197 | while env.agents:
198 | print(obs)
199 | actions = {}
200 | # for k, v in obs.items():
201 | # action, state, _ = PPOagent.compute_single_action(v, state, explore=False)
202 | # actions[k] = action
203 | actions = PPOagent.compute_actions(obs, state, explore=False)
204 | obs, rw, term, trunc, info = env.step(actions)
205 | # done = any(term.values()) or any(trunc.values())
206 | reward_sum += sum(rw.values())
207 | i += 1
208 | rec.add_frame()
209 | else:
210 | rewards = []
211 | actions_statics = []
212 | founds = 0
213 | N_EVALS = 200
214 | sample_model = PPOagent.get_policy().model
215 | for _ in range(N_EVALS):
216 | print(_)
217 | obs, info = env.reset()
218 | reward_sum = 0
219 | # state = init_hidden()
220 | state = sample_model.get_initial_state()
221 | i = 0
222 | while env.agents:
223 | actions = {}
224 | # for k, v in obs.items():
225 | # action, state, _ = PPOagent.compute_single_action(v, state, explore=False)
226 | # actions[k] = action
227 | actions = PPOagent.compute_actions(obs, state, explore=False)
228 | obs, rw, term, trunc, info = env.step(actions)
229 | reward_sum += sum(rw.values())
230 | i += 1
231 | actions_statics.append(i)
232 | rewards.append(reward_sum)
233 | for _, v in info.items():
234 | if v["Found"]:
235 | founds += 1
236 | break
237 |
238 | print("Average reward: ", sum(rewards) / N_EVALS)
239 | print("Found %: ", founds / N_EVALS)
240 | print("Mean steps: ", sum(actions_statics) / N_EVALS)
241 | print("Median steps: ", np.median(actions_statics))
242 | print("Found: ", info)
243 | env.close()
--------------------------------------------------------------------------------