├── fastrl ├── envs │ ├── __init__.py │ ├── debug_env.py │ └── continuous_debug_env.py ├── agents │ ├── __init__.py │ ├── dqn │ │ ├── __init__.py │ │ ├── dueling.py │ │ ├── rainbow.py │ │ ├── double.py │ │ └── target.py │ ├── discrete.py │ └── core.py ├── funcs │ ├── __init__.py │ └── conjugation.py ├── learner │ ├── __init__.py │ └── core.py ├── loggers │ ├── __init__.py │ ├── tensorboard.py │ ├── jupyter_visualizers.py │ └── vscode_visualizers.py ├── memory │ ├── __init__.py │ ├── experience_replay.py │ └── memory_visualizer.py ├── pipes │ ├── __init__.py │ ├── map │ │ ├── __init__.py │ │ ├── transforms.py │ │ ├── mux.py │ │ └── demux.py │ ├── iter │ │ ├── __init__.py │ │ ├── firstlast.py │ │ ├── nskip.py │ │ ├── nstep.py │ │ └── cacheholder.py │ └── core.py ├── dataloading │ ├── __init__.py │ └── core.py ├── __init__.py ├── nbdev_extensions.py ├── test_utils.py └── cli.py ├── extra ├── requirements.txt ├── pip_requirements.txt └── dev_requirements.txt ├── nbs ├── .gitignore ├── 11_FAQ │ ├── header_dates.json │ ├── 99_template.ipynb │ ├── 99_notes.multi_proc.ipynb │ └── 99_notes.speed.ipynb ├── images │ ├── cat.jpg │ ├── half.png │ ├── mnist3.png │ ├── puppy.jpg │ ├── sample.dcm │ ├── ulmfit.png │ ├── favicon.ico │ ├── layered.png │ ├── pets │ │ ├── cat.jpg │ │ └── puppy.jpg │ ├── att_00000.png │ ├── att_00001.png │ ├── att_00002.png │ ├── hf_hub_fastai.png │ ├── hf_model_card.png │ ├── pixelshuffle.png │ ├── Mixed_precision.jpeg │ ├── half_representation.png │ ├── siim_folder_structure.jpeg │ ├── 10e_agents.dqn.categorical_algorithm1.png │ ├── (Lillicrap et al., 2016) DDPG Algorithm 1.png │ └── (Schulman et al., 2017) [TRPO] Trust Region Policy Optimization Algorithm 1.png ├── 01_DataPipes │ └── sidebar.yml ├── testing │ ├── fastrl_test_env.yaml │ ├── test_settings.ini │ └── fastrl_test_dev_env.yaml ├── nbdev.yml ├── _quarto.yml ├── styles.css ├── external_run_scripts │ ├── spawn_multiproc.py │ ├── agents_dqn_async_35.py │ └── notes_multi_proc_82.py ├── 12_Blog │ ├── 99_blog.from_xxxx_xx_to_xx.ipynb │ └── 99_blog.from_2023_05_to_now.ipynb ├── 00_template.ipynb ├── sidebar.yml ├── 05_Logging │ ├── 09e_loggers.tensorboard.ipynb │ ├── 09d_loggers.jupyter_visualizers.ipynb │ └── 09f_loggers.vscode_visualizers.ipynb ├── README.md ├── 07_Agents │ └── 01_Discrete │ │ └── 12n_agents.dqn.dueling.ipynb └── 02_DataLoading │ └── 00_core.ipynb ├── .gitmodules ├── MANIFEST.in ├── fastrl.code-workspace ├── .pre-commit-config.yaml ├── Makefile ├── .devcontainer.json ├── docker-compose.yml ├── .github └── workflows │ ├── fastrl-test.yml │ ├── quarto_deploy.yml │ ├── python-publish.yml │ └── fastrl-docker.yml ├── settings.ini ├── .gitignore ├── CONTRIBUTING.md ├── setup.py ├── README.md └── fastrl.Dockerfile /fastrl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/funcs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/learner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/memory/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/pipes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/pipes/map/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/agents/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/dataloading/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastrl/pipes/iter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /extra/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 -------------------------------------------------------------------------------- /nbs/.gitignore: -------------------------------------------------------------------------------- 1 | /.quarto/ 2 | /_docs/ 3 | -------------------------------------------------------------------------------- /fastrl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.47" 2 | -------------------------------------------------------------------------------- /nbs/11_FAQ/header_dates.json: -------------------------------------------------------------------------------- 1 | {"1": "2023-05-29"} -------------------------------------------------------------------------------- /nbs/images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/cat.jpg -------------------------------------------------------------------------------- /nbs/images/half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/half.png -------------------------------------------------------------------------------- /nbs/images/mnist3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/mnist3.png -------------------------------------------------------------------------------- /nbs/images/puppy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/puppy.jpg -------------------------------------------------------------------------------- /nbs/images/sample.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/sample.dcm -------------------------------------------------------------------------------- /nbs/images/ulmfit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/ulmfit.png -------------------------------------------------------------------------------- /nbs/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/favicon.ico -------------------------------------------------------------------------------- /nbs/images/layered.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/layered.png -------------------------------------------------------------------------------- /nbs/images/pets/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/pets/cat.jpg -------------------------------------------------------------------------------- /nbs/images/att_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/att_00000.png -------------------------------------------------------------------------------- /nbs/images/att_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/att_00001.png -------------------------------------------------------------------------------- /nbs/images/att_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/att_00002.png -------------------------------------------------------------------------------- /nbs/images/pets/puppy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/pets/puppy.jpg -------------------------------------------------------------------------------- /nbs/images/hf_hub_fastai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/hf_hub_fastai.png -------------------------------------------------------------------------------- /nbs/images/hf_model_card.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/hf_model_card.png -------------------------------------------------------------------------------- /nbs/images/pixelshuffle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/pixelshuffle.png -------------------------------------------------------------------------------- /nbs/images/Mixed_precision.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/Mixed_precision.jpeg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data"] 2 | path = data 3 | url = https://github.com/josiahls/data.git 4 | ignore = dirty 5 | -------------------------------------------------------------------------------- /nbs/images/half_representation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/half_representation.png -------------------------------------------------------------------------------- /nbs/images/siim_folder_structure.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/siim_folder_structure.jpeg -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /fastrl.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "path": "." 5 | } 6 | ], 7 | "settings": {}, 8 | "editor.rulers": [80,120], 9 | } -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/fastai/nbdev 3 | rev: 2.2.10 4 | hooks: 5 | - id: nbdev_clean 6 | - id: nbdev_export -------------------------------------------------------------------------------- /nbs/images/10e_agents.dqn.categorical_algorithm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/10e_agents.dqn.categorical_algorithm1.png -------------------------------------------------------------------------------- /nbs/images/(Lillicrap et al., 2016) DDPG Algorithm 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/(Lillicrap et al., 2016) DDPG Algorithm 1.png -------------------------------------------------------------------------------- /extra/pip_requirements.txt: -------------------------------------------------------------------------------- 1 | gymnasium>=0.27.1 2 | tensordict 3 | pyopengl 4 | pyglet 5 | tensorboard 6 | pygame 7 | pandas 8 | scipy 9 | scikit-learn 10 | fastcore 11 | fastprogress -------------------------------------------------------------------------------- /nbs/images/(Schulman et al., 2017) [TRPO] Trust Region Policy Optimization Algorithm 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josiahls/fastrl/HEAD/nbs/images/(Schulman et al., 2017) [TRPO] Trust Region Policy Optimization Algorithm 1.png -------------------------------------------------------------------------------- /nbs/01_DataPipes/sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | sidebar: 3 | contents: 4 | - 01a_pipes.core.ipynb 5 | - 01b_pipes.map.demux.ipynb 6 | - 01c_pipes.map.mux.ipynb 7 | - 01d_pipes.iter.nskip.ipynb 8 | - 01e_pipes.iter.nstep.ipynb -------------------------------------------------------------------------------- /fastrl/funcs/conjugation.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/02_Funcs/00_funcs.conjugation.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = [] 5 | 6 | # %% ../../nbs/02_Funcs/00_funcs.conjugation.ipynb 3 7 | # Python native modules 8 | 9 | # Third party libs 10 | import torch 11 | import numpy as np 12 | # Local modules 13 | -------------------------------------------------------------------------------- /extra/dev_requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | gymnasium[all]>=0.27.1 3 | nbdev>=2.3.1 4 | pre-commit 5 | ipywidgets 6 | moviepy 7 | pygifsicle 8 | aquirdturtle_collapsible_headings 9 | plotly 10 | matplotlib_inline 11 | wheel 12 | twine 13 | fastdownload 14 | watchdog[watchmedo] 15 | graphviz 16 | typing-extensions>=4.3.0 17 | spacy<4 18 | mypy 19 | pyvirtualdisplay -------------------------------------------------------------------------------- /nbs/testing/fastrl_test_env.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - pytorch 4 | - fastai 5 | dependencies: 6 | - python=3.6 7 | - pip 8 | - setuptools 9 | - fastai>=2.0.0 10 | - moviepy 11 | - jupyter 12 | - notebook 13 | - setuptools 14 | - pip: 15 | - pytest 16 | - nvidia-ml-py3 17 | - dataclasses 18 | - pandas 19 | - pyyaml 20 | name: fastrl_test 21 | -------------------------------------------------------------------------------- /nbs/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "fastrl" 6 | site-url: "https://josiahls.github.io/fastrl/" 7 | description: "fastrl is a reinforcement learning library that extends Fastai. This project is not affiliated with fastai or Jeremy Howard." 8 | repo-branch: main 9 | repo-url: "https://github.com/josiahls/fastrl/tree/main/" 10 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | css: styles.css 8 | toc: true 9 | 10 | website: 11 | twitter-card: true 12 | open-graph: true 13 | repo-actions: [issue] 14 | navbar: 15 | background: primary 16 | search: true 17 | sidebar: 18 | style: floating 19 | 20 | metadata-files: [nbdev.yml, sidebar.yml] -------------------------------------------------------------------------------- /nbs/styles.css: -------------------------------------------------------------------------------- 1 | .cell-output pre { 2 | margin-left: 0.8rem; 3 | margin-top: 0; 4 | background: none; 5 | border-left: 2px solid lightsalmon; 6 | border-top-left-radius: 0; 7 | border-top-right-radius: 0; 8 | } 9 | 10 | .cell-output .sourceCode { 11 | background: none; 12 | margin-top: 0; 13 | } 14 | 15 | .cell > .sourceCode { 16 | margin-bottom: 0; 17 | } -------------------------------------------------------------------------------- /nbs/testing/test_settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | lib_name = fastrl_test 3 | user = josiahls 4 | branch = master 5 | version = 0.0.1 6 | min_python = 3.6 7 | requirements = fastai>=2.0.0 moviepy 8 | pip_requirements = pytest nvidia-ml-py3 dataclasses pandas pyyaml 9 | conda_requirements = jupyter notebook setuptools 10 | dev_requirements = jupyterlab nbdev ipywidgets moviepy pygifsicle aquirdturtle_collapsible_headings 11 | -------------------------------------------------------------------------------- /nbs/testing/fastrl_test_dev_env.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - pytorch 4 | - fastai 5 | dependencies: 6 | - python=3.6 7 | - pip 8 | - setuptools 9 | - fastai>=2.0.0 10 | - moviepy 11 | - jupyter 12 | - notebook 13 | - setuptools 14 | - jupyterlab 15 | - nbdev 16 | - ipywidgets 17 | - moviepy 18 | - pygifsicle 19 | - aquirdturtle_collapsible_headings 20 | - pip: 21 | - pytest 22 | - nvidia-ml-py3 23 | - dataclasses 24 | - pandas 25 | - pyyaml 26 | name: fastrl_test_dev 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | SHELL := /bin/bash 3 | SRC = $(wildcard nbs/*.ipynb) 4 | 5 | all: fastrl docs 6 | 7 | fastrl: $(SRC) 8 | nbdev_build_lib 9 | touch fastrl 10 | 11 | sync: 12 | nbdev_update_lib 13 | 14 | docs_serve: docs 15 | cd docs && bundle exec jekyll serve 16 | 17 | docs: $(SRC) 18 | nbdev_build_docs 19 | touch docs 20 | 21 | test: 22 | nbdev_test_nbs 23 | 24 | release: pypi conda_release 25 | nbdev_bump_version 26 | 27 | conda_release: 28 | fastrelease_conda_package 29 | 30 | pypi: dist 31 | twine upload --repository pypi dist/* 32 | 33 | dist: clean 34 | python setup.py sdist bdist_wheel 35 | 36 | clean: 37 | rm -rf dist -------------------------------------------------------------------------------- /.devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fastrl_development", 3 | "dockerComposeFile": "docker-compose.yml", 4 | "service": "fastrl", 5 | "customizations":{ 6 | "vscode": { 7 | "settings": {"terminal.integrated.shell.linux": "/bin/bash"} 8 | } 9 | }, 10 | // "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind", 11 | // "source=/usr/bin/docker,target=/usr/bin/docker,type=bind" ], 12 | "workspaceFolder": "/home/fastrl_user/fastrl", 13 | // "forwardPorts": [4000, 8080], 14 | // "appPort": [4000, 8080], 15 | // "extensions": [ 16 | // "ms-python.python", 17 | // "ms-azuretools.vscode-docker", 18 | // "ms-toolsai.jupyter-renderers" 19 | // ], 20 | "runServices": ["dep_watcher", "quarto"] //, 21 | // "postStartCommand": "pip install -e .[dev]" 22 | } -------------------------------------------------------------------------------- /nbs/external_run_scripts/spawn_multiproc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchdata.datapipes as dp 3 | from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService 4 | 5 | class PointlessLoop(dp.iter.IterDataPipe): 6 | def __init__(self,datapipe=None): 7 | self.datapipe = datapipe 8 | 9 | def __iter__(self): 10 | while True: 11 | yield torch.LongTensor(4).detach().clone() 12 | 13 | 14 | if __name__=='__main__': 15 | from torch.multiprocessing import Pool, Process, set_start_method 16 | try: 17 | set_start_method('spawn') 18 | except RuntimeError: 19 | pass 20 | 21 | 22 | pipe = PointlessLoop() 23 | pipe = pipe.header(limit=10) 24 | dls = [DataLoader2(pipe, 25 | reading_service=MultiProcessingReadingService( 26 | num_workers = 2 27 | ))] 28 | # Setup the Learner 29 | print('type: ',type(dls[0])) 30 | for o in dls[0]: 31 | print(o) 32 | -------------------------------------------------------------------------------- /nbs/12_Blog/99_blog.from_xxxx_xx_to_xx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "assisted-contract", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Python native modules\n", 11 | "# Third party libs\n", 12 | "\n", 13 | "# Local modules\n", 14 | "from fastrl.nbdev_extensions import header" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "40e740d6", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "#|echo: false\n", 25 | "header(\n", 26 | " \"Test run\",\n", 27 | " \"subtitle\",\n", 28 | " freeze=False\n", 29 | ")" 30 | ] 31 | }, 32 | { 33 | "attachments": {}, 34 | "cell_type": "markdown", 35 | "id": "fec5e005", 36 | "metadata": {}, 37 | "source": [] 38 | } 39 | ], 40 | "metadata": { 41 | "kernelspec": { 42 | "display_name": "python3", 43 | "language": "python", 44 | "name": "python3" 45 | } 46 | }, 47 | "nbformat": 4, 48 | "nbformat_minor": 5 49 | } 50 | -------------------------------------------------------------------------------- /fastrl/agents/dqn/dueling.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DuelingHead'] 5 | 6 | # %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 2 7 | # Python native modules 8 | # Third party libs 9 | import torch 10 | from torch import nn 11 | # Local modules 12 | from fastrl.agents.dqn.basic import ( 13 | DQN, 14 | DQNAgent 15 | ) 16 | from .target import DQNTargetLearner 17 | 18 | # %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 5 19 | class DuelingHead(nn.Module): 20 | def __init__( 21 | self, 22 | hidden: int, # Input into the DuelingHead, likely a hidden layer input 23 | n_actions: int, # Number/dim of actions to output 24 | lin_cls = nn.Linear 25 | ): 26 | super().__init__() 27 | self.val = lin_cls(hidden,1) 28 | self.adv = lin_cls(hidden,n_actions) 29 | 30 | def forward(self,xi): 31 | val,adv = self.val(xi),self.adv(xi) 32 | xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0) 33 | return xi 34 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | services: 3 | fastrl_build: 4 | build: 5 | dockerfile: fastrl.Dockerfile 6 | context: . 7 | image: josiahls/fastrl-dev:latest 8 | profiles: ["build"] 9 | 10 | fastrl: &fastrl 11 | restart: unless-stopped 12 | working_dir: /home/fastrl_user/fastrl 13 | image: josiahls/fastrl-dev:latest 14 | deploy: 15 | resources: 16 | reservations: 17 | devices: 18 | - capabilities: 19 | - gpu 20 | logging: 21 | driver: json-file 22 | options: 23 | max-size: 50m 24 | stdin_open: true 25 | tty: true 26 | shm_size: 16000000000 27 | volumes: 28 | - .:/home/fastrl_user/fastrl/ 29 | - ~/.ssh:/home/fastrl_user/.ssh:rw 30 | network_mode: host # for GitHub Codespaces https://github.com/features/codespaces/ 31 | 32 | dep_watcher: 33 | <<: *fastrl 34 | command: watchmedo shell-command --command fastrl_make_requirements --pattern *.ini --recursive --drop 35 | 36 | quarto: 37 | <<: *fastrl 38 | restart: unless-stopped 39 | working_dir: /home/fastrl_user/fastrl/nbs/_docs 40 | volumes: 41 | - .:/home/fastrl_user/fastrl/ 42 | command: nbdev_preview 43 | -------------------------------------------------------------------------------- /nbs/11_FAQ/99_template.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "offshore-stuart", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# #|default_exp template" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "assisted-contract", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# #|export\n", 33 | "# Python native modules\n", 34 | "import os\n", 35 | "# Third party libs\n", 36 | "from fastcore.all import *\n", 37 | "# Local modules" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "40e740d6", 43 | "metadata": {}, 44 | "source": [ 45 | "# Template\n", 46 | "> notebook" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "current-pilot", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "#|hide\n", 57 | "#|eval: false\n", 58 | "!nbdev_export" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "kernelspec": { 64 | "display_name": "python3", 65 | "language": "python", 66 | "name": "python3" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 5 71 | } 72 | -------------------------------------------------------------------------------- /fastrl/loggers/tensorboard.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_Logging/09e_loggers.tensorboard.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['run_tensorboard'] 5 | 6 | # %% ../../nbs/05_Logging/09e_loggers.tensorboard.ipynb 1 7 | # Python native modules 8 | import os 9 | from pathlib import Path 10 | # Third party libs 11 | import torchdata.datapipes as dp 12 | # Local modules 13 | 14 | # %% ../../nbs/05_Logging/09e_loggers.tensorboard.ipynb 4 15 | def run_tensorboard( 16 | port:int=6006, # The port to run tensorboard on/connect on 17 | start_tag:str=None, # Starting regex e.g.: experience_replay/1 18 | samples_per_plugin:str=None, # Sampling freq such as images=0 (keep all) 19 | extra_args:str=None, # Any additional arguments in the `--arg value` format 20 | rm_glob:bool=None # Remove old logs via a parttern e.g.: '*' will remove all files: runs/* 21 | ): 22 | if rm_glob is not None: 23 | for p in Path('runs').glob(rm_glob): p.delete() 24 | import socket 25 | from tensorboard import notebook 26 | a_socket=socket.socket(socket.AF_INET, socket.SOCK_STREAM) 27 | cmd=None 28 | if not a_socket.connect_ex(('127.0.0.1',6006)): 29 | notebook.display(port=port,height=1000) 30 | else: 31 | cmd=f'--logdir runs --port {port} --host=0.0.0.0' 32 | if samples_per_plugin is not None: cmd+=f' --samples_per_plugin {samples_per_plugin}' 33 | if start_tag is not None: cmd+=f' --tag {start_tag}' 34 | if extra_args is not None: cmd+=f' {extra_args}' 35 | notebook.start(cmd) 36 | return cmd 37 | -------------------------------------------------------------------------------- /nbs/00_template.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "offshore-stuart", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# #|default_exp template" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "assisted-contract", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# #|export\n", 33 | "# Python native modules\n", 34 | "\n", 35 | "# Third party libs\n", 36 | "\n", 37 | "# Local modules" 38 | ] 39 | }, 40 | { 41 | "attachments": {}, 42 | "cell_type": "markdown", 43 | "id": "a258abcf", 44 | "metadata": {}, 45 | "source": [ 46 | "# Template" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "current-pilot", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "#|hide\n", 57 | "#|eval: false\n", 58 | "!nbdev_export" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "ed71a089", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "python3", 73 | "language": "python", 74 | "name": "python3" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 5 79 | } 80 | -------------------------------------------------------------------------------- /fastrl/pipes/map/transforms.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01h_pipes.map.transforms.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['T_co', 'TypeTransformer'] 5 | 6 | # %% ../../../nbs/01_DataPipes/01h_pipes.map.transforms.ipynb 3 7 | # Python native modules 8 | from typing import Callable,Union,TypeVar 9 | # Third party libs 10 | from fastcore.all import * 11 | import torchdata.datapipes as dp 12 | from torchdata.datapipes.map import MapDataPipe 13 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe 14 | # Local modules 15 | 16 | # %% ../../../nbs/01_DataPipes/01h_pipes.map.transforms.ipynb 5 17 | T_co = TypeVar("T_co", covariant=True) 18 | 19 | class TypeTransformer(dp.map.MapDataPipe): 20 | def __init__( 21 | self, 22 | # Should allow `__getitem__` and producing elements to be injested by `type_tfms` 23 | source_datapipe:MapDataPipe[T_co], 24 | # A list of Callables that accept an input, and return an output 25 | type_tfms:List[Callable] 26 | ) -> None: 27 | self.type_tfms:Pipeline[Callable] = Pipeline(type_tfms) 28 | self.source_datapipe:MapDataPipe[T_co] = source_datapipe 29 | 30 | def __getitem__(self, index) -> T_co: 31 | data = self.source_datapipe[index] 32 | return self.type_tfms(data) 33 | 34 | def __len__(self) -> int: return len(self.source_datapipe) 35 | 36 | TypeTransformer.__doc__ = """On `__getitem__` functions in `self.type_tfms` get called over each element. 37 | Generally `TypeTransformer` as the name suggests is intended to convert elements from one type to another. 38 | reference documentation on how to combine this with `InMemoryCacheHolder`.""" 39 | -------------------------------------------------------------------------------- /.github/workflows/fastrl-test.yml: -------------------------------------------------------------------------------- 1 | name: Fastrl Testing 2 | on: [push, pull_request] 3 | 4 | # env: 5 | # PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/nightly/cu113 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | container: 11 | image: 'josiahls/fastrl-dev:latest' 12 | 13 | steps: 14 | - uses: actions/checkout@v1 15 | # - uses: actions/setup-python@v1 16 | # with: 17 | # python-version: '3.7' 18 | # architecture: 'x64' 19 | - name: Install the library 20 | run: | 21 | sudo mkdir -p /github/home 22 | sudo pip install -e .["dev"] 23 | # - name: Read all notebooks 24 | # run: | 25 | # nbdev_read_nbs 26 | - name: Check if all notebooks are cleaned 27 | run: | 28 | sudo git config --global --add safe.directory /__w/fastrl/fastrl 29 | echo "Check we are starting with clean git checkout" 30 | if [ -n "$(git status -uno -s)" ]; then echo "git status is not clean"; false; fi 31 | echo "Trying to strip out notebooks" 32 | sudo nbdev_clean 33 | echo "Check that strip out was unnecessary" 34 | git status -s # display the status to see which nbs need cleaning up 35 | if [ -n "$(git status -uno -s)" ]; then echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_hooks"; false; fi 36 | 37 | # - name: Check if there is no diff library/notebooks 38 | # run: | 39 | # if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi 40 | - name: Run tests 41 | run: | 42 | pip3 show torchdata torch 43 | sudo pip3 install -e . 44 | cd nbs 45 | xvfb-run -s "-screen 0 1400x900x24" fastrl_nbdev_test --n_workers 12 --one2one 46 | - name: Run Doc Build Test 47 | run: | 48 | fastrl_nbdev_docs --one2one 49 | -------------------------------------------------------------------------------- /fastrl/dataloading/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/02_DataLoading/00_core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['dataloaders'] 5 | 6 | # %% ../../nbs/02_DataLoading/00_core.ipynb 2 7 | # Python native modules 8 | from typing import Tuple,Union,List 9 | # Third party libs 10 | import torchdata.datapipes as dp 11 | from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2 12 | from fastcore.all import delegates 13 | # Local modules 14 | 15 | # %% ../../nbs/02_DataLoading/00_core.ipynb 4 16 | @delegates(MultiProcessingReadingService) 17 | def dataloaders( 18 | # A tuple of iterable datapipes to generate dataloaders from. 19 | pipes:Union[Tuple[dp.iter.IterDataPipe],dp.iter.IterDataPipe], 20 | # Concat the dataloaders together 21 | do_concat:bool = False, 22 | # Multiplex the dataloaders 23 | do_multiplex:bool = False, 24 | # Number of workers the dataloaders should run in 25 | num_workers: int = 0, 26 | **kwargs 27 | ) -> Union[dp.iter.IterDataPipe,List[dp.iter.IterDataPipe]]: 28 | "Function that creates dataloaders based on `pipes` with different ways of combing them." 29 | if not isinstance(pipes,tuple): 30 | pipes = (pipes,) 31 | 32 | dls = [] 33 | for pipe in pipes: 34 | dl = DataLoader2( 35 | datapipe=pipe, 36 | reading_service=MultiProcessingReadingService( 37 | num_workers = num_workers, 38 | **kwargs 39 | ) if num_workers > 0 else None 40 | ) 41 | dl = dp.iter.IterableWrapper(dl,deepcopy=False) 42 | dls.append(dl) 43 | #TODO(josiahls): Not sure if this is needed tbh.. Might be better to just 44 | # return dls, and have the user wrap them if they want. Then try can do more complex stuff. 45 | if do_concat: 46 | return dp.iter.Concater(*dls) 47 | elif do_multiplex: 48 | return dp.iter.Multiplexer(*dls) 49 | else: 50 | return dls 51 | 52 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | lib_name = fastrl 3 | description = fastrl is a reinforcement learning library that extends Fastai. This project is not affiliated with fastai or Jeremy Howard. 4 | copyright = 2021 onwards, Josiah Laivins 5 | keywords = fastrl reinforcement learning rl robotics fastai deep learning machine learning 6 | user = josiahls 7 | author = Josiah Laivins, and contributors 8 | author_email = 9 | branch = main 10 | min_python = 3.7 11 | version = 0.0.47 12 | audience = Developers 13 | language = English 14 | custom_sidebar = True 15 | license = apache2 16 | status = 2 17 | dep_links = https://download.pytorch.org/whl/nightly/cu117 18 | requirements = torch>=2.0.0 19 | data_requirements = git+https://github.com/josiahls/data.git@main#egg=torchdata 20 | pip_requirements = gymnasium>=0.27.1 tensordict pyopengl pyglet tensorboard pygame pandas scipy scikit-learn fastcore fastprogress 21 | #nbformat>=4.2.0,<5.* 22 | conda_requirements = 23 | dev_requirements = jupyter gymnasium[all]>=0.27.1 nbdev>=2.3.1 pre-commit ipywidgets moviepy pygifsicle aquirdturtle_collapsible_headings plotly matplotlib_inline wheel twine fastdownload watchdog[watchmedo] graphviz typing-extensions>=4.3.0 spacy<4 mypy pyvirtualdisplay 24 | console_scripts = fastrl_make_requirements=fastrl.cli:fastrl_make_requirements 25 | fastrl_nbdev_docs=fastrl.cli:fastrl_nbdev_docs 26 | fastrl_nbdev_test=fastrl.cli:fastrl_nbdev_test 27 | fastrl_nbdev_create_blog_notebook=fastrl.nbdev_extensions:create_blog_notebook 28 | tst_flags = slow cpp cuda multicuda 29 | nbs_path = nbs 30 | doc_path = _docs 31 | recursive = True 32 | doc_host = https://josiahls.github.io 33 | doc_baseurl = /fastrl/ 34 | git_url = https://github.com/josiahls/fastrl/tree/main/ 35 | lib_path = fastrl 36 | title = fastrl 37 | black_formatting = False 38 | readme_nb = index.ipynb 39 | allowed_metadata_keys = 40 | allowed_cell_metadata_keys = 41 | jupyter_hooks = True 42 | clean_ids = True 43 | clear_all = False 44 | put_version_in_init = True 45 | 46 | -------------------------------------------------------------------------------- /.github/workflows/quarto_deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | on: 3 | push: 4 | branches: [ "main", "master", "update_to_nbdev_2", "*docs" ] 5 | workflow_dispatch: 6 | 7 | env: 8 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/nightly/cu113 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | container: 14 | image: 'josiahls/fastrl-dev:latest' 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Install Dependencies 19 | shell: bash 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -e ".[dev]" --upgrade 23 | fastrl_nbdev_docs --one2one 24 | # - uses: fastai/workflows/quarto-ghp@master 25 | - name: Enable GitHub Pages 26 | shell: python 27 | run: | 28 | import ghapi.core,nbdev.config,sys 29 | msg="Please enable GitHub Pages to publish from the root of the `gh-pages` branch per these instructions - https://docs.github.com/en/pages/getting-started-with-github-pages/configuring-a-publishing-source-for-your-github-pages-site#publishing-from-a-branch" 30 | try: 31 | api = ghapi.core.GhApi(owner=nbdev.config.get_config().user, repo=nbdev.config.get_config().repo, token="${{secrets.GITHUB_TOKEN}}") 32 | api.enable_pages(branch='gh-pages') 33 | except Exception as e: 34 | print(f'::error title="Could not enable GitHub Pages Automatically":: {msg}\n{e}') 35 | sys.exit(1) 36 | - name: Deploy to GitHub Pages 37 | uses: peaceiris/actions-gh-pages@v3 38 | with: 39 | github_token: ${{secrets.GITHUB_TOKEN}} 40 | force_orphan: true 41 | publish_dir: ./_docs 42 | # The following lines assign commit authorship to the official GH-Actions bot for deploys to `gh-pages` branch. 43 | # You can swap them out with your own user credentials. 44 | user_name: github-actions[bot] 45 | user_email: 41898282+github-actions[bot]@users.noreply.github.com 46 | 47 | -------------------------------------------------------------------------------- /nbs/external_run_scripts/agents_dqn_async_35.py: -------------------------------------------------------------------------------- 1 | # %%python 2 | 3 | if __name__=='__main__': 4 | from torch.multiprocessing import Pool, Process, set_start_method 5 | 6 | try: 7 | set_start_method('spawn') 8 | except RuntimeError: 9 | pass 10 | 11 | from fastcore.all import * 12 | import torch 13 | from torch.nn import * 14 | import torch.nn.functional as F 15 | from fastrl.loggers.core import * 16 | from fastrl.loggers.jupyter_visualizers import * 17 | from fastrl.learner.core import * 18 | from fastrl.data.block import * 19 | from fastrl.envs.gym import * 20 | from fastrl.agents.core import * 21 | from fastrl.agents.discrete import * 22 | from fastrl.agents.dqn.basic import * 23 | from fastrl.agents.dqn.asynchronous import * 24 | 25 | from torchdata.dataloader2 import DataLoader2 26 | from torchdata.dataloader2.graph import traverse 27 | from fastrl.data.dataloader2 import * 28 | 29 | logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector, 30 | batch_on_pipe=BatchCollector) 31 | 32 | # Setup up the core NN 33 | torch.manual_seed(0) 34 | model = DQN(4,2).cuda() 35 | # model.share_memory() # This will not work in spawn 36 | # Setup the Agent 37 | agent = DQNAgent(model,max_steps=4000,device='cuda', 38 | dp_augmentation_fns=[ModelSubscriber.insert_dp()]) 39 | # Setup the DataBlock 40 | block = DataBlock( 41 | GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False), 42 | GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,include_images=True) 43 | ) 44 | dls = L(block.dataloaders(['CartPole-v1']*1,num_workers=1)) 45 | # # Setup the Learner 46 | learner = DQNLearner(model,dls,batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda', 47 | dp_augmentation_fns=[ModelPublisher.insert_dp(publish_freq=10)]) 48 | # print(traverse(learner)) 49 | learner.fit(20) 50 | -------------------------------------------------------------------------------- /nbs/external_run_scripts/notes_multi_proc_82.py: -------------------------------------------------------------------------------- 1 | import torchdata.datapipes as dp 2 | from torch.utils.data import IterableDataset 3 | 4 | class AddABunch1(dp.iter.IterDataPipe): 5 | def __init__(self,q): 6 | super().__init__() 7 | self.q = [q] 8 | 9 | def __iter__(self): 10 | for o in range(10): 11 | self.q[0].put(o) 12 | yield o 13 | 14 | class AddABunch2(dp.iter.IterDataPipe): 15 | def __init__(self,source_datapipe,q): 16 | super().__init__() 17 | self.q = q 18 | print(id(self.q)) 19 | self.source_datapipe = source_datapipe 20 | 21 | def __iter__(self): 22 | for o in self.source_datapipe: 23 | print(id(self.q)) 24 | self.q.put(o) 25 | yield o 26 | 27 | class AddABunch3(IterableDataset): 28 | def __init__(self,q): 29 | self.q = q 30 | 31 | def __iter__(self): 32 | for o in range(10): 33 | print(id(self.q)) 34 | self.q.put(o) 35 | yield o 36 | 37 | if __name__=='__main__': 38 | from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method 39 | import torch 40 | 41 | try: set_start_method('spawn') 42 | except RuntimeError: pass 43 | # from torch.utils.data.dataloader_experimental import DataLoader2 44 | from torchdata.dataloader2 import DataLoader2 45 | from torchdata.dataloader2.reading_service import MultiProcessingReadingService 46 | 47 | m = Manager() 48 | q = m.Queue() 49 | 50 | pipe = AddABunch2(list(range(10)),q) 51 | print(type(pipe)) 52 | dl = DataLoader2(pipe, 53 | reading_service=MultiProcessingReadingService(num_workers=1) 54 | ) # Will fail if num_workers>0 55 | 56 | # dl = DataLoader2(AddABunch1(q),num_workers=1) # Will fail if num_workers>0 57 | # dl = DataLoader2(AddABunch2(q),num_workers=1) # Will fail if num_workers>0 58 | # dl = DataLoader2(AddABunch3(q),num_workers=1) # Will succeed if num_workers>0 59 | list(dl) 60 | 61 | while not q.empty(): 62 | print(q.get()) 63 | -------------------------------------------------------------------------------- /fastrl/envs/debug_env.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03_Environment/06_envs.debug_env.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SimpleContinuousEnv'] 5 | 6 | # %% ../../nbs/03_Environment/06_envs.debug_env.ipynb 2 7 | # Python native modules 8 | import os 9 | # Third party libs 10 | import gymnasium as gym 11 | from gymnasium import spaces 12 | import numpy as np 13 | # Local modules 14 | 15 | # %% ../../nbs/03_Environment/06_envs.debug_env.ipynb 4 16 | class SimpleContinuousEnv(gym.Env): 17 | metadata = {'render.modes': ['console']} 18 | 19 | def __init__(self, goal_position=None, proximity_threshold=0.5): 20 | super(SimpleContinuousEnv, self).__init__() 21 | 22 | self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32) 23 | self.observation_space = spaces.Box(low=-10, high=10, shape=(1,), dtype=np.float32) 24 | 25 | self.goal_position = goal_position if goal_position is not None else np.random.uniform(-10, 10) 26 | self.proximity_threshold = proximity_threshold 27 | self.state = None 28 | 29 | def step(self, action): 30 | self.state += action 31 | 32 | # Calculate the distance to the goal 33 | distance_to_goal = np.abs(self.state - self.goal_position) 34 | 35 | # Calculate reward: higher for being closer to the goal 36 | reward = -distance_to_goal 37 | 38 | # Check if the agent is within the proximity threshold of the goal 39 | done = distance_to_goal <= self.proximity_threshold 40 | 41 | info = {} 42 | 43 | return self.state, reward, done, info 44 | 45 | def reset(self): 46 | self.state = np.array([0.0], dtype=np.float32) 47 | if self.goal_position is None: 48 | self.goal_position = np.random.uniform(-10, 10) 49 | return self.state 50 | 51 | 52 | def render(self, mode='console'): 53 | if mode != 'console': 54 | raise NotImplementedError("Only console mode is supported.") 55 | print(f"Position: {self.state} Goal: {self.goal_position}") 56 | 57 | 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tags 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # dotenv 94 | .env 95 | 96 | # virtualenv 97 | .venv 98 | venv/ 99 | ENV/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | 114 | .vscode 115 | *.swp 116 | 117 | # osx generated files 118 | .DS_Store 119 | .DS_Store? 120 | .Trashes 121 | ehthumbs.db 122 | Thumbs.db 123 | .idea 124 | 125 | # pytest 126 | .pytest_cache 127 | 128 | # tools/trust-doc-nbs 129 | docs_src/.last_checked 130 | 131 | # symlinks to fastai 132 | docs_src/fastai 133 | tools/fastai 134 | 135 | # link checker 136 | checklink/cookies.txt 137 | 138 | # .gitconfig is now autogenerated 139 | .gitconfig 140 | 141 | /bfg-1.13.0.jar 142 | /sensitive_info.txt 143 | /nbs/runs/ 144 | /quarto-linux-amd64.deb 145 | **/**/_docs 146 | /_proc/ 147 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## How to get started 4 | 5 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it: 6 | ``` 7 | nbdev_install_git_hooks 8 | ``` 9 | 10 | ## Did you find a bug? 11 | 12 | * Ensure the bug was not already reported by searching on GitHub under Issues. 13 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. 14 | * Be sure to add the complete error messages. 15 | 16 | #### Did you write a patch that fixes a bug? 17 | 18 | * Open a new GitHub pull request with the patch. 19 | * Ensure that your PR includes a test that fails without your patch, and pass with it. 20 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. 21 | 22 | ## PR submission guidelines 23 | 24 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. 25 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. 26 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. 27 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. 28 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. 29 | 30 | ## Do you want to contribute to the documentation? 31 | 32 | * Docs are automatically created from the notebooks in the nbs folder. 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python PyPi and Conda 5 | 6 | on: 7 | release: 8 | types: [created] 9 | # schedule: 10 | # - cron: '1 6 * * *' 11 | workflow_dispatch: #allows you to trigger manually 12 | push: 13 | branches: 14 | - main 15 | 16 | jobs: 17 | deploy_python_pypi: 18 | name: "Deploy Python PyPi" 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: '3.x' 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install setuptools wheel twine 31 | - name: Build and publish 32 | env: 33 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | run: | 36 | python setup.py sdist bdist_wheel 37 | twine upload dist/* || echo "File Already exists"; exit 0 38 | 39 | # Disable usage of conda since it never installs correctly. 40 | # deploy_python_conda: 41 | # runs-on: ubuntu-latest 42 | # name: "Deploy Python Conda" 43 | # needs: deploy_python_pypi 44 | # strategy: 45 | # max-parallel: 5 46 | # timeout-minutes: 500 47 | 48 | # steps: 49 | # - uses: actions/checkout@v2 50 | # - name: Set up Python 3.6 51 | # uses: actions/setup-python@v2 52 | # with: 53 | # python-version: 3.6 54 | # - name: Add conda to system path 55 | # run: | 56 | # # $CONDA is an environment variable pointing to the root of the miniconda directory 57 | # echo $CONDA/bin >> $GITHUB_PATH 58 | # - name: Install dependencies 59 | # run: | 60 | # conda install -y -c fastchan conda-build anaconda-client fastrelease 61 | # - name: Build Conda package 62 | # timeout-minutes: 360 63 | # run: | 64 | # fastrelease_conda_package --do_build false 65 | # cd conda 66 | # conda build --no-anaconda-upload --output-folder build fastrl -c fastchan 67 | # anaconda upload build/noarch/fastrl-*-*.tar.bz2 68 | # env: 69 | # ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }} 70 | # - name: Test package import 71 | # run: | 72 | # conda install -c josiahls fastrl 73 | # conda update -c josiahls fastrl 74 | # python -c "import fastrl" 75 | -------------------------------------------------------------------------------- /nbs/sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | sidebar: 3 | contents: 4 | - index.ipynb 5 | - 00_core.ipynb 6 | - 00_nbdev_extension.ipynb 7 | - 19_cli.ipynb 8 | - 20_test_utils.ipynb 9 | - contents: 10 | - 00_Fastai/00_torch_core.ipynb 11 | section: Fastai 12 | - contents: 13 | - 01_DataPipes/01a_pipes.core.ipynb 14 | - 01_DataPipes/01b_pipes.map.demux.ipynb 15 | - 01_DataPipes/01c_pipes.map.mux.ipynb 16 | - 01_DataPipes/01d_pipes.iter.nskip.ipynb 17 | - 01_DataPipes/01e_pipes.iter.nstep.ipynb 18 | - 01_DataPipes/01f_pipes.iter.firstlast.ipynb 19 | - 01_DataPipes/01g_pipes.iter.transforms.ipynb 20 | - 01_DataPipes/01h_pipes.map.transforms.ipynb 21 | section: DataPipes 22 | - contents: 23 | - 02_DataLoading/02f_data.dataloader2.ipynb 24 | - 02_DataLoading/02g_data.block.ipynb 25 | section: DataLoading 26 | - contents: 27 | - 03_Environment/05b_envs.gym.ipynb 28 | section: Environment 29 | - contents: 30 | - 02_Funcs/00_funcs.conjugation.ipynb 31 | section: Funcs 32 | - contents: 33 | - 04_Memory/06a_memory.experience_replay.ipynb 34 | section: Memory 35 | - contents: 36 | - 05_Logging/09a_loggers.core.ipynb 37 | - 05_Logging/09d_loggers.jupyter_visualizers.ipynb 38 | - 05_Logging/09e_loggers.tensorboard.ipynb 39 | - 05_Logging/09f_loggers.vscode_visualizers.ipynb 40 | section: Logging 41 | - contents: 42 | - 06_Learning/10a_learner.core.ipynb 43 | section: Learning 44 | - contents: 45 | - 07_Agents/12a_agents.core.ipynb 46 | - contents: 47 | - 07_Agents/01_Discrete/12b_agents.discrete.ipynb 48 | - 07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 49 | - 07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 50 | - 07_Agents/01_Discrete/12l_agents.dqn.asynchronous.ipynb 51 | - 07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 52 | - 07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 53 | - 07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb 54 | - 07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb 55 | section: Discrete 56 | - contents: 57 | - 07_Agents/02_Continuous/12s_agents.ddpg.ipynb 58 | - 07_Agents/02_Continuous/12t_agents.trpo.ipynb 59 | section: Continuous 60 | section: Agents 61 | - contents: 62 | - 11_FAQ/99_notes.multi_proc.ipynb 63 | - 11_FAQ/99_notes.pipe_insertion.ipynb 64 | - 11_FAQ/99_notes.speed.ipynb 65 | - 11_FAQ/99_template.ipynb 66 | section: FAQ 67 | - contents: 68 | - 12_Blog/99_blog.from_2023_05_to_now.ipynb 69 | section: Blog 70 | -------------------------------------------------------------------------------- /fastrl/envs/continuous_debug_env.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03_Environment/06_envs.continuous_debug_env.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['ContinuousDebugEnv'] 5 | 6 | # %% ../../nbs/03_Environment/06_envs.continuous_debug_env.ipynb 2 7 | # Python native modules 8 | import os 9 | # Third party libs 10 | import gymnasium as gym 11 | from gymnasium import spaces 12 | from gymnasium.envs.registration import register 13 | import numpy as np 14 | # Local modules 15 | 16 | # %% ../../nbs/03_Environment/06_envs.continuous_debug_env.ipynb 4 17 | class ContinuousDebugEnv(gym.Env): 18 | metadata = {'render_modes': ['console']} # Corrected metadata key 19 | 20 | def __init__(self, goal_position=None, proximity_threshold=0.5): 21 | super(ContinuousDebugEnv, self).__init__() 22 | 23 | self.goal_position = goal_position if goal_position is not None else np.random.uniform(-10, 10) 24 | 25 | if goal_position is not None: 26 | self.observation_space = spaces.Box(low=-goal_position, high=goal_position, shape=(1,), dtype=np.float32) 27 | else: 28 | self.observation_space = spaces.Box(low=-10, high=10, shape=(1,), dtype=np.float32) 29 | 30 | self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32) 31 | 32 | 33 | self.proximity_threshold = proximity_threshold 34 | self.state = None 35 | 36 | def step(self, action): 37 | self.state[0] += action[0] # Assuming action is a NumPy array, use the first element 38 | 39 | distance_to_goal = np.abs(self.state[0] - self.goal_position) 40 | reward = -distance_to_goal.item() # Ensure reward is a float 41 | 42 | done = distance_to_goal <= self.proximity_threshold 43 | done = bool(done.item()) # Ensure done is a boolean 44 | 45 | info = {} 46 | 47 | return self.state, reward, done,done, info 48 | 49 | def reset(self, seed=None, options=None): 50 | super().reset(seed=seed) # Call the superclass reset, which handles the seeding 51 | if self.goal_position is None: 52 | self.goal_position = np.random.uniform(-10, 10) 53 | # The state is {current position, goal position} 54 | self.state = np.array([0.0, self.goal_position], dtype=np.float32) 55 | 56 | return self.state, {} # Return observation and an empty info dictionary 57 | 58 | 59 | def render(self, mode='console'): 60 | if mode != 'console': 61 | raise NotImplementedError("Only console mode is supported.") 62 | print(f"Position: {self.state} Goal: {self.goal_position}") 63 | 64 | 65 | register( 66 | id="fastrl/ContinuousDebugEnv-v0", 67 | entry_point="fastrl.envs.continuous_debug_env:ContinuousDebugEnv", 68 | max_episode_steps=300, 69 | ) 70 | 71 | -------------------------------------------------------------------------------- /nbs/12_Blog/99_blog.from_2023_05_to_now.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "assisted-contract", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Python native modules\n", 11 | "# Third party libs\n", 12 | "\n", 13 | "# Local modules\n", 14 | "from fastrl.nbdev_extensions import header" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "40e740d6", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "#|echo: false\n", 25 | "header(\n", 26 | " \"Updating torchdata to 0.7.0 and torch 2.0\",\n", 27 | " \"Gettign back up to speed with things\",\n", 28 | " freeze=False\n", 29 | ")" 30 | ] 31 | }, 32 | { 33 | "attachments": {}, 34 | "cell_type": "markdown", 35 | "id": "fec5e005", 36 | "metadata": {}, 37 | "source": [ 38 | "First change is `traverse` to `traverse_dps`. It seems that the traverse function has been\n", 39 | "moved to torch? Not sure what the benefit of this this. I like the improvement of the \n", 40 | "utlities for understanding datapipes though.\n", 41 | "\n", 42 | "The original dataloading stuff is found at [dataloader2](https://github.com/josiahls/fastrl/blob/c2ec68b9092f616e51bc9d574873a13e06ca3a26/fastrl/data/dataloader2.py)\n", 43 | "\n", 44 | "The goal of that dataloading module is:\n", 45 | "- Allowing getting data back from the main thread\n", 46 | "- Pushing an additional stream of data for logging.\n", 47 | "\n", 48 | "My main issue is that this is maybe overly complicated. I feel like \"getting data back from the main thread\"\n", 49 | "could be done through the file system.\n", 50 | "\n", 51 | "I know that I've re-worked this a bunch of times, but I just feel like this is too complicated, and \n", 52 | "I'm hoping the new changes to pytorch make this easier." 53 | ] 54 | }, 55 | { 56 | "attachments": {}, 57 | "cell_type": "markdown", 58 | "id": "aeb4f889", 59 | "metadata": {}, 60 | "source": [ 61 | "Ok so some additional changes:\n", 62 | "\n", 63 | "- removed the map iter pipe mux and demux. RL likely wont use this, and even if I need something like this, I can just\n", 64 | "iterate through them I think.\n", 65 | "- `find_dps` would be nice to have include subclasses. You can argue there should never be\n", 66 | "subclasses though? It would maybe be better to have a `find_dps` that just took a\n", 67 | "filter function. On the other hand... you could just filter after the fact... hmm" 68 | ] 69 | }, 70 | { 71 | "attachments": {}, 72 | "cell_type": "markdown", 73 | "id": "ab16ee37", 74 | "metadata": {}, 75 | "source": [] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "python3", 81 | "language": "python", 82 | "name": "python3" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 5 87 | } 88 | -------------------------------------------------------------------------------- /fastrl/loggers/jupyter_visualizers.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SimpleJupyterVideoPlayer', 'ImageCollector'] 5 | 6 | # %% ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb 1 7 | # Python native modules 8 | import os 9 | from torch.multiprocessing import Queue 10 | from typing import Tuple,NamedTuple 11 | # Third party libs 12 | from fastcore.all import add_docs 13 | import matplotlib.pyplot as plt 14 | import torchdata.datapipes as dp 15 | from IPython.core.display import clear_output 16 | import torch 17 | import numpy as np 18 | # Local modules 19 | from ..core import Record 20 | from .core import LoggerBase,LogCollector,is_record 21 | # from fastrl.torch_core import * 22 | 23 | # %% ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb 4 24 | class SimpleJupyterVideoPlayer(dp.iter.IterDataPipe): 25 | def __init__(self, 26 | source_datapipe=None, 27 | between_frame_wait_seconds:float=0.1 28 | ): 29 | self.source_datapipe = source_datapipe 30 | self.between_frame_wait_seconds = 0.1 31 | 32 | def dequeue(self): 33 | while self.buffer: yield self.buffer.pop(0) 34 | 35 | 36 | def __iter__(self) -> Tuple[NamedTuple]: 37 | img = None 38 | for record in self.source_datapipe: 39 | # for o in self.dequeue(): 40 | if is_record(record): 41 | if record.value is None: continue 42 | if img is None: img = plt.imshow(record.value) 43 | img.set_data(record.value) 44 | plt.axis('off') 45 | display(plt.gcf()) 46 | clear_output(wait=True) 47 | yield record 48 | add_docs( 49 | SimpleJupyterVideoPlayer, 50 | """Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field. 51 | This only can handle 1 env input.""", 52 | dequeue="Grabs records from the `main_queue` and attempts to display them" 53 | ) 54 | 55 | # %% ../../nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb 5 56 | class ImageCollector(dp.iter.IterDataPipe): 57 | title:str='image' 58 | 59 | def __init__(self,source_datapipe): 60 | self.source_datapipe = source_datapipe 61 | 62 | def convert_np(self,o): 63 | if isinstance(o,torch.Tensor): return o.detach().numpy() 64 | elif isinstance(o,np.ndarray): return o 65 | else: raise ValueError(f'Expects Tensor or np.ndarray not {type(o)}') 66 | 67 | def __iter__(self): 68 | # for q in self.main_buffers: q.append(Record('image',None)) 69 | yield Record(self.title,None) 70 | for steps in self.source_datapipe: 71 | if isinstance(steps,dp.DataChunk): 72 | for step in steps: 73 | yield Record(self.title,self.convert_np(step.image)) 74 | else: 75 | yield Record(self.title,self.convert_np(steps.image)) 76 | yield steps 77 | -------------------------------------------------------------------------------- /fastrl/pipes/iter/firstlast.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['simple_step_first_last_merge', 'FirstLastMerger', 'n_first_last_steps_expected'] 5 | 6 | # %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 2 7 | # Python native modules 8 | import warnings 9 | from typing import Callable,List,Union 10 | # Third party libs 11 | from fastcore.all import add_docs 12 | import torchdata.datapipes as dp 13 | 14 | import torch 15 | # Local modules 16 | from ...core import StepTypes,SimpleStep 17 | 18 | # %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 4 19 | def simple_step_first_last_merge(steps:List[SimpleStep],gamma): 20 | fstep,lstep = steps[0],steps[-1] 21 | 22 | reward = fstep.reward 23 | for step in steps[1:]: 24 | reward *= gamma 25 | reward += step.reward 26 | 27 | yield SimpleStep( 28 | state=fstep.state.clone().detach(), 29 | next_state=lstep.next_state.clone().detach(), 30 | action=fstep.action, 31 | episode_n=fstep.episode_n, 32 | image=fstep.image, 33 | reward=reward, 34 | raw_action=fstep.raw_action, 35 | terminated=lstep.terminated, 36 | truncated=lstep.truncated, 37 | total_reward=lstep.total_reward, 38 | env_id=lstep.env_id, 39 | proc_id=lstep.proc_id, 40 | step_n=lstep.step_n, 41 | batch_size=[] 42 | ) 43 | 44 | class FirstLastMerger(dp.iter.IterDataPipe): 45 | def __init__(self, 46 | source_datapipe, 47 | merge_behavior:Callable[[List[Union[StepTypes.types]],float],Union[StepTypes.types]]=simple_step_first_last_merge, 48 | gamma:float=0.99 49 | ): 50 | self.source_datapipe = source_datapipe 51 | self.gamma = gamma 52 | self.merge_behavior = merge_behavior 53 | 54 | def __iter__(self) -> StepTypes.types: 55 | self.env_buffer = {} 56 | for steps in self.source_datapipe: 57 | if not isinstance(steps,(list,tuple)): 58 | raise ValueError(f'Expected {self.source_datapipe} to return a list/tuple of steps, however got {type(steps)}') 59 | 60 | if len(steps)==1: 61 | yield steps[0] 62 | continue 63 | 64 | yield from self.merge_behavior(steps,gamma=self.gamma) 65 | 66 | add_docs( 67 | FirstLastMerger, 68 | """Takes multiple steps and converts them into a single step consisting of properties 69 | from the first and last steps. Reward is recalculated to factor in the multiple steps.""", 70 | ) 71 | 72 | # %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 15 73 | def n_first_last_steps_expected( 74 | default_steps:int, # The number of steps the episode would run without n_steps 75 | ): 76 | return default_steps 77 | 78 | n_first_last_steps_expected.__doc__=r""" 79 | This function doesnt do much for now. `FirstLastMerger` pretty much undoes the number of steps `nsteps` does. 80 | """ 81 | -------------------------------------------------------------------------------- /fastrl/loggers/vscode_visualizers.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SimpleVSCodeVideoPlayer', 'VSCodeDataPipe'] 5 | 6 | # %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 1 7 | # Python native modules 8 | import os 9 | import io 10 | from typing import Tuple,Any,Optional,NamedTuple,Iterable 11 | # Third party libs 12 | import imageio 13 | from fastcore.all import add_docs,ifnone 14 | import matplotlib.pyplot as plt 15 | import torchdata.datapipes as dp 16 | from torchdata.datapipes import functional_datapipe 17 | from IPython.core.display import Video,Image 18 | from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService 19 | # Local modules 20 | from .core import LoggerBase,is_record 21 | from ..pipes.core import DataPipeAugmentationFn,apply_dp_augmentation_fns 22 | from .jupyter_visualizers import ImageCollector 23 | 24 | # %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 5 25 | class SimpleVSCodeVideoPlayer(dp.iter.IterDataPipe): 26 | def __init__(self, 27 | source_datapipe=None, 28 | skip_frames:int=1, 29 | fps:int=30, 30 | downsize_res=(2,2) 31 | ): 32 | self.source_datapipe = source_datapipe 33 | self.fps = fps 34 | self.skip_frames = skip_frames 35 | self.downsize_res = downsize_res 36 | self._bytes_object = None 37 | self.frames = [] 38 | 39 | def reset(self): 40 | super().reset() 41 | self._bytes_object = io.BytesIO() 42 | 43 | def show(self,start:int=0,end:Optional[int]=None,step:int=1): 44 | print(f'Creating gif from {len(self.frames)} frames') 45 | imageio.mimwrite( 46 | self._bytes_object, 47 | self.frames[start:end:step], 48 | format='GIF', 49 | duration=self.fps 50 | ) 51 | return Image(self._bytes_object.getvalue()) 52 | 53 | def __iter__(self) -> Tuple[NamedTuple]: 54 | n_frame = 0 55 | for record in self.source_datapipe: 56 | # for o in self.dequeue(): 57 | if is_record(record) and record.name=='image': 58 | if record.value is None: continue 59 | n_frame += 1 60 | if n_frame%self.skip_frames!=0: continue 61 | self.frames.append( 62 | record.value[::self.downsize_res[0],::self.downsize_res[1]] 63 | ) 64 | yield record 65 | add_docs( 66 | SimpleVSCodeVideoPlayer, 67 | """Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field. 68 | This only can handle 1 env input.""", 69 | show="In order to show the video, this must be called in a notebook cell.", 70 | reset="Will reset the bytes object that is used to store file data." 71 | ) 72 | 73 | # %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 6 74 | @functional_datapipe('visualize_vscode') 75 | class VSCodeDataPipe(dp.iter.IterDataPipe): 76 | def __new__(self,source:Iterable): 77 | "This is the function that is actually run by `DataBlock`" 78 | pipe = ImageCollector(source).dump_records() 79 | pipe = SimpleVSCodeVideoPlayer(pipe) 80 | return pipe 81 | 82 | -------------------------------------------------------------------------------- /fastrl/agents/dqn/rainbow.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DQNRainbowLearner'] 5 | 6 | # %% ../../../nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb 2 7 | # Python native modules 8 | from copy import deepcopy 9 | from typing import Optional,Callable,Tuple 10 | # Third party libs 11 | import torchdata.datapipes as dp 12 | from torchdata.dataloader2.graph import traverse_dps,DataPipe 13 | import torch 14 | from torch import nn,optim 15 | from fastcore.all import store_attr,ifnone 16 | import numpy as np 17 | import torch.nn.functional as F 18 | # Local modulesf 19 | from ...torch_core import default_device,to_detach,evaluating 20 | from ...pipes.core import find_dp 21 | from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter 22 | from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector 23 | from ...memory.experience_replay import ExperienceReplay 24 | from ...loggers.core import BatchCollector,EpochCollector 25 | from ...learner.core import LearnerBase,LearnerHead 26 | from ..core import AgentHead,AgentBase 27 | from ...loggers.vscode_visualizers import VSCodeDataPipe 28 | from ...loggers.core import ProgressBarLogger 29 | from fastrl.agents.dqn.basic import ( 30 | LossCollector, 31 | RollingTerminatedRewardCollector, 32 | EpisodeCollector, 33 | StepBatcher, 34 | TargetCalc, 35 | LossCalc, 36 | ModelLearnCalc, 37 | DQN, 38 | DQNAgent 39 | ) 40 | from fastrl.agents.dqn.target import ( 41 | TargetModelUpdater, 42 | TargetModelQCalc 43 | ) 44 | from .dueling import DuelingHead 45 | from fastrl.agents.dqn.categorical import ( 46 | CategoricalDQNAgent, 47 | CategoricalDQN, 48 | CategoricalTargetQCalc, 49 | PartialCrossEntropy 50 | ) 51 | 52 | # %% ../../../nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb 4 53 | def DQNRainbowLearner( 54 | model, 55 | dls, 56 | do_logging:bool=True, 57 | loss_func=PartialCrossEntropy, 58 | opt=optim.AdamW, 59 | lr=0.005, 60 | bs=128, 61 | max_sz=10000, 62 | nsteps=1, 63 | device=None, 64 | batches=None, 65 | target_sync=300, 66 | # Use DoubleDQN target strategy 67 | double_dqn_strategy=True 68 | ) -> LearnerHead: 69 | learner = LearnerBase(model,dls=dls[0]) 70 | learner = BatchCollector(learner,batches=batches) 71 | learner = EpochCollector(learner) 72 | if do_logging: 73 | learner = learner.dump_records() 74 | learner = ProgressBarLogger(learner) 75 | learner = RollingTerminatedRewardCollector(learner) 76 | learner = EpisodeCollector(learner).catch_records() 77 | learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) 78 | learner = StepBatcher(learner,device=device) 79 | learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device) 80 | learner = LossCalc(learner,loss_func=loss_func) 81 | learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) 82 | learner = TargetModelUpdater(learner,target_sync=target_sync) 83 | if do_logging: 84 | learner = LossCollector(learner).catch_records() 85 | 86 | if len(dls)==2: 87 | val_learner = LearnerBase(model,dls[1]).visualize_vscode() 88 | val_learner = BatchCollector(val_learner,batches=batches) 89 | val_learner = EpochCollector(val_learner).catch_records(drop=True) 90 | return LearnerHead((learner,val_learner)) 91 | else: 92 | return LearnerHead(learner) 93 | -------------------------------------------------------------------------------- /fastrl/pipes/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/01_DataPipes/01a_pipes.core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['find_dps', 'find_dp', 'DataPipeAugmentationFn', 'apply_dp_augmentation_fns'] 5 | 6 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 2 7 | # Python native modules 8 | import os 9 | import logging 10 | import inspect 11 | from typing import Callable,Union,TypeVar,Optional,Type,List,Tuple 12 | # Third party libs 13 | import torchdata.datapipes as dp 14 | from torchdata.datapipes import functional_datapipe 15 | from torchdata.datapipes.iter import IterDataPipe 16 | from torchdata.datapipes.map import MapDataPipe 17 | from torchdata.dataloader2.graph import DataPipe, DataPipeGraph,find_dps,traverse_dps,list_dps 18 | # Local modules 19 | 20 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 5 21 | def find_dps( 22 | graph: DataPipeGraph, 23 | dp_type: Type[DataPipe], 24 | include_subclasses:bool=False 25 | ) -> List[DataPipe]: 26 | r""" 27 | Given the graph of DataPipe generated by ``traverse`` function, return DataPipe 28 | instances with the provided DataPipe type. 29 | """ 30 | dps: List[DataPipe] = [] 31 | 32 | def helper(g) -> None: # pyre-ignore 33 | for _, (dp, src_graph) in g.items(): 34 | if include_subclasses and issubclass(type(dp),dp_type): 35 | dps.append(dp) 36 | elif type(dp) is dp_type: # Please not use `isinstance`, there is a bug. 37 | dps.append(dp) 38 | helper(src_graph) 39 | 40 | helper(graph) 41 | 42 | return dps 43 | 44 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 6 45 | def find_dp( 46 | # A graph created from the `traverse` function 47 | graph: DataPipeGraph, 48 | # 49 | dp_type: Type[DataPipe], 50 | include_subclasses:bool=False 51 | ) -> DataPipe: 52 | pipes = find_dps(graph,dp_type,include_subclasses) 53 | if len(pipes)==1: return pipes[0] 54 | elif len(pipes)>1: 55 | found_ids = set([id(pipe) for pipe in pipes]) 56 | if len(found_ids)>1: 57 | logging.warn("""There are %s pipes of type %s. If this is intended, 58 | please use `find_dps` directly. Returning first instance.""",len(pipes),dp_type) 59 | return pipes[0] 60 | else: 61 | raise LookupError(f'Unable to find {dp_type} starting at {graph}') 62 | 63 | find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__ 64 | 65 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 19 66 | class DataPipeAugmentationFn(Callable[[DataPipe],Optional[DataPipe]]):... 67 | 68 | DataPipeAugmentationFn.__doc__ = f"""`DataPipeAugmentationFn` must take in a `DataPipe` and either output a `DataPipe` or `None`. This function should perform some operation on the graph 69 | such as replacing, removing, inserting `DataPipe`'s and `DataGraph`s. Below is an example that replaces a `dp.iter.Batcher` datapipe with a `dp.iter.Filter`""" 70 | 71 | # %% ../../nbs/01_DataPipes/01a_pipes.core.ipynb 23 72 | def apply_dp_augmentation_fns( 73 | pipe:DataPipe, 74 | dp_augmentation_fns:Optional[Tuple[DataPipeAugmentationFn]], 75 | debug:bool=False 76 | ) -> DataPipe: 77 | "Given a `pipe`, run `dp_augmentation_fns` other the pipeline" 78 | if dp_augmentation_fns is None: return pipe 79 | for fn in dp_augmentation_fns: 80 | if debug: print(f'Running fn: {fn} given current pipe: \n\t{traverse_dps(pipe)}') 81 | result = fn(pipe) 82 | if result is not None: pipe = result 83 | return pipe 84 | -------------------------------------------------------------------------------- /fastrl/pipes/map/mux.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01c_pipes.map.mux.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['T_co', 'MultiplexerMapDataPipe'] 5 | 6 | # %% ../../../nbs/01_DataPipes/01c_pipes.map.mux.ipynb 3 7 | # Python native modules 8 | import os 9 | from inspect import isfunction,ismethod 10 | from itertools import chain, zip_longest 11 | from typing import Callable, Dict, Iterable, Optional, TypeVar 12 | # Third party libs 13 | from fastcore.all import * 14 | from ...torch_core import * 15 | # from torch.utils.data.dataloader import DataLoader as OrgDataLoader 16 | import torchdata.datapipes as dp 17 | from torchdata.datapipes import functional_datapipe 18 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,MapDataPipe,IterDataPipe 19 | from torchdata.dataloader2.dataloader2 import DataLoader2 20 | # Local modules 21 | 22 | # %% ../../../nbs/01_DataPipes/01c_pipes.map.mux.ipynb 5 23 | T_co = TypeVar("T_co", covariant=True) 24 | 25 | @functional_datapipe("mux") 26 | class MultiplexerMapDataPipe(MapDataPipe[T_co]): 27 | def __init__(self, *datapipes, dp_index_map: Optional[Dict[MapDataPipe, Iterable]] = None): 28 | self.datapipes = datapipes 29 | self.dp_index_map = dp_index_map if dp_index_map else {} 30 | self.length: Optional[int] = None 31 | self.index_map = {} 32 | # Create a generator that yields (index, (dp_num, old_index)) in sequentially order. 33 | indices = (self._add_dp_num(i, dp) for i, dp in enumerate(datapipes)) 34 | dp_id_and_key_tuples = chain.from_iterable(zip_longest(*indices)) 35 | self.key_gen = enumerate(e for e in dp_id_and_key_tuples if e is not None) 36 | 37 | def _add_dp_num(self, dp_num: int, dp: MapDataPipe): 38 | # Assume 0-index for all DataPipes unless alternate indices are defined in `self.dp_index_map` 39 | dp_indices = self.dp_index_map[dp] if dp in self.dp_index_map else range(len(dp)) 40 | for idx in dp_indices: 41 | yield dp_num, idx 42 | 43 | def __getitem__(self, index): 44 | if 0 <= index < len(self): 45 | if index in self.index_map: 46 | dp_num, old_key = self.index_map[index] 47 | else: 48 | curr_key = -1 49 | while curr_key < index: 50 | curr_key, dp_num_key_tuple = next(self.key_gen) 51 | dp_num, old_key = dp_num_key_tuple 52 | self.index_map[index] = dp_num, old_key 53 | try: 54 | return self.datapipes[dp_num][old_key] 55 | except KeyError: 56 | raise RuntimeError( 57 | f"Incorrect key is given to MapDataPipe {dp_num} in Multiplexer, likely because" 58 | f"that DataPipe is not 0-index but alternate indices are not given." 59 | ) 60 | raise RuntimeError(f"Index {index} is out of bound for Multiplexer.") 61 | 62 | def __iter__(self): 63 | for i in range(len(self)): 64 | yield self[i] 65 | 66 | def __len__(self): 67 | if self.length is None: 68 | self.length = 0 69 | for dp in self.datapipes: 70 | self.length += len(dp) 71 | return self.length 72 | 73 | MultiplexerMapDataPipe.__doc__ = """Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). As in, 74 | one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, 75 | and so on. It ends when the shortest input DataPipe is exhausted. 76 | """ 77 | -------------------------------------------------------------------------------- /fastrl/agents/dqn/double.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DoubleQCalc', 'DoubleDQNLearner'] 5 | 6 | # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 2 7 | # Python native modules 8 | from copy import deepcopy 9 | from typing import Optional,Callable,Tuple 10 | # Third party libs 11 | import torchdata.datapipes as dp 12 | from torchdata.dataloader2.graph import traverse_dps,DataPipe 13 | import torch 14 | from torch import nn,optim 15 | # Local modulesf 16 | from ...pipes.core import find_dp 17 | from ...memory.experience_replay import ExperienceReplay 18 | from ...loggers.core import BatchCollector,EpochCollector 19 | from ...learner.core import LearnerBase,LearnerHead 20 | from ...loggers.vscode_visualizers import VSCodeDataPipe 21 | from ...loggers.core import ProgressBarLogger 22 | from fastrl.agents.dqn.basic import ( 23 | LossCollector, 24 | RollingTerminatedRewardCollector, 25 | EpisodeCollector, 26 | StepBatcher, 27 | TargetCalc, 28 | LossCalc, 29 | ModelLearnCalc, 30 | DQN, 31 | DQNAgent 32 | ) 33 | from fastrl.agents.dqn.target import ( 34 | TargetModelUpdater, 35 | TargetModelQCalc 36 | ) 37 | 38 | # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 5 39 | class DoubleQCalc(dp.iter.IterDataPipe): 40 | def __init__(self,source_datapipe): 41 | self.source_datapipe = source_datapipe 42 | 43 | def __iter__(self): 44 | self.learner = find_dp(traverse_dps(self),LearnerBase) 45 | for batch in self.source_datapipe: 46 | self.learner.done_mask = batch.terminated.reshape(-1,) 47 | with torch.no_grad(): 48 | chosen_actions = self.learner.model(batch.next_state).argmax(dim=1).reshape(-1,1) 49 | self.learner.next_q = self.learner.target_model(batch.next_state).gather(1,chosen_actions) 50 | self.learner.next_q[self.learner.done_mask] = 0 51 | yield batch 52 | 53 | # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 6 54 | def DoubleDQNLearner( 55 | model, 56 | dls, 57 | do_logging:bool=True, 58 | loss_func=nn.MSELoss(), 59 | opt=optim.AdamW, 60 | lr=0.005, 61 | bs=128, 62 | max_sz=10000, 63 | nsteps=1, 64 | device=None, 65 | batches=None, 66 | target_sync=300 67 | ) -> LearnerHead: 68 | learner = LearnerBase(model,dls=dls[0]) 69 | learner = BatchCollector(learner,batches=batches) 70 | learner = EpochCollector(learner) 71 | if do_logging: 72 | learner = learner.dump_records() 73 | learner = ProgressBarLogger(learner) 74 | learner = RollingTerminatedRewardCollector(learner) 75 | learner = EpisodeCollector(learner).catch_records() 76 | learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) 77 | learner = StepBatcher(learner,device=device) 78 | learner = DoubleQCalc(learner) 79 | learner = TargetCalc(learner,nsteps=nsteps) 80 | learner = LossCalc(learner,loss_func=loss_func) 81 | learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) 82 | learner = TargetModelUpdater(learner,target_sync=target_sync) 83 | if do_logging: 84 | learner = LossCollector(learner).catch_records() 85 | 86 | if len(dls)==2: 87 | val_learner = LearnerBase(model,dls[1]).visualize_vscode() 88 | val_learner = BatchCollector(val_learner,batches=batches) 89 | val_learner = EpochCollector(val_learner).catch_records(drop=True) 90 | return LearnerHead((learner,val_learner)) 91 | else: 92 | return LearnerHead(learner) 93 | -------------------------------------------------------------------------------- /fastrl/pipes/iter/nskip.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['NSkipper', 'n_skips_expected'] 5 | 6 | # %% ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb 2 7 | # Python native modules 8 | import os 9 | import warnings 10 | from typing import Callable, Dict, Iterable, Optional, TypeVar, Type,Union 11 | # Third party libs 12 | import torchdata.datapipes as dp 13 | from torchdata.datapipes.iter import IterDataPipe 14 | from fastcore.all import add_docs 15 | # Local modules 16 | from ...core import StepTypes 17 | from .nstep import NStepper 18 | 19 | # %% ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb 4 20 | _msg = """ 21 | NSkipper should not go after NStepper. Please make the order: 22 | 23 | ```python 24 | ... 25 | pipe = NSkipper(pipe,n=3) 26 | pipe = NStepper(pipe,n=3) 27 | ... 28 | ``` 29 | 30 | """ 31 | 32 | class NSkipper(IterDataPipe[Union[StepTypes.types]]): 33 | def __init__( 34 | self, 35 | # The datapipe we are extracting from must produce `StepType` 36 | source_datapipe:IterDataPipe[Union[StepTypes.types]], 37 | # Number of steps to skip per env. Default will not skip at all. 38 | n:int=1 39 | ) -> None: 40 | if isinstance(source_datapipe,NStepper): raise Exception(_msg) 41 | self.source_datapipe = source_datapipe 42 | self.n = n 43 | self.env_buffer = {} 44 | 45 | def __iter__(self) -> StepTypes.types: 46 | self.env_buffer = {} 47 | for step in self.source_datapipe: 48 | if not issubclass(step.__class__,StepTypes.types): 49 | raise Exception(f'Expected {StepTypes.types} object got {type(step)}\n{step}') 50 | 51 | env_id,terminated,step_n = int(step.env_id),bool(step.terminated),int(step.step_n) 52 | 53 | if env_id in self.env_buffer: self.env_buffer[env_id] += 1 54 | else: self.env_buffer[env_id] = 1 55 | 56 | if self.env_buffer[env_id]%self.n==0: yield step 57 | elif terminated: yield step 58 | elif step_n==1: yield step 59 | 60 | if terminated: self.env_buffer[env_id] = 1 61 | 62 | add_docs( 63 | NSkipper, 64 | """Accepts a `source_datapipe` or iterable whose `next()` produces a `StepType` that 65 | skips N steps for individual environments *while always producing 1st steps and terminated steps.* 66 | """ 67 | ) 68 | 69 | # %% ../../../nbs/01_DataPipes/01d_pipes.iter.nskip.ipynb 17 70 | def n_skips_expected( 71 | default_steps:int, # The number of steps the episode would run without n_skips 72 | n:int # The n-skip value that we are planning to use 73 | ): 74 | if n==1: return default_steps # All the steps will eb retained including the 1st step. No offset needed 75 | # If n goes into default_steps evenly, then the final "done" will be technically an "extra" step 76 | elif default_steps%n==0: return (default_steps // n) + 1 # first step will be kept 77 | else: 78 | # If the steps dont divide evenly then it will attempt to skip done, but ofcourse, we dont 79 | # let that happen 80 | return (default_steps // n) + 2 # first step and done will be kept 81 | 82 | n_skips_expected.__doc__=r""" 83 | Produces the expected number of steps, assuming a fully deterministic episode based on `default_steps` and `n`. 84 | 85 | Mainly used for testing. 86 | 87 | Given `n=2`, given 1 envs, knowing that `CartPole-v1` when `seed=0` will always run 18 steps, the total 88 | steps will be: 89 | 90 | $$ 91 | 18 // n + 1 (1st+last) 92 | $$ 93 | """ 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import subprocess 4 | from setuptools.command.develop import develop 5 | from setuptools.command.install import install 6 | import setuptools 7 | import os 8 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 9 | 10 | # note: all settings are in settings.ini; edit there, not here 11 | config = ConfigParser(delimiters=['=']) 12 | config.read('settings.ini') 13 | cfg = config['DEFAULT'] 14 | 15 | cfg_keys = 'version description keywords author author_email'.split() 16 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 17 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 18 | setup_cfg = {o:cfg[o] for o in cfg_keys} 19 | 20 | licenses = { 21 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = ['pip', 'packaging'] 28 | if cfg.get('requirements'): requirements += cfg.get('requirements','').split() 29 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() 30 | dev_requirements = (cfg.get('dev_requirements') or '').split() 31 | 32 | lic = licenses[cfg['license']] 33 | min_python = cfg['min_python'] 34 | 35 | # Define the repository and branch or commit you want to install from 36 | TORCHDATA_GIT_REPO = "https://github.com/josiahls/data.git" 37 | TORCHDATA_COMMIT = "main" # or replace with a specific commit hash 38 | 39 | class CustomInstall(install): 40 | def run(self): 41 | # Ensure that torchdata is cloned and installed before proceeding 42 | print('Cloning torchdata') 43 | subprocess.check_call(["git", "clone", TORCHDATA_GIT_REPO]) 44 | print('Installing torchdata') 45 | subprocess.check_call(["pip", "install","-vvv", "./data"]) 46 | # Call the standard install. 47 | install.run(self) 48 | 49 | class CustomDevelop(develop): 50 | def run(self): 51 | # Ensure that torchdata is cloned but not installed 52 | if not os.path.exists('data'): 53 | print('Cloning torchdata') 54 | subprocess.check_call(["git", "clone", TORCHDATA_GIT_REPO]) 55 | try: 56 | import torchdata 57 | except ImportError: 58 | print('Installing torchdata') 59 | subprocess.check_call(["pip", "install","-vvv", "-e", "./data"]) 60 | # Call the standard develop. 61 | develop.run(self) 62 | 63 | setuptools.setup( 64 | name = cfg['lib_name'], 65 | license = lic[0], 66 | classifiers = [ 67 | 'Development Status :: ' + statuses[int(cfg['status'])], 68 | 'Intended Audience :: ' + cfg['audience'].title(), 69 | 'License :: ' + lic[1], 70 | 'Natural Language :: ' + cfg['language'].title(), 71 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]], 72 | url = cfg['git_url'], 73 | packages = setuptools.find_packages(), 74 | include_package_data = True, 75 | install_requires = requirements, 76 | extras_require={'dev': dev_requirements }, 77 | dependency_links = cfg.get('dep_links','').split(), 78 | python_requires = '>=' + cfg['min_python'], 79 | long_description = open('README.md').read(), 80 | long_description_content_type = 'text/markdown', 81 | zip_safe = False, 82 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() }, 83 | cmdclass={ 84 | 'install': CustomInstall, 85 | 'develop': CustomDevelop, 86 | }, 87 | **setup_cfg) 88 | 89 | -------------------------------------------------------------------------------- /nbs/05_Logging/09e_loggers.tensorboard.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "assisted-contract", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#|export\n", 23 | "# Python native modules\n", 24 | "import os\n", 25 | "from pathlib import Path\n", 26 | "# Third party libs\n", 27 | "import torchdata.datapipes as dp\n", 28 | "# Local modules" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "offshore-stuart", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|default_exp loggers.tensorboard" 39 | ] 40 | }, 41 | { 42 | "attachments": {}, 43 | "cell_type": "markdown", 44 | "id": "lesser-innocent", 45 | "metadata": {}, 46 | "source": [ 47 | "# Tensorboard \n", 48 | "> Iterable pipes for exporting to tensorboard" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "fea383b7-e004-4ce1-8007-7b6d29248677", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#|export\n", 59 | "def run_tensorboard(\n", 60 | " port:int=6006, # The port to run tensorboard on/connect on\n", 61 | " start_tag:str=None, # Starting regex e.g.: experience_replay/1\n", 62 | " samples_per_plugin:str=None, # Sampling freq such as images=0 (keep all)\n", 63 | " extra_args:str=None, # Any additional arguments in the `--arg value` format\n", 64 | " rm_glob:bool=None # Remove old logs via a parttern e.g.: '*' will remove all files: runs/* \n", 65 | " ):\n", 66 | " if rm_glob is not None:\n", 67 | " for p in Path('runs').glob(rm_glob): p.delete()\n", 68 | " import socket\n", 69 | " from tensorboard import notebook\n", 70 | " a_socket=socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", 71 | " cmd=None\n", 72 | " if not a_socket.connect_ex(('127.0.0.1',6006)):\n", 73 | " notebook.display(port=port,height=1000)\n", 74 | " else:\n", 75 | " cmd=f'--logdir runs --port {port} --host=0.0.0.0'\n", 76 | " if samples_per_plugin is not None: cmd+=f' --samples_per_plugin {samples_per_plugin}'\n", 77 | " if start_tag is not None: cmd+=f' --tag {start_tag}'\n", 78 | " if extra_args is not None: cmd+=f' {extra_args}'\n", 79 | " notebook.start(cmd)\n", 80 | " return cmd" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "530f998c-250a-4005-8abc-65ca89d8ae7d", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "#|hide\n", 91 | "SHOW_TENSOR_BOARD=False\n", 92 | "if not os.environ.get(\"IN_TEST\", None) and SHOW_TENSOR_BOARD:\n", 93 | " run_tensorboard()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "current-pilot", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "#|hide\n", 104 | "#|eval: false\n", 105 | "!nbdev_export" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "e0d82468-a2bf-4bfd-9ac7-e56db49b8476", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "python3", 120 | "language": "python", 121 | "name": "python3" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /.github/workflows/fastrl-docker.yml: -------------------------------------------------------------------------------- 1 | name: Build fastrl images 2 | on: 3 | schedule: 4 | - cron: '1 6 * * *' 5 | workflow_dispatch: #allows you to trigger manually 6 | push: 7 | branches: 8 | - main 9 | - update_nbdev_docs 10 | - refactor/advantage-buffer 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | max-parallel: 1 17 | matrix: 18 | build_type: [dev] 19 | # build_type: [prod, dev] 20 | steps: 21 | - name: Maximize build space 22 | uses: easimon/maximize-build-space@master 23 | with: 24 | root-reserve-mb: 35000 25 | swap-size-mb: 1024 26 | remove-dotnet: 'true' 27 | remove-android: 'true' 28 | - name: Copy This Repository Contents 29 | uses: actions/checkout@v2 30 | # with: 31 | # submodules: recursive 32 | - name: Build 33 | run: | 34 | echo "Free space:" 35 | df -h 36 | 37 | - uses: actions/setup-python@v2 38 | with: 39 | python-version: '3.8' 40 | architecture: 'x64' 41 | 42 | - name: Copy settings.ini file 43 | run: | 44 | wget https://raw.githubusercontent.com/josiahls/fastrl/master/settings.ini 45 | - name: get version from settings.ini and create image name 46 | id: get_variables 47 | run: | 48 | from configparser import ConfigParser 49 | import os 50 | from pathlib import Path 51 | config = ConfigParser() 52 | settings = Path('settings.ini') 53 | assert settings.exists(), 'Not able to read or download settings.ini file.' 54 | config.read(settings) 55 | cfg = config['DEFAULT'] 56 | print(f"::set-output name=version::{cfg['version']}") 57 | btype = os.getenv('BUILD_TYPE') 58 | assert btype in ['prod', 'dev'], "BUILD_TYPE must be either prod, dev or course" 59 | if btype != 'prod': 60 | image_name = f'josiahls/fastrl-{btype}' 61 | else: 62 | image_name = 'josiahls/fastrl' 63 | print(f"::set-output name=image_name::{image_name}") 64 | shell: python 65 | env: 66 | BUILD_TYPE: ${{ matrix.build_type }} 67 | 68 | # - name: Cache Docker layers 69 | # if: always() 70 | # uses: actions/cache@v3 71 | # with: 72 | # path: /tmp/.buildx-cache 73 | # key: ${{ runner.os }}-buildx-${{ github.sha }} 74 | # restore-keys: | 75 | # ${{ runner.os }}-buildx- 76 | 77 | # - name: Set up Docker Buildx 78 | # uses: docker/setup-buildx-action@v3 79 | 80 | - name: build and tag container 81 | run: | 82 | export DOCKER_BUILDKIT=1 83 | # We need to clear the previous docker images 84 | docker system prune -fa 85 | docker pull ${IMAGE_NAME}:latest || true 86 | # docker build --build-arg BUILD=${BUILD_TYPE} \ 87 | # docker buildx create --use 88 | # docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \ 89 | # -t ${IMAGE_NAME}:latest \ 90 | # -t ${IMAGE_NAME}:${VERSION} \ 91 | # -t ${IMAGE_NAME}:$(date +%F) \ 92 | # -f fastrl.Dockerfile . 93 | docker buildx build --build-arg BUILD=${BUILD_TYPE} \ 94 | -t ${IMAGE_NAME}:latest \ 95 | -t ${IMAGE_NAME}:${VERSION} \ 96 | -t ${IMAGE_NAME}:$(date +%F) \ 97 | -f fastrl.Dockerfile . 98 | env: 99 | VERSION: ${{ steps.get_variables.outputs.version }} 100 | IMAGE_NAME: ${{ steps.get_variables.outputs.image_name }} 101 | BUILD_TYPE: ${{ matrix.build_type }} 102 | 103 | - name: push images 104 | run: | 105 | echo ${PASSWORD} | docker login -u $USERNAME --password-stdin 106 | docker push ${IMAGE_NAME} 107 | env: 108 | USERNAME: ${{ secrets.DOCKER_USERNAME }} 109 | PASSWORD: ${{ secrets.DOCKER_PASSWORD }} 110 | IMAGE_NAME: ${{ steps.get_variables.outputs.image_name }} 111 | -------------------------------------------------------------------------------- /nbs/README.md: -------------------------------------------------------------------------------- 1 | fastrl 2 | ================ 3 | 4 | 5 | 6 | [![CI 7 | Status](https://github.com/josiahls/fastrl/workflows/Fastrl%20Testing/badge.svg)](https://github.com/josiahls/fastrl/actions?query=workflow%3A%22Fastrl+Testing%22) 8 | [![pypi fastrl 9 | version](https://img.shields.io/pypi/v/fastrl.svg)](https://pypi.python.org/pypi/fastrl) 10 | [![Conda fastrl 11 | version](https://img.shields.io/conda/v/josiahls/fastrl.svg)](https://anaconda.org/josiahls/fastrl) 12 | [![Docker Image 13 | Latest](https://img.shields.io/docker/v/josiahls/fastrl?label=Docker&sort=date)](https://hub.docker.com/repository/docker/josiahls/fastrl) 14 | [![Docker Image-Dev 15 | Latest](https://img.shields.io/docker/v/josiahls/fastrl-dev?label=Docker%20Dev&sort=date)](https://hub.docker.com/repository/docker/josiahls/fastrl-dev) 16 | 17 | [![Anaconda-Server 18 | Badge](https://anaconda.org/josiahls/fastrl/badges/platforms.svg)](https://anaconda.org/josiahls/fastrl) 19 | [![fastrl python 20 | compatibility](https://img.shields.io/pypi/pyversions/fastrl.svg)](https://pypi.python.org/pypi/fastrl) 21 | [![fastrl 22 | license](https://img.shields.io/pypi/l/fastrl.svg)](https://pypi.python.org/pypi/fastrl) 23 | 24 | > Warning: Even before fastrl==2.0.0, all Models should converge 25 | > reasonably fast, however HRL models `DADS` and `DIAYN` will need 26 | > re-balancing and some extra features that the respective authors used. 27 | 28 | # Overview 29 | 30 | Here is change 31 | 32 | Fastai for computer vision and tabular learning has been amazing. One 33 | would wish that this would be the same for RL. The purpose of this repo 34 | is to have a framework that is as easy as possible to start, but also 35 | designed for testing new agents. 36 | 37 | Documentation is being served at https://josiahls.github.io/fastrl/ from 38 | documentation directly generated via `nbdev` in this repo. 39 | 40 | # Current Issues of Interest 41 | 42 | ## Data Issues 43 | 44 | - [ ] data and async_data are still buggy. We need to verify that the 45 | order that the data being returned is the best it can be for our 46 | models. We need to make sure that “terminateds” are returned and that 47 | there are new duplicate (unless intended) 48 | - [ ] Better data debugging. Do environments skips steps correctly? Do 49 | n_steps work correct? 50 | 51 | # Whats new? 52 | 53 | As we have learned how to support as many RL agents as possible, we 54 | found that `fastrl==1.*` was vastly limited in the models that it can 55 | support. `fastrl==2.*` will leverage the `nbdev` library for better 56 | documentation and more relevant testing. We also will be building on the 57 | work of the `ptan`1 library as a close reference for pytorch 58 | based reinforcement learning APIs. 59 | 60 | 1 “Shmuma/Ptan”. Github, 2020, 61 | https://github.com/Shmuma/ptan. Accessed 13 June 2020. 62 | 63 | ## Install 64 | 65 | ## PyPI (Not implemented yet) 66 | 67 | Placeholder here, there is no pypi package yet. It is recommended to do 68 | traditional forking. 69 | 70 | (For future, currently there is no pypi 71 | persion)`pip install fastrl==2.0.0 --pre` 72 | 73 | ## Conda (Not implimented yet) 74 | 75 | `conda install -c fastchan -c josiahls fastrl` 76 | 77 | `source activate fastrl && python setup.py develop` 78 | 79 | ## Docker (highly recommend) 80 | 81 | Install: 82 | [Nvidia-Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) 83 | 84 | Install: [docker-compose](https://docs.docker.com/compose/install/) 85 | 86 | ``` bash 87 | docker-compose pull && docker-compose up 88 | ``` 89 | 90 | ## Contributing 91 | 92 | After you clone this repository, please run `nbdev_install_git_hooks` in 93 | your terminal. This sets up git hooks, which clean up the notebooks to 94 | remove the extraneous stuff stored in the notebooks (e.g. which cells 95 | you ran) which causes unnecessary merge conflicts. 96 | 97 | Before submitting a PR, check that the local library and notebooks 98 | match. The script `nbdev_diff_nbs` can let you know if there is a 99 | difference between the local library and the notebooks. \* If you made a 100 | change to the notebooks in one of the exported cells, you can export it 101 | to the library with `nbdev_build_lib` or `make fastai2`. \* If you made 102 | a change to the library, you can export it back to the notebooks with 103 | `nbdev_update_lib`. 104 | -------------------------------------------------------------------------------- /fastrl/pipes/iter/nstep.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['NStepper', 'NStepFlattener', 'n_steps_expected'] 5 | 6 | # %% ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb 2 7 | # Python native modules 8 | import os 9 | from typing import Type, Dict, Union, Tuple 10 | import typing 11 | import warnings 12 | # Third party libs 13 | from fastcore.all import add_docs 14 | import torchdata.datapipes as dp 15 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,DataPipe 16 | from torchdata.datapipes.iter import IterDataPipe 17 | from torchdata.datapipes.map import MapDataPipe 18 | # Local modules 19 | from ...core import StepTypes 20 | 21 | # %% ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb 4 22 | class NStepper(IterDataPipe): 23 | def __init__( 24 | self, 25 | # The datapipe we are extracting from must produce `StepType.types` 26 | source_datapipe:IterDataPipe[Union[StepTypes.types]], 27 | # Maximum number of steps to produce per yield as a tuple. This is the *max* number 28 | # and may be less if for example we are yielding terminal states. 29 | # Default produces single steps 30 | n:int=1 31 | ) -> None: 32 | self.source_datapipe:IterDataPipe[StepTypes.types] = source_datapipe 33 | self.n:int = n 34 | self.env_buffer:Dict = {} 35 | 36 | def __iter__(self) -> StepTypes.types: 37 | self.env_buffer = {} 38 | for step in self.source_datapipe: 39 | if not issubclass(step.__class__,StepTypes.types): 40 | raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}') 41 | 42 | env_id,terminated = int(step.env_id),bool(step.terminated) 43 | 44 | if env_id in self.env_buffer: 45 | self.env_buffer[env_id].append(step) 46 | else: 47 | self.env_buffer[env_id] = [step] 48 | 49 | if not terminated and len(self.env_buffer[env_id]) None: 72 | self.source_datapipe:IterDataPipe[[StepTypes.types]] = source_datapipe 73 | 74 | def __iter__(self) -> StepTypes.types: 75 | for step in self.source_datapipe: 76 | if issubclass(step.__class__,StepTypes.types): 77 | # print(step) 78 | yield step 79 | elif isinstance(step,tuple): 80 | # print('got step: ',step) 81 | yield from step 82 | else: 83 | raise Exception(f'Expected {StepTypes.types} or tuple object got {type(step)}\n{step}') 84 | 85 | 86 | add_docs( 87 | NStepFlattener, 88 | """Handles unwrapping `StepType.typess` in tuples better than `dp.iter.UnBatcher` and `dp.iter.Flattener`""", 89 | ) 90 | 91 | # %% ../../../nbs/01_DataPipes/01c_pipes.iter.nstep.ipynb 20 92 | def n_steps_expected( 93 | default_steps:int, # The number of steps the episode would run without n_steps 94 | n:int # The n-step value that we are planning ot use 95 | ): 96 | return (default_steps * n) - sum(range(n)) 97 | 98 | n_steps_expected.__doc__=r""" 99 | Produces the expected number of steps, assuming a fully deterministic episode based on `default_steps` and `n` 100 | 101 | Given `n=2`, given 1 envs, knowing that `CartPole-v1` when `seed=0` will always run 18 steps, the total 102 | steps will be: 103 | 104 | $$ 105 | 18 * n - \sum_{0}^{n - 1}(i) 106 | $$ 107 | """ 108 | -------------------------------------------------------------------------------- /fastrl/pipes/iter/cacheholder.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/02a_pipes.iter.cacheholder.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['T_co', 'PickleableInMemoryCacheHolderIterDataPipe'] 5 | 6 | # %% ../../../nbs/01_DataPipes/02a_pipes.iter.cacheholder.ipynb 2 7 | # Copyright (c) Meta Platforms, Inc. and affiliates. 8 | # All rights reserved. 9 | # 10 | # This source code is licensed under the BSD-style license found in the 11 | # LICENSE file in the root directory of this source tree. 12 | 13 | # Python native modules 14 | import hashlib 15 | import inspect 16 | import os.path 17 | import sys 18 | import time 19 | import uuid 20 | import warnings 21 | from enum import IntEnum 22 | 23 | from collections import deque 24 | from functools import partial 25 | from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Tuple, TypeVar 26 | # Third party libs 27 | try: 28 | import portalocker 29 | except ImportError: 30 | portalocker = None 31 | 32 | from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE 33 | 34 | from torch.utils.data.graph import traverse_dps 35 | from torchdata.datapipes import functional_datapipe 36 | from torchdata.datapipes.iter import IterableWrapper, IterDataPipe 37 | # Local modules 38 | 39 | 40 | # %% ../../../nbs/01_DataPipes/02a_pipes.iter.cacheholder.ipynb 4 41 | if DILL_AVAILABLE: 42 | import dill 43 | 44 | dill.extend(use_dill=False) 45 | 46 | T_co = TypeVar("T_co", covariant=True) 47 | 48 | @functional_datapipe("pickleable_in_memory_cache") 49 | class PickleableInMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]): 50 | r""" 51 | Stores elements from the source DataPipe in memory, up to a size limit 52 | if specified (functional name: ``in_memory_cache``). This cache is FIFO - once the cache is full, 53 | further elements will not be added to the cache until the previous ones are yielded and popped off from the cache. 54 | 55 | Args: 56 | source_dp: source DataPipe from which elements are read and stored in memory 57 | size: The maximum size (in megabytes) that this DataPipe can hold in memory. This defaults to unlimited. 58 | 59 | Example: 60 | >>> from torchdata.datapipes.iter import IterableWrapper 61 | >>> source_dp = IterableWrapper(range(10)) 62 | >>> cache_dp = source_dp.pickleable_in_memory_cache(size=5) 63 | >>> list(cache_dp) 64 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 65 | """ 66 | size: Optional[int] = None 67 | idx: int 68 | 69 | def __init__(self, source_dp: IterDataPipe[T_co], size: Optional[int] = None) -> None: 70 | self.source_dp: IterDataPipe[T_co] = source_dp 71 | # cache size in MB 72 | if size is not None: 73 | self.size = size * 1024 * 1024 74 | self.cache: Optional[Deque] = None 75 | self.idx: int = 0 76 | 77 | def __getstate__(self): 78 | state = ( 79 | self.source_dp, 80 | self.size 81 | ) 82 | if IterDataPipe.getstate_hook is not None: 83 | return IterDataPipe.getstate_hook(state) 84 | return state 85 | 86 | def __setstate__(self, state): 87 | ( 88 | self.source_dp, 89 | self.size 90 | ) = state 91 | self.cache: Optional[Deque] = None 92 | self.idx: int = 0 93 | 94 | def __iter__(self) -> Iterator[T_co]: 95 | if self.cache: 96 | if self.idx > 0: 97 | for idx, data in enumerate(self.source_dp): 98 | if idx < self.idx: 99 | yield data 100 | else: 101 | break 102 | yield from self.cache 103 | else: 104 | # Local cache 105 | cache: Deque = deque() 106 | idx = 0 107 | for data in self.source_dp: 108 | cache.append(data) 109 | # Cache reaches limit 110 | if self.size is not None and sys.getsizeof(cache) > self.size: 111 | cache.popleft() 112 | idx += 1 113 | yield data 114 | self.cache = cache 115 | self.idx = idx 116 | 117 | def __len__(self) -> int: 118 | try: 119 | return len(self.source_dp) 120 | except TypeError: 121 | if self.cache: 122 | return self.idx + len(self.cache) 123 | else: 124 | raise TypeError(f"{type(self).__name__} instance doesn't have valid length until the cache is loaded.") 125 | 126 | -------------------------------------------------------------------------------- /fastrl/pipes/map/demux.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01b_pipes.map.demux.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['T_co', 'DemultiplexerMapDataPipe'] 5 | 6 | # %% ../../../nbs/01_DataPipes/01b_pipes.map.demux.ipynb 3 7 | # Python native modules 8 | import os 9 | from inspect import isfunction,ismethod 10 | from typing import Callable, Dict, Iterable, Optional, TypeVar 11 | # Third party libs 12 | from fastcore.all import * 13 | from ...torch_core import * 14 | from torch.utils.data.datapipes.utils.common import _check_unpickable_fn 15 | import torchdata.datapipes as dp 16 | from torchdata.datapipes import functional_datapipe 17 | from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,MapDataPipe,IterDataPipe 18 | from torchdata.dataloader2.dataloader2 import DataLoader2 19 | # Local modules 20 | 21 | # %% ../../../nbs/01_DataPipes/01b_pipes.map.demux.ipynb 5 22 | T_co = TypeVar("T_co", covariant=True) 23 | 24 | @functional_datapipe("demux") 25 | class DemultiplexerMapDataPipe(MapDataPipe[T_co]): 26 | def __new__(cls, datapipe: MapDataPipe, num_instances: int, classifier_fn: Callable, drop_none: bool = False, 27 | source_index: Optional[Iterable] = None): 28 | if num_instances < 1: 29 | raise ValueError(f"Expected `num_instances` larger than 0, but {num_instances} is found") 30 | _check_unpickable_fn(classifier_fn) 31 | container = _DemultiplexerMapDataPipe(datapipe, num_instances, classifier_fn, drop_none, source_index) 32 | return [_DemultiplexerChildMapDataPipe(container, i) for i in range(num_instances)] 33 | 34 | class _DemultiplexerMapDataPipe(MapDataPipe[T_co]): 35 | def __init__( 36 | self, 37 | datapipe: MapDataPipe[T_co], 38 | num_instances: int, 39 | classifier_fn: Callable[[T_co], Optional[int]], 40 | drop_none: bool, 41 | source_index: Optional[Iterable], 42 | ): 43 | self.main_datapipe = datapipe 44 | self.num_instances = num_instances 45 | self.classifier_fn = classifier_fn 46 | self.drop_none = drop_none 47 | self.iterator = None 48 | self.exhausted = False # Once we iterate through `main_datapipe` once, we know all the index mapping 49 | self.index_mapping = [[] for _ in range(num_instances)] 50 | self.source_index = source_index # if None, assume `main_datapipe` 0-index 51 | 52 | def _classify_next(self): 53 | if self.source_index is None: 54 | self.source_index = range(len(self.main_datapipe)) 55 | if self.iterator is None: 56 | self.iterator = iter(self.source_index) 57 | try: 58 | next_source_idx = next(self.iterator) 59 | except StopIteration: 60 | self.exhausted = True 61 | return 62 | value = self.main_datapipe[next_source_idx] 63 | classification = self.classifier_fn(value) 64 | if classification is None and self.drop_none: 65 | self._classify_next() 66 | else: 67 | self.index_mapping[classification].append(value) 68 | 69 | def classify_all(self): 70 | while not self.exhausted: 71 | self._classify_next() 72 | 73 | def get_value(self, instance_id: int, index: int) -> T_co: 74 | while not self.exhausted and len(self.index_mapping[instance_id]) <= index: 75 | self._classify_next() 76 | if len(self.index_mapping[instance_id]) > index: 77 | return self.index_mapping[instance_id][index] 78 | raise RuntimeError("Index is out of bound.") 79 | 80 | def __len__(self): 81 | return len(self.main_datapipe) 82 | 83 | class _DemultiplexerChildMapDataPipe(MapDataPipe[T_co]): 84 | def __init__(self, main_datapipe: _DemultiplexerMapDataPipe, instance_id: int): 85 | self.main_datapipe: _DemultiplexerMapDataPipe = main_datapipe 86 | self.instance_id = instance_id 87 | 88 | def __getitem__(self, index: int): 89 | return self.main_datapipe.get_value(self.instance_id, index) 90 | 91 | def __len__(self): 92 | self.main_datapipe.classify_all() # You have to read through the entirety of main_datapipe to know `len` 93 | return len(self.main_datapipe.index_mapping[self.instance_id]) 94 | 95 | def __iter__(self): 96 | for i in range(len(self)): 97 | yield self[i] 98 | 99 | DemultiplexerMapDataPipe.__doc__ = """Splits the input DataPipe into multiple child DataPipes, using the given 100 | classification function (functional name: ``demux``). A list of the child DataPipes is returned from this operation. 101 | """ 102 | -------------------------------------------------------------------------------- /fastrl/nbdev_extensions.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_nbdev_extensions.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['header', 'create_blog_notebook'] 5 | 6 | # %% ../nbs/00_nbdev_extensions.ipynb 1 7 | # Python native modules 8 | import os 9 | from datetime import datetime 10 | import os 11 | import shutil 12 | import json 13 | # Third party libs 14 | from fastcore.all import * 15 | from nbdev.config import get_config 16 | import yaml 17 | from IPython.display import display, Markdown 18 | from fastcore.all import call_parse 19 | # Local modules 20 | 21 | # %% ../nbs/00_nbdev_extensions.ipynb 5 22 | def header( 23 | # The main header title to display. 24 | title: str, 25 | # The subtitle to display underneath the title. If None, no subtitle will be displayed. 26 | subtitle: Optional[str] = None, 27 | # If True, the date associated with the header will be frozen, 28 | # meaning it won't change in subsequent runs. 29 | # If False, a new date will be generated each time the function is run, 30 | # and the date will not be saved to file. 31 | freeze: bool = False 32 | ): 33 | """ 34 | Function to generate a Markdown formatted header with an associated date. 35 | Dates are auto-incremented and can be frozen. This function also controls the persistent storage of dates. 36 | """ 37 | filename = 'header_dates.json' 38 | date = None 39 | id:int = None 40 | 41 | # Load or initialize date dictionary 42 | if os.path.exists(filename): 43 | with open(filename, 'r') as file: 44 | dates = json.load(file) 45 | else: 46 | dates = {} 47 | 48 | # Determine the id for the new entry 49 | if freeze: 50 | # If frozen, use the maximum id from the file, or 0 if the file is empty 51 | id = max(dates.keys(), default=0) 52 | else: 53 | # If not frozen, increment the maximum id from the file, or use 0 if the file is empty 54 | id = max(dates.keys(), default=-1) + 1 55 | 56 | # Get or create the date 57 | date = dates.get(id) 58 | if date is None: 59 | date = datetime.now().strftime('%Y-%m-%d') 60 | dates[id] = date 61 | 62 | # Only write to file if the date is frozen 63 | if freeze: 64 | with open(filename, 'w') as file: 65 | json.dump(dates, file) 66 | 67 | # Display the markdown 68 | if subtitle is None: 69 | display(Markdown(f"# `{date}` **{title}**")) 70 | else: 71 | display(Markdown(f"# `{date}` **{title}**\n> {subtitle}")) 72 | 73 | 74 | # %% ../nbs/00_nbdev_extensions.ipynb 8 75 | @call_parse 76 | def create_blog_notebook() -> None: # Creates a new blog notebook from template 77 | template = '99_blog.from_xxxx_xx_to_xx.ipynb' 78 | new_name = datetime.now().strftime('99_blog.from_%Y_%m_to_now.ipynb') 79 | 80 | # Check if the template file exists 81 | if not os.path.exists(template): 82 | raise FileNotFoundError(f"Template file '{template}' not found in current directory.") 83 | 84 | # Rename old notebooks and update sidebar.yml 85 | sidebar_file = '../sidebar.yml' 86 | with open(sidebar_file, 'r') as f: 87 | sidebar = yaml.safe_load(f) 88 | 89 | blog_section = None 90 | for section in sidebar['website']['sidebar']['contents']: 91 | print(section) 92 | if 'section' in section and section['section'] == 'Blog': 93 | blog_section = section['contents'] 94 | break 95 | 96 | # Rename old notebooks 97 | for filename in os.listdir(): 98 | if filename.startswith('99_blog.from_') and filename.endswith('_to_now.ipynb'): 99 | date_from = filename[13:20] # corrected substring indexing 100 | date_to = datetime.now().strftime('%Y_%m') 101 | new_filename = f'99_blog.from_{date_from}_to_{date_to}.ipynb' 102 | os.rename(filename, new_filename) 103 | 104 | if blog_section is not None: 105 | # Update sidebar.yml 106 | old_entry = f'12_Blog/{filename}' 107 | new_entry = f'12_Blog/{new_filename}' 108 | if old_entry in blog_section: 109 | blog_section.remove(old_entry) 110 | blog_section.append(new_entry) 111 | 112 | # Add new notebook to sidebar.yml 113 | if f'12_Blog/{new_name}' not in blog_section: 114 | blog_section.append(f'12_Blog/{new_name}') 115 | 116 | with open(sidebar_file, 'w') as f: 117 | yaml.safe_dump(sidebar, f) 118 | 119 | # Create new notebook from template 120 | shutil.copy(template, new_name) 121 | -------------------------------------------------------------------------------- /fastrl/agents/dqn/target.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['TargetModelUpdater', 'TargetModelQCalc', 'DQNTargetLearner'] 5 | 6 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 2 7 | # Python native modules 8 | from copy import deepcopy 9 | from typing import Optional,Callable,Tuple 10 | # Third party libs 11 | import torchdata.datapipes as dp 12 | from torchdata.dataloader2.graph import traverse_dps,DataPipe 13 | import torch 14 | from torch import nn,optim 15 | # Local modulesf 16 | from ...pipes.core import find_dp 17 | from ...memory.experience_replay import ExperienceReplay 18 | from ...loggers.core import BatchCollector,EpochCollector 19 | from ...learner.core import LearnerBase,LearnerHead 20 | from ...loggers.vscode_visualizers import VSCodeDataPipe 21 | from ...loggers.core import ProgressBarLogger 22 | from fastrl.agents.dqn.basic import ( 23 | LossCollector, 24 | RollingTerminatedRewardCollector, 25 | EpisodeCollector, 26 | StepBatcher, 27 | TargetCalc, 28 | LossCalc, 29 | ModelLearnCalc, 30 | DQN, 31 | DQNAgent 32 | ) 33 | 34 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 7 35 | class TargetModelUpdater(dp.iter.IterDataPipe): 36 | def __init__(self,source_datapipe,target_sync=300): 37 | self.source_datapipe = source_datapipe 38 | self.target_sync = target_sync 39 | self.n_batch = 0 40 | self.learner = find_dp(traverse_dps(self),LearnerBase) 41 | with torch.no_grad(): 42 | self.learner.target_model = deepcopy(self.learner.model) 43 | 44 | def reset(self): 45 | self.learner = find_dp(traverse_dps(self),LearnerBase) 46 | with torch.no_grad(): 47 | self.learner.target_model = deepcopy(self.learner.model) 48 | 49 | def __iter__(self): 50 | if self._snapshot_state.NotStarted: 51 | self.reset() 52 | for batch in self.source_datapipe: 53 | if self.n_batch%self.target_sync==0: 54 | with torch.no_grad(): 55 | self.learner.target_model.load_state_dict(self.learner.model.state_dict()) 56 | self.n_batch+=1 57 | yield batch 58 | 59 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 8 60 | class TargetModelQCalc(dp.iter.IterDataPipe): 61 | def __init__(self,source_datapipe=None): 62 | self.source_datapipe = source_datapipe 63 | 64 | def __iter__(self): 65 | self.learner = find_dp(traverse_dps(self),LearnerBase) 66 | for batch in self.source_datapipe: 67 | self.learner.done_mask = batch.terminated.reshape(-1,) 68 | with torch.no_grad(): 69 | self.learner.next_q = self.learner.target_model(batch.next_state) 70 | self.learner.next_q = self.learner.next_q.max(dim=1).values.reshape(-1,1) 71 | self.learner.next_q[self.learner.done_mask] = 0 72 | yield batch 73 | 74 | # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 9 75 | def DQNTargetLearner( 76 | model, 77 | dls, 78 | do_logging:bool=True, 79 | loss_func=nn.MSELoss(), 80 | opt=optim.AdamW, 81 | lr=0.005, 82 | bs=128, 83 | max_sz=10000, 84 | nsteps=1, 85 | device=None, 86 | batches=None, 87 | target_sync=300 88 | ) -> LearnerHead: 89 | learner = LearnerBase(model,dls=dls[0]) 90 | learner = BatchCollector(learner,batches=batches) 91 | learner = EpochCollector(learner) 92 | if do_logging: 93 | learner = learner.dump_records() 94 | learner = ProgressBarLogger(learner) 95 | learner = RollingTerminatedRewardCollector(learner) 96 | learner = EpisodeCollector(learner).catch_records() 97 | learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) 98 | learner = StepBatcher(learner,device=device) 99 | learner = TargetModelQCalc(learner) 100 | learner = TargetCalc(learner,nsteps=nsteps) 101 | learner = LossCalc(learner,loss_func=loss_func) 102 | learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) 103 | learner = TargetModelUpdater(learner,target_sync=target_sync) 104 | if do_logging: 105 | learner = LossCollector(learner).catch_records() 106 | 107 | if len(dls)==2: 108 | val_learner = LearnerBase(model,dls[1]).visualize_vscode() 109 | val_learner = BatchCollector(val_learner,batches=batches) 110 | val_learner = EpochCollector(val_learner).catch_records(drop=True) 111 | return LearnerHead((learner,val_learner)) 112 | else: 113 | return LearnerHead(learner) 114 | -------------------------------------------------------------------------------- /nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "offshore-stuart", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#|default_exp agents.dqn.dueling" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "assisted-contract", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "#|export\n", 33 | "# Python native modules\n", 34 | "# Third party libs\n", 35 | "import torch\n", 36 | "from torch import nn\n", 37 | "# Local modules\n", 38 | "from fastrl.agents.dqn.basic import (\n", 39 | " DQN,\n", 40 | " DQNAgent\n", 41 | ")\n", 42 | "from fastrl.agents.dqn.target import DQNTargetLearner" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "lesser-innocent", 48 | "metadata": {}, 49 | "source": [ 50 | "# DQN Dueling\n", 51 | "> DQN using a split head for comparing the davantage of different actions" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "30c98be0-6288-443a-b4ab-9390fbe3081c", 57 | "metadata": {}, 58 | "source": [ 59 | "\n", 60 | "\n", 61 | "## Training DataPipes" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "d921085e-7d53-40ac-9b37-56a31b15d47c", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "#|export\n", 72 | "class DuelingHead(nn.Module):\n", 73 | " def __init__(\n", 74 | " self,\n", 75 | " hidden: int, # Input into the DuelingHead, likely a hidden layer input\n", 76 | " n_actions: int, # Number/dim of actions to output\n", 77 | " lin_cls = nn.Linear\n", 78 | " ):\n", 79 | " super().__init__()\n", 80 | " self.val = lin_cls(hidden,1)\n", 81 | " self.adv = lin_cls(hidden,n_actions)\n", 82 | "\n", 83 | " def forward(self,xi):\n", 84 | " val,adv = self.val(xi),self.adv(xi)\n", 85 | " xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)\n", 86 | " return xi" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "2b8f9ed8-fb05-40a1-ac0d-d4cafee8fa07", 92 | "metadata": {}, 93 | "source": [ 94 | "Try training with basic defaults..." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "56acb9ff", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "from fastrl.envs.gym import GymDataPipe\n", 105 | "from fastrl.dataloading.core import dataloaders" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "63d9b481-5998-472a-a2df-18d79bf07ae2", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "#|eval:false\n", 116 | "# Setup up the core NN\n", 117 | "torch.manual_seed(0)\n", 118 | "model = DQN(4,2,head_layer=DuelingHead)\n", 119 | "# Setup the Agent\n", 120 | "model.train()\n", 121 | "agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n", 122 | "# Setup the Dataloaders\n", 123 | "params = dict(\n", 124 | " source=['CartPole-v1']*1,\n", 125 | " agent=agent,\n", 126 | " nsteps=2,\n", 127 | " nskips=2,\n", 128 | " firstlast=True\n", 129 | ")\n", 130 | "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", 131 | "# Setup the Learner\n", 132 | "learner = DQNTargetLearner(\n", 133 | " model,\n", 134 | " dls,\n", 135 | " bs=128,\n", 136 | " max_sz=100_000,\n", 137 | " nsteps=2,\n", 138 | " lr=0.01,\n", 139 | " batches=1000,\n", 140 | " target_sync=300\n", 141 | ")\n", 142 | "learner.fit(7)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "cf3c74e6", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "#|eval:false\n", 153 | "learner.validate()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "current-pilot", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "#|hide\n", 164 | "#|eval: false\n", 165 | "!nbdev_export" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "ab3d3626-1702-4f22-ae15-bb93a75bec68", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "python3", 180 | "language": "python", 181 | "name": "python3" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 5 186 | } 187 | -------------------------------------------------------------------------------- /fastrl/memory/experience_replay.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/04_Memory/06a_memory.experience_replay.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['ExperienceReplay'] 5 | 6 | # %% ../../nbs/04_Memory/06a_memory.experience_replay.ipynb 2 7 | # Python native modules 8 | from copy import copy 9 | # Third party libs 10 | from fastcore.all import add_docs,ifnone 11 | import torchdata.datapipes as dp 12 | import numpy as np 13 | import torch 14 | # Local modules 15 | from ..core import StepTypes 16 | 17 | # %% ../../nbs/04_Memory/06a_memory.experience_replay.ipynb 4 18 | class ExperienceReplay(dp.iter.IterDataPipe): 19 | debug=False 20 | def __init__(self, 21 | source_datapipe, 22 | learner=None, 23 | bs=1, 24 | max_sz=100, 25 | return_idxs=False, 26 | # If the `self.device` is not cpu, and `store_as_cpu=True`, then 27 | # calls to `sample()` will dynamically move them to `self.device`, and 28 | # next `sample()` will move them back to cpu before producing new samples. 29 | # This can be slower, but can save vram. 30 | # If `store_as_cpu=False`, then samples stay on `self.device` 31 | # 32 | # If being run with n_workers>0, shared_memory, and fork, this MUST be true. This is needed because 33 | # otherwise the tensors in the memory will remain shared with the tensors created in the 34 | # dataloader. 35 | store_as_cpu:bool=True, 36 | # When `max_sz` is reached, no new records will be added to the memory. 37 | # This is useful for debugging since a model should be able to 38 | # reach a loss of 0 learning on a static set. 39 | freeze_memory:bool=False 40 | ): 41 | self.memory = np.array([None]*max_sz) 42 | self.source_datapipe = source_datapipe 43 | self.learner = learner 44 | if learner is not None: 45 | self.learner.experience_replay = self 46 | self.bs = bs 47 | self.freeze_memory = freeze_memory 48 | self.max_sz = max_sz 49 | self._sz_tracker = 0 50 | self._idx_tracker = 0 51 | self._cycle_tracker = 0 52 | self.return_idxs = return_idxs 53 | self.store_as_cpu = store_as_cpu 54 | self._last_idx = None 55 | self.device = None 56 | 57 | def to(self,*args,**kwargs): 58 | self.device = kwargs.get('device',None) 59 | 60 | def sample(self,bs=None): 61 | idxs = np.random.choice(range(self._sz_tracker),size=(ifnone(bs,self.bs),),replace=False) 62 | if self.return_idxs: return self.memory[idxs],idxs 63 | self._last_idx = idxs 64 | return [o.to(device=self.device) for o in self.memory[idxs]] 65 | 66 | def __repr__(self): 67 | return str({k:v if k!='memory' else f'{len(self)} elements' for k,v in self.__dict__.items()}) 68 | 69 | def __len__(self): return self._sz_tracker 70 | 71 | def show(self, agent=None): 72 | from fastrl.memory.memory_visualizer import MemoryBufferViewer 73 | return MemoryBufferViewer(self.memory,agent=agent) 74 | 75 | def __iter__(self): 76 | for i,b in enumerate(self.source_datapipe): 77 | if self.debug: print('Experience Replay Adding: ',b) 78 | 79 | if not issubclass(b.__class__,(*StepTypes.types,list,tuple)): 80 | raise Exception(f'Expected typing.NamedTuple,list,tuple object got {type(step)}\n{step}') 81 | 82 | if issubclass(b.__class__,StepTypes.types): self.add(b) 83 | elif issubclass(b.__class__,(list,tuple)): 84 | for step in b: self.add(step) 85 | else: 86 | raise Exception(f'This should not have occured: {self.__dict__}') 87 | 88 | if self._sz_tracker=self.max_sz: 104 | if not self.freeze_memory: 105 | if self._idx_tracker>=self.max_sz: 106 | self._idx_tracker = 0 107 | self._cycle_tracker += 1 108 | self.memory[self._idx_tracker] = step 109 | self._idx_tracker += 1 110 | else: 111 | raise Exception(f'This should not have occured: {self.__dict__}') 112 | 113 | add_docs( 114 | ExperienceReplay, 115 | """Simplest form of memory. Takes steps from `source_datapipe` to stores them in `memory`. 116 | It outputs `bs` steps.""", 117 | sample="Returns `bs` steps from `memory` in a uniform distribution.", 118 | add="Adds new steps to `memory`. If `memory` reaches size `max_sz` then `step` will be added in earlier steps.", 119 | to=torch.Tensor.to.__doc__, 120 | show="Displays a ipywidget to look at the steps in `self.memory`" 121 | ) 122 | -------------------------------------------------------------------------------- /nbs/02_DataLoading/00_core.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "offshore-stuart", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#|default_exp dataloading.core" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "assisted-contract", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "#|export\n", 33 | "# Python native modules\n", 34 | "from typing import Tuple,Union,List\n", 35 | "# Third party libs\n", 36 | "import torchdata.datapipes as dp\n", 37 | "from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2\n", 38 | "from fastcore.all import delegates\n", 39 | "# Local modules" 40 | ] 41 | }, 42 | { 43 | "attachments": {}, 44 | "cell_type": "markdown", 45 | "id": "a258abcf", 46 | "metadata": {}, 47 | "source": [ 48 | "# Dataloading Core\n", 49 | "> Basic utils for creating dataloaders from rl datapipes." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "0c4e1268", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "#|export\n", 60 | "@delegates(MultiProcessingReadingService)\n", 61 | "def dataloaders(\n", 62 | " # A tuple of iterable datapipes to generate dataloaders from.\n", 63 | " pipes:Union[Tuple[dp.iter.IterDataPipe],dp.iter.IterDataPipe],\n", 64 | " # Concat the dataloaders together\n", 65 | " do_concat:bool = False,\n", 66 | " # Multiplex the dataloaders\n", 67 | " do_multiplex:bool = False,\n", 68 | " # Number of workers the dataloaders should run in\n", 69 | " num_workers: int = 0,\n", 70 | " **kwargs\n", 71 | ") -> Union[dp.iter.IterDataPipe,List[dp.iter.IterDataPipe]]:\n", 72 | " \"Function that creates dataloaders based on `pipes` with different ways of combing them.\"\n", 73 | " if not isinstance(pipes,tuple):\n", 74 | " pipes = (pipes,)\n", 75 | "\n", 76 | " dls = []\n", 77 | " for pipe in pipes:\n", 78 | " dl = DataLoader2(\n", 79 | " datapipe=pipe,\n", 80 | " reading_service=MultiProcessingReadingService(\n", 81 | " num_workers = num_workers,\n", 82 | " **kwargs\n", 83 | " ) if num_workers > 0 else None\n", 84 | " )\n", 85 | " dl = dp.iter.IterableWrapper(dl,deepcopy=False)\n", 86 | " dls.append(dl)\n", 87 | " #TODO(josiahls): Not sure if this is needed tbh.. Might be better to just\n", 88 | " # return dls, and have the user wrap them if they want. Then try can do more complex stuff.\n", 89 | " if do_concat:\n", 90 | " return dp.iter.Concater(*dls)\n", 91 | " elif do_multiplex:\n", 92 | " return dp.iter.Multiplexer(*dls)\n", 93 | " else:\n", 94 | " return dls\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "91b2f9c2", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "from fastcore.test import test_eq" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "fc24a6bb", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# Sample Data\n", 115 | "pipe1 = dp.iter.IterableWrapper([1, 2, 3])\n", 116 | "pipe2 = dp.iter.IterableWrapper([4, 5, 6])\n", 117 | "\n", 118 | "# Test for a single IterDataPipe\n", 119 | "dls = dataloaders(pipe1)\n", 120 | "assert len(dls) == 1\n", 121 | "assert isinstance(dls[0], dp.iter.IterableWrapper)\n", 122 | "test_eq(list(dls[0]), [1, 2, 3])\n", 123 | "\n", 124 | "# Test for a tuple of IterDataPipes without concatenation or multiplexing\n", 125 | "dls = dataloaders((pipe1, pipe2))\n", 126 | "test_eq(len(dls),2)\n", 127 | "test_eq(list(dls[0]), [1, 2, 3])\n", 128 | "test_eq(list(dls[1]), [4, 5, 6])\n", 129 | "\n", 130 | "# Test for concatenation\n", 131 | "dl = dataloaders((pipe1, pipe2), do_concat=True)\n", 132 | "assert isinstance(dl, dp.iter.Concater)\n", 133 | "test_eq(list(dl), [1, 2, 3, 4, 5, 6])\n", 134 | "\n", 135 | "# Test for multiplexing\n", 136 | "dl = dataloaders((pipe1, pipe2), do_multiplex=True)\n", 137 | "assert isinstance(dl, dp.iter.Multiplexer)\n", 138 | "test_eq(list(dl), [1, 4, 2, 5, 3, 6])" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "current-pilot", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "#|hide\n", 149 | "#|eval: false\n", 150 | "!nbdev_export" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "ed71a089", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "python3", 165 | "language": "python", 166 | "name": "python3" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 5 171 | } 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | fastrl 2 | ================ 3 | 4 | 5 | 6 | [![CI 7 | Status](https://github.com/josiahls/fastrl/workflows/Fastrl%20Testing/badge.svg)](https://github.com/josiahls/fastrl/actions?query=workflow%3A%22Fastrl+Testing%22) 8 | [![pypi fastrl 9 | version](https://img.shields.io/pypi/v/fastrl.svg)](https://pypi.python.org/pypi/fastrl) 10 | [![Docker Image 11 | Latest](https://img.shields.io/docker/v/josiahls/fastrl?label=Docker&sort=date.png)](https://hub.docker.com/repository/docker/josiahls/fastrl) 12 | [![Docker Image-Dev 13 | Latest](https://img.shields.io/docker/v/josiahls/fastrl-dev?label=Docker%20Dev&sort=date.png)](https://hub.docker.com/repository/docker/josiahls/fastrl-dev) 14 | 15 | [![fastrl python 16 | compatibility](https://img.shields.io/pypi/pyversions/fastrl.svg)](https://pypi.python.org/pypi/fastrl) 17 | [![fastrl 18 | license](https://img.shields.io/pypi/l/fastrl.svg)](https://pypi.python.org/pypi/fastrl) 19 | 20 | > Warning: This is in alpha, and so uses latest torch and torchdata, 21 | > very importantly torchdata. The base API, while at the point of 22 | > semi-stability, might be changed in future versions, and so there will 23 | > be no promises of backward compatiblity. For the time being, it is 24 | > best to hard-pin versions of the library. 25 | 26 | > Warning: Even before fastrl==2.0.0, all Models should converge 27 | > reasonably fast, however HRL models `DADS` and `DIAYN` will need 28 | > re-balancing and some extra features that the respective authors used. 29 | 30 | # Overview 31 | 32 | Fastai for computer vision and tabular learning has been amazing. One 33 | would wish that this would be the same for RL. The purpose of this repo 34 | is to have a framework that is as easy as possible to start, but also 35 | designed for testing new agents. 36 | 37 | This version fo fastrl is basically a wrapper around 38 | [torchdata](https://github.com/pytorch/data). 39 | 40 | It is built around 4 pipeline concepts (half is from fastai): 41 | 42 | - DataLoading/DataBlock pipelines 43 | - Agent pipelines 44 | - Learner pipelines 45 | - Logger plugins 46 | 47 | Documentation is being served at https://josiahls.github.io/fastrl/ from 48 | documentation directly generated via `nbdev` in this repo. 49 | 50 | Basic DQN example: 51 | 52 | ``` python 53 | from fastrl.loggers.core import * 54 | from fastrl.loggers.vscode_visualizers import * 55 | from fastrl.agents.dqn.basic import * 56 | from fastrl.agents.dqn.target import * 57 | from fastrl.data.block import * 58 | from fastrl.envs.gym import * 59 | import torch 60 | ``` 61 | 62 | ``` python 63 | # Setup Loggers 64 | logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector, 65 | batch_on_pipe=BatchCollector) 66 | 67 | # Setup up the core NN 68 | torch.manual_seed(0) 69 | model = DQN(4,2) 70 | # Setup the Agent 71 | agent = DQNAgent(model,[logger_base],max_steps=10000) 72 | # Setup the DataBlock 73 | block = DataBlock( 74 | GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True), # We basically merge 2 steps into 1 and skip. 75 | (GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True,n=100,include_images=True),VSCodeTransformBlock()) 76 | ) 77 | dls = L(block.dataloaders(['CartPole-v1']*1)) 78 | # Setup the Learner 79 | learner = DQNLearner(model,dls,logger_bases=[logger_base],bs=128,max_sz=20_000,nsteps=2,lr=0.001, 80 | batches=1000, 81 | dp_augmentation_fns=[ 82 | # Plugin TargetDQN code 83 | TargetModelUpdater.insert_dp(), 84 | TargetModelQCalc.replace_dp() 85 | ]) 86 | learner.fit(10) 87 | #learner.validate() 88 | ``` 89 | 90 | # Whats new? 91 | 92 | As we have learned how to support as many RL agents as possible, we 93 | found that `fastrl==1.*` was vastly limited in the models that it can 94 | support. `fastrl==2.*` will leverage the `nbdev` library for better 95 | documentation and more relevant testing, and `torchdata` is the base 96 | lib. We also will be building on the work of the `ptan`1 97 | library as a close reference for pytorch based reinforcement learning 98 | APIs. 99 | 100 | 1 “Shmuma/Ptan”. Github, 2020, 101 | https://github.com/Shmuma/ptan. Accessed 13 June 2020. 102 | 103 | ## Install 104 | 105 | ## PyPI 106 | 107 | Below will install the alpha build of fastrl. 108 | 109 | **Cuda Install** 110 | 111 | `pip install fastrl==0.0.* --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu113` 112 | 113 | **Cpu Install** 114 | 115 | `pip install fastrl==0.0.* --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu` 116 | 117 | ## Docker (highly recommend) 118 | 119 | Install: 120 | [Nvidia-Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) 121 | 122 | Install: [docker-compose](https://docs.docker.com/compose/install/) 123 | 124 | ``` bash 125 | docker-compose pull && docker-compose up 126 | ``` 127 | 128 | ## Contributing 129 | 130 | After you clone this repository, please run `nbdev_install_hooks` in 131 | your terminal. This sets up git hooks, which clean up the notebooks to 132 | remove the extraneous stuff stored in the notebooks (e.g. which cells 133 | you ran) which causes unnecessary merge conflicts. 134 | 135 | Before submitting a PR, check that the local library and notebooks 136 | match. The script `nbdev_clean` can let you know if there is a 137 | difference between the local library and the notebooks. \* If you made a 138 | change to the notebooks in one of the exported cells, you can export it 139 | to the library with `nbdev_build_lib` or `make fastai2`. \* If you made 140 | a change to the library, you can export it back to the notebooks with 141 | `nbdev_update_lib`. 142 | -------------------------------------------------------------------------------- /fastrl.Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime 2 | # FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime 3 | FROM nvidia/cuda:12.0.0-runtime-ubuntu20.04 4 | # RUN conda install python=3.8 5 | 6 | ENV CONTAINER_USER fastrl_user 7 | ENV CONTAINER_GROUP fastrl_group 8 | ENV CONTAINER_UID 1000 9 | # Add user to conda 10 | RUN addgroup --gid $CONTAINER_UID $CONTAINER_GROUP && \ 11 | adduser --uid $CONTAINER_UID --gid $CONTAINER_UID $CONTAINER_USER --disabled-password 12 | # && \ 13 | # mkdir -p /opt/conda && chown $CONTAINER_USER /opt/conda 14 | 15 | RUN apt-get update && apt-get install -y software-properties-common rsync curl gcc g++ 16 | #RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-key C99B11DEB97541F0 && apt-add-repository https://cli.github.com/packages 17 | 18 | RUN curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | gpg --dearmor -o /usr/share/keyrings/githubcli-archive-keyring.gpg 19 | RUN echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null 20 | 21 | RUN apt-get install -y build-essential python3.8-dev python3.8-distutils 22 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2 && update-alternatives --config python 23 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3.8 get-pip.py 24 | 25 | RUN apt-get update && apt-get install -y git libglib2.0-dev graphviz libxext6 \ 26 | libsm6 libxrender1 python-opengl xvfb nano gh tree wget libosmesa6-dev \ 27 | libgl1-mesa-glx libglfw3 && apt-get update 28 | 29 | 30 | WORKDIR /home/$CONTAINER_USER 31 | # Install Primary Pip Reqs 32 | # ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/nightly/cu117 33 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP extra/requirements.txt /home/$CONTAINER_USER/extra/requirements.txt 34 | # Since we are using a custom fork of torchdata, we install torchdata as part of a submodule. 35 | # RUN pip3 install -r extra/requirements.txt --pre --upgrade 36 | RUN pip3 install torch>=2.0.0 37 | # --pre --upgrade 38 | RUN pip3 show torch 39 | 40 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP extra/pip_requirements.txt /home/$CONTAINER_USER/extra/pip_requirements.txt 41 | RUN pip3 install -r extra/pip_requirements.txt 42 | 43 | WORKDIR /home/$CONTAINER_USER/fastrl 44 | RUN git clone https://github.com/josiahls/data.git \ 45 | && cd data && pip3 install -e . 46 | WORKDIR /home/$CONTAINER_USER 47 | 48 | # Install Dev Reqs 49 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP extra/dev_requirements.txt /home/$CONTAINER_USER/extra/dev_requirements.txt 50 | ARG BUILD=dev 51 | # Needed for gymnasium[all] when installing Box2d 52 | RUN apt-get install swig3.0 && ln -s /usr/bin/swig3.0 /usr/bin/swig 53 | RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && pip3 install -r extra/dev_requirements.txt ; fi" 54 | # RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz; fi" 55 | # RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && conda install -c conda-forge nodejs==15.14.0 line_profiler && jupyter labextension install jupyterlab-plotly; fi" 56 | 57 | # RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /opt/conda/bin 58 | RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /usr/local/lib/python3.8/dist-packages/torch/utils/data/datapipes 59 | # RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /usr/local/lib/python3.8/dist-packagess/mujoco_py 60 | # RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /usr/local/lib/python3.8/dist-packages 61 | 62 | RUN pip3 show torch 63 | 64 | RUN chown $CONTAINER_USER:$CONTAINER_GROUP -R /home/$CONTAINER_USER 65 | 66 | RUN apt-get install sudo 67 | # Give user password-less sudo access 68 | RUN echo "$CONTAINER_USER ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers.d/$CONTAINER_USER && \ 69 | chmod 0440 /etc/sudoers.d/$CONTAINER_USER 70 | 71 | RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then nbdev_install_quarto ; fi" 72 | 73 | # RUN mkdir -p /home/$CONTAINER_USER/.mujoco \ 74 | # && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ 75 | # && tar -xf mujoco.tar.gz -C /home/$CONTAINER_USER/.mujoco \ 76 | # && rm mujoco.tar.gz 77 | 78 | # ENV LD_LIBRARY_PATH /home/$CONTAINER_USER/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} 79 | # ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH} 80 | 81 | # RUN ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so 82 | 83 | USER $CONTAINER_USER 84 | WORKDIR /home/$CONTAINER_USER 85 | ENV PATH="/home/$CONTAINER_USER/.local/bin:${PATH}" 86 | 87 | # RUN git clone https://github.com/josiahls/fastrl.git --depth 1 88 | RUN pip install setuptools==60.7.0 89 | COPY --chown=$CONTAINER_USER:$CONTAINER_GROUP . fastrl 90 | 91 | RUN sudo apt-get -y install cmake python3.8-venv 92 | 93 | # RUN curl https://get.modular.com | sh - && \ 94 | # modular auth mut_9b52dfea7b05427385fdeddc85dd3a64 && \ 95 | # modular install mojo 96 | 97 | RUN BASHRC=$( [ -f "$HOME/.bash_profile" ] && echo "$HOME/.bash_profile" || echo "$HOME/.bashrc" ) && \ 98 | echo 'export MODULAR_HOME="/home/fastrl_user/.modular"' >> "$BASHRC" && \ 99 | echo 'export PATH="/home/fastrl_user/.modular/pkg/packages.modular.com_mojo/bin:$PATH"' >> "$BASHRC" && \ 100 | source "$BASHRC" 101 | 102 | # RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && cd fastrl/data && mv pyproject.toml pyproject.toml_tmp && pip install -e . --no-dependencies && mv pyproject.toml_tmp pyproject.toml && cd ../; fi" 103 | 104 | RUN /bin/bash -c "if [[ $BUILD == 'prod' ]] ; then echo \"Production Build\" && cd fastrl && pip install . --no-dependencies; fi" 105 | RUN /bin/bash -c "if [[ $BUILD == 'dev' ]] ; then echo \"Development Build\" && cd fastrl && pip install -e \".[dev]\" --no-dependencies ; fi" 106 | 107 | # RUN echo '#!/bin/bash\npip install -e .[dev] --no-dependencies && xvfb-run -s "-screen 0 1400x900x24" jupyter lab --ip=0.0.0.0 --port=8080 --allow-root --no-browser --NotebookApp.token='' --NotebookApp.password=''' >> run_jupyter.sh 108 | 109 | RUN /bin/bash -c "cd fastrl && pip install -e . --no-dependencies" 110 | 111 | -------------------------------------------------------------------------------- /nbs/05_Logging/09d_loggers.jupyter_visualizers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "assisted-contract", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#|export\n", 23 | "# Python native modules\n", 24 | "import os\n", 25 | "from torch.multiprocessing import Queue\n", 26 | "from typing import Tuple,NamedTuple\n", 27 | "# Third party libs\n", 28 | "from fastcore.all import add_docs\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import torchdata.datapipes as dp\n", 31 | "from IPython.core.display import clear_output\n", 32 | "import torch\n", 33 | "import numpy as np\n", 34 | "# Local modules\n", 35 | "from fastrl.core import Record\n", 36 | "from fastrl.loggers.core import LoggerBase,LogCollector,is_record\n", 37 | "# from fastrl.torch_core import *" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "offshore-stuart", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#|default_exp loggers.jupyter_visualizers" 48 | ] 49 | }, 50 | { 51 | "attachments": {}, 52 | "cell_type": "markdown", 53 | "id": "lesser-innocent", 54 | "metadata": {}, 55 | "source": [ 56 | "# Visualizers \n", 57 | "> Iterable pipes for displaying environments as they run using `typing.NamedTuples` with `image` fields" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "9cb488d2-41f1-4160-a938-e39003f1a06a", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "#|export\n", 68 | "class SimpleJupyterVideoPlayer(dp.iter.IterDataPipe):\n", 69 | " def __init__(self, \n", 70 | " source_datapipe=None, \n", 71 | " between_frame_wait_seconds:float=0.1\n", 72 | " ):\n", 73 | " self.source_datapipe = source_datapipe\n", 74 | " self.between_frame_wait_seconds = 0.1\n", 75 | "\n", 76 | " def dequeue(self): \n", 77 | " while self.buffer: yield self.buffer.pop(0)\n", 78 | "\n", 79 | " \n", 80 | " def __iter__(self) -> Tuple[NamedTuple]:\n", 81 | " img = None\n", 82 | " for record in self.source_datapipe:\n", 83 | " # for o in self.dequeue():\n", 84 | " if is_record(record):\n", 85 | " if record.value is None: continue\n", 86 | " if img is None: img = plt.imshow(record.value)\n", 87 | " img.set_data(record.value) \n", 88 | " plt.axis('off')\n", 89 | " display(plt.gcf())\n", 90 | " clear_output(wait=True)\n", 91 | " yield record\n", 92 | "add_docs(\n", 93 | " SimpleJupyterVideoPlayer,\n", 94 | " \"\"\"Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field.\n", 95 | " This only can handle 1 env input.\"\"\",\n", 96 | " dequeue=\"Grabs records from the `main_queue` and attempts to display them\"\n", 97 | ")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "d4e46c41-f7f9-4168-b453-c43ec80377f0", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "#|export\n", 108 | "class ImageCollector(dp.iter.IterDataPipe):\n", 109 | " title:str='image'\n", 110 | "\n", 111 | " def __init__(self,source_datapipe):\n", 112 | " self.source_datapipe = source_datapipe\n", 113 | "\n", 114 | " def convert_np(self,o):\n", 115 | " if isinstance(o,torch.Tensor): return o.detach().numpy()\n", 116 | " elif isinstance(o,np.ndarray): return o\n", 117 | " else: raise ValueError(f'Expects Tensor or np.ndarray not {type(o)}')\n", 118 | " \n", 119 | " def __iter__(self):\n", 120 | " # for q in self.main_buffers: q.append(Record('image',None))\n", 121 | " yield Record(self.title,None)\n", 122 | " for steps in self.source_datapipe:\n", 123 | " if isinstance(steps,dp.DataChunk):\n", 124 | " for step in steps:\n", 125 | " yield Record(self.title,self.convert_np(step.image))\n", 126 | " else:\n", 127 | " yield Record(self.title,self.convert_np(steps.image))\n", 128 | " yield steps" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "1ade24c0", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "from fastrl.envs.gym import GymDataPipe" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "6a4b9ca9-2027-40a1-ac97-4516d60479a2", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "#|eval:false\n", 149 | "%matplotlib inline" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "ead27188-7322-46c2-9300-59842af2386d", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "\n", 160 | "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True)\n", 161 | "pipe = ImageCollector(pipe)\n", 162 | "pipe = SimpleJupyterVideoPlayer(pipe)\n", 163 | "\n", 164 | "for o in pipe: pass" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "current-pilot", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "#|hide\n", 175 | "#|eval: false\n", 176 | "!nbdev_export" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "e0d82468-a2bf-4bfd-9ac7-e56db49b8476", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "python3", 191 | "language": "python", 192 | "name": "python3" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 5 197 | } 198 | -------------------------------------------------------------------------------- /fastrl/test_utils.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20_test_utils.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['get_env', 'try_import', 'nvidia_mem', 'nvidia_smi', 'initialize_notebook', 'show_install'] 5 | 6 | # %% ../nbs/20_test_utils.ipynb 1 7 | # Python native modules 8 | import os 9 | import re 10 | import sys 11 | import importlib 12 | # Third party libs 13 | 14 | # Local modules 15 | 16 | # %% ../nbs/20_test_utils.ipynb 4 17 | def get_env(name): 18 | "Return env var value if it's defined and not an empty string, or return Unknown" 19 | res = os.environ.get(name,'') 20 | return res if len(res) else "Unknown" 21 | 22 | # %% ../nbs/20_test_utils.ipynb 5 23 | def try_import(module): 24 | "Try to import `module`. Returns module's object on success, None on failure" 25 | try: return importlib.import_module(module) 26 | except: return None 27 | 28 | # %% ../nbs/20_test_utils.ipynb 6 29 | def nvidia_mem(): 30 | from fastcore.all import run 31 | try: mem = run("nvidia-smi --query-gpu=memory.total --format=csv,nounits,noheader") 32 | except: return None 33 | return mem.strip().split('\n') 34 | 35 | # %% ../nbs/20_test_utils.ipynb 7 36 | def nvidia_smi(cmd = "nvidia-smi"): 37 | from fastcore.all import run 38 | try: res = run(cmd) 39 | except OSError as e: return None 40 | return res 41 | 42 | # %% ../nbs/20_test_utils.ipynb 8 43 | def initialize_notebook(): 44 | """ 45 | Function to initialize the notebook environment considering whether it is in Colab or not. 46 | It handles installation of necessary packages and setting up the environment variables. 47 | """ 48 | 49 | # Checking if the environment is Google Colab 50 | if os.path.exists("/content"): 51 | # Installing necessary packages 52 | os.system("pip install -Uqq fastrl['dev'] pyvirtualdisplay") 53 | os.system("apt-get install -y xvfb python-opengl > /dev/null 2>&1") 54 | 55 | # Starting a virtual display 56 | from pyvirtualdisplay import Display 57 | display = Display(visible=0, size=(400, 300)) 58 | display.start() 59 | 60 | else: 61 | # If not in Colab, importing necessary packages and checking environment variables 62 | from nbdev.showdoc import show_doc 63 | from nbdev.imports import IN_NOTEBOOK, IN_COLAB, IN_IPYTHON 64 | 65 | # Asserting the environment variables 66 | if not os.environ.get("IN_TEST", None): 67 | assert IN_NOTEBOOK 68 | assert not IN_COLAB 69 | assert IN_IPYTHON 70 | 71 | 72 | # %% ../nbs/20_test_utils.ipynb 9 73 | def show_install(show_nvidia_smi:bool=False): 74 | "Print user's setup information" 75 | 76 | # import fastai 77 | import platform 78 | import fastprogress 79 | import fastcore 80 | import fastrl 81 | import torch 82 | from fastcore.all import ifnone 83 | 84 | 85 | rep = [] 86 | opt_mods = [] 87 | 88 | rep.append(["=== Software ===", None]) 89 | rep.append(["python", platform.python_version()]) 90 | rep.append(["fastrl", fastrl.__version__]) 91 | # rep.append(["fastai", fastai.__version__]) 92 | rep.append(["fastcore", fastcore.__version__]) 93 | rep.append(["fastprogress", fastprogress.__version__]) 94 | rep.append(["torch", torch.__version__]) 95 | 96 | # nvidia-smi 97 | smi = nvidia_smi() 98 | if smi: 99 | match = re.findall(r'Driver Version: +(\d+\.\d+)', smi) 100 | if match: rep.append(["nvidia driver", match[0]]) 101 | 102 | available = "available" if torch.cuda.is_available() else "**Not available** " 103 | rep.append(["torch cuda", f"{torch.version.cuda} / is {available}"]) 104 | 105 | # no point reporting on cudnn if cuda is not available, as it 106 | # seems to be enabled at times even on cpu-only setups 107 | if torch.cuda.is_available(): 108 | enabled = "enabled" if torch.backends.cudnn.enabled else "**Not enabled** " 109 | rep.append(["torch cudnn", f"{torch.backends.cudnn.version()} / is {enabled}"]) 110 | 111 | rep.append(["\n=== Hardware ===", None]) 112 | 113 | gpu_total_mem = [] 114 | nvidia_gpu_cnt = 0 115 | if smi: 116 | mem = nvidia_mem() 117 | nvidia_gpu_cnt = len(ifnone(mem, [])) 118 | 119 | if nvidia_gpu_cnt: rep.append(["nvidia gpus", nvidia_gpu_cnt]) 120 | 121 | torch_gpu_cnt = torch.cuda.device_count() 122 | if torch_gpu_cnt: 123 | rep.append(["torch devices", torch_gpu_cnt]) 124 | # information for each gpu 125 | for i in range(torch_gpu_cnt): 126 | rep.append([f" - gpu{i}", (f"{gpu_total_mem[i]}MB | " if gpu_total_mem else "") + torch.cuda.get_device_name(i)]) 127 | else: 128 | if nvidia_gpu_cnt: 129 | rep.append([f"Have {nvidia_gpu_cnt} GPU(s), but torch can't use them (check nvidia driver)", None]) 130 | else: 131 | rep.append([f"No GPUs available", None]) 132 | 133 | 134 | rep.append(["\n=== Environment ===", None]) 135 | 136 | rep.append(["platform", platform.platform()]) 137 | 138 | if platform.system() == 'Linux': 139 | distro = try_import('distro') 140 | if distro: 141 | # full distro info 142 | rep.append(["distro", ' '.join(distro.linux_distribution())]) 143 | else: 144 | opt_mods.append('distro'); 145 | # partial distro info 146 | rep.append(["distro", platform.uname().version]) 147 | 148 | rep.append(["conda env", get_env('CONDA_DEFAULT_ENV')]) 149 | rep.append(["python", sys.executable]) 150 | rep.append(["sys.path", "\n".join(sys.path)]) 151 | 152 | print("\n\n```text") 153 | 154 | keylen = max([len(e[0]) for e in rep if e[1] is not None]) 155 | for e in rep: 156 | print(f"{e[0]:{keylen}}", (f": {e[1]}" if e[1] is not None else "")) 157 | 158 | if smi: 159 | if show_nvidia_smi: print(f"\n{smi}") 160 | else: 161 | if torch_gpu_cnt: print("no nvidia-smi is found") 162 | else: print("no supported gpus found on this system") 163 | 164 | print("```\n") 165 | 166 | print("Please make sure to include opening/closing ``` when you paste into forums/github to make the reports appear formatted as code sections.\n") 167 | 168 | if opt_mods: 169 | print("Optional package(s) to enhance the diagnostics can be installed with:") 170 | print(f"pip install {' '.join(opt_mods)}") 171 | print("Once installed, re-run this utility to get the additional information") 172 | -------------------------------------------------------------------------------- /nbs/11_FAQ/99_notes.multi_proc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "#|eval: false\n", 12 | "! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \\\n", 13 | " apt-get install -y xvfb python-opengl > /dev/null 2>&1 \n", 14 | "# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "viral-cambridge", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "#|hide\n", 25 | "#|eval: false\n", 26 | "from fastcore.imports import in_colab\n", 27 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n", 28 | "if not in_colab():\n", 29 | " from nbdev.showdoc import *\n", 30 | " from nbdev.imports import *\n", 31 | " if not os.environ.get(\"IN_TEST\", None):\n", 32 | " assert IN_NOTEBOOK\n", 33 | " assert not IN_COLAB\n", 34 | " assert IN_IPYTHON\n", 35 | "else:\n", 36 | " # Virutual display is needed for colab\n", 37 | " from pyvirtualdisplay import Display\n", 38 | " display = Display(visible=0, size=(400, 300))\n", 39 | " display.start()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "assisted-contract", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Python native modules\n", 50 | "import os\n", 51 | "from copy import deepcopy\n", 52 | "# Third party libs\n", 53 | "from fastcore.all import *\n", 54 | "import numpy as np\n", 55 | "# Local modules\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "8681dfae-2d75-4936-8ab3-eb1bf1727312", 61 | "metadata": {}, 62 | "source": [ 63 | "## MultiProcessing Notes" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "2efbf0be-7a0c-4d7f-b876-df06e410dae7", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "import torchdata.datapipes as dp" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "139af462-4cee-42a0-9b8c-037f413ecf10", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "%%writefile ../external_run_scripts/notes_multi_proc_82.py\n", 84 | "import torchdata.datapipes as dp\n", 85 | "from torch.utils.data import IterableDataset\n", 86 | "\n", 87 | "class AddABunch1(dp.iter.IterDataPipe):\n", 88 | " def __init__(self,q):\n", 89 | " super().__init__()\n", 90 | " self.q = [q]\n", 91 | "\n", 92 | " def __iter__(self):\n", 93 | " for o in range(10): \n", 94 | " self.q[0].put(o)\n", 95 | " yield o\n", 96 | " \n", 97 | "class AddABunch2(dp.iter.IterDataPipe):\n", 98 | " def __init__(self,source_datapipe,q):\n", 99 | " super().__init__()\n", 100 | " self.q = q\n", 101 | " print(id(self.q))\n", 102 | " self.source_datapipe = source_datapipe\n", 103 | "\n", 104 | " def __iter__(self):\n", 105 | " for o in self.source_datapipe: \n", 106 | " print(id(self.q))\n", 107 | " self.q.put(o)\n", 108 | " yield o\n", 109 | " \n", 110 | "class AddABunch3(IterableDataset):\n", 111 | " def __init__(self,q):\n", 112 | " self.q = q\n", 113 | "\n", 114 | " def __iter__(self):\n", 115 | " for o in range(10): \n", 116 | " print(id(self.q))\n", 117 | " self.q.put(o)\n", 118 | " yield o\n", 119 | "\n", 120 | "if __name__=='__main__':\n", 121 | " from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method\n", 122 | " import torch\n", 123 | " \n", 124 | " try: set_start_method('spawn')\n", 125 | " except RuntimeError: pass\n", 126 | " # from torch.utils.data.dataloader_experimental import DataLoader2\n", 127 | " from torchdata.dataloader2 import DataLoader2\n", 128 | " from torchdata.dataloader2.reading_service import MultiProcessingReadingService\n", 129 | "\n", 130 | " m = Manager()\n", 131 | " q = m.Queue()\n", 132 | " \n", 133 | " pipe = AddABunch2(list(range(10)),q)\n", 134 | " print(type(pipe))\n", 135 | " dl = DataLoader2(pipe,\n", 136 | " reading_service=MultiProcessingReadingService(num_workers=1)\n", 137 | " ) # Will fail if num_workers>0\n", 138 | " \n", 139 | " # dl = DataLoader2(AddABunch1(q),num_workers=1) # Will fail if num_workers>0\n", 140 | " # dl = DataLoader2(AddABunch2(q),num_workers=1) # Will fail if num_workers>0\n", 141 | " # dl = DataLoader2(AddABunch3(q),num_workers=1) # Will succeed if num_workers>0\n", 142 | " list(dl)\n", 143 | " \n", 144 | " while not q.empty():\n", 145 | " print(q.get())" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "a3b223ee-40ab-4eea-a7cc-35d659301b14", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "988390db-b284-494d-a413-2fc12b6fa032", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "get_start_method()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "current-pilot", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "#|hide\n", 176 | "#|eval: false\n", 177 | "from fastcore.imports import in_colab\n", 178 | "\n", 179 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n", 180 | "if not in_colab():\n", 181 | " from nbdev import nbdev_export\n", 182 | " nbdev_export()" 183 | ] 184 | } 185 | ], 186 | "metadata": { 187 | "kernelspec": { 188 | "display_name": "python3", 189 | "language": "python", 190 | "name": "python3" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 5 195 | } 196 | -------------------------------------------------------------------------------- /fastrl/learner/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/06_Learning/10a_learner.core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['LearnerBase', 'LearnerHead', 'StepBatcher'] 5 | 6 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 2 7 | # Python native modules 8 | import os 9 | from contextlib import contextmanager 10 | from typing import List,Union,Dict,Optional,Iterable,Tuple 11 | # Third party libs 12 | from fastcore.all import add_docs 13 | import torchdata.datapipes as dp 14 | from torchdata.dataloader2.graph import list_dps 15 | import torch 16 | from torch import nn 17 | from torchdata.dataloader2 import DataLoader2 18 | from torchdata.dataloader2.graph import traverse_dps,DataPipeGraph,DataPipe 19 | # Local modules 20 | from ..torch_core import evaluating 21 | from ..pipes.core import find_dp 22 | from ..loggers.core import Record,EpochCollector,BatchCollector 23 | 24 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 4 25 | class LearnerBase(dp.iter.IterDataPipe): 26 | def __init__(self, 27 | # The base NN that we getting raw action values out of. 28 | # This can either be a `nn.Module` or a dict of multiple `nn.Module`s 29 | # For multimodel training 30 | model:Union[nn.Module,Dict[str,nn.Module]], 31 | # The dataloaders to read data from for training. This can be a single 32 | # DataLoader2 or an iterable that yields from a DataLoader2. 33 | dls:Union[DataLoader2,Iterable], 34 | # By default for reinforcement learning, we want to keep the workers 35 | # alive so that simluations are not being shutdown / restarted. 36 | # Epochs are expected to be handled semantically via tracking the number 37 | # of batches. 38 | infinite_dls:bool=True 39 | ): 40 | self.model = model 41 | self.iterable = dls 42 | self.learner_base = self 43 | self.infinite_dls = infinite_dls 44 | self._dls = None 45 | self._ended = False 46 | 47 | def __getstate__(self): 48 | state = {k:v for k,v in self.__dict__.items() if k not in ['_dls']} 49 | # TODO: Needs a better way to serialize / deserialize states. 50 | # state['iterable'] = [d.state_dict() for d in state['iterable']] 51 | if dp.iter.IterDataPipe.getstate_hook is not None: 52 | return dp.iter.IterDataPipe.getstate_hook(state) 53 | return state 54 | 55 | def __setstate__(self, state): 56 | # state['iterable'] = [d.from_state_dict() for d in state['iterable']] 57 | for k,v in state.items(): 58 | setattr(self,k,v) 59 | 60 | def end(self): 61 | self._ended = True 62 | 63 | def __iter__(self): 64 | self._ended = False 65 | for data in self.iterable: 66 | if self._ended: 67 | break 68 | yield data 69 | 70 | add_docs( 71 | LearnerBase, 72 | "Combines models,dataloaders, and optimizers together for running a training pipeline.", 73 | reset="""If `infinite_dls` is false, then all dls will be reset, otherwise they will be 74 | kept alive.""", 75 | end="When called, will cause the Learner to stop iterating and cleanup." 76 | ) 77 | 78 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 5 79 | class LearnerHead(dp.iter.IterDataPipe): 80 | def __init__( 81 | self, 82 | source_datapipes:Tuple[dp.iter.IterDataPipe] 83 | ): 84 | if not isinstance(source_datapipes,tuple): 85 | self.source_datapipes = (source_datapipes,) 86 | else: 87 | self.source_datapipes = source_datapipes 88 | self.dp_idx = 0 89 | 90 | def __iter__(self): yield from self.source_datapipes[self.dp_idx] 91 | 92 | def fit(self,epochs): 93 | self.dp_idx = 0 94 | epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector) 95 | learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase) 96 | epocher.epochs = epochs 97 | if isinstance(learner.model,dict): 98 | for m in learner.model.values(): 99 | m.train() 100 | else: 101 | learner.model.train() 102 | for _ in self: pass 103 | 104 | def validate(self,epochs=1,batches=100,show=True,return_outputs=False) -> DataPipe: 105 | self.dp_idx = 1 106 | epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector) 107 | epocher.epochs = epochs 108 | batcher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),BatchCollector) 109 | batcher.batches = batches 110 | learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase) 111 | model = learner.model 112 | model = tuple(model.values()) if isinstance(model,dict) else model 113 | with evaluating(model): 114 | if return_outputs: 115 | return list(self) 116 | else: 117 | for _ in self: pass 118 | if show: 119 | pipes = list_dps(traverse_dps(self.source_datapipes[self.dp_idx])) 120 | for pipe in pipes: 121 | if hasattr(pipe,'show'): 122 | return pipe.show() 123 | 124 | add_docs( 125 | LearnerHead, 126 | """LearnerHead can connect to multiple `LearnerBase`s and handles training 127 | and validation execution. 128 | """, 129 | fit="Runs the `LearnerHead` pipeline for `epochs`", 130 | validate="""If there is more than 1 dl, then run 1 epoch of that dl based on 131 | `dl_idx` and returns the original datapipe for displaying.""" 132 | ) 133 | 134 | # %% ../../nbs/06_Learning/10a_learner.core.ipynb 18 135 | class StepBatcher(dp.iter.IterDataPipe): 136 | def __init__(self, 137 | source_datapipe, 138 | device=None 139 | ): 140 | self.source_datapipe = source_datapipe 141 | self.device = device 142 | 143 | def vstack_by_fld(self,batch,fld): 144 | try: 145 | t = torch.vstack(tuple(getattr(step,fld) for step in batch)) 146 | # if self.device is not None: 147 | # t = t.to(torch.device(self.device)) 148 | t.requires_grad = False 149 | return t 150 | except RuntimeError as e: 151 | print(f'Failed to stack {fld} given batch: {batch}') 152 | raise 153 | 154 | def __iter__(self): 155 | for batch in self.source_datapipe: 156 | cls = batch[0].__class__ 157 | batched_step = cls(**{fld:self.vstack_by_fld(batch,fld) for fld in cls.__dataclass_fields__},batch_size=[len(batch)]) 158 | if self.device is not None: 159 | batched_step = batched_step.to(self.device) 160 | yield batched_step 161 | 162 | add_docs( 163 | StepBatcher, 164 | "Converts multiple `StepType` into a single `StepType` with the fields concated.", 165 | vstack_by_fld="vstacks a `fld` in `batch`" 166 | ) 167 | -------------------------------------------------------------------------------- /nbs/11_FAQ/99_notes.speed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "#|eval: false\n", 12 | "! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \\\n", 13 | " apt-get install -y xvfb python-opengl > /dev/null 2>&1 \n", 14 | "# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "viral-cambridge", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "#|hide\n", 25 | "#|eval: false\n", 26 | "from fastcore.imports import in_colab\n", 27 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n", 28 | "if not in_colab():\n", 29 | " from nbdev.showdoc import *\n", 30 | " from nbdev.imports import *\n", 31 | " if not os.environ.get(\"IN_TEST\", None):\n", 32 | " assert IN_NOTEBOOK\n", 33 | " assert not IN_COLAB\n", 34 | " assert IN_IPYTHON\n", 35 | "else:\n", 36 | " # Virutual display is needed for colab\n", 37 | " from pyvirtualdisplay import Display\n", 38 | " display = Display(visible=0, size=(400, 300))\n", 39 | " display.start()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "assisted-contract", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Python native modules\n", 50 | "import os\n", 51 | "from copy import deepcopy\n", 52 | "# Third party libs\n", 53 | "from fastcore.all import *\n", 54 | "import numpy as np\n", 55 | "# Local modules\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "lesser-innocent", 61 | "metadata": {}, 62 | "source": [ 63 | "# Speed\n", 64 | "> Some obvious / not so obvious notes on speed" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "id": "4cce6de1-8643-4338-b65e-60406a35cb3a", 70 | "metadata": {}, 71 | "source": [ 72 | "## Numpy to Tensor Performance" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "f51e6709-ae53-474b-b975-914bd36159b2", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "img=np.random.randint(0,255,size=(240, 320, 3))" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "202749e2-72dd-4baa-b5e5-8217f297c0a3", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "%%timeit\n", 93 | "#|eval: false\n", 94 | "img=np.random.randint(0,255,size=(240, 320, 3))" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "00356113-ef79-42aa-a746-9eca3ffe65bc", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "%%timeit\n", 105 | "#|eval: false\n", 106 | "deepcopy(img)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "45a4e034-9a2b-4e4a-b1b9-71d23b12155b", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "%%timeit\n", 117 | "#|eval: false\n", 118 | "Tensor(img)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "1acadadb-94e3-4be6-841a-2045b1da8767", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "%%timeit\n", 129 | "#|eval: false\n", 130 | "Tensor([img])" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "9d5040b2-6670-495b-abf7-fe5a26c9eda0", 136 | "metadata": {}, 137 | "source": [ 138 | "You will notice that if you wrap a numpy in a list, it completely kills the performance. The solution is to\n", 139 | "just add a batch dim to the existing array and pass it directly." 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "f1007f5b-db0d-4152-b474-e6e319bd81e5", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "%%timeit\n", 150 | "#|eval: false\n", 151 | "Tensor(np.expand_dims(img,0))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "id": "ac025532-f917-4ded-9ec7-903373712c16", 157 | "metadata": {}, 158 | "source": [ 159 | "In fact we can just test this with python lists..." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "0a03815e-205e-48ee-8d37-be8ffa2f7a20", 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "%%timeit\n", 170 | "#|eval: false\n", 171 | "Tensor([[1]])" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "e8001456-bb86-44c2-b68f-59c4e88018bd", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "test_arr=[[1]*270000]" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "3ba7e414-d53f-4865-832c-c16aa098963d", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "%%timeit\n", 192 | "#|eval: false\n", 193 | "Tensor(test_arr)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "id": "dc756fe2-c890-4736-9e2a-426e9d646b3c", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "test_arr=np.array([[1]*270000])" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "0627b489-d4b7-475c-bf73-de263ea5c5e4", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "%%timeit\n", 214 | "#|eval: false\n", 215 | "Tensor(test_arr)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "c43811a0-ca7d-4d20-b4a8-51aaccfe1b6d", 221 | "metadata": {}, 222 | "source": [ 223 | "This is horrifying just how made of a performance hit this causes... So we will be avoiding python list inputs \n", 224 | "to Tensors for now on..." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "current-pilot", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "#|hide\n", 235 | "#|eval: false\n", 236 | "from fastcore.imports import in_colab\n", 237 | "\n", 238 | "# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to\n", 239 | "if not in_colab():\n", 240 | " from nbdev import nbdev_export\n", 241 | " nbdev_export()" 242 | ] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "python3", 248 | "language": "python", 249 | "name": "python3" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "durable-dialogue", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|hide\n", 11 | "from fastrl.test_utils import initialize_notebook\n", 12 | "initialize_notebook()" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "assisted-contract", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#|export\n", 23 | "# Python native modules\n", 24 | "import os\n", 25 | "import io\n", 26 | "from typing import Tuple,Any,Optional,NamedTuple,Iterable\n", 27 | "# Third party libs\n", 28 | "import imageio\n", 29 | "from fastcore.all import add_docs,ifnone\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import torchdata.datapipes as dp\n", 32 | "from torchdata.datapipes import functional_datapipe\n", 33 | "from IPython.core.display import Video,Image\n", 34 | "from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService\n", 35 | "# Local modules\n", 36 | "from fastrl.loggers.core import LoggerBase,is_record\n", 37 | "from fastrl.pipes.core import DataPipeAugmentationFn,apply_dp_augmentation_fns\n", 38 | "from fastrl.loggers.jupyter_visualizers import ImageCollector" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "offshore-stuart", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "#|default_exp loggers.vscode_visualizers" 49 | ] 50 | }, 51 | { 52 | "attachments": {}, 53 | "cell_type": "markdown", 54 | "id": "lesser-innocent", 55 | "metadata": {}, 56 | "source": [ 57 | "# Visualizers - VS-Code\n", 58 | "> Iterable pipes for displaying environments as they run using `typing.NamedTuples` with `image` fields for VS-Code\n", 59 | "\n", 60 | "`fastrl.jupyter_visualizers` can be used in vscode, however you likely will notice flickering for video\n", 61 | "based outputs. For vscode, we can generate a gif instead." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "fde1fa58", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "dp.iter.Repeater" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "9cb488d2-41f1-4160-a938-e39003f1a06a", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "#|export\n", 82 | "class SimpleVSCodeVideoPlayer(dp.iter.IterDataPipe):\n", 83 | " def __init__(self, \n", 84 | " source_datapipe=None, \n", 85 | " skip_frames:int=1,\n", 86 | " fps:int=30,\n", 87 | " downsize_res=(2,2)\n", 88 | " ):\n", 89 | " self.source_datapipe = source_datapipe\n", 90 | " self.fps = fps\n", 91 | " self.skip_frames = skip_frames\n", 92 | " self.downsize_res = downsize_res\n", 93 | " self._bytes_object = None\n", 94 | " self.frames = [] \n", 95 | "\n", 96 | " def reset(self):\n", 97 | " super().reset()\n", 98 | " self._bytes_object = io.BytesIO()\n", 99 | "\n", 100 | " def show(self,start:int=0,end:Optional[int]=None,step:int=1):\n", 101 | " print(f'Creating gif from {len(self.frames)} frames')\n", 102 | " imageio.mimwrite(\n", 103 | " self._bytes_object,\n", 104 | " self.frames[start:end:step],\n", 105 | " format='GIF',\n", 106 | " duration=self.fps\n", 107 | " )\n", 108 | " return Image(self._bytes_object.getvalue())\n", 109 | " \n", 110 | " def __iter__(self) -> Tuple[NamedTuple]:\n", 111 | " n_frame = 0\n", 112 | " for record in self.source_datapipe:\n", 113 | " # for o in self.dequeue():\n", 114 | " if is_record(record) and record.name=='image':\n", 115 | " if record.value is None: continue\n", 116 | " n_frame += 1\n", 117 | " if n_frame%self.skip_frames!=0: continue\n", 118 | " self.frames.append(\n", 119 | " record.value[::self.downsize_res[0],::self.downsize_res[1]]\n", 120 | " )\n", 121 | " yield record\n", 122 | "add_docs(\n", 123 | "SimpleVSCodeVideoPlayer,\n", 124 | "\"\"\"Displays video from a `source_datapipe` that produces `typing.NamedTuples` that contain an `image` field.\n", 125 | "This only can handle 1 env input.\"\"\",\n", 126 | "show=\"In order to show the video, this must be called in a notebook cell.\",\n", 127 | "reset=\"Will reset the bytes object that is used to store file data.\"\n", 128 | ")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "d486d06a", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "#|export\n", 139 | "@functional_datapipe('visualize_vscode')\n", 140 | "class VSCodeDataPipe(dp.iter.IterDataPipe):\n", 141 | " def __new__(self,source:Iterable):\n", 142 | " \"This is the function that is actually run by `DataBlock`\"\n", 143 | " pipe = ImageCollector(source).dump_records()\n", 144 | " pipe = SimpleVSCodeVideoPlayer(pipe)\n", 145 | " return pipe \n", 146 | " " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "55130404", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "from fastrl.envs.gym import GymDataPipe" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "id": "8b58a84d", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "#|hide\n", 167 | "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True).visualize_vscode()\n", 168 | "\n", 169 | "list(pipe);\n", 170 | "pipe.show()" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "id": "ead27188-7322-46c2-9300-59842af2386d", 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "#|hide\n", 181 | "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True)\n", 182 | "pipe = VSCodeDataPipe(pipe)\n", 183 | "\n", 184 | "list(pipe);\n", 185 | "pipe.show()" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "current-pilot", 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "#|hide\n", 196 | "#|eval: false\n", 197 | "!nbdev_export" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "e0d82468-a2bf-4bfd-9ac7-e56db49b8476", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "python3", 212 | "language": "python", 213 | "name": "python3" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 5 218 | } 219 | -------------------------------------------------------------------------------- /fastrl/memory/memory_visualizer.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/04_Memory/01_memory_visualizer.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['MemoryBufferViewer'] 5 | 6 | # %% ../../nbs/04_Memory/01_memory_visualizer.ipynb 2 7 | # Python native modules 8 | import io 9 | from typing import List 10 | # Third party libs 11 | from PIL import Image 12 | from ipywidgets import Button, HBox, VBox, Output, IntText, Label 13 | from ipywidgets import widgets 14 | from IPython.display import display 15 | import numpy as np 16 | import torch 17 | # Local modules 18 | from ..core import StepTypes 19 | 20 | # %% ../../nbs/04_Memory/01_memory_visualizer.ipynb 4 21 | class MemoryBufferViewer: 22 | def __init__(self, memory:List[StepTypes.types], agent=None, ignore_image:bool=False): 23 | # Assuming memory contains SimpleStep instances or None 24 | self.memory = memory 25 | self.agent = agent 26 | self.current_index = 0 27 | self.ignore_image = ignore_image 28 | # Add a label for displaying the number of elements in memory 29 | self.memory_size_label = Label(value=f"Number of Elements in Memory: {len([x for x in memory if x is not None])}") 30 | 31 | # Create the widgets 32 | self.out = Output() 33 | self.next_button = Button(description="Next") 34 | self.prev_button = Button(description="Previous") 35 | self.goto_text = IntText(value=0, description='Index:') 36 | # Button to jump to the desired index 37 | self.goto_button = Button(description="Go") 38 | self.goto_button.on_click(self.goto_step) 39 | self.action_value_label = Label() 40 | 41 | # Setup event handlers 42 | self.next_button.on_click(self.next_step) 43 | self.prev_button.on_click(self.prev_step) 44 | self.manual_navigation = False 45 | self.goto_text.observe(self.jump_to_index, names='value') 46 | 47 | # Display the widgets 48 | # Update the widget layout 49 | self.display_content_placeholder = VBox([]) 50 | self.layout = VBox([ 51 | self.memory_size_label, 52 | HBox([self.prev_button, self.next_button, self.goto_text, self.goto_button]), 53 | self.action_value_label, 54 | self.out, 55 | self.display_content_placeholder 56 | ]) 57 | self.show_current() 58 | display(self.layout) 59 | 60 | def jump_to_index(self, change): 61 | if not self.manual_navigation: 62 | idx = change['new'] 63 | if 0 <= idx < len(self.memory): 64 | self.current_index = idx 65 | self.show_current() 66 | else: 67 | self.manual_navigation = False 68 | 69 | def next_step(self, change): 70 | self._toggle_buttons_state(False) # Disable buttons 71 | self.current_index = min(len(self.memory) - 1, self.current_index + 1) 72 | self.manual_navigation = True 73 | self.goto_text.value = self.current_index 74 | self.show_current() 75 | self._toggle_buttons_state(True) # Enable buttons 76 | 77 | def prev_step(self, change): 78 | self._toggle_buttons_state(False) # Disable buttons 79 | self.current_index = max(0, self.current_index - 1) 80 | self.manual_navigation = True 81 | self.goto_text.value = self.current_index 82 | self.show_current() 83 | self._toggle_buttons_state(True) # Enable buttons 84 | 85 | def goto_step(self, change): 86 | self._toggle_buttons_state(False) # Disable buttons 87 | target_idx = self.goto_text.value 88 | if 0 <= target_idx < len(self.memory): 89 | self.current_index = target_idx 90 | self.show_current() 91 | self._toggle_buttons_state(True) # Enable buttons 92 | 93 | def _toggle_buttons_state(self, state): 94 | """Helper function to toggle button states.""" 95 | self.prev_button.disabled = not state 96 | self.next_button.disabled = not state 97 | self.goto_button.disabled = not state 98 | 99 | def tensor_to_pil(self, tensor_image): 100 | """Convert a tensor to a PIL Image.""" 101 | # Convert the tensor to numpy 102 | img_np = tensor_image.numpy() 103 | 104 | # Check if the tensor was in C, H, W format and convert it to H, W, C for PIL 105 | if img_np.ndim == 3 and img_np.shape[2] != 3: 106 | img_np = np.transpose(img_np, (1, 2, 0)) 107 | 108 | # Make sure the data type is right 109 | if img_np.dtype in (np.float32,np.float64): 110 | img_np = (img_np * 255).astype(np.uint8) 111 | 112 | return Image.fromarray(img_np) 113 | 114 | def pil_image_to_byte_array(self, pil_image): 115 | """Helper function to convert PIL image to byte array.""" 116 | buffer = io.BytesIO() 117 | pil_image.save(buffer, format='JPEG') 118 | return buffer.getvalue() 119 | 120 | def show_current(self): 121 | self.out.clear_output(wait=True) 122 | with self.out: 123 | if self.memory[self.current_index] is not None: 124 | step = self.memory[self.current_index] 125 | 126 | # Prepare the right-side content (step details) 127 | details_list = [] 128 | details_list.append(Label(f"Action Value: {step.action.item()}")) 129 | # If agent is provided, predict the action based on step.state 130 | if self.agent is not None: 131 | with torch.no_grad(): 132 | for predicted_action in self.agent([step]):pass 133 | details_list.append(Label(f"Agent Predicted Action: {predicted_action}")) 134 | 135 | for field, value in step.to_tensordict().items(): 136 | if field not in ['state', 'next_state', 'image']: 137 | details_list.append(Label(f"{field.capitalize()}: {value}")) 138 | 139 | details_display = VBox(details_list) 140 | 141 | # If the image is present, prepare left-side content 142 | if torch.is_tensor(step.image) and step.image.nelement() > 1 and not self.ignore_image: 143 | pil_image = self.tensor_to_pil(step.image) 144 | img_display = widgets.Image(value=self.pil_image_to_byte_array(pil_image), format='jpeg') 145 | display_content = HBox([img_display, details_display]) 146 | else: 147 | # If image is not present, use the entire space for details 148 | # You can expand this to include 'state' and 'next_state' as desired 149 | # If image is not present, display 'state' and 'next_state' along with other details 150 | state_label = Label(f"State: {step.state}") 151 | next_state_label = Label(f"Next State: {step.next_state}") 152 | display_content = VBox([details_display, state_label, next_state_label]) 153 | 154 | self.display_content_placeholder.children = [display_content] 155 | else: 156 | print(f"Step {self.current_index}: Empty") 157 | self.action_value_label.value = "" 158 | -------------------------------------------------------------------------------- /fastrl/cli.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/19_cli.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['fastrl_make_requirements', 'proc_nbs', 'fastrl_nbdev_docs', 'fastrl_nbdev_test'] 5 | 6 | # %% ../nbs/19_cli.ipynb 3 7 | # Python native modules 8 | import os 9 | import shutil 10 | # Third party libs 11 | from fastcore.all import * 12 | # Local modules 13 | from nbdev.quarto import _nbglob_docs,_sprun,_pre_docs,nbdev_readme,move,proc_nbs 14 | from nbdev.test import test_nb,_keep_file 15 | 16 | # %% ../nbs/19_cli.ipynb 6 17 | @call_parse 18 | def fastrl_make_requirements( 19 | path:Path=None, # The path to a dir with the settings.ini, if none, cwd. 20 | project_file:str='settings.ini', # The file to load for reading the requirements 21 | out_path:Path=None, # The output path (can be relative to `path`) 22 | verbose:bool=False # Output to stdout 23 | ): 24 | requirement_types = ['','dev_','pip_'] 25 | path = ifnone(path, Path.cwd())/project_file 26 | 27 | if not path.exists(): raise OSError(f'File {path} does not exist') 28 | 29 | out_path = ifnone(out_path, Path('extra')) 30 | out_path = out_path if out_path.is_absolute() else path.parent/out_path 31 | out_path.mkdir(parents=True, exist_ok=True) 32 | if verbose: print('Outputting to path: ',out_path) 33 | config = Config(path.parent,path.name) 34 | 35 | for req in requirement_types: 36 | requirements = config[req+'requirements'] 37 | requirements = requirements.replace(' ','\n') 38 | Path(out_path/(req+'requirements.txt')).write_text(requirements) 39 | 40 | # %% ../nbs/19_cli.ipynb 7 41 | from nbdev.config import * 42 | from nbdev.doclinks import * 43 | 44 | from fastcore.utils import * 45 | from fastcore.script import call_parse 46 | from fastcore.shutil import rmtree,move,copytree 47 | from fastcore.meta import delegates 48 | from nbdev.serve import proc_nbs,_proc_file 49 | from nbdev import serve_drv 50 | from nbdev.quarto import _ensure_quarto 51 | from nbdev.quarto import * 52 | import nbdev 53 | 54 | # %% ../nbs/19_cli.ipynb 8 55 | @call_parse 56 | @delegates(nbglob_cli) 57 | def proc_nbs( 58 | path:str='', # Path to notebooks 59 | n_workers:int=defaults.cpus, # Number of workers 60 | force:bool=False, # Ignore cache and build all 61 | file_glob:str='', # Only include files matching glob 62 | verbose:bool=False, # verbose outputs 63 | one2one:bool=True, # Run 1 notebook per process instance. 64 | **kwargs): 65 | "Process notebooks in `path` for docs rendering" 66 | cfg = get_config() 67 | cache = cfg.config_path/'_proc' 68 | path = Path(path or cfg.nbs_path) 69 | files = nbglob(path, func=Path, file_glob=file_glob, **kwargs) 70 | if (path/'_quarto.yml').exists(): files.append(path/'_quarto.yml') 71 | 72 | # If settings.ini or filter script newer than cache folder modified, delete cache 73 | chk_mtime = max(cfg.config_file.stat().st_mtime, Path(__file__).stat().st_mtime) 74 | cache.mkdir(parents=True, exist_ok=True) 75 | cache_mtime = cache.stat().st_mtime 76 | if force or (cache.exists and cache_mtime torch.LongTensor: 34 | for step in self.source_datapipe: 35 | if not issubclass(step.__class__,torch.Tensor): 36 | raise Exception(f'Expected Tensor to take the argmax, got {type(step)}\n{step}') 37 | # Might want to support simple tuples also depending on if we are processing multiple fields. 38 | idx = torch.argmax(step,axis=self.axis).reshape(-1,1) 39 | if self.only_idx: 40 | yield idx.long() 41 | continue 42 | step[:] = 0 43 | step.scatter_(1,idx,1) 44 | yield step.long() 45 | 46 | 47 | # %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 7 48 | class EpsilonSelector(dp.iter.IterDataPipe): 49 | debug=False 50 | "Given input `Tensor` from `source_datapipe`." 51 | def __init__(self, 52 | source_datapipe, # a datapipe whose next(source_datapipe) -> `Tensor` 53 | min_epsilon:float=0.2, # The minimum epsilon to drop to 54 | # The max/starting epsilon if `epsilon` is None and used for calculating epislon decrease speed. 55 | max_epsilon:float=1, 56 | # Determines how fast the episilon should drop to `min_epsilon`. This should be the number 57 | # of steps that the agent was run through. 58 | max_steps:int=100, 59 | # The starting epsilon 60 | epsilon:float=None, 61 | # Based on the `base_agent.model.training`, by default no decrement or step tracking will 62 | # occur during validation steps. 63 | decrement_on_val:bool=False, 64 | # Based on the `base_agent.model.training`, by default random actions will not be attempted 65 | select_on_val:bool=False, 66 | # Also return the mask that, where True, the action should be randomly selected. 67 | ret_mask:bool=False, 68 | # The device to create the masks one 69 | device='cpu' 70 | ): 71 | self.source_datapipe = source_datapipe 72 | self.min_epsilon = min_epsilon 73 | self.max_epsilon = max_epsilon 74 | self.max_steps = max_steps 75 | self.epsilon = epsilon 76 | self.decrement_on_val = decrement_on_val 77 | self.select_on_val = select_on_val 78 | self.ret_mask = ret_mask 79 | self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase) 80 | self.step = 0 81 | self.device = torch.device(device) 82 | 83 | def __iter__(self): 84 | for action in self.source_datapipe: 85 | # TODO: Support tuples of actions also 86 | if not issubclass(action.__class__,torch.Tensor): 87 | raise Exception(f'Expected Tensor, got {type(action)}\n{action}') 88 | if action.dtype!=torch.int64: 89 | raise ValueError(f'Expected Tensor of dtype int64, got: {action.dtype} from {self.source_datapipe}') 90 | 91 | if self.agent_base.model.training or self.decrement_on_val: 92 | self.step+=1 93 | 94 | self.epsilon = max(self.min_epsilon,self.max_epsilon-self.step/self.max_steps) 95 | # Add a batch dim if missing 96 | if len(action.shape)==1: action.unsqueeze_(0) 97 | mask = None 98 | if self.agent_base.model.training or self.select_on_val: 99 | # Given N(action.shape[0]) actions, select the ones we want to randomly assign... 100 | mask = torch.rand(action.shape[0],).to(self.device) Union[float,bool,int]: 143 | for step in self.source_datapipe: 144 | if not issubclass(step.__class__,(np.ndarray)): 145 | raise Exception(f'Expected list or np.ndarray to convert to python primitive, got {type(step)}\n{step}') 146 | 147 | if len(step)>1 or len(step)==0: 148 | raise Exception(f'`step` from {self.source_datapipe} needs to be len 1, not {len(step)}') 149 | else: 150 | step = step[0] 151 | 152 | if np.issubdtype(step.dtype,np.integer): 153 | yield int(step) 154 | elif np.issubdtype(step.dtype,np.floating): 155 | yield float(step) 156 | elif np.issubdtype(step.dtype,np.bool8): 157 | yield bool(step) 158 | else: 159 | raise Exception(f'`step` from {self.source_datapipe} must be one of the 3 python types: bool,int,float, not {step.dtype}') 160 | -------------------------------------------------------------------------------- /fastrl/agents/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/07_Agents/12a_agents.core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['shared_model_dict', 'share_model', 'get_shared_model', 'AgentBase', 'AgentHead', 'SimpleModelRunner', 5 | 'StepFieldSelector', 'NumpyConverter'] 6 | 7 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 2 8 | # Python native modules 9 | import os 10 | from typing import List,Optional 11 | # Third party libs 12 | from fastcore.all import add_docs,ifnone 13 | import torchdata.datapipes as dp 14 | import torch 15 | from torch import nn 16 | from torchdata.dataloader2.graph import traverse_dps 17 | import torch.multiprocessing as mp 18 | # Local modules 19 | from ..core import StepTypes,SimpleStep 20 | from ..torch_core import evaluating,Module 21 | from ..pipes.core import find_dps,find_dp 22 | 23 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 4 24 | # Create a manager for shared objects 25 | # manager = mp.Manager() 26 | # shared_model_dict = manager.dict() 27 | shared_model_dict = {} 28 | 29 | def share_model(model: nn.Module, name="default"): 30 | """Move model's parameters to shared memory and store in manager's dictionary.""" 31 | # TODO(josiahls): This will not survive multiprocessing. We will need to us something 32 | # like ray to better sync models. 33 | model.share_memory() 34 | shared_model_dict[name] = model 35 | 36 | def get_shared_model(name="default"): 37 | """Retrieve model from shared memory using the manager's dictionary.""" 38 | return shared_model_dict[name] 39 | 40 | class AgentBase(dp.iter.IterDataPipe): 41 | def __init__(self, 42 | model:Optional[nn.Module], # The base NN that we getting raw action values out of. 43 | action_iterator:list=None, # A reference to an iterator that contains actions to process. 44 | logger_bases=None 45 | ): 46 | self.model = model 47 | self.iterable = ifnone(action_iterator,[]) 48 | self.agent_base = self 49 | self.logger_bases = logger_bases 50 | self._mem_name = 'agent_model' 51 | 52 | def to(self,*args,**kwargs): 53 | if self.model is not None: 54 | self.model.to(**kwargs) 55 | 56 | def __iter__(self): 57 | while self.iterable: 58 | yield self.iterable.pop(0) 59 | 60 | def __getstate__(self): 61 | if self.model is not None: 62 | share_model(self.model,self._mem_name) 63 | # Store the non-model state 64 | state = self.__dict__.copy() 65 | return state 66 | 67 | def __setstate__(self, state): 68 | self.__dict__.update(state) 69 | # Assume a globally shared model instance or a reference method to retrieve it 70 | if self.model is not None: 71 | self.model = get_shared_model(self._mem_name) 72 | 73 | add_docs( 74 | AgentBase, 75 | """Acts as the footer of the Agent pipeline. 76 | Maintains important state such as the `model` being used for get actions from. 77 | Also optionally allows passing a reference list of `action_iterator` which is a 78 | persistent list of actions for the entire agent pipeline to process through. 79 | 80 | > Important: Must be at the start of the pipeline, and be used with AgentHead at the end. 81 | 82 | > Important: `action_iterator` is stored in the `iterable` field. However the recommended 83 | way of passing actions to the pipeline is to call an `AgentHead` instance. 84 | """, 85 | to=torch.Tensor.to.__doc__ 86 | ) 87 | 88 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 5 89 | class AgentHead(dp.iter.IterDataPipe): 90 | def __init__(self,source_datapipe): 91 | self.source_datapipe = source_datapipe 92 | self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase) 93 | 94 | def __call__(self,steps:list): 95 | if issubclass(steps.__class__,StepTypes.types): 96 | raise Exception(f'Expected List[{StepTypes.types}] object got {type(steps)}\n{steps}') 97 | self.agent_base.iterable.extend(steps) 98 | return self 99 | 100 | def __iter__(self): yield from self.source_datapipe 101 | 102 | def augment_actions(self,actions): return actions 103 | 104 | def create_step(self,**kwargs): return SimpleStep(**kwargs,batch_size=[]) 105 | 106 | add_docs( 107 | AgentHead, 108 | """Acts as the head of the Agent pipeline. 109 | Used for conveniently adding actions to the pipeline to process. 110 | 111 | > Important: Must be paired with `AgentBase` 112 | """, 113 | augment_actions="""Called right before being fed into the env. 114 | 115 | > Important: The results of this function will not be kept / used in the step or forwarded to 116 | any training code. 117 | 118 | There are cases where either the entire action shouldn't be fed into the env, 119 | or the version of the action that we want to train on would be compat with the env. 120 | 121 | This is also useful if we want to train on the original raw values of the action prior to argmax being run on it for example. 122 | """, 123 | create_step="Creates the step used by the env for running, and used by the model for training." 124 | ) 125 | 126 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 6 127 | class SimpleModelRunner(dp.iter.IterDataPipe): 128 | "Takes input from `source_datapipe` and pushes through the agent bases model assuming there is only one model field." 129 | def __init__(self, 130 | source_datapipe 131 | ): 132 | self.source_datapipe = source_datapipe 133 | self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase) 134 | self.device = None 135 | 136 | def to(self,*args,**kwargs): 137 | if 'device' in kwargs: self.device = kwargs.get('device',None) 138 | return self 139 | 140 | def __iter__(self): 141 | for x in self.source_datapipe: 142 | if self.device is not None: x = x.to(self.device) 143 | if len(x.shape)==1: x = x.unsqueeze(0) 144 | with torch.no_grad(): 145 | with evaluating(self.agent_base.model): 146 | res = self.agent_base.model(x) 147 | yield res 148 | 149 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 12 150 | class StepFieldSelector(dp.iter.IterDataPipe): 151 | "Grabs `field` from `source_datapipe` to push to the rest of the pipeline." 152 | def __init__(self, 153 | source_datapipe, # datapipe whose next(source_datapipe) -> `StepTypes` 154 | field='state' # A field in `StepTypes` to grab 155 | ): 156 | # TODO: support multi-fields 157 | self.source_datapipe = source_datapipe 158 | self.field = field 159 | 160 | def __iter__(self): 161 | for step in self.source_datapipe: 162 | if not issubclass(step.__class__,StepTypes.types): 163 | raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}') 164 | yield getattr(step,self.field) 165 | 166 | # %% ../../nbs/07_Agents/12a_agents.core.ipynb 22 167 | class NumpyConverter(dp.iter.IterDataPipe): 168 | debug=False 169 | 170 | def __init__(self,source_datapipe): 171 | self.source_datapipe = source_datapipe 172 | 173 | def debug_display(self,step): 174 | print(f'Step: {step}') 175 | 176 | def __iter__(self) -> torch.LongTensor: 177 | for step in self.source_datapipe: 178 | if not issubclass(step.__class__,torch.Tensor): 179 | raise Exception(f'Expected Tensor to convert to numpy, got {type(step)}\n{step}') 180 | if self.debug: self.debug_display(step) 181 | yield step.detach().cpu().numpy() 182 | 183 | add_docs( 184 | NumpyConverter, 185 | """Given input `Tensor` from `source_datapipe` returns a numpy array of same shape with argmax set to 1.""", 186 | debug_display="Display the step being processed" 187 | ) 188 | --------------------------------------------------------------------------------