├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── pull-request.md ├── .gitignore ├── .idea ├── codeStyles │ ├── Project.xml │ └── codeStyleConfig.xml └── dictionaries │ └── jlaivins.xml ├── Dockerfile ├── LICENSE.txt ├── README.md ├── ROADMAP.md ├── azure-pipelines.yml ├── build └── azure_pipeline_helper.sh ├── docs_src ├── data │ ├── ant_ddpg │ │ ├── ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── cartpole_dddqn │ │ ├── dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ ├── dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ │ ├── dddqn_er_rms.pickle │ │ └── dddqn_per_rms.pickle │ ├── cartpole_ddqn │ │ ├── ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ ├── ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ │ ├── ddqn_er_rms.pickle │ │ └── ddqn_per_rms.pickle │ ├── cartpole_dqn fixed targeting │ │ ├── dqn fixed targeting_er_rms.pickle │ │ └── dqn fixed targeting_per_rms.pickle │ ├── cartpole_dqn │ │ ├── dqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── cartpole_dueling dqn │ │ ├── dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ ├── dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ │ ├── dueling dqn_er_rms.pickle │ │ └── dueling dqn_per_rms.pickle │ ├── cartpole_fixed target dqn │ │ ├── fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── halfcheetah_ddpg │ │ ├── ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── lunarlander_dddqn │ │ ├── dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── lunarlander_ddqn │ │ ├── ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── lunarlander_dqn fixed targeting │ │ ├── dqn fixed targeting_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── dqn fixed targeting_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── lunarlander_dqn │ │ ├── dqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── lunarlander_dueling dqn │ │ ├── dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ ├── lunarlander_fixed target dqn │ │ ├── fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle │ └── pendulum_ddpg │ │ ├── ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle │ │ └── ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle ├── rl.agents.dddqn.ipynb ├── rl.agents.ddpg.ipynb ├── rl.agents.doubledqn.ipynb ├── rl.agents.dqn.ipynb ├── rl.agents.dqnfixedtarget.ipynb ├── rl.agents.duelingdqn.ipynb ├── rl.core.train.interpretation.ipynb └── util.gif_handling.ipynb ├── environment.yaml ├── fast_rl ├── __init__.py ├── agents │ ├── __init__.py │ ├── ddpg.py │ ├── ddpg_models.py │ ├── dqn.py │ └── dqn_models.py ├── core │ ├── Interpreter.py │ ├── __init__.py │ ├── agent_core.py │ ├── basic_train.py │ ├── data_block.py │ ├── data_structures.py │ ├── layers.py │ ├── metrics.py │ └── train.py └── util │ ├── __init__.py │ ├── exceptions.py │ └── misc.py ├── res ├── RELEASE_BLOG.md ├── ddpg_balancing.gif ├── dqn_q_estimate_1.jpg ├── dqn_q_estimate_2.jpg ├── dqn_q_estimate_3.jpg ├── fit_func_out.jpg ├── heat_map_1.png ├── heat_map_2.png ├── heat_map_3.png ├── heat_map_4.png ├── pre_interpretation_maze_dqn.gif ├── reward_plot_1.png ├── reward_plot_2.png ├── reward_plots │ ├── ant_ddpg.png │ ├── cartpole_dddqn.png │ ├── cartpole_double.png │ ├── cartpole_dqn.png │ ├── cartpole_dueling.png │ ├── cartpole_fixedtarget.png │ ├── halfcheetah_ddpg.png │ ├── lunarlander_all_targetbased.png │ ├── lunarlander_dddqn.png │ ├── lunarlander_double.png │ ├── lunarlander_dqn.png │ ├── lunarlander_dueling.png │ ├── lunarlander_fixedtarget.png │ └── pendulum_ddpg.png └── run_gifs │ ├── __init__.py │ ├── acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif │ ├── acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif │ ├── acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif │ ├── acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif │ ├── acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif │ ├── acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif │ ├── ant_ExperienceReplay_DDPGModule_1_episode_54.gif │ ├── ant_ExperienceReplay_DDPGModule_1_episode_614.gif │ ├── ant_ExperienceReplay_DDPGModule_1_episode_999.gif │ ├── ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif │ ├── ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif │ ├── ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif │ ├── cartpole_ExperienceReplay_DQNModule_1_episode_207.gif │ ├── cartpole_ExperienceReplay_DQNModule_1_episode_31.gif │ ├── cartpole_ExperienceReplay_DQNModule_1_episode_447.gif │ ├── cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif │ ├── cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif │ ├── cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif │ ├── cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif │ ├── cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif │ ├── cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif │ ├── cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif │ ├── cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif │ ├── cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif │ ├── cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif │ ├── cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif │ ├── cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif │ ├── cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif │ ├── cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif │ ├── cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif │ ├── cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif │ ├── cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif │ ├── cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif │ ├── cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif │ ├── cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif │ ├── cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif │ ├── cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif │ ├── cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif │ ├── cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif │ ├── cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif │ ├── cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif │ ├── cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif │ ├── lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif │ ├── lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif │ ├── lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif │ ├── lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif │ ├── lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif │ ├── lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif │ ├── lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif │ ├── lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif │ ├── lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif │ ├── lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif │ ├── lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif │ ├── lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif │ ├── lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif │ ├── lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif │ ├── lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif │ ├── lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif │ ├── lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif │ ├── lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif │ ├── lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif │ ├── lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif │ ├── lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif │ ├── lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif │ ├── lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif │ ├── lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif │ ├── pendulum_ExperienceReplay_DDPGModule_1_episode_238.gif │ ├── pendulum_ExperienceReplay_DDPGModule_1_episode_447.gif │ ├── pendulum_ExperienceReplay_DDPGModule_1_episode_9.gif │ ├── pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif │ ├── pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif │ └── pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── data ├── cartpole_dqn │ └── dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle ├── cat │ ├── cat1.jpeg │ └── cat2.jpeg └── dog │ ├── dog1.jpeg │ └── dog2.jpeg ├── test_agent_core.py ├── test_basic_train.py ├── test_data_block.py ├── test_data_structures.py ├── test_ddpg.py ├── test_dqn.py ├── test_metrics.py └── test_train.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **System (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Version [e.g. 22] 29 | - Python Version [e.g. 3.6] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **This feature is officially added when** 20 | - [ ] These items are accomplished . . . 21 | . . . 22 | 23 | **Additional context** 24 | Add any other context or screenshots about the feature request here. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/pull-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Pull Request 3 | about: Guide lines for making pull requests 4 | title: "[Added or Changed or Fixed] . . . something in you PR [FIX or FEATURE]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **REMOVE ME START** 11 | IMPORTANT: Please do not create a Pull Request without creating an issue first. Template from [stevemao templates](https://github.com/stevemao/github-issue-templates). 12 | 13 | Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request. 14 | **END** 15 | 16 | **What this PR does** 17 | Please provide enough information so that others can review your pull request. Start with a quick TL:DR or a short set of statements one what you did. 18 | 19 | **Why this PR exists** 20 | Explain the details for making this change. What existing problem does the pull request solve? 21 | 22 | **This is how I plan to test this** 23 | Test plan (required) 24 | 25 | Demonstrate the code is solid. Example: The exact commands you ran and their output, screenshots / videos if the pull request changes UI. 26 | 27 | Code formatting 28 | 29 | **Closing issues** 30 | 31 | Put closes #XXXX in your comment to auto-close the issue that your PR fixes (if such). 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ project files 2 | *.iml 3 | .idea/inspectionProfiles/* 4 | misc.xml 5 | modules.xml 6 | other.xml 7 | vcs.xml 8 | workspace.xml 9 | out 10 | gen 11 | .DS_Store 12 | 13 | # Jupyter Notebook 14 | */.ipynb_checkpoints/* 15 | 16 | # Secure Files 17 | .pypirc 18 | 19 | # Data Files 20 | #/docs_src/data/* 21 | 22 | # Build Files / Directories 23 | build/* 24 | dist/* 25 | fast_rl.egg-info/* -------------------------------------------------------------------------------- /.idea/codeStyles/Project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 22 | 23 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/dictionaries/jlaivins.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | argmax 5 | conv 6 | convolutional 7 | cumulate 8 | darknet 9 | datasets 10 | ddpg 11 | deallocated 12 | dtype 13 | figsize 14 | fixedtargetdqn 15 | gifs 16 | groupable 17 | moviepy 18 | ornstein 19 | pybullet 20 | uhlenbeck 21 | 22 | 23 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-base-ubuntu18.04 2 | # See http://bugs.python.org/issue19846 3 | ENV LANG C.UTF-8 4 | LABEL com.nvidia.volumes.needed="nvidia_driver" 5 | 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | build-essential cmake git curl vim ca-certificates python-qt4 libjpeg-dev \ 8 | zip nano unzip libpng-dev strace python-opengl xvfb && \ 9 | rm -rf /var/lib/apt/lists/* 10 | 11 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 12 | ENV PYTHON_VERSION=3.6 13 | 14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 15 | chmod +x ~/miniconda.sh && \ 16 | ~/miniconda.sh -b -p /opt/conda && \ 17 | rm ~/miniconda.sh && \ 18 | /opt/conda/bin/conda install conda-build && \ 19 | apt-get update && apt-get upgrade -y --no-install-recommends 20 | 21 | ENV PATH=$PATH:/opt/conda/bin/ 22 | ENV USER fastrl_user 23 | # Create Enviroment 24 | COPY environment.yaml /environment.yaml 25 | RUN conda env create -f environment.yaml 26 | 27 | # Cleanup 28 | RUN rm -rf /var/lib/apt/lists/* \ 29 | && apt-get -y autoremove 30 | 31 | EXPOSE 8888 32 | ENV CONDA_DEFAULT_ENV fastrl 33 | 34 | CMD ["/bin/bash -c"] 35 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://dev.azure.com/jokellum/jokellum/_apis/build/status/josiahls.fast-reinforcement-learning?branchName=master)](https://dev.azure.com/jokellum/jokellum/_build/latest?definitionId=1&branchName=master) 2 | [![pypi fasti_rl version](https://img.shields.io/pypi/v/fast_rl)](https://pypi.python.org/pypi/fast_rl) 3 | [![github_master version](https://img.shields.io/github/v/release/josiahls/fast-reinforcement-learning?include_prereleases)](https://github.com/josiahls/fast-reinforcement-learning/releases) 4 | 5 | **Important Note** 6 | fastrl==2.* is being developed at [fastrl](https://github.com/josiahls/fastrl). https://github.com/josiahls/fastrl) is the permanent place for all fastai version 2.0 changes as well as faster/refactored/more stable models. Please go there instead for new information/code. 7 | 8 | 9 | 10 | # Fast_rl 11 | This repo is not affiliated with Jeremy Howard or his course which can be found [here](https://www.fast.ai/about/). 12 | We will be using components from the Fastai library for building and training our reinforcement learning (RL) 13 | agents. 14 | 15 | Our goal is for fast_rl to be make benchmarking easier, inference more efficient, and environment compatibility to be 16 | as decoupled as much as possible. This being version 1.0, we still have a lot of work to make RL training itself faster 17 | and more efficient. The goals for this repo can be seen in the [RoadMap](#roadmap). 18 | 19 | **An important note is that training can use up a lot of RAM. This will likely be resolved as more models are being added. Likely will be resolved by off loading to storage in the next few versions.** 20 | 21 | A simple example: 22 | ```python 23 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner 24 | from fast_rl.agents.dqn_models import * 25 | from fast_rl.core.agent_core import ExperienceReplay, GreedyEpsilon 26 | from fast_rl.core.data_block import MDPDataBunch 27 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric 28 | 29 | memory = ExperienceReplay(memory_size=1000000, reduce_ram=True) 30 | explore = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 31 | data = MDPDataBunch.from_env('CartPole-v1', render='human', bs=64, add_valid=False) 32 | model = create_dqn_model(data=data, base_arch=FixedTargetDQNModule, lr=0.001, layers=[32,32]) 33 | learn = dqn_learner(data, model, memory=memory, exploration_method=explore, copy_over_frequency=300, 34 | callback_fns=[RewardMetric, EpsilonMetric]) 35 | learn.fit(450) 36 | ``` 37 | 38 | More complex examples might involve running an RL agent multiple times, generating episode snapshots as gifs, grouping 39 | reward plots, and finally showing the best and worst runs in a single graph. 40 | ```python 41 | from fastai.basic_data import DatasetType 42 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner 43 | from fast_rl.agents.dqn_models import * 44 | from fast_rl.core.agent_core import ExperienceReplay, GreedyEpsilon 45 | from fast_rl.core.data_block import MDPDataBunch 46 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric 47 | from fast_rl.core.train import GroupAgentInterpretation, AgentInterpretation 48 | 49 | group_interp = GroupAgentInterpretation() 50 | for i in range(5): 51 | memory = ExperienceReplay(memory_size=1000000, reduce_ram=True) 52 | explore = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 53 | data = MDPDataBunch.from_env('CartPole-v1', render='human', bs=64, add_valid=False) 54 | model = create_dqn_model(data=data, base_arch=FixedTargetDQNModule, lr=0.001, layers=[32,32]) 55 | learn = dqn_learner(data, model, memory=memory, exploration_method=explore, copy_over_frequency=300, 56 | callback_fns=[RewardMetric, EpsilonMetric]) 57 | learn.fit(450) 58 | 59 | interp=AgentInterpretation(learn, ds_type=DatasetType.Train) 60 | interp.plot_rewards(cumulative=True, per_episode=True, group_name='cartpole_experience_example') 61 | group_interp.add_interpretation(interp) 62 | group_interp.to_pickle(f'{learn.model.name.lower()}/', f'{learn.model.name.lower()}') 63 | for g in interp.generate_gif(): g.write(f'{learn.model.name.lower()}') 64 | group_interp.plot_reward_bounds(per_episode=True, smooth_groups=10) 65 | ``` 66 | More examples can be found in `docs_src` and the actual code being run for generating gifs can be found in `tests` in 67 | either `test_dqn.py` or `test_ddpg.py`. 68 | 69 | As a note, here is a run down of existing RL frameworks: 70 | - [Intel Coach](https://github.com/NervanaSystems/coach) 71 | - [Tensor Force](https://github.com/tensorforce/tensorforce) 72 | - [OpenAI Baselines](https://github.com/openai/baselines) 73 | - [Tensorflow Agents](https://github.com/tensorflow/agents) 74 | - [KerasRL](https://github.com/keras-rl/keras-rl) 75 | 76 | However there are also frameworks in PyTorch: 77 | - [Horizon](https://github.com/facebookresearch/Horizon) 78 | - [DeepRL](https://github.com/ShangtongZhang/DeepRL) 79 | - [Spinning Up](https://spinningup.openai.com/en/latest/user/introduction.html) 80 | 81 | ## Installation 82 | 83 | **fastai (semi-optional)**\ 84 | [Install Fastai](https://github.com/fastai/fastai/blob/master/README.md#installation) 85 | or if you are using Anaconda (which is a good idea to use Anaconda) you can do: \ 86 | `conda install -c pytorch -c fastai fastai` 87 | 88 | **fast_rl**\ 89 | Fastai will be installed if it does not exist. If it does exist, the versioning should be repaired by the the setup.py. 90 | `pip install fastai` 91 | 92 | ## Installation (Optional) 93 | OpenAI all gyms: \ 94 | `pip install gym[all]` 95 | 96 | Mazes: \ 97 | `git clone https://github.com/MattChanTK/gym-maze.git` \ 98 | `cd gym-maze` \ 99 | `python setup.py install` 100 | 101 | 102 | ## Installation Dev (Optional) 103 | `git clone https://github.com/josiahls/fast-reinforcement-learning.git` \ 104 | `cd fast-reinforcement-learning` \ 105 | `python setup.py install` 106 | 107 | ## Installation Issues 108 | Many issues will likely fall under [fastai installation issues](https://github.com/fastai/fastai/blob/master/README.md#installation-issues). 109 | 110 | Any other issues are likely environment related. It is important to note that Python 3.7 is not being tested due to 111 | an issue with Pyglet and gym do not working. This issue will not stop you from training models, however this might impact using 112 | OpenAI environments. 113 | 114 | ## RoadMap 115 | 116 | - [ ] **Working on** **1.0.0** Base version is completed with working model visualizations proving performance / expected failure. At 117 | this point, all models should have guaranteed environments they should succeed in. 118 | - [ ] 1.1.0 **Working on** More Traditional RL models 119 | - [ ] **Working on** Add PPO 120 | - [ ] **Working on** Add TRPO 121 | - [ ] Add D4PG 122 | - [ ] Add A2C 123 | - [ ] Add A3C 124 | - [ ] 1.2.0 HRL models *Possibly might change version to 2.0 depending on SMDP issues* 125 | - [ ] Add SMDP 126 | - [ ] Add Goal oriented MDPs. Will Require a new "Step" 127 | - [ ] Add FeUdal Network 128 | - [ ] Add storage based DataBunch memory management. This can prevent RAM from being used up by episode image frames 129 | that may or may not serve any use to the agent, but only for logging. 130 | - [ ] 1.3.0 131 | - [ ] Add HAC 132 | - [ ] Add MAXQ 133 | - [ ] Add HIRO 134 | - [ ] 1.4.0 135 | - [ ] Add h-DQN 136 | - [ ] Add Modulated Policy Hierarchies 137 | - [ ] Add Meta Learning Shared Hierarchies 138 | - [ ] 1.5.0 139 | - [ ] Add STRategic Attentive Writer (STRAW) 140 | - [ ] Add H-DRLN 141 | - [ ] Add Abstract Markov Decision Process (AMDP) 142 | - [ ] Add conda integration so that installation can be truly one step. 143 | - [ ] 1.6.0 HRL Options models *Possibly will already be implemented in a previous model* 144 | - [ ] Options augmentation to DQN based models 145 | - [ ] Options augmentation to actor critic models 146 | - [ ] Options augmentation to async actor critic models 147 | - [ ] 1.8.0 HRL Skills 148 | - [ ] Skills augmentation to DQN based models 149 | - [ ] Skills augmentation to actor critic models 150 | - [ ] Skills augmentation to async actor critic models 151 | - [ ] 1.9.0 152 | - [ ] 2.0.0 Add PyBullet Fetch Environments 153 | - [ ] 2.0.0 Not part of this repo, however the envs need to subclass the OpenAI `gym.GoalEnv` 154 | - [ ] 2.0.0 Add HER 155 | 156 | 157 | ## Contribution 158 | Following fastai's guidelines would be desirable: [Guidelines](https://github.com/fastai/fastai/blob/master/README.md#contribution-guidelines) 159 | 160 | While we hope that model additions will be added smoothly. All models will only be dependent on `core.layers.py`. 161 | As time goes on, the model architecture will overall improve (we are and while continue to be still figuring things out). 162 | 163 | 164 | ## Style 165 | Since fastai uses a different style from traditional PEP-8, we will be following [Style](https://docs.fast.ai/dev/style.html) 166 | and [Abbreviations](https://docs.fast.ai/dev/abbr.html). Also we will use RL specific abbr. 167 | 168 | | | Concept | Abbr. | Combination Examples | 169 | |:------:|:-------:|:-----:|:--------------------:| 170 | | **RL** | State | st | | 171 | | | Action | acn | | 172 | | | Bounds | bb | Same as Bounding Box | 173 | 174 | ## Examples 175 | 176 | ### Reward Graphs 177 | 178 | | | Model | 179 | |:------------------------------------------:|:---------------:| 180 | | ![01](./res/reward_plots/cartpole_dqn.png) | DQN | 181 | | ![01](./res/reward_plots/cartpole_dueling.png) | Dueling DQN | 182 | | ![01](./res/reward_plots/cartpole_double.png) | Double DQN | 183 | | ![01](./res/reward_plots/cartpole_dddqn.png) | DDDQN | 184 | | ![01](./res/reward_plots/cartpole_fixedtarget.png) | Fixed Target DQN | 185 | | ![01](./res/reward_plots/lunarlander_dqn.png) | DQN | 186 | | ![01](./res/reward_plots/lunarlander_dueling.png) | Dueling DQN | 187 | | ![01](./res/reward_plots/lunarlander_double.png) | Double DQN | 188 | | ![01](./res/reward_plots/lunarlander_dddqn.png) | DDDQN | 189 | | ![01](./res/reward_plots/lunarlander_fixedtarget.png) | Fixed Target DQN | 190 | | ![01](./res/reward_plots/ant_ddpg.png) | DDPG | 191 | | ![01](./res/reward_plots/pendulum_ddpg.png) | DDPG | 192 | | ![01](./res/reward_plots/halfcheetah_ddpg.png) | DDPG | 193 | 194 | 195 | ### Agent Stages 196 | 197 | | Model | Gif(Early) | Gif(Mid) | Gif(Late) | 198 | |:------------:|:------------:|:------------:|:------------:| 199 | | DDPG+PER | ![](./res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif) | ![](./res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif) | ![](./res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif)| 200 | | DoubleDueling+ER | ![](./res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif)| 201 | | DoubleDQN+ER | ![](./res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif)| 202 | | DuelingDQN+ER | ![](./res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif)| 203 | | DoubleDueling+PER | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif)| 204 | | DQN+ER | ![](./res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif) | ![](./res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif)| 205 | | DuelingDQN+PER | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif)| 206 | | DQN+PER | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif)| 207 | | DoubleDQN+PER | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif) | ![](./res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif)| 208 | | DDPG+PER | ![](./res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif) | ![](./res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif) | ![](./res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif)| 209 | | DDPG+ER | ![](./res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_54.gif) | ![](./res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_614.gif) | ![](./res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_999.gif)| 210 | | DQN+PER | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif)| 211 | | FixedTargetDQN+ER | ![](./res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif)| 212 | | DQN+ER | ![](./res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_31.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_207.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_447.gif)| 213 | | FixedTargetDQN+PER | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif)| 214 | | DoubleDQN+ER | ![](./res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif)| 215 | | DoubleDQN+PER | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif)| 216 | | DuelingDQN+ER | ![](./res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif)| 217 | | DoubleDueling+PER | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif)| 218 | | DuelingDQN+PER | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif) | ![](./res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif)| 219 | | DoubleDueling+ER | ![](./res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif) | ![](./res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif)| 220 | | DDPG+ER | ![](./res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif) | ![](./res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif) | ![](./res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif)| 221 | | DDPG+PER | ![](./res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif) | ![](./res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif) | ![](./res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif)| 222 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | - [X] 0.7.0 Full test suite using multi-processing. Connect to CI. 2 | - [X] 0.8.0 Comprehensive model eval **debug/verify**. Each model should succeed at at least a few known environments. Also, massive refactoring will be needed. 3 | - [X] 0.9.0 Notebook demonstrations of basic model usage. 4 | - [X] **1.0.0** Base version is completed with working model visualizations proving performance / expected failure. At 5 | this point, all models should have guaranteed environments they should succeed in. 6 | - [ ] **Working on** 1.1.0 More Traditional RL models 7 | - [ ] **Working on** Add PPO 8 | - [ ] **Working on** Add TRPO 9 | - [ ] Add D4PG 10 | - [ ] Add A2C 11 | - [ ] Add A3C 12 | - [ ] 1.2.0 HRL models *Possibly might change version to 2.0 depending on SMDP issues* 13 | - [ ] Add SMDP 14 | - [ ] Add Goal oriented MDPs. Will Require a new "Step" 15 | - [ ] Add FeUdal Network 16 | - [ ] Add storage based DataBunch memory management. This can prevent RAM from being used up by episode image frames 17 | that may or may not serve any use to the agent, but only for logging. 18 | - [ ] 1.3.0 19 | - [ ] Add HAC 20 | - [ ] Add MAXQ 21 | - [ ] Add HIRO 22 | - [ ] 1.4.0 23 | - [ ] Add h-DQN 24 | - [ ] Add Modulated Policy Hierarchies 25 | - [ ] Add Meta Learning Shared Hierarchies 26 | - [ ] 1.5.0 27 | - [ ] Add STRategic Attentive Writer (STRAW) 28 | - [ ] Add H-DRLN 29 | - [ ] Add Abstract Markov Decision Process (AMDP) 30 | - [ ] 1.6.0 HRL Options models *Possibly will already be implemented in a previous model* 31 | - [ ] Options augmentation to DQN based models 32 | - [ ] Options augmentation to actor critic models 33 | - [ ] Options augmentation to async actor critic models 34 | - [ ] 1.8.0 HRL Skills 35 | - [ ] Skills augmentation to DQN based models 36 | - [ ] Skills augmentation to actor critic models 37 | - [ ] Skills augmentation to async actor critic models 38 | - [ ] 1.9.0 39 | - [ ] 2.0.0 Add PyBullet Fetch Environments 40 | - [ ] 2.0.0 Not part of this repo, however the envs need to subclass the OpenAI `gym.GoalEnv` 41 | - [ ] 2.0.0 Add HER -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Starter pipeline 2 | # Start with a minimal pipeline that you can customize to build and deploy your code. 3 | # Add steps that build, run tests, deploy, and more: 4 | # https://aka.ms/yaml 5 | 6 | # - bash: "sudo apt-get install -y xvfb freeglut3-dev python-opengl --fix-missing" 7 | # displayName: 'Install ffmpeg, freeglut3-dev, and xvfb' 8 | 9 | stages: 10 | - stage: Test 11 | condition: and(always(), eq(variables['Build.Reason'], 'PullRequest')) 12 | jobs: 13 | - job: 'Test' 14 | pool: 15 | vmImage: 'ubuntu-16.04' # other options: 'macOS-10.13', 'vs2017-win2016' 16 | strategy: 17 | matrix: 18 | Python36: 19 | python.version: '3.6' 20 | # Python37: # Pyglet and gym do not working here 21 | # python.version: '3.7.5' 22 | maxParallel: 4 23 | 24 | steps: 25 | - task: UsePythonVersion@0 26 | inputs: 27 | versionSpec: '$(python.version)' 28 | 29 | - bash: "sudo apt-get install -y freeglut3-dev python-opengl" 30 | displayName: 'Install freeglut3-dev' 31 | 32 | - script: | 33 | python -m pip install --upgrade pip setuptools wheel pytest pytest-cov python-xlib -e . 34 | python setup.py install 35 | displayName: 'Install dependencies' 36 | 37 | - script: sh ./build/azure_pipeline_helper.sh 38 | displayName: 'Complex Installs' 39 | 40 | - script: | 41 | xvfb-run -s "-screen 0 1400x900x24" py.test tests --cov fast_rl --cov-report html --doctest-modules --junitxml=junit/test-results.xml --cov=./ --cov-report=xml --cov-report=html 42 | displayName: 'Test with pytest' 43 | 44 | - task: PublishTestResults@2 45 | condition: succeededOrFailed() 46 | inputs: 47 | testResultsFiles: '**/test-*.xml' 48 | testRunTitle: 'Publish test results for Python $(python.version)' 49 | 50 | - stage: Deploy 51 | condition: and(always(), eq(variables['Build.SourceBranch'], 'refs/heads/master')) 52 | jobs: 53 | - job: "TwineDeploy" 54 | pool: 55 | vmImage: 'ubuntu-16.04' # other options: 'macOS-10.13', 'vs2017-win2016' 56 | strategy: 57 | matrix: 58 | Python36: 59 | python.version: '3.6' 60 | steps: 61 | - task: UsePythonVersion@0 62 | inputs: 63 | versionSpec: '$(python.version)' 64 | # Install python distributions like wheel, twine etc 65 | - task: Bash@3 66 | inputs: 67 | targetType: 'inline' 68 | script: | 69 | echo $TWINE_USERNAME 70 | pip install wheel setuptools twine 71 | python setup.py sdist bdist_wheel 72 | python -m twine upload -u $TWINE_USERNAME -p $TWINE_PASSWORD --repository-url 'https://upload.pypi.org/legacy/' dist/* 73 | env: 74 | TWINE_PASSWORD: $(SECRET_TWINE_PASSWORD) 75 | TWINE_USERNAME: $(SECRET_TWINE_USERNAME) -------------------------------------------------------------------------------- /build/azure_pipeline_helper.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## Install pybullet 4 | #git clone https://github.com/benelot/pybullet-gym.git 5 | #cd pybullet-gym 6 | #pip install -e . 7 | #cd ../ 8 | 9 | ## Install gym_maze 10 | #git clone https://github.com/MattChanTK/gym-maze.git 11 | #cd gym-maze 12 | #python setup.py install 13 | #cd ../ 14 | 15 | -------------------------------------------------------------------------------- /docs_src/data/ant_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/ant_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/ant_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/ant_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dddqn/dddqn_er_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_er_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dddqn/dddqn_per_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_per_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_ddqn/ddqn_er_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_er_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_ddqn/ddqn_per_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_per_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_er_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_er_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_per_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_per_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dueling dqn/dueling dqn_er_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_er_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_dueling dqn/dueling dqn_per_rms.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_per_rms.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/cartpole_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/halfcheetah_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/halfcheetah_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/halfcheetah_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/halfcheetah_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/lunarlander_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/pendulum_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/pendulum_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/data/pendulum_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/pendulum_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /docs_src/rl.core.train.interpretation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "pycharm": { 8 | "is_executing": false 9 | } 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stdout", 14 | "output_type": "stream", 15 | "text": [ 16 | "Can't import one of these: No module named 'pybullet'\n", 17 | "Can't import one of these: No module named 'gym_maze'\n", 18 | "Can't import one of these: No module named 'gym_minigrid'\n" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "from fast_rl.agents.dqn import *\n", 24 | "from fast_rl.agents.dqn_models import FixedTargetDQNModule\n", 25 | "from fast_rl.core.agent_core import *\n", 26 | "from fast_rl.core.data_block import *\n", 27 | "from fast_rl.core.train import *" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": { 34 | "pycharm": { 35 | "is_executing": true 36 | } 37 | }, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/html": [ 42 | "\n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | "
epochtrain_lossvalid_losstime
01.095179#na#00:00
11.026340#na#00:00
21.007764#na#00:00
31.001356#na#00:00
40.996845#na#00:00
50.993165#na#00:00
60.988180#na#00:00
70.986040#na#00:00
80.982307#na#00:00
90.976414#na#00:00
" 114 | ], 115 | "text/plain": [ 116 | "" 117 | ] 118 | }, 119 | "metadata": {}, 120 | "output_type": "display_data" 121 | } 122 | ], 123 | "source": [ 124 | "data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=32, add_valid=False, \n", 125 | " memory_management_strategy='k_partitions_top', k=3)\n", 126 | "model = create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop)\n", 127 | "memory = ExperienceReplay(10000)\n", 128 | "exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)\n", 129 | "learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)\n", 130 | "learner.fit(10)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 7, 136 | "metadata": { 137 | "pycharm": { 138 | "is_executing": true, 139 | "name": "#%%\n" 140 | } 141 | }, 142 | "outputs": [ 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "\r", 148 | "t: 0%| | 0/10 [00:00" 179 | ], 180 | "text/plain": [ 181 | "" 182 | ] 183 | }, 184 | "execution_count": 7, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "interp = AgentInterpretation(learner, ds_type=DatasetType.Train)\n", 191 | "interp.generate_gif(2).plot()" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "Python 3", 198 | "language": "python", 199 | "name": "python3" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.6.7" 212 | }, 213 | "pycharm": { 214 | "stem_cell": { 215 | "cell_type": "raw", 216 | "metadata": { 217 | "collapsed": false 218 | }, 219 | "source": [] 220 | } 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 1 225 | } 226 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: fastrl 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - fastai 6 | dependencies: 7 | - python=3.6 8 | - pip 9 | - cuda100 10 | - fastprogress 11 | - fastai=1.0.58 12 | - jupyter 13 | - notebook 14 | - setuptools 15 | - pip: 16 | - pytest 17 | - nvidia-ml-py3 18 | - dataclasses 19 | - numpy==1.17.4 20 | - pandas 21 | - pyyaml 22 | - opencv-contrib-python 23 | - requests 24 | - sklearn 25 | - cython 26 | - gym[box2d, atari] 27 | - easydict 28 | - matplotlib 29 | - jupyter_console 30 | - moviepy 31 | - pygifsicle -------------------------------------------------------------------------------- /fast_rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/__init__.py -------------------------------------------------------------------------------- /fast_rl/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/agents/__init__.py -------------------------------------------------------------------------------- /fast_rl/agents/ddpg.py: -------------------------------------------------------------------------------- 1 | from fastai.basic_train import LearnerCallback 2 | import fastai.tabular.data 3 | from fastai.torch_core import * 4 | 5 | from fast_rl.agents.ddpg_models import DDPGModule 6 | from fast_rl.core.agent_core import ExperienceReplay, ExplorationStrategy, Experience 7 | from fast_rl.core.basic_train import AgentLearner 8 | from fast_rl.core.data_block import MDPDataBunch, MDPStep, FEED_TYPE_STATE, FEED_TYPE_IMAGE 9 | 10 | 11 | class DDPGLearner(AgentLearner): 12 | def __init__(self, data: MDPDataBunch, model, memory, exploration_method, trainers, opt=optim.Adam, 13 | **kwargs): 14 | self.memory: Experience = memory 15 | self.exploration_method: ExplorationStrategy = exploration_method 16 | super().__init__(data=data, model=model, opt=opt, **kwargs) 17 | self.ddpg_trainers = listify(trainers) 18 | for t in self.ddpg_trainers: self.callbacks.append(t(self)) 19 | 20 | def predict(self, element, **kwargs): 21 | with torch.no_grad(): 22 | training = self.model.training 23 | if element.shape[0] == 1: self.model.eval() 24 | pred = self.model(element) 25 | if training: self.model.train() 26 | return self.exploration_method.perturb(pred.detach().cpu().numpy(), self.data.action.action_space) 27 | 28 | def interpret_q(self, item): 29 | with torch.no_grad(): 30 | return self.model.interpret_q(item).cpu().numpy().item() 31 | 32 | 33 | class BaseDDPGTrainer(LearnerCallback): 34 | def __init__(self, learn): 35 | super().__init__(learn) 36 | self.max_episodes = 0 37 | self.episode = 0 38 | self.iteration = 0 39 | self.copy_over_frequency = 3 40 | 41 | @property 42 | def learn(self) -> DDPGLearner: 43 | return self._learn() 44 | 45 | def on_train_begin(self, n_epochs, **kwargs: Any): 46 | self.max_episodes = n_epochs 47 | 48 | def on_epoch_begin(self, epoch, **kwargs: Any): 49 | self.episode = epoch 50 | self.iteration = 0 51 | 52 | def on_loss_begin(self, **kwargs: Any): 53 | """Performs tree updates, exploration updates, and model optimization.""" 54 | if self.learn.model.training: self.learn.memory.update(item=self.learn.data.x.items[-1]) 55 | self.learn.exploration_method.update(self.episode, max_episodes=self.max_episodes, explore=self.learn.model.training) 56 | if not self.learn.warming_up: 57 | samples: List[MDPStep] = self.memory.sample(self.learn.data.bs) 58 | post_optimize = self.learn.model.optimize(samples) 59 | if self.learn.model.training: 60 | self.learn.memory.refresh(post_optimize=post_optimize) 61 | self.learn.model.target_copy_over() 62 | self.iteration += 1 63 | 64 | 65 | def create_ddpg_model(data: MDPDataBunch, base_arch: DDPGModule, layers=None, ignore_embed=False, channels=None, 66 | opt=torch.optim.RMSprop, loss_func=None, **kwargs): 67 | bs, state, action = data.bs, data.state, data.action 68 | nc, w, h, n_conv_blocks = -1, -1, -1, [] if state.mode == FEED_TYPE_STATE else ifnone(channels, [32, 32, 32]) 69 | if state.mode == FEED_TYPE_IMAGE: nc, w, h = state.s.shape[3], state.s.shape[2], state.s.shape[1] 70 | _layers = ifnone(layers, [400, 200] if len(n_conv_blocks) == 0 else [200, 200]) 71 | if ignore_embed or np.any(state.n_possible_values == np.inf) or state.mode == FEED_TYPE_IMAGE: emb_szs = [] 72 | else: emb_szs = [(d+1, int(fastai.tabular.data.emb_sz_rule(d))) for d in state.n_possible_values.reshape(-1, )] 73 | ao = int(action.taken_action.shape[1]) 74 | model = base_arch(ni=state.s.shape[1], ao=ao, layers=_layers, emb_szs=emb_szs, n_conv_blocks=n_conv_blocks, 75 | nc=nc, w=w, h=h, opt=opt, loss_func=loss_func, **kwargs) 76 | return model 77 | 78 | 79 | ddpg_config = { 80 | DDPGModule: BaseDDPGTrainer 81 | } 82 | 83 | 84 | def ddpg_learner(data: MDPDataBunch, model, memory: ExperienceReplay, exploration_method: ExplorationStrategy, 85 | trainers=None, **kwargs): 86 | trainers = ifnone(trainers, ddpg_config[model.__class__]) 87 | return DDPGLearner(data, model, memory, exploration_method, trainers, **kwargs) 88 | -------------------------------------------------------------------------------- /fast_rl/agents/ddpg_models.py: -------------------------------------------------------------------------------- 1 | from fastai.callback import OptimWrapper 2 | 3 | from fast_rl.core.layers import * 4 | 5 | 6 | class CriticModule(nn.Sequential): 7 | def __init__(self, ni: int, ao: int, layers: List[int], batch_norm=False, 8 | n_conv_blocks: Collection[int] = 0, nc=3, emb_szs: ListSizes = None, 9 | w=-1, h=-1, ks=None, stride=None, conv_kern_proportion=0.1, stride_proportion=0.1, pad=False): 10 | super().__init__() 11 | self.switched, self.batch_norm = False, batch_norm 12 | # self.ks, self.stride = ([], []) if len(n_conv_blocks) == 0 else ks_stride(ks, stride, w, h, n_conv_blocks, conv_kern_proportion, stride_proportion) 13 | self.ks, self.stride=([], []) if len(n_conv_blocks)==0 else (ifnone(ks,[10,10,10]),ifnone(stride,[5,5,5])) 14 | self.action_model = nn.Sequential() 15 | _layers = [conv_bn_lrelu(ch, self.nf, ks=ks, stride=stride, pad=pad, bn=self.batch_norm) for ch, self.nf, ks, stride in zip([nc]+n_conv_blocks[:-1],n_conv_blocks, self.ks, self.stride)] 16 | if _layers: ni = self.setup_conv_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h) 17 | else: 18 | self.add_module('lin_state_block', StateActionPassThrough(nn.Linear(ni, layers[0]))) 19 | ni, layers = layers[0], layers[1:] 20 | self.setup_linear_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h, emb_szs=emb_szs, layers=layers, ao=ao) 21 | self.init_weights(self) 22 | 23 | def setup_conv_block(self, _layers, ni, nc, w, h): 24 | self.add_module('conv_block', StateActionPassThrough(nn.Sequential(*(self.fix_switched_channels(ni, nc, _layers) + [Flatten()])))) 25 | return int(self(torch.zeros((2, 1, w, h, nc) if self.switched else (2, 1, nc, w, h)))[0].view(-1, ).shape[0]) 26 | 27 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao): 28 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni+ao if not emb_szs else ao, layers=layers, out_sz=1, 29 | use_bn=self.batch_norm) 30 | if not emb_szs: tabular_model.embeds = None 31 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm() 32 | self.add_module('lin_block', CriticTabularEmbedWrapper(tabular_model, exclude_cat=not emb_szs)) 33 | 34 | def fix_switched_channels(self, current_channels, expected_channels, layers: list): 35 | if current_channels == expected_channels: 36 | return layers 37 | else: 38 | self.switched = True 39 | return [ChannelTranspose()] + layers 40 | 41 | def init_weights(self, m): 42 | if type(m) == nn.Linear: 43 | torch.nn.init.xavier_uniform_(m.weight) 44 | m.bias.data.fill_(0.01) 45 | 46 | 47 | class ActorModule(nn.Sequential): 48 | def __init__(self, ni: int, ao: int, layers: Collection[int],batch_norm = False, 49 | n_conv_blocks: Collection[int] = 0, nc=3, emb_szs: ListSizes = None, 50 | w=-1, h=-1, ks=None, stride=None, conv_kern_proportion=0.1, stride_proportion=0.1, pad=False): 51 | super().__init__() 52 | self.switched, self.batch_norm = False, batch_norm 53 | # self.ks, self.stride = ([], []) if len(n_conv_blocks) == 0 else ks_stride(ks, stride, w, h, n_conv_blocks, conv_kern_proportion, stride_proportion) 54 | self.ks, self.stride=([], []) if len(n_conv_blocks)==0 else (ifnone(ks,[10,10,10]),ifnone(stride,[5,5,5])) 55 | self.action_model = nn.Sequential() 56 | _layers = [conv_bn_lrelu(ch, self.nf, ks=ks, stride=stride, pad=pad, bn=self.batch_norm) for ch, self.nf, ks, stride in zip([nc]+n_conv_blocks[:-1],n_conv_blocks, self.ks, self.stride)] 57 | if _layers: ni = self.setup_conv_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h) 58 | self.setup_linear_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h, emb_szs=emb_szs, layers=layers, ao=ao) 59 | self.init_weights(self) 60 | 61 | def setup_conv_block(self, _layers, ni, nc, w, h): 62 | self.add_module('conv_block', nn.Sequential(*(self.fix_switched_channels(ni, nc, _layers) + [Flatten()]))) 63 | return int(self(torch.zeros((1, w, h, nc) if self.switched else (1, nc, w, h))).view(-1, ).shape[0]) 64 | 65 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao): 66 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni if not emb_szs else 0, layers=layers, out_sz=ao, use_bn=self.batch_norm) 67 | 68 | if not emb_szs: tabular_model.embeds = None 69 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm() 70 | self.add_module('lin_block', TabularEmbedWrapper(tabular_model)) 71 | self.add_module('tanh', nn.Tanh()) 72 | 73 | def fix_switched_channels(self, current_channels, expected_channels, layers: list): 74 | if current_channels == expected_channels: 75 | return layers 76 | else: 77 | self.switched = True 78 | return [ChannelTranspose()] + layers 79 | 80 | def init_weights(self, m): 81 | if type(m) == nn.Linear: 82 | torch.nn.init.xavier_uniform_(m.weight) 83 | m.bias.data.fill_(0.01) 84 | 85 | class DDPGModule(Module): 86 | def __init__(self, ni: int, ao: int, layers: Collection[int], discount: float = 0.99, 87 | n_conv_blocks: Collection[int] = 0, nc=3, opt=None, emb_szs: ListSizes = None, loss_func=None, 88 | w=-1, h=-1, ks=None, stride=None, grad_clip=5, tau=1e-3, lr=1e-3, actor_lr=1e-4, 89 | batch_norm=False, **kwargs): 90 | r""" 91 | Implementation of a discrete control algorithm using an actor/critic architecture. 92 | 93 | Notes: 94 | Uses 4 networks, 2 actors, 2 critics. 95 | All models use batch norm for feature invariance. 96 | NNCritic simply predicts Q while the Actor proposes the actions to take given a s s. 97 | 98 | References: 99 | [1] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning." 100 | arXiv preprint arXiv:1509.02971 (2015). 101 | 102 | Args: 103 | data: Primary data object to use. 104 | memory: How big the tree buffer will be for offline training. 105 | tau: Defines how "soft/hard" we will copy the target networks over to the primary networks. 106 | discount: Determines the amount of discounting the existing Q reward. 107 | lr: Rate that the opt will learn parameter gradients. 108 | """ 109 | super().__init__() 110 | self.name = 'DDPG' 111 | self.lr = lr 112 | self.discount = discount 113 | self.tau = tau 114 | self.loss_func = None 115 | self.loss = None 116 | self.opt = None 117 | self.critic_optimizer = None 118 | self.batch_norm = batch_norm 119 | self.actor_lr = actor_lr 120 | 121 | self.action_model = ActorModule(ni=ni, ao=ao, layers=layers, nc=nc, emb_szs=emb_szs,batch_norm = batch_norm, 122 | w=w, h=h, ks=ks, n_conv_blocks=n_conv_blocks, stride=stride) 123 | self.critic_model = CriticModule(ni=ni, ao=ao, layers=layers, nc=nc, emb_szs=emb_szs, batch_norm = batch_norm, 124 | w=w, h=h, ks=ks, n_conv_blocks=n_conv_blocks, stride=stride) 125 | 126 | self.set_opt(opt) 127 | 128 | self.t_action_model = deepcopy(self.action_model) 129 | self.t_critic_model = deepcopy(self.critic_model) 130 | 131 | self.target_copy_over() 132 | self.tau = tau 133 | 134 | def set_opt(self, opt): 135 | self.opt=OptimWrapper.create(ifnone(optim.Adam, opt), lr=self.actor_lr, layer_groups=[self.action_model]) 136 | self.critic_optimizer=OptimWrapper.create(ifnone(optim.Adam, opt), lr=self.lr, layer_groups=[self.critic_model]) 137 | 138 | def optimize(self, sampled): 139 | r""" 140 | Performs separate updates to the actor and critic models. 141 | 142 | Get the predicted yi for optimizing the actor: 143 | 144 | .. math:: 145 | y_i = r_i + \lambda Q^'(s_{i+1}, \; \mu^'(s_{i+1} \;|\; \Theta^{\mu'}}\;|\; \Theta^{Q'}) 146 | 147 | On actor optimization, use the actor as the sample policy gradient. 148 | 149 | Returns: 150 | 151 | """ 152 | with torch.no_grad(): 153 | r = torch.cat([item.reward.float() for item in sampled]) 154 | s_prime = torch.cat([item.s_prime for item in sampled]) 155 | s = torch.cat([item.s for item in sampled]) 156 | a = torch.cat([item.a.float() for item in sampled]) 157 | 158 | y = r + self.discount * self.t_critic_model((s_prime, self.t_action_model(s_prime))) 159 | 160 | y_hat = self.critic_model((s, a)) 161 | 162 | critic_loss = self.loss_func(y_hat, y) 163 | 164 | if self.training: 165 | self.critic_optimizer.zero_grad() 166 | critic_loss.backward() 167 | self.critic_optimizer.step() 168 | 169 | actor_loss = -self.critic_model((s, self.action_model(s))).mean() 170 | 171 | self.loss = critic_loss.cpu().detach() 172 | 173 | if self.training: 174 | self.opt.zero_grad() 175 | actor_loss.backward() 176 | self.opt.step() 177 | 178 | with torch.no_grad(): 179 | post_info = {'td_error': (y - y_hat).cpu().numpy()} 180 | return post_info 181 | 182 | def forward(self, xi): 183 | training = self.training 184 | if xi.shape[0] == 1: self.eval() 185 | pred = self.action_model(xi) 186 | if training: self.train() 187 | return pred 188 | 189 | def target_copy_over(self): 190 | """ Soft target updates the actor and critic models..""" 191 | self.soft_target_copy_over(self.t_action_model, self.action_model, self.tau) 192 | self.soft_target_copy_over(self.t_critic_model, self.critic_model, self.tau) 193 | 194 | def soft_target_copy_over(self, t_m, f_m, tau): 195 | for target_param, local_param in zip(t_m.parameters(), f_m.parameters()): 196 | target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) 197 | 198 | def interpret_q(self, item): 199 | with torch.no_grad(): 200 | return self.critic_model(torch.cat((item.s, item.a), 1)) -------------------------------------------------------------------------------- /fast_rl/agents/dqn.py: -------------------------------------------------------------------------------- 1 | from fastai.basic_train import LearnerCallback 2 | from fastai.tabular.data import emb_sz_rule 3 | 4 | from fast_rl.agents.dqn_models import * 5 | from fast_rl.core.agent_core import ExperienceReplay, ExplorationStrategy, Experience 6 | from fast_rl.core.basic_train import AgentLearner 7 | from fast_rl.core.data_block import MDPDataBunch, FEED_TYPE_STATE, FEED_TYPE_IMAGE, MDPStep 8 | 9 | 10 | class DQNLearner(AgentLearner): 11 | def __init__(self, data: MDPDataBunch, model, memory, exploration_method, trainers, opt=torch.optim.RMSprop, 12 | **learn_kwargs): 13 | self.memory: Experience = memory 14 | self.exploration_method: ExplorationStrategy = exploration_method 15 | super().__init__(data=data, model=model, opt=opt, **learn_kwargs) 16 | self.trainers = listify(trainers) 17 | for t in self.trainers: self.callbacks.append(t(self)) 18 | 19 | def predict(self, element, **kwargs): 20 | training = self.model.training 21 | if element.shape[0] == 1: self.model.eval() 22 | pred = self.model(element) 23 | if training: self.model.train() 24 | return self.exploration_method.perturb(torch.argmax(pred, 1), self.data.action.action_space) 25 | 26 | def interpret_q(self, item): 27 | with torch.no_grad(): 28 | return torch.sum(self.model(item.s)).cpu().numpy().item() 29 | 30 | 31 | class FixedTargetDQNTrainer(LearnerCallback): 32 | def __init__(self, learn, copy_over_frequency=3): 33 | r"""Handles updating the target model in a fixed target DQN. 34 | 35 | Args: 36 | learn: Basic Learner. 37 | copy_over_frequency: For every N iterations we want to update the target model. 38 | """ 39 | super().__init__(learn) 40 | self._order = 1 41 | self.iteration = 0 42 | self.copy_over_frequency = copy_over_frequency 43 | 44 | def on_step_end(self, **kwargs: Any): 45 | self.iteration += 1 46 | if self.iteration % self.copy_over_frequency == 0 and self.learn.model.training: 47 | self.learn.model.target_copy_over() 48 | 49 | 50 | class BaseDQNTrainer(LearnerCallback): 51 | def __init__(self, learn: DQNLearner, max_episodes=None): 52 | r"""Handles basic DQN end of step model optimization.""" 53 | super().__init__(learn) 54 | self.n_skipped = 0 55 | self._persist = max_episodes is not None 56 | self.max_episodes = max_episodes 57 | self.episode = -1 58 | self.iteration = 0 59 | # For the callback handler 60 | self._order = 0 61 | self.previous_item = None 62 | 63 | @property 64 | def learn(self) -> DQNLearner: 65 | return self._learn() 66 | 67 | def on_train_begin(self, n_epochs, **kwargs: Any): 68 | self.max_episodes = n_epochs if not self._persist else self.max_episodes 69 | 70 | def on_epoch_begin(self, epoch, **kwargs: Any): 71 | self.episode = epoch if not self._persist else self.episode + 1 72 | self.iteration = 0 73 | 74 | def on_loss_begin(self, **kwargs: Any): 75 | r"""Performs tree updates, exploration updates, and model optimization.""" 76 | if self.learn.model.training: self.learn.memory.update(item=self.learn.data.x.items[-1]) 77 | self.learn.exploration_method.update(self.episode, max_episodes=self.max_episodes, explore=self.learn.model.training) 78 | if not self.learn.warming_up: 79 | samples: List[MDPStep] = self.memory.sample(self.learn.data.bs) 80 | post_optimize = self.learn.model.optimize(samples) 81 | if self.learn.model.training: self.learn.memory.refresh(post_optimize=post_optimize) 82 | self.iteration += 1 83 | 84 | 85 | def create_dqn_model(data: MDPDataBunch, base_arch: DQNModule, layers=None, ignore_embed=False, channels=None, 86 | opt=torch.optim.RMSprop, loss_func=None, lr=0.001, **kwargs): 87 | bs,state,action=data.bs,data.state,data.action 88 | nc, w, h, n_conv_blocks = -1, -1, -1, [] if state.mode == FEED_TYPE_STATE else ifnone(channels, [16, 16, 16]) 89 | if state.mode == FEED_TYPE_IMAGE: nc, w, h = state.s.shape[3], state.s.shape[2], state.s.shape[1] 90 | _layers = ifnone(layers, [64, 64]) 91 | if ignore_embed or np.any(state.n_possible_values == np.inf) or state.mode == FEED_TYPE_IMAGE: emb_szs = [] 92 | else: emb_szs = [(d+1, int(emb_sz_rule(d))) for d in state.n_possible_values.reshape(-1, )] 93 | ao = int(action.n_possible_values[0]) 94 | model = base_arch(ni=state.s.shape[1], ao=ao, layers=_layers, emb_szs=emb_szs, n_conv_blocks=n_conv_blocks, 95 | nc=nc, w=w, h=h, opt=opt, loss_func=loss_func, lr=lr, **kwargs) 96 | return model 97 | 98 | 99 | dqn_config = { 100 | DQNModule: [BaseDQNTrainer], 101 | DoubleDQNModule: [BaseDQNTrainer, FixedTargetDQNTrainer], 102 | DuelingDQNModule: [BaseDQNTrainer, FixedTargetDQNTrainer], 103 | DoubleDuelingModule: [BaseDQNTrainer, FixedTargetDQNTrainer], 104 | FixedTargetDQNModule: [BaseDQNTrainer, FixedTargetDQNTrainer] 105 | } 106 | 107 | 108 | def dqn_learner(data: MDPDataBunch, model: DQNModule, memory: ExperienceReplay, exploration_method: ExplorationStrategy, 109 | trainers=None, copy_over_frequency=300, **kwargs): 110 | trainers = ifnone(trainers, [c if c != FixedTargetDQNTrainer else partial(c, copy_over_frequency=copy_over_frequency) 111 | for c in dqn_config[model.__class__]]) 112 | return DQNLearner(data, model, memory, exploration_method, trainers, **kwargs) 113 | -------------------------------------------------------------------------------- /fast_rl/agents/dqn_models.py: -------------------------------------------------------------------------------- 1 | from fastai.callback import OptimWrapper 2 | 3 | from fast_rl.core.layers import * 4 | 5 | 6 | class DQNModule(Module): 7 | 8 | def __init__(self, ni: int, ao: int, layers: Collection[int], discount: float = 0.99, lr=0.001, 9 | n_conv_blocks: Collection[int] = 0, nc=3, opt=None, emb_szs: ListSizes = None, loss_func=None, 10 | w=-1, h=-1, ks: Union[None, list]=None, stride: Union[None, list]=None, grad_clip=5, 11 | conv_kern_proportion=0.1, stride_proportion=0.1, pad=False, batch_norm=False): 12 | r""" 13 | Basic DQN Module. 14 | 15 | Args: 16 | ni: Number of inputs. Expecting a flat state `[1 x ni]` 17 | ao: Number of actions to output. 18 | layers: Number of layers where is determined per element. 19 | n_conv_blocks: If `n_conv_blocks` is not 0, then convolutional blocks will be added 20 | to the head on top of existing linear layers. 21 | nc: Number of channels that will be expected by the convolutional blocks. 22 | """ 23 | super().__init__() 24 | self.name = 'DQN' 25 | self.loss = None 26 | self.loss_func = loss_func 27 | self.discount = discount 28 | self.gradient_clipping_norm = grad_clip 29 | self.lr = lr 30 | self.batch_norm = batch_norm 31 | self.switched = False 32 | # self.ks, self.stride = ([], []) if len(n_conv_blocks) == 0 else ks_stride(ks, stride, w, h, n_conv_blocks, conv_kern_proportion, stride_proportion) 33 | self.ks, self.stride=([], []) if len(n_conv_blocks)==0 else (ifnone(ks, [10, 10, 10]), ifnone(stride, [5, 5, 5])) 34 | self.action_model = nn.Sequential() 35 | _layers = [conv_bn_lrelu(ch, self.nf, ks=ks, stride=stride, pad=pad, bn=self.batch_norm) for ch, self.nf, ks, stride in zip([nc]+n_conv_blocks[:-1],n_conv_blocks, self.ks, self.stride)] 36 | 37 | if _layers: ni = self.setup_conv_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h) 38 | self.setup_linear_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h, emb_szs=emb_szs, layers=layers, ao=ao) 39 | self.init_weights(self.action_model) 40 | self.opt = None 41 | self.set_opt(opt) 42 | 43 | def set_opt(self, opt): 44 | self.opt=OptimWrapper.create(ifnone(optim.Adam, opt), lr=self.lr, layer_groups=[self.action_model]) 45 | 46 | def setup_conv_block(self, _layers, ni, nc, w, h): 47 | self.action_model.add_module('conv_block', nn.Sequential(*(self.fix_switched_channels(ni, nc, _layers) + [Flatten()]))) 48 | training = self.action_model.training 49 | self.action_model.eval() 50 | ni = int(self.action_model(torch.zeros((1, w, h, nc) if self.switched else (1, nc, w, h))).view(-1, ).shape[0]) 51 | self.action_model.train(training) 52 | return ni 53 | 54 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao): 55 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni if not emb_szs else 0, layers=layers, out_sz=ao, use_bn=self.batch_norm) 56 | if not emb_szs: tabular_model.embeds = None 57 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm() 58 | self.action_model.add_module('lin_block', TabularEmbedWrapper(tabular_model)) 59 | 60 | def fix_switched_channels(self, current_channels, expected_channels, layers: list): 61 | if current_channels == expected_channels: 62 | return layers 63 | else: 64 | self.switched = True 65 | return [ChannelTranspose()] + layers 66 | 67 | def forward(self, xi: Tensor): 68 | training = self.training 69 | if xi.shape[0] == 1: self.eval() 70 | pred = self.action_model(xi) 71 | if training: self.train() 72 | return pred 73 | 74 | def init_weights(self, m): 75 | if type(m) == nn.Linear: 76 | torch.nn.init.xavier_uniform_(m.weight) 77 | m.bias.data.fill_(0.01) 78 | 79 | def sample_mask(self, d): 80 | return torch.sub(1.0, d) 81 | 82 | def optimize(self, sampled): 83 | r"""Uses ER to optimize the Q-net (without fixed targets). 84 | 85 | Uses the equation: 86 | 87 | .. math:: 88 | Q^{*}(s, a) = \mathbb{E}_{s'∼ \Big\epsilon} \Big[r + \lambda \displaystyle\max_{a'}(Q^{*}(s' , a')) 89 | \;|\; s, a \Big] 90 | 91 | 92 | Returns (dict): Optimization information 93 | 94 | """ 95 | with torch.no_grad(): 96 | r = torch.cat([item.reward.float() for item in sampled]) 97 | s_prime = torch.cat([item.s_prime for item in sampled]) 98 | s = torch.cat([item.s for item in sampled]) 99 | a = torch.cat([item.a.long() for item in sampled]) 100 | d = torch.cat([item.done.float() for item in sampled]) 101 | masking = self.sample_mask(d) 102 | 103 | y_hat = self.y_hat(s, a) 104 | y = self.y(s_prime, masking, r, y_hat) 105 | 106 | loss = self.loss_func(y, y_hat) 107 | 108 | if self.training: 109 | self.opt.zero_grad() 110 | loss.backward() 111 | torch.nn.utils.clip_grad_norm_(self.action_model.parameters(), self.gradient_clipping_norm) 112 | for param in self.action_model.parameters(): 113 | if param.grad is not None: param.grad.data.clamp_(-1, 1) 114 | self.opt.step() 115 | 116 | with torch.no_grad(): 117 | self.loss = loss 118 | post_info = {'td_error': to_detach(y - y_hat).cpu().numpy()} 119 | return post_info 120 | 121 | def y_hat(self, s, a): 122 | return self.action_model(s).gather(1, a) 123 | 124 | def y(self, s_prime, masking, r, y_hat): 125 | return self.discount * self.action_model(s_prime).max(1)[0].unsqueeze(1) * masking + r.expand_as(y_hat) 126 | 127 | 128 | class FixedTargetDQNModule(DQNModule): 129 | def __init__(self, ni: int, ao: int, layers: Collection[int], tau=1, **kwargs): 130 | super().__init__(ni, ao, layers, **kwargs) 131 | self.name = 'Fixed Target DQN' 132 | self.tau = tau 133 | self.target_model = copy(self.action_model) 134 | 135 | def target_copy_over(self): 136 | r""" Updates the target network from calls in the FixedTargetDQNTrainer callback.""" 137 | # self.target_net.load_state_dict(self.action_model.state_dict()) 138 | for target_param, local_param in zip(self.target_model.parameters(), self.action_model.parameters()): 139 | target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data) 140 | 141 | def y(self, s_prime, masking, r, y_hat): 142 | r""" 143 | Uses the equation: 144 | 145 | .. math:: 146 | 147 | Q^{*}(s, a) = \mathbb{E}_{s'∼ \Big\epsilon} \Big[r + \lambda \displaystyle\max_{a'}(Q^{*}(s' , a')) 148 | \;|\; s, a \Big] 149 | 150 | """ 151 | return self.discount * self.target_model(s_prime).max(1)[0].unsqueeze(1) * masking + r.expand_as(y_hat) 152 | 153 | 154 | class DoubleDQNModule(FixedTargetDQNModule): 155 | def __init__(self, ni: int, ao: int, layers: Collection[int], **kwargs): 156 | super().__init__(ni, ao, layers, **kwargs) 157 | self.name = 'DDQN' 158 | 159 | def calc_y(self, s_prime, masking, r, y_hat): 160 | return self.discount * self.target_model(s_prime).gather(1, self.action_model(s_prime).argmax(1).unsqueeze( 161 | 1)) * masking + r.expand_as(y_hat) 162 | 163 | 164 | class DuelingBlock(nn.Module): 165 | def __init__(self, ao, stream_input_size): 166 | super().__init__() 167 | 168 | self.val = nn.Linear(stream_input_size, 1) 169 | self.adv = nn.Linear(stream_input_size, ao) 170 | 171 | def forward(self, xi): 172 | r"""Splits the base neural net output into 2 streams to evaluate the advantage and v of the s space and 173 | corresponding actions. 174 | 175 | .. math:: 176 | Q(s,a;\; \Theta, \\alpha, \\beta) = V(s;\; \Theta, \\beta) + A(s, a;\; \Theta, \\alpha) - \\frac{1}{|A|} 177 | \\Big\\sum_{a'} A(s, a';\; \Theta, \\alpha) 178 | 179 | """ 180 | val, adv = self.val(xi), self.adv(xi) 181 | xi = val.expand_as(adv) + (adv - adv.mean()).squeeze(0) 182 | return xi 183 | 184 | 185 | class DuelingDQNModule(FixedTargetDQNModule): 186 | def __init__(self, **kwargs): 187 | super().__init__(**kwargs) 188 | self.name = 'Dueling DQN' 189 | 190 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao): 191 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni if not emb_szs else 0, layers=layers, out_sz=ao, 192 | use_bn=self.batch_norm) 193 | if not emb_szs: tabular_model.embeds = None 194 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm() 195 | tabular_model.layers, removed_layer = split_model(tabular_model.layers, [last_layer(tabular_model)]) 196 | ni = removed_layer[0].in_features 197 | self.action_model.add_module('lin_block', TabularEmbedWrapper(tabular_model)) 198 | self.action_model.add_module('dueling_block', DuelingBlock(ao, ni)) 199 | 200 | 201 | class DoubleDuelingModule(DuelingDQNModule, DoubleDQNModule): 202 | def __init__(self, **kwargs): 203 | super().__init__(**kwargs) 204 | self.name = 'DDDQN' 205 | -------------------------------------------------------------------------------- /fast_rl/core/Interpreter.py: -------------------------------------------------------------------------------- 1 | import io 2 | from functools import partial 3 | from typing import List, Tuple, Dict 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import scipy.stats as st 8 | import torch 9 | from PIL import Image 10 | from fastai.train import Interpretation, DatasetType, copy 11 | from gym.spaces import Box 12 | from itertools import product 13 | from matplotlib.axes import Axes 14 | from matplotlib.figure import Figure 15 | from moviepy.video.VideoClip import VideoClip 16 | from moviepy.video.io.bindings import mplfig_to_npimage 17 | from torch import nn 18 | 19 | from fast_rl.core import Learner 20 | from fast_rl.core.data_block import MarkovDecisionProcessSliceAlpha, FEED_TYPE_IMAGE 21 | 22 | 23 | class AgentInterpretationAlpha(Interpretation): 24 | def __init__(self, learn: Learner, ds_type: DatasetType = DatasetType.Valid, base_chart_size=(20, 10)): 25 | """ 26 | Handles converting a learner, and it's runs into useful human interpretable information. 27 | 28 | Notes: 29 | This class is called AgentInterpretationAlpha because it will overall get deprecated. 30 | The final working version will be called AgentInterpretation. 31 | 32 | Args: 33 | learn: 34 | """ 35 | super().__init__(learn, None, None, None, ds_type=ds_type) 36 | self.current_animation = None 37 | plt.rcParams["figure.figsize"] = base_chart_size 38 | 39 | def _get_items(self, ignore=True): 40 | episodes = list(self.ds.x.info.keys()) 41 | if ignore or len(episodes) == 0: return self.ds.x.items 42 | return [item for item in self.ds.x.items if item.episode in episodes] 43 | 44 | @classmethod 45 | def from_learner(cls, learn: Learner, ds_type: DatasetType = DatasetType.Valid, activ: nn.Module = None): 46 | raise NotImplementedError 47 | 48 | def normalize(self, item: np.array): 49 | if np.max(item) - np.min(item) != 0: 50 | return np.divide(item + np.min(item), np.max(item) - np.min(item)) 51 | else: 52 | item.fill(1) 53 | return item 54 | 55 | def top_losses(self, k: int = None, largest=True): 56 | raise NotImplementedError 57 | 58 | def reward_heatmap(self, episode_slices: List[MarkovDecisionProcessSliceAlpha], action=None): 59 | """ 60 | Takes a state_space and uses the agent to heat map rewards over the space. 61 | 62 | We first need to determine if the s space is discrete or discrete. 63 | 64 | Args: 65 | state_space: 66 | 67 | Returns: 68 | 69 | """ 70 | if action is not None: action = torch.tensor(action).long() 71 | current_state_slice = [p for p in product( 72 | np.arange(min(self.ds.env.observation_space.low), max(self.ds.env.observation_space.high) + 1), 73 | repeat=len(self.ds.env.observation_space.high))] 74 | heat_map = np.zeros(np.add(self.ds.env.observation_space.high, 1)) 75 | with torch.no_grad(): 76 | for state in current_state_slice: 77 | if action is not None: 78 | heat_map[state] = self.learn.model(torch.from_numpy(np.array(state)).unsqueeze(0))[0].gather(0, action) 79 | else: 80 | self.learn.model.eval() 81 | if self.learn.model.name == 'DDPG': 82 | heat_map[state] = self.learn.model.critic_model(torch.cat((torch.from_numpy(np.array(state)).unsqueeze(0).float(), self.learn.model.action_model(torch.from_numpy(np.array(state)).unsqueeze(0).float())), 1)) 83 | else: 84 | heat_map[state] = self.learn.model(torch.from_numpy(np.array(state)).unsqueeze(0))[0].max().numpy() 85 | return heat_map 86 | 87 | def plot_heatmapped_episode(self, episode, fig_size=(13, 5), action_index=None, return_heat_maps=False): 88 | """ 89 | Generates plots of heatmapped s spaces for analyzing reward distribution. 90 | 91 | Currently only makes sense for grid based envs. Will be expecting gym_maze environments that are discrete. 92 | 93 | Returns: 94 | 95 | """ 96 | if not str(self.ds.env.spec).__contains__('maze'): 97 | raise NotImplementedError('Currently only supports gym_maze envs that have discrete s spaces') 98 | if not isinstance(self.ds.state_size, Box): 99 | raise NotImplementedError('Currently only supports Box based s spaces with 2 dimensions') 100 | 101 | items = self._get_items() 102 | heat_maps = [] 103 | 104 | # For each episode 105 | buffer = [] 106 | episode = episode if episode != -1 else list(set([i.episode for i in items]))[-1] 107 | for item in [i for i in items if i.episode == episode]: 108 | buffer.append(item) 109 | heat_map = self.reward_heatmap(buffer, action=action_index) 110 | heat_maps.append((copy(heat_map), copy(buffer[-1]), copy(episode))) 111 | 112 | plots = [] 113 | for single_heatmap in [heat_maps[-1]]: 114 | fig, ax = plt.subplots(1, 2, figsize=fig_size) 115 | fig.suptitle(f'Episode {episode}') 116 | ax[0].imshow(single_heatmap[1].to_one().data) 117 | im = ax[1].imshow(single_heatmap[0]) 118 | ax[0].grid(False) 119 | ax[1].grid(False) 120 | ax[0].set_title('Final State Snapshot') 121 | ax[1].set_title('State Space Heatmap') 122 | fig.colorbar(im, ax=ax[1]) 123 | 124 | buf = io.BytesIO() 125 | fig.savefig(buf, format='png') 126 | # Closing the figure prevents it from being displayed directly inside 127 | # the notebook. 128 | plt.close(fig) 129 | buf.seek(0) 130 | # Create Image object 131 | plots.append(np.array(Image.open(buf))[:, :, :3]) 132 | 133 | for plot in plots: 134 | plt.grid(False) 135 | plt.xticks([]) 136 | plt.yticks([]) 137 | plt.tight_layout() 138 | plt.imshow(plot) 139 | plt.show() 140 | 141 | if return_heat_maps: return heat_maps 142 | 143 | def plot_episode(self, episode): 144 | items = self._get_items(False) # type: List[MarkovDecisionProcessSliceAlpha] 145 | 146 | episode_counter = 0 147 | # For each episode 148 | buffer = [] 149 | for item in items: 150 | buffer.append(item) 151 | if item.done: 152 | if episode_counter == episode: 153 | break 154 | episode_counter += 1 155 | buffer = [] 156 | 157 | plots = [] 158 | with torch.no_grad(): 159 | agent_reward_plots = [self.learn.model(torch.from_numpy(np.array(i.current_state))).max().numpy() for i in 160 | buffer] 161 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 162 | fig.suptitle(f'Episode {episode}') 163 | ax.plot(agent_reward_plots) 164 | ax.set_xlabel('Time Steps') 165 | ax.set_ylabel('Max Expected Reward from Agent') 166 | 167 | buf = io.BytesIO() 168 | fig.savefig(buf, format='png') 169 | # Closing the figure prevents it from being displayed directly inside 170 | # the notebook. 171 | plt.close(fig) 172 | buf.seek(0) 173 | # Create Image object 174 | plots.append(np.array(Image.open(buf))[:, :, :3]) 175 | 176 | for plot in plots: 177 | plt.grid(False) 178 | plt.xticks([]) 179 | plt.yticks([]) 180 | plt.tight_layout() 181 | plt.imshow(plot) 182 | plt.show() 183 | 184 | def get_agent_accuracy_density(self, items, episode_num=None): 185 | x = None 186 | y = None 187 | 188 | for episode in [_ for _ in list(set(mdp.episode for mdp in items)) if episode_num is None or episode_num == _]: 189 | subset = [item for item in items if item.episode == episode] 190 | state = np.array([_.current_state for _ in subset]) 191 | result_state = np.array([_.result_state for _ in subset]) 192 | 193 | prim_q_pred = self.learn.model(torch.from_numpy(state)) 194 | target_q_pred = self.learn.model.target_net(torch.from_numpy(state).float()) 195 | state_difference = (prim_q_pred - target_q_pred).sum(1) 196 | prim_q_pred = self.learn.model(torch.from_numpy(result_state)) 197 | target_q_pred = self.learn.model.target_net(torch.from_numpy(result_state).float()) 198 | result_state_difference = (prim_q_pred - target_q_pred).sum(1) 199 | 200 | x = state_difference if x is None else np.hstack((x, state_difference)) 201 | y = result_state_difference if y is None else np.hstack((y, result_state_difference)) 202 | 203 | return x, y 204 | 205 | def plot_agent_accuracy_density(self, episode_num=None): 206 | """ 207 | Heat maps the density of actual vs estimated q v. Good reference for this is at [1]. 208 | 209 | References: 210 | [1] "Simple Example Of 2D Density Plots In Python." Medium. N. p., 2019. Web. 31 Aug. 2019. 211 | https://towardsdatascience.com/simple-example-of-2d-density-plots-in-python-83b83b934f67 212 | 213 | Returns: 214 | 215 | """ 216 | items = self._get_items(False) # type: List[MarkovDecisionProcessSliceAlpha] 217 | x, y = self.get_agent_accuracy_density(items, episode_num) 218 | 219 | fig = plt.figure(figsize=(8, 8)) 220 | ax = fig.gca() 221 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}') 222 | ax.set_ylabel('State / State Prime Q Value Deviation') 223 | ax.set_xlabel('Iterations') 224 | ax.plot(np.hstack([x, y])) 225 | plt.show() 226 | 227 | def get_q_density(self, items, episode_num=None): 228 | x = None 229 | y = None 230 | 231 | for episode in [_ for _ in list(set(mdp.episode for mdp in items)) if episode_num is None or episode_num == _]: 232 | subset = [item for item in items if item.episode == episode] 233 | r = np.array([_.reward for _ in subset]) 234 | # Gets the total accumulated r over a single markov chain 235 | actual_returns = np.flip([np.cumsum(r)[i:][0] for i in np.flip(np.arange(len(r)))]).reshape(1, -1) 236 | estimated_returns = self.learn.model.interpret_q(subset).view(1, -1).numpy() 237 | x = actual_returns if x is None else np.hstack((x, actual_returns)) 238 | y = estimated_returns if y is None else np.hstack((y, estimated_returns)) 239 | 240 | return self.normalize(x), self.normalize(y) 241 | 242 | def plot_q_density(self, episode_num=None): 243 | """ 244 | Heat maps the density of actual vs estimated q v. Good reference for this is at [1]. 245 | 246 | References: 247 | [1] "Simple Example Of 2D Density Plots In Python." Medium. N. p., 2019. Web. 31 Aug. 2019. 248 | https://towardsdatascience.com/simple-example-of-2d-density-plots-in-python-83b83b934f67 249 | 250 | Returns: 251 | 252 | """ 253 | items = self._get_items(False) # type: List[MarkovDecisionProcessSliceAlpha] 254 | x, y = self.get_q_density(items, episode_num) 255 | 256 | # Define the borders 257 | deltaX = (np.max(x) - np.min(x)) / 10 258 | deltaY = (np.max(y) - np.min(y)) / 10 259 | xmin = np.min(x) - deltaX 260 | xmax = np.max(x) + deltaX 261 | ymin = np.min(y) - deltaY 262 | ymax = np.max(y) + deltaY 263 | print(xmin, xmax, ymin, ymax) 264 | # Create meshgrid 265 | xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j] 266 | 267 | positions = np.vstack([xx.ravel(), yy.ravel()]) 268 | values = np.vstack([x, y]) 269 | 270 | kernel = st.gaussian_kde(values) 271 | 272 | f = np.reshape(kernel(positions).T, xx.shape) 273 | 274 | fig = plt.figure(figsize=(8, 8)) 275 | ax = fig.gca() 276 | ax.set_xlim(xmin, xmax) 277 | ax.set_ylim(ymin, ymax) 278 | cfset = ax.contourf(xx, yy, f, cmap='coolwarm') 279 | ax.imshow(np.rot90(f), cmap='coolwarm', extent=[xmin, xmax, ymin, ymax]) 280 | cset = ax.contour(xx, yy, f, colors='k') 281 | ax.clabel(cset, inline=1, fontsize=10) 282 | ax.set_xlabel('Actual Returns') 283 | ax.set_ylabel('Estimated Q') 284 | if episode_num is None: 285 | plt.title('2D Gaussian Kernel Q Density Estimation') 286 | else: 287 | plt.title(f'2D Gaussian Kernel Q Density Estimation for episode {episode_num}') 288 | plt.show() 289 | 290 | def plot_rewards_over_iterations(self, cumulative=False, return_rewards=False): 291 | items = self._get_items() 292 | r_iter = [el.reward[0] if np.ndim(el.reward) == 0 else np.average(el.reward) for el in items] 293 | if cumulative: r_iter = np.cumsum(r_iter) 294 | fig = plt.figure(figsize=(8, 8)) 295 | ax = fig.gca() 296 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}') 297 | ax.set_ylabel('Rewards' if not cumulative else 'Cumulative Rewards') 298 | ax.set_xlabel('Iterations') 299 | ax.plot(r_iter) 300 | plt.show() 301 | if return_rewards: return r_iter 302 | 303 | def plot_rewards_over_episodes(self, cumulative=False, fig_size=(8, 8)): 304 | items = self._get_items() 305 | r_iter = [(el.reward[0] if np.ndim(el.reward) == 0 else np.average(el.reward), el.episode) for el in items] 306 | rewards, episodes = zip(*r_iter) 307 | if cumulative: rewards = np.cumsum(rewards) 308 | fig = plt.figure(figsize=(8, 8)) 309 | ax = fig.gca() 310 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}') 311 | ax.set_ylabel('Rewards' if not cumulative else 'Cumulative Rewards') 312 | ax.set_xlabel('Episodes') 313 | ax.xaxis.set_ticks([i for i, el in enumerate(episodes) if episodes[i - 1] != el or i == 0]) 314 | ax.xaxis.set_ticklabels([el for i, el in enumerate(episodes) if episodes[i - 1] != el or i == 0]) 315 | ax.plot(rewards) 316 | plt.show() 317 | 318 | def episode_video_frames(self, episode=None) -> Dict[str, np.array]: 319 | """ Returns numpy arrays representing purely episode frames. """ 320 | items = self._get_items(False) 321 | if episode is None: episode_frames = {key: None for key in list(set([_.episode for _ in items]))} 322 | else: episode_frames = {episode: None} 323 | 324 | for key in episode_frames: 325 | if self.ds.feed_type == FEED_TYPE_IMAGE: 326 | episode_frames[key] = np.array([_.current_state for _ in items if key == _.episode]) 327 | else: 328 | episode_frames[key] = np.array([_.alternate_state for _ in items if key == _.episode]) 329 | 330 | return episode_frames 331 | 332 | def episode_to_gif(self, episode=None, path='', fps=30): 333 | frames = self.episode_video_frames(episode) 334 | 335 | for ep in frames: 336 | fig, ax = plt.subplots() 337 | animation = VideoClip(partial(self._make_frame, frames=frames[ep], axes=ax, fig=fig, title=f'Episode {ep}'), 338 | duration=frames[ep].shape[0]) 339 | animation.write_gif(path + f'episode_{ep}.gif', fps=fps) 340 | 341 | def _make_frame(self, t, frames, axes, fig, title): 342 | axes.clear() 343 | fig.suptitle(title) 344 | axes.imshow(frames[int(t)]) 345 | return mplfig_to_npimage(fig) 346 | 347 | def iplot_episode(self, episode, fps=30): 348 | if episode is None: raise ValueError('The episode cannot be None for jupyter display') 349 | x = self.episode_video_frames(episode)[episode] 350 | fig, ax = plt.subplots() 351 | 352 | self.current_animation = VideoClip(partial(self._make_frame, frames=x, axes=ax, fig=fig, 353 | title=f'Episode {episode}'), duration=x.shape[0]) 354 | self.current_animation.ipython_display(fps=fps, loop=True, autoplay=True) 355 | 356 | def get_memory_samples(self, batch_size=None, key='reward'): 357 | samples = self.learn.model.memory.sample(self.learn.model.batch_size if batch_size is None else batch_size) 358 | if not samples: raise IndexError('Your tree seems empty.') 359 | if batch_size is not None and batch_size > len(self.learn.model.memory): 360 | raise IndexError(f'Your batch size {batch_size} > the tree\'s batch size {len(self.learn.model.memory)}') 361 | if key not in samples[0].obj.keys(): raise ValueError(f'Key {key} not in {samples[0].obj.keys()}') 362 | return [s.obj[key] for s in samples] 363 | 364 | def plot_memory_samples(self, batch_size=None, key='reward', fig_size=(8, 8)): 365 | values_of_interest = self.get_memory_samples(batch_size, key) 366 | fig = plt.figure(figsize=fig_size) 367 | ax = fig.gca() 368 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}') 369 | ax.set_ylabel(key) 370 | ax.set_xlabel('Values') 371 | ax.plot(values_of_interest) 372 | plt.show() -------------------------------------------------------------------------------- /fast_rl/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/core/__init__.py -------------------------------------------------------------------------------- /fast_rl/core/agent_core.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from math import ceil 3 | 4 | import gym 5 | from fastai.basic_train import * 6 | from fastai.torch_core import * 7 | 8 | from fast_rl.core.data_structures import SumTree 9 | 10 | 11 | class ExplorationStrategy: 12 | def __init__(self, explore: bool = True): self.explore=explore 13 | def update(self, max_episodes, explore, **kwargs): self.explore=explore 14 | def perturb(self, action, action_space) -> np.ndarray: 15 | """ 16 | Base method just returns the action. Subclass, and change to return randomly / augmented actions. 17 | 18 | Should use `do_exploration` field. It is recommended that when you subclass / overload, you allow this field 19 | to completely bypass these actions. 20 | 21 | Args: 22 | action (np.array): Action input as a regular numpy array. 23 | action_space (gym.Space): The original gym space. Should contain information on the action type, and 24 | possible convenience methods for random action selection. 25 | """ 26 | _=action_space 27 | return action 28 | 29 | 30 | class GreedyEpsilon(ExplorationStrategy): 31 | def __init__(self, epsilon_start, epsilon_end, decay, start_episode=0, end_episode=0, **kwargs): 32 | super().__init__(**kwargs) 33 | self.end_episode=end_episode 34 | self.start_episode=start_episode 35 | self.decay=decay 36 | self.e_end=epsilon_end 37 | self.e_start=epsilon_start 38 | self.epsilon=self.e_start 39 | self.steps=0 40 | 41 | def perturb(self, action, action_space: gym.Space): 42 | return action_space.sample() if np.random.random() 0 (uniform) 148 | epsilon (float): Keeps the probabilities of items from being 0 149 | memory_size (int): Max N samples to store 150 | """ 151 | super().__init__(memory_size, **kwargs) 152 | self.batch_size=batch_size 153 | self.alpha=alpha 154 | self.beta=beta 155 | self.b_inc=b_inc 156 | self.p_weights=None # np.zeros(self.batch_size, dtype=float) 157 | self.epsilon=epsilon 158 | self.tree=SumTree(self.max_size) 159 | self.callbacks=[PriorityExperienceReplayCallback] 160 | # When sampled, store the sample indices for refresh. 161 | self._indices=None # np.zeros(self.batch_size, dtype=int) 162 | 163 | @property 164 | def memory(self): 165 | return self.tree.data 166 | 167 | def __len__(self): 168 | return self.tree.n_entries 169 | 170 | def refresh(self, post_optimize, **kwargs): 171 | if post_optimize is not None: 172 | self.tree.update(self._indices.astype(int), np.abs(post_optimize['td_error'])+self.epsilon) 173 | 174 | def sample(self, batch, **kwargs): 175 | self.beta=np.min([1., self.beta+self.b_inc]) 176 | ranges=np.linspace(0, ceil(self.tree.total()/batch), num=batch+1) 177 | uniform_ranges=[np.random.uniform(ranges[i], ranges[i+1]) for i in range(len(ranges)-1)] 178 | try: 179 | self._indices, weights, samples=self.tree.batch_get(uniform_ranges) 180 | except ValueError: 181 | warn('Too few values to unpack. Your batch size is too small, when PER queries tree, all 0 values get' 182 | ' ignored. We will retry until we can return at least one sample.') 183 | samples=self.sample(batch) 184 | return samples 185 | 186 | self.p_weights=self.tree.anneal_weights(weights, self.beta) 187 | return samples 188 | 189 | def update(self, item, **kwargs): 190 | """ 191 | Updates the tree of PER. 192 | 193 | Assigns maximal priority per [1] Alg:1, thus guaranteeing that sample being visited once. 194 | 195 | Args: 196 | item: 197 | 198 | Returns: 199 | 200 | """ 201 | item=deepcopy(item) 202 | super().update(item, **kwargs) 203 | maximal_priority=self.alpha 204 | if self.reduce_ram: item.clean() 205 | self.tree.add(np.abs(maximal_priority)+self.epsilon, item) 206 | 207 | 208 | # class HindsightExperienceReplay(Experience): 209 | # def __init__(self, memory_size): 210 | # """ 211 | # 212 | # References: 213 | # [1] Andrychowicz, Marcin, et al. "Hindsight experience replay." 214 | # Advances in Neural Information Processing Systems. 2017. 215 | # 216 | # Args: 217 | # memory_size: 218 | # """ 219 | # super().__init__(memory_size) 220 | -------------------------------------------------------------------------------- /fast_rl/core/basic_train.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.pool import Pool 2 | 3 | from fastai.basic_train import Learner, load_callback 4 | from fastai.torch_core import * 5 | 6 | from fast_rl.core.data_block import MDPDataBunch 7 | 8 | 9 | class WrapperLossFunc(object): 10 | def __init__(self, learn): 11 | self.learn = learn 12 | 13 | def __call__(self, *args, **kwargs): 14 | return self.learn.model.loss 15 | 16 | 17 | 18 | def load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', **db_kwargs): 19 | r""" Similar to fastai `load_learner`, handles load_state for data differently. """ 20 | source = Path(path)/file if is_pathlike(file) else file 21 | state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source) 22 | model = state.pop('model') 23 | data = MDPDataBunch.load_state(path, state.pop('data')) 24 | # if test is not None: src.add_test(test) 25 | # data = src.databunch(**db_kwargs) 26 | cb_state = state.pop('cb_state') 27 | clas_func = state.pop('cls') 28 | res = clas_func(data, model, **state) 29 | res.callback_fns = state['callback_fns'] #to avoid duplicates 30 | res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()] 31 | return res 32 | 33 | 34 | class AgentLearner(Learner): 35 | 36 | def __init__(self, data, loss_func=None, callback_fns=None, opt=torch.optim.Adam, **kwargs): 37 | super().__init__(data=data, callback_fns=ifnone(callback_fns, []) + data.callback, **kwargs) 38 | self.model.loss_func = ifnone(loss_func, F.mse_loss) 39 | self.model.set_opt(opt) 40 | self.loss_func = None 41 | self.trainers = None 42 | self._loss_func = WrapperLossFunc(self) 43 | 44 | @property 45 | def warming_up(self): 46 | return self.data.bs > len(self.data.x) 47 | 48 | def init_loss_func(self): 49 | r""" 50 | Initializes the loss function wrapper for logging loss. 51 | 52 | Since most RL models have a period of warming up such as filling tree buffers, we cannot log any loss. 53 | By default, the learner will have a `None` loss function, and so the fit function will not try to log that 54 | loss. 55 | """ 56 | self.loss_func = WrapperLossFunc(self) 57 | 58 | def export(self, file:PathLikeOrBinaryStream='export.pkl', destroy=False, pickle_data=False): 59 | "Export the state of the `Learner` in `self.path/file`. `file` can be file-like (file or buffer)" 60 | if rank_distrib(): return # don't save if slave proc 61 | # For now we exclude the 'loss_func' since it is pointing to a model loss. 62 | args = ['opt_func', 'metrics', 'true_wd', 'bn_wd', 'wd', 'train_bn', 'model_dir', 'callback_fns', 'memory', 63 | 'exploration_method', 'trainers'] 64 | state = {a:getattr(self,a) for a in args} 65 | state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks} 66 | #layer_groups -> need to find a way 67 | #TO SEE: do we save model structure and weights separately? 68 | with ModelOnCPU(self.model) as m: 69 | m.opt = None 70 | state['model'] = m 71 | xtra = dict(normalize=self.data.norm.keywords) if getattr(self.data, 'norm', False) else {} 72 | state['data'] = self.data.train_ds.get_state(**xtra) if self.data.valid_dl is None else self.data.valid_ds.get_state(**xtra) 73 | state['data']['add_valid'] = not self.data.empty_val 74 | if pickle_data: self.data.to_pickle(self.data.path) 75 | state['cls'] = self.__class__ 76 | try_save(state, self.path, file) 77 | if destroy: self.destroy() 78 | 79 | def interpret_q(self, xi): 80 | raise NotImplemented 81 | 82 | 83 | class PipeLine(object): 84 | def __init__(self, n_threads, pipe_line_function): 85 | warn(Warning('Currently not super useful. Seems to have issues with running a single env in multiple threads.')) 86 | self.pipe_line_function = pipe_line_function 87 | self.n_threads = n_threads 88 | self.pool = Pool(self.n_threads) 89 | 90 | def start(self, n_runs): 91 | return self.pool.map(self.pipe_line_function, range(n_runs)) -------------------------------------------------------------------------------- /fast_rl/core/data_structures.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 4 | Notices: 5 | [1] SumTree implementation belongs to: https://github.com/rlcode/per 6 | As of 8/23/2019, does not have a license provided. As another note, this code is modified. 7 | 8 | 9 | """ 10 | 11 | import numpy as np 12 | 13 | 14 | class SumTree(object): 15 | write = 0 16 | 17 | def __init__(self, capacity): 18 | """ 19 | Used for PER. 20 | 21 | References: 22 | [1] SumTree implementation belongs to: https://github.com/rlcode/per 23 | 24 | Notes: 25 | As of 8/23/2019, does not have a license provided. As another note, this code is modified. 26 | 27 | 28 | Args: 29 | capacity: 30 | """ 31 | 32 | self.capacity = capacity 33 | self.tree = np.zeros(2 * capacity - 1) 34 | self.data = np.zeros(capacity, dtype=object) 35 | self.n_entries = 0 36 | 37 | def _propagate(self, idx, change): 38 | """ Update to the root node """ 39 | parent = (idx - 1) // 2 40 | 41 | self.tree[parent] += change 42 | 43 | if (np.isscalar(parent) and parent != 0) or (not np.isscalar(parent) and all(parent != 0)): 44 | if not np.isscalar(parent): change[parent == 0] = 0 45 | self._propagate(parent, change) 46 | 47 | def get_left(self, index): 48 | return 2 * index + 1 49 | 50 | def get_right(self, index): 51 | return self.get_left(index) + 1 52 | 53 | def _retrieve(self, idx, s): 54 | """ Finds sample on leaf node """ 55 | left = self.get_left(idx) 56 | right = self.get_right(idx) 57 | 58 | if left >= len(self.tree): 59 | return idx 60 | 61 | if s <= self.tree[left]: 62 | return self._retrieve(left, s) 63 | else: 64 | return self._retrieve(right, s - self.tree[left]) 65 | 66 | def total(self): 67 | return self.tree[0] 68 | 69 | def add(self, p, data): 70 | """ Store priority and sample """ 71 | idx = self.write + self.capacity - 1 72 | 73 | self.data[self.write] = data 74 | self.update(idx, p) 75 | 76 | self.write += 1 77 | if self.write >= self.capacity: 78 | self.write = 0 79 | 80 | if self.n_entries < self.capacity: 81 | self.n_entries += 1 82 | 83 | def update(self, idx, p): 84 | """ Update priority """ 85 | p = p.flatten() if not np.isscalar(p) else p 86 | change = p - self.tree[idx] 87 | 88 | self.tree[idx] = p 89 | self._propagate(idx, change) 90 | 91 | def get(self, s): 92 | """ Get priority and sample """ 93 | idx = self._retrieve(0, s) 94 | data_index = idx - self.capacity + 1 95 | 96 | return idx, self.tree[idx], self.data[data_index] 97 | 98 | def anneal_weights(self, priorities, beta): 99 | sampling_probabilities = priorities / self.total() 100 | is_weight = np.power(self.n_entries * sampling_probabilities, -beta) 101 | is_weight /= is_weight.max() 102 | return is_weight.astype(float) 103 | 104 | def batch_get(self, ss): 105 | return np.array(list(zip(*list([self.get(s) for s in ss if self.get(s)[2] != 0])))) 106 | 107 | 108 | def print_tree(tree: SumTree): 109 | print('\n') 110 | if tree.n_entries == 0: 111 | print('empty') 112 | return 113 | 114 | max_d = int(np.log2(len(tree.tree))) 115 | string_len_max = len(str(tree.tree[-1])) 116 | 117 | tree_strings = [] 118 | display_values = None 119 | display_indexes = None 120 | for layer in range(max_d + 1): 121 | # Get the indexes in the current layer d 122 | if display_indexes is None: 123 | display_indexes = [[0]] 124 | else: 125 | local_list = [] 126 | for i in [_ for _ in display_indexes[-1] if _ < len(tree.tree)]: 127 | if tree.get_left(i) < len(tree.tree): local_list.append(tree.get_left(i)) 128 | if tree.get_right(i) < len(tree.tree): local_list.append(tree.get_right(i)) 129 | display_indexes.append(local_list) 130 | 131 | for layer in display_indexes: 132 | # Get the v contained in current layer d 133 | if display_values is None: 134 | display_values = [[tree.tree[i] for i in layer]] 135 | else: 136 | display_values.append([tree.tree[i] for i in layer]) 137 | 138 | tab_sizes = [] 139 | spacings = [] 140 | for i, layer in enumerate(display_values): 141 | # for now ignore string length 142 | tab_sizes.append(0 if i == 0 else (tab_sizes[-1] + 1) * 2) 143 | spacings.append(3 if i == 0 else (spacings[-1] * 2 + 1)) 144 | 145 | for i, layer in enumerate(display_values): 146 | # tree_strings.append('*' * list(reversed(tab_sizes))[i]) 147 | values = ''.join(str(v) + ' ' * (string_len_max * list(reversed(spacings))[i]) for v in layer) 148 | tree_strings.append(' ' * (string_len_max * list(reversed(tab_sizes))[i]) + values) 149 | 150 | for tree_string in tree_strings: 151 | print(tree_string) 152 | -------------------------------------------------------------------------------- /fast_rl/core/layers.py: -------------------------------------------------------------------------------- 1 | r"""`fast_rl.layers` provides essential functions to building and modifying `model` architectures""" 2 | from math import ceil 3 | 4 | from fastai.torch_core import * 5 | from fastai.tabular import TabularModel 6 | 7 | 8 | def init_cnn(mod: Any): 9 | r""" Utility for initializing cnn Modules. """ 10 | if getattr(mod, 'bias', None) is not None: nn.init.constant_(mod.bias, 0) 11 | if isinstance(mod, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(mod.weight) 12 | for sub_mod in mod.children(): init_cnn(sub_mod) 13 | 14 | 15 | def ks_stride(ks, stride, w, h, n_blocks, kern_proportion=.1, stride_proportion=0.3): 16 | r""" Utility for determing the the kernel size and stride. """ 17 | kernels, strides, max_dim = [], [], max((w, h)) 18 | for i in range(len(n_blocks)): 19 | kernels.append(max_dim * kern_proportion) 20 | strides.append(kernels[-1] * stride_proportion) 21 | max_dim = (max_dim - kernels[-1]) / strides[-1] 22 | assert max_dim > 1 23 | 24 | return ifnone(ks, map(ceil, kernels)), ifnone(stride, map(ceil, strides)) 25 | 26 | 27 | class Flatten(nn.Module): 28 | def forward(self, y): return y.view(y.size(0), -1) 29 | 30 | 31 | class FakeBatchNorm(Module): 32 | r""" If we want all the batch norm layers gone, then we will replace the tabular batch norm with this. """ 33 | def forward(self, xi: Tensor, *args): return xi 34 | 35 | 36 | def conv_bn_lrelu(ni: int, nf: int, ks: int = 3, stride: int = 1, pad=True, bn=True) -> nn.Sequential: 37 | r""" Create a sequence Conv2d->BatchNorm2d->LeakyReLu layer. (from darknet.py). Allows excluding BatchNorm2d Layer.""" 38 | return nn.Sequential( 39 | nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=(ks // 2) if pad else 0), 40 | nn.BatchNorm2d(nf) if bn else FakeBatchNorm(), 41 | nn.LeakyReLU(negative_slope=0.1, inplace=True)) 42 | 43 | 44 | class ChannelTranspose(Module): 45 | r""" Runtime image input channel changing. Useful for handling different image channel outputs from different envs. """ 46 | def forward(self, xi: Tensor): 47 | return xi.transpose(3, 1).transpose(3, 2) 48 | 49 | 50 | class StateActionSplitter(Module): 51 | r""" `Actor / Critic` models require breaking the state and action into 2 streams. """ 52 | 53 | def forward(self, s_a_tuple: Tuple[Tensor]): 54 | r""" Returns tensors as -> (State Tensor, Action Tensor) """ 55 | return s_a_tuple[0], s_a_tuple[1] 56 | 57 | 58 | class StateActionPassThrough(nn.Module): 59 | r""" Passes action input untouched, but runs the state tensors through a sub module. """ 60 | def __init__(self, layers): 61 | super().__init__() 62 | self.layers = layers 63 | 64 | def forward(self, state_action): 65 | return self.layers(state_action[0]), state_action[1] 66 | 67 | 68 | class TabularEmbedWrapper(Module): 69 | r""" Basic `TabularModel` compatibility wrapper. Typically, state inputs will be either categorical or continuous. """ 70 | def __init__(self, tabular_model: TabularModel): 71 | super().__init__() 72 | self.tabular_model = tabular_model 73 | 74 | def forward(self, xi: Tensor, *args): 75 | return self.tabular_model(xi, xi) 76 | 77 | 78 | class CriticTabularEmbedWrapper(Module): 79 | r""" Similar to `TabularEmbedWrapper` but assumes input is state / action and requires concatenation. """ 80 | def __init__(self, tabular_model: TabularModel, exclude_cat): 81 | super().__init__() 82 | self.tabular_model = tabular_model 83 | self.exclude_cat = exclude_cat 84 | 85 | def forward(self, args): 86 | if not self.exclude_cat: 87 | return self.tabular_model(*args) 88 | else: 89 | return self.tabular_model(0, torch.cat(args, 1)) 90 | -------------------------------------------------------------------------------- /fast_rl/core/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fastai.basic_train import LearnerCallback, Any 3 | from fastai.callback import Callback, is_listy, add_metrics 4 | 5 | 6 | class EpsilonMetric(LearnerCallback): 7 | _order = -20 # Needs to run before the recorder 8 | 9 | def __init__(self, learn): 10 | super().__init__(learn) 11 | self.epsilon = 0 12 | if not hasattr(self.learn, 'exploration_method'): 13 | raise ValueError('Your model is not using an exploration strategy! Please use epsilon based exploration') 14 | if not hasattr(self.learn.exploration_method, 'epsilon'): 15 | raise ValueError('Please use epsilon based exploration (should have an epsilon field)') 16 | 17 | # noinspection PyUnresolvedReferences 18 | def on_train_begin(self, **kwargs): 19 | self.learn.recorder.add_metric_names(['epsilon']) 20 | 21 | def on_epoch_end(self, last_metrics, **kwargs): 22 | self.epsilon = self.learn.exploration_method.epsilon 23 | if last_metrics and last_metrics[-1] is None: del last_metrics[-1] 24 | return add_metrics(last_metrics, [float(self.epsilon)]) 25 | 26 | class RewardMetric(LearnerCallback): 27 | _order = -20 28 | 29 | def __init__(self, learn): 30 | super().__init__(learn) 31 | self.train_reward, self.valid_reward = [], [] 32 | 33 | def on_epoch_begin(self, **kwargs:Any): 34 | self.train_reward, self.valid_reward = [], [] 35 | 36 | def on_batch_end(self, **kwargs: Any): 37 | if self.learn.model.training: self.train_reward.append(self.learn.data.train_ds.item.reward.cpu().numpy()[0][0]) 38 | elif not self.learn.recorder.no_val: self.valid_reward.append(self.learn.data.valid_ds.item.reward.cpu().numpy()[0][0]) 39 | 40 | def on_train_begin(self, **kwargs): 41 | metric_names = ['train_reward'] if self.learn.recorder.no_val else ['train_reward', 'valid_reward'] 42 | self.learn.recorder.add_metric_names(metric_names) 43 | 44 | def on_epoch_end(self, last_metrics, **kwargs: Any): 45 | return add_metrics(last_metrics, [sum(self.train_reward), sum(self.valid_reward)]) 46 | -------------------------------------------------------------------------------- /fast_rl/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/util/__init__.py -------------------------------------------------------------------------------- /fast_rl/util/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class MaxEpisodeStepsMissingError(Exception): 4 | pass -------------------------------------------------------------------------------- /fast_rl/util/misc.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error 3 | 4 | 5 | class b_colors: 6 | HEADER = '\033[95m' 7 | OKBLUE = '\033[94m' 8 | OKGREEN = '\033[92m' 9 | WARNING = '\033[93m' 10 | FAIL = '\033[91m' 11 | ENDC = '\033[0m' 12 | BOLD = '\033[1m' 13 | UNDERLINE = '\033[4m' 14 | 15 | 16 | def list_in_str(s: str, str_list: list, make_lower=True): 17 | if make_lower: s = s.lower() 18 | return any([s.__contains__(el) for el in str_list]) 19 | 20 | 21 | def is_goal_env(env, suppress_errors=True): 22 | msg = 'GoalEnv requires the "{}" key to be part of the observation d.' 23 | # Enforce that each GoalEnv uses a Goal-compatible observation space. 24 | if not isinstance(env.observation_space, gym.spaces.Dict): 25 | if not suppress_errors: 26 | raise error.Error('GoalEnv requires an observation space of type gym.spaces.Dict') 27 | else: 28 | return False 29 | for key in ['observation', 'achieved_goal', 'desired_goal']: 30 | if key not in env.observation_space.spaces: 31 | if not suppress_errors: 32 | raise error.Error(msg.format(key)) 33 | else: 34 | return False 35 | return True 36 | -------------------------------------------------------------------------------- /res/RELEASE_BLOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/RELEASE_BLOG.md -------------------------------------------------------------------------------- /res/ddpg_balancing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/ddpg_balancing.gif -------------------------------------------------------------------------------- /res/dqn_q_estimate_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/dqn_q_estimate_1.jpg -------------------------------------------------------------------------------- /res/dqn_q_estimate_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/dqn_q_estimate_2.jpg -------------------------------------------------------------------------------- /res/dqn_q_estimate_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/dqn_q_estimate_3.jpg -------------------------------------------------------------------------------- /res/fit_func_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/fit_func_out.jpg -------------------------------------------------------------------------------- /res/heat_map_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_1.png -------------------------------------------------------------------------------- /res/heat_map_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_2.png -------------------------------------------------------------------------------- /res/heat_map_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_3.png -------------------------------------------------------------------------------- /res/heat_map_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_4.png -------------------------------------------------------------------------------- /res/pre_interpretation_maze_dqn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/pre_interpretation_maze_dqn.gif -------------------------------------------------------------------------------- /res/reward_plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plot_1.png -------------------------------------------------------------------------------- /res/reward_plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plot_2.png -------------------------------------------------------------------------------- /res/reward_plots/ant_ddpg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/ant_ddpg.png -------------------------------------------------------------------------------- /res/reward_plots/cartpole_dddqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_dddqn.png -------------------------------------------------------------------------------- /res/reward_plots/cartpole_double.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_double.png -------------------------------------------------------------------------------- /res/reward_plots/cartpole_dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_dqn.png -------------------------------------------------------------------------------- /res/reward_plots/cartpole_dueling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_dueling.png -------------------------------------------------------------------------------- /res/reward_plots/cartpole_fixedtarget.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_fixedtarget.png -------------------------------------------------------------------------------- /res/reward_plots/halfcheetah_ddpg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/halfcheetah_ddpg.png -------------------------------------------------------------------------------- /res/reward_plots/lunarlander_all_targetbased.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_all_targetbased.png -------------------------------------------------------------------------------- /res/reward_plots/lunarlander_dddqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_dddqn.png -------------------------------------------------------------------------------- /res/reward_plots/lunarlander_double.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_double.png -------------------------------------------------------------------------------- /res/reward_plots/lunarlander_dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_dqn.png -------------------------------------------------------------------------------- /res/reward_plots/lunarlander_dueling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_dueling.png -------------------------------------------------------------------------------- /res/reward_plots/lunarlander_fixedtarget.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_fixedtarget.png -------------------------------------------------------------------------------- /res/reward_plots/pendulum_ddpg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/pendulum_ddpg.png -------------------------------------------------------------------------------- /res/run_gifs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/__init__.py -------------------------------------------------------------------------------- /res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif -------------------------------------------------------------------------------- /res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif -------------------------------------------------------------------------------- /res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif -------------------------------------------------------------------------------- /res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif -------------------------------------------------------------------------------- /res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif -------------------------------------------------------------------------------- /res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif -------------------------------------------------------------------------------- /res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_54.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_54.gif -------------------------------------------------------------------------------- /res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_614.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_614.gif -------------------------------------------------------------------------------- /res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_999.gif -------------------------------------------------------------------------------- /res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif -------------------------------------------------------------------------------- /res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif -------------------------------------------------------------------------------- /res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_207.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_207.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_31.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_31.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_447.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_447.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif -------------------------------------------------------------------------------- /res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif -------------------------------------------------------------------------------- /res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif -------------------------------------------------------------------------------- /res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_238.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_238.gif -------------------------------------------------------------------------------- /res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_447.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_447.gif -------------------------------------------------------------------------------- /res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_9.gif -------------------------------------------------------------------------------- /res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif -------------------------------------------------------------------------------- /res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif -------------------------------------------------------------------------------- /res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | VERSION = "1.0.1" 7 | 8 | 9 | setup(name='fast_rl', 10 | version=VERSION, 11 | description='Fastai for computer vision and tabular learning has been amazing. One would wish that this would ' 12 | 'be the same for RL. The purpose of this repo is to have a framework that is as easy as possible to ' 13 | 'start, but also designed for testing new agents. ', 14 | url='https://github.com/josiahls/fast-reinforcement-learning', 15 | author='Josiah Laivins', 16 | author_email='jlaivins@uncc.edu', 17 | python_requires='>=3.6', 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | license='', 21 | packages=find_packages(), 22 | zip_safe=False, 23 | install_requires=['fastai>=1.0.59', 'gym[box2d, atari]', 'jupyter'], 24 | extras_require={'all': [ 25 | 'gym-minigrid', 26 | 'moviepy' 27 | # 'gym_maze @ git+https://github.com/MattChanTK/gym-maze.git', 28 | # 'pybullet-gym @ git+https://github.com/benelot/pybullet-gym.git' 29 | ]}, 30 | classifiers=[ 31 | "Development Status :: 3 - Alpha", 32 | "Programming Language :: Python :: 3", 33 | "License :: OSI Approved :: Apache Software License", 34 | "Operating System :: OS Independent", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_addoption(parser): 5 | parser.addoption("--include_performance_tests", action="store_true", 6 | help="Will run the performance tests which do full model testing. This could take a few" 7 | "days to fully accomplish.") 8 | 9 | @pytest.fixture() 10 | def include_performance_tests(pytestconfig): 11 | return pytestconfig.getoption("include_performance_tests") 12 | 13 | 14 | @pytest.fixture() 15 | def skip_performance_check(include_performance_tests): 16 | if not include_performance_tests: 17 | pytest.skip('Skipping due to performance argument not specified. Add --include_performance_tests to not skip') 18 | -------------------------------------------------------------------------------- /tests/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle -------------------------------------------------------------------------------- /tests/data/cat/cat1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/cat/cat1.jpeg -------------------------------------------------------------------------------- /tests/data/cat/cat2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/cat/cat2.jpeg -------------------------------------------------------------------------------- /tests/data/dog/dog1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/dog/dog1.jpeg -------------------------------------------------------------------------------- /tests/data/dog/dog2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/dog/dog2.jpeg -------------------------------------------------------------------------------- /tests/test_agent_core.py: -------------------------------------------------------------------------------- 1 | 2 | # import pytest 3 | # 4 | # from fast_rl.core.data_block import MDPDataBunch 5 | # from fast_rl.core.agent_core import PriorityExperienceReplay 6 | # from fast_rl.core.basic_train import AgentLearner 7 | # 8 | 9 | -------------------------------------------------------------------------------- /tests/test_basic_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | 5 | from fast_rl.agents.dqn import create_dqn_model, FixedTargetDQNModule, dqn_learner 6 | from fast_rl.core.agent_core import ExperienceReplay, torch, GreedyEpsilon 7 | from fast_rl.core.basic_train import load_learner 8 | from fast_rl.core.data_block import MDPDataBunch 9 | 10 | 11 | def test_fit(): 12 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False) 13 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop) 14 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True) 15 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 16 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 17 | learner.fit(2) 18 | learner.fit(2) 19 | learner.fit(2) 20 | 21 | assert len(data.x.info)==6 22 | assert 0 in data.x.info 23 | assert 5 in data.x.info 24 | 25 | 26 | def test_to_pickle(): 27 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False) 28 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop) 29 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True) 30 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 31 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 32 | learner.fit(2) 33 | 34 | assert len(data.x.info)==2 35 | assert 0 in data.x.info 36 | assert 1 in data.x.info 37 | 38 | data.to_pickle('./data/test_to_pickle') 39 | assert os.path.exists('./data/test_to_pickle_CartPole-v0') 40 | 41 | 42 | def test_from_pickle(): 43 | data=MDPDataBunch.from_pickle('./data/test_to_pickle_CartPole-v0') 44 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop) 45 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True) 46 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 47 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 48 | learner.fit(2) 49 | 50 | assert len(data.x.info)==4 51 | assert 0 in data.x.info 52 | assert 3 in data.x.info 53 | 54 | 55 | def test_export_learner(): 56 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False) 57 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop) 58 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True) 59 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 60 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 61 | learner.fit(2) 62 | 63 | learner.export('test_export.pkl')#, pickle_data=True) 64 | learner = load_learner(learner.path, 'test_export.pkl') 65 | learner.fit(2) 66 | -------------------------------------------------------------------------------- /tests/test_data_block.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from itertools import product 4 | 5 | import gym 6 | import pytest 7 | import numpy as np 8 | import torch 9 | from fastai.basic_train import ItemLists 10 | 11 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner 12 | from fast_rl.agents.dqn_models import DQNModule 13 | from fast_rl.core.agent_core import GreedyEpsilon, ExperienceReplay 14 | from fast_rl.core.data_block import MDPDataBunch, ResolutionWrapper, FEED_TYPE_IMAGE 15 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric 16 | 17 | 18 | def validate_item_list(item_list: ItemLists): 19 | # Check items 20 | for i, item in enumerate(item_list.items): 21 | if item.done: assert not item_list.items[ 22 | i - 1].done, f'The dataset has duplicate "done\'s" that are consecutive.' 23 | assert item.state.s is not None, f'The item: {item}\'s state is None' 24 | assert item.state.s_prime is not None, f'The item: {item}\'s state prime is None' 25 | 26 | 27 | @pytest.mark.parametrize(["memory_strategy", "k"], list(product(['k_top', 'k_partitions_top'], [1, 3, 5]))) 28 | def test_dataset_memory_manager(memory_strategy, k): 29 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False, 30 | memory_management_strategy=memory_strategy, k=k) 31 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1) 32 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True) 33 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 34 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 35 | callback_fns=[RewardMetric, EpsilonMetric]) 36 | learner.fit(10) 37 | 38 | data_info = {episode: data.train_ds.x.info[episode] for episode in data.train_ds.x.info if episode != -1} 39 | full_episodes = [episode for episode in data_info if not data_info[episode][1]] 40 | 41 | assert sum([not _[1] for _ in data_info.values()]) == k, 'There should be k episodes but there is not.' 42 | if memory_strategy.__contains__('top') and not memory_strategy.__contains__('both'): 43 | assert (np.argmax([_[0] for _ in data_info.values()])) in full_episodes 44 | 45 | 46 | def test_databunch_to_pickle(): 47 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False, 48 | memory_management_strategy='k_partitions_top', k=3) 49 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1) 50 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True) 51 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 52 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 53 | callback_fns=[RewardMetric, EpsilonMetric]) 54 | learner.fit(10) 55 | data.to_pickle('./data/cartpole_10_epoch') 56 | MDPDataBunch.from_pickle(env_name='CartPole-v0', path='./data/cartpole_10_epoch') 57 | 58 | 59 | def test_resolution_wrapper(): 60 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=10, add_valid=False, 61 | memory_management_strategy='k_top', k=1, feed_type=FEED_TYPE_IMAGE, 62 | res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2)) 63 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1,channels=[32,32,32],ks=[5,5,5],stride=[2,2,2]) 64 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True) 65 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 66 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 67 | callback_fns=[RewardMetric, EpsilonMetric]) 68 | learner.fit(2) 69 | temp = gym.make('CartPole-v0') 70 | temp.reset() 71 | original_shape = temp.render(mode='rgb_array').shape 72 | assert data.env.render(mode='rgb_array').shape == (original_shape[0] // 2, original_shape[1] // 2, 3) 73 | -------------------------------------------------------------------------------- /tests/test_data_structures.py: -------------------------------------------------------------------------------- 1 | from fast_rl.core.data_structures import print_tree, SumTree 2 | 3 | 4 | def test_sum_tree_with_max_size(): 5 | memory = SumTree(10) 6 | 7 | values = [1, 1, 1, 1, 1, 1] 8 | data = [f'data with priority: {i}' for i in values] 9 | 10 | for element, value in zip(data, values): 11 | memory.add(value, element) 12 | 13 | print_tree(memory) 14 | 15 | 16 | -------------------------------------------------------------------------------- /tests/test_ddpg.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import product 3 | 4 | import pytest 5 | from fastai.basic_train import torch, DatasetType 6 | from fastai.core import ifnone 7 | 8 | from fast_rl.agents.ddpg import create_ddpg_model, ddpg_learner, DDPGLearner 9 | from fast_rl.agents.ddpg_models import DDPGModule 10 | from fast_rl.core.agent_core import ExperienceReplay, PriorityExperienceReplay, OrnsteinUhlenbeck 11 | from fast_rl.core.data_block import FEED_TYPE_STATE, MDPDataBunch, ResolutionWrapper 12 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric 13 | from fast_rl.core.train import GroupAgentInterpretation, AgentInterpretation 14 | 15 | p_model=[DDPGModule] 16 | p_exp=[ExperienceReplay, PriorityExperienceReplay] 17 | p_format=[FEED_TYPE_STATE] # , FEED_TYPE_IMAGE] 18 | p_full_format=[FEED_TYPE_STATE] 19 | p_envs=['Pendulum-v0'] 20 | 21 | config_env_expectations={ 22 | 'Pendulum-v0': {'action_shape': (1, 1), 'state_shape': (1, 3)} 23 | } 24 | 25 | 26 | def trained_learner(model_cls, env, s_format, experience, bs=64,layers=None,render='rgb_array', memory_size=1000000, 27 | decay=0.0001,lr=None,actor_lr=None,epochs=450,opt=torch.optim.RMSprop, **kwargs): 28 | lr,actor_lr=ifnone(lr,1e-3),ifnone(actor_lr,1e-4) 29 | data=MDPDataBunch.from_env(env,render=render,bs=bs,add_valid=False,keep_env_open=False,feed_type=s_format, 30 | memory_management_strategy='k_partitions_top',k=3,**kwargs) 31 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape,epsilon_start=1,epsilon_end=0.1, 32 | decay=decay) 33 | memory=experience(memory_size=memory_size, reduce_ram=True) 34 | model=create_ddpg_model(data=data,base_arch=model_cls,lr=lr,actor_lr=actor_lr,layers=layers,opt=opt) 35 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 36 | callback_fns=[RewardMetric, EpsilonMetric]) 37 | learner.fit(epochs) 38 | return learner 39 | 40 | def check_shape(env,data,s_format): 41 | assert config_env_expectations[env]['action_shape']==(1, data.action.taken_action.shape[1]) 42 | if s_format==FEED_TYPE_STATE: 43 | assert config_env_expectations[env]['state_shape']==data.state.s.shape 44 | 45 | def learner2gif(lnr:DDPGLearner,s_format,group_interp:GroupAgentInterpretation,name:str,extra:str): 46 | meta=f'{lnr.memory.__class__.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 47 | interp=AgentInterpretation(lnr, ds_type=DatasetType.Train) 48 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 49 | group_interp.add_interpretation(interp) 50 | group_interp.to_pickle(f'../docs_src/data/{name}_{lnr.model.name.lower()}/', f'{lnr.model.name.lower()}_{meta}') 51 | [g.write(f'../res/run_gifs/{name}_{extra}') for g in interp.generate_gif()] 52 | 53 | 54 | @pytest.mark.parametrize(["model_cls", "s_format", "env"], list(product(p_model, p_format, p_envs))) 55 | def test_ddpg_create_ddpg_model(model_cls, s_format, env): 56 | data=MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, feed_type=s_format) 57 | model=create_ddpg_model(data, model_cls) 58 | model.eval() 59 | model(data.state.s.float()) 60 | check_shape(env,data,s_format) 61 | data.close() 62 | 63 | 64 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs))) 65 | def test_ddpg_ddpglearner(model_cls, s_format, mem, env): 66 | data=MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, feed_type=s_format) 67 | model=create_ddpg_model(data, model_cls) 68 | memory=mem(memory_size=1000, reduce_ram=True) 69 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1, 70 | decay=0.001) 71 | ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 72 | check_shape(env,data,s_format) 73 | data.close() 74 | 75 | 76 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs))) 77 | def test_ddpg_fit(model_cls, s_format, mem, env): 78 | learner=trained_learner(env=env, bs=10,opt=torch.optim.RMSprop,model_cls=model_cls,layers=[20, 20],memory_size=100, 79 | max_steps=20,render='rgb_array',decay=0.001,s_format=s_format,experience=mem,epochs=2) 80 | 81 | check_shape(env,learner.data,s_format) 82 | del learner 83 | 84 | 85 | @pytest.mark.usefixtures('skip_performance_check') 86 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 87 | list(product(p_model, p_format, p_exp))) 88 | def test_ddpg_models_pendulum(model_cls, s_format, experience): 89 | group_interp=GroupAgentInterpretation() 90 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}' 91 | for i in range(5): 92 | print('\n') 93 | learner=trained_learner(model_cls,'Pendulum-v0',s_format,experience,decay=0.0001,render='rgb_array') 94 | learner2gif(learner,s_format,group_interp,'pendulum',extra_s) 95 | del learner 96 | 97 | 98 | @pytest.mark.usefixtures('skip_performance_check') 99 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 100 | list(product(p_model, p_format, p_exp))) 101 | def test_ddpg_models_acrobot(model_cls, s_format, experience): 102 | group_interp=GroupAgentInterpretation() 103 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}' 104 | for i in range(5): 105 | print('\n') 106 | learner=trained_learner(model_cls,'Acrobot-v1',s_format,experience,decay=0.0001,render='rgb_array') 107 | learner2gif(learner,s_format,group_interp,'acrobot',extra_s) 108 | del learner 109 | 110 | 111 | @pytest.mark.usefixtures('skip_performance_check') 112 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 113 | list(product(p_model, p_full_format, p_exp))) 114 | def test_ddpg_models_mountain_car_continuous(model_cls, s_format, experience): 115 | group_interp=GroupAgentInterpretation() 116 | for i in range(5): 117 | print('\n') 118 | data=MDPDataBunch.from_env('MountainCarContinuous-v0', render='rgb_array', bs=40, add_valid=False, keep_env_open=False, 119 | feed_type=s_format, memory_management_strategy='k_partitions_top', k=3, res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2)) 120 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1, 121 | decay=0.0001) 122 | memory=experience(memory_size=1000000, reduce_ram=True) 123 | model=create_ddpg_model(data=data, base_arch=model_cls) 124 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 125 | callback_fns=[RewardMetric, EpsilonMetric]) 126 | learner.fit(450) 127 | 128 | meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 129 | interp=AgentInterpretation(learner, ds_type=DatasetType.Train) 130 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 131 | group_interp.add_interpretation(interp) 132 | group_interp.to_pickle(f'../docs_src/data/mountaincarcontinuous_{model.name.lower()}/', 133 | f'{model.name.lower()}_{meta}') 134 | [g.write('../res/run_gifs/mountaincarcontinuous') for g in interp.generate_gif()] 135 | data.close() 136 | del learner 137 | del model 138 | del data 139 | 140 | 141 | @pytest.mark.usefixtures('skip_performance_check') 142 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 143 | list(product(p_model, p_full_format, p_exp))) 144 | def test_ddpg_models_reach(model_cls, s_format, experience): 145 | group_interp=GroupAgentInterpretation() 146 | for i in range(5): 147 | print('\n') 148 | data=MDPDataBunch.from_env('ReacherPyBulletEnv-v0', render='rgb_array', bs=40, add_valid=False, keep_env_open=False, feed_type=s_format, 149 | memory_management_strategy='k_partitions_top', k=3, res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2)) 150 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1, 151 | decay=0.00001) 152 | memory=experience(memory_size=1000000, reduce_ram=True) 153 | model=create_ddpg_model(data=data, base_arch=model_cls) 154 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 155 | callback_fns=[RewardMetric, EpsilonMetric]) 156 | learner.fit(450) 157 | 158 | meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 159 | interp=AgentInterpretation(learner, ds_type=DatasetType.Train) 160 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 161 | group_interp.add_interpretation(interp) 162 | group_interp.to_pickle(f'../docs_src/data/reacher_{model.name.lower()}/', 163 | f'{model.name.lower()}_{meta}') 164 | [g.write('../res/run_gifs/reacher') for g in interp.generate_gif()] 165 | data.close() 166 | del learner 167 | del model 168 | del data 169 | 170 | 171 | @pytest.mark.usefixtures('skip_performance_check') 172 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 173 | list(product(p_model, p_full_format, p_exp))) 174 | def test_ddpg_models_walker(model_cls, s_format, experience): 175 | group_interp=GroupAgentInterpretation() 176 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}' 177 | for i in range(5): 178 | print('\n') 179 | # data=MDPDataBunch.from_env('Walker2DPyBulletEnv-v0', render='human', bs=64, add_valid=False, keep_env_open=False, 180 | # feed_type=s_format, memory_management_strategy='k_partitions_top', k=3) 181 | # exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1, 182 | # decay=0.0001) 183 | # memory=experience(memory_size=1000000, reduce_ram=True) 184 | # model=create_ddpg_model(data=data, base_arch=model_cls) 185 | # learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 186 | # callback_fns=[RewardMetric, EpsilonMetric]) 187 | # learner.fit(1500) 188 | learner=trained_learner(model_cls,'Walker2DPyBulletEnv-v0',s_format,experience,decay=0.0001,render='rgb_array') 189 | learner2gif(learner,s_format,group_interp,'walker2d',extra_s) 190 | 191 | # meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 192 | # interp=AgentInterpretation(learner, ds_type=DatasetType.Train) 193 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 194 | # group_interp.add_interpretation(interp) 195 | # group_interp.to_pickle(f'../docs_src/data/walker2d_{model.name.lower()}/', 196 | # f'{model.name.lower()}_{meta}') 197 | # [g.write('../res/run_gifs/walker2d') for g in interp.generate_gif()] 198 | # data.close() 199 | del learner 200 | # del model 201 | # del data 202 | 203 | 204 | @pytest.mark.usefixtures('skip_performance_check') 205 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 206 | list(product(p_model, p_full_format, p_exp))) 207 | def test_ddpg_models_ant(model_cls, s_format, experience): 208 | group_interp=GroupAgentInterpretation() 209 | extra_s = f'{experience.__name__}_{model_cls.__name__}_{s_format}' 210 | for i in range(5): 211 | print('\n') 212 | # data=MDPDataBunch.from_env('AntPyBulletEnv-v0', render='human', bs=64, add_valid=False, keep_env_open=False, 213 | # feed_type=s_format, memory_management_strategy='k_partitions_top', k=3) 214 | # exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1, 215 | # decay=0.00001) 216 | # memory=experience(memory_size=1000000, reduce_ram=True) 217 | # model=create_ddpg_model(data=data, base_arch=model_cls, lr=1e-3, actor_lr=1e-4) 218 | # learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 219 | # opt_func=torch.optim.Adam, callback_fns=[RewardMetric, EpsilonMetric]) 220 | # learner.fit(4) 221 | learner=trained_learner(model_cls,'AntPyBulletEnv-v0',s_format,experience,decay=0.0001,render='rgb_array',epochs=1000) 222 | learner2gif(learner,s_format,group_interp,'ant',extra_s) 223 | # meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 224 | # interp=AgentInterpretation(learner, ds_type=DatasetType.Train) 225 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 226 | # group_interp.add_interpretation(interp) 227 | # group_interp.to_pickle(f'../docs_src/data/ant_{model.name.lower()}/', 228 | # f'{model.name.lower()}_{meta}') 229 | # [g.write('../res/run_gifs/ant', frame_skip=5) for g in interp.generate_gif()] 230 | del learner 231 | # del model 232 | # del data 233 | 234 | 235 | @pytest.mark.usefixtures('skip_performance_check') 236 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 237 | list(product(p_model, p_full_format, p_exp))) 238 | def test_ddpg_models_halfcheetah(model_cls, s_format, experience): 239 | group_interp=GroupAgentInterpretation() 240 | for i in range(5): 241 | print('\n') 242 | data=MDPDataBunch.from_env('HalfCheetahPyBulletEnv-v0', render='rgb_array', bs=64, add_valid=False, keep_env_open=False, 243 | feed_type=s_format, memory_management_strategy='k_partitions_top', k=3, res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2)) 244 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1, 245 | decay=0.000001) 246 | memory=experience(memory_size=1000000, reduce_ram=True) 247 | model=create_ddpg_model(data=data, base_arch=model_cls) 248 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 249 | callback_fns=[RewardMetric, EpsilonMetric]) 250 | learner.fit(1000) 251 | 252 | meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 253 | interp=AgentInterpretation(learner, ds_type=DatasetType.Train) 254 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 255 | group_interp.add_interpretation(interp) 256 | group_interp.to_pickle(f'../docs_src/data/halfcheetah_{model.name.lower()}/', 257 | f'{model.name.lower()}_{meta}') 258 | [g.write('../res/run_gifs/halfcheetah') for g in interp.generate_gif()] 259 | del learner 260 | del model 261 | del data 262 | -------------------------------------------------------------------------------- /tests/test_dqn.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from time import sleep 3 | 4 | import pytest 5 | from fastai.basic_data import DatasetType 6 | 7 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner, DQNLearner 8 | from fast_rl.agents.dqn_models import * 9 | from fast_rl.core.agent_core import ExperienceReplay, PriorityExperienceReplay, GreedyEpsilon 10 | from fast_rl.core.data_block import MDPDataBunch, FEED_TYPE_STATE, FEED_TYPE_IMAGE, ResolutionWrapper 11 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric 12 | from fast_rl.core.train import GroupAgentInterpretation, AgentInterpretation 13 | from torch import optim 14 | 15 | p_model = [DQNModule, FixedTargetDQNModule,DoubleDuelingModule,DuelingDQNModule,DoubleDQNModule] 16 | p_exp = [ExperienceReplay, 17 | PriorityExperienceReplay] 18 | p_format = [FEED_TYPE_STATE]#, FEED_TYPE_IMAGE] 19 | p_envs = ['CartPole-v1'] 20 | 21 | config_env_expectations = { 22 | 'CartPole-v1': {'action_shape': (1, 2), 'state_shape': (1, 4)}, 23 | 'maze-random-5x5-v0': {'action_shape': (1, 4), 'state_shape': (1, 2)} 24 | } 25 | 26 | 27 | def learner2gif(lnr:DQNLearner,s_format,group_interp:GroupAgentInterpretation,name:str,extra:str): 28 | meta=f'{lnr.memory.__class__.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 29 | interp=AgentInterpretation(lnr, ds_type=DatasetType.Train) 30 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 31 | group_interp.add_interpretation(interp) 32 | group_interp.to_pickle(f'../docs_src/data/{name}_{lnr.model.name.lower()}/', f'{lnr.model.name.lower()}_{meta}') 33 | temp=[g.write(f'../res/run_gifs/{name}_{extra}') for g in interp.generate_gif()] 34 | del temp 35 | gc.collect() 36 | 37 | 38 | 39 | def trained_learner(model_cls, env, s_format, experience, bs, layers, memory_size=1000000, decay=0.001, 40 | copy_over_frequency=300, lr=None, epochs=450,**kwargs): 41 | if lr is None: lr = [0.001, 0.00025] 42 | memory = experience(memory_size=memory_size, reduce_ram=True) 43 | explore = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=decay) 44 | if type(lr) == list: lr = lr[0] if model_cls == DQNModule else lr[1] 45 | data = MDPDataBunch.from_env(env, render='human', bs=bs, add_valid=False, keep_env_open=False, feed_type=s_format, 46 | memory_management_strategy='k_partitions_top', k=3,**kwargs) 47 | if model_cls == DQNModule: model = create_dqn_model(data=data, base_arch=model_cls, lr=lr, layers=layers, opt=optim.RMSprop) 48 | else: model = create_dqn_model(data=data, base_arch=model_cls, lr=lr, layers=layers) 49 | learn = dqn_learner(data, model, memory=memory, exploration_method=explore, copy_over_frequency=copy_over_frequency, 50 | callback_fns=[RewardMetric, EpsilonMetric]) 51 | learn.fit(epochs) 52 | return learn 53 | 54 | # @pytest.mark.usefixtures('skip_performance_check') 55 | @pytest.mark.parametrize(["model_cls", "s_format", "env"], list(product(p_model, p_format, p_envs))) 56 | def test_dqn_create_dqn_model(model_cls, s_format, env): 57 | data = MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, feed_type=s_format) 58 | model = create_dqn_model(data, model_cls) 59 | model.eval() 60 | model(data.state.s) 61 | 62 | assert config_env_expectations[env]['action_shape'] == (1, data.action.n_possible_values.item()) 63 | if s_format == FEED_TYPE_STATE: 64 | assert config_env_expectations[env]['state_shape'] == data.state.s.shape 65 | 66 | 67 | # @pytest.mark.usefixtures('skip_performance_check') 68 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs))) 69 | def test_dqn_dqn_learner(model_cls, s_format, mem, env): 70 | data = MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, keep_env_open=False, feed_type=s_format) 71 | model = create_dqn_model(data, model_cls) 72 | memory = mem(memory_size=1000, reduce_ram=True) 73 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 74 | dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 75 | 76 | assert config_env_expectations[env]['action_shape'] == (1, data.action.n_possible_values.item()) 77 | if s_format == FEED_TYPE_STATE: 78 | assert config_env_expectations[env]['state_shape'] == data.state.s.shape 79 | 80 | 81 | # @pytest.mark.usefixtures('skip_performance_check') 82 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs))) 83 | def test_dqn_fit(model_cls, s_format, mem, env): 84 | data = MDPDataBunch.from_env(env, render='rgb_array', bs=5, max_steps=20, add_valid=False, keep_env_open=False, feed_type=s_format) 85 | model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop) 86 | memory = mem(memory_size=1000, reduce_ram=True) 87 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 88 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 89 | learner.fit(2) 90 | 91 | assert config_env_expectations[env]['action_shape'] == (1, data.action.n_possible_values.item()) 92 | if s_format == FEED_TYPE_STATE: 93 | assert config_env_expectations[env]['state_shape'] == data.state.s.shape 94 | 95 | 96 | @pytest.mark.usefixtures('skip_performance_check') 97 | @pytest.mark.parametrize(["model_cls", "s_format", "mem"], list(product(p_model, p_format, p_exp))) 98 | def test_dqn_fit_maze_env(model_cls, s_format, mem): 99 | group_interp = GroupAgentInterpretation() 100 | extra_s=f'{mem.__name__}_{model_cls.__name__}_{s_format}' 101 | for i in range(5): 102 | learn = trained_learner(model_cls, 'maze-random-5x5-v0', s_format, mem, bs=32, layers=[32, 32], 103 | memory_size=1000000, decay=0.00001, res_wrap=partial(ResolutionWrapper, w_step=3, h_step=3)) 104 | 105 | learner2gif(learn,s_format,group_interp,'maze_5x5',extra_s) 106 | # success = False 107 | # while not success: 108 | # try: 109 | # data = MDPDataBunch.from_env('maze-random-5x5-v0', render='rgb_array', bs=5, max_steps=20, 110 | # add_valid=False, keep_env_open=False, feed_type=s_format) 111 | # model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop) 112 | # memory = ExperienceReplay(10000) 113 | # exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 114 | # learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 115 | # callback_fns=[RewardMetric, EpsilonMetric]) 116 | # learner.fit(2) 117 | # 118 | # assert config_env_expectations['maze-random-5x5-v0']['action_shape'] == ( 119 | # 1, data.action.n_possible_values.item()) 120 | # if s_format == FEED_TYPE_STATE: 121 | # assert config_env_expectations['maze-random-5x5-v0']['state_shape'] == data.state.s.shape 122 | # sleep(1) 123 | # success = True 124 | # except Exception as e: 125 | # if not str(e).__contains__('Surface'): 126 | # raise Exception 127 | 128 | 129 | @pytest.mark.usefixtures('skip_performance_check') 130 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], list(product(p_model, p_format, p_exp))) 131 | def test_dqn_models_minigrids(model_cls, s_format, experience): 132 | group_interp = GroupAgentInterpretation() 133 | for i in range(5): 134 | learn = trained_learner(model_cls, 'MiniGrid-FourRooms-v0', s_format, experience, bs=32, layers=[64, 64], 135 | memory_size=1000000, decay=0.00001, epochs=1000) 136 | 137 | meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 138 | interp = AgentInterpretation(learn, ds_type=DatasetType.Train) 139 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 140 | group_interp.add_interpretation(interp) 141 | filename = f'{learn.model.name.lower()}_{meta}' 142 | group_interp.to_pickle(f'../docs_src/data/minigrid_{learn.model.name.lower()}/', filename) 143 | [g.write('../res/run_gifs/minigrid') for g in interp.generate_gif()] 144 | del learn 145 | 146 | 147 | @pytest.mark.usefixtures('skip_performance_check') 148 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], 149 | list(product(p_model, p_format, p_exp))) 150 | def test_dqn_models_cartpole(model_cls, s_format, experience): 151 | group_interp = GroupAgentInterpretation() 152 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}' 153 | for i in range(5): 154 | learn = trained_learner(model_cls, 'CartPole-v1', s_format, experience, bs=32, layers=[64, 64], 155 | memory_size=1000000, decay=0.001) 156 | 157 | learner2gif(learn,s_format,group_interp,'cartpole',extra_s) 158 | # meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 159 | # interp = AgentInterpretation(learn, ds_type=DatasetType.Train) 160 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 161 | # group_interp.add_interpretation(interp) 162 | # filename = f'{learn.model.name.lower()}_{meta}' 163 | # group_interp.to_pickle(f'../docs_src/data/cartpole_{learn.model.name.lower()}/', filename) 164 | # [g.write('../res/run_gifs/cartpole') for g in interp.generate_gif()] 165 | # del learn 166 | 167 | 168 | @pytest.mark.usefixtures('skip_performance_check') 169 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], list(product(p_model, p_format, p_exp))) 170 | def test_dqn_models_lunarlander(model_cls, s_format, experience): 171 | group_interp = GroupAgentInterpretation() 172 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}' 173 | for i in range(5): 174 | learn = trained_learner(model_cls, 'LunarLander-v2', s_format, experience, bs=32, layers=[128, 64], 175 | memory_size=1000000, decay=0.00001, copy_over_frequency=600, lr=[0.001, 0.00025], 176 | epochs=1000) 177 | learner2gif(learn, s_format, group_interp, 'lunarlander', extra_s) 178 | del learn 179 | gc.collect() 180 | # meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 181 | # interp = AgentInterpretation(learn, ds_type=DatasetType.Train) 182 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 183 | # group_interp.add_interpretation(interp) 184 | # filename = f'{learn.model.name.lower()}_{meta}' 185 | # group_interp.to_pickle(f'../docs_src/data/lunarlander_{learn.model.name.lower()}/', filename) 186 | # [g.write('../res/run_gifs/lunarlander') for g in interp.generate_gif()] 187 | # del learn 188 | 189 | 190 | @pytest.mark.usefixtures('skip_performance_check') 191 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], list(product(p_model, p_format, p_exp))) 192 | def test_dqn_models_mountaincar(model_cls, s_format, experience): 193 | group_interp = GroupAgentInterpretation() 194 | for i in range(5): 195 | learn = trained_learner(model_cls, 'MountainCar-v0', s_format, experience, bs=32, layers=[24, 12], 196 | memory_size=1000000, decay=0.00001, copy_over_frequency=1000) 197 | meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}' 198 | interp = AgentInterpretation(learn, ds_type=DatasetType.Train) 199 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta) 200 | group_interp.add_interpretation(interp) 201 | filename = f'{learn.model.name.lower()}_{meta}' 202 | group_interp.to_pickle(f'../docs_src/data/mountaincar_{learn.model.name.lower()}/', filename) 203 | [g.write('../res/run_gifs/mountaincar') for g in interp.generate_gif()] 204 | del learn 205 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from fastai.imports import torch 2 | 3 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner 4 | from fast_rl.agents.dqn_models import DQNModule 5 | from fast_rl.core.agent_core import ExperienceReplay, GreedyEpsilon 6 | from fast_rl.core.data_block import MDPDataBunch 7 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric 8 | 9 | 10 | def test_metrics_reward_init(): 11 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20) 12 | model=create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop) 13 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True) 14 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 15 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 16 | callback_fns=[RewardMetric]) 17 | learner.fit(2) 18 | 19 | 20 | def test_metrics_epsilon_init(): 21 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20) 22 | model=create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop) 23 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True) 24 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 25 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method, 26 | callback_fns=[EpsilonMetric]) 27 | learner.fit(2) 28 | 29 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | from fast_rl.agents.dqn import * 6 | from fast_rl.core.agent_core import * 7 | from fast_rl.core.data_block import * 8 | from fast_rl.core.train import * 9 | 10 | 11 | 12 | 13 | @pytest.mark.usefixtures('skip_performance_check') 14 | def test_interpretation_gif(): 15 | logger = logging.getLogger('root') 16 | logger.setLevel('DEBUG') 17 | 18 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=32, add_valid=False, 19 | memory_management_strategy='k_partitions_top', k=3) 20 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1) 21 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True) 22 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001) 23 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 24 | learner.fit(10) 25 | interp = AgentInterpretation(learner, ds_type=DatasetType.Train) 26 | interp.generate_gif(-1).write('last_episode') 27 | # 28 | # 29 | # @pytest.mark.parametrize(["model_cls", "s_format", "mem"], list(product(p_model, p_format, p_exp))) 30 | # def test_train_gym_maze_interpretation(model_cls, s_format, mem): 31 | # success = False 32 | # while not success: 33 | # try: 34 | # data = MDPDataBunch.from_env('maze-random-5x5-v0', render='rgb_array', bs=5, max_steps=50, 35 | # add_valid=False, feed_type=s_format) 36 | # model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop) 37 | # memory = mem(10000) 38 | # exploration_method = GreedyEpsilon(e_start=1, e_end=0.1, decay=0.001) 39 | # lnr = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 40 | # lnr.fit(1) 41 | # 42 | # interp = GymMazeInterpretation(lnr, ds_type=DatasetType.Train) 43 | # for i in range(-1, 4): interp.plot_heat_map(action=i) 44 | # 45 | # success = True 46 | # except Exception as e: 47 | # if not str(e).__contains__('Surface'): 48 | # raise Exception 49 | # 50 | # 51 | # @pytest.mark.parametrize(["model_cls", "s_format", "mem"], list(product(p_model, p_format, p_exp))) 52 | # def test_train_q_value_interpretation(model_cls, s_format, mem): 53 | # success = False 54 | # while not success: 55 | # try: 56 | # data = MDPDataBunch.from_env('maze-random-5x5-v0', render='rgb_array', bs=5, max_steps=50, 57 | # add_valid=False, feed_type=s_format) 58 | # model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop) 59 | # memory = mem(10000) 60 | # exploration_method = GreedyEpsilon(e_start=1, e_end=0.1, decay=0.001) 61 | # lnr = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method) 62 | # lnr.fit(1) 63 | # 64 | # interp = QValueInterpretation(lnr, ds_type=DatasetType.Train) 65 | # interp.plot_q() 66 | # 67 | # success = True 68 | # except Exception as e: 69 | # if not str(e).__contains__('Surface'): 70 | # raise Exception(e) 71 | # 72 | # # 73 | # # def test_groupagentinterpretation_from_pickle(): 74 | # # group_interp = GroupAgentInterpretation.from_pickle('./data/cartpole_dqn', 75 | # # 'dqn_PriorityExperienceReplay_FEED_TYPE_STATE') 76 | # # group_interp.plot_reward_bounds(return_fig=True, per_episode=True, smooth_groups=5).show() 77 | # # 78 | # # 79 | # # def test_groupagentinterpretation_analysis(): 80 | # # group_interp = GroupAgentInterpretation.from_pickle('./data/cartpole_dqn', 81 | # # 'dqn_PriorityExperienceReplay_FEED_TYPE_STATE') 82 | # # assert isinstance(group_interp.analysis, list) 83 | # # group_interp.in_notebook = True 84 | # # assert isinstance(group_interp.analysis, pd.DataFrame) 85 | # 86 | # 87 | # 88 | # 89 | # # 90 | # # def test_interpretation_reward_group_plot(): 91 | # # group_interp = GroupAgentInterpretation() 92 | # # group_interp2 = GroupAgentInterpretation() 93 | # # 94 | # # for i in range(2): 95 | # # data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=4, add_valid=False) 96 | # # model = DQN(data) 97 | # # learn = AgentLearner(data, model) 98 | # # learn.fit(2) 99 | # # 100 | # # interp = AgentInterpretation(learn=learn, ds_type=DatasetType.Train) 101 | # # interp.plot_rewards(cumulative=True, per_episode=True, group_name='run1') 102 | # # group_interp.add_interpretation(interp) 103 | # # group_interp2.add_interpretation(interp) 104 | # # 105 | # # group_interp.plot_reward_bounds(return_fig=True, per_episode=True).show() 106 | # # group_interp2.plot_reward_bounds(return_fig=True, per_episode=True).show() 107 | # # 108 | # # new_interp = group_interp.merge(group_interp2) 109 | # # assert len(new_interp.groups) == len(group_interp.groups) + len(group_interp2.groups), 'Lengths do not match' 110 | # # new_interp.plot_reward_bounds(return_fig=True, per_episode=True).show() 111 | --------------------------------------------------------------------------------