├── .gitignore ├── Dockerfile ├── LICENSE ├── LICENSE.md ├── Makefile ├── README.md ├── assets ├── initial_positions.png └── tikz_setup.png ├── common ├── __init__.py └── utils.py ├── evaluate.py ├── evaluate.sh ├── evaluate_container.py ├── evaluate_container.sh ├── main.py ├── networks ├── __init__.py └── structures.py ├── notebooks ├── 3D_trajectories.ipynb ├── README.md ├── Results_pictures.ipynb ├── Robustness.ipynb └── figures │ ├── 4propellers_pwm.png │ ├── SAC_Line_3DPosition4.pdf │ ├── SAC_Line_Position_x.pdf │ ├── SAC_Line_Position_y.pdf │ ├── SAC_Line_Position_z.pdf │ ├── SAC_Senoid_3DPosition4.pdf │ ├── SAC_Senoid_Position_x.pdf │ ├── SAC_Senoid_Position_y.pdf │ ├── SAC_Senoid_Position_z.pdf │ ├── SAC_Square_3DPosition4.pdf │ ├── SAC_Square_Position_x.pdf │ ├── SAC_Square_Position_y.pdf │ ├── SAC_Square_Position_z.pdf │ ├── SAC_newaction_False_Re24_angx.png │ ├── SAC_newaction_False_Re24_angy.png │ ├── SAC_newaction_False_Re24_angz.png │ ├── SAC_newaction_False_Re24_posx.png │ ├── SAC_newaction_False_Re24_posy.png │ └── SAC_newaction_False_Re24_posz.png ├── requirements.txt ├── saved_policies ├── sac_optimal_policy.pt └── sac_optimal_policy_2.pt └── training.sh /.gitignore: -------------------------------------------------------------------------------- 1 | ## Pip3 in the container is installing this src directory TODO: fix it 2 | src/ 3 | 4 | ## Random code 5 | idea/ 6 | idea 7 | ## jupyter checkpoints 8 | .ipynb_checkpoints 9 | 10 | ## saved models 11 | checkpoint/ 12 | 13 | 14 | *.swp 15 | # Directory of ideas in Jupyter Notebook 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | **/__pycache__ 19 | 20 | *.py[cod] 21 | .idea 22 | *.rdb 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *.pyc 27 | __pycache__/ 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | env/ 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.1-base-ubuntu18.04 2 | 3 | LABEL mantainer="gabrielmoraesbarros@gmail.com" \ 4 | version="0.1" 5 | 6 | RUN apt-get update && apt-get install -y \ 7 | wget \ 8 | libglib2.0-0 \ 9 | libgl1-mesa-glx \ 10 | xcb \ 11 | "^libxcb.*" \ 12 | libx11-xcb-dev \ 13 | libglu1-mesa-dev \ 14 | libxrender-dev \ 15 | libxi6 \ 16 | libdbus-1-3 \ 17 | libfontconfig1 \ 18 | xvfb \ 19 | && rm -rf /var/lib/apt/lists/* 20 | 21 | 22 | RUN apt-get update && apt-get install -y \ 23 | # flex \ 24 | # dh-make \ 25 | # debhelper \ 26 | checkinstall \ 27 | # fuse \ 28 | # snapcraft \ 29 | bison \ 30 | libxcursor-dev \ 31 | libxcomposite-dev \ 32 | software-properties-common \ 33 | build-essential \ 34 | libssl-dev \ 35 | libxcb1-dev \ 36 | libx11-dev \ 37 | libgl1-mesa-dev \ 38 | libudev-dev \ 39 | qt5-default \ 40 | qttools5-dev \ 41 | qtdeclarative5-dev \ 42 | qtpositioning5-dev \ 43 | qtbase5-dev \ 44 | python3-pip \ 45 | git \ 46 | vim \ 47 | wget 48 | 49 | WORKDIR "/" 50 | 51 | RUN wget -q http://coppeliarobotics.com/files/CoppeliaSim_Edu_V4_0_0_Ubuntu18_04.tar.xz 52 | RUN tar -xf CoppeliaSim_Edu_V4_0_0_Ubuntu18_04.tar.xz 53 | RUN rm -rf CoppeliaSim_Edu_V4_0_0_Ubuntu18_04.tar.xz 54 | 55 | 56 | RUN echo 'export QT_DEBUG_PLUGINS=1' >> ~/.bashrc 57 | RUN echo 'export PATH=/CoppeliaSim_Edu_V4_0_0_Ubuntu18_04/:$PATH' >> ~/.bashrc 58 | 59 | 60 | 61 | WORKDIR "/" 62 | 63 | RUN git clone https://github.com/stepjam/PyRep.git 64 | 65 | WORKDIR "/PyRep" 66 | 67 | 68 | RUN pip3 install -r requirements.txt 69 | 70 | RUN echo 'export COPPELIASIM_ROOT=/CoppeliaSim_Edu_V4_0_0_Ubuntu18_04/' >> ~/.bashrc 71 | RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT' >> ~/.bashrc 72 | RUN echo 'export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT' >> ~/.bashrc 73 | 74 | ARG COPPELIASIM_ROOT=/CoppeliaSim_Edu_V4_0_0_Ubuntu18_04/ 75 | 76 | RUN echo "$COPPELIASIM_ROOT" 77 | RUN python3 setup.py install 78 | 79 | 80 | 81 | # # RUN wget https://files.pythonhosted.org/packages/24/19/4804aea17cd136f1705a5e98a00618cb8f6ccc375ad8bfa437408e09d058/torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl 82 | RUN wget -v https://files.pythonhosted.org/packages/38/53/914885a93a44b96c0dd1c36f36ff10afe341f091230aad68f7228d61db1e/torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl 83 | 84 | 85 | 86 | RUN pip3 install torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl 87 | 88 | 89 | 90 | ENV LANG C.UTF-8 91 | 92 | WORKDIR '/home/' 93 | 94 | # COPY "dockerrun.sh" /home/ 95 | RUN git clone https://github.com/larocs/Drone_RL.git 96 | 97 | WORKDIR '/home/Drone_RL' 98 | 99 | # RUN python3 setup.py install 100 | RUN pip3 install -e . 101 | WORKDIR '/home/' 102 | 103 | COPY requirements.txt /home/ 104 | 105 | RUN pip3 install -r requirements.txt 106 | WORKDIR '/home/sac_uav' 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LaRoCS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help lint clean init 2 | 3 | 4 | ################################################################################# 5 | # GLOBALS # 6 | ################################################################################# 7 | export DOCKER=docker 8 | export BASE_IMAGE_NAME=sac_uav 9 | export BASE_DOCKERFILE=Dockerfile 10 | export JUPYTER_HOST_PORT=8888 11 | export JUPYTER_CONTAINER_PORT=8888 12 | export CONTAINER_NAME=sac_uav-container 13 | ################################################################################# 14 | # COMMANDS # 15 | ################################################################################# 16 | 17 | ## Build docker image 18 | create-image: 19 | sudo $(DOCKER) build -t $(BASE_IMAGE_NAME) -f $(BASE_DOCKERFILE) --force-rm --build-arg UID=$(shell id -u) . 20 | 21 | ## Run docker container 22 | create-container: 23 | sudo $(DOCKER) run -it -v $(shell pwd):/home/sac_uav -p $(JUPYTER_HOST_PORT):$(JUPYTER_CONTAINER_PORT) --name $(CONTAINER_NAME) $(BASE_IMAGE_NAME) 24 | 25 | ## Start docker container. Attach if already started 26 | start-container: ## start docker container 27 | @echo "$$START_DOCKER_CONTAINER" | $(SHELL) 28 | sudo $(DOCKER) start $(CONTAINER_NAME) 29 | @echo "Launched $(CONTAINER_NAME)..." 30 | sudo $(DOCKER) attach $(CONTAINER_NAME) 31 | 32 | ## Stop active containers 33 | stop-container: ## Spin down active containers 34 | sudo $(DOCKER) container stop $(CONTAINER_NAME) 35 | 36 | ## Rm containers 37 | clean-container: ## remove Docker container 38 | sudo $(DOCKER) rm $(CONTAINER_NAME) 39 | 40 | ## Train your agent 41 | training: 42 | xvfb-run ./training.sh 43 | 44 | 45 | ## Evaluate with headless 46 | evaluate: 47 | ./evaluate.sh 48 | ## Evaluate with headless 49 | evaluate-container: 50 | xvfb-run ./evaluate_container.sh 51 | 52 | ## Rm images 53 | clean-image: ## remove Docker image 54 | sudo $(DOCKER) image rm $(IMAGE_NAME) 55 | 56 | ## Start Jupyter Notebook server. Inside the container 57 | jupyter: 58 | sudo docker exec -i $(CONTAINER_NAME) jupyter notebook --ip=0.0.0.0 --port=${JUPYTER_CONTAINER_PORT} 59 | 60 | ################################################################################# 61 | # Self Documenting Commands # 62 | ################################################################################# 63 | 64 | .DEFAULT_GOAL := help 65 | 66 | # Inspired by 67 | # sed script explained: 68 | # /^##/: 69 | # * save line in hold space 70 | # * purge line 71 | # * Loop: 72 | # * append newline + line to hold space 73 | # * go to next line 74 | # * if line starts with doc comment, strip comment character off and loop 75 | # * remove target prerequisites 76 | # * append hold space (+ newline) to line 77 | # * replace newline plus comments by `---` 78 | # * print line 79 | # Separate expressions are necessary because labels cannot be delimited by 80 | # semicolon; see 81 | # .PHONY: help 82 | help: 83 | @echo "$$(tput bold)Available rules:$$(tput sgr0)" 84 | @echo 85 | @sed -n -e "/^## / { \ 86 | h; \ 87 | s/.*//; \ 88 | :doc" \ 89 | -e "H; \ 90 | n; \ 91 | s/^## //; \ 92 | t doc" \ 93 | -e "s/:.*//; \ 94 | G; \ 95 | s/\\n## /---/; \ 96 | s/\\n/ /g; \ 97 | p; \ 98 | }" ${MAKEFILE_LIST} \ 99 | | LC_ALL='C' sort --ignore-case \ 100 | | awk -F '---' \ 101 | -v ncol=$$(tput cols) \ 102 | -v indent=19 \ 103 | -v col_on="$$(tput setaf 6)" \ 104 | -v col_off="$$(tput sgr0)" \ 105 | '{ \ 106 | printf "%s%*s%s ", col_on, -indent, $$1, col_off; \ 107 | n = split($$2, words, " "); \ 108 | line_length = ncol - indent; \ 109 | for (i = 1; i <= n; i++) { \ 110 | line_length -= length(words[i]) + 1; \ 111 | if (line_length <= 0) { \ 112 | line_length = ncol - indent - length(words[i]) - 1; \ 113 | printf "\n%*s ", -indent, " "; \ 114 | } \ 115 | printf "%s ", words[i]; \ 116 | } \ 117 | printf "\n"; \ 118 | }' \ 119 | | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Using Soft Actor-Critic for Low-Level UAV Control 2 | 3 | This repository is the official implementation of [Using Soft Actor-Critic for Low-Level UAV Control](https://arxiv.org/abs/2010.02293). This work will be presented in the IROS 2020 Workshop - "Perception, Learning, and Control for Autonomous Agile Vehicles". 4 | 5 | We train a policy using Soft Actor-Critic to control a UAV. This agent is dropped in the air, with a sampled distance and inclination from the target (the green sphere in the [0,0,0] position), and has to get as close as possible to the target. In our experiments the target always has the position = [0,0,0] and angular velocity = [0,0,0]. 6 | 7 | **Watch the video** 8 | 9 | 10 | 11 | 12 | 13 |

14 | 15 | homepage 16 | 17 |

18 | 19 | 20 | **Framework** 21 | It is a traditional RL env that accesses the Pyrep plugin, which accesses Coppelia Simulator API. It is a lot faster than using the Remote API of Coppelia Simulator, and you also have access to a simpler API for manipulating/creating objects inside your running simulation. 22 | 23 | 24 | 25 |

26 | 27 |

28 | 29 | **Initial positions for the UAV agent** 30 | 31 | 32 | 33 |

34 | 35 |

36 | 37 | 38 | 39 | ## Requirements/Installing 40 | 41 | ### Docker 42 | 43 | One of the safest ways to emulate our environment is by using a Docker container. This approach is better to train in a cluster and have a stable environment, although forwarding the display server with Docker is always tricky (we leave this one to the reader). 44 | 45 | Change the container's variables and then use the Makefile to make it easier to use our Docker Container. The commands are self-explanatory. 46 | 47 | **create-image** 48 | 49 | ```creating-image 50 | make create-image 51 | ``` 52 | 53 | **create-container** 54 | 55 | ```creating-container 56 | make create-container 57 | ``` 58 | **training** 59 | 60 | ```training-an-agent 61 | make training 62 | ``` 63 | 64 | **evaluate-container** 65 | 66 | ```evaluate 67 | make evaluate-container 68 | ``` 69 | 70 | 71 | 72 | ### Without-Docker 73 | 74 | 1) Install Coppelia [Coppelia Simulator](https://www.coppeliarobotics.com/) 75 | 2) Install Pyrep [Pyrep](https://github.com/stepjam/PyRep) 76 | 3) Install Drone_RL [Drone_RL](https://github.com/larocs/Drone_RL) 77 | 4)To install requirements: 78 | 79 | ```setup 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | 4) To install this repo: 84 | 85 | ```setup 86 | python setup.py install 87 | ``` 88 | 89 | 90 | 92 | 93 | ## Training 94 | 95 | To train the model(s) in the paper, run this command: 96 | 97 | ```train 98 | ./training.sh 99 | ``` 100 | 101 | 102 | *Is somewhat tricky to train an exact policy, because that is a variability inherent to off-policy models and reward-shaping to achieve optimal control politics for Robotics.* 103 | 104 | *One hack that alleviates this problem is save something like a moving-window of say 5-10 policies and pick the best one (qualitatively) after a particular reward stabilization. More research is needed to alleviate the need for qualitative assessment of the trained policies.* 105 | 106 | 107 | ## Evaluation 108 | 109 | To evaluate my model with the optimal policy, run: 110 | 111 | ```eval 112 | ./evaluate.sh 113 | ``` 114 | 115 | 116 | 117 | ## Pre-trained Models 118 | 119 | You can check the saved trained policies in: 120 | 121 | - [saved_policies/](saved_policies/) 122 | 123 | 124 | 125 | 126 | 127 | ## Results 128 | 129 | Run the notebooks on [notebooks/](notebooks/) to check the images presented on the paper. 130 | 131 | [results](notebooks/README.md) 132 | 133 | ## Credits 134 | 135 | Code heavily based in [RL-Adventure-2](https://github.com/higgsfield/RL-Adventure-2) 136 | 137 | The environment is a continuation of the work in: 138 | 139 | G. Lopes, M. Ferreira, A. Sim ̃oes, and E. Colombini, “Intelligent Control of a Quadrotor with Proximal Policy Optimization,”Latin American Robotic Symposium, pp. 503–508, 11 2018 140 | 141 | 149 | 150 | ## License 151 | 152 | [MIT-LICENSE](License.md) 153 | 154 | ## Cite us 155 | 156 | Barros, Gabriel M.; Colombini, Esther L, "Using Soft Actor-Critic for Low-Level UAV Control", *IROS - workshop Perception, Learning, and Control for Autonomous Agile Vehicles*, 2020. 157 | 158 | @misc{barros2020using, 159 | title={Using Soft Actor-Critic for Low-Level UAV Control}, 160 | author={Gabriel Moraes Barros and Esther Luna Colombini}, 161 | year={2020}, 162 | eprint={2010.02293}, 163 | archivePrefix={arXiv}, 164 | primaryClass={cs.RO}, 165 | journal={IROS - workshop "Perception, Learning, and Control for Autonomous Agile Vehicles"}, 166 | } 167 | 168 | -------------------------------------------------------------------------------- /assets/initial_positions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/assets/initial_positions.png -------------------------------------------------------------------------------- /assets/tikz_setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/assets/tikz_setup.png -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/common/__init__.py -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gym 3 | import random 4 | import numpy as np 5 | # import matplotlib.pyplot as plt 6 | import os 7 | 8 | from networks.structures import PolicyNetwork, ValueNetwork, SoftQNetwork 9 | 10 | 11 | use_cuda = torch.cuda.is_available() 12 | device = torch.device("cuda" if use_cuda else "cpu") 13 | 14 | 15 | class ReplayBuffer: 16 | """ 17 | A experience buffer used to store and replay data 18 | 19 | Parameters 20 | ---------- 21 | capacity : [int] 22 | The max size of the buffer 23 | """ 24 | 25 | def __init__(self, capacity): 26 | self.capacity = capacity 27 | self.buffer = [] 28 | self.position = 0 29 | 30 | def push(self, state, action, reward, next_state, done): 31 | """ 32 | Add data to the buffer 33 | """ 34 | if len(self.buffer) < self.capacity: 35 | self.buffer.append(None) 36 | self.buffer[self.position] = (state, action, reward, next_state, done) 37 | self.position = (self.position + 1) % self.capacity 38 | 39 | # def sample(self, batch_size): 40 | # """ 41 | # Sample a random batch of memmory data from the buffer 42 | # ---------- 43 | # batch_size : [int] 44 | # The size of the batch 45 | 46 | # Returns 47 | # ------- 48 | # [list] 49 | # A batch of rollout experience 50 | # """ 51 | # batch = random.sample(self.buffer, batch_size) 52 | # state, action, reward, next_state, done = map(np.stack, zip(*batch)) 53 | # return state, action, reward, next_state, done 54 | 55 | def sample(self, batch_size): # this version is significantly faster 56 | """ 57 | Sample a random batch of memmory data from the buffer 58 | ---------- 59 | batch_size : [int] 60 | The size of the batch 61 | 62 | Returns 63 | ------- 64 | [list] 65 | A batch of rollout experience 66 | """ 67 | 68 | batch = random.sample(self.buffer, batch_size) 69 | state = np.array([elem[0] for elem in batch], dtype=np.double) 70 | action = np.array([elem[1] for elem in batch], dtype=np.double) 71 | reward = np.array([elem[2] for elem in batch], dtype=np.double) 72 | next_state = np.array([elem[3] for elem in batch], dtype=np.double) 73 | done = np.array([elem[4] for elem in batch], dtype=np.double) 74 | 75 | return state, action, reward, next_state, done 76 | 77 | def __len__(self): 78 | """ 79 | Size of the buffer 80 | """ 81 | return len(self.buffer) 82 | 83 | 84 | def check_dir(file_name): 85 | """ 86 | Checking if directory path exists 87 | """ 88 | directory = os.path.dirname(file_name) 89 | if not os.path.exists(directory): 90 | os.makedirs(directory) 91 | 92 | 93 | def restore_data(restore_path): 94 | """ 95 | Restore data to re-load training 96 | 97 | Parameters 98 | ---------- 99 | restore_path : [str] 100 | File path of the saved data 101 | """ 102 | try: 103 | checkpoint = torch.load(restore_path + '/state.pt') 104 | # checkpoint = torch.load(restore_path) 105 | 106 | # Episode and frames 107 | episode = checkpoint['episode'] 108 | frame_count = checkpoint['frame_count'] 109 | # Models 110 | value_net.load_state_dict(checkpoint['value_net']) 111 | target_value_net.load_state_dict(checkpoint['target_value_net']) 112 | soft_q_net.load_state_dict(checkpoint['soft_q_net']) 113 | policy_net.load_state_dict(checkpoint['policy_net']) 114 | # Optimizers 115 | value_optimizer.load_state_dict(checkpoint['value_optimizer']) 116 | soft_q_optimizer.load_state_dict(checkpoint['soft_q_optimizer']) 117 | policy_optimizer.load_state_dict(checkpoint['policy_optimizer']) 118 | replay_buffer = checkpoint['replay_buffer'] 119 | except BaseException: 120 | print('Não foi possível carregar um modelo pré-existente') 121 | 122 | 123 | def terminate(): 124 | """ 125 | Helper function to proper close the process and the Coppelia Simulator 126 | """ 127 | try: 128 | env.shutdown() 129 | import sys 130 | sys.exit(0) 131 | except BaseException: 132 | import sys 133 | sys.exit(0) 134 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dill 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import pandas as pd 8 | 9 | from sim_framework.envs.drone_env import DroneEnv 10 | 11 | from networks.structures import PolicyNetwork, ValueNetwork, SoftQNetwork 12 | 13 | import pyrep.backend.sim as sim 14 | 15 | 16 | def rollouts( 17 | env, 18 | policy, 19 | action_range, 20 | device, 21 | max_timesteps=1000, 22 | time_horizon=250): 23 | """ 24 | Perform policy rollouts until a max given number of steps 25 | 26 | Parameters 27 | ---------- 28 | env : 29 | A larocs_sim environment 30 | policy : 31 | An actor-policy for the agent act in the environment 32 | action_range : list 33 | Range of possible float values for the action 34 | max_timesteps : int, optional 35 | Number of timesteps to perform while interacting with the environment, by default 1000 36 | time_horizon : int, optional 37 | The number of steps for each episode, by default 250 38 | 39 | """ 40 | count = 0 41 | dones = False 42 | set_of_obs, set_of_next_obs, set_of_rewards, set_of_actions, set_of_dones, set_of_infos = [], [], [], [], [], [] 43 | 44 | rollout = -1 45 | 46 | while True: 47 | mb_obs, mb_next_obs, mb_rewards, mb_actions, mb_dones, mb_infos = [], [], [], [], [], [] 48 | sim.simRemoveBanner(sim.sim_handle_all) 49 | rollout += 1 50 | 51 | obs0 = env.reset() 52 | 53 | sim.simAddBanner(label="Rollout = {0}".format(rollout).encode('ascii'), 54 | size=0.2, 55 | options=1, 56 | positionAndEulerAngles=[0, 0, 2.5, 1.3, 0, 0], 57 | parentObjectHandle=-1) 58 | 59 | for j in range(time_horizon): 60 | dones = False 61 | if count == max_timesteps: 62 | set_tau = {'obs': set_of_obs, 63 | 'next_obs': set_of_next_obs, 64 | 'rewards': set_of_rewards, 65 | 'actions': set_of_actions, 66 | 'dones': set_of_dones, 67 | 'infos': set_of_infos} 68 | return set_tau 69 | try: 70 | actions, agent_info = policy.deterministic_action( 71 | state_to_tensor(obs0, device)) 72 | except BaseException: 73 | actions = policy.deterministic_action( 74 | state_to_tensor(obs0, device)) 75 | 76 | # Take actions in env and look the results 77 | obs1, rewards, dones, infos = env.step(actions * action_range[1]) 78 | # Append on the experience buffers 79 | mb_obs.append(obs0.copy()) 80 | # mb_obs.append(obs0) 81 | mb_next_obs.append(obs1) 82 | mb_actions.append(actions) 83 | mb_dones.append(dones) 84 | mb_rewards.append(rewards) 85 | mb_infos.append(infos) 86 | 87 | count += 1 88 | 89 | if dones: 90 | break 91 | 92 | obs0 = obs1 93 | 94 | print() 95 | print('rewards: mean = {0}'.format(np.mean(mb_rewards))) 96 | print('rewards: sum = {0}'.format(np.sum(mb_rewards))) 97 | 98 | set_of_obs.append(mb_obs) 99 | set_of_next_obs.append(mb_next_obs) 100 | set_of_rewards.append(mb_rewards) 101 | set_of_actions.append(mb_actions) 102 | set_of_dones.append(mb_dones) 103 | set_of_infos.append(mb_infos) 104 | 105 | 106 | def run_policy(args): 107 | """ 108 | Loads a and evaluates a trained policy 109 | 110 | Parameters 111 | ---------- 112 | args : [dict] 113 | Users arguments with the options for the framework 114 | """ 115 | 116 | use_cuda = torch.cuda.is_available() 117 | if use_cuda and (args.use_cuda): 118 | device = torch.device("cuda") 119 | else: 120 | device = torch.device("cpu") 121 | 122 | # Set environment 123 | env = DroneEnv( 124 | random=args.env_reset_mode, 125 | headless=args.headless, 126 | seed=args.seed, 127 | reward_function_name=args.reward_function, 128 | state=args.state) 129 | 130 | restore_path = args.file 131 | print('Loading') 132 | # Load parameters if necessary 133 | try: 134 | checkpoint = torch.load(restore_path, map_location='cpu') 135 | except BaseException: 136 | checkpoint = torch.load(restore_path, map_location=torch.device('cpu')) 137 | print('Finished Loading') 138 | 139 | # Neural network parameters 140 | try: 141 | state_dim = env.observation_space.shape[0] 142 | except BaseException: 143 | state_dim = env.observation_space 144 | action_dim = env.action_space.shape[0] 145 | hidden_dim = checkpoint['linear1.weight'].data.shape[0] 146 | action_range = [env.agent.action_space.low.min( 147 | ), env.agent.action_space.high.max()] 148 | size_obs = checkpoint['linear1.weight'].data.shape[1] 149 | 150 | assert size_obs == state_dim, 'Checkpoint state must be the same as the env' 151 | 152 | # Networks instantiation 153 | policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device) 154 | 155 | # Loading Models 156 | policy_net.load_state_dict(checkpoint) 157 | print('Finished Loading the weights') 158 | 159 | print("Running the policy...") 160 | set_tau = rollouts( 161 | env, 162 | policy_net, 163 | action_range, 164 | device, 165 | max_timesteps=args.max_timesteps, 166 | time_horizon=args.H) 167 | 168 | print('Closing env') 169 | env.shutdown() 170 | 171 | 172 | def state_to_tensor(state, device): 173 | """Transform numpy array to torch tensor""" 174 | if args.use_double: 175 | return torch.DoubleTensor(state).unsqueeze(0).to(device) 176 | else: 177 | return torch.FloatTensor(state).unsqueeze(0).to(device) 178 | 179 | 180 | if __name__ == "__main__": 181 | 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--file', type=str, 184 | help='path to the snapshot file') 185 | parser.add_argument('--H', type=int, default=250, 186 | help='Max length of rollout') 187 | parser.add_argument('--max_timesteps', type=int, default=1000, 188 | help='Max number of timesteps') 189 | parser.add_argument('--gpu', action='store_true', default=False) 190 | parser.add_argument( 191 | '--headless', 192 | help='To render or not the environment', 193 | choices=( 194 | 'True', 195 | 'False'), 196 | default='True') 197 | parser.add_argument( 198 | '--env_reset_mode', 199 | help='How to sample the starting position of the agent', 200 | choices=( 201 | 'Uniform', 202 | 'Gaussian', 203 | 'False', 204 | 'Discretized_Uniform'), 205 | default='False') 206 | parser.add_argument( 207 | '--seed', help='Global seed', default=42, type=int) 208 | parser.add_argument( 209 | '--reward_function', 210 | help='What reward function to use', 211 | default='Normal', 212 | type=str) 213 | parser.add_argument( 214 | '--state', help='State to be used', default='Old', type=str) 215 | parser.add_argument( 216 | '--use_double', help='Flag to use float64', type=str, default=None) 217 | 218 | args = parser.parse_args() 219 | 220 | if (args.headless) == 'False': 221 | args.headless = False 222 | else: 223 | args.headless = True 224 | if (args.env_reset_mode) == 'False': 225 | args.env_reset_mode = False 226 | 227 | run_policy(args) 228 | print("Done") 229 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Path to the trained policy 4 | # SAVED_POLICY=saved_policies/sac_optimal_policy.pt 5 | SAVED_POLICY=saved_policies/sac_optimal_policy_2.pt 6 | 7 | ## Reward function 8 | REWARD_FUNCTION=Reward_24 9 | ## Initial position probability distribution (\rho_{o}) 10 | env_reset_mode=Discretized_Uniform 11 | ## State-Space 12 | STATE=New_action 13 | 14 | 15 | python3 evaluate.py --headless=False --reward_function=${REWARD_FUNCTION} --state=${STATE} --file=${SAVED_POLICY} \ 16 | --env_reset_mode=${env_reset_mode} 17 | 18 | -------------------------------------------------------------------------------- /evaluate_container.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dill 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import pandas as pd 8 | 9 | from sim_framework.envs.drone_env import DroneEnv 10 | 11 | from networks.structures import PolicyNetwork, ValueNetwork, SoftQNetwork 12 | 13 | import pyrep.backend.sim as sim 14 | 15 | 16 | def rollouts( 17 | env, 18 | policy, 19 | action_range, 20 | device, 21 | max_timesteps=1000, 22 | time_horizon=250): 23 | """ 24 | Perform policy rollouts until a max given number of steps 25 | 26 | Parameters 27 | ---------- 28 | env : 29 | A larocs_sim environment 30 | policy : 31 | An actor-policy for the agent act in the environment 32 | action_range : list 33 | Range of possible float values for the action 34 | max_timesteps : int, optional 35 | Number of timesteps to perform while interacting with the environment, by default 1000 36 | time_horizon : int, optional 37 | The number of steps for each episode, by default 250 38 | 39 | """ 40 | count = 0 41 | dones = False 42 | set_of_obs, set_of_next_obs, set_of_rewards, set_of_actions, set_of_dones, set_of_infos = [], [], [], [], [], [] 43 | 44 | rollout = -1 45 | 46 | while True: 47 | mb_obs, mb_next_obs, mb_rewards, mb_actions, mb_dones, mb_infos = [], [], [], [], [], [] 48 | # sim.simRemoveBanner(sim.sim_handle_all) 49 | rollout += 1 50 | 51 | obs0 = env.reset() 52 | 53 | # sim.simAddBanner(label = "Rollout = {0}".format(rollout).encode('ascii'),\ 54 | # size = 0.2,\ 55 | # options = 1, 56 | # positionAndEulerAngles=[0,0,2.5,1.3,0,0], 57 | # parentObjectHandle = -1) 58 | 59 | for j in range(time_horizon): 60 | dones = False 61 | if count == max_timesteps: 62 | set_tau = {'obs': set_of_obs, 63 | 'next_obs': set_of_next_obs, 64 | 'rewards': set_of_rewards, 65 | 'actions': set_of_actions, 66 | 'dones': set_of_dones, 67 | 'infos': set_of_infos} 68 | return set_tau 69 | try: 70 | actions, agent_info = policy.deterministic_action( 71 | state_to_tensor(obs0, device)) 72 | except BaseException: 73 | actions = policy.deterministic_action( 74 | state_to_tensor(obs0, device)) 75 | 76 | # Take actions in env and look the results 77 | obs1, rewards, dones, infos = env.step(actions * action_range[1]) 78 | # Append on the experience buffers 79 | mb_obs.append(obs0.copy()) 80 | # mb_obs.append(obs0) 81 | mb_next_obs.append(obs1) 82 | mb_actions.append(actions) 83 | mb_dones.append(dones) 84 | mb_rewards.append(rewards) 85 | mb_infos.append(infos) 86 | 87 | count += 1 88 | 89 | if dones: 90 | break 91 | 92 | obs0 = obs1 93 | 94 | print() 95 | print('rewards: mean = {0}'.format(np.mean(mb_rewards))) 96 | print('rewards: sum = {0}'.format(np.sum(mb_rewards))) 97 | 98 | set_of_obs.append(mb_obs) 99 | set_of_next_obs.append(mb_next_obs) 100 | set_of_rewards.append(mb_rewards) 101 | set_of_actions.append(mb_actions) 102 | set_of_dones.append(mb_dones) 103 | set_of_infos.append(mb_infos) 104 | 105 | 106 | def run_policy(args): 107 | """ 108 | Loads a and evaluates a trained policy 109 | 110 | Parameters 111 | ---------- 112 | args : [dict] 113 | Users arguments with the options for the framework 114 | """ 115 | 116 | use_cuda = torch.cuda.is_available() 117 | if use_cuda and (args.use_cuda): 118 | device = torch.device("cuda") 119 | else: 120 | device = torch.device("cpu") 121 | 122 | # Set environment 123 | env = DroneEnv( 124 | random=args.env_reset_mode, 125 | headless=args.headless, 126 | seed=args.seed, 127 | reward_function_name=args.reward_function, 128 | state=args.state) 129 | 130 | restore_path = args.file 131 | print('Loading') 132 | # Load parameters if necessary 133 | try: 134 | checkpoint = torch.load(restore_path, map_location='cpu') 135 | except BaseException: 136 | checkpoint = torch.load(restore_path, map_location=torch.device('cpu')) 137 | print('Finished Loading') 138 | 139 | # Neural network parameters 140 | try: 141 | state_dim = env.observation_space.shape[0] 142 | except BaseException: 143 | state_dim = env.observation_space 144 | action_dim = env.action_space.shape[0] 145 | hidden_dim = checkpoint['linear1.weight'].data.shape[0] 146 | action_range = [env.agent.action_space.low.min( 147 | ), env.agent.action_space.high.max()] 148 | size_obs = checkpoint['linear1.weight'].data.shape[1] 149 | 150 | assert size_obs == state_dim, 'Checkpoint state must be the same as the env' 151 | 152 | # Networks instantiation 153 | policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device) 154 | 155 | # Loading Models 156 | policy_net.load_state_dict(checkpoint) 157 | print('Finished Loading the weights') 158 | 159 | print("Running the policy...") 160 | set_tau = rollouts( 161 | env, 162 | policy_net, 163 | action_range, 164 | device, 165 | max_timesteps=args.max_timesteps, 166 | time_horizon=args.H) 167 | 168 | print('Closing env') 169 | env.shutdown() 170 | 171 | 172 | def state_to_tensor(state, device): 173 | """Transform numpy array to torch tensor""" 174 | if args.use_double: 175 | return torch.DoubleTensor(state).unsqueeze(0).to(device) 176 | else: 177 | return torch.FloatTensor(state).unsqueeze(0).to(device) 178 | 179 | 180 | if __name__ == "__main__": 181 | 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--file', type=str, 184 | help='path to the snapshot file') 185 | parser.add_argument('--H', type=int, default=250, 186 | help='Max length of rollout') 187 | parser.add_argument('--max_timesteps', type=int, default=1000, 188 | help='Max number of timesteps') 189 | parser.add_argument('--gpu', action='store_true', default=False) 190 | parser.add_argument( 191 | '--headless', 192 | help='To render or not the environment', 193 | choices=( 194 | 'True', 195 | 'False'), 196 | default='True') 197 | parser.add_argument( 198 | '--env_reset_mode', 199 | help='How to sample the starting position of the agent', 200 | choices=( 201 | 'Uniform', 202 | 'Gaussian', 203 | 'False', 204 | 'Discretized_Uniform'), 205 | default='False') 206 | parser.add_argument( 207 | '--seed', help='Global seed', default=42, type=int) 208 | parser.add_argument( 209 | '--reward_function', 210 | help='What reward function to use', 211 | default='Normal', 212 | type=str) 213 | parser.add_argument( 214 | '--state', help='State to be used', default='Old', type=str) 215 | parser.add_argument( 216 | '--use_double', help='Flag to use float64', type=str, default=None) 217 | 218 | args = parser.parse_args() 219 | 220 | if (args.headless) == 'False': 221 | args.headless = False 222 | else: 223 | args.headless = True 224 | if (args.env_reset_mode) == 'False': 225 | args.env_reset_mode = False 226 | 227 | run_policy(args) 228 | print("Done") 229 | -------------------------------------------------------------------------------- /evaluate_container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Path to the trained policy 4 | # SAVED_POLICY=saved_policies/sac_optimal_policy.pt 5 | SAVED_POLICY=saved_policies/sac_optimal_policy_2.pt 6 | 7 | ## Reward function 8 | REWARD_FUNCTION=Reward_24 9 | ## Initial position probability distribution (\rho_{o}) 10 | env_reset_mode=Discretized_Uniform 11 | ## State-Space 12 | STATE=New_action 13 | 14 | 15 | python3 evaluate_container.py --headless=True --reward_function=${REWARD_FUNCTION} --state=${STATE} --file=${SAVED_POLICY} \ 16 | --env_reset_mode=${env_reset_mode} 17 | 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import copy 5 | import csv 6 | 7 | import numpy as np 8 | 9 | from torch import nn 10 | from torch import optim 11 | import torch 12 | 13 | from sim_framework.envs.drone_env import DroneEnv 14 | 15 | from common import utils 16 | from networks.structures import PolicyNetwork, ValueNetwork, SoftQNetwork 17 | 18 | 19 | def argparser(): 20 | 21 | parser = argparse.ArgumentParser("Benching envs") 22 | parser.add_argument( 23 | '--net_size_policy', 24 | help='Size of the neural networks hidden layers', 25 | type=int, 26 | default=64) 27 | parser.add_argument( 28 | '--activation_function', 29 | help='activation function for policy', 30 | type=str, 31 | default='relu', 32 | choices=[ 33 | 'relu', 34 | 'tanh']) 35 | 36 | parser.add_argument( 37 | '--net_size_value', 38 | help='Size of the neural networks hidden layers', 39 | type=int, 40 | default=256) 41 | parser.add_argument( 42 | '--max_episodes', 43 | help='Number of epochs in the training', 44 | type=int, 45 | default=int(10000)) 46 | parser.add_argument( 47 | '--replay_buffer_size', 48 | help='Size of the replay buffer', 49 | type=int, 50 | default=int(1e6)) 51 | parser.add_argument( 52 | '--num_steps_until_train', 53 | help='How many steps we sample with current policy', 54 | type=int, 55 | default=1) 56 | parser.add_argument( 57 | '--num_trains_per_step', 58 | help='Number of timesteps in each step', 59 | type=int, 60 | default=1) 61 | parser.add_argument( 62 | '--min_num_steps_before_training', 63 | help='Number of timesteps using random (uniform) policy to fill \ 64 | the Replay buffer in the beggining of the training', 65 | type=int, 66 | default=100) 67 | parser.add_argument( 68 | '--batch-size', help='Batch size in each epoch', type=int, default=256) 69 | # parser.add_argument( 70 | # '--use_automatic_entropy_tuning', help='Set True to automatically discover best alpha (temperature)', type=bool, default=True) 71 | parser.add_argument( 72 | '--env_reset_mode', 73 | help='How to sample the starting position of the agent', 74 | choices=( 75 | 'Uniform', 76 | 'Gaussian', 77 | 'False', 78 | 'Discretized_Uniform'), 79 | default='False') 80 | parser.add_argument( 81 | '--use_cuda', 82 | help='If the device is a GPU or CPU', 83 | default='False', 84 | choices=[ 85 | 'True', 86 | 'False']) 87 | parser.add_argument( 88 | '--restore_path', 89 | help='Filename of the policy being loaded', 90 | type=str, 91 | default=None) 92 | parser.add_argument( 93 | '--save_path', help='Filename to save the current training', type=str) 94 | parser.add_argument( 95 | '--log_interval', help='Frequency of logging', type=int, default=100) 96 | parser.add_argument( 97 | '--save-interval', help='Frequency of saving', type=int, default=100) 98 | parser.add_argument( 99 | '--eval_interval', 100 | help='Frequency for evaluating deterministic policy', 101 | type=int, 102 | default=None) 103 | parser.add_argument( 104 | '--use_double', help='Flag to use float64', type=str, default=None) 105 | parser.add_argument( 106 | '--learning_rate', help='Learning rate', type=float, default=3e-4) 107 | parser.add_argument( 108 | '--reward_function', 109 | help='What reward function to use', 110 | default='Normal', 111 | type=str) 112 | parser.add_argument( 113 | '--seed', help='Global seed', default=42, type=int) 114 | parser.add_argument( 115 | '--state', help='Global seed', default='Old', type=str) 116 | parser.add_argument( 117 | '--same-norm', help='same_norm', default=False) 118 | parser.add_argument( 119 | '--threshold', 120 | help='Clipping the difference between action vectors', 121 | default=4.0, 122 | type=float) 123 | parser.add_argument( 124 | '--clip-action', 125 | help='Clipping the difference between action vectors', 126 | default=100, 127 | type=int) 128 | parser.add_argument( 129 | '--save_interval', help='Frequency of saving', type=int, default=100) 130 | return parser.parse_args() 131 | 132 | if (args.env_reset_mode) == 'False': 133 | args.env_reset_mode = False 134 | 135 | if (args.use_cuda) == 'False': 136 | args.use_cuda = False 137 | else: 138 | args.use_cuda = True 139 | if args.activation_function == 'relu': 140 | args.activation_function = F.relu 141 | else: 142 | args.activation_function = F.tanh 143 | 144 | 145 | class SAC(): 146 | 147 | def __init__( 148 | self, 149 | env, 150 | replay_buffer_size, 151 | hidden_dim, 152 | restore_path, 153 | device, 154 | max_episodes, 155 | save_path, 156 | learning_rate=3e-4, 157 | use_double=True, 158 | min_num_steps_before_training=0, 159 | save_interval=100, 160 | ): 161 | """ 162 | Soft Actor-Critic algorithm 163 | 164 | 165 | Parameters 166 | ---------- 167 | env : 168 | The environment to be used 169 | replay_buffer_size : [int] 170 | Replay-buffer size 171 | hidden_dim : [int] 172 | Size of the hidden-layers in Q and V functions[description] 173 | restore_path : [str] 174 | File path to restore training 175 | device : [str or torch.device] 176 | 'cpu' or 'gpu 177 | max_episodes : [int] 178 | Max number of episodes to train the policy 179 | save_path : [str] 180 | File path to save the networks 181 | learning_rate : [float], optional 182 | The learning rate for gradient based optimization, by default 3e-4 183 | use_double : bool, optional 184 | Use Float Tensor or Double Tensor, by default True 185 | min_num_steps_before_training : int, optional 186 | Number of steps to randomly sample before start acting, by default 0 187 | save_interval : int, optional 188 | The interval in epochs to save the models, by default 100 189 | """ 190 | self.env = env 191 | 192 | self.save_path = save_path 193 | self.save_interval = save_interval 194 | self.min_num_steps_before_training = min_num_steps_before_training 195 | self.restore_path = restore_path 196 | self.device = device 197 | self.hidden_dim = hidden_dim 198 | self.use_double = use_double 199 | # Network and env parameters 200 | self.action_dim = self.env.action_space.shape[0] 201 | try: 202 | self.state_dim = self.env.observation_space.shape[0] 203 | except BaseException: 204 | self.state_dim = self.env.observation_space 205 | # hidden_dim = args.net_size_value 206 | self.action_range = [self.env.agent.action_space.low.min( 207 | ), self.env.agent.action_space.high.max()] 208 | 209 | self._creating_models(replay_buffer_size, self.state_dim, 210 | self.action_dim, self.device, self.hidden_dim) 211 | 212 | # Copying the data to the target networks 213 | for target_param, param in zip( 214 | self.target_value_net.parameters(), self.value_net.parameters()): 215 | target_param.data.copy_(param.data) 216 | 217 | # Types of losses 218 | self.value_criterion = nn.MSELoss() 219 | self.soft_q_criterion = nn.MSELoss() 220 | 221 | # Learning rates 222 | self.value_lr = learning_rate 223 | self.soft_q_lr = learning_rate 224 | self.policy_lr = learning_rate 225 | # Optimizers 226 | self.value_optimizer = optim.Adam( 227 | self.value_net.parameters(), lr=self.value_lr) 228 | self.soft_q_optimizer = optim.Adam( 229 | self.soft_q_net.parameters(), lr=self.soft_q_lr) 230 | self.policy_optimizer = optim.Adam( 231 | self.policy_net.parameters(), lr=self.policy_lr) 232 | 233 | if self.use_double: 234 | self.state_to_tensor = lambda x: torch.DoubleTensor( 235 | x).unsqueeze(0).to(device) 236 | else: 237 | self.state_to_tensor = lambda x: torch.FloatTensor( 238 | x).unsqueeze(0).to(device) 239 | 240 | def soft_q_update(self, 241 | batch_size, 242 | gamma=0.99, 243 | mean_lambda=1e-3, 244 | std_lambda=1e-3, 245 | z_lambda=0.0, 246 | soft_tau=1e-2, 247 | ): 248 | """ 249 | SAC train update (Soft-Q) 250 | 251 | Parameters 252 | ---------- 253 | batch_size : [int] 254 | Batch size 255 | gamma : float, optional 256 | Discount factor, by default 0.99 257 | mean_lambda : [float], optional 258 | coefficient for penalty on policy mean magnitude, by default 1e-3 259 | std_lambda : [float], optional 260 | coefficient for penalty on policy variance, by default 1e-3 261 | z_lambda : float, optional 262 | coefficient for penalty on policy mean before been squashed by tanh, by default 0 263 | soft_tau : [float], optional 264 | Soft coefficient to update target networks, by default 1e-2 265 | """ 266 | # Sampling memmory batch 267 | state, action, reward, next_state, done = self.replay_buffer.sample( 268 | batch_size) 269 | 270 | # Broadcast 271 | if self.use_double: 272 | state = torch.DoubleTensor(state).to(self.device) 273 | next_state = torch.DoubleTensor(next_state).to(self.device) 274 | action = torch.DoubleTensor(action).to(self.device) 275 | reward = torch.DoubleTensor(reward).unsqueeze(1).to(self.device) 276 | done = torch.DoubleTensor(np.float64( 277 | done)).unsqueeze(1).to(self.device) 278 | 279 | else: 280 | state = torch.FloatTensor(state).to(self.device) 281 | next_state = torch.FloatTensor(next_state).to(self.device) 282 | action = torch.FloatTensor(action).to(self.device) 283 | reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device) 284 | done = torch.FloatTensor(np.float32( 285 | done)).unsqueeze(1).to(self.device) 286 | 287 | # Net forward-passes 288 | expected_q_value = self.soft_q_net(state, action) 289 | expected_value = self.value_net(state) 290 | new_action, log_prob, z, mean, log_std = self.policy_net.evaluate( 291 | state) 292 | 293 | ## Qf - loss 294 | target_value = self.target_value_net(next_state) 295 | next_q_value = reward + (1 - done) * gamma * target_value 296 | q_value_loss = self.soft_q_criterion( 297 | expected_q_value, next_q_value.detach()) 298 | 299 | ## Vf - loss 300 | expected_new_q_value = self.soft_q_net(state, new_action) 301 | next_value = expected_new_q_value - log_prob 302 | value_loss = self.value_criterion(expected_value, next_value.detach()) 303 | 304 | # Policy Loss 305 | log_prob_target = expected_new_q_value - expected_value 306 | policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean() 307 | mean_loss = mean_lambda * mean.pow(2).mean() 308 | std_loss = std_lambda * log_std.pow(2).mean() 309 | z_loss = z_lambda * z.pow(2).sum(1).mean() 310 | policy_loss += mean_loss + std_loss + z_loss 311 | 312 | # NN updates 313 | self.soft_q_optimizer.zero_grad() 314 | q_value_loss.backward() 315 | torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5) 316 | self.soft_q_optimizer.step() 317 | 318 | self.value_optimizer.zero_grad() 319 | value_loss.backward() 320 | torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5) 321 | self.value_optimizer.step() 322 | 323 | self.policy_optimizer.zero_grad() 324 | policy_loss.backward() 325 | torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5) 326 | self.policy_optimizer.step() 327 | 328 | # Updating the target networks 329 | for target_param, param in zip( 330 | self.target_value_net.parameters(), self.value_net.parameters()): 331 | target_param.data.copy_( 332 | target_param.data * (1.0 - soft_tau) + param.data * soft_tau 333 | ) 334 | 335 | return(q_value_loss.item(), policy_loss.item(), value_loss.item()) 336 | 337 | def __write_csv( 338 | self, 339 | episode, 340 | time_elapsed, 341 | frame_count, 342 | len_replay_buffer, 343 | episode_reward, 344 | value_loss, 345 | q_value_loss, 346 | policy_loss, 347 | step): 348 | """ 349 | Writes data in csv 350 | """ 351 | 352 | with open(os.path.join(self.save_path, 'progress.csv'), 'a') as csvfile: 353 | rew_writer = csv.writer(csvfile, delimiter=';', 354 | quotechar='|', quoting=csv.QUOTE_MINIMAL) 355 | 356 | rew_writer.writerow([episode, 357 | time_elapsed, 358 | frame_count, 359 | len(self.replay_buffer), 360 | episode_reward, 361 | value_loss, 362 | q_value_loss, 363 | policy_loss, 364 | step]) 365 | 366 | def __save_model(self, episode, frame_count): 367 | """ 368 | Saves model pickles 369 | 370 | Parameters 371 | ---------- 372 | episode : [int] 373 | Current episode 374 | frame_count : [int] 375 | Current timestep 376 | """ 377 | save_state = { 378 | 'episode': episode, 379 | 'frame_count': frame_count, 380 | 'value_net': self.value_net.state_dict(), 381 | 'target_value_net': self.target_value_net.state_dict(), 382 | 'soft_q_net': self.soft_q_net.state_dict(), 383 | 'policy_net': self.policy_net.state_dict(), 384 | 'value_optimizer': self.value_optimizer.state_dict(), 385 | 'soft_q_optimizer': self.soft_q_optimizer.state_dict(), 386 | 'policy_optimizer': self.policy_optimizer.state_dict(), 387 | 'replay_buffer': self.replay_buffer 388 | } 389 | torch.save(save_state, self.save_path + '/state.pt') 390 | torch.save(save_state['policy_net'], 391 | self.save_path + '/state.pt'[:-3] + '_policy.pt') 392 | print('saving model at = ', self.save_path) 393 | 394 | def _creating_models( 395 | self, 396 | buffer_size, 397 | state_dim, 398 | action_dim, 399 | device, 400 | hidden_dim): 401 | """ 402 | Istantiating the networks and buffer 403 | """ 404 | 405 | self.policy_net = PolicyNetwork( 406 | state_dim, 407 | action_dim, 408 | args.net_size_policy).to(device).type( 409 | torch.double) 410 | self.eval_policy = copy.deepcopy(self.policy_net) 411 | self.replay_buffer = utils.ReplayBuffer(buffer_size) 412 | self.value_net = ValueNetwork( 413 | state_dim, hidden_dim).to(device).type(torch.double) 414 | self.target_value_net = ValueNetwork( 415 | state_dim, hidden_dim).to(device).type(torch.double) 416 | self.soft_q_net = SoftQNetwork( 417 | state_dim, action_dim, hidden_dim).to(device).type(torch.double) 418 | 419 | if self.use_double: 420 | self.policy_net = self.policy_net.type(torch.double) 421 | self.target_value_net = self.target_value_net.type(torch.double) 422 | self.value_net = self.value_net.type(torch.double) 423 | self.replay_buffer.buffer = np.asarray( 424 | self.replay_buffer.buffer).astype(np.float64).tolist() 425 | self.soft_q_net = self.soft_q_net.type(torch.double) 426 | 427 | def train(self): 428 | """ 429 | Trains a continuous policy by means of SAC 430 | """ 431 | # Starting data 432 | episode = 0 433 | frame_count = 0 434 | max_episodes = args.max_episodes 435 | time_horizon = 250 436 | batch_size = args.batch_size 437 | # Load parameters from previous training if available 438 | utils.restore_data(self.restore_path) 439 | 440 | begin = time.time() 441 | while episode < max_episodes: 442 | 443 | if (episode % 50 == 0): # Hack because of PyRep set position bug 444 | self.env.restart = True 445 | self.env.reset() 446 | self.env.restart = False 447 | 448 | state = self.env.reset() 449 | episode_reward = 0 450 | 451 | for step in range(time_horizon): 452 | if frame_count > self.min_num_steps_before_training: 453 | 454 | action = self.policy_net.get_action( 455 | self.state_to_tensor(state)) # .detach() 456 | next_state, reward, done, self.env_info = self.env.step( 457 | action * self.action_range[1]) 458 | 459 | else: 460 | action = np.random.sample(self.action_dim) 461 | next_state, reward, done, self.env_info = self.env.step( 462 | action * self.action_range[1]) 463 | 464 | self.replay_buffer.push( 465 | state, action, reward, next_state, done) 466 | 467 | if len(self.replay_buffer) > batch_size: 468 | if (episode % args.num_steps_until_train) == 0: 469 | for i in range(args.num_trains_per_step): 470 | q_value_loss, policy_loss, value_loss = self.soft_q_update( 471 | batch_size) 472 | 473 | state = next_state 474 | episode_reward += reward 475 | frame_count += 1 476 | 477 | if done: 478 | break 479 | 480 | print("Episode = {0} | Reward = {1:.2f} | Lenght = {2:.2f}".format( 481 | episode, episode_reward, step)) 482 | episode += 1 483 | 484 | # Saving 485 | if (episode % self.save_interval == 0) and (episode > 0): 486 | self.__save_model(episode, frame_count) 487 | 488 | time_elapsed = time.time() - begin 489 | if (episode % 100 == 0) and (episode > 0): 490 | print('Time elapsed so far = {0:.2f} seconds.'.format( 491 | time_elapsed)) 492 | 493 | # Logging 494 | if len(self.replay_buffer) > batch_size: 495 | self.__write_csv(episode, 496 | time_elapsed, 497 | frame_count, 498 | len(self.replay_buffer), 499 | episode_reward, 500 | value_loss, 501 | q_value_loss, 502 | policy_loss, 503 | step) 504 | 505 | 506 | def main(args): 507 | 508 | env = DroneEnv(random=args.env_reset_mode, headless=True, seed=args.seed, 509 | reward_function_name=args.reward_function, state=args.state) 510 | 511 | use_cuda = torch.cuda.is_available() 512 | 513 | if use_cuda and (args.use_cuda): 514 | device = torch.device("cuda") 515 | else: 516 | device = torch.device("cpu") 517 | 518 | # Set save/restore paths 519 | save_path = os.path.join('./checkpoint/', args.save_path) + '/' 520 | 521 | restore_path = args.restore_path or save_path 522 | report_folder = save_path # save_path.split('/')[0] + '/' 523 | 524 | # Check if they exist 525 | utils.check_dir(save_path) 526 | if restore_path: 527 | utils.check_dir(restore_path) 528 | utils.check_dir(report_folder) 529 | 530 | # Preparing log csv 531 | if not os.path.isfile(os.path.join(report_folder, 'progress.csv')): 532 | print('There is no csv there') 533 | with open(os.path.join(report_folder, 'progress.csv'), 'w') as outcsv: 534 | writer = csv.writer(outcsv, delimiter=';', 535 | quotechar='|', quoting=csv.QUOTE_MINIMAL) 536 | writer.writerow(["Episode", "Total time (s)", "Frame", 537 | "Buffer_size", 538 | "Mean_Reward", 539 | "value_loss", "q_value_loss", "policy_loss", 540 | "episode_lenght"]) 541 | 542 | # Network and env parameters 543 | action_dim = env.action_space.shape[0] 544 | try: 545 | state_dim = env.observation_space.shape[0] 546 | except BaseException: 547 | state_dim = env.observation_space 548 | hidden_dim = args.net_size_value 549 | action_range = [env.agent.action_space.low.min( 550 | ), env.agent.action_space.high.max()] 551 | 552 | sac = SAC( 553 | env=env, 554 | replay_buffer_size=args.replay_buffer_size, 555 | hidden_dim=hidden_dim, 556 | restore_path=restore_path, 557 | device=device, 558 | save_path=save_path, 559 | learning_rate=args.learning_rate, 560 | max_episodes=args.max_episodes, 561 | use_double=args.use_double, 562 | save_interval=args.save_interval) 563 | 564 | sac.train() 565 | 566 | 567 | if __name__ == "__main__": 568 | args = argparser() 569 | 570 | # Setting seed 571 | torch.manual_seed(args.seed) 572 | # random.seed(a = seed) 573 | np.random.seed(seed=args.seed) 574 | 575 | main(args) 576 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/networks/__init__.py -------------------------------------------------------------------------------- /networks/structures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal, Beta 6 | 7 | # from common.utils import * 8 | 9 | 10 | class ValueNetwork(nn.Module): 11 | """ 12 | A Value V(s) network 13 | 14 | Parameters 15 | ---------- 16 | state_dim : [int] 17 | The observation_space of the environment 18 | hidden_dim : [int] 19 | The latent dimension in the hidden-layers 20 | init_w : [float], optional 21 | Initial weights for the neural network, by default 3e-3 22 | """ 23 | 24 | def __init__(self, state_dim, hidden_dim, init_w=3e-3): 25 | super(ValueNetwork, self).__init__() 26 | 27 | self.linear1 = nn.Linear(state_dim, hidden_dim) 28 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 29 | self.linear3 = nn.Linear(hidden_dim, 1) 30 | 31 | self.linear3.weight.data.uniform_(-init_w, init_w) 32 | self.linear3.bias.data.uniform_(-init_w, init_w) 33 | 34 | def forward(self, state): 35 | """ 36 | Forward-pass of the value net 37 | 38 | Parameters 39 | ---------- 40 | state : [torch.Tensor] 41 | The input state 42 | 43 | Returns 44 | ------- 45 | 46 | The value_hat from being in each state 47 | """ 48 | x = F.relu(self.linear1(state)) 49 | x = F.relu(self.linear2(x)) 50 | x = self.linear3(x) 51 | return x 52 | 53 | 54 | class SoftQNetwork(nn.Module): 55 | """ 56 | A Q(s,a)-function network 57 | 58 | Parameters 59 | ---------- 60 | state_dim : [int] 61 | The observation_space of the environment 62 | action_dim : [int] 63 | The action of the environment 64 | hidden_dim : [int] 65 | The latent dimension in the hidden-layers 66 | init_w : [float], optional 67 | Initial weights for the neural network, by default 3e-3 68 | """ 69 | 70 | def __init__(self, state_dim, action_dim, hidden_size, init_w=3e-3): 71 | super(SoftQNetwork, self).__init__() 72 | 73 | self.linear1 = nn.Linear(state_dim + action_dim, hidden_size) 74 | self.linear2 = nn.Linear(hidden_size, hidden_size) 75 | self.linear3 = nn.Linear(hidden_size, 1) 76 | 77 | self.linear3.weight.data.uniform_(-init_w, init_w) 78 | self.linear3.bias.data.uniform_(-init_w, init_w) 79 | 80 | def forward(self, state, action): 81 | """ 82 | Forward-pass of the q-value net 83 | 84 | Parameters 85 | ---------- 86 | action : [torch.Tensor] 87 | The input action 88 | state : [torch.Tensor] 89 | The input state 90 | Returns 91 | ------- 92 | Q(s,v) q-value 93 | """ 94 | x = torch.cat([state, action], 1) 95 | x = F.relu(self.linear1(x)) 96 | x = F.relu(self.linear2(x)) 97 | x = self.linear3(x) 98 | return x 99 | 100 | 101 | class PolicyNetwork(nn.Module): 102 | """ 103 | The policy network for implementing SAC 104 | 105 | Parameters 106 | ---------- 107 | state_dim : [int] 108 | The observation_space of the environment 109 | action_dim : [int] 110 | The action of the environment 111 | hidden_dim : [int] 112 | The latent dimension in the hidden-layers 113 | init_w : [float], optional 114 | Initial weights for the neural network, by default 3e-3 115 | log_std_min : int, optional 116 | Min possible value for policy log_std, by default -20 117 | log_std_max : int, optional 118 | Max possible value for policy log_std, by default 2 119 | activation_function : , optional 120 | Name of the activation function 121 | 122 | """ 123 | 124 | def __init__( 125 | self, 126 | state_dim, 127 | num_actions, 128 | hidden_size, 129 | init_w=3e-3, 130 | log_std_min=-20, 131 | log_std_max=2, 132 | activation_function=F.relu): 133 | 134 | super(PolicyNetwork, self).__init__() 135 | 136 | self.log_std_min = log_std_min 137 | self.log_std_max = log_std_max 138 | 139 | self.linear1 = nn.Linear(state_dim, hidden_size) 140 | self.linear2 = nn.Linear(hidden_size, hidden_size) 141 | 142 | self.mean_linear = nn.Linear(hidden_size, num_actions) 143 | self.mean_linear.weight.data.uniform_(-init_w, init_w) 144 | self.mean_linear.bias.data.uniform_(-init_w, init_w) 145 | 146 | self.log_std_linear = nn.Linear(hidden_size, num_actions) 147 | self.log_std_linear.weight.data.uniform_(-init_w, init_w) 148 | self.log_std_linear.bias.data.uniform_(-init_w, init_w) 149 | 150 | self.activation_function = activation_function 151 | 152 | def forward(self, state,): 153 | """ 154 | Policy forward-pass 155 | 156 | Parameters 157 | ---------- 158 | state : [torch.Tensor] 159 | The input state 160 | 161 | Returns 162 | [torch.Tensor] - action to be taken 163 | ------- 164 | """ 165 | x = self.activation_function(self.linear1(state)) 166 | 167 | x = self.activation_function(self.linear2(x)) 168 | 169 | mean = self.mean_linear(x) 170 | log_std = self.log_std_linear(x) 171 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 172 | 173 | return mean, log_std 174 | 175 | def evaluate(self, state, epsilon=1e-6): 176 | """ 177 | Calculates log_prob and squashes the action 178 | 179 | Parameters 180 | ---------- 181 | state : [torch.Tensor] 182 | The input state 183 | 184 | Returns 185 | ------- 186 | squashed_action, log_prob, raw_action, policy_mean, policy_log_std 187 | """ 188 | mean, log_std = self.forward(state) 189 | std = log_std.exp() 190 | 191 | normal = Normal(mean, std) 192 | z = normal.sample() # Add reparam trick? 193 | action = torch.tanh(z) 194 | 195 | # - np.log(self.action_range) See gist https://github.com/quantumiracle/SOTA-RL-Algorithms/blob/master/sac_v2.py 196 | log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon) 197 | # log_prob = normal.log_prob(z) - torch.log(torch.clamp(1 - action.pow(2), 198 | # min=0,max=1) + epsilon) # nao precisa por causa do squase tanh 199 | 200 | log_prob = log_prob.sum(-1, keepdim=True) 201 | 202 | return action, log_prob, z, mean, log_std 203 | 204 | def get_action(self, state): 205 | """ 206 | Return stochastic action, without calculating log_prob 207 | 208 | Parameters 209 | ---------- 210 | state : [torch.Tensor] 211 | The input state 212 | 213 | Returns 214 | ------- 215 | squashed_action: 216 | Action after tanh 217 | """ 218 | mean, log_std = self.forward(state) 219 | std = log_std.exp() 220 | 221 | normal = Normal(mean, std) 222 | z = normal.sample() 223 | 224 | action = torch.tanh(z) 225 | 226 | action = action.detach().cpu().numpy() 227 | return action[0] 228 | 229 | def deterministic_action(self, state): 230 | """ 231 | Return deterministic action, without calculating log_prob 232 | 233 | Parameters 234 | ---------- 235 | state : [torch.Tensor] 236 | The input state 237 | 238 | Returns 239 | ------- 240 | squashed_action: 241 | Action after tanh 242 | """ 243 | mean, log_std = self.forward(state) 244 | action = torch.tanh(mean) 245 | 246 | action = action.detach().cpu().numpy() 247 | return action[0] 248 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | ## Figures 2 | 3 | **Initial positions for the UAV agent** 4 | 5 | 6 | 7 | ### Pos x 8 | 9 |

10 | 11 |

12 | 13 | ### Pos y 14 | 15 |

16 | 17 |

18 | 19 | 20 | ### Pos z 21 | 22 | 23 |

24 | 25 |

26 | 27 | ### Ang x 28 | 29 | 30 |

31 | 32 |

33 | 34 | ### Ang y 35 | 36 | 37 |

38 | 39 |

40 | 41 | ### Ang z 42 | 43 | 44 |

45 | 46 |

47 | 48 | 49 | ### Agent actions (Proppeler inputs) 50 | 51 | 52 |

53 | 54 |

55 | 56 | -------------------------------------------------------------------------------- /notebooks/Robustness.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "sys.path.append('../')\n", 18 | "\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,\n", 21 | " AutoMinorLocator, LinearLocator)\n", 22 | "from matplotlib.ticker import ScalarFormatter\n", 23 | "import matplotlib.ticker as mtick\n", 24 | "from collections import OrderedDict\n", 25 | "import pandas as pd\n", 26 | "import numpy as np\n", 27 | "# import dill\n", 28 | "import argparse\n", 29 | "import pyrep.backend.sim as sim\n", 30 | "from networks.structures import PolicyNetwork, ValueNetwork, SoftQNetwork\n", 31 | "import torch\n", 32 | "\n", 33 | "from sim_framework.envs.drone_env import DroneEnv\n", 34 | "import time\n", 35 | "import itertools\n", 36 | "import random\n", 37 | "%matplotlib inline" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Utils" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "def custom_reset(env, variant):\n", 54 | "# env.pr.stop()\n", 55 | " env.current_action = np.array([0,0,0,0])\n", 56 | " \n", 57 | "# env.agent.set_position(np.round([0,0,1.7],2).tolist())\n", 58 | "# env.agent.set_orientation(np.round([0,0,0],2).tolist())\n", 59 | "\n", 60 | " env.agent.set_orientation(np.round(variant[3:],2).tolist())\n", 61 | " env.agent.set_position(np.round(variant[:3],2).tolist())\n", 62 | "\n", 63 | " env.agent.set_thrust_and_torque(np.asarray([0.] * 4), force_zero=True)\n", 64 | " env.agent.set_joint_positions(env.initial_joint_positions)\n", 65 | " env.agent.set_joint_target_velocities(env.initial_joint_velocities)\n", 66 | " env.agent.set_joint_target_positions(env.initial_joint_target_positions)\n", 67 | " \n", 68 | "# env.agent.set_orientation(np.round([0,0,0],2).tolist())\n", 69 | " \n", 70 | "\n", 71 | " env.first_obs=True\n", 72 | " env._make_observation()\n", 73 | " env.last_state = env.observation[:18] \n", 74 | "# env.pr.start()\n", 75 | " \n", 76 | " return env.observation" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "\n", 86 | "def state_to_tensor(state, device):\n", 87 | " \"\"\"Transform numpy array to torch tensor\"\"\"\n", 88 | " if args.use_double:\n", 89 | " return torch.DoubleTensor(state).unsqueeze(0).to(device)\n", 90 | " else:\n", 91 | " return torch.FloatTensor(state).unsqueeze(0).to(device)\n", 92 | "\n" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 17, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "\n", 102 | "\n", 103 | "def rollouts(\n", 104 | " variant,\n", 105 | " env,\n", 106 | " policy,\n", 107 | " action_range,\n", 108 | " device,\n", 109 | " max_timesteps=1000,\n", 110 | " time_horizon=250):\n", 111 | " \"\"\"\n", 112 | " Perform policy rollouts until a max given number of steps\n", 113 | "\n", 114 | " Parameters\n", 115 | " ----------\n", 116 | " env :\n", 117 | " A larocs_sim environment\n", 118 | " policy :\n", 119 | " An actor-policy for the agent act in the environment\n", 120 | " action_range : list\n", 121 | " Range of possible float values for the action\n", 122 | " max_timesteps : int, optional\n", 123 | " Number of timesteps to perform while interacting with the environment, by default 1000\n", 124 | " time_horizon : int, optional\n", 125 | " The number of steps for each episode, by default 250\n", 126 | "\n", 127 | " \"\"\"\n", 128 | " count = 0\n", 129 | " dones = False\n", 130 | " set_of_obs, set_of_next_obs, set_of_rewards, set_of_actions, set_of_dones, set_of_infos = [], [], [], [], [], []\n", 131 | "\n", 132 | " rollout = -1\n", 133 | "\n", 134 | " mb_obs, mb_next_obs, mb_rewards, mb_actions, mb_dones, mb_infos = [], [], [], [], [], []\n", 135 | " rollout += 1\n", 136 | "\n", 137 | "# obs0 = env.reset()\n", 138 | " obs0 = custom_reset(env,variant)\n", 139 | "\n", 140 | "\n", 141 | " for j in range(time_horizon):\n", 142 | " dones = False\n", 143 | " try:\n", 144 | " actions, agent_info = policy.deterministic_action(\n", 145 | " state_to_tensor(obs0, device))\n", 146 | " except:\n", 147 | " actions = policy.deterministic_action(\n", 148 | " state_to_tensor(obs0, device))\n", 149 | "\n", 150 | " # Take actions in env and look the results\n", 151 | " obs1, rewards, dones, infos = env.step(actions * action_range[1])\n", 152 | " # Append on the experience buffers\n", 153 | " mb_obs.append(obs0)\n", 154 | " # mb_obs.append(obs0)\n", 155 | " mb_next_obs.append(obs1)\n", 156 | " mb_actions.append(actions)\n", 157 | " mb_dones.append(dones)\n", 158 | " mb_rewards.append(rewards)\n", 159 | " mb_infos.append(infos)\n", 160 | "\n", 161 | " count += 1\n", 162 | " if dones:\n", 163 | " break\n", 164 | "\n", 165 | " obs0 = obs1\n", 166 | "# print()\n", 167 | "# print('rewards: mean = {0}'.format(np.mean(mb_rewards)))\n", 168 | " print('rewards: sum = {0}'.format(np.sum(mb_rewards)))\n", 169 | "\n", 170 | " set_of_obs.append(mb_obs)\n", 171 | " set_of_next_obs.append(mb_next_obs)\n", 172 | " set_of_rewards.append(mb_rewards)\n", 173 | " set_of_actions.append(mb_actions)\n", 174 | " set_of_dones.append(mb_dones)\n", 175 | " set_of_infos.append(mb_infos)\n", 176 | "\n", 177 | " set_tau = {'obs': set_of_obs,\n", 178 | " 'next_obs': set_of_next_obs,\n", 179 | " 'rewards': set_of_rewards,\n", 180 | " 'actions': set_of_actions,\n", 181 | " 'dones': set_of_dones,\n", 182 | " 'infos': set_of_infos}\n", 183 | " return set_tau" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "# Variables" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 28, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "class Args():\n", 200 | " def __init__(self):\n", 201 | " pass\n", 202 | "\n", 203 | "\n", 204 | "args = Args()\n", 205 | "\n", 206 | "args.H = 250\n", 207 | "\n", 208 | "args.max_timesteps=250\n", 209 | "\n", 210 | "\n", 211 | "env_reset_mode = \"Discretized_Uniform\"\n", 212 | "seed = 42\n", 213 | "headless = True\n", 214 | "# headless = False\n", 215 | "\n", 216 | "state='New_action'\n", 217 | "reward='Normal'\n", 218 | "try:\n", 219 | " env.shutdown()\n", 220 | "except:\n", 221 | " pass\n", 222 | "env = DroneEnv(random=env_reset_mode,seed=seed, headless = headless, state=state,\\\n", 223 | " reward_function_name=reward)\n" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "# Permutation List" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 11, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "Number of different variants\n", 243 | "216\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "## Ticks for Discretized_Uniform initialization setup\n", 249 | "xy_ticks = env.x_y_ticks\n", 250 | "z_ticks = env.z_ticks\n", 251 | "ang_ticks = env.ang_ticks\n", 252 | "\n", 253 | "extreme_angles = ang_ticks[[0,-1]]\n", 254 | "extreme_angles = np.append(extreme_angles, 0)\n", 255 | "extreme_xyticks = xy_ticks[[0,-1]]\n", 256 | "extreme_zticks = z_ticks[[0,-1]]\n", 257 | "\n", 258 | "\n", 259 | "\n", 260 | "all_list = [extreme_xyticks,extreme_xyticks, np.round( extreme_zticks,2), extreme_angles,extreme_angles,extreme_angles]\n", 261 | "\n", 262 | "\n", 263 | "res = list(itertools.product(*all_list)) \n", 264 | "random.shuffle(res)\n", 265 | "print('Number of different variants')\n", 266 | "print(len(res))" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "# Running" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 24, 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "Finished Loading\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "args.use_double=False\n", 291 | "args.use_cuda=False\n", 292 | "\n", 293 | "use_cuda = torch.cuda.is_available()\n", 294 | "\n", 295 | "if use_cuda and (args.use_cuda):\n", 296 | " device = torch.device(\"cuda\")\n", 297 | "else:\n", 298 | " device = torch.device(\"cpu\")\n", 299 | "\n", 300 | "\n", 301 | "restore_path = '../saved_policies/sac_optimal_policy_2.pt'\n", 302 | "\n", 303 | "try:\n", 304 | " checkpoint = torch.load(restore_path, map_location='cpu')\n", 305 | "except BaseException:\n", 306 | " checkpoint = torch.load(restore_path, map_location=torch.device('cpu'))\n", 307 | "print('Finished Loading')\n", 308 | "\n", 309 | "# Neural network parameters\n", 310 | "try:\n", 311 | " state_dim = env.observation_space.shape[0]\n", 312 | "except BaseException:\n", 313 | " state_dim = env.observation_space\n", 314 | "action_dim = env.action_space.shape[0]\n", 315 | "hidden_dim = checkpoint['linear1.weight'].data.shape[0]\n", 316 | "action_range = [env.agent.action_space.low.min(\n", 317 | "), env.agent.action_space.high.max()]\n", 318 | "size_obs = checkpoint['linear1.weight'].data.shape[1]\n", 319 | "\n", 320 | "assert size_obs == state_dim, 'Checkpoint state must be the same as the env'\n", 321 | "\n", 322 | "\n", 323 | "policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)\n" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 25, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "Finished Loading the weights\n" 336 | ] 337 | } 338 | ], 339 | "source": [ 340 | "# Loading Models\n", 341 | "policy_net.load_state_dict(checkpoint)\n", 342 | "print('Finished Loading the weights')\n" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 19, 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "0\n", 355 | "rewards: sum = 877.1200845534663\n", 356 | "1\n", 357 | "rewards: sum = 869.9280922522692\n", 358 | "2\n", 359 | "rewards: sum = 891.6576752246817\n", 360 | "3\n", 361 | "rewards: sum = 883.0101529152173\n", 362 | "4\n", 363 | "rewards: sum = 896.3264992912492\n", 364 | "5\n", 365 | "rewards: sum = 874.0152384834564\n", 366 | "6\n", 367 | "rewards: sum = 900.5373236322093\n", 368 | "7\n", 369 | "rewards: sum = 899.9096966483784\n", 370 | "8\n", 371 | "rewards: sum = 894.9760864179252\n", 372 | "9\n", 373 | "rewards: sum = 886.492229917679\n", 374 | "10\n", 375 | "rewards: sum = 901.0729021115104\n", 376 | "11\n", 377 | "rewards: sum = 887.2421559703803\n", 378 | "12\n", 379 | "rewards: sum = 889.0111172262613\n", 380 | "13\n", 381 | "rewards: sum = 860.6523525027959\n", 382 | "14\n", 383 | "rewards: sum = 862.3917533988092\n", 384 | "15\n", 385 | "rewards: sum = 888.6664420551124\n", 386 | "16\n", 387 | "rewards: sum = 904.3786044462072\n", 388 | "17\n", 389 | "rewards: sum = 897.1422578799277\n", 390 | "18\n", 391 | "rewards: sum = 904.8975850526614\n", 392 | "19\n", 393 | "rewards: sum = 878.4370341487543\n", 394 | "20\n", 395 | "rewards: sum = 896.7519630918687\n", 396 | "21\n", 397 | "rewards: sum = 871.8436541283222\n", 398 | "22\n", 399 | "rewards: sum = 878.558593293236\n", 400 | "23\n", 401 | "rewards: sum = 904.699504927172\n", 402 | "24\n", 403 | "rewards: sum = 880.1194983128975\n", 404 | "25\n", 405 | "rewards: sum = 878.523675280956\n", 406 | "26\n", 407 | "rewards: sum = 882.9461452790712\n", 408 | "27\n", 409 | "rewards: sum = 863.4715386745926\n", 410 | "28\n", 411 | "rewards: sum = 899.7397506765387\n", 412 | "29\n", 413 | "rewards: sum = 875.8037820009756\n", 414 | "30\n", 415 | "rewards: sum = 898.0817699817233\n", 416 | "31\n", 417 | "rewards: sum = 885.0298235842963\n", 418 | "32\n", 419 | "rewards: sum = 860.603878460849\n", 420 | "33\n", 421 | "rewards: sum = 899.1610794915347\n", 422 | "34\n", 423 | "rewards: sum = 907.2296450942133\n", 424 | "35\n", 425 | "rewards: sum = 884.4922844950675\n", 426 | "36\n", 427 | "rewards: sum = 887.118100912087\n", 428 | "37\n", 429 | "rewards: sum = 899.6497850689009\n", 430 | "38\n", 431 | "rewards: sum = 874.040763531865\n", 432 | "39\n", 433 | "rewards: sum = 877.4780777957269\n", 434 | "40\n", 435 | "rewards: sum = 865.3159603220156\n", 436 | "41\n", 437 | "rewards: sum = 890.8162485466335\n", 438 | "42\n", 439 | "rewards: sum = 894.398685175023\n", 440 | "43\n", 441 | "rewards: sum = 896.7292923725229\n", 442 | "44\n", 443 | "rewards: sum = 874.3684217745333\n", 444 | "45\n", 445 | "rewards: sum = 890.0497513913326\n", 446 | "46\n", 447 | "rewards: sum = 888.3588113783471\n", 448 | "47\n", 449 | "rewards: sum = 897.3101298101333\n", 450 | "48\n", 451 | "rewards: sum = 903.720273688921\n", 452 | "49\n", 453 | "rewards: sum = 878.1758214536835\n", 454 | "50\n", 455 | "rewards: sum = 891.54218835386\n", 456 | "51\n", 457 | "rewards: sum = 887.3333606750956\n", 458 | "52\n", 459 | "rewards: sum = 896.4733332860587\n", 460 | "53\n", 461 | "rewards: sum = 896.2428940994746\n", 462 | "54\n", 463 | "rewards: sum = 872.0980191229746\n", 464 | "55\n", 465 | "rewards: sum = 895.2377687164615\n", 466 | "56\n", 467 | "rewards: sum = 900.2512609809758\n", 468 | "57\n", 469 | "rewards: sum = 883.9459429163135\n", 470 | "58\n", 471 | "rewards: sum = 881.1159313309353\n", 472 | "59\n", 473 | "rewards: sum = 905.6991006539256\n", 474 | "60\n", 475 | "rewards: sum = 892.1217530806359\n", 476 | "61\n", 477 | "rewards: sum = 897.8970661355372\n", 478 | "62\n", 479 | "rewards: sum = 898.0981787532521\n", 480 | "63\n", 481 | "rewards: sum = 880.1850117657287\n", 482 | "64\n", 483 | "rewards: sum = 877.8151583108748\n", 484 | "65\n", 485 | "rewards: sum = 888.6726378395499\n", 486 | "66\n", 487 | "rewards: sum = 889.7889131081004\n", 488 | "67\n", 489 | "rewards: sum = 875.4889509668368\n", 490 | "68\n", 491 | "rewards: sum = 874.8845105642113\n", 492 | "69\n", 493 | "rewards: sum = 908.5999890390235\n", 494 | "70\n", 495 | "rewards: sum = 880.1802331306792\n", 496 | "71\n", 497 | "rewards: sum = 864.0250776596363\n", 498 | "72\n", 499 | "rewards: sum = 882.3160873384968\n", 500 | "73\n", 501 | "rewards: sum = 900.0099752908798\n", 502 | "74\n", 503 | "rewards: sum = 859.1669384718923\n", 504 | "75\n", 505 | "rewards: sum = 870.4063612971073\n", 506 | "76\n", 507 | "rewards: sum = 892.0235208870704\n", 508 | "77\n", 509 | "rewards: sum = 864.4731059839928\n", 510 | "78\n", 511 | "rewards: sum = 892.9096124655371\n", 512 | "79\n", 513 | "rewards: sum = 898.7552933317618\n", 514 | "80\n", 515 | "rewards: sum = 879.8477111625366\n", 516 | "81\n", 517 | "rewards: sum = 874.9483665867407\n", 518 | "82\n", 519 | "rewards: sum = 890.1741272239042\n", 520 | "83\n", 521 | "rewards: sum = 888.5680966989888\n", 522 | "84\n", 523 | "rewards: sum = 887.394871781398\n", 524 | "85\n", 525 | "rewards: sum = 869.6367722934792\n", 526 | "86\n", 527 | "rewards: sum = 907.0537672727554\n", 528 | "87\n", 529 | "rewards: sum = 903.4620615204284\n", 530 | "88\n", 531 | "rewards: sum = 895.4471779789939\n", 532 | "89\n", 533 | "rewards: sum = 905.2972526552121\n", 534 | "90\n", 535 | "rewards: sum = 888.7805656164828\n", 536 | "91\n", 537 | "rewards: sum = 889.2913439626157\n", 538 | "92\n", 539 | "rewards: sum = 889.3228466167005\n", 540 | "93\n", 541 | "rewards: sum = 891.6299172011553\n", 542 | "94\n", 543 | "rewards: sum = 877.0072429058387\n", 544 | "95\n", 545 | "rewards: sum = 886.1487450326917\n", 546 | "96\n", 547 | "rewards: sum = 904.2764247031805\n", 548 | "97\n", 549 | "rewards: sum = 867.9149661258944\n", 550 | "98\n", 551 | "rewards: sum = 861.3714723129583\n", 552 | "99\n", 553 | "rewards: sum = 886.2550257673836\n", 554 | "100\n", 555 | "rewards: sum = 905.1098989330309\n", 556 | "101\n", 557 | "rewards: sum = 868.7850588059113\n", 558 | "102\n", 559 | "rewards: sum = 896.1657750931473\n", 560 | "103\n", 561 | "rewards: sum = 901.9382654078967\n", 562 | "104\n", 563 | "rewards: sum = 905.9356892658475\n", 564 | "105\n", 565 | "rewards: sum = 892.8438508915797\n", 566 | "106\n", 567 | "rewards: sum = 865.3958650772488\n", 568 | "107\n", 569 | "rewards: sum = 880.1571674100767\n", 570 | "108\n", 571 | "rewards: sum = 887.6338052695636\n", 572 | "109\n", 573 | "rewards: sum = 869.3132076379745\n", 574 | "110\n", 575 | "rewards: sum = 885.605521167952\n", 576 | "111\n", 577 | "rewards: sum = 904.648925306919\n", 578 | "112\n", 579 | "rewards: sum = 876.8297294798848\n", 580 | "113\n", 581 | "rewards: sum = 875.9425116335061\n", 582 | "114\n", 583 | "rewards: sum = 859.3138985533194\n", 584 | "115\n", 585 | "rewards: sum = 896.0121033288492\n", 586 | "116\n", 587 | "rewards: sum = 883.3270857664729\n", 588 | "117\n", 589 | "rewards: sum = 885.2384869879169\n", 590 | "118\n", 591 | "rewards: sum = 905.4902121118921\n", 592 | "119\n", 593 | "rewards: sum = 874.5613911275354\n", 594 | "120\n", 595 | "rewards: sum = 900.173096402854\n", 596 | "121\n", 597 | "rewards: sum = 906.6125687178132\n", 598 | "122\n", 599 | "rewards: sum = 869.8399541657919\n", 600 | "123\n", 601 | "rewards: sum = 897.3525655404424\n", 602 | "124\n", 603 | "rewards: sum = 882.7358544631047\n", 604 | "125\n", 605 | "rewards: sum = 880.1140476267295\n", 606 | "126\n", 607 | "rewards: sum = 867.8751980722742\n", 608 | "127\n", 609 | "rewards: sum = 886.3557792963463\n", 610 | "128\n", 611 | "rewards: sum = 888.3169191519439\n", 612 | "129\n", 613 | "rewards: sum = 895.0608496324099\n", 614 | "130\n", 615 | "rewards: sum = 874.2996137935713\n", 616 | "131\n", 617 | "rewards: sum = 866.5415474741133\n", 618 | "132\n", 619 | "rewards: sum = 896.9772065911468\n", 620 | "133\n", 621 | "rewards: sum = 908.1953438831422\n", 622 | "134\n", 623 | "rewards: sum = 877.4911076870802\n", 624 | "135\n", 625 | "rewards: sum = 871.8277994381117\n", 626 | "136\n", 627 | "rewards: sum = 897.7919084211632\n", 628 | "137\n", 629 | "rewards: sum = 882.0196374989864\n", 630 | "138\n", 631 | "rewards: sum = 901.1230401159547\n", 632 | "139\n", 633 | "rewards: sum = 894.7401410247085\n", 634 | "140\n", 635 | "rewards: sum = 875.4262999174857\n", 636 | "141\n", 637 | "rewards: sum = 868.940856199061\n", 638 | "142\n", 639 | "rewards: sum = 873.8985469960013\n", 640 | "143\n", 641 | "rewards: sum = 877.176272228336\n", 642 | "144\n", 643 | "rewards: sum = 874.6807737344345\n", 644 | "145\n", 645 | "rewards: sum = 880.6095389342806\n", 646 | "146\n", 647 | "rewards: sum = 899.5998276260364\n", 648 | "147\n", 649 | "rewards: sum = 904.6259297211284\n", 650 | "148\n", 651 | "rewards: sum = 891.537499132633\n", 652 | "149\n", 653 | "rewards: sum = 885.4084181786037\n", 654 | "150\n", 655 | "rewards: sum = 878.3407669313056\n", 656 | "151\n", 657 | "rewards: sum = 884.2035327490833\n", 658 | "152\n", 659 | "rewards: sum = 880.3855661686064\n", 660 | "153\n", 661 | "rewards: sum = 873.3097975933238\n", 662 | "154\n", 663 | "rewards: sum = 886.8109559181175\n", 664 | "155\n", 665 | "rewards: sum = 880.4762756123328\n", 666 | "156\n", 667 | "rewards: sum = 882.7655801814049\n", 668 | "157\n", 669 | "rewards: sum = 877.2979047255722\n", 670 | "158\n", 671 | "rewards: sum = 892.6267489700444\n", 672 | "159\n", 673 | "rewards: sum = 894.4956217047243\n", 674 | "160\n", 675 | "rewards: sum = 899.271683824795\n", 676 | "161\n", 677 | "rewards: sum = 903.3672494783788\n", 678 | "162\n", 679 | "rewards: sum = 900.2672554117383\n", 680 | "163\n", 681 | "rewards: sum = 896.8892139879791\n", 682 | "164\n", 683 | "rewards: sum = 880.6624677497364\n", 684 | "165\n", 685 | "rewards: sum = 880.0700071427483\n", 686 | "166\n", 687 | "rewards: sum = 889.988926414548\n", 688 | "167\n", 689 | "rewards: sum = 897.9832567074013\n", 690 | "168\n", 691 | "rewards: sum = 891.6839192903972\n", 692 | "169\n", 693 | "rewards: sum = 889.0628301191881\n", 694 | "170\n", 695 | "rewards: sum = 871.0577941564181\n", 696 | "171\n", 697 | "rewards: sum = 871.8166329747819\n", 698 | "172\n", 699 | "rewards: sum = 907.5437227034581\n", 700 | "173\n", 701 | "rewards: sum = 881.4289658824617\n", 702 | "174\n", 703 | "rewards: sum = 879.0560142258798\n", 704 | "175\n", 705 | "rewards: sum = 889.5663122836002\n", 706 | "176\n", 707 | "rewards: sum = 903.6944064497113\n", 708 | "177\n", 709 | "rewards: sum = 885.3211531556096\n", 710 | "178\n", 711 | "rewards: sum = 905.7183268704994\n", 712 | "179\n", 713 | "rewards: sum = 880.6149473825969\n", 714 | "180\n", 715 | "rewards: sum = 896.3333512810068\n", 716 | "181\n", 717 | "rewards: sum = 886.8225148284394\n", 718 | "182\n", 719 | "rewards: sum = 897.9333795506757\n", 720 | "183\n", 721 | "rewards: sum = 896.8336241145724\n", 722 | "184\n", 723 | "rewards: sum = 892.4037890477107\n", 724 | "185\n", 725 | "rewards: sum = 873.0821395156672\n", 726 | "186\n", 727 | "rewards: sum = 858.7041658727724\n", 728 | "187\n", 729 | "rewards: sum = 864.8203602415979\n", 730 | "188\n", 731 | "rewards: sum = 884.2697529649645\n", 732 | "189\n", 733 | "rewards: sum = 881.4868743841479\n", 734 | "190\n", 735 | "rewards: sum = 894.3075156396611\n", 736 | "191\n", 737 | "rewards: sum = 882.8933127460814\n", 738 | "192\n", 739 | "rewards: sum = 857.7496610625499\n", 740 | "193\n", 741 | "rewards: sum = 870.3293932168983\n", 742 | "194\n", 743 | "rewards: sum = 908.3577337424588\n", 744 | "195\n", 745 | "rewards: sum = 899.3688506135675\n", 746 | "196\n", 747 | "rewards: sum = 873.5455914068026\n", 748 | "197\n", 749 | "rewards: sum = 869.7594461471712\n", 750 | "198\n", 751 | "rewards: sum = 897.6000464546166\n", 752 | "199\n", 753 | "rewards: sum = 882.9467139045704\n", 754 | "200\n", 755 | "rewards: sum = 895.7135742065373\n", 756 | "201\n", 757 | "rewards: sum = 874.6181235762783\n", 758 | "202\n", 759 | "rewards: sum = 882.1051917323938\n", 760 | "203\n", 761 | "rewards: sum = 871.6068109441206\n", 762 | "204\n", 763 | "rewards: sum = 887.7014972571806\n", 764 | "205\n", 765 | "rewards: sum = 898.7713085530722\n", 766 | "206\n", 767 | "rewards: sum = 865.0594020336572\n", 768 | "207\n", 769 | "rewards: sum = 879.7226768549124\n", 770 | "208\n", 771 | "rewards: sum = 858.2652145960723\n", 772 | "209\n", 773 | "rewards: sum = 883.5465473106078\n", 774 | "210\n", 775 | "rewards: sum = 889.4173072025488\n", 776 | "211\n", 777 | "rewards: sum = 897.0758127364392\n", 778 | "212\n", 779 | "rewards: sum = 874.7689052086174\n", 780 | "213\n", 781 | "rewards: sum = 906.4591840207333\n", 782 | "214\n", 783 | "rewards: sum = 888.1710793771435\n", 784 | "215\n", 785 | "rewards: sum = 875.3350440186033\n", 786 | "Time = 325.42\n" 787 | ] 788 | } 789 | ], 790 | "source": [ 791 | "list_of_numsteps=[]\n", 792 | "list_of_rewards = []\n", 793 | "list_of_variants = []\n", 794 | "begin = time.time()\n", 795 | "for k, variant in enumerate(res):\n", 796 | " print(k)\n", 797 | " set_tau = rollouts(\n", 798 | " variant,\n", 799 | " env,\n", 800 | " policy_net,\n", 801 | " action_range,\n", 802 | " device,\n", 803 | " max_timesteps=args.max_timesteps,\n", 804 | " time_horizon=args.H)\n", 805 | " list_of_numsteps.append(len(set_tau['obs'][0]))\n", 806 | " list_of_rewards.append(np.sum(set_tau['rewards'][0]))\n", 807 | " list_of_variants.append(variant)\n", 808 | "end = time.time()\n", 809 | "print(\"Time = {0:.2f}\".format(end-begin))" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": 20, 815 | "metadata": {}, 816 | "outputs": [ 817 | { 818 | "name": "stdout", 819 | "output_type": "stream", 820 | "text": [ 821 | "Percentage of successful runs\n", 822 | "100.0\n" 823 | ] 824 | }, 825 | { 826 | "data": { 827 | "text/html": [ 828 | "
\n", 829 | "\n", 842 | "\n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | "
countmeanstdmin25%50%75%max
Reward216.0886.13564512.723838857.749661876.962865886.816735896.847522908.599989
\n", 870 | "
" 871 | ], 872 | "text/plain": [ 873 | " count mean std min 25% 50% \\\n", 874 | "Reward 216.0 886.135645 12.723838 857.749661 876.962865 886.816735 \n", 875 | "\n", 876 | " 75% max \n", 877 | "Reward 896.847522 908.599989 " 878 | ] 879 | }, 880 | "metadata": {}, 881 | "output_type": "display_data" 882 | }, 883 | { 884 | "name": "stdout", 885 | "output_type": "stream", 886 | "text": [ 887 | "Reward 886.816735\n", 888 | "dtype: float64\n" 889 | ] 890 | } 891 | ], 892 | "source": [ 893 | "df = pd.DataFrame({'Variant': list_of_variants, \"Reward\" : list_of_rewards, 'Len' : list_of_numsteps})\n", 894 | "\n", 895 | "print('Percentage of successful runs')\n", 896 | "print((1 - len(df[df['Len'] < 250])/len(df))*100)\n", 897 | "\n", 898 | "display(df[['Reward']].describe().T)\n", 899 | "\n", 900 | "print(df[['Reward']].median())" 901 | ] 902 | }, 903 | { 904 | "cell_type": "markdown", 905 | "metadata": {}, 906 | "source": [ 907 | "# Anedoctal Run" 908 | ] 909 | }, 910 | { 911 | "cell_type": "code", 912 | "execution_count": 27, 913 | "metadata": {}, 914 | "outputs": [ 915 | { 916 | "name": "stdout", 917 | "output_type": "stream", 918 | "text": [ 919 | "rewards: sum = 842.9171796620125\n" 920 | ] 921 | } 922 | ], 923 | "source": [ 924 | "## Remember to chance headless flag!\n", 925 | "env.pr.stop()\n", 926 | "env.pr.start()\n", 927 | "\n", 928 | "for variant in [[0,0,6,0,1.7,0]]:\n", 929 | " set_tau = rollouts(\n", 930 | " variant,\n", 931 | " env,\n", 932 | " policy_net,\n", 933 | " action_range,\n", 934 | " device,\n", 935 | " max_timesteps=args.max_timesteps,\n", 936 | " time_horizon=args.H)" 937 | ] 938 | } 939 | ], 940 | "metadata": { 941 | "kernelspec": { 942 | "display_name": "Python 3", 943 | "language": "python", 944 | "name": "python3" 945 | }, 946 | "language_info": { 947 | "codemirror_mode": { 948 | "name": "ipython", 949 | "version": 3 950 | }, 951 | "file_extension": ".py", 952 | "mimetype": "text/x-python", 953 | "name": "python", 954 | "nbconvert_exporter": "python", 955 | "pygments_lexer": "ipython3", 956 | "version": "3.7.5" 957 | }, 958 | "latex_envs": { 959 | "LaTeX_envs_menu_present": true, 960 | "autoclose": false, 961 | "autocomplete": true, 962 | "bibliofile": "biblio.bib", 963 | "cite_by": "apalike", 964 | "current_citInitial": 1, 965 | "eqLabelWithNumbers": true, 966 | "eqNumInitial": 1, 967 | "hotkeys": { 968 | "equation": "Ctrl-E", 969 | "itemize": "Ctrl-I" 970 | }, 971 | "labels_anchors": false, 972 | "latex_user_defs": false, 973 | "report_style_numbering": false, 974 | "user_envs_cfg": false 975 | }, 976 | "toc": { 977 | "base_numbering": 1, 978 | "nav_menu": {}, 979 | "number_sections": true, 980 | "sideBar": true, 981 | "skip_h1_title": false, 982 | "title_cell": "Table of Contents", 983 | "title_sidebar": "Contents", 984 | "toc_cell": false, 985 | "toc_position": { 986 | "height": "calc(100% - 180px)", 987 | "left": "10px", 988 | "top": "150px", 989 | "width": "289.097px" 990 | }, 991 | "toc_section_display": true, 992 | "toc_window_display": true 993 | } 994 | }, 995 | "nbformat": 4, 996 | "nbformat_minor": 2 997 | } 998 | -------------------------------------------------------------------------------- /notebooks/figures/4propellers_pwm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/4propellers_pwm.png -------------------------------------------------------------------------------- /notebooks/figures/SAC_Line_3DPosition4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Line_3DPosition4.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Line_Position_x.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Line_Position_x.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Line_Position_y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Line_Position_y.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Line_Position_z.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Line_Position_z.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Senoid_3DPosition4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Senoid_3DPosition4.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Senoid_Position_x.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Senoid_Position_x.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Senoid_Position_y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Senoid_Position_y.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Senoid_Position_z.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Senoid_Position_z.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Square_3DPosition4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Square_3DPosition4.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Square_Position_x.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Square_Position_x.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Square_Position_y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Square_Position_y.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_Square_Position_z.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_Square_Position_z.pdf -------------------------------------------------------------------------------- /notebooks/figures/SAC_newaction_False_Re24_angx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_newaction_False_Re24_angx.png -------------------------------------------------------------------------------- /notebooks/figures/SAC_newaction_False_Re24_angy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_newaction_False_Re24_angy.png -------------------------------------------------------------------------------- /notebooks/figures/SAC_newaction_False_Re24_angz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_newaction_False_Re24_angz.png -------------------------------------------------------------------------------- /notebooks/figures/SAC_newaction_False_Re24_posx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_newaction_False_Re24_posx.png -------------------------------------------------------------------------------- /notebooks/figures/SAC_newaction_False_Re24_posy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_newaction_False_Re24_posy.png -------------------------------------------------------------------------------- /notebooks/figures/SAC_newaction_False_Re24_posz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/notebooks/figures/SAC_newaction_False_Re24_posz.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asn1crypto==0.24.0 2 | cffi==1.11.5 3 | cloudpickle==1.3.0 4 | cryptography==2.1.4 5 | dill==0.3.2 6 | # -e git+https://github.com/larocs/Drone_RL.git@72d5809c8fbf5b52b1bc7a14f04d6c2b5b4618e6#egg=Drone_RL 7 | future==0.18.2 8 | gym==0.17.2 9 | idna==2.6 10 | joblib==0.16.0 11 | jupyter==1.0.0 12 | keyring==10.6.0 13 | keyrings.alt==3.0 14 | numpy==1.19.2 15 | pandas==1.1.2 16 | pycparser==2.20 17 | pycrypto==2.6.1 18 | pyglet==1.5.0 19 | pygobject==3.26.1 20 | # PyRep==1.2 21 | python-apt==1.6.5+ubuntu0.3 22 | python-dateutil==2.8.1 23 | pytz==2020.1 24 | pyxdg==0.25 25 | scipy==1.5.2 26 | SecretStorage==2.3.1 27 | six==1.11.0 28 | torch==1.6.0 29 | unattended-upgrades==0.1 -------------------------------------------------------------------------------- /saved_policies/sac_optimal_policy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/saved_policies/sac_optimal_policy.pt -------------------------------------------------------------------------------- /saved_policies/sac_optimal_policy_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larocs/SAC_uav/1eb42154927843c0aff56b4281b6ba5c9512b325/saved_policies/sac_optimal_policy_2.pt -------------------------------------------------------------------------------- /training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Reward functions 4 | REWARD_FUNCTION=Reward_24 5 | # Random seed 6 | SEED=42 7 | # Cuda or cpu 8 | CUDA=True 9 | # Prefix for experiment name 10 | PREFIX=Experiment_1 11 | ## Initial position probability distribution (\rho_{o}) 12 | env_reset_mode=Discretized_Uniform 13 | ## State-Space 14 | STATE=New_action 15 | 16 | # Training seetings 17 | log_interval=10 18 | max_episodes=100000 19 | eval_interval=100 20 | SAVE_INTERVAL=250 21 | 22 | ## Sac hyperparameters 23 | buffer_size=1000000 24 | BATCH_SIZE=4000 25 | net_size_value=64 26 | net_size_policy=64 27 | num_steps_until_train=1 28 | num_trains_per_step=1 29 | min_num_steps_before_training=0 30 | learning_rate=0.0001 31 | ACT_FUNCTION=tanh 32 | 33 | ## Experiment name 34 | experiment_name=${PREFIX}_state_${STATE}_Reward_${REWARD_FUNCTION}_clipaction_${CLIP_ACTION}_lr_${learning_rate}_bat_${batch_size}_net_${net_size_value}_netpol_${net_size_policy}_ati_${ACT_FUNCTION}_buff_${buffer_size}_numsteps_${num_steps_until_train}_numtrainperstep_${num_trains_per_step}_before_${min_num_steps_before_training}_reset_${env_reset_mode} 35 | # experiment_name=test 36 | 37 | ## SAVING MODEL - default = experiment_name 38 | SAVED_POLICY= 39 | 40 | 41 | python3 main.py --save_path=${experiment_name} --replay_buffer_size=${buffer_size} --restore_path=${SAVED_POLICY} \ 42 | --log_interval=${log_interval} --env_reset_mode=${env_reset_mode} --batch-size=${BATCH_SIZE} \ 43 | --net_size_value=${net_size_value} --net_size_policy=${net_size_policy} --num_steps_until_train=${num_steps_until_train} --num_trains_per_step=${num_trains_per_step} --min_num_steps_before_training=${min_num_steps_before_training} \ 44 | --use_cuda=${CUDA} --seed=${SEED} --eval_interval=100 --reward_function=${REWARD_FUNCTION} --max_episodes=${max_episodes} \ 45 | --learning_rate=${learning_rate} --use_double=True --state=${STATE} \ 46 | --activation_function=${ACT_FUNCTION} --save_interval=${SAVE_INTERVAL} 47 | --------------------------------------------------------------------------------