├── .coveragerc ├── .flexci ├── config.pbtxt ├── linux │ ├── Dockerfile │ ├── build_and_push.sh │ ├── download_mnist.sh │ ├── main-flexci.sh │ ├── script.sh │ └── unittest.sh └── windows │ ├── README.md │ ├── _error_handler.ps1 │ ├── _flexci.ps1 │ ├── download_mnist.ps1 │ ├── run.bat │ └── test.ps1 ├── .github ├── CODEOWNERS ├── release-drafter.yml └── workflows │ ├── nightly-test-cpu.yml │ ├── pretest-and-test.yml │ └── publish.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── Makefile ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _example │ ├── quick_start_log.py │ ├── quick_start_progress.py │ ├── quick_start_save.py │ └── quick_start_trainer.py │ ├── _templates │ └── autosummary │ │ ├── class.rst │ │ └── module.rst │ ├── conf.py │ ├── index.rst │ ├── reference │ ├── generated │ │ └── .gitignore │ └── index.rst │ └── user_guide │ ├── codeblocks.rst │ ├── config.md │ ├── cuda.md │ ├── extensions.md │ ├── index.rst │ ├── lazy.md │ ├── logic.rst │ ├── onnx.md │ ├── quick_start.rst │ ├── reporting.md │ ├── runtimes.rst │ ├── snapshot.md │ └── trainer.rst ├── example ├── .gitignore ├── cifar10.py ├── cifar10_ddp_trainer.py ├── ignite-mnist.py ├── mnist.py ├── mnist_custom_logic.py ├── mnist_ddp.py └── mnist_trainer.py ├── example_pysen.toml ├── pyproject.toml ├── pytest.ini ├── pytorch_pfn_extras ├── __init__.py ├── _cupy │ ├── __init__.py │ └── _cupy_stub.py ├── _dynamo │ ├── __init__.py │ ├── _compile.py │ ├── _optimizer.py │ └── _splitter.py ├── _tensor.py ├── _torch_version.py ├── _version.py ├── config.py ├── config_types.py ├── cuda │ ├── __init__.py │ └── _allocator.py ├── dataloaders │ ├── __init__.py │ ├── dataloader.py │ └── utils.py ├── dataset │ ├── __init__.py │ ├── shared_dataset.py │ └── tabular │ │ ├── __init__.py │ │ ├── _asmode.py │ │ ├── _concat.py │ │ ├── _join.py │ │ ├── _slice.py │ │ ├── _transform.py │ │ ├── _utils.py │ │ ├── _with_converter.py │ │ ├── delegate_dataset.py │ │ ├── from_data.py │ │ └── tabular_dataset.py ├── distributed │ ├── __init__.py │ ├── _dataset_util.py │ ├── _distributed_validation_sampler.py │ └── _initialize.py ├── engine.py ├── handler │ ├── __init__.py │ ├── _code_block.py │ ├── _handler.py │ └── _logic.py ├── logging.py ├── nn │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── ensure_shape.py │ │ ├── extended_sequential.py │ │ ├── lazy.py │ │ ├── lazy_batchnorm.py │ │ ├── lazy_conv.py │ │ └── lazy_linear.py │ └── parallel │ │ ├── __init__.py │ │ └── distributed.py ├── onnx │ ├── __init__.py │ ├── __init__.pyi │ ├── _as_output.py │ ├── _constants.py │ ├── _globals.py │ ├── _grad.py │ ├── _helper.py │ ├── _lax.py │ ├── annotate.py │ ├── export_testcase.py │ ├── load.py │ ├── pfto_exporter │ │ ├── __init__.py │ │ └── export.py │ ├── strip_large_tensor.py │ ├── symbolic_registry.py │ └── unstrip_tensor.py ├── ops │ ├── __init__.py │ └── register.py ├── profiler │ ├── __init__.py │ ├── _record.py │ ├── _time_summary.py │ ├── _tracing.py │ └── _util.py ├── py.typed ├── reporting.py ├── runtime │ ├── __init__.py │ ├── _autocast.py │ ├── _map.py │ ├── _registry.py │ ├── _runtime.py │ └── _to.py ├── testing.py ├── torchscript.py ├── training │ ├── __init__.py │ ├── _evaluator.py │ ├── _manager_protocol.py │ ├── _trainer.py │ ├── _transform_model.py │ ├── _trigger_util.py │ ├── _util.py │ ├── extension.py │ ├── extensions │ │ ├── __init__.py │ │ ├── _snapshot.py │ │ ├── accumulate │ │ │ ├── __init__.py │ │ │ ├── _accumulate_base.py │ │ │ ├── _accumulate_utils.py │ │ │ ├── _summary │ │ │ │ ├── __init__.py │ │ │ │ ├── _average_summary.py │ │ │ │ ├── _base_summary.py │ │ │ │ ├── _max_summary.py │ │ │ │ ├── _min_summary.py │ │ │ │ ├── _standard_deviation_summary.py │ │ │ │ ├── _summary_utils.py │ │ │ │ └── _unbiased_standard_deviation_summary.py │ │ │ ├── average_accumulate.py │ │ │ ├── max_accumulate.py │ │ │ ├── min_accumulate.py │ │ │ ├── standard_deviation_accumulate.py │ │ │ └── unbiased_standard_deviation_accumulate.py │ │ ├── best_value.py │ │ ├── evaluator.py │ │ ├── fail_on_non_number.py │ │ ├── log_report.py │ │ ├── lr_scheduler.py │ │ ├── micro_average.py │ │ ├── parameter_statistics.py │ │ ├── plot_report.py │ │ ├── print_report.py │ │ ├── print_report_notebook.py │ │ ├── profile_report.py │ │ ├── progress_bar.py │ │ ├── progress_bar_notebook.py │ │ ├── slack.py │ │ ├── slack_manifest.yml │ │ ├── snapshot_writers.py │ │ ├── timeline_trace.py │ │ ├── util.py │ │ ├── value_observation.py │ │ └── variable_statistics_plot.py │ ├── manager.py │ ├── metrics.py │ ├── trigger.py │ └── triggers │ │ ├── __init__.py │ │ ├── early_stopping_trigger.py │ │ ├── function_trigger.py │ │ ├── interval_trigger.py │ │ ├── manual_schedule_trigger.py │ │ ├── minmax_value_trigger.py │ │ ├── once_trigger.py │ │ └── time_trigger.py ├── utils │ ├── __init__.py │ ├── checkpoint.py │ └── comparer.py └── writing │ ├── __init__.py │ ├── _parallel_writer.py │ ├── _queue_writer.py │ ├── _simple_writer.py │ ├── _tensorboard_writer.py │ └── _writer_base.py ├── setup.cfg ├── setup.py ├── stubs └── torch │ ├── _C │ └── __init__.pyi │ ├── fx │ ├── __init__.pyi │ ├── graph.pyi │ ├── graph_module.pyi │ ├── node.pyi │ └── proxy.pyi │ └── library │ └── __init__.pyi └── tests ├── conftest.py ├── pytorch_pfn_extras_tests ├── __init__.py ├── cuda_tests │ ├── __init__.py │ └── test_allocator.py ├── dataloader_test │ ├── __init__.py │ └── test_dataloader.py ├── dataset_tests │ ├── __init__.py │ ├── tabular_tests │ │ ├── __init__.py │ │ ├── dummy_dataset.py │ │ ├── test_asmode.py │ │ ├── test_concat.py │ │ ├── test_delegate_dataset.py │ │ ├── test_from_data.py │ │ ├── test_join.py │ │ ├── test_slice.py │ │ ├── test_tabular_dataset.py │ │ ├── test_transform.py │ │ ├── test_with_converter.py │ │ └── test_with_torch_dataloader.py │ └── test_shared_dataset.py ├── distributed_tests │ ├── __init__.py │ ├── test_distributed_subset_indices.py │ └── test_distributed_validation_sampler.py ├── dynamo_tests │ └── test_compile.py ├── handler_tests │ ├── __init__.py │ ├── test_handler.py │ └── test_logic.py ├── nn_tests │ ├── __init__.py │ ├── modules_tests │ │ ├── __init__.py │ │ ├── test_ensure_shape.py │ │ ├── test_extended_sequential.py │ │ ├── test_lazy.py │ │ ├── test_lazy_batchnorm.py │ │ ├── test_lazy_conv.py │ │ └── test_lazy_linear.py │ └── parallel_tests │ │ ├── __init__.py │ │ └── test_distributed.py ├── onnx_tests │ ├── __init__.py │ ├── conftest.py │ ├── test_annotate.py │ ├── test_as_output.py │ ├── test_export.py │ ├── test_export_testcase.py │ ├── test_grad.py │ ├── test_helper.py │ ├── test_lax.py │ ├── test_load_model.py │ ├── test_torchvision.py │ └── utils.py ├── profiler_tests │ ├── __init__.py │ ├── test_record.py │ └── test_time_summary.py ├── runtime_tests │ ├── __init__.py │ ├── test_jit_runtime.py │ ├── test_registry.py │ ├── test_runtime.py │ └── test_to.py ├── test_config.py ├── test_config_types.py ├── test_logging.py ├── test_ops │ └── test_register.py ├── test_reporter.py ├── test_tensor.py ├── test_torchscript.py ├── test_writing.py ├── training_tests │ ├── __init__.py │ ├── extensions_tests │ │ ├── __init__.py │ │ ├── test_accumulate.py │ │ ├── test_best_value.py │ │ ├── test_distributed_snapshot.py │ │ ├── test_evaluator.py │ │ ├── test_fail_on_non_number.py │ │ ├── test_log_buffer.py │ │ ├── test_log_report.py │ │ ├── test_lr_scheduler.py │ │ ├── test_micro_average.py │ │ ├── test_plot_report.py │ │ ├── test_print_report.py │ │ ├── test_print_report_notebook.py │ │ ├── test_profile_report.py │ │ ├── test_progress_bar.py │ │ ├── test_progress_bar_notebook.py │ │ ├── test_sharded_snapshot.py │ │ ├── test_slack.py │ │ ├── test_snapshot.py │ │ ├── test_snapshot_writers.py │ │ ├── test_timeline_trace.py │ │ ├── test_value_observation.py │ │ └── test_variable_statistics_plot.py │ ├── test_engine.py │ ├── test_evaluator_metrics.py │ ├── test_extension.py │ ├── test_extension_entry.py │ ├── test_manager.py │ ├── test_trainer.py │ ├── test_trigger_util.py │ └── triggers_tests │ │ ├── __init__.py │ │ ├── test_early_stopping_trigger.py │ │ ├── test_function_trigger.py │ │ ├── test_interval_trigger.py │ │ ├── test_minmax_value_trigger.py │ │ ├── test_once_trigger.py │ │ ├── test_schedule_trigger.py │ │ └── test_time_trigger.py └── utils_tests │ ├── __init__.py │ ├── test_checkpoint.py │ ├── test_comparer.py │ └── test_new_comparer.py ├── requirements.mpi.txt └── requirements.txt /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | 3 | exclude_lines = 4 | raise NotImplementedError 5 | pass 6 | -------------------------------------------------------------------------------- /.flexci/linux/Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 2 | ARG base_image 3 | FROM ${base_image} 4 | 5 | # Update GPG repository key. 6 | RUN export DEBIAN_FRONTEND=noninteractive && \ 7 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \ 8 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 9 | 10 | # Install pyenv requirements. 11 | # https://github.com/pyenv/pyenv/wiki/Common-build-problems#requirements 12 | RUN export DEBIAN_FRONTEND=noninteractive && \ 13 | apt-get -y update && \ 14 | apt-get -y install \ 15 | build-essential libssl-dev zlib1g-dev libbz2-dev \ 16 | libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \ 17 | xz-utils tk-dev libffi-dev liblzma-dev git cmake protobuf-compiler libprotobuf-dev \ 18 | openmpi-bin openmpi-common libopenmpi-dev && \ 19 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 20 | 21 | # Install pyenv. 22 | RUN git clone https://github.com/pyenv/pyenv.git /opt/pyenv 23 | ENV PYENV_ROOT=/opt/pyenv 24 | ENV PATH ${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${PATH} 25 | 26 | # Install Python. 27 | ARG python_version 28 | RUN pyenv install ${python_version} && \ 29 | pyenv global ${python_version} 30 | 31 | COPY ./tests/requirements.txt ./tests/requirements.mpi.txt ./ 32 | 33 | # Install test dependencies. 34 | ARG pip_install_torch_args 35 | ARG pip_install_dep_args 36 | RUN pip install -U pip && \ 37 | pip install -U setuptools && \ 38 | pip install -r requirements.txt -r requirements.mpi.txt ${pip_install_torch_args} && \ 39 | pip install ${pip_install_dep_args} && \ 40 | pip list 41 | -------------------------------------------------------------------------------- /.flexci/linux/download_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -uex 2 | 3 | # Download MNIST dataset for examples. 4 | 5 | set -uex 6 | #curl -LO http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 7 | #curl -LO http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 8 | #curl -LO http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 9 | #curl -LO http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 10 | mkdir -p mnist_raw 11 | gsutil -m cp -r "gs://chainer-artifacts-pfn-public-ci/pytorch-pfn-extras-assets/mnist/*" mnist_raw 12 | -------------------------------------------------------------------------------- /.flexci/linux/main-flexci.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Bootstrap script for FlexCI. 4 | 5 | set -ue 6 | 7 | echo "Environment Variables:" 8 | env 9 | 10 | pull_req="" 11 | if [[ "${FLEXCI_BRANCH:-}" == refs/pull/* ]]; then 12 | # Extract pull-request ID 13 | pull_req="$(echo "${FLEXCI_BRANCH}" | cut -d/ -f3)" 14 | echo "Testing Pull-Request: #${pull_req}" 15 | fi 16 | 17 | export PPE_FLEXCI_IMAGE_PUSH="0" 18 | if [[ "${pull_req}" == "" ]]; then 19 | # Push images when running on branch. 20 | export PPE_FLEXCI_IMAGE_PUSH="1" 21 | fi 22 | 23 | "$(dirname ${0})/script.sh" "${@}" 24 | -------------------------------------------------------------------------------- /.flexci/linux/unittest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -uex 4 | 5 | TEST_MODE="${1:-}" 6 | # Install 7 | rm -rf dist 8 | python setup.py sdist 9 | pip install dist/pytorch_pfn_extras-*.tar.gz 10 | 11 | # Show packages 12 | pip list 13 | 14 | if [ "${TEST_MODE}" == "unittest" ]; then 15 | # Run unit tests 16 | pushd tests 17 | JUPYTER_PLATFORM_DIRS=1 \ 18 | python -m pytest --cov-report=html --cov pytorch_pfn_extras . 19 | popd 20 | 21 | # Publish coverage report 22 | mv tests/htmlcov /output/htmlcov 23 | 24 | elif [ "${TEST_MODE}" == "mpitest" ]; then 25 | # Run unit tests with mpi 26 | make mpitest 27 | 28 | # Run examples 29 | if [ -d mnist_raw ]; then 30 | mkdir -p data/MNIST/raw 31 | mv mnist_raw/* data/MNIST/raw 32 | fi 33 | pushd example 34 | python mnist.py --batch-size 2048 --test-batch-size 2048 --epochs 1 --save-model 35 | python mnist_trainer.py --batch-size 2048 --test-batch-size 2048 --epochs 1 36 | python ignite-mnist.py --batch_size 2048 --val_batch_size 2048 --epochs 1 37 | python cifar10.py --batch-size 2048 --test-batch-size 2048 --epoch 1 --no-autoload 38 | MASTER_ADDR=127.0.0.1 MASTER_PORT=1236 mpirun -n 2 --allow-run-as-root python mnist_ddp.py --batch-size 2048 --test-batch-size 2048 --epochs 1 39 | MASTER_ADDR=127.0.0.1 MASTER_PORT=1236 mpirun -n 2 --allow-run-as-root python cifar10_ddp_trainer.py --batch-size 2048 --test-batch-size 2048 --epoch 1 --no-autoload 40 | popd 41 | 42 | # Trainer 43 | pushd example 44 | python mnist_custom_logic.py --batch-size 2048 --test-batch-size 2048 --epochs 1 45 | popd 46 | 47 | # Comparer 48 | pushd example 49 | mkdir -p comp_dump_cpu 50 | python mnist_trainer.py --device cpu --epochs 2 --batch-size 1024 --deterministic --compare-dump comp_dump_cpu 51 | CUBLAS_WORKSPACE_CONFIG=:4096:8 python mnist_trainer.py --device cuda --epochs 2 --batch-size 1024 --deterministic --compare-with comp_dump_cpu 52 | popd 53 | 54 | # For docs 55 | pushd docs/source/_example/ 56 | python quick_start_trainer.py 57 | python quick_start_log.py 58 | python quick_start_progress.py 59 | python quick_start_save.py 60 | popd 61 | else 62 | echo "Unexpected TEST_MODE: ${TEST_MODE}" 63 | exit 1 64 | fi 65 | -------------------------------------------------------------------------------- /.flexci/windows/README.md: -------------------------------------------------------------------------------- 1 | # Windows Tests 2 | 3 | FlexCI helper scripts are borrowed from CuPy. 4 | 5 | https://github.com/cupy/cupy/tree/master/.pfnci/windows 6 | -------------------------------------------------------------------------------- /.flexci/windows/_error_handler.ps1: -------------------------------------------------------------------------------- 1 | # https://stackoverflow.com/questions/9948517/how-to-stop-a-powershell-script-on-the-first-error 2 | 3 | Set-StrictMode -Version Latest 4 | $ErrorActionPreference = "Stop" 5 | $PSDefaultParameterValues['*:ErrorAction']='Stop' 6 | 7 | function RunOrDie { 8 | $cmd, $params = $args 9 | $params = @($params) 10 | $global:LastExitCode = 0 11 | & $cmd @params 12 | if (-not $?) { 13 | throw "Command failed (exit code = $LastExitCode): $cmd $params" 14 | } 15 | } 16 | 17 | function RunOrDieWithRetry { 18 | $retry, $cmd, $params = $args 19 | for ($i = 1; $i -le $retry; $i++) { 20 | try { 21 | RunOrDie $cmd $params 22 | return 23 | } catch { 24 | $errmsg = $error[0] 25 | Write-Host "RunOrDieWithRetry (attempt ${i}): ${errmsg}" 26 | } 27 | } 28 | throw "No more retry." 29 | } 30 | -------------------------------------------------------------------------------- /.flexci/windows/download_mnist.ps1: -------------------------------------------------------------------------------- 1 | # Download MNIST dataset for examples. 2 | 3 | $ErrorActionPreference = "Stop" 4 | . "$PSScriptRoot\_error_handler.ps1" 5 | 6 | New-Item ../data/MNIST/raw -ItemType Directory 7 | Push-Location ../data/MNIST/raw 8 | #curl.exe -LO http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 9 | #curl.exe -LO http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 10 | #curl.exe -LO http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 11 | #curl.exe -LO http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 12 | gsutil -m cp -r "gs://chainer-artifacts-pfn-public-ci/pytorch-pfn-extras-assets/mnist/*" . 13 | Pop-Location 14 | -------------------------------------------------------------------------------- /.flexci/windows/run.bat: -------------------------------------------------------------------------------- 1 | PowerShell .flexci\windows\test.ps1 %1 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @asi1024 @emcastillo @kmaehashi @linshokaku 2 | /pytorch_pfn_extras/onnx @xuzijian629 @take-cheeze @asi1024 @emcastillo @kmaehashi @linshokaku 3 | /tests/pytorch_pfn_extras_tests/onnx_tests @xuzijian629 @take-cheeze @asi1024 @emcastillo @kmaehashi @linshokaku 4 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | template: | 2 | ## Changes 3 | 4 | $CHANGES 5 | -------------------------------------------------------------------------------- /.github/workflows/nightly-test-cpu.yml: -------------------------------------------------------------------------------- 1 | name: "Nightly CPU Tests" 2 | 3 | on: 4 | schedule: 5 | - cron: 0 0 * * * 6 | workflow_dispatch: 7 | inputs: 8 | ref: 9 | description: 'Commit or branch to test (e.g., refs/pull/1234/merge)' 10 | type: string 11 | 12 | jobs: 13 | nightly-test: 14 | runs-on: ubuntu-22.04 15 | env: 16 | JUPYTER_PLATFORM_DIRS: "1" 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v4 21 | with: 22 | ref: ${{ inputs.ref || github.ref }} 23 | submodules: recursive 24 | 25 | - name: Setup Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.9' 29 | 30 | - name: Install 31 | run: | 32 | pip install -U pip wheel 33 | pip install -v -e . -r ./tests/requirements.txt "torch>=0.0.0a1" "torchvision>=0.0.0a1" "torchaudio>=0.0.0a1" --extra-index-url https://download.pytorch.org/whl/nightly/cpu 34 | # Test PPE is importable with minimum dependency 35 | python -c 'import pytorch_pfn_extras' 36 | 37 | - name: Test CPU only 38 | run: | 39 | make cputest 40 | -------------------------------------------------------------------------------- /.github/workflows/pretest-and-test.yml: -------------------------------------------------------------------------------- 1 | name: "Pre-review Tests" 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-22.04 8 | strategy: 9 | matrix: 10 | torch: ['1.13.*', '2.0.*', '2.1.*', '2.2.*', '2.3.*', '2.4.*', '2.5.*', '2.6.*'] 11 | 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v2 15 | with: 16 | submodules: recursive 17 | 18 | - name: Setup Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: '3.9' 22 | 23 | - uses: actions/cache@v3 24 | with: 25 | path: | 26 | ~/.cache 27 | key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}-${{ matrix.torch }} 28 | restore-keys: | 29 | ${{ runner.os }}-pip- 30 | 31 | - name: Install 32 | run: | 33 | pip install -U pip wheel 34 | pip install -e . -r ./tests/requirements.txt torch==${{ matrix.torch }} --extra-index-url https://download.pytorch.org/whl/cpu 35 | # Test PPE is importable with minimum dependency 36 | python -c 'import pytorch_pfn_extras' 37 | 38 | - name: Code Style 39 | run: | 40 | make lint 41 | 42 | - name: Code Style (Examples) 43 | run: | 44 | make example_lint 45 | 46 | - name: Test CPU only 47 | run: | 48 | make cputest 49 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: "Publish to PyPI" 2 | on: 3 | release: 4 | types: [published] 5 | workflow_dispatch: 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-22.04 10 | steps: 11 | - name: Git Checkout 12 | uses: actions/checkout@v3 13 | - name: Setup Python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: '3.11' 17 | - name: Install Dependencies 18 | run: python3 -m pip install build --user 19 | - name: Build Packages 20 | run: python3 -m build --outdir dist/ . 21 | - name: Upload Artifacts 22 | uses: actions/upload-artifact@v4 23 | with: 24 | name: dist 25 | path: dist 26 | 27 | publish: 28 | needs: build 29 | runs-on: ubuntu-22.04 30 | permissions: 31 | id-token: write 32 | steps: 33 | - name: Download Artifacts 34 | uses: actions/download-artifact@v4 35 | with: 36 | name: dist 37 | path: dist 38 | - name: Enumerate Files 39 | run: find dist -ls 40 | - name: Publish to PyPI 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | 107 | # onnx 108 | /out 109 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 2 | 3 | version: 2 4 | formats: all 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "3.12" 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | - requirements: docs/requirements.txt 16 | 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Preferred Networks, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := bash 2 | .SHELLFLAGS := -eu -o pipefail -c 3 | MAKEFLAGS += --warn-undefined-variables 4 | .DEFAULT_GOAL := help 5 | 6 | PWD := $(realpath $(dir $(abspath $(firstword $(MAKEFILE_LIST))))) 7 | 8 | PY := python 9 | PIP := $(PY) -m pip 10 | 11 | PROCESS_NUM = 2 12 | MPI_OUTPUT_FILE_DIR = $(realpath $(shell mktemp -d)) 13 | 14 | .PHONY: format 15 | format: ## Format the Python code. 16 | cp "$$($(PIP) show torch | awk '/^Location:/ { print $$2 }')/torch/__init__.py" stubs/torch/__init__.py 17 | trap "rm -f stubs/torch/__init__.py" EXIT; MYPYPATH="$(PWD)/stubs" $(PY) -m pysen run format lint 18 | 19 | .PHONY: lint 20 | lint: ## Lint the Python code. 21 | cp "$$($(PIP) show torch | awk '/^Location:/ { print $$2 }')/torch/__init__.py" stubs/torch/__init__.py 22 | trap "rm -f stubs/torch/__init__.py" EXIT; MYPYPATH="$(PWD)/stubs" $(PY) -m pysen run lint 23 | 24 | .PHONY: test 25 | test: ## Run all tests. 26 | $(PY) -m pytest -m "not mpi" tests 27 | 28 | .PHONY: cputest 29 | cputest: ## Run all tests except for ones requiring GPU. 30 | $(PY) -m pytest -m "not gpu and not mpi" tests 31 | 32 | .PHONY: mpitest 33 | mpitest: ## Run all tests except for ones requiring GPU. 34 | mpi_output_file_dir=$(MPI_OUTPUT_FILE_DIR); \ 35 | mpirun --allow-run-as-root -n $(PROCESS_NUM) --output-filename $$mpi_output_file_dir -x TORCH_DISTRIBUTED_DEBUG=DETAIL $(PY) -m pytest -m mpi tests > /dev/null 2> /dev/null &&:; \ 36 | ret=$$?; \ 37 | for i in $$(seq 0 $$(($(PROCESS_NUM) - 1))); do echo ========= MPI process $$i =========; cat $$mpi_output_file_dir/1/rank.$$i/stdout; cat $$mpi_output_file_dir/1/rank.$$i/stderr; done; \ 38 | [ $$ret = 0 ] 39 | 40 | .PHONY: example_lint 41 | example_lint: ## Format the Python code. 42 | cp "$$($(PIP) show torch | awk '/^Location:/ { print $$2 }')/torch/__init__.py" stubs/torch/__init__.py 43 | trap "rm -f stubs/torch/__init__.py" EXIT; $(PY) -m pysen --config ./example_pysen.toml run lint 44 | 45 | .PHONY: help 46 | help: ## Display this help message. 47 | @grep -E '^[%%a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \ 48 | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-pfn-extras 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/pytorch-pfn-extras)](https://pypi.python.org/pypi/pytorch-pfn-extras) 4 | [![Docs](https://img.shields.io/readthedocs/pytorch-pfn-extras)](https://pytorch-pfn-extras.readthedocs.io/) 5 | [![License](https://img.shields.io/github/license/pfnet/pytorch-pfn-extras)](https://github.com/pfnet/pytorch-pfn-extras/blob/master/LICENSE) 6 | 7 | Supplementary components to accelerate research and development in PyTorch. 8 | 9 | ## Installation 10 | 11 | ```sh 12 | pip install pytorch-pfn-extras 13 | 14 | # Use `[onnx]` to use onnx submodule like: 15 | # pip install "pytorch-pfn-extras[onnx]" 16 | 17 | ### Optinal dependencies 18 | # For PlotReport / VariableStatisticsPlot extensions 19 | pip install matplotlib 20 | 21 | # For IgniteExtensionsManager 22 | pip install pytorch-ignite torchvision 23 | 24 | # For CuPy interoperability (see: https://docs.cupy.dev/en/stable/install.html) 25 | pip install cupy # or cupy-cudaXXX 26 | ``` 27 | 28 | ## Requirements 29 | 30 | * Python 3.9+ 31 | * PyTorch 1.13+ 32 | 33 | Optional dependencies: 34 | 35 | * CuPy 8.0+ for PyTorch/CuPy interoperatbility 36 | 37 | ## Documentation 38 | 39 | Refer to [Read The Docs](https://pytorch-pfn-extras.readthedocs.io/) for the complete documentation. 40 | 41 | Below are some quick-links to the most important features of the library. 42 | 43 | * [Extensions Manager](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/extensions.html) 44 | * [Reporting](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/reporting.html) 45 | * [Lazy Modules](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/lazy.html) 46 | * [Distributed Snapshot](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/snapshot.html) 47 | * [Config System](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/config.html) 48 | * [ONNX Utils](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/onnx.html) 49 | * [CUDA Utils (CuPy Interoperability)](https://pytorch-pfn-extras.readthedocs.io/en/latest/user_guide/cuda.html) 50 | 51 | ## Examples 52 | 53 | * [Custom training loop](example/mnist.py) 54 | * [Ignite integration](example/ignite-mnist.py) 55 | 56 | ## Contribution Guide 57 | 58 | You can contribute to this project by sending a pull request. 59 | After approval, the pull request will be merged by the reviewer. 60 | 61 | Before making a contribution, please confirm that: 62 | 63 | - Code quality stays consistent across the script, module or package. 64 | - Code is covered by unit tests. 65 | - API is maintainable. 66 | 67 | ## License 68 | 69 | MIT License 70 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.3.0 2 | pydata_sphinx_theme==0.11.0 3 | myst-parser==1.0.0 4 | 5 | # TODO: #204 6 | onnx 7 | -------------------------------------------------------------------------------- /docs/source/_example/quick_start_log.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class Model(torch.nn.Module): 6 | def __init__(self, *args, **kwargs) -> None: 7 | super().__init__(*args, **kwargs) 8 | self.linear = torch.nn.Linear(in_features=64, out_features=2) 9 | self.criterion = torch.nn.NLLLoss() 10 | 11 | def forward(self, x, target): 12 | y = self.linear.forward(x).log_softmax(dim=1) 13 | loss = self.criterion.forward(y, target) 14 | return {"loss": loss} 15 | 16 | 17 | model = Model() 18 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 19 | 20 | device = "cuda:0" 21 | epochs = 3 22 | trainer = ppe.engine.create_trainer( 23 | models=model, 24 | optimizers=optimizer, 25 | max_epochs=epochs, 26 | evaluator=ppe.engine.create_evaluator( 27 | models=model, 28 | device=device, 29 | options={ 30 | "eval_report_keys": [ 31 | "loss" 32 | ], # Let the value of the loss be notified to the LogReport. 33 | }, 34 | ), 35 | device=device, 36 | options={ 37 | "train_report_keys": [ 38 | "loss" 39 | ], # Let the value of the loss be notified to the LogReport. 40 | }, 41 | ) 42 | 43 | trainer.extend( 44 | ppe.training.extensions.LogReport() 45 | ) # It is an extension to collect parameters reported during training. 46 | 47 | ppe.to(model, device=device) 48 | 49 | batch_size = 10 50 | training_data = [ 51 | { 52 | "x": torch.rand((batch_size, 64)), 53 | "target": torch.ones((batch_size,), dtype=torch.long), 54 | } 55 | for _ in range(10) 56 | ] 57 | validation_data = [ 58 | { 59 | "x": torch.rand((batch_size, 64)), 60 | "target": torch.ones((batch_size,), dtype=torch.long), 61 | } 62 | for _ in range(10) 63 | ] 64 | 65 | trainer.run(train_loader=training_data, val_loader=validation_data) 66 | 67 | print("Finish training!") 68 | -------------------------------------------------------------------------------- /docs/source/_example/quick_start_progress.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class Model(torch.nn.Module): 6 | def __init__(self, *args, **kwargs) -> None: 7 | super().__init__(*args, **kwargs) 8 | self.linear = torch.nn.Linear(in_features=64, out_features=2) 9 | self.criterion = torch.nn.NLLLoss() 10 | 11 | def forward(self, x, target): 12 | y = self.linear.forward(x).log_softmax(dim=1) 13 | loss = self.criterion.forward(y, target) 14 | return {"loss": loss} 15 | 16 | 17 | model = Model() 18 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 19 | 20 | device = "cuda:0" 21 | epochs = 3 22 | trainer = ppe.engine.create_trainer( 23 | models=model, 24 | optimizers=optimizer, 25 | max_epochs=epochs, 26 | evaluator=ppe.engine.create_evaluator( 27 | models=model, 28 | device=device, 29 | options={ 30 | "eval_report_keys": ["loss"], 31 | }, 32 | ), 33 | device=device, 34 | options={ 35 | "train_report_keys": ["loss"], 36 | }, 37 | ) 38 | 39 | trainer.extend(ppe.training.extensions.LogReport()) 40 | trainer.extend(ppe.training.extensions.ProgressBar()) 41 | trainer.extend( 42 | ppe.training.extensions.PrintReport( # Displays the collected logs interactively. 43 | [ 44 | "epoch", # epoch, iteration, elapsed_time are automatically collected by LogReport. 45 | "iteration", 46 | "elapsed_time", 47 | "train/loss", # The parameters specified by train_report_keys are collected under keys with the 'train/' prefix. 48 | "val/loss", # The parameters specified by eval_report_keys are collected under keys with the 'val/' prefix. 49 | ], 50 | ) 51 | ) 52 | 53 | ppe.to(model, device=device) 54 | 55 | batch_size = 10 56 | training_data = [ 57 | { 58 | "x": torch.rand((batch_size, 64)), 59 | "target": torch.ones((batch_size,), dtype=torch.long), 60 | } 61 | for _ in range(10) 62 | ] 63 | validation_data = [ 64 | { 65 | "x": torch.rand((batch_size, 64)), 66 | "target": torch.ones((batch_size,), dtype=torch.long), 67 | } 68 | for _ in range(10) 69 | ] 70 | 71 | trainer.run(train_loader=training_data, val_loader=validation_data) 72 | 73 | print("Finish training!") 74 | -------------------------------------------------------------------------------- /docs/source/_example/quick_start_save.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class Model(torch.nn.Module): 6 | def __init__(self, *args, **kwargs) -> None: 7 | super().__init__(*args, **kwargs) 8 | self.linear = torch.nn.Linear(in_features=64, out_features=2) 9 | self.criterion = torch.nn.NLLLoss() 10 | 11 | def forward(self, x, target): 12 | y = self.linear.forward(x).log_softmax(dim=1) 13 | loss = self.criterion.forward(y, target) 14 | return {"loss": loss} 15 | 16 | 17 | model = Model() 18 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 19 | 20 | 21 | device = "cuda:0" 22 | epochs = 3 23 | trainer = ppe.engine.create_trainer( 24 | models=model, 25 | optimizers=optimizer, 26 | max_epochs=epochs, 27 | evaluator=ppe.engine.create_evaluator( 28 | models=model, 29 | device=device, 30 | options={ 31 | "eval_report_keys": ["loss"], 32 | }, 33 | ), 34 | device=device, 35 | options={ 36 | "train_report_keys": ["loss"], 37 | }, 38 | ) 39 | 40 | trainer.extend(ppe.training.extensions.LogReport()) 41 | trainer.extend(ppe.training.extensions.ProgressBar()) 42 | trainer.extend( 43 | ppe.training.extensions.PrintReport( # Displays the collected logs interactively. 44 | [ 45 | "epoch", # epoch, iteration, elapsed_time are automatically collected by LogReport. 46 | "iteration", 47 | "elapsed_time", 48 | "train/loss", # The parameters specified by train_report_keys are collected under keys with the 'train/' prefix. 49 | "val/loss", # The parameters specified by eval_report_keys are collected under keys with the 'val/' prefix. 50 | ], 51 | ) 52 | ) 53 | trainer.extend( 54 | ppe.training.extensions.snapshot(target=model) 55 | ) # Save the model parameters after each epoch. 56 | 57 | ppe.to(model, device=device) 58 | 59 | batch_size = 10 60 | training_data = [ 61 | { 62 | "x": torch.rand((batch_size, 64)), 63 | "target": torch.ones((batch_size,), dtype=torch.long), 64 | } 65 | for _ in range(10) 66 | ] 67 | validation_data = [ 68 | { 69 | "x": torch.rand((batch_size, 64)), 70 | "target": torch.ones((batch_size,), dtype=torch.long), 71 | } 72 | for _ in range(10) 73 | ] 74 | 75 | trainer.run(train_loader=training_data, val_loader=validation_data) 76 | 77 | print("Finish training!") 78 | -------------------------------------------------------------------------------- /docs/source/_example/quick_start_trainer.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class Model(torch.nn.Module): 6 | def __init__(self, *args, **kwargs) -> None: 7 | super().__init__(*args, **kwargs) 8 | self.linear = torch.nn.Linear(in_features=64, out_features=2) 9 | self.criterion = torch.nn.NLLLoss() 10 | 11 | def forward(self, x, target): 12 | y = self.linear.forward(x).log_softmax(dim=1) 13 | loss = self.criterion.forward(y, target) 14 | return {"loss": loss} 15 | 16 | 17 | model = Model() 18 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 19 | 20 | device = ( 21 | "cuda:0" # or any other PyTorch devices ('cpu', etc.) or PPE runtime names 22 | ) 23 | epochs = 3 24 | # Create a trainer with the defined model, optimizer, and other parameters 25 | trainer = ppe.engine.create_trainer( 26 | models=model, 27 | optimizers=optimizer, 28 | max_epochs=epochs, 29 | evaluator=ppe.engine.create_evaluator( 30 | models=model, 31 | device=device, 32 | ), 33 | device=device, 34 | ) 35 | 36 | # Send the model to device(GPU) for computation 37 | ppe.to(model, device=device) 38 | 39 | batch_size = 10 40 | # Create 10 batches of random training data with dimension (batch_size x 64) 41 | training_data = [ 42 | { 43 | "x": torch.rand((batch_size, 64)), 44 | "target": torch.ones((batch_size,), dtype=torch.long), 45 | } 46 | for _ in range(10) 47 | ] 48 | # Create 10 batches of random validation data with dimension (batch_size x 64) 49 | validation_data = [ 50 | { 51 | "x": torch.rand((batch_size, 64)), 52 | "target": torch.ones((batch_size,), dtype=torch.long), 53 | } 54 | for _ in range(10) 55 | ] 56 | 57 | # Start the training and validation of the model 58 | trainer.run(train_loader=training_data, val_loader=validation_data) 59 | 60 | print("Finish training!") 61 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :special-members: __init__, __call__ 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | {% block methods %} 12 | {% if methods %} 13 | .. rubric:: {{ _('Methods') }} 14 | 15 | .. autosummary:: 16 | {% for item in methods %} 17 | ~{{ name }}.{{ item }} 18 | {%- endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | 22 | {% block attributes %} 23 | {% if attributes %} 24 | .. rubric:: {{ _('Attributes') }} 25 | 26 | .. autosummary:: 27 | {% for item in attributes %} 28 | ~{{ name }}.{{ item }} 29 | {%- endfor %} 30 | {% endif %} 31 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: {{ _('Module Attributes') }} 8 | 9 | .. autosummary:: 10 | :toctree: . 11 | 12 | {% for item in attributes %} 13 | {{ fullname }}.{{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :toctree: . 24 | 25 | {% for item in functions %} 26 | {{ fullname }}.{{ item }} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} 30 | 31 | {% block classes %} 32 | {% if classes %} 33 | .. rubric:: {{ _('Classes') }} 34 | 35 | .. autosummary:: 36 | :toctree: . 37 | 38 | {% for item in classes %} 39 | {{ fullname }}.{{ item }} 40 | {%- endfor %} 41 | {% endif %} 42 | {% endblock %} 43 | 44 | {% block exceptions %} 45 | {% if exceptions %} 46 | .. rubric:: {{ _('Exceptions') }} 47 | 48 | .. autosummary:: 49 | :toctree: . 50 | 51 | {% for item in exceptions %} 52 | {{ fullname }}.{{ item }} 53 | {%- endfor %} 54 | {% endif %} 55 | {% endblock %} 56 | 57 | {% block modules %} 58 | {% if modules %} 59 | .. rubric:: Modules 60 | 61 | .. autosummary:: 62 | :toctree: . 63 | :recursive: 64 | {% for item in modules %} 65 | {{ item }} 66 | {%- endfor %} 67 | {% endif %} 68 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "pytorch-pfn-extras" 21 | copyright = "2021, Preferred Networks, Inc." 22 | author = "Preferred Networks, Inc." 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "sphinx.ext.autosummary", 32 | "sphinx.ext.napoleon", 33 | "myst_parser", 34 | ] 35 | 36 | autodoc_typehints = "description" 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ["_templates"] 40 | 41 | # List of patterns, relative to source directory, that match files and 42 | # directories to ignore when looking for source files. 43 | # This pattern also affects html_static_path and html_extra_path. 44 | exclude_patterns = [] 45 | 46 | # Autosummary 47 | autosummary_generate = True 48 | autosummary_imported_members = True 49 | autoclass_content = "both" 50 | 51 | # -- Options for HTML output ------------------------------------------------- 52 | 53 | # The theme to use for HTML and HTML Help pages. See the documentation for 54 | # a list of builtin themes. 55 | # 56 | html_theme = "pydata_sphinx_theme" 57 | 58 | # Add any paths that contain custom static files (such as style sheets) here, 59 | # relative to this directory. They are copied after the builtin static files, 60 | # so a file named "default.css" will overwrite the builtin "default.css". 61 | html_static_path = ["_static"] 62 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | pytorch-pfn-extras 2 | ================== 3 | 4 | **pytorch-pfn-extras** (PPE) is a collection of supplementary components to accelerate research and development in PyTorch. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | user_guide/index 10 | reference/index 11 | -------------------------------------------------------------------------------- /docs/source/reference/generated/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /docs/source/user_guide/codeblocks.rst: -------------------------------------------------------------------------------- 1 | CodeBlocks for Abstracting Logic Steps 2 | ======================================== 3 | 4 | The :class:`ppe.handler.CodeBlock ` API 5 | provides a mean of abstracting the actions that are possible to be done in a model 6 | in a device agnostic way. 7 | 8 | Currently there is support for two different actions using ``CodeBlock``. 9 | 10 | - :function:`ppe.handler.update_parameters ` 11 | takes a model, an optimizer and returns a ``CodeBlock`` object that performs the forward, backward and optimizer step at once. 12 | 13 | - :function:`ppe.handler.forward ` 14 | takes a model and returns a ``CodeBlock`` object that performs only the forward pass. 15 | 16 | Executing CodeBlocks 17 | ------------------------------- 18 | 19 | For executing ``CodeBlock`` objects we need to add an :method:`ppe.runtime.BaseRuntime.execute ` API is in charge of abstracting the algorithmic details of the training and evaluation loops. 5 | 6 | Logic is an object that defines multiple callbacks used 7 | through the training and evaluation processes. 8 | With logic, we can implement training of complex models such as GANs. 9 | 10 | Users wanting to define their own Logic for training can inherit from 11 | :class:`ppe.handler.Logic ` which implements the training and evaluation steps to train 12 | a single module. 13 | 14 | Logic functions are not exepcted to be directly called by the user. 15 | They will be invoked by the Trainer and Evaluator engines. 16 | 17 | Default Logic (:class:`ppe.handler.Logic `) 18 | ------------------------------------------------------------------------------------------ 19 | 20 | PPE provides a default logic that performs the forward/backward/optimizer loop 21 | for a single model. This logic allows using some torch features such as AMP autocast 22 | and GradScaler and performs the backward pass on the outputs specified by the 23 | config option backward_outputs. 24 | 25 | CodeBlock Logic (:class:`ppe.handler.Logic `) 26 | ------------------------------------------------------------------------------------------ 27 | 28 | With the CodeBlock API, we provide a basic logic that uses it to perform the training 29 | of a single model. Similarly to the default logic AMP features are supported but 30 | by means of the Runtime. For more information check the codeblock documentation. 31 | -------------------------------------------------------------------------------- /docs/source/user_guide/quick_start.rst: -------------------------------------------------------------------------------- 1 | Quick Start 2 | =========== 3 | 4 | First, pytorch-pfn-extras organizes the training code 5 | implemented using PyTorch using the Trainer/Evaluator classes. 6 | 7 | Next, it provides the following interfaces for training PyTorch models. 8 | 9 | 1. Addition of extensions for analysis and visualization 10 | 2. Runtime changes 11 | 3. Addition of custom training steps 12 | 4. Custom data handling 13 | 14 | Step 1: Use Trainer 15 | ------------------- 16 | 17 | First, pass to the Trainer the Model and Optimizer you want to train. 18 | 19 | .. literalinclude:: /_example/quick_start_trainer.py 20 | :language: python 21 | :caption: quick_start_trainer.py 22 | 23 | 24 | Step 2: Get Log 25 | --------------- 26 | 27 | Next, collect the logs of the training progress. 28 | 29 | .. literalinclude:: /_example/quick_start_log.py 30 | :language: python 31 | :caption: quick_start_log.py 32 | 33 | The logs of the collected learning progress are output to ``./result/log``. 34 | 35 | Step 3: Display of progress 36 | --------------------------- 37 | 38 | Make it possible to check the progress of the learning. 39 | 40 | .. literalinclude:: /_example/quick_start_progress.py 41 | :language: python 42 | :caption: quick_start_progress.py 43 | 44 | Step 4: Save Model 45 | ------------------ 46 | 47 | Finally, save the trained model. 48 | 49 | .. literalinclude:: /_example/quick_start_save.py 50 | :language: python 51 | :caption: quick_start_save.py 52 | 53 | The model parameters are stored with a file name that includes the time they were saved under ``./result``. 54 | 55 | Snapshots are generated using ``state_dict()``. Please refer to the official PyTorch `docs `_ for how to load the model. 56 | -------------------------------------------------------------------------------- /docs/source/user_guide/reporting.md: -------------------------------------------------------------------------------- 1 | # Reporting 2 | 3 | `reporting.Reporter` is used to collect values that users want to watch. 4 | The reporter object holds a mapping from value names to the actually observed values. We call this mapping observations. 5 | 6 | When a value is passed to the reporter, an object called observer can be optionally attached. In this case, the name of the observer is added as the prefix of the value name. The observer name should be registered beforehand. 7 | 8 | ```python 9 | import pytorch_pfn_extras as ppe 10 | 11 | reporter = ppe.reporting.Reporter() 12 | observer = object() 13 | reporter.add_observer('my_observer', observer) 14 | observation = {} 15 | 16 | with reporter.scope(observation): 17 | reporter.report({'x': 1}, observer) 18 | 19 | print(observation) 20 | # outputs: {'my_observer/x': 1} 21 | ``` 22 | 23 | There is also a global API to add values: 24 | 25 | ```python 26 | import pytorch_pfn_extras as ppe 27 | 28 | reporter = ppe.reporting.Reporter() 29 | observer = object() 30 | reporter.add_observer('my_observer', observer) 31 | 32 | observation = {} 33 | with reporter: 34 | with ppe.reporting.report_scope(observation): 35 | ppe.reporting.report({'x': 1}, observer) 36 | 37 | print(observation) 38 | # outputs: {'my_observer/x': 1} 39 | ``` 40 | 41 | The most important application of Reporter is to report observed values from different parts of the model in the training 42 | and validation procedures. `ExtensionsManager` objects hold their own `Reporter` object with the parameters of the target 43 | module registered as observers. `report()` can be used inside the modules to report the observed values (e.g., training loss, 44 | accuracy, activation statistics, etc.). 45 | -------------------------------------------------------------------------------- /docs/source/user_guide/snapshot.md: -------------------------------------------------------------------------------- 1 | # Distributed Snapshot 2 | 3 | To take snapshots when using `torch.distributed` the only needed step is to 4 | provide the `saver_rank` keyword argument to the regular snapshot extension. 5 | 6 | ```python 7 | # saver_rank is the MPI rank which will write the actual snapshot. 8 | snapshot = extensions.snapshot(saver_rank=saver_rank) 9 | ``` 10 | 11 | To resume the training, snapshots are loaded in every worker by using the 12 | `ExtensionsManager.load_state_dict` method, or the `extensions.snapshot` 13 | `autoload` keyword argument. 14 | -------------------------------------------------------------------------------- /example/.gitignore: -------------------------------------------------------------------------------- 1 | result/ 2 | -------------------------------------------------------------------------------- /example_pysen.toml: -------------------------------------------------------------------------------- 1 | [tool.pysen] 2 | version = "0.10.1" 3 | 4 | [tool.pysen.lint] 5 | enable_black = false 6 | enable_flake8 = false 7 | enable_isort = false 8 | enable_mypy = true 9 | mypy_preset = "entry" 10 | line_length = 80 11 | py_version = "py38" 12 | mypy_path = ["./stubs"] 13 | 14 | [tool.pysen.lint.mypy_modules."torch.*"] 15 | ignore_errors = true 16 | 17 | [[tool.pysen.lint.mypy_targets]] 18 | paths = ["./example"] 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | 4 | [tool.pysen] 5 | version = "0.11.0" 6 | 7 | [tool.pysen.lint] 8 | enable_black = true 9 | enable_flake8 = true 10 | enable_isort = true 11 | enable_mypy = true 12 | mypy_preset = "strict" 13 | line_length = 80 14 | py_version = "py39" 15 | mypy_path = ["./stubs"] 16 | 17 | [[tool.pysen.lint.mypy_targets]] 18 | paths = ["pytorch_pfn_extras"] 19 | 20 | [tool.pysen.lint.mypy_modules."torch.*"] 21 | ignore_errors = true 22 | 23 | [tool.pysen.lint.source] 24 | includes = ["."] 25 | excludes = ["pytorch_pfn_extras/onnx/", "tests/pytorch_pfn_extras_tests/onnx_tests/"] 26 | 27 | [tool.black] # automatically generated by pysen 28 | # pysen ignores and overwrites any modifications 29 | line-length = 80 30 | target-version = ["py39"] 31 | 32 | [tool.isort] # automatically generated by pysen 33 | # pysen ignores and overwrites any modifications 34 | default_section = "THIRDPARTY" 35 | ensure_newline_before_comments = true 36 | force_grid_wrap = 0 37 | force_single_line = false 38 | include_trailing_comma = true 39 | line_length = 80 40 | multi_line_output = 3 41 | use_parentheses = true 42 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/__init__.py: -------------------------------------------------------------------------------- 1 | # Configure the logging before instantiating anything else 2 | from pytorch_pfn_extras import logging # NOQA 3 | 4 | logging._configure_logging() 5 | 6 | from pytorch_pfn_extras import config # NOQA 7 | from pytorch_pfn_extras import cuda # NOQA 8 | from pytorch_pfn_extras import dataloaders # NOQA 9 | from pytorch_pfn_extras import dataset # NOQA 10 | from pytorch_pfn_extras import distributed # NOQA 11 | from pytorch_pfn_extras import engine # NOQA 12 | from pytorch_pfn_extras import handler # NOQA 13 | from pytorch_pfn_extras import nn # NOQA 14 | from pytorch_pfn_extras import profiler # NOQA 15 | from pytorch_pfn_extras import reporting # NOQA 16 | from pytorch_pfn_extras import runtime # NOQA 17 | from pytorch_pfn_extras import training # NOQA 18 | from pytorch_pfn_extras import utils # NOQA 19 | from pytorch_pfn_extras import writing # NOQA 20 | from pytorch_pfn_extras._tensor import as_ndarray # NOQA 21 | from pytorch_pfn_extras._tensor import as_numpy_dtype # NOQA 22 | from pytorch_pfn_extras._tensor import from_ndarray # NOQA 23 | from pytorch_pfn_extras._tensor import from_numpy_dtype # NOQA 24 | from pytorch_pfn_extras._tensor import get_xp # NOQA 25 | from pytorch_pfn_extras._torch_version import requires # NOQA 26 | from pytorch_pfn_extras._version import __version__ # NOQA 27 | from pytorch_pfn_extras.runtime._map import map # NOQA 28 | from pytorch_pfn_extras.runtime._to import to # NOQA 29 | 30 | if requires("2.0.0"): 31 | from pytorch_pfn_extras import ops # NOQA 32 | from pytorch_pfn_extras._dynamo import compile # NOQA 33 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/_cupy/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cupy # NOQA 3 | 4 | _cupy_import_error = None 5 | except Exception as e: 6 | from pytorch_pfn_extras._cupy import _cupy_stub as cupy # NOQA 7 | 8 | _cupy_import_error = e 9 | 10 | 11 | def ensure_cupy() -> None: 12 | if _cupy_import_error is not None: 13 | raise RuntimeError( 14 | f"CuPy is not available. Reason:\n{_cupy_import_error}" 15 | ) 16 | 17 | 18 | def is_available() -> bool: 19 | return _cupy_import_error is None 20 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/_cupy/_cupy_stub.py: -------------------------------------------------------------------------------- 1 | class ndarray: 2 | pass 3 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/_dynamo/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras._dynamo._compile import compile # NOQA 2 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/_torch_version.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from packaging.version import Version 4 | 5 | 6 | def requires(version: str, package: str = "torch") -> bool: 7 | pkg_ver = importlib.metadata.version(package) 8 | return Version(pkg_ver.split("+")[0].split("-")[0]) >= Version(version) 9 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.8.2" 2 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/config_types.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import TYPE_CHECKING, Any, Callable, Dict, Optional 3 | 4 | from pytorch_pfn_extras import config 5 | 6 | if TYPE_CHECKING: 7 | import optuna 8 | 9 | 10 | def optuna_types(trial: "optuna.trial.Trial") -> Dict[str, Any]: 11 | types = { 12 | "optuna_suggest_categorical": trial.suggest_categorical, 13 | "optuna_suggest_discrete_uniform": trial.suggest_discrete_uniform, 14 | "optuna_suggest_float": trial.suggest_float, 15 | "optuna_suggest_int": trial.suggest_int, 16 | "optuna_suggest_loguniform": trial.suggest_loguniform, 17 | "optuna_suggest_uniform": trial.suggest_uniform, 18 | } 19 | return types 20 | 21 | 22 | def load_path_with_optuna_types( 23 | path: str, 24 | trial: "optuna.trial.Trial", 25 | loader: Optional[config.Loader] = None, 26 | types: Optional[Dict[str, Callable[..., Any]]] = None, 27 | ) -> config.Config: 28 | if types is None: 29 | types = {} 30 | for key, value in optuna_types(trial).items(): 31 | if key in types: 32 | warnings.warn(key + " is overwritten by optuna suggest.") 33 | types[key] = value 34 | return config.Config.load_path(path, loader=loader, types=types) 35 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.cuda._allocator import stream # NOQA 2 | from pytorch_pfn_extras.cuda._allocator import use_torch_mempool_in_cupy # NOQA 3 | from pytorch_pfn_extras.cuda._allocator import ( # NOQA 4 | use_default_mempool_in_cupy, 5 | ) 6 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/cuda/_allocator.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Any, Generator, Optional 3 | 4 | import torch 5 | from pytorch_pfn_extras._cupy import cupy, ensure_cupy, is_available 6 | 7 | _allocator = None 8 | 9 | 10 | @contextlib.contextmanager 11 | def stream(stream: Optional[torch.cuda.Stream]) -> Generator[None, None, None]: 12 | """Context-manager that selects a given stream. 13 | 14 | This context manager also changes the CuPy's default stream if CuPy 15 | is available. When CuPy is not available, the functionality is the same 16 | as the PyTorch's counterpart, `torch.cuda.stream()`. 17 | """ 18 | 19 | if stream is None: 20 | yield 21 | return 22 | 23 | with torch.cuda.stream(stream): 24 | if is_available(): 25 | cupy_stream = cupy.cuda.ExternalStream( 26 | stream.cuda_stream, device_id=stream.device.index 27 | ) 28 | with cupy_stream: 29 | yield 30 | else: 31 | yield 32 | 33 | 34 | def use_default_mempool_in_cupy() -> None: 35 | """Use the default memory pool in CuPy.""" 36 | ensure_cupy() 37 | cupy.cuda.set_allocator(cupy.get_default_memory_pool().malloc) 38 | 39 | 40 | def use_torch_mempool_in_cupy() -> None: 41 | """Use the PyTorch memory pool in CuPy. 42 | 43 | If you want to use PyTorch's memory pool and non-default CUDA streams, 44 | streams must be created and managed using PyTorch (using 45 | `torch.cuda.Stream()` and `pytorch_pfn_extras.cuda.stream(stream)`). 46 | """ 47 | global _allocator 48 | 49 | ensure_cupy() 50 | _allocator = cupy.cuda.memory.PythonFunctionAllocator( 51 | _torch_alloc, _torch_free 52 | ) 53 | cupy.cuda.set_allocator(_allocator.malloc) 54 | 55 | 56 | def _torch_alloc(size: int, device_id: int) -> Any: 57 | torch_stream_ptr = torch.cuda.current_stream().cuda_stream 58 | cupy_stream_ptr = cupy.cuda.get_current_stream().ptr 59 | if torch_stream_ptr != cupy_stream_ptr: 60 | raise RuntimeError( 61 | "The current stream set in PyTorch and CuPy must be same." 62 | " Use `pytorch_pfn_extras.cuda.stream` instead of" 63 | " `torch.cuda.stream`." 64 | ) 65 | return torch.cuda.caching_allocator_alloc(size, device_id, torch_stream_ptr) 66 | 67 | 68 | def _torch_free(mem_ptr: int, device_id: int) -> None: 69 | torch.cuda.caching_allocator_delete(mem_ptr) # type: ignore[no-untyped-call] 70 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.dataloaders import utils # NOQA 2 | from pytorch_pfn_extras.dataloaders.dataloader import DataLoader # NOQA 3 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataloaders/dataloader.py: -------------------------------------------------------------------------------- 1 | # Kept for backward compatibility 2 | from torch.utils.data.dataloader import * # NOQA 3 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Sequence 2 | 3 | import torch 4 | 5 | 6 | class CollateAsDict: 7 | """Creates a collate function that converts inputs to a dict of tensors. 8 | 9 | An instantiated callable object can be feeded to 10 | :class:`torch.utils.data.DataLoader` as a ``collate_fn`` option. 11 | 12 | Args: 13 | names (list of str): Names of keys of output dict. 14 | collate_fn (function): A function preprocesses inputs. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | names: Sequence[str], 20 | collate_fn: Callable[ 21 | ..., Any 22 | ] = torch.utils.data._utils.collate.default_collate, 23 | ) -> None: 24 | self.names = names 25 | self.collate_fn = collate_fn 26 | 27 | def __call__(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: 28 | """Converts inputs the dataset generated to a dictionary of tensors. 29 | 30 | Returns (dict of Tensor): 31 | A dictionary with keys that specified as ``names`` option, and 32 | values as input tensors. 33 | """ 34 | batch = self.collate_fn(*args, **kwargs) 35 | return {name: v for name, v in zip(self.names, batch)} 36 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.dataset.shared_dataset import SharedDataset # NOQA 2 | from pytorch_pfn_extras.dataset.shared_dataset import ( # NOQA 3 | ItemNotFoundException, 4 | ) 5 | from pytorch_pfn_extras.dataset.tabular.tabular_dataset import ( # NOQA 6 | TabularDataset, 7 | ) 8 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/shared_dataset.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | import ctypes 4 | import multiprocessing 5 | 6 | import numpy 7 | import torch 8 | 9 | 10 | class Cache: 11 | def is_cached(self, idx): 12 | raise NotImplementedError 13 | 14 | def add_to_cache(self, idx, x): 15 | raise NotImplementedError 16 | 17 | def get_value(self, idx): 18 | raise NotImplementedError 19 | 20 | 21 | class InfiniteCache(Cache): 22 | def __init__(self, sm_size): 23 | super().__init__() 24 | self.sm_size = sm_size 25 | total_size = 1 26 | for x in sm_size: 27 | total_size *= x 28 | shared_memory = multiprocessing.Array(ctypes.c_float, total_size) 29 | storage = numpy.ctypeslib.as_array(shared_memory.get_obj()) 30 | self.storage = storage.reshape(sm_size) 31 | # This requires a continuous data loader for the cached values not 32 | # to be lost 33 | cached_ids = multiprocessing.Array(ctypes.c_bool, sm_size[0]) 34 | self.cached_ids = numpy.ctypeslib.as_array(cached_ids.get_obj()) 35 | 36 | def is_cached(self, idx): 37 | return self.cached_ids[idx] == 1 38 | 39 | def get_value(self, idx): 40 | x = None 41 | if self.is_cached(idx): 42 | x = self.storage[idx] 43 | return x 44 | 45 | def add_to_cache(self, idx, x): 46 | if isinstance(x, torch.Tensor): 47 | x = x.detach().cpu().numpy() 48 | self.storage[idx] = x 49 | self.cached_ids[idx] = 1 50 | 51 | 52 | class ItemNotFoundException(Exception): 53 | pass 54 | 55 | 56 | class SharedDataset(torch.utils.data.Dataset): 57 | """Dataset that caches the load samples in shared memory 58 | 59 | Args 60 | """ 61 | 62 | def __init__(self, sm_size, cache_type=InfiniteCache): 63 | super().__init__() 64 | self.cache = cache_type(sm_size) 65 | 66 | def __getitem__(self, idx): 67 | x = self.cache.get_value(idx) 68 | if x is None: 69 | raise ItemNotFoundException( 70 | "Item {} is not in the cache".format(idx) 71 | ) 72 | return x 73 | 74 | def is_cached(self, idx): 75 | return self.cache.is_cached(idx) 76 | 77 | def cache_item(self, idx, x): 78 | self.cache.add_to_cache(idx, x) 79 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/tabular/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.dataset.tabular import _asmode # NOQA 2 | from pytorch_pfn_extras.dataset.tabular import _concat # NOQA 3 | from pytorch_pfn_extras.dataset.tabular import _join # NOQA 4 | from pytorch_pfn_extras.dataset.tabular import _slice # NOQA 5 | from pytorch_pfn_extras.dataset.tabular import _transform # NOQA 6 | from pytorch_pfn_extras.dataset.tabular import _with_converter # NOQA 7 | from pytorch_pfn_extras.dataset.tabular.delegate_dataset import ( # NOQA 8 | DelegateDataset, 9 | ) 10 | from pytorch_pfn_extras.dataset.tabular.from_data import from_data # NOQA 11 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/tabular/_asmode.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from pytorch_pfn_extras.dataset.tabular import tabular_dataset 4 | 5 | 6 | class _Astuple(tabular_dataset.TabularDataset): 7 | def __init__(self, dataset): 8 | self._dataset = dataset 9 | 10 | def __len__(self): 11 | return len(self._dataset) 12 | 13 | @property 14 | def keys(self): 15 | return self._dataset.keys 16 | 17 | @property 18 | def mode(self): 19 | return tuple 20 | 21 | def get_examples(self, indices, key_indices): 22 | return self._dataset.get_examples(indices, key_indices) 23 | 24 | def convert(self, data): 25 | return self._dataset.convert(data) 26 | 27 | 28 | class _Asdict(tabular_dataset.TabularDataset): 29 | def __init__(self, dataset): 30 | self._dataset = dataset 31 | 32 | def __len__(self): 33 | return len(self._dataset) 34 | 35 | @property 36 | def keys(self): 37 | return self._dataset.keys 38 | 39 | @property 40 | def mode(self): 41 | return dict 42 | 43 | def get_examples(self, indices, key_indices): 44 | return self._dataset.get_examples(indices, key_indices) 45 | 46 | def convert(self, data): 47 | return self._dataset.convert(data) 48 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/tabular/_join.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from pytorch_pfn_extras.dataset.tabular import tabular_dataset 4 | 5 | 6 | class _Join(tabular_dataset.TabularDataset): 7 | def __init__(self, *datasets): 8 | keys = set(datasets[0].keys) 9 | for dataset in datasets[1:]: 10 | if not len(dataset) == len(datasets[0]): 11 | raise ValueError("All datasets must have the same length") 12 | if len(keys.intersection(dataset.keys)) > 0: 13 | raise ValueError("All keys must be unique among all datasets") 14 | keys = keys.union(dataset.keys) 15 | 16 | self._datasets = datasets 17 | 18 | def __len__(self): 19 | return len(self._datasets[0]) 20 | 21 | @property 22 | def keys(self): 23 | return tuple(key for dataset in self._datasets for key in dataset.keys) 24 | 25 | @property 26 | def mode(self): 27 | for dataset in self._datasets: 28 | if dataset.mode: 29 | return dataset.mode 30 | return tuple 31 | 32 | def get_examples(self, indices, key_indices): 33 | if key_indices is None: 34 | return tuple( 35 | col 36 | for dataset in self._datasets 37 | for col in dataset.get_examples(indices, None) 38 | ) 39 | 40 | examples = {} 41 | key_offset = 0 42 | for dataset in self._datasets: 43 | sub_key_indices = [] 44 | for key_index in key_indices: 45 | sub_key_index = key_index - key_offset 46 | if sub_key_index < 0 or len(dataset.keys) <= sub_key_index: 47 | continue 48 | if sub_key_index not in sub_key_indices: 49 | sub_key_indices.append(sub_key_index) 50 | 51 | if len(sub_key_indices) > 0: 52 | sub_key_indices = tuple(sub_key_indices) 53 | sub_examples = dataset.get_examples(indices, sub_key_indices) 54 | for sub_key_index, col_example in zip( 55 | sub_key_indices, sub_examples 56 | ): 57 | examples[key_offset + sub_key_index] = col_example 58 | 59 | key_offset += len(dataset.keys) 60 | 61 | return tuple(examples[key_index] for key_index in key_indices) 62 | 63 | def convert(self, data): 64 | return self._datasets[0].convert(data) 65 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/tabular/_slice.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from pytorch_pfn_extras.dataset.tabular import _utils, tabular_dataset 4 | 5 | 6 | class _Slice(tabular_dataset.TabularDataset): 7 | def __init__(self, dataset, indices, keys): 8 | if keys is None: 9 | self._unary = None 10 | elif isinstance(keys, tuple): 11 | self._unary = False 12 | else: 13 | self._unary = True 14 | keys = (keys,) 15 | 16 | self._dataset = dataset 17 | self._indices = _utils._as_indices(indices, len(dataset)) 18 | self._key_indices = _utils._as_key_indices(keys, dataset.keys) 19 | 20 | def __len__(self): 21 | if self._indices is None: 22 | return len(self._dataset) 23 | elif isinstance(self._indices, slice): 24 | start, stop, step = self._indices.indices(len(self._dataset)) 25 | return len(range(start, stop, step)) 26 | else: 27 | return len(self._indices) 28 | 29 | @property 30 | def keys(self): 31 | if self._key_indices is None: 32 | return self._dataset.keys 33 | else: 34 | return tuple( 35 | self._dataset.keys[key_index] for key_index in self._key_indices 36 | ) 37 | 38 | @property 39 | def mode(self): 40 | if self._unary is None: 41 | return self._dataset.mode 42 | elif self._unary: 43 | return None 44 | else: 45 | return self._dataset.mode or tuple 46 | 47 | def get_examples(self, indices, key_indices): 48 | indices = _utils._merge_indices( 49 | self._indices, indices, len(self._dataset), len(self) 50 | ) 51 | key_indices = _utils._merge_key_indices(self._key_indices, key_indices) 52 | return self._dataset.get_examples(indices, key_indices) 53 | 54 | def convert(self, data): 55 | return self._dataset.convert(data) 56 | 57 | 58 | class _SliceHelper(object): 59 | def __init__(self, dataset): 60 | self._dataset = dataset 61 | 62 | def __getitem__(self, args): 63 | if isinstance(args, tuple): 64 | indices, keys = args 65 | else: 66 | indices = args 67 | keys = None 68 | 69 | return _Slice(self._dataset, indices, keys) 70 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/tabular/_with_converter.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from pytorch_pfn_extras.dataset.tabular import tabular_dataset 4 | 5 | 6 | class _WithConverter(tabular_dataset.TabularDataset): 7 | def __init__(self, dataset, converter): 8 | self._dataset = dataset 9 | self._converter = converter 10 | 11 | def __len__(self): 12 | return len(self._dataset) 13 | 14 | @property 15 | def keys(self): 16 | return self._dataset.keys 17 | 18 | @property 19 | def mode(self): 20 | return self._dataset.mode 21 | 22 | def get_examples(self, indices, key_indices): 23 | return self._dataset.get_examples(indices, key_indices) 24 | 25 | def convert(self, data): 26 | if isinstance(data, tuple): 27 | return self._converter(*data) 28 | elif isinstance(data, dict): 29 | return self._converter(**data) 30 | else: 31 | return self._converter(data) 32 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/dataset/tabular/delegate_dataset.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from pytorch_pfn_extras.dataset.tabular import tabular_dataset 4 | 5 | 6 | class DelegateDataset(tabular_dataset.TabularDataset): 7 | """A helper class to implement a TabularDataset. 8 | 9 | This class wraps an instance of 10 | :class:`~pytorch_pfn_extras.dataset.TabularDataset` 11 | and provides methods of 12 | :class:`~pytorch_pfn_extras.dataset.TabularDataset`. 13 | This class is useful to create a custom dataset class by inheriting it. 14 | 15 | >>> import numpy as np 16 | >>> 17 | >>> from pytorch_pfn_extras.dataset import tabular 18 | >>> 19 | >>> class MyDataset(tabular.DelegateDataset): 20 | ... 21 | ... def __init__(self): 22 | ... super().__init__(tabular.from_data(( 23 | ... ('a', np.arange(10)), 24 | ... ('b', self.get_b), 25 | ... ('c', [3, 1, 4, 5, 9, 2, 6, 8, 7, 0]), 26 | ... (('d', 'e'), self.get_de)))) 27 | ... 28 | ... def get_b(self, i): 29 | ... return 'b[{}]'.format(i) 30 | ... 31 | ... def get_de(self, i): 32 | ... return {'d': 'd[{}]'.format(i), 'e': 'e[{}]'.format(i)} 33 | ... 34 | >>> dataset = MyDataset() 35 | >>> len(dataset) 36 | 10 37 | >>> dataset.keys 38 | ('a', 'b', 'c', 'd', 'e') 39 | >>> dataset[0] 40 | (0, 'b[0]', 3, 'd[0]', 'e[0]') 41 | 42 | Args: 43 | dataset (pytorch_pfn_extras.dataset.TabularDataset): 44 | An underlying dataset. 45 | 46 | """ 47 | 48 | def __init__(self, dataset): 49 | self.dataset = dataset 50 | 51 | def __len__(self): 52 | return len(self.dataset) 53 | 54 | @property 55 | def keys(self): 56 | return self.dataset.keys 57 | 58 | @property 59 | def mode(self): 60 | return self.dataset.mode 61 | 62 | def get_examples(self, indices, key_indices): 63 | return self.dataset.get_examples(indices, key_indices) 64 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.distributed._dataset_util import ( # NOQA 2 | create_distributed_subset_indices, 3 | ) 4 | from pytorch_pfn_extras.distributed._distributed_validation_sampler import ( # NOQA 5 | DistributedValidationSampler, 6 | ) 7 | from pytorch_pfn_extras.distributed._initialize import ( # NOQA 8 | initialize_ompi_environment, 9 | ) 10 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/distributed/_dataset_util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def _shared_random_seed() -> int: 8 | seed = torch.randint(0, 2**31, size=()) 9 | if torch.distributed.is_initialized(): # type: ignore 10 | if torch.distributed.get_backend() == "nccl": # type: ignore 11 | seed = seed.cuda() 12 | torch.distributed.broadcast(seed, 0) # type: ignore 13 | return int(seed) 14 | 15 | 16 | def create_distributed_subset_indices( 17 | num_total_samples: int, 18 | num_replicas: Optional[int] = None, 19 | rank: Optional[int] = None, 20 | shuffle: bool = True, 21 | seed: Optional[int] = None, 22 | ) -> List[int]: 23 | """Returns a indices of a dataset to be used for the current process. 24 | 25 | Args: 26 | num_total_samples: The size of the dataset. 27 | num_replicas: Number of processes participating in the training. 28 | By default, ``torch.distributed.get_world_size()`` is used. 29 | rank: Rank of the current process within `num_replicas`. 30 | By default, ``torch.distributed.get_rank()`` is used. 31 | shuffle: If ``True`` (default), shuffle the indices. 32 | seed: Random seed used to shuffle. 33 | """ 34 | if num_replicas is None: 35 | num_replicas = torch.distributed.get_world_size() # type: ignore 36 | if rank is None: 37 | rank = torch.distributed.get_rank() # type: ignore 38 | 39 | indices = list(range(num_total_samples)) 40 | if shuffle: 41 | if seed is None: 42 | seed = _shared_random_seed() 43 | rng = np.random.RandomState(seed) 44 | rng.shuffle(indices) 45 | n_sub_samples = (num_total_samples + num_replicas - 1) // num_replicas 46 | b = num_total_samples * rank // num_replicas 47 | e = b + n_sub_samples 48 | return indices[b:e] 49 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/distributed/_distributed_validation_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, Sized, TypeVar 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | 7 | T_co = TypeVar("T_co", covariant=True) 8 | 9 | 10 | class DistributedValidationSampler(torch.utils.data.Sampler): 11 | """Distributed sampler without duplication 12 | 13 | This sampler splits the input dataset to each worker process in distributed setup 14 | without allowing repetition. 15 | It is for evaluation purpose such as :class:`~DistributedEvaluator`. 16 | This does not guarantee each worker to get the same number of samples, 17 | so for training do not use this sampler (use PyTorch DistributedSampler instead). 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dataset: Sized, 23 | num_replicas: Optional[int] = None, 24 | rank: Optional[int] = None, 25 | shuffle: bool = True, 26 | seed: int = 0, 27 | ) -> None: 28 | if num_replicas is None: 29 | if not dist.is_available() or not dist.is_initialized(): # type: ignore[no-untyped-call] 30 | raise RuntimeError( 31 | "Requires distributed package to be available" 32 | ) 33 | num_replicas = dist.get_world_size() # type: ignore[no-untyped-call] 34 | if rank is None: 35 | if not dist.is_available() or not dist.is_initialized(): # type: ignore[no-untyped-call] 36 | raise RuntimeError( 37 | "Requires distributed package to be available" 38 | ) 39 | rank = dist.get_rank() # type: ignore[no-untyped-call] 40 | if rank >= num_replicas or rank < 0: 41 | raise ValueError( 42 | "Invalid rank {}, rank should be in the interval" 43 | " [0, {}]".format(rank, num_replicas - 1) 44 | ) 45 | self.dataset = dataset 46 | self.num_replicas = num_replicas 47 | self.rank = rank 48 | self.shuffle = shuffle 49 | self.seed = seed 50 | 51 | self.dataset_len = len(dataset) 52 | self.num_samples = len( 53 | np.array_split(range(self.dataset_len), num_replicas)[rank] 54 | ) 55 | 56 | def __iter__(self) -> Iterator[T_co]: 57 | if self.shuffle: 58 | # deterministically shuffle based on epoch and seed 59 | g = torch.Generator() 60 | g.manual_seed(self.seed) 61 | indices = torch.randperm(self.dataset_len, generator=g).tolist() 62 | else: 63 | indices = list(range(self.dataset_len)) 64 | 65 | return iter(np.array_split(indices, self.num_replicas)[self.rank]) 66 | 67 | def __len__(self) -> int: 68 | return self.num_samples 69 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/distributed/_initialize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import timedelta 3 | from typing import Tuple 4 | 5 | import torch.distributed 6 | 7 | 8 | def initialize_ompi_environment( 9 | *, 10 | backend: str = "gloo", 11 | init_method: str = "tcp", 12 | world_size: int = 1, 13 | rank: int = 0, 14 | local_rank: int = 0, 15 | addr: str = "localhost", 16 | port: str = "1234", 17 | timeout: int = 1800, 18 | ) -> Tuple[int, int, int]: 19 | """Initialize `torch.distributed` environments with values taken from 20 | OpenMPI. 21 | 22 | Args: 23 | backend: The backend to be used, only ``"gloo"`` and ``"nccl"`` are 24 | supported. Defaults to ``"gloo"``. 25 | init_method: Initialization method used by torch, only ``"tcp"`` and 26 | ``"env"`` are supported. Defaults to ``"tcp"``. 27 | world_size: The total world size to be used in case it is not specified 28 | in MPI env vars. Defaults to ``1``. 29 | rank: The process rank to be used in case it is not specified in MPI 30 | env vars. Defaults to ``0``. 31 | local_rank: The process local rank to be used in case it is not 32 | specified in MPI env vars. Defaults to ``0``. 33 | addr: The address of the master process of `torch.distributed`. 34 | Defaults to ``"localhost"`` 35 | port: The port of the master process of `torch.distributed`. 36 | Defaults to ``"1234"`` 37 | timeout: Timeout seconds for `torch.distributed` collective communication. 38 | Defaults to ``1800``. 39 | """ 40 | e = os.environ 41 | backend = backend 42 | # Ranks determined from mpirun 43 | world_size = int(e.get("OMPI_COMM_WORLD_SIZE", world_size)) 44 | rank = int(e.get("OMPI_COMM_WORLD_RANK", rank)) 45 | local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", local_rank)) 46 | addr = e.get("MASTER_ADDR", addr) 47 | port = e.get("MASTER_PORT", port) 48 | 49 | if backend not in ("gloo", "nccl"): 50 | raise ValueError( 51 | "Invalid value for backend, only 'gloo' and 'nccl' are supported" 52 | ) 53 | if init_method == "env": 54 | init_method = "env://" 55 | e["MASTER_ADDR"] = addr 56 | e["MASTER_PORT"] = port 57 | e["WORLD_SIZE"] = str(world_size) 58 | e["RANK"] = str(rank) 59 | e["LOCAL_RANK"] = str(local_rank) 60 | elif init_method == "tcp": 61 | init_method = f"tcp://{addr}:{port}" 62 | else: 63 | raise ValueError( 64 | "Invalid value for init_method, only 'env' and 'tcp' are supported" 65 | ) 66 | 67 | if world_size > 1 and not torch.distributed.is_initialized(): # type: ignore 68 | torch.distributed.init_process_group( # type: ignore 69 | backend, 70 | init_method=init_method, 71 | world_size=world_size, 72 | rank=rank, 73 | timeout=timedelta(seconds=timeout), 74 | ) 75 | torch.distributed.barrier() # type: ignore 76 | 77 | return world_size, rank, local_rank 78 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/handler/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.handler._code_block import ( # NOQA 2 | CodeBlock, 3 | forward, 4 | update_parameters, 5 | ) 6 | from pytorch_pfn_extras.handler._handler import BaseHandler, Handler # NOQA 7 | 8 | # Deprecated, only imported for backward compatibility 9 | from pytorch_pfn_extras.handler._logic import torch_autocast # NOQA 10 | from pytorch_pfn_extras.handler._logic import ( # NOQA 11 | BaseLogic, 12 | ClousureLogic, 13 | CodeBlockLogic, 14 | Logic, 15 | ) 16 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from logging import CRITICAL, DEBUG, ERROR, INFO, WARNING # NOQA 4 | from typing import Optional 5 | 6 | _logger_name = "ppe" 7 | _logger_format = "[%(name)s] %(asctime)s: (%(levelname)s) %(message)s" 8 | _logger = None 9 | 10 | 11 | def _configure_logging( 12 | *, 13 | filename: Optional[str] = None, 14 | level: str = "ERROR", 15 | format: str = _logger_format, 16 | ) -> None: 17 | global _logger 18 | filename = os.environ.get("PPE_LOG_FILENAME", filename) 19 | if filename is None: 20 | handler: logging.Handler = logging.StreamHandler() 21 | else: 22 | handler = logging.FileHandler(filename) 23 | handler.setFormatter(logging.Formatter(format)) 24 | # To dynamically change the level if needed 25 | # basicConfig does not allow to change the level right after 26 | _logger = logging.getLogger(_logger_name) 27 | level = os.environ.get("PPE_LOG_LEVEL", level) 28 | for lvl in ( 29 | logging.DEBUG, 30 | logging.INFO, 31 | logging.WARNING, 32 | logging.ERROR, 33 | logging.CRITICAL, 34 | ): 35 | if logging.getLevelName(lvl) == level: 36 | _logger.setLevel(lvl) 37 | break 38 | else: 39 | _logger.setLevel(logging.INFO) 40 | _logger.warning("invalid PPE_LOG_LEVEL (%s); using INFO", level) 41 | _logger.addHandler(handler) 42 | 43 | 44 | def _get_root_logger() -> logging.Logger: 45 | """Returns a logger to be used by pytorch-pfn-extras.""" 46 | assert _logger is not None 47 | return _logger 48 | 49 | 50 | def get_logger(name: str) -> logging.Logger: 51 | """Returns a child logger to be used by applications. 52 | 53 | Args: 54 | name (str): Name used to register and retrieve the logger object. 55 | 56 | Returns: 57 | A logging.Logger object used to log in the application code. 58 | """ 59 | return _get_root_logger().getChild(name) 60 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.nn import parallel # NOQA 2 | from pytorch_pfn_extras.nn.modules.ensure_shape import Ensure, ensure # NOQA 3 | from pytorch_pfn_extras.nn.modules.extended_sequential import ( # NOQA 4 | ExtendedSequential, 5 | ) 6 | from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm1d # NOQA 7 | from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm2d # NOQA 8 | from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm3d # NOQA 9 | from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv1d # NOQA 10 | from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv2d # NOQA 11 | from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv3d # NOQA 12 | from pytorch_pfn_extras.nn.modules.lazy_linear import LazyLinear # NOQA 13 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/nn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/pytorch_pfn_extras/nn/modules/__init__.py -------------------------------------------------------------------------------- /pytorch_pfn_extras/nn/modules/lazy_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch 4 | from pytorch_pfn_extras.nn.modules.lazy import ( 5 | LazyInitializationMixin, 6 | UninitializedParameter, 7 | ) 8 | 9 | 10 | class _LazyConvNd(LazyInitializationMixin): 11 | lazy_parameter_names = ("weight",) 12 | 13 | def __init__( 14 | self: Any, in_channels: Optional[int], *args: Any, **kwargs: Any 15 | ) -> None: 16 | super().__init__(in_channels or 0, *args, **kwargs) 17 | if in_channels is None: 18 | self.in_channels: Optional[int] = None 19 | self.weight = UninitializedParameter() 20 | 21 | def forward(self: Any, input: torch.Tensor) -> torch.Tensor: 22 | if isinstance(self.weight, UninitializedParameter): 23 | self.in_channels = input.shape[1] 24 | if self.transposed: 25 | shape = ( 26 | self.in_channels, 27 | self.out_channels // self.groups, 28 | *self.kernel_size, 29 | ) 30 | else: 31 | shape = ( 32 | self.out_channels, 33 | self.in_channels // self.groups, 34 | *self.kernel_size, 35 | ) 36 | self.weight = torch.nn.Parameter(self.weight.new_empty(*shape)) 37 | self.reset_parameters() 38 | return super().forward(input) # type: ignore 39 | 40 | def reset_parameters(self: Any) -> None: 41 | # Defer initialization of parameters until shape of all parameters 42 | # are ready. 43 | if self.lazy_parmeters_determined: 44 | super().reset_parameters() # type: ignore[misc] 45 | 46 | 47 | class LazyConv1d(_LazyConvNd, torch.nn.Conv1d): # type: ignore[misc] 48 | """Conv1d module with lazy weight initialization. 49 | 50 | When ``in_channels`` is ``None``, it is determined at the first time of 51 | the forward step. 52 | """ 53 | 54 | pass 55 | 56 | 57 | class LazyConv2d(_LazyConvNd, torch.nn.Conv2d): # type: ignore[misc] 58 | """Conv2d module with lazy weight initialization. 59 | 60 | When ``in_channels`` is ``None``, it is determined at the first time of 61 | the forward step. 62 | """ 63 | 64 | pass 65 | 66 | 67 | class LazyConv3d(_LazyConvNd, torch.nn.Conv3d): # type: ignore[misc] 68 | """Conv3d module with lazy weight initialization. 69 | 70 | When ``in_channels`` is ``None``, it is determined at the first time of 71 | the forward step. 72 | """ 73 | 74 | pass 75 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/nn/modules/lazy_linear.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch 4 | from pytorch_pfn_extras.nn.modules.lazy import ( 5 | LazyInitializationMixin, 6 | UninitializedParameter, 7 | ) 8 | 9 | 10 | class LazyLinear(LazyInitializationMixin, torch.nn.Linear): # type: ignore[misc] 11 | """Linear module with lazy weight initialization. 12 | 13 | When ``in_features`` is ``None``, it is determined at the first time of 14 | the forward step. 15 | """ 16 | 17 | lazy_parameter_names = ("weight",) 18 | 19 | def __init__( 20 | self, in_features: Optional[int], *args: Any, **kwargs: Any 21 | ) -> None: 22 | super().__init__(in_features or 0, *args, **kwargs) 23 | if in_features is None: 24 | self.in_features = None # type: ignore[assignment] 25 | self.weight = UninitializedParameter() 26 | 27 | def forward(self, input: torch.Tensor) -> torch.Tensor: 28 | if isinstance(self.weight, UninitializedParameter): 29 | self.in_features = input.shape[-1] 30 | self.weight = torch.nn.Parameter( 31 | self.weight.new_empty(self.out_features, self.in_features) 32 | ) 33 | self.reset_parameters() 34 | return super().forward(input) 35 | 36 | def reset_parameters(self) -> None: 37 | # Defer initialization of parameters until shape of the parameter 38 | # is determiend. 39 | if self.lazy_parmeters_determined: 40 | super().reset_parameters() 41 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.nn.parallel.distributed import ( # NOQA 2 | DistributedDataParallel, 3 | ) 4 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE: type stub (`__init__.pyi`) must be in sync with these public APIs. 2 | 3 | try: 4 | from pytorch_pfn_extras.onnx.export_testcase import export # NOQA 5 | from pytorch_pfn_extras.onnx.export_testcase import export_testcase # NOQA 6 | from pytorch_pfn_extras.onnx.export_testcase import is_large_tensor # NOQA 7 | from pytorch_pfn_extras.onnx.export_testcase import LARGE_TENSOR_DATA_THRESHOLD # NOQA 8 | from pytorch_pfn_extras.onnx.annotate import annotate # NOQA 9 | from pytorch_pfn_extras.onnx.annotate import apply_annotation # NOQA 10 | from pytorch_pfn_extras.onnx.annotate import scoped_anchor # NOQA 11 | from pytorch_pfn_extras.onnx._as_output import as_output # NOQA 12 | from pytorch_pfn_extras.onnx._grad import grad # NOQA 13 | from pytorch_pfn_extras.onnx.load import load_model # NOQA 14 | from pytorch_pfn_extras.onnx._helper import no_grad # NOQA 15 | import pytorch_pfn_extras.onnx._lax as lax # NOQA 16 | available = True 17 | except ImportError: 18 | available = False 19 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | def export(*args: Any, **kwargs: Any) -> Any: ... # NOQA 5 | def export_testcase(*args: Any, **kwargs: Any) -> Any: ... # NOQA 6 | def is_large_tensor(*args: Any, **kwargs: Any) -> Any: ... # NOQA 7 | def annotate(*args: Any, **kwargs: Any) -> Any: ... # NOQA 8 | def apply_annotation(*args: Any, **kwargs: Any) -> Any: ... # NOQA 9 | def scoped_anchor(*args: Any, **kwargs: Any) -> Any: ... # NOQA 10 | def as_output(*args: Any, **kwargs: Any) -> Any: ... # NOQA 11 | def grad(*args: Any, **kwargs: Any) -> Any: ... # NOQA 12 | def load_model(*args: Any, **kwargs: Any) -> Any: ... # NOQA 13 | 14 | 15 | LARGE_TENSOR_DATA_THRESHOLD: int 16 | available: bool 17 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/_constants.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras 2 | import torch.onnx 3 | import torch.onnx.symbolic_helper 4 | 5 | from torch.onnx._constants import ONNX_DEFAULT_OPSET, ONNX_CONSTANT_FOLDING_MIN_OPSET, ONNX_MAX_OPSET # type: ignore[attr-defined] 6 | onnx_default_opset = ONNX_DEFAULT_OPSET 7 | onnx_constant_folding_opsets = range(ONNX_CONSTANT_FOLDING_MIN_OPSET, ONNX_MAX_OPSET) 8 | onnx_main_opset = ONNX_MAX_OPSET 9 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/_globals.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras 2 | import torch 3 | from typing import Optional 4 | 5 | 6 | import torch.onnx._globals 7 | GLOBALS = torch.onnx._globals.GLOBALS 8 | 9 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Callable, Any 3 | 4 | 5 | def _detach(x: Any) -> Any: 6 | if isinstance(x, torch.Tensor): 7 | return x.detach() 8 | elif isinstance(x, list): 9 | return [_detach(elem) for elem in x] 10 | elif isinstance(x, tuple): 11 | return tuple([_detach(elem) for elem in x]) 12 | elif isinstance(x, dict): 13 | return {k: _detach(v) for k, v in x.items()} 14 | else: 15 | return x 16 | 17 | 18 | def no_grad(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: 19 | with torch.no_grad(): # type: ignore[no-untyped-call] 20 | out = fn(*args, **kwargs) 21 | # torch.no_grad() does not export `detach` op when tracing 22 | return _detach(out) 23 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/load.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import json 3 | import warnings 4 | from pathlib import Path 5 | from typing import Any, IO, Text, Union 6 | 7 | 8 | def load_model( 9 | f: Union[IO, Text], 10 | format: Any = None, 11 | load_external_data: bool = True, 12 | ) -> onnx.ModelProto: 13 | """Load model from ONNX file. 14 | 15 | This is a wrapper to `onnx.load_model` that automatically falls back to 16 | `load_external_data=False` when tensors are stripped. 17 | 18 | Args: 19 | f: A file-like object or a string file path to be written to this 20 | file. 21 | format: A reserved arg 22 | load_external_data: If True and the external data under the same 23 | directory of the model, load the external data 24 | """ 25 | try: 26 | return onnx.load_model(f, format=format, load_external_data=load_external_data) 27 | except (OSError, onnx.checker.ValidationError) as e: # The ONNX may contain stripped large tensors. 28 | if (load_external_data 29 | and isinstance(e, OSError) 30 | and json.loads(Path(e.filename).name)["type"] != "stripped"): 31 | raise 32 | warnings.warn( 33 | 'The specified ONNX contains stripped large tensors. ' 34 | 'Falling back to `load_external_data=False`.', 35 | UserWarning) 36 | return onnx.load_model(f, format=format, load_external_data=False) 37 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/pfto_exporter/__init__.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras.onnx.pfto_exporter.export # NOQA 2 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/onnx/symbolic_registry.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras 2 | from typing import cast, Any, Callable, Tuple, Union 3 | 4 | import torch.onnx._internal.registration as reg 5 | import torch.onnx.utils 6 | 7 | def is_registered_op(opname: str, domain: str, version: int) -> Any: 8 | return reg.registry.is_registered_op(f"{domain}::{opname}", version) 9 | 10 | Value = torch._C.Value 11 | SymbolicFunction = Callable[..., Union[Value, Tuple[Value]]] 12 | 13 | def get_registered_op(opname: str, domain: str, version: int) -> SymbolicFunction: 14 | group = reg.registry.get_function_group(f"{domain}::{opname}") 15 | assert group is not None 16 | ret = group.get(version) 17 | assert ret is not None 18 | return cast(Callable, ret) # type: ignore[redundant-cast] 19 | 20 | def register_op(op_type: str, f: Callable, domain: str, opset_version: int) -> None: 21 | if len(domain) == 0: 22 | domain = "aten" 23 | torch.onnx.utils.register_custom_op_symbolic( 24 | f"{domain}::{op_type}", f, opset_version) # type: ignore[no-untyped-call] 25 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.ops.register import OpDesc, register # NOQA 2 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/profiler/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.profiler._record import record # NOQA 2 | from pytorch_pfn_extras.profiler._record import record_function # NOQA 3 | from pytorch_pfn_extras.profiler._record import record_iterable # NOQA 4 | from pytorch_pfn_extras.profiler._time_summary import TimeSummary # NOQA 5 | from pytorch_pfn_extras.profiler._time_summary import get_time_summary # NOQA 6 | from pytorch_pfn_extras.profiler._tracing import ( # NOQA 7 | ChromeTracer, 8 | TraceableDataset, 9 | Tracer, 10 | clear_tracer, 11 | enable_global_trace, 12 | enable_thread_trace, 13 | get_tracer, 14 | load_chrome_trace_as_json, 15 | ) 16 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/profiler/_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import threading 3 | from typing import Any, Callable, Optional 4 | 5 | 6 | class QueueWorker: 7 | def __init__( 8 | self, 9 | add: Callable[[str, Any], None], 10 | max_queue_size: int, 11 | ) -> None: 12 | self._add = add 13 | self._max_queue_size = max_queue_size 14 | self._initialized = False 15 | self._queue: Optional[mp.JoinableQueue] = None 16 | self._thread: Optional[threading.Thread] = None 17 | self._thread_exited = False 18 | 19 | def initialize(self) -> None: 20 | if self._initialized: 21 | return 22 | self._thread = threading.Thread(target=self._worker, daemon=True) 23 | self._queue = mp.JoinableQueue(self._max_queue_size) 24 | self._thread.start() 25 | self._initialized = True 26 | self._thread_exited = False 27 | 28 | def finalize(self) -> None: 29 | if not self._initialized: 30 | return 31 | assert self._queue is not None 32 | assert self._thread is not None 33 | # In some situations, (when this runs in a subprocess), the queue might have 34 | # been cut in the worker thread before this function is called 35 | # due to the non-deterministic shutdown process. 36 | if not self._thread_exited: 37 | self._queue.put(None) 38 | self._queue.join() 39 | self._queue.close() 40 | self._queue.join_thread() 41 | self._initialized = False 42 | 43 | def synchronize(self) -> None: 44 | assert self._queue is not None 45 | self._queue.join() 46 | 47 | def put(self, name: str, value: Any) -> None: 48 | assert self._queue is not None 49 | assert not self._thread_exited 50 | self._queue.put((name, value)) 51 | 52 | def _worker(self) -> None: 53 | assert self._queue is not None 54 | while True: 55 | try: 56 | v = self._queue.get() 57 | # If this runs in a subprocess, the cleanup may throw an EOF here 58 | # before the queue cleanup code is executed 59 | except EOFError: 60 | self._thread_exited = True 61 | break 62 | if v is None: 63 | self._queue.task_done() 64 | break 65 | name, value = v 66 | self._add(name, value) 67 | self._queue.task_done() 68 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/pytorch_pfn_extras/py.typed -------------------------------------------------------------------------------- /pytorch_pfn_extras/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.runtime._registry import _RuntimeRegistry # NOQA 2 | from pytorch_pfn_extras.runtime._runtime import BaseRuntime # NOQA 3 | from pytorch_pfn_extras.runtime._runtime import PyTorchRuntime # NOQA 4 | 5 | runtime_registry = _RuntimeRegistry(PyTorchRuntime) 6 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/runtime/_autocast.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Any, Dict, Generator 3 | 4 | _cuda_amp_available = False 5 | 6 | try: 7 | import torch.cuda.amp 8 | 9 | _cuda_amp_available = torch.cuda.is_available() and hasattr( 10 | torch.cuda.amp, "autocast" 11 | ) 12 | except ImportError: 13 | pass 14 | 15 | 16 | class _AutocastManager: 17 | def __init__( 18 | self, 19 | autocast_options: Dict[str, Any], 20 | has_grad_scaler: bool, 21 | ) -> None: 22 | autocast_options = autocast_options.copy() 23 | self._enabled = autocast_options.pop("enabled", True) 24 | self._device_type = autocast_options.pop("device_type", "cuda") 25 | self._options = autocast_options 26 | 27 | if not _cuda_amp_available: 28 | if has_grad_scaler or ( 29 | self._enabled and self._device_type == "cuda" 30 | ): 31 | raise RuntimeError( 32 | "Requested AMP features but torch.cuda.amp" 33 | " is not enabled" 34 | ) 35 | 36 | @contextlib.contextmanager 37 | def autocast(self, enabled: bool = True) -> Generator[None, None, None]: 38 | # CUDA Availability was checked in Runtime Constructor 39 | with torch.autocast(self._device_type, enabled=self._enabled, **self._options): # type: ignore[no-untyped-call,attr-defined] 40 | yield 41 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/runtime/_map.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Sequence, Set 2 | 3 | import pytorch_pfn_extras as ppe 4 | 5 | 6 | def map( 7 | func: Callable[[Any], Any], 8 | iterable: Sequence[Any], 9 | out_keys: Optional[Set[str]] = None, 10 | device: Any = "cpu", 11 | ) -> Sequence[Any]: 12 | codeblock = ppe.handler.forward(func) 13 | return codeblock.runtime.map(codeblock, iterable, out_keys, device) # type: ignore 14 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/runtime/_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type 2 | 3 | import torch 4 | from pytorch_pfn_extras.runtime._runtime import BaseRuntime, DeviceLike 5 | 6 | 7 | class _RuntimeRegistry: 8 | def __init__(self, fallback_class: Type[BaseRuntime]): 9 | self._runtimes: Dict[str, Type[BaseRuntime]] = {} 10 | self._fallback_class = fallback_class 11 | 12 | def register( 13 | self, 14 | device_type: str, 15 | runtime_class: Type[BaseRuntime], 16 | ) -> None: 17 | self._runtimes[device_type] = runtime_class 18 | 19 | def get_runtime_class_for_device_spec( 20 | self, device: DeviceLike 21 | ) -> Type[BaseRuntime]: 22 | if isinstance(device, torch.device): 23 | device_type = device.type 24 | else: 25 | assert isinstance(device, str) 26 | device_type = device.split(":")[0] 27 | return self._runtimes.get(device_type, self._fallback_class) 28 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/runtime/_to.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Type, TypeVar 2 | 3 | import pytorch_pfn_extras as ppe 4 | import torch 5 | from pytorch_pfn_extras.runtime._runtime import BaseRuntime, DeviceLike 6 | 7 | ModuleOrTensor = TypeVar("ModuleOrTensor", torch.nn.Module, torch.Tensor) 8 | 9 | 10 | def to( 11 | module_or_tensor: ModuleOrTensor, 12 | device: DeviceLike, 13 | *, 14 | options: Optional[Dict[str, Any]] = None, 15 | runtime_class: Optional[Type[BaseRuntime]] = None, 16 | config: Optional[Dict[str, Any]] = None, 17 | ) -> ModuleOrTensor: 18 | """A function to transfer the given object to the given device. 19 | 20 | If PyTorch's device type is given as the ``device`` argument, 21 | the behavior of this function is equivalent to 22 | ``module_or_tensor.to(module_or_tensor, device)``. 23 | 24 | Otherwise, this function uses the **Runtime** mechanism. 25 | This function looks for the Runtime for the device from the RuntimeRegistry 26 | and delegates the actual transfer operation to it. 27 | 28 | See also the documentation of ``ppe.runtime.BaseRuntime`` for details. 29 | 30 | Args: 31 | module_or_tensor (torch.nn.Module or torch.Tensor): 32 | An object to be transferred. 33 | device (torch.device or str): 34 | The device that the input object is transferred to. 35 | options (dict, optional): 36 | An options of dictionary type that is passed to 37 | ``runtime_class.__init__`` as an argument. 38 | runtime_class: 39 | A runtime class inherited from `BaseRuntime` class. 40 | If ``None``, a runtime class is automatically selected 41 | based on the ``device`` argument from the runtime registry. 42 | config (dict, optional): 43 | DEPRECATED. Use `options`. 44 | 45 | Returns: 46 | A `torch.Tensor` with the specified device. 47 | """ 48 | if options is None: 49 | options = {} 50 | if config is not None: 51 | options = config 52 | elif config is not None: 53 | raise ValueError("options and config cannot be specified together") 54 | 55 | if runtime_class is None: 56 | registry = ppe.runtime.runtime_registry 57 | runtime_class = registry.get_runtime_class_for_device_spec(device) 58 | runtime = runtime_class(device, options) 59 | obj = module_or_tensor 60 | if isinstance(obj, torch.nn.Module): 61 | mod = runtime.move_module(obj) 62 | for module in mod.modules(): 63 | ppe.runtime._runtime._set_module_runtime_tag(module, runtime) 64 | return mod 65 | elif isinstance(obj, torch.Tensor): 66 | return runtime.move_tensor(obj) 67 | else: 68 | raise ValueError("Unsupported type for module_or_tensor") 69 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/testing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union 2 | 3 | import torch 4 | 5 | 6 | def _compare_states( 7 | s1: Union[Dict[Any, Any], List[Any], Tuple[Any]], 8 | s2: Union[Dict[Any, Any], List[Any], Tuple[Any]], 9 | strict: bool = False, 10 | ) -> bool: 11 | def allclose(a: torch.Tensor, b: torch.Tensor) -> bool: 12 | if strict: 13 | return bool((a == b).all()) 14 | else: 15 | return torch.allclose(a, b) 16 | 17 | if isinstance(s1, dict): 18 | keys = list(s1.keys()) 19 | assert isinstance(s2, dict) 20 | if set(keys) != set(s2.keys()): 21 | return False 22 | elif isinstance(s1, (list, tuple)): 23 | keys = list(range(len(s1))) 24 | if len(s1) != len(s2): 25 | return False 26 | 27 | all_equal = True 28 | for k in keys: 29 | if isinstance(s1[k], dict): 30 | if not isinstance(s2[k], dict): 31 | return False 32 | all_equal = all_equal and _compare_states(s1[k], s2[k]) 33 | elif isinstance(s1[k], (list, tuple)): 34 | if not isinstance(s2[k], (list, tuple)): 35 | return False 36 | all_equal = all_equal and _compare_states(s1[k], s2[k]) 37 | elif isinstance(s1[k], torch.Tensor): 38 | all_equal = all_equal and allclose(s1[k], s2[k]) 39 | else: 40 | all_equal = all_equal and s1[k] == s2[k] 41 | if not all_equal: 42 | return all_equal 43 | return all_equal 44 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/torchscript.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple 2 | 3 | import torch 4 | 5 | 6 | # Run jit pass with post lint 7 | def run_jit_pass( 8 | p: Callable, g: torch._C.Graph, *args: Any, **kwargs: Any 9 | ) -> None: 10 | p(g, *args, **kwargs) 11 | torch._C._jit_pass_lint(g) 12 | 13 | 14 | def find_inplace( 15 | g: torch._C.Graph, 16 | ) -> Tuple[torch._C.Graph, List[torch._C.Node]]: 17 | g = g.copy() 18 | run_jit_pass(torch._C._jit_pass_inline, g) 19 | nodes = [] 20 | for n in g.nodes(): 21 | if n.kind().endswith("_"): 22 | nodes.append(n) 23 | return g, nodes 24 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training import extensions # NOQA 2 | from pytorch_pfn_extras.training._evaluator import DistributedEvaluator # NOQA 3 | from pytorch_pfn_extras.training._evaluator import Evaluator # NOQA 4 | from pytorch_pfn_extras.training._manager_protocol import ( # NOQA 5 | ExtensionsManagerProtocol, 6 | StateObjectProtocol, 7 | ) 8 | from pytorch_pfn_extras.training._trainer import Trainer # NOQA 9 | from pytorch_pfn_extras.training.extension import PRIORITY_EDITOR # NOQA 10 | from pytorch_pfn_extras.training.extension import PRIORITY_READER # NOQA 11 | from pytorch_pfn_extras.training.extension import PRIORITY_WRITER # NOQA 12 | from pytorch_pfn_extras.training.extension import Extension # NOQA 13 | from pytorch_pfn_extras.training.extension import ExtensionEntry # NOQA 14 | from pytorch_pfn_extras.training.extension import make_extension # NOQA 15 | from pytorch_pfn_extras.training.manager import ExtensionsManager # NOQA 16 | from pytorch_pfn_extras.training.manager import IgniteExtensionsManager # NOQA 17 | from pytorch_pfn_extras.training.metrics import AccuracyMetric # NOQA 18 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/_manager_protocol.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Dict, 5 | Mapping, 6 | Optional, 7 | runtime_checkable, 8 | ) 9 | 10 | import torch 11 | from typing_extensions import Protocol 12 | 13 | if TYPE_CHECKING: 14 | from pytorch_pfn_extras import reporting, writing 15 | from pytorch_pfn_extras.training import trigger as trigger_module 16 | from pytorch_pfn_extras.training.extension import Extension 17 | 18 | 19 | class ExtensionsManagerProtocol(Protocol): 20 | @property 21 | def iteration(self) -> int: ... 22 | 23 | @property 24 | def epoch(self) -> int: ... 25 | 26 | @property 27 | def epoch_detail(self) -> float: ... 28 | 29 | @property 30 | def _iters_per_epoch(self) -> int: ... 31 | 32 | @property 33 | def models(self) -> Mapping[str, torch.nn.Module]: ... 34 | 35 | @property 36 | def raw_models(self) -> Mapping[str, torch.nn.Module]: ... 37 | 38 | @property 39 | def optimizers(self) -> Mapping[str, torch.optim.Optimizer]: ... 40 | 41 | @property 42 | def elapsed_time(self) -> float: ... 43 | 44 | @property 45 | def is_before_training(self) -> bool: ... 46 | 47 | @property 48 | def stop_trigger(self) -> bool: ... 49 | 50 | @property 51 | def _stop_trigger(self) -> "trigger_module.Trigger": ... 52 | 53 | @property 54 | def out(self) -> str: ... 55 | 56 | @property 57 | def writer(self) -> Optional["writing.Writer"]: ... 58 | 59 | @property 60 | def reporter(self) -> "reporting.Reporter": ... 61 | 62 | def get_extension(self, name: str) -> "Extension": ... 63 | 64 | @property 65 | def observation(self) -> "reporting.Observation": ... 66 | 67 | 68 | @runtime_checkable 69 | class StateObjectProtocol(Protocol): 70 | def state_dict(self) -> Dict[str, Any]: ... 71 | 72 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... 73 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/_transform_model.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | from pytorch_pfn_extras.nn.parallel import ( 5 | DistributedDataParallel as PpeDistributedDataParallel, 6 | ) 7 | from torch.nn.parallel import DistributedDataParallel 8 | 9 | _TransformModel = typing.Callable[[str, torch.nn.Module], torch.nn.Module] 10 | 11 | 12 | def default_transform_model(n: str, x: torch.nn.Module) -> torch.nn.Module: 13 | if isinstance(x, (DistributedDataParallel, PpeDistributedDataParallel)): 14 | return x.module 15 | return x 16 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | 5 | def _get_ignite_version(version: str) -> List[int]: 6 | # We compare up to the minor version (first two digits). 7 | # This is because it is highly unlikely that these numbers 8 | # will contain other character than digits. 9 | 10 | # Ignite versioning system is not explicitly documented. 11 | # However, it seems to be using semver, so the 12 | # major and minor ids can be only integers. 13 | # Some examples of versions are: 14 | # 0.1.0, 0.1.1, 0.3.0.dev20191007, 0.3.0. 15 | version_regexp = r"^[0-9]+\.[0-9]+\.[0-9]+(\.[0-9a-zA-Z]+)?$" 16 | if re.search(version_regexp, version): 17 | return [int(x) for x in version.split(".")[:2]] 18 | raise ValueError("Invalid version format") 19 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_accumulate import AverageAccumulate # NOQA: F401 2 | from .max_accumulate import MaxAccumulate # NOQA: F401 3 | from .min_accumulate import MinAccumulate # NOQA: F401 4 | from .standard_deviation_accumulate import ( # NOQA: F401 5 | StandardDeviationAccumulate, 6 | ) 7 | from .unbiased_standard_deviation_accumulate import ( # NOQA: F401 8 | UnbiasedStandardDeviationAccumulate, 9 | ) 10 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_accumulate_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, Tuple 3 | 4 | import torch.distributed 5 | from pytorch_pfn_extras import reporting 6 | from pytorch_pfn_extras.training import extension 7 | from pytorch_pfn_extras.training._manager_protocol import ( 8 | ExtensionsManagerProtocol, 9 | ) 10 | from pytorch_pfn_extras.training.extensions.accumulate._summary import ( 11 | SummaryBase, 12 | ) 13 | from pytorch_pfn_extras.training.trigger import TriggerLike, get_trigger 14 | 15 | 16 | class AccumulateBase(ABC, extension.Extension): 17 | priority = extension.PRIORITY_EDITOR 18 | 19 | def __init__( 20 | self, 21 | conversion_key_pair: Tuple[str, str], 22 | trigger: TriggerLike = (1, "epoch"), 23 | distributed: bool = False, 24 | ) -> None: 25 | self._conversion_key_pair = conversion_key_pair 26 | self._trigger = get_trigger(trigger=trigger) 27 | self._distributed = distributed 28 | if not torch.distributed.is_initialized() and self._distributed: # type: ignore[no-untyped-call] 29 | raise RuntimeError("PyTorch distributed module is not initialized.") 30 | 31 | self._init_summary() 32 | 33 | def __call__(self, manager: ExtensionsManagerProtocol) -> None: 34 | observation = manager.observation 35 | src_key, dst_key = self._conversion_key_pair 36 | self._summary.add(observation[src_key]) 37 | 38 | if self._trigger(manager=manager): 39 | if self._distributed: 40 | summary = self._all_reduce_summaries() 41 | else: 42 | summary = self._summary 43 | reporting.report({dst_key: summary.compute_accumulate()}) 44 | self._init_summary() 45 | 46 | def state_dict(self) -> Dict[str, Any]: 47 | state: Dict[str, Any] = {} 48 | if hasattr(self._trigger, "state_dict"): 49 | state["_trigger"] = self._trigger.state_dict() 50 | state["_summary"] = self._summary.state_dict() 51 | return state 52 | 53 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 54 | if hasattr(self._trigger, "load_state_dict"): 55 | self._trigger.load_state_dict(to_load["_trigger"]) 56 | self._summary.load_state_dict(to_load["_summary"]) 57 | 58 | @property 59 | @abstractmethod 60 | def _summary(self) -> SummaryBase: ... 61 | 62 | @abstractmethod 63 | def _init_summary(self) -> None: ... 64 | 65 | @abstractmethod 66 | def _all_reduce_summaries(self) -> SummaryBase: ... 67 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_accumulate_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, TypeVar 2 | 3 | import torch.distributed 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def all_gather_object(obj: T) -> List[Optional[T]]: 9 | world_size = torch.distributed.get_world_size() # type: ignore 10 | object_list: List[Optional[T]] = [None for _ in range(world_size)] 11 | torch.distributed.all_gather_object(object_list=object_list, obj=obj) # type: ignore 12 | return object_list 13 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/__init__.py: -------------------------------------------------------------------------------- 1 | from ._average_summary import AverageSummary # NOQA: F401 2 | from ._base_summary import SummaryBase # NOQA: F401 3 | from ._max_summary import MaxSummary # NOQA: F401 4 | from ._min_summary import MinSummary # NOQA: F401 5 | from ._standard_deviation_summary import StandardDeviationSummary # NOQA: F401 6 | from ._unbiased_standard_deviation_summary import ( # NOQA: F401 7 | UnbiasedStandardDeviationSummary, 8 | ) 9 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_average_summary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any, Dict 5 | 6 | from pytorch_pfn_extras.reporting import Scalar, Value 7 | from pytorch_pfn_extras.training.extensions.accumulate._summary._base_summary import ( 8 | SummaryBase, 9 | ) 10 | from pytorch_pfn_extras.training.extensions.accumulate._summary._summary_utils import ( 11 | nograd, 12 | ) 13 | 14 | 15 | class AverageSummary(SummaryBase): 16 | def __init__(self) -> None: 17 | self._x: Scalar = 0.0 18 | self._n: Scalar = 0 19 | super().__init__() 20 | 21 | def add(self, value: Value, weight: Scalar = 1) -> None: 22 | if callable(value): 23 | self._deferred.append((value, weight)) 24 | return 25 | m = self._n / (self._n + weight) 26 | self._x = self._x * m + value / weight * (1 - m) 27 | self._n += weight 28 | 29 | def state_dict(self) -> Dict[str, Any]: 30 | self._add_deferred_values() 31 | state = {} 32 | try: 33 | # Save the stats as python scalars in order to avoid 34 | # different device errors when loading them back 35 | state = { 36 | "_x": float(self._x), 37 | "_n": int(self._n), 38 | } 39 | except KeyError: 40 | warnings.warn("The previous statistics are not saved.") 41 | return state 42 | 43 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 44 | self._add_deferred_values() 45 | self._x = float(nograd(to_load["_x"])) 46 | self._n = int(nograd(to_load["_n"])) 47 | 48 | def compute_average(self) -> Scalar: 49 | self._add_deferred_values() 50 | return self._x 51 | 52 | def compute_accumulate(self) -> Scalar: 53 | return self.compute_average() 54 | 55 | def __add__(self, other: AverageSummary) -> AverageSummary: 56 | s = AverageSummary() 57 | m = self._n / (self._n + other._n) 58 | s._x = self._x * m + other._x * (1 - m) 59 | s._n = self._n + other._n 60 | s._deferred = self._deferred + other._deferred 61 | return s 62 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_base_summary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Callable, Dict, List, Tuple 5 | 6 | from pytorch_pfn_extras.reporting import Scalar, Value 7 | 8 | 9 | class SummaryBase(ABC): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self._deferred: List[Tuple[Callable[[], float], Scalar]] = [] 13 | 14 | @abstractmethod 15 | def add(self, value: Value, weight: Scalar = 1) -> None: ... 16 | 17 | @abstractmethod 18 | def compute_accumulate(self) -> Scalar: ... 19 | 20 | @abstractmethod 21 | def state_dict(self) -> Dict[str, Any]: ... 22 | 23 | @abstractmethod 24 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: ... 25 | 26 | def _add_deferred_values(self) -> None: 27 | for fn, weight in self._deferred: 28 | value = fn() 29 | self.add(value=value, weight=weight) 30 | self._deferred.clear() 31 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_max_summary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any, Dict 5 | 6 | from pytorch_pfn_extras.reporting import Scalar, Value 7 | from pytorch_pfn_extras.training.extensions.accumulate._summary._base_summary import ( 8 | SummaryBase, 9 | ) 10 | from pytorch_pfn_extras.training.extensions.accumulate._summary._summary_utils import ( 11 | nograd, 12 | ) 13 | 14 | 15 | class MaxSummary(SummaryBase): 16 | def __init__(self) -> None: 17 | self._max_x: Scalar = -float("inf") 18 | super().__init__() 19 | 20 | def add(self, value: Value, weight: Scalar = 1) -> None: 21 | if callable(value): 22 | self._deferred.append((value, weight)) 23 | return 24 | self._max_x = self._max_x if self._max_x > value else value 25 | 26 | def state_dict(self) -> Dict[str, Any]: 27 | self._add_deferred_values() 28 | state = {} 29 | try: 30 | # Save the stats as python scalars in order to avoid 31 | # different device errors when loading them back 32 | state = { 33 | "_max_x": float(self._max_x), 34 | } 35 | except KeyError: 36 | warnings.warn("The previous statistics are not saved.") 37 | return state 38 | 39 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 40 | self._add_deferred_values() 41 | self._max_x = float(nograd(to_load["_max_x"])) 42 | 43 | def compute_max(self) -> Scalar: 44 | self._add_deferred_values() 45 | return self._max_x 46 | 47 | def compute_accumulate(self) -> Scalar: 48 | return self.compute_max() 49 | 50 | def __add__(self, other: MaxSummary) -> MaxSummary: 51 | s = MaxSummary() 52 | s._max_x = self._max_x if self._max_x > other._max_x else other._max_x 53 | s._deferred = self._deferred + other._deferred 54 | return s 55 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_min_summary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any, Dict 5 | 6 | from pytorch_pfn_extras.reporting import Scalar, Value 7 | from pytorch_pfn_extras.training.extensions.accumulate._summary._base_summary import ( 8 | SummaryBase, 9 | ) 10 | from pytorch_pfn_extras.training.extensions.accumulate._summary._summary_utils import ( 11 | nograd, 12 | ) 13 | 14 | 15 | class MinSummary(SummaryBase): 16 | def __init__(self) -> None: 17 | self._min_x: Scalar = float("inf") 18 | super().__init__() 19 | 20 | def add(self, value: Value, weight: Scalar = 1) -> None: 21 | if callable(value): 22 | self._deferred.append((value, weight)) 23 | return 24 | self._min_x = self._min_x if self._min_x < value else value 25 | 26 | def state_dict(self) -> Dict[str, Any]: 27 | self._add_deferred_values() 28 | state = {} 29 | try: 30 | # Save the stats as python scalars in order to avoid 31 | # different device errors when loading them back 32 | state = { 33 | "_min_x": float(self._min_x), 34 | } 35 | except KeyError: 36 | warnings.warn("The previous statistics are not saved.") 37 | return state 38 | 39 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 40 | self._add_deferred_values() 41 | self._min_x = float(nograd(to_load["_min_x"])) 42 | 43 | def compute_min(self) -> Scalar: 44 | self._add_deferred_values() 45 | return self._min_x 46 | 47 | def compute_accumulate(self) -> Scalar: 48 | return self.compute_min() 49 | 50 | def __add__(self, other: MinSummary) -> MinSummary: 51 | s = MinSummary() 52 | s._min_x = self._min_x if self._min_x < other._min_x else other._min_x 53 | s._deferred = self._deferred + other._deferred 54 | return s 55 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_standard_deviation_summary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any, Dict 5 | 6 | import numpy 7 | import torch 8 | from pytorch_pfn_extras.reporting import Scalar, Value 9 | from pytorch_pfn_extras.training.extensions.accumulate._summary._base_summary import ( 10 | SummaryBase, 11 | ) 12 | from pytorch_pfn_extras.training.extensions.accumulate._summary._summary_utils import ( 13 | nograd, 14 | ) 15 | 16 | 17 | class StandardDeviationSummary(SummaryBase): 18 | def __init__(self) -> None: 19 | self._x: Scalar = 0.0 20 | self._x2: Scalar = 0.0 21 | self._n: Scalar = 0 22 | super().__init__() 23 | 24 | def add(self, value: Value, weight: Scalar = 1) -> None: 25 | if callable(value): 26 | self._deferred.append((value, weight)) 27 | return 28 | self._x += weight * value 29 | self._x2 += weight * value * value 30 | self._n += weight 31 | 32 | def state_dict(self) -> Dict[str, Any]: 33 | self._add_deferred_values() 34 | state = {} 35 | try: 36 | # Save the stats as python scalars in order to avoid 37 | # different device errors when loading them back 38 | state = { 39 | "_x": float(self._x), 40 | "_x2": float(self._x2), 41 | "_n": float(self._n), 42 | } 43 | except KeyError: 44 | warnings.warn("The previous statistics are not saved.") 45 | return state 46 | 47 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 48 | self._add_deferred_values() 49 | self._x = float(nograd(to_load["_x"])) 50 | self._x2 = float(nograd(to_load["_x2"])) 51 | self._n = float(nograd(to_load["_n"])) 52 | 53 | def compute_mean(self) -> Scalar: 54 | self._add_deferred_values() 55 | x, n = self._x, self._n 56 | return x / n 57 | 58 | def compute_standard_deviation(self) -> Scalar: 59 | self._add_deferred_values() 60 | x, n = self._x, self._n 61 | mean = x / n 62 | var = self._x2 / n - mean * mean 63 | if isinstance(var, torch.Tensor): 64 | return torch.sqrt(var) 65 | else: 66 | return numpy.sqrt(var) 67 | 68 | def compute_accumulate(self) -> Scalar: 69 | return self.compute_standard_deviation() 70 | 71 | def __add__( 72 | self, other: StandardDeviationSummary 73 | ) -> StandardDeviationSummary: 74 | s = StandardDeviationSummary() 75 | s._x = self._x + other._x 76 | s._x2 = self._x2 + other._x2 77 | s._n = self._n + other._n 78 | s._deferred = self._deferred + other._deferred 79 | return s 80 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_summary_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pfn_extras.reporting import Scalar 3 | 4 | 5 | def nograd(value: Scalar) -> Scalar: 6 | if isinstance(value, torch.Tensor): 7 | return value.detach() 8 | return value 9 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/_summary/_unbiased_standard_deviation_summary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any, Dict 5 | 6 | import numpy 7 | import torch 8 | from pytorch_pfn_extras.reporting import Scalar, Value 9 | from pytorch_pfn_extras.training.extensions.accumulate._summary._base_summary import ( 10 | SummaryBase, 11 | ) 12 | from pytorch_pfn_extras.training.extensions.accumulate._summary._summary_utils import ( 13 | nograd, 14 | ) 15 | 16 | 17 | class UnbiasedStandardDeviationSummary(SummaryBase): 18 | def __init__(self) -> None: 19 | self._x: Scalar = 0.0 20 | self._x2: Scalar = 0.0 21 | self._n: Scalar = 0 22 | super().__init__() 23 | 24 | def add(self, value: Value, weight: Scalar = 1) -> None: 25 | if callable(value): 26 | self._deferred.append((value, weight)) 27 | return 28 | self._x += weight * value 29 | self._x2 += weight * value * value 30 | self._n += weight 31 | 32 | def state_dict(self) -> Dict[str, Any]: 33 | self._add_deferred_values() 34 | state = {} 35 | try: 36 | # Save the stats as python scalars in order to avoid 37 | # different device errors when loading them back 38 | state = { 39 | "_x": float(self._x), 40 | "_x2": float(self._x2), 41 | "_n": float(self._n), 42 | } 43 | except KeyError: 44 | warnings.warn("The previous statistics are not saved.") 45 | return state 46 | 47 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 48 | self._add_deferred_values() 49 | self._x = float(nograd(to_load["_x"])) 50 | self._x2 = float(nograd(to_load["_x2"])) 51 | self._n = float(nograd(to_load["_n"])) 52 | 53 | def compute_mean(self) -> Scalar: 54 | self._add_deferred_values() 55 | x, n = self._x, self._n 56 | return x / n 57 | 58 | def compute_unbiased_standard_deviation(self) -> Scalar: 59 | self._add_deferred_values() 60 | x, n = self._x, self._n 61 | if n <= 1: 62 | return float("nan") 63 | mean = x / n 64 | var = self._x2 / n - mean * mean 65 | unbiased_var = var * (n / (n - 1)) 66 | if isinstance(unbiased_var, torch.Tensor): 67 | return torch.sqrt(unbiased_var) 68 | else: 69 | return numpy.sqrt(unbiased_var) 70 | 71 | def compute_accumulate(self) -> Scalar: 72 | return self.compute_unbiased_standard_deviation() 73 | 74 | def __add__( 75 | self, other: UnbiasedStandardDeviationSummary 76 | ) -> UnbiasedStandardDeviationSummary: 77 | s = UnbiasedStandardDeviationSummary() 78 | s._x = self._x + other._x 79 | s._x2 = self._x2 + other._x2 80 | s._n = self._n + other._n 81 | s._deferred = self._deferred + other._deferred 82 | return s 83 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/average_accumulate.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training.extensions.accumulate._accumulate_base import ( 2 | AccumulateBase, 3 | ) 4 | from pytorch_pfn_extras.training.extensions.accumulate._summary import ( 5 | AverageSummary, 6 | SummaryBase, 7 | ) 8 | 9 | from ._accumulate_utils import all_gather_object 10 | 11 | 12 | class AverageAccumulate(AccumulateBase): 13 | @property 14 | def _summary(self) -> SummaryBase: 15 | return self._average_summary 16 | 17 | def _init_summary(self) -> None: 18 | self._average_summary = AverageSummary() 19 | 20 | def _all_reduce_summaries(self) -> SummaryBase: 21 | summaries = all_gather_object(self._average_summary) 22 | all_reduced_summary = sum(filter(None, summaries), AverageSummary()) 23 | return all_reduced_summary 24 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/max_accumulate.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training.extensions.accumulate._accumulate_base import ( 2 | AccumulateBase, 3 | ) 4 | from pytorch_pfn_extras.training.extensions.accumulate._summary import ( 5 | MaxSummary, 6 | SummaryBase, 7 | ) 8 | 9 | from ._accumulate_utils import all_gather_object 10 | 11 | 12 | class MaxAccumulate(AccumulateBase): 13 | @property 14 | def _summary(self) -> SummaryBase: 15 | return self._max_summary 16 | 17 | def _init_summary(self) -> None: 18 | self._max_summary = MaxSummary() 19 | 20 | def _all_reduce_summaries(self) -> SummaryBase: 21 | summaries = all_gather_object(self._max_summary) 22 | all_reduced_summary = sum(filter(None, summaries), MaxSummary()) 23 | return all_reduced_summary 24 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/min_accumulate.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training.extensions.accumulate._accumulate_base import ( 2 | AccumulateBase, 3 | ) 4 | from pytorch_pfn_extras.training.extensions.accumulate._summary import ( 5 | MinSummary, 6 | SummaryBase, 7 | ) 8 | 9 | from ._accumulate_utils import all_gather_object 10 | 11 | 12 | class MinAccumulate(AccumulateBase): 13 | @property 14 | def _summary(self) -> SummaryBase: 15 | return self._min_summary 16 | 17 | def _init_summary(self) -> None: 18 | self._min_summary = MinSummary() 19 | 20 | def _all_reduce_summaries(self) -> SummaryBase: 21 | summaries = all_gather_object(self._min_summary) 22 | all_reduced_summary = sum(filter(None, summaries), MinSummary()) 23 | return all_reduced_summary 24 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/standard_deviation_accumulate.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training.extensions.accumulate._accumulate_base import ( 2 | AccumulateBase, 3 | ) 4 | from pytorch_pfn_extras.training.extensions.accumulate._summary import ( 5 | StandardDeviationSummary, 6 | SummaryBase, 7 | ) 8 | 9 | from ._accumulate_utils import all_gather_object 10 | 11 | 12 | class StandardDeviationAccumulate(AccumulateBase): 13 | @property 14 | def _summary(self) -> SummaryBase: 15 | return self._standard_deviation_summary 16 | 17 | def _init_summary(self) -> None: 18 | self._standard_deviation_summary = StandardDeviationSummary() 19 | 20 | def _all_reduce_summaries(self) -> SummaryBase: 21 | summaries = all_gather_object(self._standard_deviation_summary) 22 | all_reduced_summary = sum( 23 | filter(None, summaries), StandardDeviationSummary() 24 | ) 25 | return all_reduced_summary 26 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/accumulate/unbiased_standard_deviation_accumulate.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training.extensions.accumulate._accumulate_base import ( 2 | AccumulateBase, 3 | ) 4 | from pytorch_pfn_extras.training.extensions.accumulate._summary import ( 5 | SummaryBase, 6 | UnbiasedStandardDeviationSummary, 7 | ) 8 | 9 | from ._accumulate_utils import all_gather_object 10 | 11 | 12 | class UnbiasedStandardDeviationAccumulate(AccumulateBase): 13 | @property 14 | def _summary(self) -> SummaryBase: 15 | return self._standard_deviation_summary 16 | 17 | def _init_summary(self) -> None: 18 | self._standard_deviation_summary = UnbiasedStandardDeviationSummary() 19 | 20 | def _all_reduce_summaries(self) -> SummaryBase: 21 | summaries = all_gather_object(self._standard_deviation_summary) 22 | all_reduced_summary = sum( 23 | filter(None, summaries), UnbiasedStandardDeviationSummary() 24 | ) 25 | return all_reduced_summary 26 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/fail_on_non_number.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pfn_extras.training import extension 3 | from pytorch_pfn_extras.training._manager_protocol import ( 4 | ExtensionsManagerProtocol, 5 | ) 6 | 7 | 8 | class FailOnNonNumber(extension.Extension): 9 | """An extension to raise RuntimeError if parameters and its gradients 10 | contain NaN or Inf. 11 | 12 | Although parameters including non-number such as NaN and Inf are 13 | unnecessary in most cases the training loop will continue 14 | to compute even if the parameters in a given optimizer diverge. 15 | This extension is aimed to reduce unnecessary computations by throwing 16 | ``RuntimeError`` if the parameters contain NaN or Inf. 17 | 18 | Args: 19 | check_grad: Set to False to skip checking gradients. 20 | """ 21 | 22 | needs_model_state = True 23 | 24 | def __init__(self, *, check_grad: bool = True): 25 | self._check_grad = check_grad 26 | 27 | def __call__(self, manager: ExtensionsManagerProtocol) -> None: 28 | for name, model in manager.models.items(): 29 | for param in model.parameters(): 30 | if not torch.isfinite(param).all() or ( 31 | self._check_grad 32 | and param.grad is not None 33 | and not torch.isfinite(param.grad).all() 34 | ): 35 | raise RuntimeError( 36 | "Kill the process since parameters in optimizer" 37 | " '{}' diverge. R.I.P.".format(name) 38 | ) 39 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/print_report_notebook.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import IO, Any, List, Optional, Union 3 | 4 | from IPython.display import display 5 | from ipywidgets import HTML 6 | from pytorch_pfn_extras.training._manager_protocol import ( 7 | ExtensionsManagerProtocol, 8 | ) 9 | from pytorch_pfn_extras.training.extensions import ( 10 | log_report as log_report_module, 11 | ) 12 | from pytorch_pfn_extras.training.extensions.print_report import PrintReport 13 | 14 | 15 | class PrintReportNotebook(PrintReport): 16 | """An extension to print the accumulated results. 17 | 18 | It is aimed to work on jupyter notebook as replacement of `PrintReport`. 19 | This extension uses the log accumulated by a :class:`LogReport` extension 20 | to print specified entries of the log in a human-readable format. 21 | 22 | Args: 23 | entries (list of str ot None): List of keys of observations to print. 24 | If `None` is passed, automatically infer keys from reported dict. 25 | log_report (str or LogReport): Log report to accumulate the 26 | observations. This is either the name of a LogReport extensions 27 | registered to the manager, or a LogReport instance to use 28 | internally. 29 | out: This is not used, argument is kept to be consistent with 30 | `PrintReport`. 31 | 32 | """ 33 | 34 | def __init__( 35 | self, 36 | entries: Optional[List[str]] = None, 37 | log_report: Union[str, log_report_module.LogReport] = "LogReport", 38 | out: IO[Any] = sys.stdout, 39 | ) -> None: 40 | super(PrintReportNotebook, self).__init__( 41 | entries=entries, log_report=log_report, out=out 42 | ) 43 | self._widget = HTML() 44 | 45 | def initialize(self, manager: ExtensionsManagerProtocol) -> None: 46 | display(self._widget) # type: ignore[no-untyped-call] 47 | super(PrintReportNotebook, self).initialize(manager) 48 | 49 | @property 50 | def widget(self) -> HTML: 51 | return self._widget 52 | 53 | def __call__(self, manager: ExtensionsManagerProtocol) -> None: 54 | log_report = self.get_log_report(manager) 55 | df = log_report.to_dataframe() 56 | if self._infer_entries: 57 | # --- update entries --- 58 | self._update_entries(log_report) 59 | self._widget.value = df[self._entries].to_html(index=False, na_rep="") 60 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/slack_manifest.yml: -------------------------------------------------------------------------------- 1 | display_information: 2 | name: PPE Slack Extension 3 | features: 4 | bot_user: 5 | display_name: Training Report 6 | always_online: true 7 | oauth_config: 8 | scopes: 9 | bot: 10 | - chat:write 11 | - files:write 12 | settings: 13 | org_deploy_enabled: false 14 | socket_mode_enabled: false 15 | token_rotation_enabled: false 16 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/snapshot_writers.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.writing import * # NOQA 2 | 3 | # TODO(ecastill) deprecate this 4 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/extensions/value_observation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import torch.optim 4 | from pytorch_pfn_extras.training import extension 5 | from pytorch_pfn_extras.training._manager_protocol import ( 6 | ExtensionsManagerProtocol, 7 | ) 8 | 9 | 10 | def observe_value( 11 | observation_key: str, 12 | target_func: Callable[[ExtensionsManagerProtocol], Any], 13 | ) -> Callable[[ExtensionsManagerProtocol], None]: 14 | """Returns an extension to continuously record a value. 15 | 16 | Args: 17 | observation_key (str): Key of observation to record. 18 | target_func (function): Function that returns the value to record. 19 | It must take one argument: 20 | :class:~pytorch_pfn_extras.training.ExtensionsManager object. 21 | Returns: 22 | The extension function. 23 | 24 | This extension is triggered each epoch by default. 25 | To change this, use the ``trigger`` argument with the 26 | :meth:`ExtensionsManager.extend() ` method. 28 | 29 | """ 30 | 31 | @extension.make_extension( 32 | trigger=(1, "epoch"), priority=extension.PRIORITY_WRITER 33 | ) 34 | def _observe_value(manager: ExtensionsManagerProtocol) -> None: 35 | manager.observation[observation_key] = target_func(manager) 36 | 37 | return _observe_value 38 | 39 | 40 | def observe_lr( 41 | optimizer: torch.optim.Optimizer, 42 | param_group: int = 0, 43 | observation_key: str = "lr", 44 | ) -> Any: 45 | """Returns an extension to record the learning rate. 46 | 47 | Args: 48 | optimizer (Optimizer): Optimizer whose learning rate is 49 | recorded. 50 | param_group (int): Param group of the optimizer to observe 51 | observation_key (str): Key of observation to record. 52 | 53 | Returns: 54 | The extension function. 55 | 56 | This extension is triggered each epoch by default. 57 | To change this, use the ``trigger`` argument with the 58 | :meth:`ExtensionsManager.extend() ` method. 60 | 61 | """ 62 | return observe_value( 63 | observation_key, 64 | lambda manager: optimizer.param_groups[param_group]["lr"], 65 | ) 66 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Tuple 2 | 3 | import torch 4 | 5 | Batch = Dict[str, torch.Tensor] 6 | MetricType = Callable[[Batch, Batch], Batch] 7 | 8 | 9 | class AccuracyMetric: 10 | """A metric for an evaluator to report accuracy. 11 | 12 | Args: 13 | label_key: The key name of label. 14 | output_key: The key name of prediction. 15 | 16 | .. seealso: 17 | :func:`pytorch_pfn_extras.engine.create_evaluator` 18 | """ 19 | 20 | def __init__(self, label_key: str, output_key: str) -> None: 21 | self.label_key = label_key 22 | self.output_key = output_key 23 | 24 | def _preprocess_input( 25 | self, batch: Batch, out: Batch 26 | ) -> Tuple[torch.Tensor, int, torch.Tensor]: 27 | labels = batch[self.label_key].cpu() 28 | n_output = labels.shape[0] 29 | pred = out[self.output_key][:n_output].cpu() 30 | return labels, n_output, pred 31 | 32 | def __call__(self, batch: Batch, out: Batch) -> Dict[str, Any]: 33 | with torch.no_grad(): # type: ignore[no-untyped-call] 34 | labels, n_output, pred = self._preprocess_input(batch, out) 35 | correct = (labels.view_as(pred) == pred).sum().item() 36 | return {"accuracy": correct / n_output} 37 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/trigger.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.training._trigger_util import Trigger # NOQA 2 | from pytorch_pfn_extras.training._trigger_util import TriggerFunc # NOQA 3 | from pytorch_pfn_extras.training._trigger_util import TriggerLike # NOQA 4 | from pytorch_pfn_extras.training._trigger_util import get_trigger # NOQA 5 | from pytorch_pfn_extras.training._trigger_util import ( # NOQA 6 | _never_fire_trigger, 7 | ) 8 | 9 | # For backward compatibility 10 | from pytorch_pfn_extras.training.triggers.interval_trigger import ( # NOQA 11 | IntervalTrigger, 12 | ) 13 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/triggers/__init__.py: -------------------------------------------------------------------------------- 1 | # import classes and functions 2 | from pytorch_pfn_extras.training.triggers.early_stopping_trigger import ( # NOQA 3 | EarlyStoppingTrigger, 4 | ) 5 | from pytorch_pfn_extras.training.triggers.function_trigger import ( # NOQA 6 | FunctionTrigger, 7 | ) 8 | from pytorch_pfn_extras.training.triggers.interval_trigger import ( # NOQA 9 | IntervalTrigger, 10 | ) 11 | from pytorch_pfn_extras.training.triggers.manual_schedule_trigger import ( # NOQA 12 | ManualScheduleTrigger, 13 | ) 14 | from pytorch_pfn_extras.training.triggers.minmax_value_trigger import ( # NOQA 15 | BestValueTrigger, 16 | MaxValueTrigger, 17 | MinValueTrigger, 18 | ) 19 | from pytorch_pfn_extras.training.triggers.once_trigger import ( # NOQA 20 | OnceTrigger, 21 | ) 22 | from pytorch_pfn_extras.training.triggers.time_trigger import ( # NOQA 23 | TimeTrigger, 24 | ) 25 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/triggers/function_trigger.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Callable, 5 | Dict, 6 | Mapping, 7 | Optional, 8 | Sequence, 9 | ) 10 | 11 | from pytorch_pfn_extras.training import trigger as trigger_module 12 | from pytorch_pfn_extras.training._manager_protocol import ( 13 | ExtensionsManagerProtocol, 14 | ) 15 | 16 | if TYPE_CHECKING: 17 | from pytorch_pfn_extras.training._trigger_util import TriggerLike 18 | 19 | 20 | class FunctionTrigger(trigger_module.Trigger): 21 | def __init__( 22 | self, 23 | fn: Callable[..., bool], 24 | args: Optional[Sequence[Any]] = None, 25 | kwargs: Optional[Mapping[str, Any]] = None, 26 | trigger: "TriggerLike" = (1, "iteration"), 27 | ) -> None: 28 | self._fn = fn 29 | self._args = args or [] 30 | self._kwargs = kwargs or {} 31 | self._interval_trigger = trigger_module.get_trigger(trigger) 32 | 33 | def __call__(self, manager: ExtensionsManagerProtocol) -> bool: 34 | if not self._interval_trigger(manager): 35 | return False 36 | 37 | return self._fn(*self._args, **self._kwargs) 38 | 39 | def state_dict(self) -> Dict[str, Any]: 40 | state = { 41 | "interval_trigger": self._interval_trigger.state_dict(), 42 | } 43 | return state 44 | 45 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 46 | self._interval_trigger.load_state_dict(to_load["interval_trigger"]) 47 | 48 | def may_fire(self, iteration: int, epoch_len: int) -> bool: 49 | if self._interval_trigger.may_fire( 50 | iteration=iteration, epoch_len=epoch_len 51 | ): 52 | return self._fn(*self._args, **self._kwargs) 53 | else: 54 | return False 55 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Sequence, Union 2 | 3 | from pytorch_pfn_extras.training import trigger 4 | from pytorch_pfn_extras.training._manager_protocol import ( 5 | ExtensionsManagerProtocol, 6 | ) 7 | 8 | if TYPE_CHECKING: 9 | from pytorch_pfn_extras.training._trigger_util import UnitLiteral 10 | 11 | 12 | class ManualScheduleTrigger(trigger.Trigger): 13 | """Trigger invoked at specified point(s) of iterations or epochs. 14 | 15 | This trigger accepts iterations or epochs indicated by given point(s). 16 | There are two ways to specify the point(s): iteration and epoch. 17 | ``iteration`` means the number of updates, while ``epoch`` means the number 18 | of sweeps over the training dataset. Fractional values are allowed 19 | if the point is a number of epochs; the trigger uses the ``iteration`` 20 | and ``epoch_detail`` attributes defined by the manager. 21 | 22 | Args: 23 | points (int, float, or list of int or float): time of the trigger. 24 | Must be an integer or list of integer if unit is ``'iteration'``. 25 | unit (str): Unit of the time specified by ``points``. It must be 26 | either ``'iteration'`` or ``'epoch'``. 27 | 28 | """ 29 | 30 | def __init__( 31 | self, points: Union[float, Sequence[float]], unit: "UnitLiteral" 32 | ): 33 | if unit not in ("epoch", "iteration"): 34 | raise ValueError( 35 | "Trigger unit must be either 'epoch' or 'iteration'." 36 | ) 37 | 38 | self.points = points if isinstance(points, list) else [points] 39 | self.unit = unit 40 | 41 | def __call__(self, manager: ExtensionsManagerProtocol) -> bool: 42 | """Decides whether the extension should be called on this iteration. 43 | 44 | Args: 45 | manager (~pytorch_pfn_extras.training.ExtensionsManager): 46 | Manager object that this trigger is associated with. 47 | The iteration information in this manager is used to 48 | determine if the trigger should fire. 49 | 50 | Returns: 51 | bool: True if the corresponding extension should be invoked in this 52 | iteration. 53 | 54 | """ 55 | fire = self.may_fire(manager.iteration, manager._iters_per_epoch) 56 | return fire 57 | 58 | def may_fire(self, iteration: int, epoch_length: int) -> bool: 59 | if self.unit == "epoch": 60 | fire = any(int(p * epoch_length) == iteration for p in self.points) 61 | else: 62 | fire = any(p == iteration for p in self.points) 63 | return fire 64 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/triggers/once_trigger.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from pytorch_pfn_extras.training import trigger 4 | from pytorch_pfn_extras.training._manager_protocol import ( 5 | ExtensionsManagerProtocol, 6 | ) 7 | 8 | 9 | class OnceTrigger(trigger.Trigger): 10 | """Trigger based on the starting point of the iteration. 11 | 12 | This trigger accepts only once at starting point of the iteration. There 13 | are two ways to specify the starting point: only starting point in whole 14 | iteration or called again when training resumed. 15 | 16 | Args: 17 | call_on_resume (bool): Whether the extension is called again or not 18 | when restored from a snapshot. It is set to ``False`` by default. 19 | 20 | Attributes: 21 | finished (bool): Flag that indicates whether or not this trigger will 22 | fire in the future. This flag is used to determine if the extension 23 | should be initialized after resume. 24 | 25 | """ 26 | 27 | def __init__(self, call_on_resume: bool = False) -> None: 28 | self._flag_first = True 29 | self._flag_resumed = call_on_resume 30 | 31 | @property 32 | def finished(self) -> bool: 33 | return not (self._flag_first or self._flag_resumed) 34 | 35 | def __call__(self, manager: ExtensionsManagerProtocol) -> bool: 36 | fire = not self.finished 37 | self._flag_resumed = False 38 | self._flag_first = False 39 | return fire 40 | 41 | def state_dict(self) -> Dict[str, Any]: 42 | state = {"_flag_first": self._flag_first} 43 | return state 44 | 45 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 46 | self._flag_first = to_load["_flag_first"] 47 | 48 | def may_fire(self, iteration: int, epoch_length: int) -> bool: 49 | return not (self._flag_first or self._flag_resumed) 50 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/training/triggers/time_trigger.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from pytorch_pfn_extras.training import trigger 4 | from pytorch_pfn_extras.training._manager_protocol import ( 5 | ExtensionsManagerProtocol, 6 | ) 7 | 8 | 9 | class TimeTrigger(trigger.Trigger): 10 | """Trigger based on a fixed time interval. 11 | 12 | This trigger accepts iterations with a given interval time. 13 | 14 | Args: 15 | period (float): Interval time. It is given in seconds. 16 | 17 | """ 18 | 19 | def __init__(self, period: float) -> None: 20 | self._period = period 21 | self._next_time = self._period 22 | 23 | def __call__(self, manager: ExtensionsManagerProtocol) -> bool: 24 | if self._next_time < manager.elapsed_time: 25 | self._next_time += self._period 26 | return True 27 | else: 28 | return False 29 | 30 | def state_dict(self) -> Dict[str, Any]: 31 | state = {"next_time": self._next_time} 32 | return state 33 | 34 | def load_state_dict(self, to_load: Dict[str, Any]) -> None: 35 | self._next_time = to_load["next_time"] 36 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.utils import checkpoint # NOQA 2 | from pytorch_pfn_extras.utils import comparer # NOQA 3 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn 4 | import torch.utils.checkpoint 5 | 6 | 7 | class _CheckpointFunction(torch.utils.checkpoint.CheckpointFunction): 8 | """Checkpoint a model or part of the model with BN support. 9 | 10 | Refer to https://pytorch.org/docs/stable/checkpoint.html 11 | for detailed information. 12 | When using checkpointing in model using BatchNormalization, the 13 | momentum is updated twice, while we only need one update to ensure 14 | correctness. 15 | Using `ppe.utils.checkpointing.checkpoint` as a drop-in replacement 16 | can help deal with incorrect values in the BatchNormalization 17 | persistent parameters. 18 | """ 19 | 20 | @staticmethod 21 | def forward( # type: ignore[override] 22 | ctx: Any, 23 | run_function: Any, 24 | preserve_rng_state: bool, 25 | *args: Any, 26 | ) -> Any: 27 | _patch_bn_momentum(run_function) 28 | return super(_CheckpointFunction, _CheckpointFunction).forward( 29 | ctx, run_function, preserve_rng_state, *args 30 | ) 31 | 32 | 33 | def _patch_bn_momentum(module: torch.nn.Module) -> None: 34 | if not hasattr(module, "_bn_momentum_patched"): 35 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 36 | return 37 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 38 | # Set momentum so that two forward passes will produce the same 39 | # EMA as one forward pass. 40 | if module.momentum is not None: 41 | module.momentum = 1 - (1 - module.momentum) ** 0.5 42 | else: 43 | # NOTE(linsho): 44 | # In the case of cumulative moving average mode, the operation is 45 | # equivalent to not using checkpoints even if you do nothing. 46 | pass 47 | for _, child in module.named_children(): 48 | _patch_bn_momentum(child) 49 | module._bn_momentum_patched = True # type: ignore[assignment] 50 | 51 | 52 | def checkpoint(function: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: 53 | # Hack to mix *args with **kwargs in a python 2.7-compliant way 54 | preserve = kwargs.pop("preserve_rng_state", True) 55 | if kwargs: 56 | raise ValueError( 57 | "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) 58 | ) 59 | return _CheckpointFunction.apply(function, preserve, *args) # type: ignore[no-untyped-call] 60 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/writing/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.writing._parallel_writer import ProcessWriter # NOQA 2 | from pytorch_pfn_extras.writing._parallel_writer import ThreadWriter # NOQA 3 | from pytorch_pfn_extras.writing._queue_writer import ProcessQueueWriter # NOQA 4 | from pytorch_pfn_extras.writing._queue_writer import QueueWriter # NOQA 5 | from pytorch_pfn_extras.writing._queue_writer import ThreadQueueWriter # NOQA 6 | from pytorch_pfn_extras.writing._simple_writer import SimpleWriter # NOQA 7 | from pytorch_pfn_extras.writing._tensorboard_writer import ( # NOQA 8 | TensorBoardWriter, 9 | ) 10 | from pytorch_pfn_extras.writing._writer_base import StandardWriter # NOQA 11 | from pytorch_pfn_extras.writing._writer_base import Writer # NOQA 12 | -------------------------------------------------------------------------------- /pytorch_pfn_extras/writing/_simple_writer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch 4 | from pytorch_pfn_extras.writing._writer_base import ( 5 | Writer, 6 | _FileSystem, 7 | _SaveFun, 8 | _TargetType, 9 | ) 10 | 11 | 12 | class SimpleWriter(Writer): 13 | """The most simple snapshot writer. 14 | 15 | This class just passes the arguments to the actual saving function. 16 | 17 | Args: 18 | savefun: Callable object. It takes three arguments: the output file 19 | path, the serialized dictionary object, and the optional keyword 20 | arguments. 21 | fs: FileSystem abstracting interface to implement all the operations. 22 | optional, defaults to None 23 | out_dir: str. Specifies the directory this writer will use. 24 | It takes precedence over the one specified in `__call__` 25 | optional, defaults to ``''`` 26 | kwds: Keyword arguments for the ``savefun``. 27 | 28 | .. seealso:: 29 | 30 | - :meth:`pytorch_pfn_extras.training.extensions.snapshot` 31 | """ 32 | 33 | def __init__( 34 | self, 35 | savefun: _SaveFun = torch.save, 36 | fs: _FileSystem = None, 37 | out_dir: str = "", 38 | **kwds: Any, 39 | ) -> None: 40 | super().__init__(fs=fs, out_dir=out_dir) 41 | self._savefun = savefun 42 | self._kwds = kwds 43 | 44 | def __call__( 45 | self, 46 | filename: str, 47 | out_dir: str, 48 | target: _TargetType, 49 | *, 50 | savefun: Optional[_SaveFun] = None, 51 | append: bool = False, 52 | ) -> None: 53 | if savefun is None: 54 | savefun = self._savefun 55 | self.save(filename, out_dir, target, savefun, append, **self._kwds) 56 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # automatically generated by pysen 3 | # pysen ignores and overwrites any modifications 4 | # e203: black treats : as a binary operator 5 | # e231: black doesn't put a space after , 6 | # e501: black may exceed the line-length to follow other style rules 7 | # e701: black will collapse ... only functions etc. to a single line 8 | # e704: black will collapse ... only functions etc. to a single line 9 | # w503 or w504: either one needs to be disabled to select w error codes 10 | ignore = E203,E231,E501,E701,E704,W503 11 | max-line-length = 80 12 | select = B,B950,C,E,F,W 13 | 14 | [mypy] 15 | # automatically generated by pysen 16 | # pysen ignores and overwrites any modifications 17 | check_untyped_defs = True 18 | disallow_any_decorated = False 19 | disallow_any_generics = False 20 | disallow_any_unimported = False 21 | disallow_incomplete_defs = True 22 | disallow_subclassing_any = True 23 | disallow_untyped_calls = True 24 | disallow_untyped_decorators = False 25 | disallow_untyped_defs = True 26 | ignore_errors = False 27 | ignore_missing_imports = True 28 | mypy_path = stubs 29 | no_implicit_optional = True 30 | python_version = 3.9 31 | show_error_codes = True 32 | strict_equality = True 33 | strict_optional = True 34 | warn_redundant_casts = True 35 | warn_return_any = True 36 | warn_unreachable = True 37 | warn_unused_configs = True 38 | warn_unused_ignores = False 39 | 40 | [mypy-torch.*] 41 | # automatically generated by pysen 42 | # pysen ignores and overwrites any modifications 43 | ignore_errors = True 44 | 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import setuptools 4 | 5 | here = os.path.abspath(os.path.dirname(__file__)) 6 | # Get __version__ variable 7 | exec(open(os.path.join(here, "pytorch_pfn_extras", "_version.py")).read()) 8 | 9 | long_description = open(os.path.join(here, "README.md")).read() 10 | 11 | setuptools.setup( 12 | name="pytorch-pfn-extras", 13 | version=__version__, # NOQA 14 | description="Supplementary components to accelerate research and " 15 | "development in PyTorch.", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | author="Preferred Networks, Inc.", 19 | license="MIT License", 20 | install_requires=["numpy", "packaging", "torch", "typing-extensions>=3.10"], 21 | extras_require={ 22 | "onnx": ["onnx"], 23 | }, 24 | python_requires=">=3.9.0", 25 | packages=setuptools.find_packages(exclude=["tests", "tests.*"]), 26 | package_data={"pytorch_pfn_extras": ["py.typed"]}, 27 | ) 28 | -------------------------------------------------------------------------------- /stubs/torch/fx/__init__.pyi: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .graph import Graph as Graph 3 | from .graph_module import GraphModule as GraphModule 4 | from .node import Node as Node 5 | from .proxy import GraphAppendingTracer as GraphAppendingTracer 6 | from .proxy import Proxy as Proxy 7 | -------------------------------------------------------------------------------- /stubs/torch/fx/graph.pyi: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Any, Callable, Dict, List, Optional, Tuple 3 | 4 | from .node import Node 5 | 6 | class Graph: 7 | nodes: List[Node] 8 | 9 | def call_function( 10 | self, 11 | the_function: Callable[..., Any], 12 | args: Optional[Any] = None, 13 | kwargs: Optional[Any] = None, 14 | type_expr: Optional[Any] = None, 15 | ) -> Node: ... 16 | def output(self, result: Any, type_expr: Optional[Any] = None) -> Node: ... 17 | def placeholder( 18 | self, 19 | name: str, 20 | type_expr: Optional[Any] = None, 21 | default_value: Optional[Any] = None, 22 | ) -> Node: ... 23 | def node_copy( 24 | self, node: Node, arg_transform: Callable[[Node], "Argument"] 25 | ) -> Node: ... 26 | def inserting_after(self, n: Optional[Node] = None): ... 27 | def create_node( 28 | self, 29 | op: str, 30 | target: Any, 31 | args: Optional[Tuple[Any, ...]] = None, 32 | kwargs: Optional[Dict[str, Any]] = None, 33 | name: Optional[str] = None, 34 | type_expr: Optional[Any] = None, 35 | ) -> Node: ... 36 | def erase_node(self, to_erase: Node) -> None: ... 37 | ... 38 | -------------------------------------------------------------------------------- /stubs/torch/fx/graph_module.pyi: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Any, Dict 3 | 4 | import torch 5 | 6 | from .graph import Graph 7 | 8 | class GraphModule: 9 | graph: Graph 10 | 11 | def __init__( 12 | self, 13 | root: Union[torch.nn.Module, Dict[str, Any]], 14 | graph: Graph, 15 | class_name: str = "GraphModule", 16 | ): ... 17 | ... 18 | -------------------------------------------------------------------------------- /stubs/torch/fx/node.pyi: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Any, Callable, Dict, List, Optional, Tuple 3 | 4 | class Node: 5 | op: str 6 | name: str 7 | target: Any 8 | meta: Any 9 | args: Tuple[Any, ...] 10 | kwargs: Dict[str, Any] 11 | ... 12 | -------------------------------------------------------------------------------- /stubs/torch/fx/proxy.pyi: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Any 3 | 4 | from .graph import Graph, Node 5 | 6 | class GraphAppendingTracer: 7 | def __init__(self, graph: Graph): ... 8 | 9 | class Proxy: 10 | def __init__(self, node: Node, tracer: GraphAppendingTracer): ... 11 | -------------------------------------------------------------------------------- /stubs/torch/library/__init__.pyi: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Any, Callable 3 | 4 | class Library: 5 | def __init__(self, ns: str, kind: str, dispatch_key: str = "") -> None: ... 6 | def impl( 7 | self, name: str, fn: Callable[..., Any], dispatch_key: str = "" 8 | ) -> None: ... 9 | def define(self, name: str) -> None: ... 10 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | try: 2 | # Make sure that onnx is imported before importing torch in the test run. 3 | import onnx # NOQA 4 | except Exception: 5 | pass 6 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/cuda_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/cuda_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataloader_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/dataloader_test/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class DummyDataset(torch.utils.data.Dataset): 6 | def __init__(self): 7 | self.data = list(range(10)) 8 | 9 | def __len__(self): 10 | return len(self.data) 11 | 12 | def __getitem__(self, idx): 13 | if torch.is_tensor(idx): 14 | idx = idx.tolist() 15 | # The persistent workers always maintain the original 16 | # dataset through the dataloader lifetime 17 | # so the attributes will remain the same as the 18 | # first time the workers where spawned (dataloader iteration) 19 | assert self.start == 0 20 | return self.data[idx] 21 | 22 | 23 | def test_data_loader_persistent(): 24 | dataset = DummyDataset() 25 | dataloader = ppe.dataloaders.DataLoader( 26 | dataset, num_workers=1, persistent_workers=True 27 | ) 28 | dataset.start = 0 29 | for i in range(10): 30 | for _ in dataloader: 31 | pass 32 | # Changing the start value here doesn't have any effect in the dataset 33 | # cached by the workers. since they are not recreated between epochs 34 | # and can cache values safely 35 | dataset.start = i 36 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/dataset_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_pfn_extras as ppe 3 | 4 | 5 | class DummyDataset(ppe.dataset.TabularDataset): 6 | def __init__( 7 | self, 8 | size=10, 9 | keys=("a", "b", "c"), 10 | mode=tuple, 11 | return_array=False, 12 | callback=None, 13 | convert=False, 14 | ): 15 | if mode is None: 16 | keys = (keys[0],) 17 | 18 | self._keys = keys 19 | self._mode = mode 20 | self._return_array = return_array 21 | self._callback = callback 22 | self._convert = convert 23 | 24 | self.data = np.random.uniform(size=(len(keys), size)) 25 | 26 | def __len__(self): 27 | return self.data.shape[1] 28 | 29 | @property 30 | def keys(self): 31 | return self._keys 32 | 33 | @property 34 | def mode(self): 35 | return self._mode 36 | 37 | def get_examples(self, indices, key_indices): 38 | if self._callback: 39 | self._callback(indices, key_indices) 40 | 41 | data = self.data 42 | if indices is not None: 43 | data = data[:, indices] 44 | if key_indices is not None: 45 | data = data[list(key_indices)] 46 | 47 | if self._return_array: 48 | return tuple(data) 49 | else: 50 | return tuple(list(d) for d in data) 51 | 52 | def convert(self, data): 53 | if self._convert: 54 | return "converted" 55 | else: 56 | return super(DummyDataset, self).convert(data) 57 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytorch_pfn_extras as ppe 3 | from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( 4 | dummy_dataset, # NOQA 5 | ) 6 | 7 | 8 | @pytest.mark.parametrize("mode", [tuple, dict, None]) 9 | def test_astuple(mode): 10 | dataset = dummy_dataset.DummyDataset(mode=mode, convert=True) 11 | view = dataset.astuple() 12 | assert isinstance(view, ppe.dataset.TabularDataset) 13 | assert len(view) == len(dataset) 14 | assert view.keys == dataset.keys 15 | assert view.mode == tuple 16 | assert view.get_examples(None, None) == dataset.get_examples(None, None) 17 | assert view.convert(view.fetch()) == "converted" 18 | 19 | 20 | @pytest.mark.parametrize("mode", [tuple, dict, None]) 21 | def test_asdict(mode): 22 | dataset = dummy_dataset.DummyDataset(mode=mode, convert=True) 23 | view = dataset.asdict() 24 | assert isinstance(view, ppe.dataset.TabularDataset) 25 | assert len(view) == len(dataset) 26 | assert view.keys == dataset.keys 27 | assert view.mode == dict 28 | assert view.get_examples(None, None) == dataset.get_examples(None, None) 29 | assert view.convert(view.fetch()) == "converted" 30 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytorch_pfn_extras as ppe 3 | from pytorch_pfn_extras.dataset import tabular 4 | from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( 5 | dummy_dataset, # NOQA 6 | ) 7 | 8 | 9 | @pytest.mark.parametrize("mode", [tuple, dict, None]) 10 | def test_delegate_dataset(mode): 11 | dataset = tabular.DelegateDataset(dummy_dataset.DummyDataset(mode=mode)) 12 | 13 | assert isinstance(dataset, ppe.dataset.TabularDataset) 14 | assert len(dataset) == len(dataset.dataset) 15 | assert dataset.keys == dataset.dataset.keys 16 | assert dataset.mode == dataset.dataset.mode 17 | assert dataset.get_example(3) == dataset.dataset.get_example(3) 18 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import pytorch_pfn_extras as ppe 4 | from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( 5 | dummy_dataset, # NOQA 6 | ) 7 | 8 | 9 | @pytest.mark.parametrize("mode", [tuple, dict, None]) 10 | def test_with_converter(mode): 11 | dataset = dummy_dataset.DummyDataset(mode=mode) 12 | 13 | def converter(*args, **kwargs): 14 | if mode is tuple: 15 | np.testing.assert_equal(args, tuple(dataset.data)) 16 | assert kwargs == {} 17 | elif mode is dict: 18 | assert args == () 19 | np.testing.assert_equal( 20 | kwargs, dict(zip(("a", "b", "c"), dataset.data)) 21 | ) 22 | elif mode is None: 23 | np.testing.assert_equal(args, tuple(dataset.data)) 24 | assert kwargs == {} 25 | 26 | return "converted" 27 | 28 | view = dataset.with_converter(converter) 29 | assert isinstance(view, ppe.dataset.TabularDataset) 30 | assert len(view) == len(dataset) 31 | assert view.keys == dataset.keys 32 | assert view.mode == dataset.mode 33 | assert view.get_examples(None, None) == dataset.get_examples(None, None) 34 | assert view.convert(view.fetch()) == "converted" 35 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( # NOQA 4 | dummy_dataset, 5 | ) 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "batch_size,mode", 10 | [(1, dict), (2, dict), (8, dict), (1, tuple), (2, tuple), (8, tuple)], 11 | ) 12 | def test_with_dataloader(batch_size, mode): 13 | size = 10 14 | keys = ("a", "b", "c") 15 | dataset = dummy_dataset.DummyDataset(size=size, keys=keys, mode=mode) 16 | expected = torch.tensor(dataset.data).type(torch.float64) 17 | expected_per_key = [ 18 | [ 19 | expected[i, j * batch_size : (j + 1) * batch_size] 20 | for j in range((size + batch_size - 1) // batch_size) 21 | ] 22 | for i in range(len(keys)) 23 | ] 24 | print(expected_per_key) 25 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) 26 | for i, example in enumerate(dataloader): 27 | for j, key in enumerate(keys): 28 | assert torch.allclose( 29 | expected_per_key[j][i], example[key if mode == dict else j] 30 | ) 31 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/dataset_tests/test_shared_dataset.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class DummySharedDataset(ppe.dataset.SharedDataset): 6 | def __init__(self): 7 | self.data = torch.arange(100).reshape(100, 1) 8 | super().__init__(self.data.shape) 9 | 10 | def __getitem__(self, idx): 11 | try: 12 | x = super().__getitem__(idx) 13 | except ppe.dataset.ItemNotFoundException: 14 | x = self.data[idx] 15 | self.cache_item(idx, x) 16 | return x 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | 22 | def test_empty_shared_dataset(): 23 | dataset = DummySharedDataset() 24 | for i in range(100): 25 | assert not dataset.is_cached(i) 26 | 27 | 28 | def test_shared_dataset(): 29 | dataset = DummySharedDataset() 30 | dataloader = torch.utils.data.DataLoader(dataset, num_workers=0) 31 | for _ in dataloader: 32 | pass 33 | for i in range(100): 34 | assert dataset.is_cached(i) 35 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/distributed_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/distributed_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_subset_indices.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.distributed import create_distributed_subset_indices 2 | 3 | 4 | def test_not_shuffle() -> None: 5 | indices0 = create_distributed_subset_indices( 6 | num_total_samples=10, 7 | num_replicas=3, 8 | rank=0, 9 | shuffle=False, 10 | ) 11 | assert indices0 == [0, 1, 2, 3] 12 | 13 | indices1 = create_distributed_subset_indices( 14 | num_total_samples=10, 15 | num_replicas=3, 16 | rank=1, 17 | shuffle=False, 18 | ) 19 | assert indices1 == [3, 4, 5, 6] 20 | 21 | indices1 = create_distributed_subset_indices( 22 | num_total_samples=10, 23 | num_replicas=3, 24 | rank=2, 25 | shuffle=False, 26 | ) 27 | assert indices1 == [6, 7, 8, 9] 28 | 29 | 30 | def test_shuffle() -> None: 31 | indices0 = create_distributed_subset_indices( 32 | num_total_samples=10, num_replicas=3, rank=0, shuffle=True, seed=0 33 | ) 34 | indices1 = create_distributed_subset_indices( 35 | num_total_samples=10, num_replicas=3, rank=1, shuffle=True, seed=0 36 | ) 37 | indices2 = create_distributed_subset_indices( 38 | num_total_samples=10, num_replicas=3, rank=2, shuffle=True, seed=0 39 | ) 40 | assert len(indices0 + indices1 + indices2) == 12 41 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/handler_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/handler_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/nn_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/nn_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pfn_extras.nn import ( # NOQA 3 | LazyBatchNorm1d, 4 | LazyBatchNorm2d, 5 | LazyBatchNorm3d, 6 | ) 7 | from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import ( 8 | LazyTestBase, 9 | ) 10 | from torch import nn 11 | 12 | 13 | class TestLazyBatchNorm1d(LazyTestBase): 14 | def get_original_module(self): 15 | return nn.BatchNorm1d(10) 16 | 17 | def get_lazy_module(self): 18 | return LazyBatchNorm1d(None) 19 | 20 | def get_input(self): 21 | return torch.rand(10, 10) 22 | 23 | 24 | class TestLazyBatchNorm2d(LazyTestBase): 25 | def get_original_module(self): 26 | return nn.BatchNorm2d(10) 27 | 28 | def get_lazy_module(self): 29 | return LazyBatchNorm2d(None) 30 | 31 | def get_input(self): 32 | return torch.rand(10, 10, 10, 10) 33 | 34 | 35 | class TestLazyBatchNorm3d(LazyTestBase): 36 | def get_original_module(self): 37 | return nn.BatchNorm3d(10) 38 | 39 | def get_lazy_module(self): 40 | return LazyBatchNorm3d(None) 41 | 42 | def get_input(self): 43 | return torch.rand(10, 10, 10, 10, 10) 44 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pfn_extras.nn import LazyConv1d, LazyConv2d, LazyConv3d 3 | from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import ( 4 | LazyTestBase, 5 | ) 6 | from torch import nn 7 | 8 | 9 | class TestLazyConv1d(LazyTestBase): 10 | def get_original_module(self): 11 | return nn.Conv1d(3, 4, 2) 12 | 13 | def get_lazy_module(self): 14 | return LazyConv1d(None, 4, 2) 15 | 16 | def get_input(self): 17 | return torch.rand(4, 3, 10) 18 | 19 | 20 | class TestLazyConv2d(LazyTestBase): 21 | def get_original_module(self): 22 | return nn.Conv2d(3, 4, 2) 23 | 24 | def get_lazy_module(self): 25 | return LazyConv2d(None, 4, 2) 26 | 27 | def get_input(self): 28 | return torch.rand(4, 3, 10, 10) 29 | 30 | 31 | class TestLazyConv3d(LazyTestBase): 32 | def get_original_module(self): 33 | return nn.Conv3d(3, 4, 2) 34 | 35 | def get_lazy_module(self): 36 | return LazyConv3d(None, 4, 2) 37 | 38 | def get_input(self): 39 | return torch.rand(4, 3, 10, 10, 10) 40 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pfn_extras.nn import LazyLinear 3 | from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import ( 4 | LazyTestBase, 5 | ) 6 | from torch import nn 7 | 8 | 9 | class TestLazyLinear(LazyTestBase): 10 | def get_original_module(self): 11 | return nn.Linear(10, 20) 12 | 13 | def get_lazy_module(self): 14 | return LazyLinear(None, 20) 15 | 16 | def get_input(self): 17 | return torch.rand(20, 10) 18 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/onnx_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/onnx_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/onnx_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | @pytest.fixture(scope='function', autouse=True) 5 | def init_rand_seed(): 6 | torch.manual_seed(100) 7 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/onnx_tests/test_load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | import pytorch_pfn_extras.onnx as tou 7 | from pytorch_pfn_extras_tests.onnx_tests.test_export_testcase import Net 8 | 9 | 10 | @pytest.mark.filterwarnings("ignore:Named tensors .* experimental:UserWarning") 11 | def test_onnx_load_model(): 12 | model = Net() 13 | outdir = "out/load_model_test" 14 | tou.export_testcase(model, torch.rand(1, 1, 28, 28), outdir, 15 | training=True, do_constant_folding=False) 16 | tou.load_model(os.path.join(outdir, "model.onnx")) 17 | 18 | 19 | @pytest.mark.filterwarnings("ignore:.*ONNX contains stripped .*:UserWarning") 20 | def test_stripped_onnx_load_model(): 21 | model = Net() 22 | outdir = "out/stripped_load_model_test" 23 | tou.export_testcase(model, torch.rand(1, 1, 28, 28), outdir, 24 | strip_large_tensor_data=True, training=True, 25 | do_constant_folding=False) 26 | tou.load_model(os.path.join(outdir, "model.onnx")) 27 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torchvision 4 | 5 | import pytorch_pfn_extras 6 | from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test 7 | 8 | 9 | resnet18_kwargs = {'weights': None} 10 | 11 | @pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect:torch.jit.TracerWarning") 12 | def test_eval_resnet18(): 13 | old_allow_tf32 = torch.backends.cudnn.allow_tf32 14 | try: 15 | torch.backends.cudnn.allow_tf32 = False 16 | run_model_test( 17 | torchvision.models.resnet.resnet18(**resnet18_kwargs), 18 | (torch.rand(1, 3, 224, 224),), 19 | rtol=1e-03, 20 | use_gpu=True, 21 | ) 22 | finally: 23 | torch.backends.cudnn.allow_tf32 = old_allow_tf32 24 | 25 | 26 | @pytest.mark.gpu 27 | @pytest.mark.xfail 28 | def test_train_resnet18(): 29 | run_model_test( 30 | torchvision.models.resnet.resnet18(**resnet18_kwargs), 31 | (torch.rand(1, 3, 224, 224),), 32 | rtol=1e-03, 33 | use_gpu=True, 34 | mode="train", 35 | ) 36 | 37 | 38 | @pytest.mark.gpu 39 | @pytest.mark.filterwarnings("ignore:__floordiv__ is deprecated:UserWarning") 40 | def test_shufflenet(): 41 | run_model_test( 42 | torchvision.models.shufflenetv2.shufflenet_v2_x1_0(), 43 | (torch.rand(1, 3, 224, 224),), 44 | use_gpu=True, 45 | ) 46 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/profiler_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/profiler_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/runtime_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/runtime_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | class FallbackRuntime(ppe.runtime.BaseRuntime): 6 | pass 7 | 8 | 9 | class MyCustomRuntime(ppe.runtime.PyTorchRuntime): 10 | pass 11 | 12 | 13 | def test_registry_register(): 14 | registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime) 15 | registry.register("dummy_device", MyCustomRuntime) 16 | assert ( 17 | registry.get_runtime_class_for_device_spec("dummy_device") 18 | == MyCustomRuntime 19 | ) 20 | 21 | 22 | def test_registry_fallback(): 23 | registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime) 24 | registry.register("dummy_device", MyCustomRuntime) 25 | assert ( 26 | registry.get_runtime_class_for_device_spec("unknown_device") 27 | == FallbackRuntime 28 | ) 29 | 30 | 31 | def test_registry_torch_device(): 32 | registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime) 33 | registry.register("cpu", MyCustomRuntime) 34 | assert ( 35 | registry.get_runtime_class_for_device_spec(torch.device("cpu")) 36 | == MyCustomRuntime 37 | ) 38 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/test_logging.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from pytorch_pfn_extras import logging 4 | 5 | 6 | def test_file_output(): 7 | try: 8 | with tempfile.NamedTemporaryFile() as logfile: 9 | logfile.close() # this is needed for Windows 10 | logging._configure_logging(filename=logfile.name, level="DEBUG") 11 | logger = logging._get_root_logger() 12 | logger.info("TEST LOG MESSAGE") 13 | with open(logfile.name) as f: 14 | assert "TEST LOG MESSAGE" in f.read() 15 | finally: 16 | logging._configure_logging() 17 | 18 | 19 | def test_get_logger(): 20 | logger = logging.get_logger("app") 21 | logger.setLevel(logging.DEBUG) 22 | assert logger.name == "ppe.app" 23 | assert logger.level == logging.DEBUG 24 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/test_ops/test_register.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | import pytorch_pfn_extras as ppe 5 | import torch 6 | 7 | 8 | def _get_function_nodes(fx_module): 9 | return [ 10 | node for node in fx_module.graph.nodes if node.op == "call_function" 11 | ] 12 | 13 | 14 | @pytest.mark.skipif( 15 | not ppe.requires("2.1.0") or sys.platform == "win32", 16 | reason="torch custom ops only works for PyTorch>=2.1 and linux", 17 | ) 18 | def test_register(): 19 | def test(a): 20 | return a * 2 21 | 22 | def test_bwd(g, a): 23 | return g 24 | 25 | def test_meta(a): 26 | return torch.empty_like(a) 27 | 28 | def test_bwd_meta(g, a): 29 | return torch.empty_like(a) 30 | 31 | fwd_op = ppe.ops.OpDesc(test, test_meta, "(Tensor a) -> Tensor") 32 | bwd_op = ppe.ops.OpDesc( 33 | test_bwd, test_bwd_meta, "(Tensor g, Tensor a) -> Tensor" 34 | ) 35 | ppe.ops.register("test", fwd_op, bwd_op) 36 | 37 | class TestModule(torch.nn.Module): 38 | def forward(self, a): 39 | # Call the custom function 40 | return torch.ops.ppe.test(a) 41 | 42 | found_fwd_op = False 43 | found_bwd_op = False 44 | 45 | from functorch.compile import make_boxed_func 46 | from torch._dynamo.backends.common import aot_autograd 47 | 48 | # Detect the custom ops 49 | def fwd_compiler_fn(fx_module: torch.fx.GraphModule, _): 50 | nonlocal found_fwd_op 51 | function_nodes = _get_function_nodes(fx_module) 52 | assert len(function_nodes) == 1 53 | found_fwd_op = ( 54 | function_nodes[0].target is torch.ops.ppe.test_fwd.default 55 | ) 56 | return make_boxed_func(fx_module) 57 | 58 | def bwd_compiler_fn(fx_module: torch.fx.GraphModule, _): 59 | nonlocal found_bwd_op 60 | function_nodes = _get_function_nodes(fx_module) 61 | assert len(function_nodes) == 1 62 | found_bwd_op = ( 63 | function_nodes[0].target is torch.ops.ppe.test_bwd.default 64 | ) 65 | return make_boxed_func(fx_module) 66 | 67 | aot_backend = aot_autograd( # type: ignore[no-untyped-call] 68 | fw_compiler=fwd_compiler_fn, 69 | bw_compiler=bwd_compiler_fn, 70 | ) 71 | m = TestModule() 72 | torch._dynamo.reset() 73 | module_opt = torch.compile(m, fullgraph=True, backend=aot_backend) 74 | shape = [1, 16, 2048, 128] 75 | x = torch.ones(shape, requires_grad=True) 76 | y = module_opt(x) 77 | y.sum().backward() 78 | assert found_fwd_op 79 | assert found_bwd_op 80 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/test_torchscript.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras.torchscript as ts 2 | import torch 3 | 4 | 5 | def test_find_inplace(): 6 | def f(v: torch.Tensor) -> None: 7 | v += torch.ones((1, 2, 3)) 8 | 9 | def g(v: torch.Tensor): 10 | f(v) 11 | 12 | s = torch.jit.script(g) 13 | 14 | new_g, inplace_nodes = ts.find_inplace(s.graph) 15 | assert len(inplace_nodes) == 1 16 | 17 | 18 | def test_find_inplace_not_found(): 19 | def f(v: torch.Tensor) -> torch.Tensor: 20 | return torch.ones((1, 2, 3)) 21 | 22 | s = torch.jit.script(f) 23 | 24 | new_g, inplace_nodes = ts.find_inplace(s.graph) 25 | assert len(inplace_nodes) == 0 26 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/test_writing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | import pytorch_pfn_extras as ppe 6 | 7 | 8 | @pytest.mark.filterwarnings( 9 | "ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning" 10 | ) 11 | @pytest.mark.filterwarnings( 12 | "ignore:distutils Version classes are deprecated. Use packaging.version instead.:DeprecationWarning" 13 | ) 14 | def test_tensorboard_writing(): 15 | pytest.importorskip("tensorboard") 16 | data = {"a": 1, "iteration": 1} 17 | with tempfile.TemporaryDirectory() as tempd: 18 | writer = ppe.writing.TensorBoardWriter( 19 | out_dir=tempd, filename_suffix="_test" 20 | ) 21 | writer(None, None, data) 22 | # Check that the file was generated 23 | for snap in os.listdir(tempd): 24 | assert "_test" in snap 25 | writer.finalize() 26 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/training_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import numpy 4 | import pytorch_pfn_extras as ppe 5 | 6 | 7 | def test_run(tmp_path: pathlib.Path): 8 | trigger_iters = 3 9 | data_shape = (4, trigger_iters) 10 | data_total = numpy.random.randint(7, 32, size=data_shape) 11 | 12 | # NumPy<1.17 does not support array-like inputs in `numpy.random.randint`. 13 | data_correct = numpy.random.randint(10000, size=data_shape) % data_total 14 | 15 | manager = ppe.training.ExtensionsManager( 16 | {}, [], 100, iters_per_epoch=5, out_dir=str(tmp_path) 17 | ) 18 | 19 | extension = ppe.training.extensions.MicroAverage( 20 | "main/correct", 21 | "main/total", 22 | "main/accuracy", 23 | (trigger_iters, "iteration"), 24 | ) 25 | manager.extend(extension, trigger=(1, "iteration")) 26 | 27 | for js in numpy.ndindex(data_shape): 28 | with manager.run_iteration(): 29 | ppe.reporting.report( 30 | { 31 | "main/correct": data_correct[js], 32 | "main/total": data_total[js], 33 | } 34 | ) 35 | assert ( 36 | # average is computed every trigger_iters 37 | ("main/accuracy" in manager.observation) 38 | == (js[1] == trigger_iters - 1) 39 | ) 40 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | from pytorch_pfn_extras.training import extensions 5 | 6 | 7 | @pytest.fixture(scope="module") 8 | def matplotlib_or_none(): 9 | try: 10 | import matplotlib 11 | 12 | return matplotlib 13 | except ImportError: 14 | return None 15 | 16 | 17 | @pytest.fixture(scope="module") 18 | def matplotlib(matplotlib_or_none): 19 | if matplotlib_or_none is None: 20 | pytest.skip("matplotlib is not installed") 21 | return matplotlib_or_none 22 | 23 | 24 | def test_available(matplotlib_or_none): 25 | if matplotlib_or_none is not None: 26 | assert extensions.PlotReport.available() is True 27 | else: 28 | # It shows warning only when matplotlib is not available 29 | with pytest.warns(UserWarning): 30 | assert extensions.PlotReport.available() is False 31 | 32 | 33 | # TODO(kataoka): lazy import does not seem to be required with matplotlib v3 34 | def test_lazy_import(matplotlib): 35 | # matplotlib.pyplot should be lazily imported because matplotlib.use 36 | # has to be called earlier. 37 | 38 | with warnings.catch_warnings(): 39 | warnings.simplefilter("error") 40 | matplotlib.use("Agg") 41 | # Test again with a different backend, because the above does not 42 | # generate a warning if matplotlib.use('Agg') is called and then 43 | # matplotlib.pyplot is imported. 44 | matplotlib.use("PS") 45 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pathlib 3 | 4 | import pytorch_pfn_extras as ppe 5 | from pytorch_pfn_extras.training import extensions 6 | 7 | 8 | def test_run_print_report(tmp_path: pathlib.Path): 9 | max_epochs = 5 10 | iters_per_epoch = 5 11 | manager = ppe.training.ExtensionsManager( 12 | {}, 13 | {}, 14 | max_epochs, 15 | iters_per_epoch=iters_per_epoch, 16 | out_dir=str(tmp_path), 17 | ) 18 | 19 | out = io.StringIO() 20 | log_report = extensions.LogReport() 21 | manager.extend(log_report) 22 | extension = extensions.PrintReport(out=out) 23 | manager.extend(extension) 24 | 25 | for _ in range(max_epochs): 26 | for _ in range(iters_per_epoch): 27 | with manager.run_iteration(): 28 | pass 29 | assert "epoch elapsed_time" in out.getvalue() 30 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pathlib 3 | 4 | import pytest 5 | import pytorch_pfn_extras as ppe 6 | from pytorch_pfn_extras.training.extensions import _ipython_module_available 7 | from pytorch_pfn_extras.training.extensions.log_report import _pandas_available 8 | 9 | 10 | @pytest.mark.skipif( 11 | not _ipython_module_available or not _pandas_available, 12 | reason="print report notebook import failed, " 13 | "maybe ipython is not installed", 14 | ) 15 | def test_run_print_report_notebook(tmp_path: pathlib.Path): 16 | max_epochs = 5 17 | iters_per_epoch = 5 18 | manager = ppe.training.ExtensionsManager( 19 | {}, 20 | {}, 21 | max_epochs, 22 | iters_per_epoch=iters_per_epoch, 23 | out_dir=str(tmp_path), 24 | ) 25 | 26 | out = io.StringIO() 27 | log_report = ppe.training.extensions.LogReport() 28 | manager.extend(log_report) 29 | extension = ppe.training.extensions.PrintReportNotebook(out=out) 30 | manager.extend(extension) 31 | 32 | for _ in range(max_epochs): 33 | for _ in range(iters_per_epoch): 34 | with manager.run_iteration(): 35 | # Only test it runs without fail 36 | # The value is not tested now... 37 | pass 38 | 39 | 40 | if __name__ == "__main__": 41 | pytest.main([__file__, "-v", "-s"]) 42 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import tempfile 4 | import time 5 | 6 | import pytest 7 | import pytorch_pfn_extras as ppe 8 | import yaml 9 | 10 | 11 | def _body(): 12 | with ppe.profiler.get_time_summary().report("iter-time"): 13 | time.sleep(0.1) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "format,append", 18 | [ 19 | ("json", False), 20 | ("json-lines", True), 21 | ("json-lines", False), 22 | ("yaml", True), 23 | ("yaml", False), 24 | ], 25 | ) 26 | def test_profile_report(format, append): 27 | ext = ppe.training.extensions.ProfileReport(format=format, append=append) 28 | max_epochs = 3 29 | iters_per_epoch = 5 30 | # ppe.profiler.time_summary.clear() 31 | with tempfile.TemporaryDirectory() as tmpdir: 32 | manager = ppe.training.ExtensionsManager( 33 | {}, 34 | {}, 35 | max_epochs=max_epochs, 36 | iters_per_epoch=iters_per_epoch, 37 | out_dir=tmpdir, 38 | ) 39 | manager.extend(ext) 40 | for _epoch_idx in range(max_epochs): 41 | for _ in range(iters_per_epoch): 42 | with manager.run_iteration(): 43 | _body() 44 | with open(os.path.join(tmpdir, "log")) as f: 45 | data = f.read() 46 | if format == "json": 47 | values = json.loads(data) 48 | elif format == "json-lines": 49 | values = [json.loads(x) for x in data.splitlines()] 50 | elif format == "yaml": 51 | values = yaml.load(data, Loader=yaml.SafeLoader) 52 | assert len(values) == _epoch_idx + 1 53 | 54 | for value in values: 55 | assert abs(value["iter-time"] - 0.1) < 2e-2 56 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pathlib 3 | import re 4 | import time 5 | 6 | import pytorch_pfn_extras as ppe 7 | 8 | 9 | def test_run(tmp_path: pathlib.Path): 10 | max_epochs = 5 11 | iters_per_epoch = 5 12 | manager = ppe.training.ExtensionsManager( 13 | {}, 14 | {}, 15 | max_epochs, 16 | iters_per_epoch=iters_per_epoch, 17 | out_dir=str(tmp_path), 18 | ) 19 | 20 | out = io.StringIO() 21 | extension = ppe.training.extensions.ProgressBar( 22 | training_length=None, 23 | update_interval=1, 24 | bar_length=40, 25 | out=out, 26 | ) 27 | manager.extend(extension) 28 | 29 | for epoch in range(max_epochs): 30 | for _ in range(iters_per_epoch): 31 | with manager.run_iteration(): 32 | time.sleep(0.1) 33 | if manager.iteration < 2: 34 | continue 35 | status = out.getvalue() 36 | assert ( 37 | "{} iter, {} epoch / {} epochs".format( 38 | manager.iteration, epoch, max_epochs 39 | ) 40 | in status 41 | ) 42 | iters_per_sec = float( 43 | re.findall(r"([0-9]+\.[0-9]*) iters/sec", status)[-1] 44 | ) 45 | assert 7 <= iters_per_sec <= 12 46 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytorch_pfn_extras as ppe 4 | import torch 5 | 6 | 7 | def test_observe_value(tmp_path: pathlib.Path): 8 | lr = 0.1 9 | manager = ppe.training.ExtensionsManager( 10 | {}, [], 1, iters_per_epoch=1, out_dir=str(tmp_path) 11 | ) 12 | extension = ppe.training.extensions.observe_value("lr", lambda x: lr) 13 | manager.extend(extension) 14 | with manager.run_iteration(): 15 | pass 16 | 17 | assert manager.observation["lr"] == lr 18 | 19 | 20 | def test_observe_lr(tmp_path: pathlib.Path): 21 | lr = 0.01 22 | manager = ppe.training.ExtensionsManager( 23 | {}, [], 1, iters_per_epoch=1, out_dir=str(tmp_path) 24 | ) 25 | optimizer = torch.optim.Adam({torch.nn.Parameter()}, lr=lr) 26 | extension = ppe.training.extensions.observe_lr(optimizer) 27 | manager.extend(extension) 28 | with manager.run_iteration(): 29 | pass 30 | 31 | assert manager.observation["lr"] == lr 32 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytorch_pfn_extras as ppe 3 | import torch 4 | from pytorch_pfn_extras import engine 5 | 6 | 7 | class MyModel(torch.nn.Module): 8 | def __init__(self, correct_ratio): 9 | super().__init__() 10 | self.correct_ratio = correct_ratio 11 | 12 | def forward(self, x, t): 13 | g = t.clone() 14 | to_alter = int(10 * (1 - self.correct_ratio)) 15 | g[0:to_alter][:] -= 1 16 | return {"y": g} 17 | 18 | 19 | @pytest.mark.parametrize("device", ["cpu"]) 20 | @pytest.mark.parametrize("accuracy", [0, 0.5, 1.0]) 21 | def test_evaluator_with_metric(device, accuracy): 22 | model = MyModel(accuracy) 23 | data = torch.utils.data.DataLoader( 24 | [{"x": torch.rand(20), "t": torch.rand(1)} for i in range(10)], 25 | batch_size=10, 26 | ) 27 | 28 | ppe.to(model, device) 29 | evaluator = engine.create_evaluator( 30 | model, 31 | device=device, 32 | metrics=[ppe.training.metrics.AccuracyMetric("t", "y")], 33 | options={"eval_report_keys": ["accuracy"]}, 34 | ) 35 | evaluator.handler.eval_setup(evaluator, data) 36 | reporter = ppe.reporting.Reporter() 37 | observation = {} 38 | with reporter.scope(observation): 39 | evaluator.run(data) 40 | assert pytest.approx(observation["val/accuracy"]) == accuracy 41 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | import torch 3 | 4 | 5 | def _get_dummy_manager(): 6 | model = torch.nn.Module() 7 | return ppe.training.ExtensionsManager( 8 | {"main": model}, 9 | [], # optimizers 10 | 10, # max_epochs 11 | iters_per_epoch=1, 12 | ) 13 | 14 | 15 | def test_default_name(): 16 | class MyExtension(ppe.training.Extension): 17 | name = None 18 | default_name = "defalut_name" 19 | 20 | ext = MyExtension() 21 | entry = ppe.training.ExtensionEntry(ext) 22 | assert entry.name == MyExtension.default_name 23 | entry = ppe.training.ExtensionEntry(ext, name="updated") 24 | assert entry.name == "updated" 25 | 26 | 27 | def test_name(): 28 | class MyExtension(ppe.training.Extension): 29 | name = "name" 30 | default_name = "defalut_name" 31 | 32 | ext = MyExtension() 33 | entry = ppe.training.ExtensionEntry(ext) 34 | assert entry.name == MyExtension.name 35 | entry = ppe.training.ExtensionEntry(ext, name="updated") 36 | assert entry.name == "updated" 37 | 38 | 39 | def test_priority(): 40 | class MyExtension(ppe.training.Extension): 41 | priority = 100 42 | 43 | ext = MyExtension() 44 | entry = ppe.training.ExtensionEntry(ext) 45 | assert entry.priority == MyExtension.priority 46 | entry = ppe.training.ExtensionEntry(ext, priority=10) 47 | assert entry.priority == 10 48 | 49 | 50 | def test_trigger(): 51 | class MyExtension(ppe.training.Extension): 52 | trigger = (1, "iteration") 53 | 54 | ext = MyExtension() 55 | entry = ppe.training.ExtensionEntry(ext) 56 | assert isinstance(entry.trigger, ppe.training.triggers.IntervalTrigger) 57 | assert entry.trigger.period == 1 58 | assert entry.trigger.unit == "iteration" 59 | entry = ppe.training.ExtensionEntry(ext, trigger=(3, "epoch")) 60 | assert isinstance(entry.trigger, ppe.training.triggers.IntervalTrigger) 61 | assert entry.trigger.period == 3 62 | assert entry.trigger.unit == "epoch" 63 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytest 4 | from pytorch_pfn_extras import training 5 | from pytorch_pfn_extras.training import _trigger_util, triggers 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "iters_per_epoch,trigger_args,expected", 10 | [ 11 | # Never fire trigger 12 | (2, None, [False, False, False, False, False, False, False]), 13 | # Interval trigger 14 | (2, (2, "iteration"), [False, True, False, True, False, True, False]), 15 | (2, (2, "epoch"), [False, False, False, True, False, False, False]), 16 | # Callable object 17 | ( 18 | 2, 19 | _trigger_util.get_trigger(None), 20 | [False, False, False, False, False, False, False], 21 | ), 22 | ( 23 | 2, 24 | triggers.IntervalTrigger(2, "iteration"), 25 | [False, True, False, True, False, True, False], 26 | ), 27 | ( 28 | 2, 29 | (lambda trainer: trainer.iteration == 3), 30 | [False, False, True, False, False, False, False], 31 | ), 32 | ], 33 | ) 34 | def test_get_trigger( 35 | iters_per_epoch, trigger_args, expected, tmp_path: pathlib.Path 36 | ): 37 | trainer = training.ExtensionsManager( 38 | {}, 39 | [], 40 | 100, 41 | iters_per_epoch=iters_per_epoch, 42 | out_dir=str(tmp_path), 43 | ) 44 | trigger = _trigger_util.get_trigger(trigger_args) 45 | 46 | # before the first iteration, trigger should be False 47 | for _, e in enumerate(expected): 48 | with trainer.run_iteration(): 49 | pass 50 | assert trigger(trainer) == e 51 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import numpy 4 | import pytorch_pfn_extras as ppe 5 | import torch 6 | 7 | 8 | def _test_trigger(trigger, key, accuracies, expected, tmp_path: pathlib.Path): 9 | manager = ppe.training.ExtensionsManager( 10 | {}, [], 100, iters_per_epoch=1, out_dir=str(tmp_path) 11 | ) 12 | for a, e in zip(accuracies, expected): 13 | with manager.run_iteration(): 14 | pass 15 | manager.observation = {key: a} 16 | assert trigger(manager) == e 17 | 18 | 19 | def test_early_stopping_trigger_with_accuracy(tmp_path: pathlib.Path): 20 | key = "main/accuracy" 21 | trigger = ppe.training.triggers.EarlyStoppingTrigger( 22 | monitor=key, patience=3, check_trigger=(1, "epoch"), verbose=False 23 | ) 24 | accuracies = [ 25 | torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) 26 | for acc in [0.5, 0.5, 0.6, 0.7, 0.6, 0.4, 0.3, 0.2] 27 | ] 28 | expected = [False, False, False, False, False, False, True, True] 29 | _test_trigger(trigger, key, accuracies, expected, tmp_path) 30 | 31 | 32 | def test_early_stopping_trigger_with_loss(tmp_path: pathlib.Path): 33 | key = "main/loss" 34 | trigger = ppe.training.triggers.EarlyStoppingTrigger( 35 | monitor=key, patience=3, check_trigger=(1, "epoch") 36 | ) 37 | accuracies = [ 38 | torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) 39 | for acc in [100, 80, 30, 10, 20, 24, 30, 35] 40 | ] 41 | expected = [False, False, False, False, False, False, True, True] 42 | _test_trigger(trigger, key, accuracies, expected, tmp_path) 43 | 44 | 45 | def test_early_stopping_trigger_with_max_epoch(tmp_path: pathlib.Path): 46 | key = "main/loss" 47 | trigger = ppe.training.triggers.EarlyStoppingTrigger( 48 | monitor=key, 49 | patience=3, 50 | check_trigger=(1, "epoch"), 51 | max_trigger=(3, "epoch"), 52 | ) 53 | accuracies = [ 54 | torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) 55 | for acc in [100, 80, 30] 56 | ] 57 | expected = [False, False, True] 58 | _test_trigger(trigger, key, accuracies, expected, tmp_path) 59 | 60 | 61 | def test_early_stopping_trigger_with_max_iteration(tmp_path: pathlib.Path): 62 | key = "main/loss" 63 | trigger = ppe.training.triggers.EarlyStoppingTrigger( 64 | monitor=key, 65 | patience=3, 66 | check_trigger=(1, "epoch"), 67 | max_trigger=(3, "iteration"), 68 | ) 69 | accuracies = [ 70 | torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) 71 | for acc in [100, 80, 30] 72 | ] 73 | 74 | expected = [False, False, True] 75 | _test_trigger(trigger, key, accuracies, expected, tmp_path) 76 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_function_trigger.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | from pytorch_pfn_extras.training import ExtensionsManager 6 | from pytorch_pfn_extras.training import trigger as trigger_module 7 | from pytorch_pfn_extras.training._trigger_util import TriggerLike 8 | from pytorch_pfn_extras.training.triggers import FunctionTrigger 9 | 10 | 11 | def test_function_is_called(tmp_path: pathlib.Path) -> None: 12 | fn = MagicMock() 13 | args = [MagicMock()] 14 | kwargs = {"a": MagicMock()} 15 | trigger = FunctionTrigger( 16 | fn=fn, args=args, kwargs=kwargs, trigger=(1, "iteration") 17 | ) 18 | fn.assert_not_called() 19 | manager = ExtensionsManager( 20 | {}, {}, 1, iters_per_epoch=10, out_dir=str(tmp_path) 21 | ) 22 | with manager.run_iteration(): 23 | pass 24 | trigger(manager) 25 | fn.assert_called_once_with(*args, **kwargs) 26 | 27 | 28 | def test_trigger_with_value(tmp_path: pathlib.Path) -> None: 29 | value = {"value": False} 30 | args = [value] 31 | trigger = FunctionTrigger( 32 | fn=lambda x: x["value"], args=args, trigger=(1, "iteration") 33 | ) 34 | manager = ExtensionsManager( 35 | {}, {}, 1, iters_per_epoch=10, out_dir=str(tmp_path) 36 | ) 37 | with manager.run_iteration(): 38 | pass 39 | assert not trigger(manager) 40 | value["value"] = True 41 | assert trigger(manager) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "trigger, iters_per_epoch", 46 | [((1, "iteration"), 10), ((1, "epoch"), 20), ((0.123, "epoch"), 17)], 47 | ) 48 | def test_with_interval_trigger( 49 | trigger: TriggerLike, iters_per_epoch: int, tmp_path: pathlib.Path 50 | ) -> None: 51 | trigger = trigger_module.get_trigger(trigger) 52 | manager = ExtensionsManager( 53 | {}, [], 10, iters_per_epoch=iters_per_epoch, out_dir=str(tmp_path) 54 | ) 55 | function_trigger = FunctionTrigger(fn=lambda: True, trigger=trigger) 56 | 57 | while not manager.stop_trigger: 58 | with manager.run_iteration(): 59 | pass 60 | assert trigger(manager) == function_trigger(manager) 61 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytest 4 | from pytorch_pfn_extras import training 5 | from pytorch_pfn_extras.training import triggers 6 | 7 | _argvalues = [ 8 | # iteration 9 | (5, (2, "iteration"), [False, True, False, True, False, True, False], 4), 10 | # basic epoch 11 | (1, (3, "epoch"), [False, False, True, False, False, True, False], 4), 12 | # fractional epoch 13 | (2, (1.5, "epoch"), [False, False, True, False, False, True, False], 4), 14 | ( 15 | 3, 16 | (1.5, "epoch"), 17 | [False, False, False, False, True, False, False, False, True], 18 | 4, 19 | ), 20 | ] 21 | 22 | 23 | @pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues) 24 | def test_trigger( 25 | iters_per_epoch, interval, expected, resume, tmp_path: pathlib.Path 26 | ): 27 | trainer = training.ExtensionsManager( 28 | {}, [], 100, iters_per_epoch=iters_per_epoch, out_dir=str(tmp_path) 29 | ) 30 | trigger = triggers.IntervalTrigger(*interval) 31 | 32 | for e in expected: 33 | with trainer.run_iteration(): 34 | pass 35 | assert trigger.may_fire(trainer.iteration, iters_per_epoch) == e 36 | assert trigger(trainer) == e 37 | 38 | 39 | @pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues) 40 | def test_resumed_trigger( 41 | iters_per_epoch, interval, expected, resume, tmp_path: pathlib.Path 42 | ): 43 | trainer = training.ExtensionsManager( 44 | {}, 45 | [], 46 | 100, 47 | iters_per_epoch=iters_per_epoch, 48 | out_dir=str(tmp_path), 49 | ) 50 | trigger = triggers.IntervalTrigger(*interval) 51 | 52 | for e in expected[:resume]: 53 | with trainer.run_iteration(): 54 | pass 55 | assert trigger.may_fire(trainer.iteration, iters_per_epoch) == e 56 | assert trigger(trainer) == e 57 | 58 | state = trigger.state_dict() 59 | new_trigger = triggers.IntervalTrigger(*interval) 60 | new_trigger.load_state_dict(state) 61 | 62 | for e in expected[resume:]: 63 | with trainer.run_iteration(): 64 | pass 65 | assert new_trigger.may_fire(trainer.iteration, iters_per_epoch) == e 66 | assert new_trigger(trainer) == e 67 | 68 | 69 | @pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues) 70 | def test_str( 71 | iters_per_epoch, 72 | interval, 73 | expected, 74 | resume, 75 | ): 76 | trigger = triggers.IntervalTrigger(*interval) 77 | 78 | expected = "IntervalTrigger({}, '{}')".format(*interval) 79 | actual = str(trigger) 80 | 81 | assert expected == actual, 'Expected "{}" == "{}"'.format(expected, actual) 82 | 83 | 84 | def test_invalid_unit(): 85 | with pytest.raises(ValueError): 86 | triggers.IntervalTrigger(1, "day") 87 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py: -------------------------------------------------------------------------------- 1 | import pytorch_pfn_extras as ppe 2 | 3 | 4 | class DummyTrainer: 5 | def __init__(self): 6 | self.elapsed_time = 0 7 | 8 | 9 | def test_call(): 10 | trigger = ppe.training.triggers.TimeTrigger(1) 11 | trainer = DummyTrainer() 12 | 13 | assert not trigger(trainer) 14 | trainer.elapsed_time = 0.9 15 | assert not trigger(trainer) 16 | 17 | # first event is triggerred on time==1.0 18 | trainer.elapsed_time = 1.2 19 | assert trigger(trainer) 20 | 21 | trainer.elapsed_time = 1.3 22 | assert not trigger(trainer) 23 | 24 | # second event is triggerred on time==2.0, and is not on time==2.2 25 | trainer.elapsed_time = 2.1 26 | assert trigger(trainer) 27 | 28 | 29 | def test_resume(): 30 | trigger = ppe.training.triggers.TimeTrigger(1) 31 | trainer = DummyTrainer() 32 | trainer.elapsed_time = 1.2 33 | trigger(trainer) 34 | assert trigger._next_time == 2.0 35 | 36 | state = trigger.state_dict() 37 | trigger2 = ppe.training.triggers.TimeTrigger(1) 38 | trigger2.load_state_dict(state) 39 | assert trigger._next_time == 2.0 40 | -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/utils_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet/pytorch-pfn-extras/5c1f9da9202b58788119de35257cedac83847d9a/tests/pytorch_pfn_extras_tests/utils_tests/__init__.py -------------------------------------------------------------------------------- /tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytorch_pfn_extras as ppe 3 | import torch 4 | 5 | 6 | class SubNet(torch.nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.conv1 = torch.nn.Conv2d(5, 5, 3, 1, 1) 10 | self.bn1 = torch.nn.BatchNorm2d(5) 11 | 12 | def forward(self, x): 13 | return self.bn1(self.conv1(x)).relu() 14 | 15 | 16 | class Net(torch.nn.Module): 17 | def __init__(self, checkpoint_type): 18 | super().__init__() 19 | self.checkpoint_type = checkpoint_type 20 | 21 | self.conv1 = torch.nn.Conv2d(1, 5, 3, 1, 1) 22 | self.bn1 = torch.nn.BatchNorm2d(5) 23 | self.part1 = SubNet() 24 | self.part2 = SubNet() 25 | 26 | def forward(self, x): 27 | x = self.bn1(self.conv1(x)).relu() 28 | 29 | if self.checkpoint_type == "none": 30 | x = self.part1(x) 31 | elif self.checkpoint_type == "bnaware": 32 | x = ppe.utils.checkpoint.checkpoint(self.part1, x) 33 | 34 | x = self.part2(x) 35 | 36 | return x 37 | 38 | 39 | def _get_bn_stats_test_checkpoint(cp_type): 40 | torch.manual_seed(42) 41 | net = Net(cp_type) 42 | opt = torch.optim.SGD(net.parameters(), lr=0.1) 43 | h, w = 32, 32 44 | 45 | bn = net.part1.bn1 46 | 47 | for _ in range(2): 48 | x = torch.arange(2 * h * w).reshape((2, 1, h, w)).float() 49 | 50 | opt.zero_grad() 51 | y = net(x) 52 | y.sum().backward() 53 | opt.step() 54 | 55 | return (bn.weight, bn.bias, bn.running_mean, bn.running_var) 56 | 57 | 58 | @pytest.mark.gpu 59 | def test_checkpoint(): 60 | baseline = _get_bn_stats_test_checkpoint("none") 61 | ckpt = _get_bn_stats_test_checkpoint("bnaware") 62 | for p_b, p_c in zip(baseline, ckpt): 63 | assert torch.allclose(p_b, p_c) 64 | -------------------------------------------------------------------------------- /tests/requirements.mpi.txt: -------------------------------------------------------------------------------- 1 | mpi4py 2 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | # Notes: use `onnxruntime==1.22` in Windows raises the following error during `import onnxruntime` 3 | # ImportError: DLL load failed while importing onnxruntime_pybind11_state: The specified module could not be found. 4 | onnxruntime<1.22 5 | torchvision 6 | torchaudio 7 | pysen 8 | black==24.3.0 9 | flake8==4.0.1 10 | isort==5.10.1 11 | mypy==1.3.0 12 | types-PyYAML 13 | types-setuptools 14 | matplotlib 15 | tensorboard 16 | ipython 17 | ipywidgets 18 | pandas 19 | pyarrow 20 | optuna 21 | onnx 22 | pytorch-ignite 23 | pytest-cov 24 | slack_sdk 25 | numpy<2 26 | --------------------------------------------------------------------------------