├── .clang-format ├── .flake8 ├── .gitattributes ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── libtorchbeast ├── actorpool.cc ├── nest_serialize.h ├── rpcenv.cc └── rpcenv.proto ├── main.py ├── models ├── AirRaid │ ├── meta.json │ └── model.tar ├── Carnival │ ├── meta.json │ └── model.tar ├── DemonAttack │ ├── meta.json │ └── model.tar ├── MultiTask │ ├── meta.json │ └── model.tar ├── MultiTaskPopArt │ ├── meta.json │ └── model.tar ├── NameThisGame │ ├── meta.json │ └── model.tar ├── Pong │ ├── meta.json │ └── model.tar └── SpaceInvaders │ ├── meta.json │ └── model.tar ├── movies ├── .DS_Store ├── AirRaid_050009600_AirRaidNoFrameskip-v4.gif ├── Carnival_050002560_CarnivalNoFrameskip-v4.gif ├── DemonAttack_050001280_DemonAttackNoFrameskip-v4.gif ├── MultiTaskPopart_300010240_AirRaidNoFrameskip-v4.gif ├── MultiTaskPopart_300010240_CarnivalNoFrameskip-v4.gif ├── MultiTaskPopart_300010240_DemonAttackNoFrameskip-v4.gif ├── MultiTaskPopart_300010240_NameThisGameNoFrameskip-v4.gif ├── MultiTaskPopart_300010240_PongNoFrameskip-v4.gif ├── MultiTaskPopart_300010240_SpaceInvadersNoFrameskip-v4.gif ├── MultiTask_300014720_AirRaidNoFrameskip-v4.gif ├── MultiTask_300014720_CarnivalNoFrameskip-v4.gif ├── MultiTask_300014720_DemonAttackNoFrameskip-v4.gif ├── MultiTask_300014720_NameThisGameNoFrameskip-v4.gif ├── MultiTask_300014720_PongNoFrameskip-v4.gif ├── MultiTask_300014720_SpaceInvadersNoFrameskip-v4.gif ├── NameThisGame_050016000_NameThisGameNoFrameskip-v4.gif ├── Pong_050013440_PongNoFrameskip-v4.gif ├── Saliency_AirRaidNoFrameskip-v4.gif ├── Saliency_CarnivalNoFrameskip-v4.gif ├── Saliency_DemonAttackNoFrameskip-v4.gif ├── Saliency_NameThisGameNoFrameskip-v4.gif ├── Saliency_PongNoFrameskip-v4.gif ├── Saliency_SpaceInvadersNoFrameskip-v4.gif └── SpaceInvaders_050001280_SpaceInvadersNoFrameskip-v4.gif ├── nest ├── README.md ├── nest │ ├── nest.h │ ├── nest_pybind.cc │ └── nest_pybind.h ├── nest_test.py └── setup.py ├── plot.png ├── pyproject.toml ├── requirements.txt ├── results ├── 50_games_actions.pkl ├── Detrained.txt ├── MultiTask.txt ├── Popart.txt ├── Pretrained.txt ├── SingleTask.txt ├── action_analysis.ipynb ├── figures.ipynb ├── figures │ ├── computational_graph.afdesign │ ├── fig_action_distributions_aggregated.png │ ├── fig_action_distributions_all.png │ ├── fig_comparison_with_paper.png │ ├── fig_computational_graph.png │ ├── fig_detraining.png │ ├── fig_mean_episode_return.png │ ├── fig_model.png │ ├── fig_mu_sigma.png │ ├── fig_multi_multipop_default.png │ ├── fig_multi_multipop_optimal.png │ ├── fig_pre_and_detraining.png │ ├── fig_saliency.png │ ├── fig_saliency_pong.png │ ├── fig_single_comp_carnival_default.png │ ├── fig_single_comp_carnival_optimal.png │ ├── fig_single_multi_default.png │ ├── fig_single_multi_optimal.png │ ├── fig_single_multipop_default.png │ ├── fig_single_multipop_optimal.png │ └── model.pptx ├── model.ipynb ├── movies.ipynb ├── paper │ ├── paper.data.processed.pkl │ └── paper.ipynb ├── report.pdf └── report_SOURCE.zip ├── scripts └── install_grpc.sh ├── setup.py ├── tests ├── batching_queue_test.py ├── contiguous_arrays_env.py ├── contiguous_arrays_test.py ├── core_agent_state_env.py ├── core_agent_state_test.py ├── dynamic_batcher_test.py ├── inference_speed_profiling.py ├── lint_changed.sh ├── polybeast_inference_test.py ├── polybeast_learn_function_test.py ├── polybeast_loss_functions_test.py ├── polybeast_net_test.py └── vtrace_test.py ├── torchbeast.yml └── torchbeast ├── analysis ├── analyze_resnet.py ├── gradient_tracking.py └── visualize_aaa.py ├── atari_wrappers.py ├── core ├── environment.py ├── file_writer.py ├── popart.py ├── prof.py └── vtrace.py ├── models ├── atari_net_monobeast.py ├── attention_augmented_agent.py └── resnet_monobeast.py ├── monobeast.py ├── polybeast.py ├── polybeast_env.py └── saliency.py /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Left 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: All 15 | AllowShortIfStatementsOnASingleLine: true 16 | AllowShortLoopsOnASingleLine: true 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: true 20 | AlwaysBreakTemplateDeclarations: Yes 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: false 25 | AfterControlStatement: false 26 | AfterEnum: false 27 | AfterFunction: false 28 | AfterNamespace: false 29 | AfterObjCDeclaration: false 30 | AfterStruct: false 31 | AfterUnion: false 32 | AfterExternBlock: false 33 | BeforeCatch: false 34 | BeforeElse: false 35 | IndentBraces: false 36 | SplitEmptyFunction: true 37 | SplitEmptyRecord: true 38 | SplitEmptyNamespace: true 39 | BreakBeforeBinaryOperators: None 40 | BreakBeforeBraces: Attach 41 | BreakBeforeInheritanceComma: false 42 | BreakInheritanceList: BeforeColon 43 | BreakBeforeTernaryOperators: true 44 | BreakConstructorInitializersBeforeComma: false 45 | BreakConstructorInitializers: BeforeColon 46 | BreakAfterJavaFieldAnnotations: false 47 | BreakStringLiterals: true 48 | ColumnLimit: 80 49 | CommentPragmas: '^ IWYU pragma:' 50 | CompactNamespaces: false 51 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 52 | ConstructorInitializerIndentWidth: 4 53 | ContinuationIndentWidth: 4 54 | Cpp11BracedListStyle: true 55 | DerivePointerAlignment: true 56 | DisableFormat: false 57 | ExperimentalAutoDetectBinPacking: false 58 | FixNamespaceComments: true 59 | ForEachMacros: 60 | - foreach 61 | - Q_FOREACH 62 | - BOOST_FOREACH 63 | IncludeBlocks: Preserve 64 | IncludeCategories: 65 | - Regex: '^' 66 | Priority: 2 67 | - Regex: '^<.*\.h>' 68 | Priority: 1 69 | - Regex: '^<.*' 70 | Priority: 2 71 | - Regex: '.*' 72 | Priority: 3 73 | IncludeIsMainRegex: '([-_](test|unittest))?$' 74 | IndentCaseLabels: true 75 | IndentPPDirectives: None 76 | IndentWidth: 2 77 | IndentWrappedFunctionNames: false 78 | JavaScriptQuotes: Leave 79 | JavaScriptWrapImports: true 80 | KeepEmptyLinesAtTheStartOfBlocks: false 81 | MacroBlockBegin: '' 82 | MacroBlockEnd: '' 83 | MaxEmptyLinesToKeep: 1 84 | NamespaceIndentation: None 85 | ObjCBinPackProtocolList: Never 86 | ObjCBlockIndentWidth: 2 87 | ObjCSpaceAfterProperty: false 88 | ObjCSpaceBeforeProtocolList: true 89 | PenaltyBreakAssignment: 2 90 | PenaltyBreakBeforeFirstCallParameter: 1 91 | PenaltyBreakComment: 300 92 | PenaltyBreakFirstLessLess: 120 93 | PenaltyBreakString: 1000 94 | PenaltyBreakTemplateDeclaration: 10 95 | PenaltyExcessCharacter: 1000000 96 | PenaltyReturnTypeOnItsOwnLine: 200 97 | PointerAlignment: Left 98 | RawStringFormats: 99 | - Language: Cpp 100 | Delimiters: 101 | - cc 102 | - CC 103 | - cpp 104 | - Cpp 105 | - CPP 106 | - 'c++' 107 | - 'C++' 108 | CanonicalDelimiter: '' 109 | BasedOnStyle: google 110 | - Language: TextProto 111 | Delimiters: 112 | - pb 113 | - PB 114 | - proto 115 | - PROTO 116 | EnclosingFunctions: 117 | - EqualsProto 118 | - EquivToProto 119 | - PARSE_PARTIAL_TEXT_PROTO 120 | - PARSE_TEST_PROTO 121 | - PARSE_TEXT_PROTO 122 | - ParseTextOrDie 123 | - ParseTextProtoOrDie 124 | CanonicalDelimiter: '' 125 | BasedOnStyle: google 126 | ReflowComments: true 127 | SortIncludes: true 128 | SortUsingDeclarations: true 129 | SpaceAfterCStyleCast: false 130 | SpaceAfterTemplateKeyword: true 131 | SpaceBeforeAssignmentOperators: true 132 | SpaceBeforeCpp11BracedList: false 133 | SpaceBeforeCtorInitializerColon: true 134 | SpaceBeforeInheritanceColon: true 135 | SpaceBeforeParens: ControlStatements 136 | SpaceBeforeRangeBasedForLoopColon: true 137 | SpaceInEmptyParentheses: false 138 | SpacesBeforeTrailingComments: 2 139 | SpacesInAngles: false 140 | SpacesInContainerLiterals: true 141 | SpacesInCStyleCastParentheses: false 142 | SpacesInParentheses: false 143 | SpacesInSquareBrackets: false 144 | Standard: Auto 145 | StatementMacros: 146 | - Q_UNUSED 147 | - QT_REQUIRE_VERSION 148 | TabWidth: 8 149 | UseTab: Never 150 | ... 151 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, E731, W503, C901, B008 3 | # 80 to use as a soft test 4 | max-line-length = 80 5 | max-complexity = 18 6 | select = B,C,E,F,W,T4,B9 7 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gif filter=lfs diff=lfs merge=lfs -text 2 | *.tar filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # Backup files 119 | *~ 120 | 121 | # Output files 122 | *.tsv 123 | 124 | # PyTorch checkpoint files (also GNU tar files ...) 125 | #*.tar 126 | 127 | # Compiled protobuf files 128 | *.pb.h 129 | *.pb.cc 130 | 131 | # PyCharm 132 | .idea 133 | 134 | # Mac 135 | .DS_Store 136 | 137 | # Swap files 138 | *.sw* 139 | 140 | # VSCode 141 | .vscode 142 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/grpc"] 2 | path = third_party/grpc 3 | url = https://github.com/grpc/grpc.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.7 7 | exclude: torchbeast/atari_wrappers.py 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to TorchBeast 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Coding Style 22 | * Run `black` on Python files. 23 | * Run `clang-format` on C++ files. 24 | 25 | ## License 26 | By contributing to TorchBeast, you agree that your contributions will be 27 | licensed under the LICENSE file in the root directory of this source tree. 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:experimental 2 | FROM ubuntu:18.04 3 | 4 | SHELL ["/bin/bash", "-c"] 5 | 6 | RUN apt-get update && apt-get install -y \ 7 | python3-setuptools \ 8 | python3-pip \ 9 | git \ 10 | libsm6 \ 11 | libxext6 \ 12 | libxrender-dev \ 13 | wget \ 14 | pkg-config 15 | 16 | WORKDIR /src 17 | 18 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 19 | 20 | RUN bash Miniconda3-latest-Linux-x86_64.sh -b 21 | 22 | ENV PATH /root/miniconda3/bin:$PATH 23 | 24 | ENV CONDA_PREFIX /root/miniconda3/envs/torchbeast 25 | 26 | # Clear .bashrc (it refuses to run non-interactively otherwise). 27 | RUN echo > ~/.bashrc 28 | 29 | # Add conda logic to .bashrc. 30 | RUN conda init bash 31 | 32 | # Create new environment and install some dependencies. 33 | RUN conda create -y -n torchbeast python=3.7 \ 34 | protobuf \ 35 | numpy \ 36 | ninja \ 37 | pyyaml \ 38 | mkl \ 39 | mkl-include \ 40 | setuptools \ 41 | cmake \ 42 | cffi \ 43 | typing 44 | 45 | # Activate environment in .bashrc. 46 | RUN echo "conda activate torchbeast" >> /root/.bashrc 47 | 48 | # Make bash excecute .bashrc even when running non-interactively. 49 | ENV BASH_ENV /root/.bashrc 50 | 51 | # Install PyTorch. 52 | 53 | # Would like to install PyTorch via pip. Unfortunately, there's binary 54 | # incompatability issues (https://github.com/pytorch/pytorch/issues/18128). 55 | # Otherwise, this would work: 56 | # # # Install PyTorch. This needs increased Docker memory. 57 | # # # (https://github.com/pytorch/pytorch/issues/1022) 58 | # # RUN pip download torch 59 | # # RUN pip install torch*.whl 60 | 61 | RUN git clone --single-branch --branch v1.2.0 --recursive https://github.com/pytorch/pytorch 62 | 63 | WORKDIR /src/pytorch 64 | 65 | ENV CMAKE_PREFIX_PATH ${CONDA_PREFIX} 66 | 67 | RUN python setup.py install 68 | 69 | # Clone TorchBeast. 70 | WORKDIR /src/torchbeast 71 | 72 | COPY .git /src/torchbeast/.git 73 | 74 | RUN git reset --hard 75 | 76 | # Collect and install grpc. 77 | RUN git submodule update --init --recursive 78 | 79 | RUN ./scripts/install_grpc.sh 80 | 81 | # Install nest. 82 | RUN pip install nest/ 83 | 84 | # Install PolyBeast's requirements. 85 | RUN pip install -r requirements.txt 86 | 87 | # Compile libtorchbeast. 88 | ENV LD_LIBRARY_PATH ${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH} 89 | 90 | RUN python setup.py install 91 | 92 | ENV OMP_NUM_THREADS 1 93 | 94 | # Run. 95 | CMD ["bash", "-c", "python -m torchbeast.polybeast \ 96 | --num_actors 10 \ 97 | --total_steps 200000000 \ 98 | --unroll_length 60 --batch_size 32"] 99 | 100 | 101 | # Docker commands: 102 | # docker rm torchbeast -v 103 | # docker build -t torchbeast . 104 | # docker run --name torchbeast torchbeast 105 | # or 106 | # docker run --name torchbeast -it torchbeast /bin/bash 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchBeastPopArt 2 | [PopArt](https://arxiv.org/abs/1809.04474) extension to [TorchBeast](https://github.com/facebookresearch/torchbeast), the PyTorch implementation of [IMPALA](https://github.com/deepmind/scalable_agent). 3 | 4 | # Experiments 5 | The PopArt extension was used to train a multi-task agent for six Atari games (AirRaid, Carnival, DemonAttack, Pong, SpaceInvaders, all with the NoFrameskip-v4 variant) and compared to the corresponding single-task agents and to a simpler mulit-task agent without PopArt normalisation. More details on these experiments can be found in the [report](results/report.pdf). 6 | 7 | ## Movies 8 | 9 | Single-task: 10 | ![AirRaid (Single-task clipped)](movies/AirRaid_050009600_AirRaidNoFrameskip-v4.gif) 11 | ![Carnival (Single-task clipped)](movies/Carnival_050002560_CarnivalNoFrameskip-v4.gif) 12 | ![DemonAttack (Single-task clipped)](movies/DemonAttack_050001280_DemonAttackNoFrameskip-v4.gif) 13 | ![Pong (Single-task clipped)](movies/Pong_050013440_PongNoFrameskip-v4.gif) 14 | ![SpaceInvaders (Single-task clipped)](movies/SpaceInvaders_050001280_SpaceInvadersNoFrameskip-v4.gif) 15 | 16 | Multi-task (clipped): 17 | ![AirRaid (Multi-task clipped)](movies/MultiTask_300014720_AirRaidNoFrameskip-v4.gif) 18 | ![Carnival (Multi-task clipped)](movies/MultiTask_300014720_CarnivalNoFrameskip-v4.gif) 19 | ![DemonAttack (Multi-task clipped)](movies/MultiTask_300014720_DemonAttackNoFrameskip-v4.gif) 20 | ![Pong (Multi-task clipped)](movies/MultiTask_300014720_PongNoFrameskip-v4.gif) 21 | ![SpaceInvaders (Multi-task clipped)](movies/MultiTask_300014720_SpaceInvadersNoFrameskip-v4.gif) 22 | 23 | Multi-task PopArt: 24 | ![AirRaid (Multi-task PopArt)](movies/MultiTaskPopart_300010240_AirRaidNoFrameskip-v4.gif) 25 | ![Carnival (Multi-task PopArt)](movies/MultiTaskPopart_300010240_CarnivalNoFrameskip-v4.gif) 26 | ![DemonAttack (Multi-task PopArt)](movies/MultiTaskPopart_300010240_DemonAttackNoFrameskip-v4.gif) 27 | ![Pong (Multi-task PopArt)](movies/MultiTaskPopart_300010240_PongNoFrameskip-v4.gif) 28 | ![SpaceInvaders (Multi-task PopArt)](movies/MultiTaskPopart_300010240_SpaceInvadersNoFrameskip-v4.gif) 29 | 30 | The different games plans learned by these three models, can be illustrated with the help of saliency maps (here red is the policy saliency and green is the baseline saliency). More details on these experiments can be found in the [report](results/report.pdf). 31 | 32 | Saliency: 33 | ![AirRaid](movies/Saliency_AirRaidNoFrameskip-v4.gif) 34 | ![Carnival](movies/Saliency_CarnivalNoFrameskip-v4.gif) 35 | ![DemonAttack](movies/Saliency_DemonAttackNoFrameskip-v4.gif) 36 | ![Pong](movies/Saliency_PongNoFrameskip-v4.gif) 37 | ![SpaceInvaders)](movies/Saliency_SpaceInvadersNoFrameskip-v4.gif) 38 | 39 | 40 | ## Trained models 41 | The following trained models can be downloaded from the [models](models/) directory: 42 | 43 | | Name | Environments (NoFrameskip-v4) | Steps (millions) | 44 | | ---- |------------- | ---------------- | 45 | | [AirRaid](models/AirRaid) | AirRaid | 50 | 46 | | [Carnival](models/Carnival) | Carnival | 50 | 47 | | [DemonAttack](models/DemonAttack) | DemonAttack | 50 | 48 | | [NameThisGame](models/NameThisGame) | NameThisGame | 50 | 49 | | [Pong](models/Pong) | Pong | 50 | 50 | | [SpaceInvaders](models/SpaceInvaders) | SpaceInvaders | 50 | 51 | | [MultiTask](models/MultiTask) | AirRaid, Carnival, DemonAttack, NameThisGame, Pong, SpaceInvaders | 300 | 52 | | [MultiTaskPopArt](models/MultiTaskPopArt) | AirRaid, Carnival, DemonAttack, NameThisGame, Pong, SpaceInvaders | 300 | 53 | 54 | 55 | # Running the code 56 | ## Preparation 57 | For our experiments we used the faster [PolyBeast](https://github.com/facebookresearch/torchbeast#faster-version-polybeast) implementation of TorchBeast and refer the reader to the installation instructions in the original repository. However, since we have encountered problems getting this version to work, we also added multi-task training functionality and PopArt to the [MonoBeast](https://github.com/facebookresearch/torchbeast#getting-started-monobeast) implementation of TorchBeast. However, some of the testing functionality is not implemented for this version, but PolyBeast can be used for this if the imports for `nest` and `libtorchbeast` are commented out. 58 | 59 | Since it is more convenient to get PolyBeast to run, these are the platforms on which we managed to install and use it: 60 | - Ubuntu 18.04 61 | - MacOS (CPU only) 62 | - Google Cloud Platform (Standard machine with NVIDIA Tesla P100 GPUs) 63 | 64 | ## Training a model 65 | ```bash 66 | python -m torchbeast.polybeast --mode train --xpid MultiTaskPopArt --env AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4,NameThisGameNoFrameskip-v4,PongNoFrameskip-v4,SpaceInvadersNoFrameskip-v4 --total_steps 300000000 --use_popart 67 | ``` 68 | There are the following additional flags, as compared to the original TorchBeast implementation: 69 | - `use_popart`, to enable to PopArt extension 70 | - `save_model_every_nsteps`, to save intermediate models during training 71 | 72 | ### With MonoBeast 73 | ```bash 74 | python -m torchbeast.monobeast --mode train --xpid MultiTaskPopArt --env AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4,NameThisGameNoFrameskip-v4,PongNoFrameskip-v4,SpaceInvadersNoFrameskip-v4 --total_steps 300000000 --use_popart 75 | ``` 76 | 77 | In addition MonoBeast can also be used to run two other models: a small CNN (optionally with an LSTM) and an [Attention-Augmented Agent](https://arxiv.org/abs/1906.02500) (models selected with the flag `agent_type`). Unfortunately we did not get this model to train properly, but for the sake of completeness and possible future reference, here are the additional flags that can be used with this model: 78 | - `frame_height` and `frame_width`, which set the dimensions to which frames are rescaled (in the original paper the original size is used as opposed to the rescaling done in TorchBeast) 79 | - `aaa_input_format` (with choices `gray_stack`, `rgb_last`, `rgb_stack`), which decides how frames are formatted as input for the network (where `rgb_last` only feeds one of every four frames in RGB, as is done in the original paper) 80 | 81 | ## Testing a model 82 | ```bash 83 | python -m torchbeast.polybeast --mode test --xpid MultiTaskPopArt --env PongNoFrameskip-v4 --savedir=./models 84 | python -m torchbeast.polybeast --mode test_render --xpid MultiTaskPopArt --env PongNoFrameskip-v4 --savedir=./models 85 | ``` 86 | 87 | ## Saliency 88 | ```bash 89 | python -m torchbeast.saliency --xpid MultiTaskPopArt --env PongNoFrameskip-v4 --first_frame 0 --num_frames 100 --savedir=./models 90 | ``` 91 | Note that compared to the original [saliency code](https://github.com/greydanus/visualize_atari), the extension does not produce a movie directly, but saves the frames as individual images. Animated gifs can subsequently be produced with a [Jupyter notebook](results/movies.ipynb). 92 | 93 | ## CNN filter comparisons 94 | **NOTE:** it is assumed that a) intermediate model checkpoints have been saved (flag `save_model_every_nsteps`) and b) the results for all models are saved in the same parent directory and have the exact names used in our experiments (see in the [table](https://github.com/aluscher/torchbeastpopart#trained-models)) 95 | ```bash 96 | python -m torchbeast.analysis.analyze_resnet --model_load_path /path/to/directory --mode filter_comp --comp_num_models 10 97 | ``` 98 | The different comparisons presented in the [report](results/report.pdf) can be set with the flag `comp_between`. By default the only comparisons done are between the vanilla multi-task model and the multi-task PopArt model, as well as between each of these models and all single-task models. 99 | 100 | For plotting the following command can be used (saving the figures in the same directory that the data generated by the previous command was loaded from): 101 | ```bash 102 | python -m torchbeast.analysis.analyze_resnet --load_path /path/to/directory --mode filter_comp_plot --save_figures 103 | ``` 104 | For more options to the data generation and plotting, the help texts can be consulted. 105 | 106 | # References 107 | TorchBeast 108 | ``` 109 | @article{torchbeast2019, 110 | title={{TorchBeast: A PyTorch Platform for Distributed RL}}, 111 | author={Heinrich K\"{u}ttler and Nantas Nardelli and Thibaut Lavril and Marco Selvatici and Viswanath Sivakumar and Tim Rockt\"{a}schel and Edward Grefenstette}, 112 | year={2019}, 113 | journal={arXiv preprint arXiv:1910.03552}, 114 | url={https://github.com/facebookresearch/torchbeast}, 115 | } 116 | ``` 117 | 118 | PopArt 119 | ``` 120 | @inproceedings{hessel2019, 121 | title={Multi-task deep reinforcement learning with popart}, 122 | author={Hessel, Matteo and Soyer, Hubert and Espeholt, Lasse and Czarnecki, Wojciech and Schmitt, Simon and van Hasselt, Hado}, 123 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 124 | volume={33}, 125 | pages={3796--3803}, 126 | year={2019} 127 | } 128 | ``` 129 | 130 | Saliency 131 | ``` 132 | @article{greydanus2017visualizing, 133 | title={Visualizing and Understanding Atari Agents}, 134 | author={Greydanus, Sam and Koul, Anurag and Dodge, Jonathan and Fern, Alan}, 135 | journal={arXiv preprint arXiv:1711.00138}, 136 | year={2017}, 137 | url={https://github.com/greydanus/visualize_atari}, 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /libtorchbeast/nest_serialize.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include "../nest/nest/nest.h" 20 | #include "rpcenv.pb.h" 21 | 22 | template 23 | void fill_nest_pb(rpcenv::ArrayNest* nest_pb, nest::Nest nest, 24 | Function fill_ndarray_pb) { 25 | using Nest = nest::Nest; 26 | std::visit( 27 | nest::overloaded{ 28 | [nest_pb, &fill_ndarray_pb](const T t) { 29 | fill_ndarray_pb(nest_pb->mutable_array(), t); 30 | }, 31 | [nest_pb, &fill_ndarray_pb](const std::vector& v) { 32 | for (const Nest& n : v) { 33 | rpcenv::ArrayNest* subnest = nest_pb->add_vector(); 34 | fill_nest_pb(subnest, n, fill_ndarray_pb); 35 | } 36 | }, 37 | [nest_pb, &fill_ndarray_pb](const std::map& m) { 38 | auto* map_pb = nest_pb->mutable_map(); 39 | for (const auto& p : m) { 40 | rpcenv::ArrayNest& subnest_pb = (*map_pb)[p.first]; 41 | fill_nest_pb(&subnest_pb, p.second, fill_ndarray_pb); 42 | } 43 | }}, 44 | nest.value); 45 | } 46 | 47 | template 48 | std::invoke_result_t nest_pb_to_nest( 49 | rpcenv::ArrayNest* nest_pb, Function array_to_nest) { 50 | using Nest = std::invoke_result_t; 51 | if (nest_pb->has_array()) { 52 | return array_to_nest(nest_pb->mutable_array()); 53 | } 54 | if (nest_pb->vector_size() > 0) { 55 | std::vector v; 56 | for (int i = 0, length = nest_pb->vector_size(); i < length; ++i) { 57 | v.push_back(nest_pb_to_nest(nest_pb->mutable_vector(i), array_to_nest)); 58 | } 59 | return Nest(std::move(v)); 60 | } 61 | if (nest_pb->map_size() > 0) { 62 | std::map m; 63 | for (auto& p : *nest_pb->mutable_map()) { 64 | m[p.first] = nest_pb_to_nest(&p.second, array_to_nest); 65 | } 66 | return Nest(std::move(m)); 67 | } 68 | throw std::invalid_argument("ArrayNest proto contained no data."); 69 | } 70 | -------------------------------------------------------------------------------- /libtorchbeast/rpcenv.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "nest_serialize.h" 26 | #include "rpcenv.grpc.pb.h" 27 | #include "rpcenv.pb.h" 28 | 29 | #include "../nest/nest/nest.h" 30 | #include "../nest/nest/nest_pybind.h" 31 | 32 | namespace py = pybind11; 33 | 34 | typedef nest::Nest PyArrayNest; 35 | 36 | namespace rpcenv { 37 | class EnvServer { 38 | private: 39 | class ServiceImpl final : public RPCEnvServer::Service { 40 | public: 41 | ServiceImpl(py::object env_init) : env_init_(env_init) {} 42 | 43 | private: 44 | virtual grpc::Status StreamingEnv( 45 | grpc::ServerContext *context, 46 | grpc::ServerReaderWriter *stream) override { 47 | py::gil_scoped_acquire acquire; // Destroy after pyenv. 48 | py::object pyenv; 49 | py::object stepfunc; 50 | py::object resetfunc; 51 | 52 | PyArrayNest observation; 53 | float reward = 0.0; 54 | bool done = true; 55 | int task = 0; 56 | int episode_step = 0; 57 | float episode_return = 0.0; 58 | 59 | auto set_observation = py::cpp_function( 60 | [&observation](PyArrayNest o) { observation = std::move(o); }, 61 | py::arg("observation")); 62 | 63 | auto set_observation_reward_done_task = py::cpp_function( 64 | [&observation, &reward, &done, &task](PyArrayNest o, float r, bool d, int t, 65 | py::args) { 66 | observation = std::move(o); 67 | reward = r; 68 | done = d; 69 | task = t; 70 | }, 71 | py::arg("observation"), py::arg("reward"), py::arg("done"), py::arg("task")); 72 | 73 | try { 74 | pyenv = env_init_(); 75 | stepfunc = pyenv.attr("step"); 76 | resetfunc = pyenv.attr("reset"); 77 | set_observation(resetfunc()); 78 | } catch (const pybind11::error_already_set &e) { 79 | // Needs to be caught and not re-raised, as this isn't in a Python 80 | // thread. 81 | std::cerr << e.what() << std::endl; 82 | return grpc::Status(grpc::INTERNAL, e.what()); 83 | } 84 | 85 | Step step_pb; 86 | fill_nest_pb(step_pb.mutable_observation(), std::move(observation), 87 | fill_ndarray_pb); 88 | 89 | step_pb.set_reward(reward); 90 | step_pb.set_done(done); 91 | step_pb.set_task(task); 92 | step_pb.set_episode_step(episode_step); 93 | step_pb.set_episode_return(episode_return); 94 | 95 | Action action_pb; 96 | while (true) { 97 | { 98 | py::gil_scoped_release release; // Release while doing transfer. 99 | stream->Write(step_pb); 100 | if (!stream->Read(&action_pb)) { 101 | break; 102 | } 103 | } 104 | try { 105 | // I'm not sure if this is fast, but it's convienient. 106 | set_observation_reward_done_task(*stepfunc(nest_pb_to_nest( 107 | action_pb.mutable_nest_action(), array_pb_to_nest))); 108 | 109 | episode_step += 1; 110 | episode_return += reward; 111 | 112 | step_pb.Clear(); 113 | step_pb.set_reward(reward); 114 | step_pb.set_done(done); 115 | step_pb.set_task(task); 116 | step_pb.set_episode_step(episode_step); 117 | step_pb.set_episode_return(episode_return); 118 | if (done) { 119 | set_observation(resetfunc()); 120 | // Reset episode_* for the _next_ step. 121 | episode_step = 0; 122 | episode_return = 0.0; 123 | } 124 | } catch (const pybind11::error_already_set &e) { 125 | std::cerr << e.what() << std::endl; 126 | return grpc::Status(grpc::INTERNAL, e.what()); 127 | } 128 | 129 | fill_nest_pb(step_pb.mutable_observation(), std::move(observation), 130 | fill_ndarray_pb); 131 | } 132 | return grpc::Status::OK; 133 | } 134 | 135 | py::object env_init_; // TODO: Make sure GIL is held when destroyed. 136 | 137 | // TODO: Add observation and action size functions (pre-load env) 138 | }; 139 | 140 | public: 141 | EnvServer(py::object env_class, const std::string &server_address) 142 | : server_address_(server_address), 143 | service_(env_class), 144 | server_(nullptr) {} 145 | 146 | void run() { 147 | if (server_) { 148 | throw std::runtime_error("Server already running"); 149 | } 150 | py::gil_scoped_release release; 151 | 152 | grpc::ServerBuilder builder; 153 | builder.AddListeningPort(server_address_, 154 | grpc::InsecureServerCredentials()); 155 | builder.RegisterService(&service_); 156 | server_ = builder.BuildAndStart(); 157 | std::cerr << "Server listening on " << server_address_ << std::endl; 158 | 159 | server_->Wait(); 160 | } 161 | 162 | void stop() { 163 | if (!server_) { 164 | throw std::runtime_error("Server not running"); 165 | } 166 | server_->Shutdown(); 167 | } 168 | 169 | static void fill_ndarray_pb(rpcenv::NDArray *array, py::array pyarray) { 170 | // Make sure array is C-style contiguous. If it isn't, this creates 171 | // another memcopy that is not strictly necessary. 172 | if ((pyarray.flags() & py::array::c_style) == 0) { 173 | pyarray = py::array::ensure(pyarray, py::array::c_style); 174 | } 175 | 176 | // This seems surprisingly involved. An alternative would be to include 177 | // numpy/arrayobject.h and use PyArray_TYPE. 178 | int type_num = 179 | py::detail::array_descriptor_proxy(pyarray.dtype().ptr())->type_num; 180 | 181 | array->set_dtype(type_num); 182 | for (size_t i = 0, ndim = pyarray.ndim(); i < ndim; ++i) { 183 | array->add_shape(pyarray.shape(i)); 184 | } 185 | 186 | // TODO: Consider set_allocated_data. 187 | // TODO: consider [ctype = STRING_VIEW] in proto file. 188 | py::buffer_info info = pyarray.request(); 189 | array->set_data(info.ptr, info.itemsize * info.size); 190 | } 191 | 192 | static PyArrayNest array_pb_to_nest(rpcenv::NDArray *array_pb) { 193 | std::vector shape; 194 | for (int i = 0, length = array_pb->shape_size(); i < length; ++i) { 195 | shape.push_back(array_pb->shape(i)); 196 | } 197 | 198 | // Somewhat complex way of turning an type_num into a py::dtype. 199 | py::dtype dtype = py::reinterpret_borrow( 200 | py::detail::npy_api::get().PyArray_DescrFromType_(array_pb->dtype())); 201 | 202 | std::string *data = array_pb->release_data(); 203 | 204 | // Attach capsule as base in order to free data. 205 | return PyArrayNest(py::array(dtype, shape, {}, data->data(), 206 | py::capsule(data, [](void *ptr) { 207 | delete reinterpret_cast(ptr); 208 | }))); 209 | } 210 | 211 | private: 212 | const std::string server_address_; 213 | ServiceImpl service_; 214 | std::unique_ptr server_; 215 | }; 216 | 217 | } // namespace rpcenv 218 | 219 | PYBIND11_MODULE(rpcenv, m) { 220 | py::class_(m, "Server") 221 | .def(py::init(), py::arg("env_class"), 222 | py::arg("server_address") = "unix:/tmp/polybeast") 223 | .def("run", &rpcenv::EnvServer::run) 224 | .def("stop", &rpcenv::EnvServer::stop); 225 | } 226 | -------------------------------------------------------------------------------- /libtorchbeast/rpcenv.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | syntax = "proto2"; 18 | 19 | package rpcenv; 20 | 21 | message Action { 22 | optional int32 action = 1; 23 | optional ArrayNest nest_action = 2; 24 | } 25 | 26 | message NDArray { 27 | optional int32 dtype = 1; 28 | repeated int64 shape = 2 [packed = true]; 29 | optional bytes data = 3; 30 | }; 31 | 32 | message ArrayNest { 33 | optional NDArray array = 1; 34 | repeated ArrayNest vector = 2; 35 | map map = 3; 36 | }; 37 | 38 | message Step { 39 | optional ArrayNest observation = 1; 40 | optional float reward = 2; 41 | optional bool done = 3; 42 | optional int32 task = 4; 43 | optional int32 episode_step = 5; 44 | optional float episode_return = 6; 45 | } 46 | 47 | service RPCEnvServer { 48 | rpc StreamingEnv(stream Action) returns (stream Step) {} 49 | } 50 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument("--env", type=str, default="PongNoFrameskip-v4", 8 | help="Gym environments") 9 | parser.add_argument("--steps", type=int, default=100000, help="Number of steps") 10 | 11 | def main(flags): 12 | cwd = os.getcwd() 13 | directory = os.path.join(cwd, "logs", datetime.datetime.now().strftime("%Y%m%d%H%M%S")) 14 | os.makedirs(directory) 15 | 16 | arguments = ( 17 | f'--env {flags.env} ' 18 | f'--savedir {directory} ' 19 | f'--total_steps {flags.steps} ' 20 | '--batch_size 32 ' 21 | ) 22 | print(f'python -m torchbeast.monobeast {arguments}') 23 | os.system(f'python -m torchbeast.monobeast {arguments}') 24 | 25 | 26 | if __name__ == "__main__": 27 | flags = parser.parse_args() 28 | main(flags) 29 | -------------------------------------------------------------------------------- /models/AirRaid/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "AirRaidNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 50000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "AirRaid" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-15 22:02:21.601356", 35 | "xpid": "AirRaid" 36 | } -------------------------------------------------------------------------------- /models/AirRaid/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:25aec8c63e5dd953ba2dbb387c88ddac0efd5f0a4157e8b595fc94157c3866fb 3 | size 8743338 4 | -------------------------------------------------------------------------------- /models/Carnival/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "CarnivalNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 50000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "Carnival" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-14 16:19:26.111307", 35 | "xpid": "Carnival" 36 | } -------------------------------------------------------------------------------- /models/Carnival/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:795bfa85e540e2ff4850e686cd702e2854863ec7a560b4db79a78c786b754b52 3 | size 8743666 4 | -------------------------------------------------------------------------------- /models/DemonAttack/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "DemonAttackNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 50000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "DemonAttack" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-14 20:57:48.511132", 35 | "xpid": "DemonAttack" 36 | } -------------------------------------------------------------------------------- /models/DemonAttack/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:925f478cc15b0dd10a4c0952a265d4a84d7b71961d3d854797e2dd334ffac48c 3 | size 8743552 4 | -------------------------------------------------------------------------------- /models/MultiTask/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4,NameThisGameNoFrameskip-v4,PongNoFrameskip-v4,SpaceInvadersNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 200000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "MultiTask" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-01 22:39:18.416263", 35 | "xpid": "MultiTask" 36 | } -------------------------------------------------------------------------------- /models/MultiTask/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8a893ebb439f6edcea972df3654172fb21714ef5a387f72e9b77070a9c647148 3 | size 8743413 4 | -------------------------------------------------------------------------------- /models/MultiTaskPopArt/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "beta": 0.0001, 7 | "disable_checkpoint": false, 8 | "disable_cuda": false, 9 | "discounting": 0.99, 10 | "entropy_cost": 0.0006, 11 | "env": "AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4,NameThisGameNoFrameskip-v4,PongNoFrameskip-v4,SpaceInvadersNoFrameskip-v4", 12 | "epsilon": 0.01, 13 | "grad_norm_clipping": 40.0, 14 | "intermediate_model_id": null, 15 | "learning_rate": 0.00048, 16 | "max_learner_queue_size": null, 17 | "mode": "train", 18 | "momentum": 0, 19 | "num_actions": 6, 20 | "num_actors": 48, 21 | "num_episodes": 100, 22 | "num_inference_threads": 2, 23 | "num_learner_threads": 2, 24 | "num_tasks": 6, 25 | "pipes_basename": "unix:/tmp/polybeast", 26 | "reward_clipping": "none", 27 | "save_model_every_nsteps": 1000000, 28 | "savedir": "/home/andi/logs/torchbeast", 29 | "start_servers": true, 30 | "total_steps": 300000000, 31 | "unroll_length": 80, 32 | "use_lstm": false, 33 | "use_popart": true, 34 | "write_profiler_trace": false, 35 | "xpid": "MultiTaskPopart" 36 | }, 37 | "date_end": null, 38 | "date_start": "2019-12-26 12:53:50.139174", 39 | "xpid": "MultiTaskPopart" 40 | } -------------------------------------------------------------------------------- /models/MultiTaskPopArt/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:92c6c88683aec345da4d5ed672f17c2be0543cffc3988a963a7be65a4d6def8a 3 | size 8755003 4 | -------------------------------------------------------------------------------- /models/NameThisGame/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "NameThisGameNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 50000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "NameThisGame" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-15 01:17:56.704669", 35 | "xpid": "NameThisGame" 36 | } -------------------------------------------------------------------------------- /models/NameThisGame/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e259777dff3d1fca083f134210b409799270230d027073b8621dfc9d862b20d1 3 | size 8743554 4 | -------------------------------------------------------------------------------- /models/Pong/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "PongNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 50000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "Pong" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-15 17:44:40.505797", 35 | "xpid": "Pong" 36 | } -------------------------------------------------------------------------------- /models/Pong/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:77934665c8ced6d663f67d52fc0f5f6e0436e19787796096cbd4923a1c89b524 3 | size 8743332 4 | -------------------------------------------------------------------------------- /models/SpaceInvaders/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "args": { 3 | "alpha": 0.99, 4 | "baseline_cost": 0.5, 5 | "batch_size": 8, 6 | "disable_checkpoint": false, 7 | "disable_cuda": false, 8 | "discounting": 0.99, 9 | "entropy_cost": 0.0006, 10 | "env": "SpaceInvadersNoFrameskip-v4", 11 | "epsilon": 0.01, 12 | "grad_norm_clipping": 40.0, 13 | "learning_rate": 0.00048, 14 | "max_learner_queue_size": null, 15 | "mode": "train", 16 | "momentum": 0, 17 | "num_actions": 6, 18 | "num_actors": 48, 19 | "num_episodes": 100, 20 | "num_inference_threads": 2, 21 | "num_learner_threads": 2, 22 | "pipes_basename": "unix:/tmp/polybeast", 23 | "reward_clipping": "abs_one", 24 | "save_model_every_nsteps": 1000000, 25 | "savedir": "~/logs/torchbeast", 26 | "start_servers": true, 27 | "total_steps": 50000000, 28 | "unroll_length": 80, 29 | "use_lstm": false, 30 | "write_profiler_trace": false, 31 | "xpid": "SpaceInvaders" 32 | }, 33 | "date_end": null, 34 | "date_start": "2019-12-15 06:08:37.314705", 35 | "xpid": "SpaceInvaders" 36 | } -------------------------------------------------------------------------------- /models/SpaceInvaders/model.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fb2c9b5e32324edc2c7a2e3da0c27e90af6a8b8756db2ca41eaa6d1e2d0c05fe 3 | size 8743350 4 | -------------------------------------------------------------------------------- /movies/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/movies/.DS_Store -------------------------------------------------------------------------------- /movies/AirRaid_050009600_AirRaidNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:20d6e5882928c1ea7e1fccdf8b1c02d1047805baee376f7d522a81ad97c56ced 3 | size 5248811 4 | -------------------------------------------------------------------------------- /movies/Carnival_050002560_CarnivalNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9e925f1da639347d22136e4ea4b46054a97aea3d50f9c39faa8d892706202eda 3 | size 791399 4 | -------------------------------------------------------------------------------- /movies/DemonAttack_050001280_DemonAttackNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:44b083720b91cac63661587c472dc9e84a46d2bfba05c9358102dc0f69523332 3 | size 15148370 4 | -------------------------------------------------------------------------------- /movies/MultiTaskPopart_300010240_AirRaidNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0ad8d381e64c9f25b2b63d248cd808d18b3ad45c5af7042de297ea71dceaa858 3 | size 1959116 4 | -------------------------------------------------------------------------------- /movies/MultiTaskPopart_300010240_CarnivalNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:df23f5afef6e665c207dcd208405687bd3ed8335b27f72df47929f34a057a60c 3 | size 1220370 4 | -------------------------------------------------------------------------------- /movies/MultiTaskPopart_300010240_DemonAttackNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6cbeec62e5e68c2d6b131d324fec7f4fe380fe437ee660c00ece41e02b903cdd 3 | size 10661977 4 | -------------------------------------------------------------------------------- /movies/MultiTaskPopart_300010240_NameThisGameNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:477f6383d81a1b7dcd09b90eb12248d91790f8767298560371d2a201a6e966cb 3 | size 15082873 4 | -------------------------------------------------------------------------------- /movies/MultiTaskPopart_300010240_PongNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ae1a89efb002e429fbdd16f0d284799feae4565c1af6e052f83cf03e318f0376 3 | size 4683088 4 | -------------------------------------------------------------------------------- /movies/MultiTaskPopart_300010240_SpaceInvadersNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a84823a6fab3372b1e056c58207b3b8f62e92fba4b491f3ff60bc4424d26654f 3 | size 14551533 4 | -------------------------------------------------------------------------------- /movies/MultiTask_300014720_AirRaidNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eb54f0448824b922c30282b47d07417a6b50af17b124ed13ed0bfcdbd1083dd2 3 | size 4610796 4 | -------------------------------------------------------------------------------- /movies/MultiTask_300014720_CarnivalNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:abbe2e3fc85d9adc33a848bbfa10148d539ee501b3d2d4748c6ea1dff8609118 3 | size 985159 4 | -------------------------------------------------------------------------------- /movies/MultiTask_300014720_DemonAttackNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5080bb24c3d2cd58fdc2b1f2d36d40b9dccbab64da89df2fe3c530871813e01d 3 | size 9312607 4 | -------------------------------------------------------------------------------- /movies/MultiTask_300014720_NameThisGameNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bd876fa81499e941d02e14ca5c78e60a8389d4225f80be91d1a5b67498fa6031 3 | size 5272279 4 | -------------------------------------------------------------------------------- /movies/MultiTask_300014720_PongNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e7ed79a88188701f3fc140f24655b2b22653065f67aec3b2a3152f4e7282e79d 3 | size 5136986 4 | -------------------------------------------------------------------------------- /movies/MultiTask_300014720_SpaceInvadersNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9271fab2f8dd0bec45d720d4c86c8521a03acae789d33de699f1f3a5dcbf0103 3 | size 6030426 4 | -------------------------------------------------------------------------------- /movies/NameThisGame_050016000_NameThisGameNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d70c0e95d4c41687376ddbb4f5a549c895ae2c34bde71ada67055229466ce69c 3 | size 31711322 4 | -------------------------------------------------------------------------------- /movies/Pong_050013440_PongNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5a5dfb5bee2b02a7b5399fa4c90e25dc22d20577206df32f717703d073c9ff5f 3 | size 4477437 4 | -------------------------------------------------------------------------------- /movies/Saliency_AirRaidNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6f81df34cdc9391a43a1cc51bd5546c776baa75c157560d1939a806194fb0375 3 | size 95215101 4 | -------------------------------------------------------------------------------- /movies/Saliency_CarnivalNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5da399bd4a84fb78e7f13960eed9141bfd78d5926be4a301433d8f6ab5f64656 3 | size 14760148 4 | -------------------------------------------------------------------------------- /movies/Saliency_DemonAttackNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dcc1c9d9d2fb497b8ac15b6bef4b4df2072bc36cd69a501570ac32eef9569055 3 | size 220934216 4 | -------------------------------------------------------------------------------- /movies/Saliency_NameThisGameNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e4c8804c7175f026e193cc1a01a52da53602ebc9eb76ee9eaaf95cf970ddc604 3 | size 252520562 4 | -------------------------------------------------------------------------------- /movies/Saliency_PongNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1a981f8ccabc440c692d16f1c9ce5936764471c7d682437846f2272eb2dd58a7 3 | size 96148651 4 | -------------------------------------------------------------------------------- /movies/Saliency_SpaceInvadersNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ebe9baa0edae788b3691d54790df77d9e553ff0d3c679cbd9c78c1514ae29c5d 3 | size 143983549 4 | -------------------------------------------------------------------------------- /movies/SpaceInvaders_050001280_SpaceInvadersNoFrameskip-v4.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c5aa9209db9a17948bb47c6b2849d3991423efa79df1bbca99cd33a4dc2044e2 3 | size 9412632 4 | -------------------------------------------------------------------------------- /nest/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Nest library 3 | 4 | ```shell 5 | CXX=c++ pip install . -vv 6 | ``` 7 | 8 | Usage in Python: 9 | 10 | ```python 11 | import torch 12 | import nest 13 | 14 | t1 = torch.tensor(0) 15 | t2 = torch.tensor(1) 16 | d = {'hey': torch.tensor(2)} 17 | 18 | print(nest.map(lambda t: t + 42, (t1, t2, d))) 19 | # --> (tensor(42), tensor(43), {'hey': tensor(44)}) 20 | ``` 21 | -------------------------------------------------------------------------------- /nest/nest/nest_pybind.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "nest.h" 22 | #include "nest_pybind.h" 23 | 24 | namespace py = pybind11; 25 | 26 | typedef nest::Nest PyNest; 27 | 28 | class py_list_back_inserter { 29 | public: 30 | py_list_back_inserter(py::list &l) : list_(&l) {} 31 | py_list_back_inserter &operator=(const py::object &value) { 32 | list_->append(value); 33 | return *this; 34 | }; 35 | constexpr py_list_back_inserter &operator*() { return *this; }; 36 | constexpr py_list_back_inserter &operator++() { return *this; } 37 | constexpr py_list_back_inserter &operator++(int) { return *this; } 38 | 39 | private: 40 | py::list *list_; 41 | }; 42 | 43 | PYBIND11_MODULE(nest, m) { 44 | m.def("map", [](py::function f, const PyNest &n) { 45 | // This says const py::object, but f can actually modify it! 46 | std::function cppf = 47 | [&f](const py::object &arg) { return f(arg); }; 48 | return n.map(cppf); 49 | }); 50 | m.def("map_many", 51 | [](const std::function &)> &f, 52 | py::args args) { 53 | std::vector nests = args.cast>(); 54 | return PyNest::zip(nests).map(f); 55 | }); 56 | m.def("map_many2", [](const std::function &f, 58 | const PyNest &n1, const PyNest &n2) { 59 | try { 60 | return PyNest::map2(f, n1, n2); 61 | } catch (const std::invalid_argument &e) { 62 | // IDK why I have to do this manually. 63 | throw py::value_error(e.what()); 64 | } 65 | }); 66 | m.def("flatten", [](const PyNest &n) { 67 | py::list result; 68 | n.flatten(py_list_back_inserter(result)); 69 | return result; 70 | }); 71 | m.def("pack_as", [](const PyNest &n, const py::sequence &sequence) { 72 | try { 73 | return n.pack_as(sequence.begin(), sequence.end()); 74 | } catch (const std::exception &e) { 75 | // PyTorch pybind11 doesn't seem to translate exceptions? 76 | throw py::value_error(e.what()); 77 | } 78 | }); 79 | m.def("front", [](const PyNest &n) { return n.front(); }); 80 | } 81 | -------------------------------------------------------------------------------- /nest/nest/nest_pybind.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | #include "nest.h" 23 | 24 | namespace pybind11 { 25 | namespace detail { 26 | template 27 | struct type_caster> { 28 | using ValueNest = nest::Nest; 29 | using value_conv = make_caster; 30 | 31 | public: 32 | PYBIND11_TYPE_CASTER(ValueNest, _("Nest[") + value_conv::name + _("]")); 33 | 34 | bool load(handle src, bool convert) { 35 | if (!src.ptr()) { 36 | return false; 37 | } 38 | if (isinstance(src) || isinstance(src)) { 39 | value.value = std::move(src).cast>(); 40 | return true; 41 | } 42 | if (isinstance(src)) { 43 | value.value = std::move(src).cast>(); 44 | return true; 45 | } 46 | 47 | value_conv conv; 48 | if (!conv.load(src, convert)) return false; 49 | 50 | value.value = cast_op(std::move(conv)); 51 | return true; 52 | } 53 | 54 | static handle cast(ValueNest&& src, return_value_policy policy, 55 | handle parent) { 56 | return std::visit( 57 | nest::overloaded{ 58 | [&policy, &parent](Value&& t) { 59 | return value_conv::cast(std::move(t), policy, parent); 60 | }, 61 | [&policy, &parent](std::vector&& v) { 62 | object py_list = reinterpret_steal( 63 | list_caster, ValueNest>::cast( 64 | std::move(v), policy, parent)); 65 | 66 | return handle(PyList_AsTuple(py_list.ptr())); 67 | }, 68 | [&policy, &parent](std::map&& m) { 69 | return map_caster::cast( 70 | std::move(m), policy, parent); 71 | }}, 72 | std::move(src.value)); 73 | } 74 | 75 | static handle cast(const ValueNest& src, return_value_policy policy, 76 | handle parent) { 77 | return std::visit( 78 | nest::overloaded{ 79 | [&policy, &parent](const Value& t) { 80 | return value_conv::cast(t, policy, parent); 81 | }, 82 | [&policy, &parent](const std::vector& v) { 83 | object py_list = reinterpret_steal( 84 | list_caster, ValueNest>::cast( 85 | v, policy, parent)); 86 | 87 | return handle(PyList_AsTuple(py_list.ptr())); 88 | }, 89 | [&policy, &parent](const std::map& m) { 90 | return map_caster::cast( 91 | m, policy, parent); 92 | }}, 93 | src.value); 94 | } 95 | }; 96 | } // namespace detail 97 | } // namespace pybind11 98 | -------------------------------------------------------------------------------- /nest/nest_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import unittest 17 | 18 | import nest 19 | import torch 20 | 21 | 22 | class NestTest(unittest.TestCase): 23 | def setUp(self): 24 | self.n1 = ("Test", ["More", 32], {"h": 4}) 25 | self.n2 = ("Test", ("More", 32, (None, 43, ())), {"h": 4}) 26 | 27 | def test_nest_flatten_no_asserts(self): 28 | t = torch.tensor(1) 29 | t2 = torch.tensor(2) 30 | n = (t, t2) 31 | d = {"hey": t} 32 | 33 | nest.flatten((t, t2)) 34 | nest.flatten(d) 35 | nest.flatten((d, t)) 36 | nest.flatten((d, n, t)) 37 | 38 | nest.flatten(((t, t2), (t, t2))) 39 | 40 | nest.flatten(self.n1) 41 | nest.flatten(self.n2) 42 | 43 | d2 = {"hey": t2, "there": d, "more": t2} 44 | nest.flatten(d2) # Careful here, order not necessarily as above. 45 | 46 | def test_nest_map(self): 47 | t1 = torch.tensor(0) 48 | t2 = torch.tensor(1) 49 | d = {"hey": t2} 50 | 51 | n = nest.map(lambda t: t + 42, (t1, t2)) 52 | 53 | self.assertSequenceEqual(n, [t1 + 42, t2 + 42]) 54 | self.assertSequenceEqual(n, nest.flatten(n)) 55 | 56 | n1 = (d, n, t1) 57 | n2 = nest.map(lambda t: t * 2, n1) 58 | 59 | self.assertEqual(n2[0], {"hey": torch.tensor(2)}) 60 | self.assertEqual(n2[1], (torch.tensor(84), torch.tensor(86))) 61 | self.assertEqual(n2[2], torch.tensor(0)) 62 | 63 | t = torch.tensor(42) 64 | 65 | # Doesn't work with pybind11/functional.h, but does with py::function. 66 | self.assertEqual(nest.map(t.add, t2), torch.tensor(43)) 67 | 68 | def test_nest_flatten(self): 69 | self.assertEqual(nest.flatten(None), [None]) 70 | self.assertEqual(nest.flatten(self.n1), ["Test", "More", 32, 4]) 71 | 72 | def test_nest_pack_as(self): 73 | self.assertEqual(self.n2, nest.pack_as(self.n2, nest.flatten(self.n2))) 74 | 75 | with self.assertRaisesRegex(ValueError, "didn't exhaust sequence"): 76 | nest.pack_as(self.n2, nest.flatten(self.n2) + [None]) 77 | with self.assertRaisesRegex(ValueError, "Too few elements"): 78 | nest.pack_as(self.n2, nest.flatten(self.n2)[1:]) 79 | 80 | def test_nest_map_many2(self): 81 | def f(a, b): 82 | return (b, a) 83 | 84 | self.assertEqual(nest.map_many2(f, (1, 2), (3, 4)), ((3, 1), (4, 2))) 85 | 86 | with self.assertRaisesRegex(ValueError, "got 2 vs 1"): 87 | nest.map_many2(f, (1, 2), (3,)) 88 | 89 | self.assertEqual(nest.map_many2(f, {"a": 1}, {"a": 2}), {"a": (2, 1)}) 90 | 91 | with self.assertRaisesRegex(ValueError, "same keys"): 92 | nest.map_many2(f, {"a": 1}, {"b": 2}) 93 | 94 | with self.assertRaisesRegex(ValueError, "1 vs 0"): 95 | nest.map_many2(f, {"a": 1}, {}) 96 | 97 | with self.assertRaisesRegex(ValueError, "nests don't match"): 98 | nest.map_many2(f, {"a": 1}, ()) 99 | 100 | def test_nest_map_many(self): 101 | def f(a): 102 | return (a[1], a[0]) 103 | 104 | self.assertEqual(nest.map_many(f, (1, 2), (3, 4)), ((3, 1), (4, 2))) 105 | 106 | return 107 | with self.assertRaisesRegex(ValueError, "got 2 vs 1"): 108 | nest.map_many(f, (1, 2), (3,)) 109 | 110 | self.assertEqual(nest.map_many(f, {"a": 1}, {"a": 2}), {"a": (2, 1)}) 111 | 112 | with self.assertRaisesRegex(ValueError, "same keys"): 113 | nest.map_many(f, {"a": 1}, {"b": 2}) 114 | 115 | with self.assertRaisesRegex(ValueError, "1 vs 0"): 116 | nest.map_many(f, {"a": 1}, {}) 117 | 118 | with self.assertRaisesRegex(ValueError, "nests don't match"): 119 | nest.map_many(f, {"a": 1}, ()) 120 | 121 | def test_front(self): 122 | self.assertEqual(nest.front((1, 2, 3)), 1) 123 | self.assertEqual(nest.front((2, 3)), 2) 124 | self.assertEqual(nest.front((3,)), 3) 125 | 126 | def test_refcount(self): 127 | obj = "my very large and random string with numbers 1234" 128 | 129 | rc = sys.getrefcount(obj) 130 | 131 | # Test nest.front. This doesn't involve returning nests 132 | # from C++ to Python. 133 | nest.front((None, obj)) 134 | self.assertEqual(rc, sys.getrefcount(obj)) 135 | 136 | nest.front(obj) 137 | self.assertEqual(rc, sys.getrefcount(obj)) 138 | 139 | nest.front((obj,)) 140 | self.assertEqual(rc, sys.getrefcount(obj)) 141 | 142 | nest.front((obj, obj, [obj, {"obj": obj}, obj])) 143 | self.assertEqual(rc, sys.getrefcount(obj)) 144 | 145 | # Test returning nests of Nones. 146 | nest.map(lambda x: None, (obj, obj, [obj, {"obj": obj}, obj])) 147 | self.assertEqual(rc, sys.getrefcount(obj)) 148 | 149 | # Test returning actual nests. 150 | nest.map(lambda s: s, obj) 151 | self.assertEqual(rc, sys.getrefcount(obj)) 152 | 153 | nest.map(lambda x: x, {"obj": obj}) 154 | self.assertEqual(rc, sys.getrefcount(obj)) 155 | 156 | nest.map(lambda x: x, (obj,)) 157 | self.assertEqual(rc, sys.getrefcount(obj)) 158 | 159 | nest.map(lambda s: s, (obj, obj)) 160 | nest.map(lambda s: s, (obj, obj)) 161 | self.assertEqual(rc, sys.getrefcount(obj)) 162 | 163 | n = nest.map(lambda s: s, (obj,)) 164 | self.assertEqual(rc + 1, sys.getrefcount(obj)) 165 | del n 166 | self.assertEqual(rc, sys.getrefcount(obj)) 167 | 168 | 169 | if __name__ == "__main__": 170 | unittest.main() 171 | -------------------------------------------------------------------------------- /nest/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # 16 | # CXX=c++ python3 setup.py build develop 17 | # or 18 | # CXX=c++ pip install . -vv 19 | # 20 | 21 | import sys 22 | 23 | import setuptools 24 | import setuptools.command.build_ext 25 | 26 | 27 | class get_pybind_include(object): 28 | """Helper class to determine the pybind11 include path 29 | 30 | The purpose of this class is to postpone importing pybind11 31 | until it is actually installed, so that the ``get_include()`` 32 | method can be invoked. """ 33 | 34 | def __init__(self, user=False): 35 | self.user = user 36 | 37 | def __str__(self): 38 | import pybind11 39 | 40 | return pybind11.get_include(self.user) 41 | 42 | 43 | ext_modules = [ 44 | setuptools.Extension( 45 | "nest", 46 | ["nest/nest_pybind.cc"], 47 | include_dirs=[ 48 | # Path to pybind11 headers 49 | get_pybind_include(), 50 | get_pybind_include(user=True), 51 | ], 52 | depends=["nest/nest.h", "nest/nest_pybind.h"], 53 | language="c++", 54 | extra_compile_args=["-std=c++17"], 55 | ) 56 | ] 57 | 58 | 59 | class BuildExt(setuptools.command.build_ext.build_ext): 60 | """A custom build extension for adding compiler-specific options.""" 61 | 62 | c_opts = {"msvc": ["/EHsc"], "unix": []} 63 | 64 | if sys.platform == "darwin": 65 | c_opts["unix"] += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] 66 | 67 | def build_extensions(self): 68 | ct = self.compiler.compiler_type 69 | opts = self.c_opts.get(ct, []) 70 | if ct == "unix": 71 | opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) 72 | opts.append("-std=c++17") 73 | opts.append("-fvisibility=hidden") 74 | elif ct == "msvc": 75 | opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) 76 | for ext in self.extensions: 77 | ext.extra_compile_args += opts 78 | if sys.platform == "darwin": 79 | ext.extra_link_args = ["-stdlib=libc++"] 80 | 81 | super().build_extensions() 82 | 83 | 84 | setuptools.setup( 85 | name="nest", 86 | version="0.0.3", 87 | author="TorchBeast team", 88 | ext_modules=ext_modules, 89 | headers=["nest/nest.h", "nest/nest_pybind.h"], 90 | cmdclass={"build_ext": BuildExt}, 91 | install_requires=["pybind11>=2.3"], 92 | setup_requires=["pybind11>=2.3"], 93 | ) 94 | -------------------------------------------------------------------------------- /plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/plot.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py37'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.eggs 8 | | \.git 9 | )/ 10 | ''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | tabulate 3 | tqdm 4 | scipy 5 | matplotlib 6 | numpy 7 | gym[atari]>=0.14.0 # Installs gym and atari. Needs to happen at once. 8 | gitpython>=2.1 # For logging metadata. 9 | ## Wrappers 10 | opencv-python # for atari 11 | ## dev packages 12 | flake8 13 | black 14 | pre-commit 15 | ## saliency 16 | ffmpeg 17 | -------------------------------------------------------------------------------- /results/50_games_actions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/50_games_actions.pkl -------------------------------------------------------------------------------- /results/figures/computational_graph.afdesign: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/computational_graph.afdesign -------------------------------------------------------------------------------- /results/figures/fig_action_distributions_aggregated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_action_distributions_aggregated.png -------------------------------------------------------------------------------- /results/figures/fig_action_distributions_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_action_distributions_all.png -------------------------------------------------------------------------------- /results/figures/fig_comparison_with_paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_comparison_with_paper.png -------------------------------------------------------------------------------- /results/figures/fig_computational_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_computational_graph.png -------------------------------------------------------------------------------- /results/figures/fig_detraining.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_detraining.png -------------------------------------------------------------------------------- /results/figures/fig_mean_episode_return.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_mean_episode_return.png -------------------------------------------------------------------------------- /results/figures/fig_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_model.png -------------------------------------------------------------------------------- /results/figures/fig_mu_sigma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_mu_sigma.png -------------------------------------------------------------------------------- /results/figures/fig_multi_multipop_default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_multi_multipop_default.png -------------------------------------------------------------------------------- /results/figures/fig_multi_multipop_optimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_multi_multipop_optimal.png -------------------------------------------------------------------------------- /results/figures/fig_pre_and_detraining.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_pre_and_detraining.png -------------------------------------------------------------------------------- /results/figures/fig_saliency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_saliency.png -------------------------------------------------------------------------------- /results/figures/fig_saliency_pong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_saliency_pong.png -------------------------------------------------------------------------------- /results/figures/fig_single_comp_carnival_default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_single_comp_carnival_default.png -------------------------------------------------------------------------------- /results/figures/fig_single_comp_carnival_optimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_single_comp_carnival_optimal.png -------------------------------------------------------------------------------- /results/figures/fig_single_multi_default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_single_multi_default.png -------------------------------------------------------------------------------- /results/figures/fig_single_multi_optimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_single_multi_optimal.png -------------------------------------------------------------------------------- /results/figures/fig_single_multipop_default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_single_multipop_default.png -------------------------------------------------------------------------------- /results/figures/fig_single_multipop_optimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/fig_single_multipop_optimal.png -------------------------------------------------------------------------------- /results/figures/model.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/figures/model.pptx -------------------------------------------------------------------------------- /results/paper/paper.data.processed.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/paper/paper.data.processed.pkl -------------------------------------------------------------------------------- /results/report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/report.pdf -------------------------------------------------------------------------------- /results/report_SOURCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aluscher/torchbeastpopart/1c710dd4c78d24ed73a5732ad7ba14ce578143f2/results/report_SOURCE.zip -------------------------------------------------------------------------------- /scripts/install_grpc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | set -e 18 | set -x 19 | 20 | if [ -z ${GRPC_DIR+x} ]; then 21 | GRPC_DIR=$(pwd)/third_party/grpc; 22 | fi 23 | 24 | PREFIX=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} 25 | 26 | NPROCS=$(getconf _NPROCESSORS_ONLN) 27 | 28 | pushd ${GRPC_DIR} 29 | 30 | ## This requires libprotobuf to be installed in the conda env. 31 | ## Otherwise, we could also do this: 32 | # cd ${GRPC_DIR}/third_party/grpc/third_party/protobuf 33 | # ./autogen.sh && ./configure --prefix=${PREFIX} 34 | # make && make install && ldconfig 35 | 36 | # Make make find libprotobuf 37 | export CPATH=${PREFIX}/include:${CPATH} 38 | export LIBRARY_PATH=${PREFIX}/lib:${LIBRARY_PATH} 39 | export LD_LIBRARY_PATH=${PREFIX}/lib:${LD_LIBRARY_PATH} 40 | 41 | make -j ${NPROCS} prefix=${PREFIX} \ 42 | HAS_SYSTEM_PROTOBUF=true HAS_SYSTEM_CARES=false 43 | make prefix=${PREFIX} \ 44 | HAS_SYSTEM_PROTOBUF=true HAS_SYSTEM_CARES=false install 45 | 46 | popd 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # 16 | # CXX=c++ python3 setup.py build develop 17 | # or 18 | # CXX=c++ pip install . -vv 19 | # 20 | # Potentially also set TORCHBEAST_LIBS_PREFIX. 21 | 22 | import os 23 | import subprocess 24 | import sys 25 | import unittest 26 | 27 | import numpy as np 28 | import setuptools 29 | from torch.utils import cpp_extension 30 | 31 | 32 | PREFIX = os.getenv("CONDA_PREFIX") 33 | 34 | if os.getenv("TORCHBEAST_LIBS_PREFIX"): 35 | PREFIX = os.getenv("TORCHBEAST_LIBS_PREFIX") 36 | if not PREFIX: 37 | PREFIX = "/usr/local" 38 | 39 | 40 | extra_compile_args = [] 41 | extra_link_args = [] 42 | 43 | protoc = f"{PREFIX}/bin/protoc" 44 | 45 | grpc_objects = [ 46 | f"{PREFIX}/lib/libgrpc++.a", 47 | f"{PREFIX}/lib/libgrpc.a", 48 | f"{PREFIX}/lib/libgpr.a", 49 | f"{PREFIX}/lib/libaddress_sorting.a", 50 | ] 51 | 52 | include_dirs = cpp_extension.include_paths() + [np.get_include(), f"{PREFIX}/include"] 53 | libraries = [] 54 | 55 | if sys.platform == "darwin": 56 | extra_compile_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14", "-I/usr/local/opt/openssl@1.1/include"] 57 | extra_link_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14", "-L/usr/local/opt/openssl@1.1/lib"] 58 | 59 | # Relevant only when c-cares is not embedded in grpc, e.g. when 60 | # installing grpc via homebrew. 61 | libraries.append("cares") 62 | libraries.append("ssl") 63 | elif sys.platform == "linux": 64 | libraries.append("z") 65 | 66 | grpc_objects.append(f"{PREFIX}/lib/libprotobuf.a") 67 | 68 | 69 | actorpool = cpp_extension.CppExtension( 70 | name="libtorchbeast.actorpool", 71 | sources=[ 72 | "libtorchbeast/actorpool.cc", 73 | "libtorchbeast/rpcenv.pb.cc", 74 | "libtorchbeast/rpcenv.grpc.pb.cc", 75 | ], 76 | include_dirs=include_dirs, 77 | libraries=libraries, 78 | language="c++", 79 | extra_compile_args=["-std=c++17"] + extra_compile_args, 80 | extra_link_args=extra_link_args, 81 | extra_objects=grpc_objects, 82 | ) 83 | 84 | rpcenv = cpp_extension.CppExtension( 85 | name="libtorchbeast.rpcenv", 86 | sources=[ 87 | "libtorchbeast/rpcenv.cc", 88 | "libtorchbeast/rpcenv.pb.cc", 89 | "libtorchbeast/rpcenv.grpc.pb.cc", 90 | ], 91 | include_dirs=include_dirs, 92 | libraries=libraries, 93 | language="c++", 94 | extra_compile_args=["-std=c++17"] + extra_compile_args, 95 | extra_link_args=extra_link_args, 96 | extra_objects=grpc_objects, 97 | ) 98 | 99 | 100 | def build_pb(): 101 | # Hard-code rpcenv.proto for now. 102 | source = os.path.join(os.path.dirname(__file__), "libtorchbeast", "rpcenv.proto") 103 | output = source.replace(".proto", ".pb.cc") 104 | 105 | if os.path.exists(output) and ( 106 | os.path.exists(source) and os.path.getmtime(source) < os.path.getmtime(output) 107 | ): 108 | return 109 | 110 | print("calling protoc") 111 | if ( 112 | subprocess.call( 113 | [protoc, "--cpp_out=libtorchbeast", "-Ilibtorchbeast", "rpcenv.proto"] 114 | ) 115 | != 0 116 | ): 117 | sys.exit(-1) 118 | if ( 119 | subprocess.call( 120 | protoc + " --grpc_out=libtorchbeast -Ilibtorchbeast" 121 | " --plugin=protoc-gen-grpc=`which grpc_cpp_plugin`" 122 | " rpcenv.proto", 123 | shell=True, 124 | ) 125 | != 0 126 | ): 127 | sys.exit(-1) 128 | 129 | 130 | def test_suite(): 131 | test_loader = unittest.TestLoader() 132 | test_suite = test_loader.discover("tests", pattern="*_test.py") 133 | return test_suite 134 | 135 | 136 | class build_ext(cpp_extension.BuildExtension): 137 | def run(self): 138 | build_pb() 139 | 140 | cpp_extension.BuildExtension.run(self) 141 | 142 | 143 | setuptools.setup( 144 | name="libtorchbeast", 145 | packages=["libtorchbeast"], 146 | version="0.0.13", 147 | author="TorchBeast team", 148 | ext_modules=[actorpool, rpcenv], 149 | cmdclass={"build_ext": build_ext}, 150 | test_suite="setup.test_suite", 151 | ) 152 | -------------------------------------------------------------------------------- /tests/batching_queue_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for actorpool.BatchingQueue. 15 | Basic functionalities actorpool.BatchingQueue are tested 16 | in libtorchbeast/actorpool_test.cc. 17 | """ 18 | 19 | import threading 20 | import time 21 | import unittest 22 | 23 | import numpy as np 24 | import torch 25 | from libtorchbeast import actorpool 26 | 27 | 28 | class BatchingQueueTest(unittest.TestCase): 29 | def test_bad_construct(self): 30 | with self.assertRaisesRegex(ValueError, "Min batch size must be >= 1"): 31 | actorpool.BatchingQueue( 32 | batch_dim=3, minimum_batch_size=0, maximum_batch_size=1 33 | ) 34 | 35 | with self.assertRaisesRegex( 36 | ValueError, "Max batch size must be >= min batch size" 37 | ): 38 | actorpool.BatchingQueue( 39 | batch_dim=3, minimum_batch_size=1, maximum_batch_size=0 40 | ) 41 | 42 | def test_multiple_close_calls(self): 43 | queue = actorpool.BatchingQueue() 44 | queue.close() 45 | with self.assertRaisesRegex(RuntimeError, "Queue was closed already"): 46 | queue.close() 47 | 48 | def test_check_inputs(self): 49 | queue = actorpool.BatchingQueue(batch_dim=2) 50 | with self.assertRaisesRegex( 51 | ValueError, "Enqueued tensors must have more than batch_dim ==" 52 | ): 53 | queue.enqueue(torch.ones(5)) 54 | with self.assertRaisesRegex( 55 | ValueError, "Cannot enqueue empty vector of tensors" 56 | ): 57 | queue.enqueue([]) 58 | with self.assertRaisesRegex( 59 | actorpool.ClosedBatchingQueue, "Enqueue to closed queue" 60 | ): 61 | queue.close() 62 | queue.enqueue(torch.ones(1, 1, 1)) 63 | 64 | def test_simple_run(self): 65 | queue = actorpool.BatchingQueue( 66 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 67 | ) 68 | 69 | inputs = torch.zeros(1, 2, 3) 70 | queue.enqueue(inputs) 71 | batch = next(queue) 72 | np.testing.assert_array_equal(batch, inputs) 73 | 74 | def test_batched_run(self, batch_size=2): 75 | queue = actorpool.BatchingQueue( 76 | batch_dim=0, minimum_batch_size=batch_size, maximum_batch_size=batch_size 77 | ) 78 | 79 | inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)] 80 | 81 | def enqueue_target(i): 82 | while queue.size() < i: 83 | # Make sure thread i calls enqueue before thread i + 1. 84 | time.sleep(0.05) 85 | queue.enqueue(inputs[i]) 86 | 87 | enqueue_threads = [] 88 | for i in range(batch_size): 89 | enqueue_threads.append( 90 | threading.Thread( 91 | target=enqueue_target, name=f"enqueue-thread-{i}", args=(i,) 92 | ) 93 | ) 94 | 95 | for t in enqueue_threads: 96 | t.start() 97 | 98 | batch = next(queue) 99 | np.testing.assert_array_equal(batch, torch.cat(inputs)) 100 | 101 | for t in enqueue_threads: 102 | t.join() 103 | 104 | 105 | class BatchingQueueProducerConsumerTest(unittest.TestCase): 106 | def test_many_consumers( 107 | self, enqueue_threads_number=16, repeats=100, dequeue_threads_number=64 108 | ): 109 | queue = actorpool.BatchingQueue(batch_dim=0) 110 | 111 | lock = threading.Lock() 112 | total_batches_consumed = 0 113 | 114 | def enqueue_target(i): 115 | for _ in range(repeats): 116 | queue.enqueue(torch.full((1, 2, 3), i)) 117 | 118 | def dequeue_target(): 119 | nonlocal total_batches_consumed 120 | for batch in queue: 121 | batch_size, *_ = batch.shape 122 | with lock: 123 | total_batches_consumed += batch_size 124 | 125 | enqueue_threads = [] 126 | for i in range(enqueue_threads_number): 127 | enqueue_threads.append( 128 | threading.Thread( 129 | target=enqueue_target, name=f"enqueue-thread-{i}", args=(i,) 130 | ) 131 | ) 132 | 133 | dequeue_threads = [] 134 | for i in range(dequeue_threads_number): 135 | dequeue_threads.append( 136 | threading.Thread(target=dequeue_target, name=f"dequeue-thread-{i}") 137 | ) 138 | 139 | for t in enqueue_threads + dequeue_threads: 140 | t.start() 141 | 142 | for t in enqueue_threads: 143 | t.join() 144 | 145 | queue.close() 146 | 147 | for t in dequeue_threads: 148 | t.join() 149 | 150 | self.assertEqual(total_batches_consumed, repeats * enqueue_threads_number) 151 | 152 | 153 | if __name__ == "__main__": 154 | unittest.main() 155 | -------------------------------------------------------------------------------- /tests/contiguous_arrays_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Mock environment for the test contiguous_arrays_test.py.""" 15 | 16 | import numpy as np 17 | from libtorchbeast import rpcenv 18 | 19 | 20 | class Env: 21 | def __init__(self): 22 | self.frame = np.arange(3 * 4 * 5) 23 | self.frame = self.frame.reshape(3, 4, 5) 24 | self.frame = self.frame.transpose(2, 1, 0) 25 | assert not self.frame.flags.c_contiguous 26 | 27 | def reset(self): 28 | return self.frame 29 | 30 | def step(self, action): 31 | return self.frame, 0.0, False, {} 32 | 33 | 34 | if __name__ == "__main__": 35 | server_address = "unix:/tmp/contiguous_arrays_test" 36 | server = rpcenv.Server(Env, server_address=server_address) 37 | server.run() 38 | -------------------------------------------------------------------------------- /tests/contiguous_arrays_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test that non-contiguous arrays are handled properly.""" 15 | 16 | import subprocess 17 | import threading 18 | import unittest 19 | 20 | import numpy as np 21 | from libtorchbeast import actorpool 22 | 23 | 24 | class ContiguousArraysTest(unittest.TestCase): 25 | def setUp(self): 26 | self.server_proc = subprocess.Popen( 27 | ["python", "tests/contiguous_arrays_env.py"] 28 | ) 29 | 30 | server_address = ["unix:/tmp/contiguous_arrays_test"] 31 | self.learner_queue = actorpool.BatchingQueue( 32 | batch_dim=1, minimum_batch_size=1, maximum_batch_size=10, check_inputs=True 33 | ) 34 | self.inference_batcher = actorpool.DynamicBatcher( 35 | batch_dim=1, 36 | minimum_batch_size=1, 37 | maximum_batch_size=10, 38 | timeout_ms=100, 39 | check_outputs=True, 40 | ) 41 | actor = actorpool.ActorPool( 42 | unroll_length=1, 43 | learner_queue=self.learner_queue, 44 | inference_batcher=self.inference_batcher, 45 | env_server_addresses=server_address, 46 | initial_agent_state=(), 47 | ) 48 | 49 | def run(): 50 | actor.run() 51 | 52 | self.actor_thread = threading.Thread(target=run) 53 | self.actor_thread.start() 54 | 55 | self.target = np.arange(3 * 4 * 5) 56 | self.target = self.target.reshape(3, 4, 5) 57 | self.target = self.target.transpose(2, 1, 0) 58 | 59 | def check_inference_inputs(self): 60 | batch = next(self.inference_batcher) 61 | batched_env_outputs, _ = batch.get_inputs() 62 | frame, *_ = batched_env_outputs 63 | self.assertTrue(np.array_equal(frame.shape, (1, 1, 5, 4, 3))) 64 | frame = frame.reshape(5, 4, 3) 65 | self.assertTrue(np.array_equal(frame, self.target)) 66 | # Set an arbitrary output. 67 | batch.set_outputs(((torch.ones(1, 1),), ())) 68 | 69 | def test_contiguous_arrays(self): 70 | self.check_inference_inputs() 71 | # Stop actor thread. 72 | self.inference_batcher.close() 73 | self.learner_queue.close() 74 | self.actor_thread.join() 75 | 76 | def tearDown(self): 77 | self.server_proc.terminate() 78 | -------------------------------------------------------------------------------- /tests/core_agent_state_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Mock environment for the test core_agent_state_test.py.""" 15 | 16 | import numpy as np 17 | 18 | from libtorchbeast import rpcenv 19 | 20 | 21 | class Env: 22 | def __init__(self): 23 | self.frame = np.zeros((1, 1)) 24 | self.count = 0 25 | self.done_after = 5 26 | 27 | def reset(self): 28 | self.frame = np.zeros((1, 1)) 29 | return self.frame 30 | 31 | def step(self, action): 32 | self.frame += 1 33 | done = self.frame.item() == self.done_after 34 | return self.frame, 0.0, done, {} 35 | 36 | 37 | if __name__ == "__main__": 38 | server_address = "unix:/tmp/core_agent_state_test" 39 | server = rpcenv.Server(Env, server_address=server_address) 40 | server.run() 41 | -------------------------------------------------------------------------------- /tests/core_agent_state_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test that the core state is handled correctly by the batching mechanism.""" 15 | 16 | import unittest 17 | import threading 18 | import subprocess 19 | 20 | import torch 21 | from torch import nn 22 | 23 | from libtorchbeast import actorpool 24 | 25 | 26 | class Net(nn.Module): 27 | def __init__(self): 28 | super(Net, self).__init__() 29 | 30 | def initial_state(self): 31 | return torch.zeros(1, 1) 32 | 33 | def forward(self, inputs, core_state): 34 | x = inputs["frame"] 35 | notdone = (~inputs["done"]).float() 36 | T, B, *_ = x.shape 37 | 38 | for nd in notdone.unbind(): 39 | nd.view(1, -1) 40 | core_state = nd * core_state 41 | core_state = core_state + 1 42 | # Arbitrarily return action 1. 43 | action = torch.ones((T, B), dtype=torch.int32) 44 | return (action,), core_state 45 | 46 | 47 | class CoreAgentStateTest(unittest.TestCase): 48 | def setUp(self): 49 | self.server_proc = subprocess.Popen(["python", "tests/core_agent_state_env.py"]) 50 | 51 | self.B = 2 52 | self.T = 3 53 | self.model = Net() 54 | server_address = ["unix:/tmp/core_agent_state_test"] 55 | self.learner_queue = actorpool.BatchingQueue( 56 | batch_dim=1, 57 | minimum_batch_size=self.B, 58 | maximum_batch_size=self.B, 59 | check_inputs=True, 60 | ) 61 | self.inference_batcher = actorpool.DynamicBatcher( 62 | batch_dim=1, 63 | minimum_batch_size=1, 64 | maximum_batch_size=1, 65 | timeout_ms=100, 66 | check_outputs=True, 67 | ) 68 | self.actor = actorpool.ActorPool( 69 | unroll_length=self.T, 70 | learner_queue=self.learner_queue, 71 | inference_batcher=self.inference_batcher, 72 | env_server_addresses=server_address, 73 | initial_agent_state=self.model.initial_state(), 74 | ) 75 | 76 | def inference(self): 77 | for batch in self.inference_batcher: 78 | batched_env_outputs, agent_state = batch.get_inputs() 79 | frame, _, done, *_ = batched_env_outputs 80 | # Check that when done is set we reset the environment. 81 | # Since we only have one actor producing experience we will always 82 | # have batch_size == 1, hence we can safely use item(). 83 | if done.item(): 84 | self.assertEqual(frame.item(), 0.0) 85 | outputs = self.model(dict(frame=frame, done=done), agent_state) 86 | batch.set_outputs(outputs) 87 | 88 | def learn(self): 89 | for i, tensors in enumerate(self.learner_queue): 90 | batch, initial_agent_state = tensors 91 | env_outputs, actor_outputs = batch 92 | frame, _, done, *_ = env_outputs 93 | # Make sure the last env_outputs of a rollout equals the first of the 94 | # following one. 95 | # This is guaranteed to be true if there is only one actor filling up 96 | # the learner queue. 97 | self.assertEqual(frame[self.T][0].item(), frame[0][1].item()) 98 | self.assertEqual(done[self.T][0].item(), done[0][1].item()) 99 | 100 | # Make sure the initial state equals the value of the frame at the beginning 101 | # of the rollout. This has to be the case in our test since: 102 | # - every call to forward increments the core state by one. 103 | # - every call to step increments the value in the frame by one (modulo 5). 104 | env_done_after = 5 # Matches self.done_after in core_agent_state_env.py. 105 | self.assertEqual( 106 | frame[0][0].item(), initial_agent_state[0][0].item() % env_done_after 107 | ) 108 | self.assertEqual( 109 | frame[0][1].item(), initial_agent_state[0][1].item() % env_done_after 110 | ) 111 | 112 | if i >= 10: 113 | # Stop execution. 114 | self.learner_queue.close() 115 | self.inference_batcher.close() 116 | 117 | def test_core_agent_state(self): 118 | def run(): 119 | self.actor.run() 120 | 121 | threads = [ 122 | threading.Thread(target=self.inference), 123 | threading.Thread(target=run), 124 | ] 125 | 126 | # Start actor and inference thread. 127 | for thread in threads: 128 | thread.start() 129 | 130 | self.learn() 131 | 132 | for thread in threads: 133 | thread.join() 134 | 135 | def tearDown(self): 136 | self.server_proc.terminate() 137 | self.server_proc.wait() 138 | 139 | 140 | if __name__ == "__main__": 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /tests/dynamic_batcher_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for actorpool.DynamicBatcher.""" 15 | 16 | import threading 17 | import time 18 | import unittest 19 | 20 | import numpy as np 21 | import torch 22 | from libtorchbeast import actorpool 23 | 24 | 25 | _BROKEN_PROMISE_MESSAGE = ( 26 | "The associated promise has been destructed prior" 27 | " to the associated state becoming ready." 28 | ) 29 | 30 | 31 | class DynamicBatcherTest(unittest.TestCase): 32 | def test_simple_run(self): 33 | batcher = actorpool.DynamicBatcher( 34 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 35 | ) 36 | 37 | inputs = torch.zeros(1, 2, 3) 38 | outputs = torch.ones(1, 42, 3) 39 | 40 | def target(): 41 | np.testing.assert_array_equal(batcher.compute(inputs), outputs) 42 | 43 | t = threading.Thread(target=target, name="compute-thread") 44 | t.start() 45 | 46 | batch = next(batcher) 47 | np.testing.assert_array_equal(batch.get_inputs(), inputs) 48 | batch.set_outputs(outputs) 49 | 50 | t.join() 51 | 52 | def test_timeout(self): 53 | timeout_ms = 300 54 | batcher = actorpool.DynamicBatcher( 55 | batch_dim=0, 56 | minimum_batch_size=5, 57 | maximum_batch_size=5, 58 | timeout_ms=timeout_ms, 59 | ) 60 | 61 | inputs = torch.zeros(1, 2, 3) 62 | outputs = torch.ones(1, 42, 3) 63 | 64 | def compute_target(): 65 | batcher.compute(inputs) 66 | 67 | compute_thread = threading.Thread(target=compute_target, name="compute-thread") 68 | compute_thread.start() 69 | 70 | start_waiting_time = time.time() 71 | # Wait until approximately timeout_ms. 72 | batch = next(batcher) 73 | waiting_time_ms = (time.time() - start_waiting_time) * 1000 74 | # Timeout has expired and the batch of size 1 (< minimum_batch_size) 75 | # has been consumed. 76 | batch.set_outputs(outputs) 77 | 78 | compute_thread.join() 79 | 80 | self.assertTrue(timeout_ms <= waiting_time_ms <= timeout_ms + timeout_ms / 10) 81 | 82 | def test_batched_run(self, batch_size=10): 83 | batcher = actorpool.DynamicBatcher( 84 | batch_dim=0, minimum_batch_size=batch_size, maximum_batch_size=batch_size 85 | ) 86 | 87 | inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)] 88 | outputs = torch.ones(batch_size, 42, 3) 89 | 90 | def target(i): 91 | while batcher.size() < i: 92 | # Make sure thread i calls compute before thread i + 1. 93 | time.sleep(0.05) 94 | 95 | np.testing.assert_array_equal( 96 | batcher.compute(inputs[i]), outputs[i : i + 1] 97 | ) 98 | 99 | threads = [] 100 | for i in range(batch_size): 101 | threads.append( 102 | threading.Thread(target=target, name=f"compute-thread-{i}", args=(i,)) 103 | ) 104 | 105 | for t in threads: 106 | t.start() 107 | 108 | batch = next(batcher) 109 | 110 | batched_inputs = batch.get_inputs() 111 | np.testing.assert_array_equal(batched_inputs, torch.cat(inputs)) 112 | batch.set_outputs(outputs) 113 | 114 | for t in threads: 115 | t.join() 116 | 117 | def test_dropped_batch(self): 118 | batcher = actorpool.DynamicBatcher( 119 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 120 | ) 121 | 122 | inputs = torch.zeros(1, 2, 3) 123 | 124 | def target(): 125 | with self.assertRaisesRegex(actorpool.AsyncError, _BROKEN_PROMISE_MESSAGE): 126 | batcher.compute(inputs) 127 | 128 | t = threading.Thread(target=target, name="compute-thread") 129 | t.start() 130 | 131 | next(batcher) # Retrieves but doesn't keep the batch object. 132 | t.join() 133 | 134 | def test_check_outputs1(self): 135 | batcher = actorpool.DynamicBatcher( 136 | batch_dim=2, minimum_batch_size=1, maximum_batch_size=1 137 | ) 138 | 139 | inputs = torch.zeros(1, 2, 3) 140 | 141 | def target(): 142 | batcher.compute(inputs) 143 | 144 | t = threading.Thread(target=target, name="compute-thread") 145 | t.start() 146 | 147 | batch = next(batcher) 148 | 149 | with self.assertRaisesRegex(ValueError, "output shape must have at least"): 150 | outputs = torch.ones(1) 151 | batch.set_outputs(outputs) 152 | 153 | # Set correct outputs so the thread can join. 154 | batch.set_outputs(torch.ones(1, 1, 1)) 155 | t.join() 156 | 157 | def test_check_outputs2(self): 158 | batcher = actorpool.DynamicBatcher( 159 | batch_dim=2, minimum_batch_size=1, maximum_batch_size=1 160 | ) 161 | 162 | inputs = torch.zeros(1, 2, 3) 163 | 164 | def target(): 165 | batcher.compute(inputs) 166 | 167 | t = threading.Thread(target=target, name="compute-thread") 168 | t.start() 169 | 170 | batch = next(batcher) 171 | 172 | with self.assertRaisesRegex( 173 | ValueError, 174 | "Output shape must have the same batch dimension as the input batch size.", 175 | ): 176 | # Dimenstion two of the outputs is != from the size of the batch (3 != 1). 177 | batch.set_outputs(torch.ones(1, 42, 3)) 178 | 179 | # Set correct outputs so the thread can join. 180 | batch.set_outputs(torch.ones(1, 1, 1)) 181 | t.join() 182 | 183 | def test_multiple_set_outputs_calls(self): 184 | batcher = actorpool.DynamicBatcher( 185 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 186 | ) 187 | 188 | inputs = torch.zeros(1, 2, 3) 189 | outputs = torch.ones(1, 42, 3) 190 | 191 | def target(): 192 | batcher.compute(inputs) 193 | 194 | t = threading.Thread(target=target, name="compute-thread") 195 | t.start() 196 | 197 | batch = next(batcher) 198 | batch.set_outputs(outputs) 199 | with self.assertRaisesRegex(RuntimeError, "set_outputs called twice"): 200 | batch.set_outputs(outputs) 201 | 202 | t.join() 203 | 204 | 205 | class DynamicBatcherProducerConsumerTest(unittest.TestCase): 206 | def test_many_consumers( 207 | self, 208 | minimum_batch_size=1, 209 | compute_thread_number=64, 210 | repeats=100, 211 | consume_thread_number=16, 212 | ): 213 | batcher = actorpool.DynamicBatcher( 214 | batch_dim=0, minimum_batch_size=minimum_batch_size 215 | ) 216 | 217 | lock = threading.Lock() 218 | total_batches_consumed = 0 219 | 220 | def compute_thread_target(i): 221 | for _ in range(repeats): 222 | inputs = torch.full((1, 2, 3), i) 223 | batcher.compute(inputs) 224 | 225 | def consume_thread_target(): 226 | nonlocal total_batches_consumed 227 | for batch in batcher: 228 | inputs = batch.get_inputs() 229 | batch_size, *_ = inputs.shape 230 | batch.set_outputs(torch.ones_like(inputs)) 231 | with lock: 232 | total_batches_consumed += batch_size 233 | 234 | compute_threads = [] 235 | for i in range(compute_thread_number): 236 | compute_threads.append( 237 | threading.Thread( 238 | target=compute_thread_target, name=f"compute-thread-{i}", args=(i,) 239 | ) 240 | ) 241 | 242 | consume_threads = [] 243 | for i in range(consume_thread_number): 244 | consume_threads.append( 245 | threading.Thread( 246 | target=consume_thread_target, name=f"consume-thread-{i}" 247 | ) 248 | ) 249 | 250 | for t in compute_threads + consume_threads: 251 | t.start() 252 | 253 | for t in compute_threads: 254 | t.join() 255 | 256 | # Stop iteration in all consume_threads. 257 | batcher.close() 258 | 259 | for t in consume_threads: 260 | t.join() 261 | 262 | self.assertEqual(total_batches_consumed, compute_thread_number * repeats) 263 | 264 | 265 | if __name__ == "__main__": 266 | unittest.main() 267 | -------------------------------------------------------------------------------- /tests/inference_speed_profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import sys 18 | import threading 19 | import time 20 | import timeit 21 | 22 | import torch 23 | 24 | sys.path.append("..") 25 | import experiment # noqa: E402 26 | 27 | logging.basicConfig( 28 | format=( 29 | "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" 30 | ), 31 | level=0, 32 | ) 33 | 34 | batch_size = int(sys.argv[1]) if len(sys.argv) > 1 else 4 35 | num_inference_threads = int(sys.argv[2]) if len(sys.argv) > 2 else 2 36 | 37 | 38 | def main(): 39 | filename = "inference_speed_test.json" 40 | with torch.autograd.profiler.profile() as prof: 41 | run() 42 | logging.info("Collecting trace and writing to '%s.gz'", filename) 43 | prof.export_chrome_trace(filename) 44 | os.system("gzip %s" % filename) 45 | 46 | 47 | def run(): 48 | size = (4, 84, 84) 49 | num_actions = 6 50 | 51 | if torch.cuda.is_available(): 52 | device = torch.device("cuda:0") 53 | else: 54 | device = torch.device("cpu") 55 | 56 | model = experiment.Net(observation_size=size, num_actions=num_actions) 57 | model = model.to(device=device) 58 | 59 | should_stop = threading.Event() 60 | 61 | step = 0 62 | 63 | def stream_inference(frame): 64 | nonlocal step 65 | 66 | T, B, *_ = frame.shape 67 | stream = torch.cuda.Stream() 68 | 69 | with torch.no_grad(): 70 | with torch.cuda.stream(stream): 71 | while not should_stop.is_set(): 72 | input = frame.pin_memory() 73 | input = frame.to(device, non_blocking=True) 74 | outputs = model(input) 75 | outputs = [t.cpu() for t in outputs] 76 | stream.synchronize() 77 | step += B 78 | 79 | def inference(frame, lock=threading.Lock()): # noqa: B008 80 | nonlocal step 81 | 82 | T, B, *_ = frame.shape 83 | with torch.no_grad(): 84 | while not should_stop.is_set(): 85 | input = frame.to(device) 86 | with lock: 87 | outputs = model(input) 88 | step += B 89 | outputs = [t.cpu() for t in outputs] 90 | 91 | def direct_inference(frame): 92 | nonlocal step 93 | frame = frame.to(device) 94 | 95 | T, B, *_ = frame.shape 96 | with torch.no_grad(): 97 | while not should_stop.is_set(): 98 | model(frame) 99 | step += B 100 | 101 | frame = 255 * torch.rand((1, batch_size) + size) 102 | 103 | work_threads = [ 104 | threading.Thread(target=stream_inference, args=(frame,)) 105 | for _ in range(num_inference_threads) 106 | ] 107 | for thread in work_threads: 108 | thread.start() 109 | 110 | try: 111 | while step < 10000: 112 | start_time = timeit.default_timer() 113 | start_step = step 114 | time.sleep(3) 115 | end_step = step 116 | 117 | logging.info( 118 | "Step %i @ %.1f SPS.", 119 | end_step, 120 | (end_step - start_step) / (timeit.default_timer() - start_time), 121 | ) 122 | except KeyboardInterrupt: 123 | pass 124 | 125 | should_stop.set() 126 | for thread in work_threads: 127 | thread.join() 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /tests/lint_changed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This shell script lints only the things that changed in the most recent change. 8 | # It also ignores deleted files, so that black and flake8 don't explode. 9 | 10 | set -e 11 | 12 | CMD="flake8" 13 | CHANGED_FILES="$(git diff --diff-filter=d --name-only master... | grep '\.py$' | grep -v "torchbeast/atari_wrappers.py" | tr '\n' ' ')" 14 | while getopts bi opt; do 15 | case $opt in 16 | b) 17 | CMD="black" 18 | esac 19 | 20 | done 21 | 22 | if [ "$CHANGED_FILES" != "" ] 23 | then 24 | if [[ "$CMD" == "black" ]] 25 | then 26 | command -v black >/dev/null || \ 27 | ( echo "Please install black." && false ) 28 | # Only output if something needs to change. 29 | black --check $CHANGED_FILES 30 | else 31 | flake8 --version | grep '^3\.[6-9]\.' >/dev/null || \ 32 | ( echo "Please install flake8 >=3.6.0." && false ) 33 | 34 | # Soft complaint on too-long-lines. 35 | flake8 --select=E501 --show-source $CHANGED_FILES 36 | # Hard complaint on really long lines. 37 | exec flake8 --max-line-length=127 --show-source $CHANGED_FILES 38 | fi 39 | fi 40 | -------------------------------------------------------------------------------- /tests/polybeast_inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast inference implementation.""" 15 | 16 | import unittest 17 | import warnings 18 | from unittest import mock 19 | 20 | import torch 21 | from torchbeast import polybeast 22 | 23 | 24 | class InferenceTest(unittest.TestCase): 25 | def setUp(self): 26 | self.unroll_length = 1 # Inference called for every step. 27 | self.batch_size = 4 # Arbitrary. 28 | self.frame_dimension = 84 # Has to match what expected by the model. 29 | self.num_actions = 6 # Specific to each environment. 30 | self.num_channels = 4 # Has to match with the first conv layer of the net. 31 | self.core_output_size = 256 # Has to match what expected by the model. 32 | self.num_lstm_layers = 1 # As in the model. 33 | 34 | self.frame = torch.ones( 35 | self.unroll_length, 36 | self.batch_size, 37 | self.num_channels, 38 | self.frame_dimension, 39 | self.frame_dimension, 40 | ) 41 | self.rewards = torch.ones(self.unroll_length, self.batch_size) 42 | self.done = torch.zeros(self.unroll_length, self.batch_size, dtype=torch.uint8) 43 | self.episode_return = torch.ones( 44 | self.unroll_length, self.batch_size 45 | ) # Not used in the current implemenation of inference. 46 | self.episode_step = torch.ones( 47 | self.unroll_length, self.batch_size 48 | ) # Not used in the current implemenation of inference. 49 | 50 | self.mock_batch = mock.Mock() 51 | # Set the mock inference batcher to be iterable and return a mock_batch. 52 | self.mock_inference_batcher = mock.MagicMock() 53 | self.mock_inference_batcher.__iter__.return_value = iter([self.mock_batch]) 54 | 55 | def _test_inference(self, use_lstm, device): 56 | model = polybeast.Net(num_actions=self.num_actions, use_lstm=use_lstm) 57 | model.to(device) 58 | agent_state = model.initial_state() 59 | 60 | inputs = ( 61 | ( 62 | self.frame, 63 | self.rewards, 64 | self.done, 65 | self.episode_return, 66 | self.episode_return, 67 | ), 68 | agent_state, 69 | ) 70 | # Set the behaviour of the methods of the mock batch. 71 | self.mock_batch.get_inputs = mock.Mock(return_value=inputs) 72 | self.mock_batch.set_outputs = mock.Mock() 73 | 74 | # Preparing the mock flags. Could do with just a dict but using 75 | # a Mock object for consistency. 76 | mock_flags = mock.Mock() 77 | mock_flags.actor_device = device 78 | mock_flags.use_lstm = use_lstm 79 | 80 | polybeast.inference(mock_flags, self.mock_inference_batcher, model) 81 | 82 | # Assert the batch is used only once. 83 | self.mock_batch.get_inputs.assert_called_once() 84 | self.mock_batch.set_outputs.assert_called_once() 85 | # Check that set_outputs has been called with paramaters with the expected shape. 86 | batch_args, batch_kwargs = self.mock_batch.set_outputs.call_args 87 | self.assertEqual(batch_kwargs, {}) 88 | model_outputs, *other_args = batch_args 89 | self.assertEqual(other_args, []) 90 | 91 | (action, policy_logits, baseline), core_state = model_outputs 92 | self.assertSequenceEqual(action.shape, (self.unroll_length, self.batch_size)) 93 | self.assertSequenceEqual( 94 | policy_logits.shape, (self.unroll_length, self.batch_size, self.num_actions) 95 | ) 96 | self.assertSequenceEqual(baseline.shape, (self.unroll_length, self.batch_size)) 97 | 98 | for tensor in (action, policy_logits, baseline) + core_state: 99 | self.assertEqual(tensor.device, torch.device("cpu")) 100 | 101 | self.assertEqual(len(core_state), 2 if use_lstm else 0) 102 | for core_state_element in core_state: 103 | self.assertSequenceEqual( 104 | core_state_element.shape, 105 | (self.num_lstm_layers, self.batch_size, self.core_output_size), 106 | ) 107 | 108 | def test_inference_cpu_no_lstm(self): 109 | self._test_inference(use_lstm=False, device=torch.device("cpu")) 110 | 111 | def test_inference_cuda_no_lstm(self): 112 | if not torch.cuda.is_available(): 113 | warnings.warn("Not testing cuda as it's not available") 114 | return 115 | self._test_inference(use_lstm=False, device=torch.device("cuda")) 116 | 117 | def test_inference_cpu_with_lstm(self): 118 | self._test_inference(use_lstm=True, device=torch.device("cpu")) 119 | 120 | def test_inference_cuda_with_lstm(self): 121 | if not torch.cuda.is_available(): 122 | warnings.warn("Not testing cuda as it's not available") 123 | return 124 | self._test_inference(use_lstm=True, device=torch.device("cuda")) 125 | 126 | 127 | if __name__ == "__main__": 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /tests/polybeast_learn_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast learn function implementation.""" 15 | 16 | import copy 17 | import unittest 18 | from unittest import mock 19 | 20 | import numpy as np 21 | import torch 22 | from torchbeast import polybeast 23 | 24 | 25 | def _state_dict_to_numpy(state_dict): 26 | return {key: value.numpy() for key, value in state_dict.items()} 27 | 28 | 29 | class LearnTest(unittest.TestCase): 30 | def setUp(self): 31 | unroll_length = 2 # Arbitrary. 32 | batch_size = 4 # Arbitrary. 33 | frame_dimension = 84 # Has to match what expected by the model. 34 | num_actions = 6 # Specific to each environment. 35 | num_channels = 4 # Has to match with the first conv layer of the net. 36 | 37 | # The following hyperparamaters are arbitrary. 38 | self.lr = 0.1 39 | total_steps = 100000 40 | 41 | # Set the random seed manually to get reproducible results. 42 | torch.manual_seed(0) 43 | 44 | self.model = polybeast.Net(num_actions=num_actions, use_lstm=False) 45 | self.actor_model = polybeast.Net(num_actions=num_actions, use_lstm=False) 46 | self.initial_model_dict = copy.deepcopy(self.model.state_dict()) 47 | self.initial_actor_model_dict = copy.deepcopy(self.actor_model.state_dict()) 48 | 49 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) 50 | 51 | scheduler = torch.optim.lr_scheduler.StepLR( 52 | optimizer, step_size=total_steps // 10 53 | ) 54 | 55 | self.stats = {} 56 | 57 | # The call to plogger.log will not perform any action. 58 | plogger = mock.Mock() 59 | plogger.log = mock.Mock() 60 | 61 | # Mock flags. 62 | mock_flags = mock.Mock() 63 | mock_flags.learner_device = torch.device("cpu") 64 | mock_flags.reward_clipping = "abs_one" # Default value from cmd. 65 | mock_flags.discounting = 0.99 # Default value from cmd. 66 | mock_flags.baseline_cost = 0.5 # Default value from cmd. 67 | mock_flags.entropy_cost = 0.0006 # Default value from cmd. 68 | mock_flags.unroll_length = unroll_length 69 | mock_flags.batch_size = batch_size 70 | mock_flags.grad_norm_clipping = 40 71 | 72 | # Prepare content for mock_learner_queue. 73 | frame = torch.ones( 74 | unroll_length, batch_size, num_channels, frame_dimension, frame_dimension 75 | ) 76 | rewards = torch.ones(unroll_length, batch_size) 77 | done = torch.zeros(unroll_length, batch_size, dtype=torch.uint8) 78 | episode_step = torch.ones(unroll_length, batch_size) 79 | episode_return = torch.ones(unroll_length, batch_size) 80 | 81 | env_outputs = (frame, rewards, done, episode_step, episode_return) 82 | actor_outputs = ( 83 | # Actions taken. 84 | torch.randint(low=0, high=num_actions, size=(unroll_length, batch_size)), 85 | # Logits. 86 | torch.randn(unroll_length, batch_size, num_actions), 87 | # Baseline. 88 | torch.rand(unroll_length, batch_size), 89 | ) 90 | initial_agent_state = () # No lstm. 91 | tensors = ((env_outputs, actor_outputs), initial_agent_state) 92 | 93 | # Mock learner_queue. 94 | mock_learner_queue = mock.MagicMock() 95 | mock_learner_queue.__iter__.return_value = iter([tensors]) 96 | 97 | self.learn_args = ( 98 | mock_flags, 99 | mock_learner_queue, 100 | self.model, 101 | self.actor_model, 102 | optimizer, 103 | scheduler, 104 | self.stats, 105 | plogger, 106 | ) 107 | 108 | def test_parameters_copied_to_actor_model(self): 109 | """Check that the learner model copies the parameters to the actor model.""" 110 | # Reset models. 111 | self.model.load_state_dict(self.initial_model_dict) 112 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 113 | 114 | polybeast.learn(*self.learn_args) 115 | 116 | np.testing.assert_equal( 117 | _state_dict_to_numpy(self.actor_model.state_dict()), 118 | _state_dict_to_numpy(self.model.state_dict()), 119 | ) 120 | 121 | def test_weights_update(self): 122 | """Check that trainable parameters get updated after one iteration.""" 123 | # Reset models. 124 | self.model.load_state_dict(self.initial_model_dict) 125 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 126 | 127 | polybeast.learn(*self.learn_args) 128 | 129 | model_state_dict = self.model.state_dict(keep_vars=True) 130 | actor_model_state_dict = self.actor_model.state_dict(keep_vars=True) 131 | for key, initial_tensor in self.initial_model_dict.items(): 132 | model_tensor = model_state_dict[key] 133 | actor_model_tensor = actor_model_state_dict[key] 134 | # Assert that the gradient is not zero for the learner. 135 | self.assertGreater(torch.norm(model_tensor.grad), 0.0) 136 | # Assert actor has no gradient. 137 | # Note that even though actor model tensors have no gradient, 138 | # they have requires_grad == True. No gradients are ever calculated 139 | # for these tensors because the inference function in polybeast.py 140 | # (that performs forward passes with the actor_model) uses torch.no_grad 141 | # context manager. 142 | self.assertIsNone(actor_model_tensor.grad) 143 | # Assert that the weights are updated in the expected way. 144 | # We manually perform a gradient descent step, 145 | # and check that they are the same as the calculated ones 146 | # (ignoring floating point errors). 147 | expected_tensor = ( 148 | initial_tensor.detach().numpy() - self.lr * model_tensor.grad.numpy() 149 | ) 150 | np.testing.assert_almost_equal( 151 | model_tensor.detach().numpy(), expected_tensor 152 | ) 153 | np.testing.assert_almost_equal( 154 | actor_model_tensor.detach().numpy(), expected_tensor 155 | ) 156 | 157 | def test_gradients_update(self): 158 | """Check that gradients get updated after one iteration.""" 159 | # Reset models. 160 | self.model.load_state_dict(self.initial_model_dict) 161 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 162 | 163 | # There should be no calculated gradient yet. 164 | for p in self.model.parameters(): 165 | self.assertIsNone(p.grad) 166 | for p in self.actor_model.parameters(): 167 | self.assertIsNone(p.grad) 168 | 169 | polybeast.learn(*self.learn_args) 170 | 171 | # Check that every parameter for the learner model has a gradient, and that 172 | # there is at least some non-zero gradient for each set of paramaters. 173 | for p in self.model.parameters(): 174 | self.assertIsNotNone(p.grad) 175 | self.assertFalse(torch.equal(p.grad, torch.zeros_like(p.grad))) 176 | 177 | # Check that the actor model has no gradients associated with it. 178 | for p in self.actor_model.parameters(): 179 | self.assertIsNone(p.grad) 180 | 181 | def test_non_zero_loss(self): 182 | """Check that the loss is not zero after one iteration.""" 183 | # Reset models. 184 | self.model.load_state_dict(self.initial_model_dict) 185 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 186 | 187 | polybeast.learn(*self.learn_args) 188 | 189 | self.assertNotEqual(self.stats["total_loss"], 0.0) 190 | self.assertNotEqual(self.stats["pg_loss"], 0.0) 191 | self.assertNotEqual(self.stats["baseline_loss"], 0.0) 192 | self.assertNotEqual(self.stats["entropy_loss"], 0.0) 193 | 194 | 195 | if __name__ == "__main__": 196 | unittest.main() 197 | -------------------------------------------------------------------------------- /tests/polybeast_loss_functions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast loss functions implementation.""" 15 | 16 | import unittest 17 | 18 | import numpy as np 19 | import torch 20 | from torch.nn import functional as F 21 | from torchbeast import polybeast 22 | 23 | 24 | def _softmax(logits): 25 | """Applies softmax non-linearity on inputs.""" 26 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 27 | 28 | 29 | def _softmax_grad(logits): 30 | """Compute the gradient of softmax function.""" 31 | s = np.expand_dims(_softmax(logits), 0) 32 | return s.T * (np.eye(s.size) - s) 33 | 34 | 35 | def assert_allclose(actual, desired): 36 | return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05) 37 | 38 | 39 | class ComputeBaselineLossTest(unittest.TestCase): 40 | def setUp(self): 41 | # Floating point constants are randomly generated. 42 | self.advantages = np.array([1.4, 3.43, 5.2, 0.33]) 43 | 44 | def test_compute_baseline_loss(self): 45 | ground_truth_value = 0.5 * np.sum(self.advantages ** 2) 46 | assert_allclose( 47 | ground_truth_value, 48 | polybeast.compute_baseline_loss(torch.from_numpy(self.advantages)), 49 | ) 50 | 51 | def test_compute_baseline_loss_grad(self): 52 | advantages_tensor = torch.from_numpy(self.advantages) 53 | advantages_tensor.requires_grad_() 54 | calculated_value = polybeast.compute_baseline_loss(advantages_tensor) 55 | calculated_value.backward() 56 | 57 | # Manually computed gradients: 58 | # 0.5 * d(xˆ2)/dx == x 59 | # hence the expected gradient is the same as self.advantages. 60 | assert_allclose(advantages_tensor.grad, self.advantages) 61 | 62 | 63 | class ComputeEntropyLossTest(unittest.TestCase): 64 | def setUp(self): 65 | # Floating point constants are randomly generated. 66 | self.logits = np.array([0.0012, 0.321, 0.523, 0.109, 0.416]) 67 | 68 | def test_compute_entropy_loss(self): 69 | # Calculate entropy with: 70 | # H(s) = - sum(prob(x) * ln(prob(x)) for each x in s) 71 | softmax_logits = _softmax(self.logits) 72 | ground_truth_value = np.sum(softmax_logits * np.log(softmax_logits)) 73 | calculated_value = polybeast.compute_entropy_loss(torch.from_numpy(self.logits)) 74 | 75 | assert_allclose(ground_truth_value, calculated_value) 76 | 77 | def test_compute_entropy_loss_grad(self): 78 | logits_tensor = torch.from_numpy(self.logits) 79 | logits_tensor.requires_grad_() 80 | calculated_value = polybeast.compute_entropy_loss(logits_tensor) 81 | calculated_value.backward() 82 | 83 | expected_grad = np.matmul( 84 | np.ones_like(self.logits), 85 | np.matmul( 86 | np.diag(1 + np.log(_softmax(self.logits))), _softmax_grad(self.logits) 87 | ), 88 | ) 89 | 90 | assert_allclose(logits_tensor.grad, expected_grad) 91 | 92 | 93 | class ComputePolicyGradientLossTest(unittest.TestCase): 94 | def setUp(self): 95 | # Floating point constants are randomly generated. 96 | self.logits = np.array( 97 | [ 98 | [ 99 | [0.206, 0.738, 0.125, 0.484, 0.332], 100 | [0.168, 0.504, 0.523, 0.496, 0.626], 101 | [0.236, 0.186, 0.627, 0.441, 0.533], 102 | ], 103 | [ 104 | [0.015, 0.904, 0.583, 0.651, 0.855], 105 | [0.811, 0.292, 0.061, 0.597, 0.590], 106 | [0.999, 0.504, 0.464, 0.077, 0.143], 107 | ], 108 | ] 109 | ) 110 | self.actions = np.array([[3, 0, 1], [4, 2, 2]]) 111 | self.advantages = np.array([[1.4, 0.31, 0.75], [2.1, 1.5, 0.03]]) 112 | 113 | def test_compute_policy_gradient_loss(self): 114 | T, B, N = self.logits.shape 115 | 116 | # Calculate the the cross entropy loss, with the formula: 117 | # loss = -sum_over_j(y_j * log(p_j)) 118 | # Where: 119 | # - `y_j` is whether the action corrisponding to index j has been taken or not, 120 | # (hence y is a one-hot-array of size == number of actions). 121 | # - `p_j` is the value of the sofmax logit corresponding to the jth action. 122 | # In our implementation, we also multiply for the advantages. 123 | labels = F.one_hot(torch.from_numpy(self.actions), num_classes=N).numpy() 124 | cross_entropy_loss = -labels * np.log(_softmax(self.logits)) 125 | ground_truth_value = np.sum( 126 | cross_entropy_loss * self.advantages.reshape(T, B, 1) 127 | ) 128 | 129 | calculated_value = polybeast.compute_policy_gradient_loss( 130 | torch.from_numpy(self.logits), 131 | torch.from_numpy(self.actions), 132 | torch.from_numpy(self.advantages), 133 | ) 134 | assert_allclose(ground_truth_value, calculated_value.item()) 135 | 136 | def test_compute_policy_gradient_loss_grad(self): 137 | T, B, N = self.logits.shape 138 | 139 | logits_tensor = torch.from_numpy(self.logits) 140 | logits_tensor.requires_grad_() 141 | 142 | calculated_value = polybeast.compute_policy_gradient_loss( 143 | logits_tensor, 144 | torch.from_numpy(self.actions), 145 | torch.from_numpy(self.advantages), 146 | ) 147 | 148 | self.assertSequenceEqual(calculated_value.shape, []) 149 | calculated_value.backward() 150 | 151 | # The gradient of the cross entropy loss function for the jth logit 152 | # can be expressed as: 153 | # p_j - y_j 154 | # where: 155 | # - `p_j` is the value of the softmax logit corresponding to the jth action. 156 | # - `y_j` is whether the action corrisponding to index j has been taken, 157 | # (hence y is a one-hot-array of size == number of actions). 158 | # In our implementation, we also multiply for the advantages. 159 | softmax = _softmax(self.logits) 160 | labels = F.one_hot(torch.from_numpy(self.actions), num_classes=N).numpy() 161 | expected_grad = (softmax - labels) * self.advantages.reshape(T, B, 1) 162 | 163 | assert_allclose(logits_tensor.grad, expected_grad) 164 | 165 | def test_compute_policy_gradient_loss_grad_flow(self): 166 | logits_tensor = torch.from_numpy(self.logits) 167 | logits_tensor.requires_grad_() 168 | advantages_tensor = torch.from_numpy(self.advantages) 169 | advantages_tensor.requires_grad_() 170 | 171 | loss = polybeast.compute_policy_gradient_loss( 172 | logits_tensor, torch.from_numpy(self.actions), advantages_tensor 173 | ) 174 | loss.backward() 175 | 176 | self.assertIsNotNone(logits_tensor.grad) 177 | self.assertIsNone(advantages_tensor.grad) 178 | 179 | 180 | if __name__ == "__main__": 181 | unittest.main() 182 | -------------------------------------------------------------------------------- /tests/polybeast_net_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast Net class implementation.""" 15 | 16 | import unittest 17 | 18 | import torch 19 | from torchbeast import polybeast 20 | 21 | 22 | class NetTest(unittest.TestCase): 23 | def setUp(self): 24 | self.unroll_length = 4 # Arbitrary. 25 | self.batch_size = 4 # Arbitrary. 26 | self.frame_dimension = 84 # Has to match what expected by the model. 27 | self.num_actions = 6 # Specific to each environment. 28 | self.num_channels = 4 # Has to match with the first conv layer of the net. 29 | self.core_output_size = 256 # Has to match what expected by the model. 30 | self.num_lstm_layers = 1 # As in the model. 31 | 32 | self.inputs = dict( 33 | frame=torch.ones( 34 | self.unroll_length, 35 | self.batch_size, 36 | self.num_channels, 37 | self.frame_dimension, 38 | self.frame_dimension, 39 | ), 40 | reward=torch.ones(self.batch_size, self.unroll_length), 41 | done=torch.zeros(self.batch_size, self.unroll_length, dtype=torch.uint8), 42 | ) 43 | 44 | def test_forward_return_signature_no_lstm(self): 45 | model = polybeast.Net(num_actions=self.num_actions, use_lstm=False) 46 | core_state = () 47 | 48 | (action, policy_logits, baseline), core_state = model(self.inputs, core_state) 49 | self.assertSequenceEqual(action.shape, (self.batch_size, self.unroll_length)) 50 | self.assertSequenceEqual( 51 | policy_logits.shape, (self.batch_size, self.unroll_length, self.num_actions) 52 | ) 53 | self.assertSequenceEqual(baseline.shape, (self.batch_size, self.unroll_length)) 54 | self.assertSequenceEqual(core_state, ()) 55 | 56 | def test_forward_return_signature_with_lstm(self): 57 | model = polybeast.Net(num_actions=self.num_actions, use_lstm=True) 58 | core_state = model.initial_state(self.batch_size) 59 | 60 | (action, policy_logits, baseline), core_state = model(self.inputs, core_state) 61 | self.assertSequenceEqual(action.shape, (self.batch_size, self.unroll_length)) 62 | self.assertSequenceEqual( 63 | policy_logits.shape, (self.batch_size, self.unroll_length, self.num_actions) 64 | ) 65 | self.assertSequenceEqual(baseline.shape, (self.batch_size, self.unroll_length)) 66 | self.assertEqual(len(core_state), 2) 67 | for core_state_element in core_state: 68 | self.assertSequenceEqual( 69 | core_state_element.shape, 70 | (self.num_lstm_layers, self.batch_size, self.core_output_size), 71 | ) 72 | 73 | def test_initial_state(self): 74 | model_no_lstm = polybeast.Net(num_actions=self.num_actions, use_lstm=False) 75 | initial_state_no_lstm = model_no_lstm.initial_state(self.batch_size) 76 | self.assertSequenceEqual(initial_state_no_lstm, ()) 77 | 78 | model_with_lstm = polybeast.Net(num_actions=self.num_actions, use_lstm=True) 79 | initial_state_with_lstm = model_with_lstm.initial_state(self.batch_size) 80 | self.assertEqual(len(initial_state_with_lstm), 2) 81 | for core_state_element in initial_state_with_lstm: 82 | self.assertSequenceEqual( 83 | core_state_element.shape, 84 | (self.num_lstm_layers, self.batch_size, self.core_output_size), 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /tests/vtrace_test.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # d24bd74bd53d454b7222b7f0bea57a358e4ca33e/vtrace_test.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Tests for V-trace. 21 | 22 | For details and theory see: 23 | 24 | "IMPALA: Scalable Distributed Deep-RL with 25 | Importance Weighted Actor-Learner Architectures" 26 | by Espeholt, Soyer, Munos et al. 27 | """ 28 | 29 | import unittest 30 | 31 | import numpy as np 32 | import torch 33 | from torchbeast.core import vtrace 34 | 35 | 36 | def _shaped_arange(*shape): 37 | """Runs np.arange, converts to float and reshapes.""" 38 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) 39 | 40 | 41 | def _softmax(logits): 42 | """Applies softmax non-linearity on inputs.""" 43 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 44 | 45 | 46 | def _ground_truth_calculation( 47 | discounts, 48 | log_rhos, 49 | rewards, 50 | values, 51 | bootstrap_value, 52 | clip_rho_threshold, 53 | clip_pg_rho_threshold, 54 | ): 55 | """Calculates the ground truth for V-trace in Python/Numpy.""" 56 | vs = [] 57 | seq_len = len(discounts) 58 | rhos = np.exp(log_rhos) 59 | cs = np.minimum(rhos, 1.0) 60 | clipped_rhos = rhos 61 | if clip_rho_threshold: 62 | clipped_rhos = np.minimum(rhos, clip_rho_threshold) 63 | clipped_pg_rhos = rhos 64 | if clip_pg_rho_threshold: 65 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) 66 | 67 | # This is a very inefficient way to calculate the V-trace ground truth. 68 | # We calculate it this way because it is close to the mathematical notation 69 | # of V-trace. 70 | # v_s = V(x_s) 71 | # + \sum^{T-1}_{t=s} \gamma^{t-s} 72 | # * \prod_{i=s}^{t-1} c_i 73 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) 74 | # Note that when we take the product over c_i, we write `s:t` as the 75 | # notation of the paper is inclusive of the `t-1`, but Python is exclusive. 76 | # Also note that np.prod([]) == 1. 77 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) 78 | for s in range(seq_len): 79 | v_s = np.copy(values[s]) # Very important copy. 80 | for t in range(s, seq_len): 81 | v_s += ( 82 | np.prod(discounts[s:t], axis=0) 83 | * np.prod(cs[s:t], axis=0) 84 | * clipped_rhos[t] 85 | * (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t]) 86 | ) 87 | vs.append(v_s) 88 | vs = np.stack(vs, axis=0) 89 | pg_advantages = clipped_pg_rhos * ( 90 | rewards 91 | + discounts * np.concatenate([vs[1:], bootstrap_value[None, :]], axis=0) 92 | - values 93 | ) 94 | 95 | return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) 96 | 97 | 98 | def assert_allclose(actual, desired): 99 | return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05) 100 | 101 | 102 | class ActionLogProbsTest(unittest.TestCase): 103 | def test_action_log_probs(self, batch_size=2): 104 | seq_len = 7 105 | num_actions = 3 106 | 107 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 108 | actions = np.random.randint( 109 | 0, num_actions, size=(seq_len, batch_size), dtype=np.int64 110 | ) 111 | 112 | action_log_probs_tensor = vtrace.action_log_probs( 113 | torch.from_numpy(policy_logits), torch.from_numpy(actions) 114 | ) 115 | 116 | # Ground Truth 117 | # Using broadcasting to create a mask that indexes action logits 118 | action_index_mask = actions[..., None] == np.arange(num_actions) 119 | 120 | def index_with_mask(array, mask): 121 | return array[mask].reshape(*array.shape[:-1]) 122 | 123 | # Note: Normally log(softmax) is not a good idea because it's not 124 | # numerically stable. However, in this test we have well-behaved values. 125 | ground_truth_v = index_with_mask( 126 | np.log(_softmax(policy_logits)), action_index_mask 127 | ) 128 | 129 | assert_allclose(ground_truth_v, action_log_probs_tensor) 130 | 131 | def test_action_log_probs_batch_1(self): 132 | self.test_action_log_probs(1) 133 | 134 | 135 | class VtraceTest(unittest.TestCase): 136 | def test_vtrace(self, batch_size=5): 137 | """Tests V-trace against ground truth data calculated in python.""" 138 | seq_len = 5 139 | 140 | # Create log_rhos such that rho will span from near-zero to above the 141 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), 142 | # so that rho is in approx [0.08, 12.2). 143 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) 144 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). 145 | values = { 146 | "log_rhos": log_rhos, 147 | # T, B where B_i: [0.9 / (i+1)] * T 148 | "discounts": np.array( 149 | [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], 150 | dtype=np.float32, 151 | ), 152 | "rewards": _shaped_arange(seq_len, batch_size), 153 | "values": _shaped_arange(seq_len, batch_size) / batch_size, 154 | "bootstrap_value": _shaped_arange(batch_size) + 1.0, 155 | "clip_rho_threshold": 3.7, 156 | "clip_pg_rho_threshold": 2.2, 157 | } 158 | 159 | ground_truth = _ground_truth_calculation(**values) 160 | 161 | values = {key: torch.tensor(value) for key, value in values.items()} 162 | output = vtrace.from_importance_weights(**values) 163 | 164 | for a, b in zip(ground_truth, output): 165 | assert_allclose(a, b) 166 | 167 | def test_vtrace_batch_1(self): 168 | self.test_vtrace(1) 169 | 170 | def test_vtrace_from_logits(self, batch_size=2): 171 | """Tests V-trace calculated from logits.""" 172 | seq_len = 5 173 | num_actions = 3 174 | clip_rho_threshold = None # No clipping. 175 | clip_pg_rho_threshold = None # No clipping. 176 | 177 | values = { 178 | "behavior_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), 179 | "target_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), 180 | "actions": np.random.randint( 181 | 0, num_actions - 1, size=(seq_len, batch_size) 182 | ), 183 | "discounts": np.array( # T, B where B_i: [0.9 / (i+1)] * T 184 | [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], 185 | dtype=np.float32, 186 | ), 187 | "rewards": _shaped_arange(seq_len, batch_size), 188 | "values": _shaped_arange(seq_len, batch_size) / batch_size, 189 | "bootstrap_value": _shaped_arange(batch_size) + 1.0, # B 190 | } 191 | values = {k: torch.from_numpy(v) for k, v in values.items()} 192 | 193 | from_logits_output = vtrace.from_logits( 194 | clip_rho_threshold=clip_rho_threshold, 195 | clip_pg_rho_threshold=clip_pg_rho_threshold, 196 | **values, 197 | ) 198 | 199 | target_log_probs = vtrace.action_log_probs( 200 | values["target_policy_logits"], values["actions"] 201 | ) 202 | behavior_log_probs = vtrace.action_log_probs( 203 | values["behavior_policy_logits"], values["actions"] 204 | ) 205 | log_rhos = target_log_probs - behavior_log_probs 206 | 207 | # Calculate V-trace using the ground truth logits. 208 | from_iw = vtrace.from_importance_weights( 209 | log_rhos=log_rhos, 210 | discounts=values["discounts"], 211 | rewards=values["rewards"], 212 | values=values["values"], 213 | bootstrap_value=values["bootstrap_value"], 214 | clip_rho_threshold=clip_rho_threshold, 215 | clip_pg_rho_threshold=clip_pg_rho_threshold, 216 | ) 217 | 218 | assert_allclose(from_iw.vs, from_logits_output.vs) 219 | assert_allclose(from_iw.pg_advantages, from_logits_output.pg_advantages) 220 | assert_allclose( 221 | behavior_log_probs, from_logits_output.behavior_action_log_probs 222 | ) 223 | assert_allclose(target_log_probs, from_logits_output.target_action_log_probs) 224 | assert_allclose(log_rhos, from_logits_output.log_rhos) 225 | 226 | def test_vtrace_from_logits_batch_1(self): 227 | self.test_vtrace_from_logits(1) 228 | 229 | def test_higher_rank_inputs_for_importance_weights(self): 230 | """Checks support for additional dimensions in inputs.""" 231 | T = 3 # pylint: disable=invalid-name 232 | B = 2 # pylint: disable=invalid-name 233 | values = { 234 | "log_rhos": torch.zeros(T, B, 1), 235 | "discounts": torch.zeros(T, B, 1), 236 | "rewards": torch.zeros(T, B, 42), 237 | "values": torch.zeros(T, B, 42), 238 | "bootstrap_value": torch.zeros(B, 42), 239 | } 240 | output = vtrace.from_importance_weights(**values) 241 | self.assertSequenceEqual(output.vs.shape, (T, B, 42)) 242 | 243 | def test_inconsistent_rank_inputs_for_importance_weights(self): 244 | """Test one of many possible errors in shape of inputs.""" 245 | T = 3 # pylint: disable=invalid-name 246 | B = 2 # pylint: disable=invalid-name 247 | 248 | values = { 249 | "log_rhos": torch.zeros(T, B, 1), 250 | "discounts": torch.zeros(T, B, 1), 251 | "rewards": torch.zeros(T, B, 42), 252 | "values": torch.zeros(T, B, 42), 253 | # Should be [B, 42]. 254 | "bootstrap_value": torch.zeros(B), 255 | } 256 | 257 | with self.assertRaisesRegex( 258 | RuntimeError, "same number of dimensions: got 3 and 2" 259 | ): 260 | vtrace.from_importance_weights(**values) 261 | 262 | 263 | if __name__ == "__main__": 264 | unittest.main() 265 | -------------------------------------------------------------------------------- /torchbeast.yml: -------------------------------------------------------------------------------- 1 | name: torchbeast 2 | dependencies: 3 | - pytorch 4 | - torchvision 5 | - scikit-image 6 | - flake8 7 | - black 8 | - pre-commit 9 | - gitpython 10 | - ffmpeg 11 | - pip 12 | - pip: 13 | - gym[atari] 14 | - opencv-python 15 | -------------------------------------------------------------------------------- /torchbeast/analysis/gradient_tracking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import itertools as it 4 | import tabulate 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from matplotlib.lines import Line2D 8 | 9 | 10 | def uniquify(path, sep=''): 11 | def name_sequence(): 12 | count = it.count() 13 | yield '' 14 | while True: 15 | yield '{s}{n:d}'.format(s = sep, n = next(count)) 16 | orig = tempfile._name_sequence 17 | with tempfile._once_lock: 18 | tempfile._name_sequence = name_sequence() 19 | path = os.path.normpath(path) 20 | dirname, basename = os.path.split(path) 21 | filename, ext = os.path.splitext(basename) 22 | fd, filename = tempfile.mkstemp(dir = dirname, prefix = filename, suffix = ext) 23 | tempfile._name_sequence = orig 24 | return filename 25 | 26 | 27 | def plot_grad_flow(named_parameters, flags): 28 | '''Plots the gradients flowing through different layers in the net during training. 29 | Can be used for checking for possible gradient vanishing / exploding problems. 30 | 31 | Usage: Plug this function in Trainer class after loss.backwards() as 32 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 33 | ave_grads = [] 34 | max_grads = [] 35 | layers = [] 36 | for n, p in named_parameters: 37 | if (p.requires_grad) and ("bias" not in n): 38 | layers.append(n) 39 | ave_grads.append(p.grad.abs().mean()) 40 | max_grads.append(p.grad.abs().max()) 41 | fig = plt.figure(figsize=(5, 10)) 42 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.3, lw=1, color="c") 43 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.9, lw=1, color="b") 44 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 45 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 46 | plt.xlim(left=0, right=len(ave_grads)) 47 | plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions 48 | plt.xlabel("Layers") 49 | plt.ylabel("average gradient") 50 | plt.title("Gradient flow") 51 | plt.grid(True) 52 | plt.legend([Line2D([0], [0], color="c", lw=4), 53 | Line2D([0], [0], color="b", lw=4), 54 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 55 | plt.tight_layout() 56 | 57 | path = os.path.expandvars(os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "gradients.png"))) 58 | plt.savefig(uniquify(path)) 59 | 60 | 61 | class GradientTracker: 62 | 63 | def __init__(self): 64 | self.avg_grad = {} 65 | self.max_grad = {} 66 | 67 | self.learning_step_count = 0 68 | 69 | def process_backward_pass(self, named_parameters, verbose=False): 70 | current_grad = [] 71 | for n, p in named_parameters: 72 | if p.requires_grad and "bias" not in n: 73 | current_grad.append([n]) 74 | 75 | if n not in self.avg_grad: 76 | self.avg_grad[n] = [] 77 | if n not in self.max_grad: 78 | self.max_grad[n] = [] 79 | 80 | if p.grad is None and verbose: 81 | print("Layer '{}' has gradient None!".format(n)) 82 | 83 | self.avg_grad[n].append(p.grad.abs().mean()) 84 | self.max_grad[n].append(p.grad.abs().max()) 85 | 86 | current_grad[-1].append(self.avg_grad[n][-1]) 87 | current_grad[-1].append(p.grad.abs().std()) 88 | current_grad[-1].append(self.max_grad[n][-1]) 89 | 90 | if verbose: 91 | print("\nCurrent gradients at learning step {:d}:".format(self.learning_step_count)) 92 | print(tabulate.tabulate(current_grad, headers=["layer", "mean", "std", "max"], tablefmt="presto"), "\n") 93 | 94 | self.learning_step_count += 1 95 | 96 | def print_total(self): 97 | grad = [] 98 | for n in self.avg_grad: 99 | grad.append([n]) 100 | grad[-1].append(np.mean(self.avg_grad[n])) 101 | grad[-1].append(np.max(self.max_grad[n])) 102 | 103 | print("\nTotal gradients at learning step {:d}:".format(self.learning_step_count)) 104 | print(tabulate.tabulate(grad, headers=["layer", "mean", "max"], tablefmt="presto"), "\n") -------------------------------------------------------------------------------- /torchbeast/analysis/visualize_aaa.py: -------------------------------------------------------------------------------- 1 | # Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import time 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore") # mute warnings, live dangerously ;) 10 | 11 | import torch 12 | import matplotlib.pyplot as plt 13 | import matplotlib.animation as manimation 14 | import cv2 15 | 16 | from torchbeast.monobeast import create_env 17 | from torchbeast.core.environment import Environment 18 | from torchbeast.models.attention_augmented_agent import AttentionAugmentedAgent 19 | 20 | cv2.ocl.setUseOpenCL(False) 21 | 22 | parser = argparse.ArgumentParser(description="Visualizations for the Attention-Augmented Agent") 23 | 24 | parser.add_argument("--model_load_path", default="./logs/torchbeast", 25 | help="Path to the model that should be used for the visualizations.") 26 | parser.add_argument("--env", type=str, default="PongNoFrameskip-v4", 27 | help="Gym environment.") 28 | parser.add_argument("--frame_height", type=int, default=84, 29 | help="Height to which frames are rescaled.") 30 | parser.add_argument("--frame_width", type=int, default=84, 31 | help="Width to which frames are rescaled.") 32 | parser.add_argument("--aaa_input_format", type=str, default="gray_stack", choices=["gray_stack, rgb_last, rgb_stack"], 33 | help="Color format of the frames as input for the AAA.") 34 | parser.add_argument("--num_frames", default=50, type=int, 35 | help=".") 36 | parser.add_argument("--first_frame", default=200, type=int, 37 | help=".") 38 | parser.add_argument("--resolution", default=75, type=int, 39 | help=".") 40 | parser.add_argument("--density", default=2, type=int, 41 | help=".") 42 | parser.add_argument("--radius", default=2, type=int, 43 | help=".") 44 | parser.add_argument("--save_dir", default="~/logs/aaa-vis", 45 | help=".") 46 | 47 | logging.basicConfig( 48 | format=( 49 | "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" 50 | ), 51 | level=0, 52 | ) 53 | 54 | 55 | def extract_data(data): 56 | if type(data) is tuple: 57 | return tuple(extract_data(d) for d in data) 58 | elif type(data) is dict: 59 | return None 60 | return data.detach() 61 | 62 | 63 | def rollout(model, env, max_ep_len=3e3, render=False): 64 | history = { 65 | "observation": [], 66 | "policy_logits": [], 67 | "baseline": [], 68 | "agent_state": [], 69 | "image": [], 70 | "attention_maps": [] 71 | } 72 | episode_length, epr, eploss, done = 0, 0, 0, False # bookkeeping 73 | 74 | activations = {} 75 | 76 | def get_activation(layer_name): 77 | if layer_name not in activations: 78 | activations[layer_name] = [] 79 | 80 | def hook(model, input, output): 81 | activations[layer_name].append(extract_data(output)) 82 | 83 | return hook 84 | 85 | for name, module in model.named_modules(): 86 | module.register_forward_hook(get_activation(name)) 87 | 88 | observation = env.initial() 89 | with torch.no_grad(): 90 | agent_state = model.initial_state(batch_size=1) 91 | while not done and episode_length <= max_ep_len: 92 | episode_length += 1 93 | agent_output, agent_state, attention_maps = model(observation, agent_state, return_attention_maps=True) 94 | observation = env.step(agent_output["action"]) 95 | done = observation["done"] 96 | 97 | history["observation"].append(observation) 98 | history["policy_logits"].append(agent_output["policy_logits"].detach().numpy()[0]) 99 | history["baseline"].append(agent_output["baseline"].detach().numpy()[0]) 100 | history["agent_state"].append(tuple(s.data.numpy()[0] for s in agent_state)) 101 | history["attention_maps"].append(attention_maps.detach().numpy()[0]) 102 | history["image"].append(env.gym_env.render(mode='rgb_array')) 103 | history["activations"] = activations 104 | 105 | return history 106 | 107 | 108 | def visualize_aaa(model, env, flags): 109 | video_title = "{}_{}_{}_{}.mp4".format("aaa-vis", flags.env, flags.first_frame, flags.num_frames) 110 | max_ep_len = flags.first_frame + flags.num_frames + 1 111 | torch.manual_seed(0) 112 | history = rollout(model, env, max_ep_len=max_ep_len) 113 | 114 | start = time.time() 115 | ffmpeg_writer = manimation.writers["ffmpeg"] 116 | metadata = dict(title=video_title, artist="", comment="atari-attention-augmented-agent-video") 117 | writer = ffmpeg_writer(fps=8, metadata=metadata) 118 | 119 | total_frames = len(history["observation"]) 120 | f = plt.figure(figsize=[(4 / 1.3) * 2, 4], dpi=flags.resolution) 121 | axis_f = f.add_subplot(1, 2, 1) 122 | axis_a = f.add_subplot(1, 2, 2) 123 | axis_f.axis("off") 124 | axis_a.axis("off") 125 | 126 | video_path = os.path.expandvars(os.path.expanduser(flags.save_dir)) 127 | if not os.path.exists(video_path): 128 | os.makedirs(video_path) 129 | with writer.saving(f, video_path + "/" + video_title, flags.resolution): 130 | for i in range(flags.num_frames): 131 | ix = flags.first_frame + i 132 | if ix < total_frames: # prevent loop from trying to process a frame ix greater than rollout length 133 | frame = history["image"][ix] 134 | attention_maps = history["attention_maps"][ix] 135 | attention_map = attention_maps[:, :, 0] 136 | attention_map = cv2.resize(attention_map, 137 | (frame.shape[1], frame.shape[0]), 138 | interpolation=cv2.INTER_NEAREST) 139 | 140 | axis_f.imshow(frame) 141 | axis_a.imshow(attention_map, cmap="gray") 142 | f.suptitle(flags.env, fontsize=15, fontname="DejaVuSans") 143 | 144 | writer.grab_frame() 145 | f.clear() 146 | axis_f = f.add_subplot(1, 2, 1) 147 | axis_a = f.add_subplot(1, 2, 2) 148 | axis_f.axis("off") 149 | axis_a.axis("off") 150 | 151 | time_str = time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start)) 152 | print("\ttime: {} | progress: {:.1f}%".format( 153 | time_str, 100 * i / min(flags.num_frames, total_frames)), end="\r") 154 | print("\nFinished.") 155 | 156 | 157 | if __name__ == "__main__": 158 | flags = parser.parse_args() 159 | 160 | gym_env = create_env(flags.env, frame_height=flags.frame_height, frame_width=flags.frame_width, 161 | gray_scale=(flags.aaa_input_format == "gray_stack")) 162 | env = Environment(gym_env) 163 | model = AttentionAugmentedAgent(gym_env.observation_space.shape, gym_env.action_space.n, 164 | rgb_last=(flags.aaa_input_format == "rgb_last")) 165 | model.eval() 166 | checkpoint = torch.load(flags.model_load_path, map_location="cpu") 167 | model.load_state_dict(checkpoint["model_state_dict"]) 168 | 169 | logging.info("Visualizing AAA using checkpoint at %s.", flags.model_load_path) 170 | visualize_aaa(model, env, flags) 171 | -------------------------------------------------------------------------------- /torchbeast/core/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The environment class for MonoBeast.""" 15 | 16 | import torch 17 | 18 | 19 | def _format_frame(frame): 20 | frame = torch.from_numpy(frame) 21 | return frame.view((1, 1) + frame.shape) # (...) -> (T,B,...). 22 | 23 | 24 | class Environment: 25 | 26 | def __init__(self, gym_env): 27 | self.gym_env = gym_env 28 | self.episode_return = None 29 | self.episode_step = None 30 | 31 | def initial(self): 32 | initial_reward = torch.zeros(1, 1) 33 | # This supports only single-tensor actions ATM. 34 | initial_last_action = torch.zeros(1, 1, dtype=torch.int64) 35 | self.episode_return = torch.zeros(1, 1) 36 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 37 | initial_done = torch.ones(1, 1, dtype=torch.uint8) 38 | initial_frame = _format_frame(self.gym_env.reset()) 39 | return dict( 40 | frame=initial_frame, 41 | reward=initial_reward, 42 | done=initial_done, 43 | episode_return=self.episode_return, 44 | episode_step=self.episode_step, 45 | last_action=initial_last_action, 46 | ) 47 | 48 | def step(self, action): 49 | frame, reward, done, task, unused_info = self.gym_env.step(action.item()) 50 | self.episode_step += 1 51 | self.episode_return += reward 52 | episode_step = self.episode_step 53 | episode_return = self.episode_return 54 | if done: 55 | frame = self.gym_env.reset() 56 | self.episode_return = torch.zeros(1, 1) 57 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 58 | 59 | frame = _format_frame(frame) 60 | reward = torch.tensor(reward).view(1, 1) 61 | done = torch.tensor(done).view(1, 1) 62 | task = torch.tensor(task).view(1, 1) 63 | 64 | return dict( 65 | frame=frame, 66 | reward=reward, 67 | done=done, 68 | task=task, 69 | episode_return=episode_return, 70 | episode_step=episode_step, 71 | last_action=action, 72 | ) 73 | 74 | def step_no_task(self, action): 75 | frame, reward, done, unused_info = self.gym_env.step(action.item()) 76 | self.episode_step += 1 77 | self.episode_return += reward 78 | episode_step = self.episode_step 79 | episode_return = self.episode_return 80 | if done: 81 | frame = self.gym_env.reset() 82 | self.episode_return = torch.zeros(1, 1) 83 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 84 | 85 | frame = _format_frame(frame) 86 | reward = torch.tensor(reward).view(1, 1) 87 | done = torch.tensor(done).view(1, 1) 88 | 89 | return dict( 90 | frame=frame, 91 | reward=reward, 92 | done=done, 93 | episode_return=episode_return, 94 | episode_step=episode_step, 95 | last_action=action, 96 | ) 97 | 98 | 99 | def close(self): 100 | self.gym_env.close() 101 | -------------------------------------------------------------------------------- /torchbeast/core/file_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import csv 18 | import datetime 19 | import json 20 | import logging 21 | import os 22 | import time 23 | from typing import Dict 24 | 25 | 26 | def gather_metadata() -> Dict: 27 | date_start = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") 28 | # Gathering git metadata. 29 | try: 30 | import git 31 | 32 | try: 33 | repo = git.Repo(search_parent_directories=True) 34 | git_sha = repo.commit().hexsha 35 | git_data = dict( 36 | commit=git_sha, 37 | branch=None if repo.head.is_detached else repo.active_branch.name, 38 | is_dirty=repo.is_dirty(), 39 | path=repo.git_dir, 40 | ) 41 | except git.InvalidGitRepositoryError: 42 | git_data = None 43 | except ImportError: 44 | git_data = None 45 | # Gathering slurm metadata. 46 | if "SLURM_JOB_ID" in os.environ: 47 | slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")] 48 | slurm_data = {} 49 | for k in slurm_env_keys: 50 | d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower() 51 | slurm_data[d_key] = os.environ[k] 52 | else: 53 | slurm_data = None 54 | return dict( 55 | date_start=date_start, 56 | date_end=None, 57 | successful=False, 58 | git=git_data, 59 | slurm=slurm_data, 60 | env=os.environ.copy(), 61 | ) 62 | 63 | 64 | class FileWriter: 65 | def __init__( 66 | self, 67 | xpid: str = None, 68 | xp_args: dict = None, 69 | rootdir: str = "~/logs", 70 | symlink_to_latest: bool = True, 71 | ): 72 | if not xpid: 73 | # Make unique id. 74 | xpid = "{proc}_{unixtime}".format( 75 | proc=os.getpid(), unixtime=int(time.time()) 76 | ) 77 | self.xpid = xpid 78 | self._tick = 0 79 | 80 | # Metadata gathering. 81 | if xp_args is None: 82 | xp_args = {} 83 | self.metadata = gather_metadata() 84 | # We need to copy the args, otherwise when we close the file writer 85 | # (and rewrite the args) we might have non-serializable objects (or 86 | # other unwanted side-effects). 87 | self.metadata["args"] = copy.deepcopy(xp_args) 88 | self.metadata["xpid"] = self.xpid 89 | 90 | formatter = logging.Formatter("%(message)s") 91 | self._logger = logging.getLogger("logs/out") 92 | 93 | # To stdout handler. 94 | shandle = logging.StreamHandler() 95 | shandle.setFormatter(formatter) 96 | self._logger.addHandler(shandle) 97 | self._logger.setLevel(logging.INFO) 98 | 99 | rootdir = os.path.expandvars(os.path.expanduser(rootdir)) 100 | # To file handler. 101 | self.basepath = os.path.join(rootdir, self.xpid) 102 | if not os.path.exists(self.basepath): 103 | self._logger.info("Creating log directory: %s", self.basepath) 104 | os.makedirs(self.basepath, exist_ok=True) 105 | else: 106 | self._logger.info("Found log directory: %s", self.basepath) 107 | 108 | if symlink_to_latest: 109 | # Add 'latest' as symlink unless it exists and is no symlink. 110 | symlink = os.path.join(rootdir, "latest") 111 | try: 112 | if os.path.islink(symlink): 113 | os.remove(symlink) 114 | if not os.path.exists(symlink): 115 | os.symlink(self.basepath, symlink) 116 | self._logger.info("Symlinked log directory: %s", symlink) 117 | except OSError: 118 | # os.remove() or os.symlink() raced. Don't do anything. 119 | pass 120 | 121 | self.paths = dict( 122 | msg="{base}/out.log".format(base=self.basepath), 123 | logs="{base}/logs.csv".format(base=self.basepath), 124 | fields="{base}/fields.csv".format(base=self.basepath), 125 | meta="{base}/meta.json".format(base=self.basepath), 126 | ) 127 | 128 | self._logger.info("Saving arguments to %s", self.paths["meta"]) 129 | if os.path.exists(self.paths["meta"]): 130 | self._logger.warning( 131 | "Path to meta file already exists. " "Not overriding meta." 132 | ) 133 | else: 134 | self._save_metadata() 135 | 136 | self._logger.info("Saving messages to %s", self.paths["msg"]) 137 | if os.path.exists(self.paths["msg"]): 138 | self._logger.warning( 139 | "Path to message file already exists. " "New data will be appended." 140 | ) 141 | 142 | fhandle = logging.FileHandler(self.paths["msg"]) 143 | fhandle.setFormatter(formatter) 144 | self._logger.addHandler(fhandle) 145 | 146 | self._logger.info("Saving logs data to %s", self.paths["logs"]) 147 | self._logger.info("Saving logs' fields to %s", self.paths["fields"]) 148 | self.fieldnames = ["_tick", "_time"] 149 | if os.path.exists(self.paths["logs"]): 150 | self._logger.warning( 151 | "Path to log file already exists. " "New data will be appended." 152 | ) 153 | # Override default fieldnames. 154 | with open(self.paths["fields"], "r") as csvfile: 155 | reader = csv.reader(csvfile) 156 | lines = list(reader) 157 | if len(lines) > 0: 158 | self.fieldnames = lines[-1] 159 | # Override default tick: use the last tick from the logs file plus 1. 160 | with open(self.paths["logs"], "r") as csvfile: 161 | reader = csv.reader(csvfile) 162 | lines = list(reader) 163 | # Need at least two lines in order to read the last tick: 164 | # the first is the csv header and the second is the first line 165 | # of data. 166 | if len(lines) > 1: 167 | self._tick = int(lines[-1][0]) + 1 168 | 169 | self._fieldfile = open(self.paths["fields"], "a") 170 | self._fieldwriter = csv.writer(self._fieldfile) 171 | self._logfile = open(self.paths["logs"], "a") 172 | self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames) 173 | 174 | def log(self, to_log: Dict, tick: int = None, verbose: bool = False) -> None: 175 | if tick is not None: 176 | raise NotImplementedError 177 | else: 178 | to_log["_tick"] = self._tick 179 | self._tick += 1 180 | to_log["_time"] = time.time() 181 | 182 | old_len = len(self.fieldnames) 183 | for k in to_log: 184 | if k not in self.fieldnames: 185 | self.fieldnames.append(k) 186 | if old_len != len(self.fieldnames): 187 | self._fieldwriter.writerow(self.fieldnames) 188 | self._logger.info("Updated log fields: %s", self.fieldnames) 189 | 190 | if to_log["_tick"] == 0: 191 | self._logfile.write("# %s\n" % ",".join(self.fieldnames)) 192 | 193 | if verbose: 194 | self._logger.info( 195 | "LOG | %s", 196 | ", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]), 197 | ) 198 | 199 | self._logwriter.writerow(to_log) 200 | self._logfile.flush() 201 | 202 | def close(self, successful: bool = True) -> None: 203 | self.metadata["date_end"] = datetime.datetime.now().strftime( 204 | "%Y-%m-%d %H:%M:%S.%f" 205 | ) 206 | self.metadata["successful"] = successful 207 | self._save_metadata() 208 | 209 | for f in [self._logfile, self._fieldfile]: 210 | f.close() 211 | 212 | def _save_metadata(self) -> None: 213 | with open(self.paths["meta"], "w") as jsonfile: 214 | json.dump(self.metadata, jsonfile, indent=4, sort_keys=True) 215 | 216 | 217 | def read_metadata(filename): 218 | with open(filename) as json_file: 219 | data = json.load(json_file) 220 | 221 | return data -------------------------------------------------------------------------------- /torchbeast/core/popart.py: -------------------------------------------------------------------------------- 1 | # https://github.com/steffenvan/attentive-multi-tasking 2 | # https://github.com/ysr-plus-ultra/keras_popart_impala 3 | 4 | import math 5 | 6 | import torch 7 | 8 | 9 | class PopArtLayer(torch.nn.Module): 10 | 11 | def __init__(self, input_features, output_features, beta=4e-4): 12 | self.beta = beta 13 | 14 | super(PopArtLayer, self).__init__() 15 | 16 | self.input_features = input_features 17 | self.output_features = output_features 18 | 19 | self.weight = torch.nn.Parameter(torch.Tensor(output_features, input_features)) 20 | self.bias = torch.nn.Parameter(torch.Tensor(output_features)) 21 | 22 | self.register_buffer('mu', torch.zeros(output_features, requires_grad=False)) 23 | self.register_buffer('sigma', torch.ones(output_features, requires_grad=False)) 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 29 | if self.bias is not None: 30 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 31 | bound = 1 / math.sqrt(fan_in) 32 | torch.nn.init.uniform_(self.bias, -bound, bound) 33 | 34 | def forward(self, inputs): 35 | 36 | normalized_output = inputs.mm(self.weight.t()) 37 | normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output) 38 | 39 | with torch.no_grad(): 40 | output = normalized_output * self.sigma + self.mu 41 | 42 | return [output, normalized_output] 43 | 44 | def update_parameters(self, vs, task): 45 | 46 | oldmu = self.mu 47 | oldsigma = self.sigma 48 | 49 | vs = vs * task 50 | n = task.sum((0, 1)) 51 | mu = vs.sum((0, 1)) / n 52 | nu = torch.sum(vs**2, (0, 1)) / n 53 | sigma = torch.sqrt(nu - mu**2) 54 | sigma = torch.clamp(sigma, min=1e-4, max=1e+6) 55 | 56 | mu[torch.isnan(mu)] = self.mu[torch.isnan(mu)] 57 | sigma[torch.isnan(sigma)] = self.sigma[torch.isnan(sigma)] 58 | 59 | self.mu = (1 - self.beta) * self.mu + self.beta * mu 60 | self.sigma = (1 - self.beta) * self.sigma + self.beta * sigma 61 | 62 | self.weight.data = (self.weight.t() * oldsigma / self.sigma).t() 63 | self.bias.data = (oldsigma * self.bias + oldmu - self.mu) / self.sigma 64 | 65 | -------------------------------------------------------------------------------- /torchbeast/core/prof.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Naive profiling using timeit. (Used in MonoBeast.)""" 15 | 16 | import collections 17 | import timeit 18 | 19 | 20 | class Timings: 21 | """Not thread-safe.""" 22 | 23 | def __init__(self): 24 | self._means = collections.defaultdict(int) 25 | self._vars = collections.defaultdict(int) 26 | self._counts = collections.defaultdict(int) 27 | self.reset() 28 | 29 | def reset(self): 30 | self.last_time = timeit.default_timer() 31 | 32 | def time(self, name): 33 | """Save an update for event `name`. 34 | 35 | Nerd alarm: We could just store a 36 | collections.defaultdict(list) 37 | and compute means and standard deviations at the end. But thanks to the 38 | clever math in Sutton-Barto 39 | (http://www.incompleteideas.net/book/first/ebook/node19.html) and 40 | https://math.stackexchange.com/a/103025/5051 we can update both the 41 | means and the stds online. O(1) FTW! 42 | """ 43 | now = timeit.default_timer() 44 | x = now - self.last_time 45 | self.last_time = now 46 | 47 | n = self._counts[name] 48 | 49 | mean = self._means[name] + (x - self._means[name]) / (n + 1) 50 | var = ( 51 | n * self._vars[name] + n * (self._means[name] - mean) ** 2 + (x - mean) ** 2 52 | ) / (n + 1) 53 | 54 | self._means[name] = mean 55 | self._vars[name] = var 56 | self._counts[name] += 1 57 | 58 | def means(self): 59 | return self._means 60 | 61 | def vars(self): 62 | return self._vars 63 | 64 | def stds(self): 65 | return {k: v ** 0.5 for k, v in self._vars.items()} 66 | 67 | def summary(self, prefix=""): 68 | means = self.means() 69 | stds = self.stds() 70 | total = sum(means.values()) 71 | 72 | result = prefix 73 | for k in sorted(means, key=means.get, reverse=True): 74 | result += f"\n %s: %.6fms +- %.6fms (%.2f%%) " % ( 75 | k, 76 | 1000 * means[k], 77 | 1000 * stds[k], 78 | 100 * means[k] / total, 79 | ) 80 | result += "\nTotal: %.6fms" % (1000 * total) 81 | return result 82 | -------------------------------------------------------------------------------- /torchbeast/core/vtrace.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | """Functions to compute V-trace off-policy actor critic targets. 20 | 21 | For details and theory see: 22 | 23 | "IMPALA: Scalable Distributed Deep-RL with 24 | Importance Weighted Actor-Learner Architectures" 25 | by Espeholt, Soyer, Munos et al. 26 | 27 | See https://arxiv.org/abs/1802.01561 for the full paper. 28 | """ 29 | 30 | import collections 31 | 32 | import torch 33 | import torch.nn.functional as F 34 | 35 | 36 | VTraceFromLogitsReturns = collections.namedtuple( 37 | "VTraceFromLogitsReturns", 38 | [ 39 | "vs", 40 | "pg_advantages", 41 | "log_rhos", 42 | "behavior_action_log_probs", 43 | "target_action_log_probs", 44 | ], 45 | ) 46 | 47 | VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") 48 | 49 | 50 | def action_log_probs(policy_logits, actions): 51 | return -F.nll_loss( 52 | F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1), 53 | torch.flatten(actions), 54 | reduction="none", 55 | ).view_as(actions) 56 | 57 | 58 | def from_logits( 59 | behavior_policy_logits, 60 | target_policy_logits, 61 | actions, 62 | discounts, 63 | rewards, 64 | values, 65 | bootstrap_value, 66 | normalized_values=None, 67 | mu=None, 68 | sigma=None, 69 | clip_rho_threshold=1.0, 70 | clip_pg_rho_threshold=1.0, 71 | ): 72 | """V-trace for softmax policies.""" 73 | device = target_policy_logits.device 74 | target_action_log_probs = action_log_probs(target_policy_logits, actions) 75 | behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions) 76 | log_rhos = target_action_log_probs - behavior_action_log_probs 77 | if mu is None: 78 | mu = torch.zeros(1).to(device) 79 | sigma = torch.ones(1).to(device) 80 | normalized_values = values 81 | vtrace_returns = from_importance_weights( 82 | log_rhos=log_rhos, 83 | discounts=discounts, 84 | rewards=rewards, 85 | values=values, 86 | normalized_values=normalized_values, 87 | bootstrap_value=bootstrap_value, 88 | mu=mu, 89 | sigma=sigma, 90 | clip_rho_threshold=clip_rho_threshold, 91 | clip_pg_rho_threshold=clip_pg_rho_threshold, 92 | ) 93 | return VTraceFromLogitsReturns( 94 | log_rhos=log_rhos, 95 | behavior_action_log_probs=behavior_action_log_probs, 96 | target_action_log_probs=target_action_log_probs, 97 | **vtrace_returns._asdict(), 98 | ) 99 | 100 | 101 | @torch.no_grad() 102 | def from_importance_weights( 103 | log_rhos, 104 | discounts, 105 | rewards, 106 | values, 107 | normalized_values, 108 | bootstrap_value, 109 | mu, 110 | sigma, 111 | clip_rho_threshold=1.0, 112 | clip_pg_rho_threshold=1.0, 113 | ): 114 | """V-trace from log importance weights.""" 115 | with torch.no_grad(): 116 | rhos = torch.exp(log_rhos) 117 | if clip_rho_threshold is not None: 118 | clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) 119 | else: 120 | clipped_rhos = rhos 121 | 122 | cs = torch.clamp(rhos, max=1.0) 123 | # Append bootstrapped value to get [v1, ..., v_t+1] 124 | values_t_plus_1 = torch.cat( 125 | [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0 126 | ) 127 | 128 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) 129 | 130 | acc = torch.zeros_like(bootstrap_value) 131 | result = [] 132 | for t in range(discounts.shape[0] - 1, -1, -1): 133 | acc = deltas[t] + discounts[t] * cs[t] * acc 134 | result.append(acc) 135 | result.reverse() 136 | vs_minus_v_xs = torch.stack(result) 137 | 138 | # Add V(x_s) to get v_s. 139 | vs = torch.add(vs_minus_v_xs, values) 140 | 141 | # Advantage for policy gradient. 142 | broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value 143 | vs_t_plus_1 = torch.cat( 144 | [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 145 | ) 146 | if clip_pg_rho_threshold is not None: 147 | clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) 148 | else: 149 | clipped_pg_rhos = rhos 150 | 151 | #pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values) 152 | pg_advantages = clipped_pg_rhos * (((rewards + discounts * vs_t_plus_1) - mu) / sigma - normalized_values) 153 | 154 | # Make sure no gradients backpropagated through the returned values. 155 | return VTraceReturns(vs=vs, pg_advantages=pg_advantages) 156 | -------------------------------------------------------------------------------- /torchbeast/models/atari_net_monobeast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from torchbeast.core.popart import PopArtLayer 6 | 7 | 8 | class AtariNet(nn.Module): 9 | 10 | def __init__(self, observation_shape, num_actions, num_tasks=1, use_lstm=False, use_popart=False, **kwargs): 11 | super(AtariNet, self).__init__() 12 | self.observation_shape = observation_shape 13 | self.num_actions = num_actions 14 | self.num_tasks = num_tasks 15 | self.use_lstm = use_lstm 16 | self.use_popart = use_popart 17 | 18 | # Feature extraction. 19 | self.conv1 = nn.Conv2d( 20 | in_channels=self.observation_shape[0], 21 | out_channels=32, 22 | kernel_size=8, 23 | stride=4, 24 | ) 25 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 26 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 27 | 28 | # Fully connected layer. 29 | self.fc = nn.Linear(3136, 512) 30 | 31 | # FC output size + one-hot of last action + last reward. 32 | core_output_size = self.fc.out_features + num_actions + 1 33 | 34 | if use_lstm: 35 | self.core = nn.LSTM(core_output_size, core_output_size, 2) 36 | 37 | self.policy = nn.Linear(core_output_size, self.num_actions) 38 | self.baseline = PopArtLayer(core_output_size, num_tasks if self.use_popart else 1) 39 | 40 | def initial_state(self, batch_size): 41 | if not self.use_lstm: 42 | return tuple() 43 | return tuple( 44 | torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) 45 | for _ in range(2) 46 | ) 47 | 48 | def forward(self, inputs, core_state=()): 49 | x = inputs["frame"] # [T, B, C, H, W]. 50 | T, B, *_ = x.shape 51 | x = torch.flatten(x, 0, 1) # Merge time and batch. 52 | x = x.float() / 255.0 53 | x = F.relu(self.conv1(x)) 54 | x = F.relu(self.conv2(x)) 55 | x = F.relu(self.conv3(x)) 56 | x = x.view(T * B, -1) 57 | x = F.relu(self.fc(x)) 58 | 59 | one_hot_last_action = F.one_hot( 60 | inputs["last_action"].view(T * B), self.num_actions 61 | ).float() 62 | clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1) 63 | core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1) 64 | 65 | if self.use_lstm: 66 | core_input = core_input.view(T, B, -1) 67 | core_output_list = [] 68 | # notdone has shape (time_steps, batch_size) 69 | notdone = (~inputs["done"]).float() 70 | for input, nd in zip(core_input.unbind(), notdone.unbind()): 71 | # Reset core state to zero whenever an episode ended. 72 | # Make `done` broadcastable with (num_layers, B, hidden_size) 73 | # states: 74 | nd = nd.view(1, -1, 1) 75 | core_state = tuple(nd * s for s in core_state) 76 | output, core_state = self.core(input.unsqueeze(0), core_state) 77 | core_output_list.append(output) 78 | core_output = torch.flatten(torch.cat(core_output_list), 0, 1) 79 | # pretty sure flatten() is just used to merge time and batch again 80 | else: 81 | core_output = core_input 82 | core_state = tuple() 83 | 84 | # core_output should have shape (T * B, hidden_size) now? 85 | policy_logits = self.policy(core_output) 86 | baseline, normalized_baseline = self.baseline(core_output) 87 | 88 | if self.training: 89 | action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) 90 | else: 91 | # Don't sample when testing. 92 | action = torch.argmax(policy_logits, dim=1) 93 | 94 | policy_logits = policy_logits.view(T, B, self.num_actions) 95 | 96 | baseline = baseline.view(T, B, self.num_tasks) 97 | normalized_baseline = normalized_baseline.view(T, B, self.num_tasks) 98 | action = action.view(T, B, 1) 99 | 100 | return ( 101 | dict(policy_logits=policy_logits, baseline=baseline, action=action, 102 | normalized_baseline=normalized_baseline), 103 | core_state, 104 | ) -------------------------------------------------------------------------------- /torchbeast/models/resnet_monobeast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from torchbeast.core.popart import PopArtLayer 6 | 7 | 8 | class ResNet(nn.Module): 9 | 10 | def __init__( 11 | self, 12 | observation_shape, # not used in this architecture 13 | num_actions, 14 | num_tasks=1, 15 | use_lstm=False, 16 | use_popart=False, 17 | reward_clipping="abs_one", 18 | **kwargs 19 | ): 20 | 21 | super(ResNet, self).__init__() 22 | self.num_actions = num_actions 23 | self.num_tasks = num_tasks 24 | self.use_lstm = use_lstm 25 | self.use_popart = use_popart 26 | self.reward_clipping = reward_clipping 27 | 28 | self.feat_convs = [] 29 | self.resnet1 = [] 30 | self.resnet2 = [] 31 | 32 | self.convs = [] 33 | 34 | input_channels = 4 35 | for num_ch in [16, 32, 32]: 36 | feats_convs = [nn.Conv2d( 37 | in_channels=input_channels, 38 | out_channels=num_ch, 39 | kernel_size=3, 40 | stride=1, 41 | padding=1, 42 | ), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)] 43 | self.feat_convs.append(nn.Sequential(*feats_convs)) 44 | 45 | input_channels = num_ch 46 | 47 | for i in range(2): 48 | resnet_block = [nn.ReLU(), nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=num_ch, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1, 54 | ), nn.ReLU(), nn.Conv2d( 55 | in_channels=input_channels, 56 | out_channels=num_ch, 57 | kernel_size=3, 58 | stride=1, 59 | padding=1, 60 | )] 61 | if i == 0: 62 | self.resnet1.append(nn.Sequential(*resnet_block)) 63 | else: 64 | self.resnet2.append(nn.Sequential(*resnet_block)) 65 | 66 | self.feat_convs = nn.ModuleList(self.feat_convs) 67 | self.resnet1 = nn.ModuleList(self.resnet1) 68 | self.resnet2 = nn.ModuleList(self.resnet2) 69 | 70 | self.fc = nn.Linear(3872, 256) 71 | 72 | # FC output size + last reward. 73 | core_output_size = self.fc.out_features + 1 74 | 75 | if use_lstm: 76 | self.core = nn.LSTM(core_output_size, 256, num_layers=1) 77 | core_output_size = 256 78 | 79 | self.policy = nn.Linear(core_output_size, self.num_actions) 80 | self.baseline = PopArtLayer(core_output_size, num_tasks if self.use_popart else 1) 81 | 82 | def initial_state(self, batch_size=1): 83 | if not self.use_lstm: 84 | return tuple() 85 | return tuple( 86 | torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) 87 | for _ in range(2) 88 | ) 89 | 90 | def forward(self, inputs, core_state=(), run_to_conv=-1): 91 | if run_to_conv >= 0: 92 | x = inputs 93 | else: 94 | x = inputs["frame"] 95 | 96 | T, B, *_ = x.shape 97 | x = torch.flatten(x, 0, 1) # Merge time and batch. 98 | x = x.float() / 255.0 99 | 100 | conv_counter = 0 101 | for i, f_conv in enumerate(self.feat_convs): 102 | x = f_conv(x) 103 | conv_counter += 1 104 | if 0 <= run_to_conv < conv_counter: 105 | return x 106 | 107 | res_input = x 108 | x = self.resnet1[i](x) 109 | conv_counter += 2 110 | if 0 <= run_to_conv < conv_counter: 111 | return x 112 | x += res_input 113 | 114 | res_input = x 115 | x = self.resnet2[i](x) 116 | conv_counter += 2 117 | if 0 <= run_to_conv < conv_counter: 118 | return x 119 | x += res_input 120 | 121 | x = F.relu(x) 122 | x = x.view(T * B, -1) 123 | x = F.relu(self.fc(x)) 124 | 125 | clipped_reward = None 126 | if self.reward_clipping == "abs_one": 127 | clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1) 128 | elif self.reward_clipping == "none": 129 | clipped_reward = inputs["reward"].view(T * B, 1) 130 | 131 | core_input = torch.cat([x, clipped_reward], dim=-1) 132 | 133 | if self.use_lstm: 134 | core_input = core_input.view(T, B, -1) 135 | core_output_list = [] 136 | not_done = (~inputs["done"]).float() 137 | for input, nd in zip(core_input.unbind(), not_done.unbind()): 138 | # Reset core state to zero whenever an episode ended. 139 | # Make `done` broadcastable with (num_layers, B, hidden_size) 140 | # states: 141 | nd = nd.view(1, -1, 1) 142 | # core_state = nest.map(nd.mul, core_state) 143 | core_state = tuple(nd * s for s in core_state) 144 | output, core_state = self.core(input.unsqueeze(0), core_state) 145 | core_output_list.append(output) 146 | core_output = torch.flatten(torch.cat(core_output_list), 0, 1) 147 | else: 148 | core_output = core_input 149 | 150 | policy_logits = self.policy(core_output) 151 | baseline, normalized_baseline = self.baseline(core_output) 152 | 153 | if self.training: 154 | action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) 155 | else: 156 | # Don't sample when testing. 157 | action = torch.argmax(policy_logits, dim=1) 158 | 159 | policy_logits = policy_logits.view(T, B, self.num_actions) 160 | 161 | baseline = baseline.view(T, B, self.num_tasks) 162 | normalized_baseline = normalized_baseline.view(T, B, self.num_tasks) 163 | action = action.view(T, B, 1) 164 | 165 | return ( 166 | dict(policy_logits=policy_logits, baseline=baseline, action=action, 167 | normalized_baseline=normalized_baseline), 168 | core_state, 169 | ) 170 | -------------------------------------------------------------------------------- /torchbeast/polybeast_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import argparse 17 | import multiprocessing as mp 18 | import threading 19 | import time 20 | 21 | import numpy as np 22 | from libtorchbeast import rpcenv 23 | from torchbeast import atari_wrappers 24 | 25 | 26 | # yapf: disable 27 | parser = argparse.ArgumentParser(description='Remote Environment Server') 28 | 29 | parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast", 30 | help="Basename for the pipes for inter-process communication. " 31 | "Has to be of the type unix:/some/path.") 32 | parser.add_argument('--num_servers', default=4, type=int, metavar='N', 33 | help='Number of environment servers.') 34 | parser.add_argument('--env', type=str, default='PongNoFrameskip-v4', 35 | help='Gym environment.') 36 | parser.add_argument("--multitask", action="store_true", 37 | help="Broadcast task id") 38 | 39 | # yapf: enable 40 | 41 | 42 | class Env: 43 | def reset(self): 44 | print("reset called") 45 | return np.ones((4, 84, 84), dtype=np.uint8) 46 | 47 | def step(self, action): 48 | frame = np.zeros((4, 84, 84), dtype=np.uint8) 49 | return frame, 0.0, False, 0, {} # First four mandatory. 50 | 51 | 52 | def create_env(env_name, task=0, full_action_space=False, lock=threading.Lock()): 53 | with lock: # Atari isn't threadsafe at construction time. 54 | return atari_wrappers.wrap_pytorch_task( 55 | atari_wrappers.wrap_deepmind( 56 | atari_wrappers.make_atari(env_name, full_action_space=full_action_space), 57 | clip_rewards=False, 58 | frame_stack=True, 59 | scale=False, 60 | ), 61 | task=task 62 | ) 63 | 64 | 65 | def serve(env_name, task, full_action_space, server_address): 66 | init = Env if env_name == "Mock" else lambda: create_env(env_name, task=task, full_action_space=full_action_space) 67 | server = rpcenv.Server(init, server_address=server_address) 68 | server.run() 69 | 70 | 71 | if __name__ == "__main__": 72 | flags = parser.parse_args() 73 | 74 | if not flags.pipes_basename.startswith("unix:"): 75 | raise Exception("--pipes_basename has to be of the form unix:/some/path.") 76 | 77 | processes = [] 78 | envs = flags.env.split(",") 79 | 80 | # determine if action spaces are compatible, otherwise use full action space 81 | full_action_space = True 82 | if flags.env != "Mock": 83 | action_spaces = [] 84 | for i in range(len(envs)): 85 | env = create_env(envs[i]) 86 | action_spaces.append(env.action_space) 87 | env.close() 88 | if all(x == action_spaces[0] for x in action_spaces): 89 | full_action_space = False 90 | 91 | if len(envs) <= flags.num_servers: 92 | for i in range(flags.num_servers): 93 | task = i % len(envs) if flags.multitask else 0 94 | p = mp.Process( 95 | target=serve, args=(envs[i % len(envs)], task, full_action_space, f"{flags.pipes_basename}.{i}"), daemon=True 96 | ) 97 | p.start() 98 | processes.append(p) 99 | print("Starting environment", i, "(", task, ", ", envs[task], ").") 100 | else: 101 | raise Exception("Wrong number of servers for environments.") 102 | 103 | 104 | try: 105 | # We are only here to listen to the interrupt. 106 | while True: 107 | time.sleep(10) 108 | except KeyboardInterrupt: 109 | pass 110 | --------------------------------------------------------------------------------