├── .clang-format ├── .flake8 ├── .github ├── release-drafter.yml └── workflows │ ├── black_flake8.yml │ ├── clang-format.yml │ ├── deploy.yml │ ├── release-drafter.yml │ └── run_python_tests.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── README.md ├── atari_1.png ├── atari_2.png ├── make.bat ├── run_docs.sh └── source │ ├── _templates │ ├── moolib_class_template.rst │ └── moolib_result_template.rst │ ├── conf.py │ └── index.rst ├── examples ├── README.md ├── a2c.py ├── atari │ ├── atari_preprocessing.py │ ├── environment.py │ └── models.py ├── common │ ├── __init__.py │ ├── nest.py │ ├── record.py │ └── vtrace.py ├── plot.py ├── requirements.txt ├── sbatch_experiment.py └── vtrace │ ├── __init__.py │ ├── config.yaml │ └── experiment.py ├── py └── moolib │ ├── __init__.py │ ├── broker.py │ └── examples ├── pyproject.toml ├── setup.py ├── src ├── accumulator.cc ├── accumulator.h ├── any.h ├── async.cc ├── async.h ├── batch_utils.cc ├── batch_utils.h ├── batchsizefinder.h ├── broker.h ├── env.cc ├── env.h ├── function.h ├── group.cc ├── group.h ├── intrusive_list.h ├── logging.h ├── memory │ ├── allocator.h │ ├── buffer.h │ ├── memfd.cc │ └── memfd.h ├── moolib.cc ├── pythonserialization.h ├── pytorch.h ├── pyutil.h ├── rpc.cc ├── rpc.h ├── serialization.h ├── shm.h ├── synchronization.h ├── tensor.cc ├── tensor.h ├── tensorpython.cc ├── transports │ ├── ipc.cc │ ├── ipc.h │ ├── socket.cc │ └── socket.h ├── util.h └── vector.h └── test ├── CMakeLists.txt ├── example.py ├── integration └── test_a2c.py ├── test.h ├── test.py ├── test_asyncio.py ├── test_asyncio_queue.py ├── test_batch.py ├── test_dynamic_batching_queue.py ├── test_group.py ├── test_multinode_allreduce.cc ├── test_reduce.py ├── test_reduce_asyncio.py ├── test_rpc.cc └── unit ├── test_batcher.py ├── test_broker.py ├── test_envpool.py ├── test_pickle.py ├── test_simple.py └── test_tensors.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | 3 | AccessModifierOffset: -2 4 | AlignAfterOpenBracket: AlwaysBreak 5 | AllowShortFunctionsOnASingleLine: Empty 6 | AllowShortIfStatementsOnASingleLine: true 7 | AllowShortLoopsOnASingleLine: false 8 | AlwaysBreakTemplateDeclarations: true 9 | ColumnLimit: 120 10 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 11 | DerivePointerAlignment: false 12 | PointerAlignment: Left 13 | SpaceAfterTemplateKeyword: false 14 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | B001, # Do not use bare `except:` 4 | B008, # Do not perform calls in argument defaults. 5 | B950, # line too long 6 | C901, # Function is too complex 7 | E203, # Whitespace before ':' (breaks black) 8 | E266, ## Too many leading '#' for block comment 9 | E302, # Expected 2 blank lines, found 0 10 | E501, # Line too long 11 | E722, # do not use bare 'except' 12 | E731, # Do not assign a lambda expression, use a def 13 | W503 # Line break occurred before a binary operator 14 | # 80 to use as a soft test 15 | max-line-length = 80 16 | max-complexity = 18 17 | select = B,C,E,F,W,T4,B9 18 | exclude = 19 | .git 20 | __pycache__, 21 | build, 22 | dist, 23 | src, # non-python and third-party code. 24 | atari_preprocessing.py # third-party code. 25 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name-template: 'v$NEXT_PATCH_VERSION' 3 | tag-template: 'v$NEXT_PATCH_VERSION' 4 | categories: 5 | - title: '🚀 New Features' 6 | label: 'enhancement' 7 | - title: '💣 Breaking Change' 8 | label: 'breaking change' 9 | - title: '🐛 Bug Fixes' 10 | label: 'bug' 11 | - title: '📝 Documentation' 12 | label: 'docs' 13 | - title: '🔨 Maintenance' 14 | label: 'chore' 15 | exclude-labels: 16 | - 'skip changelog' 17 | change-template: '- $TITLE (#$NUMBER, @$AUTHOR)' 18 | template: | 19 | # Installing `moolib` 20 | 21 | Install with pip: `pip install moolib==$NEXT_PATCH_VERSION`. 22 | 23 | See [README.md](https://github.com/facebookresearch/moolib/blob/v$NEXT_PATCH_VERSION/README.md) for further instructions. 24 | 25 | 26 | # New in `moolib` v$NEXT_PATCH_VERSION 27 | 28 | $CHANGES 29 | ... 30 | -------------------------------------------------------------------------------- /.github/workflows/black_flake8.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Black & flake8 3 | 4 | on: 5 | push: 6 | paths: 7 | - "**.py" 8 | - "!src/backward-cpp/**" 9 | - "!src/fmt/**" 10 | - "!src/pybind11/**" 11 | - "!src/tensorpipe/**" 12 | pull_request: 13 | paths: 14 | - "**.py" 15 | - "!src/backward-cpp/**" 16 | - "!src/fmt/**" 17 | - "!src/pybind11/**" 18 | - "!src/tensorpipe/**" 19 | schedule: 20 | - cron: "0 6,18 * * *" 21 | 22 | jobs: 23 | check_python: 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - name: Setup Python 3.8 env 28 | uses: actions/setup-python@v1 29 | with: 30 | python-version: "3.8" 31 | - name: Clone NLE repo 32 | uses: actions/checkout@v2 33 | - name: Ensure latest pip 34 | run: "python -m pip install -q --upgrade pip" 35 | - name: Install python linting deps 36 | run: "pip install -q black flake8 flake8-bugbear" 37 | - name: Run black 38 | run: "black --check --diff ." 39 | - name: Run flake8 40 | run: "flake8" 41 | -------------------------------------------------------------------------------- /.github/workflows/clang-format.yml: -------------------------------------------------------------------------------- 1 | name: Clang format 2 | 3 | on: 4 | push: 5 | paths: 6 | - "**.c" 7 | - "**.cc" 8 | - "**.h" 9 | - "!src/backward-cpp/**" 10 | - "!src/fmt/**" 11 | - "!src/pybind11/**" 12 | - "!src/tensorpipe/**" 13 | pull_request: 14 | paths: 15 | - "**.c" 16 | - "**.cc" 17 | - "**.h" 18 | - "!src/backward-cpp/**" 19 | - "!src/fmt/**" 20 | - "!src/pybind11/**" 21 | - "!src/tensorpipe/**" 22 | 23 | jobs: 24 | clang_format: 25 | runs-on: ubuntu-latest 26 | 27 | steps: 28 | - uses: actions/checkout@v2 29 | - uses: DoozyX/clang-format-lint-action@v0.12 30 | with: 31 | source: '.' 32 | exclude: './src/pybind11 ./src/fmt ./src/backward-cpp' 33 | clangFormatVersion: 12 34 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Deploy 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | release: 10 | types: [released] 11 | 12 | 13 | jobs: 14 | test_sdist: 15 | name: Test sdist on MacOS w/ Py3.8 16 | runs-on: macos-latest 17 | when: manual 18 | steps: 19 | - name: Setup Python 3.8 env 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | - name: Ensure latest pip & wheel 24 | run: "python -m pip install -q --upgrade pip wheel" 25 | - name: Install dependencies 26 | run: | 27 | brew install cmake 28 | python -m pip install torch numpy pytest-forked gym 29 | - uses: actions/checkout@v2 30 | with: 31 | submodules: recursive 32 | - name: Generate sdist 33 | run: | 34 | USE_CUDA=0 python setup.py sdist 35 | - name: Install from sdist 36 | run: | 37 | SDISTNAME=$(ls dist/) 38 | MODE="[all]" 39 | USE_CUDA=0 pip install "dist/$SDISTNAME$MODE" 40 | - name: Run tests outside repo dir 41 | run: | 42 | REPONAME=$(basename $PWD) 43 | pushd .. 44 | PYTHONPATH=$REPONAME python -um pytest -svx --forked $REPONAME/test/unit $REPONAME/test/integration 45 | popd 46 | - name: Save sdist 47 | if: ${{ always() }} 48 | uses: actions/upload-artifact@v1 49 | with: 50 | name: moolib_dist 51 | path: dist/ 52 | 53 | # TODO move to separate workflow? 54 | deploy_sdist: 55 | name: Deploy sdist to pypi 56 | needs: test_sdist 57 | if: github.event_name == 'release' && github.event.action == 'released' 58 | runs-on: ubuntu-latest 59 | steps: 60 | - uses: actions/checkout@v2 61 | with: 62 | submodules: recursive 63 | - name: Check version matches release tag 64 | run: | 65 | echo "v$(grep -Po '(?<=version=")[0-9.]+' setup.py)" 66 | echo "${{ github.event.release.tag_name }}" 67 | [[ "${{ github.event.release.tag_name }}" == "v$(grep -Po '(?<=version=")[0-9.]+' setup.py)" ]] 68 | - name: Get dist artifacts from test_sdist 69 | uses: actions/download-artifact@v2 70 | with: 71 | name: moolib_dist 72 | path: dist 73 | - name: Install from sdist 74 | run: | 75 | pwd 76 | ls -R 77 | ls -al . 78 | ls -R dist/ 79 | ls -al dist/ 80 | # NOTE: We assume that dist/ contains a built sdist (and only that). 81 | - name: Publish package to PyPI 82 | uses: pypa/gh-action-pypi-publish@master 83 | with: 84 | user: __token__ 85 | password: ${{ secrets.PYPI_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Release Drafter 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | update_release_draft: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Draft release notes 14 | uses: release-drafter/release-drafter@v5 15 | env: 16 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 17 | ... 18 | -------------------------------------------------------------------------------- /.github/workflows/run_python_tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | schedule: 6 | - cron: "0 6,18 * * *" 7 | release: 8 | types: [released] 9 | 10 | jobs: 11 | run_python_tests: 12 | name: ${{ matrix.os }} w/ Py${{ matrix.python-version }} 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | python-version: ["3.7"] 17 | #os: [ubuntu-latest, macos-latest] 18 | os: [ubuntu-latest] 19 | fail-fast: false 20 | 21 | steps: 22 | - name: Setup Python ${{ matrix.python-version }} env 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Ensure latest pip & wheel 27 | run: "python -m pip install --upgrade pip" 28 | - name: Setup cmake 29 | uses: jwlawson/actions-setup-cmake@v1.9 30 | with: 31 | cmake-version: '3.16.x' 32 | - name: Install PyTorch 33 | # Need to explicitly ask for non-CUDA version on Linux. 34 | run: | 35 | if [ "$RUNNER_OS" == "Linux" ]; then 36 | pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 37 | pip install numpy pytest-forked gym 38 | else 39 | pip install torch numpy pytest-forked gym 40 | fi 41 | shell: bash 42 | - uses: actions/checkout@v2 43 | with: 44 | submodules: recursive 45 | - name: Install moolib 46 | # pip installing makes site-packages not be part of sys.path, 47 | # so torch won't be found by our cmake check. TODO: Fix. 48 | # run: "USE_CUDA=0 pip install ." 49 | run: "USE_CUDA=0 python setup.py install" 50 | - name: Run tests 51 | run: "python -um pytest -svx --forked test/unit test/integration" 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | dist/ 4 | *.so 5 | *.egg-info 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/pybind11"] 2 | path = third_party/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | [submodule "third_party/backward-cpp"] 5 | path = third_party/backward-cpp 6 | url = https://github.com/bombela/backward-cpp.git 7 | [submodule "third_party/fmt"] 8 | path = third_party/fmt 9 | url = https://github.com/fmtlib/fmt 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | 3 | project(moolib) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_FLAGS 7 | "${CMAKE_CXX_FLAGS} -Wfatal-errors -ftemplate-backtrace-limit=0 -Bsymbolic") 8 | 9 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 10 | set(CMAKE_CXX_VISIBILITY_PRESET hidden) 11 | 12 | execute_process( 13 | COMMAND python -c "import os, torch; print(os.path.dirname(torch.__file__), end='')" 14 | OUTPUT_VARIABLE TorchPath 15 | ) 16 | set(CMAKE_PREFIX_PATH ${TorchPath}) 17 | find_package(Torch REQUIRED) 18 | 19 | message(STATUS "PyTorch compilation flags: ${TORCH_CXX_FLAGS}") 20 | message(STATUS "PyTorch include dirs: ${TORCH_INCLUDE_DIRS}") 21 | #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 22 | 23 | find_package(Threads REQUIRED) 24 | 25 | option(USE_CUDA "Enable CUDA support" ON) 26 | 27 | if (USE_CUDA) 28 | option(TP_USE_CUDA "" ON) 29 | endif() 30 | 31 | add_subdirectory(third_party/fmt) 32 | add_subdirectory(third_party/pybind11) 33 | 34 | #add_subdirectory(test) 35 | 36 | add_library(tensor src/tensor.cc) 37 | target_link_libraries(tensor PRIVATE ${TORCH_LIBRARIES}) 38 | target_include_directories(tensor SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) 39 | target_compile_options(tensor PRIVATE ${TORCH_CXX_FLAGS}) 40 | 41 | add_library(moorpc OBJECT 42 | src/rpc.cc 43 | src/async.cc 44 | src/transports/socket.cc 45 | src/transports/ipc.cc 46 | src/memory/memfd.cc 47 | ) 48 | target_link_libraries(moorpc PUBLIC tensor fmt anl) 49 | target_include_directories(moorpc PRIVATE src) 50 | 51 | if (USE_CUDA) 52 | target_compile_definitions(tensor PRIVATE -DUSE_CUDA) 53 | target_compile_definitions(moorpc PRIVATE -DUSE_CUDA) 54 | endif() 55 | 56 | pybind11_add_module(_C 57 | $ 58 | src/moolib.cc 59 | src/accumulator.cc 60 | src/batch_utils.cc 61 | src/env.cc 62 | src/tensorpython.cc 63 | ) 64 | target_link_libraries( 65 | _C 66 | PRIVATE 67 | $ 68 | ${TorchPath}/lib/libtorch_python${CMAKE_SHARED_LIBRARY_SUFFIX} 69 | ) 70 | target_include_directories(_C PRIVATE src) 71 | 72 | set_source_files_properties(src/tensorpython.cc PROPERTIES INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}") 73 | set_source_files_properties(src/tensorpython.cc PROPERTIES COMPILE_FLAGS "${TORCH_CXX_FLAGS}") 74 | 75 | if (USE_BACKWARD) 76 | find_library(DW_LIB dw REQUIRED) 77 | find_library(UNWIND_LIB unwind REQUIRED) 78 | find_path(DW_INCLUDE dwarf.h REQUIRED) 79 | 80 | target_sources(_C PUBLIC third_party/backward-cpp/backward.cpp) 81 | target_include_directories(_C PRIVATE ${DW_INCLUDE}) 82 | target_link_libraries(_C PUBLIC moorpc ${DW_LIB} ${UNWIND_LIB}) 83 | target_compile_definitions(_C PRIVATE BACKWARD_HAS_DW BACKWARD_HAS_LIBUNWIND) 84 | endif() 85 | 86 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to moolib 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to moolib, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # -*- mode: dockerfile -*- 2 | 3 | FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04 4 | 5 | ARG PYTHON_VERSION=3.8 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | 8 | RUN apt-get update && apt-get install -yq \ 9 | build-essential \ 10 | cmake \ 11 | curl \ 12 | git \ 13 | ninja-build \ 14 | wget 15 | 16 | WORKDIR /opt/conda_setup 17 | 18 | RUN curl -o miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 19 | chmod +x miniconda.sh && \ 20 | ./miniconda.sh -b -p /opt/conda && \ 21 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION && \ 22 | /opt/conda/bin/conda clean -ya 23 | ENV PATH /opt/conda/bin:$PATH 24 | 25 | RUN python -m pip install --upgrade pip 26 | 27 | RUN conda install pytorch numpy cudatoolkit=11.3 -c pytorch 28 | 29 | WORKDIR /opt/moolib 30 | 31 | COPY . /opt/moolib/ 32 | 33 | RUN pip install -r examples/requirements.txt 34 | 35 | RUN pip install '.[all]' 36 | 37 | WORKDIR /opt/moolib 38 | 39 | CMD ["bash", "-c", "python examples/a2c.py"] 40 | 41 | # Docker commands: 42 | # docker rm moolib -v 43 | # docker build -t moolib -f Dockerfile . 44 | # docker run --gpus all --rm --name moolib moolib 45 | # or 46 | # docker run --gpus all -it --entrypoint /bin/bash moolib -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CMakeLists.txt 2 | recursive-include docs * 3 | recursive-include examples * 4 | recursive-include py * 5 | recursive-include src * 6 | recursive-include test * 7 | recursive-include third_party * 8 | 9 | recursive-exclude build * 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🐮
moolib 2 | 3 |

4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |

14 | 15 | moolib - a communications library for distributed ML training 16 | 17 | moolib offers general purpose RPC with automatic transport 18 | selection (shared memory, TCP/IP, Infiniband) allowing models to 19 | data-parallelise their training and synchronize gradients 20 | and model weights across many nodes. 21 | 22 | `moolib` is an RPC library to help you perform distributed machine 23 | learning research, particularly reinforcement learning. It is designed 24 | to be *highly flexible* and *highly performant*. 25 | 26 | It is *flexible* because it allows researchers to define their own 27 | training loops and data-collection policies with minimal interference 28 | or abstractions - `moolib` gets out of the way of research code. 29 | 30 | It is *performant* because it gives researchers the power of efficient 31 | data-parallelization across GPUs with minimal overhead, in a manner 32 | that is highly scalable. 33 | 34 | `moolib` aims to provide researchers with the freedom to implement 35 | whatever experiment loop they desire, and the freedom to scale it up 36 | from single GPUs to hundreds at will (with no additional code). It 37 | ships with a reference implementations 38 | [IMPALA](examples/vtrace/experiment.py) on 39 | [Atari](examples/atari/environment.py) that can easily be adapted to 40 | other environments or algorithms. 41 | 42 | 43 | ## Installing 44 | 45 | **To compile `moolib` without CUDA support** 46 | 47 | EXPORT USE_CUDA=0 48 | 49 | To install from GitHub: 50 | 51 | pip install git+https://github.com/facebookresearch/moolib 52 | 53 | To build from source: 54 | 55 | git clone --recursive git@github.com:facebookresearch/moolib 56 | cd moolib 57 | pip install . 58 | 59 | 60 | ## Run an Example 61 | 62 | To run the example agent on a given Atari level: 63 | 64 | First, start the broker: 65 | 66 | python -m moolib.broker 67 | 68 | It will output something like `Broker listening at 0.0.0.0:4431`. 69 | 70 | Note that a **single broker is enough** for all your experiments. 71 | 72 | Now take the IP address of your computer. If you ssh'd into your 73 | machine, this should work (in a new shell): 74 | 75 | ``` 76 | export BROKER_IP=$(echo $SSH_CONNECTION | cut -d' ' -f3) # Should give your machine's IP. 77 | export BROKER_PORT=4431 78 | ``` 79 | 80 | To start an experiment with a single peer: 81 | 82 | python -m examples.vtrace.experiment connect=BROKER_IP:BROKER_PORT \ 83 | savedir=/tmp/moolib-atari/savedir \ 84 | project=moolib-atari \ 85 | group=Zaxxon-Breakout \ 86 | env.name=ALE/Breakout-v5 87 | 88 | To add more peers to this experiment, start more processes with the 89 | same `project` and `group` settings, using a different setting for 90 | `device` (default: `'cuda:0'`). 91 | 92 | 93 | ## Documentation 94 | 95 | * [`moolib` whitepaper](https://research.facebook.com/publications/moolib-a-platform-for-distributed-rl/). 96 | * [`moolib`'s API documentation](https://facebookresearch.github.io/moolib/). 97 | 98 | 99 | ## Benchmarks 100 | 101 |
Show results on Atari 102 | 103 | ![atari_1](./docs/atari_1.png) 104 | ![atari_2](./docs/atari_2.png) 105 |
106 | 107 | 108 | ## Citation 109 | 110 | ``` 111 | @article{moolib2022, 112 | title = {{moolib: A Platform for Distributed RL}}, 113 | author = {Vegard Mella and Eric Hambro and Danielle Rothermel and Heinrich K{\"{u}}ttler}, 114 | year = {2022}, 115 | url = {https://github.com/facebookresearch/moolib}, 116 | } 117 | ``` 118 | 119 | 120 | ## License 121 | 122 | moolib is licensed under the MIT License. See [`LICENSE`](LICENSE) for details. 123 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # moolib Documentation 2 | 3 | ### Run Docsite Locally 4 | ``` 5 | pip install sphinx==4.1.2 6 | ./run_docs.sh 7 | ``` 8 | 9 | ### To Update Jeckyll Site 10 | ``` 11 | make html 12 | mv build/html/* . 13 | rm -r build/doctrees/ 14 | touch .nojekyll 15 | ``` 16 | And commit and push to branch `gh-pages` 17 | -------------------------------------------------------------------------------- /docs/atari_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/moolib/06e7a3e80c9f52729b4a6159f3fb4fc78986c98e/docs/atari_1.png -------------------------------------------------------------------------------- /docs/atari_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/moolib/06e7a3e80c9f52729b4a6159f3fb4fc78986c98e/docs/atari_2.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/run_docs.sh: -------------------------------------------------------------------------------- 1 | make html && printf "\nServing docs at: \033[37;1m http://0.0.0.0:8000/build/html \033[0m\n" && python3 -m http.server 2 | -------------------------------------------------------------------------------- /docs/source/_templates/moolib_class_template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | 8 | 9 | {% block methods %} 10 | .. automethod:: __init__ 11 | 12 | {% if methods %} 13 | .. rubric:: {{ _('Methods') }} 14 | 15 | .. autosummary:: 16 | :nosignatures: 17 | 18 | {% for item in methods %} 19 | ~{{ name }}.{{ item }} 20 | {%- endfor %} 21 | {% endif %} 22 | {% endblock %} 23 | 24 | {% block attributes %} 25 | {% if attributes %} 26 | .. rubric:: {{ _('Attributes') }} 27 | 28 | .. autosummary:: 29 | {% for item in attributes %} 30 | ~{{ name }}.{{ item }} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /docs/source/_templates/moolib_result_template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | 8 | .. note:: 9 | This class is a return value of an asynchronous call. It can not be 10 | instantiatied by the user. 11 | 12 | {% block methods %} 13 | 14 | {% if methods %} 15 | .. rubric:: {{ _('Methods') }} 16 | 17 | .. autosummary:: 18 | :nosignatures: 19 | 20 | {% for item in methods %} 21 | {% if item != "__init__" %} 22 | ~{{ name }}.{{ item }} 23 | {% endif %} 24 | {%- endfor %} 25 | {% endif %} 26 | {% endblock %} 27 | 28 | {% block attributes %} 29 | {% if attributes %} 30 | .. rubric:: {{ _('Attributes') }} 31 | 32 | .. autosummary:: 33 | {% for item in attributes %} 34 | ~{{ name }}.{{ item }} 35 | {%- endfor %} 36 | {% endif %} 37 | {% endblock %} 38 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "moolib" 21 | copyright = "2021, Facebook AI Research" 22 | author = "Facebook AI Research" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon"] 31 | 32 | # Add any paths that contain templates here, relative to this directory. 33 | templates_path = ["_templates"] 34 | 35 | # List of patterns, relative to source directory, that match files and 36 | # directories to ignore when looking for source files. 37 | # This pattern also affects html_static_path and html_extra_path. 38 | exclude_patterns = [] 39 | 40 | 41 | # -- Options for HTML output ------------------------------------------------- 42 | 43 | # The theme to use for HTML and HTML Help pages. See the documentation for 44 | # a list of builtin themes. 45 | # 46 | html_theme = "alabaster" 47 | 48 | # Add any paths that contain custom static files (such as style sheets) here, 49 | # relative to this directory. They are copied after the builtin static files, 50 | # so a file named "default.css" will overwrite the builtin "default.css". 51 | html_static_path = ["_static"] 52 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. moolib documentation master file, created by 2 | sphinx-quickstart on Thu Aug 19 14:02:03 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Moolib Documentation 7 | ================================== 8 | 9 | 10 | .. code-block:: html 11 | 12 | moolib - a communications library for distributed ml training 13 | 14 | moolib offers general purpose RPC with automatic transport 15 | selection (shared memory, tcp/ip, infiniband) allowing models 16 | to data-parallelise their training and synchronize gradients 17 | and model weights across many nodes. 18 | 19 | 20 | .. toctree:: 21 | :hidden: 22 | :maxdepth: 2 23 | :caption: Contents: 24 | 25 | What is Moolib? 26 | 27 | 28 | Getting Started 29 | --------------- 30 | 31 | Install from GitHub 32 | 33 | .. code-block:: bash 34 | 35 | pip install git+https://github.com/facebookresearch/moolib 36 | 37 | Build from source: **Linux** 38 | 39 | .. code-block:: bash 40 | 41 | git clone --recursive git@github.com:facebookresearch/moolib 42 | cd moolib && pip install . 43 | 44 | Build from source: **MacOS** 45 | 46 | .. code-block:: bash 47 | 48 | git clone --recursive git@github.com:facebookresearch/moolib 49 | cd moolib && USE_CUDA=0 pip install . 50 | 51 | How to host docs: 52 | 53 | .. code-block:: bash 54 | 55 | # after installation 56 | pip install sphinx==4.1.2 57 | cd docs && ./run_docs.sh 58 | 59 | 60 | API 61 | ----------------- 62 | 63 | Classes 64 | """"""" 65 | 66 | .. currentmodule:: moolib 67 | .. autosummary:: 68 | :toctree: api 69 | :nosignatures: 70 | :template: moolib_class_template.rst 71 | 72 | Accumulator 73 | Batcher 74 | Broker 75 | EnvPool 76 | EnvStepper 77 | Group 78 | Rpc 79 | 80 | Methods 81 | """"""" 82 | 83 | .. currentmodule:: moolib 84 | .. autosummary:: 85 | :toctree: api 86 | :nosignatures: 87 | :template: 88 | 89 | create_uid 90 | set_logging 91 | set_log_level 92 | set_max_threads 93 | 94 | Futures 95 | """"""" 96 | 97 | .. currentmodule:: moolib 98 | .. autosummary:: 99 | :toctree: api 100 | :nosignatures: 101 | :template: moolib_result_template.rst 102 | 103 | AllReduce 104 | Future 105 | EnvStepperFuture 106 | 107 | Examples 108 | ----------------- 109 | 110 | Some examples are in the ``./examples`` directory. 111 | 112 | 113 | .. .. automodule:: moolib 114 | .. :members: 115 | 116 | 117 | Search 118 | ----------------- 119 | 120 | * :ref:`search` 121 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 2 | # moolib examples 3 | 4 | ## Simplified example agent 5 | 6 | To try out the simplified A2C example locally, run 7 | 8 | ``` 9 | python examples/a2c.py 10 | ``` 11 | 12 | This produces a file called `logs.tsv` which can be plotted using 13 | `plot.py`. To see results during training, you can run 14 | 15 | ``` 16 | watch -n3 --color python examples/plot.py logs.tsv --ykey mean_episode_return --window 500 17 | ``` 18 | 19 | in another terminal. 20 | 21 | Here's a sample result: 22 | 23 | ``` 24 | mean_episode_return 25 | 180 +---------------------------------------------------------------------+ 26 | | + + + AAAAAAAAAAAA + | 27 | 160 |-+.........:..........:..........AAAAAAA......:..........:.........+-| 28 | | : : AAAAAAA: : : | 29 | | : AAAAAA : : : | 30 | 140 |-+.........:........AAA...........:...........:..........:.........+-| 31 | | : AAA : : : : | 32 | 120 |-+.........AAAAAAAA...:...........:...........:..........:.........+-| 33 | | AAAAA : : : : | 34 | | AA : : : : : | 35 | 100 |-+....AA...:..........:...........:...........:..........:.........+-| 36 | | AAA : : : : : | 37 | 80 |-+.AA......:..........:...........:...........:..........:.........+-| 38 | | A : : : : : | 39 | | AA : : : : : | 40 | 60 |-A.........:..........:...........:...........:..........:.........+-| 41 | | A : : : : : | 42 | 40 |AA.........:..........:...........:...........:..........:.........+-| 43 | |A : : : : : | 44 | |A : : : : : | 45 | 20 |-+.........:..........:...........:...........:..........:.........+-| 46 | | + + + + + | 47 | 0 +---------------------------------------------------------------------+ 48 | 0 50000 100000 150000 200000 250000 300000 49 | step 50 | logs.tsv +--A--+ 51 | ``` 52 | 53 | 54 | ## Fully-fledged vtrace agent 55 | 56 | ### Running 57 | 58 | To run the example agent on a given Atari level: 59 | 60 | First, start the broker: 61 | 62 | python -m moolib.broker 63 | 64 | It will output something like `Broker listening at 0.0.0.0:4431`. 65 | 66 | Note that a **single broker is enough** for all your experiments. 67 | 68 | Now take the IP address of your computer. If you ssh'd into your 69 | machine, this should work (in a new shell): 70 | 71 | ``` 72 | export BROKER_IP=$(echo $SSH_CONNECTION | cut -d' ' -f3) # Should give your machine's IP. 73 | export BROKER_PORT=4431 74 | ``` 75 | 76 | To start an experiment with a single peer: 77 | 78 | python -m examples.vtrace.experiment connect=BROKER_IP:BROKER_PORT \ 79 | savedir=/tmp/moolib-atari/savedir \ 80 | project=moolib-atari \ 81 | group=Zaxxon-Breakout \ 82 | env.name=ALE/Breakout-v5 83 | 84 | To add more peers to this experiment, start more processes with the 85 | same `project` and `group` settings, using a different setting for 86 | `device` (default: `'cuda:0'`) if on the same machine. 87 | 88 | 89 | ### Batch sizes in example agent. 90 | 91 | In the `moolib` example agent(s), there are several different batch sizes: 92 | 93 | * The `actor_batch_size`, i.e., the second dimension of the model 94 | inputs at acting time: The `B` in `[1, B, W, H, C]`. 95 | (2x actor batch size is the number of environment instances due to 96 | 'double buffering') 97 | 98 | * The learner batch size (often just `batch_size`), i.e. the `B` in 99 | `[T, B, W, H, C]` at learning time (to produce local gradients). 100 | 101 | * The unroll length is a batch size of sorts, i.e. the `T` in `[T, 102 | B, W, H, C]` at learning time. Only when using RNNs (agents with 103 | memory) is this partially treated as a sequence length as well. 104 | 105 | * The virtual batch size, i.e., the number of samples in the `B` 106 | dimension moolib consumes and adds to its gradient buffers before 107 | a gradient descent step is happening. This can happen in two ways: 108 | (1) A single peer could go through several backprop steps and keep 109 | adding to its "running gradient" before it applies the gradient, 110 | or (2) multiple peers go through one sample each and accumulate 111 | their gradients. The virtual batch size is then 112 | `number_of_peers * learner_batch_size`.[^1] 113 | 114 | Note that the `virtual_batch_size` setting in moolib is (currently) a 115 | _lower bound_ on the number of samples required to do a single grad 116 | descent step. When using several peers in parallel, `moolib` can 117 | overshoot. Logging the `gradient_stats["batch_size"]` entry tells you 118 | what the actual virtual batch size has been at each step. The reason 119 | `moolib` treats the `accumulator.set_virtual_batch_size` value as a lower 120 | bound (instead of as an lower and upper bound) is that it would 121 | otherwise need to do more synchronisation, which would reduce overall 122 | throughput. 123 | 124 | [^1]: Note that here `number_of_peers` isn't necessarily _all_ peers 125 | that participate in the training: Instead, it's how many peers 126 | happened to participate in "this" gradient accumulation and due to 127 | (1), a peer could participate multiple times. Thus 128 | `number_of_peers` can even be higher than the total number of 129 | peers in the training. 130 | -------------------------------------------------------------------------------- /examples/atari/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import gym 15 | 16 | from . import atari_preprocessing 17 | 18 | 19 | def create_env(flags): 20 | env = gym.make( # Cf. https://brosa.ca/blog/ale-release-v0.7 21 | flags.env.name, 22 | obs_type="grayscale", # "ram", "rgb", or "grayscale". 23 | frameskip=1, # Action repeats. Done in wrapper b/c of noops. 24 | repeat_action_probability=flags.env.repeat_action_probability, # Sticky actions. 25 | full_action_space=True, # Use all actions. 26 | render_mode=None, # None, "human", or "rgb_array". 27 | ) 28 | 29 | # Using wrapper from seed_rl in order to do random no-ops _before_ frameskipping. 30 | # gym.wrappers.AtariPreprocessing doesn't play well with the -v5 versions of the game. 31 | env = atari_preprocessing.AtariPreprocessing( 32 | env, 33 | frame_skip=flags.env.num_action_repeats, 34 | terminal_on_life_loss=False, 35 | screen_size=84, 36 | max_random_noops=flags.env.noop_max, # Max no-ops to apply at the beginning. 37 | ) 38 | env = gym.wrappers.FrameStack(env, num_stack=4) 39 | return env 40 | -------------------------------------------------------------------------------- /examples/atari/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from moolib.examples.common import nest 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self, num_actions=18, input_channels=4, use_lstm=False): 11 | super(Net, self).__init__() 12 | self.num_actions = num_actions 13 | self.use_lstm = use_lstm 14 | 15 | self.feat_convs = [] 16 | self.resnet1 = [] 17 | self.resnet2 = [] 18 | 19 | self.convs = [] 20 | 21 | for num_ch in [16, 32, 32]: 22 | feats_convs = [] 23 | feats_convs.append( 24 | nn.Conv2d( 25 | in_channels=input_channels, 26 | out_channels=num_ch, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1, 30 | ) 31 | ) 32 | feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 33 | self.feat_convs.append(nn.Sequential(*feats_convs)) 34 | 35 | input_channels = num_ch 36 | 37 | for i in range(2): 38 | resnet_block = [] 39 | resnet_block.append(nn.ReLU()) 40 | resnet_block.append( 41 | nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=num_ch, 44 | kernel_size=3, 45 | stride=1, 46 | padding=1, 47 | ) 48 | ) 49 | resnet_block.append(nn.ReLU()) 50 | resnet_block.append( 51 | nn.Conv2d( 52 | in_channels=input_channels, 53 | out_channels=num_ch, 54 | kernel_size=3, 55 | stride=1, 56 | padding=1, 57 | ) 58 | ) 59 | if i == 0: 60 | self.resnet1.append(nn.Sequential(*resnet_block)) 61 | else: 62 | self.resnet2.append(nn.Sequential(*resnet_block)) 63 | 64 | self.feat_convs = nn.ModuleList(self.feat_convs) 65 | self.resnet1 = nn.ModuleList(self.resnet1) 66 | self.resnet2 = nn.ModuleList(self.resnet2) 67 | 68 | self.fc = nn.Linear(3872, 256) 69 | 70 | # FC output size + one-hot of last action + last reward. 71 | core_output_size = self.fc.out_features + num_actions + 1 72 | 73 | if use_lstm: 74 | self.core = nn.LSTM(core_output_size, 256, num_layers=1) 75 | core_output_size = 256 76 | 77 | self.policy = nn.Linear(core_output_size, self.num_actions) 78 | self.baseline = nn.Linear(core_output_size, 1) 79 | 80 | def initial_state(self, batch_size=1): 81 | if not self.use_lstm: 82 | return tuple() 83 | return tuple( 84 | torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) 85 | for _ in range(2) 86 | ) 87 | 88 | def forward(self, inputs, core_state=None): 89 | reward = inputs["reward"] 90 | x = inputs["state"] 91 | 92 | T, B, *_ = x.shape 93 | x = torch.flatten(x, 0, 1) # Merge time and batch. 94 | x = x.float() / 255.0 95 | 96 | res_input = None 97 | for i, fconv in enumerate(self.feat_convs): 98 | x = fconv(x) 99 | res_input = x 100 | x = self.resnet1[i](x) 101 | x += res_input 102 | res_input = x 103 | x = self.resnet2[i](x) 104 | x += res_input 105 | 106 | x = F.relu(x) 107 | x = x.view(T * B, -1) 108 | x = F.relu(self.fc(x)) 109 | 110 | one_hot_last_action = F.one_hot( 111 | inputs["prev_action"].view(T * B), self.num_actions 112 | ).float() 113 | clipped_reward = torch.clamp(reward, -1, 1).view(T * B, 1) 114 | core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1) 115 | 116 | if self.use_lstm: 117 | done = inputs["done"] 118 | core_input = core_input.view(T, B, -1) 119 | core_output_list = [] 120 | notdone = (~done).float() 121 | for input, nd in zip(core_input.unbind(), notdone.unbind()): 122 | # Reset core state to zero whenever an episode ended. 123 | # Make `done` broadcastable with (num_layers, B, hidden_size) 124 | # states: 125 | nd = nd.view(1, -1, 1) 126 | core_state = nest.map(nd.mul, core_state) 127 | output, core_state = self.core(input.unsqueeze(0), core_state) 128 | core_output_list.append(output) 129 | core_output = torch.flatten(torch.cat(core_output_list), 0, 1) 130 | else: 131 | core_output = core_input 132 | 133 | policy_logits = self.policy(core_output) 134 | baseline = self.baseline(core_output) 135 | 136 | action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) 137 | 138 | policy_logits = policy_logits.view(T, B, self.num_actions) 139 | baseline = baseline.view(T, B) 140 | action = action.view(T, B) 141 | 142 | output = dict( 143 | policy_logits=policy_logits, 144 | baseline=baseline, 145 | action=action, 146 | ) 147 | return output, core_state 148 | 149 | 150 | def create_model(flags): 151 | model = Net(use_lstm=flags.use_lstm) 152 | model.to(device=flags.device) 153 | return model 154 | -------------------------------------------------------------------------------- /examples/common/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import dataclasses 3 | import logging 4 | 5 | import torch 6 | import moolib 7 | 8 | 9 | @dataclasses.dataclass 10 | class StatMean: 11 | value: float = 0 12 | n: int = 0 13 | 14 | def result(self): 15 | if self.n == 0: 16 | return None 17 | return self.value / self.n 18 | 19 | def __sub__(self, other): 20 | assert isinstance(other, StatMean) 21 | return StatMean(self.value - other.value, self.n - other.n) 22 | 23 | def __iadd__(self, other): 24 | if isinstance(other, StatMean): 25 | self.value += other.value 26 | self.n += other.n 27 | else: 28 | self.value += other 29 | self.n += 1 30 | return self 31 | 32 | def reset(self): 33 | self.value = 0 34 | self.n = 0 35 | 36 | def __repr__(self): 37 | return repr(self.result()) 38 | 39 | 40 | @dataclasses.dataclass 41 | class StatSum: 42 | value: float = 0 43 | 44 | def result(self): 45 | return self.value 46 | 47 | def __sub__(self, other): 48 | assert isinstance(other, StatSum) 49 | return StatSum(self.value - other.value) 50 | 51 | def __iadd__(self, other): 52 | if isinstance(other, StatSum): 53 | self.value += other.value 54 | else: 55 | self.value += other 56 | return self 57 | 58 | def reset(self): 59 | pass 60 | 61 | def __repr__(self): 62 | return repr(self.result()) 63 | 64 | 65 | class GlobalStatsAccumulator: 66 | """Class for global accumulation state. add_stats gets reduced.""" 67 | 68 | def __init__(self, rpc_group, global_stats): 69 | self.rpc_group = rpc_group 70 | self.global_stats = global_stats 71 | self.reduce_future = None 72 | self.queued_global_stats = None 73 | self.sent_global_stats = None 74 | self.prev_stats = None 75 | 76 | def add_stats(self, dst, src): 77 | for k, v in dst.items(): 78 | v += src[k] 79 | return dst 80 | 81 | def enqueue_global_stats(self, stats): 82 | if self.queued_global_stats is None: 83 | self.queued_global_stats = copy.deepcopy(stats) 84 | else: 85 | # Sum pending data. 86 | self.add_stats(self.queued_global_stats, stats) 87 | 88 | def reduce(self, stats): 89 | if self.reduce_future is not None and self.reduce_future.done(): 90 | if self.reduce_future.exception() is not None: 91 | logging.info( 92 | "global stats accumulation error: %s", 93 | self.reduce_future.exception(), 94 | ) 95 | self.enqueue_global_stats(self.sent_global_stats) 96 | else: 97 | self.add_stats(self.global_stats, self.reduce_future.result()) 98 | self.reduce_future = None 99 | 100 | stats_diff = stats 101 | if self.prev_stats is not None: 102 | stats_diff = {k: v - self.prev_stats[k] for k, v in stats.items()} 103 | 104 | self.enqueue_global_stats(stats_diff) 105 | self.prev_stats = copy.deepcopy(stats) 106 | 107 | if self.reduce_future is None: 108 | # Only reduce when not currently reducing. 109 | # Otherwise, we keep queued_global_stats for next time. 110 | self.sent_global_stats = self.queued_global_stats 111 | self.queued_global_stats = None 112 | # Additional copy to deal with potential partial reductions. 113 | self.reduce_future = self.rpc_group.all_reduce( 114 | "global stats", copy.deepcopy(self.sent_global_stats), self.add_stats 115 | ) 116 | 117 | def reset(self): 118 | if self.prev_stats is not None: 119 | for _, v in self.prev_stats.items(): 120 | v.reset() 121 | 122 | 123 | def _mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 124 | delta = batch_mean - mean 125 | tot_count = count + batch_count 126 | 127 | new_mean = mean + delta * batch_count / tot_count 128 | m_a = var * count 129 | m_b = batch_var * batch_count 130 | M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count 131 | new_var = M2 / tot_count 132 | new_count = tot_count 133 | 134 | return new_mean, new_var, new_count 135 | 136 | 137 | # From https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/running_mean_std.py#L5 138 | class RunningMeanStd(object): 139 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 140 | def __init__(self, epsilon=1e-4, shape=()): 141 | self.mean = torch.zeros(shape, dtype=torch.float64) 142 | self.var = torch.ones(shape, dtype=torch.float64) 143 | self.count = epsilon 144 | 145 | def update(self, x): 146 | batch_mean = torch.mean(x, axis=0) 147 | batch_var = torch.var(x, axis=0) 148 | batch_count = x.shape[0] 149 | self.mean, self.var, self.count = _mean_var_count_from_moments( 150 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 151 | ) 152 | 153 | 154 | class EnvBatchState: 155 | def __init__(self, flags, model): 156 | batch_size = flags.actor_batch_size 157 | device = flags.device 158 | self.batch_size = batch_size 159 | self.prev_action = torch.zeros(batch_size).long().to(device) 160 | self.future = None 161 | self.core_state = model.initial_state(batch_size=batch_size) 162 | self.core_state = tuple(s.to(device) for s in self.core_state) 163 | self.initial_core_state = self.core_state 164 | 165 | self.running_reward = torch.zeros(batch_size) 166 | self.step_count = torch.zeros(batch_size) 167 | 168 | self.discounting = flags.discounting 169 | self.weighted_returns = torch.zeros(batch_size) 170 | self.weighted_returns_rms = RunningMeanStd() 171 | 172 | self.time_batcher = moolib.Batcher(flags.unroll_length + 1, device) 173 | 174 | def update(self, env_outputs, action, stats): 175 | self.prev_action = action 176 | self.running_reward += env_outputs["reward"] 177 | self.weighted_returns *= self.discounting 178 | self.weighted_returns += env_outputs["reward"] 179 | self.weighted_returns_rms.update(self.weighted_returns) 180 | 181 | self.scaled_reward = env_outputs["reward"] / torch.sqrt( 182 | self.weighted_returns_rms.var + 1e-8 183 | ) 184 | 185 | self.step_count += 1 186 | 187 | done = env_outputs["done"] 188 | 189 | episode_return = self.running_reward * done 190 | episode_step = self.step_count * done 191 | 192 | episodes_done = done.sum().item() 193 | if episodes_done > 0: 194 | stats["mean_episode_return"] += episode_return.sum().item() / episodes_done 195 | stats["mean_episode_step"] += episode_step.sum().item() / episodes_done 196 | stats["steps_done"] += done.numel() 197 | stats["episodes_done"] += episodes_done 198 | 199 | stats["running_reward"] += self.running_reward.mean().item() 200 | stats["running_step"] += self.step_count.mean().item() 201 | 202 | not_done = ~done 203 | 204 | self.running_reward *= not_done 205 | self.weighted_returns *= not_done 206 | self.step_count *= not_done 207 | -------------------------------------------------------------------------------- /examples/common/nest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | def map(f, n): 5 | if isinstance(n, tuple) or isinstance(n, list): 6 | return n.__class__(map(f, sn) for sn in n) 7 | elif isinstance(n, dict): 8 | return {k: map(f, v) for k, v in n.items()} 9 | else: 10 | return f(n) 11 | 12 | 13 | def flatten(n): 14 | if isinstance(n, tuple) or isinstance(n, list): 15 | for sn in n: 16 | yield from flatten(sn) 17 | elif isinstance(n, dict): 18 | for key in n: 19 | yield from flatten(n[key]) 20 | else: 21 | yield n 22 | 23 | 24 | def zip(*nests): 25 | n0, *nests = nests 26 | iters = [flatten(n) for n in nests] 27 | 28 | def f(first): 29 | return [first] + [next(i) for i in iters] 30 | 31 | return map(f, n0) 32 | 33 | 34 | def map_many(f, *nests): 35 | n0, *nests = nests 36 | iters = [flatten(n) for n in nests] 37 | 38 | def g(first): 39 | return f([first] + [next(i) for i in iters]) 40 | 41 | return map(g, n0) 42 | -------------------------------------------------------------------------------- /examples/common/record.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | """'Logging' utilities.""" 3 | 4 | import csv 5 | import json 6 | import logging 7 | import os 8 | import time 9 | 10 | 11 | def log_to_file(_state={}, **fields): # noqa: B006 12 | """Incrementally write logs.tsv into pwd.""" 13 | if "writer" not in _state: 14 | path = "logs.tsv" # Could infer FLAGS if we had them. 15 | 16 | writeheader = not os.path.exists(path) 17 | fieldnames = list(fields.keys()) 18 | 19 | _state["file"] = open(path, "a", buffering=1) # Line buffering. 20 | _state["writer"] = csv.DictWriter(_state["file"], fieldnames, delimiter="\t") 21 | if writeheader: 22 | _state["writer"].writeheader() 23 | logging.info("Writing logs to %s", path) 24 | else: 25 | logging.info("Appending logs to %s", path) 26 | 27 | writer = _state["writer"] 28 | if writer is not None: 29 | writer.writerow(fields) 30 | 31 | 32 | def symlink_path(target, symlink): 33 | try: 34 | if os.path.islink(symlink): 35 | os.remove(symlink) 36 | if not os.path.exists(symlink): 37 | os.symlink(target, symlink) 38 | return True 39 | except OSError: 40 | # os.remove() or os.symlink() raced. Don't do anything. 41 | pass 42 | return False 43 | 44 | 45 | def write_metadata(localdir, srcdir, **kwargs): 46 | """Write meta.json file with some information on our setup.""" 47 | if not localdir: 48 | return 49 | 50 | metadata = { 51 | "env": os.environ.copy(), 52 | "date_start": time.strftime("%Y-%m-%d %H:%M:%S"), 53 | } 54 | metadata.update(kwargs) 55 | 56 | try: 57 | import git 58 | except ImportError: 59 | logging.warning( 60 | "Couldn't import gitpython module; install it with `pip install gitpython`." 61 | ) 62 | else: 63 | try: 64 | repo = git.Repo(path=srcdir, search_parent_directories=True) 65 | metadata["git"] = { 66 | "commit": repo.commit().hexsha, 67 | "is_dirty": repo.is_dirty(), 68 | "path": repo.git_dir, 69 | } 70 | if not repo.head.is_detached: 71 | metadata["git"]["branch"] = repo.active_branch.name 72 | except git.InvalidGitRepositoryError: 73 | pass 74 | 75 | if "git" not in metadata: 76 | logging.warning("Couldn't determine git data.") 77 | 78 | symlink = os.path.join(localdir, "meta.json") 79 | filename = "%s.%s.%d" % (symlink, time.strftime("%Y%m%d-%H%M%S"), os.getpid()) 80 | 81 | with open(filename, "w") as f: 82 | json.dump(metadata, f, indent=2, sort_keys=True) 83 | 84 | symlink_path(os.path.relpath(filename, start=localdir), symlink) 85 | -------------------------------------------------------------------------------- /examples/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | """ 6 | Script for plotting results. 7 | 8 | ``` 9 | python plot.py logs.tsv 10 | ``` 11 | """ 12 | import argparse 13 | import glob 14 | import os 15 | 16 | import gnuplotlib as gp 17 | import numpy as np 18 | import pandas # Fast CSV reading. 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument("--xkey", default="step", type=str, help="x values to plot.") 24 | parser.add_argument( 25 | "--ykey", default="episode_return", type=str, help="y values to plot." 26 | ) 27 | parser.add_argument("--window", default=50, type=int, help="Smoothing window size.") 28 | parser.add_argument("--width", default=80, type=int, help="Width of plot.") 29 | parser.add_argument("--height", default=30, type=int, help="Height of plot.") 30 | parser.add_argument( 31 | "--errorbars", action="store_true", help="Whether to print error bars." 32 | ) 33 | parser.add_argument( 34 | "--smoothing", 35 | default="pandas", 36 | choices=["pandas", "convolve", "cumsum"], 37 | help="Smoothing algorithm.", 38 | ) 39 | parser.add_argument("files", nargs="*", type=str) 40 | 41 | 42 | def moving_average_cumsum(a, n=20): 43 | # Fast, but doesn't play well with NaNs 44 | ret = np.cumsum(a, dtype=float) 45 | ret[n:] = ret[n:] - ret[:-n] 46 | return ret[n - 1 :] / n 47 | 48 | 49 | def moving_average(a, n=20): 50 | return np.convolve(a, np.ones((n,)) / n, mode="valid") 51 | 52 | 53 | def rolling_xs_ys(xs, ys, window_size=20): 54 | """Alternative to rolling() in pandas.""" 55 | ma = moving_average_cumsum if FLAGS.smoothing == "cumsum" else moving_average 56 | return xs[window_size - 1 :], ma(ys, window_size) 57 | 58 | 59 | def plot(xys, xrange=None, yrange=None, color="green"): 60 | plot_options = dict( 61 | terminal="dumb %d %d ansi" % (FLAGS.width, FLAGS.height), 62 | title=FLAGS.ykey, 63 | xlabel=FLAGS.xkey, 64 | set=("key outside bottom center",), 65 | # _with="points linecolor '%s'" % color, 66 | ) 67 | 68 | if FLAGS.errorbars: 69 | plot_options["with"] = "yerrorbars" 70 | 71 | if xrange is not None: 72 | plot_options.update(xrange=xrange) 73 | 74 | if yrange is not None: 75 | plot_options.update(yrange=yrange) 76 | 77 | gp.plot(*xys, **plot_options) 78 | 79 | 80 | def load_file(filename): 81 | delimiters = {".tsv": "\t", ".csv": ","} 82 | _, ext = os.path.splitext(filename) 83 | 84 | if ext not in delimiters: 85 | raise RuntimeError("Filetype not recognised (expected .csv or .tsv): %s" % ext) 86 | 87 | df = pandas.read_csv(filename, sep=delimiters[ext]) 88 | 89 | xs = np.array(df[FLAGS.xkey]) 90 | 91 | if FLAGS.smoothing == "pandas": 92 | window = df[FLAGS.ykey].rolling(window=FLAGS.window, min_periods=0) 93 | ys = np.array(window.mean()) 94 | else: 95 | ys = np.array(df[FLAGS.ykey]) 96 | xs, ys = rolling_xs_ys(xs, ys, window_size=FLAGS.window) 97 | 98 | return (xs, ys, {"legend": filename}) 99 | 100 | 101 | def main(): 102 | xys = [] 103 | 104 | for pattern in FLAGS.files: 105 | for filename in glob.glob(pattern): 106 | xys.append(load_file(filename)) 107 | 108 | plot(xys) 109 | 110 | 111 | if __name__ == "__main__": 112 | global FLAGS 113 | FLAGS = parser.parse_intermixed_args() 114 | main() 115 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | gym[atari, accept-rom-license] 2 | opencv-python 3 | pygame # Additional requirement for cartpole. -------------------------------------------------------------------------------- /examples/sbatch_experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # Runs a single experiment (multiple moolib peers) using sbatch. 6 | # 7 | # Test with python -m scripts/sbatch_experiment --dry 8 | # 9 | # Run w/o --dry. 10 | # 11 | import argparse 12 | import ctypes 13 | import getpass 14 | import os 15 | import socket 16 | import sys 17 | 18 | import coolname 19 | import moolib 20 | 21 | DEFAULT_PORT = 4431 22 | 23 | parser = argparse.ArgumentParser( 24 | description="Training with slurm", 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 26 | ) 27 | parser.add_argument( 28 | "--project", 29 | default="moolib-atari", 30 | type=str, 31 | help="Project name.", 32 | ) 33 | parser.add_argument( 34 | "--group", 35 | default=coolname.generate_slug(2), 36 | type=str, 37 | help="Group name. Defaults to a coolname slug.", 38 | ) 39 | parser.add_argument("--dry", action="store_true") 40 | parser.add_argument( 41 | "-n", 42 | "--num_peers", 43 | default=1, 44 | type=int, 45 | metavar="N", 46 | help="Number of peers (jobs in array) in this experiment.", 47 | ) 48 | parser.add_argument( 49 | "--time", 50 | default=60 * 24, 51 | type=int, 52 | metavar="T", 53 | help="Maximum time this experiment runs (in min).", 54 | ) 55 | parser.add_argument( 56 | "--constraint", 57 | default="", # eg: "bldg2,volta32gb" 58 | type=str, 59 | metavar="constr", 60 | help="Matching this job constraint.", 61 | ) 62 | parser.add_argument( 63 | "--partition", 64 | default="learnlab,learnfair", 65 | type=str, 66 | metavar="part", 67 | help="Request a specific partition for the resource allocation.", 68 | ) 69 | parser.add_argument( 70 | "--cmd", 71 | default=( 72 | "python -m examples.appo.experiment connect=%(broker)s savedir=%(savedir)s" 73 | " project=%(project)s group=%(group)s" 74 | ), 75 | type=str, 76 | metavar="cmd", 77 | help="The command to run.", 78 | ) 79 | parser.add_argument( 80 | "--broker", 81 | default="", 82 | type=str, 83 | metavar="addr:port", 84 | help="The address of the broker.", 85 | ) 86 | parser.add_argument( 87 | "args", 88 | nargs="*", 89 | default=[], 90 | help="Extra arguments.", 91 | type=str, 92 | ) 93 | parser.add_argument("--no-checks", action="store_true", help="Don't run checks.") 94 | 95 | 96 | def get_address(address): 97 | if address: 98 | return address 99 | 100 | ssh_connection = os.getenv("SSH_CONNECTION") 101 | if ssh_connection: 102 | try: 103 | client_ip, client_port, host_ip, host_port = ssh_connection.split() 104 | if len(host_ip.split(".")) == 4: 105 | return "%s:%i" % (host_ip, DEFAULT_PORT) 106 | except ValueError: 107 | pass 108 | 109 | return "%s:%i" % (socket.gethostbyname(socket.gethostname()), DEFAULT_PORT) 110 | 111 | 112 | def check_nfs(nfs_super_magic=0x6969): 113 | try: 114 | # See statfs(2). 115 | libc = ctypes.CDLL("libc.so.6") 116 | Statfs = ctypes.c_uint * 32 117 | buf = Statfs() 118 | ret = libc.statfs(".", buf) 119 | if ret != 0: 120 | return 121 | if buf[0] != nfs_super_magic: 122 | raise RuntimeError( 123 | "Must run from NFS directory, but cwd (%s) isn't (0x%x)" 124 | % (os.getcwd(), buf[0]), 125 | ) 126 | except OSError: 127 | pass 128 | 129 | 130 | def check_broker_online(address): 131 | rpc = moolib.Rpc() 132 | rpc.connect(address) 133 | rpc.set_timeout(2) 134 | try: 135 | rpc.sync("broker", "") 136 | except RuntimeError as e: 137 | # TODO: Add "ping" feature to moolib so we don't need to do _this_! 138 | if "timed out" in str(e): 139 | raise 140 | 141 | 142 | def check(address): 143 | if FLAGS.no_checks: 144 | return 145 | 146 | try: 147 | check_broker_online(address) 148 | except RuntimeError as e: 149 | print("Couldn't reach broker at %s. Is it online? (Error: %s)" % (address, e)) 150 | sys.exit(1) 151 | 152 | try: 153 | check_nfs() 154 | except RuntimeError as e: 155 | print(str(e)) 156 | sys.exit(2) 157 | 158 | 159 | def cmdlist(args): 160 | return ["sbatch"] + ["%s=%s" % item for item in args.items()] 161 | 162 | 163 | def main(): 164 | global FLAGS 165 | FLAGS = parser.parse_args() 166 | 167 | address = get_address(FLAGS.broker) 168 | check(address) 169 | 170 | savedir = os.path.join("/checkpoint", getpass.getuser(), FLAGS.project, FLAGS.group) 171 | 172 | try: 173 | os.makedirs(savedir) 174 | except FileExistsError: 175 | sys.stderr.write("Warning: Savedir path '%s' already exists\n" % savedir) 176 | 177 | if not os.access(savedir, os.W_OK | os.X_OK): 178 | sys.stderr.write("No write access to '%s'\n" % savedir) 179 | sys.exit(1) 180 | 181 | cmd = FLAGS.cmd % { 182 | "savedir": savedir, 183 | "broker": address, 184 | "project": FLAGS.project, 185 | "group": FLAGS.group, 186 | } 187 | 188 | slurm_output = os.path.join(savedir, "slurm-%A_%a.out") 189 | 190 | args = { 191 | "--constraint": FLAGS.constraint, 192 | "--job-name": "%s/%s" % (FLAGS.project, FLAGS.group), 193 | "--array": "0-%i" % (FLAGS.num_peers - 1), 194 | "--partition": FLAGS.partition, 195 | "--cpus-per-task": 10, 196 | "--gpus-per-task": 1, 197 | "--mem-per-cpu": "8G", 198 | "--time": FLAGS.time, 199 | "--ntasks": 1, 200 | "--output": slurm_output, 201 | "--error": slurm_output, 202 | "--export": "ALL", 203 | "--wrap": " ".join(cmd.split() + FLAGS.args), 204 | } 205 | 206 | execvlist = cmdlist(args) 207 | 208 | # Can't extra-escape strings for execvp, but want that for printing. 209 | args["--wrap"] = repr(args["--wrap"]) 210 | print(" ".join(cmdlist(args))) 211 | 212 | if FLAGS.dry: 213 | return 214 | 215 | os.execvp("sbatch", execvlist) 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /examples/vtrace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/moolib/06e7a3e80c9f52729b4a6159f3fb4fc78986c98e/examples/vtrace/__init__.py -------------------------------------------------------------------------------- /examples/vtrace/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | hydra: 16 | job_logging: 17 | formatters: 18 | simple: 19 | format: ${log_fmt} 20 | run: 21 | dir: "${localdir}" 22 | 23 | actor_batch_size: 128 24 | baseline_cost: 0.5 25 | batch_size: 32 26 | connect: 127.0.0.1:4431 27 | device: cuda:0 28 | discounting: 0.99 29 | entity: null 30 | entropy_cost: 0.0006 31 | env: 32 | name: "ALE/Pong-v5" # See https://brosa.ca/blog/ale-release-v0.7 33 | repeat_action_probability: 0.0 # Sticky action probability 34 | num_action_repeats: 4 35 | noop_max: 30 36 | fixup_init: true 37 | grad_norm_clipping: 40 38 | group: group 39 | local_name: "${uid:}" 40 | log_fmt: "[%(levelname)s:${local_name} %(module)s:%(lineno)d %(asctime)s] %(message)s" 41 | log_interval: 10 42 | checkpoint_interval: 600 43 | checkpoint_history_interval: 3600 44 | num_actor_batches: 2 45 | num_actor_cpus: 10 46 | optimizer: 47 | learning_rate: 0.0006 48 | beta_1: 0.9 # PyTorch default: 0.9 49 | beta_2: 0.999 # PyTorch default: 0.999 50 | epsilon: 1e-8 # PyTorch default: 1e-08 51 | project: project 52 | # Savedir is used for storing the checkpoint(s), 53 | # including flags and any global settings/stats for the training 54 | # localdir (which is a subdirectory of savedir) should be used 55 | # for storing logs and anything local to each instance 56 | savedir: "/checkpoint/${oc.env:USER}/hackrl/${project}/${group}" 57 | localdir: "${savedir}/peers/${local_name}" 58 | state_counter: none 59 | total_steps: 50e6 # 200M steps w/ frame skipping. 60 | unroll_length: 20 61 | use_lstm: false 62 | virtual_batch_size: 32 63 | reward_clip: 1.0 64 | wandb: true 65 | warmup: 0 66 | -------------------------------------------------------------------------------- /py/moolib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from ._C import ( 3 | Accumulator, 4 | AllReduce, 5 | Batcher, 6 | Broker, 7 | EnvPool, 8 | EnvRunner, 9 | EnvStepper, 10 | EnvStepperFuture, 11 | Future, 12 | Group, 13 | Queue, 14 | Rpc, 15 | RpcDeferredReturn, 16 | RpcError, 17 | __doc__, 18 | create_uid, 19 | set_log_level, 20 | set_logging, 21 | set_max_threads, 22 | ) 23 | 24 | 25 | __all__ = [ 26 | "Accumulator", 27 | "AllReduce", 28 | "Batcher", 29 | "Broker", 30 | "EnvPool", 31 | "EnvRunner", 32 | "EnvStepper", 33 | "EnvStepperFuture", 34 | "Future", 35 | "Group", 36 | "Queue", 37 | "Rpc", 38 | "RpcDeferredReturn", 39 | "RpcError", 40 | "__doc__", 41 | "create_uid", 42 | "set_log_level", 43 | "set_logging", 44 | "set_max_threads", 45 | ] 46 | -------------------------------------------------------------------------------- /py/moolib/broker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import moolib 3 | import argparse 4 | import time 5 | 6 | # TODO: Can moolib choose a port for us? 7 | DEFAULT_PORT = 4431 8 | 9 | parser = argparse.ArgumentParser(description="A script to run a moolib broker") 10 | 11 | parser.add_argument( 12 | "address", 13 | nargs="?", 14 | default="0.0.0.0:%i" % DEFAULT_PORT, 15 | type=str, 16 | metavar="addr:port", 17 | help="Broker server address to listen on.", 18 | ) 19 | 20 | 21 | def main(): 22 | FLAGS = parser.parse_args() 23 | 24 | broker_rpc = moolib.Rpc() 25 | broker_rpc.set_name("broker") 26 | broker = moolib.Broker(broker_rpc) 27 | broker_rpc.listen(FLAGS.address) 28 | 29 | print("Broker listening at %s" % FLAGS.address) 30 | 31 | try: 32 | while True: 33 | broker.update() 34 | time.sleep(0.25) 35 | except KeyboardInterrupt: 36 | pass 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /py/moolib/examples: -------------------------------------------------------------------------------- 1 | ../../examples -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | ( 6 | /( 7 | \.eggs 8 | | \.git 9 | | src\/backward-cpp 10 | | src\/fmt 11 | | src\/pybind11 12 | | src\/tensorpipe 13 | )/ 14 | | atari_preprocessing\.py 15 | ) 16 | ''' 17 | [tool.isort] 18 | force_single_line = true 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # To install: pip install . 6 | # 7 | # For debug builds: python setup.py build --debug install 8 | # 9 | # The environment variable USE_CUDA can be set to "OFF" (or 0). 10 | # 11 | 12 | import os 13 | import pathlib 14 | import subprocess 15 | import sys 16 | 17 | import setuptools 18 | from setuptools.command import build_ext 19 | from distutils import spawn 20 | 21 | 22 | class CMakeBuild(build_ext.build_ext): 23 | def run(self): # Necessary for pip install -e. 24 | for ext in self.extensions: 25 | self.build_extension(ext) 26 | 27 | def build_extension(self, ext): 28 | source_path = pathlib.Path(__file__).parent.resolve() 29 | output_path = pathlib.Path(self.get_ext_fullpath(ext.name)).parent.absolute() 30 | 31 | os.makedirs(self.build_temp, exist_ok=True) 32 | 33 | build_type = "Debug" if self.debug else "RelWithDebInfo" 34 | 35 | generator = "Ninja" if spawn.find_executable("ninja") else "Unix Makefiles" 36 | 37 | cmake_cmd = [ 38 | "cmake", 39 | str(source_path), 40 | "-G%s" % generator, 41 | "-DCMAKE_BUILD_TYPE=%s" % build_type, 42 | "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=%s" % output_path, 43 | ] 44 | 45 | use_cuda = os.environ.get("USE_CUDA", True) 46 | if use_cuda == "OFF": 47 | use_cuda = False 48 | if not int(use_cuda): 49 | cmake_cmd.append("-DUSE_CUDA=OFF") 50 | 51 | build_cmd = ["cmake", "--build", ".", "--parallel"] 52 | 53 | # pip install (but not python setup.py install) runs with a modified PYTHONPATH. 54 | # This can prevent cmake from finding the torch libraries. 55 | env = os.environ.copy() 56 | if "PYTHONPATH" in env: 57 | del env["PYTHONPATH"] 58 | try: 59 | subprocess.check_call(cmake_cmd, cwd=self.build_temp, env=env) 60 | subprocess.check_call(build_cmd, cwd=self.build_temp, env=env) 61 | except subprocess.CalledProcessError: 62 | # Don't obscure the error with a setuptools backtrace. 63 | sys.exit(1) 64 | 65 | 66 | def main(): 67 | with open("README.md") as f: 68 | long_description = f.read() 69 | 70 | setuptools.setup( 71 | name="moolib", 72 | version="0.2.0", 73 | description=("A library for distributed ML training with PyTorch"), 74 | long_description=long_description, 75 | long_description_content_type="text/markdown", 76 | author="tscmoo & the moolib dev team", 77 | url="https://github.com/facebookresearch/moolib", 78 | classifiers=[ 79 | "Programming Language :: C++", 80 | "Programming Language :: Python :: 3", 81 | "License :: OSI Approved :: MIT License", 82 | "Operating System :: POSIX :: Linux", 83 | "Operating System :: MacOS :: MacOS X", 84 | "Environment :: GPU :: NVIDIA CUDA", 85 | ], 86 | packages=["moolib", "moolib.examples.common", "moolib.examples.vtrace"], 87 | package_dir={"": "py", "moolib.examples": "examples"}, 88 | ext_modules=[setuptools.Extension("moolib._C", sources=[])], 89 | install_requires=["torch>=1.6.0"], 90 | cmdclass={"build_ext": CMakeBuild}, 91 | zip_safe=False, 92 | ) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /src/accumulator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | #include "pybind11/pybind11.h" 14 | 15 | #include "rpc.h" 16 | #include "util.h" 17 | 18 | namespace moolib { 19 | 20 | namespace py = pybind11; 21 | 22 | struct Group; 23 | 24 | struct AccumulatorImpl; 25 | 26 | struct GradientStats { 27 | int numGradients = 0; 28 | int numSkipped = 0; 29 | int batchSize = 0; 30 | }; 31 | 32 | struct Accumulator { 33 | 34 | std::unique_ptr impl; 35 | 36 | Accumulator(std::string name, py::object parameters, py::object buffers, const Group* group = nullptr); 37 | ~Accumulator(); 38 | 39 | void update(); 40 | 41 | void connect(std::string address); 42 | 43 | bool connected(); 44 | bool wantsState(); 45 | bool hasNewState(); 46 | bool hasGradients(); 47 | bool wantsGradients(); 48 | void setState(py::object userState); 49 | py::object state(); 50 | 51 | void skipGradients(); 52 | void reduceGradients(int batchSize); 53 | void zeroGradients(); 54 | 55 | int64_t modelVersion() const; 56 | void setModelVersion(int64_t n); 57 | py::dict getGradientStats() const; 58 | 59 | void setVirtualBatchSize(int n); 60 | void setParallelGradients(int n); 61 | 62 | std::string getLeader(); 63 | bool isLeader(); 64 | }; 65 | 66 | } // namespace moolib 67 | -------------------------------------------------------------------------------- /src/any.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace moolib { 19 | 20 | template 21 | struct Any { 22 | std::aligned_storage_t buf; 23 | void (*dtor)(Any*) = nullptr; 24 | template 25 | constexpr bool embed() const noexcept { 26 | // static_assert(sizeof(T) <= embeddedSize); 27 | return sizeof(T) <= embeddedSize; 28 | } 29 | Any() = default; 30 | Any(const Any&) = delete; 31 | Any(Any&&) = delete; 32 | ~Any() { 33 | if (dtor) { 34 | dtor(this); 35 | } 36 | } 37 | Any& operator=(const Any&) = delete; 38 | Any& operator=(Any&&) = delete; 39 | operator bool() const noexcept { 40 | return dtor != nullptr; 41 | } 42 | template 43 | T* pointer() const noexcept { 44 | return embed() ? (T*)&buf : (T*&)buf; 45 | } 46 | template 47 | T& as() noexcept { 48 | return *pointer(); 49 | } 50 | template 51 | const T& as() const noexcept { 52 | return *pointer(); 53 | } 54 | template 55 | T& emplace(Args&&... args) { 56 | if (dtor) { 57 | dtor(this); 58 | } 59 | T* p; 60 | if (embed()) { 61 | p = pointer(); 62 | new (p) T(std::forward(args)...); 63 | dtor = [](Any* me) { 64 | me->as().~T(); 65 | me->dtor = nullptr; 66 | }; 67 | } else { 68 | p = new T(std::forward(args)...); 69 | new ((T**)&buf) T*(p); 70 | dtor = [](Any* me) { 71 | delete me->pointer(); 72 | me->dtor = nullptr; 73 | }; 74 | } 75 | return *p; 76 | } 77 | template 78 | void destroy() { 79 | if (dtor) { 80 | if (embed()) { 81 | as().~T(); 82 | } else { 83 | delete pointer(); 84 | } 85 | dtor = nullptr; 86 | } 87 | } 88 | }; 89 | 90 | } // namespace moolib 91 | -------------------------------------------------------------------------------- /src/async.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "function.h" 11 | 12 | #include 13 | #include 14 | 15 | namespace async { 16 | 17 | void setCurrentThreadName(const std::string& name); 18 | 19 | template 20 | using Function = rpc::Function; 21 | using FunctionPointer = rpc::FunctionPointer; 22 | 23 | struct SchedulerFifoImpl; 24 | 25 | struct SchedulerFifo { 26 | 27 | std::unique_ptr impl_; 28 | 29 | SchedulerFifo(); 30 | SchedulerFifo(size_t nThreads); 31 | ~SchedulerFifo(); 32 | 33 | void run(Function f) noexcept; 34 | void setMaxThreads(size_t nThreads); 35 | bool isInThread() const noexcept; 36 | }; 37 | 38 | void stopForksFromHereOn(); 39 | 40 | } // namespace async 41 | -------------------------------------------------------------------------------- /src/batch_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | #include 13 | 14 | namespace moolib { 15 | namespace utils { 16 | 17 | namespace py = pybind11; 18 | 19 | // TODO: Merge these functions with Batcher. 20 | 21 | py::object squeezeFields(const py::handle& input, int64_t dim); 22 | py::object unsqueezeFields(const py::handle& input, int64_t dim); 23 | 24 | py::object stackFields(const py::tuple& input, int64_t dim); 25 | py::tuple unstackFields(const py::handle& input, int64_t batchSize, int64_t dim); 26 | 27 | } // namespace utils 28 | } // namespace moolib 29 | -------------------------------------------------------------------------------- /src/batchsizefinder.h: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | * Copyright (c) Facebook, Inc. and its affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | 24 | namespace batchsizefinder { 25 | 26 | struct Timer { 27 | std::chrono::steady_clock::time_point start; 28 | Timer() { 29 | reset(); 30 | } 31 | void reset() { 32 | start = std::chrono::steady_clock::now(); 33 | } 34 | float elapsedAt(std::chrono::steady_clock::time_point now) { 35 | return std::chrono::duration_cast>>(now - start).count(); 36 | } 37 | float elapsed() { 38 | return elapsedAt(std::chrono::steady_clock::now()); 39 | } 40 | float elapsedReset() { 41 | auto now = std::chrono::steady_clock::now(); 42 | float r = elapsedAt(now); 43 | start = now; 44 | return r; 45 | } 46 | }; 47 | 48 | float defaultScore(float latency, int bs) { 49 | return latency / 400 - std::log(bs / latency); 50 | } 51 | 52 | template 53 | int find( 54 | torch::Device device, Prepare&& prepare, Forward&& forward, int minBatchSize, int maxBatchsize, float maxTimeMs, 55 | Score&& scoreFunction) { 56 | torch::NoGradGuard ng; 57 | bool isCuda = device.is_cuda(); 58 | std::optional g; 59 | if (isCuda) { 60 | g.emplace(c10::cuda::getStreamFromPool(false, device.index())); 61 | } else { 62 | // throw std::runtime_error("findBatchSize on non-cuda device is not meaningful"); 63 | } 64 | auto input = prepare(1); 65 | auto call = [&]() { 66 | forward(input); 67 | if (isCuda) { 68 | g->current_stream().synchronize(); 69 | } 70 | }; 71 | fmt::printf("Finding batch size\n"); 72 | // warm up 73 | for (int i = 0; i != 10; ++i) { 74 | call(); 75 | } 76 | Timer t; 77 | for (int i = 0; i != 10; ++i) { 78 | call(); 79 | } 80 | float call1 = t.elapsed() / 10.0f * 1000.0f; 81 | fmt::printf("Base latency: %gms\n", call1); 82 | 83 | float maxms = maxTimeMs; 84 | int maxbs = maxBatchsize; 85 | 86 | struct I { 87 | float latency = 0.0f; 88 | int size = 0; 89 | int n = 0; 90 | bool isBad = false; 91 | }; 92 | 93 | auto scorex = [&](auto& x) { return scoreFunction(x.latency / x.n, x.size); }; 94 | 95 | std::map li; 96 | 97 | int best = 0; 98 | float bestScore = std::numeric_limits::infinity(); 99 | 100 | auto eval = [&](int i) { 101 | input = prepare(i); 102 | int badcount = 0; 103 | float latency = 0.0f; 104 | int n = 2; 105 | for (int j = 0; j != n; ++j) { 106 | call(); 107 | } 108 | for (int j = 0; j != n; ++j) { 109 | t.reset(); 110 | call(); 111 | float ms = t.elapsed() * 1000; 112 | latency += ms; 113 | if (ms > maxms || i > maxbs || i < minBatchSize) { 114 | ++badcount; 115 | } 116 | } 117 | auto& x = li[i]; 118 | x.size = i; 119 | x.latency += latency; 120 | x.n += n; 121 | x.isBad = badcount >= n; 122 | float score = scorex(x); 123 | if (!x.isBad && score < bestScore) { 124 | bestScore = score; 125 | best = i; 126 | } 127 | return badcount < n; 128 | }; 129 | 130 | for (int i = std::max(minBatchSize, 1);; i += (i + 3) / 4) { 131 | if (!eval(i)) { 132 | break; 133 | } 134 | } 135 | std::minstd_rand rng(std::random_device{}()); 136 | 137 | auto expandNear = [&](int k) { 138 | int r = 0; 139 | auto i = li.find(k); 140 | if (i != li.end()) { 141 | auto search = [&](auto begin, auto end) { 142 | int b = begin->first; 143 | int e; 144 | if (end == li.end()) { 145 | e = std::prev(end)->first; 146 | } else { 147 | e = end->first; 148 | } 149 | b = std::max(b, i->first - 3); 150 | e = std::max(b, i->first + 6); 151 | for (int i = b; i != e; ++i) { 152 | if (li.find(i) != li.end()) { 153 | continue; 154 | } 155 | ++r; 156 | if (!eval(i)) { 157 | break; 158 | } 159 | } 160 | }; 161 | search(i, std::next(i)); 162 | if (i != li.begin()) { 163 | search(std::prev(i), i); 164 | } 165 | } 166 | return r; 167 | }; 168 | 169 | for (int j = 0; j != 4; ++j) { 170 | int expands = 12; 171 | for (int k = 0; k != 12; ++k) { 172 | float sum = 0.0f; 173 | std::vector> list; 174 | float minweight = std::numeric_limits::infinity(); 175 | for (auto& [k, v] : li) { 176 | if (!v.isBad) { 177 | minweight = std::min(minweight, scorex(v)); 178 | } 179 | } 180 | for (auto i = li.begin();;) { 181 | auto next = std::next(i); 182 | if (next == li.end()) { 183 | break; 184 | } 185 | if (i->second.isBad && next->second.isBad) { 186 | i = next; 187 | continue; 188 | } 189 | int from = i->first + 1; 190 | int to = next->first; 191 | if (to - from > 0) { 192 | float weight = std::min(scorex(i->second), scorex(next->second)) - minweight; 193 | weight = 1.0f / std::min(std::exp(weight * 4), 1e9f); 194 | weight *= to - from; 195 | list.emplace_back(weight, from, to); 196 | sum += weight; 197 | } 198 | i = next; 199 | } 200 | if (list.size() > 0 && sum > 0.0f) { 201 | float val = std::uniform_real_distribution(0.0f, sum)(rng); 202 | for (auto& [weight, from, to] : list) { 203 | val -= weight; 204 | if (val <= 0) { 205 | int k = std::uniform_int_distribution(from, to - 1)(rng); 206 | eval(k); 207 | if (expands > 0) { 208 | expands -= expandNear(k); 209 | } 210 | break; 211 | } 212 | } 213 | } 214 | } 215 | if (best) { 216 | expandNear(best); 217 | } 218 | std::vector> sorted; 219 | for (auto& [k, v] : li) { 220 | if (!v.isBad) { 221 | sorted.emplace_back(scorex(v), k); 222 | } 223 | } 224 | std::sort(sorted.begin(), sorted.end()); 225 | for (size_t i = 0; i != sorted.size() && i < 10; ++i) { 226 | int k = std::get<1>(sorted[i]); 227 | if (li[k].n < 8) { 228 | eval(k); 229 | } 230 | } 231 | } 232 | 233 | for (auto& [k, v] : li) { 234 | fmt::printf( 235 | "Batch size %d, evals %d latency %fms throughput %g score %g\n", k, v.n, v.latency / v.n, 236 | v.size / (v.latency / v.n), scorex(v)); 237 | } 238 | 239 | fmt::printf( 240 | "Found best batch size of %d with evals %d latency %fms " 241 | "throughput %g score %g\n", 242 | best, li[best].n, li[best].latency / li[best].n, li[best].size / (li[best].latency / li[best].n), 243 | scorex(li[best])); 244 | return best; 245 | } 246 | 247 | template 248 | int find( 249 | torch::Device device, Prepare&& prepare, Forward&& forward, int minBatchSize, int maxBatchsize, float maxTimeMs) { 250 | return find( 251 | device, std::forward(prepare), std::forward(forward), minBatchSize, maxBatchsize, maxTimeMs, 252 | defaultScore); 253 | } 254 | 255 | } // namespace batchsizefinder 256 | -------------------------------------------------------------------------------- /src/broker.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "util.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace moolib { 20 | 21 | struct BrokerService { 22 | 23 | struct Peer { 24 | std::string name; 25 | std::chrono::steady_clock::time_point lastPing; 26 | std::chrono::steady_clock::duration timeoutDuration; 27 | std::optional>> syncFuture; 28 | std::optional> updateFuture; 29 | int32_t sortOrder = 0; 30 | bool active = false; 31 | size_t creationOrder = 0; 32 | }; 33 | 34 | struct Group { 35 | std::mutex mutex; 36 | std::string name; 37 | std::unordered_map peers; 38 | bool needsUpdate = false; 39 | std::chrono::steady_clock::time_point lastUpdate; 40 | uint32_t syncId = 1; 41 | size_t updateCount = 0; 42 | size_t orderCounter = 0; 43 | 44 | std::vector active; 45 | 46 | Peer& getPeer(std::string name) { 47 | auto i = peers.try_emplace(name); 48 | if (i.second) { 49 | auto& p = i.first->second; 50 | p.name = name; 51 | p.creationOrder = orderCounter++; 52 | } 53 | return i.first->second; 54 | } 55 | }; 56 | 57 | std::mutex groupsMutex; 58 | std::unordered_map groups; 59 | 60 | Group& getGroup(const std::string& name) { 61 | std::lock_guard l(groupsMutex); 62 | auto i = groups.try_emplace(name); 63 | if (i.second) { 64 | auto& g = i.first->second; 65 | g.name = name; 66 | } 67 | return i.first->second; 68 | } 69 | 70 | std::unordered_set syncSet; 71 | std::chrono::steady_clock::time_point lastCheckTimeouts; 72 | 73 | uint32_t nextSyncId = random(); 74 | 75 | std::vector tmpGroups; 76 | std::vector tmpPeers; 77 | 78 | rpc::Rpc* rpc = nullptr; 79 | 80 | BrokerService(rpc::Rpc& rpc) : rpc(&rpc) { 81 | setup(); 82 | } 83 | ~BrokerService() { 84 | close(); 85 | } 86 | 87 | template 88 | Future call(std::string_view peerName, std::string_view funcName, Args&&... args) { 89 | return callImpl(*rpc, peerName, funcName, std::forward(args)...); 90 | } 91 | 92 | void close() { 93 | rpc->undefine("BrokerService::groupSize"); 94 | rpc->undefine("BrokerService::resync"); 95 | } 96 | 97 | void setup() { 98 | 99 | rpc->define("BrokerService::groupSize", [this](std::string group) { 100 | log.info("groupSize called!\n"); 101 | auto& g = getGroup(group); 102 | std::lock_guard l(g.mutex); 103 | return g.peers.size(); 104 | }); 105 | 106 | rpc->define( 107 | "BrokerService::ping", [this](std::string group, std::string name, uint32_t timeoutMilliseconds) { 108 | auto& g = getGroup(group); 109 | std::lock_guard l(g.mutex); 110 | auto& p = g.getPeer(name); 111 | p.lastPing = std::chrono::steady_clock::now(); 112 | p.timeoutDuration = std::chrono::milliseconds(timeoutMilliseconds); 113 | if (!p.active) { 114 | g.needsUpdate = true; 115 | } 116 | // log("got ping for %s::%s\n", group, name); 117 | return g.syncId; 118 | }); 119 | 120 | rpc->define("BrokerService::resync", [this](std::string group) { 121 | auto& g = getGroup(group); 122 | std::lock_guard l(g.mutex); 123 | if (!g.needsUpdate) { 124 | log.info("Got resync request for %s\n", group); 125 | g.needsUpdate = true; 126 | } 127 | }); 128 | } 129 | 130 | void update() { 131 | 132 | auto now = std::chrono::steady_clock::now(); 133 | 134 | if (!syncSet.empty()) { 135 | 136 | for (auto i = syncSet.begin(); i != syncSet.end();) { 137 | auto& g = **i; 138 | std::lock_guard l(g.mutex); 139 | size_t total = 0; 140 | size_t ready = 0; 141 | for ([[maybe_unused]] auto& [pname, p] : g.peers) { 142 | if (p.syncFuture) { 143 | ++total; 144 | if (*p.syncFuture) { 145 | if ((*p.syncFuture)->first == g.syncId) { 146 | ++ready; 147 | } else { 148 | log.info("bad sync id?? got %#x expected %#x", (*p.syncFuture)->first, g.syncId); 149 | --total; 150 | } 151 | } 152 | } 153 | } 154 | // log("Sync midway %s %d/%d in %gs\n", g.name, ready, total, seconds(now - g.lastUpdate)); 155 | if (ready >= total || now - g.lastUpdate >= std::chrono::seconds(1)) { 156 | log.info("Sync %s %d/%d in %gs\n", g.name, ready, total, seconds(now - g.lastUpdate)); 157 | 158 | tmpPeers.clear(); 159 | for ([[maybe_unused]] auto& [pname, p] : g.peers) { 160 | if (p.syncFuture && *p.syncFuture && (*p.syncFuture)->first == g.syncId) { 161 | p.sortOrder = (*p.syncFuture)->second; 162 | tmpPeers.push_back(&p); 163 | p.active = true; 164 | } else { 165 | p.active = false; 166 | } 167 | } 168 | std::sort(tmpPeers.begin(), tmpPeers.end(), [](Peer* a, Peer* b) { 169 | if (a->sortOrder == b->sortOrder) { 170 | return a->creationOrder < b->creationOrder; 171 | } 172 | return a->sortOrder < b->sortOrder; 173 | }); 174 | g.active.clear(); 175 | for (auto* p : tmpPeers) { 176 | log.info("%s with sort order %d\n", p->name, p->sortOrder); 177 | g.active.push_back(p->name); 178 | } 179 | if (!g.active.empty()) { 180 | log.info("%s is the master\n", g.active.front()); 181 | } 182 | for (auto* p : tmpPeers) { 183 | p->updateFuture = call(p->name, "GroupService::update", g.name, g.syncId, g.active); 184 | } 185 | 186 | i = syncSet.erase(i); 187 | } else { 188 | ++i; 189 | } 190 | } 191 | } 192 | 193 | if (now - lastCheckTimeouts < std::chrono::milliseconds(500)) { 194 | return; 195 | } 196 | 197 | lastCheckTimeouts = now; 198 | tmpGroups.clear(); 199 | { 200 | std::lock_guard l(groupsMutex); 201 | for (auto& [gname, g] : groups) { 202 | tmpGroups.push_back(&g); 203 | } 204 | } 205 | for (auto* pg : tmpGroups) { 206 | auto& g = *pg; 207 | std::lock_guard l2(g.mutex); 208 | for (auto i = g.peers.begin(); i != g.peers.end();) { 209 | auto& p = i->second; 210 | if (now - p.lastPing >= p.timeoutDuration) { 211 | log.info("Peer %s::%s timed out\n", g.name, p.name); 212 | if (p.active) { 213 | g.needsUpdate = true; 214 | } 215 | i = g.peers.erase(i); 216 | } else { 217 | ++i; 218 | } 219 | } 220 | auto mintime = std::chrono::seconds(2); 221 | if (g.needsUpdate && (now - g.lastUpdate >= mintime)) { 222 | log.info("Initiating update of group %s\n", g.name); 223 | ++g.updateCount; 224 | g.lastUpdate = now; 225 | g.needsUpdate = false; 226 | uint32_t syncId = nextSyncId++; 227 | if (syncId == 0) { 228 | syncId = nextSyncId++; 229 | } 230 | g.syncId = syncId; 231 | for ([[maybe_unused]] auto& [pname, p] : g.peers) { 232 | p.syncFuture = call>(pname, "GroupService::sync", g.name, syncId); 233 | } 234 | syncSet.insert(&g); 235 | } 236 | } 237 | } 238 | }; 239 | 240 | struct Broker { 241 | 242 | std::shared_ptr rpc; 243 | BrokerService* brokerService = nullptr; 244 | 245 | Broker(std::shared_ptr rpc) : rpc(std::move(rpc)) { 246 | brokerService = this->rpc->getService("BrokerService"); 247 | } 248 | 249 | Broker() : rpc(std::make_shared()) { 250 | rpc->setName("broker"); 251 | brokerService = rpc->getService("BrokerService"); 252 | } 253 | 254 | void setName(std::string name) { 255 | rpc->setName(name); 256 | } 257 | 258 | void listen(std::string address) { 259 | rpc->listen(address); 260 | } 261 | 262 | void update() { 263 | brokerService->update(); 264 | } 265 | }; 266 | 267 | } // namespace moolib 268 | -------------------------------------------------------------------------------- /src/group.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "group.h" 9 | 10 | namespace moolib { 11 | 12 | struct GroupImpl { 13 | 14 | void connect(std::string brokerAddress, std::string groupName); 15 | void ping(); 16 | bool connected(); 17 | std::vector members(); 18 | }; 19 | 20 | Group::Group() { 21 | impl = std::make_unique(); 22 | } 23 | 24 | Group::~Group() {} 25 | 26 | void Group::connect(std::string brokerAddress, std::string groupName) { 27 | impl->connect(brokerAddress, groupName); 28 | } 29 | 30 | void Group::ping() { 31 | impl->ping(); 32 | } 33 | 34 | bool Group::connected() { 35 | return impl->connected(); 36 | } 37 | 38 | std::vector Group::members() { 39 | return impl->members(); 40 | } 41 | 42 | } // namespace moolib 43 | -------------------------------------------------------------------------------- /src/intrusive_list.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | namespace moolib { 14 | 15 | template 16 | struct IntrusiveListLink { 17 | T* prev = nullptr; 18 | T* next = nullptr; 19 | }; 20 | 21 | template T::*link> 22 | struct IntrusiveList { 23 | private: 24 | T head; 25 | static T*& next(T* at) noexcept { 26 | return (at->*link).next; 27 | } 28 | static T*& prev(T* at) noexcept { 29 | return (at->*link).prev; 30 | } 31 | 32 | public: 33 | using value_type = T; 34 | using reference = T&; 35 | using const_reference = const T&; 36 | using difference_type = std::ptrdiff_t; 37 | using size_type = std::size_t; 38 | 39 | struct iterator { 40 | private: 41 | T* ptr = nullptr; 42 | 43 | public: 44 | iterator() = default; 45 | iterator(T* ptr) : ptr(ptr) {} 46 | 47 | using difference_type = std::ptrdiff_t; 48 | using value_type = T; 49 | using pointer = T*; 50 | using reference = T&; 51 | using iterator_category = std::bidirectional_iterator_tag; 52 | 53 | T& operator*() const noexcept { 54 | return *ptr; 55 | } 56 | T* operator->() const noexcept { 57 | return ptr; 58 | } 59 | iterator& operator++() noexcept { 60 | ptr = next(ptr); 61 | return *this; 62 | } 63 | iterator operator++(int) noexcept { 64 | iterator r = (*this); 65 | ptr = next(ptr); 66 | return r; 67 | } 68 | iterator& operator--() noexcept { 69 | ptr = prev(ptr); 70 | return *this; 71 | } 72 | iterator operator--(int) noexcept { 73 | iterator r = (*this); 74 | ptr = prev(ptr); 75 | return r; 76 | } 77 | bool operator==(iterator n) const noexcept { 78 | return ptr == n.ptr; 79 | } 80 | bool operator!=(iterator n) const noexcept { 81 | return ptr != n.ptr; 82 | } 83 | }; 84 | 85 | IntrusiveList() noexcept { 86 | prev(&head) = &head; 87 | next(&head) = &head; 88 | } 89 | 90 | iterator begin() noexcept { 91 | return iterator(next(&head)); 92 | } 93 | iterator end() noexcept { 94 | return iterator(&head); 95 | } 96 | size_t size() = delete; 97 | constexpr size_t max_size() = delete; 98 | bool empty() const noexcept { 99 | return next((T*)&head) == &head; 100 | } 101 | 102 | void clear() noexcept { 103 | prev(&head) = &head; 104 | next(&head) = &head; 105 | } 106 | iterator insert(iterator at, T& item) noexcept { 107 | T* nextItem = &*at; 108 | T* prevItem = prev(&*at); 109 | prev(nextItem) = &item; 110 | next(prevItem) = &item; 111 | next(&item) = nextItem; 112 | prev(&item) = prevItem; 113 | return at; 114 | } 115 | static iterator erase(iterator at) noexcept { 116 | T* nextItem = next(&*at); 117 | T* prevItem = prev(&*at); 118 | prev(nextItem) = prevItem; 119 | next(prevItem) = nextItem; 120 | prev(&*at) = nullptr; 121 | next(&*at) = nullptr; 122 | return at; 123 | } 124 | static void erase(T& item) noexcept { 125 | erase(iterator(&item)); 126 | } 127 | iterator push_front(T& item) noexcept { 128 | return insert(begin(), item); 129 | } 130 | iterator push_back(T& item) noexcept { 131 | return insert(end(), item); 132 | } 133 | void pop_front() noexcept { 134 | erase(begin()); 135 | } 136 | void pop_back() noexcept { 137 | erase(prev(&head)); 138 | } 139 | T& front() noexcept { 140 | return *next(&head); 141 | } 142 | T& back() noexcept { 143 | return *prev(&head); 144 | } 145 | }; 146 | 147 | } // namespace moolib 148 | -------------------------------------------------------------------------------- /src/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "fmt/printf.h" 11 | #include "pybind11/pybind11.h" 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace moolib { 20 | 21 | namespace py = pybind11; 22 | 23 | inline py::object pyLogging; 24 | 25 | inline std::mutex logMutex; 26 | 27 | enum class LogLevel { 28 | LOG_NONE, 29 | LOG_ERROR, 30 | LOG_INFO, 31 | LOG_VERBOSE, 32 | LOG_DEBUG, 33 | }; 34 | 35 | constexpr auto LOG_NONE = LogLevel::LOG_NONE; 36 | constexpr auto LOG_ERROR = LogLevel::LOG_ERROR; 37 | constexpr auto LOG_INFO = LogLevel::LOG_INFO; 38 | constexpr auto LOG_VERBOSE = LogLevel::LOG_VERBOSE; 39 | constexpr auto LOG_DEBUG = LogLevel::LOG_DEBUG; 40 | 41 | inline LogLevel currentLogLevel = LOG_ERROR; 42 | 43 | template 44 | void logat(LogLevel level, const char* fmt, Args&&... args) { 45 | if (level > currentLogLevel) { 46 | return; 47 | } 48 | if (!pyLogging || pyLogging.is_none()) { 49 | std::lock_guard l(logMutex); 50 | time_t now = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); 51 | auto* tm = std::localtime(&now); 52 | char buf[0x40]; 53 | std::strftime(buf, sizeof(buf), "%d-%m-%Y %H:%M:%S", tm); 54 | auto s = fmt::sprintf(fmt, std::forward(args)...); 55 | if (!s.empty() && s.back() == '\n') { 56 | fmt::printf("%s: %s", buf, s); 57 | } else { 58 | fmt::printf("%s: %s\n", buf, s); 59 | } 60 | fflush(stdout); 61 | fflush(stderr); 62 | } else { 63 | auto s = fmt::sprintf(fmt, std::forward(args)...); 64 | if (s.size() && s.back() == '\n') { 65 | s.pop_back(); 66 | } 67 | s = fmt::sprintf("%d: %s", getpid(), s); 68 | py::gil_scoped_acquire gil; 69 | if (level == LOG_ERROR) { 70 | pyLogging.attr("error")(s); 71 | } else if (level == LOG_DEBUG) { 72 | // pyLogging.attr("debug")(s); 73 | pyLogging.attr("info")(s); 74 | } else { 75 | pyLogging.attr("info")(s); 76 | } 77 | } 78 | } 79 | 80 | inline struct Log { 81 | template 82 | void error(const char* fmt, Args&&... args) { 83 | logat(LOG_ERROR, fmt, std::forward(args)...); 84 | } 85 | template 86 | void info(const char* fmt, Args&&... args) { 87 | logat(LOG_INFO, fmt, std::forward(args)...); 88 | } 89 | template 90 | void verbose(const char* fmt, Args&&... args) { 91 | logat(LOG_VERBOSE, fmt, std::forward(args)...); 92 | } 93 | template 94 | void debug(const char* fmt, Args&&... args) { 95 | logat(LOG_DEBUG, fmt, std::forward(args)...); 96 | } 97 | } log; 98 | 99 | template 100 | [[noreturn]] void fatal(const char* fmt, Args&&... args) { 101 | auto s = fmt::sprintf(fmt, std::forward(args)...); 102 | log.error(" -- FATAL ERROR --\n%s\n", s); 103 | std::abort(); 104 | } 105 | 106 | } // namespace moolib 107 | -------------------------------------------------------------------------------- /src/memory/allocator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "synchronization.h" 11 | 12 | #include "memfd.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace rpc { 24 | 25 | inline memfd::MemfdAllocator& memfdAllocator = *new memfd::MemfdAllocator(); 26 | 27 | namespace allocimpl { 28 | 29 | template 30 | struct Storage { 31 | static constexpr size_t size = (Size + 63) / 64 * 64; 32 | Header* freelist = nullptr; 33 | size_t freelistSize = 0; 34 | 35 | struct alignas(64) GlobalList { 36 | SpinMutex mutex; 37 | std::vector> list; 38 | }; 39 | 40 | ~Storage() { 41 | for (Header* ptr = freelist; ptr;) { 42 | Header* next = ptr->next; 43 | memfdAllocator.deallocate(ptr, ptr->capacity + sizeof(Header)); 44 | ptr = next; 45 | } 46 | } 47 | 48 | inline static GlobalList global; 49 | 50 | Header* allocateFromGlobal() { 51 | std::unique_lock l(global.mutex); 52 | if (!global.list.empty()) { 53 | freelist = global.list.back().first; 54 | freelistSize = global.list.back().second; 55 | global.list.pop_back(); 56 | l.unlock(); 57 | return allocate(); 58 | } 59 | l.unlock(); 60 | auto a = memfdAllocator.allocate(size); 61 | Header* r = (Header*)a.first; 62 | new (r) Header(); 63 | r->capacity = size - sizeof(Header); 64 | return r; 65 | } 66 | 67 | Header* allocate() { 68 | static_assert(alignof(Header) <= 64 && alignof(Data) <= 64 && alignof(Data) <= sizeof(Header)); 69 | Header* r = freelist; 70 | if (r) { 71 | [[likely]]; 72 | --freelistSize; 73 | freelist = r->next; 74 | return r; 75 | } else { 76 | return allocateFromGlobal(); 77 | } 78 | } 79 | void moveFreelistToGlobal() { 80 | std::unique_lock l(global.mutex); 81 | global.list.push_back({freelist, freelistSize}); 82 | l.unlock(); 83 | freelist = nullptr; 84 | freelistSize = 0; 85 | } 86 | void deallocate(Header* ptr) { 87 | if (freelistSize >= std::min(1024 * 1024 / Size, 128)) { 88 | [[unlikely]]; 89 | moveFreelistToGlobal(); 90 | } 91 | ++freelistSize; 92 | ptr->next = freelist; 93 | freelist = ptr; 94 | } 95 | 96 | static Storage& get() { 97 | thread_local Storage storage; 98 | return storage; 99 | } 100 | }; 101 | 102 | } // namespace allocimpl 103 | 104 | template 105 | Header* allocate(size_t n) { 106 | constexpr size_t overhead = sizeof(Header); 107 | if (n + overhead <= 64) { 108 | return allocimpl::Storage::get().allocate(); 109 | } else if (n + overhead <= 256) { 110 | return allocimpl::Storage::get().allocate(); 111 | } else if (n + overhead <= 1024) { 112 | return allocimpl::Storage::get().allocate(); 113 | } else if (n + overhead <= 4096) { 114 | return allocimpl::Storage::get().allocate(); 115 | } else { 116 | auto a = memfdAllocator.allocate((sizeof(Header) + sizeof(Data) * n + 63) / 64 * 64); 117 | Header* h = (Header*)a.first; 118 | new (h) Header(); 119 | h->capacity = a.second - sizeof(Header); 120 | return h; 121 | } 122 | } 123 | template 124 | void deallocate(Header* ptr) { 125 | const size_t n = ptr->capacity + sizeof(Header); 126 | switch (n) { 127 | case 64: 128 | allocimpl::Storage::get().deallocate(ptr); 129 | break; 130 | case 256: 131 | allocimpl::Storage::get().deallocate(ptr); 132 | break; 133 | case 1024: 134 | allocimpl::Storage::get().deallocate(ptr); 135 | break; 136 | case 4096: 137 | allocimpl::Storage::get().deallocate(ptr); 138 | break; 139 | default: 140 | memfdAllocator.deallocate(ptr, ptr->capacity + sizeof(Header)); 141 | } 142 | } 143 | template 144 | Data* dataptr(Header* ptr) { 145 | return (Data*)(ptr + 1); 146 | } 147 | 148 | } // namespace rpc 149 | -------------------------------------------------------------------------------- /src/memory/buffer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "allocator.h" 11 | #include "tensor.h" 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace rpc { 20 | 21 | struct TensorRef { 22 | Tensor tensor; 23 | }; 24 | 25 | struct Buffer { 26 | Buffer* next{nullptr}; 27 | size_t capacity = 0; 28 | size_t size = 0; 29 | std::atomic_int refcount = 0; 30 | uint32_t nTensors = 0; 31 | std::byte* data() { 32 | return dataptr(this); 33 | } 34 | template 35 | static RT roundUpFor(P ptr) { 36 | uintptr_t v = (uintptr_t)(std::byte*)ptr; 37 | constexpr auto alignment = alignof(T); 38 | static_assert(alignment <= 64); 39 | if (alignment <= alignof(std::remove_pointer_t

)) { 40 | return (RT)v; 41 | } 42 | return (RT)((v + alignment - 1) / alignment * alignment); 43 | } 44 | size_t* tensorMetaDataOffsets() { 45 | return roundUpFor(data() + size); 46 | } 47 | TensorRef* tensors() { 48 | return roundUpFor(tensorMetaDataOffsets() + nTensors); 49 | } 50 | static size_t getAllocSize(size_t size, size_t nTensors) { 51 | uintptr_t ptr = 0; 52 | uintptr_t offsets = roundUpFor(ptr + size); 53 | uintptr_t tensors = roundUpFor(offsets + sizeof(size_t) * nTensors); 54 | return tensors + sizeof(TensorRef) * nTensors - ptr; 55 | } 56 | }; 57 | 58 | inline void destroyBuffer(Buffer* buffer) noexcept { 59 | if (buffer->nTensors) { 60 | auto* tensors = buffer->tensors(); 61 | for (size_t i = buffer->nTensors; i;) { 62 | --i; 63 | tensors[i].~TensorRef(); 64 | } 65 | auto* offsets = buffer->tensorMetaDataOffsets(); 66 | for (size_t i = buffer->nTensors; i;) { 67 | --i; 68 | offsets[i].~size_t(); 69 | } 70 | buffer->nTensors = 0; 71 | } 72 | } 73 | 74 | inline void shrinkBuffer(Buffer* buffer, size_t size, size_t nTensors) { 75 | auto* tensors = buffer->tensors(); 76 | for (size_t i = buffer->nTensors; i != nTensors;) { 77 | --i; 78 | tensors[i].~TensorRef(); 79 | } 80 | auto* offsets = buffer->tensorMetaDataOffsets(); 81 | for (size_t i = buffer->nTensors; i != nTensors;) { 82 | --i; 83 | offsets[i].~size_t(); 84 | } 85 | buffer->nTensors = nTensors; 86 | buffer->size = size; 87 | } 88 | 89 | struct BufferHandle { 90 | Buffer* buffer_ = nullptr; 91 | BufferHandle() = default; 92 | BufferHandle(std::nullptr_t) noexcept {} 93 | explicit BufferHandle(Buffer* buffer) noexcept : buffer_(buffer) {} 94 | BufferHandle(const BufferHandle&) = delete; 95 | BufferHandle& operator=(const BufferHandle&) = delete; 96 | BufferHandle(BufferHandle&& n) noexcept { 97 | buffer_ = n.buffer_; 98 | n.buffer_ = nullptr; 99 | } 100 | BufferHandle& operator=(BufferHandle&& n) noexcept { 101 | std::swap(buffer_, n.buffer_); 102 | return *this; 103 | } 104 | ~BufferHandle() { 105 | if (buffer_) { 106 | destroyBuffer(buffer_); 107 | deallocate(buffer_); 108 | } 109 | } 110 | explicit operator bool() const noexcept { 111 | return buffer_; 112 | } 113 | Buffer* operator->() const noexcept { 114 | return buffer_; 115 | } 116 | operator Buffer*() const noexcept { 117 | return buffer_; 118 | } 119 | Buffer* release() noexcept { 120 | Buffer* r = buffer_; 121 | buffer_ = nullptr; 122 | return r; 123 | } 124 | }; 125 | struct SharedBufferHandle { 126 | Buffer* buffer_ = nullptr; 127 | SharedBufferHandle() = default; 128 | SharedBufferHandle(std::nullptr_t) noexcept {} 129 | explicit SharedBufferHandle(Buffer* buffer) noexcept : buffer_(buffer) { 130 | if (buffer_) { 131 | if (buffer->refcount != 0) { 132 | std::abort(); 133 | } 134 | addref(); 135 | } 136 | } 137 | SharedBufferHandle(const SharedBufferHandle& n) noexcept { 138 | buffer_ = n.buffer_; 139 | if (buffer_) { 140 | addref(); 141 | } 142 | } 143 | SharedBufferHandle& operator=(const SharedBufferHandle& n) noexcept { 144 | buffer_ = n.buffer_; 145 | if (buffer_) { 146 | addref(); 147 | } 148 | return *this; 149 | } 150 | SharedBufferHandle(SharedBufferHandle&& n) noexcept { 151 | buffer_ = n.buffer_; 152 | n.buffer_ = nullptr; 153 | } 154 | SharedBufferHandle& operator=(SharedBufferHandle&& n) noexcept { 155 | std::swap(buffer_, n.buffer_); 156 | return *this; 157 | } 158 | ~SharedBufferHandle() { 159 | if (buffer_ && decref() == 0) { 160 | destroyBuffer(buffer_); 161 | deallocate(buffer_); 162 | } 163 | } 164 | explicit operator bool() const noexcept { 165 | return buffer_; 166 | } 167 | Buffer* operator->() const noexcept { 168 | return buffer_; 169 | } 170 | operator Buffer*() const noexcept { 171 | return buffer_; 172 | } 173 | int addref() noexcept { 174 | return buffer_->refcount.fetch_add(1, std::memory_order_acquire) + 1; 175 | } 176 | int decref() noexcept { 177 | return buffer_->refcount.fetch_sub(1) - 1; 178 | } 179 | Buffer* release() noexcept { 180 | Buffer* r = buffer_; 181 | buffer_ = nullptr; 182 | return r; 183 | } 184 | void acquire(Buffer* buffer) noexcept { 185 | buffer_ = buffer; 186 | } 187 | }; 188 | 189 | inline BufferHandle makeBuffer(size_t size, size_t nTensors) noexcept { 190 | size_t allocsize = size; 191 | if (nTensors) { 192 | allocsize = Buffer::getAllocSize(size, nTensors); 193 | } 194 | BufferHandle buffer{allocate(allocsize)}; 195 | buffer->size = size; 196 | buffer->nTensors = nTensors; 197 | if (nTensors) { 198 | auto* offsets = buffer->tensorMetaDataOffsets(); 199 | for (size_t i = 0; i != nTensors; ++i) { 200 | new (offsets + i) size_t{}; 201 | } 202 | auto* tensors = buffer->tensors(); 203 | for (size_t i = 0; i != nTensors; ++i) { 204 | new (tensors + i) TensorRef{}; 205 | } 206 | } 207 | return buffer; 208 | } 209 | 210 | } // namespace rpc 211 | -------------------------------------------------------------------------------- /src/memory/memfd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace rpc { 10 | namespace memfd { 11 | 12 | struct Memfd { 13 | int fd = -1; 14 | void* base = nullptr; 15 | size_t size = 0; 16 | 17 | Memfd() = default; 18 | ~Memfd(); 19 | Memfd(const Memfd&) = delete; 20 | Memfd(Memfd&& n) { 21 | std::swap(fd, n.fd); 22 | std::swap(base, n.base); 23 | std::swap(size, n.size); 24 | } 25 | Memfd& operator=(const Memfd&) = delete; 26 | Memfd& operator=(Memfd&& n) { 27 | std::swap(fd, n.fd); 28 | std::swap(base, n.base); 29 | std::swap(size, n.size); 30 | return *this; 31 | } 32 | 33 | static Memfd create(size_t size); 34 | static Memfd map(int fd, size_t size); 35 | }; 36 | 37 | struct AddressInfo { 38 | int fd = -1; 39 | size_t fdSize = 0; 40 | size_t offset = 0; 41 | }; 42 | 43 | struct MemfdAllocatorImpl; 44 | struct MemfdAllocator { 45 | std::unique_ptr impl; 46 | MemfdAllocator(); 47 | ~MemfdAllocator(); 48 | std::pair getMemfd(int fd); 49 | AddressInfo getAddressInfo(void* ptr); 50 | std::pair allocate(size_t size); 51 | void deallocate(void* ptr, size_t size); 52 | }; 53 | 54 | } // namespace memfd 55 | 56 | } // namespace rpc -------------------------------------------------------------------------------- /src/pytorch.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "serialization.h" 11 | #include "tensor.h" 12 | 13 | #include 14 | 15 | namespace rpc { 16 | 17 | Tensor torchTensorToTensor(const torch::Tensor&); 18 | torch::Tensor tensorToTorchTensor(Tensor&&); 19 | 20 | template 21 | void serialize(X& x, const torch::Tensor& v) { 22 | serialize(x, torchTensorToTensor(v)); 23 | } 24 | 25 | template 26 | void serialize(X& x, torch::Tensor& v) { 27 | v = tensorToTorchTensor(x.template read()); 28 | } 29 | 30 | } // namespace rpc 31 | -------------------------------------------------------------------------------- /src/pyutil.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "pybind11/pybind11.h" 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace moolib { 17 | 18 | namespace py = pybind11; 19 | 20 | template 21 | struct glock { 22 | std::unique_lock ulock; 23 | glock(T& mutex) : ulock(mutex, std::try_to_lock) { 24 | if (!ulock.owns_lock()) { 25 | if (PyGILState_Check()) { 26 | py::gil_scoped_release gil; 27 | ulock.lock(); 28 | } else { 29 | ulock.lock(); 30 | } 31 | } 32 | } 33 | void lock() { 34 | ulock.lock(); 35 | } 36 | void unlock() { 37 | ulock.unlock(); 38 | } 39 | }; 40 | 41 | template 42 | struct GilWrapper { 43 | std::optional obj; 44 | GilWrapper() = default; 45 | GilWrapper(const T& n) { 46 | py::gil_scoped_acquire gil; 47 | obj = n; 48 | } 49 | GilWrapper(T&& n) { 50 | obj = std::move(n); 51 | } 52 | GilWrapper(const GilWrapper& n) { 53 | py::gil_scoped_acquire gil; 54 | obj = n.obj; 55 | } 56 | GilWrapper(GilWrapper&& n) { 57 | obj = std::move(n.obj); 58 | } 59 | ~GilWrapper() { 60 | if (obj && *obj) { 61 | py::gil_scoped_acquire gil; 62 | obj.reset(); 63 | } 64 | } 65 | T release() { 66 | T r = std::move(obj.value()); 67 | obj.reset(); 68 | return r; 69 | } 70 | void reset() { 71 | obj.reset(); 72 | } 73 | operator bool() const noexcept { 74 | return obj.has_value(); 75 | } 76 | T& operator*() & { 77 | return *obj; 78 | } 79 | T&& operator*() && { 80 | return std::move(*obj); 81 | } 82 | T* operator->() { 83 | return &*obj; 84 | } 85 | GilWrapper& operator=(const GilWrapper& n) { 86 | py::gil_scoped_acquire gil; 87 | obj = n.obj; 88 | return *this; 89 | } 90 | GilWrapper& operator=(GilWrapper&& n) { 91 | if (obj && *obj) { 92 | // acquire GIL here as existing object needs to be released 93 | py::gil_scoped_acquire gil; 94 | obj = std::move(n.obj); 95 | } else { 96 | obj = std::move(n.obj); 97 | } 98 | return *this; 99 | } 100 | template 101 | void serialize(X& x) { 102 | py::gil_scoped_acquire gil; 103 | obj.emplace(); 104 | x(*obj); 105 | } 106 | template 107 | void serialize(X& x) const { 108 | py::gil_scoped_acquire gil; 109 | x(*obj); 110 | } 111 | }; 112 | 113 | } // namespace moolib 114 | -------------------------------------------------------------------------------- /src/shm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "logging.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | #ifndef __APPLE__ 22 | #include 23 | #else 24 | #include 25 | #endif 26 | #include 27 | 28 | namespace moolib { 29 | 30 | struct SharedMemory { 31 | int fd = -1; 32 | size_t size = 1024 * 1024 * 400; 33 | std::byte* data = nullptr; 34 | std::string name; 35 | bool unlinked = false; 36 | 37 | static_assert(std::atomic_bool::is_always_lock_free, "need lock-free atomics"); 38 | struct InitBlock { 39 | std::atomic_bool initialized; 40 | std::atomic_bool initializing; 41 | }; 42 | 43 | InitBlock* initBlock = nullptr; 44 | 45 | SharedMemory(std::string_view name) : name(name) { 46 | log.verbose("creating shm %s\n", name); 47 | fd = shm_open(std::string(name).c_str(), O_RDWR | O_CREAT, ACCESSPERMS); 48 | if (fd < 0) { 49 | throw std::system_error(errno, std::system_category(), "shm_open"); 50 | } 51 | if (ftruncate(fd, size)) { 52 | /* Fails on OSX after the first time but can be ignored. */ 53 | log.verbose("ftruncate failed on shm: %d\n", errno); 54 | } 55 | data = (std::byte*)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); 56 | if (!data) { 57 | throw std::system_error(errno, std::system_category(), "mmap"); 58 | } 59 | initBlock = (InitBlock*)data; 60 | size_t n = (sizeof(InitBlock) + 63) / 64 * 64; 61 | data += n; 62 | size -= n; 63 | } 64 | ~SharedMemory() { 65 | munmap(data, size); 66 | close(fd); 67 | if (!unlinked) { 68 | shm_unlink(name.c_str()); 69 | } 70 | } 71 | void unlink() { 72 | if (!unlinked) { 73 | shm_unlink(name.c_str()); 74 | unlinked = true; 75 | } 76 | } 77 | template 78 | T& as() { 79 | if (sizeof(T) > size) { 80 | fatal("%s is too big for shm :(\n", typeid(T).name()); 81 | } 82 | T& r = *(T*)data; 83 | if (!initBlock->initialized) { 84 | if (initBlock->initializing.exchange(true)) { 85 | while (!initBlock->initialized) 86 | ; 87 | return r; 88 | } 89 | r.init(size); 90 | initBlock->initialized = true; 91 | } 92 | return r; 93 | } 94 | }; 95 | 96 | class SharedSemaphore { 97 | #ifndef __APPLE__ 98 | sem_t sem; 99 | 100 | public: 101 | SharedSemaphore() noexcept { 102 | sem_init(&sem, 1, 0); 103 | } 104 | 105 | ~SharedSemaphore() { 106 | sem_destroy(&sem); 107 | } 108 | 109 | void post() noexcept { 110 | sem_post(&sem); 111 | } 112 | 113 | void wait() noexcept { 114 | while (sem_wait(&sem)) { 115 | if (errno != EINTR) { 116 | std::abort(); 117 | } 118 | } 119 | } 120 | 121 | template 122 | void wait_for(const std::chrono::duration& duration) noexcept { 123 | struct timespec ts; 124 | SharedSemaphore::fill_ts(ts, duration); 125 | while (sem_timedwait(&sem, &ts)) { 126 | if (errno == ETIMEDOUT) { 127 | break; 128 | } 129 | if (errno != EINTR) { 130 | std::abort(); 131 | } 132 | } 133 | } 134 | 135 | #else /* __APPLE__ */ 136 | pthread_mutex_t mu; 137 | pthread_cond_t cv; 138 | unsigned value; 139 | 140 | struct Lock { 141 | Lock(pthread_mutex_t* mutex) : mu(mutex) { 142 | int rc = pthread_mutex_lock(mu); 143 | if (rc) { 144 | fatal("pthread_mutex_lock: %d\n", rc); 145 | } 146 | } 147 | ~Lock() { 148 | int rc = pthread_mutex_unlock(mu); 149 | if (rc) { 150 | fatal("pthread_mutex_unlock: %d\n", rc); 151 | } 152 | } 153 | pthread_mutex_t* mu; 154 | }; 155 | 156 | public: 157 | SharedSemaphore() noexcept : value(0) { 158 | pthread_mutexattr_t psharedm; 159 | pthread_condattr_t psharedc; 160 | 161 | pthread_mutexattr_init(&psharedm); 162 | pthread_mutexattr_setpshared(&psharedm, PTHREAD_PROCESS_SHARED); 163 | pthread_condattr_init(&psharedc); 164 | pthread_condattr_setpshared(&psharedc, PTHREAD_PROCESS_SHARED); 165 | 166 | pthread_mutex_init(&mu, &psharedm); 167 | pthread_cond_init(&cv, &psharedc); 168 | } 169 | 170 | ~SharedSemaphore() { 171 | pthread_cond_destroy(&cv); 172 | pthread_mutex_destroy(&mu); 173 | } 174 | 175 | void post() noexcept { 176 | Lock lock(&mu); 177 | if (value == 0) { 178 | pthread_cond_signal(&cv); 179 | } 180 | value++; 181 | } 182 | 183 | void wait() noexcept { 184 | Lock lock(&mu); 185 | while (value == 0) { 186 | if (pthread_cond_wait(&cv, &mu)) { 187 | std::abort(); 188 | } 189 | } 190 | --value; 191 | } 192 | 193 | template 194 | void wait_for(const std::chrono::duration& duration) noexcept { 195 | struct timespec ts; 196 | SharedSemaphore::fill_ts(ts, duration); 197 | 198 | Lock lock(&mu); 199 | 200 | int rc = 0; 201 | while (value == 0 && rc == 0) { 202 | rc = pthread_cond_timedwait(&cv, &mu, &ts); 203 | } 204 | 205 | if (rc == 0) { 206 | --value; 207 | } else if (rc != ETIMEDOUT) { 208 | std::abort(); 209 | } 210 | } 211 | #endif /* __APPLE__ */ 212 | 213 | template 214 | void wait_until(const std::chrono::time_point& timePoint) noexcept { 215 | wait_for(timePoint - Clock::now()); 216 | } 217 | 218 | SharedSemaphore(const SharedSemaphore&) = delete; 219 | SharedSemaphore(const SharedSemaphore&&) = delete; 220 | SharedSemaphore& operator=(const SharedSemaphore&) = delete; 221 | SharedSemaphore& operator=(const SharedSemaphore&&) = delete; 222 | 223 | private: 224 | template 225 | static void fill_ts(TimeSpec& ts, const std::chrono::duration& duration) { 226 | auto absduration = std::chrono::system_clock::now().time_since_epoch() + duration; 227 | auto nanoseconds = std::chrono::duration_cast(absduration); 228 | auto seconds = std::chrono::duration_cast(nanoseconds); 229 | ts.tv_sec = seconds.count(); 230 | ts.tv_nsec = (nanoseconds - seconds).count(); 231 | } 232 | }; 233 | 234 | template 235 | struct SharedPointer { 236 | size_t offset = 0; 237 | template 238 | T* operator()(Shared* shared) { 239 | return (T*)((std::byte*)shared + offset); 240 | } 241 | }; 242 | 243 | template 244 | struct SharedArray { 245 | size_t size; 246 | SharedPointer data; 247 | 248 | template 249 | std::basic_string_view view(Shared* shared) { 250 | return {data(shared), size}; 251 | } 252 | 253 | template 254 | T& operator()(Shared* shared, size_t index) { 255 | return data(shared)[index]; 256 | } 257 | template 258 | T* operator()(Shared* shared) { 259 | return data(shared); 260 | } 261 | }; 262 | 263 | } // namespace moolib 264 | -------------------------------------------------------------------------------- /src/synchronization.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #ifdef __linux__ 16 | #include 17 | #endif 18 | 19 | #include 20 | 21 | namespace rpc { 22 | 23 | #if 0 24 | using SpinMutex = std::mutex; 25 | #elif 0 26 | inline std::atomic_int mutexThreadIdCounter = 0; 27 | inline thread_local int mutexThreadId = mutexThreadIdCounter++; 28 | class SpinMutex { 29 | int magic = 0x42; 30 | std::atomic locked_{false}; 31 | std::atomic owner = nullptr; 32 | 33 | public: 34 | void lock() { 35 | if (owner == &mutexThreadId) { 36 | printf("recursive lock\n"); 37 | std::abort(); 38 | } 39 | if (magic != 0x42) { 40 | printf("BAD MUTEX MAGIC\n"); 41 | std::abort(); 42 | } 43 | auto start = std::chrono::steady_clock::now(); 44 | do { 45 | while (locked_.load(std::memory_order_acquire)) { 46 | _mm_pause(); 47 | if (magic != 0x42) { 48 | printf("BAD MUTEX MAGIC\n"); 49 | std::abort(); 50 | } 51 | if (std::chrono::steady_clock::now() - start >= std::chrono::seconds(10)) { 52 | int* p = owner.load(); 53 | printf("deadlock detected in thread %d! held by thread %d\n", mutexThreadId, p ? *p : -1); 54 | start = std::chrono::steady_clock::now(); 55 | } 56 | } 57 | } while (locked_.exchange(true, std::memory_order_acquire)); 58 | owner = &mutexThreadId; 59 | } 60 | void unlock() { 61 | if (magic != 0x42) { 62 | printf("BAD MUTEX MAGIC\n"); 63 | std::abort(); 64 | } 65 | owner = nullptr; 66 | locked_.store(false); 67 | } 68 | bool try_lock() { 69 | if (owner == &mutexThreadId) { 70 | printf("recursive try_lock\n"); 71 | std::abort(); 72 | } 73 | if (locked_.load(std::memory_order_acquire)) { 74 | return false; 75 | } 76 | bool r = !locked_.exchange(true, std::memory_order_acquire); 77 | if (r) { 78 | owner = &mutexThreadId; 79 | } 80 | return r; 81 | } 82 | }; 83 | #else 84 | class SpinMutex { 85 | std::atomic locked = false; 86 | 87 | public: 88 | void lock() { 89 | do { 90 | while (locked.load(std::memory_order_relaxed)) { 91 | _mm_pause(); 92 | } 93 | } while (locked.exchange(true, std::memory_order_acquire)); 94 | } 95 | void unlock() { 96 | locked.store(false, std::memory_order_release); 97 | } 98 | bool try_lock() { 99 | if (locked.load(std::memory_order_relaxed)) { 100 | return false; 101 | } 102 | return !locked.exchange(true, std::memory_order_acquire); 103 | } 104 | }; 105 | #endif 106 | 107 | #if 0 108 | using SharedSpinMutex = std::shared_mutex; 109 | #else 110 | class SharedSpinMutex { 111 | std::atomic_bool locked = false; 112 | std::atomic_int shareCount = 0; 113 | 114 | public: 115 | void lock() { 116 | do { 117 | while (locked.load(std::memory_order_relaxed)) { 118 | _mm_pause(); 119 | } 120 | } while (locked.exchange(true, std::memory_order_acquire)); 121 | while (shareCount.load(std::memory_order_relaxed)) { 122 | _mm_pause(); 123 | } 124 | } 125 | void unlock() { 126 | locked.store(false, std::memory_order_release); 127 | } 128 | bool try_lock() { 129 | if (locked.load(std::memory_order_relaxed)) { 130 | return false; 131 | } 132 | if (shareCount.load(std::memory_order_relaxed) == 0 && !locked.exchange(true, std::memory_order_acquire)) { 133 | if (shareCount.load(std::memory_order_relaxed) == 0) { 134 | return true; 135 | } else { 136 | locked.store(false, std::memory_order_relaxed); 137 | } 138 | } 139 | return false; 140 | } 141 | void lock_shared() { 142 | while (true) { 143 | while (locked.load(std::memory_order_relaxed)) { 144 | _mm_pause(); 145 | } 146 | shareCount.fetch_add(1, std::memory_order_acq_rel); 147 | if (locked.load(std::memory_order_relaxed)) { 148 | shareCount.fetch_sub(1, std::memory_order_acquire); 149 | } else { 150 | break; 151 | } 152 | } 153 | } 154 | void unlock_shared() { 155 | shareCount.fetch_sub(1, std::memory_order_release); 156 | } 157 | bool try_lock_shared() { 158 | if (locked.load(std::memory_order_relaxed)) { 159 | return false; 160 | } 161 | shareCount.fetch_add(1, std::memory_order_acq_rel); 162 | if (locked.load(std::memory_order_relaxed)) { 163 | shareCount.fetch_sub(1, std::memory_order_acquire); 164 | return false; 165 | } 166 | return true; 167 | } 168 | }; 169 | #endif 170 | 171 | #ifdef __linux__ 172 | class Semaphore { 173 | sem_t sem; 174 | 175 | public: 176 | Semaphore() noexcept { 177 | sem_init(&sem, 0, 0); 178 | } 179 | ~Semaphore() { 180 | sem_destroy(&sem); 181 | } 182 | void post() noexcept { 183 | sem_post(&sem); 184 | } 185 | void wait() noexcept { 186 | while (sem_wait(&sem)) { 187 | if (errno != EINTR) { 188 | printf("sem_wait returned errno %d", (int)errno); 189 | std::abort(); 190 | } 191 | } 192 | } 193 | template 194 | void wait_for(const std::chrono::duration& duration) noexcept { 195 | struct timespec ts; 196 | auto absduration = std::chrono::system_clock::now().time_since_epoch() + duration; 197 | auto nanoseconds = std::chrono::duration_cast(absduration); 198 | auto seconds = std::chrono::duration_cast(nanoseconds); 199 | ts.tv_sec = seconds.count(); 200 | ts.tv_nsec = (nanoseconds - seconds).count(); 201 | while (sem_timedwait(&sem, &ts)) { 202 | if (errno == ETIMEDOUT) { 203 | break; 204 | } 205 | if (errno != EINTR) { 206 | printf("sem_timedwait returned errno %d", (int)errno); 207 | std::abort(); 208 | } 209 | } 210 | } 211 | template 212 | void wait_until(const std::chrono::time_point& timePoint) noexcept { 213 | wait_for(timePoint - Clock::now()); 214 | } 215 | 216 | Semaphore(const Semaphore&) = delete; 217 | Semaphore(const Semaphore&&) = delete; 218 | Semaphore& operator=(const Semaphore&) = delete; 219 | Semaphore& operator=(const Semaphore&&) = delete; 220 | }; 221 | #else 222 | class Semaphore { 223 | int count_ = 0; 224 | std::mutex mut_; 225 | std::condition_variable cv_; 226 | 227 | public: 228 | void post() { 229 | std::unique_lock l(mut_); 230 | if (++count_ >= 1) { 231 | cv_.notify_one(); 232 | } 233 | } 234 | void wait() { 235 | std::unique_lock l(mut_); 236 | while (count_ == 0) { 237 | cv_.wait(l); 238 | } 239 | --count_; 240 | } 241 | template 242 | void wait_until(const std::chrono::time_point& timePoint) noexcept { 243 | std::unique_lock l(mut_); 244 | while (count_ == 0) { 245 | if (cv_.wait_until(l, timePoint) == std::cv_status::timeout) return; 246 | } 247 | --count_; 248 | } 249 | 250 | template 251 | void wait_for(const std::chrono::duration& duration) noexcept { 252 | std::unique_lock l(mut_); 253 | if (cv_.wait_for(l, duration, [this]() { return count_ > 0; })) --count_; 254 | } 255 | }; 256 | #endif 257 | 258 | } // namespace rpc 259 | -------------------------------------------------------------------------------- /src/tensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "any.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace rpc { 19 | 20 | template 21 | using ArrayRef = std::basic_string_view; 22 | 23 | using IntArrayRef = ArrayRef; 24 | 25 | enum class DeviceType { Cpu, Cuda, Unknown }; 26 | 27 | struct Device { 28 | DeviceType type = DeviceType::Unknown; 29 | int index = -1; 30 | Device() = default; 31 | Device(DeviceType type) : type(type) {} 32 | Device(DeviceType type, int index) : type(type), index(index) {} 33 | Device(std::string_view str); 34 | }; 35 | 36 | struct CUDAStream; 37 | 38 | struct Tensor { 39 | moolib::Any<32> impl; 40 | 41 | Tensor(); 42 | Tensor(std::nullptr_t); 43 | 44 | Tensor(const Tensor&); 45 | Tensor(Tensor&&); 46 | Tensor& operator=(const Tensor&); 47 | Tensor& operator=(Tensor&&); 48 | 49 | Device device() const; 50 | int scalar_type() const; 51 | IntArrayRef sizes() const; 52 | IntArrayRef strides() const; 53 | int64_t dim() const; 54 | int64_t size(int64_t dim) const; 55 | 56 | bool is_cuda() const; 57 | int itemsize() const; 58 | 59 | void* data_ptr(); 60 | 61 | template 62 | T* data() { 63 | return (T*)data_ptr(); 64 | } 65 | template 66 | T item() { 67 | return *data(); 68 | } 69 | 70 | // These constant values match with pytorch 71 | static constexpr int kUInt8 = 0; 72 | static constexpr int kInt8 = 1; 73 | static constexpr int kInt16 = 2; 74 | static constexpr int kInt32 = 3; 75 | static constexpr int kInt64 = 4; 76 | static constexpr int kFloat16 = 5; 77 | static constexpr int kFloat32 = 6; 78 | static constexpr int kFloat64 = 7; 79 | static constexpr int kBool = 11; 80 | 81 | Tensor pin_memory() const; 82 | bool defined() const; 83 | Tensor cpu() const; 84 | Tensor& copy_(const Tensor& n, bool non_blocking = false); 85 | Tensor sum() const; 86 | 87 | Tensor& operator+=(const Tensor&); 88 | Tensor& operator*=(const Tensor&); 89 | 90 | Tensor mutable_grad(); 91 | Tensor grad() const; 92 | void set_grad(Tensor); 93 | Tensor& detach_(); 94 | Tensor detach() const; 95 | Tensor& zero_(); 96 | Tensor& mul_(float n); 97 | Tensor& add_(const Tensor&); 98 | Tensor to(Device device, bool non_blocking = false, bool copy = false) const; 99 | bool requires_grad() const; 100 | int64_t numel() const; 101 | Tensor select(int64_t dim, int64_t index) const; 102 | Tensor narrow(int64_t dim, int64_t start, int64_t length) const; 103 | Tensor flatten(int64_t start_dim = 0, int64_t end_dim = -1) const; 104 | Tensor view_as(const Tensor& n) const; 105 | Tensor clone() const; 106 | Tensor view(IntArrayRef) const; 107 | Tensor view(const std::vector&) const; 108 | Tensor squeeze(int64_t dim) const; 109 | Tensor& squeeze_(int64_t dim); 110 | Tensor unsqueeze(int64_t dim) const; 111 | Tensor& unsqueeze_(int64_t dim); 112 | 113 | Tensor operator*(const Tensor& n) const; 114 | 115 | // void record_stream(CUDAStream); 116 | 117 | Tensor operator[](size_t index); 118 | }; 119 | 120 | struct Allocator { 121 | moolib::Any<40> impl; 122 | Allocator(); 123 | Allocator(Device device, size_t bytes); 124 | Allocator(Allocator&&); 125 | Allocator(void* ptr, size_t bytes, Device device, void* context, void (*deleter)(void*)); 126 | std::byte* data() const; 127 | size_t size() const; 128 | Tensor set(int dtype, IntArrayRef sizes, IntArrayRef strides); 129 | }; 130 | 131 | constexpr auto kCPU = DeviceType::Cpu; 132 | constexpr auto kCUDA = DeviceType::Cuda; 133 | 134 | Tensor zeros_like(const Tensor&); 135 | Tensor zeros_like(const Tensor&, Device); 136 | Tensor empty(IntArrayRef sizes, int dtype, Device d); 137 | 138 | Tensor from_blob(int dtype, IntArrayRef sizes, void* data); 139 | Tensor& min_out(Tensor& out, const Tensor&, const Tensor&); 140 | Tensor& max_out(Tensor& out, const Tensor&, const Tensor&); 141 | Tensor cat(const std::vector& tensors, int64_t dim); 142 | Tensor stack(const std::vector& tensors, int64_t dim); 143 | std::vector unbind(const Tensor& input, int64_t dim); 144 | Tensor cat(const std::vector& list, int64_t dim = 0); 145 | 146 | struct AutoGradMode { 147 | moolib::Any<8> impl; 148 | AutoGradMode(bool enabled); 149 | ~AutoGradMode(); 150 | }; 151 | 152 | template 153 | void serialize(X& x, const Tensor& v) { 154 | x.addTensor(v, x.tell()); 155 | x(v.scalar_type(), v.sizes(), v.strides()); 156 | } 157 | 158 | template 159 | void serialize(X& x, Tensor& v) { 160 | decltype(v.scalar_type()) dtype; 161 | decltype(v.sizes()) sizes; 162 | decltype(v.strides()) strides; 163 | x(dtype, sizes, strides); 164 | v = std::move(x.getTensor().tensor); 165 | } 166 | 167 | bool CudaSupported(); 168 | 169 | struct CUDAStream { 170 | moolib::Any<16> impl; 171 | CUDAStream(std::nullptr_t); 172 | CUDAStream(const CUDAStream&); 173 | ~CUDAStream(); 174 | CUDAStream& operator=(const CUDAStream&); 175 | void synchronize(); 176 | void* nativeHandle(); 177 | int deviceIndex(); 178 | }; 179 | 180 | struct CUDAStreamGuard { 181 | moolib::Any<64> impl; 182 | CUDAStreamGuard(const CUDAStream&); 183 | ~CUDAStreamGuard(); 184 | }; 185 | 186 | CUDAStream getCurrentCUDAStream(int device_index = -1); 187 | CUDAStream getStreamFromPool(bool highPriority = false, int device_index = -1); 188 | 189 | struct CUDAEvent { 190 | moolib::Any<16> impl; 191 | CUDAEvent(); 192 | CUDAEvent(std::nullptr_t); 193 | CUDAEvent(CUDAEvent&&); 194 | ~CUDAEvent(); 195 | CUDAEvent& operator=(CUDAEvent&&); 196 | void record(const CUDAStream&); 197 | void block(const CUDAStream&); 198 | void synchronize() const; 199 | }; 200 | 201 | } // namespace rpc 202 | -------------------------------------------------------------------------------- /src/tensorpython.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "tensor.h" 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace rpc { 17 | 18 | pybind11::object toPython(const Tensor& t) { 19 | return py::reinterpret_steal(THPVariable_Wrap(t.impl.as())); 20 | } 21 | 22 | std::optional tryFromPython(pybind11::handle v) { 23 | if (THPVariable_Check(v.ptr())) { 24 | Tensor r(nullptr); 25 | r.impl.emplace(THPVariable_Unpack(v.ptr())); 26 | return r; 27 | } else { 28 | return {}; 29 | } 30 | } 31 | 32 | void setPythonTensor(pybind11::handle o, const Tensor& t) { 33 | // pytorch used to return a non-const reference here. 34 | // Now it returns a const reference, but we really would like 35 | // to be able to assign to the internal tensor, and nothing 36 | // seems to catch on fire for our use case. 37 | const_cast(THPVariable_Unpack(o.ptr())) = t.impl.as(); 38 | } 39 | 40 | } // namespace rpc 41 | -------------------------------------------------------------------------------- /src/transports/ipc.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "rpc.h" 11 | #include "socket.h" 12 | #include "synchronization.h" 13 | 14 | #include "fmt/printf.h" 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | namespace rpc { 22 | 23 | namespace ipc { 24 | 25 | struct Connection; 26 | struct Listener; 27 | 28 | struct UnixContext { 29 | std::shared_ptr listen(std::string_view addr); 30 | std::shared_ptr connect(std::string_view addr); 31 | static bool isReachable(std::string_view networkKey, std::string_view address); 32 | static std::string getNetworkKey(); 33 | }; 34 | 35 | struct TcpContext { 36 | std::shared_ptr listen(std::string_view addr); 37 | std::shared_ptr connect(std::string_view addr); 38 | static bool isReachable(std::string_view networkKey, std::string_view address); 39 | static std::string getNetworkKey(); 40 | }; 41 | 42 | struct Listener { 43 | Socket socket; 44 | Listener(Socket socket) : socket(std::move(socket)) {} 45 | 46 | void close() { 47 | socket.close(); 48 | } 49 | 50 | void accept(Function)> callback); 51 | 52 | std::vector localAddresses() const; 53 | }; 54 | 55 | struct ConnectionImpl; 56 | struct Connection : std::enable_shared_from_this { 57 | Socket socket; 58 | 59 | Connection(Socket socket) : socket(std::move(socket)) {} 60 | ~Connection(); 61 | 62 | void close(); 63 | void read(Function); 64 | template 65 | void write(Buffer buffer, Function); 66 | 67 | std::string localAddress() const; 68 | std::string remoteAddress() const; 69 | }; 70 | 71 | } // namespace ipc 72 | 73 | } // namespace rpc 74 | -------------------------------------------------------------------------------- /src/transports/socket.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "rpc.h" 11 | #include "vector.h" 12 | 13 | #include "fmt/printf.h" 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | namespace rpc { 25 | 26 | std::pair decodeIpAddress(std::string_view address); 27 | 28 | inline bool isLoopbackAddress(std::string_view address) { 29 | std::tie(address, std::ignore) = decodeIpAddress(address); 30 | return address == "::1" || address == "127.0.0.1"; 31 | } 32 | 33 | inline bool isAnyAddress(std::string_view address) { 34 | std::tie(address, std::ignore) = decodeIpAddress(address); 35 | return address == "" || address == "::" || address == "0.0.0.0"; 36 | } 37 | 38 | struct iovec { 39 | void* iov_base = nullptr; 40 | size_t iov_len = 0; 41 | }; 42 | 43 | struct CachedReader; 44 | 45 | struct SocketImpl; 46 | struct Socket { 47 | std::shared_ptr impl; 48 | Socket(); 49 | Socket(const Socket&) = delete; 50 | Socket(Socket&& n); 51 | ~Socket(); 52 | Socket& operator=(const Socket&) = delete; 53 | Socket& operator=(Socket&& n); 54 | static Socket Unix(); 55 | static Socket Tcp(); 56 | 57 | void close(); 58 | 59 | void listen(std::string_view address); 60 | void accept(Function callback); 61 | void connect(std::string_view address, Function callback); 62 | 63 | void writev(const iovec* vec, size_t veclen, Function callback); 64 | 65 | void setOnRead(Function*)> callback); 66 | 67 | size_t readv(const iovec* vec, size_t veclen); 68 | 69 | void sendFd(int fd, Function callback); 70 | int recvFd(CachedReader& reader); 71 | 72 | std::vector localAddresses() const; 73 | std::string localAddress() const; 74 | std::string remoteAddress() const; 75 | 76 | int nativeFd() const; 77 | }; 78 | 79 | struct CachedReader { 80 | moolib::Vector iovecs; 81 | size_t iovecsOffset; 82 | Socket* socket; 83 | size_t bufferFilled = 0; 84 | size_t bufferOffset = 0; 85 | moolib::Vector buffer; 86 | CachedReader(Socket* socket) : socket(socket) {} 87 | void newRead() { 88 | iovecs.clear(); 89 | } 90 | void addIovec(const iovec& v) { 91 | iovecs.push_back(v); 92 | } 93 | void addIovec(void* dst, size_t len) { 94 | iovecs.push_back(iovec{dst, len}); 95 | } 96 | void startRead() { 97 | size_t iovecsOffset = 0; 98 | size_t skip = bufferFilled - bufferOffset; 99 | if (skip) { 100 | size_t left = skip; 101 | size_t offset = bufferOffset; 102 | char* src = buffer.data() + bufferOffset; 103 | const char* end = buffer.data() + bufferFilled; 104 | for (auto& v : iovecs) { 105 | size_t n = std::min(left, v.iov_len); 106 | std::memcpy(v.iov_base, src, n); 107 | v.iov_base = (char*)v.iov_base + n; 108 | v.iov_len -= n; 109 | if (v.iov_len == 0) { 110 | ++iovecsOffset; 111 | } 112 | src += n; 113 | left -= n; 114 | if (left == 0) { 115 | break; 116 | } 117 | } 118 | bufferOffset = src - buffer.data(); 119 | if (bufferOffset == bufferFilled) { 120 | bufferOffset = 0; 121 | bufferFilled = 0; 122 | } 123 | } 124 | iovecs.push_back({buffer.data() + bufferFilled, buffer.size() - bufferFilled}); 125 | this->iovecsOffset = iovecsOffset; 126 | } 127 | static constexpr size_t maxBufferSize = 256 * 1024; 128 | bool done() { 129 | if (iovecsOffset && iovecsOffset == iovecs.size() - 1) { 130 | return true; 131 | } 132 | size_t i = iovecsOffset; 133 | size_t e = iovecs.size() - 1; 134 | size_t n = socket->readv(iovecs.data() + iovecsOffset, iovecs.size() - iovecsOffset); 135 | for (; i != e; ++i) { 136 | auto& v = iovecs[i]; 137 | if (n >= v.iov_len) { 138 | n -= v.iov_len; 139 | v.iov_len = 0; 140 | ++iovecsOffset; 141 | if (n == 0) { 142 | ++i; 143 | break; 144 | } 145 | } else { 146 | v.iov_base = (char*)v.iov_base + n; 147 | v.iov_len -= n; 148 | return false; 149 | } 150 | } 151 | if (i == e) { 152 | bufferFilled += n; 153 | if (bufferFilled == buffer.size() && buffer.size() < maxBufferSize) { 154 | buffer.resize(std::max(std::min(buffer.size() * 2, maxBufferSize), (size_t)4096)); 155 | } 156 | return true; 157 | } 158 | return false; 159 | } 160 | void* readBufferPointer(size_t len) { 161 | if (bufferFilled - bufferOffset >= len) { 162 | size_t o = bufferOffset; 163 | bufferOffset += len; 164 | return buffer.data() + o; 165 | } 166 | if (buffer.size() < bufferOffset + len) { 167 | buffer.resize(std::max((size_t)(bufferOffset + len), buffer.size() + buffer.size() / 2)); 168 | } 169 | newRead(); 170 | startRead(); 171 | done(); 172 | if (bufferFilled - bufferOffset >= len) { 173 | size_t o = bufferOffset; 174 | bufferOffset += len; 175 | if (bufferOffset == bufferFilled) { 176 | bufferOffset = 0; 177 | bufferFilled = 0; 178 | } 179 | return buffer.data() + o; 180 | } 181 | return nullptr; 182 | } 183 | bool readCopy(void* dst, size_t len) { 184 | void* ptr = readBufferPointer(len); 185 | if (ptr) { 186 | std::memcpy(dst, ptr, len); 187 | return true; 188 | } else { 189 | return false; 190 | } 191 | } 192 | }; 193 | 194 | } // namespace rpc 195 | -------------------------------------------------------------------------------- /src/util.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "intrusive_list.h" 11 | #include "logging.h" 12 | #include "rpc.h" 13 | #include "tensor.h" 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | namespace moolib { 20 | 21 | inline auto seedRng() { 22 | std::random_device dev; 23 | auto start = std::chrono::high_resolution_clock::now(); 24 | std::seed_seq ss( 25 | {(uint32_t)dev(), (uint32_t)dev(), (uint32_t)(std::chrono::high_resolution_clock::now() - start).count(), 26 | (uint32_t)std::chrono::steady_clock::now().time_since_epoch().count(), (uint32_t)dev(), 27 | (uint32_t)std::chrono::system_clock::now().time_since_epoch().count(), 28 | (uint32_t)std::chrono::high_resolution_clock::now().time_since_epoch().count(), 29 | (uint32_t)(std::chrono::high_resolution_clock::now() - start).count(), (uint32_t)dev(), 30 | (uint32_t)(std::chrono::high_resolution_clock::now() - start).count(), (uint32_t)dev(), 31 | (uint32_t)std::hash()(std::this_thread::get_id())}); 32 | return std::mt19937_64(ss); 33 | }; 34 | 35 | inline std::mt19937_64& getRng() { 36 | thread_local std::mt19937_64 rng{seedRng()}; 37 | return rng; 38 | } 39 | 40 | template>* = nullptr> 41 | T random(T min = std::numeric_limits::min(), T max = std::numeric_limits::max()) { 42 | return std::uniform_int_distribution(min, max)(getRng()); 43 | } 44 | 45 | template 46 | float seconds(Duration duration) { 47 | return std::chrono::duration_cast>>(duration).count(); 48 | } 49 | 50 | struct Timer { 51 | std::chrono::steady_clock::time_point start; 52 | Timer() { 53 | reset(); 54 | } 55 | void reset() { 56 | start = std::chrono::steady_clock::now(); 57 | } 58 | float elapsedAt(std::chrono::steady_clock::time_point now) { 59 | return std::chrono::duration_cast>>(now - start).count(); 60 | } 61 | float elapsed() { 62 | return elapsedAt(std::chrono::steady_clock::now()); 63 | } 64 | float elapsedReset() { 65 | auto now = std::chrono::steady_clock::now(); 66 | float r = elapsedAt(now); 67 | start = now; 68 | return r; 69 | } 70 | }; 71 | 72 | inline std::string randomName() { 73 | std::string r; 74 | for (int i = 0; i != 16; ++i) { 75 | r += "0123456789abcdef"[std::uniform_int_distribution(0, 15)(getRng())]; 76 | } 77 | return r; 78 | } 79 | 80 | inline int getTensorDType(char dtype, int itemsize) { 81 | using rpc::Tensor; 82 | switch (dtype) { 83 | case 'f': 84 | if (itemsize == 2) { 85 | return Tensor::kFloat16; 86 | } else if (itemsize == 4) { 87 | return Tensor::kFloat32; 88 | } else if (itemsize == 8) { 89 | return Tensor::kFloat64; 90 | } else { 91 | throw std::runtime_error("Unexpected itemsize for float"); 92 | } 93 | break; 94 | case 'i': 95 | if (itemsize == 1) { 96 | return Tensor::kInt8; 97 | } else if (itemsize == 2) { 98 | return Tensor::kInt16; 99 | } else if (itemsize == 4) { 100 | return Tensor::kInt32; 101 | } else if (itemsize == 8) { 102 | return Tensor::kInt64; 103 | } else 104 | throw std::runtime_error("Unexpected itemsize for int"); 105 | break; 106 | case 'u': 107 | if (itemsize == 1) { 108 | return Tensor::kUInt8; 109 | } else 110 | throw std::runtime_error("Unexpected itemsize for unsigned int"); 111 | break; 112 | case 'b': 113 | if (itemsize == 1) { 114 | return Tensor::kBool; 115 | } else 116 | throw std::runtime_error("Unexpected itemsize for boolean"); 117 | break; 118 | default: 119 | throw std::runtime_error("Unsupported dtype '" + std::string(1, dtype) + "'"); 120 | } 121 | } 122 | 123 | template 124 | struct Future { 125 | private: 126 | using IT = std::conditional_t, std::nullptr_t, T>; 127 | struct S { 128 | std::optional value; 129 | std::atomic_bool hasValue = false; 130 | }; 131 | std::shared_ptr s; 132 | 133 | public: 134 | Future() { 135 | s = std::make_shared(); 136 | } 137 | void reset() { 138 | *this = Future(); 139 | } 140 | void set() { 141 | s->value.emplace(); 142 | s->hasValue = true; 143 | } 144 | template 145 | void set(T2&& val) { 146 | s->value = std::move(val); 147 | s->hasValue = true; 148 | } 149 | operator bool() const noexcept { 150 | return s->hasValue; 151 | } 152 | IT& operator*() { 153 | return *s->value; 154 | } 155 | IT* operator->() { 156 | return &*s->value; 157 | } 158 | }; 159 | 160 | template 161 | Future callImpl(rpc::Rpc& rpc, std::string_view peerName, std::string_view funcName, Args&&... args) { 162 | Future retval; 163 | rpc.asyncCallback( 164 | peerName, funcName, 165 | [retval](T* value, rpc::Error* err) mutable { 166 | if (value) { 167 | if constexpr (!std::is_same_v) { 168 | retval.set(*value); 169 | } else { 170 | retval.set(); 171 | } 172 | } else { 173 | log.error("RPC error: %s\n", err->what()); 174 | } 175 | }, 176 | std::forward(args)...); 177 | return retval; 178 | } 179 | 180 | template 181 | std::string sizesStr(T&& sizes) { 182 | std::string s = "{"; 183 | for (auto& v : sizes) { 184 | if (s.size() > 1) { 185 | s += ", "; 186 | } 187 | using UT = std::decay_t; 188 | if constexpr (std::is_integral_v || std::is_floating_point_v) { 189 | s += std::to_string(v); 190 | } else { 191 | s += sizesStr(v); 192 | } 193 | } 194 | s += "}"; 195 | return s; 196 | } 197 | 198 | template 199 | struct Dtor { 200 | T f; 201 | Dtor(T f) : f(std::move(f)) {} 202 | ~Dtor() { 203 | f(); 204 | } 205 | }; 206 | 207 | struct Identity { 208 | template 209 | constexpr T&& operator()(T&& v) { 210 | return std::forward(v); 211 | } 212 | }; 213 | 214 | } // namespace moolib 215 | -------------------------------------------------------------------------------- /src/vector.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace moolib { 17 | 18 | template 19 | struct Vector { 20 | T* storagebegin = nullptr; 21 | T* storageend = nullptr; 22 | T* beginptr = nullptr; 23 | T* endptr = nullptr; 24 | size_t msize = 0; 25 | Vector() = default; 26 | Vector(const Vector&) = delete; 27 | Vector(Vector&& n) { 28 | *this = std::move(n); 29 | } 30 | ~Vector() { 31 | if (beginptr != endptr) { 32 | clear(); 33 | } 34 | if (storagebegin) { 35 | std::free(storagebegin); 36 | } 37 | } 38 | Vector& operator=(const Vector&) = delete; 39 | Vector& operator=(Vector&& n) { 40 | std::swap(storagebegin, n.storagebegin); 41 | std::swap(storageend, n.storageend); 42 | std::swap(beginptr, n.beginptr); 43 | std::swap(endptr, n.endptr); 44 | std::swap(msize, n.msize); 45 | return *this; 46 | } 47 | size_t size() { 48 | return msize; 49 | } 50 | T* data() { 51 | return beginptr; 52 | } 53 | T* begin() { 54 | return beginptr; 55 | } 56 | T* end() { 57 | return endptr; 58 | } 59 | T& operator[](size_t index) { 60 | return beginptr[index]; 61 | } 62 | void clear() { 63 | for (auto* i = beginptr; i != endptr; ++i) { 64 | i->~T(); 65 | } 66 | beginptr = storagebegin; 67 | endptr = beginptr; 68 | msize = 0; 69 | } 70 | void move(T* dst, T* begin, T* end) { 71 | if constexpr (std::is_trivially_copyable_v) { 72 | std::memmove((void*)dst, (void*)begin, (end - begin) * sizeof(T)); 73 | } else { 74 | if (dst <= begin) { 75 | for (auto* i = begin; i != end;) { 76 | *dst = std::move(*i); 77 | ++dst; 78 | ++i; 79 | } 80 | } else { 81 | auto* dsti = dst + (end - begin); 82 | for (auto* i = end; i != begin;) { 83 | --dsti; 84 | --i; 85 | *dsti = std::move(*i); 86 | } 87 | } 88 | } 89 | } 90 | void erase(T* begin, T* end) { 91 | for (auto* i = begin; i != end; ++i) { 92 | i->~T(); 93 | } 94 | size_t n = end - begin; 95 | msize -= n; 96 | if (begin == beginptr) { 97 | for (auto* i = begin; i != end; ++i) { 98 | i->~T(); 99 | } 100 | beginptr = end; 101 | if (beginptr != endptr) { 102 | size_t unused = beginptr - storagebegin; 103 | if (unused > msize && unused >= 1024 * 512 / sizeof(T)) { 104 | if constexpr (std::is_trivially_copyable_v) { 105 | move(storagebegin, beginptr, endptr); 106 | } else { 107 | auto* sbi = storagebegin; 108 | auto* bi = beginptr; 109 | while (sbi != beginptr && bi != endptr) { 110 | new (sbi) T(std::move(*bi)); 111 | ++sbi; 112 | ++bi; 113 | } 114 | move(sbi, bi, endptr); 115 | for (auto* i = bi; i != endptr; ++i) { 116 | i->~T(); 117 | } 118 | } 119 | beginptr = storagebegin; 120 | endptr = beginptr + msize; 121 | } 122 | } 123 | } else { 124 | move(begin, end, endptr); 125 | for (auto* i = end; i != endptr; ++i) { 126 | i->~T(); 127 | } 128 | endptr -= n; 129 | } 130 | if (beginptr == endptr) { 131 | beginptr = storagebegin; 132 | endptr = beginptr; 133 | } 134 | } 135 | void resize(size_t n) { 136 | if (msize > n) { 137 | T* i = endptr; 138 | T* e = beginptr + n; 139 | while (i != e) { 140 | --i; 141 | i->~T(); 142 | } 143 | } else if (n > msize) { 144 | reserve(n); 145 | T* i = endptr; 146 | T* e = beginptr + n; 147 | while (i != e) { 148 | new (i) T(); 149 | ++i; 150 | } 151 | } 152 | endptr = beginptr + n; 153 | msize = n; 154 | } 155 | bool empty() const { 156 | return beginptr == endptr; 157 | } 158 | size_t capacity() { 159 | return storageend - beginptr; 160 | } 161 | void reserveImpl(size_t n) { 162 | auto* lbegin = beginptr; 163 | auto* lend = endptr; 164 | auto* prevstorage = storagebegin; 165 | size_t msize = this->msize; 166 | T* newptr = (T*)std::aligned_alloc(alignof(T), sizeof(T) * n); 167 | if (!newptr) { 168 | throw std::bad_alloc(); 169 | } 170 | if (prevstorage) { 171 | if constexpr (std::is_trivially_copyable_v) { 172 | std::memcpy(newptr, lbegin, sizeof(T) * msize); 173 | } else { 174 | T* dst = newptr; 175 | for (T* i = lbegin; i != lend; ++i) { 176 | new (dst) T(std::move(*i)); 177 | i->~T(); 178 | ++dst; 179 | } 180 | } 181 | std::free(prevstorage); 182 | } 183 | storagebegin = newptr; 184 | storageend = newptr + n; 185 | beginptr = newptr; 186 | endptr = newptr + msize; 187 | } 188 | void reserve(size_t n) { 189 | if (n <= capacity()) { 190 | return; 191 | } 192 | reserveImpl(n); 193 | } 194 | void expand() { 195 | reserveImpl(std::max(capacity() * 2, (size_t)16)); 196 | } 197 | void push_back(const T& v) { 198 | emplace_back(v); 199 | } 200 | void push_back(T&& v) { 201 | emplace_back(std::move(v)); 202 | } 203 | template 204 | void emplace_back(Args&&... args) { 205 | if (endptr == storageend) { 206 | if (capacity() != size()) { 207 | __builtin_unreachable(); 208 | } 209 | [[unlikely]]; 210 | expand(); 211 | } 212 | new (endptr) T(std::forward(args)...); 213 | ++endptr; 214 | ++msize; 215 | } 216 | T& front() { 217 | return *beginptr; 218 | } 219 | T& back() { 220 | return endptr[-1]; 221 | } 222 | void pop_back() { 223 | --endptr; 224 | --msize; 225 | endptr->~T(); 226 | } 227 | }; 228 | 229 | } // namespace moolib 230 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | add_executable(test_rpc test_rpc.cc) 4 | target_include_directories(test_rpc PUBLIC "../src") 5 | target_link_libraries(test_rpc moorpc) 6 | 7 | 8 | add_executable(test_multinode_allreduce test_multinode_allreduce.cc) 9 | target_include_directories(test_multinode_allreduce PUBLIC "../src") 10 | target_link_libraries(test_multinode_allreduce moorpc) 11 | 12 | 13 | -------------------------------------------------------------------------------- /test/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | 8 | 9 | def printFunction(str): 10 | print(str) 11 | return 42 12 | 13 | 14 | host = moolib.Rpc() 15 | host.set_name("host") 16 | host.define("print", printFunction) 17 | host.listen("127.0.0.1:1234") 18 | 19 | client = moolib.Rpc() 20 | client.connect("127.0.0.1:1234") 21 | 22 | future = client.async_("host", "print", "hello world") 23 | print(future.get()) 24 | 25 | 26 | client.define("sum", sum) 27 | print(host.sync(client.get_name(), "sum", [1, 2, 3, 4], 10)) 28 | -------------------------------------------------------------------------------- /test/integration/test_a2c.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import collections 7 | import pytest 8 | from unittest import mock 9 | 10 | from examples import a2c 11 | 12 | 13 | class TestA2CExample: 14 | @pytest.mark.skip(reason="broken with gym>=0.26 !?") 15 | def test_single_node_training(self, num_steps=40000): 16 | items = collections.deque(maxlen=20) 17 | 18 | def log_func(**kwargs): 19 | items.append(kwargs) 20 | 21 | with mock.patch.object(a2c, "log_to_file", log_func): 22 | a2c.train(num_steps) 23 | 24 | low_return_items = [] 25 | 26 | for index, item in enumerate(items): 27 | max_step_offset = ( 28 | 2 * (len(items) - index) * a2c.BATCH_SIZE * a2c.ROLLOUT_LENGTH 29 | ) 30 | assert item["step"] > num_steps - max_step_offset 31 | if item["mean_episode_return"] < 100: 32 | low_return_items.append((index, item)) 33 | assert -1 < item["entropy_loss"] < 0 34 | 35 | assert len(low_return_items) / len(items) < 0.5 36 | -------------------------------------------------------------------------------- /test/test.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fmt/printf.h" 14 | 15 | namespace { 16 | struct Timer { 17 | std::chrono::steady_clock::time_point start; 18 | Timer() { 19 | reset(); 20 | } 21 | void reset() { 22 | start = std::chrono::steady_clock::now(); 23 | } 24 | float elapsedAt(std::chrono::steady_clock::time_point now) { 25 | return std::chrono::duration_cast>>(now - start).count(); 26 | } 27 | float elapsed() { 28 | return elapsedAt(std::chrono::steady_clock::now()); 29 | } 30 | float elapsedReset() { 31 | auto now = std::chrono::steady_clock::now(); 32 | float r = elapsedAt(now); 33 | start = now; 34 | return r; 35 | } 36 | }; 37 | 38 | Timer mainTimer; 39 | std::atomic isDone = false; 40 | 41 | float timeoutSeconds = 300; 42 | 43 | std::atomic currentTest = ""; 44 | std::atomic_int pass = 0; 45 | std::atomic_int fail = 0; 46 | std::atomic_int skip = 0; 47 | std::mutex quitMutex; 48 | std::thread timeoutThread; 49 | std::once_flag timeoutOnce; 50 | 51 | void quit() { 52 | if (isDone) { 53 | return; 54 | } 55 | { 56 | std::lock_guard l(quitMutex); 57 | if (isDone) { 58 | return; 59 | } 60 | fmt::printf("Passed: %d\nFailed: %d\nSkipped: %d\n", pass.load(), fail.load(), skip.load()); 61 | fmt::printf("Ran %d tests in %gs\n", pass.load() + fail.load(), mainTimer.elapsed()); 62 | fflush(stdout); 63 | isDone = true; 64 | if (timeoutThread.joinable()) { 65 | if (timeoutThread.get_id() != std::this_thread::get_id()) { 66 | timeoutThread.detach(); 67 | } else { 68 | timeoutThread.join(); 69 | } 70 | } 71 | } 72 | if (fail == 0) { 73 | std::exit(0); 74 | } else { 75 | std::quick_exit(1); 76 | } 77 | } 78 | 79 | void failAt(const char* file, int line, std::string str) { 80 | ++fail; 81 | fmt::printf("\n"); 82 | fmt::printf("Test FAILED: %s at %s:%d\n", currentTest.load(), file, line); 83 | fmt::printf("Message: %s\n", str); 84 | fmt::printf("\n"); 85 | fflush(stdout); 86 | quit(); 87 | } 88 | 89 | void passAt(const char* file, int line, std::string str) { 90 | ++pass; 91 | fmt::printf("Passed: %s\n", currentTest.load()); 92 | } 93 | 94 | #define FAIL(x) failAt(__FILE__, __LINE__, x) 95 | #define PASS(x) passAt(__FILE__, __LINE__, x) 96 | #define ASSERT(x) \ 97 | if (!(x)) FAIL(#x) 98 | 99 | #define RUN(x, ...) runTest(#x); 100 | #define RUNARG(x, ...) runTest(#x, __VA_ARGS__); 101 | 102 | void startTimeoutThread() { 103 | std::call_once(timeoutOnce, [&] { 104 | timeoutThread = std::thread([&] { 105 | while (mainTimer.elapsed() < timeoutSeconds && !isDone) { 106 | std::this_thread::sleep_for(std::chrono::seconds(1)); 107 | } 108 | if (!isDone) { 109 | FAIL("Timed out waiting for tests to finish!"); 110 | } 111 | }); 112 | }); 113 | } 114 | 115 | struct Test { 116 | Test() {} 117 | ~Test() {} 118 | }; 119 | 120 | template 121 | void runTest(const char* name, Args&&... args) { 122 | startTimeoutThread(); 123 | currentTest = name; 124 | T obj(std::forward(args)...); 125 | PASS("Passed"); 126 | } 127 | 128 | } // namespace 129 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import time 8 | import torch 9 | 10 | 11 | def main(): 12 | 13 | localAddr = "127.0.0.1:4411" 14 | # localAddr = "shm://testtest" 15 | 16 | client = moolib.Rpc() 17 | host = moolib.Rpc() 18 | 19 | client.set_name("client") 20 | client.set_timeout(1) 21 | 22 | def client_hello(*args): 23 | print("client got hello:", (*args,)) 24 | 25 | client.define("client hello", client_hello) 26 | 27 | def hello(message): 28 | print("Got hello: ", message) 29 | return "this is a response to message '" + message + "'" 30 | 31 | host.define("hello", hello) 32 | 33 | host.set_name("host") 34 | host.listen(localAddr) 35 | 36 | client.connect(localAddr) 37 | 38 | future = client.async_("host", "hello", "this is a message from client") 39 | 40 | response = future.result() 41 | 42 | print("Got response: ", response) 43 | 44 | print("sync: ", client.sync("host", "hello", "sync test")) 45 | 46 | if True: 47 | 48 | def helloCallback(response, error): 49 | print("Callback response: ", response) 50 | print("Callback error: ", error) 51 | assert error is None 52 | 53 | client.async_callback( 54 | "host", "hello", helloCallback, "this is a message through async_callback" 55 | ) 56 | 57 | try: 58 | future = client.async_("nowhere", "hello", "this is a message to nowhere") 59 | response = future.result() 60 | print("Response from nowhere: ", response) 61 | raise AssertionError() 62 | except Exception as e: 63 | print(e) 64 | 65 | try: 66 | client.sync("host", "non-existant function") 67 | raise AssertionError() 68 | except Exception as e: 69 | print(e) 70 | 71 | del host 72 | 73 | try: 74 | client.sync("host", "hello", "is host dead?") 75 | raise AssertionError() 76 | except Exception as e: 77 | print(e) 78 | 79 | host = moolib.Rpc() 80 | host.define("hello", hello) 81 | host.set_name("host") 82 | host.listen(localAddr) 83 | 84 | client.set_timeout(30) 85 | print(client.sync("host", "hello", "is host alive?")) 86 | 87 | weights = torch.randn(4096, 4096) 88 | 89 | def linear(input): 90 | return (weights * input).sum(-1) 91 | 92 | def noop(): 93 | pass 94 | 95 | host.define("linear", linear) 96 | host.define("noop", noop) 97 | 98 | client.set_timeout(60) 99 | 100 | input = torch.randn(16, 4096) 101 | 102 | client.sync("host", "linear", input.unsqueeze(1)) 103 | 104 | for _ in range(128): 105 | client.sync("host", "noop") 106 | 107 | for _ in range(4): 108 | futures = [] 109 | start = time.time() 110 | for _ in range(10000): 111 | futures.append(client.async_("host", "noop")) 112 | for i in futures: 113 | i.result() 114 | t = time.time() - start 115 | print("noop x%d time %g (%g/s)" % (len(futures), t, len(futures) / t)) 116 | 117 | for _ in range(2): 118 | start = time.time() 119 | local_result = sum(linear(input[i]).sum().item() for i in range(input.size(0))) 120 | print("base time ", time.time() - start) 121 | 122 | for _ in range(4): 123 | futures = [] 124 | start = time.time() 125 | for i in range(input.size(0)): 126 | futures.append(client.async_("host", "linear", input[i])) 127 | result = sum(i.result().sum().item() for i in futures) 128 | print("async time ", time.time() - start) 129 | assert abs(result - local_result) < 0.1 130 | 131 | host.debug_info() 132 | client.debug_info() 133 | 134 | 135 | main() 136 | -------------------------------------------------------------------------------- /test/test_asyncio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import time 8 | import torch 9 | 10 | import asyncio 11 | 12 | 13 | async def main(): 14 | 15 | localAddr = "127.0.0.1:4411" 16 | # localAddr = "shm://testtest" 17 | 18 | client = moolib.Rpc() 19 | host = moolib.Rpc() 20 | 21 | client.set_name("client") 22 | client.set_timeout(1) 23 | 24 | def client_hello(*args): 25 | print("client got hello:", (*args,)) 26 | 27 | client.define("client hello", client_hello) 28 | 29 | def hello(message): 30 | print("Got hello: ", message) 31 | return "this is a response to message '" + message + "'" 32 | 33 | host.define("hello", hello) 34 | 35 | host.set_name("host") 36 | host.listen(localAddr) 37 | 38 | client.connect(localAddr) 39 | 40 | foo = client.async_("host", "hello", "this is a message from client") 41 | response = await foo 42 | print("Got response: ", response) 43 | 44 | print("sync: ", client.sync("host", "hello", "sync test")) 45 | 46 | if True: 47 | 48 | def helloCallback(response, error): 49 | print("Callback response: ", response) 50 | print("Callback error: ", error) 51 | assert error is None 52 | 53 | client.async_callback( 54 | "host", "hello", helloCallback, "this is a message through async_callback" 55 | ) 56 | 57 | try: 58 | response = await client.async_( 59 | "nowhere", "hello", "this is a message to nowhere" 60 | ) 61 | print("Response from nowhere: ", response) 62 | raise AssertionError() 63 | except Exception as e: 64 | print(e) 65 | 66 | try: 67 | client.sync("host", "non-existant function") 68 | raise AssertionError() 69 | except Exception as e: 70 | print(e) 71 | 72 | del host 73 | 74 | try: 75 | client.sync("host", "hello", "is host dead?") 76 | raise AssertionError() 77 | except Exception as e: 78 | print(e) 79 | 80 | host = moolib.Rpc() 81 | host.define("hello", hello) 82 | host.set_name("host") 83 | host.listen(localAddr) 84 | 85 | client.set_timeout(30) 86 | print(client.sync("host", "hello", "is host alive?")) 87 | 88 | weights = torch.randn(4096, 4096) 89 | 90 | def linear(input): 91 | return (weights * input).sum(-1) 92 | 93 | def noop(): 94 | pass 95 | 96 | host.define("linear", linear) 97 | host.define("noop", noop) 98 | 99 | client.set_timeout(60) 100 | 101 | input = torch.randn(16, 4096) 102 | 103 | client.sync("host", "linear", input.unsqueeze(1)) 104 | 105 | for _ in range(128): 106 | client.sync("host", "noop") 107 | 108 | for _ in range(4): 109 | futures = [] 110 | start = time.time() 111 | for _ in range(10000): 112 | futures.append(client.async_("host", "noop")) 113 | for i in futures: 114 | await i 115 | t = time.time() - start 116 | print("noop x%d time %g (%g/s)" % (len(futures), t, len(futures) / t)) 117 | 118 | for _ in range(2): 119 | start = time.time() 120 | local_result = sum(linear(input[i]).sum().item() for i in range(input.size(0))) 121 | print("base time ", time.time() - start) 122 | 123 | for _ in range(4): 124 | futures = [] 125 | start = time.time() 126 | for i in range(input.size(0)): 127 | futures.append(client.async_("host", "linear", input[i])) 128 | result = 0 129 | for i in futures: 130 | result += (await i).sum().item() 131 | # why does this not work ? 132 | # result = sum((await i).sum().item() for i in futures) 133 | print("async time ", time.time() - start) 134 | assert abs(result - local_result) < 0.1 135 | 136 | host.debug_info() 137 | client.debug_info() 138 | 139 | 140 | try: 141 | asyncio.run(main()) 142 | except: 143 | import traceback 144 | 145 | traceback.print_exc() 146 | -------------------------------------------------------------------------------- /test/test_asyncio_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import time 8 | import torch 9 | 10 | import asyncio 11 | 12 | 13 | async def process(queue, callback): 14 | try: 15 | while True: 16 | return_callback, args, kwargs = await queue 17 | if args and kwargs: 18 | retval = callback(*args, **kwargs) 19 | elif args: 20 | retval = callback(*args) 21 | elif kwargs: 22 | retval = callback(**kwargs) 23 | else: 24 | retval = callback() 25 | return_callback(retval) 26 | except asyncio.CancelledError: 27 | print("process cancelled") 28 | pass 29 | except Exception as e: 30 | print(e) 31 | raise 32 | 33 | 34 | async def main(): 35 | 36 | loop = asyncio.get_running_loop() 37 | 38 | localAddr = "127.0.0.1:4411" 39 | # localAddr = "shm://testtest" 40 | 41 | client = moolib.Rpc() 42 | host = moolib.Rpc() 43 | 44 | client.set_name("client") 45 | client.set_timeout(1) 46 | 47 | def client_hello(*args): 48 | print("client got hello:", (*args,)) 49 | 50 | client.define("client hello", client_hello) 51 | 52 | def hello(message): 53 | print("Got hello: ", message) 54 | return "this is a response to message '" + message + "'" 55 | 56 | def hello_deferred(callback, message): 57 | print("Got hello (deferred): ", message) 58 | callback("this is the deferred response to " + message) 59 | 60 | host.define_deferred("hello deferred", hello_deferred) 61 | 62 | def wrap_define(self, name, func): 63 | loop.create_task(process(self.define_queue(name), func)) 64 | 65 | wrap_define(host, "hello", hello) 66 | 67 | host.set_name("host") 68 | host.listen(localAddr) 69 | 70 | client.connect(localAddr) 71 | 72 | print(client.sync("host", "hello deferred", "sync test")) 73 | print(client.sync("host", "hello deferred", message="named argument")) 74 | 75 | foo = client.async_("host", "hello", "this is a message from client") 76 | response = await foo 77 | print("Got response: ", response) 78 | 79 | try: 80 | # sync will block and time out since we can't process the queue 81 | print("sync: ", client.sync("host", "hello", "sync test")) 82 | raise AssertionError() 83 | except RuntimeError as e: 84 | print(e) 85 | 86 | print("async: ", await client.async_("host", "hello", "async test")) 87 | 88 | weights = torch.randn(4096, 4096) 89 | 90 | def linear(input): 91 | return (weights * input).sum(-1) 92 | 93 | def noop(): 94 | pass 95 | 96 | wrap_define(host, "linear", linear) 97 | wrap_define(host, "noop", noop) 98 | 99 | client.set_timeout(60) 100 | 101 | input = torch.randn(16, 4096) 102 | 103 | await client.async_("host", "linear", input.unsqueeze(1)) 104 | 105 | for _ in range(128): 106 | await client.async_("host", "noop") 107 | 108 | for _ in range(4): 109 | futures = [] 110 | start = time.time() 111 | for _ in range(10000): 112 | futures.append(client.async_("host", "noop")) 113 | for i in futures: 114 | await i 115 | t = time.time() - start 116 | print("noop x%d time %g (%g/s)" % (len(futures), t, len(futures) / t)) 117 | 118 | for _ in range(2): 119 | start = time.time() 120 | local_result = sum(linear(input[i]).sum().item() for i in range(input.size(0))) 121 | print("base time ", time.time() - start) 122 | 123 | for _ in range(4): 124 | futures = [] 125 | start = time.time() 126 | for i in range(input.size(0)): 127 | futures.append(client.async_("host", "linear", input[i])) 128 | result = 0 129 | for i in futures: 130 | result += (await i).sum().item() 131 | # why does this not work ? 132 | # result = sum((await i).sum().item() for i in futures) 133 | print("async time ", time.time() - start) 134 | assert abs(result - local_result) < 0.1 135 | 136 | host.debug_info() 137 | client.debug_info() 138 | 139 | 140 | try: 141 | asyncio.run(main()) 142 | except: 143 | import traceback 144 | 145 | traceback.print_exc() 146 | -------------------------------------------------------------------------------- /test/test_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import torch 8 | 9 | import asyncio 10 | 11 | 12 | async def process(queue, callback): 13 | try: 14 | while True: 15 | return_callback, args, kwargs = await queue 16 | if args and kwargs: 17 | retval = callback(*args, **kwargs) 18 | elif args: 19 | retval = callback(*args) 20 | elif kwargs: 21 | retval = callback(**kwargs) 22 | else: 23 | retval = callback() 24 | return_callback(retval) 25 | except asyncio.CancelledError: 26 | print("process cancelled") 27 | pass 28 | except Exception as e: 29 | print(e) 30 | raise 31 | 32 | 33 | async def main(): 34 | 35 | loop = asyncio.get_running_loop() 36 | 37 | localAddr = "127.0.0.1:4411" 38 | # localAddr = "shm://testtest" 39 | 40 | host = moolib.Rpc() 41 | 42 | bs = 8 43 | 44 | def hello(message, tensor): 45 | print("Got hello: ", message, tensor.shape) 46 | for i in range(tensor.size(0)): 47 | print("tensor[%d].sum() is %g" % (i, tensor[i].sum().item())) 48 | r = tensor.flatten(1).sum(1) 49 | print("returning ", r.shape) 50 | return "this is a response to message '" + message + "'", r 51 | 52 | host.define("hello", hello, batch_size=bs) 53 | 54 | def hello_deferred(callback, message, tensor): 55 | callback(hello(message, tensor)) 56 | 57 | host.define_deferred("hello deferred", hello_deferred, batch_size=bs) 58 | 59 | def wrap_define(self, name, func, batch_size=None): 60 | return loop.create_task( 61 | process(self.define_queue(name, batch_size=batch_size), func) 62 | ) 63 | 64 | wrap_define(host, "hello queue", hello, batch_size=bs) 65 | 66 | host.set_name("host") 67 | host.listen(localAddr) 68 | 69 | clients = [] 70 | for _ in range(40): 71 | client = moolib.Rpc() 72 | client.set_timeout(10) 73 | client.connect(localAddr) 74 | clients.append(client) 75 | 76 | for i in range(21): 77 | futures = [] 78 | fn = ["hello", "hello deferred", "hello queue"][i % 3] 79 | for c in clients: 80 | t = torch.randn(2, 3) 81 | print("calling with tensor sum %g" % t.sum().item()) 82 | futures.append(c.async_("host", fn, "wee " + fn, t)) 83 | 84 | for f in futures: 85 | try: 86 | print(await f) 87 | except Exception as e: 88 | print(e) 89 | raise 90 | 91 | # host.debug_info() 92 | 93 | 94 | try: 95 | asyncio.run(main()) 96 | except: 97 | import traceback 98 | 99 | traceback.print_exc() 100 | -------------------------------------------------------------------------------- /test/test_dynamic_batching_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import asyncio 7 | import time 8 | import traceback 9 | 10 | import torch 11 | 12 | import moolib 13 | 14 | 15 | async def process(que, callback): 16 | try: 17 | while True: 18 | ret_cb, args, kwargs = await que 19 | if args and kwargs: 20 | ret = callback(*args, **kwargs) 21 | elif args: 22 | ret = callback(*args) 23 | elif kwargs: 24 | ret = callback(**kwargs) 25 | else: 26 | ret = callback() 27 | ret_cb(ret) 28 | except asyncio.CancelledError: 29 | print("[Server] process cancelled") 30 | pass 31 | except Exception as e: 32 | print(e) 33 | raise 34 | 35 | 36 | async def main(): 37 | addr = "127.0.0.1:4411" 38 | timeout = 60 39 | 40 | num_tests = 200 41 | num_benchmarks = 10000 42 | dim = 128 43 | linear = torch.nn.Linear(dim, dim) 44 | 45 | loop = asyncio.get_running_loop() 46 | 47 | server = moolib.Rpc() 48 | server.set_name("server") 49 | server.set_timeout(timeout) 50 | 51 | verbose = True 52 | 53 | def run_linear(x, info): 54 | if verbose: 55 | print(f"[Linear] batch_size = {x.size()[0]}") 56 | 57 | with torch.no_grad(): 58 | return linear(x), info 59 | 60 | loop.create_task(process(server.define_queue("linear"), run_linear)) 61 | loop.create_task( 62 | process( 63 | server.define_queue("batch_linear", batch_size=100, dynamic_batching=True), 64 | run_linear, 65 | ) 66 | ) 67 | server.listen(addr) 68 | 69 | client = moolib.Rpc() 70 | client.set_name("client") 71 | client.set_timeout(timeout) 72 | client.connect(addr) 73 | 74 | x_list = [torch.randn(dim) for _ in range(num_tests)] 75 | y_list = [linear(x) for x in x_list] 76 | 77 | futs = [] 78 | for i, x in enumerate(x_list): 79 | futs.append( 80 | client.async_("server", "batch_linear", x, info=dict(index=[i, i + 1])) 81 | ) 82 | for i, fut in enumerate(futs): 83 | y, info = await fut 84 | assert torch.allclose(y, y_list[i], rtol=1e-5, atol=1e-6) 85 | assert len(info) == 1 86 | assert info["index"] == [i, i + 1] 87 | 88 | verbose = False 89 | futs1 = [] 90 | futs2 = [] 91 | 92 | t0 = time.time() 93 | for _ in range(num_benchmarks): 94 | futs1.append(client.async_("server", "linear", x_list[0], 0)) 95 | for fut in futs1: 96 | await fut 97 | t1 = time.time() 98 | for _ in range(num_benchmarks): 99 | futs2.append(client.async_("server", "batch_linear", x_list[0], 0)) 100 | for fut in futs2: 101 | await fut 102 | t2 = time.time() 103 | 104 | print(f"[Benchmark] without batching time: {t1 - t0} seconds") 105 | print(f"[Benchmark] dynamic batching time: {t2 - t1} seconds") 106 | 107 | 108 | if __name__ == "__main__": 109 | try: 110 | asyncio.run(main()) 111 | except: 112 | traceback.print_exc() 113 | -------------------------------------------------------------------------------- /test/test_group.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import time 8 | 9 | moolib.set_log_level("info") 10 | 11 | 12 | class Client: 13 | def __init__(self, broker_addr, index): 14 | self.rpc = moolib.Rpc() 15 | self.name = "client %d" % index 16 | self.rpc.set_name(self.name) 17 | self.rpc.set_timeout(2) 18 | self.rpc.define("hello", self.hello) 19 | self.rpc.connect(broker_addr) 20 | self.group = None 21 | self.index = index 22 | 23 | def hello(self, str): 24 | print("%s received hello: %s" % (self.name, str)) 25 | 26 | def update(self): 27 | 28 | if self.group is None: 29 | self.group = moolib.Group(self.rpc, "test group") 30 | self.group.set_timeout(1) 31 | self.group.set_sort_order(-self.index) 32 | else: 33 | updated = self.group.update() 34 | 35 | if updated: 36 | print( 37 | "group '%s', sync id %#x, members %s" 38 | % (self.group.name(), self.group.sync_id(), self.group.members()) 39 | ) 40 | for n in self.group.members(): 41 | if n != self.name: 42 | self.rpc.async_(n, "hello", "hello from " + self.name) 43 | 44 | 45 | def main(): 46 | 47 | localAddr = "127.0.0.1:4411" 48 | # localAddr = "shm://testtest" 49 | 50 | broker_rpc = moolib.Rpc() 51 | broker_rpc.set_name("broker") 52 | broker = moolib.Broker(broker_rpc) 53 | broker_rpc.listen(localAddr) 54 | 55 | clients = [] 56 | for i in range(4): 57 | clients.append(Client(localAddr, i)) 58 | 59 | for _ in range(10): 60 | 61 | for client in clients: 62 | client.update() 63 | 64 | broker.update() 65 | 66 | time.sleep(0.1) 67 | 68 | assert clients[0].group.members() == [ 69 | "client 3", 70 | "client 2", 71 | "client 1", 72 | "client 0", 73 | ] 74 | 75 | del clients[2] 76 | 77 | for _ in range(20): 78 | 79 | for client in clients: 80 | client.update() 81 | 82 | broker.update() 83 | 84 | time.sleep(0.1) 85 | 86 | assert clients[0].group.members() == ["client 3", "client 1", "client 0"] 87 | 88 | 89 | main() 90 | -------------------------------------------------------------------------------- /test/test_multinode_allreduce.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "test.h" 9 | 10 | #include "fmt/printf.h" 11 | #include "pytorch.h" 12 | #include "rpc.h" 13 | 14 | #include 15 | 16 | template 17 | struct AllReduce { 18 | rpc::Rpc rpc; 19 | 20 | size_t myRank = 0; 21 | size_t worldSize = 0; 22 | 23 | torch::Tensor localData; 24 | std::atomic chunksDone = 0; 25 | std::atomic syncCount = 0; 26 | std::atomic_int calls = 0; 27 | void reduce(size_t sourcePeer, size_t offset, torch::Tensor data) { 28 | ++calls; 29 | size_t nextPeer = myRank == worldSize - 1 ? 0 : myRank + 1; 30 | auto ldata = localData.narrow(0, offset, data.size(0)); 31 | ldata += data; 32 | if (sourcePeer == nextPeer) { 33 | for (size_t i = 0; i != worldSize; ++i) { 34 | if (i != myRank) { 35 | rpc.async(std::to_string(i), "share", offset, ldata); 36 | } 37 | } 38 | ++chunksDone; 39 | // fmt::printf("reduce: I have all data from %d for %d + %d (chunksDone is now %d)\n", sourcePeer, offset, 40 | // data.size(0), chunksDone.load()); 41 | } else { 42 | rpc.async(std::to_string(nextPeer), "reduce", sourcePeer, offset, ldata); 43 | } 44 | } 45 | void share(size_t offset, torch::Tensor data) { 46 | ++calls; 47 | localData.narrow(0, offset, data.size(0)) = data; 48 | ++chunksDone; 49 | // fmt::printf("share: I got share data for %d + %d (chunksDone is now %d)\n", offset, data.size(0), 50 | // chunksDone.load()); 51 | } 52 | 53 | void synchronize() { 54 | static std::atomic_int32_t counter = 0; 55 | for (size_t i = 0; i != worldSize; ++i) { 56 | if (i != myRank) { 57 | rpc.async(std::to_string(i), "sync", std::string(rpc.getName()) + "-" + std::to_string(++counter)); 58 | } 59 | } 60 | while (syncCount < worldSize - 1) { 61 | std::this_thread::sleep_for(std::chrono::milliseconds(2)); 62 | } 63 | syncCount -= worldSize - 1; 64 | fmt::printf("Synchronized!\n"); 65 | } 66 | 67 | AllReduce(size_t worldSize, size_t rank, std::string masterAddr) : myRank(rank), worldSize(worldSize) { 68 | 69 | rpc.define( 70 | "reduce", 71 | [this](size_t sourcePeer, size_t offset, torch::Tensor data) { reduce(sourcePeer, offset, std::move(data)); }); 72 | rpc.define( 73 | "share", [this](size_t offset, torch::Tensor data) { share(offset, std::move(data)); }); 74 | rpc.define("sync", [this](std::string id) { ++syncCount; }); 75 | 76 | rpc.setName(std::to_string(rank)); 77 | 78 | if (rank == 0) { 79 | rpc.listen(masterAddr); 80 | } else { 81 | rpc.connect(masterAddr); 82 | } 83 | 84 | for (int i = 0; i < 20; i += 1) { 85 | 86 | chunksDone = 0; 87 | 88 | size_t dataSize = 400 + 1024 * 128 * i; 89 | 90 | localData = torch::randn({(int64_t)dataSize}, device); 91 | 92 | synchronize(); 93 | 94 | Timer tx; 95 | 96 | size_t remainingData = dataSize; 97 | size_t offset = 0; 98 | for (size_t i = 0; i != worldSize; ++i) { 99 | size_t div = (worldSize - i); 100 | size_t chunkSize = (remainingData + div - 1) / div; 101 | remainingData -= chunkSize; 102 | 103 | if (i == myRank) { 104 | size_t nextPeer = i == worldSize - 1 ? 0 : i + 1; 105 | rpc.async(std::to_string(nextPeer), "reduce", i, offset, localData.narrow(0, offset, chunkSize)); 106 | } 107 | 108 | offset += chunkSize; 109 | } 110 | 111 | ASSERT(remainingData == 0); 112 | 113 | while (chunksDone != worldSize) { 114 | std::this_thread::sleep_for(std::chrono::milliseconds(2)); 115 | } 116 | fmt::printf("Local AllReduce done\n"); 117 | 118 | float time = tx.elapsed(); 119 | 120 | // for (int i = 0; i != 10; ++i) { 121 | // ASSERT(chunksDone == worldSize); 122 | // ASSERT(calls == (worldSize - 1) * 2); 123 | // std::this_thread::sleep_for(std::chrono::milliseconds(i)); 124 | // } 125 | synchronize(); 126 | 127 | // for (size_t i = 0; i != dataSize; ++i) { 128 | // uint64_t sum = 0; 129 | // for (auto& v : originalLocalData) { 130 | // sum += v[i]; 131 | // } 132 | // for (auto& v : finalData) { 133 | // ASSERT(sum == v[i]); 134 | // } 135 | // } 136 | 137 | ASSERT(chunksDone == worldSize); 138 | 139 | fmt::printf("AllReduce %d done! Sum %g\n", i, localData.sum().template item()); 140 | 141 | int thiscalls = calls; 142 | calls = 0; 143 | 144 | fmt::printf("AllReduce %gs, %gM/s, calls: %d\n", time, (dataSize / time) / 1024 / 1024, thiscalls); 145 | } 146 | 147 | rpc.debugInfo(); 148 | } 149 | }; 150 | 151 | using AllReduceCpu = AllReduce; 152 | using AllReduceCuda = AllReduce; 153 | 154 | #include 155 | 156 | int main() { 157 | struct sigaction act; 158 | act.sa_handler = SIG_IGN; 159 | sigaction(SIGPIPE, &act, NULL); 160 | 161 | auto env = [&](const char* name) { 162 | const char* value = std::getenv(name); 163 | if (!value) { 164 | fmt::printf("Required env var %s not set\n", name); 165 | std::exit(-1); 166 | } 167 | return value; 168 | }; 169 | 170 | int worldSize = std::atoi(env("WORLD_SIZE")); 171 | int rank = std::atoi(env("RANK")); 172 | std::string masterAddr = std::string(env("MASTER_ADDR")) + ":" + std::string(env("MASTER_PORT")); 173 | 174 | fmt::printf("World size: %d\nRank: %d\nMaster address: %s\n", worldSize, rank, masterAddr); 175 | fflush(stdout); 176 | 177 | RUNARG(AllReduceCpu, worldSize, rank, masterAddr); 178 | // RUNARG(AllReduceCuda, worldSize, rank, masterAddr); 179 | 180 | quit(); 181 | return 0; 182 | } 183 | -------------------------------------------------------------------------------- /test/test_reduce.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import time 8 | import torch 9 | import random 10 | 11 | moolib.set_log_level("debug") 12 | 13 | current_reduce_size = 0 14 | current_reduce_done = 0 15 | current_reduce_sum = 0 16 | 17 | 18 | class Client: 19 | def __init__(self, broker_addr, index): 20 | self.rpc = moolib.Rpc() 21 | self.rpc.set_name("client %d" % index) 22 | self.rpc.set_timeout(20) 23 | self.rpc.connect(broker_addr) 24 | self.group = None 25 | self.index = index 26 | self.reduce = None 27 | self.sum = None 28 | self.reduce_counter = 0 29 | 30 | def update(self): 31 | 32 | if self.group is None: 33 | self.group = moolib.Group(self.rpc, "test group") 34 | self.group.set_timeout(1) 35 | self.group.set_sort_order(self.index) 36 | elif random.random() < 0.5 or True: 37 | updated = self.group.update() 38 | 39 | if updated: 40 | print( 41 | "group '%s', sync id %#x, members %s" 42 | % (self.group.name(), self.group.sync_id(), self.group.members()) 43 | ) 44 | 45 | if self.group.active(): 46 | if self.reduce is None: 47 | self.start_reduce() 48 | else: 49 | result = None 50 | 51 | global current_reduce_size 52 | global current_reduce_done 53 | global current_reduce_sum 54 | 55 | try: 56 | if self.reduce.done(): 57 | # although the result is available in self.tensor, 58 | # calling .result() here is necessary to raise any errors 59 | result = self.reduce.result() 60 | except RuntimeError as e: 61 | current_reduce_size = 0 62 | current_reduce_done = 0 63 | print(e) 64 | self.reduce = None 65 | 66 | if result is not None: 67 | self.sum = result.sum().item() 68 | print("reduced to sum ", self.sum) 69 | self.reduce = None 70 | if current_reduce_size == current_reduce_done: 71 | current_reduce_done = 1 72 | current_reduce_size = len(self.group.members()) 73 | current_reduce_sum = self.sum 74 | print("New reduction of size %d" % current_reduce_size) 75 | else: 76 | current_reduce_done += 1 77 | if abs(self.sum - current_reduce_sum) > 0.01: 78 | raise RuntimeError( 79 | "Reduce sum mismatch, got %g, expected %g" 80 | % (self.sum, current_reduce_sum) 81 | ) 82 | print( 83 | "Reduce %d/%d done" 84 | % (current_reduce_done, current_reduce_size) 85 | ) 86 | 87 | def start_reduce(self): 88 | self.tensor = torch.randn(64, 64) 89 | print("input sum ", self.tensor.sum().item()) 90 | self.reduce = self.group.all_reduce("test reduce", self.tensor) 91 | 92 | 93 | def main(): 94 | 95 | localAddr = "127.0.0.1:4411" 96 | 97 | broker_rpc = moolib.Rpc() 98 | broker_rpc.set_name("broker") 99 | broker = moolib.Broker(broker_rpc) 100 | broker_rpc.listen(localAddr) 101 | 102 | clients = [] 103 | for i in range(4): 104 | clients.append(Client(localAddr, i)) 105 | 106 | for _ in range(30): 107 | 108 | for client in clients: 109 | client.update() 110 | 111 | broker.update() 112 | 113 | time.sleep(0.1) 114 | 115 | # assert clients[0].group.members() == ["client 0", "client 1", "client 2", "client 3"] 116 | # del clients[1] 117 | 118 | for _ in range(300): 119 | 120 | for client in clients: 121 | client.update() 122 | 123 | broker.update() 124 | 125 | time.sleep(0.1) 126 | 127 | print("All done") 128 | 129 | 130 | main() 131 | -------------------------------------------------------------------------------- /test/test_reduce_asyncio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import moolib 7 | import time 8 | import torch 9 | 10 | import asyncio 11 | 12 | moolib.set_log_level("verbose") 13 | 14 | localAddr = "127.0.0.1:4411" 15 | # localAddr = "shm://testest" 16 | 17 | loop = asyncio.get_event_loop() 18 | 19 | terminate = False 20 | 21 | 22 | class Group: 23 | def __init__(self, rpc, group_name): 24 | self.group = moolib.Group(rpc, group_name) 25 | self.group.set_timeout(2) 26 | self.rpc = rpc 27 | self.my_name = rpc.get_name() 28 | self.group_name = group_name 29 | self.members = [] 30 | 31 | def update(self): 32 | if self.group.update(): 33 | self.members = self.group.members() 34 | 35 | def active(self): 36 | return self.group.active() 37 | 38 | async def wait_for_active(self): 39 | while not self.active(): 40 | await asyncio.sleep(0.25) 41 | 42 | async def all_reduce(self, name, tensor): 43 | return await self.group.all_reduce(name, tensor) 44 | 45 | 46 | async def keepalive(obj): 47 | try: 48 | while True: 49 | obj.update() 50 | await asyncio.sleep(0.25) 51 | except asyncio.CancelledError: 52 | print("keepalive cancelled") 53 | pass 54 | 55 | 56 | inputsum = {} 57 | reducesum = None 58 | 59 | 60 | async def client(index, n_clients): 61 | rpc = moolib.Rpc() 62 | rpc.set_name("client %d" % index) 63 | rpc.set_timeout(2) 64 | rpc.connect(localAddr) 65 | 66 | group = Group(rpc, "test group") 67 | 68 | group_keepalive = loop.create_task(keepalive(group)) 69 | 70 | my_name = rpc.get_name() 71 | 72 | global inputsum 73 | global reducesum 74 | 75 | n = 0 76 | 77 | while not terminate: 78 | try: 79 | await group.wait_for_active() 80 | 81 | tensor = torch.randn(64, 64) 82 | print("input sum ", tensor.sum().item()) 83 | inputsum[index] = tensor.sum().item() 84 | try: 85 | start = time.time() 86 | localsum = await group.all_reduce("test reduce", tensor) 87 | print("allreduce took %g" % (time.time() - start)) 88 | except RuntimeError as e: 89 | print(e) 90 | continue 91 | 92 | mysum = localsum.sum().item() 93 | print("reduce %d done -> sum %g" % (index, mysum)) 94 | 95 | if len(inputsum) == n_clients: 96 | actualsum = sum(v for k, v in inputsum.items()) 97 | for k, v in inputsum.items(): 98 | print(k, v) 99 | inputsum = {} 100 | 101 | if abs(mysum - actualsum) > 1e-2: 102 | raise RuntimeError( 103 | "sum mismatch: my sum is %g, real sum is %g" 104 | % (mysum, actualsum) 105 | ) 106 | reducesum = actualsum 107 | elif abs(mysum - reducesum) > 1e-2: 108 | raise RuntimeError( 109 | "sum mismatch: my sum is %g, should be %g" % (mysum, actualsum) 110 | ) 111 | 112 | n += 1 113 | 114 | # await asyncio.sleep(2) 115 | 116 | except asyncio.CancelledError: 117 | print("client cancelled!") 118 | break 119 | 120 | del group_keepalive 121 | print(my_name, "normal exit :)") 122 | 123 | 124 | async def broker(): 125 | broker_rpc = moolib.Rpc() 126 | broker_rpc.set_name("broker") 127 | broker = moolib.Broker(broker_rpc) 128 | broker_rpc.listen(localAddr) 129 | t = time.time() 130 | while not terminate: 131 | now = time.time() 132 | print("Time since last broker update: %g" % (now - t)) 133 | t = now 134 | broker.update() 135 | await asyncio.sleep(0.25) 136 | print("broker done!") 137 | 138 | 139 | async def wait(tasks, timeout): 140 | done, pending = await asyncio.wait( 141 | tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED 142 | ) 143 | for i in done: 144 | i.result() 145 | return pending 146 | 147 | 148 | async def main(): 149 | global terminate 150 | 151 | broker_task = loop.create_task(broker()) 152 | await wait([broker_task], 0.25) 153 | 154 | # await wait([broker_task], 120) 155 | 156 | n_clients = 11 157 | 158 | clients = [] 159 | for i in range(n_clients): 160 | clients.append(loop.create_task(client(i, n_clients))) 161 | 162 | await wait([broker_task, *clients], 45) 163 | terminate = True 164 | for i in clients: 165 | i.cancel() 166 | await asyncio.gather(broker_task, *clients) 167 | 168 | print("All done") 169 | 170 | 171 | try: 172 | loop.run_until_complete(main()) 173 | except: 174 | import traceback 175 | 176 | traceback.print_exc() 177 | -------------------------------------------------------------------------------- /test/unit/test_batcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | 8 | import torch 9 | import moolib 10 | 11 | 12 | class TestMoolibBatcher: 13 | def test_batcher(self): 14 | for _ in range(256): 15 | size = random.randint(1, 20) 16 | dim = random.randint(0, 2) 17 | dims = random.randint(dim + 1, dim + 2) 18 | 19 | n = random.randint(20, 100) 20 | shape = [random.randint(1, 4) for _ in range(dims)] 21 | 22 | batcher = moolib.Batcher(size=size, dim=dim) 23 | 24 | inputs = [] 25 | for _ in range(n): 26 | input = torch.randn(shape) 27 | inputs.append(input.clone()) 28 | batcher.stack(input) 29 | assert input.equal(inputs[-1]) 30 | if not batcher.empty(): 31 | batched = batcher.get() 32 | stacked = torch.stack(inputs, dim=dim) 33 | assert batched.equal(stacked) 34 | inputs = [] 35 | 36 | batcher = moolib.Batcher(size=size, dim=dim) 37 | 38 | inputs = [] 39 | for _ in range(n): 40 | input = torch.randn(shape) 41 | inputs.append(input.clone()) 42 | batcher.cat(input) 43 | assert input.equal(inputs[-1]) 44 | if not batcher.empty(): 45 | batched = batcher.get() 46 | catted = torch.cat(inputs, dim=dim) 47 | overflow = catted.narrow(dim, size, catted.size(dim) - size) 48 | catted = catted.narrow(dim, 0, size) 49 | assert batched.equal(catted) 50 | inputs = [] 51 | if overflow.size(dim) > 0: 52 | inputs.append(overflow) 53 | -------------------------------------------------------------------------------- /test/unit/test_broker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from unittest import mock 7 | 8 | import moolib.broker 9 | 10 | 11 | class TestBrokerScript: 12 | def test_has_main(self): 13 | assert hasattr(moolib.broker, "main") 14 | 15 | @mock.patch("sys.argv", ["broker"]) 16 | def test_run(self): 17 | with mock.patch("moolib.Broker") as MockBroker: 18 | instance = MockBroker.return_value 19 | instance.update.side_effect = [ 20 | None, 21 | None, 22 | None, 23 | KeyboardInterrupt("Enough"), 24 | ] 25 | 26 | moolib.broker.main() 27 | 28 | MockBroker.assert_called_once() 29 | assert instance.update.call_count == 4 30 | -------------------------------------------------------------------------------- /test/unit/test_envpool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | import pytest 8 | 9 | import torch 10 | import moolib 11 | 12 | 13 | class Env: 14 | def __init__(self): 15 | self.n = None 16 | 17 | def step(self, action): 18 | if action == 0: 19 | pass 20 | elif action == 1: 21 | self.n *= 2 22 | elif action == 2: 23 | self.n /= 2 24 | else: 25 | raise RuntimeError("bad action") 26 | obs = {"n": self.n} 27 | done = self.n.sum() < 1 28 | reward = abs(self.n.sum() - 4) 29 | return obs, reward, done 30 | 31 | def reset(self): 32 | self.n = torch.ones(4, 4) 33 | self.n[0][0] = 4.0 34 | self.n[3][1] = 0.5 35 | self.n[1][2] = 0.25 36 | return {"n": self.n} 37 | 38 | 39 | class TestMoolibpEnvPool: 40 | def test_envpool(self): 41 | bs = 32 42 | envs = moolib.EnvPool(Env, batch_size=bs, num_batches=2, num_processes=4) 43 | 44 | with pytest.raises(RuntimeError): 45 | envs.step(0, torch.zeros(bs)) 46 | with pytest.raises(RuntimeError): 47 | envs.step(0, torch.zeros(bs + 1).long()) 48 | 49 | z = torch.zeros(bs).long() 50 | 51 | initial = torch.ones(4, 4) 52 | initial[0][0] = 4.0 53 | initial[3][1] = 0.5 54 | initial[1][2] = 0.25 55 | initial = initial.expand(bs, 4, 4) 56 | 57 | obs = envs.step(batch_index=0, action=z).result() 58 | assert obs["n"].equal(initial) 59 | fut0 = envs.step(batch_index=0, action=z + 1) 60 | fut1 = envs.step(batch_index=1, action=z) 61 | obs = fut0.result() 62 | assert obs["n"].equal(initial * 2) 63 | obs = fut1.result() 64 | assert obs["n"].equal(initial) 65 | obs = envs.step(batch_index=0, action=z + 2).result() 66 | assert obs["n"].equal(initial) 67 | 68 | states = [initial.clone(), initial.clone()] 69 | 70 | for _ in range(100): 71 | index = random.randint(0, 1) 72 | action = torch.randint(0, 3, [bs]) 73 | s = states[index] 74 | obs = envs.step(index, action).result() 75 | reward = obs["reward"] 76 | done = obs["done"] 77 | for i in range(bs): 78 | if action[i] == 1: 79 | s[i] *= 2 80 | elif action[i] == 2: 81 | s[i] /= 2 82 | r = abs(s[i].sum() - 4) 83 | d = s[i].sum() < 1 84 | if d: 85 | s[i] = initial[i] 86 | assert d == done[i] 87 | assert s[i].equal(obs["n"][i]) 88 | assert r == reward[i] 89 | -------------------------------------------------------------------------------- /test/unit/test_pickle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | 8 | import pytest 9 | 10 | import moolib 11 | 12 | 13 | class MyValue: 14 | __slots__ = ["value"] 15 | 16 | def __init__(self, value): 17 | self.value = value 18 | 19 | 20 | class MyClass: 21 | __slots__ = ["foo", "bar"] 22 | 23 | def __init__(self): 24 | self.foo = 42 25 | self.bar = {"key": MyValue(random.getrandbits(18))} 26 | 27 | 28 | ADDRESS = "127.0.0.1:4422" 29 | 30 | 31 | class TestMoolibPickle: 32 | @pytest.fixture 33 | def host(self): 34 | host = moolib.Rpc() 35 | host.set_name("host") 36 | host.listen(ADDRESS) 37 | yield host 38 | 39 | @pytest.fixture 40 | def client(self, host): 41 | client = moolib.Rpc() 42 | client.set_name("client") 43 | client.connect(ADDRESS) 44 | yield client 45 | 46 | def test_pickle_hello(self, host, client): 47 | def hello(message): 48 | print("Got hello: ", message.foo, message.bar["key"].value) 49 | return ( 50 | "this is a response to message '" + str(message.foo) + "'", 51 | message.bar["key"].value, 52 | ) 53 | 54 | host.define("hello", hello) 55 | 56 | for _ in range(10): 57 | input = MyClass() 58 | 59 | print("sending ", input.foo, input.bar["key"].value) 60 | 61 | msg, value = client.sync("host", "hello", input) 62 | print("sync: ", msg, value) 63 | assert value == input.bar["key"].value 64 | -------------------------------------------------------------------------------- /test/unit/test_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | import pytest 9 | 10 | import moolib 11 | 12 | ADDRESS = "127.0.0.1:4411" 13 | 14 | 15 | class TestMoolib: 16 | def test_call_async_and_sync(self): 17 | client = moolib.Rpc() 18 | host = moolib.Rpc() 19 | 20 | client.set_name("client") 21 | client.set_timeout(1) 22 | 23 | num_calls = 0 24 | 25 | def hello(message): 26 | nonlocal num_calls 27 | num_calls += 1 28 | print("Got hello: ", message) 29 | return "this is a response to message '" + message + "'" 30 | 31 | host.define("hello", hello) 32 | 33 | host.set_name("host") 34 | host.listen(ADDRESS) 35 | 36 | client.connect(ADDRESS) 37 | 38 | message = "this is a message from client" 39 | future = client.async_("host", "hello", message) 40 | 41 | response = future.result() 42 | 43 | assert num_calls == 1 44 | assert response == hello(message) 45 | 46 | message2 = "sync test" 47 | assert client.sync("host", "hello", message2) == hello(message2) 48 | assert num_calls == 4 49 | 50 | def test_async_callback_and_unknown_peer(self): 51 | client = moolib.Rpc() 52 | host = moolib.Rpc() 53 | 54 | client.set_name("client") 55 | client.set_timeout(1) 56 | 57 | def hello(message): 58 | return "this is a response to message %s" % repr(message) 59 | 60 | host.define("hello", hello) 61 | host.set_name("host") 62 | host.listen(ADDRESS) 63 | client.connect(ADDRESS) 64 | 65 | num_calls = 0 66 | message = "this is a message through async_callback" 67 | 68 | def helloCallback(response, error): 69 | nonlocal num_calls 70 | num_calls += 1 71 | assert response == hello(message) 72 | assert error is None 73 | 74 | client.async_callback("host", "hello", helloCallback, message) 75 | 76 | future = client.async_("nowhere", "hello", "this is a message to nowhere") 77 | with pytest.raises( # TODO: Should this be a RuntimeError? 78 | RuntimeError, match=re.escape("Call (nowhere::) timed out") 79 | ): 80 | future.result() 81 | 82 | assert num_calls == 1 83 | 84 | def test_nonexistent_function_and_dead_host(self): 85 | client = moolib.Rpc() 86 | host = moolib.Rpc() 87 | 88 | client.set_name("client") 89 | client.set_timeout(1) 90 | 91 | called = False 92 | 93 | def client_hello(a, b, c): 94 | nonlocal called 95 | called = True 96 | assert (a, b, c) == (1, 2, 3) 97 | return (c, b, a) 98 | 99 | client.define("client hello", client_hello) 100 | 101 | def hello(message): 102 | return 42 103 | 104 | host.define("hello", hello) 105 | host.set_name("host") 106 | host.listen(ADDRESS) 107 | client.connect(ADDRESS) 108 | 109 | with pytest.raises( # TODO: Should this be a RuntimeError? 110 | RuntimeError, 111 | match=re.escape( 112 | "RPC remote function host::'non-existent function' does not exist" 113 | ), 114 | ): 115 | client.sync("host", "non-existent function") 116 | 117 | del host 118 | 119 | with pytest.raises( 120 | # TODO: Why is this and not "hello"? 121 | RuntimeError, 122 | match=re.escape("Call (host::) timed out"), 123 | ): 124 | client.sync("host", "hello", "is host dead?") 125 | 126 | host = moolib.Rpc() 127 | host.define("hello", hello) 128 | host.set_name("host") 129 | host.listen(ADDRESS) 130 | 131 | client.set_timeout(30) 132 | assert client.sync("host", "hello", "is host alive?") == 42 133 | assert host.sync("client", "client hello", 1, 2, 3) == (3, 2, 1) 134 | assert called 135 | -------------------------------------------------------------------------------- /test/unit/test_tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import time 7 | import warnings 8 | 9 | import torch 10 | import pytest 11 | 12 | import moolib 13 | 14 | ADDRESS = "127.0.0.1:4422" 15 | 16 | 17 | class TestMoolibSpeeds: 18 | @pytest.fixture 19 | def host(self): 20 | host = moolib.Rpc() 21 | host.set_name("host") 22 | host.listen(ADDRESS) 23 | yield host 24 | 25 | @pytest.fixture 26 | def client(self, host): 27 | client = moolib.Rpc() 28 | client.set_name("client") 29 | client.connect(ADDRESS) 30 | yield client 31 | 32 | def test_sync(self, host, client): 33 | weights = torch.randn(4096, 4096) 34 | 35 | def linear(inputs): 36 | return (weights * inputs).sum(-1) 37 | 38 | host.define("linear", linear) 39 | 40 | client.set_timeout(60) 41 | 42 | inputs = torch.randn(16, 4096) 43 | 44 | client.sync("host", "linear", inputs.unsqueeze(1)) 45 | 46 | def test_sync_noop_speed(self, host, client): 47 | def noop(): 48 | pass 49 | 50 | host.define("noop", noop) 51 | 52 | client.set_timeout(60) 53 | 54 | iterations = 128 55 | 56 | start = time.time() 57 | for _ in range(iterations): 58 | client.sync("host", "noop") 59 | t = time.time() - start 60 | 61 | print( 62 | "%d iterations took %f sec (%.1f per sec)" % (iterations, t, iterations / t) 63 | ) 64 | 65 | if iterations / t < 1000: 66 | warnings.warn(f"Very slow iteration speed: {iterations / t}") 67 | 68 | def test_async_noop_speed(self, host, client): 69 | def noop(): 70 | pass 71 | 72 | host.define("noop", noop) 73 | 74 | futures = [] 75 | start = time.time() 76 | for _ in range(2000): 77 | futures.append(client.async_("host", "noop")) 78 | for i in futures: 79 | i.result() 80 | t = time.time() - start 81 | print("noop x%d time %g (%g/s)" % (len(futures), t, len(futures) / t)) 82 | 83 | if len(futures) / t < 500: 84 | warnings.warn(f"Very slow iteration speed: {len(futures) / t}") 85 | 86 | def test_async_vs_local(self, host, client): 87 | weights = torch.randn(4096, 4096) 88 | 89 | def linear(inputs): 90 | return (weights * inputs).sum(-1) 91 | 92 | host.define("linear", linear) 93 | 94 | inputs = torch.randn(16, 4096) 95 | for _ in range(2): 96 | start = time.time() 97 | local_result = sum( 98 | linear(inputs[i]).sum().item() for i in range(inputs.size(0)) 99 | ) 100 | print("base time ", time.time() - start) 101 | 102 | for _ in range(4): 103 | futures = [] 104 | start = time.time() 105 | for i in range(inputs.size(0)): 106 | futures.append(client.async_("host", "linear", inputs[i])) 107 | result = sum(i.result().sum().item() for i in futures) 108 | print("async time ", time.time() - start) 109 | assert abs(result - local_result) < 0.1 110 | 111 | host.debug_info() 112 | client.debug_info() 113 | --------------------------------------------------------------------------------