├── CMakeLists.txt ├── Models.h ├── ProximalPolicyOptimization.h ├── README.md ├── TestEnvironment.h ├── TestPPO.cpp ├── TrainPPO.cpp ├── data ├── data.csv └── data_test.csv ├── img ├── epoch_1.gif ├── epoch_10.gif └── test_mode.gif └── plot.py /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | project(ppo) 3 | 4 | find_package(Eigen3 REQUIRED) 5 | find_package(Torch REQUIRED) 6 | set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}) 7 | 8 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 9 | 10 | add_executable(train_ppo TrainPPO.cpp) 11 | target_link_libraries(train_ppo ${TORCH_LIBRARIES}) 12 | target_include_directories(train_ppo PRIVATE ${CMAKE_CURRENT_SOURCE}) 13 | set_property(TARGET train_ppo PROPERTY CXX_STANDARD_14) 14 | 15 | add_executable(test_ppo TestPPO.cpp) 16 | target_link_libraries(test_ppo ${TORCH_LIBRARIES}) 17 | target_include_directories(test_ppo PRIVATE ${CMAKE_CURRENT_SOURCE}) 18 | set_property(TARGET test_ppo PROPERTY CXX_STANDARD_14) 19 | -------------------------------------------------------------------------------- /Models.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | // Network model for Proximal Policy Optimization on Incy Wincy. 7 | struct ActorCriticImpl : public torch::nn::Module 8 | { 9 | // Actor. 10 | torch::nn::Linear a_lin1_, a_lin2_, a_lin3_; 11 | torch::Tensor mu_; 12 | torch::Tensor log_std_; 13 | 14 | // Critic. 15 | torch::nn::Linear c_lin1_, c_lin2_, c_lin3_, c_val_; 16 | 17 | ActorCriticImpl(int64_t n_in, int64_t n_out, double std) 18 | : // Actor. 19 | a_lin1_(torch::nn::Linear(n_in, 16)), 20 | a_lin2_(torch::nn::Linear(16, 32)), 21 | a_lin3_(torch::nn::Linear(32, n_out)), 22 | mu_(torch::full(n_out, 0.)), 23 | log_std_(torch::full(n_out, std)), 24 | 25 | // Critic 26 | c_lin1_(torch::nn::Linear(n_in, 16)), 27 | c_lin2_(torch::nn::Linear(16, 32)), 28 | c_lin3_(torch::nn::Linear(32, n_out)), 29 | c_val_(torch::nn::Linear(n_out, 1)) 30 | { 31 | // Register the modules. 32 | register_module("a_lin1", a_lin1_); 33 | register_module("a_lin2", a_lin2_); 34 | register_module("a_lin3", a_lin3_); 35 | register_parameter("log_std", log_std_); 36 | 37 | register_module("c_lin1", c_lin1_); 38 | register_module("c_lin2", c_lin2_); 39 | register_module("c_lin3", c_lin3_); 40 | register_module("c_val", c_val_); 41 | } 42 | 43 | // Forward pass. 44 | auto forward(torch::Tensor x) -> std::tuple 45 | { 46 | 47 | // Actor. 48 | mu_ = torch::relu(a_lin1_->forward(x)); 49 | mu_ = torch::relu(a_lin2_->forward(mu_)); 50 | mu_ = torch::tanh(a_lin3_->forward(mu_)); 51 | 52 | // Critic. 53 | torch::Tensor val = torch::relu(c_lin1_->forward(x)); 54 | val = torch::relu(c_lin2_->forward(val)); 55 | val = torch::tanh(c_lin3_->forward(val)); 56 | val = c_val_->forward(val); 57 | 58 | if (this->is_training()) 59 | { 60 | torch::NoGradGuard no_grad; 61 | 62 | torch::Tensor action = at::normal(mu_, log_std_.exp().expand_as(mu_)); 63 | return std::make_tuple(action, val); 64 | } 65 | else 66 | { 67 | return std::make_tuple(mu_, val); 68 | } 69 | } 70 | 71 | // Initialize network. 72 | void normal(double mu, double std) 73 | { 74 | torch::NoGradGuard no_grad; 75 | 76 | for (auto& p: this->parameters()) 77 | { 78 | p.normal_(mu,std); 79 | } 80 | } 81 | 82 | auto entropy() -> torch::Tensor 83 | { 84 | // Differential entropy of normal distribution. For reference https://pytorch.org/docs/stable/_modules/torch/distributions/normal.html#Normal 85 | return 0.5 + 0.5*log(2*M_PI) + log_std_; 86 | } 87 | 88 | auto log_prob(torch::Tensor action) -> torch::Tensor 89 | { 90 | // Logarithmic probability of taken action, given the current distribution. 91 | torch::Tensor var = (log_std_+log_std_).exp(); 92 | 93 | return -((action - mu_)*(action - mu_))/(2*var) - log_std_ - log(sqrt(2*M_PI)); 94 | } 95 | }; 96 | 97 | TORCH_MODULE(ActorCritic); 98 | -------------------------------------------------------------------------------- /ProximalPolicyOptimization.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "Models.h" 7 | 8 | // Vector of tensors. 9 | using VT = std::vector; 10 | 11 | // Optimizer. 12 | using OPT = torch::optim::Optimizer; 13 | 14 | // Random engine for shuffling memory. 15 | std::random_device rd; 16 | std::mt19937 re(rd()); 17 | 18 | // Proximal policy optimization, https://arxiv.org/abs/1707.06347 19 | class PPO 20 | { 21 | public: 22 | static auto returns(VT& rewards, VT& dones, VT& vals, double gamma, double lambda) -> VT; // Generalized advantage estimate, https://arxiv.org/abs/1506.02438 23 | static auto update(ActorCritic& ac, 24 | torch::Tensor& states, 25 | torch::Tensor& actions, 26 | torch::Tensor& log_probs, 27 | torch::Tensor& returns, 28 | torch::Tensor& advantages, 29 | OPT& opt, 30 | uint steps, uint epochs, uint mini_batch_size, double beta, double clip_param=.2) -> void; 31 | }; 32 | 33 | auto PPO::returns(VT& rewards, VT& dones, VT& vals, double gamma, double lambda) -> VT 34 | { 35 | // Compute the returns. 36 | torch::Tensor gae = torch::zeros({1}, torch::kFloat64); 37 | VT returns(rewards.size(), torch::zeros({1}, torch::kFloat64)); 38 | 39 | for (uint i=rewards.size();i-- >0;) // inverse for loops over unsigned: https://stackoverflow.com/questions/665745/whats-the-best-way-to-do-a-reverse-for-loop-with-an-unsigned-index/665773 40 | { 41 | // Advantage. 42 | auto delta = rewards[i] + gamma*vals[i+1]*(1-dones[i]) - vals[i]; 43 | gae = delta + gamma*lambda*(1-dones[i])*gae; 44 | 45 | returns[i] = gae + vals[i]; 46 | } 47 | 48 | return returns; 49 | } 50 | 51 | auto PPO::update(ActorCritic& ac, 52 | torch::Tensor& states, 53 | torch::Tensor& actions, 54 | torch::Tensor& log_probs, 55 | torch::Tensor& returns, 56 | torch::Tensor& advantages, 57 | OPT& opt, 58 | uint steps, uint epochs, uint mini_batch_size, double beta, double clip_param) -> void 59 | { 60 | for (uint e=0;e(0, steps-1)(re); 72 | cpy_sta[b] = states[idx]; 73 | cpy_act[b] = actions[idx]; 74 | cpy_log[b] = log_probs[idx]; 75 | cpy_ret[b] = returns[idx]; 76 | cpy_adv[b] = advantages[idx]; 77 | } 78 | 79 | auto av = ac->forward(cpy_sta); // action value pairs 80 | auto action = std::get<0>(av); 81 | auto entropy = ac->entropy().mean(); 82 | auto new_log_prob = ac->log_prob(cpy_act); 83 | 84 | auto old_log_prob = cpy_log; 85 | auto ratio = (new_log_prob - old_log_prob).exp(); 86 | auto surr1 = ratio*cpy_adv; 87 | auto surr2 = torch::clamp(ratio, 1. - clip_param, 1. + clip_param)*cpy_adv; 88 | 89 | auto val = std::get<1>(av); 90 | auto actor_loss = -torch::min(surr1, surr2).mean(); 91 | auto critic_loss = (cpy_ret-val).pow(2).mean(); 92 | 93 | auto loss = 0.5*critic_loss+actor_loss-beta*entropy; 94 | 95 | opt.zero_grad(); 96 | loss.backward(); 97 | opt.step(); 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PPO Pytorch C++ 2 | 3 | This is an implementation of the [proximal policy optimization algorithm](https://arxiv.org/abs/1707.06347) for the C++ API of Pytorch. It uses a simple `TestEnvironment` to test the algorithm. Below is a small visualization of the environment, the algorithm is tested in. 4 |
5 |
6 |

7 |
Fig. 1: The agent in testing mode.
8 |
9 |

10 | 11 | ## Build 12 | You first need to install PyTorch. For a clean installation from Anaconda, checkout this short [tutorial](https://gist.github.com/mhubii/1c1049fb5043b8be262259efac4b89d5), or this [tutorial](https://pytorch.org/cppdocs/installing.html), to only install the binaries. 13 | 14 | Do 15 | ```shell 16 | mkdir build 17 | cd build 18 | cmake -DCMAKE_PREFIX_PATH=/absolut/path/to/libtorch .. 19 | make 20 | ``` 21 | 22 | ## Run 23 | Run the executable with 24 | ```shell 25 | cd build 26 | ./train_ppo 27 | ``` 28 | To plot the results, run 29 | ```shell 30 | cd .. 31 | python plot.py --online_view --csv_file data/data.csv --epochs 1 10 32 | ``` 33 | It should produce something like shown below. 34 |
35 |
36 |

37 |
Fig. 2: From left to right, the agent for successive epochs in training mode as it takes actions in the environment to reach the goal.
38 |
39 |

40 | 41 | The algorithm can also be used in test mode, once trained. Therefore, run 42 | ```shell 43 | cd build 44 | ./test_ppo 45 | ``` 46 | To plot the results, run 47 | ```shell 48 | cd .. 49 | python plot.py --online_view --csv_file data/data_test.csv --epochs 1 50 | ``` 51 | ## Visualization 52 | The results are saved to `data/data.csv` and can be visualized by running `python plot.py`. Run 53 | ```shell 54 | python plot.py --help 55 | ``` 56 | for help. 57 | -------------------------------------------------------------------------------- /TestEnvironment.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | enum STATUS { 5 | PLAYING, 6 | WON, 7 | LOST, 8 | RESETTING 9 | }; 10 | 11 | struct TestEnvironment 12 | { 13 | Eigen::Vector2d pos_; 14 | Eigen::Vector2d goal_; 15 | Eigen::VectorXd state_; 16 | 17 | double old_dist_; 18 | 19 | TestEnvironment(double x, double y) : goal_(2), pos_(2), state_(4) 20 | { 21 | goal_ << x, y; 22 | pos_.setZero(); 23 | state_ << pos_, goal_; 24 | 25 | old_dist_ = GoalDist(pos_); 26 | }; 27 | 28 | auto Act(double act_x, double act_y) -> std::tuple 29 | { 30 | old_dist_ = GoalDist(pos_); 31 | 32 | double max_step = 0.1; 33 | pos_(0) += max_step*act_x; 34 | pos_(1) += max_step*act_y; 35 | 36 | state_ << pos_, goal_; 37 | 38 | torch::Tensor state = State(); 39 | torch::Tensor done = torch::zeros({1, 1}, torch::kF64); 40 | STATUS status; 41 | 42 | if (GoalDist(pos_) < 6e-1) { 43 | status = WON; 44 | done[0][0] = 1.; 45 | } 46 | else if (GoalDist(pos_) > 1e1) { 47 | status = LOST; 48 | done[0][0] = 1.; 49 | } 50 | else { 51 | status = PLAYING; 52 | done[0][0] = 0.; 53 | } 54 | 55 | return std::make_tuple(state, status, done); 56 | } 57 | auto State() -> torch::Tensor 58 | { 59 | torch::Tensor state = torch::zeros({1, state_.size()}, torch::kF64); 60 | std::memcpy(state.data_ptr(), state_.data(), state_.size()*sizeof(double)); 61 | return state; 62 | } 63 | auto Reward(int status) -> torch::Tensor 64 | { 65 | torch::Tensor reward = torch::full({1, 1}, old_dist_ - GoalDist(pos_), torch::kF64); 66 | 67 | switch (status) 68 | { 69 | case PLAYING: 70 | break; 71 | case WON: 72 | reward[0][0] += 10.; 73 | printf("won, reward: %f\n", reward[0][0].item()); 74 | break; 75 | case LOST: 76 | reward[0][0] -= 10.; 77 | printf("lost, reward: %f\n", reward[0][0].item()); 78 | break; 79 | } 80 | 81 | return reward; 82 | } 83 | double GoalDist(Eigen::Vector2d& x) 84 | { 85 | return (goal_ - x).norm(); 86 | } 87 | void Reset() 88 | { 89 | pos_.setZero(); 90 | state_ << pos_, goal_; 91 | } 92 | void SetGoal(double x, double y) 93 | { 94 | goal_(0) = x; 95 | goal_(1) = y; 96 | 97 | old_dist_ = GoalDist(pos_); 98 | state_ << pos_, goal_; 99 | } 100 | }; 101 | -------------------------------------------------------------------------------- /TestPPO.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "ProximalPolicyOptimization.h" 5 | #include "Models.h" 6 | #include "TestEnvironment.h" 7 | 8 | int main() { 9 | 10 | // Random engine. 11 | std::random_device rd; 12 | std::mt19937 re(rd()); 13 | std::uniform_int_distribution<> dist(-5, 5); 14 | 15 | // Environment. 16 | double x = double(dist(re)); // goal x pos 17 | double y = double(dist(re)); // goal y pos 18 | TestEnvironment env(x, y); 19 | 20 | // Model. 21 | uint n_in = 4; 22 | uint n_out = 2; 23 | double std = 1e-2; 24 | 25 | ActorCritic ac(n_in, n_out, std); 26 | ac->to(torch::kF64); 27 | ac->normal(0., std); 28 | ac->eval(); 29 | torch::load(ac, "best_model.pt"); 30 | 31 | // Training loop. 32 | uint n_iter = 10000; 33 | 34 | // Output. 35 | std::ofstream out; 36 | out.open("../data/data_test.csv"); 37 | 38 | // episode, agent_x, agent_y, goal_x, goal_y, STATUS=(PLAYING, WON, LOST, RESETTING) 39 | out << 1 << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << RESETTING << "\n"; 40 | 41 | // Counter. 42 | uint c = 0; 43 | 44 | for (uint i=0;iforward(env.State()); 48 | auto action = std::get<0>(av); 49 | 50 | double x_act = action[0][0].item(); 51 | double y_act = action[0][1].item(); 52 | auto sd = env.Act(x_act, y_act); 53 | 54 | // Check for done state. 55 | auto done = std::get<2>(sd); 56 | 57 | // episode, agent_x, agent_y, goal_x, goal_y, AGENT=(PLAYING, WON, LOST, RESETTING) 58 | out << 1 << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << std::get<1>(sd) << "\n"; 59 | 60 | if (done[0][0].item() == 1.) 61 | { 62 | // Set new goal. 63 | double x_new = double(dist(re)); 64 | double y_new = double(dist(re)); 65 | env.SetGoal(x_new, y_new); 66 | 67 | // Reset the position of the agent. 68 | env.Reset(); 69 | 70 | // episode, agent_x, agent_y, goal_x, goal_y, STATUS=(PLAYING, WON, LOST, RESETTING) 71 | out << 1 << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << RESETTING << "\n"; 72 | } 73 | } 74 | 75 | out.close(); 76 | 77 | return 0; 78 | } 79 | -------------------------------------------------------------------------------- /TrainPPO.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "ProximalPolicyOptimization.h" 6 | #include "Models.h" 7 | #include "TestEnvironment.h" 8 | 9 | int main() { 10 | 11 | // Random engine. 12 | std::random_device rd; 13 | std::mt19937 re(rd()); 14 | std::uniform_int_distribution<> dist(-5, 5); 15 | 16 | // Environment. 17 | double x = double(dist(re)); // goal x pos 18 | double y = double(dist(re)); // goal y pos 19 | TestEnvironment env(x, y); 20 | 21 | // Model. 22 | uint n_in = 4; 23 | uint n_out = 2; 24 | double std = 2e-2; 25 | 26 | ActorCritic ac(n_in, n_out, std); 27 | ac->to(torch::kF64); 28 | ac->normal(0., std); 29 | torch::optim::Adam opt(ac->parameters(), 1e-3); 30 | 31 | // Training loop. 32 | uint n_iter = 10000; 33 | uint n_steps = 2048; 34 | uint n_epochs = 15; 35 | uint mini_batch_size = 512; 36 | uint ppo_epochs = 4; 37 | double beta = 1e-3; 38 | 39 | VT states; 40 | VT actions; 41 | VT rewards; 42 | VT dones; 43 | 44 | VT log_probs; 45 | VT returns; 46 | VT values; 47 | 48 | // Output. 49 | std::ofstream out; 50 | out.open("../data/data.csv"); 51 | 52 | // episode, agent_x, agent_y, goal_x, goal_y, STATUS=(PLAYING, WON, LOST, RESETTING) 53 | out << 1 << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << RESETTING << "\n"; 54 | 55 | // Counter. 56 | uint c = 0; 57 | 58 | // Average reward. 59 | double best_avg_reward = 0.; 60 | double avg_reward = 0.; 61 | 62 | for (uint e=1;e<=n_epochs;e++) 63 | { 64 | printf("epoch %u/%u\n", e, n_epochs); 65 | 66 | for (uint i=0;iforward(states[c]); 73 | actions.push_back(std::get<0>(av)); 74 | values.push_back(std::get<1>(av)); 75 | log_probs.push_back(ac->log_prob(actions[c])); 76 | 77 | double x_act = actions[c][0][0].item(); 78 | double y_act = actions[c][0][1].item(); 79 | auto sd = env.Act(x_act, y_act); 80 | 81 | // New state. 82 | rewards.push_back(env.Reward(std::get<1>(sd))); 83 | dones.push_back(std::get<2>(sd)); 84 | 85 | avg_reward += rewards[c][0][0].item()/n_iter; 86 | 87 | // episode, agent_x, agent_y, goal_x, goal_y, AGENT=(PLAYING, WON, LOST, RESETTING) 88 | out << e << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << std::get<1>(sd) << "\n"; 89 | 90 | if (dones[c][0][0].item() == 1.) 91 | { 92 | // Set new goal. 93 | double x_new = double(dist(re)); 94 | double y_new = double(dist(re)); 95 | env.SetGoal(x_new, y_new); 96 | 97 | // Reset the position of the agent. 98 | env.Reset(); 99 | 100 | // episode, agent_x, agent_y, goal_x, goal_y, STATUS=(PLAYING, WON, LOST, RESETTING) 101 | out << e << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << RESETTING << "\n"; 102 | } 103 | 104 | c++; 105 | 106 | // Update. 107 | if (c%n_steps == 0) 108 | { 109 | printf("Updating the network.\n"); 110 | values.push_back(std::get<1>(ac->forward(states[c-1]))); 111 | 112 | returns = PPO::returns(rewards, dones, values, .99, .95); 113 | 114 | torch::Tensor t_log_probs = torch::cat(log_probs).detach(); 115 | torch::Tensor t_returns = torch::cat(returns).detach(); 116 | torch::Tensor t_values = torch::cat(values).detach(); 117 | torch::Tensor t_states = torch::cat(states); 118 | torch::Tensor t_actions = torch::cat(actions); 119 | torch::Tensor t_advantages = t_returns - t_values.slice(0, 0, n_steps); 120 | 121 | PPO::update(ac, t_states, t_actions, t_log_probs, t_returns, t_advantages, opt, n_steps, ppo_epochs, mini_batch_size, beta); 122 | 123 | c = 0; 124 | 125 | states.clear(); 126 | actions.clear(); 127 | rewards.clear(); 128 | dones.clear(); 129 | 130 | log_probs.clear(); 131 | returns.clear(); 132 | values.clear(); 133 | } 134 | } 135 | 136 | // Save the best net. 137 | if (avg_reward > best_avg_reward) { 138 | 139 | best_avg_reward = avg_reward; 140 | printf("Best average reward: %f\n", best_avg_reward); 141 | torch::save(ac, "best_model.pt"); 142 | } 143 | 144 | avg_reward = 0.; 145 | 146 | // Reset at the end of an epoch. 147 | double x_new = double(dist(re)); 148 | double y_new = double(dist(re)); 149 | env.SetGoal(x_new, y_new); 150 | 151 | // Reset the position of the agent. 152 | env.Reset(); 153 | 154 | // episode, agent_x, agent_y, goal_x, goal_y, STATUS=(PLAYING, WON, LOST, RESETTING) 155 | out << e << ", " << env.pos_(0) << ", " << env.pos_(1) << ", " << env.goal_(0) << ", " << env.goal_(1) << ", " << RESETTING << "\n"; 156 | } 157 | 158 | out.close(); 159 | 160 | return 0; 161 | } 162 | -------------------------------------------------------------------------------- /img/epoch_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhubii/ppo_libtorch/7e40b47da56dbd94eaf03f669740a8b5bb138f95/img/epoch_1.gif -------------------------------------------------------------------------------- /img/epoch_10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhubii/ppo_libtorch/7e40b47da56dbd94eaf03f669740a8b5bb138f95/img/epoch_10.gif -------------------------------------------------------------------------------- /img/test_mode.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhubii/ppo_libtorch/7e40b47da56dbd94eaf03f669740a8b5bb138f95/img/test_mode.gif -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import matplotlib.animation as animation 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--csv_file", type=str, default="data/data_test.csv", 12 | help="Path to the generated trajectories.") 13 | parser.add_argument("--epochs", nargs="+", default="0", 14 | help="Epochs to be plotted.") 15 | parser.add_argument("--online_view", action="store_true", 16 | help="Whether to show online view or generate gif.") 17 | parser.add_argument("--output_path", type=str, default="img", 18 | help="The path to write generated gifs to.") 19 | parser.add_argument("--output_file", type=str, 20 | default="test", help="The prefix of the gif.") 21 | args = parser.parse_args() 22 | 23 | # create output path 24 | path = pathlib.Path(args.output_path) 25 | if not args.online_view: 26 | if not path.exists(): 27 | path.mkdir(parents=True) 28 | 29 | # get data 30 | data = np.genfromtxt(args.csv_file, delimiter=",") 31 | 32 | fig, ax = plt.subplots() 33 | 34 | # setup all plots 35 | # spawn of the agent 36 | ax.plot(0, 0, "x", c="black", label="Spawn") 37 | 38 | # adding a circle around the goal that indicates maximum distance to goal before the environment gets reset 39 | circle = plt.Circle((data[0, 3], data[0, 4]), 10, linestyle="--", 40 | color="gray", fill=False, label="Maximum Goal Distance") 41 | ax.add_patch(circle) 42 | 43 | agent, = ax.plot(data[0, 1], data[0, 2], "o", 44 | c="b", label="Agent") # agent 45 | # small tail following the agent 46 | agent_line, = ax.plot(data[0, 1], data[0, 2], "-", c="b") 47 | goal, = ax.plot(data[0, 3], data[0, 4], "o", c="r", label="Goal") # goal 48 | 49 | # plot settings 50 | ax.set_xlabel("x / a.u.") 51 | ax.set_ylabel("y / a.u.") 52 | ax.set_xlim(-10, 10) 53 | ax.set_ylim(-10, 10) 54 | ax.set_title("Agent in Test Environment") 55 | ax.legend() 56 | title = ax.text(0.15, 0.85, "", bbox={"facecolor": "w", "alpha": 0.5, "pad": 5}, 57 | transform=ax.transAxes, ha="center") 58 | 59 | # plot everything 60 | for e in args.epochs: 61 | e = int(e) 62 | 63 | epoch_data = data[np.where(data[:, 0] == e)] 64 | 65 | # tail for the agent 66 | global tail, frame 67 | tail, frame = 0, 0 68 | 69 | def animate(i): 70 | global tail, frame 71 | agent.set_data(epoch_data[frame, 1], epoch_data[frame, 2]) 72 | # AGENT enum in main.cpp, 1, 2, 3 = WON, LOST, RESETTING 73 | if (epoch_data[frame, 5] in [1, 2, 3]): 74 | tail = 0 75 | agent_line.set_data( 76 | epoch_data[frame-tail:frame, 1], epoch_data[frame-tail:frame, 2]) 77 | if (tail < 50): 78 | tail += 1 79 | goal.set_data(epoch_data[frame, 3], epoch_data[frame, 4]) 80 | circle.center = (epoch_data[frame, 3], epoch_data[frame, 4]) 81 | title.set_text("Epoch {:1.0f}".format(epoch_data[frame, 0])) 82 | frame += 1 83 | return agent, agent_line, goal, circle, title 84 | 85 | ani = animation.FuncAnimation( 86 | fig, animate, blit=True, interval=5, frames=1000) 87 | if args.online_view: 88 | plt.show() 89 | else: 90 | ani.save(f"{path.absolute()}/{args.output_file}_{e}.gif", 91 | writer="imagemagick", fps=100) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | --------------------------------------------------------------------------------