├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── FAQ.md ├── LICENSE ├── README.md ├── docs ├── .DS_Store ├── doctrees │ ├── algorithms │ │ └── index.doctree │ ├── api │ │ ├── index.doctree │ │ ├── rlstructures.doctree │ │ └── rlstructures.env_wrappers.doctree │ ├── autoapi │ │ ├── conf │ │ │ └── index.doctree │ │ └── index.doctree │ ├── deprecated │ │ ├── deprecated.doctree │ │ ├── index.doctree │ │ └── tutorial │ │ │ ├── a2c.doctree │ │ │ ├── hierarchical_policy.doctree │ │ │ ├── index.doctree │ │ │ ├── recurrent_policy.doctree │ │ │ ├── reinforce.doctree │ │ │ ├── reinforce_with_evaluation.doctree │ │ │ └── transformer_policy.doctree │ ├── environment.pickle │ ├── foireaq │ │ └── foireaq.doctree │ ├── gettingstarted │ │ ├── BatcherExamples.doctree │ │ ├── DataStructures.doctree │ │ ├── Environments.doctree │ │ ├── PlayingWithRLStructures.doctree │ │ ├── RLAgentAndBatcher.doctree │ │ └── index.doctree │ ├── index.doctree │ ├── migrating_v0.1_v0.2.doctree │ └── overview.doctree ├── html │ ├── .buildinfo │ ├── .nojekyll │ ├── _modules │ │ ├── index.html │ │ └── rlstructures │ │ │ ├── core.html │ │ │ ├── env.html │ │ │ ├── env_wrappers │ │ │ └── gymenv.html │ │ │ └── rl_batchers │ │ │ ├── agent.html │ │ │ └── batcher.html │ ├── _sources │ │ ├── algorithms │ │ │ └── index.rst.txt │ │ ├── api │ │ │ ├── index.rst.txt │ │ │ ├── rlstructures.env_wrappers.rst.txt │ │ │ └── rlstructures.rst.txt │ │ ├── autoapi │ │ │ ├── conf │ │ │ │ └── index.rst.txt │ │ │ └── index.rst.txt │ │ ├── deprecated │ │ │ ├── deprecated.rst.txt │ │ │ ├── index.rst.txt │ │ │ └── tutorial │ │ │ │ ├── a2c.rst.txt │ │ │ │ ├── hierarchical_policy.rst.txt │ │ │ │ ├── index.rst.txt │ │ │ │ ├── recurrent_policy.rst.txt │ │ │ │ ├── reinforce.rst.txt │ │ │ │ ├── reinforce_with_evaluation.rst.txt │ │ │ │ └── transformer_policy.rst.txt │ │ ├── foireaq │ │ │ └── foireaq.rst.txt │ │ ├── gettingstarted │ │ │ ├── BatcherExamples.rst.txt │ │ │ ├── DataStructures.rst.txt │ │ │ ├── Environments.rst.txt │ │ │ ├── PlayingWithRLStructures.rst.txt │ │ │ ├── RLAgentAndBatcher.rst.txt │ │ │ └── index.rst.txt │ │ ├── index.rst.txt │ │ ├── migrating_v0.1_v0.2.rst.txt │ │ └── overview.rst.txt │ ├── _static │ │ ├── basic.css │ │ ├── css │ │ │ ├── badge_only.css │ │ │ ├── fonts │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.svg │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ ├── lato-normal.woff │ │ │ │ └── lato-normal.woff2 │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── fonts │ │ │ ├── FontAwesome.otf │ │ │ ├── Lato │ │ │ │ ├── lato-bold.eot │ │ │ │ ├── lato-bold.ttf │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-bolditalic.eot │ │ │ │ ├── lato-bolditalic.ttf │ │ │ │ ├── lato-bolditalic.woff │ │ │ │ ├── lato-bolditalic.woff2 │ │ │ │ ├── lato-italic.eot │ │ │ │ ├── lato-italic.ttf │ │ │ │ ├── lato-italic.woff │ │ │ │ ├── lato-italic.woff2 │ │ │ │ ├── lato-regular.eot │ │ │ │ ├── lato-regular.ttf │ │ │ │ ├── lato-regular.woff │ │ │ │ └── lato-regular.woff2 │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ ├── Roboto-Slab-Light.woff │ │ │ ├── Roboto-Slab-Light.woff2 │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ ├── Roboto-Slab-Thin.woff │ │ │ ├── Roboto-Slab-Thin.woff2 │ │ │ ├── RobotoSlab │ │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ │ └── roboto-slab-v7-regular.woff2 │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ ├── fontawesome-webfont.woff2 │ │ │ ├── lato-bold-italic.woff │ │ │ ├── lato-bold-italic.woff2 │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-normal-italic.woff │ │ │ ├── lato-normal-italic.woff2 │ │ │ ├── lato-normal.woff │ │ │ └── lato-normal.woff2 │ │ ├── graphviz.css │ │ ├── jquery-3.5.1.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── badge_only.js │ │ │ ├── html5shiv-printshiv.min.js │ │ │ ├── html5shiv.min.js │ │ │ ├── modernizr.min.js │ │ │ └── theme.js │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── underscore-1.3.1.js │ │ └── underscore.js │ ├── algorithms │ │ └── index.html │ ├── api │ │ ├── index.html │ │ ├── rlstructures.env_wrappers.html │ │ └── rlstructures.html │ ├── autoapi │ │ ├── conf │ │ │ └── index.html │ │ └── index.html │ ├── deprecated │ │ ├── deprecated.html │ │ ├── index.html │ │ └── tutorial │ │ │ ├── a2c.html │ │ │ ├── hierarchical_policy.html │ │ │ ├── index.html │ │ │ ├── recurrent_policy.html │ │ │ ├── reinforce.html │ │ │ ├── reinforce_with_evaluation.html │ │ │ └── transformer_policy.html │ ├── foireaq │ │ └── foireaq.html │ ├── genindex.html │ ├── gettingstarted │ │ ├── BatcherExamples.html │ │ ├── DataStructures.html │ │ ├── Environments.html │ │ ├── PlayingWithRLStructures.html │ │ ├── RLAgentAndBatcher.html │ │ └── index.html │ ├── index.html │ ├── migrating_v0.1_v0.2.html │ ├── objects.inv │ ├── overview.html │ ├── py-modindex.html │ ├── search.html │ └── searchindex.js └── images │ └── batchers.jpg ├── requirements.txt ├── rlalgos ├── README.md ├── __init__.py ├── a2c_gae │ ├── README.md │ ├── __init__.py │ ├── a2c.py │ ├── agent.py │ ├── atari_agent.py │ ├── main_atari.py │ └── main_cartpole.py ├── atari_wrappers.py ├── deprecated │ ├── README.md │ ├── __init__.py │ ├── a2c │ │ ├── __init__.py │ │ ├── a2c_episodes.py │ │ ├── agent.py │ │ ├── run_cartpole.py │ │ └── run_cartpole_pomdp.py │ ├── dqn │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── duelling_dqn.py │ │ └── run_q_cartpole.py │ ├── envs │ │ ├── __init__.py │ │ └── continuouscartopole.py │ ├── ppo │ │ ├── __init__.py │ │ ├── discrete_ppo.py │ │ └── run_cartpole.py │ ├── sac │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── run_cartpole.py │ │ └── sac.py │ └── template_exp.py ├── dqn │ ├── README.md │ ├── __init__.py │ ├── agent.py │ ├── duelling_dqn.py │ ├── run_atari.py │ └── run_cartpole.py ├── logger.py ├── ppo │ ├── __init__.py │ ├── discrete_ppo.py │ └── run_cartpole.py ├── reinforce │ ├── agent.py │ ├── reinforce.py │ └── run_reinforce.py ├── reinforce_device │ ├── README.md │ ├── reinforce.py │ └── run_reinforce.py ├── reinforce_diayn │ ├── README.md │ ├── agent.py │ ├── reinforce_diayn.py │ └── run_diayn.py ├── sac │ ├── __init__.py │ ├── agent.py │ ├── continuouscartopole.py │ ├── run_cartpole.py │ └── sac.py ├── simple_ddqn │ ├── README.md │ ├── __init__.py │ ├── agent.py │ ├── ddqn.py │ └── run_cartpole.py └── tools.py ├── rlstructures ├── __init__.py ├── core.py ├── deprecated │ ├── __init__.py │ ├── agent.py │ ├── batchers │ │ ├── __init__.py │ │ ├── agent_fxn.py │ │ ├── buffers.py │ │ ├── episodebatchers.py │ │ ├── threadworker.py │ │ └── trajectorybatchers.py │ └── logging.py ├── env.py ├── env_wrappers │ ├── __init__.py │ ├── devicewrapper.py │ └── gymenv.py └── rl_batchers │ ├── __init__.py │ ├── agent.py │ ├── batcher.py │ └── tools.py ├── setup.py ├── sphinx_docs ├── Makefile └── source │ ├── algorithms │ └── index.rst │ ├── api │ ├── index.rst │ ├── rlstructures.env_wrappers.rst │ └── rlstructures.rst │ ├── conf.py │ ├── deprecated │ ├── deprecated.rst │ ├── index.rst │ └── tutorial │ │ ├── a2c.rst │ │ ├── hierarchical_policy.rst │ │ ├── index.rst │ │ ├── recurrent_policy.rst │ │ ├── reinforce.rst │ │ ├── reinforce_with_evaluation.rst │ │ └── transformer_policy.rst │ ├── foireaq │ └── foireaq.rst │ ├── gettingstarted │ ├── BatcherExamples.rst │ ├── DataStructures.rst │ ├── Environments.rst │ ├── PlayingWithRLStructures.rst │ ├── RLAgentAndBatcher.rst │ └── index.rst │ ├── index.rst │ ├── migrating_v0.1_v0.2.rst │ └── overview.rst └── tutorial ├── __init__.py ├── deprecated ├── deprecated_tutorial_agent.py ├── deprecated_tutorial_multiprocess_episode_batcher.py ├── deprecated_tutorial_multiprocess_trajectory_batcher.py ├── tutorial_a2c_with_infinite_env │ ├── __init__.py │ ├── a2c.py │ └── main_a2c.py ├── tutorial_from_reinforce_to_a2c │ ├── __init__.py │ ├── a2c.py │ └── main_a2c.py ├── tutorial_from_reinforce_to_a2c_s │ ├── __init__.py │ ├── a2c.py │ ├── agent.py │ └── main_a2c.py ├── tutorial_recurrent_a2c_gae_s │ ├── __init__.py │ ├── a2c.py │ └── main_a2c.py ├── tutorial_recurrent_a2c_s │ ├── __init__.py │ ├── a2c.py │ ├── agent.py │ ├── main_a2c.py │ └── test.py ├── tutorial_recurrent_policy │ ├── __init__.py │ ├── a2c.py │ ├── agent.py │ └── main_a2c.py ├── tutorial_reinforce │ ├── __init__.py │ ├── agent.py │ ├── main_reinforce.py │ └── reinforce.py ├── tutorial_reinforce_s │ ├── __init__.py │ ├── agent.py │ ├── main_reinforce.py │ └── reinforce.py ├── tutorial_reinforce_with_evaluation │ ├── __init__.py │ ├── main_reinforce.py │ └── reinforce.py └── tutorial_reinforce_with_evaluation_s │ ├── __init__.py │ ├── agent.py │ ├── main_reinforce.py │ └── reinforce.py ├── playing_with_envs.py ├── playing_with_rlstructures.py ├── tutorial_datastructures.py ├── tutorial_environments.py └── tutorial_rlagent.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | outputs/ 3 | results/ 4 | **/.ipynb_checkpoints 5 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # (January, 2021) 2 | 3 | * Initial Release 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org 46 | 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to RLStructures 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | We actively welcome your pull requests. 9 | 10 | 1. Fork the repo and create your branch from `main`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | 26 | We use GitHub issues to track public bugs. Please ensure your description is 27 | clear and has sufficient instructions to be able to reproduce the issue. 28 | 29 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 30 | disclosure of security bugs. In those cases, please go through the process 31 | outlined on that page and do not file a public issue. 32 | 33 | ## License 34 | By contributing to RLStructures, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | Classical Deep Learning (DL) algorithms optimize a loss function defined over a training dataset. The implementation of such a training loop is quite simple. 5 | Moreover, the computation can be easily sped-up by i) using GPUs for the loss and gradient computations, ii) by using multi-process DataLoaders to deal with large amounts of data. 6 | 7 | *Can we do the same in Reinforcement Learning? This is the objective of RLStructures.* 8 | 9 | The main difference between classical DL approaches and Reinforcement Learning ones is in the way data is acquired. While in DL batches come from a dataset, in RL learning data comes from the interaction between policies and an environment. Moreover, the nature of the collected information is usually more complex, structured as sequences, and involving multiple variables (e.g observations, actions, internal state of the policies, etc.), particularly when considering complex policies like hierarchical ones, mixtures, etc. 10 | 11 | RLStructures is a library focused on making the implementation of RL algorithms **as simple as possible**, providing tools allowing users to write RL learning loops easily. Indeed, RLStructures will take care of simulating the interactions between multiple policies and multiple agents at scale and will return a simple data structure on which loss computation is easy. **RLStructures is not a library of RL algorithms** (even if some algorithms are provided to illustrate how it can be used in practical cases) and can be used in any setting where policies interact with environments, and where **the user needs to do it at scale**, including unsupervised RL, meta-RL, multitask RL, etc…. 12 | 13 | RLStructures is based on three components that have been made as easy as possible to facilitate the use of the library: 14 | 15 | * A generic data structure (DictTensor) encoding (batches of) information exchanged between the agent and the environment. In addition, we provide a temporal version of this structure (TemporalDictTensor) encoding (batches of) sequences of DictTensor, the sequences being of various lengths 16 | * An Agent API allowing the implementation of complex agents (with or without using pytorch models) with complex outputs and internal states (e.g hierarchical agents, transformer-based agents, etc.). The API also allows the user to specify which information is stored to allow future computations. 17 | * A set of Batchers where a batcher handles the execution of multiple agents over multiple environments and produces as an output a data structure (TemporalDictTensor) that will be easily usable for computation. By using multi-process batchers, the user will easily scale his/her algorithm to multiple environments, multiple policies on multiple cores. Moreover, since batchers can be executed asynchronuously, the user will be able to collect interactions between agents and environments while executing some other computations in parallel. 18 | 19 | Note that multiple GPUs and CPUs can be used (but examples are provided with batchers on CPUs, and learning algorithm on CPU or GPU) 20 | 21 | Moreover, RLStructures provides: 22 | 23 | * A set of classical algorithms (A2C, PPO, DDQN and SAC) 24 | * A simple logger allowing users to monitor results on tensorboard, but also to generate CSV files for future analysis 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/.DS_Store -------------------------------------------------------------------------------- /docs/doctrees/algorithms/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/algorithms/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/api/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/api/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/api/rlstructures.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/api/rlstructures.doctree -------------------------------------------------------------------------------- /docs/doctrees/api/rlstructures.env_wrappers.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/api/rlstructures.env_wrappers.doctree -------------------------------------------------------------------------------- /docs/doctrees/autoapi/conf/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/autoapi/conf/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/autoapi/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/autoapi/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/deprecated.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/deprecated.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/a2c.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/a2c.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/hierarchical_policy.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/hierarchical_policy.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/recurrent_policy.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/recurrent_policy.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/reinforce.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/reinforce.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/reinforce_with_evaluation.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/reinforce_with_evaluation.doctree -------------------------------------------------------------------------------- /docs/doctrees/deprecated/tutorial/transformer_policy.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/deprecated/tutorial/transformer_policy.doctree -------------------------------------------------------------------------------- /docs/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/doctrees/foireaq/foireaq.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/foireaq/foireaq.doctree -------------------------------------------------------------------------------- /docs/doctrees/gettingstarted/BatcherExamples.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/gettingstarted/BatcherExamples.doctree -------------------------------------------------------------------------------- /docs/doctrees/gettingstarted/DataStructures.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/gettingstarted/DataStructures.doctree -------------------------------------------------------------------------------- /docs/doctrees/gettingstarted/Environments.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/gettingstarted/Environments.doctree -------------------------------------------------------------------------------- /docs/doctrees/gettingstarted/PlayingWithRLStructures.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/gettingstarted/PlayingWithRLStructures.doctree -------------------------------------------------------------------------------- /docs/doctrees/gettingstarted/RLAgentAndBatcher.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/gettingstarted/RLAgentAndBatcher.doctree -------------------------------------------------------------------------------- /docs/doctrees/gettingstarted/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/gettingstarted/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/doctrees/migrating_v0.1_v0.2.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/migrating_v0.1_v0.2.doctree -------------------------------------------------------------------------------- /docs/doctrees/overview.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/doctrees/overview.doctree -------------------------------------------------------------------------------- /docs/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 154ca7e76289922148a5f8e72a97093e 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/html/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/.nojekyll -------------------------------------------------------------------------------- /docs/html/_sources/algorithms/index.rst.txt: -------------------------------------------------------------------------------- 1 | Provided Algorithms 2 | =================== 3 | 4 | We provide multiple RL algorithms as examples. 5 | 6 | 1) A2C with General Advantage Estimator 7 | 2) PPO with discrete actions 8 | 3) Double Duelling Q-Learning + Prioritized Experience Replay 9 | 3bis) A simpler DQN implementation (as an example) 10 | 4) SAC for continuous actions 11 | 5) REINFORCE 12 | 6) REINFORCE DIAYN (see https://arxiv.org/abs/1802.06070) 13 | 14 | The algorithms can be used as examples to implement your own algorithms. 15 | 16 | Typical execution is `OMP_NUM_THREADS=1 PYTHONPATH=rlstructures python rlstructures/rlalgos/reinforce/main_reinforce.py` 17 | 18 | Note that all algorithms produced a tensorboard and a CSV output (see `config["logdir"]` in the main file) 19 | -------------------------------------------------------------------------------- /docs/html/_sources/api/index.rst.txt: -------------------------------------------------------------------------------- 1 | RLStructures API 2 | ================ 3 | .. toctree:: 4 | :maxdepth: 1 5 | :caption: API 6 | 7 | rlstructures 8 | rlstructures.env_wrappers 9 | -------------------------------------------------------------------------------- /docs/html/_sources/api/rlstructures.env_wrappers.rst.txt: -------------------------------------------------------------------------------- 1 | rlstructures.env\_wrappers package 2 | ================================== 3 | 4 | OpenAI Gym Wrappers 5 | ---------------------------------------- 6 | 7 | .. automodule:: rlstructures.env_wrappers.gymenv 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/html/_sources/api/rlstructures.rst.txt: -------------------------------------------------------------------------------- 1 | rlstructures API 2 | ==================== 3 | 4 | DictTensor, TemporalDictTensor, Trajectories 5 | -------------------------------------------- 6 | 7 | .. automodule:: rlstructures.core 8 | :members: 9 | :undoc-members: 10 | 11 | VecEnv 12 | ------ 13 | 14 | .. automodule:: rlstructures.env 15 | :members: 16 | :undoc-members: 17 | 18 | RL_Agent 19 | -------- 20 | 21 | .. automodule:: rlstructures.rl_batchers.agent 22 | :members: 23 | :undoc-members: 24 | 25 | RL_Batcher 26 | ---------- 27 | 28 | .. automodule:: rlstructures.rl_batchers.batcher 29 | :members: 30 | :undoc-members: 31 | -------------------------------------------------------------------------------- /docs/html/_sources/autoapi/conf/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`conf` 2 | =========== 3 | 4 | .. py:module:: conf 5 | 6 | 7 | Module Contents 8 | --------------- 9 | 10 | .. data:: project 11 | :annotation: = RLStructures 12 | 13 | 14 | 15 | .. data:: copyright 16 | :annotation: = 2021, Facebook AI Research 17 | 18 | 19 | 20 | .. data:: author 21 | :annotation: = Facebook AI Research 22 | 23 | 24 | 25 | .. data:: extensions 26 | :annotation: = ['autoapi.extension', 'sphinx.ext.autodoc', 'sphinx.ext.githubpages', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx.ext.autosummary', 'recommonmark', 'sphinx.ext.viewcode'] 27 | 28 | 29 | 30 | .. data:: autoapi_type 31 | :annotation: = python 32 | 33 | 34 | 35 | .. data:: autoapi_dirs 36 | :annotation: = ['..'] 37 | 38 | 39 | 40 | .. data:: source_suffix 41 | 42 | 43 | 44 | 45 | .. data:: master_doc 46 | :annotation: = index 47 | 48 | 49 | 50 | .. data:: templates_path 51 | :annotation: = [] 52 | 53 | 54 | 55 | .. data:: exclude_patterns 56 | :annotation: = ['_build', 'Thumbs.db', '.DS_Store'] 57 | 58 | 59 | 60 | .. data:: html_theme 61 | :annotation: = sphinx_rtd_theme 62 | 63 | 64 | 65 | .. data:: html_sidebars 66 | 67 | 68 | 69 | 70 | .. data:: html_static_path 71 | :annotation: = [] 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/html/_sources/autoapi/index.rst.txt: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | This page contains auto-generated API reference documentation [#f1]_. 5 | 6 | .. toctree:: 7 | :titlesonly: 8 | 9 | /autoapi/conf/index 10 | 11 | .. [#f1] Created with `sphinx-autoapi `_ -------------------------------------------------------------------------------- /docs/html/_sources/deprecated/index.rst.txt: -------------------------------------------------------------------------------- 1 | Deprecated API (v0.1) 2 | ===================== 3 | 4 | We provide the documentation over deprecated functions. 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | :caption: Overview 10 | 11 | deprecated 12 | tutorial/ 13 | -------------------------------------------------------------------------------- /docs/html/_sources/deprecated/tutorial/hierarchical_policy.rst.txt: -------------------------------------------------------------------------------- 1 | Hierarchical Policies 2 | ===================== 3 | 4 | (Soon....) 5 | -------------------------------------------------------------------------------- /docs/html/_sources/deprecated/tutorial/index.rst.txt: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | We propose different tutorials to learn `rlstructures`. All the tutorials are available as python files in the repository. 5 | 6 | * `Tutorial Files: ` 7 | 8 | Note that other algorithms are provided in the `rlalgos` package (PPO, DQN and SAC). 9 | 10 | .. toctree:: 11 | :maxdepth: 1 12 | :caption: Tutorials 13 | 14 | reinforce 15 | reinforce_with_evaluation 16 | a2c 17 | recurrent_policy 18 | hierarchical_policy 19 | transformer_policy 20 | -------------------------------------------------------------------------------- /docs/html/_sources/deprecated/tutorial/reinforce_with_evaluation.rst.txt: -------------------------------------------------------------------------------- 1 | Evaluation of RL models in other processes 2 | ========================================== 3 | 4 | * https://github.com/facebookresearch/rlstructures/tree/main/tutorial/tutorial_reinforce_with_evaluation 5 | 6 | 7 | Regarding the REINFORCE implementation, one missing aspect is a good evaluation of the policy: 8 | * the evaluation has to be done with the `deterministic` policy (while learning is made with the stochastic policy) 9 | * the evaluation over N episodes may be long, and we would like to avoid to slow down the learning 10 | 11 | To solve this issue, we will use another batcher in `asynchronous` mode. 12 | 13 | Creation of the evaluation batcher 14 | ---------------------------------- 15 | 16 | The evaluation batcher can be created like the trainig batcher (but with a different number of threads and slots) 17 | 18 | .. code-block:: python 19 | 20 | model=copy.deepcopy(self.learning_model) 21 | self.evaluation_batcher=EpisodeBatcher( 22 | n_timesteps=self.config["max_episode_steps"], 23 | n_slots=self.config["n_evaluation_episodes"], 24 | create_agent=self._create_agent, 25 | create_env=self._create_env, 26 | env_args={ 27 | "n_envs": self.config["n_envs"], 28 | "max_episode_steps": self.config["max_episode_steps"], 29 | "env_name":self.config["env_name"] 30 | }, 31 | agent_args={"n_actions": self.n_actions, "model": model}, 32 | n_threads=self.config["n_evaluation_threads"], 33 | seeds=[self.config["env_seed"]+k*10 for k in range(self.config["n_evaluation_threads"])], 34 | ) 35 | 36 | Running the evaluation batcher 37 | ------------------------------ 38 | 39 | Running the evaluation batcher is made through `execute`: 40 | 41 | .. code-block:: python 42 | 43 | n_episodes=self.config["n_evaluation_episodes"] 44 | agent_info=DictTensor({"stochastic":torch.tensor([False]).repeat(n_episodes)}) 45 | self.evaluation_batcher.execute(n_episodes=n_episodes,agent_info=agent_info) 46 | self.evaluation_iteration=self.iteration 47 | 48 | Note that we store the iteration at which the evaluation batcher has been executed 49 | 50 | Getting trajectories without blocking the learning 51 | -------------------------------------------------- 52 | 53 | Not we can get episodes, but in non blocking mode: the batcher will return `None` if the process of computing episodes is not finished. 54 | If the process is finished, we can 1) compute the reward 2) update the batchers models 3) relaunch the acquisition process. We thus have an evaluation process that runs without blocking the learning, and at maximum speed. 55 | 56 | .. code-block:: python 57 | 58 | evaluation_trajectories=self.evaluation_batcher.get(blocking=False) 59 | if not evaluation_trajectories is None: #trajectories are available 60 | #Compute the cumulated reward 61 | cumulated_reward=(evaluation_trajectories["_reward"]*evaluation_trajectories.mask()).sum(1).mean() 62 | self.logger.add_scalar("evaluation_reward",cumulated_reward.item(),self.evaluation_iteration) 63 | #We reexecute the evaluation batcher (with same value of agent_info and same number of episodes) 64 | self.evaluation_batcher.update(self.learning_model.state_dict()) 65 | self.evaluation_iteration=self.iteration 66 | self.evaluation_batcher.reexecute() 67 | -------------------------------------------------------------------------------- /docs/html/_sources/deprecated/tutorial/transformer_policy.rst.txt: -------------------------------------------------------------------------------- 1 | Transformer Policy 2 | ================== 3 | 4 | (Soon...) 5 | -------------------------------------------------------------------------------- /docs/html/_sources/gettingstarted/BatcherExamples.rst.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_sources/gettingstarted/BatcherExamples.rst.txt -------------------------------------------------------------------------------- /docs/html/_sources/gettingstarted/DataStructures.rst.txt: -------------------------------------------------------------------------------- 1 | 2 | Data Structures 3 | =============== 4 | 5 | * https://github.com/facebookresearch/rlstructures/blob/main/tutorial/tutorial_datastructures.py 6 | 7 | 8 | DictTensor 9 | ---------- 10 | 11 | A DictTensor is dictionary of pytorch tensors. It assumes that the first dimension of each tensor contained in the DictTensor is the batch dimension. The easiest way to build a DictTensor is to use a ditcionary of tensors as input 12 | 13 | .. code-block:: python 14 | 15 | from rlstructures import DictTensor 16 | import torch 17 | d=DictTensor({"x":torch.randn(3,5),"y":torch.randn(3,8)}) 18 | 19 | The number of elements in the batch is accessible through `n_elems()`: 20 | 21 | .. code-block:: python 22 | 23 | print(d.n_elems()," <- number of elements in the batch") 24 | 25 | An empty DictTensor can be defined as follows: 26 | 27 | .. code-block:: python 28 | 29 | d=DictTensor({}) 30 | 31 | Many methods can be used over DictTensor (see DictTensor documentation): 32 | 33 | .. code-block:: python 34 | 35 | d["x"] # Returns the tensor 'x' in the DictTensor 36 | d.keys() # Returns the names of the variables of the DictTensor 37 | 38 | Tensors can be organized in a tree structure: 39 | 40 | .. code-block:: python 41 | d=DictTensor({}) 42 | d.set("observation/x",torch.randn(5,3)) 43 | d.set("observation/y",torch.randn(5,8,2)) 44 | d.set("agent_state/z",torch.randn(5,4)) 45 | 46 | observation=d.truncate_key("observation/") #returns a DictTensor with 'x' and 'y' 47 | print(observation) 48 | 49 | 50 | TemporalDictTensor 51 | ------------------ 52 | 53 | A `TemporalDictTensor` is a packed sequence of `DictTensors`. In memory, it is stored as a dictionary of tensors, where the first dimesion is the batch dimension, and the second dimension is the time index. Each element in the batch is a sequence, and two sequences can have different lengths. 54 | 55 | .. code-block:: python 56 | 57 | from rlstructures import TemporalDictTensor 58 | 59 | #Create three sequences of variables x and y, where the length of the first sequence is 6, the length of the second is 10 and the length of the last sequence is 3 60 | d=TemporalDictTensor({"x":torch.randn(3,10,5),"y":torch.randn(3,10,8)},lengths=torch.tensor([6,10,3])) 61 | 62 | print(d.n_elems()," <- number of elements in the batch") 63 | print(d.lengths,"<- Lengths of the sequences") 64 | print(d["x"].size(),"<- access to the tensor 'x'") 65 | 66 | print("Masking: ") 67 | print(d.mask()) 68 | 69 | print("Slicing (restricting the sequence to some particular temporal indexes) ") 70 | d_slice=d.temporal_slice(0,4) 71 | print(d_slice.lengths) 72 | print(d_slice.mask()) 73 | 74 | `DictTensor` and `TemporalDictTensor` can be moved to cpu/gpu using the *xxx.to(...)* method. 75 | 76 | Trajectories 77 | ------------ 78 | 79 | We recently introduced the `Trajectories` structure as a pair of one DictTensor and one TemporalDictTensor to represent Trajectories 80 | 81 | .. code-block:: python 82 | trajectories.info #A DictTensor 83 | trajectories.trajectories #A TemporalDictTensor of transitions 84 | 85 | See the `Agent and Batcher` documentation for more details. 86 | -------------------------------------------------------------------------------- /docs/html/_sources/gettingstarted/PlayingWithRLStructures.rst.txt: -------------------------------------------------------------------------------- 1 | Playing with rlstructures 2 | ========================= 3 | 4 | We propose some examples of Batcher uses to better understand how it works. The python file is https://github.com/facebookresearch/rlstructures/blob/main/tutorial/playing_with_rlstructures.py 5 | 6 | 7 | Blocking / non-Blocking batcher execution 8 | ----------------------------------------- 9 | 10 | The `batcher.get` function can be executed in `batcher.get(blocking=True)` or `batcher.get(blocking=False)` modes. 11 | 12 | * In the first mode `blocking=True`, the progam will wait the batcher to end its acquisition and will return trajectories 13 | 14 | * In the second mode `blocking=False`,the batcher will return `None,None` is the acquisition is not finished. It thus allows to perform other computation without waiting the batcher to finished 15 | 16 | Replaying an agent over an acquired trajectory 17 | ---------------------------------------------- 18 | 19 | When trajectories have been acquired, then the autograd graph is not avaialbe (i.e batcher are launched in `require_grad=False` mode). 20 | It is important to be able to recompute the agent steps on these trajectories. 21 | 22 | We provide the `replay_agent` function to facilitate this `replay`. An example is given in https://github.com/facebookresearch/rlstructures/blob/main/rlalgos/reinforce 23 | 24 | Some other examples of use are given in the A2C and DQN implementations. 25 | -------------------------------------------------------------------------------- /docs/html/_sources/gettingstarted/index.rst.txt: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | We will explain the main concepts of rlstructures. Each HTML file is associated with a corresponding python file in the repository. 5 | 6 | * `Tutorial Files: ` 7 | 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Overview 12 | 13 | DataStructures 14 | Environments 15 | RLAgentAndBatcher 16 | PlayingWithRLStructures 17 | -------------------------------------------------------------------------------- /docs/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | rlstructures 2 | ============ 3 | 4 | TL;DR 5 | ----- 6 | `rlstructures` is a lightweight Python library that provides simple APIs as well as data structures that make as few assumptions as possible on the structure of your agent or your task while allowing the transparent execution of multiple policies on multiple environments in parallel (incl. multiple GPUs). 7 | 8 | Important Note (Feb 2021) 9 | ------------------------- 10 | 11 | Due to feedback, we have made changed over the API. The old API is still working, but we encourage you to move to the new one. The modifications are: 12 | 13 | * There is now only one Batcher class (called `RL_Batcher`) 14 | 15 | * The format of the trajectories returned by the batcher is different (see the `Getting Started` section) 16 | * The Agent API (`RL_Agent`) is different and simplified 17 | 18 | * We also include a `replay` function to facilitate loss computation 19 | * The principles are exaclty the same, and adaptation is easy (and we can help !) 20 | * The API will not change anymore during the next months. 21 | 22 | Why/What? 23 | --------- 24 | RL research addresses multiple aspects of RL like hierarchical policies, option-based policies, goal-oriented policies, structured input/output spaces, transformers-based policies, etc., and there are currently few tools to handle this diversity of research projects. 25 | 26 | We propose `rlstructures` as a way to: 27 | 28 | * Simulate multiple policies, multiple models and multiple environments simultaneously at scale 29 | 30 | * Define complex loss functions 31 | 32 | * Quickly implement various policy architectures. 33 | 34 | The main RLStructures principle is that the users delegates the sampling of trajectories and episodes to the library so they can spend most of their time on the interesting part of RL research: developing new models and algorithms. 35 | 36 | `rlstructures` is easy to use: it has very few simple interfaces that can be learned in one hour by reading the tutorials. 37 | 38 | It comes with multiple RL algorithms as examples including A2C, PPO, DDQN and SAC. 39 | 40 | Please reach out to us if you intend to use it. We will be happy to help, and potentially to implement missing functionalities. 41 | 42 | Targeted users 43 | -------------- 44 | 45 | RLStructures comes with a set of implemented RL algorithms. But rlstructures does not aim at being a repository of benchmarked RL algorithms (an other RL librairies do that very well). If your objective is to apply state-of-the-art methods on particular environments, then rlstructures is not the best fit. If your objective is to implement new algorithms, then rlstructures is a good fit. 46 | 47 | Where? 48 | ------ 49 | 50 | * Github: http://github.com/facebookresearch/rlstructures 51 | * Tutorials: https://medium.com/@ludovic.den 52 | * Discussion Group: https://www.facebook.com/groups/834804787067021 53 | 54 | .. toctree:: 55 | :maxdepth: 1 56 | :caption: Getting Started 57 | 58 | overview 59 | gettingstarted/index 60 | algorithms/index 61 | api/index 62 | foireaq/foireaq.rst 63 | migrating_v0.1_v0.2 64 | deprecated/index.rst 65 | -------------------------------------------------------------------------------- /docs/html/_sources/migrating_v0.1_v0.2.rst.txt: -------------------------------------------------------------------------------- 1 | rlstructures -- mgirating from v0.1 to v0.2 2 | =========================================== 3 | 4 | Version 0.2 of rlstructures have some critical changes: 5 | 6 | From Agent to RL_Agent 7 | ---------------------- 8 | 9 | Policies are now implemented through the RL_Agent class. The two differences are: 10 | 11 | * The RL_Agent class has a `initial_state` methods that initialize the state of the agent at reset time (i.e when you call Batcher.reset). It avoids you to handle the state initialization in the `__call__` function. 12 | 13 | * The RL_Agent does not return its `old state` anymore, and just provide the `agent_do` and `new_state` as an output 14 | 15 | From EpisodeBatcher/Batcher to RL_Batcher 16 | ----------------------------------------- 17 | 18 | RL_Batcher is the batcher class that works with RL_Agent: 19 | 20 | * At construction time: 21 | 22 | * There is no need to specify the `n_slots` arguments anymore 23 | 24 | * One has to provide examples (with n_elems()==1) of `agent_info` and `env_info` that will be sent to the batcher at construction time 25 | 26 | * You can specify the device of the batcher (default is CPU -- see the CPU/GPU tutorial) 27 | 28 | * At use time: 29 | 30 | * Only three functions are available: `reset`, `execute` and `get` 31 | 32 | * Outputs: 33 | 34 | * The RL_Batcher now outputs a `Trajectories` object composed of `trajectories.info:DictTensor` and `trajectories.trajectories:TemporalDictTensor` 35 | 36 | * `trajectories.info` contains informations that is fixed during the trajectorie: agent_info, env_info and initial agent state 37 | 38 | * `trajectories.trajectories` contains informations generated by the environment (observations), and also actions produced by the Agent 39 | 40 | Replay functions 41 | ---------------- 42 | 43 | We now propose a `replay_agent` function that allows to easily repaly an agent over trajectories (e.g for loss computation) 44 | -------------------------------------------------------------------------------- /docs/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /docs/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/file.png -------------------------------------------------------------------------------- /docs/html/_static/fonts/FontAwesome.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/FontAwesome.otf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Light.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Light.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Light.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Light.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Thin.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Thin.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Roboto-Slab-Thin.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/Roboto-Slab-Thin.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/html/_static/graphviz.css: -------------------------------------------------------------------------------- 1 | /* 2 | * graphviz.css 3 | * ~~~~~~~~~~~~ 4 | * 5 | * Sphinx stylesheet -- graphviz extension. 6 | * 7 | * :copyright: Copyright 2007-2020 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | img.graphviz { 13 | border: 0; 14 | max-width: 100%; 15 | } 16 | 17 | object.graphviz { 18 | max-width: 100%; 19 | } 20 | -------------------------------------------------------------------------------- /docs/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/minus.png -------------------------------------------------------------------------------- /docs/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/_static/plus.png -------------------------------------------------------------------------------- /docs/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/html/objects.inv -------------------------------------------------------------------------------- /docs/images/batchers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/docs/images/batchers.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | gym 4 | tensorboard 5 | -------------------------------------------------------------------------------- /rlalgos/README.md: -------------------------------------------------------------------------------- 1 | # RLStructures: rlalgos 2 | 3 | The *rlalgos* library is a collection of classical RL algorithms coded using *rlstructures* 4 | -------------------------------------------------------------------------------- /rlalgos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/__init__.py -------------------------------------------------------------------------------- /rlalgos/a2c_gae/README.md: -------------------------------------------------------------------------------- 1 | An implementation of A2C with GAE that works both on CPU and GPU for loss computation 2 | -------------------------------------------------------------------------------- /rlalgos/a2c_gae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/a2c_gae/__init__.py -------------------------------------------------------------------------------- /rlalgos/a2c_gae/main_atari.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 8 | from rlstructures.tools import weight_init 9 | import torch.nn as nn 10 | import copy 11 | import torch 12 | import time 13 | import numpy as np 14 | import torch.nn.functional as F 15 | from rlalgos.a2c_gae.atari_agent import AtariAgent, ActionModel, CriticModel, Model 16 | from rlalgos.a2c_gae.a2c import A2C 17 | import gym 18 | from gym.wrappers import TimeLimit 19 | from gym import ObservationWrapper 20 | from rlalgos.atari_wrappers import make_atari, wrap_deepmind, wrap_pytorch 21 | 22 | 23 | def create_env(n_envs, env_name, max_episode_steps=None, seed=None, **args): 24 | envs = [] 25 | for k in range(n_envs): 26 | e = make_atari(env_name) 27 | e = wrap_deepmind(e) 28 | e = wrap_pytorch(e) 29 | envs.append(e) 30 | return GymEnv(envs, seed) 31 | 32 | 33 | def create_train_env(n_envs, env_name, max_episode_steps=None, seed=None, **args): 34 | envs = [] 35 | for k in range(n_envs): 36 | e = make_atari(env_name) 37 | e = wrap_deepmind(e) 38 | e = wrap_pytorch(e) 39 | envs.append(e) 40 | return GymEnvInf(envs, seed) 41 | 42 | 43 | def create_agent(model, n_actions=1): 44 | return AtariAgent(model=model, n_actions=n_actions) 45 | 46 | 47 | class Experiment(A2C): 48 | def __init__(self, config, create_train_env, create_env, create_agent): 49 | super().__init__(config, create_train_env, create_env, create_agent) 50 | 51 | def _create_model(self): 52 | am = ActionModel(self.obs_shape, self.n_actions) 53 | cm = CriticModel(self.obs_shape) 54 | model = Model(am, cm) 55 | # model.apply(weight_init) 56 | return model 57 | 58 | 59 | if __name__ == "__main__": 60 | # We use spawn mode such that most of the environment will run in multiple processes 61 | import torch.multiprocessing as mp 62 | 63 | mp.set_start_method("spawn") 64 | 65 | config = { 66 | "env_name": "PongNoFrameskip-v4", 67 | "a2c_timesteps": 1, 68 | "n_envs": 4, 69 | "max_episode_steps": 15000, 70 | "env_seed": 42, 71 | "n_processes": 4, 72 | "n_evaluation_processes": 4, 73 | "n_evaluation_envs": 1, 74 | "time_limit": 3600, 75 | "lr": 0.0001, 76 | "discount_factor": 0.95, 77 | "critic_coef": 1.0, 78 | "entropy_coef": 0.01, 79 | "a2c_coef": 1.0, 80 | "gae_coef": 0.3, 81 | "logdir": "./results", 82 | "clip_grad": 0, 83 | "learner_device": "cpu", 84 | "save_every": 1, 85 | "optim": "RMSprop", 86 | } 87 | exp = Experiment(config, create_train_env, create_env, create_agent) 88 | exp.run() 89 | -------------------------------------------------------------------------------- /rlalgos/a2c_gae/main_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 9 | from rlstructures.tools import weight_init 10 | import torch.nn as nn 11 | import copy 12 | import torch 13 | import time 14 | import numpy as np 15 | import torch.nn.functional as F 16 | from rlalgos.a2c_gae.agent import RecurrentAgent, Model, ActionModel, CriticModel 17 | from rlalgos.a2c_gae.a2c import A2C 18 | import gym 19 | from gym.wrappers import TimeLimit 20 | from gym import ObservationWrapper 21 | 22 | 23 | class MyWrapper(ObservationWrapper): 24 | """Observation wrapper that flattens the observation.""" 25 | 26 | def __init__(self, env): 27 | super(MyWrapper, self).__init__(env) 28 | self.observation_space = None # spaces.flatten_space(env.observation_space) 29 | 30 | def observation(self, observation): 31 | return [observation[0], observation[2]] 32 | 33 | 34 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 35 | def create_gym_env(env_name): 36 | return gym.make(env_name) 37 | 38 | 39 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 40 | envs = [] 41 | for k in range(n_envs): 42 | e = create_gym_env(env_name) 43 | # e = MyWrapper(e) 44 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 45 | envs.append(e) 46 | return GymEnv(envs, seed) 47 | 48 | 49 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 50 | envs = [] 51 | for k in range(n_envs): 52 | e = create_gym_env(env_name) 53 | # e = MyWrapper(e) 54 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 55 | envs.append(e) 56 | return GymEnvInf(envs, seed) 57 | 58 | 59 | def create_agent(model, n_actions=1): 60 | return RecurrentAgent(model=model, n_actions=n_actions) 61 | 62 | 63 | class Experiment(A2C): 64 | def __init__(self, config, create_train_env, create_env, create_agent): 65 | super().__init__(config, create_train_env, create_env, create_agent) 66 | 67 | def _create_model(self): 68 | am = ActionModel(self.obs_shape[1], self.n_actions, self.config["hidden_size"]) 69 | cm = CriticModel(self.obs_shape[1], self.config["hidden_size"]) 70 | model = Model(am, cm) 71 | model.apply(weight_init) 72 | return model 73 | 74 | 75 | if __name__ == "__main__": 76 | # We use spawn mode such that most of the environment will run in multiple processes 77 | import torch.multiprocessing as mp 78 | 79 | mp.set_start_method("spawn") 80 | 81 | config = { 82 | "env_name": "CartPole-v0", 83 | "a2c_timesteps": 20, 84 | "n_envs": 4, 85 | "max_episode_steps": 100, 86 | "env_seed": 42, 87 | "n_processes": 4, 88 | "n_evaluation_processes": 2, 89 | "n_evaluation_envs": 128, 90 | "time_limit": 3600, 91 | "lr": 0.001, 92 | "hidden_size": 32, 93 | "discount_factor": 0.9, 94 | "critic_coef": 1.0, 95 | "entropy_coef": 0.01, 96 | "a2c_coef": 0.1, 97 | "gae_coef": 0.3, 98 | "logdir": "./results", 99 | "clip_grad": 40, 100 | "learner_device": "cpu", 101 | "save_every": 100, 102 | "optim":"Adam" 103 | } 104 | exp = Experiment(config, create_train_env, create_env, create_agent) 105 | exp.run() 106 | -------------------------------------------------------------------------------- /rlalgos/deprecated/README.md: -------------------------------------------------------------------------------- 1 | # Deprecated Algorithms 2 | 3 | These algorithms make use of the deprecated rlstructures API - v0.1 (but works) 4 | -------------------------------------------------------------------------------- /rlalgos/deprecated/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/deprecated/__init__.py -------------------------------------------------------------------------------- /rlalgos/deprecated/a2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/deprecated/a2c/__init__.py -------------------------------------------------------------------------------- /rlalgos/deprecated/a2c/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from rlalgos.a2c.agent import NNAgent, MLPAgentModel 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlalgos.a2c.a2c_episodes import A2CGAE 22 | 23 | 24 | import hydra 25 | from omegaconf import DictConfig, OmegaConf 26 | 27 | 28 | def create_gym_env(args): 29 | return gym.make(args["environment/env_name"]) 30 | 31 | 32 | def create_env(n_envs, max_episode_steps=None, seed=None, **args): 33 | envs = [] 34 | for k in range(n_envs): 35 | e = create_gym_env(args) 36 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 37 | envs.append(e) 38 | return GymEnv(envs, seed) 39 | 40 | 41 | def create_agent(model, n_actions=1): 42 | return NNAgent(model=model, n_actions=n_actions) 43 | 44 | 45 | class Experiment(A2CGAE): 46 | def __init__(self, config, create_env, create_agent): 47 | super().__init__(config, create_env, create_agent) 48 | 49 | def _create_model(self): 50 | module = MLPAgentModel( 51 | self.obs_dim, self.n_actions, self.config["model/hidden_size"] 52 | ) 53 | module.apply(weight_init) 54 | return module 55 | 56 | 57 | def flatten(d, parent_key="", sep="/"): 58 | items = [] 59 | for k, v in d.items(): 60 | new_key = parent_key + sep + k if parent_key else k 61 | if isinstance(v, DictConfig): 62 | items.extend(flatten(v, new_key, sep=sep).items()) 63 | else: 64 | items.append((new_key, v)) 65 | return dict(items) 66 | 67 | 68 | @hydra.main() 69 | def my_app(cfg: DictConfig) -> None: 70 | f = flatten(cfg) 71 | print(f) 72 | exp = Experiment(f, create_env, create_agent) 73 | exp.go() 74 | 75 | 76 | if __name__ == "__main__": 77 | import torch.multiprocessing as mp 78 | 79 | mp.set_start_method("spawn") 80 | 81 | my_app() 82 | -------------------------------------------------------------------------------- /rlalgos/deprecated/a2c/run_cartpole_pomdp.py: -------------------------------------------------------------------------------- 1 | from rlstructures import logging 2 | from rlstructures.env_wrappers import GymEnv 3 | from rlstructures.tools import weight_init 4 | 5 | # 6 | # Copyright (c) Facebook, Inc. and its affiliates. 7 | # 8 | # This source code is licensed under the MIT license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | # 11 | 12 | 13 | import torch.nn as nn 14 | import copy 15 | import torch 16 | import time 17 | import numpy as np 18 | import torch.nn.functional as F 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlalgos.a2c.a2c_episodes import A2CGAE 22 | from rlalgos.a2c.agent import NNAgent, GRUAgentModel 23 | 24 | import gym.spaces as spaces 25 | from gym import ObservationWrapper 26 | 27 | import hydra 28 | from omegaconf import DictConfig, OmegaConf 29 | 30 | 31 | class MyWrapper(ObservationWrapper): 32 | r"""Observation wrapper that flattens the observation.""" 33 | 34 | def __init__(self, env): 35 | super(MyWrapper, self).__init__(env) 36 | self.observation_space = None # spaces.flatten_space(env.observation_space) 37 | 38 | def observation(self, observation): 39 | return [observation[0], observation[2]] 40 | 41 | 42 | def create_gym_env(args): 43 | return gym.make(args["environment/env_name"]) 44 | 45 | 46 | def create_env(n_envs, max_episode_steps=None, seed=None, **args): 47 | envs = [] 48 | for k in range(n_envs): 49 | e = create_gym_env(args) 50 | e = MyWrapper(e) 51 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 52 | envs.append(e) 53 | return GymEnv(envs, seed) 54 | 55 | 56 | def create_agent(model, n_actions=1): 57 | return NNAgent(model=model, n_actions=n_actions) 58 | 59 | 60 | class Experiment(A2CGAE): 61 | def __init__(self, config, create_env, create_agent): 62 | super().__init__(config, create_env, create_agent) 63 | 64 | def _create_model(self): 65 | module = GRUAgentModel( 66 | self.obs_dim, self.n_actions, self.config["model/hidden_size"] 67 | ) 68 | module.apply(weight_init) 69 | return module 70 | 71 | 72 | def flatten(d, parent_key="", sep="/"): 73 | items = [] 74 | for k, v in d.items(): 75 | new_key = parent_key + sep + k if parent_key else k 76 | if isinstance(v, DictConfig): 77 | items.extend(flatten(v, new_key, sep=sep).items()) 78 | else: 79 | items.append((new_key, v)) 80 | return dict(items) 81 | 82 | 83 | @hydra.main() 84 | def my_app(cfg: DictConfig) -> None: 85 | f = flatten(cfg) 86 | print(f) 87 | exp = Experiment(f, create_env, create_agent) 88 | exp.go() 89 | 90 | 91 | if __name__ == "__main__": 92 | import torch.multiprocessing as mp 93 | 94 | mp.set_start_method("spawn") 95 | 96 | my_app() 97 | -------------------------------------------------------------------------------- /rlalgos/deprecated/dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/deprecated/dqn/__init__.py -------------------------------------------------------------------------------- /rlalgos/deprecated/dqn/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | # import rlstructures.logging as logging 13 | from rlstructures import DictTensor 14 | from rlstructures import Agent 15 | import time 16 | import numpy as np 17 | 18 | 19 | class QAgent(Agent): 20 | """ 21 | Describes a discrete agent based on a model that produces a score for each 22 | possible action, and an estimation of the value function in the current 23 | state. 24 | """ 25 | 26 | def __init__(self, model=None, n_actions=None): 27 | """ 28 | Args: 29 | model (nn.Module): a module producing a tuple: (actions scores, value) 30 | n_actions (int): the number of possible actions 31 | """ 32 | super().__init__() 33 | self.model = model 34 | self.n_actions = n_actions 35 | 36 | def update(self, sd): 37 | self.model.load_state_dict(sd) 38 | 39 | def __call__(self, state, observation, agent_info=None, history=None): 40 | """ 41 | Executing one step of the agent 42 | """ 43 | # Verify that the batch size is 1 44 | 45 | initial_state = observation["initial_state"] 46 | B = observation.n_elems() 47 | 48 | if agent_info is None: 49 | agent_info = DictTensor({"epsilon": torch.zeros(B)}) 50 | 51 | agent_step = None 52 | if state is None: 53 | assert initial_state.all() 54 | agent_step = torch.zeros(B).long() 55 | else: 56 | agent_step = ( 57 | initial_state.float() * torch.zeros(B) 58 | + (1 - initial_state.float()) * state["agent_step"] 59 | ).long() 60 | 61 | q = self.model(observation["frame"]) 62 | 63 | qs, action = q.max(1) 64 | raction = torch.tensor( 65 | np.random.randint(low=0, high=self.n_actions, size=(action.size()[0])) 66 | ) 67 | epsilon = agent_info["epsilon"] 68 | mask = torch.rand(action.size()[0]).lt(epsilon).float() 69 | action = mask * raction + (1 - mask) * action 70 | action = action.long() 71 | 72 | new_state = DictTensor({"agent_step": agent_step + 1}) 73 | 74 | agent_do = DictTensor({"action": action, "q": q}) 75 | 76 | state = DictTensor({"agent_step": agent_step}) 77 | 78 | return state, agent_do, new_state 79 | 80 | 81 | class QMLP(nn.Module): 82 | def __init__(self, n_observations, n_actions, n_hidden): 83 | super().__init__() 84 | self.linear = nn.Linear(n_observations, n_hidden) 85 | self.linear2 = nn.Linear(n_hidden, n_actions) 86 | 87 | def forward(self, frame): 88 | z = torch.tanh(self.linear(frame)) 89 | score_actions = self.linear2(z) 90 | return score_actions 91 | 92 | 93 | class DQMLP(nn.Module): 94 | def __init__(self, n_observations, n_actions, n_hidden): 95 | super().__init__() 96 | self.linear = nn.Linear(n_observations, n_hidden) 97 | self.linear_adv = nn.Linear(n_hidden, n_actions) 98 | self.linear_value = nn.Linear(n_hidden, 1) 99 | self.n_actions = n_actions 100 | 101 | def forward_common(self, frame): 102 | z = torch.tanh(self.linear(frame)) 103 | return z 104 | 105 | def forward_value(self, z): 106 | return self.linear_value(z) 107 | 108 | def forward_advantage(self, z): 109 | adv = self.linear_adv(z) 110 | advm = adv.mean(1).unsqueeze(-1).repeat(1, self.n_actions) 111 | return adv - advm 112 | 113 | def forward(self, state): 114 | z = self.forward_common(state) 115 | v = self.forward_value(z) 116 | adv = self.forward_advantage(z) 117 | return v + adv 118 | -------------------------------------------------------------------------------- /rlalgos/deprecated/dqn/run_q_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | from rlstructures.batchers import Batcher, EpisodeBatcher 13 | import torch.nn as nn 14 | import copy 15 | import torch 16 | import time 17 | import numpy as np 18 | import torch.nn.functional as F 19 | from rlalgos.dqn.agent import QAgent, QMLP, DQMLP 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | from rlalgos.dqn.duelling_dqn import DQN 23 | 24 | import hydra 25 | from omegaconf import DictConfig, OmegaConf 26 | 27 | 28 | def create_gym_env(args): 29 | return gym.make(args["environment/env_name"]) 30 | 31 | 32 | def create_env(n_envs, mode="train", max_episode_steps=None, seed=None, **args): 33 | envs = [] 34 | for k in range(n_envs): 35 | e = create_gym_env(args) 36 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 37 | envs.append(e) 38 | 39 | if mode == "train": 40 | return GymEnvInf(envs, seed) 41 | else: 42 | return GymEnv(envs, seed) 43 | 44 | 45 | def create_agent( 46 | n_actions=None, 47 | model=None, 48 | ): 49 | return QAgent(model=model, n_actions=n_actions) 50 | 51 | 52 | class Experiment(DQN): 53 | def __init__(self, config, create_env, create_agent): 54 | super().__init__(config, create_env, create_agent) 55 | 56 | def _create_model(self): 57 | module = None 58 | if not self.config["use_duelling"]: 59 | module = QMLP(self.obs_dim, self.n_actions, 64) 60 | else: 61 | module = DQMLP(self.obs_dim, self.n_actions, 64) 62 | 63 | module.apply(weight_init) 64 | return module 65 | 66 | 67 | def flatten(d, parent_key="", sep="/"): 68 | items = [] 69 | for k, v in d.items(): 70 | new_key = parent_key + sep + k if parent_key else k 71 | if isinstance(v, DictConfig): 72 | items.extend(flatten(v, new_key, sep=sep).items()) 73 | else: 74 | items.append((new_key, v)) 75 | return dict(items) 76 | 77 | 78 | @hydra.main() 79 | def my_app(cfg: DictConfig) -> None: 80 | f = flatten(cfg) 81 | print(f) 82 | exp = Experiment(f, create_env, create_agent) 83 | exp.go() 84 | 85 | 86 | if __name__ == "__main__": 87 | import torch.multiprocessing as mp 88 | 89 | mp.set_start_method("spawn") 90 | 91 | my_app() 92 | -------------------------------------------------------------------------------- /rlalgos/deprecated/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/deprecated/envs/__init__.py -------------------------------------------------------------------------------- /rlalgos/deprecated/ppo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/deprecated/ppo/__init__.py -------------------------------------------------------------------------------- /rlalgos/deprecated/ppo/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from rlalgos.a2c.agent import NNAgent, MLPAgentModel 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlalgos.ppo.discrete_ppo import PPO 22 | 23 | import hydra 24 | from omegaconf import DictConfig, OmegaConf 25 | 26 | 27 | def create_gym_env(args): 28 | return gym.make(args["environment/env_name"]) 29 | 30 | 31 | def create_env(n_envs, mode="train", max_episode_steps=None, seed=None, **args): 32 | envs = [] 33 | for k in range(n_envs): 34 | e = create_gym_env(args) 35 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 36 | envs.append(e) 37 | if mode == "train": 38 | return GymEnvInf(envs, seed) 39 | else: 40 | return GymEnv(envs, seed) 41 | 42 | 43 | def create_agent(n_actions, model): 44 | return NNAgent(model=model, n_actions=n_actions) 45 | 46 | 47 | class Experiment(PPO): 48 | def __init__(self, config, create_env, create_agent): 49 | super().__init__(config, create_env, create_agent) 50 | 51 | def _create_model(self): 52 | module = MLPAgentModel( 53 | self.obs_dim, self.n_actions, self.config["model/hidden_size"] 54 | ) 55 | module.apply(weight_init) 56 | return module 57 | 58 | 59 | def flatten(d, parent_key="", sep="/"): 60 | items = [] 61 | for k, v in d.items(): 62 | new_key = parent_key + sep + k if parent_key else k 63 | if isinstance(v, DictConfig): 64 | items.extend(flatten(v, new_key, sep=sep).items()) 65 | else: 66 | items.append((new_key, v)) 67 | return dict(items) 68 | 69 | 70 | @hydra.main() 71 | def my_app(cfg: DictConfig) -> None: 72 | f = flatten(cfg) 73 | print(f) 74 | exp = Experiment(f, create_env, create_agent) 75 | exp.go() 76 | 77 | 78 | if __name__ == "__main__": 79 | import torch.multiprocessing as mp 80 | 81 | mp.set_start_method("spawn") 82 | 83 | my_app() 84 | -------------------------------------------------------------------------------- /rlalgos/deprecated/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/deprecated/sac/__init__.py -------------------------------------------------------------------------------- /rlalgos/deprecated/sac/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from rlalgos.sac.agent import SACAgent, SACPolicy, SACQ 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlalgos.sac.sac import SAC 22 | from rlalgos.envs.continuouscartopole import ContinuousCartPoleEnv 23 | 24 | import hydra 25 | from omegaconf import DictConfig, OmegaConf 26 | 27 | 28 | def create_gym_env(args): 29 | return ContinuousCartPoleEnv() 30 | 31 | 32 | def create_env(n_envs, mode="train", max_episode_steps=None, seed=None, **args): 33 | envs = [] 34 | for k in range(n_envs): 35 | e = create_gym_env(args) 36 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 37 | envs.append(e) 38 | 39 | if mode == "train": 40 | return GymEnvInf(envs, seed) 41 | else: 42 | return GymEnv(envs, seed) 43 | 44 | 45 | def create_agent(policy, action_dim=None): 46 | return SACAgent(policy=policy, action_dim=action_dim) 47 | 48 | 49 | class Experiment(SAC): 50 | def __init__(self, config, create_env, create_agent): 51 | super().__init__(config, create_env, create_agent) 52 | 53 | def _create_model(self): 54 | module = SACPolicy(self.obs_dim, self.action_dim, 16) 55 | module.apply(weight_init) 56 | return module 57 | 58 | def _create_q(self): 59 | module = SACQ(self.obs_dim, self.action_dim, 16) 60 | module.apply(weight_init) 61 | return module 62 | 63 | 64 | def flatten(d, parent_key="", sep="/"): 65 | items = [] 66 | for k, v in d.items(): 67 | new_key = parent_key + sep + k if parent_key else k 68 | if isinstance(v, DictConfig): 69 | items.extend(flatten(v, new_key, sep=sep).items()) 70 | else: 71 | items.append((new_key, v)) 72 | return dict(items) 73 | 74 | 75 | @hydra.main() 76 | def my_app(cfg: DictConfig) -> None: 77 | f = flatten(cfg) 78 | print(f) 79 | exp = Experiment(f, create_env, create_agent) 80 | exp.go() 81 | 82 | 83 | if __name__ == "__main__": 84 | import torch.multiprocessing as mp 85 | 86 | mp.set_start_method("spawn") 87 | 88 | my_app() 89 | -------------------------------------------------------------------------------- /rlalgos/deprecated/template_exp.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlalgos.logger import Logger, TFLogger 10 | from rlstructures import DictTensor, TemporalDictTensor 11 | from rlalgos.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | 19 | 20 | class BaseExperiment: 21 | def __init__(self, config, create_env, create_agent): 22 | assert self.check_arguments(config) 23 | self.config = config 24 | self.logger = TFLogger(log_dir=self.config["logdir"], hps=self.config) 25 | self.batchers = [] 26 | self._create_env = create_env 27 | self._create_agent = create_agent 28 | 29 | def check_arguments(self, arguments): 30 | """ 31 | The function aims at checking that the arguments (provided in config) are the good ones 32 | """ 33 | return True 34 | 35 | def register_batcher(self, batcher): 36 | """ 37 | Register a new batcher when you create one, to ensure a correct closing of the experiment 38 | """ 39 | self.batchers.append(batcher) 40 | 41 | def _create_model(self): 42 | # self.learning_model = ...... 43 | raise NotImplementedError 44 | 45 | def create_model(self): 46 | self.learning_model = self._create_model() 47 | self.iteration = 0 48 | 49 | def reset(self): 50 | raise NotImplementedError 51 | 52 | def run(self): 53 | raise NotImplementedError 54 | 55 | def terminate(self): 56 | for b in self.batchers: 57 | b.close() 58 | self.logger.close() 59 | 60 | def go(self): 61 | self.create_model() 62 | self.reset() 63 | self.run() 64 | self.terminate() 65 | -------------------------------------------------------------------------------- /rlalgos/dqn/README.md: -------------------------------------------------------------------------------- 1 | Implementation of DQN with multiple versions that works both on CPU and GPU for loss computation only. Data acquisition is made on CPU 2 | -------------------------------------------------------------------------------- /rlalgos/dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/dqn/__init__.py -------------------------------------------------------------------------------- /rlalgos/dqn/run_atari.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | from rlstructures.batchers import Batcher, EpisodeBatcher 13 | import torch.nn as nn 14 | import copy 15 | import torch 16 | import time 17 | import numpy as np 18 | import torch.nn.functional as F 19 | from rlalgos.dqn.agent import QAgent, QMLP, DQMLP, DuelingCnnDQN, CnnDQN 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | from rlalgos.dqn.duelling_dqn import DQN 23 | from rlalgos.atari_wrappers import make_atari, wrap_deepmind, wrap_pytorch 24 | import math 25 | 26 | 27 | def create_env(n_envs, mode="train", max_episode_steps=None, seed=None, **args): 28 | if mode == "train": 29 | envs = [] 30 | for k in range(n_envs): 31 | e = make_atari(args["environment/env_name"]) 32 | e = wrap_deepmind(e) 33 | e = wrap_pytorch(e) 34 | envs.append(e) 35 | return GymEnvInf(envs, seed) 36 | else: 37 | envs = [] 38 | for k in range(n_envs): 39 | e = make_atari(args["environment/env_name"]) 40 | e = wrap_deepmind(e) 41 | e = wrap_pytorch(e) 42 | envs.append(e) 43 | return GymEnv(envs, seed) 44 | 45 | 46 | def create_agent( 47 | n_actions=None, 48 | model=None, 49 | ): 50 | return QAgent(model=model, n_actions=n_actions) 51 | 52 | 53 | class Experiment(DQN): 54 | def __init__(self, config, create_env, create_agent): 55 | super().__init__(config, create_env, create_agent) 56 | 57 | def _create_model(self): 58 | if self.config["use_duelling"]: 59 | module = DuelingCnnDQN(self.obs_shape, self.n_actions) 60 | else: 61 | module = CnnDQN(self.obs_shape, self.n_actions) 62 | # module.apply(weight_init) 63 | return module 64 | 65 | 66 | if __name__ == "__main__": 67 | # We use spawn mode such that most of the environment will run in multiple processes 68 | import torch.multiprocessing as mp 69 | 70 | mp.set_start_method("spawn") 71 | 72 | config = { 73 | "environment/env_name": "PongNoFrameskip-v4", 74 | "n_envs": 1, 75 | "max_episode_steps": 100, 76 | "discount_factor": 0.99, 77 | "epsilon_greedy_max": 0.5, 78 | "epsilon_greedy_min": 0.1, 79 | "epsilon_min_epoch": 1000, 80 | "replay_buffer_size": 10000, 81 | "n_batches": 32, 82 | "initial_buffer_epochs": 10, 83 | "qvalue_epochs": 1, 84 | "batch_timesteps": 1, 85 | "use_duelling": False, 86 | "use_double": False, 87 | "lr": 0.00001, 88 | "n_processes": 1, 89 | "n_evaluation_processes": 4, 90 | "verbose": True, 91 | "n_evaluation_envs": 4, 92 | "time_limit": 28800, 93 | "env_seed": 42, 94 | "clip_grad": 0.0, 95 | "learner_device": "cpu", 96 | "as_fast_as_possible": True, 97 | "optim": "AdamW", 98 | "update_target_hard": True, 99 | "update_target_epoch": 1000, 100 | "update_target_tau": 0.005, 101 | "buffer/alpha": 0.0, 102 | "buffer/beta": 0.0, 103 | "logdir": "./results", 104 | "save_every": 100, 105 | } 106 | exp = Experiment(config, create_env, create_agent) 107 | exp.run() 108 | -------------------------------------------------------------------------------- /rlalgos/dqn/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | from rlstructures.batchers import Batcher, EpisodeBatcher 13 | import torch.nn as nn 14 | import copy 15 | import torch 16 | import time 17 | import numpy as np 18 | import torch.nn.functional as F 19 | from rlalgos.dqn.agent import QAgent, QMLP, DQMLP 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | from rlalgos.dqn.duelling_dqn import DQN 23 | 24 | 25 | def create_gym_env(args): 26 | return gym.make(args["environment/env_name"]) 27 | 28 | 29 | def create_env(n_envs, mode="train", max_episode_steps=None, seed=None, **args): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(args) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | 36 | if mode == "train": 37 | return GymEnvInf(envs, seed) 38 | else: 39 | return GymEnv(envs, seed) 40 | 41 | 42 | def create_agent( 43 | n_actions=None, 44 | model=None, 45 | ): 46 | return QAgent(model=model, n_actions=n_actions) 47 | 48 | 49 | class Experiment(DQN): 50 | def __init__(self, config, create_env, create_agent): 51 | super().__init__(config, create_env, create_agent) 52 | 53 | def _create_model(self): 54 | module = None 55 | if not self.config["use_duelling"]: 56 | module = QMLP(self.obs_shape[0], self.n_actions, 64) 57 | else: 58 | module = DQMLP(self.obs_shape[0], self.n_actions, 64) 59 | 60 | module.apply(weight_init) 61 | return module 62 | 63 | 64 | def flatten(d, parent_key="", sep="/"): 65 | items = [] 66 | for k, v in d.items(): 67 | new_key = parent_key + sep + k if parent_key else k 68 | if isinstance(v, DictConfig): 69 | items.extend(flatten(v, new_key, sep=sep).items()) 70 | else: 71 | items.append((new_key, v)) 72 | return dict(items) 73 | 74 | 75 | if __name__ == "__main__": 76 | # We use spawn mode such that most of the environment will run in multiple processes 77 | import torch.multiprocessing as mp 78 | 79 | mp.set_start_method("spawn") 80 | 81 | config = { 82 | "environment/env_name": "CartPole-v0", 83 | "n_envs": 4, 84 | "max_episode_steps": 10000, 85 | "discount_factor": 0.99, 86 | "epsilon_greedy_max": 0.9, 87 | "epsilon_greedy_min": 0.01, 88 | "epsilon_min_epoch": 2000, 89 | "replay_buffer_size": 10000, 90 | "n_batches": 32, 91 | "initial_buffer_epochs": 1, 92 | "qvalue_epochs": 1, 93 | "batch_timesteps": 4, 94 | "use_duelling": True, 95 | "use_double": True, 96 | "lr": 0.001, 97 | "n_processes": 4, 98 | "n_evaluation_processes": 4, 99 | "verbose": True, 100 | "n_evaluation_envs": 32, 101 | "time_limit": 28800, 102 | "env_seed": 42, 103 | "clip_grad": 0.0, 104 | "learner_device": "cpu", 105 | "as_fast_as_possible": True, 106 | "optim": "AdamW", 107 | "update_target_hard": False, 108 | "update_target_epoch": 1000, 109 | "update_target_tau": 0.005, 110 | "buffer/alpha": 0.0, 111 | "buffer/beta": 0.0, 112 | "logdir": "./results", 113 | "save_every": 1, 114 | } 115 | exp = Experiment(config, create_env, create_agent) 116 | exp.run() 117 | -------------------------------------------------------------------------------- /rlalgos/ppo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/ppo/__init__.py -------------------------------------------------------------------------------- /rlalgos/ppo/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from rlalgos.a2c_gae.agent import RecurrentAgent, ActionModel 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlalgos.ppo.discrete_ppo import PPO 22 | from rlalgos.a2c_gae.agent import ActionModel, CriticModel, Model 23 | 24 | 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 39 | envs = [] 40 | for k in range(n_envs): 41 | e = create_gym_env(env_name) 42 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 43 | envs.append(e) 44 | return GymEnvInf(envs, seed) 45 | 46 | 47 | def create_agent(model, n_actions=1): 48 | return RecurrentAgent(model=model, n_actions=n_actions) 49 | 50 | 51 | class Experiment(PPO): 52 | def __init__(self, config, create_train_env, create_env, create_agent): 53 | super().__init__(config, create_train_env, create_env, create_agent) 54 | 55 | def _create_model(self): 56 | action_model = ActionModel( 57 | self.obs_dim, self.n_actions, self.config["model/hidden_size"] 58 | ) 59 | critic_model = CriticModel(self.obs_dim, self.config["model/hidden_size"]) 60 | module = Model(action_model, critic_model) 61 | module.apply(weight_init) 62 | return module 63 | 64 | 65 | if __name__ == "__main__": 66 | import torch.multiprocessing as mp 67 | 68 | mp.set_start_method("spawn") 69 | 70 | config = { 71 | "env_name": "CartPole-v0", 72 | "n_envs": 4, 73 | "max_episode_steps": 100, 74 | "discount_factor": 0.9, 75 | "logdir": "./results", 76 | "lr": 0.001, 77 | "n_processes": 4, 78 | "n_evaluation_processes": 4, 79 | "n_evaluation_envs": 64, 80 | "time_limit": 360, 81 | "coef_critic": 1.0, 82 | "coef_entropy": 0.01, 83 | "coef_ppo": 1.0, 84 | "env_seed": 42, 85 | "ppo_timesteps": 20, 86 | "k_epochs": 4, 87 | "eps_clip": 0.2, 88 | "gae_coef": 0.3, 89 | "clip_grad": 2, 90 | "learner_device": "cpu", 91 | "evaluation_mode": "stochastic", 92 | "verbose": True, 93 | "model/hidden_size": 16, 94 | } 95 | exp = Experiment(config, create_train_env, create_env, create_agent) 96 | exp.run() 97 | -------------------------------------------------------------------------------- /rlalgos/reinforce/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import time 11 | import torch.nn as nn 12 | from rlstructures import DictTensor,masked_tensor,masked_dicttensor 13 | from rlstructures import RL_Agent 14 | 15 | class ReinforceAgent(RL_Agent): 16 | def __init__(self,model=None, n_actions=None): 17 | super().__init__() 18 | self.model = model 19 | self.n_actions = n_actions 20 | 21 | def update(self, state_dict): 22 | self.model.load_state_dict(state_dict) 23 | 24 | def seed(self,seed): 25 | print("Agent seeding (to illustrate seed mechanism): ",seed) 26 | 27 | def require_history(self): 28 | return False 29 | 30 | def initial_state(self,agent_info,B): 31 | return DictTensor({}) 32 | 33 | def __call__(self, state,observation,agent_info=None,history=None): 34 | """ 35 | Executing one step of the agent 36 | """ 37 | assert state.empty() 38 | 39 | B = observation.n_elems() 40 | action_proba = self.model.action_model(observation["frame"]) 41 | baseline = self.model.baseline_model(observation["frame"]) 42 | dist = torch.distributions.Categorical(action_proba) 43 | action_sampled = dist.sample() 44 | 45 | action_max = action_proba.max(1)[1] 46 | smask=agent_info["stochastic"].float().to(action_max.device) 47 | action=masked_tensor(action_max,action_sampled,smask) 48 | 49 | new_state = DictTensor({}) 50 | 51 | agent_do = DictTensor( 52 | {"action": action, "action_probabilities": action_proba, "baseline":baseline} 53 | ) 54 | 55 | return agent_do, new_state 56 | 57 | class Model(nn.Module): 58 | def __init__(self,action_model,baseline_model): 59 | super().__init__() 60 | self.action_model=action_model 61 | self.baseline_model=baseline_model 62 | 63 | class ActionModel(nn.Module): 64 | """ The model that computes one score per action 65 | """ 66 | def __init__(self, n_observations, n_actions, n_hidden): 67 | super().__init__() 68 | self.linear = nn.Linear(n_observations, n_hidden) 69 | self.linear2 = nn.Linear(n_hidden, n_actions) 70 | 71 | 72 | def forward(self, frame): 73 | z = torch.tanh(self.linear(frame)) 74 | score_actions = self.linear2(z) 75 | probabilities_actions = torch.softmax(score_actions,dim=-1) 76 | return probabilities_actions 77 | 78 | class BaselineModel(nn.Module): 79 | """ The model that computes V(s) 80 | """ 81 | def __init__(self, n_observations, n_hidden): 82 | super().__init__() 83 | self.linear = nn.Linear(n_observations, n_hidden) 84 | self.linear2 = nn.Linear(n_hidden, 1) 85 | 86 | 87 | def forward(self, frame): 88 | z = torch.tanh(self.linear(frame)) 89 | critic = self.linear2(z) 90 | return critic 91 | -------------------------------------------------------------------------------- /rlalgos/reinforce/run_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures.env_wrappers import GymEnv 10 | from rlalgos.tools import weight_init 11 | import torch.nn as nn 12 | import copy 13 | import torch 14 | import time 15 | import numpy as np 16 | import torch.nn.functional as F 17 | from rlalgos.reinforce.agent import ReinforceAgent, ActionModel, BaselineModel, Model 18 | from rlalgos.reinforce.reinforce import Reinforce 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | import copy 22 | 23 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 24 | def create_gym_env(env_name): 25 | return gym.make(env_name) 26 | 27 | 28 | # Create a rlstructures.VecEnv from multiple gym.Env, limiting the number of steps 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | # Create a rlstructures.Agent 39 | def create_agent(model, n_actions=1): 40 | return ReinforceAgent(model=model, n_actions=n_actions) 41 | 42 | 43 | class Experiment(Reinforce): 44 | def __init__(self, config, create_env, create_agent): 45 | super().__init__(config, create_env, create_agent) 46 | 47 | def _create_model(self): 48 | action_model = ActionModel(self.obs_dim, self.n_actions, 16) 49 | baseline_model = BaselineModel(self.obs_dim, 16) 50 | return Model(action_model, baseline_model) 51 | 52 | 53 | if __name__ == "__main__": 54 | # We use spawn mode such that most of the environment will run in multiple processes 55 | import torch.multiprocessing as mp 56 | 57 | mp.set_start_method("spawn") 58 | 59 | config = { 60 | "env_name": "CartPole-v0", 61 | "n_envs": 4, 62 | "max_episode_steps": 100, 63 | "env_seed": 42, 64 | "n_processes": 4, 65 | "n_evaluation_processes": 2, 66 | "n_evaluation_envs": 128, 67 | "time_limit": 3600, 68 | "lr": 0.01, 69 | "discount_factor": 0.9, 70 | "baseline_coef": 0.1, 71 | "entropy_coef": 0.01, 72 | "reinforce_coef": 1.0, 73 | "evaluation_mode": "stochastic", 74 | "logdir": "./results", 75 | } 76 | exp = Experiment(config, create_env, create_agent) 77 | exp.run() 78 | -------------------------------------------------------------------------------- /rlalgos/reinforce_device/README.md: -------------------------------------------------------------------------------- 1 | A illustration about how to use multiple CPUUs and GPUs together at both the loss computation and batcher levels 2 | -------------------------------------------------------------------------------- /rlalgos/reinforce_device/run_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures.env_wrappers import GymEnv,DeviceEnv 10 | from rlalgos.tools import weight_init 11 | import torch.nn as nn 12 | import copy 13 | import torch 14 | import time 15 | import numpy as np 16 | import torch.nn.functional as F 17 | from rlalgos.reinforce.agent import ReinforceAgent, ActionModel, BaselineModel, Model 18 | from rlalgos.reinforce_device.reinforce import Reinforce 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlstructures import RL_Agent_CheckDevice 22 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 23 | 24 | def create_gym_env(env_name): 25 | return gym.make(env_name) 26 | 27 | # Create a rlstructures.VecEnv from multiple gym.Env, limiting the number of steps 28 | def create_env(n_envs, env_name=None, max_episode_steps=None, device=None,seed=None): 29 | envs = [] 30 | for k in range(n_envs): 31 | e = create_gym_env(env_name) 32 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 33 | envs.append(e) 34 | return DeviceEnv(GymEnv(envs, seed),from_device=torch.device("cpu"),to_device=device) 35 | 36 | # Create a rlstructures.Agent 37 | def create_agent(model, n_actions=1,device=None,copy_model=True): 38 | print("create agent on ",device," with model copy==",copy_model) 39 | if copy_model: 40 | model=copy.deepcopy(model) 41 | 42 | agent=ReinforceAgent(model=model.to(device), n_actions=n_actions) 43 | return RL_Agent_CheckDevice(agent,device) 44 | 45 | class Experiment(Reinforce): 46 | def __init__(self, config, create_env, create_agent): 47 | super().__init__(config, create_env, create_agent) 48 | 49 | def _create_model(self): 50 | action_model = ActionModel(self.obs_dim, self.n_actions, 16) 51 | baseline_model = BaselineModel(self.obs_dim, 16) 52 | return Model(action_model, baseline_model) 53 | 54 | 55 | if __name__ == "__main__": 56 | # We use spawn mode such that most of the environment will run in multiple processes 57 | import torch.multiprocessing as mp 58 | mp.set_start_method("spawn") 59 | 60 | config = { 61 | "env_name": "CartPole-v0", 62 | "n_envs": 4, 63 | "max_episode_steps": 100, 64 | "env_seed": 42, 65 | "n_processes": 4, 66 | "n_evaluation_processes": 2, 67 | "n_evaluation_envs": 128, 68 | "time_limit": 20, 69 | "lr": 0.01, 70 | "discount_factor": 0.9, 71 | "baseline_coef": 0.1, 72 | "entropy_coef": 0.01, 73 | "reinforce_coef": 1.0, 74 | "evaluation_mode": "stochastic", 75 | "logdir": "./results", 76 | "learner_device":torch.device("cuda:0"), 77 | "batcher_device":torch.device("cuda:1"), 78 | "evaluation_device":torch.device("cpu") 79 | } 80 | exp = Experiment(config, create_env, create_agent) 81 | exp.run() 82 | -------------------------------------------------------------------------------- /rlalgos/reinforce_diayn/README.md: -------------------------------------------------------------------------------- 1 | An implementation of the DIAYN model https://arxiv.org/abs/1802.06070 based on REINFORCE 2 | -------------------------------------------------------------------------------- /rlalgos/reinforce_diayn/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from rlstructures import DictTensor, masked_tensor, masked_dicttensor 12 | from rlstructures import RL_Agent 13 | import time 14 | 15 | 16 | class DIAYNAgent(RL_Agent): 17 | def __init__(self, model=None, n_actions=None): 18 | super().__init__() 19 | self.model = model 20 | self.n_actions = n_actions 21 | 22 | def update(self, state_dict): 23 | self.model.load_state_dict(state_dict) 24 | 25 | def require_history(self): 26 | return False 27 | 28 | def initial_state(self, agent_info, B): 29 | return DictTensor({}) 30 | 31 | def __call__(self, state, observation, agent_info=None, history=None): 32 | """ 33 | Executing one step of the agent 34 | """ 35 | assert state.empty() 36 | 37 | B = observation.n_elems() 38 | 39 | idx_policy = agent_info["idx_policy"] 40 | action_proba = self.model.action_model(observation["frame"], idx_policy) 41 | baseline = self.model.baseline_model(observation["frame"], idx_policy) 42 | 43 | dist = torch.distributions.Categorical(action_proba) 44 | action_sampled = dist.sample() 45 | 46 | action_max = action_proba.max(1)[1] 47 | smask = agent_info["stochastic"].float() 48 | action = masked_tensor(action_max, action_sampled, agent_info["stochastic"]) 49 | 50 | new_state = DictTensor({}) 51 | 52 | agent_do = DictTensor( 53 | { 54 | "action": action, 55 | "action_probabilities": action_proba, 56 | "baseline": baseline, 57 | } 58 | ) 59 | 60 | return agent_do, new_state 61 | 62 | 63 | class DIAYNModel(nn.Module): 64 | def __init__(self, action_model, baseline_model): 65 | super().__init__() 66 | self.action_model = action_model 67 | self.baseline_model = baseline_model 68 | 69 | 70 | class DIAYNActionModel(nn.Module): 71 | """The model that computes one score per action""" 72 | 73 | def __init__(self, n_observations, n_actions, n_hidden, n_policies): 74 | super().__init__() 75 | self.linear = nn.Linear(n_observations, n_hidden) 76 | self.linear2 = nn.Linear(n_hidden, n_actions * n_policies) 77 | self.n_policies = n_policies 78 | self.n_actions = n_actions 79 | 80 | def forward(self, frame, idx_policy): 81 | z = torch.tanh(self.linear(frame)) 82 | score_actions = self.linear2(z) 83 | s = score_actions.size() 84 | score_actions = score_actions.reshape(s[0], self.n_policies, self.n_actions) 85 | score_actions = score_actions[torch.arange(s[0]), idx_policy] 86 | probabilities_actions = torch.softmax(score_actions, dim=-1) 87 | return probabilities_actions 88 | 89 | 90 | class DIAYNBaselineModel(nn.Module): 91 | """The model that computes V(s)""" 92 | 93 | def __init__(self, n_observations, n_hidden, n_policies): 94 | super().__init__() 95 | self.linear = nn.Linear(n_observations, n_hidden) 96 | self.linear2 = nn.Linear(n_hidden, n_policies) 97 | self.n_policies = n_policies 98 | 99 | def forward(self, frame, idx_policy): 100 | z = torch.tanh(self.linear(frame)) 101 | critic = self.linear2(z) 102 | critic = critic.reshape(critic.size()[0], self.n_policies, 1) 103 | critic = critic[torch.arange(critic.size()[0]), idx_policy] 104 | return critic 105 | -------------------------------------------------------------------------------- /rlalgos/reinforce_diayn/run_diayn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures.env_wrappers import GymEnv 10 | from rlalgos.tools import weight_init 11 | import torch.nn as nn 12 | import copy 13 | import torch 14 | import time 15 | import numpy as np 16 | import torch.nn.functional as F 17 | from rlalgos.reinforce_diayn.agent import ( 18 | DIAYNAgent, 19 | DIAYNActionModel, 20 | DIAYNBaselineModel, 21 | DIAYNModel, 22 | ) 23 | from rlalgos.reinforce_diayn.reinforce_diayn import Reinforce 24 | import gym 25 | from gym.wrappers import TimeLimit 26 | 27 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 28 | def create_gym_env(env_name): 29 | return gym.make(env_name) 30 | 31 | 32 | # Create a rlstructures.VecEnv from multiple gym.Env, limiting the number of steps 33 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 34 | envs = [] 35 | for k in range(n_envs): 36 | e = create_gym_env(env_name) 37 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 38 | envs.append(e) 39 | return GymEnv(envs, seed) 40 | 41 | 42 | # Create a rlstructures.Agent 43 | def create_agent(model, n_actions=1): 44 | return DIAYNAgent(model=model, n_actions=n_actions) 45 | 46 | 47 | class Experiment(Reinforce): 48 | def __init__(self, config, create_env, create_agent): 49 | super().__init__(config, create_env, create_agent) 50 | 51 | def _create_model(self): 52 | action_model = DIAYNActionModel( 53 | self.obs_dim, self.n_actions, 16, self.config["n_policies"] 54 | ) 55 | baseline_model = DIAYNBaselineModel(self.obs_dim, 16, self.config["n_policies"]) 56 | return DIAYNModel(action_model, baseline_model) 57 | 58 | def _create_discriminator(self): 59 | classifier = nn.Linear(self.obs_dim, self.config["n_policies"]) 60 | classifier.apply(weight_init) 61 | return classifier 62 | 63 | 64 | if __name__ == "__main__": 65 | print( 66 | "DISCLAIMER: DIAYN is just provided as an example. It has not been tested deeply !!" 67 | ) 68 | # We use spawn mode such that most of the environment will run in multiple processes 69 | import torch.multiprocessing as mp 70 | 71 | mp.set_start_method("spawn") 72 | 73 | config = { 74 | "env_name": "CartPole-v0", 75 | "n_envs": 4, 76 | "max_episode_steps": 100, 77 | "env_seed": 42, 78 | "n_processes": 4, 79 | "n_evaluation_processes": 2, 80 | "n_evaluation_envs": 128, 81 | "time_limit": 3600, 82 | "lr": 0.01, 83 | "lr_discriminator": 0.01, 84 | "discount_factor": 0.9, 85 | "baseline_coef": 0.1, 86 | "discriminator_coef": 1.0, 87 | "entropy_coef": 0.01, 88 | "reinforce_coef": 1.0, 89 | "evaluation_mode": "stochastic", 90 | "logdir": "./results", 91 | "n_policies": 5, 92 | } 93 | exp = Experiment(config, create_env, create_agent) 94 | exp.run() 95 | -------------------------------------------------------------------------------- /rlalgos/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/sac/__init__.py -------------------------------------------------------------------------------- /rlalgos/sac/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from rlalgos.sac.agent import SACAgent, SACPolicy, SACQ 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | from rlalgos.sac.sac import SAC 22 | from rlalgos.sac.continuouscartopole import ContinuousCartPoleEnv 23 | 24 | 25 | def create_gym_env(env_name): 26 | assert env_name == "ContinousCartPole" 27 | return ContinuousCartPoleEnv() 28 | 29 | 30 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 31 | envs = [] 32 | for k in range(n_envs): 33 | e = create_gym_env(env_name) 34 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 35 | envs.append(e) 36 | return GymEnv(envs, seed) 37 | 38 | 39 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 40 | envs = [] 41 | for k in range(n_envs): 42 | e = create_gym_env(env_name) 43 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 44 | envs.append(e) 45 | return GymEnvInf(envs, seed) 46 | 47 | 48 | def create_agent(policy, action_dim=None): 49 | return SACAgent(policy=policy, action_dim=action_dim) 50 | 51 | 52 | class Experiment(SAC): 53 | def __init__(self, config, create_train_env, create_env, create_agent): 54 | super().__init__(config, create_train_env, create_env, create_agent) 55 | 56 | def _create_model(self): 57 | module = SACPolicy(self.obs_dim, self.action_dim, 16) 58 | module.apply(weight_init) 59 | return module 60 | 61 | def _create_q(self): 62 | module = SACQ(self.obs_dim, self.action_dim, 16) 63 | module.apply(weight_init) 64 | return module 65 | 66 | 67 | if __name__ == "__main__": 68 | import torch.multiprocessing as mp 69 | 70 | mp.set_start_method("spawn") 71 | 72 | config = { 73 | "env_name": "ContinousCartPole", 74 | "n_envs": 4, 75 | "n_processes": 4, 76 | "n_starting_transitions": 1600, 77 | "batch_timesteps": 1, 78 | "n_batches_per_epochs": 1, 79 | "size_batches": 1024, 80 | "max_episode_steps": 100, 81 | "tau": 0.005, 82 | "discount_factor": 0.95, 83 | "logdir": "./results", 84 | "replay_buffer_size": 1000000, 85 | "lr": 0.0003, 86 | "lambda_entropy": 0.01, 87 | "n_evaluation_processes": 4, 88 | "n_evaluation_envs": 64, 89 | "time_limit": 600, 90 | "env_seed": 42, 91 | "clip_grad": 40, 92 | "learner_device": "cpu", 93 | "evaluation_mode": "stochastic", 94 | "verbose": True, 95 | } 96 | exp = Experiment(config, create_train_env, create_env, create_agent) 97 | exp.run() 98 | -------------------------------------------------------------------------------- /rlalgos/simple_ddqn/README.md: -------------------------------------------------------------------------------- 1 | A simple implementation of DQN -- for a more complete implemetation, see the ../dqn repository 2 | -------------------------------------------------------------------------------- /rlalgos/simple_ddqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlalgos/simple_ddqn/__init__.py -------------------------------------------------------------------------------- /rlalgos/simple_ddqn/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | # import rlstructures.logging as logging 13 | from rlstructures import DictTensor 14 | from rlstructures import RL_Agent 15 | import time 16 | import numpy as np 17 | 18 | 19 | class QAgent(RL_Agent): 20 | def __init__(self, model=None, n_actions=None): 21 | super().__init__() 22 | self.model = model 23 | self.n_actions = n_actions 24 | 25 | def update(self, sd): 26 | self.model.load_state_dict(sd) 27 | 28 | def initial_state(self, agent_info, B): 29 | return DictTensor({}) 30 | 31 | def __call__(self, state, observation, agent_info=None, history=None): 32 | B = observation.n_elems() 33 | 34 | agent_step = None 35 | q = self.model(observation["frame"]) 36 | qs, action = q.max(1) 37 | raction = torch.tensor( 38 | np.random.randint(low=0, high=self.n_actions, size=(action.size()[0])) 39 | ) 40 | epsilon = agent_info["epsilon"] 41 | r = torch.rand(action.size()[0]) 42 | mask = r.lt(epsilon).float() 43 | action = mask * raction + (1 - mask) * action 44 | action = action.long() 45 | 46 | agent_do = DictTensor({"action": action, "q": q}) 47 | return agent_do, DictTensor({}) 48 | 49 | 50 | class DQMLP(nn.Module): 51 | def __init__(self, n_observations, n_actions, n_hidden): 52 | super().__init__() 53 | self.linear = nn.Linear(n_observations, n_hidden) 54 | self.linear_adv = nn.Linear(n_hidden, n_actions) 55 | self.linear_value = nn.Linear(n_hidden, 1) 56 | self.n_actions = n_actions 57 | 58 | def forward_common(self, frame): 59 | z = torch.tanh(self.linear(frame)) 60 | return z 61 | 62 | def forward_value(self, z): 63 | return self.linear_value(z) 64 | 65 | def forward_advantage(self, z): 66 | adv = self.linear_adv(z) 67 | advm = adv.mean(1).unsqueeze(-1).repeat(1, self.n_actions) 68 | return adv - advm 69 | 70 | def forward(self, state): 71 | z = self.forward_common(state) 72 | v = self.forward_value(z) 73 | adv = self.forward_advantage(z) 74 | return v + adv 75 | -------------------------------------------------------------------------------- /rlalgos/simple_ddqn/run_cartpole.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | from rlstructures.batchers import Batcher, EpisodeBatcher 13 | import torch.nn as nn 14 | import copy 15 | import torch 16 | import time 17 | import numpy as np 18 | import torch.nn.functional as F 19 | from rlalgos.simple_ddqn.agent import QAgent, DQMLP 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | from rlalgos.simple_ddqn.ddqn import DQN 23 | 24 | 25 | def create_gym_env(args): 26 | return gym.make(args["environment/env_name"]) 27 | 28 | 29 | def create_env(n_envs, mode="train", max_episode_steps=None, seed=None, **args): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(args) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | 36 | if mode == "train": 37 | return GymEnvInf(envs, seed) 38 | else: 39 | return GymEnv(envs, seed) 40 | 41 | 42 | def create_agent( 43 | n_actions=None, 44 | model=None, 45 | ): 46 | return QAgent(model=model, n_actions=n_actions) 47 | 48 | 49 | class Experiment(DQN): 50 | def __init__(self, config, create_env, create_agent): 51 | super().__init__(config, create_env, create_agent) 52 | 53 | def _create_model(self): 54 | module = DQMLP(self.obs_shape[0], self.n_actions, 64) 55 | 56 | module.apply(weight_init) 57 | return module 58 | 59 | 60 | def flatten(d, parent_key="", sep="/"): 61 | items = [] 62 | for k, v in d.items(): 63 | new_key = parent_key + sep + k if parent_key else k 64 | if isinstance(v, DictConfig): 65 | items.extend(flatten(v, new_key, sep=sep).items()) 66 | else: 67 | items.append((new_key, v)) 68 | return dict(items) 69 | 70 | 71 | if __name__ == "__main__": 72 | # We use spawn mode such that most of the environment will run in multiple processes 73 | import torch.multiprocessing as mp 74 | 75 | mp.set_start_method("spawn") 76 | 77 | config = { 78 | "environment/env_name": "CartPole-v0", 79 | "n_envs": 4, 80 | "max_episode_steps": 10000, 81 | "discount_factor": 0.99, 82 | "epsilon_greedy": 0.1, 83 | "replay_buffer_size": 10000, 84 | "n_batches": 32, 85 | "initial_buffer_epochs": 1, 86 | "qvalue_epochs": 1, 87 | "batch_timesteps": 4, 88 | "lr": 0.01, 89 | "n_processes": 4, 90 | "n_evaluation_processes": 4, 91 | "verbose": True, 92 | "n_evaluation_envs": 32, 93 | "time_limit": 28800, 94 | "env_seed": 42, 95 | "clip_grad": 0.0, 96 | "learner_device": "cpu", 97 | "optim": "Adam", 98 | "update_target_tau": 0.005, 99 | "logdir": "./results", 100 | "save_every": 1, 101 | } 102 | exp = Experiment(config, create_env, create_agent) 103 | exp.run() 104 | -------------------------------------------------------------------------------- /rlalgos/tools.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import time 11 | import numpy as np 12 | import torch.nn.init as init 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def weight_init(m): 18 | """ 19 | Usage: 20 | model = Model() 21 | model.apply(weight_init) 22 | """ 23 | if isinstance(m, nn.Conv1d): 24 | init.normal_(m.weight.data) 25 | if m.bias is not None: 26 | init.normal_(m.bias.data) 27 | elif isinstance(m, nn.Conv2d): 28 | init.xavier_normal_(m.weight.data) 29 | if m.bias is not None: 30 | init.normal_(m.bias.data) 31 | elif isinstance(m, nn.Conv3d): 32 | init.xavier_normal_(m.weight.data) 33 | if m.bias is not None: 34 | init.normal_(m.bias.data) 35 | elif isinstance(m, nn.ConvTranspose1d): 36 | init.normal_(m.weight.data) 37 | if m.bias is not None: 38 | init.normal_(m.bias.data) 39 | elif isinstance(m, nn.ConvTranspose2d): 40 | init.xavier_normal_(m.weight.data) 41 | if m.bias is not None: 42 | init.normal_(m.bias.data) 43 | elif isinstance(m, nn.ConvTranspose3d): 44 | init.xavier_normal_(m.weight.data) 45 | if m.bias is not None: 46 | init.normal_(m.bias.data) 47 | elif isinstance(m, nn.BatchNorm1d): 48 | init.normal_(m.weight.data, mean=1, std=0.02) 49 | init.constant_(m.bias.data, 0) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | init.normal_(m.weight.data, mean=1, std=0.02) 52 | init.constant_(m.bias.data, 0) 53 | elif isinstance(m, nn.BatchNorm3d): 54 | init.normal_(m.weight.data, mean=1, std=0.02) 55 | init.constant_(m.bias.data, 0) 56 | elif isinstance(m, nn.Linear): 57 | init.xavier_normal_(m.weight.data) 58 | init.normal_(m.bias.data) 59 | elif isinstance(m, nn.Embedding): 60 | init.xavier_normal_(m.weight.data) 61 | elif isinstance(m, nn.LSTM): 62 | for param in m.parameters(): 63 | if len(param.shape) >= 2: 64 | init.orthogonal_(param.data) 65 | else: 66 | init.normal_(param.data) 67 | elif isinstance(m, nn.LSTMCell): 68 | for param in m.parameters(): 69 | if len(param.shape) >= 2: 70 | init.orthogonal_(param.data) 71 | else: 72 | init.normal_(param.data) 73 | elif isinstance(m, nn.GRU): 74 | for param in m.parameters(): 75 | if len(param.shape) >= 2: 76 | init.orthogonal_(param.data) 77 | else: 78 | init.normal_(param.data) 79 | elif isinstance(m, nn.GRUCell): 80 | for param in m.parameters(): 81 | if len(param.shape) >= 2: 82 | init.orthogonal_(param.data) 83 | else: 84 | init.normal_(param.data) 85 | -------------------------------------------------------------------------------- /rlstructures/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2" 2 | __deprecated_message__ = False 3 | import sys 4 | from rlstructures.core import ( 5 | masked_tensor, 6 | masked_dicttensor, 7 | DictTensor, 8 | TemporalDictTensor, 9 | Trajectories, 10 | ) 11 | from rlstructures.rl_batchers.agent import ( 12 | RL_Agent, 13 | RL_Agent_CheckDevice, 14 | replay_agent_stateless, 15 | replay_agent, 16 | ) 17 | from rlstructures.env import VecEnv 18 | from rlstructures.rl_batchers import RL_Batcher 19 | 20 | 21 | # Deprecated import == Old version of rlstructures 22 | from rlstructures.deprecated.agent import Agent 23 | 24 | import rlalgos.logger 25 | 26 | sys.modules["rlstructures.logger"] = rlalgos.logger 27 | 28 | import rlalgos.tools 29 | 30 | sys.modules["rlstructures.tools"] = rlalgos.tools 31 | 32 | import rlstructures.core 33 | 34 | sys.modules["rlstructures.dicttensor"] = rlstructures.core 35 | 36 | import rlstructures.deprecated.logging 37 | 38 | sys.modules["rlstructures.logging"] = rlstructures.deprecated.logging 39 | 40 | import rlstructures.deprecated.batchers 41 | 42 | sys.modules["rlstructures.batchers"] = rlstructures.deprecated.batchers 43 | 44 | import rlalgos.deprecated.template_exp 45 | 46 | sys.modules["rlalgos.template_exp"] = rlalgos.deprecated.template_exp 47 | -------------------------------------------------------------------------------- /rlstructures/deprecated/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/rlstructures/deprecated/__init__.py -------------------------------------------------------------------------------- /rlstructures/deprecated/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from rlstructures import DictTensor, TemporalDictTensor 9 | import torch 10 | 11 | 12 | class Agent: 13 | """ 14 | Describes an agent responsible for producing actions when receiving 15 | observations and agent states. 16 | 17 | At each time step, and agent receives observations (DictTensor of size B), 18 | and agent states (DictTensor of size B) that reflect the agent's internal 19 | state. 20 | 21 | It then returns a triplet: 22 | agent state when receiving the observation (DictTensor): it is the 23 | agent state before computing anything. It is mainly used to 24 | initialize the state of the agent when facing initial states from the environment. 25 | action (DictTensor): the action + and additional outputs produced by 26 | the agent 27 | next agent state (DictTensor): the new state of the agent after all 28 | the computation. This value will then be provided to the agent at 29 | the next timestep. 30 | """ 31 | 32 | def __init__(self): 33 | pass 34 | 35 | def require_history(self): 36 | """if True, then the 'history' argument in the __call__ method will contain the set of previous transitions (e.g for transformers based policies)""" 37 | return False 38 | 39 | def __call__( 40 | self, 41 | state: DictTensor, 42 | input: DictTensor, 43 | user_info: DictTensor, 44 | history: TemporalDictTensor = None, 45 | ): 46 | """Execute one step of the agent 47 | 48 | :param state: the previous state of the agent, or None if the agent needs to be initialized 49 | :type state: DictTensor 50 | :param input: The observation coming from the environment 51 | :type input: DictTensor 52 | :param user_info: An additional DictTensor (provided by the user such that the epsilon value in epsilon-greedy policies) 53 | :type user_info: DictTensor 54 | :param history: [description], None if require_history()==False or a set of previous transitions (as a TemporalDictTensor) if True 55 | :type history: TemporalDictTensor, optional 56 | """ 57 | raise NotImplementedError 58 | 59 | def update(self, info): 60 | """ 61 | Update the agent. For instance, may update the pytorch model of this agent 62 | """ 63 | raise NotImplementedError 64 | 65 | def close(self): 66 | """ 67 | Terminate the agent 68 | """ 69 | pass 70 | -------------------------------------------------------------------------------- /rlstructures/deprecated/batchers/__init__.py: -------------------------------------------------------------------------------- 1 | from .episodebatchers import EpisodeBatcher, MonoThreadEpisodeBatcher 2 | from .trajectorybatchers import Batcher 3 | -------------------------------------------------------------------------------- /rlstructures/deprecated/logging.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """A simple logging interface to display message in the Console 8 | """ 9 | 10 | import rlstructures.deprecated.logging 11 | 12 | DEBUG = 0 13 | INFO = 1 14 | NO = 2 15 | 16 | __LOGGING_LEVEL = 0 17 | 18 | 19 | def error(str): 20 | print("[ERROR] " + str, flush=True) 21 | assert False 22 | 23 | 24 | def debug(str): 25 | global __LOGGING_LEVEL 26 | if __LOGGING_LEVEL <= 0: 27 | print("[DEBUG] " + str, flush=True) 28 | 29 | 30 | def info(str): 31 | global __LOGGING_LEVEL 32 | if __LOGGING_LEVEL <= 1: 33 | print("[INFO] " + str, flush=True) 34 | 35 | 36 | def basicConfig(**args): 37 | global __LOGGING_LEVEL 38 | __LOGGING_LEVEL = args["level"] 39 | 40 | 41 | def getLogger(str): 42 | return rlstructures.logging 43 | -------------------------------------------------------------------------------- /rlstructures/env.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import time 11 | import torch 12 | from rlstructures import DictTensor 13 | 14 | 15 | class VecEnv: 16 | """ 17 | An VecEnvironment corresponds to multiple 'gym' environments (i.e a batch) 18 | that are running simultaneously. 19 | 20 | At each timestep, upon the B environments, a subset B' of envs are running 21 | (since some envs may have stopped). 22 | 23 | So each observation returned by the VecEnv is a DictTensor of size B'. To 24 | mark which environments that are still running, the observation is returned 25 | with a mapping vector of size B'. e.g [0,2,5] means that the observation 0 26 | corresponds to the env 0, the observation 1 corresponds to env 2, and the 27 | observation 3 corresponds to env 5. 28 | 29 | Finally, when running a step (at time t) method (over B' running envs), the 30 | agent has to provide an action (DictTensor) of size B'. The VecEnv will return 31 | the next observation (time t+1) (size B'). But some of the B' envs may have 32 | stopped at t+1, such that actually only B'' envs are still running. The 33 | step method will thus also return a B'' observation (and corresponding 34 | mapping). 35 | 36 | The return of the step function is thus: 37 | ((DictTensor of size B', tensor of size B'), 38 | (Dicttensor of size B'', mapping vector if size B'')) 39 | """ 40 | 41 | def __init__(self): 42 | pass 43 | 44 | def reset(self, env_info: DictTensor = None): 45 | """reset the environments instances 46 | 47 | :param env_info: a DictTensor of size n_envs, such that each value will be transmitted to each environment instance 48 | :type env_info: DictTensor, optional 49 | """ 50 | pass 51 | 52 | def step( 53 | self, policy_output: DictTensor 54 | ) -> [[DictTensor, torch.Tensor], [DictTensor, torch.Tensor]]: 55 | """Execute one step over alll the running environment instances 56 | 57 | :param policy_output: the output given by the policy 58 | :type policy_output: DictTensor 59 | :return: see general description 60 | :rtype: [[DictTensor,torch.Tensor],[DictTensor,torch.Tensor]] 61 | """ 62 | raise NotImplementedError 63 | 64 | def close(self): 65 | """Terminate the environment""" 66 | raise NotImplementedError 67 | 68 | def n_envs(self) -> int: 69 | """Returns the number of environment instances contained in this env 70 | :rtype: int 71 | """ 72 | return self.reset()[0].n_elems() 73 | -------------------------------------------------------------------------------- /rlstructures/env_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .gymenv import GymEnv, GymEnvInf 2 | from .devicewrapper import DeviceEnv 3 | -------------------------------------------------------------------------------- /rlstructures/env_wrappers/devicewrapper.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from rlstructures.env import VecEnv 8 | from rlstructures import DictTensor 9 | import torch 10 | 11 | class DeviceEnv: 12 | def __init__(self,env,from_device,to_device): 13 | self.env=env 14 | self.from_device=from_device 15 | self.to_device=to_device 16 | self.action_space=self.env.action_space 17 | 18 | def reset(self, env_info=DictTensor({})): 19 | assert env_info.empty() or env_info.device()==torch.device("cpu"),"env_info must be on CPU" 20 | o,e=self.env.reset(env_info) 21 | return o.to(self.to_device),e.to(self.to_device) 22 | 23 | def step(self, policy_output): 24 | policy_output=policy_output.to(self.from_device) 25 | (a,b),(c,d)=self.env.step(policy_output) 26 | return (a.to(self.to_device),b.to(self.to_device)),(c.to(self.to_device),d.to(self.to_device)) 27 | 28 | def close(self): 29 | self.env.close() 30 | 31 | def n_envs(self): 32 | return self.env.n_envs() 33 | -------------------------------------------------------------------------------- /rlstructures/rl_batchers/__init__.py: -------------------------------------------------------------------------------- 1 | from .batcher import RL_Batcher 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from setuptools import setup, find_packages 8 | 9 | with open("requirements.txt") as f: 10 | reqs = [line.strip() for line in f] 11 | 12 | setup( 13 | name="rlstructures", 14 | version="1.0", 15 | python_requires=">=3.7", 16 | packages=find_packages(), 17 | install_requires=reqs, 18 | ) 19 | -------------------------------------------------------------------------------- /sphinx_docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = ../docs 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /sphinx_docs/source/algorithms/index.rst: -------------------------------------------------------------------------------- 1 | Provided Algorithms 2 | =================== 3 | 4 | We provide multiple RL algorithms as examples. 5 | 6 | 1) A2C with General Advantage Estimator 7 | 2) PPO with discrete actions 8 | 3) Double Duelling Q-Learning + Prioritized Experience Replay 9 | 3bis) A simpler DQN implementation (as an example) 10 | 4) SAC for continuous actions 11 | 5) REINFORCE 12 | 6) REINFORCE DIAYN (see https://arxiv.org/abs/1802.06070) 13 | 14 | The algorithms can be used as examples to implement your own algorithms. 15 | 16 | Typical execution is `OMP_NUM_THREADS=1 PYTHONPATH=rlstructures python rlstructures/rlalgos/reinforce/main_reinforce.py` 17 | 18 | Note that all algorithms produced a tensorboard and a CSV output (see `config["logdir"]` in the main file) 19 | -------------------------------------------------------------------------------- /sphinx_docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | RLStructures API 2 | ================ 3 | .. toctree:: 4 | :maxdepth: 1 5 | :caption: API 6 | 7 | rlstructures 8 | rlstructures.env_wrappers 9 | -------------------------------------------------------------------------------- /sphinx_docs/source/api/rlstructures.env_wrappers.rst: -------------------------------------------------------------------------------- 1 | rlstructures.env\_wrappers package 2 | ================================== 3 | 4 | OpenAI Gym Wrappers 5 | ---------------------------------------- 6 | 7 | .. automodule:: rlstructures.env_wrappers.gymenv 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /sphinx_docs/source/api/rlstructures.rst: -------------------------------------------------------------------------------- 1 | rlstructures API 2 | ==================== 3 | 4 | DictTensor, TemporalDictTensor, Trajectories 5 | -------------------------------------------- 6 | 7 | .. automodule:: rlstructures.core 8 | :members: 9 | :undoc-members: 10 | 11 | VecEnv 12 | ------ 13 | 14 | .. automodule:: rlstructures.env 15 | :members: 16 | :undoc-members: 17 | 18 | RL_Agent 19 | -------- 20 | 21 | .. automodule:: rlstructures.rl_batchers.agent 22 | :members: 23 | :undoc-members: 24 | 25 | RL_Batcher 26 | ---------- 27 | 28 | .. automodule:: rlstructures.rl_batchers.batcher 29 | :members: 30 | :undoc-members: 31 | -------------------------------------------------------------------------------- /sphinx_docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # This file only contains a selection of the most common options. For a full 9 | # list see the documentation: 10 | # http://www.sphinx-doc.org/en/master/config 11 | 12 | # -- Path setup -------------------------------------------------------------- 13 | 14 | # If extensions (or modules to document with autodoc) are in another directory, 15 | # add these directories to sys.path here. If the directory is relative to the 16 | # documentation root, use os.path.abspath to make it absolute, like shown here. 17 | 18 | import os 19 | import sys 20 | 21 | sys.path.insert(0, os.path.abspath("../../")) 22 | 23 | 24 | # -- Project information ----------------------------------------------------- 25 | 26 | project = "RLStructures" 27 | copyright = "2021, Facebook AI Research" # pylint: disable=redefined-builtin 28 | author = "Facebook AI Research" 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "autoapi.extension", 38 | "sphinx.ext.autodoc", 39 | "sphinx.ext.githubpages", 40 | "sphinx.ext.coverage", 41 | "sphinx.ext.napoleon", 42 | "sphinx.ext.autosummary", 43 | "recommonmark", 44 | "sphinx.ext.viewcode", 45 | ] 46 | 47 | 48 | autoapi_type = "python" 49 | autoapi_dirs = [".."] 50 | 51 | source_suffix = { 52 | ".rst": "restructuredtext", 53 | ".txt": "markdown", 54 | ".md": "markdown", 55 | } 56 | 57 | master_doc = "index" 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = [] 61 | 62 | # List of patterns, relative to source directory, that match files and 63 | # directories to ignore when looking for source files. 64 | # This pattern also affects html_static_path and html_extra_path. 65 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 66 | 67 | 68 | # -- Options for HTML output ------------------------------------------------- 69 | 70 | # The theme to use for HTML and HTML Help pages. See the documentation for 71 | # a list of builtin themes. 72 | # 73 | html_theme = "sphinx_rtd_theme" 74 | html_sidebars = { 75 | "**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"] 76 | } 77 | # Add any paths that contain custom static files (such as style sheets) here, 78 | # relative to this directory. They are copied after the builtin static files, 79 | # so a file named "default.css" will overwrite the builtin "default.css". 80 | html_static_path = [] 81 | -------------------------------------------------------------------------------- /sphinx_docs/source/deprecated/index.rst: -------------------------------------------------------------------------------- 1 | Deprecated API (v0.1) 2 | ===================== 3 | 4 | We provide the documentation over deprecated functions. 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | :caption: Overview 10 | 11 | deprecated 12 | tutorial/ 13 | -------------------------------------------------------------------------------- /sphinx_docs/source/deprecated/tutorial/hierarchical_policy.rst: -------------------------------------------------------------------------------- 1 | Hierarchical Policies 2 | ===================== 3 | 4 | (Soon....) 5 | -------------------------------------------------------------------------------- /sphinx_docs/source/deprecated/tutorial/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | We propose different tutorials to learn `rlstructures`. All the tutorials are available as python files in the repository. 5 | 6 | * `Tutorial Files: ` 7 | 8 | Note that other algorithms are provided in the `rlalgos` package (PPO, DQN and SAC). 9 | 10 | .. toctree:: 11 | :maxdepth: 1 12 | :caption: Tutorials 13 | 14 | reinforce 15 | reinforce_with_evaluation 16 | a2c 17 | recurrent_policy 18 | hierarchical_policy 19 | transformer_policy 20 | -------------------------------------------------------------------------------- /sphinx_docs/source/deprecated/tutorial/reinforce_with_evaluation.rst: -------------------------------------------------------------------------------- 1 | Evaluation of RL models in other processes 2 | ========================================== 3 | 4 | * https://github.com/facebookresearch/rlstructures/tree/main/tutorial/tutorial_reinforce_with_evaluation 5 | 6 | 7 | Regarding the REINFORCE implementation, one missing aspect is a good evaluation of the policy: 8 | * the evaluation has to be done with the `deterministic` policy (while learning is made with the stochastic policy) 9 | * the evaluation over N episodes may be long, and we would like to avoid to slow down the learning 10 | 11 | To solve this issue, we will use another batcher in `asynchronous` mode. 12 | 13 | Creation of the evaluation batcher 14 | ---------------------------------- 15 | 16 | The evaluation batcher can be created like the trainig batcher (but with a different number of threads and slots) 17 | 18 | .. code-block:: python 19 | 20 | model=copy.deepcopy(self.learning_model) 21 | self.evaluation_batcher=EpisodeBatcher( 22 | n_timesteps=self.config["max_episode_steps"], 23 | n_slots=self.config["n_evaluation_episodes"], 24 | create_agent=self._create_agent, 25 | create_env=self._create_env, 26 | env_args={ 27 | "n_envs": self.config["n_envs"], 28 | "max_episode_steps": self.config["max_episode_steps"], 29 | "env_name":self.config["env_name"] 30 | }, 31 | agent_args={"n_actions": self.n_actions, "model": model}, 32 | n_threads=self.config["n_evaluation_threads"], 33 | seeds=[self.config["env_seed"]+k*10 for k in range(self.config["n_evaluation_threads"])], 34 | ) 35 | 36 | Running the evaluation batcher 37 | ------------------------------ 38 | 39 | Running the evaluation batcher is made through `execute`: 40 | 41 | .. code-block:: python 42 | 43 | n_episodes=self.config["n_evaluation_episodes"] 44 | agent_info=DictTensor({"stochastic":torch.tensor([False]).repeat(n_episodes)}) 45 | self.evaluation_batcher.execute(n_episodes=n_episodes,agent_info=agent_info) 46 | self.evaluation_iteration=self.iteration 47 | 48 | Note that we store the iteration at which the evaluation batcher has been executed 49 | 50 | Getting trajectories without blocking the learning 51 | -------------------------------------------------- 52 | 53 | Not we can get episodes, but in non blocking mode: the batcher will return `None` if the process of computing episodes is not finished. 54 | If the process is finished, we can 1) compute the reward 2) update the batchers models 3) relaunch the acquisition process. We thus have an evaluation process that runs without blocking the learning, and at maximum speed. 55 | 56 | .. code-block:: python 57 | 58 | evaluation_trajectories=self.evaluation_batcher.get(blocking=False) 59 | if not evaluation_trajectories is None: #trajectories are available 60 | #Compute the cumulated reward 61 | cumulated_reward=(evaluation_trajectories["_reward"]*evaluation_trajectories.mask()).sum(1).mean() 62 | self.logger.add_scalar("evaluation_reward",cumulated_reward.item(),self.evaluation_iteration) 63 | #We reexecute the evaluation batcher (with same value of agent_info and same number of episodes) 64 | self.evaluation_batcher.update(self.learning_model.state_dict()) 65 | self.evaluation_iteration=self.iteration 66 | self.evaluation_batcher.reexecute() 67 | -------------------------------------------------------------------------------- /sphinx_docs/source/deprecated/tutorial/transformer_policy.rst: -------------------------------------------------------------------------------- 1 | Transformer Policy 2 | ================== 3 | 4 | (Soon...) 5 | -------------------------------------------------------------------------------- /sphinx_docs/source/gettingstarted/BatcherExamples.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/sphinx_docs/source/gettingstarted/BatcherExamples.rst -------------------------------------------------------------------------------- /sphinx_docs/source/gettingstarted/DataStructures.rst: -------------------------------------------------------------------------------- 1 | 2 | Data Structures 3 | =============== 4 | 5 | * https://github.com/facebookresearch/rlstructures/blob/main/tutorial/tutorial_datastructures.py 6 | 7 | 8 | DictTensor 9 | ---------- 10 | 11 | A DictTensor is dictionary of pytorch tensors. It assumes that the first dimension of each tensor contained in the DictTensor is the batch dimension. The easiest way to build a DictTensor is to use a ditcionary of tensors as input 12 | 13 | .. code-block:: python 14 | 15 | from rlstructures import DictTensor 16 | import torch 17 | d=DictTensor({"x":torch.randn(3,5),"y":torch.randn(3,8)}) 18 | 19 | The number of elements in the batch is accessible through `n_elems()`: 20 | 21 | .. code-block:: python 22 | 23 | print(d.n_elems()," <- number of elements in the batch") 24 | 25 | An empty DictTensor can be defined as follows: 26 | 27 | .. code-block:: python 28 | 29 | d=DictTensor({}) 30 | 31 | Many methods can be used over DictTensor (see DictTensor documentation): 32 | 33 | .. code-block:: python 34 | 35 | d["x"] # Returns the tensor 'x' in the DictTensor 36 | d.keys() # Returns the names of the variables of the DictTensor 37 | 38 | Tensors can be organized in a tree structure: 39 | 40 | .. code-block:: python 41 | d=DictTensor({}) 42 | d.set("observation/x",torch.randn(5,3)) 43 | d.set("observation/y",torch.randn(5,8,2)) 44 | d.set("agent_state/z",torch.randn(5,4)) 45 | 46 | observation=d.truncate_key("observation/") #returns a DictTensor with 'x' and 'y' 47 | print(observation) 48 | 49 | 50 | TemporalDictTensor 51 | ------------------ 52 | 53 | A `TemporalDictTensor` is a packed sequence of `DictTensors`. In memory, it is stored as a dictionary of tensors, where the first dimesion is the batch dimension, and the second dimension is the time index. Each element in the batch is a sequence, and two sequences can have different lengths. 54 | 55 | .. code-block:: python 56 | 57 | from rlstructures import TemporalDictTensor 58 | 59 | #Create three sequences of variables x and y, where the length of the first sequence is 6, the length of the second is 10 and the length of the last sequence is 3 60 | d=TemporalDictTensor({"x":torch.randn(3,10,5),"y":torch.randn(3,10,8)},lengths=torch.tensor([6,10,3])) 61 | 62 | print(d.n_elems()," <- number of elements in the batch") 63 | print(d.lengths,"<- Lengths of the sequences") 64 | print(d["x"].size(),"<- access to the tensor 'x'") 65 | 66 | print("Masking: ") 67 | print(d.mask()) 68 | 69 | print("Slicing (restricting the sequence to some particular temporal indexes) ") 70 | d_slice=d.temporal_slice(0,4) 71 | print(d_slice.lengths) 72 | print(d_slice.mask()) 73 | 74 | `DictTensor` and `TemporalDictTensor` can be moved to cpu/gpu using the *xxx.to(...)* method. 75 | 76 | Trajectories 77 | ------------ 78 | 79 | We recently introduced the `Trajectories` structure as a pair of one DictTensor and one TemporalDictTensor to represent Trajectories 80 | 81 | .. code-block:: python 82 | trajectories.info #A DictTensor 83 | trajectories.trajectories #A TemporalDictTensor of transitions 84 | 85 | See the `Agent and Batcher` documentation for more details. 86 | -------------------------------------------------------------------------------- /sphinx_docs/source/gettingstarted/PlayingWithRLStructures.rst: -------------------------------------------------------------------------------- 1 | Playing with rlstructures 2 | ========================= 3 | 4 | We propose some examples of Batcher uses to better understand how it works. The python file is https://github.com/facebookresearch/rlstructures/blob/main/tutorial/playing_with_rlstructures.py 5 | 6 | 7 | Blocking / non-Blocking batcher execution 8 | ----------------------------------------- 9 | 10 | The `batcher.get` function can be executed in `batcher.get(blocking=True)` or `batcher.get(blocking=False)` modes. 11 | 12 | * In the first mode `blocking=True`, the progam will wait the batcher to end its acquisition and will return trajectories 13 | 14 | * In the second mode `blocking=False`,the batcher will return `None,None` is the acquisition is not finished. It thus allows to perform other computation without waiting the batcher to finished 15 | 16 | Replaying an agent over an acquired trajectory 17 | ---------------------------------------------- 18 | 19 | When trajectories have been acquired, then the autograd graph is not avaialbe (i.e batcher are launched in `require_grad=False` mode). 20 | It is important to be able to recompute the agent steps on these trajectories. 21 | 22 | We provide the `replay_agent` function to facilitate this `replay`. An example is given in https://github.com/facebookresearch/rlstructures/blob/main/rlalgos/reinforce 23 | 24 | Some other examples of use are given in the A2C and DQN implementations. 25 | -------------------------------------------------------------------------------- /sphinx_docs/source/gettingstarted/index.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | We will explain the main concepts of rlstructures. Each HTML file is associated with a corresponding python file in the repository. 5 | 6 | * `Tutorial Files: ` 7 | 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Overview 12 | 13 | DataStructures 14 | Environments 15 | RLAgentAndBatcher 16 | PlayingWithRLStructures 17 | -------------------------------------------------------------------------------- /sphinx_docs/source/index.rst: -------------------------------------------------------------------------------- 1 | rlstructures 2 | ============ 3 | 4 | TL;DR 5 | ----- 6 | `rlstructures` is a lightweight Python library that provides simple APIs as well as data structures that make as few assumptions as possible on the structure of your agent or your task while allowing the transparent execution of multiple policies on multiple environments in parallel (incl. multiple GPUs). 7 | 8 | Important Note (Feb 2021) -- version 0.2 9 | ---------------------------------------- 10 | 11 | Due to feedback, we have made changes over the API (v0.2) 12 | 13 | * The new API is not compatible with the old one 14 | 15 | * The old API is still working (but printing a deprecated message) 16 | 17 | * The v0.1 version of rlstructures is available in the v0.1 github branch 18 | 19 | * We encourage users to switch to the v0.2 API which does not need lot of modifications in your current code 20 | 21 | Main changes: 22 | 23 | * A single Batcher class (instead of two) 24 | 25 | * A more clear organization of the information computed by the batcher 26 | 27 | * Agents can use seeds for reproducibility issues 28 | 29 | * Agents and batchers can work on GPU to speed up the rollouts 30 | 31 | * A replay function has been added to allow to replay an Agent over acquired trajectories 32 | 33 | * It greatly facilitate loss functions implementation 34 | 35 | * More RL algorithms as examples (including PPO,SAC, REINFORCE, A2C, DQN, DIAYN) 36 | 37 | * A growing series of tutorials at https://ludovic-denoyer.medium.com/ 38 | 39 | All these changes are documented in the HTML documentation at http://facebookresearch.github.io/rlstructures 40 | 41 | Why/What? 42 | --------- 43 | RL research addresses multiple aspects of RL like hierarchical policies, option-based policies, goal-oriented policies, structured input/output spaces, transformers-based policies, etc., and there are currently few tools to handle this diversity of research projects. 44 | 45 | We propose `rlstructures` as a way to: 46 | 47 | * Simulate multiple policies, multiple models and multiple environments simultaneously at scale 48 | 49 | * Define complex loss functions 50 | 51 | * Quickly implement various policy architectures. 52 | 53 | The main RLStructures principle is that the users delegates the sampling of trajectories and episodes to the library so they can spend most of their time on the interesting part of RL research: developing new models and algorithms. 54 | 55 | `rlstructures` is easy to use: it has very few simple interfaces that can be learned in one hour by reading the tutorials. 56 | 57 | It comes with multiple RL algorithms as examples including A2C, PPO, DDQN and SAC. 58 | 59 | Please reach out to us if you intend to use it. We will be happy to help, and potentially to implement missing functionalities. 60 | 61 | Targeted users 62 | -------------- 63 | 64 | RLStructures comes with a set of implemented RL algorithms. But rlstructures does not aim at being a repository of benchmarked RL algorithms (an other RL librairies do that very well). If your objective is to apply state-of-the-art methods on particular environments, then rlstructures is not the best fit. If your objective is to implement new algorithms, then rlstructures is a good fit. 65 | 66 | Where? 67 | ------ 68 | 69 | * Github: http://github.com/facebookresearch/rlstructures 70 | * Tutorials: https://ludovic-denoyer.medium.com/ 71 | * Discussion Group: https://www.facebook.com/groups/834804787067021 72 | 73 | .. toctree:: 74 | :maxdepth: 1 75 | :caption: Getting Started 76 | 77 | overview 78 | gettingstarted/index 79 | algorithms/index 80 | api/index 81 | foireaq/foireaq.rst 82 | migrating_v0.1_v0.2 83 | deprecated/index.rst 84 | -------------------------------------------------------------------------------- /sphinx_docs/source/migrating_v0.1_v0.2.rst: -------------------------------------------------------------------------------- 1 | rlstructures -- mgirating from v0.1 to v0.2 2 | =========================================== 3 | 4 | Version 0.2 of rlstructures have some critical changes: 5 | 6 | From Agent to RL_Agent 7 | ---------------------- 8 | 9 | Policies are now implemented through the RL_Agent class. The two differences are: 10 | 11 | * The RL_Agent class has a `initial_state` methods that initialize the state of the agent at reset time (i.e when you call Batcher.reset). It avoids you to handle the state initialization in the `__call__` function. 12 | 13 | * The RL_Agent does not return its `old state` anymore, and just provide the `agent_do` and `new_state` as an output 14 | 15 | From EpisodeBatcher/Batcher to RL_Batcher 16 | ----------------------------------------- 17 | 18 | RL_Batcher is the batcher class that works with RL_Agent: 19 | 20 | * At construction time: 21 | 22 | * There is no need to specify the `n_slots` arguments anymore 23 | 24 | * One has to provide examples (with n_elems()==1) of `agent_info` and `env_info` that will be sent to the batcher at construction time 25 | 26 | * You can specify the device of the batcher (default is CPU -- see the CPU/GPU tutorial) 27 | 28 | * At use time: 29 | 30 | * Only three functions are available: `reset`, `execute` and `get` 31 | 32 | * Outputs: 33 | 34 | * The RL_Batcher now outputs a `Trajectories` object composed of `trajectories.info:DictTensor` and `trajectories.trajectories:TemporalDictTensor` 35 | 36 | * `trajectories.info` contains informations that is fixed during the trajectorie: agent_info, env_info and initial agent state 37 | 38 | * `trajectories.trajectories` contains informations generated by the environment (observations), and also actions produced by the Agent 39 | 40 | Replay functions 41 | ---------------- 42 | 43 | We now propose a `replay_agent` function that allows to easily repaly an agent over trajectories (e.g for loss computation) 44 | -------------------------------------------------------------------------------- /tutorial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_a2c_with_infinite_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_a2c_with_infinite_env/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_a2c_with_infinite_env/main_a2c.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_reinforce.agent import ReinforceAgent 19 | from tutorial.tutorial_a2c_with_infinite_env.a2c import A2C 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | 24 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 39 | envs = [] 40 | for k in range(n_envs): 41 | e = create_gym_env(env_name) 42 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 43 | envs.append(e) 44 | return GymEnvInf(envs, seed) 45 | 46 | 47 | def create_agent(model, n_actions=1): 48 | return ReinforceAgent(model=model, n_actions=n_actions) 49 | 50 | 51 | class Experiment(A2C): 52 | def __init__(self, config, create_env, create_train_env, create_agent): 53 | super().__init__(config, create_env, create_train_env, create_agent) 54 | 55 | 56 | if __name__ == "__main__": 57 | # We use spawn mode such that most of the environment will run in multiple processes 58 | import torch.multiprocessing as mp 59 | 60 | mp.set_start_method("spawn") 61 | 62 | config = { 63 | "env_name": "CartPole-v0", 64 | "a2c_timesteps": 20, 65 | "n_envs": 4, 66 | "max_episode_steps": 100, 67 | "env_seed": 42, 68 | "n_threads": 4, 69 | "n_evaluation_threads": 2, 70 | "n_evaluation_episodes": 256, 71 | "time_limit": 3600, 72 | "lr": 0.01, 73 | "discount_factor": 0.95, 74 | "critic_coef": 1.0, 75 | "entropy_coef": 0.01, 76 | "a2c_coef": 1.0, 77 | "logdir": "./results", 78 | } 79 | exp = Experiment(config, create_env, create_train_env, create_agent) 80 | exp.run() 81 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_from_reinforce_to_a2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_from_reinforce_to_a2c/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_from_reinforce_to_a2c/main_a2c.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_reinforce.agent import ReinforceAgent 19 | from tutorial.tutorial_from_reinforce_to_a2c.a2c import A2C 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | 24 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | def create_agent(model, n_actions=1): 39 | return ReinforceAgent(model=model, n_actions=n_actions) 40 | 41 | 42 | class Experiment(A2C): 43 | def __init__(self, config, create_env, create_agent): 44 | super().__init__(config, create_env, create_agent) 45 | 46 | 47 | if __name__ == "__main__": 48 | # We use spawn mode such that most of the environment will run in multiple processes 49 | import torch.multiprocessing as mp 50 | 51 | mp.set_start_method("spawn") 52 | 53 | config = { 54 | "env_name": "CartPole-v0", 55 | "a2c_timesteps": 20, 56 | "n_envs": 4, 57 | "max_episode_steps": 100, 58 | "env_seed": 42, 59 | "n_threads": 4, 60 | "n_evaluation_threads": 2, 61 | "n_evaluation_episodes": 256, 62 | "time_limit": 3600, 63 | "lr": 0.01, 64 | "discount_factor": 0.95, 65 | "critic_coef": 1.0, 66 | "entropy_coef": 0.01, 67 | "a2c_coef": 1.0, 68 | "logdir": "./results", 69 | } 70 | exp = Experiment(config, create_env, create_agent) 71 | exp.run() 72 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_from_reinforce_to_a2c_s/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_from_reinforce_to_a2c_s/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_from_reinforce_to_a2c_s/main_a2c.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_from_reinforce_to_a2c_s.agent import ReinforceAgent 19 | from tutorial.tutorial_from_reinforce_to_a2c_s.a2c import A2C 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | 24 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | def create_agent(model, n_actions=1): 39 | return ReinforceAgent(model=model, n_actions=n_actions) 40 | 41 | 42 | class Experiment(A2C): 43 | def __init__(self, config, create_env, create_agent): 44 | super().__init__(config, create_env, create_agent) 45 | 46 | 47 | if __name__ == "__main__": 48 | # We use spawn mode such that most of the environment will run in multiple processes 49 | import torch.multiprocessing as mp 50 | 51 | mp.set_start_method("spawn") 52 | 53 | config = { 54 | "env_name": "CartPole-v0", 55 | "a2c_timesteps": 20, 56 | "n_envs": 4, 57 | "max_episode_steps": 100, 58 | "env_seed": 42, 59 | "n_threads": 4, 60 | "n_evaluation_threads": 2, 61 | "n_evaluation_episodes": 256, 62 | "time_limit": 3600, 63 | "lr": 0.01, 64 | "discount_factor": 0.95, 65 | "critic_coef": 1.0, 66 | "entropy_coef": 0.01, 67 | "a2c_coef": 1.0, 68 | "logdir": "./results", 69 | } 70 | exp = Experiment(config, create_env, create_agent) 71 | exp.run() 72 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_a2c_gae_s/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_recurrent_a2c_gae_s/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_a2c_gae_s/main_a2c.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_recurrent_a2c_s.agent import RecurrentAgent 19 | from tutorial.tutorial_recurrent_a2c_gae_s.a2c import A2C 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | from gym import ObservationWrapper 23 | 24 | 25 | class MyWrapper(ObservationWrapper): 26 | """Observation wrapper that flattens the observation.""" 27 | 28 | def __init__(self, env): 29 | super(MyWrapper, self).__init__(env) 30 | self.observation_space = None # spaces.flatten_space(env.observation_space) 31 | 32 | def observation(self, observation): 33 | return [observation[0], observation[2]] 34 | 35 | 36 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 37 | def create_gym_env(env_name): 38 | return gym.make(env_name) 39 | 40 | 41 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 42 | envs = [] 43 | for k in range(n_envs): 44 | e = create_gym_env(env_name) 45 | e = MyWrapper(e) 46 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 47 | envs.append(e) 48 | return GymEnv(envs, seed) 49 | 50 | 51 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 52 | envs = [] 53 | for k in range(n_envs): 54 | e = create_gym_env(env_name) 55 | e = MyWrapper(e) 56 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 57 | envs.append(e) 58 | return GymEnvInf(envs, seed) 59 | 60 | 61 | def create_agent(model, n_actions=1): 62 | return RecurrentAgent(model=model, n_actions=n_actions) 63 | 64 | 65 | class Experiment(A2C): 66 | def __init__(self, config, create_train_env, create_env, create_agent): 67 | super().__init__(config, create_train_env, create_env, create_agent) 68 | 69 | 70 | if __name__ == "__main__": 71 | # We use spawn mode such that most of the environment will run in multiple processes 72 | import torch.multiprocessing as mp 73 | 74 | mp.set_start_method("spawn") 75 | 76 | config = { 77 | "env_name": "CartPole-v0", 78 | "a2c_timesteps": 10, 79 | "n_envs": 4, 80 | "max_episode_steps": 100, 81 | "env_seed": 42, 82 | "n_threads": 4, 83 | "n_evaluation_threads": 2, 84 | "n_evaluation_episodes": 256, 85 | "time_limit": 3600, 86 | "lr": 0.001, 87 | "discount_factor": 0.95, 88 | "critic_coef": 1.0, 89 | "entropy_coef": 0.001, 90 | "a2c_coef": 0.1, 91 | "gae_coef": 0.3, 92 | "logdir": "./results", 93 | "clip_grad": 1, 94 | } 95 | exp = Experiment(config, create_train_env, create_env, create_agent) 96 | exp.run() 97 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_a2c_s/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_recurrent_a2c_s/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_a2c_s/main_a2c.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_recurrent_a2c_s.agent import RecurrentAgent 19 | from tutorial.tutorial_recurrent_a2c_s.a2c import A2C 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | from gym import ObservationWrapper 23 | 24 | 25 | class MyWrapper(ObservationWrapper): 26 | """Observation wrapper that flattens the observation.""" 27 | 28 | def __init__(self, env): 29 | super(MyWrapper, self).__init__(env) 30 | self.observation_space = None # spaces.flatten_space(env.observation_space) 31 | 32 | def observation(self, observation): 33 | return [ 34 | observation[0] * 100, 35 | observation[1] * 100, 36 | observation[2] * 100, 37 | observation[3] * 100, 38 | ] 39 | # return [observation[0],observation[2]] 40 | 41 | 42 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 43 | def create_gym_env(env_name): 44 | return gym.make(env_name) 45 | 46 | 47 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 48 | envs = [] 49 | for k in range(n_envs): 50 | e = create_gym_env(env_name) 51 | # e = MyWrapper(e) 52 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 53 | envs.append(e) 54 | return GymEnv(envs, seed) 55 | 56 | 57 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 58 | envs = [] 59 | for k in range(n_envs): 60 | e = create_gym_env(env_name) 61 | # e = MyWrapper(e) 62 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 63 | envs.append(e) 64 | return GymEnvInf(envs, seed) 65 | 66 | 67 | def create_agent(model, n_actions=1): 68 | return RecurrentAgent(model=model, n_actions=n_actions) 69 | 70 | 71 | class Experiment(A2C): 72 | def __init__(self, config, create_train_env, create_env, create_agent): 73 | super().__init__(config, create_train_env, create_env, create_agent) 74 | 75 | 76 | if __name__ == "__main__": 77 | # We use spawn mode such that most of the environment will run in multiple processes 78 | import torch.multiprocessing as mp 79 | 80 | mp.set_start_method("spawn") 81 | 82 | config = { 83 | "env_name": "CartPole-v0", 84 | "a2c_timesteps": 20, 85 | "n_envs": 4, 86 | "max_episode_steps": 100, 87 | "env_seed": 42, 88 | "n_threads": 4, 89 | "n_evaluation_threads": 2, 90 | "n_evaluation_episodes": 256, 91 | "time_limit": 3600, 92 | "lr": 0.001, 93 | "discount_factor": 0.95, 94 | "critic_coef": 1.0, 95 | "entropy_coef": 0.001, 96 | "a2c_coef": 1.0, 97 | "logdir": "./results", 98 | "clip_grad": 2, 99 | } 100 | exp = Experiment(config, create_train_env, create_env, create_agent) 101 | exp.run() 102 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_a2c_s/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | 10 | l = nn.Sequential(nn.Linear(64, 32)) 11 | 12 | while True: 13 | a = torch.randn(1024, 64, dtype=torch.float32) 14 | y = l(a) 15 | yy = l(a[:10]) 16 | print((y[:10] - yy).sum()) 17 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_recurrent_policy/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_recurrent_policy/main_a2c.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv, GymEnvInf 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_recurrent_policy.agent import RecurrentAgent 19 | from tutorial.tutorial_recurrent_policy.a2c import A2C 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 24 | def create_gym_env(env_name): 25 | return gym.make(env_name) 26 | 27 | 28 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 29 | envs = [] 30 | for k in range(n_envs): 31 | e = create_gym_env(env_name) 32 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 33 | envs.append(e) 34 | return GymEnv(envs, seed) 35 | 36 | 37 | def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 38 | envs = [] 39 | for k in range(n_envs): 40 | e = create_gym_env(env_name) 41 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 42 | envs.append(e) 43 | return GymEnvInf(envs, seed) 44 | 45 | 46 | def create_agent(model, n_actions=1): 47 | return RecurrentAgent(model=model, n_actions=n_actions) 48 | 49 | 50 | class Experiment(A2C): 51 | def __init__(self, config, create_env, create_train_env, create_agent): 52 | super().__init__(config, create_env, create_train_env, create_agent) 53 | 54 | 55 | if __name__ == "__main__": 56 | # We use spawn mode such that most of the environment will run in multiple processes 57 | import torch.multiprocessing as mp 58 | 59 | mp.set_start_method("spawn") 60 | 61 | config = { 62 | "env_name": "CartPole-v0", 63 | "a2c_timesteps": 3, 64 | "n_envs": 4, 65 | "max_episode_steps": 100, 66 | "env_seed": 42, 67 | "n_threads": 4, 68 | "n_evaluation_threads": 2, 69 | "n_evaluation_episodes": 256, 70 | "time_limit": 3600, 71 | "lr": 0.001, 72 | "discount_factor": 0.95, 73 | "critic_coef": 1.0, 74 | "entropy_coef": 0.01, 75 | "a2c_coef": 1.0, 76 | "logdir": "./results", 77 | } 78 | exp = Experiment(config, create_env, create_train_env, create_agent) 79 | exp.run() 80 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_reinforce/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | # import rlstructures.logging as logging 13 | from rlstructures import DictTensor 14 | from rlstructures import Agent 15 | import time 16 | from rlstructures.dicttensor import masked_tensor, masked_dicttensor 17 | 18 | 19 | class ReinforceAgent(Agent): 20 | def __init__(self, model=None, n_actions=None): 21 | super().__init__() 22 | self.model = model 23 | self.n_actions = n_actions 24 | 25 | def update(self, state_dict): 26 | self.model.load_state_dict(state_dict) 27 | 28 | def __call__(self, state, observation, agent_info=None, history=None): 29 | """ 30 | Executing one step of the agent 31 | """ 32 | # Verify that the batch size is 1 33 | initial_state = observation["initial_state"] 34 | B = observation.n_elems() 35 | 36 | if agent_info is None: 37 | agent_info = DictTensor({"stochastic": torch.tensor([True]).repeat(B)}) 38 | 39 | # We will store the agent step in the trajectories to illustrate how information can be propagated among multiple timesteps 40 | zero_step = DictTensor({"agent_step": torch.zeros(B).long()}) 41 | if state is None: 42 | # if state is None, it means that the agent does not have any internal state. The internal state thus has to be initialized 43 | state = zero_step 44 | else: 45 | # We initialize the agent_step only for trajectory where an initial_state is met 46 | state = masked_dicttensor(state, zero_step, observation["initial_state"]) 47 | # We compute one score per possible action 48 | action_proba = self.model(observation["frame"]) 49 | 50 | # We sample an action following the distribution 51 | dist = torch.distributions.Categorical(action_proba) 52 | action_sampled = dist.sample() 53 | 54 | # Depending on the agent_info variable that tells us if we are in 'stochastic' or 'deterministic' mode, we keep the sampled action, or compute the action with the max score 55 | action_max = action_proba.max(1)[1] 56 | smask = agent_info["stochastic"].float() 57 | action = masked_tensor(action_max, action_sampled, agent_info["stochastic"]) 58 | 59 | new_state = DictTensor({"agent_step": state["agent_step"] + 1}) 60 | 61 | agent_do = DictTensor({"action": action, "action_probabilities": action_proba}) 62 | 63 | return state, agent_do, new_state 64 | 65 | 66 | class AgentModel(nn.Module): 67 | """The model that computes one score per action""" 68 | 69 | def __init__(self, n_observations, n_actions, n_hidden): 70 | super().__init__() 71 | self.linear = nn.Linear(n_observations, n_hidden) 72 | self.linear2 = nn.Linear(n_hidden, n_actions) 73 | 74 | def forward(self, frame): 75 | z = torch.tanh(self.linear(frame)) 76 | score_actions = self.linear2(z) 77 | probabilities_actions = torch.softmax(score_actions, dim=-1) 78 | return probabilities_actions 79 | 80 | 81 | class BaselineModel(nn.Module): 82 | """The model that computes V(s)""" 83 | 84 | def __init__(self, n_observations, n_hidden): 85 | super().__init__() 86 | self.linear = nn.Linear(n_observations, n_hidden) 87 | self.linear2 = nn.Linear(n_hidden, 1) 88 | 89 | def forward(self, frame): 90 | z = torch.tanh(self.linear(frame)) 91 | critic = self.linear2(z) 92 | return critic 93 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce/main_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_reinforce.agent import ReinforceAgent 19 | from tutorial.tutorial_reinforce.reinforce import Reinforce 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | 24 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | def create_agent(model, n_actions=1): 39 | return ReinforceAgent(model=model, n_actions=n_actions) 40 | 41 | 42 | class Experiment(Reinforce): 43 | def __init__(self, config, create_env, create_agent): 44 | super().__init__(config, create_env, create_agent) 45 | 46 | 47 | if __name__ == "__main__": 48 | # We use spawn mode such that most of the environment will run in multiple processes 49 | import torch.multiprocessing as mp 50 | 51 | mp.set_start_method("spawn") 52 | 53 | config = { 54 | "env_name": "CartPole-v0", 55 | "n_envs": 4, 56 | "max_episode_steps": 100, 57 | "env_seed": 42, 58 | "n_threads": 4, 59 | "time_limit": 3600, 60 | "lr": 0.001, 61 | "discount_factor": 0.9, 62 | "baseline_coef": 0.1, 63 | "entropy_coef": 0.01, 64 | "reinforce_coef": 1.0, 65 | "logdir": "./results", 66 | } 67 | exp = Experiment(config, create_env, create_agent) 68 | exp.run() 69 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_s/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_reinforce_s/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_s/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from rlstructures import DictTensor 12 | from rlstructures import S_Agent 13 | import time 14 | from rlstructures.dicttensor import masked_tensor, masked_dicttensor 15 | 16 | 17 | class ReinforceAgent(S_Agent): 18 | def __init__(self, model=None, n_actions=None): 19 | super().__init__() 20 | self.model = model 21 | self.n_actions = n_actions 22 | 23 | def update(self, state_dict): 24 | self.model.load_state_dict(state_dict) 25 | 26 | def initial_state(selfl, agent_info, B): 27 | return DictTensor({}) 28 | 29 | def __call__(self, state, observation, agent_info=None, history=None): 30 | """ 31 | Executing one step of the agent 32 | """ 33 | # Verify that the batch size is 1 34 | B = observation.n_elems() 35 | # We will store the agent step in the trajectories to illustrate how information can be propagated among multiple timesteps 36 | # We compute one score per possible action 37 | action_proba = self.model.action_model(observation["frame"]) 38 | baseline = self.model.baseline_model(observation["frame"]) 39 | # We sample an action following the distribution 40 | dist = torch.distributions.Categorical(action_proba) 41 | action_sampled = dist.sample() 42 | 43 | # Depending on the agent_info variable that tells us if we are in 'stochastic' or 'deterministic' mode, we keep the sampled action, or compute the action with the max score 44 | action_max = action_proba.max(1)[1] 45 | smask = agent_info["stochastic"].float() 46 | action = masked_tensor(action_max, action_sampled, agent_info["stochastic"]) 47 | 48 | new_state = DictTensor({}) 49 | 50 | agent_do = DictTensor( 51 | { 52 | "action": action, 53 | "action_probabilities": action_proba, 54 | "baseline": baseline, 55 | } 56 | ) 57 | 58 | return agent_do, new_state 59 | 60 | 61 | class Model(nn.Module): 62 | def __init__(self, action_model, baseline_model): 63 | super().__init__() 64 | self.action_model = action_model 65 | self.baseline_model = baseline_model 66 | 67 | 68 | class ActionModel(nn.Module): 69 | """The model that computes one score per action""" 70 | 71 | def __init__(self, n_observations, n_actions, n_hidden): 72 | super().__init__() 73 | self.linear = nn.Linear(n_observations, n_hidden) 74 | self.linear2 = nn.Linear(n_hidden, n_actions) 75 | 76 | def forward(self, frame): 77 | z = torch.tanh(self.linear(frame)) 78 | score_actions = self.linear2(z) 79 | probabilities_actions = torch.softmax(score_actions, dim=-1) 80 | return probabilities_actions 81 | 82 | 83 | class BaselineModel(nn.Module): 84 | """The model that computes V(s)""" 85 | 86 | def __init__(self, n_observations, n_hidden): 87 | super().__init__() 88 | self.linear = nn.Linear(n_observations, n_hidden) 89 | self.linear2 = nn.Linear(n_hidden, 1) 90 | 91 | def forward(self, frame): 92 | z = torch.tanh(self.linear(frame)) 93 | critic = self.linear2(z) 94 | return critic 95 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_s/main_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging, DictTensor 10 | from rlstructures.env_wrappers import GymEnv 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_reinforce_s.agent import ReinforceAgent 19 | from tutorial.tutorial_reinforce_s.reinforce import Reinforce 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | 24 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | gym_env = GymEnv( 36 | envs, seed 37 | ) # ,default_env_info=DictTensor({"test":torch.ones(n_envs)})) 38 | return gym_env 39 | 40 | 41 | def create_agent(model, n_actions=1): 42 | return ReinforceAgent(model=model, n_actions=n_actions) 43 | 44 | 45 | class Experiment(Reinforce): 46 | def __init__(self, config, create_env, create_agent): 47 | super().__init__(config, create_env, create_agent) 48 | 49 | 50 | if __name__ == "__main__": 51 | # We use spawn mode such that most of the environment will run in multiple processes 52 | import torch.multiprocessing as mp 53 | 54 | mp.set_start_method("spawn") 55 | 56 | config = { 57 | "env_name": "CartPole-v0", 58 | "n_envs": 4, 59 | "max_episode_steps": 100, 60 | "env_seed": 42, 61 | "n_threads": 4, 62 | "time_limit": 3600, 63 | "lr": 0.01, 64 | "discount_factor": 0.9, 65 | "baseline_coef": 0.1, 66 | "entropy_coef": 0.01, 67 | "reinforce_coef": 1.0, 68 | "logdir": "./results", 69 | } 70 | exp = Experiment(config, create_env, create_agent) 71 | exp.run() 72 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_with_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_reinforce_with_evaluation/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_with_evaluation/main_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures import logging 10 | from rlstructures.env_wrappers import GymEnv 11 | from rlstructures.tools import weight_init 12 | import torch.nn as nn 13 | import copy 14 | import torch 15 | import time 16 | import numpy as np 17 | import torch.nn.functional as F 18 | from tutorial.tutorial_reinforce.agent import ReinforceAgent 19 | from tutorial.tutorial_reinforce_with_evaluation.reinforce import Reinforce 20 | import gym 21 | from gym.wrappers import TimeLimit 22 | 23 | 24 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 25 | def create_gym_env(env_name): 26 | return gym.make(env_name) 27 | 28 | 29 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 30 | envs = [] 31 | for k in range(n_envs): 32 | e = create_gym_env(env_name) 33 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 34 | envs.append(e) 35 | return GymEnv(envs, seed) 36 | 37 | 38 | def create_agent(model, n_actions=1): 39 | return ReinforceAgent(model=model, n_actions=n_actions) 40 | 41 | 42 | class Experiment(Reinforce): 43 | def __init__(self, config, create_env, create_agent): 44 | super().__init__(config, create_env, create_agent) 45 | 46 | 47 | if __name__ == "__main__": 48 | # We use spawn mode such that most of the environment will run in multiple processes 49 | import torch.multiprocessing as mp 50 | 51 | mp.set_start_method("spawn") 52 | 53 | config = { 54 | "env_name": "CartPole-v0", 55 | "n_envs": 4, 56 | "max_episode_steps": 100, 57 | "env_seed": 42, 58 | "n_threads": 4, 59 | "n_evaluation_threads": 2, 60 | "n_evaluation_episodes": 256, 61 | "time_limit": 3600, 62 | "lr": 0.01, 63 | "discount_factor": 0.9, 64 | "baseline_coef": 0.1, 65 | "entropy_coef": 0.01, 66 | "reinforce_coef": 1.0, 67 | "logdir": "./results", 68 | } 69 | exp = Experiment(config, create_env, create_agent) 70 | exp.run() 71 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_with_evaluation_s/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rlstructures/b5e824fed7a5f347a2aada2e94bd4f7c69487793/tutorial/deprecated/tutorial_reinforce_with_evaluation_s/__init__.py -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_with_evaluation_s/agent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from rlstructures import DictTensor, masked_tensor, masked_dicttensor 12 | from rlstructures import Agent 13 | import time 14 | 15 | 16 | class ReinforceAgent(Agent): 17 | def __init__(self, model=None, n_actions=None): 18 | super().__init__() 19 | self.model = model 20 | self.n_actions = n_actions 21 | 22 | def update(self, state_dict): 23 | self.model.load_state_dict(state_dict) 24 | 25 | def initial_state(selfl, agent_info, B): 26 | return DictTensor({}) 27 | 28 | def __call__(self, state, observation, agent_info=None, history=None): 29 | """ 30 | Executing one step of the agent 31 | """ 32 | # Verify that the batch size is 1 33 | B = observation.n_elems() 34 | # We will store the agent step in the trajectories to illustrate how information can be propagated among multiple timesteps 35 | # We compute one score per possible action 36 | action_proba = self.model.action_model(observation["frame"]) 37 | baseline = self.model.baseline_model(observation["frame"]) 38 | # We sample an action following the distribution 39 | dist = torch.distributions.Categorical(action_proba) 40 | action_sampled = dist.sample() 41 | 42 | # Depending on the agent_info variable that tells us if we are in 'stochastic' or 'deterministic' mode, we keep the sampled action, or compute the action with the max score 43 | action_max = action_proba.max(1)[1] 44 | smask = agent_info["stochastic"].float() 45 | action = masked_tensor(action_max, action_sampled, agent_info["stochastic"]) 46 | 47 | new_state = DictTensor({}) 48 | 49 | agent_do = DictTensor( 50 | { 51 | "action": action, 52 | "action_probabilities": action_proba, 53 | "baseline": baseline, 54 | } 55 | ) 56 | 57 | return agent_do, new_state 58 | 59 | 60 | class Model(nn.Module): 61 | def __init__(self, action_model, baseline_model): 62 | super().__init__() 63 | self.action_model = action_model 64 | self.baseline_model = baseline_model 65 | 66 | 67 | class ActionModel(nn.Module): 68 | """The model that computes one score per action""" 69 | 70 | def __init__(self, n_observations, n_actions, n_hidden): 71 | super().__init__() 72 | self.linear = nn.Linear(n_observations, n_hidden) 73 | self.linear2 = nn.Linear(n_hidden, n_actions) 74 | 75 | def forward(self, frame): 76 | z = torch.tanh(self.linear(frame)) 77 | score_actions = self.linear2(z) 78 | probabilities_actions = torch.softmax(score_actions, dim=-1) 79 | return probabilities_actions 80 | 81 | 82 | class BaselineModel(nn.Module): 83 | """The model that computes V(s)""" 84 | 85 | def __init__(self, n_observations, n_hidden): 86 | super().__init__() 87 | self.linear = nn.Linear(n_observations, n_hidden) 88 | self.linear2 = nn.Linear(n_hidden, 1) 89 | 90 | def forward(self, frame): 91 | z = torch.tanh(self.linear(frame)) 92 | critic = self.linear2(z) 93 | return critic 94 | -------------------------------------------------------------------------------- /tutorial/deprecated/tutorial_reinforce_with_evaluation_s/main_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | from rlstructures.env_wrappers import GymEnv 10 | from rlalgos.tools import weight_init 11 | import torch.nn as nn 12 | import copy 13 | import torch 14 | import time 15 | import numpy as np 16 | import torch.nn.functional as F 17 | from tutorial.tutorial_reinforce_with_evaluation_s.agent import ReinforceAgent 18 | from tutorial.tutorial_reinforce_with_evaluation_s.reinforce import Reinforce 19 | import gym 20 | from gym.wrappers import TimeLimit 21 | 22 | 23 | # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes 24 | def create_gym_env(env_name): 25 | return gym.make(env_name) 26 | 27 | 28 | def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): 29 | envs = [] 30 | for k in range(n_envs): 31 | e = create_gym_env(env_name) 32 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 33 | envs.append(e) 34 | return GymEnv(envs, seed) 35 | 36 | 37 | def create_agent(model, n_actions=1): 38 | return ReinforceAgent(model=model, n_actions=n_actions) 39 | 40 | 41 | class Experiment(Reinforce): 42 | def __init__(self, config, create_env, create_agent): 43 | super().__init__(config, create_env, create_agent) 44 | 45 | 46 | if __name__ == "__main__": 47 | # We use spawn mode such that most of the environment will run in multiple processes 48 | import torch.multiprocessing as mp 49 | 50 | mp.set_start_method("spawn") 51 | 52 | config = { 53 | "env_name": "CartPole-v0", 54 | "n_envs": 4, 55 | "max_episode_steps": 100, 56 | "env_seed": 42, 57 | "n_threads": 4, 58 | "n_evaluation_threads": 2, 59 | "n_evaluation_envs": 128, 60 | "time_limit": 3600, 61 | "lr": 0.01, 62 | "discount_factor": 0.9, 63 | "baseline_coef": 0.1, 64 | "entropy_coef": 0.01, 65 | "reinforce_coef": 1.0, 66 | "logdir": "./results", 67 | } 68 | exp = Experiment(config, create_env, create_agent) 69 | exp.run() 70 | -------------------------------------------------------------------------------- /tutorial/playing_with_envs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from rlstructures.env_wrappers import GymEnv 8 | from rlstructures import DictTensor 9 | import torch 10 | import gym 11 | 12 | envs = [gym.make("CartPole-v0") for k in range(4)] 13 | env = GymEnv(envs, seed=80) 14 | 15 | obs, who_is_still_running = env.reset() 16 | print(obs) 17 | n_running = who_is_still_running.size()[0] 18 | while n_running > 0: # While some envs are still running 19 | action = DictTensor({"action": torch.tensor([0]).repeat(n_running)}) 20 | (obs, who_was_running), (obs2, who_is_still_running) = env.step(action) 21 | n_running = who_is_still_running.size()[0] 22 | print(obs) 23 | -------------------------------------------------------------------------------- /tutorial/tutorial_datastructures.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | ###### DictTensor 10 | 11 | # A DictTensor is dictionary of pytorch tensors. It assumes that the 12 | # first dimension of each tensor contained in the DictTensor is the batch dimension 13 | # The easiest way to build a DictTensor is to use a ditcionnary of tensors as input 14 | 15 | from rlstructures import DictTensor 16 | import torch 17 | 18 | d = DictTensor({"x": torch.randn(3, 5), "y": torch.randn(3, 8)}) 19 | 20 | # The number of batches is accessible through n_elems() 21 | print(d.n_elems(), " <- number of elements in the batch") 22 | 23 | 24 | # Many methods can be used over DictTensor (see DictTensor documentation) 25 | 26 | d["x"] # Returns the tensor 'x' in the DictTensor 27 | d.keys() # Returns the names of the variables of the DictTensor 28 | 29 | # An empty DictTensor can be defined as follows: 30 | d = DictTensor({}) 31 | 32 | 33 | ###### TemporalDictTensor 34 | 35 | # A TemporalDictTensor is a sequence of DictTensors. In memory, it is stored as a dictionary of tensors, 36 | # where the first dimesion is the batch dimension, and the second dimension is the time index. 37 | # Each element in the batch is a sequence, and two sequences can have a different length.etc...") 38 | 39 | from rlstructures import TemporalDictTensor 40 | 41 | # Create three sequences of variables x and y, where the length of the first sequence is 6, the length of the second is 10 and the length of the last sequence is 3 42 | d = TemporalDictTensor( 43 | {"x": torch.randn(3, 10, 5), "y": torch.randn(3, 10, 8)}, 44 | lengths=torch.tensor([6, 10, 3]), 45 | ) 46 | 47 | print(d.n_elems(), " <- number of elements in the batch") 48 | print(d.lengths, "<- Lengths of the sequences") 49 | print(d["x"].size(), "<- access to the tensor 'x'") 50 | 51 | print("Masking: ") 52 | print(d.mask()) 53 | 54 | print("Slicing (restricting the sequence to some particular temporal indexes) ") 55 | d_slice = d.temporal_slice(0, 4) 56 | print(d_slice.lengths) 57 | print(d_slice.mask()) 58 | 59 | # IMPORTANT: DictTensor and TemporalDictTensor can be moved to cpu/gpu using the .to method 60 | -------------------------------------------------------------------------------- /tutorial/tutorial_environments.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import gym 9 | from gym.utils import seeding 10 | from gym.spaces import Discrete 11 | 12 | # Defining a custom environment. 13 | # 1. It is a gym.Env environment 14 | # 2. (but observation_space can be empty) 15 | # 3. The observation can be a list/np.array or a dictionnary of list/np.array 16 | # 4. The reset function may receive a env_info arguments as a dictionnary of list/np.array 17 | class MyEnv(gym.Env): 18 | def __init__(self): 19 | super().__init__() 20 | self.action_space = Discrete(2) 21 | 22 | def seed(self, seed=None): 23 | print("Seed = %d" % seed) 24 | self.np_random, seed = seeding.np_random(seed) 25 | 26 | def reset(self, env_info={}): 27 | self.x = self.np_random.rand() * 2.0 - 1.0 28 | self.identifier = self.np_random.rand() 29 | return {"x": self.x, "identifier": self.identifier} 30 | 31 | def step(self, action): 32 | if action == 0: 33 | self.x -= 0.3 34 | else: 35 | self.x += 0.3 36 | 37 | return ( 38 | {"x": self.x, "identifier": self.identifier}, 39 | self.x, 40 | self.x < -1 or self.x > 1, 41 | {}, 42 | ) 43 | 44 | 45 | # Now, one can use a wrapper to transform this gym.Env to a rlstructures.VecEnv 46 | # 1. A VecEnv corresponds to env.n_envs() environnements that are running simultaneously 47 | # 2. VecEnv receives a dictTensor d as action where d.n_elems()<=env.n_envs() 48 | # 3. VecEnv returns a DictTensor obs as an observation. This observation contains multiple "field". 49 | # e.g obs["reward"] is the reward signal, obs["initial_state"] tells if it is the first state of a new episode, ... 50 | # Actually, since N <= env.n_envs() environments are still running at timestep t, the VecEnv.reset() and VecEnv.step(...) methods also returns the list of envs that are still running 51 | # 52 | # Example: (obs,who_was_running),(obs2,who_is_still_running) = env.step(action) 53 | # * obs is the observation (at t) coming from the environments that were running at t-1 54 | # * who_was_running is the list of environnments still running at time t-1. Note that who_was_running.size()[0]=obs.n_elems() 55 | # * obs2 is the observation (at t) from the environments that are still running at time t (i.e obs2 is a subset of obs) 56 | # * who_is_still_running is the list of environments running at time t 57 | 58 | from rlstructures.env_wrappers import GymEnv 59 | from rlstructures import DictTensor 60 | import torch 61 | 62 | envs = [MyEnv() for k in range(4)] 63 | env = GymEnv(envs, seed=80) 64 | 65 | # Each instance of the gym.Env will be initialized with seed+i such that the multiple instances will have different seeds 66 | 67 | # Interaction with the environment is easy, but made by using DictTensor 68 | 69 | obs, who_is_still_running = env.reset() 70 | print(obs) 71 | n_running = who_is_still_running.size()[0] 72 | while n_running > 0: # While some envs are still running 73 | action = DictTensor({"action": torch.tensor([0]).repeat(n_running)}) 74 | (obs, who_was_running), (obs2, who_is_still_running) = env.step(action) 75 | n_running = who_is_still_running.size()[0] 76 | print(obs2) 77 | 78 | # Note that gym wrappers work with continuous and discrete action spaces, but may not with environments where the action space is more complicated. 79 | # If you are facing gym envs with a complex action space, you may develop your own wrapper 80 | # A good starting point is the rlstructures.GymEnv code which is very simple can be used to define a new wrapper 81 | # All the other rlstuctures components will work with complex action spaces without modifications 82 | -------------------------------------------------------------------------------- /tutorial/tutorial_rlagent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from rlstructures import RL_Agent, DictTensor 9 | import torch 10 | 11 | 12 | class UniformAgent(RL_Agent): 13 | def __init__(self, n_actions): 14 | super().__init__() 15 | self.n_actions = n_actions 16 | 17 | def initial_state(self, agent_info, B): 18 | return DictTensor({"timestep": torch.zeros(B).long()}) 19 | 20 | def __call__(self, state, observation, agent_info=None, history=None): 21 | B = observation.n_elems() 22 | 23 | scores = torch.randn(B, self.n_actions) 24 | probabilities = torch.softmax(scores, dim=1) 25 | actions = torch.distributions.Categorical(probabilities).sample() 26 | new_state = DictTensor({"timestep": state["timestep"] + 1}) 27 | return DictTensor({"action": actions}), new_state 28 | 29 | 30 | # Agent and Batcher 31 | # 32 | # An *Agent* and a *VecEnv* are used together into a **Batcher** to collect episodes or trjaectories (a trajectory is a piece of episode) 33 | # The simplest Batcher is the **MonoThreadEpisodeBatcher** which is running in the main process. Other batcher are in RLStructures: 34 | # * The *EpisodeBatcher* is a multi-process batcher sampling full episodes 35 | # * The *Batcher* is a multi-process batcher sampling N timesteps 36 | # The complex batchers are explained later 37 | 38 | # For creating a batcher, one has to provide **(pickable) functions and arguments** and not built object. Indeed, the batchers are taking in charge the creation of the objects. 39 | 40 | import gym 41 | from gym.wrappers import TimeLimit 42 | from rlstructures.env_wrappers import GymEnv 43 | 44 | 45 | def create_env(max_episode_steps=100, seed=None): 46 | envs = [] 47 | for k in range(4): 48 | e = gym.make("CartPole-v0") 49 | e.seed(seed) 50 | e = TimeLimit(e, max_episode_steps=max_episode_steps) 51 | envs.append(e) 52 | return GymEnv(envs, seed=10) 53 | 54 | 55 | def create_agent(n_actions): 56 | return UniformAgent(n_actions) 57 | 58 | 59 | if __name__ == "__main__": 60 | # We use spawn mode such that most of the environment will run in multiple processes 61 | import torch.multiprocessing as mp 62 | 63 | mp.set_start_method("spawn") 64 | from rlstructures import RL_Batcher 65 | 66 | batcher = RL_Batcher( 67 | n_timesteps=100, 68 | create_agent=create_agent, 69 | create_env=create_env, 70 | agent_args={"n_actions": 2}, 71 | env_args={"max_episode_steps": 100}, 72 | n_processes=1, 73 | seeds=[42], 74 | agent_info=DictTensor({}), 75 | env_info=DictTensor({}), 76 | ) 77 | 78 | batcher.reset() 79 | batcher.execute() 80 | trajectories, n_still_running_envs = batcher.get() 81 | 82 | print("Informations: ") 83 | print(trajectories, trajectories.info) 84 | print("Lengths of trajectories: ") 85 | print(trajectories.trajectories.lengths) 86 | --------------------------------------------------------------------------------