├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── pull-request.md
├── .gitignore
├── .idea
├── codeStyles
│ ├── Project.xml
│ └── codeStyleConfig.xml
└── dictionaries
│ └── jlaivins.xml
├── Dockerfile
├── LICENSE.txt
├── README.md
├── ROADMAP.md
├── azure-pipelines.yml
├── build
└── azure_pipeline_helper.sh
├── docs_src
├── data
│ ├── ant_ddpg
│ │ ├── ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── cartpole_dddqn
│ │ ├── dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ ├── dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ │ ├── dddqn_er_rms.pickle
│ │ └── dddqn_per_rms.pickle
│ ├── cartpole_ddqn
│ │ ├── ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ ├── ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ │ ├── ddqn_er_rms.pickle
│ │ └── ddqn_per_rms.pickle
│ ├── cartpole_dqn fixed targeting
│ │ ├── dqn fixed targeting_er_rms.pickle
│ │ └── dqn fixed targeting_per_rms.pickle
│ ├── cartpole_dqn
│ │ ├── dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── cartpole_dueling dqn
│ │ ├── dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ ├── dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ │ ├── dueling dqn_er_rms.pickle
│ │ └── dueling dqn_per_rms.pickle
│ ├── cartpole_fixed target dqn
│ │ ├── fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── halfcheetah_ddpg
│ │ ├── ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── lunarlander_dddqn
│ │ ├── dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── lunarlander_ddqn
│ │ ├── ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── lunarlander_dqn fixed targeting
│ │ ├── dqn fixed targeting_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── dqn fixed targeting_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── lunarlander_dqn
│ │ ├── dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── lunarlander_dueling dqn
│ │ ├── dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ ├── lunarlander_fixed target dqn
│ │ ├── fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
│ └── pendulum_ddpg
│ │ ├── ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle
│ │ └── ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
├── rl.agents.dddqn.ipynb
├── rl.agents.ddpg.ipynb
├── rl.agents.doubledqn.ipynb
├── rl.agents.dqn.ipynb
├── rl.agents.dqnfixedtarget.ipynb
├── rl.agents.duelingdqn.ipynb
├── rl.core.train.interpretation.ipynb
└── util.gif_handling.ipynb
├── environment.yaml
├── fast_rl
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── ddpg.py
│ ├── ddpg_models.py
│ ├── dqn.py
│ └── dqn_models.py
├── core
│ ├── Interpreter.py
│ ├── __init__.py
│ ├── agent_core.py
│ ├── basic_train.py
│ ├── data_block.py
│ ├── data_structures.py
│ ├── layers.py
│ ├── metrics.py
│ └── train.py
└── util
│ ├── __init__.py
│ ├── exceptions.py
│ └── misc.py
├── res
├── RELEASE_BLOG.md
├── ddpg_balancing.gif
├── dqn_q_estimate_1.jpg
├── dqn_q_estimate_2.jpg
├── dqn_q_estimate_3.jpg
├── fit_func_out.jpg
├── heat_map_1.png
├── heat_map_2.png
├── heat_map_3.png
├── heat_map_4.png
├── pre_interpretation_maze_dqn.gif
├── reward_plot_1.png
├── reward_plot_2.png
├── reward_plots
│ ├── ant_ddpg.png
│ ├── cartpole_dddqn.png
│ ├── cartpole_double.png
│ ├── cartpole_dqn.png
│ ├── cartpole_dueling.png
│ ├── cartpole_fixedtarget.png
│ ├── halfcheetah_ddpg.png
│ ├── lunarlander_all_targetbased.png
│ ├── lunarlander_dddqn.png
│ ├── lunarlander_double.png
│ ├── lunarlander_dqn.png
│ ├── lunarlander_dueling.png
│ ├── lunarlander_fixedtarget.png
│ └── pendulum_ddpg.png
└── run_gifs
│ ├── __init__.py
│ ├── acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif
│ ├── acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif
│ ├── acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif
│ ├── acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif
│ ├── acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif
│ ├── acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif
│ ├── ant_ExperienceReplay_DDPGModule_1_episode_54.gif
│ ├── ant_ExperienceReplay_DDPGModule_1_episode_614.gif
│ ├── ant_ExperienceReplay_DDPGModule_1_episode_999.gif
│ ├── ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif
│ ├── ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif
│ ├── ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif
│ ├── cartpole_ExperienceReplay_DQNModule_1_episode_207.gif
│ ├── cartpole_ExperienceReplay_DQNModule_1_episode_31.gif
│ ├── cartpole_ExperienceReplay_DQNModule_1_episode_447.gif
│ ├── cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif
│ ├── cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif
│ ├── cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif
│ ├── cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif
│ ├── cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif
│ ├── cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif
│ ├── cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif
│ ├── cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif
│ ├── cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif
│ ├── cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif
│ ├── cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif
│ ├── cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif
│ ├── cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif
│ ├── cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif
│ ├── cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif
│ ├── cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif
│ ├── cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif
│ ├── cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif
│ ├── cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif
│ ├── cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif
│ ├── cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif
│ ├── cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif
│ ├── cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif
│ ├── cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif
│ ├── cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif
│ ├── cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif
│ ├── cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif
│ ├── lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif
│ ├── lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif
│ ├── lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif
│ ├── lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif
│ ├── lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif
│ ├── lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif
│ ├── lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif
│ ├── lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif
│ ├── lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif
│ ├── lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif
│ ├── lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif
│ ├── lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif
│ ├── lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif
│ ├── lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif
│ ├── lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif
│ ├── lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif
│ ├── lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif
│ ├── lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif
│ ├── lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif
│ ├── lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif
│ ├── lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif
│ ├── lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif
│ ├── lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif
│ ├── lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif
│ ├── pendulum_ExperienceReplay_DDPGModule_1_episode_238.gif
│ ├── pendulum_ExperienceReplay_DDPGModule_1_episode_447.gif
│ ├── pendulum_ExperienceReplay_DDPGModule_1_episode_9.gif
│ ├── pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif
│ ├── pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif
│ └── pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif
├── setup.py
└── tests
├── __init__.py
├── conftest.py
├── data
├── cartpole_dqn
│ └── dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
├── cat
│ ├── cat1.jpeg
│ └── cat2.jpeg
└── dog
│ ├── dog1.jpeg
│ └── dog2.jpeg
├── test_agent_core.py
├── test_basic_train.py
├── test_data_block.py
├── test_data_structures.py
├── test_ddpg.py
├── test_dqn.py
├── test_metrics.py
└── test_train.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **System (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Version [e.g. 22]
29 | - Python Version [e.g. 3.6]
30 |
31 | **Additional context**
32 | Add any other context about the problem here.
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **This feature is officially added when**
20 | - [ ] These items are accomplished . . .
21 | . . .
22 |
23 | **Additional context**
24 | Add any other context or screenshots about the feature request here.
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/pull-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Pull Request
3 | about: Guide lines for making pull requests
4 | title: "[Added or Changed or Fixed] . . . something in you PR [FIX or FEATURE]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **REMOVE ME START**
11 | IMPORTANT: Please do not create a Pull Request without creating an issue first. Template from [stevemao templates](https://github.com/stevemao/github-issue-templates).
12 |
13 | Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request.
14 | **END**
15 |
16 | **What this PR does**
17 | Please provide enough information so that others can review your pull request. Start with a quick TL:DR or a short set of statements one what you did.
18 |
19 | **Why this PR exists**
20 | Explain the details for making this change. What existing problem does the pull request solve?
21 |
22 | **This is how I plan to test this**
23 | Test plan (required)
24 |
25 | Demonstrate the code is solid. Example: The exact commands you ran and their output, screenshots / videos if the pull request changes UI.
26 |
27 | Code formatting
28 |
29 | **Closing issues**
30 |
31 | Put closes #XXXX in your comment to auto-close the issue that your PR fixes (if such).
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # IntelliJ project files
2 | *.iml
3 | .idea/inspectionProfiles/*
4 | misc.xml
5 | modules.xml
6 | other.xml
7 | vcs.xml
8 | workspace.xml
9 | out
10 | gen
11 | .DS_Store
12 |
13 | # Jupyter Notebook
14 | */.ipynb_checkpoints/*
15 |
16 | # Secure Files
17 | .pypirc
18 |
19 | # Data Files
20 | #/docs_src/data/*
21 |
22 | # Build Files / Directories
23 | build/*
24 | dist/*
25 | fast_rl.egg-info/*
--------------------------------------------------------------------------------
/.idea/codeStyles/Project.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/dictionaries/jlaivins.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | argmax
5 | conv
6 | convolutional
7 | cumulate
8 | darknet
9 | datasets
10 | ddpg
11 | deallocated
12 | dtype
13 | figsize
14 | fixedtargetdqn
15 | gifs
16 | groupable
17 | moviepy
18 | ornstein
19 | pybullet
20 | uhlenbeck
21 |
22 |
23 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:10.0-base-ubuntu18.04
2 | # See http://bugs.python.org/issue19846
3 | ENV LANG C.UTF-8
4 | LABEL com.nvidia.volumes.needed="nvidia_driver"
5 |
6 | RUN apt-get update && apt-get install -y --no-install-recommends \
7 | build-essential cmake git curl vim ca-certificates python-qt4 libjpeg-dev \
8 | zip nano unzip libpng-dev strace python-opengl xvfb && \
9 | rm -rf /var/lib/apt/lists/*
10 |
11 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
12 | ENV PYTHON_VERSION=3.6
13 |
14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
15 | chmod +x ~/miniconda.sh && \
16 | ~/miniconda.sh -b -p /opt/conda && \
17 | rm ~/miniconda.sh && \
18 | /opt/conda/bin/conda install conda-build && \
19 | apt-get update && apt-get upgrade -y --no-install-recommends
20 |
21 | ENV PATH=$PATH:/opt/conda/bin/
22 | ENV USER fastrl_user
23 | # Create Enviroment
24 | COPY environment.yaml /environment.yaml
25 | RUN conda env create -f environment.yaml
26 |
27 | # Cleanup
28 | RUN rm -rf /var/lib/apt/lists/* \
29 | && apt-get -y autoremove
30 |
31 | EXPOSE 8888
32 | ENV CONDA_DEFAULT_ENV fastrl
33 |
34 | CMD ["/bin/bash -c"]
35 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://dev.azure.com/jokellum/jokellum/_build/latest?definitionId=1&branchName=master)
2 | [](https://pypi.python.org/pypi/fast_rl)
3 | [](https://github.com/josiahls/fast-reinforcement-learning/releases)
4 |
5 | **Important Note**
6 | fastrl==2.* is being developed at [fastrl](https://github.com/josiahls/fastrl). https://github.com/josiahls/fastrl) is the permanent place for all fastai version 2.0 changes as well as faster/refactored/more stable models. Please go there instead for new information/code.
7 |
8 |
9 |
10 | # Fast_rl
11 | This repo is not affiliated with Jeremy Howard or his course which can be found [here](https://www.fast.ai/about/).
12 | We will be using components from the Fastai library for building and training our reinforcement learning (RL)
13 | agents.
14 |
15 | Our goal is for fast_rl to be make benchmarking easier, inference more efficient, and environment compatibility to be
16 | as decoupled as much as possible. This being version 1.0, we still have a lot of work to make RL training itself faster
17 | and more efficient. The goals for this repo can be seen in the [RoadMap](#roadmap).
18 |
19 | **An important note is that training can use up a lot of RAM. This will likely be resolved as more models are being added. Likely will be resolved by off loading to storage in the next few versions.**
20 |
21 | A simple example:
22 | ```python
23 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner
24 | from fast_rl.agents.dqn_models import *
25 | from fast_rl.core.agent_core import ExperienceReplay, GreedyEpsilon
26 | from fast_rl.core.data_block import MDPDataBunch
27 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric
28 |
29 | memory = ExperienceReplay(memory_size=1000000, reduce_ram=True)
30 | explore = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
31 | data = MDPDataBunch.from_env('CartPole-v1', render='human', bs=64, add_valid=False)
32 | model = create_dqn_model(data=data, base_arch=FixedTargetDQNModule, lr=0.001, layers=[32,32])
33 | learn = dqn_learner(data, model, memory=memory, exploration_method=explore, copy_over_frequency=300,
34 | callback_fns=[RewardMetric, EpsilonMetric])
35 | learn.fit(450)
36 | ```
37 |
38 | More complex examples might involve running an RL agent multiple times, generating episode snapshots as gifs, grouping
39 | reward plots, and finally showing the best and worst runs in a single graph.
40 | ```python
41 | from fastai.basic_data import DatasetType
42 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner
43 | from fast_rl.agents.dqn_models import *
44 | from fast_rl.core.agent_core import ExperienceReplay, GreedyEpsilon
45 | from fast_rl.core.data_block import MDPDataBunch
46 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric
47 | from fast_rl.core.train import GroupAgentInterpretation, AgentInterpretation
48 |
49 | group_interp = GroupAgentInterpretation()
50 | for i in range(5):
51 | memory = ExperienceReplay(memory_size=1000000, reduce_ram=True)
52 | explore = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
53 | data = MDPDataBunch.from_env('CartPole-v1', render='human', bs=64, add_valid=False)
54 | model = create_dqn_model(data=data, base_arch=FixedTargetDQNModule, lr=0.001, layers=[32,32])
55 | learn = dqn_learner(data, model, memory=memory, exploration_method=explore, copy_over_frequency=300,
56 | callback_fns=[RewardMetric, EpsilonMetric])
57 | learn.fit(450)
58 |
59 | interp=AgentInterpretation(learn, ds_type=DatasetType.Train)
60 | interp.plot_rewards(cumulative=True, per_episode=True, group_name='cartpole_experience_example')
61 | group_interp.add_interpretation(interp)
62 | group_interp.to_pickle(f'{learn.model.name.lower()}/', f'{learn.model.name.lower()}')
63 | for g in interp.generate_gif(): g.write(f'{learn.model.name.lower()}')
64 | group_interp.plot_reward_bounds(per_episode=True, smooth_groups=10)
65 | ```
66 | More examples can be found in `docs_src` and the actual code being run for generating gifs can be found in `tests` in
67 | either `test_dqn.py` or `test_ddpg.py`.
68 |
69 | As a note, here is a run down of existing RL frameworks:
70 | - [Intel Coach](https://github.com/NervanaSystems/coach)
71 | - [Tensor Force](https://github.com/tensorforce/tensorforce)
72 | - [OpenAI Baselines](https://github.com/openai/baselines)
73 | - [Tensorflow Agents](https://github.com/tensorflow/agents)
74 | - [KerasRL](https://github.com/keras-rl/keras-rl)
75 |
76 | However there are also frameworks in PyTorch:
77 | - [Horizon](https://github.com/facebookresearch/Horizon)
78 | - [DeepRL](https://github.com/ShangtongZhang/DeepRL)
79 | - [Spinning Up](https://spinningup.openai.com/en/latest/user/introduction.html)
80 |
81 | ## Installation
82 |
83 | **fastai (semi-optional)**\
84 | [Install Fastai](https://github.com/fastai/fastai/blob/master/README.md#installation)
85 | or if you are using Anaconda (which is a good idea to use Anaconda) you can do: \
86 | `conda install -c pytorch -c fastai fastai`
87 |
88 | **fast_rl**\
89 | Fastai will be installed if it does not exist. If it does exist, the versioning should be repaired by the the setup.py.
90 | `pip install fastai`
91 |
92 | ## Installation (Optional)
93 | OpenAI all gyms: \
94 | `pip install gym[all]`
95 |
96 | Mazes: \
97 | `git clone https://github.com/MattChanTK/gym-maze.git` \
98 | `cd gym-maze` \
99 | `python setup.py install`
100 |
101 |
102 | ## Installation Dev (Optional)
103 | `git clone https://github.com/josiahls/fast-reinforcement-learning.git` \
104 | `cd fast-reinforcement-learning` \
105 | `python setup.py install`
106 |
107 | ## Installation Issues
108 | Many issues will likely fall under [fastai installation issues](https://github.com/fastai/fastai/blob/master/README.md#installation-issues).
109 |
110 | Any other issues are likely environment related. It is important to note that Python 3.7 is not being tested due to
111 | an issue with Pyglet and gym do not working. This issue will not stop you from training models, however this might impact using
112 | OpenAI environments.
113 |
114 | ## RoadMap
115 |
116 | - [ ] **Working on** **1.0.0** Base version is completed with working model visualizations proving performance / expected failure. At
117 | this point, all models should have guaranteed environments they should succeed in.
118 | - [ ] 1.1.0 **Working on** More Traditional RL models
119 | - [ ] **Working on** Add PPO
120 | - [ ] **Working on** Add TRPO
121 | - [ ] Add D4PG
122 | - [ ] Add A2C
123 | - [ ] Add A3C
124 | - [ ] 1.2.0 HRL models *Possibly might change version to 2.0 depending on SMDP issues*
125 | - [ ] Add SMDP
126 | - [ ] Add Goal oriented MDPs. Will Require a new "Step"
127 | - [ ] Add FeUdal Network
128 | - [ ] Add storage based DataBunch memory management. This can prevent RAM from being used up by episode image frames
129 | that may or may not serve any use to the agent, but only for logging.
130 | - [ ] 1.3.0
131 | - [ ] Add HAC
132 | - [ ] Add MAXQ
133 | - [ ] Add HIRO
134 | - [ ] 1.4.0
135 | - [ ] Add h-DQN
136 | - [ ] Add Modulated Policy Hierarchies
137 | - [ ] Add Meta Learning Shared Hierarchies
138 | - [ ] 1.5.0
139 | - [ ] Add STRategic Attentive Writer (STRAW)
140 | - [ ] Add H-DRLN
141 | - [ ] Add Abstract Markov Decision Process (AMDP)
142 | - [ ] Add conda integration so that installation can be truly one step.
143 | - [ ] 1.6.0 HRL Options models *Possibly will already be implemented in a previous model*
144 | - [ ] Options augmentation to DQN based models
145 | - [ ] Options augmentation to actor critic models
146 | - [ ] Options augmentation to async actor critic models
147 | - [ ] 1.8.0 HRL Skills
148 | - [ ] Skills augmentation to DQN based models
149 | - [ ] Skills augmentation to actor critic models
150 | - [ ] Skills augmentation to async actor critic models
151 | - [ ] 1.9.0
152 | - [ ] 2.0.0 Add PyBullet Fetch Environments
153 | - [ ] 2.0.0 Not part of this repo, however the envs need to subclass the OpenAI `gym.GoalEnv`
154 | - [ ] 2.0.0 Add HER
155 |
156 |
157 | ## Contribution
158 | Following fastai's guidelines would be desirable: [Guidelines](https://github.com/fastai/fastai/blob/master/README.md#contribution-guidelines)
159 |
160 | While we hope that model additions will be added smoothly. All models will only be dependent on `core.layers.py`.
161 | As time goes on, the model architecture will overall improve (we are and while continue to be still figuring things out).
162 |
163 |
164 | ## Style
165 | Since fastai uses a different style from traditional PEP-8, we will be following [Style](https://docs.fast.ai/dev/style.html)
166 | and [Abbreviations](https://docs.fast.ai/dev/abbr.html). Also we will use RL specific abbr.
167 |
168 | | | Concept | Abbr. | Combination Examples |
169 | |:------:|:-------:|:-----:|:--------------------:|
170 | | **RL** | State | st | |
171 | | | Action | acn | |
172 | | | Bounds | bb | Same as Bounding Box |
173 |
174 | ## Examples
175 |
176 | ### Reward Graphs
177 |
178 | | | Model |
179 | |:------------------------------------------:|:---------------:|
180 | |  | DQN |
181 | |  | Dueling DQN |
182 | |  | Double DQN |
183 | |  | DDDQN |
184 | |  | Fixed Target DQN |
185 | |  | DQN |
186 | |  | Dueling DQN |
187 | |  | Double DQN |
188 | |  | DDDQN |
189 | |  | Fixed Target DQN |
190 | |  | DDPG |
191 | |  | DDPG |
192 | |  | DDPG |
193 |
194 |
195 | ### Agent Stages
196 |
197 | | Model | Gif(Early) | Gif(Mid) | Gif(Late) |
198 | |:------------:|:------------:|:------------:|:------------:|
199 | | DDPG+PER |  |  | |
200 | | DoubleDueling+ER |  |  | |
201 | | DoubleDQN+ER |  |  | |
202 | | DuelingDQN+ER |  |  | |
203 | | DoubleDueling+PER |  |  | |
204 | | DQN+ER |  |  | |
205 | | DuelingDQN+PER |  |  | |
206 | | DQN+PER |  |  | |
207 | | DoubleDQN+PER |  |  | |
208 | | DDPG+PER |  |  | |
209 | | DDPG+ER |  |  | |
210 | | DQN+PER |  |  | |
211 | | FixedTargetDQN+ER |  |  | |
212 | | DQN+ER |  |  | |
213 | | FixedTargetDQN+PER |  |  | |
214 | | DoubleDQN+ER |  |  | |
215 | | DoubleDQN+PER |  |  | |
216 | | DuelingDQN+ER |  |  | |
217 | | DoubleDueling+PER |  |  | |
218 | | DuelingDQN+PER |  |  | |
219 | | DoubleDueling+ER |  |  | |
220 | | DDPG+ER |  |  | |
221 | | DDPG+PER |  |  | |
222 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | - [X] 0.7.0 Full test suite using multi-processing. Connect to CI.
2 | - [X] 0.8.0 Comprehensive model eval **debug/verify**. Each model should succeed at at least a few known environments. Also, massive refactoring will be needed.
3 | - [X] 0.9.0 Notebook demonstrations of basic model usage.
4 | - [X] **1.0.0** Base version is completed with working model visualizations proving performance / expected failure. At
5 | this point, all models should have guaranteed environments they should succeed in.
6 | - [ ] **Working on** 1.1.0 More Traditional RL models
7 | - [ ] **Working on** Add PPO
8 | - [ ] **Working on** Add TRPO
9 | - [ ] Add D4PG
10 | - [ ] Add A2C
11 | - [ ] Add A3C
12 | - [ ] 1.2.0 HRL models *Possibly might change version to 2.0 depending on SMDP issues*
13 | - [ ] Add SMDP
14 | - [ ] Add Goal oriented MDPs. Will Require a new "Step"
15 | - [ ] Add FeUdal Network
16 | - [ ] Add storage based DataBunch memory management. This can prevent RAM from being used up by episode image frames
17 | that may or may not serve any use to the agent, but only for logging.
18 | - [ ] 1.3.0
19 | - [ ] Add HAC
20 | - [ ] Add MAXQ
21 | - [ ] Add HIRO
22 | - [ ] 1.4.0
23 | - [ ] Add h-DQN
24 | - [ ] Add Modulated Policy Hierarchies
25 | - [ ] Add Meta Learning Shared Hierarchies
26 | - [ ] 1.5.0
27 | - [ ] Add STRategic Attentive Writer (STRAW)
28 | - [ ] Add H-DRLN
29 | - [ ] Add Abstract Markov Decision Process (AMDP)
30 | - [ ] 1.6.0 HRL Options models *Possibly will already be implemented in a previous model*
31 | - [ ] Options augmentation to DQN based models
32 | - [ ] Options augmentation to actor critic models
33 | - [ ] Options augmentation to async actor critic models
34 | - [ ] 1.8.0 HRL Skills
35 | - [ ] Skills augmentation to DQN based models
36 | - [ ] Skills augmentation to actor critic models
37 | - [ ] Skills augmentation to async actor critic models
38 | - [ ] 1.9.0
39 | - [ ] 2.0.0 Add PyBullet Fetch Environments
40 | - [ ] 2.0.0 Not part of this repo, however the envs need to subclass the OpenAI `gym.GoalEnv`
41 | - [ ] 2.0.0 Add HER
--------------------------------------------------------------------------------
/azure-pipelines.yml:
--------------------------------------------------------------------------------
1 | # Starter pipeline
2 | # Start with a minimal pipeline that you can customize to build and deploy your code.
3 | # Add steps that build, run tests, deploy, and more:
4 | # https://aka.ms/yaml
5 |
6 | # - bash: "sudo apt-get install -y xvfb freeglut3-dev python-opengl --fix-missing"
7 | # displayName: 'Install ffmpeg, freeglut3-dev, and xvfb'
8 |
9 | stages:
10 | - stage: Test
11 | condition: and(always(), eq(variables['Build.Reason'], 'PullRequest'))
12 | jobs:
13 | - job: 'Test'
14 | pool:
15 | vmImage: 'ubuntu-16.04' # other options: 'macOS-10.13', 'vs2017-win2016'
16 | strategy:
17 | matrix:
18 | Python36:
19 | python.version: '3.6'
20 | # Python37: # Pyglet and gym do not working here
21 | # python.version: '3.7.5'
22 | maxParallel: 4
23 |
24 | steps:
25 | - task: UsePythonVersion@0
26 | inputs:
27 | versionSpec: '$(python.version)'
28 |
29 | - bash: "sudo apt-get install -y freeglut3-dev python-opengl"
30 | displayName: 'Install freeglut3-dev'
31 |
32 | - script: |
33 | python -m pip install --upgrade pip setuptools wheel pytest pytest-cov python-xlib -e .
34 | python setup.py install
35 | displayName: 'Install dependencies'
36 |
37 | - script: sh ./build/azure_pipeline_helper.sh
38 | displayName: 'Complex Installs'
39 |
40 | - script: |
41 | xvfb-run -s "-screen 0 1400x900x24" py.test tests --cov fast_rl --cov-report html --doctest-modules --junitxml=junit/test-results.xml --cov=./ --cov-report=xml --cov-report=html
42 | displayName: 'Test with pytest'
43 |
44 | - task: PublishTestResults@2
45 | condition: succeededOrFailed()
46 | inputs:
47 | testResultsFiles: '**/test-*.xml'
48 | testRunTitle: 'Publish test results for Python $(python.version)'
49 |
50 | - stage: Deploy
51 | condition: and(always(), eq(variables['Build.SourceBranch'], 'refs/heads/master'))
52 | jobs:
53 | - job: "TwineDeploy"
54 | pool:
55 | vmImage: 'ubuntu-16.04' # other options: 'macOS-10.13', 'vs2017-win2016'
56 | strategy:
57 | matrix:
58 | Python36:
59 | python.version: '3.6'
60 | steps:
61 | - task: UsePythonVersion@0
62 | inputs:
63 | versionSpec: '$(python.version)'
64 | # Install python distributions like wheel, twine etc
65 | - task: Bash@3
66 | inputs:
67 | targetType: 'inline'
68 | script: |
69 | echo $TWINE_USERNAME
70 | pip install wheel setuptools twine
71 | python setup.py sdist bdist_wheel
72 | python -m twine upload -u $TWINE_USERNAME -p $TWINE_PASSWORD --repository-url 'https://upload.pypi.org/legacy/' dist/*
73 | env:
74 | TWINE_PASSWORD: $(SECRET_TWINE_PASSWORD)
75 | TWINE_USERNAME: $(SECRET_TWINE_USERNAME)
--------------------------------------------------------------------------------
/build/azure_pipeline_helper.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## Install pybullet
4 | #git clone https://github.com/benelot/pybullet-gym.git
5 | #cd pybullet-gym
6 | #pip install -e .
7 | #cd ../
8 |
9 | ## Install gym_maze
10 | #git clone https://github.com/MattChanTK/gym-maze.git
11 | #cd gym-maze
12 | #python setup.py install
13 | #cd ../
14 |
15 |
--------------------------------------------------------------------------------
/docs_src/data/ant_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/ant_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/ant_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/ant_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dddqn/dddqn_er_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_er_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dddqn/dddqn_per_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dddqn/dddqn_per_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_ddqn/ddqn_er_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_er_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_ddqn/ddqn_per_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_ddqn/ddqn_per_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_er_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_er_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_per_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn fixed targeting/dqn fixed targeting_per_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dueling dqn/dueling dqn_er_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_er_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_dueling dqn/dueling dqn_per_rms.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_dueling dqn/dueling dqn_per_rms.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/cartpole_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/cartpole_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/halfcheetah_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/halfcheetah_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/halfcheetah_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/halfcheetah_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dddqn/dddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dddqn/dddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_ddqn/ddqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_ddqn/ddqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn fixed targeting/dqn fixed targeting_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn/dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dueling dqn/dueling dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_dueling dqn/dueling dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_fixed target dqn/fixed target dqn_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/lunarlander_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/lunarlander_fixed target dqn/fixed target dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/pendulum_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/pendulum_ddpg/ddpg_ExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/data/pendulum_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/docs_src/data/pendulum_ddpg/ddpg_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/docs_src/rl.core.train.interpretation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "pycharm": {
8 | "is_executing": false
9 | }
10 | },
11 | "outputs": [
12 | {
13 | "name": "stdout",
14 | "output_type": "stream",
15 | "text": [
16 | "Can't import one of these: No module named 'pybullet'\n",
17 | "Can't import one of these: No module named 'gym_maze'\n",
18 | "Can't import one of these: No module named 'gym_minigrid'\n"
19 | ]
20 | }
21 | ],
22 | "source": [
23 | "from fast_rl.agents.dqn import *\n",
24 | "from fast_rl.agents.dqn_models import FixedTargetDQNModule\n",
25 | "from fast_rl.core.agent_core import *\n",
26 | "from fast_rl.core.data_block import *\n",
27 | "from fast_rl.core.train import *"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {
34 | "pycharm": {
35 | "is_executing": true
36 | }
37 | },
38 | "outputs": [
39 | {
40 | "data": {
41 | "text/html": [
42 | "
\n",
43 | " \n",
44 | "
\n",
45 | "
epoch
\n",
46 | "
train_loss
\n",
47 | "
valid_loss
\n",
48 | "
time
\n",
49 | "
\n",
50 | " \n",
51 | " \n",
52 | "
\n",
53 | "
0
\n",
54 | "
1.095179
\n",
55 | "
#na#
\n",
56 | "
00:00
\n",
57 | "
\n",
58 | "
\n",
59 | "
1
\n",
60 | "
1.026340
\n",
61 | "
#na#
\n",
62 | "
00:00
\n",
63 | "
\n",
64 | "
\n",
65 | "
2
\n",
66 | "
1.007764
\n",
67 | "
#na#
\n",
68 | "
00:00
\n",
69 | "
\n",
70 | "
\n",
71 | "
3
\n",
72 | "
1.001356
\n",
73 | "
#na#
\n",
74 | "
00:00
\n",
75 | "
\n",
76 | "
\n",
77 | "
4
\n",
78 | "
0.996845
\n",
79 | "
#na#
\n",
80 | "
00:00
\n",
81 | "
\n",
82 | "
\n",
83 | "
5
\n",
84 | "
0.993165
\n",
85 | "
#na#
\n",
86 | "
00:00
\n",
87 | "
\n",
88 | "
\n",
89 | "
6
\n",
90 | "
0.988180
\n",
91 | "
#na#
\n",
92 | "
00:00
\n",
93 | "
\n",
94 | "
\n",
95 | "
7
\n",
96 | "
0.986040
\n",
97 | "
#na#
\n",
98 | "
00:00
\n",
99 | "
\n",
100 | "
\n",
101 | "
8
\n",
102 | "
0.982307
\n",
103 | "
#na#
\n",
104 | "
00:00
\n",
105 | "
\n",
106 | "
\n",
107 | "
9
\n",
108 | "
0.976414
\n",
109 | "
#na#
\n",
110 | "
00:00
\n",
111 | "
\n",
112 | " \n",
113 | "
"
114 | ],
115 | "text/plain": [
116 | ""
117 | ]
118 | },
119 | "metadata": {},
120 | "output_type": "display_data"
121 | }
122 | ],
123 | "source": [
124 | "data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=32, add_valid=False, \n",
125 | " memory_management_strategy='k_partitions_top', k=3)\n",
126 | "model = create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop)\n",
127 | "memory = ExperienceReplay(10000)\n",
128 | "exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)\n",
129 | "learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)\n",
130 | "learner.fit(10)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 7,
136 | "metadata": {
137 | "pycharm": {
138 | "is_executing": true,
139 | "name": "#%%\n"
140 | }
141 | },
142 | "outputs": [
143 | {
144 | "name": "stderr",
145 | "output_type": "stream",
146 | "text": [
147 | "\r",
148 | "t: 0%| | 0/10 [00:00, ?it/s, now=None]"
149 | ]
150 | },
151 | {
152 | "name": "stdout",
153 | "output_type": "stream",
154 | "text": [
155 | "Moviepy - Building video __temp__.mp4.\n",
156 | "Moviepy - Writing video __temp__.mp4\n",
157 | "\n"
158 | ]
159 | },
160 | {
161 | "name": "stderr",
162 | "output_type": "stream",
163 | "text": [
164 | " \r"
165 | ]
166 | },
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "Moviepy - Done !\n",
172 | "Moviepy - video ready __temp__.mp4\n"
173 | ]
174 | },
175 | {
176 | "data": {
177 | "text/html": [
178 | ""
179 | ],
180 | "text/plain": [
181 | ""
182 | ]
183 | },
184 | "execution_count": 7,
185 | "metadata": {},
186 | "output_type": "execute_result"
187 | }
188 | ],
189 | "source": [
190 | "interp = AgentInterpretation(learner, ds_type=DatasetType.Train)\n",
191 | "interp.generate_gif(2).plot()"
192 | ]
193 | }
194 | ],
195 | "metadata": {
196 | "kernelspec": {
197 | "display_name": "Python 3",
198 | "language": "python",
199 | "name": "python3"
200 | },
201 | "language_info": {
202 | "codemirror_mode": {
203 | "name": "ipython",
204 | "version": 3
205 | },
206 | "file_extension": ".py",
207 | "mimetype": "text/x-python",
208 | "name": "python",
209 | "nbconvert_exporter": "python",
210 | "pygments_lexer": "ipython3",
211 | "version": "3.6.7"
212 | },
213 | "pycharm": {
214 | "stem_cell": {
215 | "cell_type": "raw",
216 | "metadata": {
217 | "collapsed": false
218 | },
219 | "source": []
220 | }
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 1
225 | }
226 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: fastrl
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | - fastai
6 | dependencies:
7 | - python=3.6
8 | - pip
9 | - cuda100
10 | - fastprogress
11 | - fastai=1.0.58
12 | - jupyter
13 | - notebook
14 | - setuptools
15 | - pip:
16 | - pytest
17 | - nvidia-ml-py3
18 | - dataclasses
19 | - numpy==1.17.4
20 | - pandas
21 | - pyyaml
22 | - opencv-contrib-python
23 | - requests
24 | - sklearn
25 | - cython
26 | - gym[box2d, atari]
27 | - easydict
28 | - matplotlib
29 | - jupyter_console
30 | - moviepy
31 | - pygifsicle
--------------------------------------------------------------------------------
/fast_rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/__init__.py
--------------------------------------------------------------------------------
/fast_rl/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/agents/__init__.py
--------------------------------------------------------------------------------
/fast_rl/agents/ddpg.py:
--------------------------------------------------------------------------------
1 | from fastai.basic_train import LearnerCallback
2 | import fastai.tabular.data
3 | from fastai.torch_core import *
4 |
5 | from fast_rl.agents.ddpg_models import DDPGModule
6 | from fast_rl.core.agent_core import ExperienceReplay, ExplorationStrategy, Experience
7 | from fast_rl.core.basic_train import AgentLearner
8 | from fast_rl.core.data_block import MDPDataBunch, MDPStep, FEED_TYPE_STATE, FEED_TYPE_IMAGE
9 |
10 |
11 | class DDPGLearner(AgentLearner):
12 | def __init__(self, data: MDPDataBunch, model, memory, exploration_method, trainers, opt=optim.Adam,
13 | **kwargs):
14 | self.memory: Experience = memory
15 | self.exploration_method: ExplorationStrategy = exploration_method
16 | super().__init__(data=data, model=model, opt=opt, **kwargs)
17 | self.ddpg_trainers = listify(trainers)
18 | for t in self.ddpg_trainers: self.callbacks.append(t(self))
19 |
20 | def predict(self, element, **kwargs):
21 | with torch.no_grad():
22 | training = self.model.training
23 | if element.shape[0] == 1: self.model.eval()
24 | pred = self.model(element)
25 | if training: self.model.train()
26 | return self.exploration_method.perturb(pred.detach().cpu().numpy(), self.data.action.action_space)
27 |
28 | def interpret_q(self, item):
29 | with torch.no_grad():
30 | return self.model.interpret_q(item).cpu().numpy().item()
31 |
32 |
33 | class BaseDDPGTrainer(LearnerCallback):
34 | def __init__(self, learn):
35 | super().__init__(learn)
36 | self.max_episodes = 0
37 | self.episode = 0
38 | self.iteration = 0
39 | self.copy_over_frequency = 3
40 |
41 | @property
42 | def learn(self) -> DDPGLearner:
43 | return self._learn()
44 |
45 | def on_train_begin(self, n_epochs, **kwargs: Any):
46 | self.max_episodes = n_epochs
47 |
48 | def on_epoch_begin(self, epoch, **kwargs: Any):
49 | self.episode = epoch
50 | self.iteration = 0
51 |
52 | def on_loss_begin(self, **kwargs: Any):
53 | """Performs tree updates, exploration updates, and model optimization."""
54 | if self.learn.model.training: self.learn.memory.update(item=self.learn.data.x.items[-1])
55 | self.learn.exploration_method.update(self.episode, max_episodes=self.max_episodes, explore=self.learn.model.training)
56 | if not self.learn.warming_up:
57 | samples: List[MDPStep] = self.memory.sample(self.learn.data.bs)
58 | post_optimize = self.learn.model.optimize(samples)
59 | if self.learn.model.training:
60 | self.learn.memory.refresh(post_optimize=post_optimize)
61 | self.learn.model.target_copy_over()
62 | self.iteration += 1
63 |
64 |
65 | def create_ddpg_model(data: MDPDataBunch, base_arch: DDPGModule, layers=None, ignore_embed=False, channels=None,
66 | opt=torch.optim.RMSprop, loss_func=None, **kwargs):
67 | bs, state, action = data.bs, data.state, data.action
68 | nc, w, h, n_conv_blocks = -1, -1, -1, [] if state.mode == FEED_TYPE_STATE else ifnone(channels, [32, 32, 32])
69 | if state.mode == FEED_TYPE_IMAGE: nc, w, h = state.s.shape[3], state.s.shape[2], state.s.shape[1]
70 | _layers = ifnone(layers, [400, 200] if len(n_conv_blocks) == 0 else [200, 200])
71 | if ignore_embed or np.any(state.n_possible_values == np.inf) or state.mode == FEED_TYPE_IMAGE: emb_szs = []
72 | else: emb_szs = [(d+1, int(fastai.tabular.data.emb_sz_rule(d))) for d in state.n_possible_values.reshape(-1, )]
73 | ao = int(action.taken_action.shape[1])
74 | model = base_arch(ni=state.s.shape[1], ao=ao, layers=_layers, emb_szs=emb_szs, n_conv_blocks=n_conv_blocks,
75 | nc=nc, w=w, h=h, opt=opt, loss_func=loss_func, **kwargs)
76 | return model
77 |
78 |
79 | ddpg_config = {
80 | DDPGModule: BaseDDPGTrainer
81 | }
82 |
83 |
84 | def ddpg_learner(data: MDPDataBunch, model, memory: ExperienceReplay, exploration_method: ExplorationStrategy,
85 | trainers=None, **kwargs):
86 | trainers = ifnone(trainers, ddpg_config[model.__class__])
87 | return DDPGLearner(data, model, memory, exploration_method, trainers, **kwargs)
88 |
--------------------------------------------------------------------------------
/fast_rl/agents/ddpg_models.py:
--------------------------------------------------------------------------------
1 | from fastai.callback import OptimWrapper
2 |
3 | from fast_rl.core.layers import *
4 |
5 |
6 | class CriticModule(nn.Sequential):
7 | def __init__(self, ni: int, ao: int, layers: List[int], batch_norm=False,
8 | n_conv_blocks: Collection[int] = 0, nc=3, emb_szs: ListSizes = None,
9 | w=-1, h=-1, ks=None, stride=None, conv_kern_proportion=0.1, stride_proportion=0.1, pad=False):
10 | super().__init__()
11 | self.switched, self.batch_norm = False, batch_norm
12 | # self.ks, self.stride = ([], []) if len(n_conv_blocks) == 0 else ks_stride(ks, stride, w, h, n_conv_blocks, conv_kern_proportion, stride_proportion)
13 | self.ks, self.stride=([], []) if len(n_conv_blocks)==0 else (ifnone(ks,[10,10,10]),ifnone(stride,[5,5,5]))
14 | self.action_model = nn.Sequential()
15 | _layers = [conv_bn_lrelu(ch, self.nf, ks=ks, stride=stride, pad=pad, bn=self.batch_norm) for ch, self.nf, ks, stride in zip([nc]+n_conv_blocks[:-1],n_conv_blocks, self.ks, self.stride)]
16 | if _layers: ni = self.setup_conv_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h)
17 | else:
18 | self.add_module('lin_state_block', StateActionPassThrough(nn.Linear(ni, layers[0])))
19 | ni, layers = layers[0], layers[1:]
20 | self.setup_linear_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h, emb_szs=emb_szs, layers=layers, ao=ao)
21 | self.init_weights(self)
22 |
23 | def setup_conv_block(self, _layers, ni, nc, w, h):
24 | self.add_module('conv_block', StateActionPassThrough(nn.Sequential(*(self.fix_switched_channels(ni, nc, _layers) + [Flatten()]))))
25 | return int(self(torch.zeros((2, 1, w, h, nc) if self.switched else (2, 1, nc, w, h)))[0].view(-1, ).shape[0])
26 |
27 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao):
28 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni+ao if not emb_szs else ao, layers=layers, out_sz=1,
29 | use_bn=self.batch_norm)
30 | if not emb_szs: tabular_model.embeds = None
31 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm()
32 | self.add_module('lin_block', CriticTabularEmbedWrapper(tabular_model, exclude_cat=not emb_szs))
33 |
34 | def fix_switched_channels(self, current_channels, expected_channels, layers: list):
35 | if current_channels == expected_channels:
36 | return layers
37 | else:
38 | self.switched = True
39 | return [ChannelTranspose()] + layers
40 |
41 | def init_weights(self, m):
42 | if type(m) == nn.Linear:
43 | torch.nn.init.xavier_uniform_(m.weight)
44 | m.bias.data.fill_(0.01)
45 |
46 |
47 | class ActorModule(nn.Sequential):
48 | def __init__(self, ni: int, ao: int, layers: Collection[int],batch_norm = False,
49 | n_conv_blocks: Collection[int] = 0, nc=3, emb_szs: ListSizes = None,
50 | w=-1, h=-1, ks=None, stride=None, conv_kern_proportion=0.1, stride_proportion=0.1, pad=False):
51 | super().__init__()
52 | self.switched, self.batch_norm = False, batch_norm
53 | # self.ks, self.stride = ([], []) if len(n_conv_blocks) == 0 else ks_stride(ks, stride, w, h, n_conv_blocks, conv_kern_proportion, stride_proportion)
54 | self.ks, self.stride=([], []) if len(n_conv_blocks)==0 else (ifnone(ks,[10,10,10]),ifnone(stride,[5,5,5]))
55 | self.action_model = nn.Sequential()
56 | _layers = [conv_bn_lrelu(ch, self.nf, ks=ks, stride=stride, pad=pad, bn=self.batch_norm) for ch, self.nf, ks, stride in zip([nc]+n_conv_blocks[:-1],n_conv_blocks, self.ks, self.stride)]
57 | if _layers: ni = self.setup_conv_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h)
58 | self.setup_linear_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h, emb_szs=emb_szs, layers=layers, ao=ao)
59 | self.init_weights(self)
60 |
61 | def setup_conv_block(self, _layers, ni, nc, w, h):
62 | self.add_module('conv_block', nn.Sequential(*(self.fix_switched_channels(ni, nc, _layers) + [Flatten()])))
63 | return int(self(torch.zeros((1, w, h, nc) if self.switched else (1, nc, w, h))).view(-1, ).shape[0])
64 |
65 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao):
66 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni if not emb_szs else 0, layers=layers, out_sz=ao, use_bn=self.batch_norm)
67 |
68 | if not emb_szs: tabular_model.embeds = None
69 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm()
70 | self.add_module('lin_block', TabularEmbedWrapper(tabular_model))
71 | self.add_module('tanh', nn.Tanh())
72 |
73 | def fix_switched_channels(self, current_channels, expected_channels, layers: list):
74 | if current_channels == expected_channels:
75 | return layers
76 | else:
77 | self.switched = True
78 | return [ChannelTranspose()] + layers
79 |
80 | def init_weights(self, m):
81 | if type(m) == nn.Linear:
82 | torch.nn.init.xavier_uniform_(m.weight)
83 | m.bias.data.fill_(0.01)
84 |
85 | class DDPGModule(Module):
86 | def __init__(self, ni: int, ao: int, layers: Collection[int], discount: float = 0.99,
87 | n_conv_blocks: Collection[int] = 0, nc=3, opt=None, emb_szs: ListSizes = None, loss_func=None,
88 | w=-1, h=-1, ks=None, stride=None, grad_clip=5, tau=1e-3, lr=1e-3, actor_lr=1e-4,
89 | batch_norm=False, **kwargs):
90 | r"""
91 | Implementation of a discrete control algorithm using an actor/critic architecture.
92 |
93 | Notes:
94 | Uses 4 networks, 2 actors, 2 critics.
95 | All models use batch norm for feature invariance.
96 | NNCritic simply predicts Q while the Actor proposes the actions to take given a s s.
97 |
98 | References:
99 | [1] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning."
100 | arXiv preprint arXiv:1509.02971 (2015).
101 |
102 | Args:
103 | data: Primary data object to use.
104 | memory: How big the tree buffer will be for offline training.
105 | tau: Defines how "soft/hard" we will copy the target networks over to the primary networks.
106 | discount: Determines the amount of discounting the existing Q reward.
107 | lr: Rate that the opt will learn parameter gradients.
108 | """
109 | super().__init__()
110 | self.name = 'DDPG'
111 | self.lr = lr
112 | self.discount = discount
113 | self.tau = tau
114 | self.loss_func = None
115 | self.loss = None
116 | self.opt = None
117 | self.critic_optimizer = None
118 | self.batch_norm = batch_norm
119 | self.actor_lr = actor_lr
120 |
121 | self.action_model = ActorModule(ni=ni, ao=ao, layers=layers, nc=nc, emb_szs=emb_szs,batch_norm = batch_norm,
122 | w=w, h=h, ks=ks, n_conv_blocks=n_conv_blocks, stride=stride)
123 | self.critic_model = CriticModule(ni=ni, ao=ao, layers=layers, nc=nc, emb_szs=emb_szs, batch_norm = batch_norm,
124 | w=w, h=h, ks=ks, n_conv_blocks=n_conv_blocks, stride=stride)
125 |
126 | self.set_opt(opt)
127 |
128 | self.t_action_model = deepcopy(self.action_model)
129 | self.t_critic_model = deepcopy(self.critic_model)
130 |
131 | self.target_copy_over()
132 | self.tau = tau
133 |
134 | def set_opt(self, opt):
135 | self.opt=OptimWrapper.create(ifnone(optim.Adam, opt), lr=self.actor_lr, layer_groups=[self.action_model])
136 | self.critic_optimizer=OptimWrapper.create(ifnone(optim.Adam, opt), lr=self.lr, layer_groups=[self.critic_model])
137 |
138 | def optimize(self, sampled):
139 | r"""
140 | Performs separate updates to the actor and critic models.
141 |
142 | Get the predicted yi for optimizing the actor:
143 |
144 | .. math::
145 | y_i = r_i + \lambda Q^'(s_{i+1}, \; \mu^'(s_{i+1} \;|\; \Theta^{\mu'}}\;|\; \Theta^{Q'})
146 |
147 | On actor optimization, use the actor as the sample policy gradient.
148 |
149 | Returns:
150 |
151 | """
152 | with torch.no_grad():
153 | r = torch.cat([item.reward.float() for item in sampled])
154 | s_prime = torch.cat([item.s_prime for item in sampled])
155 | s = torch.cat([item.s for item in sampled])
156 | a = torch.cat([item.a.float() for item in sampled])
157 |
158 | y = r + self.discount * self.t_critic_model((s_prime, self.t_action_model(s_prime)))
159 |
160 | y_hat = self.critic_model((s, a))
161 |
162 | critic_loss = self.loss_func(y_hat, y)
163 |
164 | if self.training:
165 | self.critic_optimizer.zero_grad()
166 | critic_loss.backward()
167 | self.critic_optimizer.step()
168 |
169 | actor_loss = -self.critic_model((s, self.action_model(s))).mean()
170 |
171 | self.loss = critic_loss.cpu().detach()
172 |
173 | if self.training:
174 | self.opt.zero_grad()
175 | actor_loss.backward()
176 | self.opt.step()
177 |
178 | with torch.no_grad():
179 | post_info = {'td_error': (y - y_hat).cpu().numpy()}
180 | return post_info
181 |
182 | def forward(self, xi):
183 | training = self.training
184 | if xi.shape[0] == 1: self.eval()
185 | pred = self.action_model(xi)
186 | if training: self.train()
187 | return pred
188 |
189 | def target_copy_over(self):
190 | """ Soft target updates the actor and critic models.."""
191 | self.soft_target_copy_over(self.t_action_model, self.action_model, self.tau)
192 | self.soft_target_copy_over(self.t_critic_model, self.critic_model, self.tau)
193 |
194 | def soft_target_copy_over(self, t_m, f_m, tau):
195 | for target_param, local_param in zip(t_m.parameters(), f_m.parameters()):
196 | target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
197 |
198 | def interpret_q(self, item):
199 | with torch.no_grad():
200 | return self.critic_model(torch.cat((item.s, item.a), 1))
--------------------------------------------------------------------------------
/fast_rl/agents/dqn.py:
--------------------------------------------------------------------------------
1 | from fastai.basic_train import LearnerCallback
2 | from fastai.tabular.data import emb_sz_rule
3 |
4 | from fast_rl.agents.dqn_models import *
5 | from fast_rl.core.agent_core import ExperienceReplay, ExplorationStrategy, Experience
6 | from fast_rl.core.basic_train import AgentLearner
7 | from fast_rl.core.data_block import MDPDataBunch, FEED_TYPE_STATE, FEED_TYPE_IMAGE, MDPStep
8 |
9 |
10 | class DQNLearner(AgentLearner):
11 | def __init__(self, data: MDPDataBunch, model, memory, exploration_method, trainers, opt=torch.optim.RMSprop,
12 | **learn_kwargs):
13 | self.memory: Experience = memory
14 | self.exploration_method: ExplorationStrategy = exploration_method
15 | super().__init__(data=data, model=model, opt=opt, **learn_kwargs)
16 | self.trainers = listify(trainers)
17 | for t in self.trainers: self.callbacks.append(t(self))
18 |
19 | def predict(self, element, **kwargs):
20 | training = self.model.training
21 | if element.shape[0] == 1: self.model.eval()
22 | pred = self.model(element)
23 | if training: self.model.train()
24 | return self.exploration_method.perturb(torch.argmax(pred, 1), self.data.action.action_space)
25 |
26 | def interpret_q(self, item):
27 | with torch.no_grad():
28 | return torch.sum(self.model(item.s)).cpu().numpy().item()
29 |
30 |
31 | class FixedTargetDQNTrainer(LearnerCallback):
32 | def __init__(self, learn, copy_over_frequency=3):
33 | r"""Handles updating the target model in a fixed target DQN.
34 |
35 | Args:
36 | learn: Basic Learner.
37 | copy_over_frequency: For every N iterations we want to update the target model.
38 | """
39 | super().__init__(learn)
40 | self._order = 1
41 | self.iteration = 0
42 | self.copy_over_frequency = copy_over_frequency
43 |
44 | def on_step_end(self, **kwargs: Any):
45 | self.iteration += 1
46 | if self.iteration % self.copy_over_frequency == 0 and self.learn.model.training:
47 | self.learn.model.target_copy_over()
48 |
49 |
50 | class BaseDQNTrainer(LearnerCallback):
51 | def __init__(self, learn: DQNLearner, max_episodes=None):
52 | r"""Handles basic DQN end of step model optimization."""
53 | super().__init__(learn)
54 | self.n_skipped = 0
55 | self._persist = max_episodes is not None
56 | self.max_episodes = max_episodes
57 | self.episode = -1
58 | self.iteration = 0
59 | # For the callback handler
60 | self._order = 0
61 | self.previous_item = None
62 |
63 | @property
64 | def learn(self) -> DQNLearner:
65 | return self._learn()
66 |
67 | def on_train_begin(self, n_epochs, **kwargs: Any):
68 | self.max_episodes = n_epochs if not self._persist else self.max_episodes
69 |
70 | def on_epoch_begin(self, epoch, **kwargs: Any):
71 | self.episode = epoch if not self._persist else self.episode + 1
72 | self.iteration = 0
73 |
74 | def on_loss_begin(self, **kwargs: Any):
75 | r"""Performs tree updates, exploration updates, and model optimization."""
76 | if self.learn.model.training: self.learn.memory.update(item=self.learn.data.x.items[-1])
77 | self.learn.exploration_method.update(self.episode, max_episodes=self.max_episodes, explore=self.learn.model.training)
78 | if not self.learn.warming_up:
79 | samples: List[MDPStep] = self.memory.sample(self.learn.data.bs)
80 | post_optimize = self.learn.model.optimize(samples)
81 | if self.learn.model.training: self.learn.memory.refresh(post_optimize=post_optimize)
82 | self.iteration += 1
83 |
84 |
85 | def create_dqn_model(data: MDPDataBunch, base_arch: DQNModule, layers=None, ignore_embed=False, channels=None,
86 | opt=torch.optim.RMSprop, loss_func=None, lr=0.001, **kwargs):
87 | bs,state,action=data.bs,data.state,data.action
88 | nc, w, h, n_conv_blocks = -1, -1, -1, [] if state.mode == FEED_TYPE_STATE else ifnone(channels, [16, 16, 16])
89 | if state.mode == FEED_TYPE_IMAGE: nc, w, h = state.s.shape[3], state.s.shape[2], state.s.shape[1]
90 | _layers = ifnone(layers, [64, 64])
91 | if ignore_embed or np.any(state.n_possible_values == np.inf) or state.mode == FEED_TYPE_IMAGE: emb_szs = []
92 | else: emb_szs = [(d+1, int(emb_sz_rule(d))) for d in state.n_possible_values.reshape(-1, )]
93 | ao = int(action.n_possible_values[0])
94 | model = base_arch(ni=state.s.shape[1], ao=ao, layers=_layers, emb_szs=emb_szs, n_conv_blocks=n_conv_blocks,
95 | nc=nc, w=w, h=h, opt=opt, loss_func=loss_func, lr=lr, **kwargs)
96 | return model
97 |
98 |
99 | dqn_config = {
100 | DQNModule: [BaseDQNTrainer],
101 | DoubleDQNModule: [BaseDQNTrainer, FixedTargetDQNTrainer],
102 | DuelingDQNModule: [BaseDQNTrainer, FixedTargetDQNTrainer],
103 | DoubleDuelingModule: [BaseDQNTrainer, FixedTargetDQNTrainer],
104 | FixedTargetDQNModule: [BaseDQNTrainer, FixedTargetDQNTrainer]
105 | }
106 |
107 |
108 | def dqn_learner(data: MDPDataBunch, model: DQNModule, memory: ExperienceReplay, exploration_method: ExplorationStrategy,
109 | trainers=None, copy_over_frequency=300, **kwargs):
110 | trainers = ifnone(trainers, [c if c != FixedTargetDQNTrainer else partial(c, copy_over_frequency=copy_over_frequency)
111 | for c in dqn_config[model.__class__]])
112 | return DQNLearner(data, model, memory, exploration_method, trainers, **kwargs)
113 |
--------------------------------------------------------------------------------
/fast_rl/agents/dqn_models.py:
--------------------------------------------------------------------------------
1 | from fastai.callback import OptimWrapper
2 |
3 | from fast_rl.core.layers import *
4 |
5 |
6 | class DQNModule(Module):
7 |
8 | def __init__(self, ni: int, ao: int, layers: Collection[int], discount: float = 0.99, lr=0.001,
9 | n_conv_blocks: Collection[int] = 0, nc=3, opt=None, emb_szs: ListSizes = None, loss_func=None,
10 | w=-1, h=-1, ks: Union[None, list]=None, stride: Union[None, list]=None, grad_clip=5,
11 | conv_kern_proportion=0.1, stride_proportion=0.1, pad=False, batch_norm=False):
12 | r"""
13 | Basic DQN Module.
14 |
15 | Args:
16 | ni: Number of inputs. Expecting a flat state `[1 x ni]`
17 | ao: Number of actions to output.
18 | layers: Number of layers where is determined per element.
19 | n_conv_blocks: If `n_conv_blocks` is not 0, then convolutional blocks will be added
20 | to the head on top of existing linear layers.
21 | nc: Number of channels that will be expected by the convolutional blocks.
22 | """
23 | super().__init__()
24 | self.name = 'DQN'
25 | self.loss = None
26 | self.loss_func = loss_func
27 | self.discount = discount
28 | self.gradient_clipping_norm = grad_clip
29 | self.lr = lr
30 | self.batch_norm = batch_norm
31 | self.switched = False
32 | # self.ks, self.stride = ([], []) if len(n_conv_blocks) == 0 else ks_stride(ks, stride, w, h, n_conv_blocks, conv_kern_proportion, stride_proportion)
33 | self.ks, self.stride=([], []) if len(n_conv_blocks)==0 else (ifnone(ks, [10, 10, 10]), ifnone(stride, [5, 5, 5]))
34 | self.action_model = nn.Sequential()
35 | _layers = [conv_bn_lrelu(ch, self.nf, ks=ks, stride=stride, pad=pad, bn=self.batch_norm) for ch, self.nf, ks, stride in zip([nc]+n_conv_blocks[:-1],n_conv_blocks, self.ks, self.stride)]
36 |
37 | if _layers: ni = self.setup_conv_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h)
38 | self.setup_linear_block(_layers=_layers, ni=ni, nc=nc, w=w, h=h, emb_szs=emb_szs, layers=layers, ao=ao)
39 | self.init_weights(self.action_model)
40 | self.opt = None
41 | self.set_opt(opt)
42 |
43 | def set_opt(self, opt):
44 | self.opt=OptimWrapper.create(ifnone(optim.Adam, opt), lr=self.lr, layer_groups=[self.action_model])
45 |
46 | def setup_conv_block(self, _layers, ni, nc, w, h):
47 | self.action_model.add_module('conv_block', nn.Sequential(*(self.fix_switched_channels(ni, nc, _layers) + [Flatten()])))
48 | training = self.action_model.training
49 | self.action_model.eval()
50 | ni = int(self.action_model(torch.zeros((1, w, h, nc) if self.switched else (1, nc, w, h))).view(-1, ).shape[0])
51 | self.action_model.train(training)
52 | return ni
53 |
54 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao):
55 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni if not emb_szs else 0, layers=layers, out_sz=ao, use_bn=self.batch_norm)
56 | if not emb_szs: tabular_model.embeds = None
57 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm()
58 | self.action_model.add_module('lin_block', TabularEmbedWrapper(tabular_model))
59 |
60 | def fix_switched_channels(self, current_channels, expected_channels, layers: list):
61 | if current_channels == expected_channels:
62 | return layers
63 | else:
64 | self.switched = True
65 | return [ChannelTranspose()] + layers
66 |
67 | def forward(self, xi: Tensor):
68 | training = self.training
69 | if xi.shape[0] == 1: self.eval()
70 | pred = self.action_model(xi)
71 | if training: self.train()
72 | return pred
73 |
74 | def init_weights(self, m):
75 | if type(m) == nn.Linear:
76 | torch.nn.init.xavier_uniform_(m.weight)
77 | m.bias.data.fill_(0.01)
78 |
79 | def sample_mask(self, d):
80 | return torch.sub(1.0, d)
81 |
82 | def optimize(self, sampled):
83 | r"""Uses ER to optimize the Q-net (without fixed targets).
84 |
85 | Uses the equation:
86 |
87 | .. math::
88 | Q^{*}(s, a) = \mathbb{E}_{s'∼ \Big\epsilon} \Big[r + \lambda \displaystyle\max_{a'}(Q^{*}(s' , a'))
89 | \;|\; s, a \Big]
90 |
91 |
92 | Returns (dict): Optimization information
93 |
94 | """
95 | with torch.no_grad():
96 | r = torch.cat([item.reward.float() for item in sampled])
97 | s_prime = torch.cat([item.s_prime for item in sampled])
98 | s = torch.cat([item.s for item in sampled])
99 | a = torch.cat([item.a.long() for item in sampled])
100 | d = torch.cat([item.done.float() for item in sampled])
101 | masking = self.sample_mask(d)
102 |
103 | y_hat = self.y_hat(s, a)
104 | y = self.y(s_prime, masking, r, y_hat)
105 |
106 | loss = self.loss_func(y, y_hat)
107 |
108 | if self.training:
109 | self.opt.zero_grad()
110 | loss.backward()
111 | torch.nn.utils.clip_grad_norm_(self.action_model.parameters(), self.gradient_clipping_norm)
112 | for param in self.action_model.parameters():
113 | if param.grad is not None: param.grad.data.clamp_(-1, 1)
114 | self.opt.step()
115 |
116 | with torch.no_grad():
117 | self.loss = loss
118 | post_info = {'td_error': to_detach(y - y_hat).cpu().numpy()}
119 | return post_info
120 |
121 | def y_hat(self, s, a):
122 | return self.action_model(s).gather(1, a)
123 |
124 | def y(self, s_prime, masking, r, y_hat):
125 | return self.discount * self.action_model(s_prime).max(1)[0].unsqueeze(1) * masking + r.expand_as(y_hat)
126 |
127 |
128 | class FixedTargetDQNModule(DQNModule):
129 | def __init__(self, ni: int, ao: int, layers: Collection[int], tau=1, **kwargs):
130 | super().__init__(ni, ao, layers, **kwargs)
131 | self.name = 'Fixed Target DQN'
132 | self.tau = tau
133 | self.target_model = copy(self.action_model)
134 |
135 | def target_copy_over(self):
136 | r""" Updates the target network from calls in the FixedTargetDQNTrainer callback."""
137 | # self.target_net.load_state_dict(self.action_model.state_dict())
138 | for target_param, local_param in zip(self.target_model.parameters(), self.action_model.parameters()):
139 | target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
140 |
141 | def y(self, s_prime, masking, r, y_hat):
142 | r"""
143 | Uses the equation:
144 |
145 | .. math::
146 |
147 | Q^{*}(s, a) = \mathbb{E}_{s'∼ \Big\epsilon} \Big[r + \lambda \displaystyle\max_{a'}(Q^{*}(s' , a'))
148 | \;|\; s, a \Big]
149 |
150 | """
151 | return self.discount * self.target_model(s_prime).max(1)[0].unsqueeze(1) * masking + r.expand_as(y_hat)
152 |
153 |
154 | class DoubleDQNModule(FixedTargetDQNModule):
155 | def __init__(self, ni: int, ao: int, layers: Collection[int], **kwargs):
156 | super().__init__(ni, ao, layers, **kwargs)
157 | self.name = 'DDQN'
158 |
159 | def calc_y(self, s_prime, masking, r, y_hat):
160 | return self.discount * self.target_model(s_prime).gather(1, self.action_model(s_prime).argmax(1).unsqueeze(
161 | 1)) * masking + r.expand_as(y_hat)
162 |
163 |
164 | class DuelingBlock(nn.Module):
165 | def __init__(self, ao, stream_input_size):
166 | super().__init__()
167 |
168 | self.val = nn.Linear(stream_input_size, 1)
169 | self.adv = nn.Linear(stream_input_size, ao)
170 |
171 | def forward(self, xi):
172 | r"""Splits the base neural net output into 2 streams to evaluate the advantage and v of the s space and
173 | corresponding actions.
174 |
175 | .. math::
176 | Q(s,a;\; \Theta, \\alpha, \\beta) = V(s;\; \Theta, \\beta) + A(s, a;\; \Theta, \\alpha) - \\frac{1}{|A|}
177 | \\Big\\sum_{a'} A(s, a';\; \Theta, \\alpha)
178 |
179 | """
180 | val, adv = self.val(xi), self.adv(xi)
181 | xi = val.expand_as(adv) + (adv - adv.mean()).squeeze(0)
182 | return xi
183 |
184 |
185 | class DuelingDQNModule(FixedTargetDQNModule):
186 | def __init__(self, **kwargs):
187 | super().__init__(**kwargs)
188 | self.name = 'Dueling DQN'
189 |
190 | def setup_linear_block(self, _layers, ni, nc, w, h, emb_szs, layers, ao):
191 | tabular_model = TabularModel(emb_szs=emb_szs, n_cont=ni if not emb_szs else 0, layers=layers, out_sz=ao,
192 | use_bn=self.batch_norm)
193 | if not emb_szs: tabular_model.embeds = None
194 | if not self.batch_norm: tabular_model.bn_cont = FakeBatchNorm()
195 | tabular_model.layers, removed_layer = split_model(tabular_model.layers, [last_layer(tabular_model)])
196 | ni = removed_layer[0].in_features
197 | self.action_model.add_module('lin_block', TabularEmbedWrapper(tabular_model))
198 | self.action_model.add_module('dueling_block', DuelingBlock(ao, ni))
199 |
200 |
201 | class DoubleDuelingModule(DuelingDQNModule, DoubleDQNModule):
202 | def __init__(self, **kwargs):
203 | super().__init__(**kwargs)
204 | self.name = 'DDDQN'
205 |
--------------------------------------------------------------------------------
/fast_rl/core/Interpreter.py:
--------------------------------------------------------------------------------
1 | import io
2 | from functools import partial
3 | from typing import List, Tuple, Dict
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import scipy.stats as st
8 | import torch
9 | from PIL import Image
10 | from fastai.train import Interpretation, DatasetType, copy
11 | from gym.spaces import Box
12 | from itertools import product
13 | from matplotlib.axes import Axes
14 | from matplotlib.figure import Figure
15 | from moviepy.video.VideoClip import VideoClip
16 | from moviepy.video.io.bindings import mplfig_to_npimage
17 | from torch import nn
18 |
19 | from fast_rl.core import Learner
20 | from fast_rl.core.data_block import MarkovDecisionProcessSliceAlpha, FEED_TYPE_IMAGE
21 |
22 |
23 | class AgentInterpretationAlpha(Interpretation):
24 | def __init__(self, learn: Learner, ds_type: DatasetType = DatasetType.Valid, base_chart_size=(20, 10)):
25 | """
26 | Handles converting a learner, and it's runs into useful human interpretable information.
27 |
28 | Notes:
29 | This class is called AgentInterpretationAlpha because it will overall get deprecated.
30 | The final working version will be called AgentInterpretation.
31 |
32 | Args:
33 | learn:
34 | """
35 | super().__init__(learn, None, None, None, ds_type=ds_type)
36 | self.current_animation = None
37 | plt.rcParams["figure.figsize"] = base_chart_size
38 |
39 | def _get_items(self, ignore=True):
40 | episodes = list(self.ds.x.info.keys())
41 | if ignore or len(episodes) == 0: return self.ds.x.items
42 | return [item for item in self.ds.x.items if item.episode in episodes]
43 |
44 | @classmethod
45 | def from_learner(cls, learn: Learner, ds_type: DatasetType = DatasetType.Valid, activ: nn.Module = None):
46 | raise NotImplementedError
47 |
48 | def normalize(self, item: np.array):
49 | if np.max(item) - np.min(item) != 0:
50 | return np.divide(item + np.min(item), np.max(item) - np.min(item))
51 | else:
52 | item.fill(1)
53 | return item
54 |
55 | def top_losses(self, k: int = None, largest=True):
56 | raise NotImplementedError
57 |
58 | def reward_heatmap(self, episode_slices: List[MarkovDecisionProcessSliceAlpha], action=None):
59 | """
60 | Takes a state_space and uses the agent to heat map rewards over the space.
61 |
62 | We first need to determine if the s space is discrete or discrete.
63 |
64 | Args:
65 | state_space:
66 |
67 | Returns:
68 |
69 | """
70 | if action is not None: action = torch.tensor(action).long()
71 | current_state_slice = [p for p in product(
72 | np.arange(min(self.ds.env.observation_space.low), max(self.ds.env.observation_space.high) + 1),
73 | repeat=len(self.ds.env.observation_space.high))]
74 | heat_map = np.zeros(np.add(self.ds.env.observation_space.high, 1))
75 | with torch.no_grad():
76 | for state in current_state_slice:
77 | if action is not None:
78 | heat_map[state] = self.learn.model(torch.from_numpy(np.array(state)).unsqueeze(0))[0].gather(0, action)
79 | else:
80 | self.learn.model.eval()
81 | if self.learn.model.name == 'DDPG':
82 | heat_map[state] = self.learn.model.critic_model(torch.cat((torch.from_numpy(np.array(state)).unsqueeze(0).float(), self.learn.model.action_model(torch.from_numpy(np.array(state)).unsqueeze(0).float())), 1))
83 | else:
84 | heat_map[state] = self.learn.model(torch.from_numpy(np.array(state)).unsqueeze(0))[0].max().numpy()
85 | return heat_map
86 |
87 | def plot_heatmapped_episode(self, episode, fig_size=(13, 5), action_index=None, return_heat_maps=False):
88 | """
89 | Generates plots of heatmapped s spaces for analyzing reward distribution.
90 |
91 | Currently only makes sense for grid based envs. Will be expecting gym_maze environments that are discrete.
92 |
93 | Returns:
94 |
95 | """
96 | if not str(self.ds.env.spec).__contains__('maze'):
97 | raise NotImplementedError('Currently only supports gym_maze envs that have discrete s spaces')
98 | if not isinstance(self.ds.state_size, Box):
99 | raise NotImplementedError('Currently only supports Box based s spaces with 2 dimensions')
100 |
101 | items = self._get_items()
102 | heat_maps = []
103 |
104 | # For each episode
105 | buffer = []
106 | episode = episode if episode != -1 else list(set([i.episode for i in items]))[-1]
107 | for item in [i for i in items if i.episode == episode]:
108 | buffer.append(item)
109 | heat_map = self.reward_heatmap(buffer, action=action_index)
110 | heat_maps.append((copy(heat_map), copy(buffer[-1]), copy(episode)))
111 |
112 | plots = []
113 | for single_heatmap in [heat_maps[-1]]:
114 | fig, ax = plt.subplots(1, 2, figsize=fig_size)
115 | fig.suptitle(f'Episode {episode}')
116 | ax[0].imshow(single_heatmap[1].to_one().data)
117 | im = ax[1].imshow(single_heatmap[0])
118 | ax[0].grid(False)
119 | ax[1].grid(False)
120 | ax[0].set_title('Final State Snapshot')
121 | ax[1].set_title('State Space Heatmap')
122 | fig.colorbar(im, ax=ax[1])
123 |
124 | buf = io.BytesIO()
125 | fig.savefig(buf, format='png')
126 | # Closing the figure prevents it from being displayed directly inside
127 | # the notebook.
128 | plt.close(fig)
129 | buf.seek(0)
130 | # Create Image object
131 | plots.append(np.array(Image.open(buf))[:, :, :3])
132 |
133 | for plot in plots:
134 | plt.grid(False)
135 | plt.xticks([])
136 | plt.yticks([])
137 | plt.tight_layout()
138 | plt.imshow(plot)
139 | plt.show()
140 |
141 | if return_heat_maps: return heat_maps
142 |
143 | def plot_episode(self, episode):
144 | items = self._get_items(False) # type: List[MarkovDecisionProcessSliceAlpha]
145 |
146 | episode_counter = 0
147 | # For each episode
148 | buffer = []
149 | for item in items:
150 | buffer.append(item)
151 | if item.done:
152 | if episode_counter == episode:
153 | break
154 | episode_counter += 1
155 | buffer = []
156 |
157 | plots = []
158 | with torch.no_grad():
159 | agent_reward_plots = [self.learn.model(torch.from_numpy(np.array(i.current_state))).max().numpy() for i in
160 | buffer]
161 | fig, ax = plt.subplots(1, 1, figsize=(5, 5))
162 | fig.suptitle(f'Episode {episode}')
163 | ax.plot(agent_reward_plots)
164 | ax.set_xlabel('Time Steps')
165 | ax.set_ylabel('Max Expected Reward from Agent')
166 |
167 | buf = io.BytesIO()
168 | fig.savefig(buf, format='png')
169 | # Closing the figure prevents it from being displayed directly inside
170 | # the notebook.
171 | plt.close(fig)
172 | buf.seek(0)
173 | # Create Image object
174 | plots.append(np.array(Image.open(buf))[:, :, :3])
175 |
176 | for plot in plots:
177 | plt.grid(False)
178 | plt.xticks([])
179 | plt.yticks([])
180 | plt.tight_layout()
181 | plt.imshow(plot)
182 | plt.show()
183 |
184 | def get_agent_accuracy_density(self, items, episode_num=None):
185 | x = None
186 | y = None
187 |
188 | for episode in [_ for _ in list(set(mdp.episode for mdp in items)) if episode_num is None or episode_num == _]:
189 | subset = [item for item in items if item.episode == episode]
190 | state = np.array([_.current_state for _ in subset])
191 | result_state = np.array([_.result_state for _ in subset])
192 |
193 | prim_q_pred = self.learn.model(torch.from_numpy(state))
194 | target_q_pred = self.learn.model.target_net(torch.from_numpy(state).float())
195 | state_difference = (prim_q_pred - target_q_pred).sum(1)
196 | prim_q_pred = self.learn.model(torch.from_numpy(result_state))
197 | target_q_pred = self.learn.model.target_net(torch.from_numpy(result_state).float())
198 | result_state_difference = (prim_q_pred - target_q_pred).sum(1)
199 |
200 | x = state_difference if x is None else np.hstack((x, state_difference))
201 | y = result_state_difference if y is None else np.hstack((y, result_state_difference))
202 |
203 | return x, y
204 |
205 | def plot_agent_accuracy_density(self, episode_num=None):
206 | """
207 | Heat maps the density of actual vs estimated q v. Good reference for this is at [1].
208 |
209 | References:
210 | [1] "Simple Example Of 2D Density Plots In Python." Medium. N. p., 2019. Web. 31 Aug. 2019.
211 | https://towardsdatascience.com/simple-example-of-2d-density-plots-in-python-83b83b934f67
212 |
213 | Returns:
214 |
215 | """
216 | items = self._get_items(False) # type: List[MarkovDecisionProcessSliceAlpha]
217 | x, y = self.get_agent_accuracy_density(items, episode_num)
218 |
219 | fig = plt.figure(figsize=(8, 8))
220 | ax = fig.gca()
221 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}')
222 | ax.set_ylabel('State / State Prime Q Value Deviation')
223 | ax.set_xlabel('Iterations')
224 | ax.plot(np.hstack([x, y]))
225 | plt.show()
226 |
227 | def get_q_density(self, items, episode_num=None):
228 | x = None
229 | y = None
230 |
231 | for episode in [_ for _ in list(set(mdp.episode for mdp in items)) if episode_num is None or episode_num == _]:
232 | subset = [item for item in items if item.episode == episode]
233 | r = np.array([_.reward for _ in subset])
234 | # Gets the total accumulated r over a single markov chain
235 | actual_returns = np.flip([np.cumsum(r)[i:][0] for i in np.flip(np.arange(len(r)))]).reshape(1, -1)
236 | estimated_returns = self.learn.model.interpret_q(subset).view(1, -1).numpy()
237 | x = actual_returns if x is None else np.hstack((x, actual_returns))
238 | y = estimated_returns if y is None else np.hstack((y, estimated_returns))
239 |
240 | return self.normalize(x), self.normalize(y)
241 |
242 | def plot_q_density(self, episode_num=None):
243 | """
244 | Heat maps the density of actual vs estimated q v. Good reference for this is at [1].
245 |
246 | References:
247 | [1] "Simple Example Of 2D Density Plots In Python." Medium. N. p., 2019. Web. 31 Aug. 2019.
248 | https://towardsdatascience.com/simple-example-of-2d-density-plots-in-python-83b83b934f67
249 |
250 | Returns:
251 |
252 | """
253 | items = self._get_items(False) # type: List[MarkovDecisionProcessSliceAlpha]
254 | x, y = self.get_q_density(items, episode_num)
255 |
256 | # Define the borders
257 | deltaX = (np.max(x) - np.min(x)) / 10
258 | deltaY = (np.max(y) - np.min(y)) / 10
259 | xmin = np.min(x) - deltaX
260 | xmax = np.max(x) + deltaX
261 | ymin = np.min(y) - deltaY
262 | ymax = np.max(y) + deltaY
263 | print(xmin, xmax, ymin, ymax)
264 | # Create meshgrid
265 | xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
266 |
267 | positions = np.vstack([xx.ravel(), yy.ravel()])
268 | values = np.vstack([x, y])
269 |
270 | kernel = st.gaussian_kde(values)
271 |
272 | f = np.reshape(kernel(positions).T, xx.shape)
273 |
274 | fig = plt.figure(figsize=(8, 8))
275 | ax = fig.gca()
276 | ax.set_xlim(xmin, xmax)
277 | ax.set_ylim(ymin, ymax)
278 | cfset = ax.contourf(xx, yy, f, cmap='coolwarm')
279 | ax.imshow(np.rot90(f), cmap='coolwarm', extent=[xmin, xmax, ymin, ymax])
280 | cset = ax.contour(xx, yy, f, colors='k')
281 | ax.clabel(cset, inline=1, fontsize=10)
282 | ax.set_xlabel('Actual Returns')
283 | ax.set_ylabel('Estimated Q')
284 | if episode_num is None:
285 | plt.title('2D Gaussian Kernel Q Density Estimation')
286 | else:
287 | plt.title(f'2D Gaussian Kernel Q Density Estimation for episode {episode_num}')
288 | plt.show()
289 |
290 | def plot_rewards_over_iterations(self, cumulative=False, return_rewards=False):
291 | items = self._get_items()
292 | r_iter = [el.reward[0] if np.ndim(el.reward) == 0 else np.average(el.reward) for el in items]
293 | if cumulative: r_iter = np.cumsum(r_iter)
294 | fig = plt.figure(figsize=(8, 8))
295 | ax = fig.gca()
296 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}')
297 | ax.set_ylabel('Rewards' if not cumulative else 'Cumulative Rewards')
298 | ax.set_xlabel('Iterations')
299 | ax.plot(r_iter)
300 | plt.show()
301 | if return_rewards: return r_iter
302 |
303 | def plot_rewards_over_episodes(self, cumulative=False, fig_size=(8, 8)):
304 | items = self._get_items()
305 | r_iter = [(el.reward[0] if np.ndim(el.reward) == 0 else np.average(el.reward), el.episode) for el in items]
306 | rewards, episodes = zip(*r_iter)
307 | if cumulative: rewards = np.cumsum(rewards)
308 | fig = plt.figure(figsize=(8, 8))
309 | ax = fig.gca()
310 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}')
311 | ax.set_ylabel('Rewards' if not cumulative else 'Cumulative Rewards')
312 | ax.set_xlabel('Episodes')
313 | ax.xaxis.set_ticks([i for i, el in enumerate(episodes) if episodes[i - 1] != el or i == 0])
314 | ax.xaxis.set_ticklabels([el for i, el in enumerate(episodes) if episodes[i - 1] != el or i == 0])
315 | ax.plot(rewards)
316 | plt.show()
317 |
318 | def episode_video_frames(self, episode=None) -> Dict[str, np.array]:
319 | """ Returns numpy arrays representing purely episode frames. """
320 | items = self._get_items(False)
321 | if episode is None: episode_frames = {key: None for key in list(set([_.episode for _ in items]))}
322 | else: episode_frames = {episode: None}
323 |
324 | for key in episode_frames:
325 | if self.ds.feed_type == FEED_TYPE_IMAGE:
326 | episode_frames[key] = np.array([_.current_state for _ in items if key == _.episode])
327 | else:
328 | episode_frames[key] = np.array([_.alternate_state for _ in items if key == _.episode])
329 |
330 | return episode_frames
331 |
332 | def episode_to_gif(self, episode=None, path='', fps=30):
333 | frames = self.episode_video_frames(episode)
334 |
335 | for ep in frames:
336 | fig, ax = plt.subplots()
337 | animation = VideoClip(partial(self._make_frame, frames=frames[ep], axes=ax, fig=fig, title=f'Episode {ep}'),
338 | duration=frames[ep].shape[0])
339 | animation.write_gif(path + f'episode_{ep}.gif', fps=fps)
340 |
341 | def _make_frame(self, t, frames, axes, fig, title):
342 | axes.clear()
343 | fig.suptitle(title)
344 | axes.imshow(frames[int(t)])
345 | return mplfig_to_npimage(fig)
346 |
347 | def iplot_episode(self, episode, fps=30):
348 | if episode is None: raise ValueError('The episode cannot be None for jupyter display')
349 | x = self.episode_video_frames(episode)[episode]
350 | fig, ax = plt.subplots()
351 |
352 | self.current_animation = VideoClip(partial(self._make_frame, frames=x, axes=ax, fig=fig,
353 | title=f'Episode {episode}'), duration=x.shape[0])
354 | self.current_animation.ipython_display(fps=fps, loop=True, autoplay=True)
355 |
356 | def get_memory_samples(self, batch_size=None, key='reward'):
357 | samples = self.learn.model.memory.sample(self.learn.model.batch_size if batch_size is None else batch_size)
358 | if not samples: raise IndexError('Your tree seems empty.')
359 | if batch_size is not None and batch_size > len(self.learn.model.memory):
360 | raise IndexError(f'Your batch size {batch_size} > the tree\'s batch size {len(self.learn.model.memory)}')
361 | if key not in samples[0].obj.keys(): raise ValueError(f'Key {key} not in {samples[0].obj.keys()}')
362 | return [s.obj[key] for s in samples]
363 |
364 | def plot_memory_samples(self, batch_size=None, key='reward', fig_size=(8, 8)):
365 | values_of_interest = self.get_memory_samples(batch_size, key)
366 | fig = plt.figure(figsize=fig_size)
367 | ax = fig.gca()
368 | fig.suptitle(f'{self.learn.model.name} for {self.ds.env.spec._env_name}')
369 | ax.set_ylabel(key)
370 | ax.set_xlabel('Values')
371 | ax.plot(values_of_interest)
372 | plt.show()
--------------------------------------------------------------------------------
/fast_rl/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/core/__init__.py
--------------------------------------------------------------------------------
/fast_rl/core/agent_core.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from math import ceil
3 |
4 | import gym
5 | from fastai.basic_train import *
6 | from fastai.torch_core import *
7 |
8 | from fast_rl.core.data_structures import SumTree
9 |
10 |
11 | class ExplorationStrategy:
12 | def __init__(self, explore: bool = True): self.explore=explore
13 | def update(self, max_episodes, explore, **kwargs): self.explore=explore
14 | def perturb(self, action, action_space) -> np.ndarray:
15 | """
16 | Base method just returns the action. Subclass, and change to return randomly / augmented actions.
17 |
18 | Should use `do_exploration` field. It is recommended that when you subclass / overload, you allow this field
19 | to completely bypass these actions.
20 |
21 | Args:
22 | action (np.array): Action input as a regular numpy array.
23 | action_space (gym.Space): The original gym space. Should contain information on the action type, and
24 | possible convenience methods for random action selection.
25 | """
26 | _=action_space
27 | return action
28 |
29 |
30 | class GreedyEpsilon(ExplorationStrategy):
31 | def __init__(self, epsilon_start, epsilon_end, decay, start_episode=0, end_episode=0, **kwargs):
32 | super().__init__(**kwargs)
33 | self.end_episode=end_episode
34 | self.start_episode=start_episode
35 | self.decay=decay
36 | self.e_end=epsilon_end
37 | self.e_start=epsilon_start
38 | self.epsilon=self.e_start
39 | self.steps=0
40 |
41 | def perturb(self, action, action_space: gym.Space):
42 | return action_space.sample() if np.random.random() 0 (uniform)
148 | epsilon (float): Keeps the probabilities of items from being 0
149 | memory_size (int): Max N samples to store
150 | """
151 | super().__init__(memory_size, **kwargs)
152 | self.batch_size=batch_size
153 | self.alpha=alpha
154 | self.beta=beta
155 | self.b_inc=b_inc
156 | self.p_weights=None # np.zeros(self.batch_size, dtype=float)
157 | self.epsilon=epsilon
158 | self.tree=SumTree(self.max_size)
159 | self.callbacks=[PriorityExperienceReplayCallback]
160 | # When sampled, store the sample indices for refresh.
161 | self._indices=None # np.zeros(self.batch_size, dtype=int)
162 |
163 | @property
164 | def memory(self):
165 | return self.tree.data
166 |
167 | def __len__(self):
168 | return self.tree.n_entries
169 |
170 | def refresh(self, post_optimize, **kwargs):
171 | if post_optimize is not None:
172 | self.tree.update(self._indices.astype(int), np.abs(post_optimize['td_error'])+self.epsilon)
173 |
174 | def sample(self, batch, **kwargs):
175 | self.beta=np.min([1., self.beta+self.b_inc])
176 | ranges=np.linspace(0, ceil(self.tree.total()/batch), num=batch+1)
177 | uniform_ranges=[np.random.uniform(ranges[i], ranges[i+1]) for i in range(len(ranges)-1)]
178 | try:
179 | self._indices, weights, samples=self.tree.batch_get(uniform_ranges)
180 | except ValueError:
181 | warn('Too few values to unpack. Your batch size is too small, when PER queries tree, all 0 values get'
182 | ' ignored. We will retry until we can return at least one sample.')
183 | samples=self.sample(batch)
184 | return samples
185 |
186 | self.p_weights=self.tree.anneal_weights(weights, self.beta)
187 | return samples
188 |
189 | def update(self, item, **kwargs):
190 | """
191 | Updates the tree of PER.
192 |
193 | Assigns maximal priority per [1] Alg:1, thus guaranteeing that sample being visited once.
194 |
195 | Args:
196 | item:
197 |
198 | Returns:
199 |
200 | """
201 | item=deepcopy(item)
202 | super().update(item, **kwargs)
203 | maximal_priority=self.alpha
204 | if self.reduce_ram: item.clean()
205 | self.tree.add(np.abs(maximal_priority)+self.epsilon, item)
206 |
207 |
208 | # class HindsightExperienceReplay(Experience):
209 | # def __init__(self, memory_size):
210 | # """
211 | #
212 | # References:
213 | # [1] Andrychowicz, Marcin, et al. "Hindsight experience replay."
214 | # Advances in Neural Information Processing Systems. 2017.
215 | #
216 | # Args:
217 | # memory_size:
218 | # """
219 | # super().__init__(memory_size)
220 |
--------------------------------------------------------------------------------
/fast_rl/core/basic_train.py:
--------------------------------------------------------------------------------
1 | from multiprocessing.pool import Pool
2 |
3 | from fastai.basic_train import Learner, load_callback
4 | from fastai.torch_core import *
5 |
6 | from fast_rl.core.data_block import MDPDataBunch
7 |
8 |
9 | class WrapperLossFunc(object):
10 | def __init__(self, learn):
11 | self.learn = learn
12 |
13 | def __call__(self, *args, **kwargs):
14 | return self.learn.model.loss
15 |
16 |
17 |
18 | def load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', **db_kwargs):
19 | r""" Similar to fastai `load_learner`, handles load_state for data differently. """
20 | source = Path(path)/file if is_pathlike(file) else file
21 | state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
22 | model = state.pop('model')
23 | data = MDPDataBunch.load_state(path, state.pop('data'))
24 | # if test is not None: src.add_test(test)
25 | # data = src.databunch(**db_kwargs)
26 | cb_state = state.pop('cb_state')
27 | clas_func = state.pop('cls')
28 | res = clas_func(data, model, **state)
29 | res.callback_fns = state['callback_fns'] #to avoid duplicates
30 | res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]
31 | return res
32 |
33 |
34 | class AgentLearner(Learner):
35 |
36 | def __init__(self, data, loss_func=None, callback_fns=None, opt=torch.optim.Adam, **kwargs):
37 | super().__init__(data=data, callback_fns=ifnone(callback_fns, []) + data.callback, **kwargs)
38 | self.model.loss_func = ifnone(loss_func, F.mse_loss)
39 | self.model.set_opt(opt)
40 | self.loss_func = None
41 | self.trainers = None
42 | self._loss_func = WrapperLossFunc(self)
43 |
44 | @property
45 | def warming_up(self):
46 | return self.data.bs > len(self.data.x)
47 |
48 | def init_loss_func(self):
49 | r"""
50 | Initializes the loss function wrapper for logging loss.
51 |
52 | Since most RL models have a period of warming up such as filling tree buffers, we cannot log any loss.
53 | By default, the learner will have a `None` loss function, and so the fit function will not try to log that
54 | loss.
55 | """
56 | self.loss_func = WrapperLossFunc(self)
57 |
58 | def export(self, file:PathLikeOrBinaryStream='export.pkl', destroy=False, pickle_data=False):
59 | "Export the state of the `Learner` in `self.path/file`. `file` can be file-like (file or buffer)"
60 | if rank_distrib(): return # don't save if slave proc
61 | # For now we exclude the 'loss_func' since it is pointing to a model loss.
62 | args = ['opt_func', 'metrics', 'true_wd', 'bn_wd', 'wd', 'train_bn', 'model_dir', 'callback_fns', 'memory',
63 | 'exploration_method', 'trainers']
64 | state = {a:getattr(self,a) for a in args}
65 | state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
66 | #layer_groups -> need to find a way
67 | #TO SEE: do we save model structure and weights separately?
68 | with ModelOnCPU(self.model) as m:
69 | m.opt = None
70 | state['model'] = m
71 | xtra = dict(normalize=self.data.norm.keywords) if getattr(self.data, 'norm', False) else {}
72 | state['data'] = self.data.train_ds.get_state(**xtra) if self.data.valid_dl is None else self.data.valid_ds.get_state(**xtra)
73 | state['data']['add_valid'] = not self.data.empty_val
74 | if pickle_data: self.data.to_pickle(self.data.path)
75 | state['cls'] = self.__class__
76 | try_save(state, self.path, file)
77 | if destroy: self.destroy()
78 |
79 | def interpret_q(self, xi):
80 | raise NotImplemented
81 |
82 |
83 | class PipeLine(object):
84 | def __init__(self, n_threads, pipe_line_function):
85 | warn(Warning('Currently not super useful. Seems to have issues with running a single env in multiple threads.'))
86 | self.pipe_line_function = pipe_line_function
87 | self.n_threads = n_threads
88 | self.pool = Pool(self.n_threads)
89 |
90 | def start(self, n_runs):
91 | return self.pool.map(self.pipe_line_function, range(n_runs))
--------------------------------------------------------------------------------
/fast_rl/core/data_structures.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 |
4 | Notices:
5 | [1] SumTree implementation belongs to: https://github.com/rlcode/per
6 | As of 8/23/2019, does not have a license provided. As another note, this code is modified.
7 |
8 |
9 | """
10 |
11 | import numpy as np
12 |
13 |
14 | class SumTree(object):
15 | write = 0
16 |
17 | def __init__(self, capacity):
18 | """
19 | Used for PER.
20 |
21 | References:
22 | [1] SumTree implementation belongs to: https://github.com/rlcode/per
23 |
24 | Notes:
25 | As of 8/23/2019, does not have a license provided. As another note, this code is modified.
26 |
27 |
28 | Args:
29 | capacity:
30 | """
31 |
32 | self.capacity = capacity
33 | self.tree = np.zeros(2 * capacity - 1)
34 | self.data = np.zeros(capacity, dtype=object)
35 | self.n_entries = 0
36 |
37 | def _propagate(self, idx, change):
38 | """ Update to the root node """
39 | parent = (idx - 1) // 2
40 |
41 | self.tree[parent] += change
42 |
43 | if (np.isscalar(parent) and parent != 0) or (not np.isscalar(parent) and all(parent != 0)):
44 | if not np.isscalar(parent): change[parent == 0] = 0
45 | self._propagate(parent, change)
46 |
47 | def get_left(self, index):
48 | return 2 * index + 1
49 |
50 | def get_right(self, index):
51 | return self.get_left(index) + 1
52 |
53 | def _retrieve(self, idx, s):
54 | """ Finds sample on leaf node """
55 | left = self.get_left(idx)
56 | right = self.get_right(idx)
57 |
58 | if left >= len(self.tree):
59 | return idx
60 |
61 | if s <= self.tree[left]:
62 | return self._retrieve(left, s)
63 | else:
64 | return self._retrieve(right, s - self.tree[left])
65 |
66 | def total(self):
67 | return self.tree[0]
68 |
69 | def add(self, p, data):
70 | """ Store priority and sample """
71 | idx = self.write + self.capacity - 1
72 |
73 | self.data[self.write] = data
74 | self.update(idx, p)
75 |
76 | self.write += 1
77 | if self.write >= self.capacity:
78 | self.write = 0
79 |
80 | if self.n_entries < self.capacity:
81 | self.n_entries += 1
82 |
83 | def update(self, idx, p):
84 | """ Update priority """
85 | p = p.flatten() if not np.isscalar(p) else p
86 | change = p - self.tree[idx]
87 |
88 | self.tree[idx] = p
89 | self._propagate(idx, change)
90 |
91 | def get(self, s):
92 | """ Get priority and sample """
93 | idx = self._retrieve(0, s)
94 | data_index = idx - self.capacity + 1
95 |
96 | return idx, self.tree[idx], self.data[data_index]
97 |
98 | def anneal_weights(self, priorities, beta):
99 | sampling_probabilities = priorities / self.total()
100 | is_weight = np.power(self.n_entries * sampling_probabilities, -beta)
101 | is_weight /= is_weight.max()
102 | return is_weight.astype(float)
103 |
104 | def batch_get(self, ss):
105 | return np.array(list(zip(*list([self.get(s) for s in ss if self.get(s)[2] != 0]))))
106 |
107 |
108 | def print_tree(tree: SumTree):
109 | print('\n')
110 | if tree.n_entries == 0:
111 | print('empty')
112 | return
113 |
114 | max_d = int(np.log2(len(tree.tree)))
115 | string_len_max = len(str(tree.tree[-1]))
116 |
117 | tree_strings = []
118 | display_values = None
119 | display_indexes = None
120 | for layer in range(max_d + 1):
121 | # Get the indexes in the current layer d
122 | if display_indexes is None:
123 | display_indexes = [[0]]
124 | else:
125 | local_list = []
126 | for i in [_ for _ in display_indexes[-1] if _ < len(tree.tree)]:
127 | if tree.get_left(i) < len(tree.tree): local_list.append(tree.get_left(i))
128 | if tree.get_right(i) < len(tree.tree): local_list.append(tree.get_right(i))
129 | display_indexes.append(local_list)
130 |
131 | for layer in display_indexes:
132 | # Get the v contained in current layer d
133 | if display_values is None:
134 | display_values = [[tree.tree[i] for i in layer]]
135 | else:
136 | display_values.append([tree.tree[i] for i in layer])
137 |
138 | tab_sizes = []
139 | spacings = []
140 | for i, layer in enumerate(display_values):
141 | # for now ignore string length
142 | tab_sizes.append(0 if i == 0 else (tab_sizes[-1] + 1) * 2)
143 | spacings.append(3 if i == 0 else (spacings[-1] * 2 + 1))
144 |
145 | for i, layer in enumerate(display_values):
146 | # tree_strings.append('*' * list(reversed(tab_sizes))[i])
147 | values = ''.join(str(v) + ' ' * (string_len_max * list(reversed(spacings))[i]) for v in layer)
148 | tree_strings.append(' ' * (string_len_max * list(reversed(tab_sizes))[i]) + values)
149 |
150 | for tree_string in tree_strings:
151 | print(tree_string)
152 |
--------------------------------------------------------------------------------
/fast_rl/core/layers.py:
--------------------------------------------------------------------------------
1 | r"""`fast_rl.layers` provides essential functions to building and modifying `model` architectures"""
2 | from math import ceil
3 |
4 | from fastai.torch_core import *
5 | from fastai.tabular import TabularModel
6 |
7 |
8 | def init_cnn(mod: Any):
9 | r""" Utility for initializing cnn Modules. """
10 | if getattr(mod, 'bias', None) is not None: nn.init.constant_(mod.bias, 0)
11 | if isinstance(mod, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(mod.weight)
12 | for sub_mod in mod.children(): init_cnn(sub_mod)
13 |
14 |
15 | def ks_stride(ks, stride, w, h, n_blocks, kern_proportion=.1, stride_proportion=0.3):
16 | r""" Utility for determing the the kernel size and stride. """
17 | kernels, strides, max_dim = [], [], max((w, h))
18 | for i in range(len(n_blocks)):
19 | kernels.append(max_dim * kern_proportion)
20 | strides.append(kernels[-1] * stride_proportion)
21 | max_dim = (max_dim - kernels[-1]) / strides[-1]
22 | assert max_dim > 1
23 |
24 | return ifnone(ks, map(ceil, kernels)), ifnone(stride, map(ceil, strides))
25 |
26 |
27 | class Flatten(nn.Module):
28 | def forward(self, y): return y.view(y.size(0), -1)
29 |
30 |
31 | class FakeBatchNorm(Module):
32 | r""" If we want all the batch norm layers gone, then we will replace the tabular batch norm with this. """
33 | def forward(self, xi: Tensor, *args): return xi
34 |
35 |
36 | def conv_bn_lrelu(ni: int, nf: int, ks: int = 3, stride: int = 1, pad=True, bn=True) -> nn.Sequential:
37 | r""" Create a sequence Conv2d->BatchNorm2d->LeakyReLu layer. (from darknet.py). Allows excluding BatchNorm2d Layer."""
38 | return nn.Sequential(
39 | nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=(ks // 2) if pad else 0),
40 | nn.BatchNorm2d(nf) if bn else FakeBatchNorm(),
41 | nn.LeakyReLU(negative_slope=0.1, inplace=True))
42 |
43 |
44 | class ChannelTranspose(Module):
45 | r""" Runtime image input channel changing. Useful for handling different image channel outputs from different envs. """
46 | def forward(self, xi: Tensor):
47 | return xi.transpose(3, 1).transpose(3, 2)
48 |
49 |
50 | class StateActionSplitter(Module):
51 | r""" `Actor / Critic` models require breaking the state and action into 2 streams. """
52 |
53 | def forward(self, s_a_tuple: Tuple[Tensor]):
54 | r""" Returns tensors as -> (State Tensor, Action Tensor) """
55 | return s_a_tuple[0], s_a_tuple[1]
56 |
57 |
58 | class StateActionPassThrough(nn.Module):
59 | r""" Passes action input untouched, but runs the state tensors through a sub module. """
60 | def __init__(self, layers):
61 | super().__init__()
62 | self.layers = layers
63 |
64 | def forward(self, state_action):
65 | return self.layers(state_action[0]), state_action[1]
66 |
67 |
68 | class TabularEmbedWrapper(Module):
69 | r""" Basic `TabularModel` compatibility wrapper. Typically, state inputs will be either categorical or continuous. """
70 | def __init__(self, tabular_model: TabularModel):
71 | super().__init__()
72 | self.tabular_model = tabular_model
73 |
74 | def forward(self, xi: Tensor, *args):
75 | return self.tabular_model(xi, xi)
76 |
77 |
78 | class CriticTabularEmbedWrapper(Module):
79 | r""" Similar to `TabularEmbedWrapper` but assumes input is state / action and requires concatenation. """
80 | def __init__(self, tabular_model: TabularModel, exclude_cat):
81 | super().__init__()
82 | self.tabular_model = tabular_model
83 | self.exclude_cat = exclude_cat
84 |
85 | def forward(self, args):
86 | if not self.exclude_cat:
87 | return self.tabular_model(*args)
88 | else:
89 | return self.tabular_model(0, torch.cat(args, 1))
90 |
--------------------------------------------------------------------------------
/fast_rl/core/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fastai.basic_train import LearnerCallback, Any
3 | from fastai.callback import Callback, is_listy, add_metrics
4 |
5 |
6 | class EpsilonMetric(LearnerCallback):
7 | _order = -20 # Needs to run before the recorder
8 |
9 | def __init__(self, learn):
10 | super().__init__(learn)
11 | self.epsilon = 0
12 | if not hasattr(self.learn, 'exploration_method'):
13 | raise ValueError('Your model is not using an exploration strategy! Please use epsilon based exploration')
14 | if not hasattr(self.learn.exploration_method, 'epsilon'):
15 | raise ValueError('Please use epsilon based exploration (should have an epsilon field)')
16 |
17 | # noinspection PyUnresolvedReferences
18 | def on_train_begin(self, **kwargs):
19 | self.learn.recorder.add_metric_names(['epsilon'])
20 |
21 | def on_epoch_end(self, last_metrics, **kwargs):
22 | self.epsilon = self.learn.exploration_method.epsilon
23 | if last_metrics and last_metrics[-1] is None: del last_metrics[-1]
24 | return add_metrics(last_metrics, [float(self.epsilon)])
25 |
26 | class RewardMetric(LearnerCallback):
27 | _order = -20
28 |
29 | def __init__(self, learn):
30 | super().__init__(learn)
31 | self.train_reward, self.valid_reward = [], []
32 |
33 | def on_epoch_begin(self, **kwargs:Any):
34 | self.train_reward, self.valid_reward = [], []
35 |
36 | def on_batch_end(self, **kwargs: Any):
37 | if self.learn.model.training: self.train_reward.append(self.learn.data.train_ds.item.reward.cpu().numpy()[0][0])
38 | elif not self.learn.recorder.no_val: self.valid_reward.append(self.learn.data.valid_ds.item.reward.cpu().numpy()[0][0])
39 |
40 | def on_train_begin(self, **kwargs):
41 | metric_names = ['train_reward'] if self.learn.recorder.no_val else ['train_reward', 'valid_reward']
42 | self.learn.recorder.add_metric_names(metric_names)
43 |
44 | def on_epoch_end(self, last_metrics, **kwargs: Any):
45 | return add_metrics(last_metrics, [sum(self.train_reward), sum(self.valid_reward)])
46 |
--------------------------------------------------------------------------------
/fast_rl/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/fast_rl/util/__init__.py
--------------------------------------------------------------------------------
/fast_rl/util/exceptions.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class MaxEpisodeStepsMissingError(Exception):
4 | pass
--------------------------------------------------------------------------------
/fast_rl/util/misc.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import error
3 |
4 |
5 | class b_colors:
6 | HEADER = '\033[95m'
7 | OKBLUE = '\033[94m'
8 | OKGREEN = '\033[92m'
9 | WARNING = '\033[93m'
10 | FAIL = '\033[91m'
11 | ENDC = '\033[0m'
12 | BOLD = '\033[1m'
13 | UNDERLINE = '\033[4m'
14 |
15 |
16 | def list_in_str(s: str, str_list: list, make_lower=True):
17 | if make_lower: s = s.lower()
18 | return any([s.__contains__(el) for el in str_list])
19 |
20 |
21 | def is_goal_env(env, suppress_errors=True):
22 | msg = 'GoalEnv requires the "{}" key to be part of the observation d.'
23 | # Enforce that each GoalEnv uses a Goal-compatible observation space.
24 | if not isinstance(env.observation_space, gym.spaces.Dict):
25 | if not suppress_errors:
26 | raise error.Error('GoalEnv requires an observation space of type gym.spaces.Dict')
27 | else:
28 | return False
29 | for key in ['observation', 'achieved_goal', 'desired_goal']:
30 | if key not in env.observation_space.spaces:
31 | if not suppress_errors:
32 | raise error.Error(msg.format(key))
33 | else:
34 | return False
35 | return True
36 |
--------------------------------------------------------------------------------
/res/RELEASE_BLOG.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/RELEASE_BLOG.md
--------------------------------------------------------------------------------
/res/ddpg_balancing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/ddpg_balancing.gif
--------------------------------------------------------------------------------
/res/dqn_q_estimate_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/dqn_q_estimate_1.jpg
--------------------------------------------------------------------------------
/res/dqn_q_estimate_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/dqn_q_estimate_2.jpg
--------------------------------------------------------------------------------
/res/dqn_q_estimate_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/dqn_q_estimate_3.jpg
--------------------------------------------------------------------------------
/res/fit_func_out.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/fit_func_out.jpg
--------------------------------------------------------------------------------
/res/heat_map_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_1.png
--------------------------------------------------------------------------------
/res/heat_map_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_2.png
--------------------------------------------------------------------------------
/res/heat_map_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_3.png
--------------------------------------------------------------------------------
/res/heat_map_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/heat_map_4.png
--------------------------------------------------------------------------------
/res/pre_interpretation_maze_dqn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/pre_interpretation_maze_dqn.gif
--------------------------------------------------------------------------------
/res/reward_plot_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plot_1.png
--------------------------------------------------------------------------------
/res/reward_plot_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plot_2.png
--------------------------------------------------------------------------------
/res/reward_plots/ant_ddpg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/ant_ddpg.png
--------------------------------------------------------------------------------
/res/reward_plots/cartpole_dddqn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_dddqn.png
--------------------------------------------------------------------------------
/res/reward_plots/cartpole_double.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_double.png
--------------------------------------------------------------------------------
/res/reward_plots/cartpole_dqn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_dqn.png
--------------------------------------------------------------------------------
/res/reward_plots/cartpole_dueling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_dueling.png
--------------------------------------------------------------------------------
/res/reward_plots/cartpole_fixedtarget.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/cartpole_fixedtarget.png
--------------------------------------------------------------------------------
/res/reward_plots/halfcheetah_ddpg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/halfcheetah_ddpg.png
--------------------------------------------------------------------------------
/res/reward_plots/lunarlander_all_targetbased.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_all_targetbased.png
--------------------------------------------------------------------------------
/res/reward_plots/lunarlander_dddqn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_dddqn.png
--------------------------------------------------------------------------------
/res/reward_plots/lunarlander_double.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_double.png
--------------------------------------------------------------------------------
/res/reward_plots/lunarlander_dqn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_dqn.png
--------------------------------------------------------------------------------
/res/reward_plots/lunarlander_dueling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_dueling.png
--------------------------------------------------------------------------------
/res/reward_plots/lunarlander_fixedtarget.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/lunarlander_fixedtarget.png
--------------------------------------------------------------------------------
/res/reward_plots/pendulum_ddpg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/reward_plots/pendulum_ddpg.png
--------------------------------------------------------------------------------
/res/run_gifs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/__init__.py
--------------------------------------------------------------------------------
/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_197.gif
--------------------------------------------------------------------------------
/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_438.gif
--------------------------------------------------------------------------------
/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_ExperienceReplay_DDPGModule_1_episode_69.gif
--------------------------------------------------------------------------------
/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_267.gif
--------------------------------------------------------------------------------
/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_422.gif
--------------------------------------------------------------------------------
/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/acrobot_PriorityExperienceReplay_DDPGModule_1_episode_55.gif
--------------------------------------------------------------------------------
/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_54.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_54.gif
--------------------------------------------------------------------------------
/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_614.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_614.gif
--------------------------------------------------------------------------------
/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_ExperienceReplay_DDPGModule_1_episode_999.gif
--------------------------------------------------------------------------------
/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_52.gif
--------------------------------------------------------------------------------
/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_596.gif
--------------------------------------------------------------------------------
/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/ant_PriorityExperienceReplay_DDPGModule_1_episode_984.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_207.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_207.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_31.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_31.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_447.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DQNModule_1_episode_447.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_268.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_438.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDQNModule_1_episode_60.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_287.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_43.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DoubleDuelingModule_1_episode_447.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_209.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_432.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_DuelingDQNModule_1_episode_62.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_309.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_438.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_ExperienceReplay_FixedTargetDQNModule_1_episode_57.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_216.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_413.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DQNModule_1_episode_44.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_269.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_35.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDQNModule_1_episode_444.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_2.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_260.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DoubleDuelingModule_1_episode_438.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_272.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_438.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_DuelingDQNModule_1_episode_69.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_13.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_265.gif
--------------------------------------------------------------------------------
/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/cartpole_PriorityExperienceReplay_FixedTargetDQNModule_1_episode_449.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_541.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_93.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DQNModule_1_episode_999.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_613.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_88.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDQNModule_1_episode_999.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_114.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_346.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DoubleDuelingModule_1_episode_925.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_112.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_431.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_ExperienceReplay_DuelingDQNModule_1_episode_980.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_382.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_949.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DQNModule_1_episode_99.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_514.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_7.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDQNModule_1_episode_999.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_151.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_341.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DoubleDuelingModule_1_episode_999.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_21.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_442.gif
--------------------------------------------------------------------------------
/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/lunarlander_PriorityExperienceReplay_DuelingDQNModule_1_episode_998.gif
--------------------------------------------------------------------------------
/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_238.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_238.gif
--------------------------------------------------------------------------------
/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_447.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_447.gif
--------------------------------------------------------------------------------
/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_ExperienceReplay_DDPGModule_1_episode_9.gif
--------------------------------------------------------------------------------
/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_222.gif
--------------------------------------------------------------------------------
/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_35.gif
--------------------------------------------------------------------------------
/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/res/run_gifs/pendulum_PriorityExperienceReplay_DDPGModule_1_episode_431.gif
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | VERSION = "1.0.1"
7 |
8 |
9 | setup(name='fast_rl',
10 | version=VERSION,
11 | description='Fastai for computer vision and tabular learning has been amazing. One would wish that this would '
12 | 'be the same for RL. The purpose of this repo is to have a framework that is as easy as possible to '
13 | 'start, but also designed for testing new agents. ',
14 | url='https://github.com/josiahls/fast-reinforcement-learning',
15 | author='Josiah Laivins',
16 | author_email='jlaivins@uncc.edu',
17 | python_requires='>=3.6',
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | license='',
21 | packages=find_packages(),
22 | zip_safe=False,
23 | install_requires=['fastai>=1.0.59', 'gym[box2d, atari]', 'jupyter'],
24 | extras_require={'all': [
25 | 'gym-minigrid',
26 | 'moviepy'
27 | # 'gym_maze @ git+https://github.com/MattChanTK/gym-maze.git',
28 | # 'pybullet-gym @ git+https://github.com/benelot/pybullet-gym.git'
29 | ]},
30 | classifiers=[
31 | "Development Status :: 3 - Alpha",
32 | "Programming Language :: Python :: 3",
33 | "License :: OSI Approved :: Apache Software License",
34 | "Operating System :: OS Independent",
35 | ],
36 | )
37 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def pytest_addoption(parser):
5 | parser.addoption("--include_performance_tests", action="store_true",
6 | help="Will run the performance tests which do full model testing. This could take a few"
7 | "days to fully accomplish.")
8 |
9 | @pytest.fixture()
10 | def include_performance_tests(pytestconfig):
11 | return pytestconfig.getoption("include_performance_tests")
12 |
13 |
14 | @pytest.fixture()
15 | def skip_performance_check(include_performance_tests):
16 | if not include_performance_tests:
17 | pytest.skip('Skipping due to performance argument not specified. Add --include_performance_tests to not skip')
18 |
--------------------------------------------------------------------------------
/tests/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/cartpole_dqn/dqn_PriorityExperienceReplay_FEED_TYPE_STATE.pickle
--------------------------------------------------------------------------------
/tests/data/cat/cat1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/cat/cat1.jpeg
--------------------------------------------------------------------------------
/tests/data/cat/cat2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/cat/cat2.jpeg
--------------------------------------------------------------------------------
/tests/data/dog/dog1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/dog/dog1.jpeg
--------------------------------------------------------------------------------
/tests/data/dog/dog2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fast-reinforcement-learning/66136009dd7052d4a9c07631d5c170c9aeba67f3/tests/data/dog/dog2.jpeg
--------------------------------------------------------------------------------
/tests/test_agent_core.py:
--------------------------------------------------------------------------------
1 |
2 | # import pytest
3 | #
4 | # from fast_rl.core.data_block import MDPDataBunch
5 | # from fast_rl.core.agent_core import PriorityExperienceReplay
6 | # from fast_rl.core.basic_train import AgentLearner
7 | #
8 |
9 |
--------------------------------------------------------------------------------
/tests/test_basic_train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 |
5 | from fast_rl.agents.dqn import create_dqn_model, FixedTargetDQNModule, dqn_learner
6 | from fast_rl.core.agent_core import ExperienceReplay, torch, GreedyEpsilon
7 | from fast_rl.core.basic_train import load_learner
8 | from fast_rl.core.data_block import MDPDataBunch
9 |
10 |
11 | def test_fit():
12 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False)
13 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop)
14 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True)
15 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
16 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
17 | learner.fit(2)
18 | learner.fit(2)
19 | learner.fit(2)
20 |
21 | assert len(data.x.info)==6
22 | assert 0 in data.x.info
23 | assert 5 in data.x.info
24 |
25 |
26 | def test_to_pickle():
27 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False)
28 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop)
29 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True)
30 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
31 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
32 | learner.fit(2)
33 |
34 | assert len(data.x.info)==2
35 | assert 0 in data.x.info
36 | assert 1 in data.x.info
37 |
38 | data.to_pickle('./data/test_to_pickle')
39 | assert os.path.exists('./data/test_to_pickle_CartPole-v0')
40 |
41 |
42 | def test_from_pickle():
43 | data=MDPDataBunch.from_pickle('./data/test_to_pickle_CartPole-v0')
44 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop)
45 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True)
46 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
47 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
48 | learner.fit(2)
49 |
50 | assert len(data.x.info)==4
51 | assert 0 in data.x.info
52 | assert 3 in data.x.info
53 |
54 |
55 | def test_export_learner():
56 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False)
57 | model=create_dqn_model(data, FixedTargetDQNModule, opt=torch.optim.RMSprop)
58 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True)
59 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
60 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
61 | learner.fit(2)
62 |
63 | learner.export('test_export.pkl')#, pickle_data=True)
64 | learner = load_learner(learner.path, 'test_export.pkl')
65 | learner.fit(2)
66 |
--------------------------------------------------------------------------------
/tests/test_data_block.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 | from itertools import product
4 |
5 | import gym
6 | import pytest
7 | import numpy as np
8 | import torch
9 | from fastai.basic_train import ItemLists
10 |
11 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner
12 | from fast_rl.agents.dqn_models import DQNModule
13 | from fast_rl.core.agent_core import GreedyEpsilon, ExperienceReplay
14 | from fast_rl.core.data_block import MDPDataBunch, ResolutionWrapper, FEED_TYPE_IMAGE
15 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric
16 |
17 |
18 | def validate_item_list(item_list: ItemLists):
19 | # Check items
20 | for i, item in enumerate(item_list.items):
21 | if item.done: assert not item_list.items[
22 | i - 1].done, f'The dataset has duplicate "done\'s" that are consecutive.'
23 | assert item.state.s is not None, f'The item: {item}\'s state is None'
24 | assert item.state.s_prime is not None, f'The item: {item}\'s state prime is None'
25 |
26 |
27 | @pytest.mark.parametrize(["memory_strategy", "k"], list(product(['k_top', 'k_partitions_top'], [1, 3, 5])))
28 | def test_dataset_memory_manager(memory_strategy, k):
29 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False,
30 | memory_management_strategy=memory_strategy, k=k)
31 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1)
32 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True)
33 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
34 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
35 | callback_fns=[RewardMetric, EpsilonMetric])
36 | learner.fit(10)
37 |
38 | data_info = {episode: data.train_ds.x.info[episode] for episode in data.train_ds.x.info if episode != -1}
39 | full_episodes = [episode for episode in data_info if not data_info[episode][1]]
40 |
41 | assert sum([not _[1] for _ in data_info.values()]) == k, 'There should be k episodes but there is not.'
42 | if memory_strategy.__contains__('top') and not memory_strategy.__contains__('both'):
43 | assert (np.argmax([_[0] for _ in data_info.values()])) in full_episodes
44 |
45 |
46 | def test_databunch_to_pickle():
47 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20, add_valid=False,
48 | memory_management_strategy='k_partitions_top', k=3)
49 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1)
50 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True)
51 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
52 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
53 | callback_fns=[RewardMetric, EpsilonMetric])
54 | learner.fit(10)
55 | data.to_pickle('./data/cartpole_10_epoch')
56 | MDPDataBunch.from_pickle(env_name='CartPole-v0', path='./data/cartpole_10_epoch')
57 |
58 |
59 | def test_resolution_wrapper():
60 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=10, add_valid=False,
61 | memory_management_strategy='k_top', k=1, feed_type=FEED_TYPE_IMAGE,
62 | res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2))
63 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1,channels=[32,32,32],ks=[5,5,5],stride=[2,2,2])
64 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True)
65 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
66 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
67 | callback_fns=[RewardMetric, EpsilonMetric])
68 | learner.fit(2)
69 | temp = gym.make('CartPole-v0')
70 | temp.reset()
71 | original_shape = temp.render(mode='rgb_array').shape
72 | assert data.env.render(mode='rgb_array').shape == (original_shape[0] // 2, original_shape[1] // 2, 3)
73 |
--------------------------------------------------------------------------------
/tests/test_data_structures.py:
--------------------------------------------------------------------------------
1 | from fast_rl.core.data_structures import print_tree, SumTree
2 |
3 |
4 | def test_sum_tree_with_max_size():
5 | memory = SumTree(10)
6 |
7 | values = [1, 1, 1, 1, 1, 1]
8 | data = [f'data with priority: {i}' for i in values]
9 |
10 | for element, value in zip(data, values):
11 | memory.add(value, element)
12 |
13 | print_tree(memory)
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tests/test_ddpg.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import product
3 |
4 | import pytest
5 | from fastai.basic_train import torch, DatasetType
6 | from fastai.core import ifnone
7 |
8 | from fast_rl.agents.ddpg import create_ddpg_model, ddpg_learner, DDPGLearner
9 | from fast_rl.agents.ddpg_models import DDPGModule
10 | from fast_rl.core.agent_core import ExperienceReplay, PriorityExperienceReplay, OrnsteinUhlenbeck
11 | from fast_rl.core.data_block import FEED_TYPE_STATE, MDPDataBunch, ResolutionWrapper
12 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric
13 | from fast_rl.core.train import GroupAgentInterpretation, AgentInterpretation
14 |
15 | p_model=[DDPGModule]
16 | p_exp=[ExperienceReplay, PriorityExperienceReplay]
17 | p_format=[FEED_TYPE_STATE] # , FEED_TYPE_IMAGE]
18 | p_full_format=[FEED_TYPE_STATE]
19 | p_envs=['Pendulum-v0']
20 |
21 | config_env_expectations={
22 | 'Pendulum-v0': {'action_shape': (1, 1), 'state_shape': (1, 3)}
23 | }
24 |
25 |
26 | def trained_learner(model_cls, env, s_format, experience, bs=64,layers=None,render='rgb_array', memory_size=1000000,
27 | decay=0.0001,lr=None,actor_lr=None,epochs=450,opt=torch.optim.RMSprop, **kwargs):
28 | lr,actor_lr=ifnone(lr,1e-3),ifnone(actor_lr,1e-4)
29 | data=MDPDataBunch.from_env(env,render=render,bs=bs,add_valid=False,keep_env_open=False,feed_type=s_format,
30 | memory_management_strategy='k_partitions_top',k=3,**kwargs)
31 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape,epsilon_start=1,epsilon_end=0.1,
32 | decay=decay)
33 | memory=experience(memory_size=memory_size, reduce_ram=True)
34 | model=create_ddpg_model(data=data,base_arch=model_cls,lr=lr,actor_lr=actor_lr,layers=layers,opt=opt)
35 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
36 | callback_fns=[RewardMetric, EpsilonMetric])
37 | learner.fit(epochs)
38 | return learner
39 |
40 | def check_shape(env,data,s_format):
41 | assert config_env_expectations[env]['action_shape']==(1, data.action.taken_action.shape[1])
42 | if s_format==FEED_TYPE_STATE:
43 | assert config_env_expectations[env]['state_shape']==data.state.s.shape
44 |
45 | def learner2gif(lnr:DDPGLearner,s_format,group_interp:GroupAgentInterpretation,name:str,extra:str):
46 | meta=f'{lnr.memory.__class__.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
47 | interp=AgentInterpretation(lnr, ds_type=DatasetType.Train)
48 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
49 | group_interp.add_interpretation(interp)
50 | group_interp.to_pickle(f'../docs_src/data/{name}_{lnr.model.name.lower()}/', f'{lnr.model.name.lower()}_{meta}')
51 | [g.write(f'../res/run_gifs/{name}_{extra}') for g in interp.generate_gif()]
52 |
53 |
54 | @pytest.mark.parametrize(["model_cls", "s_format", "env"], list(product(p_model, p_format, p_envs)))
55 | def test_ddpg_create_ddpg_model(model_cls, s_format, env):
56 | data=MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, feed_type=s_format)
57 | model=create_ddpg_model(data, model_cls)
58 | model.eval()
59 | model(data.state.s.float())
60 | check_shape(env,data,s_format)
61 | data.close()
62 |
63 |
64 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs)))
65 | def test_ddpg_ddpglearner(model_cls, s_format, mem, env):
66 | data=MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, feed_type=s_format)
67 | model=create_ddpg_model(data, model_cls)
68 | memory=mem(memory_size=1000, reduce_ram=True)
69 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1,
70 | decay=0.001)
71 | ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
72 | check_shape(env,data,s_format)
73 | data.close()
74 |
75 |
76 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs)))
77 | def test_ddpg_fit(model_cls, s_format, mem, env):
78 | learner=trained_learner(env=env, bs=10,opt=torch.optim.RMSprop,model_cls=model_cls,layers=[20, 20],memory_size=100,
79 | max_steps=20,render='rgb_array',decay=0.001,s_format=s_format,experience=mem,epochs=2)
80 |
81 | check_shape(env,learner.data,s_format)
82 | del learner
83 |
84 |
85 | @pytest.mark.usefixtures('skip_performance_check')
86 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
87 | list(product(p_model, p_format, p_exp)))
88 | def test_ddpg_models_pendulum(model_cls, s_format, experience):
89 | group_interp=GroupAgentInterpretation()
90 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}'
91 | for i in range(5):
92 | print('\n')
93 | learner=trained_learner(model_cls,'Pendulum-v0',s_format,experience,decay=0.0001,render='rgb_array')
94 | learner2gif(learner,s_format,group_interp,'pendulum',extra_s)
95 | del learner
96 |
97 |
98 | @pytest.mark.usefixtures('skip_performance_check')
99 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
100 | list(product(p_model, p_format, p_exp)))
101 | def test_ddpg_models_acrobot(model_cls, s_format, experience):
102 | group_interp=GroupAgentInterpretation()
103 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}'
104 | for i in range(5):
105 | print('\n')
106 | learner=trained_learner(model_cls,'Acrobot-v1',s_format,experience,decay=0.0001,render='rgb_array')
107 | learner2gif(learner,s_format,group_interp,'acrobot',extra_s)
108 | del learner
109 |
110 |
111 | @pytest.mark.usefixtures('skip_performance_check')
112 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
113 | list(product(p_model, p_full_format, p_exp)))
114 | def test_ddpg_models_mountain_car_continuous(model_cls, s_format, experience):
115 | group_interp=GroupAgentInterpretation()
116 | for i in range(5):
117 | print('\n')
118 | data=MDPDataBunch.from_env('MountainCarContinuous-v0', render='rgb_array', bs=40, add_valid=False, keep_env_open=False,
119 | feed_type=s_format, memory_management_strategy='k_partitions_top', k=3, res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2))
120 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1,
121 | decay=0.0001)
122 | memory=experience(memory_size=1000000, reduce_ram=True)
123 | model=create_ddpg_model(data=data, base_arch=model_cls)
124 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
125 | callback_fns=[RewardMetric, EpsilonMetric])
126 | learner.fit(450)
127 |
128 | meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
129 | interp=AgentInterpretation(learner, ds_type=DatasetType.Train)
130 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
131 | group_interp.add_interpretation(interp)
132 | group_interp.to_pickle(f'../docs_src/data/mountaincarcontinuous_{model.name.lower()}/',
133 | f'{model.name.lower()}_{meta}')
134 | [g.write('../res/run_gifs/mountaincarcontinuous') for g in interp.generate_gif()]
135 | data.close()
136 | del learner
137 | del model
138 | del data
139 |
140 |
141 | @pytest.mark.usefixtures('skip_performance_check')
142 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
143 | list(product(p_model, p_full_format, p_exp)))
144 | def test_ddpg_models_reach(model_cls, s_format, experience):
145 | group_interp=GroupAgentInterpretation()
146 | for i in range(5):
147 | print('\n')
148 | data=MDPDataBunch.from_env('ReacherPyBulletEnv-v0', render='rgb_array', bs=40, add_valid=False, keep_env_open=False, feed_type=s_format,
149 | memory_management_strategy='k_partitions_top', k=3, res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2))
150 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1,
151 | decay=0.00001)
152 | memory=experience(memory_size=1000000, reduce_ram=True)
153 | model=create_ddpg_model(data=data, base_arch=model_cls)
154 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
155 | callback_fns=[RewardMetric, EpsilonMetric])
156 | learner.fit(450)
157 |
158 | meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
159 | interp=AgentInterpretation(learner, ds_type=DatasetType.Train)
160 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
161 | group_interp.add_interpretation(interp)
162 | group_interp.to_pickle(f'../docs_src/data/reacher_{model.name.lower()}/',
163 | f'{model.name.lower()}_{meta}')
164 | [g.write('../res/run_gifs/reacher') for g in interp.generate_gif()]
165 | data.close()
166 | del learner
167 | del model
168 | del data
169 |
170 |
171 | @pytest.mark.usefixtures('skip_performance_check')
172 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
173 | list(product(p_model, p_full_format, p_exp)))
174 | def test_ddpg_models_walker(model_cls, s_format, experience):
175 | group_interp=GroupAgentInterpretation()
176 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}'
177 | for i in range(5):
178 | print('\n')
179 | # data=MDPDataBunch.from_env('Walker2DPyBulletEnv-v0', render='human', bs=64, add_valid=False, keep_env_open=False,
180 | # feed_type=s_format, memory_management_strategy='k_partitions_top', k=3)
181 | # exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1,
182 | # decay=0.0001)
183 | # memory=experience(memory_size=1000000, reduce_ram=True)
184 | # model=create_ddpg_model(data=data, base_arch=model_cls)
185 | # learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
186 | # callback_fns=[RewardMetric, EpsilonMetric])
187 | # learner.fit(1500)
188 | learner=trained_learner(model_cls,'Walker2DPyBulletEnv-v0',s_format,experience,decay=0.0001,render='rgb_array')
189 | learner2gif(learner,s_format,group_interp,'walker2d',extra_s)
190 |
191 | # meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
192 | # interp=AgentInterpretation(learner, ds_type=DatasetType.Train)
193 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
194 | # group_interp.add_interpretation(interp)
195 | # group_interp.to_pickle(f'../docs_src/data/walker2d_{model.name.lower()}/',
196 | # f'{model.name.lower()}_{meta}')
197 | # [g.write('../res/run_gifs/walker2d') for g in interp.generate_gif()]
198 | # data.close()
199 | del learner
200 | # del model
201 | # del data
202 |
203 |
204 | @pytest.mark.usefixtures('skip_performance_check')
205 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
206 | list(product(p_model, p_full_format, p_exp)))
207 | def test_ddpg_models_ant(model_cls, s_format, experience):
208 | group_interp=GroupAgentInterpretation()
209 | extra_s = f'{experience.__name__}_{model_cls.__name__}_{s_format}'
210 | for i in range(5):
211 | print('\n')
212 | # data=MDPDataBunch.from_env('AntPyBulletEnv-v0', render='human', bs=64, add_valid=False, keep_env_open=False,
213 | # feed_type=s_format, memory_management_strategy='k_partitions_top', k=3)
214 | # exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1,
215 | # decay=0.00001)
216 | # memory=experience(memory_size=1000000, reduce_ram=True)
217 | # model=create_ddpg_model(data=data, base_arch=model_cls, lr=1e-3, actor_lr=1e-4)
218 | # learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
219 | # opt_func=torch.optim.Adam, callback_fns=[RewardMetric, EpsilonMetric])
220 | # learner.fit(4)
221 | learner=trained_learner(model_cls,'AntPyBulletEnv-v0',s_format,experience,decay=0.0001,render='rgb_array',epochs=1000)
222 | learner2gif(learner,s_format,group_interp,'ant',extra_s)
223 | # meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
224 | # interp=AgentInterpretation(learner, ds_type=DatasetType.Train)
225 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
226 | # group_interp.add_interpretation(interp)
227 | # group_interp.to_pickle(f'../docs_src/data/ant_{model.name.lower()}/',
228 | # f'{model.name.lower()}_{meta}')
229 | # [g.write('../res/run_gifs/ant', frame_skip=5) for g in interp.generate_gif()]
230 | del learner
231 | # del model
232 | # del data
233 |
234 |
235 | @pytest.mark.usefixtures('skip_performance_check')
236 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
237 | list(product(p_model, p_full_format, p_exp)))
238 | def test_ddpg_models_halfcheetah(model_cls, s_format, experience):
239 | group_interp=GroupAgentInterpretation()
240 | for i in range(5):
241 | print('\n')
242 | data=MDPDataBunch.from_env('HalfCheetahPyBulletEnv-v0', render='rgb_array', bs=64, add_valid=False, keep_env_open=False,
243 | feed_type=s_format, memory_management_strategy='k_partitions_top', k=3, res_wrap=partial(ResolutionWrapper, w_step=2, h_step=2))
244 | exploration_method=OrnsteinUhlenbeck(size=data.action.taken_action.shape, epsilon_start=1, epsilon_end=0.1,
245 | decay=0.000001)
246 | memory=experience(memory_size=1000000, reduce_ram=True)
247 | model=create_ddpg_model(data=data, base_arch=model_cls)
248 | learner=ddpg_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
249 | callback_fns=[RewardMetric, EpsilonMetric])
250 | learner.fit(1000)
251 |
252 | meta=f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
253 | interp=AgentInterpretation(learner, ds_type=DatasetType.Train)
254 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
255 | group_interp.add_interpretation(interp)
256 | group_interp.to_pickle(f'../docs_src/data/halfcheetah_{model.name.lower()}/',
257 | f'{model.name.lower()}_{meta}')
258 | [g.write('../res/run_gifs/halfcheetah') for g in interp.generate_gif()]
259 | del learner
260 | del model
261 | del data
262 |
--------------------------------------------------------------------------------
/tests/test_dqn.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from time import sleep
3 |
4 | import pytest
5 | from fastai.basic_data import DatasetType
6 |
7 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner, DQNLearner
8 | from fast_rl.agents.dqn_models import *
9 | from fast_rl.core.agent_core import ExperienceReplay, PriorityExperienceReplay, GreedyEpsilon
10 | from fast_rl.core.data_block import MDPDataBunch, FEED_TYPE_STATE, FEED_TYPE_IMAGE, ResolutionWrapper
11 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric
12 | from fast_rl.core.train import GroupAgentInterpretation, AgentInterpretation
13 | from torch import optim
14 |
15 | p_model = [DQNModule, FixedTargetDQNModule,DoubleDuelingModule,DuelingDQNModule,DoubleDQNModule]
16 | p_exp = [ExperienceReplay,
17 | PriorityExperienceReplay]
18 | p_format = [FEED_TYPE_STATE]#, FEED_TYPE_IMAGE]
19 | p_envs = ['CartPole-v1']
20 |
21 | config_env_expectations = {
22 | 'CartPole-v1': {'action_shape': (1, 2), 'state_shape': (1, 4)},
23 | 'maze-random-5x5-v0': {'action_shape': (1, 4), 'state_shape': (1, 2)}
24 | }
25 |
26 |
27 | def learner2gif(lnr:DQNLearner,s_format,group_interp:GroupAgentInterpretation,name:str,extra:str):
28 | meta=f'{lnr.memory.__class__.__name__}_{"FEED_TYPE_STATE" if s_format==FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
29 | interp=AgentInterpretation(lnr, ds_type=DatasetType.Train)
30 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
31 | group_interp.add_interpretation(interp)
32 | group_interp.to_pickle(f'../docs_src/data/{name}_{lnr.model.name.lower()}/', f'{lnr.model.name.lower()}_{meta}')
33 | temp=[g.write(f'../res/run_gifs/{name}_{extra}') for g in interp.generate_gif()]
34 | del temp
35 | gc.collect()
36 |
37 |
38 |
39 | def trained_learner(model_cls, env, s_format, experience, bs, layers, memory_size=1000000, decay=0.001,
40 | copy_over_frequency=300, lr=None, epochs=450,**kwargs):
41 | if lr is None: lr = [0.001, 0.00025]
42 | memory = experience(memory_size=memory_size, reduce_ram=True)
43 | explore = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=decay)
44 | if type(lr) == list: lr = lr[0] if model_cls == DQNModule else lr[1]
45 | data = MDPDataBunch.from_env(env, render='human', bs=bs, add_valid=False, keep_env_open=False, feed_type=s_format,
46 | memory_management_strategy='k_partitions_top', k=3,**kwargs)
47 | if model_cls == DQNModule: model = create_dqn_model(data=data, base_arch=model_cls, lr=lr, layers=layers, opt=optim.RMSprop)
48 | else: model = create_dqn_model(data=data, base_arch=model_cls, lr=lr, layers=layers)
49 | learn = dqn_learner(data, model, memory=memory, exploration_method=explore, copy_over_frequency=copy_over_frequency,
50 | callback_fns=[RewardMetric, EpsilonMetric])
51 | learn.fit(epochs)
52 | return learn
53 |
54 | # @pytest.mark.usefixtures('skip_performance_check')
55 | @pytest.mark.parametrize(["model_cls", "s_format", "env"], list(product(p_model, p_format, p_envs)))
56 | def test_dqn_create_dqn_model(model_cls, s_format, env):
57 | data = MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, feed_type=s_format)
58 | model = create_dqn_model(data, model_cls)
59 | model.eval()
60 | model(data.state.s)
61 |
62 | assert config_env_expectations[env]['action_shape'] == (1, data.action.n_possible_values.item())
63 | if s_format == FEED_TYPE_STATE:
64 | assert config_env_expectations[env]['state_shape'] == data.state.s.shape
65 |
66 |
67 | # @pytest.mark.usefixtures('skip_performance_check')
68 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs)))
69 | def test_dqn_dqn_learner(model_cls, s_format, mem, env):
70 | data = MDPDataBunch.from_env(env, render='rgb_array', bs=32, add_valid=False, keep_env_open=False, feed_type=s_format)
71 | model = create_dqn_model(data, model_cls)
72 | memory = mem(memory_size=1000, reduce_ram=True)
73 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
74 | dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
75 |
76 | assert config_env_expectations[env]['action_shape'] == (1, data.action.n_possible_values.item())
77 | if s_format == FEED_TYPE_STATE:
78 | assert config_env_expectations[env]['state_shape'] == data.state.s.shape
79 |
80 |
81 | # @pytest.mark.usefixtures('skip_performance_check')
82 | @pytest.mark.parametrize(["model_cls", "s_format", "mem", "env"], list(product(p_model, p_format, p_exp, p_envs)))
83 | def test_dqn_fit(model_cls, s_format, mem, env):
84 | data = MDPDataBunch.from_env(env, render='rgb_array', bs=5, max_steps=20, add_valid=False, keep_env_open=False, feed_type=s_format)
85 | model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop)
86 | memory = mem(memory_size=1000, reduce_ram=True)
87 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
88 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
89 | learner.fit(2)
90 |
91 | assert config_env_expectations[env]['action_shape'] == (1, data.action.n_possible_values.item())
92 | if s_format == FEED_TYPE_STATE:
93 | assert config_env_expectations[env]['state_shape'] == data.state.s.shape
94 |
95 |
96 | @pytest.mark.usefixtures('skip_performance_check')
97 | @pytest.mark.parametrize(["model_cls", "s_format", "mem"], list(product(p_model, p_format, p_exp)))
98 | def test_dqn_fit_maze_env(model_cls, s_format, mem):
99 | group_interp = GroupAgentInterpretation()
100 | extra_s=f'{mem.__name__}_{model_cls.__name__}_{s_format}'
101 | for i in range(5):
102 | learn = trained_learner(model_cls, 'maze-random-5x5-v0', s_format, mem, bs=32, layers=[32, 32],
103 | memory_size=1000000, decay=0.00001, res_wrap=partial(ResolutionWrapper, w_step=3, h_step=3))
104 |
105 | learner2gif(learn,s_format,group_interp,'maze_5x5',extra_s)
106 | # success = False
107 | # while not success:
108 | # try:
109 | # data = MDPDataBunch.from_env('maze-random-5x5-v0', render='rgb_array', bs=5, max_steps=20,
110 | # add_valid=False, keep_env_open=False, feed_type=s_format)
111 | # model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop)
112 | # memory = ExperienceReplay(10000)
113 | # exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
114 | # learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
115 | # callback_fns=[RewardMetric, EpsilonMetric])
116 | # learner.fit(2)
117 | #
118 | # assert config_env_expectations['maze-random-5x5-v0']['action_shape'] == (
119 | # 1, data.action.n_possible_values.item())
120 | # if s_format == FEED_TYPE_STATE:
121 | # assert config_env_expectations['maze-random-5x5-v0']['state_shape'] == data.state.s.shape
122 | # sleep(1)
123 | # success = True
124 | # except Exception as e:
125 | # if not str(e).__contains__('Surface'):
126 | # raise Exception
127 |
128 |
129 | @pytest.mark.usefixtures('skip_performance_check')
130 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], list(product(p_model, p_format, p_exp)))
131 | def test_dqn_models_minigrids(model_cls, s_format, experience):
132 | group_interp = GroupAgentInterpretation()
133 | for i in range(5):
134 | learn = trained_learner(model_cls, 'MiniGrid-FourRooms-v0', s_format, experience, bs=32, layers=[64, 64],
135 | memory_size=1000000, decay=0.00001, epochs=1000)
136 |
137 | meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
138 | interp = AgentInterpretation(learn, ds_type=DatasetType.Train)
139 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
140 | group_interp.add_interpretation(interp)
141 | filename = f'{learn.model.name.lower()}_{meta}'
142 | group_interp.to_pickle(f'../docs_src/data/minigrid_{learn.model.name.lower()}/', filename)
143 | [g.write('../res/run_gifs/minigrid') for g in interp.generate_gif()]
144 | del learn
145 |
146 |
147 | @pytest.mark.usefixtures('skip_performance_check')
148 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'],
149 | list(product(p_model, p_format, p_exp)))
150 | def test_dqn_models_cartpole(model_cls, s_format, experience):
151 | group_interp = GroupAgentInterpretation()
152 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}'
153 | for i in range(5):
154 | learn = trained_learner(model_cls, 'CartPole-v1', s_format, experience, bs=32, layers=[64, 64],
155 | memory_size=1000000, decay=0.001)
156 |
157 | learner2gif(learn,s_format,group_interp,'cartpole',extra_s)
158 | # meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
159 | # interp = AgentInterpretation(learn, ds_type=DatasetType.Train)
160 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
161 | # group_interp.add_interpretation(interp)
162 | # filename = f'{learn.model.name.lower()}_{meta}'
163 | # group_interp.to_pickle(f'../docs_src/data/cartpole_{learn.model.name.lower()}/', filename)
164 | # [g.write('../res/run_gifs/cartpole') for g in interp.generate_gif()]
165 | # del learn
166 |
167 |
168 | @pytest.mark.usefixtures('skip_performance_check')
169 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], list(product(p_model, p_format, p_exp)))
170 | def test_dqn_models_lunarlander(model_cls, s_format, experience):
171 | group_interp = GroupAgentInterpretation()
172 | extra_s=f'{experience.__name__}_{model_cls.__name__}_{s_format}'
173 | for i in range(5):
174 | learn = trained_learner(model_cls, 'LunarLander-v2', s_format, experience, bs=32, layers=[128, 64],
175 | memory_size=1000000, decay=0.00001, copy_over_frequency=600, lr=[0.001, 0.00025],
176 | epochs=1000)
177 | learner2gif(learn, s_format, group_interp, 'lunarlander', extra_s)
178 | del learn
179 | gc.collect()
180 | # meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
181 | # interp = AgentInterpretation(learn, ds_type=DatasetType.Train)
182 | # interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
183 | # group_interp.add_interpretation(interp)
184 | # filename = f'{learn.model.name.lower()}_{meta}'
185 | # group_interp.to_pickle(f'../docs_src/data/lunarlander_{learn.model.name.lower()}/', filename)
186 | # [g.write('../res/run_gifs/lunarlander') for g in interp.generate_gif()]
187 | # del learn
188 |
189 |
190 | @pytest.mark.usefixtures('skip_performance_check')
191 | @pytest.mark.parametrize(["model_cls", "s_format", 'experience'], list(product(p_model, p_format, p_exp)))
192 | def test_dqn_models_mountaincar(model_cls, s_format, experience):
193 | group_interp = GroupAgentInterpretation()
194 | for i in range(5):
195 | learn = trained_learner(model_cls, 'MountainCar-v0', s_format, experience, bs=32, layers=[24, 12],
196 | memory_size=1000000, decay=0.00001, copy_over_frequency=1000)
197 | meta = f'{experience.__name__}_{"FEED_TYPE_STATE" if s_format == FEED_TYPE_STATE else "FEED_TYPE_IMAGE"}'
198 | interp = AgentInterpretation(learn, ds_type=DatasetType.Train)
199 | interp.plot_rewards(cumulative=True, per_episode=True, group_name=meta)
200 | group_interp.add_interpretation(interp)
201 | filename = f'{learn.model.name.lower()}_{meta}'
202 | group_interp.to_pickle(f'../docs_src/data/mountaincar_{learn.model.name.lower()}/', filename)
203 | [g.write('../res/run_gifs/mountaincar') for g in interp.generate_gif()]
204 | del learn
205 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | from fastai.imports import torch
2 |
3 | from fast_rl.agents.dqn import create_dqn_model, dqn_learner
4 | from fast_rl.agents.dqn_models import DQNModule
5 | from fast_rl.core.agent_core import ExperienceReplay, GreedyEpsilon
6 | from fast_rl.core.data_block import MDPDataBunch
7 | from fast_rl.core.metrics import RewardMetric, EpsilonMetric
8 |
9 |
10 | def test_metrics_reward_init():
11 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20)
12 | model=create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop)
13 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True)
14 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
15 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
16 | callback_fns=[RewardMetric])
17 | learner.fit(2)
18 |
19 |
20 | def test_metrics_epsilon_init():
21 | data=MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=5, max_steps=20)
22 | model=create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop)
23 | memory=ExperienceReplay(memory_size=1000, reduce_ram=True)
24 | exploration_method=GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
25 | learner=dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method,
26 | callback_fns=[EpsilonMetric])
27 | learner.fit(2)
28 |
29 |
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytest
4 |
5 | from fast_rl.agents.dqn import *
6 | from fast_rl.core.agent_core import *
7 | from fast_rl.core.data_block import *
8 | from fast_rl.core.train import *
9 |
10 |
11 |
12 |
13 | @pytest.mark.usefixtures('skip_performance_check')
14 | def test_interpretation_gif():
15 | logger = logging.getLogger('root')
16 | logger.setLevel('DEBUG')
17 |
18 | data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=32, add_valid=False,
19 | memory_management_strategy='k_partitions_top', k=3)
20 | model = create_dqn_model(data, DQNModule, opt=torch.optim.RMSprop, lr=0.1)
21 | memory = ExperienceReplay(memory_size=1000, reduce_ram=True)
22 | exploration_method = GreedyEpsilon(epsilon_start=1, epsilon_end=0.1, decay=0.001)
23 | learner = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
24 | learner.fit(10)
25 | interp = AgentInterpretation(learner, ds_type=DatasetType.Train)
26 | interp.generate_gif(-1).write('last_episode')
27 | #
28 | #
29 | # @pytest.mark.parametrize(["model_cls", "s_format", "mem"], list(product(p_model, p_format, p_exp)))
30 | # def test_train_gym_maze_interpretation(model_cls, s_format, mem):
31 | # success = False
32 | # while not success:
33 | # try:
34 | # data = MDPDataBunch.from_env('maze-random-5x5-v0', render='rgb_array', bs=5, max_steps=50,
35 | # add_valid=False, feed_type=s_format)
36 | # model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop)
37 | # memory = mem(10000)
38 | # exploration_method = GreedyEpsilon(e_start=1, e_end=0.1, decay=0.001)
39 | # lnr = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
40 | # lnr.fit(1)
41 | #
42 | # interp = GymMazeInterpretation(lnr, ds_type=DatasetType.Train)
43 | # for i in range(-1, 4): interp.plot_heat_map(action=i)
44 | #
45 | # success = True
46 | # except Exception as e:
47 | # if not str(e).__contains__('Surface'):
48 | # raise Exception
49 | #
50 | #
51 | # @pytest.mark.parametrize(["model_cls", "s_format", "mem"], list(product(p_model, p_format, p_exp)))
52 | # def test_train_q_value_interpretation(model_cls, s_format, mem):
53 | # success = False
54 | # while not success:
55 | # try:
56 | # data = MDPDataBunch.from_env('maze-random-5x5-v0', render='rgb_array', bs=5, max_steps=50,
57 | # add_valid=False, feed_type=s_format)
58 | # model = create_dqn_model(data, model_cls, opt=torch.optim.RMSprop)
59 | # memory = mem(10000)
60 | # exploration_method = GreedyEpsilon(e_start=1, e_end=0.1, decay=0.001)
61 | # lnr = dqn_learner(data=data, model=model, memory=memory, exploration_method=exploration_method)
62 | # lnr.fit(1)
63 | #
64 | # interp = QValueInterpretation(lnr, ds_type=DatasetType.Train)
65 | # interp.plot_q()
66 | #
67 | # success = True
68 | # except Exception as e:
69 | # if not str(e).__contains__('Surface'):
70 | # raise Exception(e)
71 | #
72 | # #
73 | # # def test_groupagentinterpretation_from_pickle():
74 | # # group_interp = GroupAgentInterpretation.from_pickle('./data/cartpole_dqn',
75 | # # 'dqn_PriorityExperienceReplay_FEED_TYPE_STATE')
76 | # # group_interp.plot_reward_bounds(return_fig=True, per_episode=True, smooth_groups=5).show()
77 | # #
78 | # #
79 | # # def test_groupagentinterpretation_analysis():
80 | # # group_interp = GroupAgentInterpretation.from_pickle('./data/cartpole_dqn',
81 | # # 'dqn_PriorityExperienceReplay_FEED_TYPE_STATE')
82 | # # assert isinstance(group_interp.analysis, list)
83 | # # group_interp.in_notebook = True
84 | # # assert isinstance(group_interp.analysis, pd.DataFrame)
85 | #
86 | #
87 | #
88 | #
89 | # #
90 | # # def test_interpretation_reward_group_plot():
91 | # # group_interp = GroupAgentInterpretation()
92 | # # group_interp2 = GroupAgentInterpretation()
93 | # #
94 | # # for i in range(2):
95 | # # data = MDPDataBunch.from_env('CartPole-v0', render='rgb_array', bs=4, add_valid=False)
96 | # # model = DQN(data)
97 | # # learn = AgentLearner(data, model)
98 | # # learn.fit(2)
99 | # #
100 | # # interp = AgentInterpretation(learn=learn, ds_type=DatasetType.Train)
101 | # # interp.plot_rewards(cumulative=True, per_episode=True, group_name='run1')
102 | # # group_interp.add_interpretation(interp)
103 | # # group_interp2.add_interpretation(interp)
104 | # #
105 | # # group_interp.plot_reward_bounds(return_fig=True, per_episode=True).show()
106 | # # group_interp2.plot_reward_bounds(return_fig=True, per_episode=True).show()
107 | # #
108 | # # new_interp = group_interp.merge(group_interp2)
109 | # # assert len(new_interp.groups) == len(group_interp.groups) + len(group_interp2.groups), 'Lengths do not match'
110 | # # new_interp.plot_reward_bounds(return_fig=True, per_episode=True).show()
111 |
--------------------------------------------------------------------------------