├── .gitignore
├── LICENSE
├── README-rsrc
├── doorkey.png
├── evaluate-terminal-logs.png
├── model.png
├── model.xml
├── train-tensorboard.png
├── train-terminal-logs.png
├── visualize-doorkey.gif
├── visualize-gotodoor.gif
├── visualize-keycorridor.gif
└── visualize-redbluedoors.gif
├── README.md
├── model.py
├── requirements.txt
├── scripts
├── evaluate.py
├── train.py
└── visualize.py
├── storage
└── .gitignore
└── utils
├── __init__.py
├── agent.py
├── env.py
├── format.py
├── other.py
└── storage.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
2 | *egg-info
3 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Lucas Willems
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README-rsrc/doorkey.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/doorkey.png
--------------------------------------------------------------------------------
/README-rsrc/evaluate-terminal-logs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/evaluate-terminal-logs.png
--------------------------------------------------------------------------------
/README-rsrc/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/model.png
--------------------------------------------------------------------------------
/README-rsrc/model.xml:
--------------------------------------------------------------------------------
1 | 7Vtbj5s4FP41SLsPW8UYSHhMssnsQytVnV3t9tEBh9ACjohz66+vCXYwNpkyEweS0Y40Gvv4fs53bsZjwWl6eMrRevWJhDix7EF4sOCflm0De2izPwXlWFJcb1ASojwOeaeK8Bz/wJwoum3jEG9qHSkhCY3XdWJAsgwHtEZDeU729W5LktRXXaOIrzioCM8BSrDW7d84pKuSOnKl3n/hOFqJlcGAtyxQ8D3KyTbj61k2XJ5+yuYUibl4/80KhWQvkeDMgtOcEFqW0sMUJwVvBdvKcfMLred95zijbQaIERt6FGfHIWMFr5KcrkhEMpTMKurkdD5czDBgtRVNE1YErIgPMf2vIH9wee0r78S2kx+lpqL6lY/6hik9cgygLSWMVK37kZA176efTWyebPOA754Dj6I8wmfplLTiYNI4zpAnTFLMdsM65DhBNN7VMYA4lKJzv4qdrMA5eoG75RQ7lGz5pFOS7RhF53qdp/tVTPHzGp1OtWc6VufzRU7scE7x4cVD8lZbIJZrKBSI3Fd4B4K2krAOB9fzxdH4oqMwC8eFIrNakKDNJg4UqL0AqIvskVFhtwSFxA+3gR2C1ho7fIXPJGabq8ThKeIYKWwuMc5HyQqtTjT8xUQlD7SJTiI7H7uVFF1NinFa2FXbSxhzJpvtghWjokgFjU0pkQV1kVcUBQYMzbQu+A3NyXc8JQnJGSUjWWGRlnGSKCSUxFFWoIeBADP6pNCNmJn4MW9I4zA8mbMmZaurowl9cxS5+Lq+OQ34sg2om/Cqj2DkZUky5xkiPFoGmthZy9ifOrNJW7fg6QbA78sreJrezNIFDsM4i65zCwZ51x7YfwAF2KClI3EMIHuom6CMnfYdmqAlyagkQW84nvhzMxKE/Vmm0W0DAYMKIVsO765CB9dQ5KA4KMe9WeDgdJ902PfmkHwdVmBo2iOdhjLtQUepw7qQ4KY1noA3eBE2aiKh9GeFcgNvBYuvmYinL/+8Az955nMXfrLJrt5rBNhKfUCDWQZtUzrjEZ3YjZ7o63zvO9E/J4idJPqPm3ksRwEOGq3CHPhT32+N1KYbKac3pF68koL3h9Qur6SAaxSpZzzKaCzB2RaPAu1AxvpphlshFXaB1OYoAqhRh+PXp7gQxb4hoBDHlFRg9ZqE8RHSw9nYm3ieGbX0FckMXU0rb5Ue2jdxH7pCtVXJb9t0LRaPTlJ4SdFvpaaOrqZ2R2oKlWtq4Cpqai5JhH7nocPw7mKHBlGbzxJbu0j9s83H578/XRc3GOTc2yMN0DYmtv3r2WhfD2zdDt1ZdjZsMFFtb82M4xZe70UekuFub4ZCv5ZnSUaAKM7Y73X2woD2A+UTKWibZxjRfj3P6gCMv7oo7/kytDHzgH3Bt+F5woUEYQL+TxIuBIYNWcL5kZTpNAFen7vfu4EXxlzWENjbfWfD04/5NkmOvZt29eOF16Flh977R2HDMwrY3+s6/SHFfaDQ7xGE+pcIazq3im/C3IGtUaY+fpA913RmTUashtKCL9lis5Y6Nr+dKKjSrL+x5T9YQ7bVAXnNvVox4BUXcb/fqZM1EaCq6anfoS+1NQDtOsTOjRHzII99THhCNSBrwBAwgyFWrR6slzd91X8FwNlP
--------------------------------------------------------------------------------
/README-rsrc/train-tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/train-tensorboard.png
--------------------------------------------------------------------------------
/README-rsrc/train-terminal-logs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/train-terminal-logs.png
--------------------------------------------------------------------------------
/README-rsrc/visualize-doorkey.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-doorkey.gif
--------------------------------------------------------------------------------
/README-rsrc/visualize-gotodoor.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-gotodoor.gif
--------------------------------------------------------------------------------
/README-rsrc/visualize-keycorridor.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-keycorridor.gif
--------------------------------------------------------------------------------
/README-rsrc/visualize-redbluedoors.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lcswillems/rl-starter-files/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/README-rsrc/visualize-redbluedoors.gif
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RL Starter Files
2 |
3 | RL starter files in order to immediatly train, visualize and evaluate an agent **without writing any line of code**.
4 |
5 |
6 |
7 |
8 |
9 | These files are suited for [`minigrid`](https://github.com/Farama-Foundation/Minigrid) environments and [`torch-ac`](https://github.com/lcswillems/torch-ac) RL algorithms. They are easy to adapt to other environments and RL algorithms.
10 |
11 | ## Features
12 |
13 | - **Script to train**, including:
14 | - Log in txt, CSV and Tensorboard
15 | - Save model
16 | - Stop and restart training
17 | - Use A2C or PPO algorithms
18 | - **Script to visualize**, including:
19 | - Act by sampling or argmax
20 | - Save as Gif
21 | - **Script to evaluate**, including:
22 | - Act by sampling or argmax
23 | - List the worst performed episodes
24 |
25 | ## Installation
26 |
27 | 1. Clone this repository.
28 |
29 | 2. Install `minigrid` environments and `torch-ac` RL algorithms:
30 |
31 | ```
32 | pip3 install -r requirements.txt
33 | ```
34 |
35 | **Note:** If you want to modify `torch-ac` algorithms, you will need to rather install a cloned version, i.e.:
36 | ```
37 | git clone https://github.com/lcswillems/torch-ac.git
38 | cd torch-ac
39 | pip3 install -e .
40 | ```
41 |
42 | ## Example of use
43 |
44 | Train, visualize and evaluate an agent on the `MiniGrid-DoorKey-5x5-v0` environment:
45 |
46 | 
47 |
48 | 1. Train the agent on the `MiniGrid-DoorKey-5x5-v0` environment with PPO algorithm:
49 |
50 | ```
51 | python3 -m scripts.train --algo ppo --env MiniGrid-DoorKey-5x5-v0 --model DoorKey --save-interval 10 --frames 80000
52 | ```
53 |
54 | 
55 |
56 | 2. Visualize agent's behavior:
57 |
58 | ```
59 | python3 -m scripts.visualize --env MiniGrid-DoorKey-5x5-v0 --model DoorKey
60 | ```
61 |
62 | 
63 |
64 | 3. Evaluate agent's performance:
65 |
66 | ```
67 | python3 -m scripts.evaluate --env MiniGrid-DoorKey-5x5-v0 --model DoorKey
68 | ```
69 |
70 | 
71 |
72 | **Note:** More details on the commands are given below.
73 |
74 | ## Other examples
75 |
76 | ### Handle textual instructions
77 |
78 | In the `GoToDoor` environment, the agent receives an image along with a textual instruction. To handle the latter, add `--text` to the command:
79 |
80 | ```
81 | python3 -m scripts.train --algo ppo --env MiniGrid-GoToDoor-5x5-v0 --model GoToDoor --text --save-interval 10 --frames 1000000
82 | ```
83 |
84 | 
85 |
86 | ### Add memory
87 |
88 | In the `RedBlueDoors` environment, the agent has to open the red door then the blue one. To solve it efficiently, when it opens the red door, it has to remember it. To add memory to the agent, add `--recurrence X` to the command:
89 |
90 | ```
91 | python3 -m scripts.train --algo ppo --env MiniGrid-RedBlueDoors-6x6-v0 --model RedBlueDoors --recurrence 4 --save-interval 10 --frames 1000000
92 | ```
93 |
94 | 
95 |
96 | ## Files
97 |
98 | This package contains:
99 | - scripts to:
100 | - train an agent \
101 | in `script/train.py` ([more details](#scripts-train))
102 | - visualize agent's behavior \
103 | in `script/visualize.py` ([more details](#scripts-visualize))
104 | - evaluate agent's performances \
105 | in `script/evaluate.py` ([more details](#scripts-evaluate))
106 | - a default agent's model \
107 | in `model.py` ([more details](#model))
108 | - utilitarian classes and functions used by the scripts \
109 | in `utils`
110 |
111 | These files are suited for [`minigrid`](https://github.com/Farama-Foundation/Minigrid) environments and [`torch-ac`](https://github.com/lcswillems/torch-ac) RL algorithms. They are easy to adapt to other environments and RL algorithms by modifying:
112 | - `model.py`
113 | - `utils/format.py`
114 |
115 | scripts/train.py
116 |
117 | An example of use:
118 |
119 | ```bash
120 | python3 -m scripts.train --algo ppo --env MiniGrid-DoorKey-5x5-v0 --model DoorKey --save-interval 10 --frames 80000
121 | ```
122 |
123 | The script loads the model in `storage/DoorKey` or creates it if it doesn't exist, then trains it with the PPO algorithm on the MiniGrid DoorKey environment, and saves it every 10 updates in `storage/DoorKey`. It stops after 80 000 frames.
124 |
125 | **Note:** You can define a different storage location in the environment variable `PROJECT_STORAGE`.
126 |
127 | More generally, the script has 2 required arguments:
128 | - `--algo ALGO`: name of the RL algorithm used to train
129 | - `--env ENV`: name of the environment to train on
130 |
131 | and a bunch of optional arguments among which:
132 | - `--recurrence N`: gradient will be backpropagated over N timesteps. By default, N = 1. If N > 1, a LSTM is added to the model to have memory.
133 | - `--text`: a GRU is added to the model to handle text input.
134 | - ... (see more using `--help`)
135 |
136 | During training, logs are printed in your terminal (and saved in text and CSV format):
137 |
138 | 
139 |
140 | **Note:** `U` gives the update number, `F` the total number of frames, `FPS` the number of frames per second, `D` the total duration, `rR:μσmM` the mean, std, min and max reshaped return per episode, `F:μσmM` the mean, std, min and max number of frames per episode, `H` the entropy, `V` the value, `pL` the policy loss, `vL` the value loss and `∇` the gradient norm.
141 |
142 | During training, logs are also plotted in Tensorboard:
143 |
144 | 
145 |
146 | scripts/visualize.py
147 |
148 | An example of use:
149 |
150 | ```
151 | python3 -m scripts.visualize --env MiniGrid-DoorKey-5x5-v0 --model DoorKey
152 | ```
153 |
154 | 
155 |
156 | In this use case, the script displays how the model in `storage/DoorKey` behaves on the MiniGrid DoorKey environment.
157 |
158 | More generally, the script has 2 required arguments:
159 | - `--env ENV`: name of the environment to act on.
160 | - `--model MODEL`: name of the trained model.
161 |
162 | and a bunch of optional arguments among which:
163 | - `--argmax`: select the action with highest probability
164 | - ... (see more using `--help`)
165 |
166 | scripts/evaluate.py
167 |
168 | An example of use:
169 |
170 | ```
171 | python3 -m scripts.evaluate --env MiniGrid-DoorKey-5x5-v0 --model DoorKey
172 | ```
173 |
174 | 
175 |
176 | In this use case, the script prints in the terminal the performance among 100 episodes of the model in `storage/DoorKey`.
177 |
178 | More generally, the script has 2 required arguments:
179 | - `--env ENV`: name of the environment to act on.
180 | - `--model MODEL`: name of the trained model.
181 |
182 | and a bunch of optional arguments among which:
183 | - `--episodes N`: number of episodes of evaluation. By default, N = 100.
184 | - ... (see more using `--help`)
185 |
186 | model.py
187 |
188 | The default model is discribed by the following schema:
189 |
190 | 
191 |
192 | By default, the memory part (in red) and the langage part (in blue) are disabled. They can be enabled by setting to `True` the `use_memory` and `use_text` parameters of the model constructor.
193 |
194 | This model can be easily adapted to your needs.
195 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions.categorical import Categorical
5 | import torch_ac
6 |
7 |
8 | # Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
9 | def init_params(m):
10 | classname = m.__class__.__name__
11 | if classname.find("Linear") != -1:
12 | m.weight.data.normal_(0, 1)
13 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
14 | if m.bias is not None:
15 | m.bias.data.fill_(0)
16 |
17 |
18 | class ACModel(nn.Module, torch_ac.RecurrentACModel):
19 | def __init__(self, obs_space, action_space, use_memory=False, use_text=False):
20 | super().__init__()
21 |
22 | # Decide which components are enabled
23 | self.use_text = use_text
24 | self.use_memory = use_memory
25 |
26 | # Define image embedding
27 | self.image_conv = nn.Sequential(
28 | nn.Conv2d(3, 16, (2, 2)),
29 | nn.ReLU(),
30 | nn.MaxPool2d((2, 2)),
31 | nn.Conv2d(16, 32, (2, 2)),
32 | nn.ReLU(),
33 | nn.Conv2d(32, 64, (2, 2)),
34 | nn.ReLU()
35 | )
36 | n = obs_space["image"][0]
37 | m = obs_space["image"][1]
38 | self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64
39 |
40 | # Define memory
41 | if self.use_memory:
42 | self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size)
43 |
44 | # Define text embedding
45 | if self.use_text:
46 | self.word_embedding_size = 32
47 | self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size)
48 | self.text_embedding_size = 128
49 | self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True)
50 |
51 | # Resize image embedding
52 | self.embedding_size = self.semi_memory_size
53 | if self.use_text:
54 | self.embedding_size += self.text_embedding_size
55 |
56 | # Define actor's model
57 | self.actor = nn.Sequential(
58 | nn.Linear(self.embedding_size, 64),
59 | nn.Tanh(),
60 | nn.Linear(64, action_space.n)
61 | )
62 |
63 | # Define critic's model
64 | self.critic = nn.Sequential(
65 | nn.Linear(self.embedding_size, 64),
66 | nn.Tanh(),
67 | nn.Linear(64, 1)
68 | )
69 |
70 | # Initialize parameters correctly
71 | self.apply(init_params)
72 |
73 | @property
74 | def memory_size(self):
75 | return 2*self.semi_memory_size
76 |
77 | @property
78 | def semi_memory_size(self):
79 | return self.image_embedding_size
80 |
81 | def forward(self, obs, memory):
82 | x = obs.image.transpose(1, 3).transpose(2, 3)
83 | x = self.image_conv(x)
84 | x = x.reshape(x.shape[0], -1)
85 |
86 | if self.use_memory:
87 | hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])
88 | hidden = self.memory_rnn(x, hidden)
89 | embedding = hidden[0]
90 | memory = torch.cat(hidden, dim=1)
91 | else:
92 | embedding = x
93 |
94 | if self.use_text:
95 | embed_text = self._get_embed_text(obs.text)
96 | embedding = torch.cat((embedding, embed_text), dim=1)
97 |
98 | x = self.actor(embedding)
99 | dist = Categorical(logits=F.log_softmax(x, dim=1))
100 |
101 | x = self.critic(embedding)
102 | value = x.squeeze(1)
103 |
104 | return dist, value, memory
105 |
106 | def _get_embed_text(self, text):
107 | _, hidden = self.text_rnn(self.word_embedding(text))
108 | return hidden[-1]
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch-ac>=1.4.0
2 | minigrid>=2.2.0
3 | tensorboardX>=1.6
4 | numpy>=1.3
5 | gymnasium>=0.26
6 |
--------------------------------------------------------------------------------
/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import torch
4 | from torch_ac.utils.penv import ParallelEnv
5 |
6 | import utils
7 | from utils import device
8 |
9 |
10 | # Parse arguments
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--env", required=True,
14 | help="name of the environment (REQUIRED)")
15 | parser.add_argument("--model", required=True,
16 | help="name of the trained model (REQUIRED)")
17 | parser.add_argument("--episodes", type=int, default=100,
18 | help="number of episodes of evaluation (default: 100)")
19 | parser.add_argument("--seed", type=int, default=0,
20 | help="random seed (default: 0)")
21 | parser.add_argument("--procs", type=int, default=16,
22 | help="number of processes (default: 16)")
23 | parser.add_argument("--argmax", action="store_true", default=False,
24 | help="action with highest probability is selected")
25 | parser.add_argument("--worst-episodes-to-show", type=int, default=10,
26 | help="how many worst episodes to show")
27 | parser.add_argument("--memory", action="store_true", default=False,
28 | help="add a LSTM to the model")
29 | parser.add_argument("--text", action="store_true", default=False,
30 | help="add a GRU to the model")
31 |
32 | if __name__ == "__main__":
33 | args = parser.parse_args()
34 |
35 | # Set seed for all randomness sources
36 |
37 | utils.seed(args.seed)
38 |
39 | # Set device
40 |
41 | print(f"Device: {device}\n")
42 |
43 | # Load environments
44 |
45 | envs = []
46 | for i in range(args.procs):
47 | env = utils.make_env(args.env, args.seed + 10000 * i)
48 | envs.append(env)
49 | env = ParallelEnv(envs)
50 | print("Environments loaded\n")
51 |
52 | # Load agent
53 |
54 | model_dir = utils.get_model_dir(args.model)
55 | agent = utils.Agent(env.observation_space, env.action_space, model_dir,
56 | argmax=args.argmax, num_envs=args.procs,
57 | use_memory=args.memory, use_text=args.text)
58 | print("Agent loaded\n")
59 |
60 | # Initialize logs
61 |
62 | logs = {"num_frames_per_episode": [], "return_per_episode": []}
63 |
64 | # Run agent
65 |
66 | start_time = time.time()
67 |
68 | obss = env.reset()
69 |
70 | log_done_counter = 0
71 | log_episode_return = torch.zeros(args.procs, device=device)
72 | log_episode_num_frames = torch.zeros(args.procs, device=device)
73 |
74 | while log_done_counter < args.episodes:
75 | actions = agent.get_actions(obss)
76 | obss, rewards, terminateds, truncateds, _ = env.step(actions)
77 | dones = tuple(a | b for a, b in zip(terminateds, truncateds))
78 | agent.analyze_feedbacks(rewards, dones)
79 |
80 | log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float)
81 | log_episode_num_frames += torch.ones(args.procs, device=device)
82 |
83 | for i, done in enumerate(dones):
84 | if done:
85 | log_done_counter += 1
86 | logs["return_per_episode"].append(log_episode_return[i].item())
87 | logs["num_frames_per_episode"].append(log_episode_num_frames[i].item())
88 |
89 | mask = 1 - torch.tensor(dones, device=device, dtype=torch.float)
90 | log_episode_return *= mask
91 | log_episode_num_frames *= mask
92 |
93 | end_time = time.time()
94 |
95 | # Print logs
96 |
97 | num_frames = sum(logs["num_frames_per_episode"])
98 | fps = num_frames / (end_time - start_time)
99 | duration = int(end_time - start_time)
100 | return_per_episode = utils.synthesize(logs["return_per_episode"])
101 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])
102 |
103 | print("F {} | FPS {:.0f} | D {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}"
104 | .format(num_frames, fps, duration,
105 | *return_per_episode.values(),
106 | *num_frames_per_episode.values()))
107 |
108 | # Print worst episodes
109 |
110 | n = args.worst_episodes_to_show
111 | if n > 0:
112 | print("\n{} worst episodes:".format(n))
113 |
114 | indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k])
115 | for i in indexes[:n]:
116 | print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i]))
117 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import datetime
4 | import torch_ac
5 | import tensorboardX
6 | import sys
7 |
8 | import utils
9 | from utils import device
10 | from model import ACModel
11 |
12 |
13 | # Parse arguments
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | # General parameters
18 | parser.add_argument("--algo", required=True,
19 | help="algorithm to use: a2c | ppo (REQUIRED)")
20 | parser.add_argument("--env", required=True,
21 | help="name of the environment to train on (REQUIRED)")
22 | parser.add_argument("--model", default=None,
23 | help="name of the model (default: {ENV}_{ALGO}_{TIME})")
24 | parser.add_argument("--seed", type=int, default=1,
25 | help="random seed (default: 1)")
26 | parser.add_argument("--log-interval", type=int, default=1,
27 | help="number of updates between two logs (default: 1)")
28 | parser.add_argument("--save-interval", type=int, default=10,
29 | help="number of updates between two saves (default: 10, 0 means no saving)")
30 | parser.add_argument("--procs", type=int, default=16,
31 | help="number of processes (default: 16)")
32 | parser.add_argument("--frames", type=int, default=10**7,
33 | help="number of frames of training (default: 1e7)")
34 |
35 | # Parameters for main algorithm
36 | parser.add_argument("--epochs", type=int, default=4,
37 | help="number of epochs for PPO (default: 4)")
38 | parser.add_argument("--batch-size", type=int, default=256,
39 | help="batch size for PPO (default: 256)")
40 | parser.add_argument("--frames-per-proc", type=int, default=None,
41 | help="number of frames per process before update (default: 5 for A2C and 128 for PPO)")
42 | parser.add_argument("--discount", type=float, default=0.99,
43 | help="discount factor (default: 0.99)")
44 | parser.add_argument("--lr", type=float, default=0.001,
45 | help="learning rate (default: 0.001)")
46 | parser.add_argument("--gae-lambda", type=float, default=0.95,
47 | help="lambda coefficient in GAE formula (default: 0.95, 1 means no gae)")
48 | parser.add_argument("--entropy-coef", type=float, default=0.01,
49 | help="entropy term coefficient (default: 0.01)")
50 | parser.add_argument("--value-loss-coef", type=float, default=0.5,
51 | help="value loss term coefficient (default: 0.5)")
52 | parser.add_argument("--max-grad-norm", type=float, default=0.5,
53 | help="maximum norm of gradient (default: 0.5)")
54 | parser.add_argument("--optim-eps", type=float, default=1e-8,
55 | help="Adam and RMSprop optimizer epsilon (default: 1e-8)")
56 | parser.add_argument("--optim-alpha", type=float, default=0.99,
57 | help="RMSprop optimizer alpha (default: 0.99)")
58 | parser.add_argument("--clip-eps", type=float, default=0.2,
59 | help="clipping epsilon for PPO (default: 0.2)")
60 | parser.add_argument("--recurrence", type=int, default=1,
61 | help="number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory.")
62 | parser.add_argument("--text", action="store_true", default=False,
63 | help="add a GRU to the model to handle text input")
64 |
65 | if __name__ == "__main__":
66 | args = parser.parse_args()
67 |
68 | args.mem = args.recurrence > 1
69 |
70 | # Set run dir
71 |
72 | date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
73 | default_model_name = f"{args.env}_{args.algo}_seed{args.seed}_{date}"
74 |
75 | model_name = args.model or default_model_name
76 | model_dir = utils.get_model_dir(model_name)
77 |
78 | # Load loggers and Tensorboard writer
79 |
80 | txt_logger = utils.get_txt_logger(model_dir)
81 | csv_file, csv_logger = utils.get_csv_logger(model_dir)
82 | tb_writer = tensorboardX.SummaryWriter(model_dir)
83 |
84 | # Log command and all script arguments
85 |
86 | txt_logger.info("{}\n".format(" ".join(sys.argv)))
87 | txt_logger.info("{}\n".format(args))
88 |
89 | # Set seed for all randomness sources
90 |
91 | utils.seed(args.seed)
92 |
93 | # Set device
94 |
95 | txt_logger.info(f"Device: {device}\n")
96 |
97 | # Load environments
98 |
99 | envs = []
100 | for i in range(args.procs):
101 | envs.append(utils.make_env(args.env, args.seed + 10000 * i))
102 | txt_logger.info("Environments loaded\n")
103 |
104 | # Load training status
105 |
106 | try:
107 | status = utils.get_status(model_dir)
108 | except OSError:
109 | status = {"num_frames": 0, "update": 0}
110 | txt_logger.info("Training status loaded\n")
111 |
112 | # Load observations preprocessor
113 |
114 | obs_space, preprocess_obss = utils.get_obss_preprocessor(envs[0].observation_space)
115 | if "vocab" in status:
116 | preprocess_obss.vocab.load_vocab(status["vocab"])
117 | txt_logger.info("Observations preprocessor loaded")
118 |
119 | # Load model
120 |
121 | acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text)
122 | if "model_state" in status:
123 | acmodel.load_state_dict(status["model_state"])
124 | acmodel.to(device)
125 | txt_logger.info("Model loaded\n")
126 | txt_logger.info("{}\n".format(acmodel))
127 |
128 | # Load algo
129 |
130 | if args.algo == "a2c":
131 | algo = torch_ac.A2CAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
132 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
133 | args.optim_alpha, args.optim_eps, preprocess_obss)
134 | elif args.algo == "ppo":
135 | algo = torch_ac.PPOAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
136 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
137 | args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss)
138 | else:
139 | raise ValueError("Incorrect algorithm name: {}".format(args.algo))
140 |
141 | if "optimizer_state" in status:
142 | algo.optimizer.load_state_dict(status["optimizer_state"])
143 | txt_logger.info("Optimizer loaded\n")
144 |
145 | # Train model
146 |
147 | num_frames = status["num_frames"]
148 | update = status["update"]
149 | start_time = time.time()
150 |
151 | while num_frames < args.frames:
152 | # Update model parameters
153 | update_start_time = time.time()
154 | exps, logs1 = algo.collect_experiences()
155 | logs2 = algo.update_parameters(exps)
156 | logs = {**logs1, **logs2}
157 | update_end_time = time.time()
158 |
159 | num_frames += logs["num_frames"]
160 | update += 1
161 |
162 | # Print logs
163 |
164 | if update % args.log_interval == 0:
165 | fps = logs["num_frames"] / (update_end_time - update_start_time)
166 | duration = int(time.time() - start_time)
167 | return_per_episode = utils.synthesize(logs["return_per_episode"])
168 | rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
169 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])
170 |
171 | header = ["update", "frames", "FPS", "duration"]
172 | data = [update, num_frames, fps, duration]
173 | header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
174 | data += rreturn_per_episode.values()
175 | header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
176 | data += num_frames_per_episode.values()
177 | header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
178 | data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]
179 |
180 | txt_logger.info(
181 | "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
182 | .format(*data))
183 |
184 | header += ["return_" + key for key in return_per_episode.keys()]
185 | data += return_per_episode.values()
186 |
187 | if status["num_frames"] == 0:
188 | csv_logger.writerow(header)
189 | csv_logger.writerow(data)
190 | csv_file.flush()
191 |
192 | for field, value in zip(header, data):
193 | tb_writer.add_scalar(field, value, num_frames)
194 |
195 | # Save status
196 |
197 | if args.save_interval > 0 and update % args.save_interval == 0:
198 | status = {"num_frames": num_frames, "update": update,
199 | "model_state": acmodel.state_dict(), "optimizer_state": algo.optimizer.state_dict()}
200 | if hasattr(preprocess_obss, "vocab"):
201 | status["vocab"] = preprocess_obss.vocab.vocab
202 | utils.save_status(status, model_dir)
203 | txt_logger.info("Status saved")
204 |
--------------------------------------------------------------------------------
/scripts/visualize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy
3 |
4 | import utils
5 | from utils import device
6 |
7 |
8 | # Parse arguments
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--env", required=True,
12 | help="name of the environment to be run (REQUIRED)")
13 | parser.add_argument("--model", required=True,
14 | help="name of the trained model (REQUIRED)")
15 | parser.add_argument("--seed", type=int, default=0,
16 | help="random seed (default: 0)")
17 | parser.add_argument("--shift", type=int, default=0,
18 | help="number of times the environment is reset at the beginning (default: 0)")
19 | parser.add_argument("--argmax", action="store_true", default=False,
20 | help="select the action with highest probability (default: False)")
21 | parser.add_argument("--pause", type=float, default=0.1,
22 | help="pause duration between two consequent actions of the agent (default: 0.1)")
23 | parser.add_argument("--gif", type=str, default=None,
24 | help="store output as gif with the given filename")
25 | parser.add_argument("--episodes", type=int, default=1000000,
26 | help="number of episodes to visualize")
27 | parser.add_argument("--memory", action="store_true", default=False,
28 | help="add a LSTM to the model")
29 | parser.add_argument("--text", action="store_true", default=False,
30 | help="add a GRU to the model")
31 |
32 | args = parser.parse_args()
33 |
34 | # Set seed for all randomness sources
35 |
36 | utils.seed(args.seed)
37 |
38 | # Set device
39 |
40 | print(f"Device: {device}\n")
41 |
42 | # Load environment
43 |
44 | env = utils.make_env(args.env, args.seed, render_mode="human")
45 | for _ in range(args.shift):
46 | env.reset()
47 | print("Environment loaded\n")
48 |
49 | # Load agent
50 |
51 | model_dir = utils.get_model_dir(args.model)
52 | agent = utils.Agent(env.observation_space, env.action_space, model_dir,
53 | argmax=args.argmax, use_memory=args.memory, use_text=args.text)
54 | print("Agent loaded\n")
55 |
56 | # Run the agent
57 |
58 | if args.gif:
59 | from array2gif import write_gif
60 |
61 | frames = []
62 |
63 | # Create a window to view the environment
64 | env.render()
65 |
66 | for episode in range(args.episodes):
67 | obs, _ = env.reset()
68 |
69 | while True:
70 | env.render()
71 | if args.gif:
72 | frames.append(numpy.moveaxis(env.get_frame(), 2, 0))
73 |
74 | action = agent.get_action(obs)
75 | obs, reward, terminated, truncated, _ = env.step(action)
76 | done = terminated | truncated
77 | agent.analyze_feedback(reward, done)
78 |
79 | if done:
80 | break
81 |
82 | if args.gif:
83 | print("Saving gif... ", end="")
84 | write_gif(numpy.array(frames), args.gif+".gif", fps=1/args.pause)
85 | print("Done.")
86 |
--------------------------------------------------------------------------------
/storage/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent import *
2 | from .env import *
3 | from .format import *
4 | from .other import *
5 | from .storage import *
6 |
--------------------------------------------------------------------------------
/utils/agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import utils
4 | from .other import device
5 | from model import ACModel
6 |
7 |
8 | class Agent:
9 | """An agent.
10 |
11 | It is able:
12 | - to choose an action given an observation,
13 | - to analyze the feedback (i.e. reward and done state) of its action."""
14 |
15 | def __init__(self, obs_space, action_space, model_dir,
16 | argmax=False, num_envs=1, use_memory=False, use_text=False):
17 | obs_space, self.preprocess_obss = utils.get_obss_preprocessor(obs_space)
18 | self.acmodel = ACModel(obs_space, action_space, use_memory=use_memory, use_text=use_text)
19 | self.argmax = argmax
20 | self.num_envs = num_envs
21 |
22 | if self.acmodel.recurrent:
23 | self.memories = torch.zeros(self.num_envs, self.acmodel.memory_size, device=device)
24 |
25 | self.acmodel.load_state_dict(utils.get_model_state(model_dir))
26 | self.acmodel.to(device)
27 | self.acmodel.eval()
28 | if hasattr(self.preprocess_obss, "vocab"):
29 | self.preprocess_obss.vocab.load_vocab(utils.get_vocab(model_dir))
30 |
31 | def get_actions(self, obss):
32 | preprocessed_obss = self.preprocess_obss(obss, device=device)
33 |
34 | with torch.no_grad():
35 | if self.acmodel.recurrent:
36 | dist, _, self.memories = self.acmodel(preprocessed_obss, self.memories)
37 | else:
38 | dist, _ = self.acmodel(preprocessed_obss)
39 |
40 | if self.argmax:
41 | actions = dist.probs.max(1, keepdim=True)[1]
42 | else:
43 | actions = dist.sample()
44 |
45 | return actions.cpu().numpy()
46 |
47 | def get_action(self, obs):
48 | return self.get_actions([obs])[0]
49 |
50 | def analyze_feedbacks(self, rewards, dones):
51 | if self.acmodel.recurrent:
52 | masks = 1 - torch.tensor(dones, dtype=torch.float, device=device).unsqueeze(1)
53 | self.memories *= masks
54 |
55 | def analyze_feedback(self, reward, done):
56 | return self.analyze_feedbacks([reward], [done])
57 |
--------------------------------------------------------------------------------
/utils/env.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 |
4 | def make_env(env_key, seed=None, render_mode=None):
5 | env = gym.make(env_key, render_mode=render_mode)
6 | env.reset(seed=seed)
7 | return env
8 |
--------------------------------------------------------------------------------
/utils/format.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy
4 | import re
5 | import torch
6 | import torch_ac
7 | import gymnasium as gym
8 |
9 | import utils
10 |
11 |
12 | def get_obss_preprocessor(obs_space):
13 | # Check if obs_space is an image space
14 | if isinstance(obs_space, gym.spaces.Box):
15 | obs_space = {"image": obs_space.shape}
16 |
17 | def preprocess_obss(obss, device=None):
18 | return torch_ac.DictList({
19 | "image": preprocess_images(obss, device=device)
20 | })
21 |
22 | # Check if it is a MiniGrid observation space
23 | elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
24 | obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
25 |
26 | vocab = Vocabulary(obs_space["text"])
27 |
28 | def preprocess_obss(obss, device=None):
29 | return torch_ac.DictList({
30 | "image": preprocess_images([obs["image"] for obs in obss], device=device),
31 | "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
32 | })
33 |
34 | preprocess_obss.vocab = vocab
35 |
36 | else:
37 | raise ValueError("Unknown observation space: " + str(obs_space))
38 |
39 | return obs_space, preprocess_obss
40 |
41 |
42 | def preprocess_images(images, device=None):
43 | # Bug of Pytorch: very slow if not first converted to numpy array
44 | images = numpy.array(images)
45 | return torch.tensor(images, device=device, dtype=torch.float)
46 |
47 |
48 | def preprocess_texts(texts, vocab, device=None):
49 | var_indexed_texts = []
50 | max_text_len = 0
51 |
52 | for text in texts:
53 | tokens = re.findall("([a-z]+)", text.lower())
54 | var_indexed_text = numpy.array([vocab[token] for token in tokens])
55 | var_indexed_texts.append(var_indexed_text)
56 | max_text_len = max(len(var_indexed_text), max_text_len)
57 |
58 | indexed_texts = numpy.zeros((len(texts), max_text_len))
59 |
60 | for i, indexed_text in enumerate(var_indexed_texts):
61 | indexed_texts[i, :len(indexed_text)] = indexed_text
62 |
63 | return torch.tensor(indexed_texts, device=device, dtype=torch.long)
64 |
65 |
66 | class Vocabulary:
67 | """A mapping from tokens to ids with a capacity of `max_size` words.
68 | It can be saved in a `vocab.json` file."""
69 |
70 | def __init__(self, max_size):
71 | self.max_size = max_size
72 | self.vocab = {}
73 |
74 | def load_vocab(self, vocab):
75 | self.vocab = vocab
76 |
77 | def __getitem__(self, token):
78 | if not token in self.vocab.keys():
79 | if len(self.vocab) >= self.max_size:
80 | raise ValueError("Maximum vocabulary capacity reached")
81 | self.vocab[token] = len(self.vocab) + 1
82 | return self.vocab[token]
83 |
--------------------------------------------------------------------------------
/utils/other.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy
3 | import torch
4 | import collections
5 |
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 |
10 | def seed(seed):
11 | random.seed(seed)
12 | numpy.random.seed(seed)
13 | torch.manual_seed(seed)
14 | if torch.cuda.is_available():
15 | torch.cuda.manual_seed_all(seed)
16 |
17 |
18 | def synthesize(array):
19 | d = collections.OrderedDict()
20 | d["mean"] = numpy.mean(array)
21 | d["std"] = numpy.std(array)
22 | d["min"] = numpy.amin(array)
23 | d["max"] = numpy.amax(array)
24 | return d
25 |
--------------------------------------------------------------------------------
/utils/storage.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import torch
4 | import logging
5 | import sys
6 |
7 | import utils
8 | from .other import device
9 |
10 |
11 | def create_folders_if_necessary(path):
12 | dirname = os.path.dirname(path)
13 | if not os.path.isdir(dirname):
14 | os.makedirs(dirname)
15 |
16 |
17 | def get_storage_dir():
18 | if "RL_STORAGE" in os.environ:
19 | return os.environ["RL_STORAGE"]
20 | return "storage"
21 |
22 |
23 | def get_model_dir(model_name):
24 | return os.path.join(get_storage_dir(), model_name)
25 |
26 |
27 | def get_status_path(model_dir):
28 | return os.path.join(model_dir, "status.pt")
29 |
30 |
31 | def get_status(model_dir):
32 | path = get_status_path(model_dir)
33 | return torch.load(path, map_location=device)
34 |
35 |
36 | def save_status(status, model_dir):
37 | path = get_status_path(model_dir)
38 | utils.create_folders_if_necessary(path)
39 | torch.save(status, path)
40 |
41 |
42 | def get_vocab(model_dir):
43 | return get_status(model_dir)["vocab"]
44 |
45 |
46 | def get_model_state(model_dir):
47 | return get_status(model_dir)["model_state"]
48 |
49 |
50 | def get_txt_logger(model_dir):
51 | path = os.path.join(model_dir, "log.txt")
52 | utils.create_folders_if_necessary(path)
53 |
54 | logging.basicConfig(
55 | level=logging.INFO,
56 | format="%(message)s",
57 | handlers=[
58 | logging.FileHandler(filename=path),
59 | logging.StreamHandler(sys.stdout)
60 | ]
61 | )
62 |
63 | return logging.getLogger()
64 |
65 |
66 | def get_csv_logger(model_dir):
67 | csv_path = os.path.join(model_dir, "log.csv")
68 | utils.create_folders_if_necessary(csv_path)
69 | csv_file = open(csv_path, "a")
70 | return csv_file, csv.writer(csv_file)
71 |
--------------------------------------------------------------------------------