├── fastrl
├── envs
│ ├── __init__.py
│ ├── debug_env.py
│ └── continuous_debug_env.py
├── agents
│ ├── __init__.py
│ ├── dqn
│ │ ├── __init__.py
│ │ ├── dueling.py
│ │ ├── rainbow.py
│ │ ├── double.py
│ │ └── target.py
│ ├── discrete.py
│ └── core.py
├── funcs
│ ├── __init__.py
│ └── conjugation.py
├── learner
│ ├── __init__.py
│ └── core.py
├── loggers
│ ├── __init__.py
│ ├── tensorboard.py
│ ├── jupyter_visualizers.py
│ └── vscode_visualizers.py
├── memory
│ ├── __init__.py
│ ├── experience_replay.py
│ └── memory_visualizer.py
├── pipes
│ ├── __init__.py
│ ├── map
│ │ ├── __init__.py
│ │ ├── transforms.py
│ │ ├── mux.py
│ │ └── demux.py
│ ├── iter
│ │ ├── __init__.py
│ │ ├── firstlast.py
│ │ ├── nskip.py
│ │ ├── nstep.py
│ │ └── cacheholder.py
│ └── core.py
├── dataloading
│ ├── __init__.py
│ └── core.py
├── __init__.py
├── nbdev_extensions.py
├── test_utils.py
└── cli.py
├── extra
├── requirements.txt
├── pip_requirements.txt
└── dev_requirements.txt
├── nbs
├── .gitignore
├── 11_FAQ
│ ├── header_dates.json
│ ├── 99_template.ipynb
│ ├── 99_notes.multi_proc.ipynb
│ └── 99_notes.speed.ipynb
├── images
│ ├── cat.jpg
│ ├── half.png
│ ├── mnist3.png
│ ├── puppy.jpg
│ ├── sample.dcm
│ ├── ulmfit.png
│ ├── favicon.ico
│ ├── layered.png
│ ├── pets
│ │ ├── cat.jpg
│ │ └── puppy.jpg
│ ├── att_00000.png
│ ├── att_00001.png
│ ├── att_00002.png
│ ├── hf_hub_fastai.png
│ ├── hf_model_card.png
│ ├── pixelshuffle.png
│ ├── Mixed_precision.jpeg
│ ├── half_representation.png
│ ├── siim_folder_structure.jpeg
│ ├── 10e_agents.dqn.categorical_algorithm1.png
│ ├── (Lillicrap et al., 2016) DDPG Algorithm 1.png
│ └── (Schulman et al., 2017) [TRPO] Trust Region Policy Optimization Algorithm 1.png
├── 01_DataPipes
│ └── sidebar.yml
├── testing
│ ├── fastrl_test_env.yaml
│ ├── test_settings.ini
│ └── fastrl_test_dev_env.yaml
├── nbdev.yml
├── _quarto.yml
├── styles.css
├── external_run_scripts
│ ├── spawn_multiproc.py
│ ├── agents_dqn_async_35.py
│ └── notes_multi_proc_82.py
├── 12_Blog
│ ├── 99_blog.from_xxxx_xx_to_xx.ipynb
│ └── 99_blog.from_2023_05_to_now.ipynb
├── 00_template.ipynb
├── sidebar.yml
├── 05_Logging
│ ├── 09e_loggers.tensorboard.ipynb
│ ├── 09d_loggers.jupyter_visualizers.ipynb
│ └── 09f_loggers.vscode_visualizers.ipynb
├── README.md
├── 07_Agents
│ └── 01_Discrete
│ │ └── 12n_agents.dqn.dueling.ipynb
└── 02_DataLoading
│ └── 00_core.ipynb
├── .gitmodules
├── MANIFEST.in
├── fastrl.code-workspace
├── .pre-commit-config.yaml
├── Makefile
├── .devcontainer.json
├── docker-compose.yml
├── .github
└── workflows
│ ├── fastrl-test.yml
│ ├── quarto_deploy.yml
│ ├── python-publish.yml
│ └── fastrl-docker.yml
├── settings.ini
├── .gitignore
├── CONTRIBUTING.md
├── setup.py
├── README.md
└── fastrl.Dockerfile
/fastrl/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/funcs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/learner/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/loggers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/memory/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/pipes/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/pipes/map/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/agents/dqn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/dataloading/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fastrl/pipes/iter/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/extra/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.0
--------------------------------------------------------------------------------
/nbs/.gitignore:
--------------------------------------------------------------------------------
1 | /.quarto/
2 | /_docs/
3 |
--------------------------------------------------------------------------------
/fastrl/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.47"
2 |
--------------------------------------------------------------------------------
/nbs/11_FAQ/header_dates.json:
--------------------------------------------------------------------------------
1 | {"1": "2023-05-29"}
--------------------------------------------------------------------------------
/nbs/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/cat.jpg
--------------------------------------------------------------------------------
/nbs/images/half.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/half.png
--------------------------------------------------------------------------------
/nbs/images/mnist3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/mnist3.png
--------------------------------------------------------------------------------
/nbs/images/puppy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/puppy.jpg
--------------------------------------------------------------------------------
/nbs/images/sample.dcm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/sample.dcm
--------------------------------------------------------------------------------
/nbs/images/ulmfit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/ulmfit.png
--------------------------------------------------------------------------------
/nbs/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/favicon.ico
--------------------------------------------------------------------------------
/nbs/images/layered.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/layered.png
--------------------------------------------------------------------------------
/nbs/images/pets/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/pets/cat.jpg
--------------------------------------------------------------------------------
/nbs/images/att_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/att_00000.png
--------------------------------------------------------------------------------
/nbs/images/att_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/att_00001.png
--------------------------------------------------------------------------------
/nbs/images/att_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/att_00002.png
--------------------------------------------------------------------------------
/nbs/images/pets/puppy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/pets/puppy.jpg
--------------------------------------------------------------------------------
/nbs/images/hf_hub_fastai.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/hf_hub_fastai.png
--------------------------------------------------------------------------------
/nbs/images/hf_model_card.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/hf_model_card.png
--------------------------------------------------------------------------------
/nbs/images/pixelshuffle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/pixelshuffle.png
--------------------------------------------------------------------------------
/nbs/images/Mixed_precision.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/Mixed_precision.jpeg
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "data"]
2 | path = data
3 | url = https://github.com/josiahls/data.git
4 | ignore = dirty
5 |
--------------------------------------------------------------------------------
/nbs/images/half_representation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/half_representation.png
--------------------------------------------------------------------------------
/nbs/images/siim_folder_structure.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/siim_folder_structure.jpeg
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include settings.ini
2 | include LICENSE
3 | include CONTRIBUTING.md
4 | include README.md
5 | recursive-exclude * __pycache__
6 |
--------------------------------------------------------------------------------
/fastrl.code-workspace:
--------------------------------------------------------------------------------
1 | {
2 | "folders": [
3 | {
4 | "path": "."
5 | }
6 | ],
7 | "settings": {},
8 | "editor.rulers": [80,120],
9 | }
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/fastai/nbdev
3 | rev: 2.2.10
4 | hooks:
5 | - id: nbdev_clean
6 | - id: nbdev_export
--------------------------------------------------------------------------------
/nbs/images/10e_agents.dqn.categorical_algorithm1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/10e_agents.dqn.categorical_algorithm1.png
--------------------------------------------------------------------------------
/nbs/images/(Lillicrap et al., 2016) DDPG Algorithm 1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/(Lillicrap et al., 2016) DDPG Algorithm 1.png
--------------------------------------------------------------------------------
/extra/pip_requirements.txt:
--------------------------------------------------------------------------------
1 | gymnasium>=0.27.1
2 | tensordict
3 | pyopengl
4 | pyglet
5 | tensorboard
6 | pygame
7 | pandas
8 | scipy
9 | scikit-learn
10 | fastcore
11 | fastprogress
--------------------------------------------------------------------------------
/nbs/images/(Schulman et al., 2017) [TRPO] Trust Region Policy Optimization Algorithm 1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/(Schulman et al., 2017) [TRPO] Trust Region Policy Optimization Algorithm 1.png
--------------------------------------------------------------------------------
/nbs/01_DataPipes/sidebar.yml:
--------------------------------------------------------------------------------
1 | website:
2 | sidebar:
3 | contents:
4 | - 01a_pipes.core.ipynb
5 | - 01b_pipes.map.demux.ipynb
6 | - 01c_pipes.map.mux.ipynb
7 | - 01d_pipes.iter.nskip.ipynb
8 | - 01e_pipes.iter.nstep.ipynb
--------------------------------------------------------------------------------
/fastrl/funcs/conjugation.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/02_Funcs/00_funcs.conjugation.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = []
5 |
6 | # %% ../../nbs/02_Funcs/00_funcs.conjugation.ipynb 3
7 | # Python native modules
8 |
9 | # Third party libs
10 | import torch
11 | import numpy as np
12 | # Local modules
13 |
--------------------------------------------------------------------------------
/extra/dev_requirements.txt:
--------------------------------------------------------------------------------
1 | jupyter
2 | gymnasium[all]>=0.27.1
3 | nbdev>=2.3.1
4 | pre-commit
5 | ipywidgets
6 | moviepy
7 | pygifsicle
8 | aquirdturtle_collapsible_headings
9 | plotly
10 | matplotlib_inline
11 | wheel
12 | twine
13 | fastdownload
14 | watchdog[watchmedo]
15 | graphviz
16 | typing-extensions>=4.3.0
17 | spacy<4
18 | mypy
19 | pyvirtualdisplay
--------------------------------------------------------------------------------
/nbs/testing/fastrl_test_env.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - pytorch
4 | - fastai
5 | dependencies:
6 | - python=3.6
7 | - pip
8 | - setuptools
9 | - fastai>=2.0.0
10 | - moviepy
11 | - jupyter
12 | - notebook
13 | - setuptools
14 | - pip:
15 | - pytest
16 | - nvidia-ml-py3
17 | - dataclasses
18 | - pandas
19 | - pyyaml
20 | name: fastrl_test
21 |
--------------------------------------------------------------------------------
/nbs/nbdev.yml:
--------------------------------------------------------------------------------
1 | project:
2 | output-dir: _docs
3 |
4 | website:
5 | title: "fastrl"
6 | site-url: "https://josiahls.github.io/fastrl/"
7 | description: "fastrl is a reinforcement learning library that extends Fastai. This project is not affiliated with fastai or Jeremy Howard."
8 | repo-branch: main
9 | repo-url: "https://github.com/josiahls/fastrl/tree/main/"
10 |
--------------------------------------------------------------------------------
/nbs/_quarto.yml:
--------------------------------------------------------------------------------
1 | project:
2 | type: website
3 |
4 | format:
5 | html:
6 | theme: cosmo
7 | css: styles.css
8 | toc: true
9 |
10 | website:
11 | twitter-card: true
12 | open-graph: true
13 | repo-actions: [issue]
14 | navbar:
15 | background: primary
16 | search: true
17 | sidebar:
18 | style: floating
19 |
20 | metadata-files: [nbdev.yml, sidebar.yml]
--------------------------------------------------------------------------------
/nbs/styles.css:
--------------------------------------------------------------------------------
1 | .cell-output pre {
2 | margin-left: 0.8rem;
3 | margin-top: 0;
4 | background: none;
5 | border-left: 2px solid lightsalmon;
6 | border-top-left-radius: 0;
7 | border-top-right-radius: 0;
8 | }
9 |
10 | .cell-output .sourceCode {
11 | background: none;
12 | margin-top: 0;
13 | }
14 |
15 | .cell > .sourceCode {
16 | margin-bottom: 0;
17 | }
--------------------------------------------------------------------------------
/nbs/testing/test_settings.ini:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | lib_name = fastrl_test
3 | user = josiahls
4 | branch = master
5 | version = 0.0.1
6 | min_python = 3.6
7 | requirements = fastai>=2.0.0 moviepy
8 | pip_requirements = pytest nvidia-ml-py3 dataclasses pandas pyyaml
9 | conda_requirements = jupyter notebook setuptools
10 | dev_requirements = jupyterlab nbdev ipywidgets moviepy pygifsicle aquirdturtle_collapsible_headings
11 |
--------------------------------------------------------------------------------
/nbs/testing/fastrl_test_dev_env.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - pytorch
4 | - fastai
5 | dependencies:
6 | - python=3.6
7 | - pip
8 | - setuptools
9 | - fastai>=2.0.0
10 | - moviepy
11 | - jupyter
12 | - notebook
13 | - setuptools
14 | - jupyterlab
15 | - nbdev
16 | - ipywidgets
17 | - moviepy
18 | - pygifsicle
19 | - aquirdturtle_collapsible_headings
20 | - pip:
21 | - pytest
22 | - nvidia-ml-py3
23 | - dataclasses
24 | - pandas
25 | - pyyaml
26 | name: fastrl_test_dev
27 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .ONESHELL:
2 | SHELL := /bin/bash
3 | SRC = $(wildcard nbs/*.ipynb)
4 |
5 | all: fastrl docs
6 |
7 | fastrl: $(SRC)
8 | nbdev_build_lib
9 | touch fastrl
10 |
11 | sync:
12 | nbdev_update_lib
13 |
14 | docs_serve: docs
15 | cd docs && bundle exec jekyll serve
16 |
17 | docs: $(SRC)
18 | nbdev_build_docs
19 | touch docs
20 |
21 | test:
22 | nbdev_test_nbs
23 |
24 | release: pypi conda_release
25 | nbdev_bump_version
26 |
27 | conda_release:
28 | fastrelease_conda_package
29 |
30 | pypi: dist
31 | twine upload --repository pypi dist/*
32 |
33 | dist: clean
34 | python setup.py sdist bdist_wheel
35 |
36 | clean:
37 | rm -rf dist
--------------------------------------------------------------------------------
/.devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "fastrl_development",
3 | "dockerComposeFile": "docker-compose.yml",
4 | "service": "fastrl",
5 | "customizations":{
6 | "vscode": {
7 | "settings": {"terminal.integrated.shell.linux": "/bin/bash"}
8 | }
9 | },
10 | // "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind",
11 | // "source=/usr/bin/docker,target=/usr/bin/docker,type=bind" ],
12 | "workspaceFolder": "/home/fastrl_user/fastrl",
13 | // "forwardPorts": [4000, 8080],
14 | // "appPort": [4000, 8080],
15 | // "extensions": [
16 | // "ms-python.python",
17 | // "ms-azuretools.vscode-docker",
18 | // "ms-toolsai.jupyter-renderers"
19 | // ],
20 | "runServices": ["dep_watcher", "quarto"] //,
21 | // "postStartCommand": "pip install -e .[dev]"
22 | }
--------------------------------------------------------------------------------
/nbs/external_run_scripts/spawn_multiproc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchdata.datapipes as dp
3 | from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService
4 |
5 | class PointlessLoop(dp.iter.IterDataPipe):
6 | def __init__(self,datapipe=None):
7 | self.datapipe = datapipe
8 |
9 | def __iter__(self):
10 | while True:
11 | yield torch.LongTensor(4).detach().clone()
12 |
13 |
14 | if __name__=='__main__':
15 | from torch.multiprocessing import Pool, Process, set_start_method
16 | try:
17 | set_start_method('spawn')
18 | except RuntimeError:
19 | pass
20 |
21 |
22 | pipe = PointlessLoop()
23 | pipe = pipe.header(limit=10)
24 | dls = [DataLoader2(pipe,
25 | reading_service=MultiProcessingReadingService(
26 | num_workers = 2
27 | ))]
28 | # Setup the Learner
29 | print('type: ',type(dls[0]))
30 | for o in dls[0]:
31 | print(o)
32 |
--------------------------------------------------------------------------------
/nbs/12_Blog/99_blog.from_xxxx_xx_to_xx.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "assisted-contract",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# Python native modules\n",
11 | "# Third party libs\n",
12 | "\n",
13 | "# Local modules\n",
14 | "from fastrl.nbdev_extensions import header"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "40e740d6",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "#|echo: false\n",
25 | "header(\n",
26 | " \"Test run\",\n",
27 | " \"subtitle\",\n",
28 | " freeze=False\n",
29 | ")"
30 | ]
31 | },
32 | {
33 | "attachments": {},
34 | "cell_type": "markdown",
35 | "id": "fec5e005",
36 | "metadata": {},
37 | "source": []
38 | }
39 | ],
40 | "metadata": {
41 | "kernelspec": {
42 | "display_name": "python3",
43 | "language": "python",
44 | "name": "python3"
45 | }
46 | },
47 | "nbformat": 4,
48 | "nbformat_minor": 5
49 | }
50 |
--------------------------------------------------------------------------------
/fastrl/agents/dqn/dueling.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['DuelingHead']
5 |
6 | # %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 2
7 | # Python native modules
8 | # Third party libs
9 | import torch
10 | from torch import nn
11 | # Local modules
12 | from fastrl.agents.dqn.basic import (
13 | DQN,
14 | DQNAgent
15 | )
16 | from .target import DQNTargetLearner
17 |
18 | # %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 5
19 | class DuelingHead(nn.Module):
20 | def __init__(
21 | self,
22 | hidden: int, # Input into the DuelingHead, likely a hidden layer input
23 | n_actions: int, # Number/dim of actions to output
24 | lin_cls = nn.Linear
25 | ):
26 | super().__init__()
27 | self.val = lin_cls(hidden,1)
28 | self.adv = lin_cls(hidden,n_actions)
29 |
30 | def forward(self,xi):
31 | val,adv = self.val(xi),self.adv(xi)
32 | xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
33 | return xi
34 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.8"
2 | services:
3 | fastrl_build:
4 | build:
5 | dockerfile: fastrl.Dockerfile
6 | context: .
7 | image: josiahls/fastrl-dev:latest
8 | profiles: ["build"]
9 |
10 | fastrl: &fastrl
11 | restart: unless-stopped
12 | working_dir: /home/fastrl_user/fastrl
13 | image: josiahls/fastrl-dev:latest
14 | deploy:
15 | resources:
16 | reservations:
17 | devices:
18 | - capabilities:
19 | - gpu
20 | logging:
21 | driver: json-file
22 | options:
23 | max-size: 50m
24 | stdin_open: true
25 | tty: true
26 | shm_size: 16000000000
27 | volumes:
28 | - .:/home/fastrl_user/fastrl/
29 | - ~/.ssh:/home/fastrl_user/.ssh:rw
30 | network_mode: host # for GitHub Codespaces https://github.com/features/codespaces/
31 |
32 | dep_watcher:
33 | <<: *fastrl
34 | command: watchmedo shell-command --command fastrl_make_requirements --pattern *.ini --recursive --drop
35 |
36 | quarto:
37 | <<: *fastrl
38 | restart: unless-stopped
39 | working_dir: /home/fastrl_user/fastrl/nbs/_docs
40 | volumes:
41 | - .:/home/fastrl_user/fastrl/
42 | command: nbdev_preview
43 |
--------------------------------------------------------------------------------
/nbs/11_FAQ/99_template.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "offshore-stuart",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "# #|default_exp template"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "assisted-contract",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# #|export\n",
33 | "# Python native modules\n",
34 | "import os\n",
35 | "# Third party libs\n",
36 | "from fastcore.all import *\n",
37 | "# Local modules"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "40e740d6",
43 | "metadata": {},
44 | "source": [
45 | "# Template\n",
46 | "> notebook"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "id": "current-pilot",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "#|hide\n",
57 | "#|eval: false\n",
58 | "!nbdev_export"
59 | ]
60 | }
61 | ],
62 | "metadata": {
63 | "kernelspec": {
64 | "display_name": "python3",
65 | "language": "python",
66 | "name": "python3"
67 | }
68 | },
69 | "nbformat": 4,
70 | "nbformat_minor": 5
71 | }
72 |
--------------------------------------------------------------------------------
/fastrl/loggers/tensorboard.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_Logging/09e_loggers.tensorboard.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['run_tensorboard']
5 |
6 | # %% ../../nbs/05_Logging/09e_loggers.tensorboard.ipynb 1
7 | # Python native modules
8 | import os
9 | from pathlib import Path
10 | # Third party libs
11 | import torchdata.datapipes as dp
12 | # Local modules
13 |
14 | # %% ../../nbs/05_Logging/09e_loggers.tensorboard.ipynb 4
15 | def run_tensorboard(
16 | port:int=6006, # The port to run tensorboard on/connect on
17 | start_tag:str=None, # Starting regex e.g.: experience_replay/1
18 | samples_per_plugin:str=None, # Sampling freq such as images=0 (keep all)
19 | extra_args:str=None, # Any additional arguments in the `--arg value` format
20 | rm_glob:bool=None # Remove old logs via a parttern e.g.: '*' will remove all files: runs/*
21 | ):
22 | if rm_glob is not None:
23 | for p in Path('runs').glob(rm_glob): p.delete()
24 | import socket
25 | from tensorboard import notebook
26 | a_socket=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27 | cmd=None
28 | if not a_socket.connect_ex(('127.0.0.1',6006)):
29 | notebook.display(port=port,height=1000)
30 | else:
31 | cmd=f'--logdir runs --port {port} --host=0.0.0.0'
32 | if samples_per_plugin is not None: cmd+=f' --samples_per_plugin {samples_per_plugin}'
33 | if start_tag is not None: cmd+=f' --tag {start_tag}'
34 | if extra_args is not None: cmd+=f' {extra_args}'
35 | notebook.start(cmd)
36 | return cmd
37 |
--------------------------------------------------------------------------------
/nbs/00_template.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "offshore-stuart",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "# #|default_exp template"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "assisted-contract",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# #|export\n",
33 | "# Python native modules\n",
34 | "\n",
35 | "# Third party libs\n",
36 | "\n",
37 | "# Local modules"
38 | ]
39 | },
40 | {
41 | "attachments": {},
42 | "cell_type": "markdown",
43 | "id": "a258abcf",
44 | "metadata": {},
45 | "source": [
46 | "# Template"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "id": "current-pilot",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "#|hide\n",
57 | "#|eval: false\n",
58 | "!nbdev_export"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "ed71a089",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": []
68 | }
69 | ],
70 | "metadata": {
71 | "kernelspec": {
72 | "display_name": "python3",
73 | "language": "python",
74 | "name": "python3"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 5
79 | }
80 |
--------------------------------------------------------------------------------
/fastrl/pipes/map/transforms.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01h_pipes.map.transforms.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['T_co', 'TypeTransformer']
5 |
6 | # %% ../../../nbs/01_DataPipes/01h_pipes.map.transforms.ipynb 3
7 | # Python native modules
8 | from typing import Callable,Union,TypeVar
9 | # Third party libs
10 | from fastcore.all import *
11 | import torchdata.datapipes as dp
12 | from torchdata.datapipes.map import MapDataPipe
13 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe
14 | # Local modules
15 |
16 | # %% ../../../nbs/01_DataPipes/01h_pipes.map.transforms.ipynb 5
17 | T_co = TypeVar("T_co", covariant=True)
18 |
19 | class TypeTransformer(dp.map.MapDataPipe):
20 | def __init__(
21 | self,
22 | # Should allow `__getitem__` and producing elements to be injested by `type_tfms`
23 | source_datapipe:MapDataPipe[T_co],
24 | # A list of Callables that accept an input, and return an output
25 | type_tfms:List[Callable]
26 | ) -> None:
27 | self.type_tfms:Pipeline[Callable] = Pipeline(type_tfms)
28 | self.source_datapipe:MapDataPipe[T_co] = source_datapipe
29 |
30 | def __getitem__(self, index) -> T_co:
31 | data = self.source_datapipe[index]
32 | return self.type_tfms(data)
33 |
34 | def __len__(self) -> int: return len(self.source_datapipe)
35 |
36 | TypeTransformer.__doc__ = """On `__getitem__` functions in `self.type_tfms` get called over each element.
37 | Generally `TypeTransformer` as the name suggests is intended to convert elements from one type to another.
38 | reference documentation on how to combine this with `InMemoryCacheHolder`."""
39 |
--------------------------------------------------------------------------------
/.github/workflows/fastrl-test.yml:
--------------------------------------------------------------------------------
1 | name: Fastrl Testing
2 | on: [push, pull_request]
3 |
4 | # env:
5 | # PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/nightly/cu113
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | container:
11 | image: 'josiahls/fastrl-dev:latest'
12 |
13 | steps:
14 | - uses: actions/checkout@v1
15 | # - uses: actions/setup-python@v1
16 | # with:
17 | # python-version: '3.7'
18 | # architecture: 'x64'
19 | - name: Install the library
20 | run: |
21 | sudo mkdir -p /github/home
22 | sudo pip install -e .["dev"]
23 | # - name: Read all notebooks
24 | # run: |
25 | # nbdev_read_nbs
26 | - name: Check if all notebooks are cleaned
27 | run: |
28 | sudo git config --global --add safe.directory /__w/fastrl/fastrl
29 | echo "Check we are starting with clean git checkout"
30 | if [ -n "$(git status -uno -s)" ]; then echo "git status is not clean"; false; fi
31 | echo "Trying to strip out notebooks"
32 | sudo nbdev_clean
33 | echo "Check that strip out was unnecessary"
34 | git status -s # display the status to see which nbs need cleaning up
35 | if [ -n "$(git status -uno -s)" ]; then echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_hooks"; false; fi
36 |
37 | # - name: Check if there is no diff library/notebooks
38 | # run: |
39 | # if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi
40 | - name: Run tests
41 | run: |
42 | pip3 show torchdata torch
43 | sudo pip3 install -e .
44 | cd nbs
45 | xvfb-run -s "-screen 0 1400x900x24" fastrl_nbdev_test --n_workers 12 --one2one
46 | - name: Run Doc Build Test
47 | run: |
48 | fastrl_nbdev_docs --one2one
49 |
--------------------------------------------------------------------------------
/fastrl/dataloading/core.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/02_DataLoading/00_core.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['dataloaders']
5 |
6 | # %% ../../nbs/02_DataLoading/00_core.ipynb 2
7 | # Python native modules
8 | from typing import Tuple,Union,List
9 | # Third party libs
10 | import torchdata.datapipes as dp
11 | from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2
12 | from fastcore.all import delegates
13 | # Local modules
14 |
15 | # %% ../../nbs/02_DataLoading/00_core.ipynb 4
16 | @delegates(MultiProcessingReadingService)
17 | def dataloaders(
18 | # A tuple of iterable datapipes to generate dataloaders from.
19 | pipes:Union[Tuple[dp.iter.IterDataPipe],dp.iter.IterDataPipe],
20 | # Concat the dataloaders together
21 | do_concat:bool = False,
22 | # Multiplex the dataloaders
23 | do_multiplex:bool = False,
24 | # Number of workers the dataloaders should run in
25 | num_workers: int = 0,
26 | **kwargs
27 | ) -> Union[dp.iter.IterDataPipe,List[dp.iter.IterDataPipe]]:
28 | "Function that creates dataloaders based on `pipes` with different ways of combing them."
29 | if not isinstance(pipes,tuple):
30 | pipes = (pipes,)
31 |
32 | dls = []
33 | for pipe in pipes:
34 | dl = DataLoader2(
35 | datapipe=pipe,
36 | reading_service=MultiProcessingReadingService(
37 | num_workers = num_workers,
38 | **kwargs
39 | ) if num_workers > 0 else None
40 | )
41 | dl = dp.iter.IterableWrapper(dl,deepcopy=False)
42 | dls.append(dl)
43 | #TODO(josiahls): Not sure if this is needed tbh.. Might be better to just
44 | # return dls, and have the user wrap them if they want. Then try can do more complex stuff.
45 | if do_concat:
46 | return dp.iter.Concater(*dls)
47 | elif do_multiplex:
48 | return dp.iter.Multiplexer(*dls)
49 | else:
50 | return dls
51 |
52 |
--------------------------------------------------------------------------------
/settings.ini:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | lib_name = fastrl
3 | description = fastrl is a reinforcement learning library that extends Fastai. This project is not affiliated with fastai or Jeremy Howard.
4 | copyright = 2021 onwards, Josiah Laivins
5 | keywords = fastrl reinforcement learning rl robotics fastai deep learning machine learning
6 | user = josiahls
7 | author = Josiah Laivins, and contributors
8 | author_email =
9 | branch = main
10 | min_python = 3.7
11 | version = 0.0.47
12 | audience = Developers
13 | language = English
14 | custom_sidebar = True
15 | license = apache2
16 | status = 2
17 | dep_links = https://download.pytorch.org/whl/nightly/cu117
18 | requirements = torch>=2.0.0
19 | data_requirements = git+https://github.com/josiahls/data.git@main#egg=torchdata
20 | pip_requirements = gymnasium>=0.27.1 tensordict pyopengl pyglet tensorboard pygame pandas scipy scikit-learn fastcore fastprogress
21 | #nbformat>=4.2.0,<5.*
22 | conda_requirements =
23 | dev_requirements = jupyter gymnasium[all]>=0.27.1 nbdev>=2.3.1 pre-commit ipywidgets moviepy pygifsicle aquirdturtle_collapsible_headings plotly matplotlib_inline wheel twine fastdownload watchdog[watchmedo] graphviz typing-extensions>=4.3.0 spacy<4 mypy pyvirtualdisplay
24 | console_scripts = fastrl_make_requirements=fastrl.cli:fastrl_make_requirements
25 | fastrl_nbdev_docs=fastrl.cli:fastrl_nbdev_docs
26 | fastrl_nbdev_test=fastrl.cli:fastrl_nbdev_test
27 | fastrl_nbdev_create_blog_notebook=fastrl.nbdev_extensions:create_blog_notebook
28 | tst_flags = slow cpp cuda multicuda
29 | nbs_path = nbs
30 | doc_path = _docs
31 | recursive = True
32 | doc_host = https://josiahls.github.io
33 | doc_baseurl = /fastrl/
34 | git_url = https://github.com/josiahls/fastrl/tree/main/
35 | lib_path = fastrl
36 | title = fastrl
37 | black_formatting = False
38 | readme_nb = index.ipynb
39 | allowed_metadata_keys =
40 | allowed_cell_metadata_keys =
41 | jupyter_hooks = True
42 | clean_ids = True
43 | clear_all = False
44 | put_version_in_init = True
45 |
46 |
--------------------------------------------------------------------------------
/.github/workflows/quarto_deploy.yml:
--------------------------------------------------------------------------------
1 | name: Deploy to GitHub Pages
2 | on:
3 | push:
4 | branches: [ "main", "master", "update_to_nbdev_2", "*docs" ]
5 | workflow_dispatch:
6 |
7 | env:
8 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/nightly/cu113
9 |
10 | jobs:
11 | deploy:
12 | runs-on: ubuntu-latest
13 | container:
14 | image: 'josiahls/fastrl-dev:latest'
15 |
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Install Dependencies
19 | shell: bash
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install -e ".[dev]" --upgrade
23 | fastrl_nbdev_docs --one2one
24 | # - uses: fastai/workflows/quarto-ghp@master
25 | - name: Enable GitHub Pages
26 | shell: python
27 | run: |
28 | import ghapi.core,nbdev.config,sys
29 | msg="Please enable GitHub Pages to publish from the root of the `gh-pages` branch per these instructions - https://docs.github.com/en/pages/getting-started-with-github-pages/configuring-a-publishing-source-for-your-github-pages-site#publishing-from-a-branch"
30 | try:
31 | api = ghapi.core.GhApi(owner=nbdev.config.get_config().user, repo=nbdev.config.get_config().repo, token="${{secrets.GITHUB_TOKEN}}")
32 | api.enable_pages(branch='gh-pages')
33 | except Exception as e:
34 | print(f'::error title="Could not enable GitHub Pages Automatically":: {msg}\n{e}')
35 | sys.exit(1)
36 | - name: Deploy to GitHub Pages
37 | uses: peaceiris/actions-gh-pages@v3
38 | with:
39 | github_token: ${{secrets.GITHUB_TOKEN}}
40 | force_orphan: true
41 | publish_dir: ./_docs
42 | # The following lines assign commit authorship to the official GH-Actions bot for deploys to `gh-pages` branch.
43 | # You can swap them out with your own user credentials.
44 | user_name: github-actions[bot]
45 | user_email: 41898282+github-actions[bot]@users.noreply.github.com
46 |
47 |
--------------------------------------------------------------------------------
/nbs/external_run_scripts/agents_dqn_async_35.py:
--------------------------------------------------------------------------------
1 | # %%python
2 |
3 | if __name__=='__main__':
4 | from torch.multiprocessing import Pool, Process, set_start_method
5 |
6 | try:
7 | set_start_method('spawn')
8 | except RuntimeError:
9 | pass
10 |
11 | from fastcore.all import *
12 | import torch
13 | from torch.nn import *
14 | import torch.nn.functional as F
15 | from fastrl.loggers.core import *
16 | from fastrl.loggers.jupyter_visualizers import *
17 | from fastrl.learner.core import *
18 | from fastrl.data.block import *
19 | from fastrl.envs.gym import *
20 | from fastrl.agents.core import *
21 | from fastrl.agents.discrete import *
22 | from fastrl.agents.dqn.basic import *
23 | from fastrl.agents.dqn.asynchronous import *
24 |
25 | from torchdata.dataloader2 import DataLoader2
26 | from torchdata.dataloader2.graph import traverse
27 | from fastrl.data.dataloader2 import *
28 |
29 | logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
30 | batch_on_pipe=BatchCollector)
31 |
32 | # Setup up the core NN
33 | torch.manual_seed(0)
34 | model = DQN(4,2).cuda()
35 | # model.share_memory() # This will not work in spawn
36 | # Setup the Agent
37 | agent = DQNAgent(model,max_steps=4000,device='cuda',
38 | dp_augmentation_fns=[ModelSubscriber.insert_dp()])
39 | # Setup the DataBlock
40 | block = DataBlock(
41 | GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False),
42 | GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,include_images=True)
43 | )
44 | dls = L(block.dataloaders(['CartPole-v1']*1,num_workers=1))
45 | # # Setup the Learner
46 | learner = DQNLearner(model,dls,batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda',
47 | dp_augmentation_fns=[ModelPublisher.insert_dp(publish_freq=10)])
48 | # print(traverse(learner))
49 | learner.fit(20)
50 |
--------------------------------------------------------------------------------
/nbs/external_run_scripts/notes_multi_proc_82.py:
--------------------------------------------------------------------------------
1 | import torchdata.datapipes as dp
2 | from torch.utils.data import IterableDataset
3 |
4 | class AddABunch1(dp.iter.IterDataPipe):
5 | def __init__(self,q):
6 | super().__init__()
7 | self.q = [q]
8 |
9 | def __iter__(self):
10 | for o in range(10):
11 | self.q[0].put(o)
12 | yield o
13 |
14 | class AddABunch2(dp.iter.IterDataPipe):
15 | def __init__(self,source_datapipe,q):
16 | super().__init__()
17 | self.q = q
18 | print(id(self.q))
19 | self.source_datapipe = source_datapipe
20 |
21 | def __iter__(self):
22 | for o in self.source_datapipe:
23 | print(id(self.q))
24 | self.q.put(o)
25 | yield o
26 |
27 | class AddABunch3(IterableDataset):
28 | def __init__(self,q):
29 | self.q = q
30 |
31 | def __iter__(self):
32 | for o in range(10):
33 | print(id(self.q))
34 | self.q.put(o)
35 | yield o
36 |
37 | if __name__=='__main__':
38 | from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method
39 | import torch
40 |
41 | try: set_start_method('spawn')
42 | except RuntimeError: pass
43 | # from torch.utils.data.dataloader_experimental import DataLoader2
44 | from torchdata.dataloader2 import DataLoader2
45 | from torchdata.dataloader2.reading_service import MultiProcessingReadingService
46 |
47 | m = Manager()
48 | q = m.Queue()
49 |
50 | pipe = AddABunch2(list(range(10)),q)
51 | print(type(pipe))
52 | dl = DataLoader2(pipe,
53 | reading_service=MultiProcessingReadingService(num_workers=1)
54 | ) # Will fail if num_workers>0
55 |
56 | # dl = DataLoader2(AddABunch1(q),num_workers=1) # Will fail if num_workers>0
57 | # dl = DataLoader2(AddABunch2(q),num_workers=1) # Will fail if num_workers>0
58 | # dl = DataLoader2(AddABunch3(q),num_workers=1) # Will succeed if num_workers>0
59 | list(dl)
60 |
61 | while not q.empty():
62 | print(q.get())
63 |
--------------------------------------------------------------------------------
/fastrl/envs/debug_env.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03_Environment/06_envs.debug_env.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['SimpleContinuousEnv']
5 |
6 | # %% ../../nbs/03_Environment/06_envs.debug_env.ipynb 2
7 | # Python native modules
8 | import os
9 | # Third party libs
10 | import gymnasium as gym
11 | from gymnasium import spaces
12 | import numpy as np
13 | # Local modules
14 |
15 | # %% ../../nbs/03_Environment/06_envs.debug_env.ipynb 4
16 | class SimpleContinuousEnv(gym.Env):
17 | metadata = {'render.modes': ['console']}
18 |
19 | def __init__(self, goal_position=None, proximity_threshold=0.5):
20 | super(SimpleContinuousEnv, self).__init__()
21 |
22 | self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
23 | self.observation_space = spaces.Box(low=-10, high=10, shape=(1,), dtype=np.float32)
24 |
25 | self.goal_position = goal_position if goal_position is not None else np.random.uniform(-10, 10)
26 | self.proximity_threshold = proximity_threshold
27 | self.state = None
28 |
29 | def step(self, action):
30 | self.state += action
31 |
32 | # Calculate the distance to the goal
33 | distance_to_goal = np.abs(self.state - self.goal_position)
34 |
35 | # Calculate reward: higher for being closer to the goal
36 | reward = -distance_to_goal
37 |
38 | # Check if the agent is within the proximity threshold of the goal
39 | done = distance_to_goal <= self.proximity_threshold
40 |
41 | info = {}
42 |
43 | return self.state, reward, done, info
44 |
45 | def reset(self):
46 | self.state = np.array([0.0], dtype=np.float32)
47 | if self.goal_position is None:
48 | self.goal_position = np.random.uniform(-10, 10)
49 | return self.state
50 |
51 |
52 | def render(self, mode='console'):
53 | if mode != 'console':
54 | raise NotImplementedError("Only console mode is supported.")
55 | print(f"Position: {self.state} Goal: {self.goal_position}")
56 |
57 |
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.bak
2 | .gitattributes
3 | .last_checked
4 | .gitconfig
5 | *.bak
6 | *.log
7 | *~
8 | ~*
9 | _tmp*
10 | tags
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | env/
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # dotenv
94 | .env
95 |
96 | # virtualenv
97 | .venv
98 | venv/
99 | ENV/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 | .spyproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
108 | # mkdocs documentation
109 | /site
110 |
111 | # mypy
112 | .mypy_cache/
113 |
114 | .vscode
115 | *.swp
116 |
117 | # osx generated files
118 | .DS_Store
119 | .DS_Store?
120 | .Trashes
121 | ehthumbs.db
122 | Thumbs.db
123 | .idea
124 |
125 | # pytest
126 | .pytest_cache
127 |
128 | # tools/trust-doc-nbs
129 | docs_src/.last_checked
130 |
131 | # symlinks to fastai
132 | docs_src/fastai
133 | tools/fastai
134 |
135 | # link checker
136 | checklink/cookies.txt
137 |
138 | # .gitconfig is now autogenerated
139 | .gitconfig
140 |
141 | /bfg-1.13.0.jar
142 | /sensitive_info.txt
143 | /nbs/runs/
144 | /quarto-linux-amd64.deb
145 | **/**/_docs
146 | /_proc/
147 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | ## How to get started
4 |
5 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it:
6 | ```
7 | nbdev_install_git_hooks
8 | ```
9 |
10 | ## Did you find a bug?
11 |
12 | * Ensure the bug was not already reported by searching on GitHub under Issues.
13 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
14 | * Be sure to add the complete error messages.
15 |
16 | #### Did you write a patch that fixes a bug?
17 |
18 | * Open a new GitHub pull request with the patch.
19 | * Ensure that your PR includes a test that fails without your patch, and pass with it.
20 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
21 |
22 | ## PR submission guidelines
23 |
24 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
25 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
26 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
27 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
28 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
29 |
30 | ## Do you want to contribute to the documentation?
31 |
32 | * Docs are automatically created from the notebooks in the nbs folder.
33 |
34 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python PyPi and Conda
5 |
6 | on:
7 | release:
8 | types: [created]
9 | # schedule:
10 | # - cron: '1 6 * * *'
11 | workflow_dispatch: #allows you to trigger manually
12 | push:
13 | branches:
14 | - main
15 |
16 | jobs:
17 | deploy_python_pypi:
18 | name: "Deploy Python PyPi"
19 | runs-on: ubuntu-latest
20 |
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Python
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: '3.x'
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install setuptools wheel twine
31 | - name: Build and publish
32 | env:
33 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
35 | run: |
36 | python setup.py sdist bdist_wheel
37 | twine upload dist/* || echo "File Already exists"; exit 0
38 |
39 | # Disable usage of conda since it never installs correctly.
40 | # deploy_python_conda:
41 | # runs-on: ubuntu-latest
42 | # name: "Deploy Python Conda"
43 | # needs: deploy_python_pypi
44 | # strategy:
45 | # max-parallel: 5
46 | # timeout-minutes: 500
47 |
48 | # steps:
49 | # - uses: actions/checkout@v2
50 | # - name: Set up Python 3.6
51 | # uses: actions/setup-python@v2
52 | # with:
53 | # python-version: 3.6
54 | # - name: Add conda to system path
55 | # run: |
56 | # # $CONDA is an environment variable pointing to the root of the miniconda directory
57 | # echo $CONDA/bin >> $GITHUB_PATH
58 | # - name: Install dependencies
59 | # run: |
60 | # conda install -y -c fastchan conda-build anaconda-client fastrelease
61 | # - name: Build Conda package
62 | # timeout-minutes: 360
63 | # run: |
64 | # fastrelease_conda_package --do_build false
65 | # cd conda
66 | # conda build --no-anaconda-upload --output-folder build fastrl -c fastchan
67 | # anaconda upload build/noarch/fastrl-*-*.tar.bz2
68 | # env:
69 | # ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }}
70 | # - name: Test package import
71 | # run: |
72 | # conda install -c josiahls fastrl
73 | # conda update -c josiahls fastrl
74 | # python -c "import fastrl"
75 |
--------------------------------------------------------------------------------
/nbs/sidebar.yml:
--------------------------------------------------------------------------------
1 | website:
2 | sidebar:
3 | contents:
4 | - index.ipynb
5 | - 00_core.ipynb
6 | - 00_nbdev_extension.ipynb
7 | - 19_cli.ipynb
8 | - 20_test_utils.ipynb
9 | - contents:
10 | - 00_Fastai/00_torch_core.ipynb
11 | section: Fastai
12 | - contents:
13 | - 01_DataPipes/01a_pipes.core.ipynb
14 | - 01_DataPipes/01b_pipes.map.demux.ipynb
15 | - 01_DataPipes/01c_pipes.map.mux.ipynb
16 | - 01_DataPipes/01d_pipes.iter.nskip.ipynb
17 | - 01_DataPipes/01e_pipes.iter.nstep.ipynb
18 | - 01_DataPipes/01f_pipes.iter.firstlast.ipynb
19 | - 01_DataPipes/01g_pipes.iter.transforms.ipynb
20 | - 01_DataPipes/01h_pipes.map.transforms.ipynb
21 | section: DataPipes
22 | - contents:
23 | - 02_DataLoading/02f_data.dataloader2.ipynb
24 | - 02_DataLoading/02g_data.block.ipynb
25 | section: DataLoading
26 | - contents:
27 | - 03_Environment/05b_envs.gym.ipynb
28 | section: Environment
29 | - contents:
30 | - 02_Funcs/00_funcs.conjugation.ipynb
31 | section: Funcs
32 | - contents:
33 | - 04_Memory/06a_memory.experience_replay.ipynb
34 | section: Memory
35 | - contents:
36 | - 05_Logging/09a_loggers.core.ipynb
37 | - 05_Logging/09d_loggers.jupyter_visualizers.ipynb
38 | - 05_Logging/09e_loggers.tensorboard.ipynb
39 | - 05_Logging/09f_loggers.vscode_visualizers.ipynb
40 | section: Logging
41 | - contents:
42 | - 06_Learning/10a_learner.core.ipynb
43 | section: Learning
44 | - contents:
45 | - 07_Agents/12a_agents.core.ipynb
46 | - contents:
47 | - 07_Agents/01_Discrete/12b_agents.discrete.ipynb
48 | - 07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb
49 | - 07_Agents/01_Discrete/12h_agents.dqn.target.ipynb
50 | - 07_Agents/01_Discrete/12l_agents.dqn.asynchronous.ipynb
51 | - 07_Agents/01_Discrete/12m_agents.dqn.double.ipynb
52 | - 07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb
53 | - 07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb
54 | - 07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb
55 | section: Discrete
56 | - contents:
57 | - 07_Agents/02_Continuous/12s_agents.ddpg.ipynb
58 | - 07_Agents/02_Continuous/12t_agents.trpo.ipynb
59 | section: Continuous
60 | section: Agents
61 | - contents:
62 | - 11_FAQ/99_notes.multi_proc.ipynb
63 | - 11_FAQ/99_notes.pipe_insertion.ipynb
64 | - 11_FAQ/99_notes.speed.ipynb
65 | - 11_FAQ/99_template.ipynb
66 | section: FAQ
67 | - contents:
68 | - 12_Blog/99_blog.from_2023_05_to_now.ipynb
69 | section: Blog
70 |
--------------------------------------------------------------------------------
/fastrl/envs/continuous_debug_env.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03_Environment/06_envs.continuous_debug_env.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['ContinuousDebugEnv']
5 |
6 | # %% ../../nbs/03_Environment/06_envs.continuous_debug_env.ipynb 2
7 | # Python native modules
8 | import os
9 | # Third party libs
10 | import gymnasium as gym
11 | from gymnasium import spaces
12 | from gymnasium.envs.registration import register
13 | import numpy as np
14 | # Local modules
15 |
16 | # %% ../../nbs/03_Environment/06_envs.continuous_debug_env.ipynb 4
17 | class ContinuousDebugEnv(gym.Env):
18 | metadata = {'render_modes': ['console']} # Corrected metadata key
19 |
20 | def __init__(self, goal_position=None, proximity_threshold=0.5):
21 | super(ContinuousDebugEnv, self).__init__()
22 |
23 | self.goal_position = goal_position if goal_position is not None else np.random.uniform(-10, 10)
24 |
25 | if goal_position is not None:
26 | self.observation_space = spaces.Box(low=-goal_position, high=goal_position, shape=(1,), dtype=np.float32)
27 | else:
28 | self.observation_space = spaces.Box(low=-10, high=10, shape=(1,), dtype=np.float32)
29 |
30 | self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
31 |
32 |
33 | self.proximity_threshold = proximity_threshold
34 | self.state = None
35 |
36 | def step(self, action):
37 | self.state[0] += action[0] # Assuming action is a NumPy array, use the first element
38 |
39 | distance_to_goal = np.abs(self.state[0] - self.goal_position)
40 | reward = -distance_to_goal.item() # Ensure reward is a float
41 |
42 | done = distance_to_goal <= self.proximity_threshold
43 | done = bool(done.item()) # Ensure done is a boolean
44 |
45 | info = {}
46 |
47 | return self.state, reward, done,done, info
48 |
49 | def reset(self, seed=None, options=None):
50 | super().reset(seed=seed) # Call the superclass reset, which handles the seeding
51 | if self.goal_position is None:
52 | self.goal_position = np.random.uniform(-10, 10)
53 | # The state is {current position, goal position}
54 | self.state = np.array([0.0, self.goal_position], dtype=np.float32)
55 |
56 | return self.state, {} # Return observation and an empty info dictionary
57 |
58 |
59 | def render(self, mode='console'):
60 | if mode != 'console':
61 | raise NotImplementedError("Only console mode is supported.")
62 | print(f"Position: {self.state} Goal: {self.goal_position}")
63 |
64 |
65 | register(
66 | id="fastrl/ContinuousDebugEnv-v0",
67 | entry_point="fastrl.envs.continuous_debug_env:ContinuousDebugEnv",
68 | max_episode_steps=300,
69 | )
70 |
71 |
--------------------------------------------------------------------------------
/nbs/12_Blog/99_blog.from_2023_05_to_now.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "assisted-contract",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# Python native modules\n",
11 | "# Third party libs\n",
12 | "\n",
13 | "# Local modules\n",
14 | "from fastrl.nbdev_extensions import header"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "40e740d6",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "#|echo: false\n",
25 | "header(\n",
26 | " \"Updating torchdata to 0.7.0 and torch 2.0\",\n",
27 | " \"Gettign back up to speed with things\",\n",
28 | " freeze=False\n",
29 | ")"
30 | ]
31 | },
32 | {
33 | "attachments": {},
34 | "cell_type": "markdown",
35 | "id": "fec5e005",
36 | "metadata": {},
37 | "source": [
38 | "First change is `traverse` to `traverse_dps`. It seems that the traverse function has been\n",
39 | "moved to torch? Not sure what the benefit of this this. I like the improvement of the \n",
40 | "utlities for understanding datapipes though.\n",
41 | "\n",
42 | "The original dataloading stuff is found at [dataloader2](https://github.com/josiahls/fastrl/blob/c2ec68b9092f616e51bc9d574873a13e06ca3a26/fastrl/data/dataloader2.py)\n",
43 | "\n",
44 | "The goal of that dataloading module is:\n",
45 | "- Allowing getting data back from the main thread\n",
46 | "- Pushing an additional stream of data for logging.\n",
47 | "\n",
48 | "My main issue is that this is maybe overly complicated. I feel like \"getting data back from the main thread\"\n",
49 | "could be done through the file system.\n",
50 | "\n",
51 | "I know that I've re-worked this a bunch of times, but I just feel like this is too complicated, and \n",
52 | "I'm hoping the new changes to pytorch make this easier."
53 | ]
54 | },
55 | {
56 | "attachments": {},
57 | "cell_type": "markdown",
58 | "id": "aeb4f889",
59 | "metadata": {},
60 | "source": [
61 | "Ok so some additional changes:\n",
62 | "\n",
63 | "- removed the map iter pipe mux and demux. RL likely wont use this, and even if I need something like this, I can just\n",
64 | "iterate through them I think.\n",
65 | "- `find_dps` would be nice to have include subclasses. You can argue there should never be\n",
66 | "subclasses though? It would maybe be better to have a `find_dps` that just took a\n",
67 | "filter function. On the other hand... you could just filter after the fact... hmm"
68 | ]
69 | },
70 | {
71 | "attachments": {},
72 | "cell_type": "markdown",
73 | "id": "ab16ee37",
74 | "metadata": {},
75 | "source": []
76 | }
77 | ],
78 | "metadata": {
79 | "kernelspec": {
80 | "display_name": "python3",
81 | "language": "python",
82 | "name": "python3"
83 | }
84 | },
85 | "nbformat": 4,
86 | "nbformat_minor": 5
87 | }
88 |
--------------------------------------------------------------------------------
/fastrl/loggers/jupyter_visualizers.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['SimpleJupyterVideoPlayer', 'ImageCollector']
5 |
6 | # %% ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb 1
7 | # Python native modules
8 | import os
9 | from torch.multiprocessing import Queue
10 | from typing import Tuple,NamedTuple
11 | # Third party libs
12 | from fastcore.all import add_docs
13 | import matplotlib.pyplot as plt
14 | import torchdata.datapipes as dp
15 | from IPython.core.display import clear_output
16 | import torch
17 | import numpy as np
18 | # Local modules
19 | from ..core import Record
20 | from .core import LoggerBase,LogCollector,is_record
21 | # from fastrl.torch_core import *
22 |
23 | # %% ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb 4
24 | class SimpleJupyterVideoPlayer(dp.iter.IterDataPipe):
25 | def __init__(self,
26 | source_datapipe=None,
27 | between_frame_wait_seconds:float=0.1
28 | ):
29 | self.source_datapipe = source_datapipe
30 | self.between_frame_wait_seconds = 0.1
31 |
32 | def dequeue(self):
33 | while self.buffer: yield self.buffer.pop(0)
34 |
35 |
36 | def __iter__(self) -> Tuple[NamedTuple]:
37 | img = None
38 | for record in self.source_datapipe:
39 | # for o in self.dequeue():
40 | if is_record(record):
41 | if record.value is None: continue
42 | if img is None: img = plt.imshow(record.value)
43 | img.set_data(record.value)
44 | plt.axis('off')
45 | display(plt.gcf())
46 | clear_output(wait=True)
47 | yield record
48 | add_docs(
49 | SimpleJupyterVideoPlayer,
50 | """Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field.
51 | This only can handle 1 env input.""",
52 | dequeue="Grabs records from the `main_queue` and attempts to display them"
53 | )
54 |
55 | # %% ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb 5
56 | class ImageCollector(dp.iter.IterDataPipe):
57 | title:str='image'
58 |
59 | def __init__(self,source_datapipe):
60 | self.source_datapipe = source_datapipe
61 |
62 | def convert_np(self,o):
63 | if isinstance(o,torch.Tensor): return o.detach().numpy()
64 | elif isinstance(o,np.ndarray): return o
65 | else: raise ValueError(f'Expects Tensor or np.ndarray not {type(o)}')
66 |
67 | def __iter__(self):
68 | # for q in self.main_buffers: q.append(Record('image',None))
69 | yield Record(self.title,None)
70 | for steps in self.source_datapipe:
71 | if isinstance(steps,dp.DataChunk):
72 | for step in steps:
73 | yield Record(self.title,self.convert_np(step.image))
74 | else:
75 | yield Record(self.title,self.convert_np(steps.image))
76 | yield steps
77 |
--------------------------------------------------------------------------------
/fastrl/pipes/iter/firstlast.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['simple_step_first_last_merge', 'FirstLastMerger', 'n_first_last_steps_expected']
5 |
6 | # %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 2
7 | # Python native modules
8 | import warnings
9 | from typing import Callable,List,Union
10 | # Third party libs
11 | from fastcore.all import add_docs
12 | import torchdata.datapipes as dp
13 |
14 | import torch
15 | # Local modules
16 | from ...core import StepTypes,SimpleStep
17 |
18 | # %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 4
19 | def simple_step_first_last_merge(steps:List[SimpleStep],gamma):
20 | fstep,lstep = steps[0],steps[-1]
21 |
22 | reward = fstep.reward
23 | for step in steps[1:]:
24 | reward *= gamma
25 | reward += step.reward
26 |
27 | yield SimpleStep(
28 | state=fstep.state.clone().detach(),
29 | next_state=lstep.next_state.clone().detach(),
30 | action=fstep.action,
31 | episode_n=fstep.episode_n,
32 | image=fstep.image,
33 | reward=reward,
34 | raw_action=fstep.raw_action,
35 | terminated=lstep.terminated,
36 | truncated=lstep.truncated,
37 | total_reward=lstep.total_reward,
38 | env_id=lstep.env_id,
39 | proc_id=lstep.proc_id,
40 | step_n=lstep.step_n,
41 | batch_size=[]
42 | )
43 |
44 | class FirstLastMerger(dp.iter.IterDataPipe):
45 | def __init__(self,
46 | source_datapipe,
47 | merge_behavior:Callable[[List[Union[StepTypes.types]],float],Union[StepTypes.types]]=simple_step_first_last_merge,
48 | gamma:float=0.99
49 | ):
50 | self.source_datapipe = source_datapipe
51 | self.gamma = gamma
52 | self.merge_behavior = merge_behavior
53 |
54 | def __iter__(self) -> StepTypes.types:
55 | self.env_buffer = {}
56 | for steps in self.source_datapipe:
57 | if not isinstance(steps,(list,tuple)):
58 | raise ValueError(f'Expected {self.source_datapipe} to return a list/tuple of steps, however got {type(steps)}')
59 |
60 | if len(steps)==1:
61 | yield steps[0]
62 | continue
63 |
64 | yield from self.merge_behavior(steps,gamma=self.gamma)
65 |
66 | add_docs(
67 | FirstLastMerger,
68 | """Takes multiple steps and converts them into a single step consisting of properties
69 | from the first and last steps. Reward is recalculated to factor in the multiple steps.""",
70 | )
71 |
72 | # %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 15
73 | def n_first_last_steps_expected(
74 | default_steps:int, # The number of steps the episode would run without n_steps
75 | ):
76 | return default_steps
77 |
78 | n_first_last_steps_expected.__doc__=r"""
79 | This function doesnt do much for now. `FirstLastMerger` pretty much undoes the number of steps `nsteps` does.
80 | """
81 |
--------------------------------------------------------------------------------
/fastrl/loggers/vscode_visualizers.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['SimpleVSCodeVideoPlayer', 'VSCodeDataPipe']
5 |
6 | # %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 1
7 | # Python native modules
8 | import os
9 | import io
10 | from typing import Tuple,Any,Optional,NamedTuple,Iterable
11 | # Third party libs
12 | import imageio
13 | from fastcore.all import add_docs,ifnone
14 | import matplotlib.pyplot as plt
15 | import torchdata.datapipes as dp
16 | from torchdata.datapipes import functional_datapipe
17 | from IPython.core.display import Video,Image
18 | from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService
19 | # Local modules
20 | from .core import LoggerBase,is_record
21 | from ..pipes.core import DataPipeAugmentationFn,apply_dp_augmentation_fns
22 | from .jupyter_visualizers import ImageCollector
23 |
24 | # %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 5
25 | class SimpleVSCodeVideoPlayer(dp.iter.IterDataPipe):
26 | def __init__(self,
27 | source_datapipe=None,
28 | skip_frames:int=1,
29 | fps:int=30,
30 | downsize_res=(2,2)
31 | ):
32 | self.source_datapipe = source_datapipe
33 | self.fps = fps
34 | self.skip_frames = skip_frames
35 | self.downsize_res = downsize_res
36 | self._bytes_object = None
37 | self.frames = []
38 |
39 | def reset(self):
40 | super().reset()
41 | self._bytes_object = io.BytesIO()
42 |
43 | def show(self,start:int=0,end:Optional[int]=None,step:int=1):
44 | print(f'Creating gif from {len(self.frames)} frames')
45 | imageio.mimwrite(
46 | self._bytes_object,
47 | self.frames[start:end:step],
48 | format='GIF',
49 | duration=self.fps
50 | )
51 | return Image(self._bytes_object.getvalue())
52 |
53 | def __iter__(self) -> Tuple[NamedTuple]:
54 | n_frame = 0
55 | for record in self.source_datapipe:
56 | # for o in self.dequeue():
57 | if is_record(record) and record.name=='image':
58 | if record.value is None: continue
59 | n_frame += 1
60 | if n_frame%self.skip_frames!=0: continue
61 | self.frames.append(
62 | record.value[::self.downsize_res[0],::self.downsize_res[1]]
63 | )
64 | yield record
65 | add_docs(
66 | SimpleVSCodeVideoPlayer,
67 | """Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field.
68 | This only can handle 1 env input.""",
69 | show="In order to show the video, this must be called in a notebook cell.",
70 | reset="Will reset the bytes object that is used to store file data."
71 | )
72 |
73 | # %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 6
74 | @functional_datapipe('visualize_vscode')
75 | class VSCodeDataPipe(dp.iter.IterDataPipe):
76 | def __new__(self,source:Iterable):
77 | "This is the function that is actually run by `DataBlock`"
78 | pipe = ImageCollector(source).dump_records()
79 | pipe = SimpleVSCodeVideoPlayer(pipe)
80 | return pipe
81 |
82 |
--------------------------------------------------------------------------------
/fastrl/agents/dqn/rainbow.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['DQNRainbowLearner']
5 |
6 | # %% ../../../nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb 2
7 | # Python native modules
8 | from copy import deepcopy
9 | from typing import Optional,Callable,Tuple
10 | # Third party libs
11 | import torchdata.datapipes as dp
12 | from torchdata.dataloader2.graph import traverse_dps,DataPipe
13 | import torch
14 | from torch import nn,optim
15 | from fastcore.all import store_attr,ifnone
16 | import numpy as np
17 | import torch.nn.functional as F
18 | # Local modulesf
19 | from ...torch_core import default_device,to_detach,evaluating
20 | from ...pipes.core import find_dp
21 | from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter
22 | from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
23 | from ...memory.experience_replay import ExperienceReplay
24 | from ...loggers.core import BatchCollector,EpochCollector
25 | from ...learner.core import LearnerBase,LearnerHead
26 | from ..core import AgentHead,AgentBase
27 | from ...loggers.vscode_visualizers import VSCodeDataPipe
28 | from ...loggers.core import ProgressBarLogger
29 | from fastrl.agents.dqn.basic import (
30 | LossCollector,
31 | RollingTerminatedRewardCollector,
32 | EpisodeCollector,
33 | StepBatcher,
34 | TargetCalc,
35 | LossCalc,
36 | ModelLearnCalc,
37 | DQN,
38 | DQNAgent
39 | )
40 | from fastrl.agents.dqn.target import (
41 | TargetModelUpdater,
42 | TargetModelQCalc
43 | )
44 | from .dueling import DuelingHead
45 | from fastrl.agents.dqn.categorical import (
46 | CategoricalDQNAgent,
47 | CategoricalDQN,
48 | CategoricalTargetQCalc,
49 | PartialCrossEntropy
50 | )
51 |
52 | # %% ../../../nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb 4
53 | def DQNRainbowLearner(
54 | model,
55 | dls,
56 | do_logging:bool=True,
57 | loss_func=PartialCrossEntropy,
58 | opt=optim.AdamW,
59 | lr=0.005,
60 | bs=128,
61 | max_sz=10000,
62 | nsteps=1,
63 | device=None,
64 | batches=None,
65 | target_sync=300,
66 | # Use DoubleDQN target strategy
67 | double_dqn_strategy=True
68 | ) -> LearnerHead:
69 | learner = LearnerBase(model,dls=dls[0])
70 | learner = BatchCollector(learner,batches=batches)
71 | learner = EpochCollector(learner)
72 | if do_logging:
73 | learner = learner.dump_records()
74 | learner = ProgressBarLogger(learner)
75 | learner = RollingTerminatedRewardCollector(learner)
76 | learner = EpisodeCollector(learner).catch_records()
77 | learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
78 | learner = StepBatcher(learner,device=device)
79 | learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)
80 | learner = LossCalc(learner,loss_func=loss_func)
81 | learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
82 | learner = TargetModelUpdater(learner,target_sync=target_sync)
83 | if do_logging:
84 | learner = LossCollector(learner).catch_records()
85 |
86 | if len(dls)==2:
87 | val_learner = LearnerBase(model,dls[1]).visualize_vscode()
88 | val_learner = BatchCollector(val_learner,batches=batches)
89 | val_learner = EpochCollector(val_learner).catch_records(drop=True)
90 | return LearnerHead((learner,val_learner))
91 | else:
92 | return LearnerHead(learner)
93 |
--------------------------------------------------------------------------------
/fastrl/pipes/core.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/01_DataPipes/01a_pipes.core.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['find_dps', 'find_dp', 'DataPipeAugmentationFn', 'apply_dp_augmentation_fns']
5 |
6 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 2
7 | # Python native modules
8 | import os
9 | import logging
10 | import inspect
11 | from typing import Callable,Union,TypeVar,Optional,Type,List,Tuple
12 | # Third party libs
13 | import torchdata.datapipes as dp
14 | from torchdata.datapipes import functional_datapipe
15 | from torchdata.datapipes.iter import IterDataPipe
16 | from torchdata.datapipes.map import MapDataPipe
17 | from torchdata.dataloader2.graph import DataPipe, DataPipeGraph,find_dps,traverse_dps,list_dps
18 | # Local modules
19 |
20 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 5
21 | def find_dps(
22 | graph: DataPipeGraph,
23 | dp_type: Type[DataPipe],
24 | include_subclasses:bool=False
25 | ) -> List[DataPipe]:
26 | r"""
27 | Given the graph of DataPipe generated by ``traverse`` function, return DataPipe
28 | instances with the provided DataPipe type.
29 | """
30 | dps: List[DataPipe] = []
31 |
32 | def helper(g) -> None: # pyre-ignore
33 | for _, (dp, src_graph) in g.items():
34 | if include_subclasses and issubclass(type(dp),dp_type):
35 | dps.append(dp)
36 | elif type(dp) is dp_type: # Please not use `isinstance`, there is a bug.
37 | dps.append(dp)
38 | helper(src_graph)
39 |
40 | helper(graph)
41 |
42 | return dps
43 |
44 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 6
45 | def find_dp(
46 | # A graph created from the `traverse` function
47 | graph: DataPipeGraph,
48 | #
49 | dp_type: Type[DataPipe],
50 | include_subclasses:bool=False
51 | ) -> DataPipe:
52 | pipes = find_dps(graph,dp_type,include_subclasses)
53 | if len(pipes)==1: return pipes[0]
54 | elif len(pipes)>1:
55 | found_ids = set([id(pipe) for pipe in pipes])
56 | if len(found_ids)>1:
57 | logging.warn("""There are %s pipes of type %s. If this is intended,
58 | please use `find_dps` directly. Returning first instance.""",len(pipes),dp_type)
59 | return pipes[0]
60 | else:
61 | raise LookupError(f'Unable to find {dp_type} starting at {graph}')
62 |
63 | find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__
64 |
65 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 19
66 | class DataPipeAugmentationFn(Callable[[DataPipe],Optional[DataPipe]]):...
67 |
68 | DataPipeAugmentationFn.__doc__ = f"""`DataPipeAugmentationFn` must take in a `DataPipe` and either output a `DataPipe` or `None`. This function should perform some operation on the graph
69 | such as replacing, removing, inserting `DataPipe`'s and `DataGraph`s. Below is an example that replaces a `dp.iter.Batcher` datapipe with a `dp.iter.Filter`"""
70 |
71 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 23
72 | def apply_dp_augmentation_fns(
73 | pipe:DataPipe,
74 | dp_augmentation_fns:Optional[Tuple[DataPipeAugmentationFn]],
75 | debug:bool=False
76 | ) -> DataPipe:
77 | "Given a `pipe`, run `dp_augmentation_fns` other the pipeline"
78 | if dp_augmentation_fns is None: return pipe
79 | for fn in dp_augmentation_fns:
80 | if debug: print(f'Running fn: {fn} given current pipe: \n\t{traverse_dps(pipe)}')
81 | result = fn(pipe)
82 | if result is not None: pipe = result
83 | return pipe
84 |
--------------------------------------------------------------------------------
/fastrl/pipes/map/mux.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01c_pipes.map.mux.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['T_co', 'MultiplexerMapDataPipe']
5 |
6 | # %% ../../../nbs/01_DataPipes/01c_pipes.map.mux.ipynb 3
7 | # Python native modules
8 | import os
9 | from inspect import isfunction,ismethod
10 | from itertools import chain, zip_longest
11 | from typing import Callable, Dict, Iterable, Optional, TypeVar
12 | # Third party libs
13 | from fastcore.all import *
14 | from ...torch_core import *
15 | # from torch.utils.data.dataloader import DataLoader as OrgDataLoader
16 | import torchdata.datapipes as dp
17 | from torchdata.datapipes import functional_datapipe
18 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,MapDataPipe,IterDataPipe
19 | from torchdata.dataloader2.dataloader2 import DataLoader2
20 | # Local modules
21 |
22 | # %% ../../../nbs/01_DataPipes/01c_pipes.map.mux.ipynb 5
23 | T_co = TypeVar("T_co", covariant=True)
24 |
25 | @functional_datapipe("mux")
26 | class MultiplexerMapDataPipe(MapDataPipe[T_co]):
27 | def __init__(self, *datapipes, dp_index_map: Optional[Dict[MapDataPipe, Iterable]] = None):
28 | self.datapipes = datapipes
29 | self.dp_index_map = dp_index_map if dp_index_map else {}
30 | self.length: Optional[int] = None
31 | self.index_map = {}
32 | # Create a generator that yields (index, (dp_num, old_index)) in sequentially order.
33 | indices = (self._add_dp_num(i, dp) for i, dp in enumerate(datapipes))
34 | dp_id_and_key_tuples = chain.from_iterable(zip_longest(*indices))
35 | self.key_gen = enumerate(e for e in dp_id_and_key_tuples if e is not None)
36 |
37 | def _add_dp_num(self, dp_num: int, dp: MapDataPipe):
38 | # Assume 0-index for all DataPipes unless alternate indices are defined in `self.dp_index_map`
39 | dp_indices = self.dp_index_map[dp] if dp in self.dp_index_map else range(len(dp))
40 | for idx in dp_indices:
41 | yield dp_num, idx
42 |
43 | def __getitem__(self, index):
44 | if 0 <= index < len(self):
45 | if index in self.index_map:
46 | dp_num, old_key = self.index_map[index]
47 | else:
48 | curr_key = -1
49 | while curr_key < index:
50 | curr_key, dp_num_key_tuple = next(self.key_gen)
51 | dp_num, old_key = dp_num_key_tuple
52 | self.index_map[index] = dp_num, old_key
53 | try:
54 | return self.datapipes[dp_num][old_key]
55 | except KeyError:
56 | raise RuntimeError(
57 | f"Incorrect key is given to MapDataPipe {dp_num} in Multiplexer, likely because"
58 | f"that DataPipe is not 0-index but alternate indices are not given."
59 | )
60 | raise RuntimeError(f"Index {index} is out of bound for Multiplexer.")
61 |
62 | def __iter__(self):
63 | for i in range(len(self)):
64 | yield self[i]
65 |
66 | def __len__(self):
67 | if self.length is None:
68 | self.length = 0
69 | for dp in self.datapipes:
70 | self.length += len(dp)
71 | return self.length
72 |
73 | MultiplexerMapDataPipe.__doc__ = """Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). As in,
74 | one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
75 | and so on. It ends when the shortest input DataPipe is exhausted.
76 | """
77 |
--------------------------------------------------------------------------------
/fastrl/agents/dqn/double.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['DoubleQCalc', 'DoubleDQNLearner']
5 |
6 | # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 2
7 | # Python native modules
8 | from copy import deepcopy
9 | from typing import Optional,Callable,Tuple
10 | # Third party libs
11 | import torchdata.datapipes as dp
12 | from torchdata.dataloader2.graph import traverse_dps,DataPipe
13 | import torch
14 | from torch import nn,optim
15 | # Local modulesf
16 | from ...pipes.core import find_dp
17 | from ...memory.experience_replay import ExperienceReplay
18 | from ...loggers.core import BatchCollector,EpochCollector
19 | from ...learner.core import LearnerBase,LearnerHead
20 | from ...loggers.vscode_visualizers import VSCodeDataPipe
21 | from ...loggers.core import ProgressBarLogger
22 | from fastrl.agents.dqn.basic import (
23 | LossCollector,
24 | RollingTerminatedRewardCollector,
25 | EpisodeCollector,
26 | StepBatcher,
27 | TargetCalc,
28 | LossCalc,
29 | ModelLearnCalc,
30 | DQN,
31 | DQNAgent
32 | )
33 | from fastrl.agents.dqn.target import (
34 | TargetModelUpdater,
35 | TargetModelQCalc
36 | )
37 |
38 | # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 5
39 | class DoubleQCalc(dp.iter.IterDataPipe):
40 | def __init__(self,source_datapipe):
41 | self.source_datapipe = source_datapipe
42 |
43 | def __iter__(self):
44 | self.learner = find_dp(traverse_dps(self),LearnerBase)
45 | for batch in self.source_datapipe:
46 | self.learner.done_mask = batch.terminated.reshape(-1,)
47 | with torch.no_grad():
48 | chosen_actions = self.learner.model(batch.next_state).argmax(dim=1).reshape(-1,1)
49 | self.learner.next_q = self.learner.target_model(batch.next_state).gather(1,chosen_actions)
50 | self.learner.next_q[self.learner.done_mask] = 0
51 | yield batch
52 |
53 | # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 6
54 | def DoubleDQNLearner(
55 | model,
56 | dls,
57 | do_logging:bool=True,
58 | loss_func=nn.MSELoss(),
59 | opt=optim.AdamW,
60 | lr=0.005,
61 | bs=128,
62 | max_sz=10000,
63 | nsteps=1,
64 | device=None,
65 | batches=None,
66 | target_sync=300
67 | ) -> LearnerHead:
68 | learner = LearnerBase(model,dls=dls[0])
69 | learner = BatchCollector(learner,batches=batches)
70 | learner = EpochCollector(learner)
71 | if do_logging:
72 | learner = learner.dump_records()
73 | learner = ProgressBarLogger(learner)
74 | learner = RollingTerminatedRewardCollector(learner)
75 | learner = EpisodeCollector(learner).catch_records()
76 | learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
77 | learner = StepBatcher(learner,device=device)
78 | learner = DoubleQCalc(learner)
79 | learner = TargetCalc(learner,nsteps=nsteps)
80 | learner = LossCalc(learner,loss_func=loss_func)
81 | learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
82 | learner = TargetModelUpdater(learner,target_sync=target_sync)
83 | if do_logging:
84 | learner = LossCollector(learner).catch_records()
85 |
86 | if len(dls)==2:
87 | val_learner = LearnerBase(model,dls[1]).visualize_vscode()
88 | val_learner = BatchCollector(val_learner,batches=batches)
89 | val_learner = EpochCollector(val_learner).catch_records(drop=True)
90 | return LearnerHead((learner,val_learner))
91 | else:
92 | return LearnerHead(learner)
93 |
--------------------------------------------------------------------------------
/fastrl/pipes/iter/nskip.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['NSkipper', 'n_skips_expected']
5 |
6 | # %% ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb 2
7 | # Python native modules
8 | import os
9 | import warnings
10 | from typing import Callable, Dict, Iterable, Optional, TypeVar, Type,Union
11 | # Third party libs
12 | import torchdata.datapipes as dp
13 | from torchdata.datapipes.iter import IterDataPipe
14 | from fastcore.all import add_docs
15 | # Local modules
16 | from ...core import StepTypes
17 | from .nstep import NStepper
18 |
19 | # %% ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb 4
20 | _msg = """
21 | NSkipper should not go after NStepper. Please make the order:
22 |
23 | ```python
24 | ...
25 | pipe = NSkipper(pipe,n=3)
26 | pipe = NStepper(pipe,n=3)
27 | ...
28 | ```
29 |
30 | """
31 |
32 | class NSkipper(IterDataPipe[Union[StepTypes.types]]):
33 | def __init__(
34 | self,
35 | # The datapipe we are extracting from must produce `StepType`
36 | source_datapipe:IterDataPipe[Union[StepTypes.types]],
37 | # Number of steps to skip per env. Default will not skip at all.
38 | n:int=1
39 | ) -> None:
40 | if isinstance(source_datapipe,NStepper): raise Exception(_msg)
41 | self.source_datapipe = source_datapipe
42 | self.n = n
43 | self.env_buffer = {}
44 |
45 | def __iter__(self) -> StepTypes.types:
46 | self.env_buffer = {}
47 | for step in self.source_datapipe:
48 | if not issubclass(step.__class__,StepTypes.types):
49 | raise Exception(f'Expected {StepTypes.types} object got {type(step)}\n{step}')
50 |
51 | env_id,terminated,step_n = int(step.env_id),bool(step.terminated),int(step.step_n)
52 |
53 | if env_id in self.env_buffer: self.env_buffer[env_id] += 1
54 | else: self.env_buffer[env_id] = 1
55 |
56 | if self.env_buffer[env_id]%self.n==0: yield step
57 | elif terminated: yield step
58 | elif step_n==1: yield step
59 |
60 | if terminated: self.env_buffer[env_id] = 1
61 |
62 | add_docs(
63 | NSkipper,
64 | """Accepts a `source_datapipe` or iterable whose `next()` produces a `StepType` that
65 | skips N steps for individual environments *while always producing 1st steps and terminated steps.*
66 | """
67 | )
68 |
69 | # %% ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb 17
70 | def n_skips_expected(
71 | default_steps:int, # The number of steps the episode would run without n_skips
72 | n:int # The n-skip value that we are planning to use
73 | ):
74 | if n==1: return default_steps # All the steps will eb retained including the 1st step. No offset needed
75 | # If n goes into default_steps evenly, then the final "done" will be technically an "extra" step
76 | elif default_steps%n==0: return (default_steps // n) + 1 # first step will be kept
77 | else:
78 | # If the steps dont divide evenly then it will attempt to skip done, but ofcourse, we dont
79 | # let that happen
80 | return (default_steps // n) + 2 # first step and done will be kept
81 |
82 | n_skips_expected.__doc__=r"""
83 | Produces the expected number of steps, assuming a fully deterministic episode based on `default_steps` and `n`.
84 |
85 | Mainly used for testing.
86 |
87 | Given `n=2`, given 1 envs, knowing that `CartPole-v1` when `seed=0` will always run 18 steps, the total
88 | steps will be:
89 |
90 | $$
91 | 18 // n + 1 (1st+last)
92 | $$
93 | """
94 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pkg_resources import parse_version
2 | from configparser import ConfigParser
3 | import subprocess
4 | from setuptools.command.develop import develop
5 | from setuptools.command.install import install
6 | import setuptools
7 | import os
8 | assert parse_version(setuptools.__version__)>=parse_version('36.2')
9 |
10 | # note: all settings are in settings.ini; edit there, not here
11 | config = ConfigParser(delimiters=['='])
12 | config.read('settings.ini')
13 | cfg = config['DEFAULT']
14 |
15 | cfg_keys = 'version description keywords author author_email'.split()
16 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split()
17 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o)
18 | setup_cfg = {o:cfg[o] for o in cfg_keys}
19 |
20 | licenses = {
21 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'),
22 | }
23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha',
24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ]
25 | py_versions = '3.7 3.8 3.9 3.10'.split()
26 |
27 | requirements = ['pip', 'packaging']
28 | if cfg.get('requirements'): requirements += cfg.get('requirements','').split()
29 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split()
30 | dev_requirements = (cfg.get('dev_requirements') or '').split()
31 |
32 | lic = licenses[cfg['license']]
33 | min_python = cfg['min_python']
34 |
35 | # Define the repository and branch or commit you want to install from
36 | TORCHDATA_GIT_REPO = "https://github.com/josiahls/data.git"
37 | TORCHDATA_COMMIT = "main" # or replace with a specific commit hash
38 |
39 | class CustomInstall(install):
40 | def run(self):
41 | # Ensure that torchdata is cloned and installed before proceeding
42 | print('Cloning torchdata')
43 | subprocess.check_call(["git", "clone", TORCHDATA_GIT_REPO])
44 | print('Installing torchdata')
45 | subprocess.check_call(["pip", "install","-vvv", "./data"])
46 | # Call the standard install.
47 | install.run(self)
48 |
49 | class CustomDevelop(develop):
50 | def run(self):
51 | # Ensure that torchdata is cloned but not installed
52 | if not os.path.exists('data'):
53 | print('Cloning torchdata')
54 | subprocess.check_call(["git", "clone", TORCHDATA_GIT_REPO])
55 | try:
56 | import torchdata
57 | except ImportError:
58 | print('Installing torchdata')
59 | subprocess.check_call(["pip", "install","-vvv", "-e", "./data"])
60 | # Call the standard develop.
61 | develop.run(self)
62 |
63 | setuptools.setup(
64 | name = cfg['lib_name'],
65 | license = lic[0],
66 | classifiers = [
67 | 'Development Status :: ' + statuses[int(cfg['status'])],
68 | 'Intended Audience :: ' + cfg['audience'].title(),
69 | 'License :: ' + lic[1],
70 | 'Natural Language :: ' + cfg['language'].title(),
71 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]],
72 | url = cfg['git_url'],
73 | packages = setuptools.find_packages(),
74 | include_package_data = True,
75 | install_requires = requirements,
76 | extras_require={'dev': dev_requirements },
77 | dependency_links = cfg.get('dep_links','').split(),
78 | python_requires = '>=' + cfg['min_python'],
79 | long_description = open('README.md').read(),
80 | long_description_content_type = 'text/markdown',
81 | zip_safe = False,
82 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() },
83 | cmdclass={
84 | 'install': CustomInstall,
85 | 'develop': CustomDevelop,
86 | },
87 | **setup_cfg)
88 |
89 |
--------------------------------------------------------------------------------
/nbs/05_Logging/09e_loggers.tensorboard.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "assisted-contract",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "#|export\n",
23 | "# Python native modules\n",
24 | "import os\n",
25 | "from pathlib import Path\n",
26 | "# Third party libs\n",
27 | "import torchdata.datapipes as dp\n",
28 | "# Local modules"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "offshore-stuart",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "#|default_exp loggers.tensorboard"
39 | ]
40 | },
41 | {
42 | "attachments": {},
43 | "cell_type": "markdown",
44 | "id": "lesser-innocent",
45 | "metadata": {},
46 | "source": [
47 | "# Tensorboard \n",
48 | "> Iterable pipes for exporting to tensorboard"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "id": "fea383b7-e004-4ce1-8007-7b6d29248677",
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "#|export\n",
59 | "def run_tensorboard(\n",
60 | " port:int=6006, # The port to run tensorboard on/connect on\n",
61 | " start_tag:str=None, # Starting regex e.g.: experience_replay/1\n",
62 | " samples_per_plugin:str=None, # Sampling freq such as images=0 (keep all)\n",
63 | " extra_args:str=None, # Any additional arguments in the `--arg value` format\n",
64 | " rm_glob:bool=None # Remove old logs via a parttern e.g.: '*' will remove all files: runs/* \n",
65 | " ):\n",
66 | " if rm_glob is not None:\n",
67 | " for p in Path('runs').glob(rm_glob): p.delete()\n",
68 | " import socket\n",
69 | " from tensorboard import notebook\n",
70 | " a_socket=socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
71 | " cmd=None\n",
72 | " if not a_socket.connect_ex(('127.0.0.1',6006)):\n",
73 | " notebook.display(port=port,height=1000)\n",
74 | " else:\n",
75 | " cmd=f'--logdir runs --port {port} --host=0.0.0.0'\n",
76 | " if samples_per_plugin is not None: cmd+=f' --samples_per_plugin {samples_per_plugin}'\n",
77 | " if start_tag is not None: cmd+=f' --tag {start_tag}'\n",
78 | " if extra_args is not None: cmd+=f' {extra_args}'\n",
79 | " notebook.start(cmd)\n",
80 | " return cmd"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "id": "530f998c-250a-4005-8abc-65ca89d8ae7d",
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "#|hide\n",
91 | "SHOW_TENSOR_BOARD=False\n",
92 | "if not os.environ.get(\"IN_TEST\", None) and SHOW_TENSOR_BOARD:\n",
93 | " run_tensorboard()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "id": "current-pilot",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "#|hide\n",
104 | "#|eval: false\n",
105 | "!nbdev_export"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "e0d82468-a2bf-4bfd-9ac7-e56db49b8476",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": []
115 | }
116 | ],
117 | "metadata": {
118 | "kernelspec": {
119 | "display_name": "python3",
120 | "language": "python",
121 | "name": "python3"
122 | }
123 | },
124 | "nbformat": 4,
125 | "nbformat_minor": 5
126 | }
127 |
--------------------------------------------------------------------------------
/.github/workflows/fastrl-docker.yml:
--------------------------------------------------------------------------------
1 | name: Build fastrl images
2 | on:
3 | schedule:
4 | - cron: '1 6 * * *'
5 | workflow_dispatch: #allows you to trigger manually
6 | push:
7 | branches:
8 | - main
9 | - update_nbdev_docs
10 | - refactor/advantage-buffer
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | max-parallel: 1
17 | matrix:
18 | build_type: [dev]
19 | # build_type: [prod, dev]
20 | steps:
21 | - name: Maximize build space
22 | uses: easimon/maximize-build-space@master
23 | with:
24 | root-reserve-mb: 35000
25 | swap-size-mb: 1024
26 | remove-dotnet: 'true'
27 | remove-android: 'true'
28 | - name: Copy This Repository Contents
29 | uses: actions/checkout@v2
30 | # with:
31 | # submodules: recursive
32 | - name: Build
33 | run: |
34 | echo "Free space:"
35 | df -h
36 |
37 | - uses: actions/setup-python@v2
38 | with:
39 | python-version: '3.8'
40 | architecture: 'x64'
41 |
42 | - name: Copy settings.ini file
43 | run: |
44 | wget https://raw.githubusercontent.com/josiahls/fastrl/master/settings.ini
45 | - name: get version from settings.ini and create image name
46 | id: get_variables
47 | run: |
48 | from configparser import ConfigParser
49 | import os
50 | from pathlib import Path
51 | config = ConfigParser()
52 | settings = Path('settings.ini')
53 | assert settings.exists(), 'Not able to read or download settings.ini file.'
54 | config.read(settings)
55 | cfg = config['DEFAULT']
56 | print(f"::set-output name=version::{cfg['version']}")
57 | btype = os.getenv('BUILD_TYPE')
58 | assert btype in ['prod', 'dev'], "BUILD_TYPE must be either prod, dev or course"
59 | if btype != 'prod':
60 | image_name = f'josiahls/fastrl-{btype}'
61 | else:
62 | image_name = 'josiahls/fastrl'
63 | print(f"::set-output name=image_name::{image_name}")
64 | shell: python
65 | env:
66 | BUILD_TYPE: ${{ matrix.build_type }}
67 |
68 | # - name: Cache Docker layers
69 | # if: always()
70 | # uses: actions/cache@v3
71 | # with:
72 | # path: /tmp/.buildx-cache
73 | # key: ${{ runner.os }}-buildx-${{ github.sha }}
74 | # restore-keys: |
75 | # ${{ runner.os }}-buildx-
76 |
77 | # - name: Set up Docker Buildx
78 | # uses: docker/setup-buildx-action@v3
79 |
80 | - name: build and tag container
81 | run: |
82 | export DOCKER_BUILDKIT=1
83 | # We need to clear the previous docker images
84 | docker system prune -fa
85 | docker pull ${IMAGE_NAME}:latest || true
86 | # docker build --build-arg BUILD=${BUILD_TYPE} \
87 | # docker buildx create --use
88 | # docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \
89 | # -t ${IMAGE_NAME}:latest \
90 | # -t ${IMAGE_NAME}:${VERSION} \
91 | # -t ${IMAGE_NAME}:$(date +%F) \
92 | # -f fastrl.Dockerfile .
93 | docker buildx build --build-arg BUILD=${BUILD_TYPE} \
94 | -t ${IMAGE_NAME}:latest \
95 | -t ${IMAGE_NAME}:${VERSION} \
96 | -t ${IMAGE_NAME}:$(date +%F) \
97 | -f fastrl.Dockerfile .
98 | env:
99 | VERSION: ${{ steps.get_variables.outputs.version }}
100 | IMAGE_NAME: ${{ steps.get_variables.outputs.image_name }}
101 | BUILD_TYPE: ${{ matrix.build_type }}
102 |
103 | - name: push images
104 | run: |
105 | echo ${PASSWORD} | docker login -u $USERNAME --password-stdin
106 | docker push ${IMAGE_NAME}
107 | env:
108 | USERNAME: ${{ secrets.DOCKER_USERNAME }}
109 | PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
110 | IMAGE_NAME: ${{ steps.get_variables.outputs.image_name }}
111 |
--------------------------------------------------------------------------------
/nbs/README.md:
--------------------------------------------------------------------------------
1 | fastrl
2 | ================
3 |
4 |
5 |
6 | [](https://github.com/josiahls/fastrl/actions?query=workflow%3A%22Fastrl+Testing%22)
8 | [](https://pypi.python.org/pypi/fastrl)
10 | [](https://anaconda.org/josiahls/fastrl)
12 | [](https://hub.docker.com/repository/docker/josiahls/fastrl)
14 | [](https://hub.docker.com/repository/docker/josiahls/fastrl-dev)
16 |
17 | [](https://anaconda.org/josiahls/fastrl)
19 | [](https://pypi.python.org/pypi/fastrl)
21 | [](https://pypi.python.org/pypi/fastrl)
23 |
24 | > Warning: Even before fastrl==2.0.0, all Models should converge
25 | > reasonably fast, however HRL models `DADS` and `DIAYN` will need
26 | > re-balancing and some extra features that the respective authors used.
27 |
28 | # Overview
29 |
30 | Here is change
31 |
32 | Fastai for computer vision and tabular learning has been amazing. One
33 | would wish that this would be the same for RL. The purpose of this repo
34 | is to have a framework that is as easy as possible to start, but also
35 | designed for testing new agents.
36 |
37 | Documentation is being served at https://josiahls.github.io/fastrl/ from
38 | documentation directly generated via `nbdev` in this repo.
39 |
40 | # Current Issues of Interest
41 |
42 | ## Data Issues
43 |
44 | - [ ] data and async_data are still buggy. We need to verify that the
45 | order that the data being returned is the best it can be for our
46 | models. We need to make sure that “terminateds” are returned and that
47 | there are new duplicate (unless intended)
48 | - [ ] Better data debugging. Do environments skips steps correctly? Do
49 | n_steps work correct?
50 |
51 | # Whats new?
52 |
53 | As we have learned how to support as many RL agents as possible, we
54 | found that `fastrl==1.*` was vastly limited in the models that it can
55 | support. `fastrl==2.*` will leverage the `nbdev` library for better
56 | documentation and more relevant testing. We also will be building on the
57 | work of the `ptan`1 library as a close reference for pytorch
58 | based reinforcement learning APIs.
59 |
60 | 1 “Shmuma/Ptan”. Github, 2020,
61 | https://github.com/Shmuma/ptan. Accessed 13 June 2020.
62 |
63 | ## Install
64 |
65 | ## PyPI (Not implemented yet)
66 |
67 | Placeholder here, there is no pypi package yet. It is recommended to do
68 | traditional forking.
69 |
70 | (For future, currently there is no pypi
71 | persion)`pip install fastrl==2.0.0 --pre`
72 |
73 | ## Conda (Not implimented yet)
74 |
75 | `conda install -c fastchan -c josiahls fastrl`
76 |
77 | `source activate fastrl && python setup.py develop`
78 |
79 | ## Docker (highly recommend)
80 |
81 | Install:
82 | [Nvidia-Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
83 |
84 | Install: [docker-compose](https://docs.docker.com/compose/install/)
85 |
86 | ``` bash
87 | docker-compose pull && docker-compose up
88 | ```
89 |
90 | ## Contributing
91 |
92 | After you clone this repository, please run `nbdev_install_git_hooks` in
93 | your terminal. This sets up git hooks, which clean up the notebooks to
94 | remove the extraneous stuff stored in the notebooks (e.g. which cells
95 | you ran) which causes unnecessary merge conflicts.
96 |
97 | Before submitting a PR, check that the local library and notebooks
98 | match. The script `nbdev_diff_nbs` can let you know if there is a
99 | difference between the local library and the notebooks. \* If you made a
100 | change to the notebooks in one of the exported cells, you can export it
101 | to the library with `nbdev_build_lib` or `make fastai2`. \* If you made
102 | a change to the library, you can export it back to the notebooks with
103 | `nbdev_update_lib`.
104 |
--------------------------------------------------------------------------------
/fastrl/pipes/iter/nstep.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['NStepper', 'NStepFlattener', 'n_steps_expected']
5 |
6 | # %% ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb 2
7 | # Python native modules
8 | import os
9 | from typing import Type, Dict, Union, Tuple
10 | import typing
11 | import warnings
12 | # Third party libs
13 | from fastcore.all import add_docs
14 | import torchdata.datapipes as dp
15 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,DataPipe
16 | from torchdata.datapipes.iter import IterDataPipe
17 | from torchdata.datapipes.map import MapDataPipe
18 | # Local modules
19 | from ...core import StepTypes
20 |
21 | # %% ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb 4
22 | class NStepper(IterDataPipe):
23 | def __init__(
24 | self,
25 | # The datapipe we are extracting from must produce `StepType.types`
26 | source_datapipe:IterDataPipe[Union[StepTypes.types]],
27 | # Maximum number of steps to produce per yield as a tuple. This is the *max* number
28 | # and may be less if for example we are yielding terminal states.
29 | # Default produces single steps
30 | n:int=1
31 | ) -> None:
32 | self.source_datapipe:IterDataPipe[StepTypes.types] = source_datapipe
33 | self.n:int = n
34 | self.env_buffer:Dict = {}
35 |
36 | def __iter__(self) -> StepTypes.types:
37 | self.env_buffer = {}
38 | for step in self.source_datapipe:
39 | if not issubclass(step.__class__,StepTypes.types):
40 | raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}')
41 |
42 | env_id,terminated = int(step.env_id),bool(step.terminated)
43 |
44 | if env_id in self.env_buffer:
45 | self.env_buffer[env_id].append(step)
46 | else:
47 | self.env_buffer[env_id] = [step]
48 |
49 | if not terminated and len(self.env_buffer[env_id]) None:
72 | self.source_datapipe:IterDataPipe[[StepTypes.types]] = source_datapipe
73 |
74 | def __iter__(self) -> StepTypes.types:
75 | for step in self.source_datapipe:
76 | if issubclass(step.__class__,StepTypes.types):
77 | # print(step)
78 | yield step
79 | elif isinstance(step,tuple):
80 | # print('got step: ',step)
81 | yield from step
82 | else:
83 | raise Exception(f'Expected {StepTypes.types} or tuple object got {type(step)}\n{step}')
84 |
85 |
86 | add_docs(
87 | NStepFlattener,
88 | """Handles unwrapping `StepType.typess` in tuples better than `dp.iter.UnBatcher` and `dp.iter.Flattener`""",
89 | )
90 |
91 | # %% ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb 20
92 | def n_steps_expected(
93 | default_steps:int, # The number of steps the episode would run without n_steps
94 | n:int # The n-step value that we are planning ot use
95 | ):
96 | return (default_steps * n) - sum(range(n))
97 |
98 | n_steps_expected.__doc__=r"""
99 | Produces the expected number of steps, assuming a fully deterministic episode based on `default_steps` and `n`
100 |
101 | Given `n=2`, given 1 envs, knowing that `CartPole-v1` when `seed=0` will always run 18 steps, the total
102 | steps will be:
103 |
104 | $$
105 | 18 * n - \sum_{0}^{n - 1}(i)
106 | $$
107 | """
108 |
--------------------------------------------------------------------------------
/fastrl/pipes/iter/cacheholder.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/02a_pipes.iter.cacheholder.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['T_co', 'PickleableInMemoryCacheHolderIterDataPipe']
5 |
6 | # %% ../../../nbs/01_DataPipes/02a_pipes.iter.cacheholder.ipynb 2
7 | # Copyright (c) Meta Platforms, Inc. and affiliates.
8 | # All rights reserved.
9 | #
10 | # This source code is licensed under the BSD-style license found in the
11 | # LICENSE file in the root directory of this source tree.
12 |
13 | # Python native modules
14 | import hashlib
15 | import inspect
16 | import os.path
17 | import sys
18 | import time
19 | import uuid
20 | import warnings
21 | from enum import IntEnum
22 |
23 | from collections import deque
24 | from functools import partial
25 | from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Tuple, TypeVar
26 | # Third party libs
27 | try:
28 | import portalocker
29 | except ImportError:
30 | portalocker = None
31 |
32 | from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE
33 |
34 | from torch.utils.data.graph import traverse_dps
35 | from torchdata.datapipes import functional_datapipe
36 | from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
37 | # Local modules
38 |
39 |
40 | # %% ../../../nbs/01_DataPipes/02a_pipes.iter.cacheholder.ipynb 4
41 | if DILL_AVAILABLE:
42 | import dill
43 |
44 | dill.extend(use_dill=False)
45 |
46 | T_co = TypeVar("T_co", covariant=True)
47 |
48 | @functional_datapipe("pickleable_in_memory_cache")
49 | class PickleableInMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]):
50 | r"""
51 | Stores elements from the source DataPipe in memory, up to a size limit
52 | if specified (functional name: ``in_memory_cache``). This cache is FIFO - once the cache is full,
53 | further elements will not be added to the cache until the previous ones are yielded and popped off from the cache.
54 |
55 | Args:
56 | source_dp: source DataPipe from which elements are read and stored in memory
57 | size: The maximum size (in megabytes) that this DataPipe can hold in memory. This defaults to unlimited.
58 |
59 | Example:
60 | >>> from torchdata.datapipes.iter import IterableWrapper
61 | >>> source_dp = IterableWrapper(range(10))
62 | >>> cache_dp = source_dp.pickleable_in_memory_cache(size=5)
63 | >>> list(cache_dp)
64 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
65 | """
66 | size: Optional[int] = None
67 | idx: int
68 |
69 | def __init__(self, source_dp: IterDataPipe[T_co], size: Optional[int] = None) -> None:
70 | self.source_dp: IterDataPipe[T_co] = source_dp
71 | # cache size in MB
72 | if size is not None:
73 | self.size = size * 1024 * 1024
74 | self.cache: Optional[Deque] = None
75 | self.idx: int = 0
76 |
77 | def __getstate__(self):
78 | state = (
79 | self.source_dp,
80 | self.size
81 | )
82 | if IterDataPipe.getstate_hook is not None:
83 | return IterDataPipe.getstate_hook(state)
84 | return state
85 |
86 | def __setstate__(self, state):
87 | (
88 | self.source_dp,
89 | self.size
90 | ) = state
91 | self.cache: Optional[Deque] = None
92 | self.idx: int = 0
93 |
94 | def __iter__(self) -> Iterator[T_co]:
95 | if self.cache:
96 | if self.idx > 0:
97 | for idx, data in enumerate(self.source_dp):
98 | if idx < self.idx:
99 | yield data
100 | else:
101 | break
102 | yield from self.cache
103 | else:
104 | # Local cache
105 | cache: Deque = deque()
106 | idx = 0
107 | for data in self.source_dp:
108 | cache.append(data)
109 | # Cache reaches limit
110 | if self.size is not None and sys.getsizeof(cache) > self.size:
111 | cache.popleft()
112 | idx += 1
113 | yield data
114 | self.cache = cache
115 | self.idx = idx
116 |
117 | def __len__(self) -> int:
118 | try:
119 | return len(self.source_dp)
120 | except TypeError:
121 | if self.cache:
122 | return self.idx + len(self.cache)
123 | else:
124 | raise TypeError(f"{type(self).__name__} instance doesn't have valid length until the cache is loaded.")
125 |
126 |
--------------------------------------------------------------------------------
/fastrl/pipes/map/demux.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01b_pipes.map.demux.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['T_co', 'DemultiplexerMapDataPipe']
5 |
6 | # %% ../../../nbs/01_DataPipes/01b_pipes.map.demux.ipynb 3
7 | # Python native modules
8 | import os
9 | from inspect import isfunction,ismethod
10 | from typing import Callable, Dict, Iterable, Optional, TypeVar
11 | # Third party libs
12 | from fastcore.all import *
13 | from ...torch_core import *
14 | from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
15 | import torchdata.datapipes as dp
16 | from torchdata.datapipes import functional_datapipe
17 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,MapDataPipe,IterDataPipe
18 | from torchdata.dataloader2.dataloader2 import DataLoader2
19 | # Local modules
20 |
21 | # %% ../../../nbs/01_DataPipes/01b_pipes.map.demux.ipynb 5
22 | T_co = TypeVar("T_co", covariant=True)
23 |
24 | @functional_datapipe("demux")
25 | class DemultiplexerMapDataPipe(MapDataPipe[T_co]):
26 | def __new__(cls, datapipe: MapDataPipe, num_instances: int, classifier_fn: Callable, drop_none: bool = False,
27 | source_index: Optional[Iterable] = None):
28 | if num_instances < 1:
29 | raise ValueError(f"Expected `num_instances` larger than 0, but {num_instances} is found")
30 | _check_unpickable_fn(classifier_fn)
31 | container = _DemultiplexerMapDataPipe(datapipe, num_instances, classifier_fn, drop_none, source_index)
32 | return [_DemultiplexerChildMapDataPipe(container, i) for i in range(num_instances)]
33 |
34 | class _DemultiplexerMapDataPipe(MapDataPipe[T_co]):
35 | def __init__(
36 | self,
37 | datapipe: MapDataPipe[T_co],
38 | num_instances: int,
39 | classifier_fn: Callable[[T_co], Optional[int]],
40 | drop_none: bool,
41 | source_index: Optional[Iterable],
42 | ):
43 | self.main_datapipe = datapipe
44 | self.num_instances = num_instances
45 | self.classifier_fn = classifier_fn
46 | self.drop_none = drop_none
47 | self.iterator = None
48 | self.exhausted = False # Once we iterate through `main_datapipe` once, we know all the index mapping
49 | self.index_mapping = [[] for _ in range(num_instances)]
50 | self.source_index = source_index # if None, assume `main_datapipe` 0-index
51 |
52 | def _classify_next(self):
53 | if self.source_index is None:
54 | self.source_index = range(len(self.main_datapipe))
55 | if self.iterator is None:
56 | self.iterator = iter(self.source_index)
57 | try:
58 | next_source_idx = next(self.iterator)
59 | except StopIteration:
60 | self.exhausted = True
61 | return
62 | value = self.main_datapipe[next_source_idx]
63 | classification = self.classifier_fn(value)
64 | if classification is None and self.drop_none:
65 | self._classify_next()
66 | else:
67 | self.index_mapping[classification].append(value)
68 |
69 | def classify_all(self):
70 | while not self.exhausted:
71 | self._classify_next()
72 |
73 | def get_value(self, instance_id: int, index: int) -> T_co:
74 | while not self.exhausted and len(self.index_mapping[instance_id]) <= index:
75 | self._classify_next()
76 | if len(self.index_mapping[instance_id]) > index:
77 | return self.index_mapping[instance_id][index]
78 | raise RuntimeError("Index is out of bound.")
79 |
80 | def __len__(self):
81 | return len(self.main_datapipe)
82 |
83 | class _DemultiplexerChildMapDataPipe(MapDataPipe[T_co]):
84 | def __init__(self, main_datapipe: _DemultiplexerMapDataPipe, instance_id: int):
85 | self.main_datapipe: _DemultiplexerMapDataPipe = main_datapipe
86 | self.instance_id = instance_id
87 |
88 | def __getitem__(self, index: int):
89 | return self.main_datapipe.get_value(self.instance_id, index)
90 |
91 | def __len__(self):
92 | self.main_datapipe.classify_all() # You have to read through the entirety of main_datapipe to know `len`
93 | return len(self.main_datapipe.index_mapping[self.instance_id])
94 |
95 | def __iter__(self):
96 | for i in range(len(self)):
97 | yield self[i]
98 |
99 | DemultiplexerMapDataPipe.__doc__ = """Splits the input DataPipe into multiple child DataPipes, using the given
100 | classification function (functional name: ``demux``). A list of the child DataPipes is returned from this operation.
101 | """
102 |
--------------------------------------------------------------------------------
/fastrl/nbdev_extensions.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_nbdev_extensions.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['header', 'create_blog_notebook']
5 |
6 | # %% ../nbs/00_nbdev_extensions.ipynb 1
7 | # Python native modules
8 | import os
9 | from datetime import datetime
10 | import os
11 | import shutil
12 | import json
13 | # Third party libs
14 | from fastcore.all import *
15 | from nbdev.config import get_config
16 | import yaml
17 | from IPython.display import display, Markdown
18 | from fastcore.all import call_parse
19 | # Local modules
20 |
21 | # %% ../nbs/00_nbdev_extensions.ipynb 5
22 | def header(
23 | # The main header title to display.
24 | title: str,
25 | # The subtitle to display underneath the title. If None, no subtitle will be displayed.
26 | subtitle: Optional[str] = None,
27 | # If True, the date associated with the header will be frozen,
28 | # meaning it won't change in subsequent runs.
29 | # If False, a new date will be generated each time the function is run,
30 | # and the date will not be saved to file.
31 | freeze: bool = False
32 | ):
33 | """
34 | Function to generate a Markdown formatted header with an associated date.
35 | Dates are auto-incremented and can be frozen. This function also controls the persistent storage of dates.
36 | """
37 | filename = 'header_dates.json'
38 | date = None
39 | id:int = None
40 |
41 | # Load or initialize date dictionary
42 | if os.path.exists(filename):
43 | with open(filename, 'r') as file:
44 | dates = json.load(file)
45 | else:
46 | dates = {}
47 |
48 | # Determine the id for the new entry
49 | if freeze:
50 | # If frozen, use the maximum id from the file, or 0 if the file is empty
51 | id = max(dates.keys(), default=0)
52 | else:
53 | # If not frozen, increment the maximum id from the file, or use 0 if the file is empty
54 | id = max(dates.keys(), default=-1) + 1
55 |
56 | # Get or create the date
57 | date = dates.get(id)
58 | if date is None:
59 | date = datetime.now().strftime('%Y-%m-%d')
60 | dates[id] = date
61 |
62 | # Only write to file if the date is frozen
63 | if freeze:
64 | with open(filename, 'w') as file:
65 | json.dump(dates, file)
66 |
67 | # Display the markdown
68 | if subtitle is None:
69 | display(Markdown(f"# `{date}` **{title}**"))
70 | else:
71 | display(Markdown(f"# `{date}` **{title}**\n> {subtitle}"))
72 |
73 |
74 | # %% ../nbs/00_nbdev_extensions.ipynb 8
75 | @call_parse
76 | def create_blog_notebook() -> None: # Creates a new blog notebook from template
77 | template = '99_blog.from_xxxx_xx_to_xx.ipynb'
78 | new_name = datetime.now().strftime('99_blog.from_%Y_%m_to_now.ipynb')
79 |
80 | # Check if the template file exists
81 | if not os.path.exists(template):
82 | raise FileNotFoundError(f"Template file '{template}' not found in current directory.")
83 |
84 | # Rename old notebooks and update sidebar.yml
85 | sidebar_file = '../sidebar.yml'
86 | with open(sidebar_file, 'r') as f:
87 | sidebar = yaml.safe_load(f)
88 |
89 | blog_section = None
90 | for section in sidebar['website']['sidebar']['contents']:
91 | print(section)
92 | if 'section' in section and section['section'] == 'Blog':
93 | blog_section = section['contents']
94 | break
95 |
96 | # Rename old notebooks
97 | for filename in os.listdir():
98 | if filename.startswith('99_blog.from_') and filename.endswith('_to_now.ipynb'):
99 | date_from = filename[13:20] # corrected substring indexing
100 | date_to = datetime.now().strftime('%Y_%m')
101 | new_filename = f'99_blog.from_{date_from}_to_{date_to}.ipynb'
102 | os.rename(filename, new_filename)
103 |
104 | if blog_section is not None:
105 | # Update sidebar.yml
106 | old_entry = f'12_Blog/{filename}'
107 | new_entry = f'12_Blog/{new_filename}'
108 | if old_entry in blog_section:
109 | blog_section.remove(old_entry)
110 | blog_section.append(new_entry)
111 |
112 | # Add new notebook to sidebar.yml
113 | if f'12_Blog/{new_name}' not in blog_section:
114 | blog_section.append(f'12_Blog/{new_name}')
115 |
116 | with open(sidebar_file, 'w') as f:
117 | yaml.safe_dump(sidebar, f)
118 |
119 | # Create new notebook from template
120 | shutil.copy(template, new_name)
121 |
--------------------------------------------------------------------------------
/fastrl/agents/dqn/target.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['TargetModelUpdater', 'TargetModelQCalc', 'DQNTargetLearner']
5 |
6 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 2
7 | # Python native modules
8 | from copy import deepcopy
9 | from typing import Optional,Callable,Tuple
10 | # Third party libs
11 | import torchdata.datapipes as dp
12 | from torchdata.dataloader2.graph import traverse_dps,DataPipe
13 | import torch
14 | from torch import nn,optim
15 | # Local modulesf
16 | from ...pipes.core import find_dp
17 | from ...memory.experience_replay import ExperienceReplay
18 | from ...loggers.core import BatchCollector,EpochCollector
19 | from ...learner.core import LearnerBase,LearnerHead
20 | from ...loggers.vscode_visualizers import VSCodeDataPipe
21 | from ...loggers.core import ProgressBarLogger
22 | from fastrl.agents.dqn.basic import (
23 | LossCollector,
24 | RollingTerminatedRewardCollector,
25 | EpisodeCollector,
26 | StepBatcher,
27 | TargetCalc,
28 | LossCalc,
29 | ModelLearnCalc,
30 | DQN,
31 | DQNAgent
32 | )
33 |
34 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 7
35 | class TargetModelUpdater(dp.iter.IterDataPipe):
36 | def __init__(self,source_datapipe,target_sync=300):
37 | self.source_datapipe = source_datapipe
38 | self.target_sync = target_sync
39 | self.n_batch = 0
40 | self.learner = find_dp(traverse_dps(self),LearnerBase)
41 | with torch.no_grad():
42 | self.learner.target_model = deepcopy(self.learner.model)
43 |
44 | def reset(self):
45 | self.learner = find_dp(traverse_dps(self),LearnerBase)
46 | with torch.no_grad():
47 | self.learner.target_model = deepcopy(self.learner.model)
48 |
49 | def __iter__(self):
50 | if self._snapshot_state.NotStarted:
51 | self.reset()
52 | for batch in self.source_datapipe:
53 | if self.n_batch%self.target_sync==0:
54 | with torch.no_grad():
55 | self.learner.target_model.load_state_dict(self.learner.model.state_dict())
56 | self.n_batch+=1
57 | yield batch
58 |
59 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 8
60 | class TargetModelQCalc(dp.iter.IterDataPipe):
61 | def __init__(self,source_datapipe=None):
62 | self.source_datapipe = source_datapipe
63 |
64 | def __iter__(self):
65 | self.learner = find_dp(traverse_dps(self),LearnerBase)
66 | for batch in self.source_datapipe:
67 | self.learner.done_mask = batch.terminated.reshape(-1,)
68 | with torch.no_grad():
69 | self.learner.next_q = self.learner.target_model(batch.next_state)
70 | self.learner.next_q = self.learner.next_q.max(dim=1).values.reshape(-1,1)
71 | self.learner.next_q[self.learner.done_mask] = 0
72 | yield batch
73 |
74 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 9
75 | def DQNTargetLearner(
76 | model,
77 | dls,
78 | do_logging:bool=True,
79 | loss_func=nn.MSELoss(),
80 | opt=optim.AdamW,
81 | lr=0.005,
82 | bs=128,
83 | max_sz=10000,
84 | nsteps=1,
85 | device=None,
86 | batches=None,
87 | target_sync=300
88 | ) -> LearnerHead:
89 | learner = LearnerBase(model,dls=dls[0])
90 | learner = BatchCollector(learner,batches=batches)
91 | learner = EpochCollector(learner)
92 | if do_logging:
93 | learner = learner.dump_records()
94 | learner = ProgressBarLogger(learner)
95 | learner = RollingTerminatedRewardCollector(learner)
96 | learner = EpisodeCollector(learner).catch_records()
97 | learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
98 | learner = StepBatcher(learner,device=device)
99 | learner = TargetModelQCalc(learner)
100 | learner = TargetCalc(learner,nsteps=nsteps)
101 | learner = LossCalc(learner,loss_func=loss_func)
102 | learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
103 | learner = TargetModelUpdater(learner,target_sync=target_sync)
104 | if do_logging:
105 | learner = LossCollector(learner).catch_records()
106 |
107 | if len(dls)==2:
108 | val_learner = LearnerBase(model,dls[1]).visualize_vscode()
109 | val_learner = BatchCollector(val_learner,batches=batches)
110 | val_learner = EpochCollector(val_learner).catch_records(drop=True)
111 | return LearnerHead((learner,val_learner))
112 | else:
113 | return LearnerHead(learner)
114 |
--------------------------------------------------------------------------------
/nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "offshore-stuart",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "#|default_exp agents.dqn.dueling"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "assisted-contract",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "#|export\n",
33 | "# Python native modules\n",
34 | "# Third party libs\n",
35 | "import torch\n",
36 | "from torch import nn\n",
37 | "# Local modules\n",
38 | "from fastrl.agents.dqn.basic import (\n",
39 | " DQN,\n",
40 | " DQNAgent\n",
41 | ")\n",
42 | "from fastrl.agents.dqn.target import DQNTargetLearner"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "id": "lesser-innocent",
48 | "metadata": {},
49 | "source": [
50 | "# DQN Dueling\n",
51 | "> DQN using a split head for comparing the davantage of different actions"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "30c98be0-6288-443a-b4ab-9390fbe3081c",
57 | "metadata": {},
58 | "source": [
59 | "\n",
60 | "\n",
61 | "## Training DataPipes"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "id": "d921085e-7d53-40ac-9b37-56a31b15d47c",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "#|export\n",
72 | "class DuelingHead(nn.Module):\n",
73 | " def __init__(\n",
74 | " self,\n",
75 | " hidden: int, # Input into the DuelingHead, likely a hidden layer input\n",
76 | " n_actions: int, # Number/dim of actions to output\n",
77 | " lin_cls = nn.Linear\n",
78 | " ):\n",
79 | " super().__init__()\n",
80 | " self.val = lin_cls(hidden,1)\n",
81 | " self.adv = lin_cls(hidden,n_actions)\n",
82 | "\n",
83 | " def forward(self,xi):\n",
84 | " val,adv = self.val(xi),self.adv(xi)\n",
85 | " xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)\n",
86 | " return xi"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "id": "2b8f9ed8-fb05-40a1-ac0d-d4cafee8fa07",
92 | "metadata": {},
93 | "source": [
94 | "Try training with basic defaults..."
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "56acb9ff",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "from fastrl.envs.gym import GymDataPipe\n",
105 | "from fastrl.dataloading.core import dataloaders"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "63d9b481-5998-472a-a2df-18d79bf07ae2",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "#|eval:false\n",
116 | "# Setup up the core NN\n",
117 | "torch.manual_seed(0)\n",
118 | "model = DQN(4,2,head_layer=DuelingHead)\n",
119 | "# Setup the Agent\n",
120 | "model.train()\n",
121 | "agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n",
122 | "# Setup the Dataloaders\n",
123 | "params = dict(\n",
124 | " source=['CartPole-v1']*1,\n",
125 | " agent=agent,\n",
126 | " nsteps=2,\n",
127 | " nskips=2,\n",
128 | " firstlast=True\n",
129 | ")\n",
130 | "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n",
131 | "# Setup the Learner\n",
132 | "learner = DQNTargetLearner(\n",
133 | " model,\n",
134 | " dls,\n",
135 | " bs=128,\n",
136 | " max_sz=100_000,\n",
137 | " nsteps=2,\n",
138 | " lr=0.01,\n",
139 | " batches=1000,\n",
140 | " target_sync=300\n",
141 | ")\n",
142 | "learner.fit(7)"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "id": "cf3c74e6",
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "#|eval:false\n",
153 | "learner.validate()"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "id": "current-pilot",
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "#|hide\n",
164 | "#|eval: false\n",
165 | "!nbdev_export"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "id": "ab3d3626-1702-4f22-ae15-bb93a75bec68",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": []
175 | }
176 | ],
177 | "metadata": {
178 | "kernelspec": {
179 | "display_name": "python3",
180 | "language": "python",
181 | "name": "python3"
182 | }
183 | },
184 | "nbformat": 4,
185 | "nbformat_minor": 5
186 | }
187 |
--------------------------------------------------------------------------------
/fastrl/memory/experience_replay.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/04_Memory/06a_memory.experience_replay.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['ExperienceReplay']
5 |
6 | # %% ../../nbs/04_Memory/06a_memory.experience_replay.ipynb 2
7 | # Python native modules
8 | from copy import copy
9 | # Third party libs
10 | from fastcore.all import add_docs,ifnone
11 | import torchdata.datapipes as dp
12 | import numpy as np
13 | import torch
14 | # Local modules
15 | from ..core import StepTypes
16 |
17 | # %% ../../nbs/04_Memory/06a_memory.experience_replay.ipynb 4
18 | class ExperienceReplay(dp.iter.IterDataPipe):
19 | debug=False
20 | def __init__(self,
21 | source_datapipe,
22 | learner=None,
23 | bs=1,
24 | max_sz=100,
25 | return_idxs=False,
26 | # If the `self.device` is not cpu, and `store_as_cpu=True`, then
27 | # calls to `sample()` will dynamically move them to `self.device`, and
28 | # next `sample()` will move them back to cpu before producing new samples.
29 | # This can be slower, but can save vram.
30 | # If `store_as_cpu=False`, then samples stay on `self.device`
31 | #
32 | # If being run with n_workers>0, shared_memory, and fork, this MUST be true. This is needed because
33 | # otherwise the tensors in the memory will remain shared with the tensors created in the
34 | # dataloader.
35 | store_as_cpu:bool=True,
36 | # When `max_sz` is reached, no new records will be added to the memory.
37 | # This is useful for debugging since a model should be able to
38 | # reach a loss of 0 learning on a static set.
39 | freeze_memory:bool=False
40 | ):
41 | self.memory = np.array([None]*max_sz)
42 | self.source_datapipe = source_datapipe
43 | self.learner = learner
44 | if learner is not None:
45 | self.learner.experience_replay = self
46 | self.bs = bs
47 | self.freeze_memory = freeze_memory
48 | self.max_sz = max_sz
49 | self._sz_tracker = 0
50 | self._idx_tracker = 0
51 | self._cycle_tracker = 0
52 | self.return_idxs = return_idxs
53 | self.store_as_cpu = store_as_cpu
54 | self._last_idx = None
55 | self.device = None
56 |
57 | def to(self,*args,**kwargs):
58 | self.device = kwargs.get('device',None)
59 |
60 | def sample(self,bs=None):
61 | idxs = np.random.choice(range(self._sz_tracker),size=(ifnone(bs,self.bs),),replace=False)
62 | if self.return_idxs: return self.memory[idxs],idxs
63 | self._last_idx = idxs
64 | return [o.to(device=self.device) for o in self.memory[idxs]]
65 |
66 | def __repr__(self):
67 | return str({k:v if k!='memory' else f'{len(self)} elements' for k,v in self.__dict__.items()})
68 |
69 | def __len__(self): return self._sz_tracker
70 |
71 | def show(self, agent=None):
72 | from fastrl.memory.memory_visualizer import MemoryBufferViewer
73 | return MemoryBufferViewer(self.memory,agent=agent)
74 |
75 | def __iter__(self):
76 | for i,b in enumerate(self.source_datapipe):
77 | if self.debug: print('Experience Replay Adding: ',b)
78 |
79 | if not issubclass(b.__class__,(*StepTypes.types,list,tuple)):
80 | raise Exception(f'Expected typing.NamedTuple,list,tuple object got {type(step)}\n{step}')
81 |
82 | if issubclass(b.__class__,StepTypes.types): self.add(b)
83 | elif issubclass(b.__class__,(list,tuple)):
84 | for step in b: self.add(step)
85 | else:
86 | raise Exception(f'This should not have occured: {self.__dict__}')
87 |
88 | if self._sz_tracker=self.max_sz:
104 | if not self.freeze_memory:
105 | if self._idx_tracker>=self.max_sz:
106 | self._idx_tracker = 0
107 | self._cycle_tracker += 1
108 | self.memory[self._idx_tracker] = step
109 | self._idx_tracker += 1
110 | else:
111 | raise Exception(f'This should not have occured: {self.__dict__}')
112 |
113 | add_docs(
114 | ExperienceReplay,
115 | """Simplest form of memory. Takes steps from `source_datapipe` to stores them in `memory`.
116 | It outputs `bs` steps.""",
117 | sample="Returns `bs` steps from `memory` in a uniform distribution.",
118 | add="Adds new steps to `memory`. If `memory` reaches size `max_sz` then `step` will be added in earlier steps.",
119 | to=torch.Tensor.to.__doc__,
120 | show="Displays a ipywidget to look at the steps in `self.memory`"
121 | )
122 |
--------------------------------------------------------------------------------
/nbs/02_DataLoading/00_core.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "offshore-stuart",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "#|default_exp dataloading.core"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "assisted-contract",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "#|export\n",
33 | "# Python native modules\n",
34 | "from typing import Tuple,Union,List\n",
35 | "# Third party libs\n",
36 | "import torchdata.datapipes as dp\n",
37 | "from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2\n",
38 | "from fastcore.all import delegates\n",
39 | "# Local modules"
40 | ]
41 | },
42 | {
43 | "attachments": {},
44 | "cell_type": "markdown",
45 | "id": "a258abcf",
46 | "metadata": {},
47 | "source": [
48 | "# Dataloading Core\n",
49 | "> Basic utils for creating dataloaders from rl datapipes."
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "0c4e1268",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "#|export\n",
60 | "@delegates(MultiProcessingReadingService)\n",
61 | "def dataloaders(\n",
62 | " # A tuple of iterable datapipes to generate dataloaders from.\n",
63 | " pipes:Union[Tuple[dp.iter.IterDataPipe],dp.iter.IterDataPipe],\n",
64 | " # Concat the dataloaders together\n",
65 | " do_concat:bool = False,\n",
66 | " # Multiplex the dataloaders\n",
67 | " do_multiplex:bool = False,\n",
68 | " # Number of workers the dataloaders should run in\n",
69 | " num_workers: int = 0,\n",
70 | " **kwargs\n",
71 | ") -> Union[dp.iter.IterDataPipe,List[dp.iter.IterDataPipe]]:\n",
72 | " \"Function that creates dataloaders based on `pipes` with different ways of combing them.\"\n",
73 | " if not isinstance(pipes,tuple):\n",
74 | " pipes = (pipes,)\n",
75 | "\n",
76 | " dls = []\n",
77 | " for pipe in pipes:\n",
78 | " dl = DataLoader2(\n",
79 | " datapipe=pipe,\n",
80 | " reading_service=MultiProcessingReadingService(\n",
81 | " num_workers = num_workers,\n",
82 | " **kwargs\n",
83 | " ) if num_workers > 0 else None\n",
84 | " )\n",
85 | " dl = dp.iter.IterableWrapper(dl,deepcopy=False)\n",
86 | " dls.append(dl)\n",
87 | " #TODO(josiahls): Not sure if this is needed tbh.. Might be better to just\n",
88 | " # return dls, and have the user wrap them if they want. Then try can do more complex stuff.\n",
89 | " if do_concat:\n",
90 | " return dp.iter.Concater(*dls)\n",
91 | " elif do_multiplex:\n",
92 | " return dp.iter.Multiplexer(*dls)\n",
93 | " else:\n",
94 | " return dls\n"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "91b2f9c2",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "from fastcore.test import test_eq"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "id": "fc24a6bb",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "# Sample Data\n",
115 | "pipe1 = dp.iter.IterableWrapper([1, 2, 3])\n",
116 | "pipe2 = dp.iter.IterableWrapper([4, 5, 6])\n",
117 | "\n",
118 | "# Test for a single IterDataPipe\n",
119 | "dls = dataloaders(pipe1)\n",
120 | "assert len(dls) == 1\n",
121 | "assert isinstance(dls[0], dp.iter.IterableWrapper)\n",
122 | "test_eq(list(dls[0]), [1, 2, 3])\n",
123 | "\n",
124 | "# Test for a tuple of IterDataPipes without concatenation or multiplexing\n",
125 | "dls = dataloaders((pipe1, pipe2))\n",
126 | "test_eq(len(dls),2)\n",
127 | "test_eq(list(dls[0]), [1, 2, 3])\n",
128 | "test_eq(list(dls[1]), [4, 5, 6])\n",
129 | "\n",
130 | "# Test for concatenation\n",
131 | "dl = dataloaders((pipe1, pipe2), do_concat=True)\n",
132 | "assert isinstance(dl, dp.iter.Concater)\n",
133 | "test_eq(list(dl), [1, 2, 3, 4, 5, 6])\n",
134 | "\n",
135 | "# Test for multiplexing\n",
136 | "dl = dataloaders((pipe1, pipe2), do_multiplex=True)\n",
137 | "assert isinstance(dl, dp.iter.Multiplexer)\n",
138 | "test_eq(list(dl), [1, 4, 2, 5, 3, 6])"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "current-pilot",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "#|hide\n",
149 | "#|eval: false\n",
150 | "!nbdev_export"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "ed71a089",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": []
160 | }
161 | ],
162 | "metadata": {
163 | "kernelspec": {
164 | "display_name": "python3",
165 | "language": "python",
166 | "name": "python3"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 5
171 | }
172 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | fastrl
2 | ================
3 |
4 |
5 |
6 | [](https://github.com/josiahls/fastrl/actions?query=workflow%3A%22Fastrl+Testing%22)
8 | [](https://pypi.python.org/pypi/fastrl)
10 | [](https://hub.docker.com/repository/docker/josiahls/fastrl)
12 | [](https://hub.docker.com/repository/docker/josiahls/fastrl-dev)
14 |
15 | [](https://pypi.python.org/pypi/fastrl)
17 | [](https://pypi.python.org/pypi/fastrl)
19 |
20 | > Warning: This is in alpha, and so uses latest torch and torchdata,
21 | > very importantly torchdata. The base API, while at the point of
22 | > semi-stability, might be changed in future versions, and so there will
23 | > be no promises of backward compatiblity. For the time being, it is
24 | > best to hard-pin versions of the library.
25 |
26 | > Warning: Even before fastrl==2.0.0, all Models should converge
27 | > reasonably fast, however HRL models `DADS` and `DIAYN` will need
28 | > re-balancing and some extra features that the respective authors used.
29 |
30 | # Overview
31 |
32 | Fastai for computer vision and tabular learning has been amazing. One
33 | would wish that this would be the same for RL. The purpose of this repo
34 | is to have a framework that is as easy as possible to start, but also
35 | designed for testing new agents.
36 |
37 | This version fo fastrl is basically a wrapper around
38 | [torchdata](https://github.com/pytorch/data).
39 |
40 | It is built around 4 pipeline concepts (half is from fastai):
41 |
42 | - DataLoading/DataBlock pipelines
43 | - Agent pipelines
44 | - Learner pipelines
45 | - Logger plugins
46 |
47 | Documentation is being served at https://josiahls.github.io/fastrl/ from
48 | documentation directly generated via `nbdev` in this repo.
49 |
50 | Basic DQN example:
51 |
52 | ``` python
53 | from fastrl.loggers.core import *
54 | from fastrl.loggers.vscode_visualizers import *
55 | from fastrl.agents.dqn.basic import *
56 | from fastrl.agents.dqn.target import *
57 | from fastrl.data.block import *
58 | from fastrl.envs.gym import *
59 | import torch
60 | ```
61 |
62 | ``` python
63 | # Setup Loggers
64 | logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
65 | batch_on_pipe=BatchCollector)
66 |
67 | # Setup up the core NN
68 | torch.manual_seed(0)
69 | model = DQN(4,2)
70 | # Setup the Agent
71 | agent = DQNAgent(model,[logger_base],max_steps=10000)
72 | # Setup the DataBlock
73 | block = DataBlock(
74 | GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True), # We basically merge 2 steps into 1 and skip.
75 | (GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True,n=100,include_images=True),VSCodeTransformBlock())
76 | )
77 | dls = L(block.dataloaders(['CartPole-v1']*1))
78 | # Setup the Learner
79 | learner = DQNLearner(model,dls,logger_bases=[logger_base],bs=128,max_sz=20_000,nsteps=2,lr=0.001,
80 | batches=1000,
81 | dp_augmentation_fns=[
82 | # Plugin TargetDQN code
83 | TargetModelUpdater.insert_dp(),
84 | TargetModelQCalc.replace_dp()
85 | ])
86 | learner.fit(10)
87 | #learner.validate()
88 | ```
89 |
90 | # Whats new?
91 |
92 | As we have learned how to support as many RL agents as possible, we
93 | found that `fastrl==1.*` was vastly limited in the models that it can
94 | support. `fastrl==2.*` will leverage the `nbdev` library for better
95 | documentation and more relevant testing, and `torchdata` is the base
96 | lib. We also will be building on the work of the `ptan`1
97 | library as a close reference for pytorch based reinforcement learning
98 | APIs.
99 |
100 | 1 “Shmuma/Ptan”. Github, 2020,
101 | https://github.com/Shmuma/ptan. Accessed 13 June 2020.
102 |
103 | ## Install
104 |
105 | ## PyPI
106 |
107 | Below will install the alpha build of fastrl.
108 |
109 | **Cuda Install**
110 |
111 | `pip install fastrl==0.0.* --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu113`
112 |
113 | **Cpu Install**
114 |
115 | `pip install fastrl==0.0.* --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu`
116 |
117 | ## Docker (highly recommend)
118 |
119 | Install:
120 | [Nvidia-Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
121 |
122 | Install: [docker-compose](https://docs.docker.com/compose/install/)
123 |
124 | ``` bash
125 | docker-compose pull && docker-compose up
126 | ```
127 |
128 | ## Contributing
129 |
130 | After you clone this repository, please run `nbdev_install_hooks` in
131 | your terminal. This sets up git hooks, which clean up the notebooks to
132 | remove the extraneous stuff stored in the notebooks (e.g. which cells
133 | you ran) which causes unnecessary merge conflicts.
134 |
135 | Before submitting a PR, check that the local library and notebooks
136 | match. The script `nbdev_clean` can let you know if there is a
137 | difference between the local library and the notebooks. \* If you made a
138 | change to the notebooks in one of the exported cells, you can export it
139 | to the library with `nbdev_build_lib` or `make fastai2`. \* If you made
140 | a change to the library, you can export it back to the notebooks with
141 | `nbdev_update_lib`.
142 |
--------------------------------------------------------------------------------
/fastrl.Dockerfile:
--------------------------------------------------------------------------------
1 | # FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime
2 | # FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime
3 | FROM nvidia/cuda:12.0.0-runtime-ubuntu20.04
4 | # RUN conda install python=3.8
5 |
6 | ENV CONTAINER_USER fastrl_user
7 | ENV CONTAINER_GROUP fastrl_group
8 | ENV CONTAINER_UID 1000
9 | # Add user to conda
10 | RUN addgroup --gid $CONTAINER_UID $CONTAINER_GROUP && \
11 | adduser --uid $CONTAINER_UID --gid $CONTAINER_UID $CONTAINER_USER --disabled-password
12 | # && \
13 | # mkdir -p /opt/conda && chown $CONTAINER_USER /opt/conda
14 |
15 | RUN apt-get update && apt-get install -y software-properties-common rsync curl gcc g++
16 | #RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-key C99B11DEB97541F0 && apt-add-repository https://cli.github.com/packages
17 |
18 | RUN curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | gpg --dearmor -o /usr/share/keyrings/githubcli-archive-keyring.gpg
19 | RUN echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null
20 |
21 | RUN apt-get install -y build-essential python3.8-dev python3.8-distutils
22 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2 && update-alternatives --config python
23 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3.8 get-pip.py
24 |
25 | RUN apt-get update && apt-get install -y git libglib2.0-dev graphviz libxext6 \
26 | libsm6 libxrender1 python-opengl xvfb nano gh tree wget libosmesa6-dev \
27 | libgl1-mesa-glx libglfw3 && apt-get update
28 |
29 |
30 | WORKDIR /home/$CONTAINER_USER
31 | # Install Primary Pip Reqs
32 | # ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/nightly/cu117
33 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP extra/requirements.txt /home/$CONTAINER_USER/extra/requirements.txt
34 | # Since we are using a custom fork of torchdata, we install torchdata as part of a submodule.
35 | # RUN pip3 install -r extra/requirements.txt --pre --upgrade
36 | RUN pip3 install torch>=2.0.0
37 | # --pre --upgrade
38 | RUN pip3 show torch
39 |
40 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP extra/pip_requirements.txt /home/$CONTAINER_USER/extra/pip_requirements.txt
41 | RUN pip3 install -r extra/pip_requirements.txt
42 |
43 | WORKDIR /home/$CONTAINER_USER/fastrl
44 | RUN git clone https://github.com/josiahls/data.git \
45 | && cd data && pip3 install -e .
46 | WORKDIR /home/$CONTAINER_USER
47 |
48 | # Install Dev Reqs
49 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP extra/dev_requirements.txt /home/$CONTAINER_USER/extra/dev_requirements.txt
50 | ARG BUILD=dev
51 | # Needed for gymnasium[all] when installing Box2d
52 | RUN apt-get install swig3.0 && ln -s /usr/bin/swig3.0 /usr/bin/swig
53 | RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && pip3 install -r extra/dev_requirements.txt ; fi"
54 | # RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz; fi"
55 | # RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && conda install -c conda-forge nodejs==15.14.0 line_profiler && jupyter labextension install jupyterlab-plotly; fi"
56 |
57 | # RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /opt/conda/bin
58 | RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /usr/local/lib/python3.8/dist-packages/torch/utils/data/datapipes
59 | # RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /usr/local/lib/python3.8/dist-packagess/mujoco_py
60 | # RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /usr/local/lib/python3.8/dist-packages
61 |
62 | RUN pip3 show torch
63 |
64 | RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /home/$CONTAINER_USER
65 |
66 | RUN apt-get install sudo
67 | # Give user password-less sudo access
68 | RUN echo "$CONTAINER_USER ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers.d/$CONTAINER_USER && \
69 | chmod 0440 /etc/sudoers.d/$CONTAINER_USER
70 |
71 | RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then nbdev_install_quarto ; fi"
72 |
73 | # RUN mkdir -p /home/$CONTAINER_USER/.mujoco \
74 | # && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \
75 | # && tar -xf mujoco.tar.gz -C /home/$CONTAINER_USER/.mujoco \
76 | # && rm mujoco.tar.gz
77 |
78 | # ENV LD_LIBRARY_PATH /home/$CONTAINER_USER/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
79 | # ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
80 |
81 | # RUN ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so
82 |
83 | USER $CONTAINER_USER
84 | WORKDIR /home/$CONTAINER_USER
85 | ENV PATH="/home/$CONTAINER_USER/.local/bin:${PATH}"
86 |
87 | # RUN git clone https://github.com/josiahls/fastrl.git --depth 1
88 | RUN pip install setuptools==60.7.0
89 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP . fastrl
90 |
91 | RUN sudo apt-get -y install cmake python3.8-venv
92 |
93 | # RUN curl https://get.modular.com | sh - && \
94 | # modular auth mut_9b52dfea7b05427385fdeddc85dd3a64 && \
95 | # modular install mojo
96 |
97 | RUN BASHRC=$( [ -f "$HOME/.bash_profile" ] && echo "$HOME/.bash_profile" || echo "$HOME/.bashrc" ) && \
98 | echo 'export MODULAR_HOME="/home/fastrl_user/.modular"' >> "$BASHRC" && \
99 | echo 'export PATH="/home/fastrl_user/.modular/pkg/packages.modular.com_mojo/bin:$PATH"' >> "$BASHRC" && \
100 | source "$BASHRC"
101 |
102 | # RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && cd fastrl/data && mv pyproject.toml pyproject.toml_tmp && pip install -e . --no-dependencies && mv pyproject.toml_tmp pyproject.toml && cd ../; fi"
103 |
104 | RUN /bin/bash -c "if [[ $BUILD == 'prod' ]] ; then echo \"Production Build\" && cd fastrl && pip install . --no-dependencies; fi"
105 | RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && cd fastrl && pip install -e \".[dev]\" --no-dependencies ; fi"
106 |
107 | # RUN echo '#!/bin/bash\npip install -e .[dev] --no-dependencies && xvfb-run -s "-screen 0 1400x900x24" jupyter lab --ip=0.0.0.0 --port=8080 --allow-root --no-browser --NotebookApp.token='' --NotebookApp.password=''' >> run_jupyter.sh
108 |
109 | RUN /bin/bash -c "cd fastrl && pip install -e . --no-dependencies"
110 |
111 |
--------------------------------------------------------------------------------
/nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "assisted-contract",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "#|export\n",
23 | "# Python native modules\n",
24 | "import os\n",
25 | "from torch.multiprocessing import Queue\n",
26 | "from typing import Tuple,NamedTuple\n",
27 | "# Third party libs\n",
28 | "from fastcore.all import add_docs\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import torchdata.datapipes as dp\n",
31 | "from IPython.core.display import clear_output\n",
32 | "import torch\n",
33 | "import numpy as np\n",
34 | "# Local modules\n",
35 | "from fastrl.core import Record\n",
36 | "from fastrl.loggers.core import LoggerBase,LogCollector,is_record\n",
37 | "# from fastrl.torch_core import *"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "id": "offshore-stuart",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "#|default_exp loggers.jupyter_visualizers"
48 | ]
49 | },
50 | {
51 | "attachments": {},
52 | "cell_type": "markdown",
53 | "id": "lesser-innocent",
54 | "metadata": {},
55 | "source": [
56 | "# Visualizers \n",
57 | "> Iterable pipes for displaying environments as they run using `typing.NamedTuples` with `image` fields"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "id": "9cb488d2-41f1-4160-a938-e39003f1a06a",
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "#|export\n",
68 | "class SimpleJupyterVideoPlayer(dp.iter.IterDataPipe):\n",
69 | " def __init__(self, \n",
70 | " source_datapipe=None, \n",
71 | " between_frame_wait_seconds:float=0.1\n",
72 | " ):\n",
73 | " self.source_datapipe = source_datapipe\n",
74 | " self.between_frame_wait_seconds = 0.1\n",
75 | "\n",
76 | " def dequeue(self): \n",
77 | " while self.buffer: yield self.buffer.pop(0)\n",
78 | "\n",
79 | " \n",
80 | " def __iter__(self) -> Tuple[NamedTuple]:\n",
81 | " img = None\n",
82 | " for record in self.source_datapipe:\n",
83 | " # for o in self.dequeue():\n",
84 | " if is_record(record):\n",
85 | " if record.value is None: continue\n",
86 | " if img is None: img = plt.imshow(record.value)\n",
87 | " img.set_data(record.value) \n",
88 | " plt.axis('off')\n",
89 | " display(plt.gcf())\n",
90 | " clear_output(wait=True)\n",
91 | " yield record\n",
92 | "add_docs(\n",
93 | " SimpleJupyterVideoPlayer,\n",
94 | " \"\"\"Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field.\n",
95 | " This only can handle 1 env input.\"\"\",\n",
96 | " dequeue=\"Grabs records from the `main_queue` and attempts to display them\"\n",
97 | ")"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "id": "d4e46c41-f7f9-4168-b453-c43ec80377f0",
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "#|export\n",
108 | "class ImageCollector(dp.iter.IterDataPipe):\n",
109 | " title:str='image'\n",
110 | "\n",
111 | " def __init__(self,source_datapipe):\n",
112 | " self.source_datapipe = source_datapipe\n",
113 | "\n",
114 | " def convert_np(self,o):\n",
115 | " if isinstance(o,torch.Tensor): return o.detach().numpy()\n",
116 | " elif isinstance(o,np.ndarray): return o\n",
117 | " else: raise ValueError(f'Expects Tensor or np.ndarray not {type(o)}')\n",
118 | " \n",
119 | " def __iter__(self):\n",
120 | " # for q in self.main_buffers: q.append(Record('image',None))\n",
121 | " yield Record(self.title,None)\n",
122 | " for steps in self.source_datapipe:\n",
123 | " if isinstance(steps,dp.DataChunk):\n",
124 | " for step in steps:\n",
125 | " yield Record(self.title,self.convert_np(step.image))\n",
126 | " else:\n",
127 | " yield Record(self.title,self.convert_np(steps.image))\n",
128 | " yield steps"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "1ade24c0",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "from fastrl.envs.gym import GymDataPipe"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "6a4b9ca9-2027-40a1-ac97-4516d60479a2",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "#|eval:false\n",
149 | "%matplotlib inline"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "id": "ead27188-7322-46c2-9300-59842af2386d",
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "\n",
160 | "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True)\n",
161 | "pipe = ImageCollector(pipe)\n",
162 | "pipe = SimpleJupyterVideoPlayer(pipe)\n",
163 | "\n",
164 | "for o in pipe: pass"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "id": "current-pilot",
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "#|hide\n",
175 | "#|eval: false\n",
176 | "!nbdev_export"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "id": "e0d82468-a2bf-4bfd-9ac7-e56db49b8476",
183 | "metadata": {},
184 | "outputs": [],
185 | "source": []
186 | }
187 | ],
188 | "metadata": {
189 | "kernelspec": {
190 | "display_name": "python3",
191 | "language": "python",
192 | "name": "python3"
193 | }
194 | },
195 | "nbformat": 4,
196 | "nbformat_minor": 5
197 | }
198 |
--------------------------------------------------------------------------------
/fastrl/test_utils.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20_test_utils.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['get_env', 'try_import', 'nvidia_mem', 'nvidia_smi', 'initialize_notebook', 'show_install']
5 |
6 | # %% ../nbs/20_test_utils.ipynb 1
7 | # Python native modules
8 | import os
9 | import re
10 | import sys
11 | import importlib
12 | # Third party libs
13 |
14 | # Local modules
15 |
16 | # %% ../nbs/20_test_utils.ipynb 4
17 | def get_env(name):
18 | "Return env var value if it's defined and not an empty string, or return Unknown"
19 | res = os.environ.get(name,'')
20 | return res if len(res) else "Unknown"
21 |
22 | # %% ../nbs/20_test_utils.ipynb 5
23 | def try_import(module):
24 | "Try to import `module`. Returns module's object on success, None on failure"
25 | try: return importlib.import_module(module)
26 | except: return None
27 |
28 | # %% ../nbs/20_test_utils.ipynb 6
29 | def nvidia_mem():
30 | from fastcore.all import run
31 | try: mem = run("nvidia-smi --query-gpu=memory.total --format=csv,nounits,noheader")
32 | except: return None
33 | return mem.strip().split('\n')
34 |
35 | # %% ../nbs/20_test_utils.ipynb 7
36 | def nvidia_smi(cmd = "nvidia-smi"):
37 | from fastcore.all import run
38 | try: res = run(cmd)
39 | except OSError as e: return None
40 | return res
41 |
42 | # %% ../nbs/20_test_utils.ipynb 8
43 | def initialize_notebook():
44 | """
45 | Function to initialize the notebook environment considering whether it is in Colab or not.
46 | It handles installation of necessary packages and setting up the environment variables.
47 | """
48 |
49 | # Checking if the environment is Google Colab
50 | if os.path.exists("/content"):
51 | # Installing necessary packages
52 | os.system("pip install -Uqq fastrl['dev'] pyvirtualdisplay")
53 | os.system("apt-get install -y xvfb python-opengl > /dev/null 2>&1")
54 |
55 | # Starting a virtual display
56 | from pyvirtualdisplay import Display
57 | display = Display(visible=0, size=(400, 300))
58 | display.start()
59 |
60 | else:
61 | # If not in Colab, importing necessary packages and checking environment variables
62 | from nbdev.showdoc import show_doc
63 | from nbdev.imports import IN_NOTEBOOK, IN_COLAB, IN_IPYTHON
64 |
65 | # Asserting the environment variables
66 | if not os.environ.get("IN_TEST", None):
67 | assert IN_NOTEBOOK
68 | assert not IN_COLAB
69 | assert IN_IPYTHON
70 |
71 |
72 | # %% ../nbs/20_test_utils.ipynb 9
73 | def show_install(show_nvidia_smi:bool=False):
74 | "Print user's setup information"
75 |
76 | # import fastai
77 | import platform
78 | import fastprogress
79 | import fastcore
80 | import fastrl
81 | import torch
82 | from fastcore.all import ifnone
83 |
84 |
85 | rep = []
86 | opt_mods = []
87 |
88 | rep.append(["=== Software ===", None])
89 | rep.append(["python", platform.python_version()])
90 | rep.append(["fastrl", fastrl.__version__])
91 | # rep.append(["fastai", fastai.__version__])
92 | rep.append(["fastcore", fastcore.__version__])
93 | rep.append(["fastprogress", fastprogress.__version__])
94 | rep.append(["torch", torch.__version__])
95 |
96 | # nvidia-smi
97 | smi = nvidia_smi()
98 | if smi:
99 | match = re.findall(r'Driver Version: +(\d+\.\d+)', smi)
100 | if match: rep.append(["nvidia driver", match[0]])
101 |
102 | available = "available" if torch.cuda.is_available() else "**Not available** "
103 | rep.append(["torch cuda", f"{torch.version.cuda} / is {available}"])
104 |
105 | # no point reporting on cudnn if cuda is not available, as it
106 | # seems to be enabled at times even on cpu-only setups
107 | if torch.cuda.is_available():
108 | enabled = "enabled" if torch.backends.cudnn.enabled else "**Not enabled** "
109 | rep.append(["torch cudnn", f"{torch.backends.cudnn.version()} / is {enabled}"])
110 |
111 | rep.append(["\n=== Hardware ===", None])
112 |
113 | gpu_total_mem = []
114 | nvidia_gpu_cnt = 0
115 | if smi:
116 | mem = nvidia_mem()
117 | nvidia_gpu_cnt = len(ifnone(mem, []))
118 |
119 | if nvidia_gpu_cnt: rep.append(["nvidia gpus", nvidia_gpu_cnt])
120 |
121 | torch_gpu_cnt = torch.cuda.device_count()
122 | if torch_gpu_cnt:
123 | rep.append(["torch devices", torch_gpu_cnt])
124 | # information for each gpu
125 | for i in range(torch_gpu_cnt):
126 | rep.append([f" - gpu{i}", (f"{gpu_total_mem[i]}MB | " if gpu_total_mem else "") + torch.cuda.get_device_name(i)])
127 | else:
128 | if nvidia_gpu_cnt:
129 | rep.append([f"Have {nvidia_gpu_cnt} GPU(s), but torch can't use them (check nvidia driver)", None])
130 | else:
131 | rep.append([f"No GPUs available", None])
132 |
133 |
134 | rep.append(["\n=== Environment ===", None])
135 |
136 | rep.append(["platform", platform.platform()])
137 |
138 | if platform.system() == 'Linux':
139 | distro = try_import('distro')
140 | if distro:
141 | # full distro info
142 | rep.append(["distro", ' '.join(distro.linux_distribution())])
143 | else:
144 | opt_mods.append('distro');
145 | # partial distro info
146 | rep.append(["distro", platform.uname().version])
147 |
148 | rep.append(["conda env", get_env('CONDA_DEFAULT_ENV')])
149 | rep.append(["python", sys.executable])
150 | rep.append(["sys.path", "\n".join(sys.path)])
151 |
152 | print("\n\n```text")
153 |
154 | keylen = max([len(e[0]) for e in rep if e[1] is not None])
155 | for e in rep:
156 | print(f"{e[0]:{keylen}}", (f": {e[1]}" if e[1] is not None else ""))
157 |
158 | if smi:
159 | if show_nvidia_smi: print(f"\n{smi}")
160 | else:
161 | if torch_gpu_cnt: print("no nvidia-smi is found")
162 | else: print("no supported gpus found on this system")
163 |
164 | print("```\n")
165 |
166 | print("Please make sure to include opening/closing ``` when you paste into forums/github to make the reports appear formatted as code sections.\n")
167 |
168 | if opt_mods:
169 | print("Optional package(s) to enhance the diagnostics can be installed with:")
170 | print(f"pip install {' '.join(opt_mods)}")
171 | print("Once installed, re-run this utility to get the additional information")
172 |
--------------------------------------------------------------------------------
/nbs/11_FAQ/99_notes.multi_proc.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "#|eval: false\n",
12 | "! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \\\n",
13 | " apt-get install -y xvfb python-opengl > /dev/null 2>&1 \n",
14 | "# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "viral-cambridge",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "#|hide\n",
25 | "#|eval: false\n",
26 | "from fastcore.imports import in_colab\n",
27 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n",
28 | "if not in_colab():\n",
29 | " from nbdev.showdoc import *\n",
30 | " from nbdev.imports import *\n",
31 | " if not os.environ.get(\"IN_TEST\", None):\n",
32 | " assert IN_NOTEBOOK\n",
33 | " assert not IN_COLAB\n",
34 | " assert IN_IPYTHON\n",
35 | "else:\n",
36 | " # Virutual display is needed for colab\n",
37 | " from pyvirtualdisplay import Display\n",
38 | " display = Display(visible=0, size=(400, 300))\n",
39 | " display.start()"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "assisted-contract",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "# Python native modules\n",
50 | "import os\n",
51 | "from copy import deepcopy\n",
52 | "# Third party libs\n",
53 | "from fastcore.all import *\n",
54 | "import numpy as np\n",
55 | "# Local modules\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "8681dfae-2d75-4936-8ab3-eb1bf1727312",
61 | "metadata": {},
62 | "source": [
63 | "## MultiProcessing Notes"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "2efbf0be-7a0c-4d7f-b876-df06e410dae7",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "import torchdata.datapipes as dp"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "id": "139af462-4cee-42a0-9b8c-037f413ecf10",
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "%%writefile ../external_run_scripts/notes_multi_proc_82.py\n",
84 | "import torchdata.datapipes as dp\n",
85 | "from torch.utils.data import IterableDataset\n",
86 | "\n",
87 | "class AddABunch1(dp.iter.IterDataPipe):\n",
88 | " def __init__(self,q):\n",
89 | " super().__init__()\n",
90 | " self.q = [q]\n",
91 | "\n",
92 | " def __iter__(self):\n",
93 | " for o in range(10): \n",
94 | " self.q[0].put(o)\n",
95 | " yield o\n",
96 | " \n",
97 | "class AddABunch2(dp.iter.IterDataPipe):\n",
98 | " def __init__(self,source_datapipe,q):\n",
99 | " super().__init__()\n",
100 | " self.q = q\n",
101 | " print(id(self.q))\n",
102 | " self.source_datapipe = source_datapipe\n",
103 | "\n",
104 | " def __iter__(self):\n",
105 | " for o in self.source_datapipe: \n",
106 | " print(id(self.q))\n",
107 | " self.q.put(o)\n",
108 | " yield o\n",
109 | " \n",
110 | "class AddABunch3(IterableDataset):\n",
111 | " def __init__(self,q):\n",
112 | " self.q = q\n",
113 | "\n",
114 | " def __iter__(self):\n",
115 | " for o in range(10): \n",
116 | " print(id(self.q))\n",
117 | " self.q.put(o)\n",
118 | " yield o\n",
119 | "\n",
120 | "if __name__=='__main__':\n",
121 | " from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method\n",
122 | " import torch\n",
123 | " \n",
124 | " try: set_start_method('spawn')\n",
125 | " except RuntimeError: pass\n",
126 | " # from torch.utils.data.dataloader_experimental import DataLoader2\n",
127 | " from torchdata.dataloader2 import DataLoader2\n",
128 | " from torchdata.dataloader2.reading_service import MultiProcessingReadingService\n",
129 | "\n",
130 | " m = Manager()\n",
131 | " q = m.Queue()\n",
132 | " \n",
133 | " pipe = AddABunch2(list(range(10)),q)\n",
134 | " print(type(pipe))\n",
135 | " dl = DataLoader2(pipe,\n",
136 | " reading_service=MultiProcessingReadingService(num_workers=1)\n",
137 | " ) # Will fail if num_workers>0\n",
138 | " \n",
139 | " # dl = DataLoader2(AddABunch1(q),num_workers=1) # Will fail if num_workers>0\n",
140 | " # dl = DataLoader2(AddABunch2(q),num_workers=1) # Will fail if num_workers>0\n",
141 | " # dl = DataLoader2(AddABunch3(q),num_workers=1) # Will succeed if num_workers>0\n",
142 | " list(dl)\n",
143 | " \n",
144 | " while not q.empty():\n",
145 | " print(q.get())"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "id": "a3b223ee-40ab-4eea-a7cc-35d659301b14",
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "988390db-b284-494d-a413-2fc12b6fa032",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "get_start_method()"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "id": "current-pilot",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "#|hide\n",
176 | "#|eval: false\n",
177 | "from fastcore.imports import in_colab\n",
178 | "\n",
179 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n",
180 | "if not in_colab():\n",
181 | " from nbdev import nbdev_export\n",
182 | " nbdev_export()"
183 | ]
184 | }
185 | ],
186 | "metadata": {
187 | "kernelspec": {
188 | "display_name": "python3",
189 | "language": "python",
190 | "name": "python3"
191 | }
192 | },
193 | "nbformat": 4,
194 | "nbformat_minor": 5
195 | }
196 |
--------------------------------------------------------------------------------
/fastrl/learner/core.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/06_Learning/10a_learner.core.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['LearnerBase', 'LearnerHead', 'StepBatcher']
5 |
6 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 2
7 | # Python native modules
8 | import os
9 | from contextlib import contextmanager
10 | from typing import List,Union,Dict,Optional,Iterable,Tuple
11 | # Third party libs
12 | from fastcore.all import add_docs
13 | import torchdata.datapipes as dp
14 | from torchdata.dataloader2.graph import list_dps
15 | import torch
16 | from torch import nn
17 | from torchdata.dataloader2 import DataLoader2
18 | from torchdata.dataloader2.graph import traverse_dps,DataPipeGraph,DataPipe
19 | # Local modules
20 | from ..torch_core import evaluating
21 | from ..pipes.core import find_dp
22 | from ..loggers.core import Record,EpochCollector,BatchCollector
23 |
24 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 4
25 | class LearnerBase(dp.iter.IterDataPipe):
26 | def __init__(self,
27 | # The base NN that we getting raw action values out of.
28 | # This can either be a `nn.Module` or a dict of multiple `nn.Module`s
29 | # For multimodel training
30 | model:Union[nn.Module,Dict[str,nn.Module]],
31 | # The dataloaders to read data from for training. This can be a single
32 | # DataLoader2 or an iterable that yields from a DataLoader2.
33 | dls:Union[DataLoader2,Iterable],
34 | # By default for reinforcement learning, we want to keep the workers
35 | # alive so that simluations are not being shutdown / restarted.
36 | # Epochs are expected to be handled semantically via tracking the number
37 | # of batches.
38 | infinite_dls:bool=True
39 | ):
40 | self.model = model
41 | self.iterable = dls
42 | self.learner_base = self
43 | self.infinite_dls = infinite_dls
44 | self._dls = None
45 | self._ended = False
46 |
47 | def __getstate__(self):
48 | state = {k:v for k,v in self.__dict__.items() if k not in ['_dls']}
49 | # TODO: Needs a better way to serialize / deserialize states.
50 | # state['iterable'] = [d.state_dict() for d in state['iterable']]
51 | if dp.iter.IterDataPipe.getstate_hook is not None:
52 | return dp.iter.IterDataPipe.getstate_hook(state)
53 | return state
54 |
55 | def __setstate__(self, state):
56 | # state['iterable'] = [d.from_state_dict() for d in state['iterable']]
57 | for k,v in state.items():
58 | setattr(self,k,v)
59 |
60 | def end(self):
61 | self._ended = True
62 |
63 | def __iter__(self):
64 | self._ended = False
65 | for data in self.iterable:
66 | if self._ended:
67 | break
68 | yield data
69 |
70 | add_docs(
71 | LearnerBase,
72 | "Combines models,dataloaders, and optimizers together for running a training pipeline.",
73 | reset="""If `infinite_dls` is false, then all dls will be reset, otherwise they will be
74 | kept alive.""",
75 | end="When called, will cause the Learner to stop iterating and cleanup."
76 | )
77 |
78 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 5
79 | class LearnerHead(dp.iter.IterDataPipe):
80 | def __init__(
81 | self,
82 | source_datapipes:Tuple[dp.iter.IterDataPipe]
83 | ):
84 | if not isinstance(source_datapipes,tuple):
85 | self.source_datapipes = (source_datapipes,)
86 | else:
87 | self.source_datapipes = source_datapipes
88 | self.dp_idx = 0
89 |
90 | def __iter__(self): yield from self.source_datapipes[self.dp_idx]
91 |
92 | def fit(self,epochs):
93 | self.dp_idx = 0
94 | epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector)
95 | learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase)
96 | epocher.epochs = epochs
97 | if isinstance(learner.model,dict):
98 | for m in learner.model.values():
99 | m.train()
100 | else:
101 | learner.model.train()
102 | for _ in self: pass
103 |
104 | def validate(self,epochs=1,batches=100,show=True,return_outputs=False) -> DataPipe:
105 | self.dp_idx = 1
106 | epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector)
107 | epocher.epochs = epochs
108 | batcher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),BatchCollector)
109 | batcher.batches = batches
110 | learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase)
111 | model = learner.model
112 | model = tuple(model.values()) if isinstance(model,dict) else model
113 | with evaluating(model):
114 | if return_outputs:
115 | return list(self)
116 | else:
117 | for _ in self: pass
118 | if show:
119 | pipes = list_dps(traverse_dps(self.source_datapipes[self.dp_idx]))
120 | for pipe in pipes:
121 | if hasattr(pipe,'show'):
122 | return pipe.show()
123 |
124 | add_docs(
125 | LearnerHead,
126 | """LearnerHead can connect to multiple `LearnerBase`s and handles training
127 | and validation execution.
128 | """,
129 | fit="Runs the `LearnerHead` pipeline for `epochs`",
130 | validate="""If there is more than 1 dl, then run 1 epoch of that dl based on
131 | `dl_idx` and returns the original datapipe for displaying."""
132 | )
133 |
134 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 18
135 | class StepBatcher(dp.iter.IterDataPipe):
136 | def __init__(self,
137 | source_datapipe,
138 | device=None
139 | ):
140 | self.source_datapipe = source_datapipe
141 | self.device = device
142 |
143 | def vstack_by_fld(self,batch,fld):
144 | try:
145 | t = torch.vstack(tuple(getattr(step,fld) for step in batch))
146 | # if self.device is not None:
147 | # t = t.to(torch.device(self.device))
148 | t.requires_grad = False
149 | return t
150 | except RuntimeError as e:
151 | print(f'Failed to stack {fld} given batch: {batch}')
152 | raise
153 |
154 | def __iter__(self):
155 | for batch in self.source_datapipe:
156 | cls = batch[0].__class__
157 | batched_step = cls(**{fld:self.vstack_by_fld(batch,fld) for fld in cls.__dataclass_fields__},batch_size=[len(batch)])
158 | if self.device is not None:
159 | batched_step = batched_step.to(self.device)
160 | yield batched_step
161 |
162 | add_docs(
163 | StepBatcher,
164 | "Converts multiple `StepType` into a single `StepType` with the fields concated.",
165 | vstack_by_fld="vstacks a `fld` in `batch`"
166 | )
167 |
--------------------------------------------------------------------------------
/nbs/11_FAQ/99_notes.speed.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "#|eval: false\n",
12 | "! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \\\n",
13 | " apt-get install -y xvfb python-opengl > /dev/null 2>&1 \n",
14 | "# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "viral-cambridge",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "#|hide\n",
25 | "#|eval: false\n",
26 | "from fastcore.imports import in_colab\n",
27 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n",
28 | "if not in_colab():\n",
29 | " from nbdev.showdoc import *\n",
30 | " from nbdev.imports import *\n",
31 | " if not os.environ.get(\"IN_TEST\", None):\n",
32 | " assert IN_NOTEBOOK\n",
33 | " assert not IN_COLAB\n",
34 | " assert IN_IPYTHON\n",
35 | "else:\n",
36 | " # Virutual display is needed for colab\n",
37 | " from pyvirtualdisplay import Display\n",
38 | " display = Display(visible=0, size=(400, 300))\n",
39 | " display.start()"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "assisted-contract",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "# Python native modules\n",
50 | "import os\n",
51 | "from copy import deepcopy\n",
52 | "# Third party libs\n",
53 | "from fastcore.all import *\n",
54 | "import numpy as np\n",
55 | "# Local modules\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "lesser-innocent",
61 | "metadata": {},
62 | "source": [
63 | "# Speed\n",
64 | "> Some obvious / not so obvious notes on speed"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "id": "4cce6de1-8643-4338-b65e-60406a35cb3a",
70 | "metadata": {},
71 | "source": [
72 | "## Numpy to Tensor Performance"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "id": "f51e6709-ae53-474b-b975-914bd36159b2",
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "img=np.random.randint(0,255,size=(240, 320, 3))"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "id": "202749e2-72dd-4baa-b5e5-8217f297c0a3",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "%%timeit\n",
93 | "#|eval: false\n",
94 | "img=np.random.randint(0,255,size=(240, 320, 3))"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "00356113-ef79-42aa-a746-9eca3ffe65bc",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "%%timeit\n",
105 | "#|eval: false\n",
106 | "deepcopy(img)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "45a4e034-9a2b-4e4a-b1b9-71d23b12155b",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "%%timeit\n",
117 | "#|eval: false\n",
118 | "Tensor(img)"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "1acadadb-94e3-4be6-841a-2045b1da8767",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "%%timeit\n",
129 | "#|eval: false\n",
130 | "Tensor([img])"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "id": "9d5040b2-6670-495b-abf7-fe5a26c9eda0",
136 | "metadata": {},
137 | "source": [
138 | "You will notice that if you wrap a numpy in a list, it completely kills the performance. The solution is to\n",
139 | "just add a batch dim to the existing array and pass it directly."
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "id": "f1007f5b-db0d-4152-b474-e6e319bd81e5",
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "%%timeit\n",
150 | "#|eval: false\n",
151 | "Tensor(np.expand_dims(img,0))"
152 | ]
153 | },
154 | {
155 | "cell_type": "markdown",
156 | "id": "ac025532-f917-4ded-9ec7-903373712c16",
157 | "metadata": {},
158 | "source": [
159 | "In fact we can just test this with python lists..."
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "id": "0a03815e-205e-48ee-8d37-be8ffa2f7a20",
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "%%timeit\n",
170 | "#|eval: false\n",
171 | "Tensor([[1]])"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "id": "e8001456-bb86-44c2-b68f-59c4e88018bd",
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "test_arr=[[1]*270000]"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "3ba7e414-d53f-4865-832c-c16aa098963d",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "%%timeit\n",
192 | "#|eval: false\n",
193 | "Tensor(test_arr)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "id": "dc756fe2-c890-4736-9e2a-426e9d646b3c",
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "test_arr=np.array([[1]*270000])"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "id": "0627b489-d4b7-475c-bf73-de263ea5c5e4",
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "%%timeit\n",
214 | "#|eval: false\n",
215 | "Tensor(test_arr)"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "id": "c43811a0-ca7d-4d20-b4a8-51aaccfe1b6d",
221 | "metadata": {},
222 | "source": [
223 | "This is horrifying just how made of a performance hit this causes... So we will be avoiding python list inputs \n",
224 | "to Tensors for now on..."
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "id": "current-pilot",
231 | "metadata": {},
232 | "outputs": [],
233 | "source": [
234 | "#|hide\n",
235 | "#|eval: false\n",
236 | "from fastcore.imports import in_colab\n",
237 | "\n",
238 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n",
239 | "if not in_colab():\n",
240 | " from nbdev import nbdev_export\n",
241 | " nbdev_export()"
242 | ]
243 | }
244 | ],
245 | "metadata": {
246 | "kernelspec": {
247 | "display_name": "python3",
248 | "language": "python",
249 | "name": "python3"
250 | }
251 | },
252 | "nbformat": 4,
253 | "nbformat_minor": 5
254 | }
255 |
--------------------------------------------------------------------------------
/nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "durable-dialogue",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#|hide\n",
11 | "from fastrl.test_utils import initialize_notebook\n",
12 | "initialize_notebook()"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "assisted-contract",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "#|export\n",
23 | "# Python native modules\n",
24 | "import os\n",
25 | "import io\n",
26 | "from typing import Tuple,Any,Optional,NamedTuple,Iterable\n",
27 | "# Third party libs\n",
28 | "import imageio\n",
29 | "from fastcore.all import add_docs,ifnone\n",
30 | "import matplotlib.pyplot as plt\n",
31 | "import torchdata.datapipes as dp\n",
32 | "from torchdata.datapipes import functional_datapipe\n",
33 | "from IPython.core.display import Video,Image\n",
34 | "from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService\n",
35 | "# Local modules\n",
36 | "from fastrl.loggers.core import LoggerBase,is_record\n",
37 | "from fastrl.pipes.core import DataPipeAugmentationFn,apply_dp_augmentation_fns\n",
38 | "from fastrl.loggers.jupyter_visualizers import ImageCollector"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "offshore-stuart",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "#|default_exp loggers.vscode_visualizers"
49 | ]
50 | },
51 | {
52 | "attachments": {},
53 | "cell_type": "markdown",
54 | "id": "lesser-innocent",
55 | "metadata": {},
56 | "source": [
57 | "# Visualizers - VS-Code\n",
58 | "> Iterable pipes for displaying environments as they run using `typing.NamedTuples` with `image` fields for VS-Code\n",
59 | "\n",
60 | "`fastrl.jupyter_visualizers` can be used in vscode, however you likely will notice flickering for video\n",
61 | "based outputs. For vscode, we can generate a gif instead."
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "id": "fde1fa58",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "dp.iter.Repeater"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "id": "9cb488d2-41f1-4160-a938-e39003f1a06a",
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "#|export\n",
82 | "class SimpleVSCodeVideoPlayer(dp.iter.IterDataPipe):\n",
83 | " def __init__(self, \n",
84 | " source_datapipe=None, \n",
85 | " skip_frames:int=1,\n",
86 | " fps:int=30,\n",
87 | " downsize_res=(2,2)\n",
88 | " ):\n",
89 | " self.source_datapipe = source_datapipe\n",
90 | " self.fps = fps\n",
91 | " self.skip_frames = skip_frames\n",
92 | " self.downsize_res = downsize_res\n",
93 | " self._bytes_object = None\n",
94 | " self.frames = [] \n",
95 | "\n",
96 | " def reset(self):\n",
97 | " super().reset()\n",
98 | " self._bytes_object = io.BytesIO()\n",
99 | "\n",
100 | " def show(self,start:int=0,end:Optional[int]=None,step:int=1):\n",
101 | " print(f'Creating gif from {len(self.frames)} frames')\n",
102 | " imageio.mimwrite(\n",
103 | " self._bytes_object,\n",
104 | " self.frames[start:end:step],\n",
105 | " format='GIF',\n",
106 | " duration=self.fps\n",
107 | " )\n",
108 | " return Image(self._bytes_object.getvalue())\n",
109 | " \n",
110 | " def __iter__(self) -> Tuple[NamedTuple]:\n",
111 | " n_frame = 0\n",
112 | " for record in self.source_datapipe:\n",
113 | " # for o in self.dequeue():\n",
114 | " if is_record(record) and record.name=='image':\n",
115 | " if record.value is None: continue\n",
116 | " n_frame += 1\n",
117 | " if n_frame%self.skip_frames!=0: continue\n",
118 | " self.frames.append(\n",
119 | " record.value[::self.downsize_res[0],::self.downsize_res[1]]\n",
120 | " )\n",
121 | " yield record\n",
122 | "add_docs(\n",
123 | "SimpleVSCodeVideoPlayer,\n",
124 | "\"\"\"Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field.\n",
125 | "This only can handle 1 env input.\"\"\",\n",
126 | "show=\"In order to show the video, this must be called in a notebook cell.\",\n",
127 | "reset=\"Will reset the bytes object that is used to store file data.\"\n",
128 | ")"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "d486d06a",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "#|export\n",
139 | "@functional_datapipe('visualize_vscode')\n",
140 | "class VSCodeDataPipe(dp.iter.IterDataPipe):\n",
141 | " def __new__(self,source:Iterable):\n",
142 | " \"This is the function that is actually run by `DataBlock`\"\n",
143 | " pipe = ImageCollector(source).dump_records()\n",
144 | " pipe = SimpleVSCodeVideoPlayer(pipe)\n",
145 | " return pipe \n",
146 | " "
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "id": "55130404",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "from fastrl.envs.gym import GymDataPipe"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "id": "8b58a84d",
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "#|hide\n",
167 | "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True).visualize_vscode()\n",
168 | "\n",
169 | "list(pipe);\n",
170 | "pipe.show()"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "id": "ead27188-7322-46c2-9300-59842af2386d",
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "#|hide\n",
181 | "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True)\n",
182 | "pipe = VSCodeDataPipe(pipe)\n",
183 | "\n",
184 | "list(pipe);\n",
185 | "pipe.show()"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "id": "current-pilot",
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "#|hide\n",
196 | "#|eval: false\n",
197 | "!nbdev_export"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "id": "e0d82468-a2bf-4bfd-9ac7-e56db49b8476",
204 | "metadata": {},
205 | "outputs": [],
206 | "source": []
207 | }
208 | ],
209 | "metadata": {
210 | "kernelspec": {
211 | "display_name": "python3",
212 | "language": "python",
213 | "name": "python3"
214 | }
215 | },
216 | "nbformat": 4,
217 | "nbformat_minor": 5
218 | }
219 |
--------------------------------------------------------------------------------
/fastrl/memory/memory_visualizer.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/04_Memory/01_memory_visualizer.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['MemoryBufferViewer']
5 |
6 | # %% ../../nbs/04_Memory/01_memory_visualizer.ipynb 2
7 | # Python native modules
8 | import io
9 | from typing import List
10 | # Third party libs
11 | from PIL import Image
12 | from ipywidgets import Button, HBox, VBox, Output, IntText, Label
13 | from ipywidgets import widgets
14 | from IPython.display import display
15 | import numpy as np
16 | import torch
17 | # Local modules
18 | from ..core import StepTypes
19 |
20 | # %% ../../nbs/04_Memory/01_memory_visualizer.ipynb 4
21 | class MemoryBufferViewer:
22 | def __init__(self, memory:List[StepTypes.types], agent=None, ignore_image:bool=False):
23 | # Assuming memory contains SimpleStep instances or None
24 | self.memory = memory
25 | self.agent = agent
26 | self.current_index = 0
27 | self.ignore_image = ignore_image
28 | # Add a label for displaying the number of elements in memory
29 | self.memory_size_label = Label(value=f"Number of Elements in Memory: {len([x for x in memory if x is not None])}")
30 |
31 | # Create the widgets
32 | self.out = Output()
33 | self.next_button = Button(description="Next")
34 | self.prev_button = Button(description="Previous")
35 | self.goto_text = IntText(value=0, description='Index:')
36 | # Button to jump to the desired index
37 | self.goto_button = Button(description="Go")
38 | self.goto_button.on_click(self.goto_step)
39 | self.action_value_label = Label()
40 |
41 | # Setup event handlers
42 | self.next_button.on_click(self.next_step)
43 | self.prev_button.on_click(self.prev_step)
44 | self.manual_navigation = False
45 | self.goto_text.observe(self.jump_to_index, names='value')
46 |
47 | # Display the widgets
48 | # Update the widget layout
49 | self.display_content_placeholder = VBox([])
50 | self.layout = VBox([
51 | self.memory_size_label,
52 | HBox([self.prev_button, self.next_button, self.goto_text, self.goto_button]),
53 | self.action_value_label,
54 | self.out,
55 | self.display_content_placeholder
56 | ])
57 | self.show_current()
58 | display(self.layout)
59 |
60 | def jump_to_index(self, change):
61 | if not self.manual_navigation:
62 | idx = change['new']
63 | if 0 <= idx < len(self.memory):
64 | self.current_index = idx
65 | self.show_current()
66 | else:
67 | self.manual_navigation = False
68 |
69 | def next_step(self, change):
70 | self._toggle_buttons_state(False) # Disable buttons
71 | self.current_index = min(len(self.memory) - 1, self.current_index + 1)
72 | self.manual_navigation = True
73 | self.goto_text.value = self.current_index
74 | self.show_current()
75 | self._toggle_buttons_state(True) # Enable buttons
76 |
77 | def prev_step(self, change):
78 | self._toggle_buttons_state(False) # Disable buttons
79 | self.current_index = max(0, self.current_index - 1)
80 | self.manual_navigation = True
81 | self.goto_text.value = self.current_index
82 | self.show_current()
83 | self._toggle_buttons_state(True) # Enable buttons
84 |
85 | def goto_step(self, change):
86 | self._toggle_buttons_state(False) # Disable buttons
87 | target_idx = self.goto_text.value
88 | if 0 <= target_idx < len(self.memory):
89 | self.current_index = target_idx
90 | self.show_current()
91 | self._toggle_buttons_state(True) # Enable buttons
92 |
93 | def _toggle_buttons_state(self, state):
94 | """Helper function to toggle button states."""
95 | self.prev_button.disabled = not state
96 | self.next_button.disabled = not state
97 | self.goto_button.disabled = not state
98 |
99 | def tensor_to_pil(self, tensor_image):
100 | """Convert a tensor to a PIL Image."""
101 | # Convert the tensor to numpy
102 | img_np = tensor_image.numpy()
103 |
104 | # Check if the tensor was in C, H, W format and convert it to H, W, C for PIL
105 | if img_np.ndim == 3 and img_np.shape[2] != 3:
106 | img_np = np.transpose(img_np, (1, 2, 0))
107 |
108 | # Make sure the data type is right
109 | if img_np.dtype in (np.float32,np.float64):
110 | img_np = (img_np * 255).astype(np.uint8)
111 |
112 | return Image.fromarray(img_np)
113 |
114 | def pil_image_to_byte_array(self, pil_image):
115 | """Helper function to convert PIL image to byte array."""
116 | buffer = io.BytesIO()
117 | pil_image.save(buffer, format='JPEG')
118 | return buffer.getvalue()
119 |
120 | def show_current(self):
121 | self.out.clear_output(wait=True)
122 | with self.out:
123 | if self.memory[self.current_index] is not None:
124 | step = self.memory[self.current_index]
125 |
126 | # Prepare the right-side content (step details)
127 | details_list = []
128 | details_list.append(Label(f"Action Value: {step.action.item()}"))
129 | # If agent is provided, predict the action based on step.state
130 | if self.agent is not None:
131 | with torch.no_grad():
132 | for predicted_action in self.agent([step]):pass
133 | details_list.append(Label(f"Agent Predicted Action: {predicted_action}"))
134 |
135 | for field, value in step.to_tensordict().items():
136 | if field not in ['state', 'next_state', 'image']:
137 | details_list.append(Label(f"{field.capitalize()}: {value}"))
138 |
139 | details_display = VBox(details_list)
140 |
141 | # If the image is present, prepare left-side content
142 | if torch.is_tensor(step.image) and step.image.nelement() > 1 and not self.ignore_image:
143 | pil_image = self.tensor_to_pil(step.image)
144 | img_display = widgets.Image(value=self.pil_image_to_byte_array(pil_image), format='jpeg')
145 | display_content = HBox([img_display, details_display])
146 | else:
147 | # If image is not present, use the entire space for details
148 | # You can expand this to include 'state' and 'next_state' as desired
149 | # If image is not present, display 'state' and 'next_state' along with other details
150 | state_label = Label(f"State: {step.state}")
151 | next_state_label = Label(f"Next State: {step.next_state}")
152 | display_content = VBox([details_display, state_label, next_state_label])
153 |
154 | self.display_content_placeholder.children = [display_content]
155 | else:
156 | print(f"Step {self.current_index}: Empty")
157 | self.action_value_label.value = ""
158 |
--------------------------------------------------------------------------------
/fastrl/cli.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/19_cli.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['fastrl_make_requirements', 'proc_nbs', 'fastrl_nbdev_docs', 'fastrl_nbdev_test']
5 |
6 | # %% ../nbs/19_cli.ipynb 3
7 | # Python native modules
8 | import os
9 | import shutil
10 | # Third party libs
11 | from fastcore.all import *
12 | # Local modules
13 | from nbdev.quarto import _nbglob_docs,_sprun,_pre_docs,nbdev_readme,move,proc_nbs
14 | from nbdev.test import test_nb,_keep_file
15 |
16 | # %% ../nbs/19_cli.ipynb 6
17 | @call_parse
18 | def fastrl_make_requirements(
19 | path:Path=None, # The path to a dir with the settings.ini, if none, cwd.
20 | project_file:str='settings.ini', # The file to load for reading the requirements
21 | out_path:Path=None, # The output path (can be relative to `path`)
22 | verbose:bool=False # Output to stdout
23 | ):
24 | requirement_types = ['','dev_','pip_']
25 | path = ifnone(path, Path.cwd())/project_file
26 |
27 | if not path.exists(): raise OSError(f'File {path} does not exist')
28 |
29 | out_path = ifnone(out_path, Path('extra'))
30 | out_path = out_path if out_path.is_absolute() else path.parent/out_path
31 | out_path.mkdir(parents=True, exist_ok=True)
32 | if verbose: print('Outputting to path: ',out_path)
33 | config = Config(path.parent,path.name)
34 |
35 | for req in requirement_types:
36 | requirements = config[req+'requirements']
37 | requirements = requirements.replace(' ','\n')
38 | Path(out_path/(req+'requirements.txt')).write_text(requirements)
39 |
40 | # %% ../nbs/19_cli.ipynb 7
41 | from nbdev.config import *
42 | from nbdev.doclinks import *
43 |
44 | from fastcore.utils import *
45 | from fastcore.script import call_parse
46 | from fastcore.shutil import rmtree,move,copytree
47 | from fastcore.meta import delegates
48 | from nbdev.serve import proc_nbs,_proc_file
49 | from nbdev import serve_drv
50 | from nbdev.quarto import _ensure_quarto
51 | from nbdev.quarto import *
52 | import nbdev
53 |
54 | # %% ../nbs/19_cli.ipynb 8
55 | @call_parse
56 | @delegates(nbglob_cli)
57 | def proc_nbs(
58 | path:str='', # Path to notebooks
59 | n_workers:int=defaults.cpus, # Number of workers
60 | force:bool=False, # Ignore cache and build all
61 | file_glob:str='', # Only include files matching glob
62 | verbose:bool=False, # verbose outputs
63 | one2one:bool=True, # Run 1 notebook per process instance.
64 | **kwargs):
65 | "Process notebooks in `path` for docs rendering"
66 | cfg = get_config()
67 | cache = cfg.config_path/'_proc'
68 | path = Path(path or cfg.nbs_path)
69 | files = nbglob(path, func=Path, file_glob=file_glob, **kwargs)
70 | if (path/'_quarto.yml').exists(): files.append(path/'_quarto.yml')
71 |
72 | # If settings.ini or filter script newer than cache folder modified, delete cache
73 | chk_mtime = max(cfg.config_file.stat().st_mtime, Path(__file__).stat().st_mtime)
74 | cache.mkdir(parents=True, exist_ok=True)
75 | cache_mtime = cache.stat().st_mtime
76 | if force or (cache.exists and cache_mtime torch.LongTensor:
34 | for step in self.source_datapipe:
35 | if not issubclass(step.__class__,torch.Tensor):
36 | raise Exception(f'Expected Tensor to take the argmax, got {type(step)}\n{step}')
37 | # Might want to support simple tuples also depending on if we are processing multiple fields.
38 | idx = torch.argmax(step,axis=self.axis).reshape(-1,1)
39 | if self.only_idx:
40 | yield idx.long()
41 | continue
42 | step[:] = 0
43 | step.scatter_(1,idx,1)
44 | yield step.long()
45 |
46 |
47 | # %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 7
48 | class EpsilonSelector(dp.iter.IterDataPipe):
49 | debug=False
50 | "Given input `Tensor` from `source_datapipe`."
51 | def __init__(self,
52 | source_datapipe, # a datapipe whose next(source_datapipe) -> `Tensor`
53 | min_epsilon:float=0.2, # The minimum epsilon to drop to
54 | # The max/starting epsilon if `epsilon` is None and used for calculating epislon decrease speed.
55 | max_epsilon:float=1,
56 | # Determines how fast the episilon should drop to `min_epsilon`. This should be the number
57 | # of steps that the agent was run through.
58 | max_steps:int=100,
59 | # The starting epsilon
60 | epsilon:float=None,
61 | # Based on the `base_agent.model.training`, by default no decrement or step tracking will
62 | # occur during validation steps.
63 | decrement_on_val:bool=False,
64 | # Based on the `base_agent.model.training`, by default random actions will not be attempted
65 | select_on_val:bool=False,
66 | # Also return the mask that, where True, the action should be randomly selected.
67 | ret_mask:bool=False,
68 | # The device to create the masks one
69 | device='cpu'
70 | ):
71 | self.source_datapipe = source_datapipe
72 | self.min_epsilon = min_epsilon
73 | self.max_epsilon = max_epsilon
74 | self.max_steps = max_steps
75 | self.epsilon = epsilon
76 | self.decrement_on_val = decrement_on_val
77 | self.select_on_val = select_on_val
78 | self.ret_mask = ret_mask
79 | self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase)
80 | self.step = 0
81 | self.device = torch.device(device)
82 |
83 | def __iter__(self):
84 | for action in self.source_datapipe:
85 | # TODO: Support tuples of actions also
86 | if not issubclass(action.__class__,torch.Tensor):
87 | raise Exception(f'Expected Tensor, got {type(action)}\n{action}')
88 | if action.dtype!=torch.int64:
89 | raise ValueError(f'Expected Tensor of dtype int64, got: {action.dtype} from {self.source_datapipe}')
90 |
91 | if self.agent_base.model.training or self.decrement_on_val:
92 | self.step+=1
93 |
94 | self.epsilon = max(self.min_epsilon,self.max_epsilon-self.step/self.max_steps)
95 | # Add a batch dim if missing
96 | if len(action.shape)==1: action.unsqueeze_(0)
97 | mask = None
98 | if self.agent_base.model.training or self.select_on_val:
99 | # Given N(action.shape[0]) actions, select the ones we want to randomly assign...
100 | mask = torch.rand(action.shape[0],).to(self.device) Union[float,bool,int]:
143 | for step in self.source_datapipe:
144 | if not issubclass(step.__class__,(np.ndarray)):
145 | raise Exception(f'Expected list or np.ndarray to convert to python primitive, got {type(step)}\n{step}')
146 |
147 | if len(step)>1 or len(step)==0:
148 | raise Exception(f'`step` from {self.source_datapipe} needs to be len 1, not {len(step)}')
149 | else:
150 | step = step[0]
151 |
152 | if np.issubdtype(step.dtype,np.integer):
153 | yield int(step)
154 | elif np.issubdtype(step.dtype,np.floating):
155 | yield float(step)
156 | elif np.issubdtype(step.dtype,np.bool8):
157 | yield bool(step)
158 | else:
159 | raise Exception(f'`step` from {self.source_datapipe} must be one of the 3 python types: bool,int,float, not {step.dtype}')
160 |
--------------------------------------------------------------------------------
/fastrl/agents/core.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/07_Agents/12a_agents.core.ipynb.
2 |
3 | # %% auto 0
4 | __all__ = ['shared_model_dict', 'share_model', 'get_shared_model', 'AgentBase', 'AgentHead', 'SimpleModelRunner',
5 | 'StepFieldSelector', 'NumpyConverter']
6 |
7 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 2
8 | # Python native modules
9 | import os
10 | from typing import List,Optional
11 | # Third party libs
12 | from fastcore.all import add_docs,ifnone
13 | import torchdata.datapipes as dp
14 | import torch
15 | from torch import nn
16 | from torchdata.dataloader2.graph import traverse_dps
17 | import torch.multiprocessing as mp
18 | # Local modules
19 | from ..core import StepTypes,SimpleStep
20 | from ..torch_core import evaluating,Module
21 | from ..pipes.core import find_dps,find_dp
22 |
23 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 4
24 | # Create a manager for shared objects
25 | # manager = mp.Manager()
26 | # shared_model_dict = manager.dict()
27 | shared_model_dict = {}
28 |
29 | def share_model(model: nn.Module, name="default"):
30 | """Move model's parameters to shared memory and store in manager's dictionary."""
31 | # TODO(josiahls): This will not survive multiprocessing. We will need to us something
32 | # like ray to better sync models.
33 | model.share_memory()
34 | shared_model_dict[name] = model
35 |
36 | def get_shared_model(name="default"):
37 | """Retrieve model from shared memory using the manager's dictionary."""
38 | return shared_model_dict[name]
39 |
40 | class AgentBase(dp.iter.IterDataPipe):
41 | def __init__(self,
42 | model:Optional[nn.Module], # The base NN that we getting raw action values out of.
43 | action_iterator:list=None, # A reference to an iterator that contains actions to process.
44 | logger_bases=None
45 | ):
46 | self.model = model
47 | self.iterable = ifnone(action_iterator,[])
48 | self.agent_base = self
49 | self.logger_bases = logger_bases
50 | self._mem_name = 'agent_model'
51 |
52 | def to(self,*args,**kwargs):
53 | if self.model is not None:
54 | self.model.to(**kwargs)
55 |
56 | def __iter__(self):
57 | while self.iterable:
58 | yield self.iterable.pop(0)
59 |
60 | def __getstate__(self):
61 | if self.model is not None:
62 | share_model(self.model,self._mem_name)
63 | # Store the non-model state
64 | state = self.__dict__.copy()
65 | return state
66 |
67 | def __setstate__(self, state):
68 | self.__dict__.update(state)
69 | # Assume a globally shared model instance or a reference method to retrieve it
70 | if self.model is not None:
71 | self.model = get_shared_model(self._mem_name)
72 |
73 | add_docs(
74 | AgentBase,
75 | """Acts as the footer of the Agent pipeline.
76 | Maintains important state such as the `model` being used for get actions from.
77 | Also optionally allows passing a reference list of `action_iterator` which is a
78 | persistent list of actions for the entire agent pipeline to process through.
79 |
80 | > Important: Must be at the start of the pipeline, and be used with AgentHead at the end.
81 |
82 | > Important: `action_iterator` is stored in the `iterable` field. However the recommended
83 | way of passing actions to the pipeline is to call an `AgentHead` instance.
84 | """,
85 | to=torch.Tensor.to.__doc__
86 | )
87 |
88 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 5
89 | class AgentHead(dp.iter.IterDataPipe):
90 | def __init__(self,source_datapipe):
91 | self.source_datapipe = source_datapipe
92 | self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase)
93 |
94 | def __call__(self,steps:list):
95 | if issubclass(steps.__class__,StepTypes.types):
96 | raise Exception(f'Expected List[{StepTypes.types}] object got {type(steps)}\n{steps}')
97 | self.agent_base.iterable.extend(steps)
98 | return self
99 |
100 | def __iter__(self): yield from self.source_datapipe
101 |
102 | def augment_actions(self,actions): return actions
103 |
104 | def create_step(self,**kwargs): return SimpleStep(**kwargs,batch_size=[])
105 |
106 | add_docs(
107 | AgentHead,
108 | """Acts as the head of the Agent pipeline.
109 | Used for conveniently adding actions to the pipeline to process.
110 |
111 | > Important: Must be paired with `AgentBase`
112 | """,
113 | augment_actions="""Called right before being fed into the env.
114 |
115 | > Important: The results of this function will not be kept / used in the step or forwarded to
116 | any training code.
117 |
118 | There are cases where either the entire action shouldn't be fed into the env,
119 | or the version of the action that we want to train on would be compat with the env.
120 |
121 | This is also useful if we want to train on the original raw values of the action prior to argmax being run on it for example.
122 | """,
123 | create_step="Creates the step used by the env for running, and used by the model for training."
124 | )
125 |
126 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 6
127 | class SimpleModelRunner(dp.iter.IterDataPipe):
128 | "Takes input from `source_datapipe` and pushes through the agent bases model assuming there is only one model field."
129 | def __init__(self,
130 | source_datapipe
131 | ):
132 | self.source_datapipe = source_datapipe
133 | self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase)
134 | self.device = None
135 |
136 | def to(self,*args,**kwargs):
137 | if 'device' in kwargs: self.device = kwargs.get('device',None)
138 | return self
139 |
140 | def __iter__(self):
141 | for x in self.source_datapipe:
142 | if self.device is not None: x = x.to(self.device)
143 | if len(x.shape)==1: x = x.unsqueeze(0)
144 | with torch.no_grad():
145 | with evaluating(self.agent_base.model):
146 | res = self.agent_base.model(x)
147 | yield res
148 |
149 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 12
150 | class StepFieldSelector(dp.iter.IterDataPipe):
151 | "Grabs `field` from `source_datapipe` to push to the rest of the pipeline."
152 | def __init__(self,
153 | source_datapipe, # datapipe whose next(source_datapipe) -> `StepTypes`
154 | field='state' # A field in `StepTypes` to grab
155 | ):
156 | # TODO: support multi-fields
157 | self.source_datapipe = source_datapipe
158 | self.field = field
159 |
160 | def __iter__(self):
161 | for step in self.source_datapipe:
162 | if not issubclass(step.__class__,StepTypes.types):
163 | raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}')
164 | yield getattr(step,self.field)
165 |
166 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 22
167 | class NumpyConverter(dp.iter.IterDataPipe):
168 | debug=False
169 |
170 | def __init__(self,source_datapipe):
171 | self.source_datapipe = source_datapipe
172 |
173 | def debug_display(self,step):
174 | print(f'Step: {step}')
175 |
176 | def __iter__(self) -> torch.LongTensor:
177 | for step in self.source_datapipe:
178 | if not issubclass(step.__class__,torch.Tensor):
179 | raise Exception(f'Expected Tensor to convert to numpy, got {type(step)}\n{step}')
180 | if self.debug: self.debug_display(step)
181 | yield step.detach().cpu().numpy()
182 |
183 | add_docs(
184 | NumpyConverter,
185 | """Given input `Tensor` from `source_datapipe` returns a numpy array of same shape with argmax set to 1.""",
186 | debug_display="Display the step being processed"
187 | )
188 |
--------------------------------------------------------------------------------