├── .git-blame-ignore-revs ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── scripts │ ├── linux-post-script.sh │ ├── linux-pre-script.sh │ ├── version_script.bat │ ├── version_script.sh │ └── win-pre-script.bat ├── unittest │ ├── linux │ │ └── scripts │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ └── rl_linux_optdeps │ │ └── scripts │ │ ├── environment.yml │ │ ├── install.sh │ │ ├── post_process.sh │ │ ├── run-clang-format.py │ │ ├── run_test.sh │ │ └── setup_env.sh └── workflows │ ├── benchmarks.yml │ ├── benchmarks_pr.yml │ ├── build-wheels-aarch64-linux.yml │ ├── build-wheels-linux.yml │ ├── build-wheels-m1.yml │ ├── build-wheels-windows.yml │ ├── docs.yml │ ├── lint.yml │ ├── nightly_build.yml │ ├── selfassign.yml │ ├── test-linux.yml │ ├── test-macos.yml │ └── test-rl-gpu.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GETTING_STARTED.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks ├── common │ ├── common_ops_test.py │ ├── h2d_test.py │ ├── memmap_benchmarks_test.py │ └── pytree_benchmarks_test.py ├── compile │ ├── compile_td_test.py │ └── tensordict_nn_test.py ├── conftest.py ├── distributed │ ├── dataloading.py │ └── distributed_benchmark_test.py ├── fx_benchmarks.py ├── nn │ └── functional_benchmarks_test.py ├── requirements.txt └── tensorclass │ ├── test_tensorclass_speed.py │ └── test_torch_functions.py ├── docs ├── Makefile ├── build_script.sh ├── make.bat ├── requirements.txt ├── source │ ├── _static │ │ ├── css │ │ │ └── custom_torchrl.css │ │ ├── img │ │ │ ├── graph.svg │ │ │ ├── pytorch-logo-dark.png │ │ │ ├── pytorch-logo-dark.svg │ │ │ ├── pytorch-logo-flame.png │ │ │ └── pytorch-logo-flame.svg │ │ └── js │ │ │ ├── modernizr.min.js │ │ │ ├── tensordict_theme.js │ │ │ └── theme.js │ ├── _templates │ │ ├── class.rst │ │ ├── function.rst │ │ ├── layout.html │ │ ├── td_template.rst │ │ └── td_template_noinherit.rst │ ├── conf.py │ ├── content_generation.py │ ├── distributed.rst │ ├── docutils.conf │ ├── fx.rst │ ├── index.rst │ ├── overview.rst │ ├── reference │ │ ├── generated │ │ │ └── tutorials │ │ │ │ └── README.rst │ │ ├── index.rst │ │ ├── nn.rst │ │ ├── tensorclass.rst │ │ └── tensordict.rst │ └── saving.rst └── tensordict.png ├── gallery └── README.rst ├── mypy.ini ├── packaging ├── build_wheels.sh ├── pkg_helpers.bash └── wheel │ └── relocate.py ├── pyproject.toml ├── pytest.ini ├── setup.cfg ├── setup.py ├── tensordict ├── _C │ └── __init__.pyi ├── __init__.py ├── _contextlib.py ├── _lazy.py ├── _nestedkey.py ├── _nestedkey.pyi ├── _pytree.py ├── _reductions.py ├── _td.py ├── _tensordict │ └── __init__.py ├── _torch_func.py ├── _unbatched.py ├── base.py ├── csrc │ ├── CMakeLists.txt │ ├── cmake │ │ └── FindPythonPyEnv.cmake │ ├── pybind.cpp │ ├── utils.cpp │ └── utils.h ├── functional.py ├── memmap.py ├── nn │ ├── __init__.py │ ├── common.py │ ├── cudagraphs.py │ ├── distributions │ │ ├── __init__.py │ │ ├── composite.py │ │ ├── continuous.py │ │ ├── discrete.py │ │ ├── truncated_normal.py │ │ └── utils.py │ ├── ensemble.py │ ├── functional_modules.py │ ├── params.py │ ├── probabilistic.py │ ├── sequence.py │ └── utils.py ├── persistent.py ├── prototype │ ├── __init__.py │ └── fx.py ├── tensorclass.py ├── tensorclass.pyi ├── tensordict.py └── utils.py ├── test ├── _utils_internal.py ├── artifacts │ └── mmap_example │ │ ├── meta.json │ │ └── nested │ │ ├── bfloat16.memmap │ │ ├── int64.memmap │ │ ├── meta.json │ │ └── string │ │ └── meta.json ├── conftest.py ├── smoke_test.py ├── test_compile.py ├── test_distributed.py ├── test_functorch.py ├── test_fx.py ├── test_h5.py ├── test_memmap.py ├── test_nn.py ├── test_tensorclass.py ├── test_tensordict.py └── test_utils.py ├── tutorials ├── README.md ├── dummy.py ├── media │ ├── .gitkeep │ ├── imagenet-benchmark-speed.png │ ├── imagenet-benchmark-time.png │ └── transformer.png ├── sphinx_tuto │ ├── data_fashion.py │ ├── export.py │ ├── streamed_tensordict.py │ ├── tensorclass_fashion.py │ ├── tensorclass_imagenet.py │ ├── tensordict_keys.py │ ├── tensordict_memory.py │ ├── tensordict_module.py │ ├── tensordict_preallocation.py │ ├── tensordict_shapes.py │ └── tensordict_slicing.py └── src │ └── .gitkeep └── version.txt /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # This file keeps git blame clean. 2 | # See https://docs.github.com/en/repositories/working-with-files/using-files/viewing-a-file#ignore-commits-in-the-blame-view 3 | 4 | 545c3b5dda2bd30c5a8dceb152b71ea5af4dacc8 5 | 6 | 5fda95b8dd0f5b66993f765a6ad92dbe658581dc 7 | 8 | f302fb07ac691c0ac8144f4ce7a6648858ecb630 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: ["bug"] 6 | assignees: vmoens 7 | 8 | --- 9 | 10 | ## Describe the bug 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | ## To Reproduce 15 | 16 | Steps to reproduce the behavior. 17 | 18 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. 19 | 20 | Please use the markdown code blocks for both code and stack traces. 21 | 22 | ```python 23 | import tensordict 24 | ``` 25 | 26 | ```bash 27 | Traceback (most recent call last): 28 | File ... 29 | ``` 30 | 31 | ## Expected behavior 32 | 33 | A clear and concise description of what you expected to happen. 34 | 35 | ## Screenshots 36 | 37 | If applicable, add screenshots to help explain your problem. 38 | 39 | ## System info 40 | 41 | Describe the characteristic of your environment: 42 | * Describe how the library was installed (pip, source, ...) 43 | * Python version 44 | * Versions of any other relevant libraries 45 | 46 | ```python 47 | import tensordict, numpy, sys, torch 48 | print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__) 49 | ``` 50 | 51 | ## Additional context 52 | 53 | Add any other context about the problem here. 54 | 55 | ## Reason and Possible fixes 56 | 57 | If you know or suspect the reason for this bug, paste the code lines and suggest modifications. 58 | 59 | ## Checklist 60 | 61 | - [ ] I have checked that there is no similar issue in the repo (**required**) 62 | - [ ] I have read the [documentation](https://github.com/pytorch/tensordict/tree/main/docs/) (**required**) 63 | - [ ] I have provided a minimal working example to reproduce the bug (**required**) 64 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: ["enhancement"] 6 | assignees: vmoens 7 | 8 | --- 9 | 10 | ## Motivation 11 | 12 | Please outline the motivation for the proposal. 13 | Is your feature request related to a problem? e.g., "I'm always frustrated when [...]". 14 | If this is related to another issue, please link here too. 15 | 16 | ## Solution 17 | 18 | A clear and concise description of what you want to happen. 19 | 20 | ## Alternatives 21 | 22 | A clear and concise description of any alternative solutions or features you've considered. 23 | 24 | ## Additional context 25 | 26 | Add any other context or screenshots about the feature request here. 27 | 28 | ## Checklist 29 | 30 | - [ ] I have checked that there is no similar issue in the repo (**required**) 31 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | Describe your changes in detail. 4 | 5 | ## Motivation and Context 6 | 7 | Why is this change required? What problem does it solve? 8 | If it fixes an open issue, please link to the issue here. 9 | You can use the syntax `close #15213` if this solves the issue #15213 10 | 11 | - [ ] I have raised an issue to propose this change ([required](https://github.com/pytorch/tensordict/issues) for new features and bug fixes) 12 | 13 | ## Types of changes 14 | 15 | What types of changes does your code introduce? Remove all that do not apply: 16 | 17 | - [ ] Bug fix (non-breaking change which fixes an issue) 18 | - [ ] New feature (non-breaking change which adds core functionality) 19 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 20 | - [ ] Documentation (update in the documentation) 21 | - [ ] Example (update in the folder of examples) 22 | 23 | ## Checklist 24 | 25 | Go over all the following points, and put an `x` in all the boxes that apply. 26 | If you are unsure about any of these, don't hesitate to ask. We are here to help! 27 | 28 | - [ ] I have read the [CONTRIBUTION](https://github.com/pytorch/tensordict/blob/main/CONTRIBUTING.md) guide (**required**) 29 | - [ ] My change requires a change to the documentation. 30 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). 31 | - [ ] I have updated the documentation accordingly. 32 | -------------------------------------------------------------------------------- /.github/scripts/linux-post-script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" != "Darwin" ]; then 4 | yum update gcc 5 | yum update libstdc++ 6 | else 7 | brew update 8 | brew upgrade gcc 9 | fi 10 | -------------------------------------------------------------------------------- /.github/scripts/linux-pre-script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ${CONDA_RUN} conda install -c conda-forge pybind11 -y 4 | -------------------------------------------------------------------------------- /.github/scripts/version_script.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set TENSORDICT_BUILD_VERSION=0.9.0 3 | set SETUPTOOLS_SCM_PRETEND_VERSION=0.9.0 4 | echo TENSORDICT_BUILD_VERSION is set to %TENSORDICT_BUILD_VERSION% 5 | 6 | if "%CONDA_RUN%"=="" ( 7 | echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN. 8 | exit /b 1 9 | ) 10 | 11 | @echo on 12 | 13 | set VC_VERSION_LOWER=17 14 | set VC_VERSION_UPPER=18 15 | if "%VC_YEAR%" == "2019" ( 16 | set VC_VERSION_LOWER=16 17 | set VC_VERSION_UPPER=17 18 | ) 19 | if "%VC_YEAR%" == "2017" ( 20 | set VC_VERSION_LOWER=15 21 | set VC_VERSION_UPPER=16 22 | ) 23 | 24 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( 25 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 26 | set "VS15INSTALLDIR=%%i" 27 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 28 | goto vswhere 29 | ) 30 | ) 31 | 32 | :vswhere 33 | if "%VSDEVCMD_ARGS%" == "" ( 34 | call "%VS15VCVARSALL%" x64 || exit /b 1 35 | ) else ( 36 | call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 37 | ) 38 | 39 | @echo on 40 | 41 | if "%CU_VERSION%" == "xpu" call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" 42 | 43 | set DISTUTILS_USE_SDK=1 44 | 45 | set args=%1 46 | shift 47 | :start 48 | if [%1] == [] goto done 49 | set args=%args% %1 50 | shift 51 | goto start 52 | 53 | :done 54 | if "%args%" == "" ( 55 | echo Usage: vc_env_helper.bat [command] [args] 56 | echo e.g. vc_env_helper.bat cl /c test.cpp 57 | ) 58 | 59 | %args% || exit /b 1 60 | -------------------------------------------------------------------------------- /.github/scripts/version_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TENSORDICT_BUILD_VERSION=0.9.0 4 | export SETUPTOOLS_SCM_PRETEND_VERSION=$TENSORDICT_BUILD_VERSION 5 | # TODO: consider lower this 6 | export MACOSX_DEPLOYMENT_TARGET=15.0 7 | 8 | ${CONDA_RUN} pip install --upgrade pip 9 | 10 | # for orjson 11 | export UNSAFE_PYO3_BUILD_FREE_THREADED=1 12 | 13 | ${CONDA_RUN} conda install -c conda-forge pybind11 -y 14 | -------------------------------------------------------------------------------- /.github/scripts/win-pre-script.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | :: Check if CONDA_RUN is set, if not, set it to a default value 3 | if "%CONDA_RUN%"=="" ( 4 | echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN. 5 | exit /b 1 6 | ) 7 | 8 | :: Run the pip install command 9 | %CONDA_RUN% conda install conda-forge::pybind11 -y 10 | 11 | :: Check if the installation was successful 12 | if errorlevel 1 ( 13 | echo Failed to install cmake and pybind11. 14 | exit /b 1 15 | ) else ( 16 | echo Successfully installed cmake and pybind11. 17 | ) 18 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | - hypothesis 9 | - future 10 | - cloudpickle 11 | - pytest 12 | - pytest-benchmark 13 | - pytest-cov 14 | - pytest-mock 15 | - pytest-instafail 16 | - pytest-rerunfailures 17 | - pytest-timeout 18 | - expecttest 19 | - coverage 20 | - h5py 21 | - orjson 22 | - ninja 23 | - numpy<2.0.0 24 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | unset PYTORCH_VERSION 4 | # For unittest, nightly PyTorch is used as the following section, 5 | # so no need to set PYTORCH_VERSION. 6 | # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. 7 | 8 | set -e 9 | set -v 10 | 11 | eval "$(./conda/bin/conda shell.bash hook)" 12 | conda activate ./env 13 | 14 | if [ "${CU_VERSION:-}" == cpu ] ; then 15 | echo "Using cpu build" 16 | else 17 | if [[ ${#CU_VERSION} -eq 4 ]]; then 18 | CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" 19 | elif [[ ${#CU_VERSION} -eq 5 ]]; then 20 | CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" 21 | fi 22 | echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" 23 | fi 24 | 25 | # submodules 26 | git submodule sync && git submodule update --init --recursive 27 | 28 | printf "Installing PyTorch with %s\n" "${CU_VERSION}" 29 | if [[ "$TORCH_VERSION" == "nightly" ]]; then 30 | if [ "${CU_VERSION:-}" == cpu ] ; then 31 | python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu 32 | else 33 | python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION 34 | fi 35 | elif [[ "$TORCH_VERSION" == "stable" ]]; then 36 | if [ "${CU_VERSION:-}" == cpu ] ; then 37 | python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 38 | else 39 | python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION 40 | fi 41 | else 42 | printf "Failed to install pytorch" 43 | exit 1 44 | fi 45 | 46 | printf "* Installing tensordict\n" 47 | pip install -e . 48 | 49 | # # install torchsnapshot nightly 50 | # if [[ "$TORCH_VERSION" == "nightly" ]]; then 51 | # python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation 52 | # elif [[ "$TORCH_VERSION" == "stable" ]]; then 53 | # python -m pip install torchsnapshot 54 | # fi 55 | # smoke test 56 | python -c "import functorch" 57 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | export PYTORCH_TEST_WITH_SLOW='1' 9 | python -m torch.utils.collect_env 10 | # Avoid error: "fatal: unsafe repository" 11 | git config --global --add safe.directory '*' 12 | 13 | root_dir="$(git rev-parse --show-toplevel)" 14 | env_dir="${root_dir}/env" 15 | lib_dir="${env_dir}/lib" 16 | 17 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 18 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 19 | export MKL_THREADING_LAYER=GNU 20 | export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 21 | export TD_GET_DEFAULTS_TO_NONE=1 22 | export LIST_TO_STACK=1 23 | 24 | coverage run -m pytest test/smoke_test.py -v --durations 20 25 | coverage run -m pytest --runslow --instafail -v --durations 20 --timeout 120 26 | coverage run -m pytest ./benchmarks --instafail -v --durations 20 27 | coverage xml -i 28 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is for setting up environment in which unit test is ran. 4 | # To speed up the CI time, the resulting environment is cached. 5 | # 6 | # Do not install PyTorch and torchvision here, otherwise they also get cached. 7 | 8 | set -e 9 | set -v 10 | 11 | apt update -y && apt install git wget gcc -y 12 | 13 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | conda_dir="${root_dir}/conda" 18 | env_dir="${root_dir}/env" 19 | 20 | cd "${root_dir}" 21 | 22 | case "$(uname -s)" in 23 | Darwin*) os=MacOSX;; 24 | *) os=Linux 25 | esac 26 | 27 | # 1. Install conda at ./conda 28 | if [ ! -d "${conda_dir}" ]; then 29 | printf "* Installing conda\n" 30 | if [ "${os}" == "MacOSX" ]; then 31 | curl -L -o miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-${ARCH}.sh" 32 | else 33 | wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-${ARCH}.sh" 34 | fi 35 | bash ./miniconda.sh -b -f -p "${conda_dir}" 36 | fi 37 | eval "$(${conda_dir}/bin/conda shell.bash hook)" 38 | 39 | # 2. Create test environment at ./env 40 | printf "python: ${PYTHON_VERSION}\n" 41 | if [ ! -d "${env_dir}" ]; then 42 | printf "* Creating a test environment\n" 43 | conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" 44 | fi 45 | conda activate "${env_dir}" 46 | 47 | # 3. Install Conda dependencies 48 | printf "* Installing dependencies (except PyTorch)\n" 49 | echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" 50 | cat "${this_dir}/environment.yml" 51 | 52 | pip install pip --upgrade 53 | 54 | conda env update --file "${this_dir}/environment.yml" --prune 55 | 56 | conda install anaconda::cmake -y 57 | conda install -c conda-forge pybind11 -y 58 | 59 | #if [[ $OSTYPE == 'darwin'* ]]; then 60 | # printf "* Installing C++ for OSX\n" 61 | # conda install -c conda-forge cxx-compiler -y 62 | #fi 63 | -------------------------------------------------------------------------------- /.github/unittest/rl_linux_optdeps/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - expecttest 16 | - pyyaml 17 | - scipy 18 | - orjson 19 | - ninja 20 | - numpy<2.0.0 21 | -------------------------------------------------------------------------------- /.github/unittest/rl_linux_optdeps/scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | unset PYTORCH_VERSION 4 | # For unittest, nightly PyTorch is used as the following section, 5 | # so no need to set PYTORCH_VERSION. 6 | # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. 7 | apt-get update && apt-get install -y git wget gcc g++ 8 | #apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev 9 | 10 | set -e 11 | 12 | eval "$(./conda/bin/conda shell.bash hook)" 13 | conda activate ./env 14 | 15 | if [ "${CU_VERSION:-}" == cpu ] ; then 16 | version="cpu" 17 | else 18 | if [[ ${#CU_VERSION} -eq 4 ]]; then 19 | CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" 20 | elif [[ ${#CU_VERSION} -eq 5 ]]; then 21 | CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" 22 | fi 23 | echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" 24 | version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" 25 | fi 26 | 27 | # submodules 28 | git submodule sync && git submodule update --init --recursive 29 | 30 | printf "Installing PyTorch with %s\n" "${CU_VERSION}" 31 | if [ "${CU_VERSION:-}" == cpu ] ; then 32 | # conda install -y pytorch torchvision cpuonly -c pytorch-nightly 33 | # use pip to install pytorch as conda can frequently pick older release 34 | # conda install -y pytorch cpuonly -c pytorch-nightly 35 | pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu 36 | else 37 | pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu118 38 | fi 39 | 40 | # install tensordict 41 | pip3 install -e . 42 | 43 | # smoke test 44 | python -c "import functorch" 45 | 46 | printf "* Installing torchrl\n" 47 | git clone https://github.com/pytorch/rl 48 | cd rl 49 | python setup.py develop 50 | cd .. 51 | 52 | # smoke test 53 | python -c "import torchrl" 54 | -------------------------------------------------------------------------------- /.github/unittest/rl_linux_optdeps/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/rl_linux_optdeps/scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev 8 | 9 | # find libstdc 10 | STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) 11 | 12 | export PYTORCH_TEST_WITH_SLOW='1' 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | export MKL_THREADING_LAYER=GNU 18 | export CKPT_BACKEND=torch 19 | export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 20 | # RL should work with the new API 21 | export TD_GET_DEFAULTS_TO_NONE='1' 22 | 23 | #MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20 24 | MUJOCO_GL=egl python -m pytest rl/test --instafail -v --durations 20 --ignore rl/test/test_distributed.py 25 | -------------------------------------------------------------------------------- /.github/unittest/rl_linux_optdeps/scripts/setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is for setting up environment in which unit test is ran. 4 | # To speed up the CI time, the resulting environment is cached. 5 | # 6 | # Do not install PyTorch and torchvision here, otherwise they also get cached. 7 | 8 | set -e 9 | 10 | apt update -y && apt install git wget gcc -y 11 | 12 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 13 | # Avoid error: "fatal: unsafe repository" 14 | apt-get update && apt-get install -y git wget gcc g++ 15 | 16 | git config --global --add safe.directory '*' 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | conda_dir="${root_dir}/conda" 19 | env_dir="${root_dir}/env" 20 | 21 | cd "${root_dir}" 22 | 23 | case "$(uname -s)" in 24 | Darwin*) os=MacOSX;; 25 | *) os=Linux 26 | esac 27 | 28 | # 1. Install conda at ./conda 29 | if [ ! -d "${conda_dir}" ]; then 30 | printf "* Installing conda\n" 31 | wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" 32 | bash ./miniconda.sh -b -f -p "${conda_dir}" 33 | fi 34 | eval "$(${conda_dir}/bin/conda shell.bash hook)" 35 | 36 | # 2. Create test environment at ./env 37 | printf "python: ${PYTHON_VERSION}\n" 38 | if [ ! -d "${env_dir}" ]; then 39 | printf "* Creating a test environment\n" 40 | conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" 41 | fi 42 | conda activate "${env_dir}" 43 | 44 | ## 3. Install mujoco 45 | #printf "* Installing mujoco and related\n" 46 | #mkdir $root_dir/.mujoco 47 | #cd $root_dir/.mujoco/ 48 | #wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz 49 | #tar -xf mujoco-2.1.1-linux-x86_64.tar.gz 50 | #wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 51 | #tar -xf mujoco210-linux-x86_64.tar.gz 52 | #cd $this_dir 53 | 54 | # 4. Install Conda dependencies 55 | printf "* Installing dependencies (except PyTorch)\n" 56 | echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" 57 | cat "${this_dir}/environment.yml" 58 | 59 | pip install pip --upgrade 60 | 61 | conda env update --file "${this_dir}/environment.yml" --prune 62 | 63 | conda install anaconda::cmake -y 64 | conda install -c conda-forge pybind11 -y 65 | 66 | #yum makecache 67 | #yum -y install glfw-devel 68 | #yum -y install libGLEW 69 | #yum -y install gcc-c++ 70 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-aarch64-linux.yml: -------------------------------------------------------------------------------- 1 | name: Build Aarch64 Linux Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: linux-aarch64 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | with-cuda: disable 29 | build: 30 | needs: generate-matrix 31 | strategy: 32 | fail-fast: false 33 | matrix: 34 | include: 35 | - repository: pytorch/tensordict 36 | package-name: tensordict 37 | name: pytorch/tensordict 38 | uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main 39 | with: 40 | repository: ${{ matrix.repository }} 41 | ref: "" 42 | test-infra-repository: pytorch/test-infra 43 | test-infra-ref: main 44 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 45 | package-name: ${{ matrix.package-name }} 46 | smoke-test-script: ${{ matrix.smoke-test-script }} 47 | trigger-event: ${{ github.event_name }} 48 | env-var-script: .github/scripts/version_script.sh 49 | architecture: aarch64 50 | setup-miniconda: false 51 | build-command: pip wheel . && mkdir -p dist && mv tensordict*.whl ./dist 52 | build-platform: python-build-package 53 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-linux.yml: -------------------------------------------------------------------------------- 1 | name: Build Linux Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: linux 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | build: 29 | needs: generate-matrix 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | include: 34 | - repository: pytorch/tensordict 35 | smoke-test-script: test/smoke_test.py 36 | pre-script: .github/scripts/linux-pre-script.sh 37 | post-script: .github/scripts/linux-post-script.sh 38 | package-name: tensordict 39 | name: pytorch/tensordict 40 | uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main 41 | with: 42 | repository: ${{ matrix.repository }} 43 | ref: "" 44 | test-infra-repository: pytorch/test-infra 45 | test-infra-ref: main 46 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 47 | package-name: ${{ matrix.package-name }} 48 | smoke-test-script: ${{ matrix.smoke-test-script }} 49 | trigger-event: ${{ github.event_name }} 50 | env-var-script: .github/scripts/version_script.sh 51 | post-script: ${{ matrix.post-script }} 52 | build-command: pip wheel . && mkdir -p dist && mv tensordict*.whl ./dist 53 | build-platform: python-build-package 54 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-m1.yml: -------------------------------------------------------------------------------- 1 | name: Build M1 Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: macos-arm64 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | build: 29 | needs: generate-matrix 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | include: 34 | - repository: pytorch/tensordict 35 | smoke-test-script: test/smoke_test.py 36 | pre-script: .github/scripts/linux-pre-script.sh 37 | post-script: .github/scripts/linux-post-script.sh 38 | package-name: tensordict 39 | name: ${{ matrix.repository }} 40 | uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main 41 | with: 42 | repository: ${{ matrix.repository }} 43 | ref: "" 44 | test-infra-repository: pytorch/test-infra 45 | test-infra-ref: main 46 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 47 | package-name: ${{ matrix.package-name }} 48 | runner-type: macos-m2-15 49 | smoke-test-script: ${{ matrix.smoke-test-script }} 50 | trigger-event: ${{ github.event_name }} 51 | env-var-script: .github/scripts/version_script.sh 52 | build-command: pip wheel . && mkdir -p dist && mv tensordict*.whl ./dist 53 | build-platform: python-build-package 54 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-windows.yml: -------------------------------------------------------------------------------- 1 | name: Build Windows Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: windows 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | build: 29 | needs: generate-matrix 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | include: 34 | - repository: pytorch/tensordict 35 | pre-script: .github/scripts/win-pre-script.bat 36 | env-script: .github/scripts/version_script.bat 37 | post-script: "python packaging/wheel/relocate.py" 38 | smoke-test-script: test/smoke_test.py 39 | package-name: tensordict 40 | name: ${{ matrix.repository }} 41 | uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main 42 | with: 43 | repository: ${{ matrix.repository }} 44 | ref: "" 45 | test-infra-repository: pytorch/test-infra 46 | test-infra-ref: main 47 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 48 | pre-script: ${{ matrix.pre-script }} 49 | env-script: ${{ matrix.env-script }} 50 | post-script: ${{ matrix.post-script }} 51 | package-name: ${{ matrix.package-name }} 52 | smoke-test-script: ${{ matrix.smoke-test-script }} 53 | trigger-event: ${{ github.event_name }} 54 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | # This workflow builds the tensordict docs and deploys them to gh-pages. 2 | name: Generate documentation 3 | on: 4 | push: 5 | branches: 6 | - nightly 7 | - main 8 | - release/* 9 | tags: 10 | - v[0-9]+.[0-9]+.[0-9] 11 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 12 | pull_request: 13 | workflow_dispatch: 14 | 15 | concurrency: 16 | # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. 17 | # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. 18 | group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | build-docs: 23 | strategy: 24 | matrix: 25 | python_version: ["3.10"] 26 | cuda_arch_version: ["12.8"] 27 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 28 | permissions: 29 | id-token: write 30 | contents: read 31 | with: 32 | repository: pytorch/tensordict 33 | upload-artifact: docs 34 | runner: "linux.g5.4xlarge.nvidia.gpu" 35 | docker-image: "nvidia/cudagl:11.4.0-base" 36 | timeout: 120 37 | script: | 38 | set -e 39 | set -v 40 | apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils 41 | root_dir="$(pwd)" 42 | conda_dir="${root_dir}/conda" 43 | env_dir="${root_dir}/env" 44 | os=Linux 45 | 46 | # 1. Install conda at ./conda 47 | printf "* Installing conda\n" 48 | wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" 49 | bash ./miniconda.sh -b -f -p "${conda_dir}" 50 | eval "$(${conda_dir}/bin/conda shell.bash hook)" 51 | printf "* Creating a test environment\n" 52 | conda create --prefix "${env_dir}" -y python=3.10 53 | printf "* Activating\n" 54 | conda activate "${env_dir}" 55 | 56 | # 2. upgrade pip, ninja and packaging 57 | apt-get install python3-pip unzip -y -f 58 | conda install anaconda::cmake -y 59 | python3 -m pip install --upgrade pip 60 | python3 -m pip install setuptools ninja packaging "pybind11[global]" -U 61 | 62 | # 3. check python version 63 | python3 --version 64 | 65 | # 4. Check git version 66 | git version 67 | 68 | # 5. Install PyTorch 69 | python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --quiet --root-user-action=ignore 70 | 71 | # 6. Install tensordict 72 | python3 setup.py develop 73 | 74 | # 7. Install requirements 75 | export TD_GET_DEFAULTS_TO_NONE='1' 76 | python3 -m pip install -r docs/requirements.txt --quiet --root-user-action=ignore 77 | 78 | # 8. Test tensordict installation 79 | mkdir _tmp 80 | cd _tmp 81 | PYOPENGL_PLATFORM=egl MUJOCO_GL=egl python3 -c """from tensordict import *""" 82 | cd .. 83 | 84 | # 9. Set sanitize version 85 | if [[ ${{ github.event_name }} == push && (${{ github.ref_type }} == tag || (${{ github.ref_type }} == branch && ${{ github.ref_name }} == release/*)) ]]; then 86 | echo '::group::Enable version string sanitization' 87 | # This environment variable just has to exist and must not be empty. The actual value is arbitrary. 88 | # See docs/source/conf.py for details 89 | export TENSORDICT_SANITIZE_VERSION_STR_IN_DOCS=1 90 | echo '::endgroup::' 91 | fi 92 | 93 | # 10. Build doc 94 | cd ./docs 95 | make docs 96 | cd .. 97 | 98 | cp -r docs/build/html/* "${RUNNER_ARTIFACT_DIR}" 99 | echo $(ls "${RUNNER_ARTIFACT_DIR}") 100 | if [[ ${{ github.event_name == 'pull_request' }} ]]; then 101 | cp -r docs/build/html/* "${RUNNER_DOCS_DIR}" 102 | fi 103 | 104 | upload: 105 | needs: build-docs 106 | if: github.repository == 'pytorch/tensordict' && github.event_name == 'push' && 107 | ((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag') 108 | permissions: 109 | contents: write 110 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 111 | with: 112 | repository: pytorch/tensordict 113 | download-artifact: docs 114 | ref: gh-pages 115 | test-infra-ref: main 116 | script: | 117 | set -euo pipefail 118 | 119 | REF_TYPE=${{ github.ref_type }} 120 | REF_NAME=${{ github.ref_name }} 121 | 122 | if [[ "${REF_TYPE}" == branch ]]; then 123 | if [[ "${REF_NAME}" == main ]]; then 124 | TARGET_FOLDER="${REF_NAME}" 125 | # Bebug: 126 | # else 127 | # TARGET_FOLDER="release-doc" 128 | fi 129 | elif [[ "${REF_TYPE}" == tag ]]; then 130 | case "${REF_NAME}" in 131 | *-rc*) 132 | echo "Aborting upload since this is an RC tag: ${REF_NAME}" 133 | exit 0 134 | ;; 135 | *) 136 | # Strip the leading "v" as well as the trailing patch version. For example: 137 | # 'v0.15.2' -> '0.15' 138 | TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') 139 | ;; 140 | esac 141 | fi 142 | 143 | echo "Target Folder: ${TARGET_FOLDER}" 144 | 145 | mkdir -p "${TARGET_FOLDER}" 146 | rm -rf "${TARGET_FOLDER}"/* 147 | 148 | echo $(ls "${RUNNER_ARTIFACT_DIR}") 149 | rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" 150 | git add "${TARGET_FOLDER}" || true 151 | 152 | if [[ "${TARGET_FOLDER}" == "main" ]] ; then 153 | mkdir -p _static 154 | rm -rf _static/* 155 | cp -r "${TARGET_FOLDER}"/_static/* _static 156 | git add _static || true 157 | fi 158 | 159 | git config user.name 'pytorchbot' 160 | git config user.email 'soumith+bot@pytorch.org' 161 | git config http.postBuffer 524288000 162 | git commit -m "auto-generating sphinx docs" || true 163 | git push 164 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | concurrency: 13 | # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. 14 | # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. 15 | group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | python-source-and-configs: 20 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 21 | permissions: 22 | id-token: write 23 | contents: read 24 | with: 25 | repository: pytorch/tensordict 26 | script: | 27 | set -euo pipefail 28 | 29 | echo '::group::Setup environment' 30 | CONDA_PATH=$(which conda) 31 | eval "$(${CONDA_PATH} shell.bash hook)" 32 | conda create --name ci --quiet --yes python=3.9 pip 33 | conda activate ci 34 | echo '::endgroup::' 35 | 36 | echo '::group::Install lint tools' 37 | pip install --progress-bar=off pre-commit 38 | echo '::endgroup::' 39 | 40 | echo '::group::Lint Python source and configs' 41 | set +e 42 | pre-commit run --all-files 43 | 44 | if [ $? -ne 0 ]; then 45 | git --no-pager diff 46 | exit 1 47 | fi 48 | echo '::endgroup::' 49 | 50 | c-source: 51 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 52 | permissions: 53 | id-token: write 54 | contents: read 55 | with: 56 | repository: pytorch/tensordict 57 | script: | 58 | set -euo pipefail 59 | 60 | echo '::group::Setup environment' 61 | CONDA_PATH=$(which conda) 62 | eval "$(${CONDA_PATH} shell.bash hook)" 63 | conda create --name ci --quiet --yes -c conda-forge python=3.10 ncurses=5 libgcc 64 | conda activate ci 65 | export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" 66 | echo '::endgroup::' 67 | 68 | echo '::group::Install lint tools' 69 | curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o ./clang-format 70 | chmod +x ./clang-format 71 | echo '::endgroup::' 72 | 73 | echo '::group::Lint C source' 74 | set +e 75 | ./.github/unittest/linux/scripts/run-clang-format.py -r tensordict/csrc --clang-format-executable ./clang-format 76 | 77 | if [ $? -ne 0 ]; then 78 | git --no-pager diff 79 | exit 1 80 | fi 81 | echo '::endgroup::' 82 | -------------------------------------------------------------------------------- /.github/workflows/selfassign.yml: -------------------------------------------------------------------------------- 1 | # Allow users to automatically tag themselves to issues 2 | # 3 | # Usage: 4 | # - a github user (a member of the repo) needs to comment 5 | # with "#self-assign" on an issue to be assigned to them. 6 | #------------------------------------------------------------ 7 | 8 | name: Self-assign 9 | on: 10 | issue_comment: 11 | types: created 12 | 13 | concurrency: 14 | # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. 15 | # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. 16 | group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | one: 21 | runs-on: ubuntu-latest 22 | if: >- 23 | (github.event.comment.body == '#take' || 24 | github.event.comment.body == '#self-assign') 25 | steps: 26 | - run: | 27 | echo "Assigning issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" 28 | curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ 29 | -d '{"assignees": ["${{ github.event.comment.user.login }}"]}' \ 30 | https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees 31 | echo "Done 🔥 " 32 | -------------------------------------------------------------------------------- /.github/workflows/test-linux.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests on Linux 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | concurrency: 16 | # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. 17 | # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. 18 | group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | test-gpu: 23 | strategy: 24 | matrix: 25 | python_version: ["3.10"] 26 | cuda_arch_version: ["12.8"] 27 | fail-fast: false 28 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 29 | permissions: 30 | id-token: write 31 | contents: read 32 | with: 33 | runner: linux.g5.4xlarge.nvidia.gpu 34 | docker-image: "nvidia/cuda:12.8.0-devel-ubuntu22.04" 35 | repository: pytorch/tensordict 36 | gpu-arch-type: cuda 37 | gpu-arch-version: ${{ matrix.cuda_arch_version }} 38 | script: | 39 | # Set env vars from matrix 40 | export PYTHON_VERSION=${{ matrix.python_version }} 41 | # Commenting these out for now because the GPU test are not working inside docker 42 | export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} 43 | export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" 44 | export TORCH_VERSION=nightly 45 | # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines 46 | #export CU_VERSION="cpu" 47 | export ARCH=x86_64 48 | 49 | echo "PYTHON_VERSION: $PYTHON_VERSION" 50 | echo "CU_VERSION: $CU_VERSION" 51 | 52 | ## setup_env.sh 53 | bash .github/unittest/linux/scripts/setup_env.sh 54 | bash .github/unittest/linux/scripts/install.sh 55 | bash .github/unittest/linux/scripts/run_test.sh 56 | bash .github/unittest/linux/scripts/post_process.sh 57 | 58 | test-cpu: 59 | strategy: 60 | matrix: 61 | python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 62 | fail-fast: false 63 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 64 | permissions: 65 | id-token: write 66 | contents: read 67 | with: 68 | runner: linux.12xlarge 69 | docker-image: "nvidia/cuda:12.8.0-devel-ubuntu22.04" 70 | repository: pytorch/tensordict 71 | timeout: 90 72 | script: | 73 | # Set env vars from matrix 74 | export PYTHON_VERSION=${{ matrix.python_version }} 75 | export CU_VERSION="cpu" 76 | export TORCH_VERSION=nightly 77 | export ARCH=x86_64 78 | 79 | echo "PYTHON_VERSION: $PYTHON_VERSION" 80 | echo "CU_VERSION: $CU_VERSION" 81 | 82 | ## setup_env.sh 83 | bash .github/unittest/linux/scripts/setup_env.sh 84 | bash .github/unittest/linux/scripts/install.sh 85 | bash .github/unittest/linux/scripts/run_test.sh 86 | bash .github/unittest/linux/scripts/post_process.sh 87 | 88 | test-stable-gpu: 89 | strategy: 90 | matrix: 91 | python_version: ["3.10"] 92 | cuda_arch_version: ["12.6"] 93 | fail-fast: false 94 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 95 | permissions: 96 | id-token: write 97 | contents: read 98 | with: 99 | runner: linux.g5.4xlarge.nvidia.gpu 100 | docker-image: "nvidia/cuda:12.6.0-devel-ubuntu22.04" 101 | repository: pytorch/tensordict 102 | gpu-arch-type: cuda 103 | gpu-arch-version: ${{ matrix.cuda_arch_version }} 104 | timeout: 90 105 | script: | 106 | # Set env vars from matrix 107 | export PYTHON_VERSION=${{ matrix.python_version }} 108 | # Commenting these out for now because the GPU test are not working inside docker 109 | export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} 110 | export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" 111 | export TORCH_VERSION=stable 112 | # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines 113 | #export CU_VERSION="cpu" 114 | export ARCH=x86_64 115 | 116 | echo "PYTHON_VERSION: $PYTHON_VERSION" 117 | echo "CU_VERSION: $CU_VERSION" 118 | 119 | ## setup_env.sh 120 | bash .github/unittest/linux/scripts/setup_env.sh 121 | bash .github/unittest/linux/scripts/install.sh 122 | bash .github/unittest/linux/scripts/run_test.sh 123 | bash .github/unittest/linux/scripts/post_process.sh 124 | 125 | test-stable-cpu: 126 | strategy: 127 | matrix: 128 | python_version: ["3.9", "3.13"] 129 | fail-fast: false 130 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 131 | permissions: 132 | id-token: write 133 | contents: read 134 | with: 135 | runner: linux.12xlarge 136 | docker-image: "nvidia/cuda:12.8.0-devel-ubuntu22.04" 137 | repository: pytorch/tensordict 138 | timeout: 90 139 | script: | 140 | # Set env vars from matrix 141 | export PYTHON_VERSION=${{ matrix.python_version }} 142 | export CU_VERSION="cpu" 143 | export TORCH_VERSION=stable 144 | export ARCH=x86_64 145 | 146 | echo "PYTHON_VERSION: $PYTHON_VERSION" 147 | echo "CU_VERSION: $CU_VERSION" 148 | 149 | ## setup_env.sh 150 | bash .github/unittest/linux/scripts/setup_env.sh 151 | bash .github/unittest/linux/scripts/install.sh 152 | bash .github/unittest/linux/scripts/run_test.sh 153 | bash .github/unittest/linux/scripts/post_process.sh 154 | -------------------------------------------------------------------------------- /.github/workflows/test-macos.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests on MacOS 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | concurrency: 16 | # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. 17 | # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. 18 | group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | tests-silicon: 23 | strategy: 24 | matrix: 25 | python_version: ["3.9"] 26 | fail-fast: false 27 | uses: pytorch/test-infra/.github/workflows/macos_job.yml@main 28 | with: 29 | runner: macos-m1-stable 30 | repository: pytorch/tensordict 31 | timeout: 120 32 | script: | 33 | # Set env vars from matrix 34 | set -e 35 | set -v 36 | export PYTHON_VERSION=${{ matrix.python_version }} 37 | export CU_VERSION="cpu" 38 | export SYSTEM_VERSION_COMPAT=0 39 | export TORCH_VERSION=nightly 40 | export ARCH=arm64 41 | 42 | echo "PYTHON_VERSION: $PYTHON_VERSION" 43 | echo "CU_VERSION: $CU_VERSION" 44 | 45 | ## setup_env.sh 46 | bash .github/unittest/linux/scripts/setup_env.sh 47 | bash .github/unittest/linux/scripts/install.sh 48 | bash .github/unittest/linux/scripts/run_test.sh 49 | bash .github/unittest/linux/scripts/post_process.sh 50 | -------------------------------------------------------------------------------- /.github/workflows/test-rl-gpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests (RL) on Linux 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | concurrency: 16 | # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. 17 | # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. 18 | group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | test-gpu: 23 | strategy: 24 | matrix: 25 | python_version: ["3.10"] 26 | cuda_arch_version: ["12.8"] 27 | fail-fast: false 28 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 29 | permissions: 30 | id-token: write 31 | contents: read 32 | with: 33 | runner: linux.g5.4xlarge.nvidia.gpu 34 | docker-image: "nvidia/cuda:12.8.0-devel-ubuntu22.04" 35 | repository: pytorch/tensordict 36 | gpu-arch-type: cuda 37 | gpu-arch-version: ${{ matrix.cuda_arch_version }} 38 | timeout: 120 39 | script: | 40 | # Set env vars from matrix 41 | export PYTHON_VERSION=${{ matrix.python_version }} 42 | # Commenting these out for now because the GPU test are not working inside docker 43 | export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} 44 | export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" 45 | export TORCH_VERSION=nightly 46 | # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines 47 | #export CU_VERSION="cpu" 48 | 49 | echo "PYTHON_VERSION: $PYTHON_VERSION" 50 | echo "CU_VERSION: $CU_VERSION" 51 | 52 | ## setup_env.sh 53 | bash .github/unittest/rl_linux_optdeps/scripts/setup_env.sh 54 | bash .github/unittest/rl_linux_optdeps/scripts/install.sh 55 | bash .github/unittest/rl_linux_optdeps/scripts/run_test.sh 56 | bash .github/unittest/rl_linux_optdeps/scripts/post_process.sh 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python artifacts 2 | __pycache__/ 3 | *.egg-info/ 4 | .pytest_cache/ 5 | build/ 6 | dist/ 7 | tensordict/version.py 8 | tensordict/_tensordict.so 9 | 10 | # docs build 11 | docs/_data 12 | docs/source/gen_modules 13 | docs/source/reference/generated 14 | docs/source/tutorials 15 | docs/src 16 | 17 | # Pycharm 18 | .idea 19 | 20 | */_C.so 21 | tensordict/_version.py 22 | 23 | scratch/*.py 24 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-docstring-first 6 | - id: check-toml 7 | - id: check-yaml 8 | exclude: packaging/.* 9 | - id: mixed-line-ending 10 | args: [--fix=lf] 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | 14 | - repo: https://github.com/omnilib/ufmt 15 | rev: v2.7.0 16 | hooks: 17 | - id: ufmt 18 | additional_dependencies: 19 | - black == 24.4.2 20 | - usort == 1.0.3 21 | - libcst == 0.4.7 22 | 23 | - repo: https://github.com/psf/black 24 | rev: 24.4.2 25 | hooks: 26 | - id: black 27 | 28 | - repo: https://github.com/pycqa/flake8 29 | rev: 7.1.0 30 | hooks: 31 | - id: flake8 32 | args: [--config=setup.cfg] 33 | additional_dependencies: 34 | - flake8-bugbear==22.10.27 35 | - flake8-comprehensions==3.10.1 36 | - torchfix==0.5.0 37 | - flake8-print==5.0.0 38 | # - flake8-unused-arguments==0.0.13 39 | 40 | - repo: https://github.com/PyCQA/pydocstyle 41 | rev: 6.3.0 42 | hooks: 43 | - id: pydocstyle 44 | files: ^tensordict/ 45 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to tensordict 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Installing the library 6 | Install the library as suggested in the README. For advanced features, 7 | it is preferable to install the nightly built of pytorch. 8 | 9 | You will need the following packages to be installed: 10 | ```bash 11 | pip install ninja "pybind11[global]" -U 12 | ``` 13 | as well as cmake (using `apt-get`, `conda` or any other package manager). 14 | 15 | Make sure you install tensordict in develop mode by running 16 | ``` 17 | pip install -e . 18 | ``` 19 | in your shell. 20 | 21 | If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message 22 | `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears, then try 23 | 24 | ``` 25 | ARCHFLAGS="-arch arm64" pip install -e . 26 | ``` 27 | 28 | ## Formatting your code 29 | **Type annotation** 30 | 31 | tensordict is not strongly-typed, i.e. we do not enforce type hints, neither do we check that the ones that are present are valid. We rely on type hints purely for documentary purposes. Although this might change in the future, there is currently no need for this to be enforced at the moment. 32 | 33 | **Linting** 34 | 35 | Before your PR is ready, you'll probably want your code to be checked. This can be done easily by installing 36 | ``` 37 | pip install pre-commit 38 | ``` 39 | and running 40 | ``` 41 | pre-commit run --all-files 42 | ``` 43 | from within the tensordict cloned directory. 44 | 45 | You can also install [pre-commit hooks](https://pre-commit.com/) (using `pre-commit install` 46 | ). You can disable the check by appending `-n` to your commit command: `git commit -m -n` 47 | 48 | ## Pull Requests 49 | We actively welcome your pull requests. 50 | 51 | 1. Fork the repo and create your branch from `main`. 52 | 2. If you've added code that should be tested, add tests. 53 | 3. If you've changed APIs, update the documentation. 54 | 4. Ensure the test suite passes. 55 | 5. Make sure your code lints. 56 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 57 | 58 | When submitting a PR, we encourage you to link it to the related issue (if any) and add some tags to it. 59 | 60 | ## Contributor License Agreement ("CLA") 61 | In order to accept your pull request, we need you to submit a CLA. You only need 62 | to do this once to work on any of Facebook's open source projects. 63 | 64 | Complete your CLA here: 65 | 66 | ## Issues 67 | We use GitHub issues to track public bugs. Please ensure your description is 68 | clear and has sufficient instructions to be able to reproduce the issue. 69 | 70 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 71 | disclosure of security bugs. In those cases, please go through the process 72 | outlined on that page and do not file a public issue. 73 | 74 | ## License 75 | By contributing to rl, you agree that your contributions will be licensed 76 | under the LICENSE file in the root directory of this source tree. 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | 4 | recursive-exclude * __pycache__ 5 | recursive-exclude * *.py[co] 6 | recursive-include tensordict *.so 7 | 8 | exclude gallery/* 9 | exclude tutorials/* 10 | exclude packaging/* 11 | exclude docs/* 12 | -------------------------------------------------------------------------------- /benchmarks/common/h2d_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import argparse 8 | import time 9 | from typing import Any 10 | 11 | import pytest 12 | import torch 13 | from packaging import version 14 | 15 | from tensordict import tensorclass, TensorDict 16 | from tensordict.utils import logger as tensordict_logger 17 | 18 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) 19 | 20 | 21 | @tensorclass 22 | class NJT: 23 | _values: torch.Tensor 24 | _offsets: torch.Tensor 25 | _lengths: torch.Tensor 26 | njt_shape: Any = None 27 | 28 | @classmethod 29 | def from_njt(cls, njt_tensor): 30 | return cls( 31 | _values=njt_tensor._values, 32 | _offsets=njt_tensor._offsets, 33 | _lengths=njt_tensor._lengths, 34 | njt_shape=njt_tensor.size(0), 35 | ).clone() 36 | 37 | 38 | @pytest.fixture(autouse=True, scope="function") 39 | def empty_compiler_cache(): 40 | torch.compiler.reset() 41 | yield 42 | 43 | 44 | def _make_njt(): 45 | lengths = torch.arange(24, 1, -1) 46 | offsets = torch.cat([lengths[:1] * 0, lengths]).cumsum(0) 47 | return torch.nested.nested_tensor_from_jagged( 48 | torch.arange(78, dtype=torch.float), offsets=offsets, lengths=lengths 49 | ) 50 | 51 | 52 | def _njt_td(): 53 | return TensorDict( 54 | # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, 55 | {str(i): _make_njt() for i in range(32)}, 56 | device="cpu", 57 | ) 58 | 59 | 60 | @pytest.fixture 61 | def njt_td(): 62 | return _njt_td() 63 | 64 | 65 | @pytest.fixture 66 | def td(): 67 | njtd = _njt_td() 68 | for k0, v0 in njtd.items(): 69 | njtd[k0] = NJT.from_njt(v0) 70 | # for k1, v1 in v0.items(): 71 | # njtd[k0, k1] = NJT.from_njt(v1) 72 | return njtd 73 | 74 | 75 | @pytest.fixture 76 | def default_device(): 77 | if torch.cuda.is_available(): 78 | yield torch.device("cuda:0") 79 | elif torch.backends.mps.is_available(): 80 | yield torch.device("mps:0") 81 | else: 82 | pytest.skip("CUDA/MPS is not available") 83 | 84 | 85 | @pytest.mark.parametrize( 86 | "compile_mode,num_threads", 87 | [ 88 | [False, None], 89 | # [False, 4], 90 | # [False, 16], 91 | ["default", None], 92 | ["reduce-overhead", None], 93 | ], 94 | ) 95 | @pytest.mark.skipif( 96 | TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" 97 | ) 98 | class TestConsolidate: 99 | def test_consolidate( 100 | self, benchmark, td, compile_mode, num_threads, default_device 101 | ): 102 | tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") 103 | 104 | # td = td.to(default_device) 105 | 106 | def consolidate(td, num_threads): 107 | return td.consolidate(num_threads=num_threads) 108 | 109 | if compile_mode: 110 | consolidate = torch.compile( 111 | consolidate, mode=compile_mode, dynamic=False, fullgraph=True 112 | ) 113 | 114 | t0 = time.time() 115 | consolidate(td, num_threads=num_threads) 116 | elapsed = time.time() - t0 117 | tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec") 118 | 119 | for _ in range(3): 120 | consolidate(td, num_threads=num_threads) 121 | 122 | benchmark(consolidate, td, num_threads) 123 | 124 | def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads): 125 | tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") 126 | 127 | def consolidate(td, num_threads): 128 | return td.consolidate(num_threads=num_threads) 129 | 130 | if compile_mode: 131 | pytest.skip( 132 | "Compiling NJTs consolidation currently triggers a RuntimeError." 133 | ) 134 | # consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True) 135 | 136 | for _ in range(3): 137 | consolidate(njt_td, num_threads=num_threads) 138 | 139 | benchmark(consolidate, njt_td, num_threads) 140 | 141 | 142 | @pytest.mark.parametrize( 143 | "consolidated,compile_mode,num_threads", 144 | [ 145 | [False, False, None], 146 | [True, False, None], 147 | ["within", False, None], 148 | # [True, False, 4], 149 | # [True, False, 16], 150 | [True, "default", None], 151 | ], 152 | ) 153 | @pytest.mark.skipif( 154 | TORCH_VERSION < version.parse("2.5.2"), reason="requires torch>=2.5" 155 | ) 156 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found") 157 | class TestTo: 158 | def test_to( 159 | self, benchmark, consolidated, td, default_device, compile_mode, num_threads 160 | ): 161 | tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") 162 | pin_mem = default_device.type == "cuda" 163 | if consolidated is True: 164 | td = td.consolidate(pin_memory=pin_mem) 165 | 166 | if consolidated == "within": 167 | 168 | def to(td, num_threads): 169 | return td.consolidate(pin_memory=pin_mem).to( 170 | default_device, num_threads=num_threads 171 | ) 172 | 173 | else: 174 | 175 | def to(td, num_threads): 176 | return td.to(default_device, num_threads=num_threads) 177 | 178 | if compile_mode: 179 | to = torch.compile(to, mode=compile_mode, dynamic=True) 180 | 181 | for _ in range(3): 182 | to(td, num_threads=num_threads) 183 | 184 | benchmark(to, td, num_threads) 185 | 186 | def test_to_njt( 187 | self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads 188 | ): 189 | if compile_mode: 190 | pytest.skip( 191 | "Compiling NJTs consolidation currently triggers a RuntimeError." 192 | ) 193 | 194 | tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") 195 | pin_mem = default_device.type == "cuda" 196 | if consolidated is True: 197 | njt_td = njt_td.consolidate(pin_memory=pin_mem) 198 | 199 | if consolidated == "within": 200 | 201 | def to(td, num_threads): 202 | return td.consolidate(pin_memory=pin_mem).to( 203 | default_device, num_threads=num_threads 204 | ) 205 | 206 | else: 207 | 208 | def to(td, num_threads): 209 | return td.to(default_device, num_threads=num_threads) 210 | 211 | if compile_mode: 212 | to = torch.compile(to, mode=compile_mode, dynamic=True) 213 | 214 | for _ in range(3): 215 | to(njt_td, num_threads=num_threads) 216 | 217 | benchmark(to, njt_td, num_threads) 218 | 219 | 220 | if __name__ == "__main__": 221 | args, unknown = argparse.ArgumentParser().parse_known_args() 222 | pytest.main( 223 | [ 224 | __file__, 225 | "--capture", 226 | "no", 227 | "--exitfirst", 228 | "--benchmark-group-by", 229 | "func", 230 | "-vvv", 231 | ] 232 | + unknown 233 | ) 234 | -------------------------------------------------------------------------------- /benchmarks/common/memmap_benchmarks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import argparse 8 | import pathlib 9 | import time 10 | import uuid 11 | from pathlib import Path 12 | 13 | import pytest 14 | import torch 15 | 16 | from tensordict import MemoryMappedTensor, TensorDict 17 | from torch import nn 18 | 19 | 20 | def get_available_devices(): 21 | devices = [torch.device("cpu")] 22 | n_cuda = torch.cuda.device_count() 23 | if n_cuda > 0: 24 | for i in range(n_cuda): 25 | devices += [torch.device(f"cuda:{i}")] 26 | return devices 27 | 28 | 29 | @pytest.fixture 30 | def tensor(): 31 | return torch.zeros(3, 4, 5) 32 | 33 | 34 | @pytest.fixture(params=[torch.device("cpu")]) 35 | def memmap_tensor(request): 36 | return MemoryMappedTensor.zeros((3, 4, 5)) 37 | 38 | 39 | @pytest.fixture 40 | def td_memmap(): 41 | return TensorDict( 42 | {str(i): torch.zeros(3, 40) + i for i in range(30)}, [3, 40] 43 | ).memmap_() 44 | 45 | 46 | @pytest.mark.parametrize("device", [torch.device("cpu")]) 47 | def test_creation(benchmark, device): 48 | benchmark(MemoryMappedTensor.empty, (3, 4, 5)) 49 | 50 | 51 | def test_creation_from_tensor(benchmark, tensor): 52 | benchmark( 53 | MemoryMappedTensor.from_tensor, 54 | tensor, 55 | ) 56 | 57 | 58 | def test_add_one(benchmark, memmap_tensor): 59 | benchmark(lambda: memmap_tensor + 1) 60 | 61 | 62 | def test_contiguous(benchmark, memmap_tensor): 63 | benchmark(lambda: memmap_tensor.contiguous()) 64 | 65 | 66 | def test_stack(benchmark, memmap_tensor): 67 | benchmark(torch.stack, [memmap_tensor] * 2, 0) 68 | 69 | 70 | def test_memmaptd_index(benchmark, td_memmap): 71 | benchmark( 72 | lambda td: td[0], 73 | td_memmap, 74 | ) 75 | 76 | 77 | def test_memmaptd_index_astensor(benchmark, td_memmap): 78 | benchmark( 79 | lambda td: td[0].as_tensor(), 80 | td_memmap, 81 | ) 82 | 83 | 84 | def test_memmaptd_index_op(benchmark, td_memmap): 85 | benchmark( 86 | lambda td: td[0].apply(lambda x: x + 1), 87 | td_memmap, 88 | ) 89 | 90 | 91 | @pytest.fixture(scope="function") 92 | def pause_when_exit(): 93 | yield None 94 | time.sleep(0.5) 95 | 96 | 97 | def test_serialize_model(benchmark, tmpdir, pause_when_exit): 98 | """Tests efficiency of saving weights as memmap tensors, including TD construction.""" 99 | has_cuda = torch.cuda.device_count() 100 | with torch.device("cuda" if has_cuda else "cpu"): 101 | t = nn.Transformer() 102 | 103 | def func(t=t, tmpdir=tmpdir): 104 | TensorDict.from_module(t).memmap(tmpdir, num_threads=32) 105 | 106 | benchmark(func) 107 | del t 108 | 109 | 110 | def test_serialize_model_pickle(benchmark, tmpdir, pause_when_exit): 111 | """Tests efficiency of pickling a model state-dict, including state-dict construction.""" 112 | has_cuda = torch.cuda.device_count() 113 | with torch.device("cuda" if has_cuda else "cpu"): 114 | t = nn.Transformer() 115 | path = Path(tmpdir) / "file.t" 116 | 117 | def func(t=t, path=path): 118 | torch.save(t.state_dict(), path) 119 | 120 | benchmark(func) 121 | del t 122 | 123 | 124 | def test_serialize_weights(benchmark, tmpdir, pause_when_exit): 125 | """Tests efficiency of saving weights as memmap tensors.""" 126 | has_cuda = torch.cuda.device_count() 127 | with torch.device("cuda" if has_cuda else "cpu"): 128 | t = nn.Transformer() 129 | 130 | weights = TensorDict.from_module(t) 131 | 132 | def func(weights=weights): 133 | weights.memmap(tmpdir, num_threads=32) 134 | 135 | benchmark(func) 136 | del t, weights 137 | 138 | 139 | def test_serialize_weights_returnearly(benchmark, tmpdir, pause_when_exit): 140 | """Tests efficiency of saving weights as memmap tensors, before writing is completed.""" 141 | has_cuda = torch.cuda.device_count() 142 | with torch.device("cuda" if has_cuda else "cpu"): 143 | t = nn.Transformer() 144 | 145 | datapath = pathlib.Path(tmpdir) 146 | weights = TensorDict.from_module(t) 147 | 148 | def func(weights=weights, datapath=datapath): 149 | weights.memmap(datapath / f"{uuid.uuid1()}", num_threads=32, return_early=True) 150 | 151 | benchmark(func) 152 | del t, weights 153 | 154 | 155 | def test_serialize_weights_pickle(benchmark, tmpdir, pause_when_exit): 156 | """Tests efficiency of pickling a model state-dict.""" 157 | has_cuda = torch.cuda.device_count() 158 | with torch.device("cuda" if has_cuda else "cpu"): 159 | t = nn.Transformer() 160 | 161 | path = Path(tmpdir) / "file.t" 162 | weights = t.state_dict() 163 | 164 | def func(path=path, weights=weights): 165 | torch.save(weights, path) 166 | 167 | benchmark(func) 168 | del t, weights 169 | 170 | 171 | def test_serialize_weights_filesystem(benchmark, pause_when_exit): 172 | """Tests efficiency of saving weights as memmap tensors.""" 173 | has_cuda = torch.cuda.device_count() 174 | if has_cuda: 175 | pytest.skip( 176 | "Multithreaded saving on filesystem with models on CUDA. " 177 | "These should be first cast on CPU for safety." 178 | ) 179 | with torch.device("cuda" if has_cuda else "cpu"): 180 | t = nn.Transformer() 181 | 182 | weights = TensorDict.from_module(t) 183 | 184 | def func(weights=weights): 185 | weights.memmap(num_threads=32) 186 | 187 | benchmark(func) 188 | del t, weights 189 | 190 | 191 | def test_serialize_model_filesystem(benchmark, pause_when_exit): 192 | """Tests efficiency of saving weights as memmap tensors in file system, including TD construction.""" 193 | has_cuda = torch.cuda.device_count() 194 | if has_cuda: 195 | pytest.skip( 196 | "Multithreaded saving on filesystem with models on CUDA. " 197 | "These should be first cast on CPU for safety." 198 | ) 199 | with torch.device("cuda" if has_cuda else "cpu"): 200 | t = nn.Transformer() 201 | 202 | def func(t=t): 203 | TensorDict.from_module(t).memmap(num_threads=32) 204 | 205 | benchmark(func) 206 | del t 207 | 208 | 209 | if __name__ == "__main__": 210 | args, unknown = argparse.ArgumentParser().parse_known_args() 211 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 212 | -------------------------------------------------------------------------------- /benchmarks/common/pytree_benchmarks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import pytest 8 | import torch 9 | 10 | from tensordict import TensorDict 11 | from torch.utils._pytree import tree_map 12 | 13 | 14 | @pytest.fixture 15 | def nested_dict(): 16 | return { 17 | "a": {"b": torch.randn(3, 4, 1), "c": {"d": torch.rand(3, 4, 5, 6)}}, 18 | "c": torch.rand(3, 4, 1), 19 | } 20 | 21 | 22 | @pytest.fixture 23 | def nested_td(nested_dict): 24 | return TensorDict(nested_dict, [3, 4]) 25 | 26 | 27 | # reshape 28 | def test_reshape_pytree(benchmark, nested_dict): 29 | benchmark(tree_map, lambda x: x.reshape(12, *x.shape[2:]), nested_dict) 30 | 31 | 32 | def test_reshape_td(benchmark, nested_td): 33 | benchmark( 34 | nested_td.reshape, 35 | 12, 36 | ) 37 | 38 | 39 | # view 40 | def test_view_pytree(benchmark, nested_dict): 41 | benchmark(tree_map, lambda x: x.view(12, *x.shape[2:]), nested_dict) 42 | 43 | 44 | def test_view_td(benchmark, nested_td): 45 | benchmark( 46 | nested_td.view, 47 | 12, 48 | ) 49 | 50 | 51 | # unbind 52 | def test_unbind_pytree(benchmark, nested_dict): 53 | benchmark(tree_map, lambda x: x.unbind(0), nested_dict) 54 | 55 | 56 | def test_unbind_td(benchmark, nested_td): 57 | benchmark( 58 | nested_td.unbind, 59 | 0, 60 | ) 61 | 62 | 63 | # split 64 | def test_split_pytree(benchmark, nested_dict): 65 | benchmark(tree_map, lambda x: x.split([1, 2], 0), nested_dict) 66 | 67 | 68 | def test_split_td(benchmark, nested_td): 69 | benchmark(nested_td.split, [1, 2], 0) 70 | 71 | 72 | # add 73 | def test_add_pytree(benchmark, nested_dict): 74 | benchmark(tree_map, lambda x: x + 1, nested_dict) 75 | 76 | 77 | def test_add_td(benchmark, nested_td): 78 | benchmark( 79 | nested_td.apply, 80 | lambda x: x + 1, 81 | ) 82 | -------------------------------------------------------------------------------- /benchmarks/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import logging 6 | import os 7 | import time 8 | from collections import defaultdict 9 | 10 | import pytest 11 | 12 | CALL_TIMES = defaultdict(lambda: 0.0) 13 | 14 | 15 | def pytest_sessionfinish(maxprint=50): 16 | out_str = """ 17 | Call times: 18 | =========== 19 | """ 20 | keys = list(CALL_TIMES.keys()) 21 | if len(keys) > 1: 22 | maxchar = max(*[len(key) for key in keys]) 23 | elif len(keys): 24 | maxchar = len(keys[0]) 25 | else: 26 | return 27 | for i, (key, item) in enumerate( 28 | sorted(CALL_TIMES.items(), key=lambda x: x[1], reverse=True) 29 | ): 30 | spaces = " " + " " * (maxchar - len(key)) 31 | out_str += f"\t{key}{spaces}{item: 4.4f}s\n" 32 | if i == maxprint - 1: 33 | break 34 | logging.info(out_str) 35 | 36 | 37 | @pytest.fixture(autouse=True) 38 | def measure_duration(request: pytest.FixtureRequest): 39 | start_time = time.time() 40 | 41 | def fin(): 42 | duration = time.time() - start_time 43 | name = request.node.name 44 | class_name = request.cls.__name__ if request.cls else None 45 | name = name.split("[")[0] 46 | if class_name is not None: 47 | name = "::".join([class_name, name]) 48 | file = os.path.basename(request.path) 49 | name = f"{file}::{name}" 50 | CALL_TIMES[name] = CALL_TIMES[name] + duration 51 | 52 | request.addfinalizer(fin) 53 | 54 | 55 | def pytest_configure(config): 56 | config.addinivalue_line("markers", "slow: mark test as slow to run") 57 | 58 | 59 | def pytest_collection_modifyitems(config, items): 60 | if config.getoption("--runslow"): 61 | # --runslow given in cli: do not skip slow tests 62 | return 63 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 64 | for item in items: 65 | if "slow" in item.keywords: 66 | item.add_marker(skip_slow) 67 | 68 | 69 | def pytest_addoption(parser): 70 | parser.addoption("--rank", action="store") 71 | parser.addoption( 72 | "--runslow", action="store_true", default=False, help="run slow tests" 73 | ) 74 | -------------------------------------------------------------------------------- /benchmarks/distributed/distributed_benchmark_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import os 6 | import pathlib 7 | import tempfile 8 | import time 9 | 10 | import pytest 11 | import torch 12 | 13 | from tensordict import TensorDict 14 | from torch.distributed import rpc 15 | 16 | MAIN_NODE = "Main" 17 | WORKER_NODE = "worker" 18 | 19 | 20 | @pytest.fixture 21 | def rank(pytestconfig): 22 | return pytestconfig.getoption("rank") 23 | 24 | 25 | def test_distributed(benchmark, rank): 26 | benchmark(exec_distributed_test, rank) 27 | 28 | 29 | class CloudpickleWrapper(object): 30 | def __init__(self, fn): 31 | self.fn = fn 32 | 33 | def __getstate__(self): 34 | import cloudpickle 35 | 36 | return cloudpickle.dumps(self.fn) 37 | 38 | def __setstate__(self, ob: bytes): 39 | import pickle 40 | 41 | self.fn = pickle.loads(ob) 42 | 43 | def __call__(self, *args, **kwargs): 44 | return self.fn(*args, **kwargs) 45 | 46 | 47 | def exec_distributed_test(rank_node): 48 | with tempfile.TemporaryDirectory() as tmpdir: 49 | tmpdir = pathlib.Path(tmpdir) 50 | os.environ["MASTER_ADDR"] = "localhost" 51 | os.environ["MASTER_PORT"] = "29549" 52 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" 53 | str_init_method = "tcp://localhost:10001" 54 | options = rpc.TensorPipeRpcBackendOptions( 55 | num_worker_threads=16, init_method=str_init_method 56 | ) 57 | rank = rank_node 58 | if rank == 0: 59 | rpc.init_rpc( 60 | MAIN_NODE, 61 | rank=rank, 62 | backend=rpc.BackendType.TENSORPIPE, 63 | rpc_backend_options=options, 64 | ) 65 | 66 | # create a tensordict is 1Gb big, stored on disk, assuming that both nodes have access to /tmp/ 67 | tensordict = TensorDict( 68 | { 69 | "memmap": torch.empty((), dtype=torch.uint8).expand( 70 | (1000, 640, 640, 3) 71 | ) 72 | }, 73 | [1000], 74 | ).memmap_(tmpdir, copy_existing=False) 75 | assert tensordict.is_memmap() 76 | 77 | while True: 78 | try: 79 | worker_info = rpc.get_worker_info("worker") 80 | break 81 | except RuntimeError: 82 | time.sleep(0.1) 83 | 84 | def fill_tensordict(tensordict, idx): 85 | tensordict[idx] = TensorDict( 86 | {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5] 87 | ) 88 | return tensordict 89 | 90 | fill_tensordict_cp = CloudpickleWrapper(fill_tensordict) 91 | idx = [0, 1, 2, 3, 999] 92 | rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx)) 93 | 94 | idx = [4, 5, 6, 7, 998] 95 | rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx)) 96 | 97 | rpc.shutdown() 98 | elif rank == 1: 99 | rpc.init_rpc( 100 | WORKER_NODE, 101 | rank=rank, 102 | backend=rpc.BackendType.TENSORPIPE, 103 | rpc_backend_options=options, 104 | ) 105 | -------------------------------------------------------------------------------- /benchmarks/fx_benchmarks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import timeit 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from tensordict import TensorDict 13 | from tensordict.nn import TensorDictModule, TensorDictSequential 14 | from tensordict.prototype.fx import symbolic_trace 15 | 16 | 17 | # modules for sequential benchmark 18 | class Net(nn.Module): 19 | def __init__(self, input_size=100, hidden_size=50, output_size=10): 20 | super().__init__() 21 | self.fc1 = nn.Linear(input_size, hidden_size) 22 | self.fc2 = nn.Linear(hidden_size, output_size) 23 | 24 | def forward(self, x): 25 | x = torch.relu(self.fc1(x)) 26 | return self.fc2(x) 27 | 28 | 29 | class Masker(nn.Module): 30 | def forward(self, x, mask): 31 | return torch.softmax(x * mask, dim=1) 32 | 33 | 34 | # modules for nested sequential benchmark 35 | class FCLayer(nn.Module): 36 | def __init__(self, input_size, output_size): 37 | super().__init__() 38 | self.fc = nn.Linear(input_size, output_size) 39 | 40 | def forward(self, x): 41 | return torch.relu(self.fc(x)) 42 | 43 | 44 | class Output(nn.Module): 45 | def __init__(self, input_size, output_size=10): 46 | super().__init__() 47 | self.fc = nn.Linear(input_size, output_size) 48 | 49 | def forward(self, x): 50 | return torch.softmax(self.fc(x), dim=1) 51 | 52 | 53 | if __name__ == "__main__": 54 | net = TensorDictModule( 55 | Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")] 56 | ) 57 | masker = TensorDictModule( 58 | Masker(), 59 | in_keys=[("intermediate", "x"), ("input", "mask")], 60 | out_keys=[("output", "probabilities")], 61 | ) 62 | module = TensorDictSequential(net, masker) 63 | graph_module = symbolic_trace(module) 64 | 65 | tensordict = TensorDict( 66 | { 67 | "input": TensorDict( 68 | {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))}, 69 | batch_size=[32], 70 | ) 71 | }, 72 | batch_size=[32], 73 | ) 74 | 75 | logging.info( 76 | "forward, TensorDictSequential", 77 | timeit.timeit( 78 | "module(tensordict)", 79 | globals={"tensordict": tensordict, "module": module}, 80 | number=10_000, 81 | ), 82 | ) 83 | 84 | logging.info( 85 | "forward, GraphModule", 86 | timeit.timeit( 87 | "module(tensordict)", 88 | globals={"tensordict": tensordict, "module": graph_module}, 89 | number=10_000, 90 | ), 91 | ) 92 | 93 | tdmodule1 = TensorDictModule(FCLayer(100, 50), ["input"], ["x"]) 94 | tdmodule2 = TensorDictModule(FCLayer(50, 40), ["x"], ["x"]) 95 | tdmodule3 = TensorDictModule(Output(40, 10), ["x"], ["probabilities"]) 96 | nested_tdmodule = TensorDictSequential( 97 | TensorDictSequential(tdmodule1, tdmodule2), tdmodule3 98 | ) 99 | 100 | nested_graph_module = symbolic_trace(nested_tdmodule) 101 | tensordict = TensorDict({"input": torch.rand(32, 100)}, [32]) 102 | 103 | logging.info( 104 | "nested_forward, TensorDictSequential", 105 | timeit.timeit( 106 | "module(tensordict)", 107 | globals={"tensordict": tensordict, "module": nested_tdmodule}, 108 | number=10_000, 109 | ), 110 | ) 111 | 112 | logging.info( 113 | "nested_forward, GraphModule", 114 | timeit.timeit( 115 | "module(tensordict)", 116 | globals={"tensordict": tensordict, "module": nested_graph_module}, 117 | number=10_000, 118 | ), 119 | ) 120 | -------------------------------------------------------------------------------- /benchmarks/nn/functional_benchmarks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # we use deepcopy as our implementation modifies the modules in-place 8 | import argparse 9 | from copy import deepcopy 10 | 11 | import pytest 12 | import torch 13 | from functorch import make_functional_with_buffers as functorch_make_functional 14 | 15 | from tensordict import TensorDict 16 | from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential 17 | 18 | from torch import nn, vmap 19 | 20 | 21 | def make_net(): 22 | return nn.Sequential( 23 | nn.Linear(2, 2), 24 | nn.Linear(2, 2), 25 | nn.Linear(2, 2), 26 | nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2), nn.Linear(2, 2)), 27 | ) 28 | 29 | 30 | @pytest.fixture 31 | def net(): 32 | return make_net() 33 | 34 | 35 | def _functorch_make_functional(net): 36 | functorch_make_functional(deepcopy(net)) 37 | 38 | 39 | def make_tdmodule(): 40 | return ( 41 | ( 42 | TensorDictModule(lambda x: x, in_keys=["x"], out_keys=["y"]), 43 | TensorDict({"x": torch.zeros(())}, []), 44 | ), 45 | {}, 46 | ) 47 | 48 | 49 | def test_tdmodule(benchmark): 50 | benchmark.pedantic( 51 | lambda net, td: net(td), 52 | setup=make_tdmodule, 53 | warmup_rounds=10, 54 | rounds=1000, 55 | iterations=1, 56 | ) 57 | 58 | 59 | def make_tdmodule_dispatch(): 60 | return ( 61 | (TensorDictModule(lambda x: x, in_keys=["x"], out_keys=["y"]), torch.zeros(())), 62 | {}, 63 | ) 64 | 65 | 66 | def test_tdmodule_dispatch(benchmark): 67 | benchmark.pedantic( 68 | lambda net, x: net(x), 69 | setup=make_tdmodule_dispatch, 70 | warmup_rounds=10, 71 | rounds=1000, 72 | iterations=1, 73 | ) 74 | 75 | 76 | def make_tdseq(): 77 | class MyModule(TensorDictModuleBase): 78 | in_keys = ["x"] 79 | out_keys = ["y"] 80 | 81 | def forward(self, tensordict): 82 | return tensordict.set("y", tensordict.get("x")) 83 | 84 | return ( 85 | (TensorDictSequential(MyModule()), TensorDict({"x": torch.zeros(())}, [])), 86 | {}, 87 | ) 88 | 89 | 90 | def test_tdseq(benchmark): 91 | benchmark.pedantic( 92 | lambda net, td: net(td), setup=make_tdseq, warmup_rounds=10, rounds=1000 93 | ) 94 | 95 | 96 | def make_tdseq_dispatch(): 97 | class MyModule(TensorDictModuleBase): 98 | in_keys = ["x"] 99 | out_keys = ["y"] 100 | 101 | def forward(self, tensordict): 102 | return tensordict.set("y", tensordict.get("x")) 103 | 104 | return ((TensorDictSequential(MyModule()), torch.zeros(())), {}) 105 | 106 | 107 | def test_tdseq_dispatch(benchmark): 108 | benchmark.pedantic( 109 | lambda net, x: net(x), setup=make_tdseq_dispatch, warmup_rounds=10, rounds=1000 110 | ) 111 | 112 | 113 | # Creation 114 | def test_instantiation_functorch(benchmark, net): 115 | benchmark(_functorch_make_functional, net) 116 | 117 | 118 | # Execution 119 | def test_exec_functorch(benchmark, net): 120 | x = torch.randn(2, 2) 121 | sd = net.state_dict() 122 | 123 | def fun(x, sd): 124 | torch.func.functional_call(net, sd, x) 125 | 126 | benchmark(fun, x, sd) 127 | 128 | 129 | def test_exec_functional_call(benchmark, net): 130 | x = torch.randn(2, 2) 131 | fmodule, params, buffers = functorch_make_functional(net) 132 | benchmark(fmodule, params, buffers, x) 133 | 134 | 135 | def test_exec_td_decorator(benchmark, net): 136 | x = torch.randn(2, 2) 137 | fmodule = net 138 | params = TensorDict.from_module(fmodule) 139 | 140 | def fun(x, params): 141 | with params.to_module(net): 142 | net(x) 143 | 144 | benchmark(fun, x, params) 145 | 146 | 147 | @torch.no_grad() 148 | @pytest.mark.parametrize("stack", [True, False]) 149 | @pytest.mark.parametrize("tdmodule", [True, False]) 150 | def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule): 151 | # tests speed of vmapping over a transformer 152 | device = "cuda" if torch.cuda.device_count() else "cpu" 153 | t = nn.Sequential( 154 | nn.Linear(64, 64, device=device), 155 | nn.ReLU(), 156 | nn.Linear(64, 64, device=device), 157 | nn.ReLU(), 158 | nn.Linear(64, 64, device=device), 159 | nn.ReLU(), 160 | nn.Linear(64, 64, device=device), 161 | nn.ReLU(), 162 | ) 163 | if tdmodule: 164 | t = TensorDictModule(t, in_keys=["x"], out_keys=["y"]) 165 | 166 | x = torch.randn(1, 1, 64, device=device) 167 | t.eval() 168 | params = TensorDict.from_module(t) 169 | if not stack: 170 | params = params.expand(2).to_tensordict().lock_() 171 | else: 172 | params = torch.stack([params, params.clone()], 0).lock_() 173 | 174 | def fun(x, params): 175 | with params.to_module(t): 176 | return t(x) 177 | 178 | vfun = vmap(fun, (None, 0)) 179 | 180 | if tdmodule: 181 | data = TensorDict({"x": x}, []) 182 | vfun(data, params) 183 | benchmark(vfun, data, params) 184 | else: 185 | vfun(x, params) 186 | benchmark(vfun, x, params) 187 | 188 | 189 | @torch.no_grad() 190 | @pytest.mark.skipif( 191 | not torch.cuda.device_count(), reason="cuda device required for test" 192 | ) 193 | @pytest.mark.parametrize("stack", [True, False]) 194 | @pytest.mark.parametrize("tdmodule", [True, False]) 195 | def test_vmap_transformer_speed_decorator(benchmark, stack, tdmodule): 196 | # tests speed of vmapping over a transformer 197 | device = "cuda" if torch.cuda.device_count() else "cpu" 198 | t = torch.nn.Transformer( 199 | 8, 200 | dim_feedforward=8, 201 | device=device, 202 | batch_first=False, 203 | ) 204 | if tdmodule: 205 | t = TensorDictModule(t, in_keys=["x", "x"], out_keys=["y"]) 206 | 207 | x = torch.randn(2, 2, 8, device=device) 208 | t.eval() 209 | params = TensorDict.from_module(t) 210 | if not stack: 211 | params = params.expand(2).to_tensordict().lock_() 212 | else: 213 | params = torch.stack([params, params.clone()], 0).lock_() 214 | 215 | if tdmodule: 216 | 217 | def fun(x, params): 218 | with params.to_module(t): 219 | return t(x) 220 | 221 | vfun = vmap(fun, (None, 0)) 222 | data = TensorDict({"x": x}, []) 223 | vfun(data, params) 224 | benchmark(vfun, data, params) 225 | else: 226 | 227 | def fun(x, params): 228 | with params.to_module(t): 229 | return t(x, x) 230 | 231 | vfun = vmap(fun, (None, 0)) 232 | vfun(x, params) 233 | benchmark(vfun, x, params) 234 | 235 | 236 | @pytest.mark.parametrize("tdparams", [True, False]) 237 | def test_to_module_speed(benchmark, tdparams): 238 | module = torch.nn.Transformer() 239 | params = TensorDict.from_module(module, as_module=tdparams) 240 | 241 | def func(params=params, module=module): 242 | with params.to_module(module): 243 | pass 244 | return 245 | 246 | benchmark(func) 247 | 248 | 249 | if __name__ == "__main__": 250 | args, unknown = argparse.ArgumentParser().parse_known_args() 251 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 252 | -------------------------------------------------------------------------------- /benchmarks/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest-benchmark 2 | tenacity 3 | -------------------------------------------------------------------------------- /benchmarks/tensorclass/test_tensorclass_speed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import argparse 8 | 9 | import pytest 10 | import torch 11 | 12 | from tensordict import tensorclass, TensorClass 13 | 14 | 15 | @tensorclass 16 | class MyData: 17 | a: torch.Tensor 18 | b: torch.Tensor 19 | c: str 20 | d: "MyData" = None 21 | 22 | 23 | class MyDataTensorOnly(TensorClass["tensor_only"]): 24 | a: torch.Tensor 25 | b: torch.Tensor 26 | c: torch.Tensor | None 27 | 28 | 29 | def test_tc_init(benchmark): 30 | z = torch.zeros(()) 31 | o = torch.ones(()) 32 | benchmark(lambda: MyData(a=z, b=o, c="a string", d=None)) 33 | 34 | 35 | def test_tc_init_tensor_only(benchmark): 36 | z = torch.zeros(()) 37 | o = torch.ones(()) 38 | benchmark(lambda: MyDataTensorOnly(a=z, b=o, c=None)) 39 | 40 | 41 | def test_tc_init_nested(benchmark): 42 | z = torch.zeros(()) 43 | o = torch.ones(()) 44 | benchmark( 45 | lambda: MyData(a=z, b=o, c="a string", d=MyData(a=z, b=o, c="a string", d=None)) 46 | ) 47 | 48 | 49 | def test_tc_first_layer_tensor(benchmark): 50 | d = MyData(a=0, b=1, c="a string", d=MyData(None, None, None)) 51 | 52 | def get(): 53 | return d.a 54 | 55 | benchmark(get) 56 | 57 | 58 | def test_tc_first_layer_tensor_only(benchmark): 59 | z = torch.zeros(()) 60 | o = torch.ones(()) 61 | d = MyDataTensorOnly(a=z, b=o, c=None) 62 | 63 | def get(): 64 | return d.a 65 | 66 | benchmark(get) 67 | 68 | 69 | def test_tc_first_layer_tensor_set(benchmark): 70 | d = MyData(a=0, b=1, c="a string", d=MyData(None, None, None)) 71 | z = torch.zeros(()) 72 | 73 | def set(d=d, z=z): 74 | d.a = z 75 | 76 | benchmark(set) 77 | 78 | 79 | def test_tc_first_layer_tensor_only_set(benchmark): 80 | z = torch.zeros(()) 81 | o = torch.ones(()) 82 | d = MyDataTensorOnly(a=z, b=o, c=None) 83 | 84 | def set(d=d, z=z): 85 | d.a = z 86 | 87 | benchmark(set) 88 | 89 | 90 | def test_tc_first_layer_nontensor(benchmark): 91 | d = MyData(a=0, b=1, c="a string", d=MyData(None, None, None)) 92 | benchmark(lambda: d.c) 93 | 94 | 95 | def test_tc_second_layer_tensor(benchmark): 96 | d = MyData(a=0, b=1, c="a string", d=MyData(torch.zeros(()), None, None)) 97 | benchmark(lambda: d.d.a) 98 | 99 | 100 | def test_tc_second_layer_nontensor(benchmark): 101 | d = MyData(a=0, b=1, c="a string", d=MyData(torch.zeros(()), None, "a string")) 102 | benchmark(lambda: d.d.c) 103 | 104 | 105 | if __name__ == "__main__": 106 | args, unknown = argparse.ArgumentParser().parse_known_args() 107 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 108 | -------------------------------------------------------------------------------- /benchmarks/tensorclass/test_torch_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import argparse 8 | 9 | import pytest 10 | import torch 11 | 12 | from tensordict import tensorclass 13 | 14 | 15 | @tensorclass 16 | class MyData: 17 | a: torch.Tensor 18 | b: torch.Tensor 19 | other: str 20 | nested: "MyData" = None 21 | 22 | 23 | @pytest.fixture 24 | def a(): 25 | return torch.zeros(300, 400, 50) 26 | 27 | 28 | @pytest.fixture 29 | def b(): 30 | return torch.zeros(300, 400, 50) 31 | 32 | 33 | @pytest.fixture 34 | def tc(a, b): 35 | return MyData( 36 | a=a, 37 | b=b, 38 | other="hello", 39 | nested=MyData( 40 | a=a.clone(), b=b.clone(), other="goodbye", batch_size=[300, 400, 50] 41 | ), 42 | batch_size=[300, 400], 43 | ) 44 | 45 | 46 | def test_unbind(benchmark, tc): 47 | benchmark(torch.unbind, tc, 0) 48 | 49 | 50 | def test_full_like(benchmark, tc): 51 | benchmark(torch.full_like, tc, 2.0) 52 | 53 | 54 | def test_zeros_like(benchmark, tc): 55 | benchmark( 56 | torch.zeros_like, 57 | tc, 58 | ) 59 | 60 | 61 | def test_ones_like(benchmark, tc): 62 | benchmark( 63 | torch.ones_like, 64 | tc, 65 | ) 66 | 67 | 68 | def test_clone(benchmark, tc): 69 | benchmark( 70 | torch.clone, 71 | tc, 72 | ) 73 | 74 | 75 | def test_squeeze(benchmark, tc): 76 | benchmark( 77 | torch.squeeze, 78 | tc, 79 | ) 80 | 81 | 82 | def test_unsqueeze(benchmark, tc): 83 | benchmark(torch.unsqueeze, tc, 0) 84 | 85 | 86 | def test_split(benchmark, tc): 87 | benchmark(torch.split, tc, [200, 100]) 88 | 89 | 90 | def test_permute(benchmark, tc): 91 | benchmark(torch.permute, tc, [1, 0]) 92 | 93 | 94 | def test_stack(benchmark, tc): 95 | benchmark(torch.stack, [tc] * 3, 0) 96 | 97 | 98 | def test_cat(benchmark, tc): 99 | benchmark(torch.cat, [tc] * 3, 0) 100 | 101 | 102 | if __name__ == "__main__": 103 | args, unknown = argparse.ArgumentParser().parse_known_args() 104 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 105 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | DATADIR = _data 11 | 12 | ZIPOPTS ?= -qo 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | 18 | .PHONY: help Makefile 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | 25 | download: 26 | # imagenet tutorial data 27 | wget -nv -N https://download.pytorch.org/tutorial/hymenoptera_data.zip -P $(DATADIR) 28 | mkdir -p $(SOURCEDIR)/reference/generated/tutorials/data 29 | unzip $(ZIPOPTS) $(DATADIR)/hymenoptera_data.zip -d $(SOURCEDIR)/reference/generated/tutorials/data 30 | 31 | docs: 32 | make download 33 | make html 34 | -------------------------------------------------------------------------------- /docs/build_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bash 2 | 3 | rm -rf _local_build build generated 4 | #sphinx-autogen -o generated source/reference/*.rst && sphinx-build ./source _local_build && 5 | make docs 6 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | sphinx-copybutton 4 | sphinx-gallery 5 | sphinx==5.0.0 6 | Jinja2==3.1.4 7 | sphinx-autodoc-typehints 8 | sphinx-serve 9 | git+https://github.com/vmoens/aafig@4319769eae88fff8e3464858f3cf8c277f35335d 10 | sphinxcontrib-htmlhelp 11 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 12 | myst-parser 13 | docutils 14 | 15 | torchvision 16 | tqdm 17 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom_torchrl.css: -------------------------------------------------------------------------------- 1 | article.pytorch-article .sphx-glr-download-link-note.admonition.note, 2 | article.pytorch-article .reference.download.internal, article.pytorch-article .sphx-glr-signature { 3 | display: block; 4 | } 5 | -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/docs/source/_static/img/pytorch-logo-dark.png -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 10 | 13 | 14 | 16 | 17 | 18 | 20 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-flame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/docs/source/_static/img/pytorch-logo-flame.png -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-flame.svg: -------------------------------------------------------------------------------- 1 | 2 | image/svg+xml 34 | -------------------------------------------------------------------------------- /docs/source/_templates/class.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autoclass:: {{ name }} 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/_templates/function.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autofunction:: {{ name }} 7 | -------------------------------------------------------------------------------- /docs/source/_templates/td_template.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autoclass:: {{ name }} 7 | :members: 8 | :inherited-members: 9 | -------------------------------------------------------------------------------- /docs/source/_templates/td_template_noinherit.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autoclass:: {{ name }} 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # This file only contains a selection of the most common options. For a full 9 | # list see the documentation: 10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 11 | 12 | # -- Path setup -------------------------------------------------------------- 13 | 14 | # If extensions (or modules to document with autodoc) are in another directory, 15 | # add these directories to sys.path here. If the directory is relative to the 16 | # documentation root, use os.path.abspath to make it absolute, like shown here. 17 | # 18 | # import os 19 | # import sys 20 | # sys.path.insert(0, os.path.abspath('.')) 21 | 22 | 23 | # -- Project information ----------------------------------------------------- 24 | import os 25 | import sys 26 | 27 | import pytorch_sphinx_theme 28 | 29 | import tensordict 30 | 31 | project = "tensordict" 32 | copyright = "2022, Meta" 33 | author = "Torch Contributors" 34 | 35 | # The version info for the project you're documenting, acts as replacement for 36 | # |version| and |release|, also used in various other places throughout the 37 | # built documents. 38 | # version: The short X.Y version. 39 | # release: The full version, including alpha/beta/rc tags. 40 | if os.environ.get("TENSORDICT_SANITIZE_VERSION_STR_IN_DOCS", None): 41 | # Turn 1.11.0aHASH into 1.11 (major.minor only) 42 | version = release = ".".join(tensordict.__version__.split(".")[:2]) 43 | html_title = " ".join((project, version, "documentation")) 44 | else: 45 | version = f"main ({tensordict.__version__})" 46 | release = "main" 47 | 48 | # The language for content autogenerated by Sphinx. Refer to documentation 49 | # for a list of supported languages. 50 | # 51 | # This is also used if you do content translation via gettext catalogs. 52 | # Usually you set "language" from the command line for these cases. 53 | language = "en" 54 | 55 | # -- General configuration --------------------------------------------------- 56 | 57 | # Add any Sphinx extension module names here, as strings. They can be 58 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 59 | # ones. 60 | extensions = [ 61 | "sphinx.ext.napoleon", 62 | "sphinx.ext.autodoc", 63 | "sphinx.ext.autosummary", 64 | "sphinx.ext.doctest", 65 | "sphinx.ext.intersphinx", 66 | "sphinx.ext.mathjax", 67 | "sphinx_gallery.gen_gallery", 68 | "sphinxcontrib.aafig", 69 | "myst_parser", 70 | ] 71 | 72 | sphinx_gallery_conf = { 73 | "examples_dirs": "reference/generated/tutorials/", # path to your example scripts 74 | "gallery_dirs": "tutorials", # path to where to save gallery generated output 75 | "backreferences_dir": "gen_modules/backreferences", 76 | "doc_module": ("tensordict",), 77 | "filename_pattern": "reference/generated/tutorials/", # files to parse 78 | "notebook_images": "reference/generated/tutorials/media/", # images to parse 79 | "download_all_examples": True, 80 | } 81 | 82 | # sphinx_gallery_conf = { 83 | # "examples_dirs": "../../gallery/", # path to your example scripts 84 | # "gallery_dirs": "auto_examples", # path to where to save gallery generated output 85 | # "backreferences_dir": "gen_modules/backreferences", 86 | # "doc_module": ("tensordict",), 87 | # } 88 | 89 | napoleon_use_ivar = True 90 | napoleon_numpy_docstring = False 91 | napoleon_google_docstring = True 92 | autosectionlabel_prefix_document = True 93 | 94 | # Add any paths that contain templates here, relative to this directory. 95 | templates_path = ["_templates"] 96 | 97 | # The suffix(es) of source filenames. 98 | # You can specify multiple suffix as a list of string: 99 | # 100 | source_suffix = { 101 | ".rst": "restructuredtext", 102 | } 103 | 104 | # The master toctree document. 105 | master_doc = "index" 106 | 107 | # List of patterns, relative to source directory, that match files and 108 | # directories to ignore when looking for source files. 109 | # This pattern also affects html_static_path and html_extra_path. 110 | exclude_patterns = ["reference/generated/tutorials/README.rst"] 111 | 112 | # -- Options for HTML output ------------------------------------------------- 113 | 114 | # The theme to use for HTML and HTML Help pages. See the documentation for 115 | # a list of builtin themes. 116 | # 117 | html_theme = "pytorch_sphinx_theme" 118 | html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] 119 | html_theme_options = { 120 | # "pytorch_project": "tensordict", 121 | "collapse_navigation": False, 122 | "display_version": True, 123 | "logo_only": True, 124 | "pytorch_project": "docs", 125 | "navigation_with_keys": True, 126 | "analytics_id": "GTM-T8XT4PS", 127 | } 128 | 129 | # Output file base name for HTML help builder. 130 | htmlhelp_basename = "PyTorchdoc" 131 | 132 | autosummary_generate = True 133 | 134 | # Add any paths that contain custom static files (such as style sheets) here, 135 | # relative to this directory. They are copied after the builtin static files, 136 | # so a file named "default.css" will overwrite the builtin "default.css". 137 | html_static_path = ["_static"] 138 | 139 | # -- Options for LaTeX output --------------------------------------------- 140 | latex_elements = {} 141 | 142 | 143 | # -- Options for manual page output --------------------------------------- 144 | 145 | # One entry per manual page. List of tuples 146 | # (source start file, name, description, authors, manual section). 147 | man_pages = [(master_doc, "torchvision", "tensordict Documentation", [author], 1)] 148 | 149 | 150 | # -- Options for Texinfo output ------------------------------------------- 151 | 152 | # Grouping the document tree into Texinfo files. List of tuples 153 | # (source start file, target name, title, author, 154 | # dir menu entry, description, category) 155 | texinfo_documents = [ 156 | ( 157 | master_doc, 158 | "tensordict", 159 | "tensordict Documentation", 160 | author, 161 | "tensordict", 162 | "TensorDict doc.", 163 | "Miscellaneous", 164 | ), 165 | ] 166 | 167 | 168 | # Example configuration for intersphinx: refer to the Python standard library. 169 | intersphinx_mapping = { 170 | "python": ("https://docs.python.org/3/", None), 171 | "torch": ("https://pytorch.org/docs/stable/", None), 172 | "numpy": ("https://numpy.org/doc/stable/", None), 173 | } 174 | 175 | 176 | aafig_default_options = {"scale": 1.5, "aspect": 1.0, "proportional": True} 177 | 178 | current_path = os.path.dirname(os.path.realpath(__file__)) 179 | sys.path.append(current_path) 180 | from content_generation import generate_tutorial_references 181 | 182 | generate_tutorial_references("../../tutorials/sphinx_tuto", "tutorial") 183 | generate_tutorial_references("../../tutorials/src/", "src") 184 | generate_tutorial_references("../../tutorials/media/", "media") 185 | -------------------------------------------------------------------------------- /docs/source/content_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from typing import List 5 | 6 | FILE_DIR = os.path.dirname(__file__) 7 | KNOWLEDGE_GEN_DIR = "reference/generated/knowledge_base" 8 | TUTORIALS_GEN_DIR = "reference/generated/tutorials" 9 | TUTORIALS_SRC_GEN_DIR = "reference/generated/tutorials/src" 10 | TUTORIALS_MEDIA_GEN_DIR = "reference/generated/tutorials/media" 11 | 12 | 13 | def _get_file_content(name: str) -> List[str]: 14 | """A function to get the content of a reference file. 15 | 16 | Given the name of a knowledge base file, populates a file template. The result can be used to link a knowledge base 17 | entry to the Sphinx docs. 18 | 19 | Args: 20 | name (str): name of the file to be referenced (without extension). 21 | 22 | Returns: List of strings 23 | 24 | """ 25 | return [ 26 | "..\n", 27 | " This file is generated by knowledge_base.py, manual changes will be overwritten.\n", 28 | "\n", 29 | f".. include:: ../../../../../knowledge_base/{name}.md\n", 30 | " :parser: myst_parser.sphinx_\n", 31 | "\n", 32 | ] 33 | 34 | 35 | def generate_knowledge_base_references(knowledge_base_path: str) -> None: 36 | """Creates a reference file per knowledge base entry. 37 | 38 | Sphinx natively doesn't support adding files from outside its root directory. To include the knowledge base in 39 | our docs (https://pytorch.org/rl/) each entry is linked using an auto-generated file that references the original. 40 | 41 | Args: 42 | knowledge_base_path (str): path to the knowledge base. 43 | """ 44 | # Create target dir 45 | target_path = os.path.join(FILE_DIR, KNOWLEDGE_GEN_DIR) 46 | Path(target_path).mkdir(parents=True, exist_ok=True) 47 | 48 | # Iterate knowledge base files 49 | file_paths = os.listdir(os.path.join(FILE_DIR, knowledge_base_path)) 50 | for file_path in file_paths: 51 | name = Path(file_path).stem 52 | 53 | # Skip README, it is already included in `knowledge_base.rst` 54 | if name == "README": 55 | continue 56 | 57 | # Existing files will be overwritten in 'w' mode 58 | with open(os.path.join(target_path, f"{name}.rst"), "w") as file: 59 | file.writelines(_get_file_content(name)) 60 | 61 | 62 | def generate_tutorial_references(tutorial_path: str, file_type: str) -> None: 63 | """Creates a python file per tutorial script. 64 | 65 | Sphinx natively doesn't support adding files from outside its root directory. To include the tutorials in 66 | our docs (https://pytorch.org/rl/) each entry is locally copied. 67 | 68 | Args: 69 | tutorial_path (str): path to the tutorial scripts. 70 | """ 71 | # Create target dir 72 | if file_type == "tutorial": 73 | target_path = os.path.join(FILE_DIR, TUTORIALS_GEN_DIR) 74 | elif file_type == "src": 75 | target_path = os.path.join(FILE_DIR, TUTORIALS_SRC_GEN_DIR) 76 | else: 77 | target_path = os.path.join(FILE_DIR, TUTORIALS_MEDIA_GEN_DIR) 78 | Path(target_path).mkdir(parents=True, exist_ok=True) 79 | 80 | # Iterate tutorial files and copy 81 | file_paths = [ 82 | os.path.join(tutorial_path, f) 83 | for f in os.listdir(tutorial_path) 84 | if f.endswith((".py", ".rst", ".png")) 85 | ] 86 | 87 | for file_path in file_paths: 88 | shutil.copyfile(file_path, os.path.join(target_path, Path(file_path).name)) 89 | -------------------------------------------------------------------------------- /docs/source/distributed.rst: -------------------------------------------------------------------------------- 1 | .. _distributed: 2 | 3 | TensorDict in distributed settings 4 | ================================== 5 | 6 | TensorDict can be used in distributed settings to pass tensors from one node 7 | to another. 8 | If two nodes have access to a shared physical storage, a memory-mapped tensor can 9 | be used to efficiently pass data from one running process to another. 10 | Here, we provide some details on how this can be achieved in a distributed RPC setting. 11 | For more details on distributed RPC, check the 12 | `official pytorch documentation `_. 13 | 14 | Creating a memory-mapped TensorDict 15 | ----------------------------------- 16 | 17 | Memory-mapped tensors (and arrays) have the great advantage that they can store 18 | a great amount of data and allow slices of data to be accessed readily without 19 | reading the whole file in memory. 20 | TensorDict offers an interface between memory-mapped 21 | arrays and the :obj:`torch.Tensor` class named :obj:`MemmapTensor`. 22 | :obj:`MemmapTensor` instances can be stored in :obj:`TensorDict` objects, allowing a 23 | tensordict to represent a big dataset, stored on disk, easily accessible in a 24 | batched way across nodes. 25 | 26 | A memory-mapped tensordict is simply created via (1) populating a TensorDict with 27 | memory-mapped tensors or (2) by calling :obj:`tensordict.memmap_()` to put it on 28 | physical storage. 29 | One can easily check that a tensordict is put on physical storage by querying 30 | `tensordict.is_memmap()`. 31 | 32 | Creating a memory-mapped tensor can itself be done in several ways. 33 | Firstly, one can simply create an empty tensor: 34 | 35 | >>> shape = torch.Size([3, 4, 5]) 36 | >>> tensor = Memmaptensor(*shape, prefix="/tmp") 37 | >>> tensor[:2] = torch.randn(2, 4, 5) 38 | 39 | The :obj:`prefix` attribute indicates where the temporary file has to be stored. 40 | It is crucial that the tensor is stored in a directory that is accessible to every 41 | node! 42 | 43 | Another option is to represent an existing tensor on disk: 44 | 45 | >>> tensor = torch.randn(3) 46 | >>> tensor = Memmaptensor(tensor, prefix="/tmp") 47 | 48 | The former method will be preferred when tensors are big or do not fit in memory: 49 | it is suitable for tensors that are extremely big and serve as common storage 50 | across nodes. For instance, one could create a dataset that would be easily accessed 51 | by a single or different nodes, much faster than it would be if each file had to be 52 | loaded independently in memory: 53 | 54 | .. code-block:: 55 | :caption: Creating an empty dataset on disk 56 | 57 | >>> dataset = TensorDict({ 58 | ... "images": MemmapTensor(50000, 480, 480, 3), 59 | ... "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool), 60 | ... "labels": MemmapTensor(50000, 1, dtype=torch.uint8), 61 | ... }, batch_size=[50000], device="cpu") 62 | >>> idx = [1, 5020, 34572, 11200] 63 | >>> batch = dataset[idx].clone() 64 | TensorDict( 65 | fields={ 66 | images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32), 67 | labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8), 68 | masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)}, 69 | batch_size=torch.Size([4]), 70 | device=cpu, 71 | is_shared=False) 72 | 73 | Notice that we have indicated the device of the :obj:`MemmapTensor`. 74 | This syntax sugar allows for the tensors that are queried to be directly loaded 75 | on device if needed. 76 | 77 | Another consideration to take into account is that currently :obj:`MemmapTensor` 78 | is not compatible with autograd operations. 79 | 80 | Operating on Memory-mapped tensors across nodes 81 | ----------------------------------------------- 82 | 83 | We provide a simple example of a distributed script where one process creates a 84 | memory-mapped tensor, and sends its reference to another worker that is responsible of 85 | updating it. You will find this example in the 86 | `benchmark directory `_. 87 | 88 | In short, our goal is to show how to handle read and write operations on big 89 | tensors when nodes have access to a shared physical storage. The steps involve: 90 | 91 | - Creating the empty tensor on disk; 92 | 93 | - Setting the local and remote operations to be executed; 94 | 95 | - Passing commands from worker to worker using RPC to read and write the 96 | shared data. 97 | 98 | This example first writes a function that updates a TensorDict instance 99 | at specific indices with a one-filled tensor: 100 | 101 | >>> def fill_tensordict(tensordict, idx): 102 | ... tensordict[idx] = TensorDict( 103 | ... {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5] 104 | ... ) 105 | ... return tensordict 106 | >>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict) 107 | 108 | The :obj:`CloudpickleWrapper` ensures that the function is serializable. 109 | Next, we create a tensordict of a considerable size, to make the point that 110 | this would be hard to pass from worker to worker if it had to be passed through 111 | a regular tensorpipe: 112 | 113 | >>> tensordict = TensorDict( 114 | ... {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000] 115 | ... ) 116 | 117 | Finally, still on the main node, we call the function *on the remote node* and then 118 | check that the data has been written where needed: 119 | 120 | >>> idx = [4, 5, 6, 7, 998] 121 | >>> t0 = time.time() 122 | >>> out = rpc.rpc_sync( 123 | ... worker_info, 124 | ... fill_tensordict_cp, 125 | ... args=(tensordict, idx), 126 | ... ) 127 | >>> print("time elapsed:", time.time() - t0) 128 | >>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone()) 129 | 130 | Although the call to :obj:`rpc.rpc_sync` involved passing the entire tensordict, 131 | updating specific indices of this object and return it to the original worker, 132 | the execution of this snippet is extremely fast (even more so if the reference 133 | to the memory location is already passed beforehand, see `torchrl's distributed 134 | replay buffer documentation `_ to learn more). 135 | 136 | The script contains additional RPC configuration steps that are beyond the 137 | purpose of this document. 138 | -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | [html writers] 2 | table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent 3 | -------------------------------------------------------------------------------- /docs/source/fx.rst: -------------------------------------------------------------------------------- 1 | Tracing TensorDictModule 2 | ======================== 3 | 4 | We support tracing execution of :obj:`TensorDictModule` to create FX graphs. Simply import :obj:`symbolic_trace` from ``tensordict.prototype.fx`` instead of ``torch.fx``. 5 | 6 | .. note:: Support for ``torch.fx`` is highly experimental and subject to change. Use with caution, and raise an issue if you try it out and encounter problems. 7 | 8 | Tracing a :obj:`TensorDictModule` 9 | --------------------------------- 10 | 11 | We'll illustrate with an example from the overview. We create a :obj:`TensorDictModule`, trace it, and inspect the graph and generated code. 12 | 13 | .. code-block:: 14 | :caption: Tracing a TensorDictModule 15 | 16 | >>> import torch 17 | >>> import torch.nn as nn 18 | >>> from tensordict import TensorDict 19 | >>> from tensordict.nn import TensorDictModule 20 | >>> from tensordict.prototype.fx import symbolic_trace 21 | 22 | >>> class Net(nn.Module): 23 | ... def __init__(self): 24 | ... super().__init__() 25 | ... self.linear = nn.LazyLinear(1) 26 | ... 27 | ... def forward(self, x): 28 | ... logits = self.linear(x) 29 | ... return logits, torch.sigmoid(logits) 30 | >>> module = TensorDictModule( 31 | ... Net(), 32 | ... in_keys=["input"], 33 | ... out_keys=[("outputs", "logits"), ("outputs", "probabilities")], 34 | ... ) 35 | >>> graph_module = symbolic_trace(module) 36 | >>> print(graph_module.graph) 37 | graph(): 38 | %tensordict : [#users=1] = placeholder[target=tensordict] 39 | %getitem : [#users=1] = call_function[target=operator.getitem](args = (%tensordict, input), kwargs = {}) 40 | %linear : [#users=2] = call_module[target=linear](args = (%getitem,), kwargs = {}) 41 | %sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {}) 42 | return (linear, sigmoid) 43 | >>> print(graph_module.code) 44 | 45 | def forward(self, tensordict): 46 | getitem = tensordict['input']; tensordict = None 47 | linear = self.linear(getitem); getitem = None 48 | sigmoid = torch.sigmoid(linear) 49 | return (linear, sigmoid) 50 | 51 | We can check that a forward pass with each module results in the same outputs. 52 | 53 | >>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32]) 54 | >>> module_out = module(tensordict, tensordict_out=TensorDict()) 55 | >>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict()) 56 | >>> assert ( 57 | ... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"] 58 | ... ).all() 59 | >>> assert ( 60 | ... module_out["outputs", "probabilities"] 61 | ... == graph_module_out["outputs", "probabilities"] 62 | ... ).all() 63 | 64 | Tracing a :obj:`TensorDictSequential` 65 | ------------------------------------- 66 | 67 | We can also trace :obj:`TensorDictSequential`. In this case the entire execution of the module is traced into a single graph, eliminating intermediate reads and writes on the input :obj:`TensorDict`. 68 | 69 | We demonstrate by tracing the sequential example from the overview. 70 | 71 | .. code-block:: 72 | :caption: Tracing TensorDictSequential 73 | 74 | >>> import torch 75 | >>> import torch.nn as nn 76 | >>> from tensordict import TensorDict 77 | >>> from tensordict.nn import TensorDictModule, TensorDictSequential 78 | >>> from tensordict.prototype.fx import symbolic_trace 79 | 80 | >>> class Net(nn.Module): 81 | ... def __init__(self, input_size=100, hidden_size=50, output_size=10): 82 | ... super().__init__() 83 | ... self.fc1 = nn.Linear(input_size, hidden_size) 84 | ... self.fc2 = nn.Linear(hidden_size, output_size) 85 | ... 86 | ... def forward(self, x): 87 | ... x = torch.relu(self.fc1(x)) 88 | ... return self.fc2(x) 89 | ... 90 | ... class Masker(nn.Module): 91 | ... def forward(self, x, mask): 92 | ... return torch.softmax(x * mask, dim=1) 93 | >>> net = TensorDictModule( 94 | ... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")] 95 | ... ) 96 | >>> masker = TensorDictModule( 97 | ... Masker(), 98 | ... in_keys=[("intermediate", "x"), ("input", "mask")], 99 | ... out_keys=[("output", "probabilities")], 100 | ... ) 101 | >>> module = TensorDictSequential(net, masker) 102 | >>> graph_module = symbolic_trace(module) 103 | >>> print(graph_module.code) 104 | 105 | def forward(self, tensordict): 106 | getitem = tensordict[('input', 'x')] 107 | _0_fc1 = getattr(self, "0").module.fc1(getitem); getitem = None 108 | relu = torch.relu(_0_fc1); _0_fc1 = None 109 | _0_fc2 = getattr(self, "0").module.fc2(relu); relu = None 110 | getitem_1 = tensordict[('input', 'mask')]; tensordict = None 111 | mul = _0_fc2 * getitem_1; getitem_1 = None 112 | softmax = torch.softmax(mul, dim = 1); mul = None 113 | return (_0_fc2, softmax) 114 | 115 | In this case the generated graph and code is a bit more complicated. We can visualize it as follows (requires ``pydot``) 116 | 117 | .. code-block:: 118 | :caption: Visualising the graph 119 | 120 | >>> from torch.fx.passes.graph_drawer import FxGraphDrawer 121 | >>> g = FxGraphDrawer(graph_module, "sequential") 122 | >>> with open("graph.svg", "wb") as f: 123 | ... f.write(g.get_dot_graph().create_svg()) 124 | 125 | Which results in the following visualisation 126 | 127 | .. image:: _static/img/graph.svg 128 | :alt: Visualization of the traced graph. 129 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. tensordict documentation master file, created by 2 | sphinx-quickstart on Mon Mar 7 13:23:20 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to the TensorDict Documentation! 7 | ======================================== 8 | 9 | `TensorDict` is a dictionary-like class that inherits properties from tensors, 10 | such as indexing, shape operations, casting to device etc. 11 | 12 | You can install tensordict directly from PyPI (see more about installation 13 | instructions in the dedicated section below): 14 | 15 | .. code-block:: 16 | 17 | $ pip install tensordict 18 | 19 | 20 | The main purpose of TensorDict is to make code-bases more *readable* and *modular* 21 | by abstracting away tailored operations: 22 | 23 | >>> for i, tensordict in enumerate(dataset): 24 | ... # the model reads and writes tensordicts 25 | ... tensordict = model(tensordict) 26 | ... loss = loss_module(tensordict) 27 | ... loss.backward() 28 | ... optimizer.step() 29 | ... optimizer.zero_grad() 30 | 31 | With this level of abstraction, one can recycle a training loop for highly heterogeneous task. 32 | Each individual step of the training loop (data collection and transform, model 33 | prediction, loss computation etc.) 34 | can be tailored to the use case at hand without impacting the others. 35 | For instance, the above example can be easily used across classification and segmentation tasks, among many others. 36 | 37 | 38 | Installation 39 | ============ 40 | 41 | Tensordict releases are synced with PyTorch, so make sure you always enjoy the latest 42 | features of the library with the `most recent version of PyTorch `__ (although core features 43 | are guaranteed to be backward compatible with pytorch>=1.13). 44 | Nightly releases can be installed via 45 | 46 | .. code-block:: 47 | 48 | $ pip install tensordict-nightly 49 | 50 | or via a `git clone` if you're willing to contribute to the library: 51 | 52 | .. code-block:: 53 | 54 | $ cd path/to/root 55 | $ git clone https://github.com/pytorch/tensordict 56 | $ cd tensordict 57 | $ python setup.py develop 58 | 59 | Tutorials 60 | ========= 61 | 62 | Basics 63 | ------ 64 | 65 | .. toctree:: 66 | :maxdepth: 1 67 | 68 | tutorials/tensordict_shapes 69 | tutorials/tensordict_slicing 70 | tutorials/tensordict_keys 71 | tutorials/tensordict_preallocation 72 | tutorials/tensordict_memory 73 | tutorials/streamed_tensordict 74 | 75 | tensordict.nn 76 | ------------- 77 | 78 | .. toctree:: 79 | :maxdepth: 1 80 | 81 | tutorials/tensordict_module 82 | tutorials/export 83 | 84 | Dataloading 85 | ----------- 86 | 87 | .. toctree:: 88 | :maxdepth: 1 89 | 90 | tutorials/data_fashion 91 | tutorials/tensorclass_fashion 92 | tutorials/tensorclass_imagenet 93 | 94 | Contents 95 | ======== 96 | 97 | .. toctree:: 98 | :maxdepth: 3 99 | 100 | overview 101 | distributed 102 | fx 103 | saving 104 | reference/index 105 | 106 | Indices and tables 107 | ================== 108 | 109 | * :ref:`genindex` 110 | * :ref:`modindex` 111 | * :ref:`search` 112 | -------------------------------------------------------------------------------- /docs/source/reference/generated/tutorials/README.rst: -------------------------------------------------------------------------------- 1 | README Tutos 2 | ============ 3 | -------------------------------------------------------------------------------- /docs/source/reference/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | tensordict 8 | nn 9 | tensorclass 10 | -------------------------------------------------------------------------------- /docs/tensordict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/docs/tensordict.png -------------------------------------------------------------------------------- /gallery/README.rst: -------------------------------------------------------------------------------- 1 | Example gallery 2 | =============== 3 | 4 | Below is a gallery of examples 5 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | files = tensordict 4 | show_error_codes = True 5 | pretty = True 6 | allow_redefinition = True 7 | warn_redundant_casts = True 8 | 9 | [mypy-torchvision.*] 10 | 11 | ignore_errors = True 12 | ignore_missing_imports = True 13 | 14 | [mypy-numpy.*] 15 | 16 | ignore_missing_imports = True 17 | 18 | [mypy-scipy.*] 19 | 20 | ignore_missing_imports = True 21 | 22 | [mypy-pycocotools.*] 23 | 24 | ignore_missing_imports = True 25 | 26 | [mypy-lmdb.*] 27 | 28 | ignore_missing_imports = True 29 | 30 | [mypy-tqdm.*] 31 | 32 | ignore_missing_imports = True 33 | 34 | [mypy-moviepy.*] 35 | 36 | ignore_missing_imports = True 37 | 38 | [mypy-dm_control.*] 39 | 40 | ignore_missing_imports = True 41 | 42 | [mypy-dm_env.*] 43 | 44 | ignore_missing_imports = True 45 | 46 | [mypy-retro.*] 47 | 48 | ignore_missing_imports = True 49 | 50 | [mypy-gym.*] 51 | 52 | ignore_missing_imports = True 53 | -------------------------------------------------------------------------------- /packaging/build_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | . "$script_dir/pkg_helpers.bash" 6 | 7 | export BUILD_TYPE=wheel 8 | setup_env 9 | setup_wheel_python 10 | pip_install numpy pyyaml future ninja 11 | pip_install --upgrade setuptools 12 | setup_pip_pytorch_version 13 | python setup.py clean 14 | 15 | # Copy binaries to be included in the wheel distribution 16 | if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then 17 | python_exec="$(which python)" 18 | bin_path=$(dirname $python_exec) 19 | else 20 | # Point to custom libraries 21 | export LD_LIBRARY_PATH=$(pwd)/ext_libraries/lib:$LD_LIBRARY_PATH 22 | fi 23 | 24 | if [[ "$OSTYPE" == "msys" ]]; then 25 | echo "ERROR: Windows installation is not supported yet." && exit 100 26 | else 27 | python setup.py bdist_wheel 28 | if [[ "$(uname)" != Darwin ]]; then 29 | rename "linux_x86_64" "manylinux1_x86_64" dist/*.whl 30 | fi 31 | fi 32 | -------------------------------------------------------------------------------- /packaging/wheel/relocate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Helper script to package wheels and relocate binaries.""" 7 | 8 | import glob 9 | import hashlib 10 | 11 | # Standard library imports 12 | import os 13 | import os.path as osp 14 | import shutil 15 | import sys 16 | import zipfile 17 | from base64 import urlsafe_b64encode 18 | 19 | HERE = osp.dirname(osp.abspath(__file__)) 20 | PACKAGE_ROOT = osp.dirname(osp.dirname(HERE)) 21 | 22 | 23 | def rehash(path, blocksize=1 << 20): 24 | """Return (hash, length) for path using hashlib.sha256()""" 25 | h = hashlib.sha256() 26 | length = 0 27 | with open(path, "rb") as f: 28 | while block := f.read(blocksize): 29 | length += len(block) 30 | h.update(block) 31 | digest = "sha256=" + urlsafe_b64encode(h.digest()).decode("latin1").rstrip("=") 32 | # unicode/str python2 issues 33 | return (digest, str(length)) # type: ignore 34 | 35 | 36 | def unzip_file(file, dest): 37 | """Decompress zip `file` into directory `dest`.""" 38 | with zipfile.ZipFile(file, "r") as zip_ref: 39 | zip_ref.extractall(dest) 40 | 41 | 42 | def is_program_installed(basename): 43 | """ 44 | Return program absolute path if installed in PATH. 45 | Otherwise, return None 46 | On macOS systems, a .app is considered installed if 47 | it exists. 48 | """ 49 | if sys.platform == "darwin" and basename.endswith(".app") and osp.exists(basename): 50 | return basename 51 | 52 | for path in os.environ["PATH"].split(os.pathsep): 53 | abspath = osp.join(path, basename) 54 | if osp.isfile(abspath): 55 | return abspath 56 | 57 | 58 | def find_program(basename): 59 | """ 60 | Find program in PATH and return absolute path 61 | Try adding .exe or .bat to basename on Windows platforms 62 | (return None if not found) 63 | """ 64 | names = [basename] 65 | if os.name == "nt": 66 | # Windows platforms 67 | extensions = (".exe", ".bat", ".cmd", ".dll") 68 | if not basename.endswith(extensions): 69 | names = [basename + ext for ext in extensions] + [basename] 70 | for name in names: 71 | path = is_program_installed(name) 72 | if path: 73 | return path 74 | 75 | 76 | def compress_wheel(output_dir, wheel, wheel_dir, wheel_name): 77 | """Create RECORD file and compress wheel distribution.""" 78 | # ("Update RECORD file in wheel") 79 | dist_info = glob.glob(osp.join(output_dir, "*.dist-info"))[0] 80 | record_file = osp.join(dist_info, "RECORD") 81 | 82 | with open(record_file, "w") as f: 83 | for root, _, files in os.walk(output_dir): 84 | for this_file in files: 85 | full_file = osp.join(root, this_file) 86 | rel_file = osp.relpath(full_file, output_dir) 87 | if full_file == record_file: 88 | f.write(f"{rel_file},,\n") 89 | else: 90 | digest, size = rehash(full_file) 91 | f.write(f"{rel_file},{digest},{size}\n") 92 | 93 | # ("Compressing wheel") 94 | base_wheel_name = osp.join(wheel_dir, wheel_name) 95 | shutil.make_archive(base_wheel_name, "zip", output_dir) 96 | os.remove(wheel) 97 | shutil.move(f"{base_wheel_name}.zip", wheel) 98 | shutil.rmtree(output_dir) 99 | 100 | 101 | def patch_win(): 102 | # # Get dumpbin location 103 | # dumpbin = find_program("dumpbin") 104 | # if dumpbin is None: 105 | # raise FileNotFoundError( 106 | # "Dumpbin was not found in the system, please make sure that is available on the PATH." 107 | # ) 108 | 109 | # Find wheel 110 | # ("Finding wheels...") 111 | wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl")) 112 | if not wheels: 113 | raise FileNotFoundError( 114 | "Did not find any wheels in {}".format(osp.join(PACKAGE_ROOT, "dist")) 115 | ) 116 | output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process") 117 | 118 | for wheel in wheels: 119 | print(f"processing {wheel}") 120 | if osp.exists(output_dir): 121 | shutil.rmtree(output_dir) 122 | print(f"creating output directory {output_dir}") 123 | os.makedirs(output_dir) 124 | 125 | # ("Unzipping wheel...") 126 | wheel_file = osp.basename(wheel) 127 | wheel_dir = osp.dirname(wheel) 128 | # (f"{wheel_file}") 129 | wheel_name, _ = osp.splitext(wheel_file) 130 | print(f"unzipping {wheel} in {output_dir}") 131 | unzip_file(wheel, output_dir) 132 | print("compressing wheel") 133 | compress_wheel(output_dir, wheel, wheel_dir, wheel_name) 134 | 135 | 136 | if __name__ == "__main__": 137 | if sys.platform == "linux": 138 | pass 139 | elif sys.platform == "win32": 140 | patch_win() 141 | else: 142 | raise NotImplementedError 143 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "pybind11", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.usort] 6 | first_party_detection = false 7 | target-version = ["py39"] 8 | excludes = [ 9 | "gallery", 10 | "tutorials", 11 | ] 12 | 13 | [tool.black] 14 | line-length = 88 15 | target-version = ["py39"] 16 | 17 | [project] 18 | name = "tensordict" 19 | version = "0.9.0" 20 | description = "TensorDict is a pytorch dedicated tensor container." 21 | authors = [ 22 | { name="Vincent Moens", email="vincentmoens@gmail.com" } 23 | ] 24 | readme = "README.md" 25 | license = { text = "BSD" } 26 | requires-python = ">=3.9" 27 | classifiers = [ 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3.11", 31 | "Programming Language :: Python :: 3.12", 32 | "Programming Language :: Python :: 3.13", 33 | "Development Status :: 4 - Beta" 34 | ] 35 | dependencies = [ 36 | "torch", 37 | "numpy", 38 | "cloudpickle", 39 | "packaging", 40 | "importlib_metadata", 41 | # orjson fails to be installed in python 3.13t 42 | 'orjson ; python_version < "3.13"', 43 | ] 44 | 45 | [project.urls] 46 | homepage = "https://github.com/pytorch/tensordict" 47 | 48 | [project.optional-dependencies] 49 | tests = [ 50 | "pytest", 51 | "pyyaml", 52 | "pytest-instafail", 53 | "pytest-rerunfailures", 54 | "pytest-benchmark" 55 | ] 56 | checkpointing = ["torchsnapshot-nightly"] 57 | h5 = ["h5py>=3.8"] 58 | dev = ["pybind11", "cmake", "ninja"] 59 | 60 | [tool.setuptools] 61 | include-package-data = false 62 | 63 | [tool.setuptools.packages.find] 64 | exclude = ["test*", "tutorials*", "packaging*", "gallery*", "docs*", "benchmarks*"] 65 | 66 | #[tool.setuptools.extension] 67 | #my_extension = { sources = ["tensordict/csrc/pybind.cpp", "tensordict/csrc/utils.cpp"] } 68 | 69 | [tool.setuptools.package-data] 70 | "tensordict" = ["*.so", "*.pyd", "*.dll"] 71 | 72 | [tool.setuptools_scm] 73 | version_scheme = "post-release" 74 | write_to = "tensordict/_version.py" 75 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = 3 | # show summary of all tests that did not pass 4 | -ra 5 | # Make tracebacks shorter 6 | --tb=native 7 | testpaths = 8 | test 9 | xfail_strict = True 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = LICENSE 3 | 4 | [pep8] 5 | max-line-length = 120 6 | 7 | [flake8] 8 | # note: we ignore all 501s (line too long) anyway as they're taken care of by black 9 | max-line-length = 79 10 | ignore = E203, E402, W503, W504, E501, E701, E704 11 | per-file-ignores = 12 | __init__.py: F401, F403, F405 13 | ./hubconf.py: F401 14 | test/smoke_test.py: F401 15 | test/smoke_test_deps.py: F401 16 | test_*.py: E731, E266, TOR101 17 | tutorials/*/**.py: T201 18 | packaging/*/**.py: T201 19 | exclude = venv 20 | extend-select = B901, C401, C408, C409 21 | 22 | [pydocstyle] 23 | ;select = D417 # Missing argument descriptions in the docstring 24 | ;inherit = false 25 | match = .*\.py 26 | ;match_dir = ^(?!(.circlecli|test)).* 27 | convention = google 28 | add-ignore = D100, D104, D105, D107, D102 29 | ignore-decorators = 30 | test_* 31 | ; test/*.py 32 | ; .github/* 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import distutils.command.clean 7 | import logging 8 | import os 9 | import shutil 10 | import subprocess 11 | import sys 12 | from pathlib import Path 13 | 14 | from setuptools import Extension, find_packages, setup 15 | from setuptools.command.build_ext import build_ext 16 | 17 | ROOT_DIR = Path(__file__).parent.resolve() 18 | 19 | 20 | def get_python_executable(): 21 | # Check if we're running in a virtual environment 22 | if "VIRTUAL_ENV" in os.environ: 23 | # Get the virtual environment's Python executable 24 | python_executable = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "python") 25 | else: 26 | # Fall back to sys.executable 27 | python_executable = sys.executable 28 | return python_executable 29 | 30 | 31 | class clean(distutils.command.clean.clean): 32 | def run(self): 33 | # Run default behavior first 34 | distutils.command.clean.clean.run(self) 35 | 36 | # Remove tensordict extension 37 | for path in (ROOT_DIR / "tensordict").glob("**/*.so"): 38 | logging.info(f"removing '{path}'") 39 | path.unlink() 40 | # Remove build directory 41 | build_dirs = [ROOT_DIR / "build"] 42 | for path in build_dirs: 43 | if path.exists(): 44 | logging.info(f"removing '{path}' (and everything under it)") 45 | shutil.rmtree(str(path), ignore_errors=True) 46 | 47 | 48 | class CMakeExtension(Extension): 49 | def __init__(self, name, sourcedir=""): 50 | super().__init__(name, sources=[]) 51 | self.sourcedir = os.path.abspath(sourcedir) 52 | 53 | 54 | class CMakeBuild(build_ext): 55 | def run(self): 56 | for ext in self.extensions: 57 | self.build_extension(ext) 58 | 59 | def build_extension(self, ext): 60 | is_editable = self.inplace 61 | if is_editable: 62 | # For editable installs, place the extension in the source directory 63 | extdir = os.path.abspath(os.path.join(ROOT_DIR, "tensordict")) 64 | else: 65 | # For regular installs, place the extension in the build directory 66 | extdir = os.path.abspath(os.path.join(self.build_lib, "tensordict")) 67 | cmake_args = [ 68 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", 69 | f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={extdir}", 70 | f"-DPYTHON_EXECUTABLE={get_python_executable()}", 71 | f"-DPython3_EXECUTABLE={get_python_executable()}", 72 | # for windows 73 | "-DCMAKE_BUILD_TYPE=Release", 74 | ] 75 | 76 | build_args = [] 77 | if not os.path.exists(self.build_temp): 78 | os.makedirs(self.build_temp) 79 | if sys.platform == "win32": 80 | build_args += ["--config", "Release"] 81 | subprocess.check_call( 82 | ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp 83 | ) 84 | subprocess.check_call( 85 | ["cmake", "--build", ".", "--verbose"] + build_args, cwd=self.build_temp 86 | ) 87 | 88 | 89 | def get_extensions(): 90 | extensions_dir = os.path.join(ROOT_DIR, "tensordict", "csrc") 91 | return [CMakeExtension("tensordict._C", sourcedir=extensions_dir)] 92 | 93 | 94 | def version(): 95 | return { 96 | "write_to": "tensordict/_version.py", # Specify the path where the version file should be written 97 | } 98 | 99 | 100 | setup( 101 | ext_modules=get_extensions(), 102 | cmdclass={ 103 | "build_ext": CMakeBuild, 104 | "clean": clean, 105 | }, 106 | packages=find_packages( 107 | exclude=("test", "tutorials", "packaging", "gallery", "docs") 108 | ), 109 | setup_requires=["setuptools_scm"], 110 | use_scm_version=version(), 111 | ) 112 | -------------------------------------------------------------------------------- /tensordict/_C/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | def unravel_key_list(list_of_keys): ... 7 | def unravel_keys(keys): ... 8 | def unravel_key(key): ... 9 | def _unravel_key_to_tuple(key): ... 10 | -------------------------------------------------------------------------------- /tensordict/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import tensordict._reductions 7 | from tensordict._lazy import LazyStackedTensorDict 8 | from tensordict._nestedkey import NestedKey 9 | from tensordict._td import ( 10 | cat, 11 | from_consolidated, 12 | from_module, 13 | from_modules, 14 | from_pytree, 15 | fromkeys, 16 | is_tensor_collection, 17 | lazy_stack, 18 | load, 19 | load_memmap, 20 | maybe_dense_stack, 21 | memmap, 22 | save, 23 | stack, 24 | TensorDict, 25 | ) 26 | from tensordict._unbatched import UnbatchedTensor 27 | 28 | from tensordict.base import ( 29 | _default_is_leaf as default_is_leaf, 30 | _is_leaf_nontensor as is_leaf_nontensor, 31 | from_any, 32 | from_dict, 33 | from_h5, 34 | from_namedtuple, 35 | from_struct_array, 36 | from_tuple, 37 | get_defaults_to_none, 38 | set_get_defaults_to_none, 39 | TensorDictBase, 40 | ) 41 | from tensordict.functional import ( 42 | dense_stack_tds, 43 | make_tensordict, 44 | merge_tensordicts, 45 | pad, 46 | pad_sequence, 47 | ) 48 | from tensordict.memmap import MemoryMappedTensor 49 | from tensordict.persistent import PersistentTensorDict 50 | from tensordict.tensorclass import ( 51 | from_dataclass, 52 | MetaData, 53 | NonTensorData, 54 | NonTensorDataBase, 55 | NonTensorStack, 56 | tensorclass, 57 | TensorClass, 58 | ) 59 | from tensordict.utils import ( 60 | assert_allclose_td, 61 | assert_close, 62 | capture_non_tensor_stack, 63 | is_batchedtensor, 64 | is_non_tensor, 65 | is_tensorclass, 66 | lazy_legacy, 67 | list_to_stack, 68 | parse_tensor_dict_string, 69 | set_capture_non_tensor_stack, 70 | set_lazy_legacy, 71 | set_list_to_stack, 72 | unravel_key, 73 | unravel_key_list, 74 | ) 75 | from tensordict._pytree import * 76 | from tensordict.nn import as_tensordict_module, TensorDictParams 77 | 78 | try: 79 | from tensordict._version import __version__ # @manual=//pytorch/tensordict:version 80 | except ImportError: 81 | __version__ = None 82 | -------------------------------------------------------------------------------- /tensordict/_nestedkey.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import abc 6 | 7 | 8 | class _NestedKeyMeta(abc.ABCMeta): 9 | def __instancecheck__(self, instance): 10 | return isinstance(instance, str) or ( 11 | isinstance(instance, tuple) 12 | and len(instance) 13 | and all(isinstance(subkey, NestedKey) for subkey in instance) 14 | ) 15 | 16 | 17 | class NestedKey(metaclass=_NestedKeyMeta): 18 | """An abstract class for nested keys. 19 | 20 | Nested keys are the generic key type accepted by TensorDict. 21 | 22 | A nested key is either a string or a non-empty tuple of NestedKeys instances. 23 | 24 | The NestedKey class supports instance checks. 25 | 26 | """ 27 | 28 | pass 29 | -------------------------------------------------------------------------------- /tensordict/_nestedkey.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from typing import Tuple, type_check_only 8 | 9 | NestedKey = type_check_only(str | Tuple["NestedKeyType", ...]) 10 | -------------------------------------------------------------------------------- /tensordict/_reductions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import copyreg 8 | from multiprocessing import reduction 9 | 10 | import torch 11 | from tensordict._lazy import LazyStackedTensorDict 12 | from tensordict._td import TensorDict 13 | 14 | from tensordict.tensorclass import NonTensorData, NonTensorStack 15 | from tensordict.utils import _is_tensorclass, _STR_DTYPE_TO_DTYPE 16 | 17 | CLS_MAP = { 18 | "TensorDict": TensorDict, 19 | "LazyStackedTensorDict": LazyStackedTensorDict, 20 | "NonTensorData": NonTensorData, 21 | "NonTensorStack": NonTensorStack, 22 | } 23 | 24 | 25 | def _rebuild_tensordict_files(flat_key_values, metadata_dict, is_shared: bool = False): 26 | def from_metadata(metadata=metadata_dict, prefix=None): 27 | non_tensor = metadata.pop("non_tensors") 28 | leaves = metadata.pop("leaves") 29 | cls = metadata.pop("cls") 30 | cls_metadata = metadata.pop("cls_metadata") 31 | is_locked = cls_metadata.pop("is_locked", False) 32 | 33 | d = { 34 | key: NonTensorData(data, batch_size=batch_size) 35 | for (key, (data, batch_size)) in non_tensor.items() 36 | } 37 | for key in leaves.keys(): 38 | total_key = (key,) if prefix is None else prefix + (key,) 39 | if total_key[-1].startswith(""): 40 | nested_values = flat_key_values[total_key] 41 | nested_lengths = None 42 | continue 43 | if total_key[-1].startswith(""): 44 | nested_lengths = flat_key_values[total_key] 45 | continue 46 | elif total_key[-1].startswith("", "") 49 | value = torch.nested.nested_tensor_from_jagged( 50 | nested_values, offsets=offsets, lengths=nested_lengths 51 | ) 52 | del nested_values 53 | del nested_lengths 54 | else: 55 | value = flat_key_values[total_key] 56 | d[key] = value 57 | for k, v in metadata.items(): 58 | # Each remaining key is a tuple pointing to a sub-tensordict 59 | d[k] = from_metadata( 60 | v, prefix=prefix + (k,) if prefix is not None else (k,) 61 | ) 62 | if isinstance(cls, str): 63 | cls = CLS_MAP[cls] 64 | result = cls._from_dict_validated(d, **cls_metadata) 65 | if is_locked: 66 | result.lock_() 67 | # if is_shared: 68 | # result._is_shared = is_shared 69 | return result 70 | 71 | return from_metadata() 72 | 73 | 74 | def _rebuild_tensordict_files_shared(flat_key_values, metadata_dict): 75 | return _rebuild_tensordict_files(flat_key_values, metadata_dict, is_shared=True) 76 | 77 | 78 | def _rebuild_tensordict_files_consolidated( 79 | metadata, 80 | storage, 81 | ): 82 | def from_metadata(metadata=metadata, prefix=None): 83 | consolidated = {"storage": storage, "metadata": metadata} 84 | metadata = dict(metadata) 85 | non_tensor = metadata.pop("non_tensors") 86 | leaves = metadata.pop("leaves") 87 | cls = metadata.pop("cls") 88 | cls_metadata = dict(metadata.pop("cls_metadata")) 89 | is_locked = cls_metadata.pop("is_locked", False) 90 | # size can be there to tell what the size of the file is 91 | _ = metadata.pop("size", None) 92 | 93 | d = { 94 | key: NonTensorData( 95 | data, 96 | batch_size=batch_size, 97 | device=torch.device(device) if device is not None else None, 98 | ) 99 | for (key, (data, batch_size, device)) in non_tensor.items() 100 | } 101 | for key, (dtype, local_shape, start, stop, pad) in leaves.items(): 102 | dtype = _STR_DTYPE_TO_DTYPE[dtype] 103 | # device = torch.device(device) 104 | local_shape = torch.Size(local_shape) 105 | value = storage[start:stop].view(dtype) 106 | if pad: 107 | value = value[: local_shape.numel()] 108 | value = value.view(local_shape) 109 | if key.startswith(""): 110 | raise RuntimeError 111 | elif key.startswith(""): 112 | nested_values = value 113 | nested_lengths = None 114 | continue 115 | elif key.startswith(""): 116 | nested_lengths = value 117 | continue 118 | elif key.startswith(""): 119 | from torch.nested._internal.nested_tensor import NestedTensor 120 | 121 | offsets = value 122 | value = NestedTensor( 123 | nested_values, offsets=offsets, lengths=nested_lengths 124 | ) 125 | key = key.replace("", "") 126 | d[key] = value 127 | for k, v in metadata.items(): 128 | # Each remaining key is a tuple pointing to a sub-tensordict 129 | d[k] = from_metadata( 130 | v, prefix=prefix + (k,) if prefix is not None else (k,) 131 | ) 132 | if isinstance(cls, str): 133 | cls = CLS_MAP[cls] 134 | result = cls._from_dict_validated(d, **cls_metadata) 135 | if is_locked: 136 | result = result.lock_() 137 | if _is_tensorclass(cls): 138 | result._tensordict._consolidated = consolidated 139 | else: 140 | result._consolidated = consolidated 141 | return result 142 | 143 | return from_metadata() 144 | 145 | 146 | def _make_td(cls, state): 147 | td = cls.__new__(cls) 148 | td.__setstate__(state) 149 | return td 150 | 151 | 152 | def _reduce_td(data: TensorDict): 153 | consolidated = getattr(data, "_consolidated", None) 154 | if consolidated and consolidated["metadata"] is not None: 155 | storage = consolidated["storage"] 156 | storge_metadata = consolidated["metadata"] 157 | return ( 158 | _rebuild_tensordict_files_consolidated, 159 | (storge_metadata, storage), 160 | ) 161 | 162 | # This is faster than the solution below. 163 | return ( 164 | _make_td, 165 | ( 166 | type(data), 167 | data.__getstate__(), 168 | ), 169 | ) 170 | # metadata_dict, flat_key_values, _, _ = data._reduce_vals_and_metadata( 171 | # requires_metadata=True 172 | # ) 173 | # return (_rebuild_tensordict_files, (flat_key_values, metadata_dict)) 174 | 175 | 176 | reduction.register(TensorDict, _reduce_td) 177 | 178 | copyreg.pickle(TensorDict, _reduce_td) 179 | 180 | reduction.register(LazyStackedTensorDict, _reduce_td) 181 | 182 | copyreg.pickle(LazyStackedTensorDict, _reduce_td) 183 | -------------------------------------------------------------------------------- /tensordict/_tensordict/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | from tensordict.utils import ( # noqa 8 | _unravel_key_to_tuple, 9 | unravel_key, 10 | unravel_key_list, 11 | unravel_keys, 12 | ) 13 | 14 | warnings.warn( 15 | "tensordict._tensordict will soon be removed in favour of tensordict._C.", 16 | category=DeprecationWarning, 17 | ) 18 | -------------------------------------------------------------------------------- /tensordict/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | project(tensordict) 3 | 4 | set(CMAKE_CXX_STANDARD 20) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | 7 | # Set the Python executable to the one from your virtual environment 8 | if(APPLE) # Check if the target OS is OSX/macOS 9 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") 10 | include(FindPythonPyEnv) 11 | endif() 12 | 13 | find_package(Python3 REQUIRED COMPONENTS Interpreter Development) 14 | find_package(pybind11 2.13 REQUIRED) 15 | 16 | file(GLOB SOURCES "*.cpp") 17 | 18 | add_library(_C MODULE ${SOURCES}) 19 | 20 | if(WIN32) 21 | set_target_properties(_C PROPERTIES 22 | OUTPUT_NAME "_C" 23 | PREFIX "" # Remove 'lib' prefix 24 | SUFFIX ".pyd" 25 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}" 26 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}" 27 | RUNTIME_OUTPUT_DIRECTORY_DEBUG "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}" 28 | RUNTIME_OUTPUT_DIRECTORY_RELEASE "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}" 29 | ) 30 | else() 31 | set_target_properties(_C PROPERTIES 32 | OUTPUT_NAME "_C" 33 | PREFIX "" # Remove 'lib' prefix 34 | SUFFIX ".so" # Ensure correct suffix for macOS/Linux (consider using CMAKE_SHARED_LIBRARY_SUFFIX instead for cross-platform compatibility) 35 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}" 36 | ) 37 | endif() 38 | 39 | find_package(Python COMPONENTS Development.Static) 40 | target_link_libraries(_C ${Python_STATIC_LIBRARIES}) 41 | 42 | target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR}) 43 | 44 | #if(APPLE OR WIN32) # Check if the target OS is OSX/macOS 45 | target_link_libraries(_C PRIVATE pybind11::module) 46 | #else() 47 | # target_link_libraries(_C PRIVATE Python3::Python pybind11::module) 48 | #endif() 49 | 50 | if(CMAKE_BUILD_TYPE STREQUAL "Debug") 51 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -fsanitize=address") 52 | else() 53 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") 54 | endif() 55 | set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum OS X deployment version") 56 | set(CMAKE_VERBOSE_MAKEFILE ON) 57 | 58 | if(WIN32) 59 | add_custom_command(TARGET _C POST_BUILD 60 | COMMAND ${CMAKE_COMMAND} -E copy $ "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/$" 61 | ) 62 | endif() 63 | -------------------------------------------------------------------------------- /tensordict/csrc/cmake/FindPythonPyEnv.cmake: -------------------------------------------------------------------------------- 1 | # Find informations about the current python environment. 2 | # by melMass 3 | # 4 | # Finds the following: 5 | # 6 | # - PYTHON_EXECUTABLE 7 | # - PYTHON_INCLUDE_DIR 8 | # - PYTHON_LIBRARY 9 | # - PYTHON_SITE 10 | # - PYTHON_NUMPY_INCLUDE_DIR 11 | # 12 | # - PYTHONLIBS_VERSION_STRING (The full version id. ie "3.7.4") 13 | # - PYTHON_VERSION_MAJOR 14 | # - PYTHON_VERSION_MINOR 15 | # - PYTHON_VERSION_PATCH 16 | # 17 | # 18 | 19 | function(debug_message messages) 20 | # message(STATUS "") 21 | message(STATUS "🐍 ${messages}") 22 | message(STATUS "\n") 23 | endfunction() 24 | 25 | if (NOT DEFINED PYTHON_EXECUTABLE) 26 | execute_process( 27 | COMMAND which python 28 | OUTPUT_VARIABLE PYTHON_EXECUTABLE OUTPUT_STRIP_TRAILING_WHITESPACE 29 | ) 30 | endif() 31 | 32 | execute_process( 33 | COMMAND ${PYTHON_EXECUTABLE} -c "from __future__ import print_function; from distutils.sysconfig import get_python_inc; print(get_python_inc())" 34 | OUTPUT_VARIABLE PYTHON_INCLUDE_DIR OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET 35 | ) 36 | 37 | if (NOT EXISTS ${PYTHON_INCLUDE_DIR}) 38 | message(FATAL "Python include directory not found.") 39 | endif() 40 | 41 | execute_process( 42 | COMMAND ${PYTHON_EXECUTABLE} -c "from __future__ import print_function; import os, numpy.distutils; print(os.pathsep.join(numpy.distutils.misc_util.get_numpy_include_dirs()))" 43 | OUTPUT_VARIABLE PYTHON_NUMPY_INCLUDE_DIR OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET 44 | ) 45 | 46 | execute_process( 47 | COMMAND ${PYTHON_EXECUTABLE} -c "from __future__ import print_function; import distutils.sysconfig as sysconfig; print('-L' + sysconfig.get_config_var('LIBDIR') + '/' + sysconfig.get_config_var('LDLIBRARY'))" 48 | OUTPUT_VARIABLE PYTHON_LIBRARY OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET 49 | ) 50 | 51 | execute_process( 52 | COMMAND ${PYTHON_EXECUTABLE} -c "from __future__ import print_function; import platform; print(platform.python_version())" 53 | OUTPUT_VARIABLE PYTHONLIBS_VERSION_STRING OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET 54 | ) 55 | 56 | execute_process( 57 | COMMAND ${PYTHON_EXECUTABLE} -c "from __future__ import print_function; from distutils.sysconfig import get_python_lib; print(get_python_lib())" 58 | OUTPUT_VARIABLE PYTHON_SITE OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET 59 | ) 60 | 61 | set(PYTHON_VIRTUAL_ENV $ENV{VIRTUAL_ENV}) 62 | string(REPLACE "." ";" _VERSION_LIST ${PYTHONLIBS_VERSION_STRING}) 63 | 64 | list(GET _VERSION_LIST 0 PYTHON_VERSION_MAJOR) 65 | list(GET _VERSION_LIST 1 PYTHON_VERSION_MINOR) 66 | list(GET _VERSION_LIST 2 PYTHON_VERSION_PATCH) 67 | 68 | 69 | 70 | debug_message("Found Python ${PYTHON_VERSION_MAJOR} (${PYTHONLIBS_VERSION_STRING})") 71 | debug_message("PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") 72 | debug_message("PYTHON_INCLUDE_DIR: ${PYTHON_INCLUDE_DIR}") 73 | debug_message("PYTHON_LIBRARY: ${PYTHON_LIBRARY}") 74 | debug_message("PYTHON_NUMPY_INCLUDE_DIR: ${PYTHON_NUMPY_INCLUDE_DIR}") 75 | -------------------------------------------------------------------------------- /tensordict/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | /* @nolint */ 2 | // Copyright (c) Meta Platforms, Inc. and affiliates. 3 | // 4 | // This source code is licensed under the MIT license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #include "utils.h" 13 | 14 | namespace py = pybind11; 15 | 16 | PYBIND11_MODULE(_C, m) { 17 | m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat 18 | m.def("unravel_key", &unravel_key, py::arg("key")); 19 | m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key")); 20 | m.def("unravel_key_list", 21 | py::overload_cast(&unravel_key_list), 22 | py::arg("keys")); 23 | m.def("unravel_key_list", 24 | py::overload_cast(&unravel_key_list), 25 | py::arg("keys")); 26 | } 27 | -------------------------------------------------------------------------------- /tensordict/csrc/utils.cpp: -------------------------------------------------------------------------------- 1 | /* @nolint */ 2 | // Copyright (c) Meta Platforms, Inc. and affiliates. 3 | // 4 | // This source code is licensed under the MIT license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "utils.h" 8 | 9 | namespace py = pybind11; 10 | 11 | py::tuple _unravel_key_to_tuple(const py::object &key) { 12 | bool is_tuple = py::isinstance(key); 13 | bool is_str = py::isinstance(key); 14 | 15 | if (is_tuple) { 16 | py::list newkey; 17 | for (const auto &subkey : key) { 18 | if (py::isinstance(subkey)) { 19 | newkey.append(subkey); 20 | } else { 21 | auto _key = _unravel_key_to_tuple(subkey.cast()); 22 | if (_key.size() == 0) { 23 | return py::make_tuple(); 24 | } 25 | newkey += _key; 26 | } 27 | } 28 | return py::tuple(newkey); 29 | } 30 | if (is_str) { 31 | return py::make_tuple(key); 32 | } else { 33 | return py::make_tuple(); 34 | } 35 | } 36 | 37 | py::object unravel_key(const py::object &key) { 38 | bool is_tuple = py::isinstance(key); 39 | bool is_str = py::isinstance(key); 40 | 41 | if (is_tuple) { 42 | py::list newkey; 43 | int count = 0; 44 | for (const auto &subkey : key) { 45 | if (py::isinstance(subkey)) { 46 | newkey.append(subkey); 47 | count++; 48 | } else { 49 | auto _key = _unravel_key_to_tuple(subkey.cast()); 50 | count += _key.size(); 51 | newkey += _key; 52 | } 53 | } 54 | if (count == 1) { 55 | return newkey[0]; 56 | } 57 | return py::tuple(newkey); 58 | } 59 | if (is_str) { 60 | return key; 61 | } else { 62 | throw std::runtime_error("key should be a Sequence"); 63 | } 64 | } 65 | 66 | py::list unravel_key_list(const py::list &keys) { 67 | py::list newkeys; 68 | for (const auto &key : keys) { 69 | auto _key = unravel_key(key.cast()); 70 | newkeys.append(_key); 71 | } 72 | return newkeys; 73 | } 74 | 75 | py::list unravel_key_list(const py::tuple &keys) { 76 | return unravel_key_list(py::list(keys)); 77 | } 78 | -------------------------------------------------------------------------------- /tensordict/csrc/utils.h: -------------------------------------------------------------------------------- 1 | /* @nolint */ 2 | // Copyright (c) Meta Platforms, Inc. and affiliates. 3 | // 4 | // This source code is licensed under the MIT license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | py::tuple _unravel_key_to_tuple(const py::object &key); 13 | 14 | py::object unravel_key(const py::object &key); 15 | 16 | py::list unravel_key_list(const py::list &keys); 17 | 18 | py::list unravel_key_list(const py::tuple &keys); 19 | -------------------------------------------------------------------------------- /tensordict/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from tensordict.nn.common import ( 7 | dispatch, 8 | make_tensordict, 9 | TensorDictModule, 10 | TensorDictModuleBase, 11 | TensorDictModuleWrapper, 12 | WrapModule, 13 | ) 14 | from tensordict.nn.distributions import ( 15 | AddStateIndependentNormalScale, 16 | CompositeDistribution, 17 | NormalParamExtractor, 18 | OneHotCategorical, 19 | rand_one_hot, 20 | TruncatedNormal, 21 | ) 22 | from tensordict.nn.ensemble import EnsembleModule 23 | from tensordict.nn.functional_modules import ( 24 | get_functional, 25 | is_functional, 26 | make_functional, 27 | repopulate_module, 28 | ) 29 | from tensordict.nn.params import TensorDictParams 30 | from tensordict.nn.probabilistic import ( 31 | InteractionType, 32 | ProbabilisticTensorDictModule, 33 | ProbabilisticTensorDictSequential, 34 | set_interaction_type, 35 | ) 36 | from tensordict.nn.sequence import TensorDictSequential 37 | from tensordict.nn.utils import ( 38 | add_custom_mapping, 39 | biased_softplus, 40 | inv_softplus, 41 | mappings, 42 | set_skip_existing, 43 | skip_existing, 44 | ) 45 | 46 | from .common import as_tensordict_module 47 | 48 | from .cudagraphs import CudaGraphModule 49 | from .utils import composite_lp_aggregate, set_composite_lp_aggregate 50 | -------------------------------------------------------------------------------- /tensordict/nn/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from tensordict.nn.distributions import continuous, discrete 7 | 8 | from tensordict.nn.distributions.composite import CompositeDistribution 9 | from tensordict.nn.distributions.continuous import ( 10 | AddStateIndependentNormalScale, 11 | Delta, 12 | NormalParamExtractor, 13 | ) 14 | from tensordict.nn.distributions.discrete import OneHotCategorical, rand_one_hot 15 | from tensordict.nn.distributions.truncated_normal import TruncatedNormal 16 | 17 | distributions_maps = { 18 | distribution_class.lower(): eval(distribution_class) 19 | for distribution_class in (*continuous.__all__, *discrete.__all__) 20 | } 21 | -------------------------------------------------------------------------------- /tensordict/nn/distributions/discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | from typing import Sequence 9 | 10 | import torch 11 | from torch import distributions as D 12 | 13 | # We need this to build the distribution maps 14 | __all__ = [ 15 | "OneHotCategorical", 16 | ] 17 | 18 | 19 | def _treat_categorical_params( 20 | params: torch.Tensor | None = None, 21 | ) -> torch.Tensor | None: 22 | if params is None: 23 | return None 24 | if params.shape[-1] == 1: 25 | params = params[..., 0] 26 | return params 27 | 28 | 29 | def rand_one_hot(values: torch.Tensor, do_softmax: bool = True) -> torch.Tensor: 30 | if do_softmax: 31 | values = values.softmax(-1) 32 | out = values.cumsum(-1) > torch.rand_like(values[..., :1]) 33 | out = (out.cumsum(-1) == 1).to(torch.long) 34 | return out 35 | 36 | 37 | class OneHotCategorical(D.Categorical): 38 | """One-hot categorical distribution. 39 | 40 | This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings 41 | of the discrete tensors. 42 | 43 | """ 44 | 45 | num_params: int = 1 46 | 47 | def __init__( 48 | self, 49 | logits: torch.Tensor | None = None, 50 | probs: torch.Tensor | None = None, 51 | **kwargs, 52 | ) -> None: 53 | logits = _treat_categorical_params(logits) 54 | probs = _treat_categorical_params(probs) 55 | super().__init__(probs=probs, logits=logits, **kwargs) 56 | 57 | def log_prob(self, value: torch.Tensor) -> torch.Tensor: 58 | return super().log_prob(value.argmax(dim=-1)) 59 | 60 | @property 61 | def mode(self) -> torch.Tensor: 62 | if hasattr(self, "logits"): 63 | return (self.logits == self.logits.max(-1, True)[0]).to(torch.long) 64 | else: 65 | return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) 66 | 67 | deterministic_sample = mode 68 | 69 | def sample( 70 | self, 71 | sample_shape: torch.Size | Sequence[int] | None = None, 72 | ) -> torch.Tensor: 73 | if sample_shape is None: 74 | sample_shape = torch.Size([]) 75 | out = super().sample(sample_shape=sample_shape) 76 | out = torch.nn.functional.one_hot(out, self.logits.shape[-1]).to(torch.long) 77 | return out 78 | 79 | def rsample( 80 | self, 81 | sample_shape: torch.Size | Sequence[int] | None = None, 82 | ) -> torch.Tensor: 83 | if sample_shape is None: 84 | sample_shape = torch.Size([]) 85 | d = D.relaxed_categorical.RelaxedOneHotCategorical( 86 | 1.0, probs=self.probs, logits=self.logits 87 | ) 88 | out = d.rsample(sample_shape) 89 | out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype)) 90 | return out 91 | 92 | 93 | D.RelaxedBernoulli.deterministic_sample = D.Bernoulli.mode 94 | 95 | 96 | @property 97 | def _relaxed_onehot_mode(self) -> torch.Tensor: 98 | probs = self.base_dist.probs 99 | mode = probs.argmax(dim=-1) 100 | return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs) 101 | 102 | 103 | D.RelaxedOneHotCategorical.deterministic_sample = _relaxed_onehot_mode 104 | -------------------------------------------------------------------------------- /tensordict/nn/distributions/truncated_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # from https://github.com/toshas/torch_truncnorm 7 | 8 | from __future__ import annotations 9 | 10 | import math 11 | from numbers import Number 12 | from typing import Sequence 13 | 14 | import torch 15 | from torch.distributions import constraints, Distribution 16 | from torch.distributions.utils import broadcast_all 17 | 18 | 19 | CONST_SQRT_2 = math.sqrt(2) 20 | CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) 21 | CONST_INV_SQRT_2 = 1 / math.sqrt(2) 22 | CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) 23 | CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) 24 | 25 | 26 | class TruncatedStandardNormal(Distribution): 27 | """Truncated Standard Normal distribution. 28 | 29 | Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 30 | """ 31 | 32 | arg_constraints = { 33 | "a": constraints.real, 34 | "b": constraints.real, 35 | } 36 | has_rsample = True 37 | eps = 1e-6 38 | 39 | def __init__( 40 | self, 41 | a: Number | torch.Tensor, 42 | b: Number | torch.Tensor, 43 | validate_args: bool | None = None, 44 | ) -> None: 45 | self.a, self.b = broadcast_all(a, b) 46 | if isinstance(a, Number) and isinstance(b, Number): 47 | batch_shape = torch.Size() 48 | else: 49 | batch_shape = self.a.size() 50 | super().__init__(batch_shape, validate_args=validate_args) 51 | if self.a.dtype != self.b.dtype: 52 | raise ValueError("Truncation bounds types are different") 53 | if any((self.a >= self.b).view(-1).tolist()): 54 | raise ValueError("Incorrect truncation range") 55 | # eps = torch.finfo(self.a.dtype).eps * 10 56 | eps = self.eps 57 | self._dtype_min_gt_0 = eps 58 | self._dtype_max_lt_1 = 1 - eps 59 | self._little_phi_a = self._little_phi(self.a) 60 | self._little_phi_b = self._little_phi(self.b) 61 | self._big_phi_a = self._big_phi(self.a) 62 | self._big_phi_b = self._big_phi(self.b) 63 | self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) 64 | self._log_Z = self._Z.log() 65 | little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) 66 | little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) 67 | self._lpbb_m_lpaa_d_Z = ( 68 | self._little_phi_b * little_phi_coeff_b 69 | - self._little_phi_a * little_phi_coeff_a 70 | ) / self._Z 71 | self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z 72 | self._variance = ( 73 | 1 74 | - self._lpbb_m_lpaa_d_Z 75 | - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 76 | ) 77 | self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z 78 | 79 | @constraints.dependent_property 80 | def support(self) -> constraints.Constraints: 81 | return constraints.interval(self.a, self.b) 82 | 83 | @property 84 | def mean(self) -> torch.Tensor: 85 | return self._mean 86 | 87 | @property 88 | def variance(self) -> torch.Tensor: 89 | return self._variance 90 | 91 | @property 92 | def entropy(self) -> torch.Tensor: 93 | return self._entropy 94 | 95 | @property 96 | def auc(self) -> torch.Tensor: 97 | return self._Z 98 | 99 | @staticmethod 100 | def _little_phi(x: torch.Tensor) -> torch.Tensor: 101 | return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI 102 | 103 | def _big_phi(self, x: torch.Tensor) -> torch.Tensor: 104 | phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) 105 | return phi.clamp(self.eps, 1 - self.eps) 106 | 107 | @staticmethod 108 | def _inv_big_phi(x: torch.Tensor) -> torch.Tensor: 109 | return CONST_SQRT_2 * (2 * x - 1).erfinv() 110 | 111 | def cdf(self, value: torch.Tensor) -> torch.Tensor: 112 | if self._validate_args: 113 | self._validate_sample(value) 114 | return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) 115 | 116 | def icdf(self, value: torch.Tensor) -> torch.Tensor: 117 | y = self._big_phi_a + value * self._Z 118 | y = y.clamp(self.eps, 1 - self.eps) 119 | return self._inv_big_phi(y) 120 | 121 | def log_prob(self, value: torch.Tensor) -> torch.Tensor: 122 | if self._validate_args: 123 | self._validate_sample(value) 124 | return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 125 | 126 | def rsample( 127 | self, 128 | sample_shape: torch.Size | Sequence[int] | None = None, 129 | ) -> torch.Tensor: 130 | if sample_shape is None: 131 | sample_shape = torch.Size([]) 132 | shape = self._extended_shape(sample_shape) 133 | p = torch.empty(shape, device=self.a.device).uniform_( 134 | self._dtype_min_gt_0, self._dtype_max_lt_1 135 | ) 136 | return self.icdf(p) 137 | 138 | 139 | class TruncatedNormal(TruncatedStandardNormal): 140 | """Truncated Normal distribution. 141 | 142 | https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 143 | """ 144 | 145 | has_rsample = True 146 | 147 | def __init__( 148 | self, 149 | loc: Number | torch.Tensor, 150 | scale: Number | torch.Tensor, 151 | a: Number | torch.Tensor, 152 | b: Number | torch.Tensor, 153 | validate_args: bool | None = None, 154 | ) -> None: 155 | scale = scale.clamp_min(self.eps) 156 | self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) 157 | self._non_std_a = a 158 | self._non_std_b = b 159 | a = (a - self.loc) / self.scale 160 | b = (b - self.loc) / self.scale 161 | super().__init__(a, b, validate_args=validate_args) 162 | self._log_scale = self.scale.log() 163 | self._mean = self._mean * self.scale + self.loc 164 | self._variance = self._variance * self.scale**2 165 | self._entropy += self._log_scale 166 | 167 | def _to_std_rv(self, value: torch.Tensor) -> torch.Tensor: 168 | return (value - self.loc) / self.scale 169 | 170 | def _from_std_rv(self, value: torch.Tensor) -> torch.Tensor: 171 | return value * self.scale + self.loc 172 | 173 | def cdf(self, value: torch.Tensor) -> torch.Tensor: 174 | return super().cdf(self._to_std_rv(value)) 175 | 176 | def icdf(self, value: torch.Tensor) -> torch.Tensor: 177 | sample = self._from_std_rv(super().icdf(value)) 178 | 179 | # clamp data but keep gradients 180 | sample_clip = torch.stack( 181 | [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 182 | ).max(0)[0] 183 | sample_clip = torch.stack( 184 | [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 185 | ).min(0)[0] 186 | sample.data.copy_(sample_clip) 187 | return sample 188 | 189 | def log_prob(self, value: torch.Tensor) -> torch.Tensor: 190 | value = self._to_std_rv(value) 191 | return super().log_prob(value) - self._log_scale 192 | -------------------------------------------------------------------------------- /tensordict/nn/distributions/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | from tensordict.utils import DeviceType 10 | from torch import distributions as D 11 | 12 | 13 | def _cast_device( 14 | elt: torch.Tensor | float, 15 | device: DeviceType, 16 | ) -> torch.Tensor | float: 17 | if isinstance(elt, torch.Tensor): 18 | return elt.to(device) 19 | return elt 20 | 21 | 22 | def _cast_transform_device( 23 | transform: D.Transform | None, 24 | device: DeviceType, 25 | ) -> D.Transform | None: 26 | if transform is None: 27 | return transform 28 | elif isinstance(transform, D.ComposeTransform): 29 | for i, t in enumerate(transform.parts): 30 | transform.parts[i] = _cast_transform_device(t, device) 31 | elif isinstance(transform, D.Transform): 32 | for attribute in dir(transform): 33 | value = getattr(transform, attribute) 34 | if isinstance(value, torch.Tensor): 35 | setattr(transform, attribute, value.to(device)) 36 | return transform 37 | else: 38 | raise TypeError( 39 | f"Cannot perform device casting for transform of type {type(transform)}" 40 | ) 41 | -------------------------------------------------------------------------------- /tensordict/nn/ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | import torch 9 | from tensordict import LazyStackedTensorDict, TensorDict 10 | from tensordict.nn.common import TensorDictBase, TensorDictModuleBase 11 | 12 | from tensordict.nn.params import TensorDictParams 13 | 14 | 15 | class EnsembleModule(TensorDictModuleBase): 16 | """Module that wraps a module and repeats it to form an ensemble. 17 | 18 | Args: 19 | module (nn.Module): The nn.module to duplicate and wrap. 20 | num_copies (int): The number of copies of module to make. 21 | parameter_init_function (Callable): A function that takes a module copy and initializes its parameters. 22 | expand_input (bool): Whether to expand the input TensorDict to match the number of copies. This should be 23 | True unless you are chaining ensemble modules together, e.g. EnsembleModule(cnn) -> EnsembleModule(mlp). 24 | If False, EnsembleModule(mlp) will expected the previous module(s) to have already expanded the input. 25 | 26 | Examples: 27 | >>> import torch 28 | >>> from torch import nn 29 | >>> from tensordict.nn import TensorDictModule, EnsembleModule 30 | >>> from tensordict import TensorDict 31 | >>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2)) 32 | >>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b']) 33 | >>> ensemble = EnsembleModule(mod, num_copies=3) 34 | >>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10]) 35 | >>> ensemble(data) 36 | TensorDict( 37 | fields={ 38 | a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False), 39 | b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, 40 | batch_size=torch.Size([3, 10]), 41 | device=None, 42 | is_shared=False) 43 | 44 | To stack EnsembleModules together, we should be mindful of turning off `expand_input` from the second module and on. 45 | 46 | Examples: 47 | >>> import torch 48 | >>> from tensordict.nn import TensorDictModule, TensorDictSequential, EnsembleModule 49 | >>> from tensordict import TensorDict 50 | >>> module = TensorDictModule(torch.nn.Linear(2,3), in_keys=['bork'], out_keys=['dork']) 51 | >>> next_module = TensorDictModule(torch.nn.Linear(3,1), in_keys=['dork'], out_keys=['spork']) 52 | >>> e0 = EnsembleModule(module, num_copies=4, expand_input=True) 53 | >>> e1 = EnsembleModule(next_module, num_copies=4, expand_input=False) 54 | >>> seq = TensorDictSequential(e0, e1) 55 | >>> data = TensorDict({'bork': torch.randn(5,2)}, batch_size=[5]) 56 | >>> seq(data) 57 | TensorDict( 58 | fields={ 59 | bork: Tensor(shape=torch.Size([4, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), 60 | dork: Tensor(shape=torch.Size([4, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False), 61 | spork: Tensor(shape=torch.Size([4, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, 62 | batch_size=torch.Size([4, 5]), 63 | device=None, 64 | is_shared=False) 65 | """ 66 | 67 | def __init__( 68 | self, 69 | module: TensorDictModuleBase, 70 | num_copies: int, 71 | expand_input: bool = True, 72 | ): 73 | super().__init__() 74 | self.in_keys = module.in_keys 75 | self.out_keys = module.out_keys 76 | params_td = TensorDict.from_module(module).expand(num_copies).to_tensordict() 77 | 78 | self.module = module 79 | if expand_input: 80 | self.vmapped_forward = torch.vmap(self._func_module_call, (None, 0)) 81 | else: 82 | self.vmapped_forward = torch.vmap(self._func_module_call, 0) 83 | 84 | self.reset_parameters_recursive(params_td) 85 | self.params_td = TensorDictParams(params_td) 86 | 87 | def _func_module_call(self, input, params): 88 | with params.to_module(self.module): 89 | return self.module(input) 90 | 91 | def forward(self, tensordict: TensorDict) -> TensorDict: 92 | return self.vmapped_forward(tensordict, self.params_td) 93 | 94 | def reset_parameters_recursive( 95 | self, parameters: TensorDictBase = None 96 | ) -> TensorDictBase: 97 | """Resets the parameters of all the copies of the module. 98 | 99 | Args: 100 | parameters (TensorDict): A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict 101 | denote the number of module copies to reset. 102 | 103 | Returns: 104 | A TensorDict of pointers to the reset parameters. 105 | """ 106 | if parameters is None: 107 | raise ValueError( 108 | "Ensembles are functional and require passing a TensorDict of parameters to reset_parameters_recursive" 109 | ) 110 | if parameters.ndim: 111 | params_pointers = [] 112 | for params_copy in parameters.unbind(0): 113 | self.reset_parameters_recursive(params_copy) 114 | params_pointers.append(params_copy) 115 | return LazyStackedTensorDict.lazy_stack(params_pointers, -1) 116 | else: 117 | # In case the user has added other neural networks to the EnsembleModule 118 | # besides those in self.module 119 | child_mods = [ 120 | mod 121 | for name, mod in self.named_children() 122 | if name != "module" and name != "ensemble_parameters" 123 | ] 124 | if child_mods: 125 | warnings.warn( 126 | "EnsembleModule.reset_parameters_recursive() only resets parameters of self.module, but other parameters were detected. These parameters will not be reset." 127 | ) 128 | # Reset all self.module descendant parameters 129 | return self.module.reset_parameters_recursive(parameters) 130 | -------------------------------------------------------------------------------- /tensordict/prototype/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from tensordict.prototype.fx import symbolic_trace 6 | 7 | __all__ = [ 8 | "symbolic_trace", 9 | ] 10 | -------------------------------------------------------------------------------- /tensordict/tensordict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from tensordict._lazy import LazyStackedTensorDict # noqa: F401 7 | from tensordict._td import TensorDict # noqa: F401 8 | from tensordict.base import ( # noqa: F401 9 | is_tensor_collection, 10 | NO_DEFAULT, 11 | TensorDictBase, 12 | ) 13 | from tensordict.functional import ( # noqa: F401 14 | dense_stack_tds, 15 | make_tensordict, 16 | merge_tensordicts, 17 | pad, 18 | pad_sequence, 19 | ) 20 | from tensordict.memmap import MemoryMappedTensor # noqa: F401 21 | from tensordict.utils import ( # noqa: F401 22 | assert_allclose_td, 23 | cache, 24 | convert_ellipsis_to_idx, 25 | erase_cache, 26 | expand_as_right, 27 | expand_right, 28 | implement_for, 29 | infer_size_impl, 30 | int_generator, 31 | is_nested_key, 32 | is_seq_of_nested_key, 33 | is_tensorclass, 34 | lock_blocked, 35 | NestedKey, 36 | ) 37 | -------------------------------------------------------------------------------- /test/artifacts/mmap_example/meta.json: -------------------------------------------------------------------------------- 1 | {"nested": {"type": "TensorDict"}, "shape": [2, 1], "device": "cpu", "_type": ""} 2 | -------------------------------------------------------------------------------- /test/artifacts/mmap_example/nested/bfloat16.memmap: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/test/artifacts/mmap_example/nested/bfloat16.memmap -------------------------------------------------------------------------------- /test/artifacts/mmap_example/nested/int64.memmap: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /test/artifacts/mmap_example/nested/meta.json: -------------------------------------------------------------------------------- 1 | {"int64": {"device": "cpu", "shape": [2, 1], "dtype": "torch.int64", "is_nested": false}, "string": {"type": "NonTensorData"}, "bfloat16": {"device": "cpu", "shape": [2, 1], "dtype": "torch.bfloat16", "is_nested": false}, "shape": [2, 1], "device": "cpu", "_type": ""} 2 | -------------------------------------------------------------------------------- /test/artifacts/mmap_example/nested/string/meta.json: -------------------------------------------------------------------------------- 1 | {"_type": "", "data": ["a string!"], "_metadata": null} 2 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import time 9 | from collections import defaultdict 10 | 11 | import pytest 12 | 13 | try: 14 | multiprocessing.set_start_method("spawn") 15 | except Exception: 16 | assert multiprocessing.get_start_method() == "spawn" 17 | 18 | CALL_TIMES = defaultdict(lambda: 0.0) 19 | 20 | 21 | def pytest_sessionfinish(maxprint=50): 22 | out_str = """ 23 | Call times: 24 | =========== 25 | """ 26 | keys = list(CALL_TIMES.keys()) 27 | if len(keys) > 1: 28 | maxchar = max(*[len(key) for key in keys]) 29 | elif len(keys): 30 | maxchar = len(keys[0]) 31 | else: 32 | return 33 | for i, (key, item) in enumerate( 34 | sorted(CALL_TIMES.items(), key=lambda x: x[1], reverse=True) 35 | ): 36 | spaces = " " + " " * (maxchar - len(key)) 37 | out_str += f"\t{key}{spaces}{item: 4.4f}s\n" 38 | if i == maxprint - 1: 39 | break 40 | 41 | 42 | def pytest_addoption(parser): 43 | parser.addoption( 44 | "--runslow", action="store_true", default=False, help="run slow tests" 45 | ) 46 | 47 | 48 | def pytest_configure(config): 49 | config.addinivalue_line("markers", "slow: mark test as slow to run") 50 | 51 | 52 | def pytest_collection_modifyitems(config, items): 53 | if config.getoption("--runslow"): 54 | # --runslow given in cli: do not skip slow tests 55 | return 56 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 57 | for item in items: 58 | if "slow" in item.keywords: 59 | item.add_marker(skip_slow) 60 | 61 | 62 | @pytest.fixture(autouse=True) 63 | def measure_duration(request: pytest.FixtureRequest): 64 | start_time = time.time() 65 | 66 | def fin(): 67 | duration = time.time() - start_time 68 | name = request.node.name 69 | class_name = request.cls.__name__ if request.cls else None 70 | name = name.split("[")[0] 71 | if class_name is not None: 72 | name = "::".join([class_name, name]) 73 | file = os.path.basename(request.path) 74 | name = f"{file}::{name}" 75 | CALL_TIMES[name] = CALL_TIMES[name] + duration 76 | 77 | request.addfinalizer(fin) 78 | -------------------------------------------------------------------------------- /test/smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import subprocess 6 | import sys 7 | from pathlib import Path 8 | 9 | _IS_LINUX = sys.platform.startswith("linux") 10 | 11 | 12 | def test_imports_deps(): 13 | print("Importing numpy") # noqa 14 | import numpy # noqa 15 | 16 | print("Importing torch") # noqa 17 | import torch # noqa 18 | 19 | 20 | def test_imports(): 21 | print("Importing tensordict") # noqa 22 | from tensordict import TensorDict # noqa: F401 23 | 24 | print("Importing tensordict nn") # noqa 25 | import tensordict # noqa 26 | from tensordict.nn import TensorDictModule # noqa: F401 27 | 28 | print("version", tensordict.__version__) # noqa 29 | 30 | 31 | def test_static_linking(): 32 | if not _IS_LINUX: 33 | return 34 | # Locate _C.so 35 | try: 36 | import tensordict._C 37 | except ImportError as e: 38 | raise RuntimeError(f"Failed to import tensordict._C: {e}") 39 | # Get the path to _C.so 40 | _C_path = Path(tensordict._C.__file__) 41 | if not _C_path.exists(): 42 | raise RuntimeError(f"_C.so not found at {_C_path}") 43 | # Run ldd on _C.so 44 | try: 45 | output = subprocess.check_output(["ldd", str(_C_path)]).decode("utf-8") 46 | except subprocess.CalledProcessError as e: 47 | raise RuntimeError(f"Failed to run ldd on {_C_path}: {e}") 48 | # Check if libpython is dynamically linked 49 | for line in output.splitlines(): 50 | if "libpython" in line and "=>" in line and "not found" not in line: 51 | raise RuntimeError( 52 | f"tensordict/_C.so is dynamically linked against {line.strip()}" 53 | ) 54 | print( # noqa 55 | "Test passed: tensordict/_C.so does not show dynamic linkage to libpython." 56 | ) 57 | 58 | 59 | if __name__ == "__main__": 60 | test_imports_deps() 61 | test_imports() 62 | -------------------------------------------------------------------------------- /test/test_fx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | 8 | import pytest 9 | import torch 10 | import torch.nn as nn 11 | 12 | from tensordict import TensorDict 13 | from tensordict.nn import TensorDictModule, TensorDictSequential 14 | from tensordict.prototype.fx import symbolic_trace 15 | 16 | 17 | def test_tensordictmodule_trace_consistency(): 18 | class Net(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.linear = nn.LazyLinear(1) 22 | 23 | def forward(self, x): 24 | logits = self.linear(x) 25 | return logits, torch.sigmoid(logits) 26 | 27 | module = TensorDictModule( 28 | Net(), 29 | in_keys=["input"], 30 | out_keys=[("outputs", "logits"), ("outputs", "probabilities")], 31 | ) 32 | graph_module = symbolic_trace(module) 33 | 34 | tensordict = TensorDict({"input": torch.randn(32, 100)}, [32]) 35 | 36 | module_out = TensorDict() 37 | graph_module_out = TensorDict() 38 | 39 | module(tensordict, tensordict_out=module_out) 40 | graph_module(tensordict, tensordict_out=graph_module_out) 41 | 42 | assert ( 43 | module_out["outputs", "logits"] == graph_module_out["outputs", "logits"] 44 | ).all() 45 | assert ( 46 | module_out["outputs", "probabilities"] 47 | == graph_module_out["outputs", "probabilities"] 48 | ).all() 49 | 50 | 51 | def test_tensordictsequential_trace_consistency(): 52 | class Net(nn.Module): 53 | def __init__(self, input_size=100, hidden_size=50, output_size=10): 54 | super().__init__() 55 | self.fc1 = nn.Linear(input_size, hidden_size) 56 | self.fc2 = nn.Linear(hidden_size, output_size) 57 | 58 | def forward(self, x): 59 | x = torch.relu(self.fc1(x)) 60 | return self.fc2(x) 61 | 62 | class Masker(nn.Module): 63 | def forward(self, x, mask): 64 | return torch.softmax(x * mask, dim=1) 65 | 66 | net = TensorDictModule( 67 | Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")] 68 | ) 69 | masker = TensorDictModule( 70 | Masker(), 71 | in_keys=[("intermediate", "x"), ("input", "mask")], 72 | out_keys=[("output", "probabilities")], 73 | ) 74 | module = TensorDictSequential(net, masker) 75 | graph_module = symbolic_trace(module) 76 | 77 | tensordict = TensorDict( 78 | { 79 | "input": TensorDict( 80 | {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))}, 81 | batch_size=[32], 82 | ) 83 | }, 84 | batch_size=[32], 85 | ) 86 | 87 | module_out = TensorDict() 88 | graph_module_out = TensorDict() 89 | 90 | module(tensordict, tensordict_out=module_out) 91 | graph_module(tensordict, tensordict_out=graph_module_out) 92 | 93 | assert ( 94 | graph_module_out["intermediate", "x"] == module_out["intermediate", "x"] 95 | ).all() 96 | assert ( 97 | graph_module_out["output", "probabilities"] 98 | == module_out["output", "probabilities"] 99 | ).all() 100 | 101 | 102 | def test_nested_tensordictsequential_trace_consistency(): 103 | class Net(nn.Module): 104 | def __init__(self, input_size, output_size): 105 | super().__init__() 106 | self.fc = nn.Linear(input_size, output_size) 107 | 108 | def forward(self, x): 109 | return torch.relu(self.fc(x)) 110 | 111 | class Output(nn.Module): 112 | def __init__(self, input_size, output_size=10): 113 | super().__init__() 114 | self.fc = nn.Linear(input_size, output_size) 115 | 116 | def forward(self, x): 117 | return torch.softmax(self.fc(x), dim=1) 118 | 119 | module1 = Net(100, 50) 120 | module2 = Net(50, 40) 121 | module3 = Output(40, 10) 122 | 123 | tdmodule1 = TensorDictModule(module1, ["input"], ["x"]) 124 | tdmodule2 = TensorDictModule(module2, ["x"], ["x"]) 125 | tdmodule3 = TensorDictModule(module3, ["x"], ["probabilities"]) 126 | 127 | tdmodule = TensorDictSequential( 128 | TensorDictSequential(tdmodule1, tdmodule2), tdmodule3 129 | ) 130 | graph_module = symbolic_trace(tdmodule) 131 | 132 | tensordict = TensorDict({"input": torch.rand(32, 100)}, [32]) 133 | 134 | module_out = TensorDict() 135 | graph_module_out = TensorDict() 136 | 137 | tdmodule(tensordict, tensordict_out=module_out) 138 | graph_module(tensordict, tensordict_out=graph_module_out) 139 | 140 | assert (module_out["x"] == graph_module_out["x"]).all() 141 | assert (module_out["probabilities"] == graph_module_out["probabilities"]).all() 142 | 143 | 144 | if __name__ == "__main__": 145 | args, unknown = argparse.ArgumentParser().parse_known_args() 146 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 147 | -------------------------------------------------------------------------------- /test/test_h5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import argparse 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | from tensordict import NonTensorData, PersistentTensorDict, TensorDict 12 | from tensordict.base import _is_leaf_nontensor 13 | from tensordict.utils import is_non_tensor 14 | from torch import multiprocessing as mp 15 | from torch.utils._pytree import tree_map 16 | 17 | TIMEOUT = 100 18 | 19 | try: 20 | import h5py 21 | 22 | _has_h5py = True 23 | except ImportError: 24 | _has_h5py = False 25 | 26 | 27 | @pytest.mark.skipif(not _has_h5py, reason="h5py not found.") 28 | class TestH5Serialization: 29 | @classmethod 30 | def worker(cls, cyberbliptronics, q1, q2): 31 | assert isinstance(cyberbliptronics, PersistentTensorDict) 32 | assert cyberbliptronics.file.filename.endswith("groups.hdf5") 33 | q1.put(cyberbliptronics["Base_Group"]["Sub_Group"]) 34 | assert q2.get(timeout=TIMEOUT) == "checked" 35 | val = cyberbliptronics["Base_Group", "Sub_Group", "default"] + 1 36 | q1.put(val) 37 | assert q2.get(timeout=TIMEOUT) == "checked" 38 | q1.close() 39 | q2.close() 40 | 41 | def test_h5_serialization(self, tmp_path): 42 | arr = np.random.randn(1000) 43 | fn = tmp_path / "groups.hdf5" 44 | with h5py.File(fn, "w") as f: 45 | g = f.create_group("Base_Group") 46 | gg = g.create_group("Sub_Group") 47 | 48 | _ = g.create_dataset("default", data=arr) 49 | _ = gg.create_dataset("default", data=arr) 50 | 51 | persistent_td = PersistentTensorDict(filename=fn, batch_size=[]) 52 | q1 = mp.Queue(1) 53 | q2 = mp.Queue(1) 54 | p = mp.Process(target=self.worker, args=(persistent_td, q1, q2)) 55 | p.start() 56 | try: 57 | val = q1.get(timeout=TIMEOUT) 58 | assert (torch.tensor(arr) == val["default"]).all() 59 | q2.put("checked") 60 | val = q1.get(timeout=TIMEOUT) 61 | assert (torch.tensor(arr) + 1 == val).all() 62 | q2.put("checked") 63 | q1.close() 64 | q2.close() 65 | finally: 66 | p.join() 67 | 68 | def test_h5_nontensor(self, tmpdir): 69 | file = Path(tmpdir) / "file.h5" 70 | td = TensorDict( 71 | { 72 | "a": 0, 73 | "b": 1, 74 | "c": "a string!", 75 | ("d", "e"): "another string!", 76 | }, 77 | [], 78 | ) 79 | td = td.expand(10) 80 | h5td = PersistentTensorDict.from_dict(td, filename=file) 81 | assert "c" in h5td.keys(is_leaf=_is_leaf_nontensor) 82 | assert "c" in h5td.keys() 83 | assert "c" in h5td 84 | assert h5td["c"] == b"a string!" 85 | assert h5td.get("c").batch_size == (10,) 86 | assert ("d", "e") in h5td.keys(True, True, is_leaf=_is_leaf_nontensor) 87 | assert ("d", "e") in h5td 88 | assert h5td["d", "e"] == b"another string!" 89 | assert h5td.get(("d", "e")).batch_size == (10,) 90 | 91 | h5td.set("f", NonTensorData(1, batch_size=[10])) 92 | assert h5td["f"] == 1 93 | h5td.set(("g", "h"), NonTensorData(1, batch_size=[10])) 94 | assert h5td["g", "h"] == 1 95 | 96 | td_recover = h5td.to_tensordict() 97 | assert is_non_tensor(td_recover.get("c")) 98 | assert is_non_tensor(td_recover.get(("d", "e"))) 99 | assert is_non_tensor(td_recover.get("f")) 100 | assert is_non_tensor(td_recover.get(("g", "h"))) 101 | 102 | 103 | def test_auto_batch_size(tmpdir): 104 | tmpdir = Path(tmpdir) 105 | td = TensorDict( 106 | { 107 | "a": torch.arange(12).view((3, 4)), 108 | "b": TensorDict( 109 | { 110 | "c": torch.arange(60).view(3, 4, 5), 111 | "d": "a string!", 112 | }, 113 | batch_size=[3, 4, 5], 114 | ), 115 | "e": "another string!", 116 | }, 117 | batch_size=[3, 4], 118 | ) 119 | td.to_h5(tmpdir / "file.h5") 120 | td_recon = TensorDict.from_h5(tmpdir / "file.h5") 121 | assert td_recon.batch_size == torch.Size([3, 4]) 122 | assert td_recon["b"].batch_size == torch.Size([3, 4, 5]) 123 | 124 | assert (td_recon["a"] == td["a"]).all() 125 | assert (td_recon["b", "c"] == td["b", "c"]).all() 126 | # This breaks because str are loaded as bytes 127 | # assert (td_recon == td).all(), (td == td_recon).to_dict() 128 | 129 | td_dict = td.to_dict() 130 | td_recon_dict = td_recon.to_dict() 131 | 132 | # Checks that all items match 133 | def check(x, y): 134 | if isinstance(x, torch.Tensor): 135 | assert (x == y).all() 136 | return 137 | assert str(x) == y.decode("utf-8") 138 | 139 | tree_map(check, td_dict, td_recon_dict) 140 | 141 | 142 | if __name__ == "__main__": 143 | args, unknown = argparse.ArgumentParser().parse_known_args() 144 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 145 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | README Tutos 2 | ============ 3 | 4 | Check a rendered version of the tutorials on tensordict doc: https://pytorch.org/tensordict 5 | -------------------------------------------------------------------------------- /tutorials/dummy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The early bird gets the worm - which is what he deserves 4 | ======================================================== 5 | """ 6 | 7 | ############################################################################## 8 | # Style comes in all shapes and sizes. Therefore, the bigger you are, the more style you have. 9 | 10 | import tensordict 11 | 12 | td = tensordict.TensorDict({}, [100]) 13 | -------------------------------------------------------------------------------- /tutorials/media/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/tutorials/media/.gitkeep -------------------------------------------------------------------------------- /tutorials/media/imagenet-benchmark-speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/tutorials/media/imagenet-benchmark-speed.png -------------------------------------------------------------------------------- /tutorials/media/imagenet-benchmark-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/tutorials/media/imagenet-benchmark-time.png -------------------------------------------------------------------------------- /tutorials/media/transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/tutorials/media/transformer.png -------------------------------------------------------------------------------- /tutorials/sphinx_tuto/tensordict_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simplifying PyTorch Memory Management with TensorDict 3 | ===================================================== 4 | **Author**: `Tom Begley `_ 5 | 6 | In this tutorial you will learn how to control where the contents of a 7 | :class:`TensorDict` are stored in memory, either by sending those contents to a device, 8 | or by utilizing memory maps. 9 | """ 10 | 11 | ############################################################################## 12 | # Devices 13 | # ------- 14 | # When you create a :class:`TensorDict`, you can specify a device with the ``device`` 15 | # keyword argument. If the ``device`` is set, then all entries of the 16 | # :class:`TensorDict` will be placed on that device. If the ``device`` is not set, then 17 | # there is no requirement that entries in the :class:`TensorDict` must be on the same 18 | # device. 19 | # 20 | # In this example we instantiate a :class:`TensorDict` with ``device="cuda:0"``. When 21 | # we print the contents we can see that they have been moved onto the device. 22 | # 23 | # .. code-block:: 24 | # 25 | # >>> import torch 26 | # >>> from tensordict import TensorDict 27 | # >>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0") 28 | # >>> print(tensordict) 29 | # TensorDict( 30 | # fields={ 31 | # a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)}, 32 | # batch_size=torch.Size([10]), 33 | # device=cuda:0, 34 | # is_shared=True) 35 | # 36 | # If the device of the :class:`TensorDict` is not ``None``, new entries are also moved 37 | # onto the device. 38 | # 39 | # .. code-block:: 40 | # 41 | # >>> tensordict["b"] = torch.rand(10, 10) 42 | # >>> print(tensordict) 43 | # TensorDict( 44 | # fields={ 45 | # a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True), 46 | # b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)}, 47 | # batch_size=torch.Size([10]), 48 | # device=cuda:0, 49 | # is_shared=True) 50 | # 51 | # You can check the current device of the :class:`TensorDict` with the ``device`` 52 | # attribute. 53 | # 54 | # .. code-block:: 55 | # 56 | # >>> print(tensordict.device) 57 | # cuda:0 58 | # 59 | # The contents of the :class:`TensorDict` can be sent to a device like a PyTorch tensor 60 | # with :meth:`TensorDict.cuda() ` or 61 | # :meth:`TensorDict.device(device) ` with ``device`` 62 | # being the desired device. 63 | # 64 | # .. code-block:: 65 | # 66 | # >>> tensordict.to(torch.device("cpu")) 67 | # >>> print(tensordict) 68 | # TensorDict( 69 | # fields={ 70 | # a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), 71 | # b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, 72 | # batch_size=torch.Size([10]), 73 | # device=cpu, 74 | # is_shared=False) 75 | # >>> tensordict.cuda() 76 | # >>> print(tensordict) 77 | # TensorDict( 78 | # fields={ 79 | # a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True), 80 | # b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)}, 81 | # batch_size=torch.Size([10]), 82 | # device=cuda:0, 83 | # is_shared=True) 84 | # 85 | # The :meth:`TensorDict.device ` method requires a valid 86 | # device to be passed as the argument. If you want to remove the device from the 87 | # :class:`TensorDict` to allow values with different devices, you should use the 88 | # :meth:`TensorDict.clear_device ` method. 89 | # 90 | # .. code-block:: 91 | # 92 | # >>> tensordict.clear_device() 93 | # >>> print(tensordict) 94 | # TensorDict( 95 | # fields={ 96 | # a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True), 97 | # b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)}, 98 | # batch_size=torch.Size([10]), 99 | # device=None, 100 | # is_shared=False) 101 | # 102 | # Memory-mapped Tensors 103 | # --------------------- 104 | # ``tensordict`` provides a class :class:`~tensordict.MemoryMappedTensor` 105 | # which allows us to store the contents of a tensor on disk, while still 106 | # supporting fast indexing and loading of the contents in batches. 107 | # See the `ImageNet Tutorial <./tensorclass_imagenet.html>`_ for an 108 | # example of this in action. 109 | # 110 | # To convert the :class:`TensorDict` to a collection of memory-mapped tensors, use the 111 | # :meth:`TensorDict.memmap_ `. 112 | 113 | # sphinx_gallery_start_ignore 114 | import warnings 115 | 116 | import torch 117 | from tensordict import TensorDict 118 | 119 | warnings.filterwarnings("ignore") 120 | # sphinx_gallery_end_ignore 121 | 122 | tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10]) 123 | tensordict.memmap_() 124 | 125 | print(tensordict) 126 | 127 | ############################################################################## 128 | # Alternatively one can use the 129 | # :meth:`TensorDict.memmap_like ` method. This will 130 | # create a new :class:`~.TensorDict` of the same structure with 131 | # :class:`~tensordict.MemoryMappedTensor` values, however it will not copy the 132 | # contents of the original tensors to the 133 | # memory-mapped tensors. This allows you to create the memory-mapped 134 | # :class:`~.TensorDict` and then populate it slowly, and hence should generally be 135 | # preferred to ``memmap_``. 136 | 137 | tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10]) 138 | mm_tensordict = tensordict.memmap_like() 139 | 140 | print(mm_tensordict["a"].contiguous()) 141 | 142 | ############################################################################## 143 | # By default the contents of the :class:`TensorDict` will be saved to a temporary 144 | # location on disk, however if you would like to control where they are saved you can 145 | # use the keyword argument ``prefix="/path/to/root"``. 146 | # 147 | # The contents of the :class:`TensorDict` are saved in a directory structure that mimics 148 | # the structure of the :class:`TensorDict` itself. The contents of the tensor is saved 149 | # in a NumPy memmap, and the metadata in an associated PyTorch save file. For example, 150 | # the above :class:`TensorDict` is saved as follows: 151 | # 152 | # :: 153 | # 154 | # ├── a.memmap 155 | # ├── a.meta.pt 156 | # ├── b 157 | # │ ├── c.memmap 158 | # │ ├── c.meta.pt 159 | # │ └── meta.pt 160 | # └── meta.pt 161 | -------------------------------------------------------------------------------- /tutorials/sphinx_tuto/tensordict_preallocation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pre-allocating memory with TensorDict 3 | ===================================== 4 | **Author**: `Tom Begley `_ 5 | 6 | In this tutorial you will learn how to take advantage of memory pre-allocation in 7 | :class:`~.TensorDict`. 8 | """ 9 | 10 | ############################################################################## 11 | # Suppose that we have a function that returns a :class:`~.TensorDict` 12 | 13 | 14 | # sphinx_gallery_start_ignore 15 | import warnings 16 | 17 | warnings.filterwarnings("ignore") 18 | # sphinx_gallery_end_ignore 19 | import torch 20 | from tensordict.tensordict import TensorDict 21 | 22 | 23 | def make_tensordict(): 24 | return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3]) 25 | 26 | 27 | ############################################################################### 28 | # Perhaps we want to call this function multiple times and use the results to populate 29 | # a single :class:`~.TensorDict`. 30 | 31 | N = 10 32 | tensordict = TensorDict({}, batch_size=[N, 3]) 33 | 34 | for i in range(N): 35 | tensordict[i] = make_tensordict() 36 | 37 | print(tensordict) 38 | 39 | ############################################################################### 40 | # Because we have specified the ``batch_size`` of ``tensordict``, during the first 41 | # iteration of the loop we populate ``tensordict`` with empty tensors whose first 42 | # dimension is size ``N``, and whose remaining dimensions are determined by the return 43 | # value of ``make_tensordict``. In the above example, we pre-allocate an array of zeros 44 | # of size ``torch.Size([10, 3])`` for the key ``"a"``, and an array size 45 | # ``torch.Size([10, 3, 4])`` for the key ``"b"``. Subsequent iterations of the loop are 46 | # written in place. As a result, if not all values are filled, they get the default 47 | # value of zero. 48 | # 49 | # Let us demonstrate what is going on by stepping through the above loop. We first 50 | # initialise an empty :class:`~.TensorDict`. 51 | 52 | N = 10 53 | tensordict = TensorDict({}, batch_size=[N, 3]) 54 | print(tensordict) 55 | 56 | ############################################################################## 57 | # After the first iteration, ``tensordict`` has been prepopulated with tensors for both 58 | # ``"a"`` and ``"b"``. These tensors contain zeros except for the first row which we 59 | # have assigned random values to. 60 | 61 | random_tensordict = make_tensordict() 62 | tensordict[0] = random_tensordict 63 | 64 | assert (tensordict[1:] == 0).all() 65 | assert (tensordict[0] == random_tensordict).all() 66 | 67 | print(tensordict) 68 | 69 | ############################################################################## 70 | # Subsequent iterations, we update the pre-allocated tensors in-place. 71 | 72 | a = tensordict["a"] 73 | random_tensordict = make_tensordict() 74 | tensordict[1] = random_tensordict 75 | 76 | # the same tensor is stored under "a", but the values have been updated 77 | assert tensordict["a"] is a 78 | assert (tensordict[:2] != 0).all() 79 | -------------------------------------------------------------------------------- /tutorials/sphinx_tuto/tensordict_slicing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Slicing, Indexing, and Masking 3 | ============================== 4 | **Author**: `Tom Begley `_ 5 | 6 | In this tutorial you will learn how to slice, index, and mask a :class:`~.TensorDict`. 7 | """ 8 | 9 | ############################################################################## 10 | # As discussed in the tutorial 11 | # `Manipulating the shape of a TensorDict <./tensordict_shapes.html>`_, when we create a 12 | # :class:`~.TensorDict` we specify a ``batch_size``, which must agree 13 | # with the leading dimensions of all entries in the :class:`~.TensorDict`. Since we have 14 | # a guarantee that all entries share those dimensions in common, we are able to index 15 | # and mask the batch dimensions in the same way that we would index a 16 | # :class:`torch.Tensor`. The indices are applied along the batch dimensions to all of 17 | # the entries in the :class:`~.TensorDict`. 18 | # 19 | # For example, given a :class:`~.TensorDict` with two batch dimensions, 20 | # ``tensordict[0]`` returns a new :class:`~.TensorDict` with the same structure, and 21 | # whose values correspond to the first "row" of each entry in the original 22 | # :class:`~.TensorDict`. 23 | 24 | import torch 25 | from tensordict import TensorDict 26 | 27 | tensordict = TensorDict( 28 | {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4] 29 | ) 30 | 31 | print(tensordict[0]) 32 | 33 | ############################################################################## 34 | # The same syntax applies as for regular tensors. For example if we wanted to drop the 35 | # first row of each entry we could index as follows 36 | 37 | print(tensordict[1:]) 38 | 39 | ############################################################################## 40 | # We can index multiple dimensions simultaneously 41 | 42 | print(tensordict[:, 2:]) 43 | 44 | ############################################################################## 45 | # We can also use ``Ellipsis`` to represent as many ``:`` as would be needed to make 46 | # the selection tuple the same length as ``tensordict.batch_dims``. 47 | 48 | print(tensordict[..., 2:]) 49 | 50 | ############################################################################## 51 | # .. note: 52 | # 53 | # Remember that all indexing is applied relative to the batch dimensions. In the 54 | # above example there is a difference between ``tensordict["a"][..., 2:]`` and 55 | # ``tensordict[..., 2:]["a"]``. The first retrieves the three-dimensional tensor 56 | # stored under the key ``"a"`` and applies the index ``2:`` to the final dimension. 57 | # The second applies the index ``2:`` to the final *batch dimension*, which is the 58 | # second dimension, before retrieving the result. 59 | # 60 | # Setting Values with Indexing 61 | # ---------------------------- 62 | # In general, ``tensordict[index] = new_tensordict`` will work as long as the batch 63 | # sizes are compatible. 64 | 65 | tensordict = TensorDict( 66 | {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4] 67 | ) 68 | 69 | td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4]) 70 | tensordict[:-1] = td2 71 | print(tensordict["a"], tensordict["b"]) 72 | 73 | ############################################################################## 74 | # Masking 75 | # ------- 76 | # We mask :class:`TensorDict` as we mask tensors. 77 | 78 | mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]) 79 | tensordict[mask] 80 | -------------------------------------------------------------------------------- /tutorials/src/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tensordict/73fe89bc067b40219dc0d1245b655bcec85464f3/tutorials/src/.gitkeep -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.9.0 2 | --------------------------------------------------------------------------------