├── .gitignore ├── LICENSE ├── README-rsrc ├── doorkey.png ├── evaluate-terminal-logs.png ├── model.png ├── model.xml ├── train-tensorboard.png ├── train-terminal-logs.png ├── visualize-doorkey.gif ├── visualize-gotodoor.gif ├── visualize-keycorridor.gif └── visualize-redbluedoors.gif ├── README.md ├── model.py ├── requirements.txt ├── scripts ├── evaluate.py ├── train.py └── visualize.py ├── storage └── .gitignore └── utils ├── __init__.py ├── agent.py ├── env.py ├── format.py ├── other.py └── storage.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | *egg-info 3 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lucas Willems 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-rsrc/doorkey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/doorkey.png -------------------------------------------------------------------------------- /README-rsrc/evaluate-terminal-logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/evaluate-terminal-logs.png -------------------------------------------------------------------------------- /README-rsrc/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/model.png -------------------------------------------------------------------------------- /README-rsrc/model.xml: -------------------------------------------------------------------------------- 1 | 7Vtbj5s4FP41SLsPW8UYSHhMssnsQytVnV3t9tEBh9ACjohz66+vCXYwNpkyEweS0Y40Gvv4fs53bsZjwWl6eMrRevWJhDix7EF4sOCflm0De2izPwXlWFJcb1ASojwOeaeK8Bz/wJwoum3jEG9qHSkhCY3XdWJAsgwHtEZDeU729W5LktRXXaOIrzioCM8BSrDW7d84pKuSOnKl3n/hOFqJlcGAtyxQ8D3KyTbj61k2XJ5+yuYUibl4/80KhWQvkeDMgtOcEFqW0sMUJwVvBdvKcfMLred95zijbQaIERt6FGfHIWMFr5KcrkhEMpTMKurkdD5czDBgtRVNE1YErIgPMf2vIH9wee0r78S2kx+lpqL6lY/6hik9cgygLSWMVK37kZA176efTWyebPOA754Dj6I8wmfplLTiYNI4zpAnTFLMdsM65DhBNN7VMYA4lKJzv4qdrMA5eoG75RQ7lGz5pFOS7RhF53qdp/tVTPHzGp1OtWc6VufzRU7scE7x4cVD8lZbIJZrKBSI3Fd4B4K2krAOB9fzxdH4oqMwC8eFIrNakKDNJg4UqL0AqIvskVFhtwSFxA+3gR2C1ho7fIXPJGabq8ThKeIYKWwuMc5HyQqtTjT8xUQlD7SJTiI7H7uVFF1NinFa2FXbSxhzJpvtghWjokgFjU0pkQV1kVcUBQYMzbQu+A3NyXc8JQnJGSUjWWGRlnGSKCSUxFFWoIeBADP6pNCNmJn4MW9I4zA8mbMmZaurowl9cxS5+Lq+OQ34sg2om/Cqj2DkZUky5xkiPFoGmthZy9ifOrNJW7fg6QbA78sreJrezNIFDsM4i65zCwZ51x7YfwAF2KClI3EMIHuom6CMnfYdmqAlyagkQW84nvhzMxKE/Vmm0W0DAYMKIVsO765CB9dQ5KA4KMe9WeDgdJ902PfmkHwdVmBo2iOdhjLtQUepw7qQ4KY1noA3eBE2aiKh9GeFcgNvBYuvmYinL/+8Az955nMXfrLJrt5rBNhKfUCDWQZtUzrjEZ3YjZ7o63zvO9E/J4idJPqPm3ksRwEOGq3CHPhT32+N1KYbKac3pF68koL3h9Qur6SAaxSpZzzKaCzB2RaPAu1AxvpphlshFXaB1OYoAqhRh+PXp7gQxb4hoBDHlFRg9ZqE8RHSw9nYm3ieGbX0FckMXU0rb5Ue2jdxH7pCtVXJb9t0LRaPTlJ4SdFvpaaOrqZ2R2oKlWtq4Cpqai5JhH7nocPw7mKHBlGbzxJbu0j9s83H578/XRc3GOTc2yMN0DYmtv3r2WhfD2zdDt1ZdjZsMFFtb82M4xZe70UekuFub4ZCv5ZnSUaAKM7Y73X2woD2A+UTKWibZxjRfj3P6gCMv7oo7/kytDHzgH3Bt+F5woUEYQL+TxIuBIYNWcL5kZTpNAFen7vfu4EXxlzWENjbfWfD04/5NkmOvZt29eOF16Flh977R2HDMwrY3+s6/SHFfaDQ7xGE+pcIazq3im/C3IGtUaY+fpA913RmTUashtKCL9lis5Y6Nr+dKKjSrL+x5T9YQ7bVAXnNvVox4BUXcb/fqZM1EaCq6anfoS+1NQDtOsTOjRHzII99THhCNSBrwBAwgyFWrR6slzd91X8FwNlP -------------------------------------------------------------------------------- /README-rsrc/train-tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/train-tensorboard.png -------------------------------------------------------------------------------- /README-rsrc/train-terminal-logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/train-terminal-logs.png -------------------------------------------------------------------------------- /README-rsrc/visualize-doorkey.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-doorkey.gif -------------------------------------------------------------------------------- /README-rsrc/visualize-gotodoor.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-gotodoor.gif -------------------------------------------------------------------------------- /README-rsrc/visualize-keycorridor.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-keycorridor.gif -------------------------------------------------------------------------------- /README-rsrc/visualize-redbluedoors.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-redbluedoors.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL Starter Files 2 | 3 | RL starter files in order to immediatly train, visualize and evaluate an agent **without writing any line of code**. 4 | 5 |

6 | 7 |

8 | 9 | These files are suited for [`minigrid`](https://github.com/Farama-Foundation/Minigrid) environments and [`torch-ac`](https://github.com/lcswillems/torch-ac) RL algorithms. They are easy to adapt to other environments and RL algorithms. 10 | 11 | ## Features 12 | 13 | - **Script to train**, including: 14 | - Log in txt, CSV and Tensorboard 15 | - Save model 16 | - Stop and restart training 17 | - Use A2C or PPO algorithms 18 | - **Script to visualize**, including: 19 | - Act by sampling or argmax 20 | - Save as Gif 21 | - **Script to evaluate**, including: 22 | - Act by sampling or argmax 23 | - List the worst performed episodes 24 | 25 | ## Installation 26 | 27 | 1. Clone this repository. 28 | 29 | 2. Install `minigrid` environments and `torch-ac` RL algorithms: 30 | 31 | ``` 32 | pip3 install -r requirements.txt 33 | ``` 34 | 35 | **Note:** If you want to modify `torch-ac` algorithms, you will need to rather install a cloned version, i.e.: 36 | ``` 37 | git clone https://github.com/lcswillems/torch-ac.git 38 | cd torch-ac 39 | pip3 install -e . 40 | ``` 41 | 42 | ## Example of use 43 | 44 | Train, visualize and evaluate an agent on the `MiniGrid-DoorKey-5x5-v0` environment: 45 | 46 |

47 | 48 | 1. Train the agent on the `MiniGrid-DoorKey-5x5-v0` environment with PPO algorithm: 49 | 50 | ``` 51 | python3 -m scripts.train --algo ppo --env MiniGrid-DoorKey-5x5-v0 --model DoorKey --save-interval 10 --frames 80000 52 | ``` 53 | 54 |

55 | 56 | 2. Visualize agent's behavior: 57 | 58 | ``` 59 | python3 -m scripts.visualize --env MiniGrid-DoorKey-5x5-v0 --model DoorKey 60 | ``` 61 | 62 |

63 | 64 | 3. Evaluate agent's performance: 65 | 66 | ``` 67 | python3 -m scripts.evaluate --env MiniGrid-DoorKey-5x5-v0 --model DoorKey 68 | ``` 69 | 70 |

71 | 72 | **Note:** More details on the commands are given below. 73 | 74 | ## Other examples 75 | 76 | ### Handle textual instructions 77 | 78 | In the `GoToDoor` environment, the agent receives an image along with a textual instruction. To handle the latter, add `--text` to the command: 79 | 80 | ``` 81 | python3 -m scripts.train --algo ppo --env MiniGrid-GoToDoor-5x5-v0 --model GoToDoor --text --save-interval 10 --frames 1000000 82 | ``` 83 | 84 |

85 | 86 | ### Add memory 87 | 88 | In the `RedBlueDoors` environment, the agent has to open the red door then the blue one. To solve it efficiently, when it opens the red door, it has to remember it. To add memory to the agent, add `--recurrence X` to the command: 89 | 90 | ``` 91 | python3 -m scripts.train --algo ppo --env MiniGrid-RedBlueDoors-6x6-v0 --model RedBlueDoors --recurrence 4 --save-interval 10 --frames 1000000 92 | ``` 93 | 94 |

95 | 96 | ## Files 97 | 98 | This package contains: 99 | - scripts to: 100 | - train an agent \ 101 | in `script/train.py` ([more details](#scripts-train)) 102 | - visualize agent's behavior \ 103 | in `script/visualize.py` ([more details](#scripts-visualize)) 104 | - evaluate agent's performances \ 105 | in `script/evaluate.py` ([more details](#scripts-evaluate)) 106 | - a default agent's model \ 107 | in `model.py` ([more details](#model)) 108 | - utilitarian classes and functions used by the scripts \ 109 | in `utils` 110 | 111 | These files are suited for [`minigrid`](https://github.com/Farama-Foundation/Minigrid) environments and [`torch-ac`](https://github.com/lcswillems/torch-ac) RL algorithms. They are easy to adapt to other environments and RL algorithms by modifying: 112 | - `model.py` 113 | - `utils/format.py` 114 | 115 |

scripts/train.py

116 | 117 | An example of use: 118 | 119 | ```bash 120 | python3 -m scripts.train --algo ppo --env MiniGrid-DoorKey-5x5-v0 --model DoorKey --save-interval 10 --frames 80000 121 | ``` 122 | 123 | The script loads the model in `storage/DoorKey` or creates it if it doesn't exist, then trains it with the PPO algorithm on the MiniGrid DoorKey environment, and saves it every 10 updates in `storage/DoorKey`. It stops after 80 000 frames. 124 | 125 | **Note:** You can define a different storage location in the environment variable `PROJECT_STORAGE`. 126 | 127 | More generally, the script has 2 required arguments: 128 | - `--algo ALGO`: name of the RL algorithm used to train 129 | - `--env ENV`: name of the environment to train on 130 | 131 | and a bunch of optional arguments among which: 132 | - `--recurrence N`: gradient will be backpropagated over N timesteps. By default, N = 1. If N > 1, a LSTM is added to the model to have memory. 133 | - `--text`: a GRU is added to the model to handle text input. 134 | - ... (see more using `--help`) 135 | 136 | During training, logs are printed in your terminal (and saved in text and CSV format): 137 | 138 |

139 | 140 | **Note:** `U` gives the update number, `F` the total number of frames, `FPS` the number of frames per second, `D` the total duration, `rR:μσmM` the mean, std, min and max reshaped return per episode, `F:μσmM` the mean, std, min and max number of frames per episode, `H` the entropy, `V` the value, `pL` the policy loss, `vL` the value loss and `∇` the gradient norm. 141 | 142 | During training, logs are also plotted in Tensorboard: 143 | 144 |

145 | 146 |

scripts/visualize.py

147 | 148 | An example of use: 149 | 150 | ``` 151 | python3 -m scripts.visualize --env MiniGrid-DoorKey-5x5-v0 --model DoorKey 152 | ``` 153 | 154 |

155 | 156 | In this use case, the script displays how the model in `storage/DoorKey` behaves on the MiniGrid DoorKey environment. 157 | 158 | More generally, the script has 2 required arguments: 159 | - `--env ENV`: name of the environment to act on. 160 | - `--model MODEL`: name of the trained model. 161 | 162 | and a bunch of optional arguments among which: 163 | - `--argmax`: select the action with highest probability 164 | - ... (see more using `--help`) 165 | 166 |

scripts/evaluate.py

167 | 168 | An example of use: 169 | 170 | ``` 171 | python3 -m scripts.evaluate --env MiniGrid-DoorKey-5x5-v0 --model DoorKey 172 | ``` 173 | 174 |

175 | 176 | In this use case, the script prints in the terminal the performance among 100 episodes of the model in `storage/DoorKey`. 177 | 178 | More generally, the script has 2 required arguments: 179 | - `--env ENV`: name of the environment to act on. 180 | - `--model MODEL`: name of the trained model. 181 | 182 | and a bunch of optional arguments among which: 183 | - `--episodes N`: number of episodes of evaluation. By default, N = 100. 184 | - ... (see more using `--help`) 185 | 186 |

model.py

187 | 188 | The default model is discribed by the following schema: 189 | 190 |

191 | 192 | By default, the memory part (in red) and the langage part (in blue) are disabled. They can be enabled by setting to `True` the `use_memory` and `use_text` parameters of the model constructor. 193 | 194 | This model can be easily adapted to your needs. 195 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.categorical import Categorical 5 | import torch_ac 6 | 7 | 8 | # Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py 9 | def init_params(m): 10 | classname = m.__class__.__name__ 11 | if classname.find("Linear") != -1: 12 | m.weight.data.normal_(0, 1) 13 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 14 | if m.bias is not None: 15 | m.bias.data.fill_(0) 16 | 17 | 18 | class ACModel(nn.Module, torch_ac.RecurrentACModel): 19 | def __init__(self, obs_space, action_space, use_memory=False, use_text=False): 20 | super().__init__() 21 | 22 | # Decide which components are enabled 23 | self.use_text = use_text 24 | self.use_memory = use_memory 25 | 26 | # Define image embedding 27 | self.image_conv = nn.Sequential( 28 | nn.Conv2d(3, 16, (2, 2)), 29 | nn.ReLU(), 30 | nn.MaxPool2d((2, 2)), 31 | nn.Conv2d(16, 32, (2, 2)), 32 | nn.ReLU(), 33 | nn.Conv2d(32, 64, (2, 2)), 34 | nn.ReLU() 35 | ) 36 | n = obs_space["image"][0] 37 | m = obs_space["image"][1] 38 | self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64 39 | 40 | # Define memory 41 | if self.use_memory: 42 | self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size) 43 | 44 | # Define text embedding 45 | if self.use_text: 46 | self.word_embedding_size = 32 47 | self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size) 48 | self.text_embedding_size = 128 49 | self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True) 50 | 51 | # Resize image embedding 52 | self.embedding_size = self.semi_memory_size 53 | if self.use_text: 54 | self.embedding_size += self.text_embedding_size 55 | 56 | # Define actor's model 57 | self.actor = nn.Sequential( 58 | nn.Linear(self.embedding_size, 64), 59 | nn.Tanh(), 60 | nn.Linear(64, action_space.n) 61 | ) 62 | 63 | # Define critic's model 64 | self.critic = nn.Sequential( 65 | nn.Linear(self.embedding_size, 64), 66 | nn.Tanh(), 67 | nn.Linear(64, 1) 68 | ) 69 | 70 | # Initialize parameters correctly 71 | self.apply(init_params) 72 | 73 | @property 74 | def memory_size(self): 75 | return 2*self.semi_memory_size 76 | 77 | @property 78 | def semi_memory_size(self): 79 | return self.image_embedding_size 80 | 81 | def forward(self, obs, memory): 82 | x = obs.image.transpose(1, 3).transpose(2, 3) 83 | x = self.image_conv(x) 84 | x = x.reshape(x.shape[0], -1) 85 | 86 | if self.use_memory: 87 | hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) 88 | hidden = self.memory_rnn(x, hidden) 89 | embedding = hidden[0] 90 | memory = torch.cat(hidden, dim=1) 91 | else: 92 | embedding = x 93 | 94 | if self.use_text: 95 | embed_text = self._get_embed_text(obs.text) 96 | embedding = torch.cat((embedding, embed_text), dim=1) 97 | 98 | x = self.actor(embedding) 99 | dist = Categorical(logits=F.log_softmax(x, dim=1)) 100 | 101 | x = self.critic(embedding) 102 | value = x.squeeze(1) 103 | 104 | return dist, value, memory 105 | 106 | def _get_embed_text(self, text): 107 | _, hidden = self.text_rnn(self.word_embedding(text)) 108 | return hidden[-1] 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch-ac>=1.4.0 2 | minigrid>=2.2.0 3 | tensorboardX>=1.6 4 | numpy>=1.3 5 | gymnasium>=0.26 6 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | from torch_ac.utils.penv import ParallelEnv 5 | 6 | import utils 7 | from utils import device 8 | 9 | 10 | # Parse arguments 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--env", required=True, 14 | help="name of the environment (REQUIRED)") 15 | parser.add_argument("--model", required=True, 16 | help="name of the trained model (REQUIRED)") 17 | parser.add_argument("--episodes", type=int, default=100, 18 | help="number of episodes of evaluation (default: 100)") 19 | parser.add_argument("--seed", type=int, default=0, 20 | help="random seed (default: 0)") 21 | parser.add_argument("--procs", type=int, default=16, 22 | help="number of processes (default: 16)") 23 | parser.add_argument("--argmax", action="store_true", default=False, 24 | help="action with highest probability is selected") 25 | parser.add_argument("--worst-episodes-to-show", type=int, default=10, 26 | help="how many worst episodes to show") 27 | parser.add_argument("--memory", action="store_true", default=False, 28 | help="add a LSTM to the model") 29 | parser.add_argument("--text", action="store_true", default=False, 30 | help="add a GRU to the model") 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | 35 | # Set seed for all randomness sources 36 | 37 | utils.seed(args.seed) 38 | 39 | # Set device 40 | 41 | print(f"Device: {device}\n") 42 | 43 | # Load environments 44 | 45 | envs = [] 46 | for i in range(args.procs): 47 | env = utils.make_env(args.env, args.seed + 10000 * i) 48 | envs.append(env) 49 | env = ParallelEnv(envs) 50 | print("Environments loaded\n") 51 | 52 | # Load agent 53 | 54 | model_dir = utils.get_model_dir(args.model) 55 | agent = utils.Agent(env.observation_space, env.action_space, model_dir, 56 | argmax=args.argmax, num_envs=args.procs, 57 | use_memory=args.memory, use_text=args.text) 58 | print("Agent loaded\n") 59 | 60 | # Initialize logs 61 | 62 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 63 | 64 | # Run agent 65 | 66 | start_time = time.time() 67 | 68 | obss = env.reset() 69 | 70 | log_done_counter = 0 71 | log_episode_return = torch.zeros(args.procs, device=device) 72 | log_episode_num_frames = torch.zeros(args.procs, device=device) 73 | 74 | while log_done_counter < args.episodes: 75 | actions = agent.get_actions(obss) 76 | obss, rewards, terminateds, truncateds, _ = env.step(actions) 77 | dones = tuple(a | b for a, b in zip(terminateds, truncateds)) 78 | agent.analyze_feedbacks(rewards, dones) 79 | 80 | log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float) 81 | log_episode_num_frames += torch.ones(args.procs, device=device) 82 | 83 | for i, done in enumerate(dones): 84 | if done: 85 | log_done_counter += 1 86 | logs["return_per_episode"].append(log_episode_return[i].item()) 87 | logs["num_frames_per_episode"].append(log_episode_num_frames[i].item()) 88 | 89 | mask = 1 - torch.tensor(dones, device=device, dtype=torch.float) 90 | log_episode_return *= mask 91 | log_episode_num_frames *= mask 92 | 93 | end_time = time.time() 94 | 95 | # Print logs 96 | 97 | num_frames = sum(logs["num_frames_per_episode"]) 98 | fps = num_frames / (end_time - start_time) 99 | duration = int(end_time - start_time) 100 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 101 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 102 | 103 | print("F {} | FPS {:.0f} | D {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}" 104 | .format(num_frames, fps, duration, 105 | *return_per_episode.values(), 106 | *num_frames_per_episode.values())) 107 | 108 | # Print worst episodes 109 | 110 | n = args.worst_episodes_to_show 111 | if n > 0: 112 | print("\n{} worst episodes:".format(n)) 113 | 114 | indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k]) 115 | for i in indexes[:n]: 116 | print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i])) 117 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import datetime 4 | import torch_ac 5 | import tensorboardX 6 | import sys 7 | 8 | import utils 9 | from utils import device 10 | from model import ACModel 11 | 12 | 13 | # Parse arguments 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | # General parameters 18 | parser.add_argument("--algo", required=True, 19 | help="algorithm to use: a2c | ppo (REQUIRED)") 20 | parser.add_argument("--env", required=True, 21 | help="name of the environment to train on (REQUIRED)") 22 | parser.add_argument("--model", default=None, 23 | help="name of the model (default: {ENV}_{ALGO}_{TIME})") 24 | parser.add_argument("--seed", type=int, default=1, 25 | help="random seed (default: 1)") 26 | parser.add_argument("--log-interval", type=int, default=1, 27 | help="number of updates between two logs (default: 1)") 28 | parser.add_argument("--save-interval", type=int, default=10, 29 | help="number of updates between two saves (default: 10, 0 means no saving)") 30 | parser.add_argument("--procs", type=int, default=16, 31 | help="number of processes (default: 16)") 32 | parser.add_argument("--frames", type=int, default=10**7, 33 | help="number of frames of training (default: 1e7)") 34 | 35 | # Parameters for main algorithm 36 | parser.add_argument("--epochs", type=int, default=4, 37 | help="number of epochs for PPO (default: 4)") 38 | parser.add_argument("--batch-size", type=int, default=256, 39 | help="batch size for PPO (default: 256)") 40 | parser.add_argument("--frames-per-proc", type=int, default=None, 41 | help="number of frames per process before update (default: 5 for A2C and 128 for PPO)") 42 | parser.add_argument("--discount", type=float, default=0.99, 43 | help="discount factor (default: 0.99)") 44 | parser.add_argument("--lr", type=float, default=0.001, 45 | help="learning rate (default: 0.001)") 46 | parser.add_argument("--gae-lambda", type=float, default=0.95, 47 | help="lambda coefficient in GAE formula (default: 0.95, 1 means no gae)") 48 | parser.add_argument("--entropy-coef", type=float, default=0.01, 49 | help="entropy term coefficient (default: 0.01)") 50 | parser.add_argument("--value-loss-coef", type=float, default=0.5, 51 | help="value loss term coefficient (default: 0.5)") 52 | parser.add_argument("--max-grad-norm", type=float, default=0.5, 53 | help="maximum norm of gradient (default: 0.5)") 54 | parser.add_argument("--optim-eps", type=float, default=1e-8, 55 | help="Adam and RMSprop optimizer epsilon (default: 1e-8)") 56 | parser.add_argument("--optim-alpha", type=float, default=0.99, 57 | help="RMSprop optimizer alpha (default: 0.99)") 58 | parser.add_argument("--clip-eps", type=float, default=0.2, 59 | help="clipping epsilon for PPO (default: 0.2)") 60 | parser.add_argument("--recurrence", type=int, default=1, 61 | help="number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory.") 62 | parser.add_argument("--text", action="store_true", default=False, 63 | help="add a GRU to the model to handle text input") 64 | 65 | if __name__ == "__main__": 66 | args = parser.parse_args() 67 | 68 | args.mem = args.recurrence > 1 69 | 70 | # Set run dir 71 | 72 | date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") 73 | default_model_name = f"{args.env}_{args.algo}_seed{args.seed}_{date}" 74 | 75 | model_name = args.model or default_model_name 76 | model_dir = utils.get_model_dir(model_name) 77 | 78 | # Load loggers and Tensorboard writer 79 | 80 | txt_logger = utils.get_txt_logger(model_dir) 81 | csv_file, csv_logger = utils.get_csv_logger(model_dir) 82 | tb_writer = tensorboardX.SummaryWriter(model_dir) 83 | 84 | # Log command and all script arguments 85 | 86 | txt_logger.info("{}\n".format(" ".join(sys.argv))) 87 | txt_logger.info("{}\n".format(args)) 88 | 89 | # Set seed for all randomness sources 90 | 91 | utils.seed(args.seed) 92 | 93 | # Set device 94 | 95 | txt_logger.info(f"Device: {device}\n") 96 | 97 | # Load environments 98 | 99 | envs = [] 100 | for i in range(args.procs): 101 | envs.append(utils.make_env(args.env, args.seed + 10000 * i)) 102 | txt_logger.info("Environments loaded\n") 103 | 104 | # Load training status 105 | 106 | try: 107 | status = utils.get_status(model_dir) 108 | except OSError: 109 | status = {"num_frames": 0, "update": 0} 110 | txt_logger.info("Training status loaded\n") 111 | 112 | # Load observations preprocessor 113 | 114 | obs_space, preprocess_obss = utils.get_obss_preprocessor(envs[0].observation_space) 115 | if "vocab" in status: 116 | preprocess_obss.vocab.load_vocab(status["vocab"]) 117 | txt_logger.info("Observations preprocessor loaded") 118 | 119 | # Load model 120 | 121 | acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text) 122 | if "model_state" in status: 123 | acmodel.load_state_dict(status["model_state"]) 124 | acmodel.to(device) 125 | txt_logger.info("Model loaded\n") 126 | txt_logger.info("{}\n".format(acmodel)) 127 | 128 | # Load algo 129 | 130 | if args.algo == "a2c": 131 | algo = torch_ac.A2CAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda, 132 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 133 | args.optim_alpha, args.optim_eps, preprocess_obss) 134 | elif args.algo == "ppo": 135 | algo = torch_ac.PPOAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda, 136 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 137 | args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss) 138 | else: 139 | raise ValueError("Incorrect algorithm name: {}".format(args.algo)) 140 | 141 | if "optimizer_state" in status: 142 | algo.optimizer.load_state_dict(status["optimizer_state"]) 143 | txt_logger.info("Optimizer loaded\n") 144 | 145 | # Train model 146 | 147 | num_frames = status["num_frames"] 148 | update = status["update"] 149 | start_time = time.time() 150 | 151 | while num_frames < args.frames: 152 | # Update model parameters 153 | update_start_time = time.time() 154 | exps, logs1 = algo.collect_experiences() 155 | logs2 = algo.update_parameters(exps) 156 | logs = {**logs1, **logs2} 157 | update_end_time = time.time() 158 | 159 | num_frames += logs["num_frames"] 160 | update += 1 161 | 162 | # Print logs 163 | 164 | if update % args.log_interval == 0: 165 | fps = logs["num_frames"] / (update_end_time - update_start_time) 166 | duration = int(time.time() - start_time) 167 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 168 | rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"]) 169 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 170 | 171 | header = ["update", "frames", "FPS", "duration"] 172 | data = [update, num_frames, fps, duration] 173 | header += ["rreturn_" + key for key in rreturn_per_episode.keys()] 174 | data += rreturn_per_episode.values() 175 | header += ["num_frames_" + key for key in num_frames_per_episode.keys()] 176 | data += num_frames_per_episode.values() 177 | header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"] 178 | data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]] 179 | 180 | txt_logger.info( 181 | "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}" 182 | .format(*data)) 183 | 184 | header += ["return_" + key for key in return_per_episode.keys()] 185 | data += return_per_episode.values() 186 | 187 | if status["num_frames"] == 0: 188 | csv_logger.writerow(header) 189 | csv_logger.writerow(data) 190 | csv_file.flush() 191 | 192 | for field, value in zip(header, data): 193 | tb_writer.add_scalar(field, value, num_frames) 194 | 195 | # Save status 196 | 197 | if args.save_interval > 0 and update % args.save_interval == 0: 198 | status = {"num_frames": num_frames, "update": update, 199 | "model_state": acmodel.state_dict(), "optimizer_state": algo.optimizer.state_dict()} 200 | if hasattr(preprocess_obss, "vocab"): 201 | status["vocab"] = preprocess_obss.vocab.vocab 202 | utils.save_status(status, model_dir) 203 | txt_logger.info("Status saved") 204 | -------------------------------------------------------------------------------- /scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | 4 | import utils 5 | from utils import device 6 | 7 | 8 | # Parse arguments 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--env", required=True, 12 | help="name of the environment to be run (REQUIRED)") 13 | parser.add_argument("--model", required=True, 14 | help="name of the trained model (REQUIRED)") 15 | parser.add_argument("--seed", type=int, default=0, 16 | help="random seed (default: 0)") 17 | parser.add_argument("--shift", type=int, default=0, 18 | help="number of times the environment is reset at the beginning (default: 0)") 19 | parser.add_argument("--argmax", action="store_true", default=False, 20 | help="select the action with highest probability (default: False)") 21 | parser.add_argument("--pause", type=float, default=0.1, 22 | help="pause duration between two consequent actions of the agent (default: 0.1)") 23 | parser.add_argument("--gif", type=str, default=None, 24 | help="store output as gif with the given filename") 25 | parser.add_argument("--episodes", type=int, default=1000000, 26 | help="number of episodes to visualize") 27 | parser.add_argument("--memory", action="store_true", default=False, 28 | help="add a LSTM to the model") 29 | parser.add_argument("--text", action="store_true", default=False, 30 | help="add a GRU to the model") 31 | 32 | args = parser.parse_args() 33 | 34 | # Set seed for all randomness sources 35 | 36 | utils.seed(args.seed) 37 | 38 | # Set device 39 | 40 | print(f"Device: {device}\n") 41 | 42 | # Load environment 43 | 44 | env = utils.make_env(args.env, args.seed, render_mode="human") 45 | for _ in range(args.shift): 46 | env.reset() 47 | print("Environment loaded\n") 48 | 49 | # Load agent 50 | 51 | model_dir = utils.get_model_dir(args.model) 52 | agent = utils.Agent(env.observation_space, env.action_space, model_dir, 53 | argmax=args.argmax, use_memory=args.memory, use_text=args.text) 54 | print("Agent loaded\n") 55 | 56 | # Run the agent 57 | 58 | if args.gif: 59 | from array2gif import write_gif 60 | 61 | frames = [] 62 | 63 | # Create a window to view the environment 64 | env.render() 65 | 66 | for episode in range(args.episodes): 67 | obs, _ = env.reset() 68 | 69 | while True: 70 | env.render() 71 | if args.gif: 72 | frames.append(numpy.moveaxis(env.get_frame(), 2, 0)) 73 | 74 | action = agent.get_action(obs) 75 | obs, reward, terminated, truncated, _ = env.step(action) 76 | done = terminated | truncated 77 | agent.analyze_feedback(reward, done) 78 | 79 | if done: 80 | break 81 | 82 | if args.gif: 83 | print("Saving gif... ", end="") 84 | write_gif(numpy.array(frames), args.gif+".gif", fps=1/args.pause) 85 | print("Done.") 86 | -------------------------------------------------------------------------------- /storage/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import * 2 | from .env import * 3 | from .format import * 4 | from .other import * 5 | from .storage import * 6 | -------------------------------------------------------------------------------- /utils/agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utils 4 | from .other import device 5 | from model import ACModel 6 | 7 | 8 | class Agent: 9 | """An agent. 10 | 11 | It is able: 12 | - to choose an action given an observation, 13 | - to analyze the feedback (i.e. reward and done state) of its action.""" 14 | 15 | def __init__(self, obs_space, action_space, model_dir, 16 | argmax=False, num_envs=1, use_memory=False, use_text=False): 17 | obs_space, self.preprocess_obss = utils.get_obss_preprocessor(obs_space) 18 | self.acmodel = ACModel(obs_space, action_space, use_memory=use_memory, use_text=use_text) 19 | self.argmax = argmax 20 | self.num_envs = num_envs 21 | 22 | if self.acmodel.recurrent: 23 | self.memories = torch.zeros(self.num_envs, self.acmodel.memory_size, device=device) 24 | 25 | self.acmodel.load_state_dict(utils.get_model_state(model_dir)) 26 | self.acmodel.to(device) 27 | self.acmodel.eval() 28 | if hasattr(self.preprocess_obss, "vocab"): 29 | self.preprocess_obss.vocab.load_vocab(utils.get_vocab(model_dir)) 30 | 31 | def get_actions(self, obss): 32 | preprocessed_obss = self.preprocess_obss(obss, device=device) 33 | 34 | with torch.no_grad(): 35 | if self.acmodel.recurrent: 36 | dist, _, self.memories = self.acmodel(preprocessed_obss, self.memories) 37 | else: 38 | dist, _ = self.acmodel(preprocessed_obss) 39 | 40 | if self.argmax: 41 | actions = dist.probs.max(1, keepdim=True)[1] 42 | else: 43 | actions = dist.sample() 44 | 45 | return actions.cpu().numpy() 46 | 47 | def get_action(self, obs): 48 | return self.get_actions([obs])[0] 49 | 50 | def analyze_feedbacks(self, rewards, dones): 51 | if self.acmodel.recurrent: 52 | masks = 1 - torch.tensor(dones, dtype=torch.float, device=device).unsqueeze(1) 53 | self.memories *= masks 54 | 55 | def analyze_feedback(self, reward, done): 56 | return self.analyze_feedbacks([reward], [done]) 57 | -------------------------------------------------------------------------------- /utils/env.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | 4 | def make_env(env_key, seed=None, render_mode=None): 5 | env = gym.make(env_key, render_mode=render_mode) 6 | env.reset(seed=seed) 7 | return env 8 | -------------------------------------------------------------------------------- /utils/format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy 4 | import re 5 | import torch 6 | import torch_ac 7 | import gymnasium as gym 8 | 9 | import utils 10 | 11 | 12 | def get_obss_preprocessor(obs_space): 13 | # Check if obs_space is an image space 14 | if isinstance(obs_space, gym.spaces.Box): 15 | obs_space = {"image": obs_space.shape} 16 | 17 | def preprocess_obss(obss, device=None): 18 | return torch_ac.DictList({ 19 | "image": preprocess_images(obss, device=device) 20 | }) 21 | 22 | # Check if it is a MiniGrid observation space 23 | elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys(): 24 | obs_space = {"image": obs_space.spaces["image"].shape, "text": 100} 25 | 26 | vocab = Vocabulary(obs_space["text"]) 27 | 28 | def preprocess_obss(obss, device=None): 29 | return torch_ac.DictList({ 30 | "image": preprocess_images([obs["image"] for obs in obss], device=device), 31 | "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device) 32 | }) 33 | 34 | preprocess_obss.vocab = vocab 35 | 36 | else: 37 | raise ValueError("Unknown observation space: " + str(obs_space)) 38 | 39 | return obs_space, preprocess_obss 40 | 41 | 42 | def preprocess_images(images, device=None): 43 | # Bug of Pytorch: very slow if not first converted to numpy array 44 | images = numpy.array(images) 45 | return torch.tensor(images, device=device, dtype=torch.float) 46 | 47 | 48 | def preprocess_texts(texts, vocab, device=None): 49 | var_indexed_texts = [] 50 | max_text_len = 0 51 | 52 | for text in texts: 53 | tokens = re.findall("([a-z]+)", text.lower()) 54 | var_indexed_text = numpy.array([vocab[token] for token in tokens]) 55 | var_indexed_texts.append(var_indexed_text) 56 | max_text_len = max(len(var_indexed_text), max_text_len) 57 | 58 | indexed_texts = numpy.zeros((len(texts), max_text_len)) 59 | 60 | for i, indexed_text in enumerate(var_indexed_texts): 61 | indexed_texts[i, :len(indexed_text)] = indexed_text 62 | 63 | return torch.tensor(indexed_texts, device=device, dtype=torch.long) 64 | 65 | 66 | class Vocabulary: 67 | """A mapping from tokens to ids with a capacity of `max_size` words. 68 | It can be saved in a `vocab.json` file.""" 69 | 70 | def __init__(self, max_size): 71 | self.max_size = max_size 72 | self.vocab = {} 73 | 74 | def load_vocab(self, vocab): 75 | self.vocab = vocab 76 | 77 | def __getitem__(self, token): 78 | if not token in self.vocab.keys(): 79 | if len(self.vocab) >= self.max_size: 80 | raise ValueError("Maximum vocabulary capacity reached") 81 | self.vocab[token] = len(self.vocab) + 1 82 | return self.vocab[token] 83 | -------------------------------------------------------------------------------- /utils/other.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy 3 | import torch 4 | import collections 5 | 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def seed(seed): 11 | random.seed(seed) 12 | numpy.random.seed(seed) 13 | torch.manual_seed(seed) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | 18 | def synthesize(array): 19 | d = collections.OrderedDict() 20 | d["mean"] = numpy.mean(array) 21 | d["std"] = numpy.std(array) 22 | d["min"] = numpy.amin(array) 23 | d["max"] = numpy.amax(array) 24 | return d 25 | -------------------------------------------------------------------------------- /utils/storage.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import torch 4 | import logging 5 | import sys 6 | 7 | import utils 8 | from .other import device 9 | 10 | 11 | def create_folders_if_necessary(path): 12 | dirname = os.path.dirname(path) 13 | if not os.path.isdir(dirname): 14 | os.makedirs(dirname) 15 | 16 | 17 | def get_storage_dir(): 18 | if "RL_STORAGE" in os.environ: 19 | return os.environ["RL_STORAGE"] 20 | return "storage" 21 | 22 | 23 | def get_model_dir(model_name): 24 | return os.path.join(get_storage_dir(), model_name) 25 | 26 | 27 | def get_status_path(model_dir): 28 | return os.path.join(model_dir, "status.pt") 29 | 30 | 31 | def get_status(model_dir): 32 | path = get_status_path(model_dir) 33 | return torch.load(path, map_location=device) 34 | 35 | 36 | def save_status(status, model_dir): 37 | path = get_status_path(model_dir) 38 | utils.create_folders_if_necessary(path) 39 | torch.save(status, path) 40 | 41 | 42 | def get_vocab(model_dir): 43 | return get_status(model_dir)["vocab"] 44 | 45 | 46 | def get_model_state(model_dir): 47 | return get_status(model_dir)["model_state"] 48 | 49 | 50 | def get_txt_logger(model_dir): 51 | path = os.path.join(model_dir, "log.txt") 52 | utils.create_folders_if_necessary(path) 53 | 54 | logging.basicConfig( 55 | level=logging.INFO, 56 | format="%(message)s", 57 | handlers=[ 58 | logging.FileHandler(filename=path), 59 | logging.StreamHandler(sys.stdout) 60 | ] 61 | ) 62 | 63 | return logging.getLogger() 64 | 65 | 66 | def get_csv_logger(model_dir): 67 | csv_path = os.path.join(model_dir, "log.csv") 68 | utils.create_folders_if_necessary(csv_path) 69 | csv_file = open(csv_path, "a") 70 | return csv_file, csv.writer(csv_file) 71 | --------------------------------------------------------------------------------