├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── .pylintrc ├── .readthedocs.yml ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── all ├── .DS_Store ├── __init__.py ├── agents │ ├── __init__.py │ ├── _agent.py │ ├── _multiagent.py │ ├── _parallel_agent.py │ ├── a2c.py │ ├── c51.py │ ├── ddpg.py │ ├── ddqn.py │ ├── dqn.py │ ├── independent.py │ ├── ppo.py │ ├── rainbow.py │ ├── sac.py │ ├── vac.py │ ├── vpg.py │ ├── vqn.py │ └── vsarsa.py ├── approximation │ ├── .DS_Store │ ├── __init__.py │ ├── approximation.py │ ├── checkpointer │ │ └── __init__.py │ ├── feature_network.py │ ├── feature_network_test.py │ ├── identity.py │ ├── identity_test.py │ ├── q_continuous.py │ ├── q_dist.py │ ├── q_dist_test.py │ ├── q_network.py │ ├── q_network_test.py │ ├── target │ │ ├── __init__.py │ │ ├── abstract.py │ │ ├── fixed.py │ │ ├── polyak.py │ │ └── trivial.py │ ├── v_network.py │ └── v_network_test.py ├── bodies │ ├── __init__.py │ ├── _body.py │ ├── atari.py │ ├── rewards.py │ ├── time.py │ ├── time_test.py │ └── vision.py ├── core │ ├── __init__.py │ ├── state.py │ └── state_test.py ├── environments │ ├── __init__.py │ ├── _environment.py │ ├── _multiagent_environment.py │ ├── _vector_environment.py │ ├── atari.py │ ├── atari_test.py │ ├── atari_wrappers.py │ ├── duplicate_env.py │ ├── duplicate_env_test.py │ ├── gym.py │ ├── gym_test.py │ ├── gym_wrappers.py │ ├── gym_wrappers_test.py │ ├── mujoco.py │ ├── mujoco_test.py │ ├── multiagent_atari.py │ ├── multiagent_atari_test.py │ ├── multiagent_pettingzoo.py │ ├── multiagent_pettingzoo_test.py │ ├── pybullet.py │ ├── pybullet_test.py │ ├── vector_env.py │ └── vector_env_test.py ├── experiments │ ├── __init__.py │ ├── experiment.py │ ├── multiagent_env_experiment.py │ ├── multiagent_env_experiment_test.py │ ├── parallel_env_experiment.py │ ├── parallel_env_experiment_test.py │ ├── plots.py │ ├── run_experiment.py │ ├── single_env_experiment.py │ ├── single_env_experiment_test.py │ ├── slurm.py │ ├── watch.py │ └── watch_test.py ├── logging │ ├── __init__.py │ ├── _logger.py │ ├── dummy.py │ └── experiment.py ├── memory │ ├── __init__.py │ ├── advantage.py │ ├── advantage_test.py │ ├── generalized_advantage.py │ ├── generalized_advantage_test.py │ ├── replay_buffer.py │ ├── replay_buffer_test.py │ └── segment_tree.py ├── nn │ ├── __init__.py │ └── nn_test.py ├── optim │ ├── __init__.py │ ├── scheduler.py │ └── scheduler_test.py ├── policies │ ├── __init__.py │ ├── deterministic.py │ ├── deterministic_test.py │ ├── gaussian.py │ ├── gaussian_test.py │ ├── greedy.py │ ├── soft_deterministic.py │ ├── soft_deterministic_test.py │ ├── softmax.py │ └── softmax_test.py ├── presets │ ├── .DS_Store │ ├── __init__.py │ ├── atari │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── a2c.py │ │ ├── c51.py │ │ ├── ddqn.py │ │ ├── dqn.py │ │ ├── models │ │ │ └── __init__.py │ │ ├── ppo.py │ │ ├── rainbow.py │ │ ├── vac.py │ │ ├── vpg.py │ │ ├── vqn.py │ │ └── vsarsa.py │ ├── atari_test.py │ ├── builder.py │ ├── builder_test.py │ ├── classic_control │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── a2c.py │ │ ├── c51.py │ │ ├── ddqn.py │ │ ├── dqn.py │ │ ├── models │ │ │ └── __init__.py │ │ ├── ppo.py │ │ ├── rainbow.py │ │ ├── vac.py │ │ ├── vpg.py │ │ ├── vqn.py │ │ └── vsarsa.py │ ├── classic_control_test.py │ ├── continuous │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── ddpg.py │ │ ├── models │ │ │ └── __init__.py │ │ ├── ppo.py │ │ └── sac.py │ ├── continuous_test.py │ ├── independent_multiagent.py │ ├── multiagent_atari_test.py │ └── preset.py └── scripts │ ├── __init__.py │ ├── plot.py │ ├── release.py │ ├── train.py │ ├── train_atari.py │ ├── train_classic.py │ ├── train_continuous.py │ ├── train_mujoco.py │ ├── train_multiagent_atari.py │ ├── train_pybullet.py │ ├── watch_atari.py │ ├── watch_classic.py │ ├── watch_continuous.py │ ├── watch_mujoco.py │ ├── watch_multiagent_atari.py │ └── watch_pybullet.py ├── benchmarks ├── atari_40m.png ├── atari_40m.py ├── mujoco_v4.png ├── mujoco_v4.py ├── pybullet_v0.png └── pybullet_v0.py ├── docs ├── .gitignore ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── environments.png │ ├── guide │ ├── ale.png │ ├── approximation.jpeg │ ├── basic_concepts.rst │ ├── benchmark_performance.rst │ ├── creating_agent.rst │ ├── getting_started.rst │ ├── plot.png │ ├── rainbow.png │ ├── rl.jpg │ └── tensorboard.png │ ├── index.rst │ └── modules │ ├── agents.rst │ ├── approximation.rst │ ├── bodies.rst │ ├── core.rst │ ├── environments.rst │ ├── experiments.rst │ ├── logging.rst │ ├── memory.rst │ ├── nn.rst │ ├── optim.rst │ ├── policies.rst │ ├── presets.rst │ └── presets │ ├── atari.rst │ ├── classic.rst │ └── continuous.rst ├── examples ├── __init__.py ├── experiment.py └── slurm_experiment.py ├── integration ├── atari_test.py ├── classic_control_test.py ├── continuous_test.py ├── multiagent_atari_test.py └── validate_agent.py └── setup.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master, develop ] 9 | pull_request: 10 | branches: [ master, develop ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.8, 3.11] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install torch~=2.0 --extra-index-url https://download.pytorch.org/whl/cpu 30 | make install 31 | - name: Lint code 32 | run: | 33 | make lint 34 | - name: Run tests 35 | run: | 36 | make test 37 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | runs-on: ubuntu-latest 16 | environment: publish 17 | permissions: 18 | id-token: write 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: 3.11 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install torch~=2.0 --extra-index-url https://download.pytorch.org/whl/cpu 29 | pip install setuptools wheel 30 | make install 31 | - name: Build package 32 | run: make build 33 | - name: Publish package 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | *.pyc 3 | __pycache__ 4 | autonomous_learning_library.egg-info 5 | 6 | # build directories 7 | /build 8 | /dist 9 | 10 | # editor 11 | .vscode 12 | .idea 13 | *.code-workspace 14 | 15 | # non-committed code 16 | local 17 | legacy 18 | /runs 19 | /out 20 | 21 | # notebooks 22 | *.ipynb 23 | *.ipynb_checkpoints 24 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.11" 7 | 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - docs 14 | 15 | sphinx: 16 | configuration: docs/source/conf.py 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions and suggestions are welcome! 4 | If you are interested in contributing either bug fixes or new features, open an issue and we can talk about it! 5 | New PRs will require: 6 | 7 | 1. New unit tests for any new or changed common module, and all unit tests should pass. 8 | 2. All code should follow a similar style to the rest of the repository and the linter should pass. 9 | 3. Documentation of new features. 10 | 4. Manual approval. 11 | 12 | 13 | We use the [GitFlow](https://datasift.github.io/gitflow/IntroducingGitFlow.html) model, meaning that all PRs should be opened against the `develop` branch! 14 | To begin, you can run the following commands: 15 | 16 | ``` 17 | git clone https://github.com/cpnota/autonomous-learning-library.git 18 | cd autonomous-learning-library 19 | git checkout develop 20 | pip install -e .[docs] 21 | ``` 22 | 23 | The unit tests may be run using: 24 | 25 | ``` 26 | make test 27 | ``` 28 | 29 | You can automatically format your code to match our code style using: 30 | 31 | ``` 32 | make format 33 | ``` 34 | 35 | Finally, you rebuild the documentation using: 36 | 37 | ``` 38 | cd docs 39 | make clean && make html 40 | ``` 41 | 42 | Happy hacking! 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Chris Nota 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | pip install -e .[dev] 3 | AutoROM -y --quiet 4 | 5 | test: unit-test integration-test 6 | 7 | unit-test: 8 | python -m unittest discover -s all -p "*test.py" -t . 9 | 10 | integration-test: 11 | python -m unittest discover -s integration -p "*test.py" 12 | 13 | lint: 14 | black --check all benchmarks examples integration setup.py 15 | isort --profile black --check all benchmarks examples integration setup.py 16 | flake8 --select "F401" all benchmarks examples integration setup.py 17 | 18 | format: 19 | black all benchmarks examples integration setup.py 20 | isort --profile black all benchmarks examples integration setup.py 21 | 22 | tensorboard: 23 | tensorboard --logdir runs 24 | 25 | benchmark: 26 | tensorboard --logdir benchmarks/runs --port=6007 27 | 28 | clean: 29 | rm -rf dist 30 | rm -rf build 31 | 32 | build: clean 33 | python setup.py sdist bdist_wheel 34 | 35 | deploy: lint test build 36 | twine upload dist/* 37 | -------------------------------------------------------------------------------- /all/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/.DS_Store -------------------------------------------------------------------------------- /all/__init__.py: -------------------------------------------------------------------------------- 1 | from all.core import State, StateArray 2 | 3 | __all__ = [ 4 | "agents", 5 | "approximation", 6 | "core", 7 | "environments", 8 | "logging", 9 | "memory", 10 | "nn", 11 | "optim", 12 | "policies", 13 | "presets", 14 | "State", 15 | "StateArray", 16 | ] 17 | -------------------------------------------------------------------------------- /all/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from ._agent import Agent 2 | from ._multiagent import Multiagent 3 | from ._parallel_agent import ParallelAgent 4 | from .a2c import A2C, A2CTestAgent 5 | from .c51 import C51, C51TestAgent 6 | from .ddpg import DDPG, DDPGTestAgent 7 | from .ddqn import DDQN, DDQNTestAgent 8 | from .dqn import DQN, DQNTestAgent 9 | from .independent import IndependentMultiagent 10 | from .ppo import PPO, PPOTestAgent 11 | from .rainbow import Rainbow, RainbowTestAgent 12 | from .sac import SAC, SACTestAgent 13 | from .vac import VAC, VACTestAgent 14 | from .vpg import VPG, VPGTestAgent 15 | from .vqn import VQN, VQNTestAgent 16 | from .vsarsa import VSarsa, VSarsaTestAgent 17 | 18 | __all__ = [ 19 | # Agent interfaces 20 | "Agent", 21 | "Multiagent", 22 | "ParallelAgent", 23 | # Agent implementations 24 | "A2C", 25 | "A2CTestAgent", 26 | "C51", 27 | "C51TestAgent", 28 | "DDPG", 29 | "DDPGTestAgent", 30 | "DDQN", 31 | "DDQNTestAgent", 32 | "DQN", 33 | "DQNTestAgent", 34 | "PPO", 35 | "PPOTestAgent", 36 | "Rainbow", 37 | "RainbowTestAgent", 38 | "SAC", 39 | "SACTestAgent", 40 | "VAC", 41 | "VACTestAgent", 42 | "VPG", 43 | "VPGTestAgent", 44 | "VQN", 45 | "VQNTestAgent", 46 | "VSarsa", 47 | "VSarsaTestAgent", 48 | "IndependentMultiagent", 49 | ] 50 | -------------------------------------------------------------------------------- /all/agents/_agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from all.optim import Schedulable 4 | 5 | 6 | class Agent(ABC, Schedulable): 7 | """ 8 | A reinforcement learning agent. 9 | 10 | In reinforcement learning, an Agent learns by interacting with an Environment. 11 | Usually, an Agent tries to maximize a reward signal. 12 | It does this by observing environment "states", taking "actions", receiving "rewards", 13 | and learning which state-action pairs correlate with high rewards. 14 | An Agent implementation should encapsulate some particular reinforcement learning algorithm. 15 | """ 16 | 17 | @abstractmethod 18 | def act(self, state): 19 | """ 20 | Select an action for the current timestep and update internal parameters. 21 | 22 | In general, a reinforcement learning agent does several things during a timestep: 23 | 1. Choose an action, 24 | 2. Compute the TD error from the previous time step 25 | 3. Update the value function and/or policy 26 | The order of these steps differs depending on the agent. 27 | This method allows the agent to do whatever is necessary for itself on a given timestep. 28 | However, the agent must ultimately return an action. 29 | 30 | Args: 31 | state (all.environment.State): The environment state at the current timestep. 32 | 33 | Returns: 34 | torch.Tensor: The action to take at the current timestep. 35 | """ 36 | -------------------------------------------------------------------------------- /all/agents/_multiagent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from all.optim import Schedulable 4 | 5 | 6 | class Multiagent(ABC, Schedulable): 7 | """ 8 | A multiagent RL agent. Differs from standard agents in that it accepts a multiagent state. 9 | 10 | In reinforcement learning, an Agent learns by interacting with an Environment. 11 | Usually, an agent tries to maximize a reward signal. 12 | It does this by observing environment "states", taking "actions", receiving "rewards", 13 | and learning which state-action pairs correlate with high rewards. 14 | An Agent implementation should encapsulate some particular reinforcement learning algorithm. 15 | """ 16 | 17 | @abstractmethod 18 | def act(self, multiagent_state): 19 | """ 20 | Select an action for the current timestep and update internal parameters. 21 | 22 | In general, a reinforcement learning agent does several things during a timestep: 23 | 1. Choose an action, 24 | 2. Compute the TD error from the previous time step 25 | 3. Update the value function and/or policy 26 | The order of these steps differs depending on the agent. 27 | This method allows the agent to do whatever is necessary for itself on a given timestep. 28 | However, the agent must ultimately return an action. 29 | 30 | Args: 31 | multiagent_state (all.core.MultiagentState): The environment state at the current timestep. 32 | 33 | Returns: 34 | torch.Tensor: The action for the current agent to take at the current timestep. 35 | """ 36 | -------------------------------------------------------------------------------- /all/agents/_parallel_agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from all.optim import Schedulable 4 | 5 | 6 | class ParallelAgent(ABC, Schedulable): 7 | """ 8 | A reinforcement learning agent that chooses actions for multiple states simultaneously. 9 | Differs from SingleAgent in that it accepts a StateArray instead of a State to process 10 | input from multiple environments in parallel. 11 | 12 | In reinforcement learning, an Agent learns by interacting with an Environment. 13 | Usually, an Agent tries to maximize a reward signal. 14 | It does this by observing environment "states", taking "actions", receiving "rewards", 15 | and learning which state-action pairs correlate with high rewards. 16 | An Agent implementation should encapsulate some particular reinforcement learning algorithm. 17 | """ 18 | 19 | @abstractmethod 20 | def act(self, state_array): 21 | """ 22 | Select an action for the current timestep and update internal parameters. 23 | 24 | In general, a reinforcement learning agent does several things during a timestep: 25 | 1. Choose an action, 26 | 2. Compute the TD error from the previous time step 27 | 3. Update the value function and/or policy 28 | The order of these steps differs depending on the agent. 29 | This method allows the agent to do whatever is necessary for itself on a given timestep. 30 | However, the agent must ultimately return an action. 31 | 32 | Args: 33 | state_array (all.environment.StateArray): An array of states for each parallel environment. 34 | 35 | Returns: 36 | torch.Tensor: The actions to take for each parallel environmets. 37 | """ 38 | -------------------------------------------------------------------------------- /all/agents/ddqn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.nn import weighted_mse_loss 4 | 5 | from ._agent import Agent 6 | from .dqn import DQNTestAgent 7 | 8 | 9 | class DDQN(Agent): 10 | """ 11 | Double Deep Q-Network (DDQN). 12 | DDQN is an enchancment to DQN that uses a "double Q-style" update, 13 | wherein the online network is used to select target actions 14 | and the target network is used to evaluate these actions. 15 | https://arxiv.org/abs/1509.06461 16 | This agent also adds support for weighted replay buffers, such 17 | as priotized experience replay (PER). 18 | https://arxiv.org/abs/1511.05952 19 | 20 | Args: 21 | q (QNetwork): An Approximation of the Q function. 22 | policy (GreedyPolicy): A policy derived from the Q-function. 23 | replay_buffer (ReplayBuffer): The experience replay buffer. 24 | discount_factor (float): Discount factor for future rewards. 25 | loss (function): The weighted loss function to use. 26 | minibatch_size (int): The number of experiences to sample in each training update. 27 | replay_start_size (int): Number of experiences in replay buffer when training begins. 28 | update_frequency (int): Number of timesteps per training update. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | q, 34 | policy, 35 | replay_buffer, 36 | discount_factor=0.99, 37 | loss=weighted_mse_loss, 38 | minibatch_size=32, 39 | replay_start_size=5000, 40 | update_frequency=1, 41 | ): 42 | # objects 43 | self.q = q 44 | self.policy = policy 45 | self.replay_buffer = replay_buffer 46 | self.loss = loss 47 | # hyperparameters 48 | self.replay_start_size = replay_start_size 49 | self.update_frequency = update_frequency 50 | self.minibatch_size = minibatch_size 51 | self.discount_factor = discount_factor 52 | # private 53 | self._state = None 54 | self._action = None 55 | self._frames_seen = 0 56 | 57 | def act(self, state): 58 | self.replay_buffer.store(self._state, self._action, state) 59 | self._train() 60 | self._state = state 61 | self._action = self.policy.no_grad(state) 62 | return self._action 63 | 64 | def eval(self, state): 65 | return self.policy.eval(state) 66 | 67 | def _train(self): 68 | if self._should_train(): 69 | # sample transitions from buffer 70 | (states, actions, rewards, next_states, weights) = ( 71 | self.replay_buffer.sample(self.minibatch_size) 72 | ) 73 | # forward pass 74 | values = self.q(states, actions) 75 | # compute targets 76 | next_actions = torch.argmax(self.q.no_grad(next_states), dim=1) 77 | targets = rewards + self.discount_factor * self.q.target( 78 | next_states, next_actions 79 | ) 80 | # compute loss 81 | loss = self.loss(values, targets, weights) 82 | # backward pass 83 | self.q.reinforce(loss) 84 | # update replay buffer priorities 85 | td_errors = targets - values 86 | self.replay_buffer.update_priorities(td_errors.abs()) 87 | 88 | def _should_train(self): 89 | self._frames_seen += 1 90 | return ( 91 | self._frames_seen > self.replay_start_size 92 | and self._frames_seen % self.update_frequency == 0 93 | ) 94 | 95 | 96 | DDQNTestAgent = DQNTestAgent 97 | -------------------------------------------------------------------------------- /all/agents/dqn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import mse_loss 3 | 4 | from ._agent import Agent 5 | 6 | 7 | class DQN(Agent): 8 | """ 9 | Deep Q-Network (DQN). 10 | DQN was one of the original deep reinforcement learning algorithms. 11 | It extends the ideas behind Q-learning to work well with modern convolution networks. 12 | The core innovation is the use of a replay buffer, which allows the use of batch-style 13 | updates with decorrelated samples. It also uses a "target" network in order to 14 | improve the stability of updates. 15 | https://www.nature.com/articles/nature14236 16 | 17 | Args: 18 | q (QNetwork): An Approximation of the Q function. 19 | policy (GreedyPolicy): A policy derived from the Q-function. 20 | replay_buffer (ReplayBuffer): The experience replay buffer. 21 | discount_factor (float): Discount factor for future rewards. 22 | exploration (float): The probability of choosing a random action. 23 | loss (function): The weighted loss function to use. 24 | minibatch_size (int): The number of experiences to sample in each training update. 25 | n_actions (int): The number of available actions. 26 | replay_start_size (int): Number of experiences in replay buffer when training begins. 27 | update_frequency (int): Number of timesteps per training update. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | q, 33 | policy, 34 | replay_buffer, 35 | discount_factor=0.99, 36 | loss=mse_loss, 37 | minibatch_size=32, 38 | replay_start_size=5000, 39 | update_frequency=1, 40 | ): 41 | # objects 42 | self.q = q 43 | self.policy = policy 44 | self.replay_buffer = replay_buffer 45 | self.loss = loss 46 | # hyperparameters 47 | self.discount_factor = discount_factor 48 | self.minibatch_size = minibatch_size 49 | self.replay_start_size = replay_start_size 50 | self.update_frequency = update_frequency 51 | # private 52 | self._state = None 53 | self._action = None 54 | self._frames_seen = 0 55 | 56 | def act(self, state): 57 | self.replay_buffer.store(self._state, self._action, state) 58 | self._train() 59 | self._state = state 60 | self._action = self.policy.no_grad(state) 61 | return self._action 62 | 63 | def eval(self, state): 64 | return self.policy.eval(state) 65 | 66 | def _train(self): 67 | if self._should_train(): 68 | # sample transitions from buffer 69 | (states, actions, rewards, next_states, _) = self.replay_buffer.sample( 70 | self.minibatch_size 71 | ) 72 | # forward pass 73 | values = self.q(states, actions) 74 | # compute targets 75 | targets = ( 76 | rewards 77 | + self.discount_factor * torch.max(self.q.target(next_states), dim=1)[0] 78 | ) 79 | # compute loss 80 | loss = self.loss(values, targets) 81 | # backward pass 82 | self.q.reinforce(loss) 83 | 84 | def _should_train(self): 85 | self._frames_seen += 1 86 | return ( 87 | self._frames_seen > self.replay_start_size 88 | and self._frames_seen % self.update_frequency == 0 89 | ) 90 | 91 | 92 | class DQNTestAgent(Agent): 93 | def __init__(self, policy): 94 | self.policy = policy 95 | 96 | def act(self, state): 97 | return self.policy.eval(state) 98 | -------------------------------------------------------------------------------- /all/agents/independent.py: -------------------------------------------------------------------------------- 1 | from ._multiagent import Multiagent 2 | 3 | 4 | class IndependentMultiagent(Multiagent): 5 | def __init__(self, agents): 6 | self.agents = agents 7 | 8 | def act(self, state): 9 | return self.agents[state["agent"]].act(state) 10 | -------------------------------------------------------------------------------- /all/agents/rainbow.py: -------------------------------------------------------------------------------- 1 | from .c51 import C51, C51TestAgent 2 | 3 | 4 | class Rainbow(C51): 5 | """ 6 | Rainbow: Combining Improvements in Deep Reinforcement Learning. 7 | Rainbow combines C51 with 5 other "enhancements" to 8 | DQN: double Q-learning, dueling networks, noisy networks 9 | prioritized reply, n-step rollouts. 10 | https://arxiv.org/abs/1710.02298 11 | 12 | Whether this agent is Rainbow or C51 depends 13 | on the objects that are passed into it. 14 | Dueling networks and noisy networks are part 15 | of the model used for q_dist, while 16 | prioritized replay and n-step rollouts are handled 17 | by the replay buffer. 18 | Double Q-learning is always used. 19 | 20 | Args: 21 | q_dist (QDist): Approximation of the Q distribution. 22 | replay_buffer (ReplayBuffer): The experience replay buffer. 23 | discount_factor (float): Discount factor for future rewards. 24 | eps (float): Stability parameter for computing the loss function. 25 | exploration (float): The probability of choosing a random action. 26 | minibatch_size (int): The number of experiences to sample in 27 | each training update. 28 | replay_start_size (int): Number of experiences in replay buffer 29 | when training begins. 30 | update_frequency (int): Number of timesteps per training update. 31 | """ 32 | 33 | 34 | RainbowTestAgent = C51TestAgent 35 | -------------------------------------------------------------------------------- /all/agents/vac.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import mse_loss 2 | 3 | from ._parallel_agent import ParallelAgent 4 | from .a2c import A2CTestAgent 5 | 6 | 7 | class VAC(ParallelAgent): 8 | """ 9 | Vanilla Actor-Critic (VAC). 10 | VAC is an implementation of the actor-critic alogorithm found in the Sutton and Barto (2018) textbook. 11 | This implementation tweaks the algorithm slightly by using a shared feature layer. 12 | It is also compatible with the use of parallel environments. 13 | https://papers.nips.cc/paper/1786-actor-critic-algorithms.pdf 14 | 15 | Args: 16 | features (FeatureNetwork): Shared feature layers. 17 | v (VNetwork): Value head which approximates the state-value function. 18 | policy (StochasticPolicy): Policy head which outputs an action distribution. 19 | discount_factor (float): Discount factor for future rewards. 20 | n_envs (int): Number of parallel actors/environments 21 | n_steps (int): Number of timesteps per rollout. Updates are performed once per rollout. 22 | logger (Logger): Used for logging. 23 | """ 24 | 25 | def __init__(self, features, v, policy, discount_factor=1): 26 | self.features = features 27 | self.v = v 28 | self.policy = policy 29 | self.discount_factor = discount_factor 30 | self._features = None 31 | self._distribution = None 32 | self._action = None 33 | 34 | def act(self, state): 35 | self._train(state, state.reward) 36 | self._features = self.features(state) 37 | self._distribution = self.policy(self._features) 38 | self._action = self._distribution.sample() 39 | return self._action 40 | 41 | def eval(self, state): 42 | return self.policy.eval(self.features.eval(state)) 43 | 44 | def _train(self, state, reward): 45 | if self._features: 46 | # forward pass 47 | values = self.v(self._features) 48 | 49 | # compute targets 50 | targets = reward + self.discount_factor * self.v.target( 51 | self.features.target(state) 52 | ) 53 | advantages = targets - values.detach() 54 | 55 | # compute losses 56 | value_loss = mse_loss(values, targets) 57 | policy_loss = -( 58 | advantages * self._distribution.log_prob(self._action) 59 | ).mean() 60 | loss = value_loss + policy_loss 61 | 62 | # backward pass 63 | loss.backward() 64 | self.v.step(loss=value_loss) 65 | self.policy.step(loss=policy_loss) 66 | self.features.step() 67 | 68 | 69 | VACTestAgent = A2CTestAgent 70 | -------------------------------------------------------------------------------- /all/agents/vqn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import mse_loss 3 | 4 | from ._agent import Agent 5 | from ._parallel_agent import ParallelAgent 6 | 7 | 8 | class VQN(ParallelAgent): 9 | """ 10 | Vanilla Q-Network (VQN). 11 | VQN is an implementation of the Q-learning algorithm found in the Sutton and Barto (2018) textbook. 12 | Q-learning algorithms attempt to learning the optimal policy while executing a (generally) 13 | suboptimal policy (typically epsilon-greedy). In theory, This allows the agent to gain the benefits 14 | of exploration without sacrificing the performance of the final policy. However, the cost of this 15 | is that Q-learning is generally less stable than its on-policy bretheren, SARSA. 16 | http://www.cs.rhul.ac.uk/~chrisw/new_thesis.pdf 17 | 18 | Args: 19 | q (QNetwork): An Approximation of the Q function. 20 | policy (GreedyPolicy): A policy derived from the Q-function. 21 | discount_factor (float): Discount factor for future rewards. 22 | """ 23 | 24 | def __init__(self, q, policy, discount_factor=0.99): 25 | self.q = q 26 | self.policy = policy 27 | self.discount_factor = discount_factor 28 | self._state = None 29 | self._action = None 30 | 31 | def act(self, state): 32 | self._train(state.reward, state) 33 | action = self.policy.no_grad(state) 34 | self._state = state 35 | self._action = action 36 | return action 37 | 38 | def eval(self, state): 39 | return self.policy.eval(state) 40 | 41 | def _train(self, reward, next_state): 42 | if self._state: 43 | # forward pass 44 | value = self.q(self._state, self._action) 45 | # compute target 46 | target = ( 47 | reward 48 | + self.discount_factor * torch.max(self.q.target(next_state), dim=1)[0] 49 | ) 50 | # compute loss 51 | loss = mse_loss(value, target) 52 | # backward pass 53 | self.q.reinforce(loss) 54 | 55 | 56 | class VQNTestAgent(Agent, ParallelAgent): 57 | def __init__(self, policy): 58 | self.policy = policy 59 | 60 | def act(self, state): 61 | return self.policy.eval(state) 62 | -------------------------------------------------------------------------------- /all/agents/vsarsa.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import mse_loss 2 | 3 | from ._parallel_agent import ParallelAgent 4 | from .vqn import VQNTestAgent 5 | 6 | 7 | class VSarsa(ParallelAgent): 8 | """ 9 | Vanilla SARSA (VSarsa). 10 | SARSA (State-Action-Reward-State-Action) is an on-policy alternative to Q-learning. Unlike Q-learning, 11 | SARSA attempts to learn the Q-function for the current policy rather than the optimal policy. This 12 | approach is more stable but may not result in the optimal policy. However, this problem can be mitigated 13 | by decaying the exploration rate over time. 14 | 15 | Args: 16 | q (QNetwork): An Approximation of the Q function. 17 | policy (GreedyPolicy): A policy derived from the Q-function. 18 | discount_factor (float): Discount factor for future rewards. 19 | """ 20 | 21 | def __init__(self, q, policy, discount_factor=0.99): 22 | self.q = q 23 | self.policy = policy 24 | self.discount_factor = discount_factor 25 | self._state = None 26 | self._action = None 27 | 28 | def act(self, state): 29 | action = self.policy.no_grad(state) 30 | self._train(state.reward, state, action) 31 | self._state = state 32 | self._action = action 33 | return action 34 | 35 | def eval(self, state): 36 | return self.policy.eval(state) 37 | 38 | def _train(self, reward, next_state, next_action): 39 | if self._state: 40 | # forward pass 41 | value = self.q(self._state, self._action) 42 | # compute target 43 | target = reward + self.discount_factor * self.q.target( 44 | next_state, next_action 45 | ) 46 | # compute loss 47 | loss = mse_loss(value, target) 48 | # backward pass 49 | self.q.reinforce(loss) 50 | 51 | 52 | VSarsaTestAgent = VQNTestAgent 53 | -------------------------------------------------------------------------------- /all/approximation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/approximation/.DS_Store -------------------------------------------------------------------------------- /all/approximation/__init__.py: -------------------------------------------------------------------------------- 1 | from .approximation import Approximation 2 | from .checkpointer import Checkpointer, DummyCheckpointer, PeriodicCheckpointer 3 | from .feature_network import FeatureNetwork 4 | from .identity import Identity 5 | from .q_continuous import QContinuous 6 | from .q_dist import QDist 7 | from .q_network import QNetwork 8 | from .target import FixedTarget, PolyakTarget, TargetNetwork, TrivialTarget 9 | from .v_network import VNetwork 10 | 11 | __all__ = [ 12 | "Approximation", 13 | "QContinuous", 14 | "QDist", 15 | "QNetwork", 16 | "VNetwork", 17 | "FeatureNetwork", 18 | "TargetNetwork", 19 | "Identity", 20 | "FixedTarget", 21 | "PolyakTarget", 22 | "TrivialTarget", 23 | "Checkpointer", 24 | "DummyCheckpointer", 25 | "PeriodicCheckpointer", 26 | ] 27 | -------------------------------------------------------------------------------- /all/approximation/checkpointer/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | 6 | 7 | class Checkpointer(ABC): 8 | @abstractmethod 9 | def init(self, model, filename): 10 | pass 11 | 12 | @abstractmethod 13 | def __call__(self): 14 | pass 15 | 16 | 17 | class DummyCheckpointer(Checkpointer): 18 | def init(self, *inputs): 19 | pass 20 | 21 | def __call__(self): 22 | pass 23 | 24 | 25 | class PeriodicCheckpointer(Checkpointer): 26 | def __init__(self, frequency): 27 | self.frequency = frequency 28 | self._updates = 1 29 | self._filename = None 30 | self._model = None 31 | 32 | def init(self, model, filename): 33 | self._model = model 34 | self._filename = filename 35 | # Some builds of pytorch throw this unhelpful warning. 36 | # We can safely disable it. 37 | # https://discuss.pytorch.org/t/got-warning-couldnt-retrieve-source-code-for-container/7689/7 38 | warnings.filterwarnings("ignore", message="Couldn't retrieve source code") 39 | 40 | def __call__(self): 41 | if self._updates % self.frequency == 0: 42 | torch.save(self._model, self._filename) 43 | self._updates += 1 44 | -------------------------------------------------------------------------------- /all/approximation/feature_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .approximation import Approximation 4 | 5 | 6 | class FeatureNetwork(Approximation): 7 | """ 8 | An Approximation that accepts a state updates the observation key 9 | based on the given model. 10 | """ 11 | 12 | def __init__(self, model, optimizer=None, name="feature", **kwargs): 13 | model = FeatureModule(model) 14 | super().__init__(model, optimizer, name=name, **kwargs) 15 | 16 | 17 | class FeatureModule(torch.nn.Module): 18 | def __init__(self, model): 19 | super().__init__() 20 | self.model = model 21 | 22 | def forward(self, states): 23 | features = states.as_output(self.model(states.as_input("observation"))) 24 | return states.update("observation", features) 25 | -------------------------------------------------------------------------------- /all/approximation/feature_network_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch_testing as tt 5 | from torch import nn 6 | 7 | from all.approximation.feature_network import FeatureNetwork 8 | from all.core import State 9 | 10 | STATE_DIM = 2 11 | 12 | 13 | class TestFeatureNetwork(unittest.TestCase): 14 | def setUp(self): 15 | torch.manual_seed(2) 16 | self.model = nn.Sequential(nn.Linear(STATE_DIM, 3)) 17 | 18 | optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) 19 | self.features = FeatureNetwork(self.model, optimizer) 20 | self.states = State( 21 | {"observation": torch.randn(3, STATE_DIM), "mask": torch.tensor([1, 0, 1])} 22 | ) 23 | self.expected_features = State( 24 | { 25 | "observation": torch.tensor( 26 | [ 27 | [-0.2385, -0.7263, -0.0340], 28 | [-0.3569, -0.6612, 0.3485], 29 | [-0.0296, -0.7566, -0.4624], 30 | ] 31 | ), 32 | "mask": torch.tensor([1, 0, 1]), 33 | } 34 | ) 35 | 36 | def test_forward(self): 37 | features = self.features(self.states) 38 | self.assert_state_equal(features, self.expected_features) 39 | 40 | def test_backward(self): 41 | states = self.features(self.states) 42 | loss = torch.tensor(0) 43 | loss = torch.sum(states.observation) 44 | self.features.reinforce(loss) 45 | features = self.features(self.states) 46 | expected = State( 47 | { 48 | "observation": torch.tensor( 49 | [[-0.71, -1.2, -0.5], [-0.72, -1.03, -0.02], [-0.57, -1.3, -1.01]] 50 | ), 51 | "mask": torch.tensor([1, 0, 1]), 52 | } 53 | ) 54 | self.assert_state_equal(features, expected) 55 | 56 | def test_eval(self): 57 | features = self.features.eval(self.states) 58 | self.assert_state_equal(features, self.expected_features) 59 | self.assertFalse(features.observation[0].requires_grad) 60 | 61 | def assert_state_equal(self, actual, expected): 62 | tt.assert_almost_equal(actual.observation, expected.observation, decimal=2) 63 | tt.assert_equal(actual.mask, expected.mask) 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /all/approximation/identity.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .approximation import Approximation 4 | 5 | 6 | class Identity(Approximation): 7 | """ 8 | An Approximation that represents the identity function. 9 | 10 | Because the model has no parameters, reinforce and step do nothing. 11 | """ 12 | 13 | def __init__(self, device, name="identity", **kwargs): 14 | super().__init__(nn.Identity(), None, device=device, name=name, **kwargs) 15 | 16 | def reinforce(self): 17 | return self 18 | 19 | def step(self): 20 | return self 21 | -------------------------------------------------------------------------------- /all/approximation/identity_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch_testing as tt 5 | 6 | from all.approximation import FixedTarget, Identity 7 | from all.core import State 8 | 9 | 10 | class TestIdentityNetwork(unittest.TestCase): 11 | def setUp(self): 12 | self.model = Identity("cpu", target=FixedTarget(10)) 13 | 14 | def test_forward_tensor(self): 15 | inputs = torch.tensor([1, 2, 3]) 16 | outputs = self.model(inputs) 17 | tt.assert_equal(inputs, outputs) 18 | 19 | def test_forward_state(self): 20 | inputs = State({"observation": torch.tensor([1, 2, 3])}) 21 | outputs = self.model(inputs) 22 | self.assertEqual(inputs, outputs) 23 | 24 | def test_eval(self): 25 | inputs = torch.tensor([1, 2, 3]) 26 | outputs = self.model.target(inputs) 27 | tt.assert_equal(inputs, outputs) 28 | 29 | def test_target(self): 30 | inputs = torch.tensor([1, 2, 3]) 31 | outputs = self.model.target(inputs) 32 | tt.assert_equal(inputs, outputs) 33 | 34 | def test_reinforce(self): 35 | self.model.reinforce() 36 | 37 | def test_step(self): 38 | self.model.step() 39 | 40 | 41 | if __name__ == "__main__": 42 | unittest.main() 43 | -------------------------------------------------------------------------------- /all/approximation/q_continuous.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.nn import RLNetwork 4 | 5 | from .approximation import Approximation 6 | 7 | 8 | class QContinuous(Approximation): 9 | def __init__(self, model, optimizer, name="q", **kwargs): 10 | model = QContinuousModule(model) 11 | super().__init__(model, optimizer, name=name, **kwargs) 12 | 13 | 14 | class QContinuousModule(RLNetwork): 15 | def forward(self, states, actions): 16 | x = torch.cat((states.observation.float(), actions), dim=1) 17 | return self.model(x).squeeze(-1) * states.mask.float() 18 | -------------------------------------------------------------------------------- /all/approximation/q_dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from all import nn 5 | 6 | from .approximation import Approximation 7 | 8 | 9 | class QDist(Approximation): 10 | def __init__( 11 | self, 12 | model, 13 | optimizer, 14 | n_actions, 15 | n_atoms, 16 | v_min, 17 | v_max, 18 | name="q_dist", 19 | **kwargs 20 | ): 21 | device = next(model.parameters()).device 22 | self.n_actions = n_actions 23 | self.atoms = torch.linspace(v_min, v_max, steps=n_atoms).to(device) 24 | model = QDistModule(model, n_actions, self.atoms) 25 | super().__init__(model, optimizer, name=name, **kwargs) 26 | 27 | def project(self, dist, support): 28 | target_dist = dist * 0 29 | atoms = self.atoms 30 | v_min = atoms[0] 31 | v_max = atoms[-1] 32 | delta_z = atoms[1] - atoms[0] 33 | batch_size = len(dist) 34 | n_atoms = len(atoms) 35 | # vectorized implementation of Algorithm 1 36 | tz_j = support.clamp(v_min, v_max) 37 | bj = (tz_j - v_min) / delta_z 38 | l = bj.floor().clamp(0, len(atoms) - 1) 39 | u = bj.ceil().clamp(0, len(atoms) - 1) 40 | # This part is a little tricky: 41 | # We have to flatten the matrix first and use index_add. 42 | # This approach is taken from Curt Park (under the MIT license): 43 | # https://github.com/Curt-Park/rainbow-is-all-you-need/blob/master/08.rainbow.ipynb 44 | offset = ( 45 | torch.linspace(0, (batch_size - 1) * n_atoms, batch_size) 46 | .long() 47 | .unsqueeze(1) 48 | .expand(batch_size, n_atoms) 49 | .to(self.device) 50 | ) 51 | target_dist.view(-1).index_add_( 52 | 0, (l.long() + offset).view(-1), (dist * (u - bj)).view(-1) 53 | ) 54 | target_dist.view(-1).index_add_( 55 | 0, (u.long() + offset).view(-1), (dist * (bj - l)).view(-1) 56 | ) 57 | return target_dist 58 | 59 | 60 | class QDistModule(torch.nn.Module): 61 | def __init__(self, model, n_actions, atoms): 62 | super().__init__() 63 | self.atoms = atoms 64 | self.n_actions = n_actions 65 | self.n_atoms = len(atoms) 66 | self.device = next(model.parameters()).device 67 | self.terminal = torch.zeros((self.n_atoms)).to(self.device) 68 | self.terminal[(self.n_atoms // 2)] = 1.0 69 | self.model = nn.RLNetwork(model) 70 | self.count = 0 71 | 72 | def forward(self, states, actions=None): 73 | values = self.model(states).view((len(states), self.n_actions, self.n_atoms)) 74 | values = F.softmax(values, dim=2) 75 | mask = states.mask 76 | 77 | # trick to convert to terminal without manually looping 78 | if torch.is_tensor(mask): 79 | values = (values - self.terminal) * states.mask.view( 80 | (-1, 1, 1) 81 | ).float() + self.terminal 82 | else: 83 | values = (values - self.terminal) * mask + self.terminal 84 | 85 | if actions is None: 86 | return values 87 | if isinstance(actions, list): 88 | actions = torch.cat(actions) 89 | return values[torch.arange(len(states)), actions] 90 | 91 | def to(self, device): 92 | self.device = device 93 | self.atoms = self.atoms.to(device) 94 | self.terminal = self.terminal.to(device) 95 | return super().to(device) 96 | -------------------------------------------------------------------------------- /all/approximation/q_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.nn import RLNetwork 4 | 5 | from .approximation import Approximation 6 | 7 | 8 | class QNetwork(Approximation): 9 | def __init__(self, model, optimizer=None, name="q", **kwargs): 10 | model = QModule(model) 11 | super().__init__(model, optimizer, name=name, **kwargs) 12 | 13 | 14 | class QModule(RLNetwork): 15 | def forward(self, states, actions=None): 16 | values = super().forward(states) 17 | if actions is None: 18 | return values 19 | if isinstance(actions, list): 20 | actions = torch.tensor(actions, device=self.device) 21 | return values.gather(1, actions.view(-1, 1)).squeeze(1) 22 | -------------------------------------------------------------------------------- /all/approximation/q_network_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import torch_testing as tt 6 | from torch import nn 7 | from torch.nn.functional import smooth_l1_loss 8 | 9 | from all.approximation import FixedTarget, QNetwork 10 | from all.core import State, StateArray 11 | 12 | STATE_DIM = 2 13 | ACTIONS = 3 14 | 15 | 16 | class TestQNetwork(unittest.TestCase): 17 | def setUp(self): 18 | torch.manual_seed(2) 19 | self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS)) 20 | 21 | def optimizer(params): 22 | return torch.optim.SGD(params, lr=0.1) 23 | 24 | self.q = QNetwork(self.model, optimizer) 25 | 26 | def test_eval_list(self): 27 | states = StateArray( 28 | torch.randn(5, STATE_DIM), (5,), mask=torch.tensor([1, 1, 0, 1, 0]) 29 | ) 30 | result = self.q.eval(states) 31 | tt.assert_almost_equal( 32 | result, 33 | torch.tensor( 34 | [ 35 | [-0.238509, -0.726287, -0.034026], 36 | [-0.35688755, -0.6612102, 0.34849477], 37 | [0.0, 0.0, 0.0], 38 | [0.1944, -0.5536, -0.2345], 39 | [0.0, 0.0, 0.0], 40 | ] 41 | ), 42 | decimal=2, 43 | ) 44 | 45 | def test_eval_actions(self): 46 | states = StateArray(torch.randn(3, STATE_DIM), (3,)) 47 | actions = [1, 2, 0] 48 | result = self.q.eval(states, actions) 49 | self.assertEqual(result.shape, torch.Size([3])) 50 | tt.assert_almost_equal( 51 | result, torch.tensor([-0.7262873, 0.3484948, -0.0296164]) 52 | ) 53 | 54 | def test_target_net(self): 55 | torch.manual_seed(2) 56 | model = nn.Sequential(nn.Linear(1, 1)) 57 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 58 | q = QNetwork(model, optimizer, target=FixedTarget(3)) 59 | inputs = State(torch.tensor([1.0])) 60 | 61 | def loss(policy_value): 62 | target = policy_value - 1 63 | return smooth_l1_loss(policy_value, target.detach()) 64 | 65 | policy_value = q(inputs) 66 | target_value = q.target(inputs).item() 67 | np.testing.assert_equal(policy_value.item(), -0.008584141731262207) 68 | np.testing.assert_equal(target_value, -0.008584141731262207) 69 | 70 | q.reinforce(loss(policy_value)) 71 | policy_value = q(inputs) 72 | target_value = q.target(inputs).item() 73 | np.testing.assert_equal(policy_value.item(), -0.20858412981033325) 74 | np.testing.assert_equal(target_value, -0.008584141731262207) 75 | 76 | q.reinforce(loss(policy_value)) 77 | policy_value = q(inputs) 78 | target_value = q.target(inputs).item() 79 | np.testing.assert_equal(policy_value.item(), -0.4085841178894043) 80 | np.testing.assert_equal(target_value, -0.008584141731262207) 81 | 82 | q.reinforce(loss(policy_value)) 83 | policy_value = q(inputs) 84 | target_value = q.target(inputs).item() 85 | np.testing.assert_equal(policy_value.item(), -0.6085841655731201) 86 | np.testing.assert_equal(target_value, -0.6085841655731201) 87 | 88 | q.reinforce(loss(policy_value)) 89 | policy_value = q(inputs) 90 | target_value = q.target(inputs).item() 91 | np.testing.assert_equal(policy_value.item(), -0.8085841536521912) 92 | np.testing.assert_equal(target_value, -0.6085841655731201) 93 | 94 | 95 | if __name__ == "__main__": 96 | unittest.main() 97 | -------------------------------------------------------------------------------- /all/approximation/target/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract import TargetNetwork 2 | from .fixed import FixedTarget 3 | from .polyak import PolyakTarget 4 | from .trivial import TrivialTarget 5 | 6 | __all__ = ["TargetNetwork", "FixedTarget", "PolyakTarget", "TrivialTarget"] 7 | -------------------------------------------------------------------------------- /all/approximation/target/abstract.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class TargetNetwork(ABC): 5 | @abstractmethod 6 | def __call__(self, *inputs): 7 | pass 8 | 9 | @abstractmethod 10 | def init(self, model): 11 | pass 12 | 13 | @abstractmethod 14 | def update(self): 15 | pass 16 | -------------------------------------------------------------------------------- /all/approximation/target/fixed.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from .abstract import TargetNetwork 6 | 7 | 8 | class FixedTarget(TargetNetwork): 9 | def __init__(self, update_frequency): 10 | self._source = None 11 | self._target = None 12 | self._updates = 0 13 | self._update_frequency = update_frequency 14 | 15 | def __call__(self, *inputs): 16 | with torch.no_grad(): 17 | return self._target(*inputs) 18 | 19 | def init(self, model): 20 | self._source = model 21 | self._target = copy.deepcopy(model) 22 | 23 | def update(self): 24 | self._updates += 1 25 | if self._should_update(): 26 | self._target.load_state_dict(self._source.state_dict()) 27 | 28 | def _should_update(self): 29 | return self._updates % self._update_frequency == 0 30 | -------------------------------------------------------------------------------- /all/approximation/target/polyak.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from .abstract import TargetNetwork 6 | 7 | 8 | class PolyakTarget(TargetNetwork): 9 | """TargetNetwork that updates using polyak averaging""" 10 | 11 | def __init__(self, rate): 12 | self._source = None 13 | self._target = None 14 | self._rate = rate 15 | 16 | def __call__(self, *inputs): 17 | with torch.no_grad(): 18 | return self._target(*inputs) 19 | 20 | def init(self, model): 21 | self._source = model 22 | self._target = copy.deepcopy(model) 23 | 24 | def update(self): 25 | for target_param, source_param in zip( 26 | self._target.parameters(), self._source.parameters() 27 | ): 28 | target_param.data.copy_( 29 | target_param.data * (1.0 - self._rate) + source_param.data * self._rate 30 | ) 31 | -------------------------------------------------------------------------------- /all/approximation/target/trivial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .abstract import TargetNetwork 4 | 5 | 6 | class TrivialTarget(TargetNetwork): 7 | def __init__(self): 8 | self._model = None 9 | 10 | def __call__(self, *inputs): 11 | with torch.no_grad(): 12 | return self._model(*inputs) 13 | 14 | def init(self, model): 15 | self._model = model 16 | 17 | def update(self): 18 | pass 19 | -------------------------------------------------------------------------------- /all/approximation/v_network.py: -------------------------------------------------------------------------------- 1 | from all.nn import RLNetwork 2 | 3 | from .approximation import Approximation 4 | 5 | 6 | class VNetwork(Approximation): 7 | def __init__(self, model, optimizer, name="v", **kwargs): 8 | model = VModule(model) 9 | super().__init__(model, optimizer, name=name, **kwargs) 10 | 11 | 12 | class VModule(RLNetwork): 13 | def forward(self, states): 14 | return super().forward(states).squeeze(-1) 15 | -------------------------------------------------------------------------------- /all/approximation/v_network_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch_testing as tt 5 | from torch import nn 6 | 7 | from all.approximation.v_network import VNetwork 8 | from all.core import StateArray 9 | 10 | STATE_DIM = 2 11 | 12 | 13 | def loss(value, error): 14 | target = value + error 15 | return ((target.detach() - value) ** 2).mean() 16 | 17 | 18 | class TestVNetwork(unittest.TestCase): 19 | def setUp(self): 20 | torch.manual_seed(2) 21 | self.model = nn.Sequential(nn.Linear(STATE_DIM, 1)) 22 | 23 | optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) 24 | self.v = VNetwork(self.model, optimizer) 25 | 26 | def test_reinforce_list(self): 27 | states = StateArray( 28 | torch.randn(5, STATE_DIM), (5,), mask=torch.tensor([1, 1, 0, 1, 0]) 29 | ) 30 | result = self.v(states) 31 | tt.assert_almost_equal( 32 | result, torch.tensor([0.7053187, 0.3975691, 0.0, 0.2701665, 0.0]) 33 | ) 34 | 35 | self.v.reinforce(loss(result, torch.tensor([1, -1, 1, 1, 1])).float()) 36 | result = self.v(states) 37 | tt.assert_almost_equal( 38 | result, torch.tensor([0.9732854, 0.5453826, 0.0, 0.4344811, 0.0]) 39 | ) 40 | 41 | def test_multi_reinforce(self): 42 | states = StateArray( 43 | torch.randn(6, STATE_DIM), (6,), mask=torch.tensor([1, 1, 0, 1, 0, 0, 0]) 44 | ) 45 | result1 = self.v(states[0:2]) 46 | self.v.reinforce(loss(result1, torch.tensor([1, 2])).float()) 47 | result2 = self.v(states[2:4]) 48 | self.v.reinforce(loss(result2, torch.tensor([1, 1])).float()) 49 | result3 = self.v(states[4:6]) 50 | self.v.reinforce(loss(result3, torch.tensor([1, 2])).float()) 51 | with self.assertRaises(Exception): 52 | self.v.reinforce(loss(result3, torch.tensor([1, 2])).float()) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /all/bodies/__init__.py: -------------------------------------------------------------------------------- 1 | from ._body import Body 2 | from .atari import DeepmindAtariBody 3 | from .rewards import ClipRewards 4 | from .time import TimeFeature 5 | from .vision import FrameStack 6 | 7 | __all__ = ["Body", "ClipRewards", "DeepmindAtariBody", "FrameStack", "TimeFeature"] 8 | -------------------------------------------------------------------------------- /all/bodies/_body.py: -------------------------------------------------------------------------------- 1 | from all.agents import Agent 2 | 3 | 4 | class Body(Agent): 5 | """ 6 | A Body wraps a reinforcement learning Agent, altering its inputs and outputs. 7 | 8 | The Body API is identical to the Agent API from the perspective of the 9 | rest of the system. This base class is provided only for semantic clarity. 10 | """ 11 | 12 | def __init__(self, agent): 13 | self._agent = agent 14 | 15 | @property 16 | def agent(self): 17 | return self._agent 18 | 19 | @agent.setter 20 | def agent(self, agent): 21 | self._agent = agent 22 | 23 | def act(self, state): 24 | return self.process_action(self.agent.act(self.process_state(state))) 25 | 26 | def eval(self, state): 27 | return self.process_action(self.agent.eval(self.process_state(state))) 28 | 29 | def process_state(self, state): 30 | return state 31 | 32 | def process_action(self, action): 33 | return action 34 | -------------------------------------------------------------------------------- /all/bodies/atari.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ._body import Body 4 | from .rewards import ClipRewards 5 | from .vision import FrameStack 6 | 7 | 8 | class DeepmindAtariBody(Body): 9 | def __init__( 10 | self, 11 | agent, 12 | lazy_frames=False, 13 | episodic_lives=True, 14 | frame_stack=4, 15 | clip_rewards=True, 16 | ): 17 | if frame_stack > 1: 18 | agent = FrameStack(agent, lazy=lazy_frames, size=frame_stack) 19 | if clip_rewards: 20 | agent = ClipRewards(agent) 21 | if episodic_lives: 22 | agent = EpisodicLives(agent) 23 | super().__init__(agent) 24 | 25 | 26 | class EpisodicLives(Body): 27 | def process_state(self, state): 28 | if "life_lost" not in state: 29 | return state 30 | 31 | if len(state.shape) == 0: 32 | if state["life_lost"]: 33 | return state.update("mask", 0.0) 34 | return state 35 | 36 | masks = [None] * len(state) 37 | life_lost = state["life_lost"] 38 | for i, old_mask in enumerate(state.mask): 39 | if life_lost[i]: 40 | masks[i] = 0.0 41 | else: 42 | masks[i] = old_mask 43 | return state.update("mask", torch.tensor(masks, device=state.device)) 44 | -------------------------------------------------------------------------------- /all/bodies/rewards.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ._body import Body 5 | 6 | 7 | class ClipRewards(Body): 8 | def process_state(self, state): 9 | return state.update("reward", self._clip(state.reward)) 10 | 11 | def _clip(self, reward): 12 | if torch.is_tensor(reward): 13 | return torch.sign(reward) 14 | return float(np.sign(reward)) 15 | -------------------------------------------------------------------------------- /all/bodies/time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.core import StateArray 4 | 5 | from ._body import Body 6 | 7 | 8 | class TimeFeature(Body): 9 | def __init__(self, agent, scale=0.001): 10 | self.timestep = None 11 | self.scale = scale 12 | super().__init__(agent) 13 | 14 | def process_state(self, state): 15 | if isinstance(state, StateArray): 16 | if self.timestep is None: 17 | self.timestep = torch.zeros(state.shape, device=state.device) 18 | observation = torch.cat( 19 | (state.observation, self.scale * self.timestep.view(-1, 1)), dim=1 20 | ) 21 | state = state.update("observation", observation) 22 | self.timestep = state.mask.float() * (self.timestep + 1) 23 | return state 24 | 25 | if self.timestep is None: 26 | self.timestep = 0 27 | state.update("timestep", self.timestep) 28 | observation = torch.cat( 29 | ( 30 | state.observation, 31 | torch.tensor(self.scale * self.timestep, device=state.device).view(-1), 32 | ), 33 | dim=0, 34 | ) 35 | state = state.update("observation", observation) 36 | self.timestep = state.mask * (self.timestep + 1) 37 | return state 38 | -------------------------------------------------------------------------------- /all/bodies/vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.core import State, StateArray 4 | 5 | from ._body import Body 6 | 7 | 8 | class FrameStack(Body): 9 | def __init__(self, agent, size=4, lazy=False): 10 | super().__init__(agent) 11 | self._frames = [] 12 | self._size = size 13 | self._lazy = lazy 14 | self._to_cache = TensorDeviceCache() 15 | 16 | def process_state(self, state): 17 | if not self._frames: 18 | self._frames = [state.observation] * self._size 19 | else: 20 | self._frames = self._frames[1:] + [state.observation] 21 | if self._lazy: 22 | return LazyState.from_state(state, self._frames, self._to_cache) 23 | if isinstance(state, StateArray): 24 | return state.update("observation", torch.cat(self._frames, dim=1)) 25 | return state.update("observation", torch.cat(self._frames, dim=0)) 26 | 27 | 28 | class TensorDeviceCache: 29 | """ 30 | To efficiently implement device trasfer of lazy states, this class 31 | caches the transfered tensor so that it is not copied multiple times. 32 | """ 33 | 34 | def __init__(self, max_size=16): 35 | self.max_size = max_size 36 | self.cache_data = [] 37 | 38 | def convert(self, value, device): 39 | cached = None 40 | for el in self.cache_data: 41 | if el[0] is value: 42 | cached = el[1] 43 | break 44 | if cached is not None and cached.device == torch.device(device): 45 | new_v = cached 46 | else: 47 | new_v = value.to(device) 48 | self.cache_data.append((value, new_v)) 49 | if len(self.cache_data) > self.max_size: 50 | self.cache_data.pop(0) 51 | return new_v 52 | 53 | 54 | class LazyState(State): 55 | @classmethod 56 | def from_state(cls, state, frames, to_cache): 57 | state = LazyState(state, device=frames[0].device) 58 | state.to_cache = to_cache 59 | state["observation"] = frames 60 | return state 61 | 62 | def __getitem__(self, key): 63 | if key == "observation": 64 | v = dict.__getitem__(self, key) 65 | if torch.is_tensor(v): 66 | return v 67 | return torch.cat(dict.__getitem__(self, key), dim=0) 68 | return super().__getitem__(key) 69 | 70 | def update(self, key, value): 71 | x = {} 72 | for k in self.keys(): 73 | if not k == key: 74 | x[k] = dict.__getitem__(self, k) 75 | x[key] = value 76 | state = LazyState.from_state(x, x["observation"], self.to_cache) 77 | return state 78 | 79 | def to(self, device): 80 | if device == self.device: 81 | return self 82 | x = {} 83 | for key, value in self.items(): 84 | if key == "observation": 85 | x[key] = [self.to_cache.convert(v, device) for v in value] 86 | # x[key] = [v.to(device) for v in value]#torch.cat(value,axis=0).to(device) 87 | elif torch.is_tensor(value): 88 | x[key] = value.to(device) 89 | else: 90 | x[key] = value 91 | state = LazyState.from_state(x, x["observation"], self.to_cache) 92 | return state 93 | -------------------------------------------------------------------------------- /all/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .state import MultiagentState, State, StateArray 2 | 3 | __all__ = ["State", "StateArray", "MultiagentState"] 4 | -------------------------------------------------------------------------------- /all/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from ._environment import Environment 2 | from ._multiagent_environment import MultiagentEnvironment 3 | from ._vector_environment import VectorEnvironment 4 | from .atari import AtariEnvironment 5 | from .duplicate_env import DuplicateEnvironment 6 | from .gym import GymEnvironment 7 | from .mujoco import MujocoEnvironment 8 | from .multiagent_atari import MultiagentAtariEnv 9 | from .multiagent_pettingzoo import MultiagentPettingZooEnv 10 | from .pybullet import PybulletEnvironment 11 | from .vector_env import GymVectorEnvironment 12 | 13 | __all__ = [ 14 | "AtariEnvironment", 15 | "DuplicateEnvironment", 16 | "Environment", 17 | "GymEnvironment", 18 | "GymVectorEnvironment", 19 | "MultiagentAtariEnv", 20 | "MultiagentEnvironment", 21 | "MultiagentPettingZooEnv", 22 | "MujocoEnvironment", 23 | "PybulletEnvironment", 24 | "VectorEnvironment", 25 | ] 26 | -------------------------------------------------------------------------------- /all/environments/_environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Environment(ABC): 5 | """ 6 | A reinforcement learning Environment. 7 | 8 | In reinforcement learning, an Agent learns by interacting with an Environment. 9 | An Environment defines the dynamics of a particular problem: 10 | the states, the actions, the transitions between states, and the rewards given to the agent. 11 | Environments are often used to benchmark reinforcement learning agents, 12 | or to define real problems that the user hopes to solve using reinforcement learning. 13 | """ 14 | 15 | @property 16 | @abstractmethod 17 | def name(self): 18 | """ 19 | The name of the environment. 20 | """ 21 | 22 | @abstractmethod 23 | def reset(self): 24 | """ 25 | Reset the environment and return a new initial state. 26 | 27 | Returns 28 | ------- 29 | State 30 | The initial state for the next episode. 31 | """ 32 | 33 | @abstractmethod 34 | def step(self, action): 35 | """ 36 | Apply an action and get the next state. 37 | 38 | Parameters 39 | ---------- 40 | action : Action 41 | The action to apply at the current time step. 42 | 43 | Returns 44 | ------- 45 | all.environments.State 46 | The State of the environment after the action is applied. 47 | This State object includes both the done flag and any additional "info" 48 | float 49 | The reward achieved by the previous action 50 | """ 51 | 52 | @abstractmethod 53 | def render(self, **kwargs): 54 | """ 55 | Render the current environment state. 56 | """ 57 | 58 | @abstractmethod 59 | def close(self): 60 | """ 61 | Clean up any extraneous environment objects. 62 | """ 63 | 64 | @property 65 | @abstractmethod 66 | def state(self): 67 | """ 68 | The State of the Environment at the current timestep. 69 | """ 70 | 71 | @property 72 | @abstractmethod 73 | def state_space(self): 74 | """ 75 | The Space representing the range of observable states. 76 | 77 | Returns 78 | ------- 79 | Space 80 | An object of type Space that represents possible states the agent may observe 81 | """ 82 | 83 | @property 84 | def observation_space(self): 85 | """ 86 | Alias for Environment.state_space. 87 | 88 | Returns 89 | ------- 90 | Space 91 | An object of type Space that represents possible states the agent may observe 92 | """ 93 | return self.state_space 94 | 95 | @property 96 | @abstractmethod 97 | def action_space(self): 98 | """ 99 | The Space representing the range of possible actions. 100 | 101 | Returns 102 | ------- 103 | Space 104 | An object of type Space that represents possible actions the agent may take 105 | """ 106 | 107 | @abstractmethod 108 | def duplicate(self, n): 109 | """ 110 | Create n copies of this environment. 111 | """ 112 | 113 | @property 114 | @abstractmethod 115 | def device(self): 116 | """ 117 | The torch device the environment lives on. 118 | """ 119 | -------------------------------------------------------------------------------- /all/environments/_multiagent_environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class MultiagentEnvironment(ABC): 5 | """ 6 | A multiagent reinforcement learning Environment. 7 | 8 | The Multiagent variant of the Environment object. 9 | An Environment defines the dynamics of a particular problem: 10 | the states, the actions, the transitions between states, and the rewards given to the agent. 11 | Environments are often used to benchmark reinforcement learning agents, 12 | or to define real problems that the user hopes to solve using reinforcement learning. 13 | """ 14 | 15 | @abstractmethod 16 | def reset(self): 17 | """ 18 | Reset the environment and return a new initial state for the first agent. 19 | 20 | Returns 21 | all.core.MultiagentState: The initial state for the next episode. 22 | """ 23 | 24 | @abstractmethod 25 | def step(self, action): 26 | """ 27 | Apply an action for the current agent and get the multiagent state for the next agent. 28 | 29 | Parameters: 30 | action: The Action for the current agent and timestep. 31 | 32 | Returns: 33 | all.core.MultiagentState: The state for the next agent. 34 | """ 35 | 36 | @abstractmethod 37 | def render(self, **kwargs): 38 | """Render the current environment state.""" 39 | 40 | @abstractmethod 41 | def close(self): 42 | """Clean up any extraneous environment objects.""" 43 | 44 | @abstractmethod 45 | def agent_iter(self): 46 | """ 47 | Create an iterable which that the next element is always the name of the agent whose turn it is to act. 48 | 49 | Returns: 50 | An Iterable over Agent strings. 51 | """ 52 | 53 | @abstractmethod 54 | def last(self): 55 | """ 56 | Get the MultiagentState object for the current agent. 57 | 58 | Returns: 59 | The all.core.MultiagentState object for the current agent. 60 | """ 61 | 62 | @abstractmethod 63 | def is_done(self, agent): 64 | """ 65 | Determine whether a given agent is done. 66 | 67 | Args: 68 | agent (str): The name of the agent. 69 | 70 | Returns: 71 | A boolean representing whether the given agent is done. 72 | """ 73 | 74 | @property 75 | def state(self): 76 | """The State for the current agent.""" 77 | return self.last() 78 | 79 | @property 80 | @abstractmethod 81 | def name(self): 82 | """str: The name of the environment.""" 83 | 84 | @abstractmethod 85 | def state_space(self, agent_id): 86 | """The state space for the given agent.""" 87 | 88 | def observation_space(self, agent_id): 89 | """Alias for MultiagentEnvironment.state_space(agent_id).""" 90 | return self.state_space(agent_id) 91 | 92 | @abstractmethod 93 | def action_space(self): 94 | """The action space for the given agent.""" 95 | 96 | @property 97 | @abstractmethod 98 | def device(self): 99 | """ 100 | The torch device the environment lives on. 101 | """ 102 | -------------------------------------------------------------------------------- /all/environments/_vector_environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class VectorEnvironment(ABC): 5 | """ 6 | A reinforcement learning vector Environment. 7 | 8 | Similar to a regular RL environment except many environments are stacked together 9 | in the observations, rewards, and dones, and the vector environment expects 10 | an action to be given for each environment in step. 11 | 12 | Also, since sub-environments are done at different times, you do not need to 13 | manually reset the environments when they are done, rather the vector environment 14 | automatically resets environments when they are complete. 15 | """ 16 | 17 | @property 18 | @abstractmethod 19 | def name(self): 20 | """ 21 | The name of the environment. 22 | """ 23 | 24 | @abstractmethod 25 | def reset(self): 26 | """ 27 | Reset the environment and return a new initial state. 28 | 29 | Returns 30 | ------- 31 | State 32 | The initial state for the next episode. 33 | """ 34 | 35 | @abstractmethod 36 | def step(self, action): 37 | """ 38 | Apply an action and get the next state. 39 | 40 | Parameters 41 | ---------- 42 | action : Action 43 | The action to apply at the current time step. 44 | 45 | Returns 46 | ------- 47 | all.environments.State 48 | The State of the environment after the action is applied. 49 | This State object includes both the done flag and any additional "info" 50 | float 51 | The reward achieved by the previous action 52 | """ 53 | 54 | @abstractmethod 55 | def close(self): 56 | """ 57 | Clean up any extraneous environment objects. 58 | """ 59 | 60 | @property 61 | @abstractmethod 62 | def state_array(self): 63 | """ 64 | A StateArray of the Environments at the current timestep. 65 | """ 66 | 67 | @property 68 | @abstractmethod 69 | def state_space(self): 70 | """ 71 | The Space representing the range of observable states for each environment. 72 | 73 | Returns 74 | ------- 75 | Space 76 | An object of type Space that represents possible states the agent may observe 77 | """ 78 | 79 | @property 80 | def observation_space(self): 81 | """ 82 | Alias for Environment.state_space. 83 | 84 | Returns 85 | ------- 86 | Space 87 | An object of type Space that represents possible states the agent may observe 88 | """ 89 | return self.state_space 90 | 91 | @property 92 | @abstractmethod 93 | def action_space(self): 94 | """ 95 | The Space representing the range of possible actions for each environment. 96 | 97 | Returns 98 | ------- 99 | Space 100 | An object of type Space that represents possible actions the agent may take 101 | """ 102 | 103 | @property 104 | @abstractmethod 105 | def device(self): 106 | """ 107 | The torch device the environment lives on. 108 | """ 109 | 110 | @property 111 | @abstractmethod 112 | def num_envs(self): 113 | """ 114 | Number of environments in vector. This is the number of actions step() expects as input 115 | and the number of observations, dones, etc returned by the environment. 116 | """ 117 | -------------------------------------------------------------------------------- /all/environments/atari.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import torch 3 | 4 | from all.core import State 5 | 6 | from ._environment import Environment 7 | from .atari_wrappers import ( 8 | FireResetEnv, 9 | LifeLostEnv, 10 | MaxAndSkipEnv, 11 | NoopResetEnv, 12 | WarpFrame, 13 | ) 14 | from .duplicate_env import DuplicateEnvironment 15 | 16 | 17 | class AtariEnvironment(Environment): 18 | def __init__(self, name, device="cpu", **gym_make_kwargs): 19 | 20 | # construct the environment 21 | env = gymnasium.make(name + "NoFrameskip-v4", **gym_make_kwargs) 22 | 23 | # apply a subset of wrappers 24 | env = NoopResetEnv(env, noop_max=30) 25 | env = MaxAndSkipEnv(env) 26 | if "FIRE" in env.unwrapped.get_action_meanings(): 27 | env = FireResetEnv(env) 28 | env = WarpFrame(env) 29 | env = LifeLostEnv(env) 30 | 31 | # initialize member variables 32 | self._env = env 33 | self._name = name 34 | self._state = None 35 | self._action = None 36 | self._reward = None 37 | self._done = True 38 | self._info = None 39 | self._device = device 40 | 41 | def reset(self): 42 | self._state = State.from_gym( 43 | self._env.reset(), 44 | dtype=self._env.observation_space.dtype, 45 | device=self._device, 46 | ) 47 | return self._state 48 | 49 | def step(self, action): 50 | self._state = State.from_gym( 51 | self._env.step(self._convert(action)), 52 | dtype=self._env.observation_space.dtype, 53 | device=self._device, 54 | ) 55 | return self._state 56 | 57 | def render(self, **kwargs): 58 | return self._env.render(**kwargs) 59 | 60 | def close(self): 61 | return self._env.close() 62 | 63 | def seed(self, seed): 64 | self._env.seed(seed) 65 | 66 | def duplicate(self, n): 67 | return DuplicateEnvironment( 68 | [AtariEnvironment(self._name, device=self._device) for _ in range(n)] 69 | ) 70 | 71 | @property 72 | def name(self): 73 | return self._name 74 | 75 | @property 76 | def state_space(self): 77 | return self._env.observation_space 78 | 79 | @property 80 | def action_space(self): 81 | return self._env.action_space 82 | 83 | @property 84 | def state(self): 85 | return self._state 86 | 87 | @property 88 | def env(self): 89 | return self._env 90 | 91 | @property 92 | def device(self): 93 | return self._device 94 | 95 | def _convert(self, action): 96 | if torch.is_tensor(action): 97 | return action.item() 98 | return action 99 | -------------------------------------------------------------------------------- /all/environments/atari_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from all.environments import AtariEnvironment 4 | 5 | 6 | class AtariEnvironmentTest(unittest.TestCase): 7 | def test_reset(self): 8 | env = AtariEnvironment("Breakout") 9 | state = env.reset() 10 | self.assertEqual(state.observation.shape, (1, 84, 84)) 11 | self.assertEqual(state.reward, 0) 12 | self.assertFalse(state.done) 13 | self.assertEqual(state.mask, 1) 14 | self.assertEqual(state["life_lost"], False) 15 | 16 | def test_step(self): 17 | env = AtariEnvironment("Breakout") 18 | env.reset() 19 | state = env.step(1) 20 | self.assertEqual(state.observation.shape, (1, 84, 84)) 21 | self.assertEqual(state.reward, 0) 22 | self.assertFalse(state.done) 23 | self.assertEqual(state.mask, 1) 24 | self.assertEqual(state["life_lost"], False) 25 | 26 | def test_step_until_life_lost(self): 27 | env = AtariEnvironment("Breakout") 28 | env.reset() 29 | for _ in range(100): 30 | state = env.step(1) 31 | if state["life_lost"]: 32 | break 33 | self.assertEqual(state.observation.shape, (1, 84, 84)) 34 | self.assertEqual(state.reward, 0) 35 | self.assertFalse(state.done) 36 | self.assertEqual(state.mask, 1) 37 | self.assertEqual(state["life_lost"], True) 38 | 39 | def test_step_until_done(self): 40 | env = AtariEnvironment("Breakout") 41 | env.reset() 42 | for _ in range(1000): 43 | state = env.step(1) 44 | if state.done: 45 | break 46 | self.assertEqual(state.observation.shape, (1, 84, 84)) 47 | self.assertEqual(state.reward, 0) 48 | self.assertTrue(state.done) 49 | self.assertEqual(state.mask, 0) 50 | self.assertEqual(state["life_lost"], False) 51 | -------------------------------------------------------------------------------- /all/environments/duplicate_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.core import State 4 | 5 | from ._vector_environment import VectorEnvironment 6 | 7 | 8 | class DuplicateEnvironment(VectorEnvironment): 9 | """ 10 | Turns a list of ALL Environment objects into a VectorEnvironment object 11 | 12 | This wrapper just takes the list of States the environments generate and outputs 13 | a StateArray object containing all of the environment states. Like all vector 14 | environments, the sub environments are automatically reset when done. 15 | 16 | Args: 17 | envs: A list of ALL environments 18 | device (optional): the device on which tensors will be stored 19 | """ 20 | 21 | def __init__(self, envs, device=torch.device("cpu")): 22 | self._name = envs[0].name 23 | self._envs = envs 24 | self._state = None 25 | self._action = None 26 | self._reward = None 27 | self._done = True 28 | self._info = None 29 | self._device = device 30 | 31 | @property 32 | def name(self): 33 | return self._name 34 | 35 | def reset(self, seed=None, **kwargs): 36 | if seed is not None: 37 | self._state = State.array( 38 | [ 39 | sub_env.reset(seed=(seed + i), **kwargs) 40 | for i, sub_env in enumerate(self._envs) 41 | ] 42 | ) 43 | else: 44 | self._state = State.array( 45 | [sub_env.reset(**kwargs) for sub_env in self._envs] 46 | ) 47 | return self._state 48 | 49 | def step(self, actions): 50 | states = [] 51 | actions = actions.cpu().detach().numpy() 52 | for sub_env, action in zip(self._envs, actions): 53 | state = sub_env.reset() if sub_env.state.done else sub_env.step(action) 54 | states.append(state) 55 | self._state = State.array(states) 56 | return self._state 57 | 58 | def close(self): 59 | return self._env.close() 60 | 61 | @property 62 | def state_space(self): 63 | return self._envs[0].observation_space 64 | 65 | @property 66 | def action_space(self): 67 | return self._envs[0].action_space 68 | 69 | @property 70 | def state_array(self): 71 | return self._state 72 | 73 | @property 74 | def device(self): 75 | return self._device 76 | 77 | @property 78 | def num_envs(self): 79 | return len(self._envs) 80 | -------------------------------------------------------------------------------- /all/environments/duplicate_env_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from all.environments import DuplicateEnvironment, GymEnvironment 6 | 7 | 8 | def make_vec_env(num_envs=3): 9 | env = [GymEnvironment("CartPole-v0") for i in range(num_envs)] 10 | return env 11 | 12 | 13 | class DuplicateEnvironmentTest(unittest.TestCase): 14 | def test_env_name(self): 15 | env = DuplicateEnvironment(make_vec_env()) 16 | self.assertEqual(env.name, "CartPole-v0") 17 | 18 | def test_num_envs(self): 19 | num_envs = 5 20 | env = DuplicateEnvironment(make_vec_env(num_envs)) 21 | self.assertEqual(env.num_envs, num_envs) 22 | self.assertEqual((num_envs,), env.reset().shape) 23 | 24 | def test_reset(self): 25 | num_envs = 5 26 | env = DuplicateEnvironment(make_vec_env(num_envs)) 27 | state = env.reset() 28 | self.assertEqual(state.observation.shape, (num_envs, 4)) 29 | self.assertTrue( 30 | ( 31 | state.reward 32 | == torch.zeros( 33 | num_envs, 34 | ) 35 | ).all() 36 | ) 37 | self.assertTrue( 38 | ( 39 | state.done 40 | == torch.zeros( 41 | num_envs, 42 | ) 43 | ).all() 44 | ) 45 | self.assertTrue( 46 | ( 47 | state.mask 48 | == torch.ones( 49 | num_envs, 50 | ) 51 | ).all() 52 | ) 53 | 54 | def test_step(self): 55 | num_envs = 5 56 | env = DuplicateEnvironment(make_vec_env(num_envs)) 57 | env.reset() 58 | state = env.step(torch.ones(num_envs, dtype=torch.int32)) 59 | self.assertEqual(state.observation.shape, (num_envs, 4)) 60 | self.assertTrue( 61 | ( 62 | state.reward 63 | == torch.ones( 64 | num_envs, 65 | ) 66 | ).all() 67 | ) 68 | self.assertTrue( 69 | ( 70 | state.done 71 | == torch.zeros( 72 | num_envs, 73 | ) 74 | ).all() 75 | ) 76 | self.assertTrue( 77 | ( 78 | state.mask 79 | == torch.ones( 80 | num_envs, 81 | ) 82 | ).all() 83 | ) 84 | 85 | def test_step_until_done(self): 86 | num_envs = 3 87 | env = DuplicateEnvironment(make_vec_env(num_envs)) 88 | env.reset() 89 | for _ in range(100): 90 | state = env.step(torch.ones(num_envs, dtype=torch.int32)) 91 | if state.done[0]: 92 | break 93 | self.assertEqual(state[0].observation.shape, (4,)) 94 | self.assertEqual(state[0].reward, 1.0) 95 | self.assertTrue(state[0].done) 96 | self.assertEqual(state[0].mask, 0) 97 | -------------------------------------------------------------------------------- /all/environments/gym_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gym 4 | import gymnasium 5 | import torch 6 | 7 | from all.environments import GymEnvironment 8 | 9 | 10 | class GymEnvironmentTest(unittest.TestCase): 11 | def test_env_name(self): 12 | env = GymEnvironment("CartPole-v0") 13 | self.assertEqual(env.name, "CartPole-v0") 14 | 15 | def test_reset(self): 16 | env = GymEnvironment("CartPole-v0") 17 | state = env.reset() 18 | self.assertEqual(state.observation.shape, (4,)) 19 | self.assertEqual(state.reward, 0) 20 | self.assertFalse(state.done) 21 | self.assertEqual(state.mask, 1) 22 | 23 | def test_step(self): 24 | env = GymEnvironment("CartPole-v0") 25 | env.reset() 26 | state = env.step(1) 27 | self.assertEqual(state.observation.shape, (4,)) 28 | self.assertEqual(state.reward, 1.0) 29 | self.assertFalse(state.done) 30 | self.assertEqual(state.mask, 1) 31 | 32 | def test_step_until_done(self): 33 | env = GymEnvironment("CartPole-v0") 34 | env.reset() 35 | for _ in range(100): 36 | state = env.step(1) 37 | if state.done: 38 | break 39 | self.assertEqual(state.observation.shape, (4,)) 40 | self.assertEqual(state.reward, 1.0) 41 | self.assertTrue(state.done) 42 | self.assertEqual(state.mask, 0) 43 | 44 | def test_duplicate_default_params(self): 45 | env = GymEnvironment("CartPole-v0") 46 | duplicates = env.duplicate(5) 47 | for duplicate in duplicates._envs: 48 | self.assertEqual(duplicate._id, "CartPole-v0") 49 | self.assertEqual(duplicate._name, "CartPole-v0") 50 | self.assertEqual(env._device, torch.device("cpu")) 51 | self.assertEqual(env._gym, gymnasium) 52 | 53 | def test_duplicate_custom_params(self): 54 | class MyWrapper: 55 | def __init__(self, env): 56 | self._env = env 57 | 58 | env = GymEnvironment( 59 | "CartPole-v0", 60 | legacy_gym=True, 61 | name="legacy_cartpole", 62 | device="my_device", 63 | wrap_env=MyWrapper, 64 | ) 65 | duplicates = env.duplicate(5) 66 | for duplicate in duplicates._envs: 67 | self.assertEqual(duplicate._id, "CartPole-v0") 68 | self.assertEqual(duplicate._name, "legacy_cartpole") 69 | self.assertEqual(env._device, "my_device") 70 | self.assertEqual(env._gym, gym) 71 | -------------------------------------------------------------------------------- /all/environments/gym_wrappers.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | 3 | 4 | class NoInfoWrapper(gymnasium.Wrapper): 5 | """ 6 | Wrapper to suppress info and simply return a dict. 7 | This prevents State.from_gym() from create keys. 8 | """ 9 | 10 | def reset(self, seed=None, options=None): 11 | obs, _ = self.env.reset(seed=seed, options=options) 12 | return obs, {} 13 | 14 | def step(self, action): 15 | *obs, info = self.env.step(action) 16 | return *obs, {} 17 | -------------------------------------------------------------------------------- /all/environments/gym_wrappers_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/environments/gym_wrappers_test.py -------------------------------------------------------------------------------- /all/environments/mujoco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .gym import GymEnvironment 4 | from .gym_wrappers import NoInfoWrapper 5 | 6 | 7 | class MujocoEnvironment(GymEnvironment): 8 | """A Mujoco Environment""" 9 | 10 | def __init__( 11 | self, id, device=torch.device("cpu"), name=None, no_info=True, **gym_make_kwargs 12 | ): 13 | wrap_env = NoInfoWrapper if no_info else None 14 | super().__init__( 15 | id, device=device, name=name, wrap_env=wrap_env, **gym_make_kwargs 16 | ) 17 | -------------------------------------------------------------------------------- /all/environments/mujoco_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from all.environments import MujocoEnvironment 4 | 5 | 6 | class MujocoEnvironmentTest(unittest.TestCase): 7 | def test_load_env(self): 8 | env = MujocoEnvironment("Ant-v4") 9 | self.assertEqual(env.name, "Ant-v4") 10 | 11 | def test_observation_space(self): 12 | env = MujocoEnvironment("Ant-v4") 13 | self.assertEqual(env.observation_space.shape, (27,)) 14 | 15 | def test_action_space(self): 16 | env = MujocoEnvironment("Ant-v4") 17 | self.assertEqual(env.action_space.shape, (8,)) 18 | 19 | def test_reset(self): 20 | env = MujocoEnvironment("Ant-v4") 21 | state = env.reset(seed=0) 22 | self.assertEqual(state.observation.shape, (27,)) 23 | self.assertEqual(state.reward, 0.0) 24 | self.assertFalse(state.done) 25 | self.assertEqual(state.mask, 1) 26 | 27 | def test_step(self): 28 | env = MujocoEnvironment("Ant-v4") 29 | state = env.reset(seed=0) 30 | state = env.step(env.action_space.sample()) 31 | self.assertEqual(state.observation.shape, (27,)) 32 | self.assertGreater(state.reward, -2.0) 33 | self.assertLess(state.reward, 2) 34 | self.assertNotEqual(state.reward, 0.0) 35 | self.assertFalse(state.done) 36 | self.assertEqual(state.mask, 1) 37 | 38 | def test_no_info_wrapper(self): 39 | env = MujocoEnvironment("Ant-v4") 40 | state = env.reset(seed=0) 41 | self.assertFalse("reward_forward" in state) 42 | state = env.step(env.action_space.sample()) 43 | self.assertFalse("reward_forward" in state) 44 | 45 | def test_with_info(self): 46 | env = MujocoEnvironment("Ant-v4", no_info=False) 47 | state = env.reset(seed=0) 48 | state = env.step(env.action_space.sample()) 49 | self.assertTrue("reward_forward" in state) 50 | 51 | def test_duplicate(self): 52 | env = MujocoEnvironment("Ant-v4") 53 | duplicates = env.duplicate(2) 54 | for duplicate in duplicates._envs: 55 | state = duplicate.reset() 56 | self.assertFalse("reward_forward" in state) 57 | state = duplicate.step(env.action_space.sample()) 58 | self.assertFalse("reward_forward" in state) 59 | -------------------------------------------------------------------------------- /all/environments/multiagent_atari.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from .multiagent_pettingzoo import MultiagentPettingZooEnv 4 | 5 | 6 | class MultiagentAtariEnv(MultiagentPettingZooEnv): 7 | """ 8 | A wrapper for PettingZoo Atari environments (see: https://www.pettingzoo.ml/atari). 9 | 10 | This wrapper converts the output of the PettingZoo environment to PyTorch tensors, 11 | and wraps them in a State object that can be passed to an Agent. 12 | 13 | Args: 14 | env_name (string): A string representing the name of the environment (e.g. pong-v1) 15 | device (optional): the device on which tensors will be stored 16 | """ 17 | 18 | def __init__(self, env_name, device="cuda", **pettingzoo_params): 19 | env = self._load_env(env_name, pettingzoo_params) 20 | super().__init__(env, name=env_name, device=device) 21 | 22 | def _load_env(self, env_name, pettingzoo_params): 23 | from supersuit import frame_skip_v0, max_observation_v0, reshape_v0, resize_v1 24 | 25 | env = importlib.import_module("pettingzoo.atari.{}".format(env_name)).env( 26 | obs_type="grayscale_image", **pettingzoo_params 27 | ) 28 | env = max_observation_v0(env, 2) 29 | env = frame_skip_v0(env, 4) 30 | env = resize_v1(env, 84, 84) 31 | env = reshape_v0(env, (1, 84, 84)) 32 | return env 33 | -------------------------------------------------------------------------------- /all/environments/multiagent_atari_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from all.environments import MultiagentAtariEnv 6 | 7 | 8 | class MultiagentAtariEnvTest(unittest.TestCase): 9 | def test_init(self): 10 | MultiagentAtariEnv("pong_v3", device="cpu") 11 | MultiagentAtariEnv("mario_bros_v3", device="cpu") 12 | MultiagentAtariEnv("entombed_cooperative_v3", device="cpu") 13 | 14 | def test_reset(self): 15 | env = MultiagentAtariEnv("pong_v3", device="cpu") 16 | state = env.reset() 17 | self.assertEqual(state.observation.shape, (1, 84, 84)) 18 | self.assertEqual(state.reward, 0) 19 | self.assertEqual(state.done, False) 20 | self.assertEqual(state.mask, 1.0) 21 | self.assertEqual(state["agent"], "first_0") 22 | 23 | def test_step(self): 24 | env = MultiagentAtariEnv("pong_v3", device="cpu") 25 | env.reset() 26 | state = env.step(0) 27 | self.assertEqual(state.observation.shape, (1, 84, 84)) 28 | self.assertEqual(state.reward, 0) 29 | self.assertEqual(state.done, False) 30 | self.assertEqual(state.mask, 1.0) 31 | self.assertEqual(state["agent"], "second_0") 32 | 33 | def test_step_tensor(self): 34 | env = MultiagentAtariEnv("pong_v3", device="cpu") 35 | env.reset() 36 | state = env.step(torch.tensor([0])) 37 | self.assertEqual(state.observation.shape, (1, 84, 84)) 38 | self.assertEqual(state.reward, 0) 39 | self.assertEqual(state.done, False) 40 | self.assertEqual(state.mask, 1.0) 41 | self.assertEqual(state["agent"], "second_0") 42 | 43 | def test_name(self): 44 | env = MultiagentAtariEnv("pong_v3", device="cpu") 45 | self.assertEqual(env.name, "pong_v3") 46 | 47 | def test_agent_iter(self): 48 | env = MultiagentAtariEnv("pong_v3", device="cpu") 49 | env.reset() 50 | it = iter(env.agent_iter()) 51 | self.assertEqual(next(it), "first_0") 52 | 53 | def test_state_spaces(self): 54 | env = MultiagentAtariEnv("pong_v3", device="cpu") 55 | self.assertEqual(env.state_space("first_0").shape, (1, 84, 84)) 56 | self.assertEqual(env.state_space("second_0").shape, (1, 84, 84)) 57 | 58 | def test_action_spaces(self): 59 | env = MultiagentAtariEnv("pong_v3", device="cpu") 60 | self.assertEqual(env.action_space("first_0").n, 6) 61 | self.assertEqual(env.action_space("second_0").n, 6) 62 | 63 | def test_list_agents(self): 64 | env = MultiagentAtariEnv("pong_v3", device="cpu") 65 | self.assertEqual(env.agents, ["first_0", "second_0"]) 66 | 67 | def test_is_done(self): 68 | env = MultiagentAtariEnv("pong_v3", device="cpu") 69 | env.reset() 70 | self.assertFalse(env.is_done("first_0")) 71 | self.assertFalse(env.is_done("second_0")) 72 | 73 | def test_last(self): 74 | env = MultiagentAtariEnv("pong_v3", device="cpu") 75 | env.reset() 76 | state = env.last() 77 | self.assertEqual(state.observation.shape, (1, 84, 84)) 78 | self.assertEqual(state.reward, 0) 79 | self.assertEqual(state.done, False) 80 | self.assertEqual(state.mask, 1.0) 81 | self.assertEqual(state["agent"], "first_0") 82 | 83 | 84 | if __name__ == "__main__": 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /all/environments/pybullet.py: -------------------------------------------------------------------------------- 1 | from .gym import GymEnvironment 2 | 3 | 4 | class PybulletEnvironment(GymEnvironment): 5 | short_names = { 6 | "ant": "AntBulletEnv-v0", 7 | "cheetah": "HalfCheetahBulletEnv-v0", 8 | "hopper": "HopperBulletEnv-v0", 9 | "humanoid": "HumanoidBulletEnv-v0", 10 | "walker": "Walker2DBulletEnv-v0", 11 | } 12 | 13 | def __init__(self, name, **kwargs): 14 | import pybullet_envs # noqa: F401 15 | 16 | if name in self.short_names: 17 | name = self.short_names[name] 18 | super().__init__(name, legacy_gym=True, **kwargs) 19 | -------------------------------------------------------------------------------- /all/environments/pybullet_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from all.environments import PybulletEnvironment 6 | 7 | 8 | class PybulletEnvironmentTest(unittest.TestCase): 9 | def test_env_short_name(self): 10 | for short_name, long_name in PybulletEnvironment.short_names.items(): 11 | env = PybulletEnvironment(short_name) 12 | self.assertEqual(env.name, long_name) 13 | 14 | def test_env_full_name(self): 15 | env = PybulletEnvironment("HalfCheetahBulletEnv-v0") 16 | self.assertEqual(env.name, "HalfCheetahBulletEnv-v0") 17 | 18 | def test_reset(self): 19 | env = PybulletEnvironment("cheetah") 20 | state = env.reset() 21 | self.assertEqual(state.observation.shape, (26,)) 22 | self.assertEqual(state.reward, 0.0) 23 | self.assertFalse(state.done) 24 | self.assertEqual(state.mask, 1) 25 | 26 | def test_step(self): 27 | env = PybulletEnvironment("cheetah") 28 | env.seed(0) 29 | state = env.reset() 30 | state = env.step(env.action_space.sample()) 31 | self.assertEqual(state.observation.shape, (26,)) 32 | self.assertGreater(state.reward, -1.0) 33 | self.assertLess(state.reward, 1) 34 | self.assertNotEqual(state.reward, 0.0) 35 | self.assertFalse(state.done) 36 | self.assertEqual(state.mask, 1) 37 | 38 | def test_duplicate(self): 39 | env = PybulletEnvironment("cheetah") 40 | duplicates = env.duplicate(3) 41 | state = duplicates.reset() 42 | self.assertEqual(state.shape, (3,)) 43 | state = duplicates.step(torch.zeros(3, env.action_space.shape[0])) 44 | self.assertEqual(state.shape, (3,)) 45 | -------------------------------------------------------------------------------- /all/environments/vector_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from all.core import StateArray 5 | 6 | from ._vector_environment import VectorEnvironment 7 | 8 | 9 | class GymVectorEnvironment(VectorEnvironment): 10 | """ 11 | A wrapper for Gym's vector environments 12 | (see: https://github.com/openai/gym/blob/master/gym/vector/vector_env.py). 13 | 14 | This wrapper converts the output of the vector environment to PyTorch tensors, 15 | and wraps them in a StateArray object that can be passed to a Parallel Agent. 16 | This constructor accepts a preconstructed gym vector environment. Note that 17 | in the latter case, the name property is set to be the whatever the name 18 | of the outermost wrapper on the environment is. 19 | 20 | Args: 21 | vec_env: An OpenAI gym vector environment 22 | device (optional): the device on which tensors will be stored 23 | """ 24 | 25 | def __init__(self, vec_env, name, device=torch.device("cpu")): 26 | self._name = name 27 | self._env = vec_env 28 | self._state = None 29 | self._action = None 30 | self._reward = None 31 | self._done = True 32 | self._info = None 33 | self._device = device 34 | 35 | @property 36 | def name(self): 37 | return self._name 38 | 39 | def reset(self, **kwargs): 40 | obs, info = self._env.reset(**kwargs) 41 | self._state = self._to_state( 42 | obs, 43 | np.zeros(self._env.num_envs), 44 | np.zeros(self._env.num_envs), 45 | np.zeros(self._env.num_envs), 46 | info, 47 | ) 48 | return self._state 49 | 50 | def _to_state(self, obs, rew, terminated, truncated, info): 51 | obs = obs.astype(self.observation_space.dtype) 52 | rew = rew.astype("float32") 53 | done = (terminated + truncated).astype("bool") 54 | mask = (1 - terminated).astype("float32") 55 | return StateArray( 56 | { 57 | "observation": torch.tensor(obs, device=self._device), 58 | "reward": torch.tensor(rew, device=self._device), 59 | "done": torch.tensor(done, device=self._device), 60 | "mask": torch.tensor(mask, device=self._device), 61 | }, 62 | shape=(self._env.num_envs,), 63 | ) 64 | 65 | def step(self, action): 66 | state_tuple = self._env.step(action.cpu().detach().numpy()) 67 | self._state = self._to_state(*state_tuple) 68 | return self._state 69 | 70 | def close(self): 71 | return self._env.close() 72 | 73 | @property 74 | def state_space(self): 75 | return getattr( 76 | self._env, 77 | "single_observation_space", 78 | getattr(self._env, "observation_space"), 79 | ) 80 | 81 | @property 82 | def action_space(self): 83 | return getattr( 84 | self._env, "single_action_space", getattr(self._env, "action_space") 85 | ) 86 | 87 | @property 88 | def state_array(self): 89 | return self._state 90 | 91 | @property 92 | def device(self): 93 | return self._device 94 | 95 | @property 96 | def num_envs(self): 97 | return self._env.num_envs 98 | -------------------------------------------------------------------------------- /all/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment import Experiment 2 | from .multiagent_env_experiment import MultiagentEnvExperiment 3 | from .parallel_env_experiment import ParallelEnvExperiment 4 | from .plots import plot_returns_100 5 | from .run_experiment import run_experiment 6 | from .single_env_experiment import SingleEnvExperiment 7 | from .slurm import SlurmExperiment 8 | from .watch import load_and_watch, watch 9 | 10 | __all__ = [ 11 | "run_experiment", 12 | "Experiment", 13 | "SingleEnvExperiment", 14 | "ParallelEnvExperiment", 15 | "MultiagentEnvExperiment", 16 | "SlurmExperiment", 17 | "watch", 18 | "load_and_watch", 19 | "plot_returns_100", 20 | ] 21 | -------------------------------------------------------------------------------- /all/experiments/plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def plot_returns_100(runs_dir, timesteps=-1): 8 | data = load_returns_100_data(runs_dir) 9 | lines = {} 10 | fig, axes = plt.subplots(1, len(data)) 11 | if len(data) == 1: 12 | axes = [axes] 13 | for i, env in enumerate(sorted(data.keys())): 14 | ax = axes[i] 15 | subplot_returns_100(ax, env, data[env], lines, timesteps=timesteps) 16 | fig.legend(list(lines.values()), list(lines.keys()), loc="center right") 17 | plt.show() 18 | 19 | 20 | def load_returns_100_data(runs_dir): 21 | data = {} 22 | 23 | def add_data(agent, env, file): 24 | if env not in data: 25 | data[env] = {} 26 | data[env][agent] = np.genfromtxt(file, delimiter=",").reshape((-1, 5)) 27 | 28 | for agent_dir in os.listdir(runs_dir): 29 | agent, env, *_ = agent_dir.split("_") 30 | agent_path = os.path.join(runs_dir, agent_dir) 31 | if os.path.isdir(agent_path): 32 | returns100path = os.path.join(agent_path, "returns100.csv") 33 | if os.path.exists(returns100path): 34 | add_data(agent, env, returns100path) 35 | 36 | return data 37 | 38 | 39 | def subplot_returns_100(ax, env, data, lines, timesteps=-1): 40 | for agent in data: 41 | agent_data = data[agent] 42 | x = agent_data[:, 0] 43 | mean = agent_data[:, 1] 44 | std = agent_data[:, 2] 45 | 46 | if timesteps > 0: 47 | x[-1] = timesteps 48 | 49 | if agent in lines: 50 | ax.plot(x, mean, label=agent, color=lines[agent].get_color()) 51 | else: 52 | (line,) = ax.plot(x, mean, label=agent) 53 | lines[agent] = line 54 | ax.fill_between( 55 | x, mean + std, mean - std, alpha=0.2, color=lines[agent].get_color() 56 | ) 57 | ax.set_title(env) 58 | ax.set_xlabel("timesteps") 59 | ax.ticklabel_format(style="sci", axis="x", scilimits=(0, 5)) 60 | -------------------------------------------------------------------------------- /all/experiments/run_experiment.py: -------------------------------------------------------------------------------- 1 | from all.presets import ParallelPreset 2 | 3 | from .parallel_env_experiment import ParallelEnvExperiment 4 | from .single_env_experiment import SingleEnvExperiment 5 | 6 | 7 | def run_experiment( 8 | agents, 9 | envs, 10 | frames, 11 | logdir="runs", 12 | quiet=False, 13 | render=False, 14 | save_freq=100, 15 | test_episodes=100, 16 | verbose=True, 17 | ): 18 | if not isinstance(agents, list): 19 | agents = [agents] 20 | 21 | if not isinstance(envs, list): 22 | envs = [envs] 23 | 24 | for env in envs: 25 | for preset_builder in agents: 26 | preset = preset_builder.env(env).build() 27 | make_experiment = get_experiment_type(preset) 28 | experiment = make_experiment( 29 | preset, 30 | env, 31 | train_steps=frames, 32 | logdir=logdir, 33 | quiet=quiet, 34 | render=render, 35 | save_freq=save_freq, 36 | verbose=verbose, 37 | ) 38 | experiment.save() 39 | experiment.train(frames=frames) 40 | experiment.save() 41 | experiment.test(episodes=test_episodes) 42 | experiment.close() 43 | 44 | 45 | def get_experiment_type(preset): 46 | if isinstance(preset, ParallelPreset): 47 | return ParallelEnvExperiment 48 | return SingleEnvExperiment 49 | -------------------------------------------------------------------------------- /all/experiments/watch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import torch 5 | 6 | 7 | def watch(agent, env, fps=60, n_episodes=sys.maxsize): 8 | action = None 9 | returns = 0 10 | env.reset() 11 | 12 | for _ in range(n_episodes): 13 | env.render() 14 | action = agent.act(env.state) 15 | env.step(action) 16 | returns += env.state.reward 17 | 18 | if env.state.done: 19 | print("returns:", returns) 20 | env.reset() 21 | returns = 0 22 | 23 | time.sleep(1 / fps) 24 | 25 | 26 | def load_and_watch(filename, env, fps=60, n_episodes=sys.maxsize): 27 | agent = torch.load(filename).test_agent() 28 | watch(agent, env, fps=fps, n_episodes=n_episodes) 29 | -------------------------------------------------------------------------------- /all/experiments/watch_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock 3 | 4 | import torch 5 | 6 | from all.environments import GymEnvironment 7 | from all.experiments.watch import load_and_watch 8 | 9 | 10 | class MockAgent: 11 | def act(self): 12 | # sample from cartpole action space 13 | return torch.randint(0, 2, []) 14 | 15 | 16 | class MockPreset: 17 | def __init__(self, filename): 18 | self.filename = filename 19 | 20 | def test_agent(self): 21 | return MockAgent 22 | 23 | 24 | class WatchTest(unittest.TestCase): 25 | @mock.patch("torch.load", lambda filename: MockPreset(filename)) 26 | @mock.patch("time.sleep", mock.MagicMock()) 27 | @mock.patch("sys.stdout", mock.MagicMock()) 28 | def test_load_and_watch(self): 29 | env = mock.MagicMock(GymEnvironment("CartPole-v0", render_mode="rgb_array")) 30 | load_and_watch("file.name", env, n_episodes=3) 31 | self.assertEqual(env.reset.call_count, 4) 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /all/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from ._logger import Logger 2 | from .dummy import DummyLogger 3 | from .experiment import ExperimentLogger 4 | 5 | __all__ = ["Logger", "DummyLogger", "ExperimentLogger"] 6 | -------------------------------------------------------------------------------- /all/logging/_logger.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Logger(ABC): 5 | log_dir = "runs" 6 | 7 | @abstractmethod 8 | def add_summary(self, name, mean, std, step="frame"): 9 | """ 10 | Log a summary statistic. 11 | 12 | Args: 13 | name (str): The tag to associate with the summary statistic 14 | mean (float): The mean of the statistic at the current step 15 | std (float): The standard deviation of the statistic at the current step 16 | step (str, optional): Which step to use (e.g., "frame" or "episode") 17 | """ 18 | 19 | @abstractmethod 20 | def add_loss(self, name, value, step="frame"): 21 | """ 22 | Log the given loss metric at the current step. 23 | 24 | Args: 25 | name (str): The tag to associate with the loss 26 | value (number): The value of the loss at the current step 27 | step (str, optional): Which step to use (e.g., "frame" or "episode") 28 | """ 29 | 30 | @abstractmethod 31 | def add_eval(self, name, value, step="frame"): 32 | """ 33 | Log the given evaluation metric at the current step. 34 | 35 | Args: 36 | name (str): The tag to associate with the loss 37 | value (number): The evaluation metric at the current step 38 | step (str, optional): Which step to use (e.g., "frame" or "episode") 39 | """ 40 | 41 | @abstractmethod 42 | def add_info(self, name, value, step="frame"): 43 | """ 44 | Log the given informational metric at the current step. 45 | 46 | Args: 47 | name (str): The tag to associate with the loss 48 | value (number): The evaluation metric at the current step 49 | step (str, optional): Which step to use (e.g., "frame" or "episode") 50 | """ 51 | 52 | @abstractmethod 53 | def add_schedule(self, name, value, step="frame"): 54 | """ 55 | Log the current value of a hyperparameter according to some schedule. 56 | 57 | Args: 58 | name (str): The tag to associate with the hyperparameter schedule 59 | value (number): The value of the hyperparameter at the current step 60 | step (str, optional): Which step to use (e.g., "frame" or "episode") 61 | """ 62 | 63 | @abstractmethod 64 | def add_hparams(self, hparam_dict, metric_dict, step="frame"): 65 | """ 66 | Logs metrics for a given set of hyperparameters. 67 | Usually this should be called once at the end of a run in order to 68 | log the final results for hyperparameters, though it can be called 69 | multiple times throughout training. However, it should be called infrequently. 70 | 71 | Args: 72 | hparam_dict (dict): A dictionary of hyperparameters. 73 | Only parameters of type (int, float, str, bool, torch.Tensor) 74 | will be logged. 75 | metric_dict (dict): A dictionary of metrics to record. 76 | step (str, optional): Which step to use (e.g., "frame" or "episode") 77 | """ 78 | pass 79 | 80 | @abstractmethod 81 | def close(self): 82 | """ 83 | Close the logger and perform any necessary cleanup. 84 | """ 85 | -------------------------------------------------------------------------------- /all/logging/dummy.py: -------------------------------------------------------------------------------- 1 | from ._logger import Logger 2 | 3 | 4 | class DummyLogger(Logger): 5 | """A default Logger object that performs no logging and has no side effects.""" 6 | 7 | def add_eval(self, name, value, step="frame"): 8 | pass 9 | 10 | def add_info(self, name, value, step="frame"): 11 | pass 12 | 13 | def add_loss(self, name, value, step="frame"): 14 | pass 15 | 16 | def add_schedule(self, name, value, step="frame"): 17 | pass 18 | 19 | def add_summary(self, name, values, step="frame"): 20 | pass 21 | 22 | def add_hparams(self, hparam_dict, metric_dict, step="frame"): 23 | pass 24 | 25 | def close(self): 26 | pass 27 | -------------------------------------------------------------------------------- /all/logging/experiment.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from ._logger import Logger 10 | 11 | 12 | class ExperimentLogger(SummaryWriter, Logger): 13 | """ 14 | The default Logger object used by all.experiments.Experiment. 15 | Writes logs using tensorboard into the current logdir directory ('runs' by default), 16 | tagging the run with a combination of the agent name, the commit hash of the 17 | current git repo of the working directory (if any), and the current time. 18 | Also writes summary statistics into CSV files. 19 | Args: 20 | experiment (all.experiments.Experiment): The Experiment associated with the Logger object. 21 | agent_name (str): The name of the Agent the Experiment is being performed on 22 | env_name (str): The name of the environment the Experiment is being performed in 23 | verbose (bool, optional): Whether or not to log all data or only summary metrics. 24 | """ 25 | 26 | def __init__(self, experiment, agent_name, env_name, verbose=True, logdir="runs"): 27 | self.env_name = env_name 28 | current_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S_%f") 29 | dir_name = f"{agent_name}_{env_name}_{current_time}" 30 | os.makedirs(os.path.join(logdir, dir_name)) 31 | self.log_dir = os.path.join(logdir, dir_name) 32 | self._experiment = experiment 33 | self._verbose = verbose 34 | super().__init__(log_dir=self.log_dir) 35 | 36 | def add_summary(self, name, values, step="frame"): 37 | aggregators = ["mean", "std", "max", "min"] 38 | metrics = { 39 | aggregator: getattr(np, aggregator)(values) for aggregator in aggregators 40 | } 41 | for aggregator, value in metrics.items(): 42 | super().add_scalar( 43 | f"summary/{name}/{aggregator}", value, self._get_step(step) 44 | ) 45 | 46 | # log summary statistics to file 47 | with open(os.path.join(self.log_dir, name + ".csv"), "a") as csvfile: 48 | csv.writer(csvfile).writerow([self._get_step(step), *metrics.values()]) 49 | 50 | def add_loss(self, name, value, step="frame"): 51 | self._add_scalar("loss/" + name, value, step) 52 | 53 | def add_eval(self, name, value, step="frame"): 54 | self._add_scalar("eval/" + name, value, step) 55 | 56 | def add_info(self, name, value, step="frame"): 57 | self._add_scalar("info/" + name, value, step) 58 | 59 | def add_schedule(self, name, value, step="frame"): 60 | self._add_scalar("schedule/" + name, value, step) 61 | 62 | def add_hparams(self, hparam_dict, metric_dict, step="frame"): 63 | allowed_types = (int, float, str, bool, torch.Tensor) 64 | hparams = {k: v for k, v in hparam_dict.items() if isinstance(v, allowed_types)} 65 | super().add_hparams( 66 | hparams, metric_dict, run_name=".", global_step=self._get_step("frame") 67 | ) 68 | 69 | def _add_scalar(self, name, value, step="frame"): 70 | if self._verbose: 71 | super().add_scalar(name, value, self._get_step(step)) 72 | 73 | def _get_step(self, _type): 74 | if _type == "frame": 75 | return self._experiment.frame 76 | if _type == "episode": 77 | return self._experiment.episode 78 | return _type 79 | 80 | def close(self): 81 | pass 82 | -------------------------------------------------------------------------------- /all/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .advantage import NStepAdvantageBuffer 2 | from .generalized_advantage import GeneralizedAdvantageBuffer 3 | from .replay_buffer import ( 4 | ExperienceReplayBuffer, 5 | NStepReplayBuffer, 6 | PrioritizedReplayBuffer, 7 | ReplayBuffer, 8 | ) 9 | 10 | __all__ = [ 11 | "ReplayBuffer", 12 | "ExperienceReplayBuffer", 13 | "PrioritizedReplayBuffer", 14 | "NStepAdvantageBuffer", 15 | "NStepReplayBuffer", 16 | "GeneralizedAdvantageBuffer", 17 | ] 18 | -------------------------------------------------------------------------------- /all/memory/generalized_advantage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.core import State 4 | from all.optim import Schedulable 5 | 6 | 7 | class GeneralizedAdvantageBuffer(Schedulable): 8 | def __init__( 9 | self, 10 | v, 11 | features, 12 | n_steps, 13 | n_envs, 14 | discount_factor=1, 15 | lam=1, 16 | compute_batch_size=256, 17 | ): 18 | self.v = v 19 | self.features = features 20 | self.n_steps = n_steps 21 | self.n_envs = n_envs 22 | self.gamma = discount_factor 23 | self.lam = lam 24 | self._batch_size = self.n_steps * self.n_envs 25 | self.compute_batch_size = compute_batch_size 26 | self._states = [] 27 | self._actions = [] 28 | self._rewards = [] 29 | 30 | def __len__(self): 31 | return len(self._states) * self.n_envs 32 | 33 | def store(self, states, actions, rewards): 34 | if states is None: 35 | return 36 | if not self._states: 37 | self._states = [states] 38 | self._actions = [actions] 39 | self._rewards = [rewards] 40 | elif len(self._states) <= self.n_steps: 41 | self._states.append(states) 42 | self._actions.append(actions) 43 | self._rewards.append(rewards) 44 | else: 45 | raise Exception("Buffer length exceeded: " + str(self.n_steps)) 46 | 47 | def advantages(self, next_states): 48 | if len(self) < self._batch_size: 49 | raise Exception("Not enough states received!") 50 | 51 | self._states.append(next_states) 52 | states = State.array(self._states[0 : self.n_steps + 1]) 53 | actions = torch.cat(self._actions[: self.n_steps], dim=0) 54 | rewards = torch.stack(self._rewards[: self.n_steps]) 55 | 56 | _values = ( 57 | states.flatten() 58 | .batch_execute( 59 | self.compute_batch_size, 60 | lambda s: self.v.target(self.features.target(s)), 61 | ) 62 | .view(states.shape) 63 | ) 64 | values = _values[0 : self.n_steps] 65 | next_values = _values[1:] 66 | 67 | td_errors = rewards + self.gamma * next_values - values 68 | advantages = self._compute_advantages(td_errors) 69 | self._clear_buffers() 70 | 71 | return (states[0:-1].flatten(), actions, advantages.view(-1)) 72 | 73 | def _compute_advantages(self, td_errors): 74 | advantages = td_errors.clone() 75 | current_advantages = advantages[0] * 0 76 | 77 | # the final advantage is always 0 78 | advantages[-1] = current_advantages 79 | for i in range(self.n_steps): 80 | t = self.n_steps - 1 - i 81 | mask = self._states[t + 1].mask.float() 82 | current_advantages = ( 83 | td_errors[t] + self.gamma * self.lam * current_advantages * mask 84 | ) 85 | advantages[t] = current_advantages 86 | 87 | return advantages 88 | 89 | def _clear_buffers(self): 90 | self._states = [] 91 | self._actions = [] 92 | self._rewards = [] 93 | -------------------------------------------------------------------------------- /all/nn/nn_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gymnasium 4 | import numpy as np 5 | import torch 6 | import torch_testing as tt 7 | 8 | from all import nn 9 | from all.core import StateArray 10 | 11 | 12 | class TestNN(unittest.TestCase): 13 | def setUp(self): 14 | torch.manual_seed(2) 15 | 16 | def test_dueling(self): 17 | torch.random.manual_seed(0) 18 | value_model = nn.Linear(2, 1) 19 | advantage_model = nn.Linear(2, 3) 20 | model = nn.Dueling(value_model, advantage_model) 21 | states = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 22 | result = model(states).detach().numpy() 23 | np.testing.assert_array_almost_equal( 24 | result, 25 | np.array( 26 | [[-0.495295, 0.330573, 0.678836], [-1.253222, 1.509323, 2.502186]], 27 | dtype=np.float32, 28 | ), 29 | ) 30 | 31 | def test_linear0(self): 32 | model = nn.Linear0(3, 3) 33 | result = model(torch.tensor([[3.0, -2.0, 10]])) 34 | tt.assert_equal(result, torch.tensor([[0.0, 0.0, 0.0]])) 35 | 36 | def test_list(self): 37 | model = nn.Linear(2, 2) 38 | net = nn.RLNetwork(model, (2,)) 39 | features = torch.randn((4, 2)) 40 | done = torch.tensor([False, False, True, False]) 41 | out = net(StateArray(features, (4,), done=done)) 42 | tt.assert_almost_equal( 43 | out, 44 | torch.tensor( 45 | [ 46 | [0.0479387, -0.2268031], 47 | [0.2346841, 0.0743403], 48 | [0.0, 0.0], 49 | [0.2204496, 0.086818], 50 | ] 51 | ), 52 | ) 53 | 54 | features = torch.randn(3, 2) 55 | done = torch.tensor([False, False, False]) 56 | out = net(StateArray(features, (3,), done=done)) 57 | tt.assert_almost_equal( 58 | out, 59 | torch.tensor( 60 | [ 61 | [0.4234636, 0.1039939], 62 | [0.6514298, 0.3354351], 63 | [-0.2543002, -0.2041451], 64 | ] 65 | ), 66 | ) 67 | 68 | def test_tanh_action_bound(self): 69 | space = gymnasium.spaces.Box(np.array([-1.0, 10.0]), np.array([1, 20])) 70 | model = nn.TanhActionBound(space) 71 | x = torch.tensor([[100.0, 100], [-100, -100], [-100, 100], [0, 0]]) 72 | tt.assert_almost_equal( 73 | model(x), torch.tensor([[1.0, 20], [-1, 10], [-1, 20], [0.0, 15]]) 74 | ) 75 | 76 | def test_categorical_dueling(self): 77 | n_actions = 2 78 | n_atoms = 3 79 | value_model = nn.Linear(2, n_atoms) 80 | advantage_model = nn.Linear(2, n_actions * n_atoms) 81 | model = nn.CategoricalDueling(value_model, advantage_model) 82 | x = torch.randn((2, 2)) 83 | out = model(x) 84 | self.assertEqual(out.shape, (2, 6)) 85 | tt.assert_almost_equal( 86 | out, 87 | torch.tensor( 88 | [ 89 | [0.014, -0.691, 0.251, -0.055, -0.419, -0.03], 90 | [0.057, -1.172, 0.568, -0.868, -0.482, -0.679], 91 | ] 92 | ), 93 | decimal=3, 94 | ) 95 | 96 | def assert_array_equal(self, actual, expected): 97 | for first, second in zip(actual, expected): 98 | if second is None: 99 | self.assertIsNone(first) 100 | else: 101 | tt.assert_almost_equal(first, second, decimal=3) 102 | 103 | 104 | if __name__ == "__main__": 105 | unittest.main() 106 | -------------------------------------------------------------------------------- /all/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduler import LinearScheduler, Schedulable 2 | 3 | __all__ = ["Schedulable", "LinearScheduler"] 4 | -------------------------------------------------------------------------------- /all/optim/scheduler.py: -------------------------------------------------------------------------------- 1 | from all.logging import DummyLogger 2 | 3 | 4 | class Schedulable: 5 | """Allow "instance" descriptors to implement parameter scheduling.""" 6 | 7 | def __getattribute__(self, name): 8 | value = object.__getattribute__(self, name) 9 | if isinstance(value, Scheduler): 10 | value = value.__get__(self, self.__class__) 11 | return value 12 | 13 | 14 | class Scheduler: 15 | pass 16 | 17 | 18 | class LinearScheduler(Scheduler): 19 | def __init__( 20 | self, 21 | initial_value, 22 | final_value, 23 | decay_start, 24 | decay_end, 25 | name="variable", 26 | logger=DummyLogger(), 27 | ): 28 | self._initial_value = initial_value 29 | self._final_value = final_value 30 | self._decay_start = decay_start 31 | self._decay_end = decay_end 32 | self._i = -1 33 | self._name = name 34 | self._logger = logger 35 | 36 | def __get__(self, instance, owner=None): 37 | result = self._get_value() 38 | self._logger.add_schedule(self._name, result) 39 | return result 40 | 41 | def _get_value(self): 42 | self._i += 1 43 | if self._i < self._decay_start: 44 | return self._initial_value 45 | if self._i >= self._decay_end: 46 | return self._final_value 47 | alpha = (self._i - self._decay_start) / (self._decay_end - self._decay_start) 48 | return alpha * self._final_value + (1 - alpha) * self._initial_value 49 | -------------------------------------------------------------------------------- /all/optim/scheduler_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from all.optim import LinearScheduler, Schedulable 6 | 7 | 8 | class Obj(Schedulable): 9 | def __init__(self): 10 | self.attr = 0 11 | 12 | 13 | class TestScheduler(unittest.TestCase): 14 | def test_linear_scheduler(self): 15 | obj = Obj() 16 | obj.attr = LinearScheduler(10, 0, 3, 13) 17 | expected = [10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 0] 18 | actual = [obj.attr for _ in expected] 19 | np.testing.assert_allclose(actual, expected) 20 | 21 | 22 | if __name__ == "__main__": 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /all/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .deterministic import DeterministicPolicy 2 | from .gaussian import GaussianPolicy 3 | from .greedy import GreedyPolicy, ParallelGreedyPolicy 4 | from .soft_deterministic import SoftDeterministicPolicy 5 | from .softmax import SoftmaxPolicy 6 | 7 | __all__ = [ 8 | "GaussianPolicy", 9 | "GreedyPolicy", 10 | "ParallelGreedyPolicy", 11 | "SoftmaxPolicy", 12 | "DeterministicPolicy", 13 | "SoftDeterministicPolicy", 14 | ] 15 | -------------------------------------------------------------------------------- /all/policies/deterministic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from all.approximation import Approximation 4 | from all.nn import RLNetwork 5 | 6 | 7 | class DeterministicPolicy(Approximation): 8 | """ 9 | A DDPG-style deterministic policy. 10 | 11 | Args: 12 | model (torch.nn.Module): A Pytorch module representing the policy network. 13 | The input shape should be the same as the shape of the state space, 14 | and the output shape should be the same as the shape of the action space. 15 | optimizer (torch.optim.Optimizer): A optimizer initialized with the 16 | model parameters, e.g. SGD, Adam, RMSprop, etc. 17 | action_space (gymnasium.spaces.Box): The Box representing the action space. 18 | kwargs (optional): Any other arguments accepted by all.approximation.Approximation 19 | """ 20 | 21 | def __init__(self, model, optimizer=None, space=None, name="policy", **kwargs): 22 | model = DeterministicPolicyNetwork(model, space) 23 | super().__init__(model, optimizer, name=name, **kwargs) 24 | 25 | 26 | class DeterministicPolicyNetwork(RLNetwork): 27 | def __init__(self, model, space): 28 | super().__init__(model) 29 | self._action_dim = space.shape[0] 30 | self._tanh_scale = torch.tensor((space.high - space.low) / 2).to(self.device) 31 | self._tanh_mean = torch.tensor((space.high + space.low) / 2).to(self.device) 32 | 33 | def forward(self, state): 34 | return self._squash(super().forward(state)) 35 | 36 | def _squash(self, x): 37 | return torch.tanh(x) * self._tanh_scale + self._tanh_mean 38 | 39 | def to(self, device): 40 | self._tanh_mean = self._tanh_mean.to(device) 41 | self._tanh_scale = self._tanh_scale.to(device) 42 | return super().to(device) 43 | -------------------------------------------------------------------------------- /all/policies/deterministic_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import torch_testing as tt 6 | from gymnasium.spaces import Box 7 | 8 | from all import nn 9 | from all.approximation import DummyCheckpointer, FixedTarget 10 | from all.core import State 11 | from all.policies import DeterministicPolicy 12 | 13 | STATE_DIM = 2 14 | ACTION_DIM = 3 15 | 16 | 17 | class TestDeterministic(unittest.TestCase): 18 | def setUp(self): 19 | torch.manual_seed(2) 20 | self.model = nn.Sequential(nn.Linear0(STATE_DIM, ACTION_DIM)) 21 | self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01) 22 | self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1])) 23 | self.policy = DeterministicPolicy( 24 | self.model, self.optimizer, self.space, checkpointer=DummyCheckpointer() 25 | ) 26 | 27 | def test_output_shape(self): 28 | state = State(torch.randn(1, STATE_DIM)) 29 | action = self.policy(state) 30 | self.assertEqual(action.shape, (1, ACTION_DIM)) 31 | state = State(torch.randn(5, STATE_DIM)) 32 | action = self.policy(state) 33 | self.assertEqual(action.shape, (5, ACTION_DIM)) 34 | 35 | def test_step_one(self): 36 | state = State(torch.randn(1, STATE_DIM)) 37 | self.policy(state) 38 | self.policy.step() 39 | 40 | def test_converge(self): 41 | state = State(torch.randn(1, STATE_DIM)) 42 | target = torch.tensor([0.25, 0.5, -0.5]) 43 | 44 | for _ in range(0, 200): 45 | action = self.policy(state) 46 | loss = ((target - action) ** 2).mean() 47 | loss.backward() 48 | self.policy.step() 49 | 50 | self.assertLess(loss, 0.001) 51 | 52 | def test_target(self): 53 | self.policy = DeterministicPolicy( 54 | self.model, self.optimizer, self.space, target=FixedTarget(3) 55 | ) 56 | state = State(torch.ones(1, STATE_DIM)) 57 | 58 | # run update step, make sure target network doesn't change 59 | self.policy(state).sum().backward() 60 | self.policy.step() 61 | tt.assert_equal(self.policy.target(state), torch.zeros(1, ACTION_DIM)) 62 | 63 | # again... 64 | self.policy(state).sum().backward() 65 | self.policy.step() 66 | tt.assert_equal(self.policy.target(state), torch.zeros(1, ACTION_DIM)) 67 | 68 | # third time, target should be updated 69 | self.policy(state).sum().backward() 70 | self.policy.step() 71 | tt.assert_allclose( 72 | self.policy.target(state), 73 | torch.tensor([[-0.574482, -0.574482, -0.574482]]), 74 | atol=1e-4, 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /all/policies/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions.independent import Independent 3 | from torch.distributions.normal import Normal 4 | 5 | from all.approximation import Approximation 6 | from all.nn import RLNetwork 7 | 8 | 9 | class GaussianPolicy(Approximation): 10 | """ 11 | A Gaussian stochastic policy. 12 | 13 | This policy will choose actions from a distribution represented by a spherical Gaussian. 14 | The first n outputs of the model are the mean of the distribution and the last n outputs are the log variance. 15 | The output will be centered and scaled to the size of the given space, but the output will not be clipped. 16 | For example, for an output range of [-1, 1], the center is 0 and the scale is 1. 17 | 18 | Args: 19 | model (torch.nn.Module): A Pytorch module representing the policy network. 20 | The input shape should be the same as the shape of the state (or feature) space, 21 | and the output shape should be double the size of the the action space. 22 | The first n outputs will be the unscaled mean of the action for each dimension, 23 | and the last n outputs will be the logarithm of the variance. 24 | optimizer (torch.optim.Optimizer): A optimizer initialized with the 25 | model parameters, e.g. SGD, Adam, RMSprop, etc. 26 | action_space (gymnasium.spaces.Box): The Box representing the action space. 27 | kwargs (optional): Any other arguments accepted by all.approximation.Approximation 28 | """ 29 | 30 | def __init__(self, model, optimizer=None, space=None, name="policy", **kwargs): 31 | super().__init__( 32 | GaussianPolicyNetwork(model, space), optimizer, name=name, **kwargs 33 | ) 34 | 35 | 36 | class GaussianPolicyNetwork(RLNetwork): 37 | def __init__(self, model, space): 38 | super().__init__(model) 39 | self._center = torch.tensor((space.high + space.low) / 2).to(self.device) 40 | self._scale = torch.tensor((space.high - space.low) / 2).to(self.device) 41 | 42 | def forward(self, state): 43 | outputs = super().forward(state) 44 | action_dim = outputs.shape[-1] // 2 45 | means = outputs[..., 0:action_dim] 46 | logvars = outputs[..., action_dim:] 47 | std = (0.5 * logvars).exp_() 48 | return Independent(Normal(means + self._center, std * self._scale), 1) 49 | 50 | def to(self, device): 51 | self._center = self._center.to(device) 52 | self._scale = self._scale.to(device) 53 | return super().to(device) 54 | -------------------------------------------------------------------------------- /all/policies/gaussian_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import torch_testing as tt 6 | from gymnasium.spaces import Box 7 | from torch import nn 8 | 9 | from all.approximation import DummyCheckpointer 10 | from all.core import State 11 | from all.policies import GaussianPolicy 12 | 13 | STATE_DIM = 2 14 | ACTION_DIM = 3 15 | 16 | 17 | class TestGaussian(unittest.TestCase): 18 | def setUp(self): 19 | torch.manual_seed(2) 20 | self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1])) 21 | self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTION_DIM * 2)) 22 | optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01) 23 | self.policy = GaussianPolicy( 24 | self.model, optimizer, self.space, checkpointer=DummyCheckpointer() 25 | ) 26 | 27 | def test_output_shape(self): 28 | state = State(torch.randn(1, STATE_DIM)) 29 | action = self.policy(state).sample() 30 | self.assertEqual(action.shape, (1, ACTION_DIM)) 31 | state = State(torch.randn(5, STATE_DIM)) 32 | action = self.policy(state).sample() 33 | self.assertEqual(action.shape, (5, ACTION_DIM)) 34 | 35 | def test_reinforce_one(self): 36 | state = State(torch.randn(1, STATE_DIM)) 37 | dist = self.policy(state) 38 | action = dist.sample() 39 | log_prob1 = dist.log_prob(action) 40 | loss = -log_prob1.mean() 41 | self.policy.reinforce(loss) 42 | 43 | dist = self.policy(state) 44 | log_prob2 = dist.log_prob(action) 45 | 46 | self.assertGreater(log_prob2.item(), log_prob1.item()) 47 | 48 | def test_converge(self): 49 | state = State(torch.randn(1, STATE_DIM)) 50 | target = torch.tensor([1.0, 2.0, -1.0]) 51 | 52 | for _ in range(0, 1000): 53 | dist = self.policy(state) 54 | action = dist.sample() 55 | log_prob = dist.log_prob(action) 56 | error = ((target - action) ** 2).mean() 57 | loss = (error * log_prob).mean() 58 | self.policy.reinforce(loss) 59 | 60 | self.assertTrue(error < 1) 61 | 62 | def test_eval(self): 63 | state = State(torch.randn(1, STATE_DIM)) 64 | dist = self.policy.no_grad(state) 65 | tt.assert_almost_equal( 66 | dist.mean, torch.tensor([[-0.237, 0.497, -0.058]]), decimal=3 67 | ) 68 | tt.assert_almost_equal(dist.entropy(), torch.tensor([4.254]), decimal=3) 69 | best = self.policy.eval(state).sample() 70 | tt.assert_almost_equal(best, torch.tensor([[-0.888, -0.887, 0.404]]), decimal=3) 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /all/policies/greedy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from all.optim import Schedulable 5 | 6 | 7 | class GreedyPolicy(Schedulable): 8 | """ 9 | An "epsilon-greedy" action selection policy for discrete action spaces. 10 | 11 | This policy will usually choose the optimal action according to an approximation 12 | of the action value function (the "q-function"), but with probability epsilon will 13 | choose a random action instead. GreedyPolicy is a Schedulable, meaning that 14 | epsilon can be varied over time by passing a Scheduler object. 15 | 16 | Args: 17 | q (all.approximation.QNetwork): The action-value or "q-function" 18 | num_actions (int): The number of available actions. 19 | epsilon (float, optional): The probability of selecting a random action. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | q, 25 | num_actions, 26 | epsilon=0.0, 27 | ): 28 | self.q = q 29 | self.num_actions = num_actions 30 | self.epsilon = epsilon 31 | 32 | def __call__(self, state): 33 | if np.random.rand() < self.epsilon: 34 | return np.random.randint(0, self.num_actions) 35 | return torch.argmax(self.q(state)).item() 36 | 37 | def no_grad(self, state): 38 | if np.random.rand() < self.epsilon: 39 | return np.random.randint(0, self.num_actions) 40 | return torch.argmax(self.q.no_grad(state)).item() 41 | 42 | def eval(self, state): 43 | if np.random.rand() < self.epsilon: 44 | return np.random.randint(0, self.num_actions) 45 | return torch.argmax(self.q.eval(state)).item() 46 | 47 | 48 | class ParallelGreedyPolicy(Schedulable): 49 | """ 50 | A parallel version of the "epsilon-greedy" action selection policy for discrete action spaces. 51 | 52 | This policy will usually choose the optimal action according to an approximation 53 | of the action value function (the "q-function"), but with probability epsilon will 54 | choose a random action instead. GreedyPolicy is a Schedulable, meaning that 55 | epsilon can be varied over time by passing a Scheduler object. 56 | 57 | Args: 58 | q (all.approximation.QNetwork): The action-value or "q-function" 59 | num_actions (int): The number of available actions. 60 | epsilon (float, optional): The probability of selecting a random action. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | q, 66 | num_actions, 67 | epsilon=0.0, 68 | ): 69 | self.q = q 70 | self.num_actions = num_actions 71 | self.epsilon = epsilon 72 | 73 | def __call__(self, state): 74 | return self._choose_action(self.q(state)) 75 | 76 | def no_grad(self, state): 77 | return self._choose_action(self.q.no_grad(state)) 78 | 79 | def eval(self, state): 80 | return self._choose_action(self.q.eval(state)) 81 | 82 | def _choose_action(self, action_values): 83 | best_actions = torch.argmax(action_values, dim=-1) 84 | random_actions = torch.randint( 85 | 0, self.num_actions, best_actions.shape, device=best_actions.device 86 | ) 87 | choices = ( 88 | torch.rand(best_actions.shape, device=best_actions.device) < self.epsilon 89 | ).int() 90 | return choices * random_actions + (1 - choices) * best_actions 91 | -------------------------------------------------------------------------------- /all/policies/soft_deterministic_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import torch_testing as tt 6 | from gymnasium.spaces import Box 7 | 8 | from all import nn 9 | from all.approximation import DummyCheckpointer 10 | from all.core import State 11 | from all.policies import SoftDeterministicPolicy 12 | 13 | STATE_DIM = 2 14 | ACTION_DIM = 3 15 | 16 | 17 | class TestSoftDeterministic(unittest.TestCase): 18 | def setUp(self): 19 | torch.manual_seed(2) 20 | self.model = nn.Sequential(nn.Linear0(STATE_DIM, ACTION_DIM * 2)) 21 | self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01) 22 | self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1])) 23 | self.policy = SoftDeterministicPolicy( 24 | self.model, self.optimizer, self.space, checkpointer=DummyCheckpointer() 25 | ) 26 | 27 | def test_output_shape(self): 28 | state = State(torch.randn(1, STATE_DIM)) 29 | action, log_prob = self.policy(state) 30 | self.assertEqual(action.shape, (1, ACTION_DIM)) 31 | self.assertEqual(log_prob.shape, torch.Size([1])) 32 | 33 | state = State(torch.randn(5, STATE_DIM)) 34 | action, log_prob = self.policy(state) 35 | self.assertEqual(action.shape, (5, ACTION_DIM)) 36 | self.assertEqual(log_prob.shape, torch.Size([5])) 37 | 38 | def test_step_one(self): 39 | state = State(torch.randn(1, STATE_DIM)) 40 | self.policy(state) 41 | self.policy.step() 42 | 43 | def test_converge(self): 44 | state = State(torch.randn(1, STATE_DIM)) 45 | target = torch.tensor([0.25, 0.5, -0.5]) 46 | 47 | for _ in range(0, 200): 48 | action, _ = self.policy(state) 49 | loss = ((target - action) ** 2).mean() 50 | loss.backward() 51 | self.policy.step() 52 | 53 | self.assertLess(loss, 0.2) 54 | 55 | def test_scaling(self): 56 | torch.manual_seed(0) 57 | state = State(torch.randn(1, STATE_DIM)) 58 | policy1 = SoftDeterministicPolicy( 59 | self.model, 60 | self.optimizer, 61 | Box(np.array([-1.0, -1.0, -1.0]), np.array([1.0, 1.0, 1.0])), 62 | ) 63 | action1, log_prob1 = policy1(state) 64 | 65 | # reset seed and sample same thing, but with different scaling 66 | torch.manual_seed(0) 67 | state = State(torch.randn(1, STATE_DIM)) 68 | policy2 = SoftDeterministicPolicy( 69 | self.model, 70 | self.optimizer, 71 | Box(np.array([-2.0, -1.0, -1.0]), np.array([2.0, 1.0, 1.0])), 72 | ) 73 | action2, log_prob2 = policy2(state) 74 | 75 | # check scaling was correct 76 | tt.assert_allclose(action1 * torch.tensor([2, 1, 1]), action2) 77 | tt.assert_allclose(log_prob1 - np.log(2), log_prob2) 78 | 79 | 80 | if __name__ == "__main__": 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /all/policies/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional 3 | 4 | from all.approximation import Approximation 5 | from all.nn import RLNetwork 6 | 7 | 8 | class SoftmaxPolicy(Approximation): 9 | """ 10 | A softmax (or Boltzmann) stochastic policy for discrete actions. 11 | 12 | Args: 13 | model (torch.nn.Module): A Pytorch module representing the policy network. 14 | The input shape should be the same as the shape of the state (or feature) space, 15 | and the output should be a vector the size of the action set. 16 | optimizer (torch.optim.Optimizer): A optimizer initialized with the 17 | model parameters, e.g. SGD, Adam, RMSprop, etc. 18 | kwargs (optional): Any other arguments accepted by all.approximation.Approximation 19 | """ 20 | 21 | def __init__(self, model, optimizer=None, name="policy", **kwargs): 22 | model = SoftmaxPolicyNetwork(model) 23 | super().__init__(model, optimizer, name=name, **kwargs) 24 | 25 | 26 | class SoftmaxPolicyNetwork(RLNetwork): 27 | def __init__(self, model): 28 | super().__init__(model) 29 | 30 | def forward(self, state): 31 | outputs = super().forward(state) 32 | probs = functional.softmax(outputs, dim=-1) 33 | return torch.distributions.Categorical(probs) 34 | -------------------------------------------------------------------------------- /all/policies/softmax_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch_testing as tt 5 | from torch import nn 6 | 7 | from all.core import State 8 | from all.policies import SoftmaxPolicy 9 | 10 | STATE_DIM = 2 11 | ACTIONS = 3 12 | 13 | 14 | class TestSoftmax(unittest.TestCase): 15 | def setUp(self): 16 | torch.manual_seed(2) 17 | self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS)) 18 | optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) 19 | self.policy = SoftmaxPolicy(self.model, optimizer) 20 | 21 | def test_run(self): 22 | state1 = State(torch.randn(1, STATE_DIM)) 23 | dist1 = self.policy(state1) 24 | action1 = dist1.sample() 25 | log_prob1 = dist1.log_prob(action1) 26 | self.assertEqual(action1.item(), 2) 27 | 28 | state2 = State(torch.randn(1, STATE_DIM)) 29 | dist2 = self.policy(state2) 30 | action2 = dist2.sample() 31 | log_prob2 = dist2.log_prob(action2) 32 | self.assertEqual(action2.item(), 2) 33 | 34 | loss = -(torch.tensor([-1, 1000000]) * torch.cat((log_prob1, log_prob2))).mean() 35 | self.policy.reinforce(loss) 36 | 37 | state3 = State(torch.randn(1, STATE_DIM)) 38 | dist3 = self.policy(state3) 39 | action3 = dist3.sample() 40 | self.assertEqual(action3.item(), 0) 41 | 42 | def test_multi_action(self): 43 | states = State(torch.randn(3, STATE_DIM)) 44 | actions = self.policy(states).sample() 45 | tt.assert_equal(actions, torch.tensor([2, 0, 0])) 46 | 47 | def test_list(self): 48 | torch.manual_seed(1) 49 | states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1])) 50 | dist = self.policy(states) 51 | actions = dist.sample() 52 | log_probs = dist.log_prob(actions) 53 | tt.assert_equal(actions, torch.tensor([0, 0, 2])) 54 | loss = -(torch.tensor([[1, 2, 3]]) * log_probs).mean() 55 | self.policy.reinforce(loss) 56 | 57 | def test_reinforce(self): 58 | def loss(log_probs): 59 | return -log_probs.mean() 60 | 61 | states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 1, 1])) 62 | actions = self.policy.no_grad(states).sample() 63 | 64 | # notice the values increase with each successive reinforce 65 | log_probs = self.policy(states).log_prob(actions) 66 | tt.assert_almost_equal( 67 | log_probs, torch.tensor([-0.84, -1.325, -0.757]), decimal=3 68 | ) 69 | self.policy.reinforce(loss(log_probs)) 70 | log_probs = self.policy(states).log_prob(actions) 71 | tt.assert_almost_equal( 72 | log_probs, torch.tensor([-0.855, -1.278, -0.726]), decimal=3 73 | ) 74 | self.policy.reinforce(loss(log_probs)) 75 | log_probs = self.policy(states).log_prob(actions) 76 | tt.assert_almost_equal( 77 | log_probs, torch.tensor([-0.871, -1.234, -0.698]), decimal=3 78 | ) 79 | 80 | def test_eval(self): 81 | states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 1, 1])) 82 | dist = self.policy.no_grad(states) 83 | tt.assert_almost_equal( 84 | dist.probs, 85 | torch.tensor( 86 | [[0.352, 0.216, 0.432], [0.266, 0.196, 0.538], [0.469, 0.227, 0.304]] 87 | ), 88 | decimal=3, 89 | ) 90 | best = self.policy.eval(states).sample() 91 | tt.assert_equal(best, torch.tensor([2, 0, 0])) 92 | 93 | 94 | if __name__ == "__main__": 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /all/presets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/presets/.DS_Store -------------------------------------------------------------------------------- /all/presets/__init__.py: -------------------------------------------------------------------------------- 1 | from all.presets import atari, classic_control, continuous 2 | 3 | from .builder import ParallelPresetBuilder, PresetBuilder 4 | from .independent_multiagent import IndependentMultiagentPreset 5 | from .preset import ParallelPreset, Preset 6 | 7 | __all__ = [ 8 | "Preset", 9 | "ParallelPreset", 10 | "PresetBuilder", 11 | "ParallelPresetBuilder", 12 | "atari", 13 | "classic_control", 14 | "continuous", 15 | "IndependentMultiagentPreset", 16 | ] 17 | -------------------------------------------------------------------------------- /all/presets/atari/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/presets/atari/.DS_Store -------------------------------------------------------------------------------- /all/presets/atari/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c import a2c 2 | from .c51 import c51 3 | from .ddqn import ddqn 4 | from .dqn import dqn 5 | from .ppo import ppo 6 | from .rainbow import rainbow 7 | from .vac import vac 8 | from .vpg import vpg 9 | from .vqn import vqn 10 | from .vsarsa import vsarsa 11 | 12 | __all__ = [ 13 | "a2c", 14 | "c51", 15 | "ddqn", 16 | "dqn", 17 | "ppo", 18 | "rainbow", 19 | "vac", 20 | "vpg", 21 | "vqn", 22 | "vsarsa", 23 | ] 24 | -------------------------------------------------------------------------------- /all/presets/atari/models/__init__.py: -------------------------------------------------------------------------------- 1 | from all import nn 2 | 3 | 4 | def nature_dqn(env, frames=4): 5 | return nn.Sequential( 6 | nn.Scale(1 / 255), 7 | nn.Conv2d(frames, 32, 8, stride=4), 8 | nn.ReLU(), 9 | nn.Conv2d(32, 64, 4, stride=2), 10 | nn.ReLU(), 11 | nn.Conv2d(64, 64, 3, stride=1), 12 | nn.ReLU(), 13 | nn.Flatten(), 14 | nn.Linear(3136, 512), 15 | nn.ReLU(), 16 | nn.Linear0(512, env.action_space.n), 17 | ) 18 | 19 | 20 | def nature_ddqn(env, frames=4): 21 | return nn.Sequential( 22 | nn.Scale(1 / 255), 23 | nn.Conv2d(frames, 32, 8, stride=4), 24 | nn.ReLU(), 25 | nn.Conv2d(32, 64, 4, stride=2), 26 | nn.ReLU(), 27 | nn.Conv2d(64, 64, 3, stride=1), 28 | nn.ReLU(), 29 | nn.Flatten(), 30 | nn.Dueling( 31 | nn.Sequential(nn.Linear(3136, 512), nn.ReLU(), nn.Linear0(512, 1)), 32 | nn.Sequential( 33 | nn.Linear(3136, 512), nn.ReLU(), nn.Linear0(512, env.action_space.n) 34 | ), 35 | ), 36 | ) 37 | 38 | 39 | def nature_features(frames=4): 40 | return nn.Sequential( 41 | nn.Scale(1 / 255), 42 | nn.Conv2d(frames, 32, 8, stride=4), 43 | nn.ReLU(), 44 | nn.Conv2d(32, 64, 4, stride=2), 45 | nn.ReLU(), 46 | nn.Conv2d(64, 64, 3, stride=1), 47 | nn.ReLU(), 48 | nn.Flatten(), 49 | nn.Linear(3136, 512), 50 | nn.ReLU(), 51 | ) 52 | 53 | 54 | def nature_value_head(): 55 | return nn.Linear(512, 1) 56 | 57 | 58 | def nature_policy_head(env): 59 | return nn.Linear0(512, env.action_space.n) 60 | 61 | 62 | def nature_c51(env, frames=4, atoms=51): 63 | return nn.Sequential( 64 | nn.Scale(1 / 255), 65 | nn.Conv2d(frames, 32, 8, stride=4), 66 | nn.ReLU(), 67 | nn.Conv2d(32, 64, 4, stride=2), 68 | nn.ReLU(), 69 | nn.Conv2d(64, 64, 3, stride=1), 70 | nn.ReLU(), 71 | nn.Flatten(), 72 | nn.Linear(3136, 512), 73 | nn.ReLU(), 74 | nn.Linear0(512, env.action_space.n * atoms), 75 | ) 76 | 77 | 78 | def nature_rainbow(env, frames=4, hidden=512, atoms=51, sigma=0.5): 79 | return nn.Sequential( 80 | nn.Scale(1 / 255), 81 | nn.Conv2d(frames, 32, 8, stride=4), 82 | nn.ReLU(), 83 | nn.Conv2d(32, 64, 4, stride=2), 84 | nn.ReLU(), 85 | nn.Conv2d(64, 64, 3, stride=1), 86 | nn.ReLU(), 87 | nn.Flatten(), 88 | nn.CategoricalDueling( 89 | nn.Sequential( 90 | nn.NoisyFactorizedLinear(3136, hidden, sigma_init=sigma), 91 | nn.ReLU(), 92 | nn.NoisyFactorizedLinear(hidden, atoms, init_scale=0, sigma_init=sigma), 93 | ), 94 | nn.Sequential( 95 | nn.NoisyFactorizedLinear(3136, hidden, sigma_init=sigma), 96 | nn.ReLU(), 97 | nn.NoisyFactorizedLinear( 98 | hidden, env.action_space.n * atoms, init_scale=0, sigma_init=sigma 99 | ), 100 | ), 101 | ), 102 | ) 103 | -------------------------------------------------------------------------------- /all/presets/atari_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import torch 5 | 6 | from all.environments import AtariEnvironment, DuplicateEnvironment 7 | from all.logging import DummyLogger 8 | from all.presets import ParallelPreset 9 | from all.presets.atari import a2c, c51, ddqn, dqn, ppo, rainbow, vac, vpg, vqn, vsarsa 10 | 11 | 12 | class TestAtariPresets(unittest.TestCase): 13 | def setUp(self): 14 | self.env = AtariEnvironment("Breakout") 15 | self.env.reset() 16 | self.parallel_env = DuplicateEnvironment( 17 | [AtariEnvironment("Breakout"), AtariEnvironment("Breakout")] 18 | ) 19 | self.parallel_env.reset() 20 | 21 | def tearDown(self): 22 | if os.path.exists("test_preset.pt"): 23 | os.remove("test_preset.pt") 24 | 25 | def test_a2c(self): 26 | self.validate_preset(a2c) 27 | 28 | def test_c51(self): 29 | self.validate_preset(c51) 30 | 31 | def test_ddqn(self): 32 | self.validate_preset(ddqn) 33 | 34 | def test_dqn(self): 35 | self.validate_preset(dqn) 36 | 37 | def test_ppo(self): 38 | self.validate_preset(ppo) 39 | 40 | def test_rainbow(self): 41 | self.validate_preset(rainbow) 42 | 43 | def test_vac(self): 44 | self.validate_preset(vac) 45 | 46 | def test_vpq(self): 47 | self.validate_preset(vpg) 48 | 49 | def test_vsarsa(self): 50 | self.validate_preset(vsarsa) 51 | 52 | def test_vqn(self): 53 | self.validate_preset(vqn) 54 | 55 | def validate_preset(self, builder): 56 | preset = builder.device("cpu").env(self.env).build() 57 | if isinstance(preset, ParallelPreset): 58 | return self.validate_parallel_preset(preset) 59 | return self.validate_standard_preset(preset) 60 | 61 | def validate_standard_preset(self, preset): 62 | # train agent 63 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 64 | agent.act(self.env.state) 65 | # test agent 66 | test_agent = preset.test_agent() 67 | test_agent.act(self.env.state) 68 | # test save/load 69 | preset.save("test_preset.pt") 70 | preset = torch.load("test_preset.pt") 71 | test_agent = preset.test_agent() 72 | test_agent.act(self.env.state) 73 | 74 | def validate_parallel_preset(self, preset): 75 | # train agent 76 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 77 | agent.act(self.parallel_env.state_array) 78 | # test agent 79 | test_agent = preset.test_agent() 80 | test_agent.act(self.env.state) 81 | # parallel test_agent 82 | parallel_test_agent = preset.test_agent() 83 | parallel_test_agent.act(self.parallel_env.state_array) 84 | # test save/load 85 | preset.save("test_preset.pt") 86 | preset = torch.load("test_preset.pt") 87 | test_agent = preset.test_agent() 88 | test_agent.act(self.env.state) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /all/presets/builder.py: -------------------------------------------------------------------------------- 1 | class PresetBuilder: 2 | def __init__( 3 | self, 4 | default_name, 5 | default_hyperparameters, 6 | constructor, 7 | device="cuda", 8 | env=None, 9 | hyperparameters=None, 10 | name=None, 11 | ): 12 | self.default_name = default_name 13 | self.default_hyperparameters = default_hyperparameters 14 | self.constructor = constructor 15 | self._device = device 16 | self._env = env 17 | self._hyperparameters = self._merge_hyperparameters( 18 | default_hyperparameters, hyperparameters 19 | ) 20 | self._name = name or default_name 21 | 22 | def __call__(self, **kwargs): 23 | return self._preset_builder(**kwargs) 24 | 25 | def device(self, device): 26 | return self._preset_builder(device=device) 27 | 28 | def env(self, env): 29 | return self._preset_builder(env=env) 30 | 31 | def hyperparameters(self, **hyperparameters): 32 | return self._preset_builder( 33 | hyperparameters=self._merge_hyperparameters( 34 | self._hyperparameters, hyperparameters 35 | ) 36 | ) 37 | 38 | def name(self, name): 39 | return self._preset_builder(name=name) 40 | 41 | def build(self): 42 | if not self._env: 43 | raise Exception("Env is required") 44 | 45 | return self.constructor( 46 | self._env, device=self._device, name=self._name, **self._hyperparameters 47 | ) 48 | 49 | def _merge_hyperparameters(self, h1, h2): 50 | if h2 is None: 51 | return h1 52 | for key in h2.keys(): 53 | if key not in h1: 54 | raise KeyError("Invalid hyperparameter: {}".format(key)) 55 | return {**h1, **h2} 56 | 57 | def _preset_builder(self, **kwargs): 58 | old_kwargs = { 59 | "device": self._device, 60 | "env": self._env, 61 | "hyperparameters": self._hyperparameters, 62 | "name": self._name, 63 | } 64 | return PresetBuilder( 65 | self.default_name, 66 | self.default_hyperparameters, 67 | self.constructor, 68 | **{**old_kwargs, **kwargs} 69 | ) 70 | 71 | 72 | class ParallelPresetBuilder(PresetBuilder): 73 | def __init__( 74 | self, 75 | default_name, 76 | default_hyperparameters, 77 | constructor, 78 | device="cuda", 79 | env=None, 80 | hyperparameters=None, 81 | name=None, 82 | ): 83 | if "n_envs" not in default_hyperparameters: 84 | raise Exception("ParallelPreset hyperparameters must include n_envs") 85 | super().__init__( 86 | default_name, 87 | default_hyperparameters, 88 | constructor, 89 | device=device, 90 | env=env, 91 | hyperparameters=hyperparameters, 92 | name=name, 93 | ) 94 | 95 | def build(self): 96 | return super().build() 97 | -------------------------------------------------------------------------------- /all/presets/builder_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import Mock 3 | 4 | from all.presets import PresetBuilder 5 | 6 | 7 | class TestPresetBuilder(unittest.TestCase): 8 | def setUp(self): 9 | self.name = "my_preset" 10 | self.default_hyperparameters = {"lr": 1e-4, "gamma": 0.99} 11 | 12 | class MockPreset: 13 | def __init__(self, env, name, device, **hyperparameters): 14 | self.env = env 15 | self.name = name 16 | self.device = device 17 | self.hyperparameters = hyperparameters 18 | 19 | self.builder = PresetBuilder( 20 | self.name, self.default_hyperparameters, MockPreset 21 | ) 22 | 23 | def test_default_name(self): 24 | agent = self.builder.env(Mock).build() 25 | self.assertEqual(agent.name, self.name) 26 | 27 | def test_override_name(self): 28 | agent = self.builder.name("cool_name").env(Mock).build() 29 | self.assertEqual(agent.name, "cool_name") 30 | 31 | def test_default_hyperparameters(self): 32 | agent = self.builder.env(Mock).build() 33 | self.assertEqual(agent.hyperparameters, self.default_hyperparameters) 34 | 35 | def test_override_hyperparameters(self): 36 | agent = self.builder.hyperparameters(lr=0.01).env(Mock).build() 37 | self.assertEqual( 38 | agent.hyperparameters, {**self.default_hyperparameters, "lr": 0.01} 39 | ) 40 | 41 | def test_bad_hyperparameters(self): 42 | with self.assertRaises(KeyError): 43 | self.builder.hyperparameters(foo=0.01).env(Mock).build() 44 | 45 | def test_default_device(self): 46 | agent = self.builder.env(Mock).build() 47 | self.assertEqual(agent.device, "cuda") 48 | 49 | def test_override_device(self): 50 | agent = self.builder.device("cpu").env(Mock).build() 51 | self.assertEqual(agent.device, "cpu") 52 | 53 | def test_no_side_effects(self): 54 | self.builder.device("cpu").hyperparameters(lr=0.01).device("cpu").env( 55 | Mock 56 | ).build() 57 | my_env = Mock 58 | agent = self.builder.env(Mock).build() 59 | self.assertEqual(agent.name, self.name) 60 | self.assertEqual(agent.hyperparameters, self.default_hyperparameters) 61 | self.assertEqual(agent.device, "cuda") 62 | self.assertEqual(agent.env, my_env) 63 | 64 | def test_call_api(self): 65 | agent = ( 66 | self.builder(device="cpu", hyperparameters={"lr": 0.01}, name="cool_name") 67 | .env(Mock) 68 | .build() 69 | ) 70 | self.assertEqual(agent.name, "cool_name") 71 | self.assertEqual( 72 | agent.hyperparameters, {**self.default_hyperparameters, "lr": 0.01} 73 | ) 74 | self.assertEqual(agent.device, "cpu") 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /all/presets/classic_control/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/presets/classic_control/.DS_Store -------------------------------------------------------------------------------- /all/presets/classic_control/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c import a2c 2 | from .c51 import c51 3 | from .ddqn import ddqn 4 | from .dqn import dqn 5 | from .ppo import ppo 6 | from .rainbow import rainbow 7 | from .vac import vac 8 | from .vpg import vpg 9 | from .vqn import vqn 10 | from .vsarsa import vsarsa 11 | 12 | __all__ = [ 13 | "a2c", 14 | "c51", 15 | "ddqn", 16 | "dqn", 17 | "ppo", 18 | "rainbow", 19 | "vac", 20 | "vpg", 21 | "vqn", 22 | "vsarsa", 23 | ] 24 | -------------------------------------------------------------------------------- /all/presets/classic_control/models/__init__.py: -------------------------------------------------------------------------------- 1 | from all import nn 2 | 3 | 4 | def fc_relu_q(env, hidden=64): 5 | return nn.Sequential( 6 | nn.Flatten(), 7 | nn.Linear(env.state_space.shape[0], hidden), 8 | nn.ReLU(), 9 | nn.Linear(hidden, env.action_space.n), 10 | ) 11 | 12 | 13 | def dueling_fc_relu_q(env): 14 | return nn.Sequential( 15 | nn.Flatten(), 16 | nn.Dueling( 17 | nn.Sequential( 18 | nn.Linear(env.state_space.shape[0], 256), nn.ReLU(), nn.Linear(256, 1) 19 | ), 20 | nn.Sequential( 21 | nn.Linear(env.state_space.shape[0], 256), 22 | nn.ReLU(), 23 | nn.Linear(256, env.action_space.n), 24 | ), 25 | ), 26 | ) 27 | 28 | 29 | def fc_relu_features(env, hidden=64): 30 | return nn.Sequential( 31 | nn.Flatten(), nn.Linear(env.state_space.shape[0], hidden), nn.ReLU() 32 | ) 33 | 34 | 35 | def fc_value_head(hidden=64): 36 | return nn.Linear0(hidden, 1) 37 | 38 | 39 | def fc_policy_head(env, hidden=64): 40 | return nn.Linear0(hidden, env.action_space.n) 41 | 42 | 43 | def fc_relu_dist_q(env, hidden=64, atoms=51): 44 | return nn.Sequential( 45 | nn.Flatten(), 46 | nn.Linear(env.state_space.shape[0], hidden), 47 | nn.ReLU(), 48 | nn.Linear0(hidden, env.action_space.n * atoms), 49 | ) 50 | 51 | 52 | def fc_relu_rainbow(env, hidden=64, atoms=51, sigma=0.5): 53 | return nn.Sequential( 54 | nn.Flatten(), 55 | nn.Linear(env.state_space.shape[0], hidden), 56 | nn.ReLU(), 57 | nn.CategoricalDueling( 58 | nn.NoisyFactorizedLinear(hidden, atoms, sigma_init=sigma), 59 | nn.NoisyFactorizedLinear( 60 | hidden, env.action_space.n * atoms, init_scale=0.0, sigma_init=sigma 61 | ), 62 | ), 63 | ) 64 | -------------------------------------------------------------------------------- /all/presets/classic_control_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import torch 5 | 6 | from all.environments import DuplicateEnvironment, GymEnvironment 7 | from all.logging import DummyLogger 8 | from all.presets import ParallelPreset 9 | from all.presets.classic_control import ( 10 | a2c, 11 | c51, 12 | ddqn, 13 | dqn, 14 | ppo, 15 | rainbow, 16 | vac, 17 | vpg, 18 | vqn, 19 | vsarsa, 20 | ) 21 | 22 | 23 | class TestClassicControlPresets(unittest.TestCase): 24 | def setUp(self): 25 | self.env = GymEnvironment("CartPole-v0") 26 | self.env.reset() 27 | self.parallel_env = DuplicateEnvironment( 28 | [GymEnvironment("CartPole-v0"), GymEnvironment("CartPole-v0")] 29 | ) 30 | self.parallel_env.reset() 31 | 32 | def tearDown(self): 33 | if os.path.exists("test_preset.pt"): 34 | os.remove("test_preset.pt") 35 | 36 | def test_a2c(self): 37 | self.validate(a2c) 38 | 39 | def test_c51(self): 40 | self.validate(c51) 41 | 42 | def test_ddqn(self): 43 | self.validate(ddqn) 44 | 45 | def test_dqn(self): 46 | self.validate(dqn) 47 | 48 | def test_ppo(self): 49 | self.validate(ppo) 50 | 51 | def test_rainbow(self): 52 | self.validate(rainbow) 53 | 54 | def test_vac(self): 55 | self.validate(vac) 56 | 57 | def test_vpg(self): 58 | self.validate(vpg) 59 | 60 | def test_vsarsa(self): 61 | self.validate(vsarsa) 62 | 63 | def test_vqn(self): 64 | self.validate(vqn) 65 | 66 | def validate(self, builder): 67 | preset = builder.device("cpu").env(self.env).build() 68 | if isinstance(preset, ParallelPreset): 69 | return self.validate_parallel_preset(preset) 70 | return self.validate_standard_preset(preset) 71 | 72 | def validate_standard_preset(self, preset): 73 | # train agent 74 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 75 | agent.act(self.env.state) 76 | # test agent 77 | test_agent = preset.test_agent() 78 | test_agent.act(self.env.state) 79 | # test save/load 80 | preset.save("test_preset.pt") 81 | preset = torch.load("test_preset.pt") 82 | test_agent = preset.test_agent() 83 | test_agent.act(self.env.state) 84 | 85 | def validate_parallel_preset(self, preset): 86 | # train agent 87 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 88 | agent.act(self.parallel_env.state_array) 89 | # test agent 90 | test_agent = preset.test_agent() 91 | test_agent.act(self.env.state) 92 | # parallel test_agent 93 | parallel_test_agent = preset.test_agent() 94 | parallel_test_agent.act(self.parallel_env.state_array) 95 | # test save/load 96 | preset.save("test_preset.pt") 97 | preset = torch.load("test_preset.pt") 98 | test_agent = preset.test_agent() 99 | test_agent.act(self.env.state) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /all/presets/continuous/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/presets/continuous/.DS_Store -------------------------------------------------------------------------------- /all/presets/continuous/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddpg import ddpg 2 | from .ppo import ppo 3 | from .sac import sac 4 | 5 | __all__ = [ 6 | "ddpg", 7 | "ppo", 8 | "sac", 9 | ] 10 | -------------------------------------------------------------------------------- /all/presets/continuous/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch models for continuous control. 3 | 4 | All models assume that a feature representing the 5 | current timestep is used in addition to the features 6 | received from the environment. 7 | """ 8 | 9 | import torch 10 | 11 | from all import nn 12 | 13 | 14 | def fc_q(env, hidden1=400, hidden2=300): 15 | return nn.Sequential( 16 | nn.Float(), 17 | nn.Linear(env.state_space.shape[0] + env.action_space.shape[0], hidden1), 18 | nn.ReLU(), 19 | nn.Linear(hidden1, hidden2), 20 | nn.ReLU(), 21 | nn.Linear0(hidden2, 1), 22 | ) 23 | 24 | 25 | def fc_v(env, hidden1=400, hidden2=300): 26 | return nn.Sequential( 27 | nn.Float(), 28 | nn.Linear(env.state_space.shape[0], hidden1), 29 | nn.ReLU(), 30 | nn.Linear(hidden1, hidden2), 31 | nn.ReLU(), 32 | nn.Linear0(hidden2, 1), 33 | ) 34 | 35 | 36 | def fc_deterministic_policy(env, hidden1=400, hidden2=300): 37 | return nn.Sequential( 38 | nn.Float(), 39 | nn.Linear(env.state_space.shape[0], hidden1), 40 | nn.ReLU(), 41 | nn.Linear(hidden1, hidden2), 42 | nn.ReLU(), 43 | nn.Linear0(hidden2, env.action_space.shape[0]), 44 | ) 45 | 46 | 47 | def fc_soft_policy(env, hidden1=400, hidden2=300): 48 | return nn.Sequential( 49 | nn.Float(), 50 | nn.Linear(env.state_space.shape[0], hidden1), 51 | nn.ReLU(), 52 | nn.Linear(hidden1, hidden2), 53 | nn.ReLU(), 54 | nn.Linear0(hidden2, env.action_space.shape[0] * 2), 55 | ) 56 | 57 | 58 | class fc_policy(nn.Module): 59 | def __init__(self, env, hidden1=400, hidden2=300): 60 | super().__init__() 61 | self.model = nn.Sequential( 62 | nn.Float(), 63 | nn.Linear(env.state_space.shape[0], hidden1), 64 | nn.Tanh(), 65 | nn.Linear(hidden1, hidden2), 66 | nn.Tanh(), 67 | nn.Linear(hidden2, env.action_space.shape[0]), 68 | ) 69 | self.log_stds = nn.Parameter(torch.zeros(env.action_space.shape[0])) 70 | 71 | def forward(self, x): 72 | means = self.model(x) 73 | stds = self.log_stds.expand(*means.shape) 74 | return torch.cat((means, stds), 1) 75 | -------------------------------------------------------------------------------- /all/presets/continuous_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import torch 5 | 6 | from all.environments import DuplicateEnvironment, GymEnvironment 7 | from all.logging import DummyLogger 8 | from all.presets import ParallelPreset 9 | from all.presets.continuous import ddpg, ppo, sac 10 | 11 | 12 | class TestContinuousPresets(unittest.TestCase): 13 | def setUp(self): 14 | self.env = GymEnvironment("MountainCarContinuous-v0") 15 | self.env.reset() 16 | self.parallel_env = DuplicateEnvironment( 17 | [ 18 | GymEnvironment("MountainCarContinuous-v0"), 19 | GymEnvironment("MountainCarContinuous-v0"), 20 | ] 21 | ) 22 | self.parallel_env.reset() 23 | 24 | def tearDown(self): 25 | if os.path.exists("test_preset.pt"): 26 | os.remove("test_preset.pt") 27 | 28 | def test_ddpg(self): 29 | self.validate(ddpg) 30 | 31 | def test_ppo(self): 32 | self.validate(ppo) 33 | 34 | def test_sac(self): 35 | self.validate(sac) 36 | 37 | def validate(self, builder): 38 | preset = builder.device("cpu").env(self.env).build() 39 | if isinstance(preset, ParallelPreset): 40 | return self.validate_parallel_preset(preset) 41 | return self.validate_standard_preset(preset) 42 | 43 | def validate_standard_preset(self, preset): 44 | # train agent 45 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 46 | agent.act(self.env.state) 47 | # test agent 48 | test_agent = preset.test_agent() 49 | test_agent.act(self.env.state) 50 | # test save/load 51 | preset.save("test_preset.pt") 52 | preset = torch.load("test_preset.pt") 53 | test_agent = preset.test_agent() 54 | test_agent.act(self.env.state) 55 | 56 | def validate_parallel_preset(self, preset): 57 | # train agent 58 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 59 | agent.act(self.parallel_env.state_array) 60 | # test agent 61 | test_agent = preset.test_agent() 62 | test_agent.act(self.env.state) 63 | # parallel test_agent 64 | parallel_test_agent = preset.test_agent() 65 | parallel_test_agent.act(self.parallel_env.state_array) 66 | # test save/load 67 | preset.save("test_preset.pt") 68 | preset = torch.load("test_preset.pt") 69 | test_agent = preset.test_agent() 70 | test_agent.act(self.env.state) 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /all/presets/independent_multiagent.py: -------------------------------------------------------------------------------- 1 | from all.agents import IndependentMultiagent 2 | from all.logging import DummyLogger 3 | 4 | from .preset import Preset 5 | 6 | 7 | class IndependentMultiagentPreset(Preset): 8 | def __init__(self, name, device, presets): 9 | super().__init__(name, device, presets) 10 | 11 | def agent(self, logger=DummyLogger(), train_steps=float("inf")): 12 | return IndependentMultiagent( 13 | { 14 | agent_id: preset.agent(logger=logger, train_steps=train_steps) 15 | for agent_id, preset in self.hyperparameters.items() 16 | } 17 | ) 18 | 19 | def test_agent(self): 20 | return IndependentMultiagent( 21 | { 22 | agent_id: preset.test_agent() 23 | for agent_id, preset in self.hyperparameters.items() 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /all/presets/multiagent_atari_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import torch 5 | 6 | from all.environments import MultiagentAtariEnv 7 | from all.logging import DummyLogger 8 | from all.presets import IndependentMultiagentPreset 9 | from all.presets.atari import dqn 10 | 11 | 12 | class TestMultiagentAtariPresets(unittest.TestCase): 13 | def setUp(self): 14 | self.env = MultiagentAtariEnv("pong_v3", device="cpu") 15 | self.env.reset() 16 | 17 | def tearDown(self): 18 | if os.path.exists("test_preset.pt"): 19 | os.remove("test_preset.pt") 20 | 21 | def test_independent(self): 22 | presets = { 23 | agent_id: dqn.device("cpu").env(self.env.subenvs[agent_id]).build() 24 | for agent_id in self.env.agents 25 | } 26 | self.validate_preset( 27 | IndependentMultiagentPreset("independent", "cpu", presets), self.env 28 | ) 29 | 30 | def validate_preset(self, preset, env): 31 | # normal agent 32 | agent = preset.agent(logger=DummyLogger(), train_steps=100000) 33 | agent.act(self.env.last()) 34 | # test agent 35 | test_agent = preset.test_agent() 36 | test_agent.act(self.env.last()) 37 | # test save/load 38 | preset.save("test_preset.pt") 39 | preset = torch.load("test_preset.pt") 40 | test_agent = preset.test_agent() 41 | test_agent.act(self.env.last()) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /all/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/all/scripts/__init__.py -------------------------------------------------------------------------------- /all/scripts/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.experiments import plot_returns_100 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser(description="Plots the results of experiments.") 8 | parser.add_argument("--logdir", help="Output directory", default="runs") 9 | parser.add_argument( 10 | "--timesteps", 11 | type=int, 12 | default=-1, 13 | help="The final point will be fixed to this x-value", 14 | ) 15 | args = parser.parse_args() 16 | plot_returns_100(args.logdir, timesteps=args.timesteps) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /all/scripts/release.py: -------------------------------------------------------------------------------- 1 | """Create slurm tasks to run release test suite""" 2 | 3 | from all.environments import AtariEnvironment, GymEnvironment 4 | from all.experiments import SlurmExperiment 5 | from all.presets import atari, classic_control, continuous 6 | 7 | 8 | def main(): 9 | # run on gpu 10 | device = "cuda" 11 | 12 | def get_agents(preset): 13 | agents = [getattr(preset, agent_name) for agent_name in preset.__all__] 14 | return [agent(device=device) for agent in agents] 15 | 16 | SlurmExperiment( 17 | get_agents(atari), 18 | AtariEnvironment("Breakout", device=device), 19 | 10e7, 20 | sbatch_args={"partition": "1080ti-long"}, 21 | ) 22 | 23 | SlurmExperiment( 24 | get_agents(classic_control), 25 | GymEnvironment("CartPole-v0", device=device), 26 | 100000, 27 | sbatch_args={"partition": "1080ti-short"}, 28 | ) 29 | 30 | SlurmExperiment( 31 | get_agents(continuous), 32 | GymEnvironment("LunarLanderContinuous-v2", device=device), 33 | 500000, 34 | sbatch_args={"partition": "1080ti-short"}, 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /all/scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.experiments import run_experiment 4 | 5 | 6 | def train( 7 | presets, 8 | env_constructor, 9 | description="Train an RL agent", 10 | env_help="Name of the environment (e.g., 'CartPole-v0')", 11 | default_frames=40e6, 12 | ): 13 | # parse command line args 14 | parser = argparse.ArgumentParser(description=description) 15 | parser.add_argument("env", help=env_help) 16 | parser.add_argument( 17 | "agent", 18 | help="Name of the agent (e.g. 'dqn'). See presets for available agents.", 19 | ) 20 | parser.add_argument( 21 | "--device", 22 | default="cuda", 23 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).", 24 | ) 25 | parser.add_argument( 26 | "--frames", 27 | type=int, 28 | default=default_frames, 29 | help="The number of training frames.", 30 | ) 31 | parser.add_argument( 32 | "--render", action="store_true", default=False, help="Render the environment." 33 | ) 34 | parser.add_argument("--logdir", default="runs", help="The base logging directory.") 35 | parser.add_argument( 36 | "--save_freq", default=100, help="How often to save the model, in episodes." 37 | ) 38 | parser.add_argument("--hyperparameters", default=[], nargs="*") 39 | args = parser.parse_args() 40 | 41 | # construct the environment 42 | env = env_constructor(args.env, device=args.device) 43 | 44 | # construct the agents 45 | agent_name = args.agent 46 | agent = getattr(presets, agent_name) 47 | agent = agent.device(args.device) 48 | 49 | # parse hyperparameters 50 | hyperparameters = {} 51 | for hp in args.hyperparameters: 52 | key, value = hp.split("=") 53 | hyperparameters[key] = type(agent.default_hyperparameters[key])(value) 54 | agent = agent.hyperparameters(**hyperparameters) 55 | 56 | # run the experiment 57 | run_experiment( 58 | agent, 59 | env, 60 | args.frames, 61 | render=args.render, 62 | logdir=args.logdir, 63 | save_freq=args.save_freq, 64 | ) 65 | -------------------------------------------------------------------------------- /all/scripts/train_atari.py: -------------------------------------------------------------------------------- 1 | from all.environments import AtariEnvironment 2 | from all.presets import atari 3 | 4 | from .train import train 5 | 6 | 7 | def main(): 8 | train( 9 | atari, 10 | AtariEnvironment, 11 | description="Train an agent on an Atari environment.", 12 | env_help="The name of the environment (e.g., 'Pong').", 13 | default_frames=40e6, 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /all/scripts/train_classic.py: -------------------------------------------------------------------------------- 1 | from all.environments import GymEnvironment 2 | from all.presets import classic_control 3 | 4 | from .train import train 5 | 6 | 7 | def main(): 8 | train( 9 | classic_control, 10 | GymEnvironment, 11 | description="Train an agent on an classic control environment.", 12 | env_help="The name of the environment (e.g., CartPole-v0).", 13 | default_frames=50000, 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /all/scripts/train_continuous.py: -------------------------------------------------------------------------------- 1 | from all.environments import GymEnvironment 2 | from all.presets import continuous 3 | 4 | from .train import train 5 | 6 | 7 | def main(): 8 | train( 9 | continuous, 10 | GymEnvironment, 11 | description="Train an agent on a continuous control environment.", 12 | env_help="The name of the environment (e.g., MountainCarContinuous-v0).", 13 | default_frames=10e6, 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /all/scripts/train_mujoco.py: -------------------------------------------------------------------------------- 1 | from all.environments import MujocoEnvironment 2 | from all.presets import continuous 3 | 4 | from .train import train 5 | 6 | 7 | def main(): 8 | train( 9 | continuous, 10 | MujocoEnvironment, 11 | description="Train an agent on an Mujoco environment.", 12 | env_help="The name of the environment (e.g., Ant-v4).", 13 | default_frames=10e6, 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /all/scripts/train_multiagent_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.environments import MultiagentAtariEnv 4 | from all.experiments.multiagent_env_experiment import MultiagentEnvExperiment 5 | from all.presets import IndependentMultiagentPreset, atari 6 | 7 | 8 | class DummyEnv: 9 | def __init__(self, state_space, action_space): 10 | self.state_space = state_space 11 | self.action_space = action_space 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description="Run an multiagent Atari benchmark.") 16 | parser.add_argument("env", help="Name of the Atari game (e.g. pong_v3).") 17 | parser.add_argument("agents", nargs="*", help="List of agents.") 18 | parser.add_argument( 19 | "--device", 20 | default="cuda", 21 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).", 22 | ) 23 | parser.add_argument( 24 | "--replay_buffer_size", 25 | default=100000, 26 | help="The size of the replay buffer, if applicable", 27 | ) 28 | parser.add_argument( 29 | "--frames", type=int, default=40e6, help="The number of training frames." 30 | ) 31 | parser.add_argument( 32 | "--save_freq", default=100, help="How often to save the model, in episodes." 33 | ) 34 | parser.add_argument( 35 | "--render", action="store_true", default=False, help="Render the environment." 36 | ) 37 | args = parser.parse_args() 38 | 39 | env = MultiagentAtariEnv(args.env, device=args.device) 40 | 41 | assert len(env.agents) == len( 42 | args.agents 43 | ), f"Must specify {len(env.agents)} agents for this environment." 44 | 45 | presets = { 46 | agent_id: getattr(atari, agent_type) 47 | .hyperparameters(replay_buffer_size=args.replay_buffer_size) 48 | .device(args.device) 49 | .env(env.subenvs[agent_id]) 50 | .build() 51 | for agent_id, agent_type in zip(env.agents, args.agents) 52 | } 53 | 54 | experiment = MultiagentEnvExperiment( 55 | IndependentMultiagentPreset("Independent", args.device, presets), 56 | env, 57 | save_freq=args.save_freq, 58 | render=args.render, 59 | ) 60 | experiment.save() 61 | experiment.train(frames=args.frames) 62 | experiment.save() 63 | experiment.test(episodes=100) 64 | experiment.close() 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /all/scripts/train_pybullet.py: -------------------------------------------------------------------------------- 1 | from all.environments import PybulletEnvironment 2 | from all.presets import continuous 3 | 4 | from .train import train 5 | 6 | 7 | def main(): 8 | train( 9 | continuous, 10 | PybulletEnvironment, 11 | description="Train an agent on an PyBullet environment.", 12 | env_help="The name of the environment (e.g., AntBulletEnv-v0).", 13 | default_frames=10e6, 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /all/scripts/watch_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.environments import AtariEnvironment 4 | from all.experiments import load_and_watch 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description="Run an Atari benchmark.") 9 | parser.add_argument("env", help="Name of the Atari game (e.g. Pong)") 10 | parser.add_argument("filename", help="File where the model was saved.") 11 | parser.add_argument( 12 | "--device", 13 | default="cuda", 14 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", 15 | ) 16 | parser.add_argument( 17 | "--fps", 18 | default=60, 19 | help="Playback speed", 20 | ) 21 | args = parser.parse_args() 22 | env = AtariEnvironment(args.env, device=args.device, render_mode="human") 23 | load_and_watch(args.filename, env, fps=args.fps) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /all/scripts/watch_classic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.environments import GymEnvironment 4 | from all.experiments import load_and_watch 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description="Watch a classic control agent.") 9 | parser.add_argument("env", help="Name of the environment (e.g. CartPole-v0)") 10 | parser.add_argument("filename", help="File where the model was saved.") 11 | parser.add_argument( 12 | "--device", 13 | default="cuda", 14 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", 15 | ) 16 | parser.add_argument( 17 | "--fps", 18 | default=60, 19 | help="Playback speed", 20 | ) 21 | args = parser.parse_args() 22 | env = GymEnvironment(args.env, device=args.device, render_mode="human") 23 | load_and_watch(args.filename, env, fps=args.fps) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /all/scripts/watch_continuous.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.environments import GymEnvironment 4 | from all.experiments import load_and_watch 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description="Watch a continuous agent.") 9 | parser.add_argument( 10 | "env", help="Name of the environment (e.g., LunarLanderContinuous-v2)" 11 | ) 12 | parser.add_argument("filename", help="File where the model was saved.") 13 | parser.add_argument( 14 | "--device", 15 | default="cuda", 16 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", 17 | ) 18 | parser.add_argument( 19 | "--fps", 20 | default=120, 21 | help="Playback speed", 22 | ) 23 | args = parser.parse_args() 24 | env = GymEnvironment(args.env, device=args.device, render_mode="human") 25 | load_and_watch(args.filename, env, fps=args.fps) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /all/scripts/watch_mujoco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from all.environments import MujocoEnvironment 4 | from all.experiments import load_and_watch 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description="Watch a mujoco agent.") 9 | parser.add_argument("env", help="ID of the Environment") 10 | parser.add_argument("filename", help="File where the model was saved.") 11 | parser.add_argument( 12 | "--device", 13 | default="cuda", 14 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", 15 | ) 16 | parser.add_argument( 17 | "--fps", 18 | default=120, 19 | help="Playback speed", 20 | ) 21 | args = parser.parse_args() 22 | env = MujocoEnvironment(args.env, device=args.device, render_mode="human") 23 | load_and_watch(args.filename, env, fps=args.fps) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /all/scripts/watch_multiagent_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch 5 | 6 | from all.environments import MultiagentAtariEnv 7 | 8 | 9 | def watch(env, filename, fps, reload): 10 | agent = torch.load(filename).test_agent() 11 | 12 | while True: 13 | watch_episode(env, agent, fps) 14 | if reload: 15 | try: 16 | agent = torch.load(filename).test_agent() 17 | except Exception as e: 18 | print("Warning: error reloading model: {}".format(filename)) 19 | print(e) 20 | 21 | 22 | def watch_episode(env, agent, fps): 23 | env.reset() 24 | for _ in env.agent_iter(): 25 | env.render() 26 | state = env.last() 27 | action = agent.act(state) 28 | if state.done: 29 | env.step(None) 30 | else: 31 | env.step(action) 32 | time.sleep(1 / fps) 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser(description="Watch pretrained multiagent atari") 37 | parser.add_argument("env", help="Name of the Atari game (e.g. pong_v3)") 38 | parser.add_argument("filename", help="File where the model was saved.") 39 | parser.add_argument( 40 | "--device", 41 | default="cuda", 42 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", 43 | ) 44 | parser.add_argument( 45 | "--fps", 46 | default=30, 47 | type=int, 48 | help="Playback speed", 49 | ) 50 | parser.add_argument( 51 | "--reload", 52 | action="store_true", 53 | default=False, 54 | help="Reload the model from disk after every episode", 55 | ) 56 | args = parser.parse_args() 57 | env = MultiagentAtariEnv(args.env, device=args.device, render_mode="human") 58 | watch(env, args.filename, args.fps, args.reload) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /all/scripts/watch_pybullet.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-import 2 | import argparse 3 | 4 | from all.environments import PybulletEnvironment 5 | from all.experiments import load_and_watch 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description="Watch a PyBullet agent.") 10 | parser.add_argument("env", help="Name of the environment (e.g., AntBulletEnv-v0)") 11 | parser.add_argument("filename", help="File where the model was saved.") 12 | parser.add_argument( 13 | "--device", 14 | default="cuda", 15 | help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", 16 | ) 17 | parser.add_argument( 18 | "--fps", 19 | default=120, 20 | help="Playback speed", 21 | ) 22 | args = parser.parse_args() 23 | env = PybulletEnvironment(args.env, device=args.device) 24 | env.render(mode="human") # needed for pybullet envs 25 | load_and_watch(args.filename, env, fps=args.fps) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /benchmarks/atari_40m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/benchmarks/atari_40m.png -------------------------------------------------------------------------------- /benchmarks/atari_40m.py: -------------------------------------------------------------------------------- 1 | from all.environments import AtariEnvironment 2 | from all.experiments import SlurmExperiment 3 | from all.presets import atari 4 | 5 | 6 | def main(): 7 | agents = [ 8 | atari.a2c, 9 | atari.c51, 10 | atari.dqn, 11 | atari.ddqn, 12 | atari.ppo, 13 | atari.rainbow, 14 | ] 15 | envs = [ 16 | AtariEnvironment(env, device="cuda") 17 | for env in ["BeamRider", "Breakout", "Pong", "Qbert", "SpaceInvaders"] 18 | ] 19 | SlurmExperiment( 20 | agents, 21 | envs, 22 | 10e6, 23 | logdir="benchmarks/atari_40m", 24 | sbatch_args={"partition": "gypsum-1080ti"}, 25 | ) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /benchmarks/mujoco_v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/benchmarks/mujoco_v4.png -------------------------------------------------------------------------------- /benchmarks/mujoco_v4.py: -------------------------------------------------------------------------------- 1 | from all.environments import MujocoEnvironment 2 | from all.experiments import SlurmExperiment 3 | from all.presets.continuous import ddpg, ppo, sac 4 | 5 | 6 | def main(): 7 | frames = int(5e6) 8 | 9 | agents = [ddpg, ppo, sac] 10 | 11 | envs = [ 12 | MujocoEnvironment(env, device="cuda") 13 | for env in [ 14 | "Ant-v4", 15 | "HalfCheetah-v4", 16 | "Hopper-v4", 17 | "Humanoid-v4", 18 | "Walker2d-v4", 19 | ] 20 | ] 21 | 22 | SlurmExperiment( 23 | agents, 24 | envs, 25 | frames, 26 | logdir="benchmarks/mujoco_v4", 27 | sbatch_args={ 28 | "partition": "gpu-long", 29 | }, 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /benchmarks/pybullet_v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/benchmarks/pybullet_v0.png -------------------------------------------------------------------------------- /benchmarks/pybullet_v0.py: -------------------------------------------------------------------------------- 1 | from all.environments import PybulletEnvironment 2 | from all.experiments import SlurmExperiment 3 | from all.presets.continuous import ddpg, ppo, sac 4 | 5 | 6 | def main(): 7 | frames = int(5e6) 8 | 9 | agents = [ddpg, ppo, sac] 10 | 11 | envs = [ 12 | PybulletEnvironment(env, device="cuda") 13 | for env in [ 14 | "AntBulletEnv-v0", 15 | "HalfCheetahBulletEnv-v0", 16 | "HopperBulletEnv-v0", 17 | "HumanoidBulletEnv-v0", 18 | "Walker2DBulletEnv-v0", 19 | ] 20 | ] 21 | 22 | SlurmExperiment( 23 | agents, 24 | envs, 25 | frames, 26 | logdir="benchmarks/pybullet_v0", 27 | sbatch_args={ 28 | "partition": "gpu-long", 29 | }, 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'autonomous-learning-library' 21 | copyright = '2024, Chris Nota' 22 | author = 'Chris Nota' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.9.1' 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | 'sphinx.ext.autodoc', 34 | 'sphinx.ext.napoleon', 35 | 'sphinx.ext.autosummary', 36 | 'sphinx.ext.autosectionlabel', 37 | 'sphinx_automodapi.automodapi', 38 | 'sphinx_rtd_theme' 39 | ] 40 | 41 | # Autosummary settings 42 | autodoc_default_options = { 43 | 'members': True, 44 | 'undoc-members': True, 45 | 'show-inheritance': True 46 | } 47 | autosummary_generate = True 48 | autodoc_inherit_docstrings = True 49 | 50 | # Mock requirements to save resources during doc build machine setup 51 | autodoc_mock_imports = [ 52 | 'torch', 53 | 'torchvision', 54 | ] 55 | 56 | # Add any paths that contain templates here, relative to this directory. 57 | templates_path = ['_templates'] 58 | 59 | # List of patterns, relative to source directory, that match files and 60 | # directories to ignore when looking for source files. 61 | # This pattern also affects html_static_path and html_extra_path. 62 | exclude_patterns = [] 63 | 64 | 65 | # -- Options for HTML output ------------------------------------------------- 66 | 67 | # The theme to use for HTML and HTML Help pages. See the documentation for 68 | # a list of builtin themes. 69 | # 70 | html_theme = 'sphinx_rtd_theme' 71 | 72 | # Add any paths that contain custom static files (such as style sheets) here, 73 | # relative to this directory. They are copied after the builtin static files, 74 | # so a file named "default.css" will overwrite the builtin "default.css". 75 | # html_static_path = ['_static'] 76 | -------------------------------------------------------------------------------- /docs/source/environments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/environments.png -------------------------------------------------------------------------------- /docs/source/guide/ale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/guide/ale.png -------------------------------------------------------------------------------- /docs/source/guide/approximation.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/guide/approximation.jpeg -------------------------------------------------------------------------------- /docs/source/guide/creating_agent.rst: -------------------------------------------------------------------------------- 1 | Building Your Own Agent 2 | ======================= 3 | 4 | In the previous section, we discussed the basic components of the ``autonomous-learning-library``. 5 | While the library contains a selection of preset agents, the primary goal of the library is to be a tool to build *your own* agents. 6 | To this end, we have provided an `example project `_ containing a new *model predictive control* variant of DQN to demonstrate the flexibility of the library. 7 | Briefly, when creating your own agent, you will generally have the following components: 8 | 9 | 1. An ``agent.py`` file containing the high-level implementation of the ``Agent``. 10 | 2. A ``model.py`` file containing the PyTorch models appropriate for your chosen domain. 11 | 3. A ``preset.py`` file that composes your ``Agent`` using the appropriate model and other objects. 12 | 4. A ``main.py`` or similar file that runs your agent and any ``autonomous-learning-library`` presets you wish to compare against. 13 | 14 | While it is not necessary to follow this structure, we believe it will generally guide you towards using the ``autonomous-learning-library`` in the intended manner and ensure that your code is understandable to other users of the library. 15 | -------------------------------------------------------------------------------- /docs/source/guide/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Prerequisites 5 | ------------- 6 | 7 | The Autonomous Learning Library requires a recent version of PyTorch (at least v2.2.0 is recommended). 8 | Additionally, Tensorboard is required in order to enable logging. 9 | We also strongly recommend using a machine with a fast GPU with at least 11 GB of VRAM (a GTX 1080ti or better is preferred). 10 | 11 | Installation 12 | ------------ 13 | 14 | The ``autonomous-learning-library`` can be installed from PyPi using ``pip``: 15 | 16 | .. code-block:: bash 17 | 18 | pip install autonomous-learning-library 19 | 20 | This will only install the core library. 21 | If you want to install all included environments, run: 22 | 23 | .. code-block:: bash 24 | 25 | pip install autonomous-learning-library[all] 26 | 27 | You can also install only a subset of the enviornments. 28 | For the list of optional dependencies, take a look at the `setup.py `_. 29 | 30 | An alternate approach, that may be useful when following this tutorial, is to instead install by cloning the Github repository: 31 | 32 | .. code-block:: bash 33 | 34 | git clone https://github.com/cpnota/autonomous-learning-library.git 35 | cd autonomous-learning-library 36 | pip install -e .[dev] 37 | 38 | ``dev`` will install all of the optional dependencies for developers of the repo, such as unit test dependencies, as well as all environments. 39 | If you chose to clone the repository, you can test your installation by running the unit test suite: 40 | 41 | .. code-block:: bash 42 | 43 | make test 44 | 45 | This should also tell you if CUDA (the GPU driver) is available. 46 | 47 | Running a Preset Agent 48 | ---------------------- 49 | 50 | The goal of the Autonomous Learning Library is to provide components for building new agents. 51 | However, the library also includes a number of "preset" agent configurations for easy benchmarking and comparison, 52 | as well as some useful scripts. 53 | For example, an a2c agent can be run on CartPole as follows: 54 | 55 | .. code-block:: bash 56 | 57 | all-classic CartPole-v0 a2c 58 | 59 | The results will be written to ``runs/a2c_CartPole-v0_``, ```` is generated by the library. 60 | You can view these results and other information through `tensorboard`: 61 | 62 | .. code-block:: bash 63 | 64 | tensorboard --logdir runs 65 | 66 | By opening your browser to `http://localhost:6006`_, you should see a dashboard that looks something like the following (you may need to adjust the "smoothing" parameter): 67 | 68 | .. image:: tensorboard.png 69 | 70 | If you want to compare agents in a nicer, format, you can use the `plot` script: 71 | 72 | .. code-block:: bash 73 | 74 | all-plot --logdir runs 75 | 76 | This should give you a plot similar to the following: 77 | 78 | .. image:: plot.png 79 | 80 | In this plot, each point represents the average of the episodic returns over the last 100 episodes for every 100 episodes. 81 | The shaded region represents the standard deviation over that interval. 82 | 83 | Finally, to watch the trained model in action, we provide a `watch` scripts for each preset module: 84 | 85 | .. code-block:: bash 86 | 87 | all-watch-classic CartPole-v0 runs/a2c_CartPole-v0_/preset.pt 88 | 89 | You need to find the by checking the ``runs`` directory. 90 | 91 | Each of these scripts can be found the ``scripts`` directory of the main repository. 92 | Be sure to check out the ``all-atari`` and ``all-mujoco`` scripts for more fun! 93 | -------------------------------------------------------------------------------- /docs/source/guide/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/guide/plot.png -------------------------------------------------------------------------------- /docs/source/guide/rainbow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/guide/rainbow.png -------------------------------------------------------------------------------- /docs/source/guide/rl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/guide/rl.jpg -------------------------------------------------------------------------------- /docs/source/guide/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/docs/source/guide/tensorboard.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | The Autonomous Learning Library 2 | =============================== 3 | 4 | The `Autonomous Learning Library `_ is a PyTorch-based toolkit for building and evaluating reinforcement learning agents. 5 | 6 | .. image:: environments.png 7 | :align: center 8 | 9 | Here are some common links: 10 | 11 | * The `GitHub `_ repository. 12 | * The :ref:`Getting Started` guide. 13 | * An `example project `_ to help you get started building your own agents. 14 | * The :ref:`Benchmark Performance` for our preset agents. 15 | * The :ref:`all.presets` documentation, including default hyperparameters. 16 | 17 | Enjoy! 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: User Guide: 22 | 23 | guide/getting_started 24 | guide/basic_concepts 25 | guide/creating_agent 26 | guide/benchmark_performance 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | :caption: Modules: 31 | 32 | modules/agents 33 | modules/approximation 34 | modules/bodies 35 | modules/core 36 | modules/environments 37 | modules/experiments 38 | modules/logging 39 | modules/memory 40 | modules/nn 41 | modules/optim 42 | modules/policies 43 | modules/presets 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`modindex` 50 | * :ref:`search` 51 | -------------------------------------------------------------------------------- /docs/source/modules/agents.rst: -------------------------------------------------------------------------------- 1 | .. _agents: 2 | 3 | 4 | all.agents 5 | ================= 6 | 7 | .. automodsumm:: all.agents 8 | 9 | .. automodule:: all.agents 10 | :members: 11 | -------------------------------------------------------------------------------- /docs/source/modules/approximation.rst: -------------------------------------------------------------------------------- 1 | .. _approximation: 2 | 3 | 4 | all.approximation 5 | ================= 6 | 7 | .. automodule:: all.approximation 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/bodies.rst: -------------------------------------------------------------------------------- 1 | .. _bodies: 2 | 3 | 4 | all.bodies 5 | ================= 6 | 7 | .. automodule:: all.bodies 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/core.rst: -------------------------------------------------------------------------------- 1 | .. _core: 2 | 3 | 4 | all.core 5 | ================= 6 | 7 | .. automodsumm:: all.core 8 | 9 | .. automodule:: all.core 10 | :members: 11 | -------------------------------------------------------------------------------- /docs/source/modules/environments.rst: -------------------------------------------------------------------------------- 1 | .. _environments: 2 | 3 | 4 | all.environments 5 | ================= 6 | 7 | .. automodsumm:: all.environments 8 | 9 | .. automodule:: all.environments 10 | :members: 11 | :inherited-members: 12 | -------------------------------------------------------------------------------- /docs/source/modules/experiments.rst: -------------------------------------------------------------------------------- 1 | .. _experiments: 2 | 3 | 4 | all.experiments 5 | ================= 6 | 7 | .. automodule:: all.experiments 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/logging.rst: -------------------------------------------------------------------------------- 1 | .. _logging: 2 | 3 | 4 | all.logging 5 | ================= 6 | 7 | .. automodule:: all.logging 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/memory.rst: -------------------------------------------------------------------------------- 1 | .. _memory: 2 | 3 | 4 | all.memory 5 | ================= 6 | 7 | .. automodule:: all.memory 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/nn.rst: -------------------------------------------------------------------------------- 1 | .. _nn: 2 | 3 | 4 | all.nn 5 | ================= 6 | 7 | .. automodule:: all.nn 8 | :ignore-module-all: 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/source/modules/optim.rst: -------------------------------------------------------------------------------- 1 | .. _optim: 2 | 3 | 4 | all.optim 5 | ================= 6 | 7 | .. automodule:: all.optim 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/policies.rst: -------------------------------------------------------------------------------- 1 | .. _policies: 2 | 3 | 4 | all.policies 5 | ================= 6 | 7 | .. automodule:: all.policies 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/modules/presets.rst: -------------------------------------------------------------------------------- 1 | all.presets 2 | =========== 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: all.presets 7 | 8 | presets/atari 9 | presets/classic 10 | presets/continuous 11 | 12 | .. automodule:: all.presets 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/modules/presets/atari.rst: -------------------------------------------------------------------------------- 1 | .. _atari: 2 | 3 | 4 | all.presets.atari 5 | ================= 6 | 7 | .. automodsumm:: all.presets.atari 8 | 9 | .. automodule:: all.presets.atari 10 | :members: 11 | :inherited-members: 12 | :show-inheritance: 13 | -------------------------------------------------------------------------------- /docs/source/modules/presets/classic.rst: -------------------------------------------------------------------------------- 1 | .. _classic: 2 | 3 | 4 | all.presets.classic_control 5 | =========================== 6 | 7 | .. automodsumm:: all.presets.classic_control 8 | 9 | .. automodule:: all.presets.classic_control 10 | :members: 11 | :inherited-members: 12 | -------------------------------------------------------------------------------- /docs/source/modules/presets/continuous.rst: -------------------------------------------------------------------------------- 1 | .. _continuous: 2 | 3 | 4 | all.presets.continuous 5 | ====================== 6 | 7 | .. automodsumm:: all.presets.continuous 8 | 9 | .. automodule:: all.presets.continuous 10 | :members: 11 | :inherited-members: 12 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpnota/autonomous-learning-library/f8073e51eb2462b8425dcb740f06e97ff10e2b1b/examples/__init__.py -------------------------------------------------------------------------------- /examples/experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quick example of usage of the run_experiment API. 3 | """ 4 | 5 | from all.environments import GymEnvironment 6 | from all.experiments import plot_returns_100, run_experiment 7 | from all.presets.classic_control import a2c, dqn 8 | 9 | 10 | def main(): 11 | DEVICE = "cpu" 12 | # DEVICE = 'cuda' # uncomment for gpu support 13 | timesteps = 40000 14 | run_experiment( 15 | [ 16 | # DQN with default hyperparameters 17 | dqn.device(DEVICE), 18 | # DQN with a custom hyperparameters and a custom name. 19 | dqn.device(DEVICE) 20 | .hyperparameters(replay_buffer_size=100) 21 | .name("dqn-small-buffer"), 22 | # A2C with a custom name 23 | a2c.device(DEVICE).name("not-dqn"), 24 | ], 25 | [GymEnvironment("CartPole-v0", DEVICE), GymEnvironment("Acrobot-v1", DEVICE)], 26 | timesteps, 27 | ) 28 | plot_returns_100("runs", timesteps=timesteps) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /examples/slurm_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quick example of a2c running on slurm, a distributed cluster. 3 | Note that it only runs for 1 million frames. 4 | For real experiments, you will surely need a modified version of this script. 5 | """ 6 | 7 | from all.environments import AtariEnvironment 8 | from all.experiments import SlurmExperiment 9 | from all.presets.atari import a2c, dqn 10 | 11 | 12 | def main(): 13 | device = "cuda" 14 | envs = [ 15 | AtariEnvironment(env, device) for env in ["Pong", "Breakout", "SpaceInvaders"] 16 | ] 17 | SlurmExperiment( 18 | [a2c.device(device), dqn.device(device)], 19 | envs, 20 | 1e6, 21 | sbatch_args={"partition": "1080ti-short"}, 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /integration/classic_control_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from validate_agent import validate_agent 4 | 5 | from all.environments import GymEnvironment 6 | from all.presets.classic_control import ( 7 | a2c, 8 | c51, 9 | ddqn, 10 | dqn, 11 | ppo, 12 | rainbow, 13 | vac, 14 | vpg, 15 | vqn, 16 | vsarsa, 17 | ) 18 | 19 | 20 | class TestClassicControlPresets(unittest.TestCase): 21 | def test_a2c(self): 22 | self.validate(a2c) 23 | 24 | def test_c51(self): 25 | self.validate(c51) 26 | 27 | def test_ddqn(self): 28 | self.validate(ddqn) 29 | 30 | def test_dqn(self): 31 | self.validate(dqn) 32 | 33 | def test_ppo(self): 34 | self.validate(ppo) 35 | 36 | def test_rainbow(self): 37 | self.validate(rainbow) 38 | 39 | def test_vac(self): 40 | self.validate(vac) 41 | 42 | def test_vpg(self): 43 | self.validate(vpg) 44 | 45 | def test_vsarsa(self): 46 | self.validate(vsarsa) 47 | 48 | def test_vqn(self): 49 | self.validate(vqn) 50 | 51 | def validate(self, builder): 52 | validate_agent(builder.device("cpu"), GymEnvironment("CartPole-v0")) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /integration/continuous_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from validate_agent import validate_agent 4 | 5 | from all.environments import GymEnvironment, MujocoEnvironment, PybulletEnvironment 6 | from all.presets.continuous import ddpg, ppo, sac 7 | 8 | 9 | class TestContinuousPresets(unittest.TestCase): 10 | def test_ddpg(self): 11 | validate_agent( 12 | ddpg.device("cpu").hyperparameters(replay_start_size=50), 13 | GymEnvironment("MountainCarContinuous-v0"), 14 | ) 15 | 16 | def test_ppo(self): 17 | validate_agent(ppo.device("cpu"), GymEnvironment("MountainCarContinuous-v0")) 18 | 19 | def test_sac(self): 20 | validate_agent( 21 | sac.device("cpu").hyperparameters(replay_start_size=50), 22 | GymEnvironment("MountainCarContinuous-v0"), 23 | ) 24 | 25 | def test_mujoco(self): 26 | validate_agent( 27 | sac.device("cpu").hyperparameters(replay_start_size=50), 28 | MujocoEnvironment("HalfCheetah-v4"), 29 | ) 30 | 31 | def test_pybullet(self): 32 | validate_agent( 33 | sac.device("cpu").hyperparameters(replay_start_size=50), 34 | PybulletEnvironment("cheetah"), 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /integration/multiagent_atari_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from validate_agent import validate_multiagent 5 | 6 | from all.environments import MultiagentAtariEnv 7 | from all.presets import IndependentMultiagentPreset 8 | from all.presets.atari import dqn 9 | 10 | CPU = torch.device("cpu") 11 | if torch.cuda.is_available(): 12 | CUDA = torch.device("cuda") 13 | else: 14 | print( 15 | "WARNING: CUDA is not available!", 16 | "Running presets in cpu mode.", 17 | "Enable CUDA for full test coverage!", 18 | ) 19 | CUDA = torch.device("cpu") 20 | 21 | 22 | class TestMultiagentAtariPresets(unittest.TestCase): 23 | def test_independent(self): 24 | env = MultiagentAtariEnv("pong_v3", max_cycles=1000, device=CPU) 25 | presets = { 26 | agent_id: dqn.device(CPU).env(env.subenvs[agent_id]).build() 27 | for agent_id in env.agents 28 | } 29 | validate_multiagent( 30 | IndependentMultiagentPreset("independent", CPU, presets), env 31 | ) 32 | 33 | def test_independent_cuda(self): 34 | env = MultiagentAtariEnv("pong_v3", max_cycles=1000, device=CUDA) 35 | presets = { 36 | agent_id: dqn.device(CUDA).env(env.subenvs[agent_id]).build() 37 | for agent_id in env.agents 38 | } 39 | validate_multiagent( 40 | IndependentMultiagentPreset("independent", CUDA, presets), env 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /integration/validate_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from all.experiments import ( 4 | MultiagentEnvExperiment, 5 | ParallelEnvExperiment, 6 | SingleEnvExperiment, 7 | ) 8 | from all.logging import DummyLogger 9 | from all.presets import ParallelPreset 10 | 11 | 12 | class TestSingleEnvExperiment(SingleEnvExperiment): 13 | def _make_logger(self, logdir, agent_name, env_name, verbose): 14 | os.makedirs(logdir, exist_ok=True) 15 | return DummyLogger() 16 | 17 | 18 | class TestParallelEnvExperiment(ParallelEnvExperiment): 19 | def _make_logger(self, logdir, agent_name, env_name, verbose): 20 | os.makedirs(logdir, exist_ok=True) 21 | return DummyLogger() 22 | 23 | 24 | class TestMultiagentEnvExperiment(MultiagentEnvExperiment): 25 | def _make_logger(self, logdir, agent_name, env_name, verbose): 26 | os.makedirs(logdir, exist_ok=True) 27 | return DummyLogger() 28 | 29 | 30 | def validate_agent(agent, env): 31 | preset = agent.env(env).build() 32 | if isinstance(preset, ParallelPreset): 33 | experiment = TestParallelEnvExperiment(preset, env, quiet=True) 34 | else: 35 | experiment = TestSingleEnvExperiment(preset, env, quiet=True) 36 | experiment.train(episodes=2) 37 | experiment.test(episodes=2) 38 | 39 | 40 | def validate_multiagent(preset, env): 41 | experiment = TestMultiagentEnvExperiment(preset, env, quiet=True) 42 | experiment.train(episodes=2) 43 | experiment.test(episodes=2) 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | GYMNASIUM_VERSION = "0.29.1" 4 | PETTINGZOO_VERSION = "1.24.3" 5 | 6 | 7 | extras = { 8 | "atari": [ 9 | f"gymnasium[atari, accept-rom-license]~={GYMNASIUM_VERSION}", 10 | ], 11 | "pybullet": [ 12 | "pybullet>=3.2.2,<4", 13 | "gym>=0.10.0,<0.26.0", 14 | ], 15 | "mujoco": [ 16 | f"gymnasium[mujoco]~={GYMNASIUM_VERSION}", 17 | ], 18 | "ma-atari": [ 19 | f"PettingZoo[atari, accept-rom-license]~={PETTINGZOO_VERSION}", 20 | "supersuit~=3.9.2", 21 | ], 22 | "test": [ 23 | "black~=24.2.0", # linting/formatting 24 | "isort~=5.13.2", # sort imports 25 | "flake8~=7.0.0", # more linting 26 | "torch-testing==0.0.2", # pytorch assertion library 27 | ], 28 | "docs": [ 29 | "sphinx~=7.2.6", # documentation library 30 | "sphinx-autobuild~=2024.2.4", # documentation live reload 31 | "sphinx-rtd-theme~=2.0.0", # documentation theme 32 | "sphinx-automodapi~=0.17.0", # autogenerate docs for modules 33 | ], 34 | } 35 | 36 | extras["all"] = ( 37 | extras["atari"] + extras["mujoco"] + extras["pybullet"] + extras["ma-atari"] 38 | ) 39 | extras["dev"] = extras["all"] + extras["test"] 40 | 41 | setup( 42 | name="autonomous-learning-library", 43 | version="0.9.1", 44 | description=("A library for building reinforcement learning agents in Pytorch"), 45 | packages=find_packages(), 46 | url="https://github.com/cpnota/autonomous-learning-library.git", 47 | author="Chris Nota", 48 | author_email="cnota@cs.umass.edu", 49 | entry_points={ 50 | "console_scripts": [ 51 | "all-plot=all.scripts.plot:main", 52 | "all-atari=all.scripts.train_atari:main", 53 | "all-classic=all.scripts.train_classic:main", 54 | "all-continuous=all.scripts.train_continuous:main", 55 | "all-mujoco=all.scripts.train_mujoco:main", 56 | "all-multiagent-atari=all.scripts.train_multiagent_atari:main", 57 | "all-pybullet=all.scripts.train_pybullet:main", 58 | "all-watch-atari=all.scripts.watch_atari:main", 59 | "all-watch-classic=all.scripts.watch_classic:main", 60 | "all-watch-continuous=all.scripts.watch_continuous:main", 61 | "all-watch-mujoco=all.scripts.watch_mujoco:main", 62 | "all-watch-multiagent-atari=all.scripts.watch_multiagent_atari:main", 63 | "all-watch-pybullet=all.scripts.watch_pybullet:main", 64 | ], 65 | }, 66 | install_requires=[ 67 | f"gymnasium~={GYMNASIUM_VERSION}", # common environment interface 68 | "numpy~=1.22", # math library 69 | "matplotlib~=3.7", # plotting library 70 | "opencv-python-headless~=4.0", # used by atari wrappers 71 | "torch~=2.2", # core deep learning library 72 | "tensorboard~=2.8", # logging and visualization 73 | "cloudpickle~=2.0", # used to copy environments 74 | ], 75 | extras_require=extras, 76 | ) 77 | --------------------------------------------------------------------------------