├── torchkit
├── utils
│ ├── module_freezing.py
│ ├── dataset.py
│ ├── git.py
│ ├── timer.py
│ ├── __init__.py
│ ├── multithreading.py
│ ├── seed.py
│ ├── io.py
│ ├── pdb_fallback.py
│ ├── module_stats.py
│ └── config.py
├── __init__.py
├── version.py
├── experiment.py
├── viz.py
├── losses.py
├── logger.py
├── checkpoint.py
└── layers.py
├── docs
├── torchkit
│ ├── utils
│ │ ├── io.rst
│ │ ├── git.rst
│ │ ├── seed.rst
│ │ ├── config.rst
│ │ ├── timer.rst
│ │ ├── dataset.rst
│ │ ├── pdb_fallback.rst
│ │ ├── module_stats.rst
│ │ └── index.rst
│ ├── layers.rst
│ ├── losses.rst
│ ├── experiment.rst
│ ├── installation.rst
│ ├── checkpoint.rst
│ └── logger.rst
├── Makefile
├── index.rst
├── make.bat
└── conf.py
├── pyproject.toml
├── setup.cfg
├── .github
└── workflows
│ ├── lint.yml
│ ├── docs.yml
│ └── build.yml
├── scripts
└── lint.sh
├── .gitignore
├── LICENSE
├── setup.py
├── tests
├── test_layers.py
├── test_checkpoint.py
├── test_losses.py
└── test_logger.py
└── README.md
/torchkit/utils/module_freezing.py:
--------------------------------------------------------------------------------
1 | # TODO(kevin): Implement freezing and unfreezing of `nn.Module`.
2 | # TODO(kevin): Implement batch norm freezing.
3 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/io.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.io
2 | =================
3 |
4 | .. automodule:: torchkit.utils.io
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 88
3 | target-version = ["py38"]
4 |
5 | [tool.isort]
6 | profile = "black"
7 | multi_line_output = 3
8 | line_length = 88
--------------------------------------------------------------------------------
/docs/torchkit/utils/git.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.git
2 | ==================
3 |
4 | .. automodule:: torchkit.utils.git
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/seed.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.seed
2 | ===================
3 |
4 | .. automodule:: torchkit.utils.seed
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/docs/torchkit/layers.rst:
--------------------------------------------------------------------------------
1 | torchkit.layers
2 | ===============
3 |
4 | .. automodule:: torchkit.layers
5 | :members:
6 | :member-order: bysource
7 | :exclude-members: forward
8 |
--------------------------------------------------------------------------------
/docs/torchkit/losses.rst:
--------------------------------------------------------------------------------
1 | torchkit.losses
2 | ===============
3 |
4 | .. automodule:: torchkit.losses
5 | :members:
6 | :member-order: bysource
7 | :exclude-members: forward
8 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/config.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.config
2 | =====================
3 |
4 | .. automodule:: torchkit.utils.config
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/timer.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.timer
2 | ====================
3 |
4 | .. automodule:: torchkit.utils.timer
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/dataset.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.dataset
2 | ======================
3 |
4 | .. automodule:: torchkit.utils.dataset
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | docstring-convention=google
3 | ignore = E203, W503
4 | max-line-length = 88
5 |
6 | [pytype]
7 | inputs =
8 | torchkit/
9 | tests/
10 | setup.py
11 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/pdb_fallback.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.pdb_fallback
2 | ===========================
3 |
4 | .. automodule:: torchkit.utils.pdb_fallback
5 | :members:
6 | :member-order: alphabetical
7 |
--------------------------------------------------------------------------------
/docs/torchkit/experiment.rst:
--------------------------------------------------------------------------------
1 | torchkit.experiment
2 | ===================
3 |
4 | .. automodule:: torchkit.experiment
5 | :members:
6 | :member-order: alphabetical
7 | :exclude-members: __init__
8 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/module_stats.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils.module_stats
2 | ===========================
3 |
4 | .. automodule:: torchkit.utils.module_stats
5 | :members:
6 | :member-order: alphabetical
7 |
8 |
--------------------------------------------------------------------------------
/docs/torchkit/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | **torchkit** requires Python 3.8+ or higher. You can install directly from Git master::
5 |
6 | pip install git+https://github.com/kevinzakka/torchkit.git
7 |
--------------------------------------------------------------------------------
/docs/torchkit/utils/index.rst:
--------------------------------------------------------------------------------
1 | torchkit.utils
2 | ==============
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | config
8 | dataset
9 | git
10 | module_stats
11 | pdb_fallback
12 | io
13 | seed
14 | timer
15 |
--------------------------------------------------------------------------------
/torchkit/__init__.py:
--------------------------------------------------------------------------------
1 | """A PyTorch toolkit for research."""
2 |
3 | from torchkit.checkpoint import CheckpointManager # noqa: F401
4 | from torchkit.logger import Logger # noqa: F401
5 | from torchkit.version import __version__ # noqa: F401
6 |
--------------------------------------------------------------------------------
/torchkit/version.py:
--------------------------------------------------------------------------------
1 | """Version file will be exec'd by setup.py and imported by __init__.py"""
2 |
3 | # Semantic versioning.
4 | _MAJOR_VERSION = "0"
5 | _MINOR_VERSION = "0"
6 | _PATCH_VERSION = "3"
7 |
8 | __version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
9 |
--------------------------------------------------------------------------------
/docs/torchkit/checkpoint.rst:
--------------------------------------------------------------------------------
1 | torchkit.checkpoint
2 | ===================
3 |
4 | .. autoclass:: torchkit.checkpoint.Checkpoint
5 | :members:
6 | :member-order: groupwise
7 |
8 | .. autoclass:: torchkit.checkpoint.CheckpointManager
9 | :members:
10 | :member-order: groupwise
11 |
--------------------------------------------------------------------------------
/docs/torchkit/logger.rst:
--------------------------------------------------------------------------------
1 | torchkit.logger
2 | ===============
3 |
4 |
5 | The `Logger` class is a simple wrapper over Pytorch's ``SummaryWriter``. You can use it to log scalars, images and videos in numpy ndarray or torch Tensor format.
6 |
7 | .. autoclass:: torchkit.logger.Logger
8 | :members:
9 | :member-order: groupwise
10 |
--------------------------------------------------------------------------------
/torchkit/utils/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, Iterator
2 |
3 |
4 | # Reference: https://github.com/unixpickle/vq-voice-swap/blob/main/vq_voice_swap/util.py
5 | def infinite_dataset(data_loader: Iterable) -> Iterator:
6 | """Create an infinite loop over a `torch.utils.DataLoader`.
7 |
8 | Args:
9 | data_loader (Iterable): A `torch.utils.DataLoader` object.
10 |
11 | Yields:
12 | Iterator: An iterator over the dataloader that repeats ad infinitum.
13 | """
14 | while True:
15 | yield from data_loader
16 |
--------------------------------------------------------------------------------
/torchkit/utils/git.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | # Reference: https://stackoverflow.com/a/21901260
5 | def git_revision_hash() -> str:
6 | """Return the git commit hash of the current directory.
7 |
8 | Note:
9 | Will return a `fatal: not a git repository` string if the command fails.
10 | """
11 | try:
12 | string = subprocess.check_output(
13 | ["git", "rev-parse", "HEAD"], stderr=subprocess.STDOUT
14 | )
15 | except subprocess.CalledProcessError as err:
16 | string = err.output
17 | return string.decode("ascii").strip()
18 |
19 |
20 | # Alias.
21 | git_commit_hash = git_revision_hash
22 |
--------------------------------------------------------------------------------
/torchkit/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class Stopwatch:
5 | """A simple timer for measuring elapsed time.
6 |
7 | Example usage::
8 |
9 | stopwatch = Stopwatch()
10 | some_func()
11 | print(f"some_func took: {stopwatch.elapsed()} seconds.")
12 | stopwatch.reset()
13 | """
14 |
15 | def __init__(self) -> None:
16 | self.reset()
17 |
18 | def elapsed(self) -> float:
19 | """Return the elapsed time since the stopwatch was reset."""
20 | return time.time() - self.time
21 |
22 | def reset(self) -> None:
23 | """Reset the stopwatch, i.e. start the timer."""
24 | self.time = time.time()
25 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.8", "3.9"]
15 |
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v2
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip setuptools wheel
25 | pip install -e ".[dev]"
26 | - name: Run lint script
27 | run: |
28 | bash scripts/lint.sh
29 |
--------------------------------------------------------------------------------
/scripts/lint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Modified from https://raw.githubusercontent.com/HumanCompatibleAI/seals/master/ci/code_checks.sh
4 | set -x
5 | set -e
6 |
7 | SRC_FILES=(torchkit/ tests/ docs/conf.py setup.py)
8 |
9 | if [ "$(uname)" == "Darwin" ]; then
10 | N_CPU=$(sysctl -n hw.ncpu)
11 | elif [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
12 | N_CPU=$(grep -c ^processor /proc/cpuinfo)
13 | fi
14 |
15 | echo "Source format checking"
16 | flake8 ${SRC_FILES[@]}
17 | black --check ${SRC_FILES}
18 |
19 | if [ "$skipexpensive" != "true" ]; then
20 | echo "Building docs (validates docstrings)"
21 | pushd docs/
22 | make clean
23 | make html
24 | popd
25 |
26 | echo "Type checking"
27 | pytype -n "${N_CPU}" ${SRC_FILES[@]}
28 | fi
29 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | on:
4 | push:
5 | branches: [master]
6 |
7 | jobs:
8 | docs:
9 | runs-on: ubuntu-latest
10 | container:
11 | image: python:3.8
12 | steps:
13 |
14 | # Check out source
15 | - uses: actions/checkout@v2
16 |
17 | # Build documentation
18 | - name: Building documentation
19 | run: |
20 | apt-get update
21 | apt-get -y install libgl1-mesa-glx
22 | pip install -e .[dev]
23 | sphinx-build docs/ docs/_build -b dirhtml
24 |
25 | # Deploy
26 | - name: Deploy to GitHub Pages
27 | uses: peaceiris/actions-gh-pages@v3
28 | with:
29 | github_token: ${{ secrets.GITHUB_TOKEN }}
30 | publish_dir: ./docs/_build
31 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: build
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.8", "3.9"]
15 | name: Python ${{ matrix.python-version }}
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v1
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip setuptools wheel
25 | pip install -e .[test]
26 | - name: Test with pytest
27 | run: |
28 | pytest tests
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .pytest_cache/
2 | .pytype
3 | .vscode
4 | .ipynb_checkpoints
5 | playground.ipynb
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | bin/
16 | build/
17 | develop-eggs/
18 | dist/
19 | eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # Installer logs
30 | pip-log.txt
31 | pip-delete-this-directory.txt
32 |
33 | # Unit test / coverage reports
34 | .tox/
35 | .coverage
36 | .cache
37 | nosetests.xml
38 | coverage.xml
39 |
40 | # Translations
41 | *.mo
42 |
43 | # Mr Developer
44 | .mr.developer.cfg
45 | .project
46 | .pydevproject
47 |
48 | # Rope
49 | .ropeproject
50 |
51 | # Django stuff:
52 | *.log
53 | *.pot
54 |
55 | # Sphinx documentation
56 | docs/_build/
57 | docs/_templates
58 |
--------------------------------------------------------------------------------
/torchkit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import copy_config_and_replace, dump_config, load_config, validate_config
2 | from .dataset import infinite_dataset
3 | from .git import git_commit_hash, git_revision_hash
4 | from .io import load_pickle, save_pickle
5 | from .module_stats import get_total_params
6 | from .multithreading import threaded_func
7 | from .pdb_fallback import pdb_fallback
8 | from .seed import seed_rngs, set_cudnn
9 | from .timer import Stopwatch
10 |
11 | __all__ = [
12 | "threaded_func",
13 | "Stopwatch",
14 | "pdb_fallback",
15 | "get_total_params",
16 | "git_revision_hash",
17 | "git_commit_hash",
18 | "seed_rngs",
19 | "set_cudnn",
20 | "validate_config",
21 | "dump_config",
22 | "load_config",
23 | "copy_config_and_replace",
24 | "infinite_dataset",
25 | "save_pickle",
26 | "load_pickle",
27 | ]
28 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. raw:: html
2 |
3 |
TorchKit: A PyTorch Toolkit for Research
4 |
5 |
6 | **torchkit** is a *lightweight* library containing PyTorch utilities useful for day-to-day research.
7 | Its main goal is to abstract away a lot of the redundant boilerplate associated with research projects
8 | like experimental configurations, logging and model checkpointing.
9 |
10 | .. toctree::
11 | :maxdepth: 1
12 | :caption: User Guide
13 |
14 | torchkit/installation
15 |
16 | .. toctree::
17 | :maxdepth: 2
18 | :caption: API Reference
19 |
20 | torchkit/logger
21 | torchkit/checkpoint
22 | torchkit/experiment
23 | torchkit/layers
24 | torchkit/losses
25 | torchkit/utils/index
26 |
27 | Indices and tables
28 | ==================
29 |
30 | * :ref:`genindex`
31 | * :ref:`modindex`
32 | * :ref:`search`
33 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/torchkit/utils/multithreading.py:
--------------------------------------------------------------------------------
1 | import threading
2 | from typing import Callable, Iterable, Tuple
3 |
4 |
5 | # Adapted from: https://github.com/facebookresearch/pytorchvideo/blob/master/pytorchvideo/data/utils.py#L99 # noqa: E501
6 | def threaded_func(
7 | func: Callable,
8 | args_iterable: Iterable[Tuple],
9 | multithreaded: bool,
10 | ) -> None:
11 | """Applies a func on a tuple of args with optional multithreading.
12 |
13 | Args:
14 | func: The func to execute.
15 | args_iterable: An iterable of arg tuples to feed to func.
16 | multithreaded: Whether to parallelize the func across threads.
17 | """
18 | if multithreaded:
19 | threads = []
20 | for args in args_iterable:
21 | thread = threading.Thread(target=func, args=args)
22 | thread.start()
23 | threads.append(thread)
24 | for t in threads:
25 | t.join()
26 | else:
27 | for args in args_iterable:
28 | func(*args)
29 |
--------------------------------------------------------------------------------
/torchkit/utils/seed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def seed_rngs(seed: int, pytorch: bool = True) -> None:
9 | """Seed system RNGs.
10 |
11 | Args:
12 | seed (int): The desired seed.
13 | pytorch (bool, optional): Whether to seed the `torch` RNG as well. Defaults to
14 | True.
15 | """
16 | os.environ["PYTHONHASHSEED"] = str(seed)
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | if pytorch:
20 | torch.manual_seed(seed)
21 |
22 |
23 | def set_cudnn(deterministic: bool = False, benchmark: bool = True) -> None:
24 | """Set PyTorch-related CUDNN settings.
25 |
26 | Args:
27 | deterministic (bool, optional): Make CUDA algorithms deterministic. Defaults to
28 | False.
29 | benchmark (bool, optional): Make CUDA arlgorithm selection deterministic.
30 | Defaults to True.
31 | """
32 | torch.backends.cudnn.deterministic = deterministic
33 | torch.backends.cudnn.benchmark = benchmark
34 |
--------------------------------------------------------------------------------
/torchkit/utils/io.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | from typing import Any
5 |
6 |
7 | def save_pickle(obj: Any, path: str, name: str) -> None:
8 | """Save a python object as a pickle file.
9 |
10 | Args:
11 | obj (Any): The object to save.
12 | path (str): Directory wherein to save file.
13 | name (str): Name of the pickle file.
14 | """
15 | filename = os.path.join(path, name)
16 | with open(filename, "wb") as fp:
17 | pickle.dump(obj, fp)
18 | logging.info(f"Successfully saved {filename}")
19 |
20 |
21 | def load_pickle(path: str, name: str) -> Any:
22 | """Load a pickled file.
23 |
24 | Args:
25 | path (str): The directory where the pickle file is stored.
26 | name (str): The name of the pickle file.
27 |
28 | Returns:
29 | Any: The object in the pickle file.
30 | """
31 | filename = os.path.join(path, name)
32 | with open(filename, "rb") as fp:
33 | obj = pickle.load(fp)
34 | logging.info(f"Successfully loaded {filename}.")
35 | return obj
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Kevin Zakka
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import find_packages, setup
4 |
5 | THIS_DIR = os.path.dirname(os.path.abspath(__file__))
6 | DESCRIPTION = "torchkit is a lightweight library containing PyTorch utilities useful for day-to-day research." # noqa: E501
7 | TESTS_REQUIRE = [
8 | "pytest",
9 | "black",
10 | "isort",
11 | "pytype",
12 | "flake8",
13 | ]
14 | DOCS_REQUIRE = [
15 | "sphinx",
16 | "sphinx-autodoc-typehints",
17 | "sphinx-rtd-theme",
18 | "docutils==0.16",
19 | ]
20 |
21 |
22 | def readme() -> str:
23 | """Load README for use as package's long description."""
24 | with open(os.path.join(THIS_DIR, "README.md"), "r") as fp:
25 | return fp.read()
26 |
27 |
28 | def get_version() -> str:
29 | locals_dict = {}
30 | with open(os.path.join(THIS_DIR, "torchkit", "version.py"), "r") as fp:
31 | exec(fp.read(), globals(), locals_dict)
32 | return locals_dict["__version__"] # pytype: disable=invalid-directive
33 |
34 |
35 | setup(
36 | name="torchkit",
37 | version=get_version(),
38 | author="Kevin Zakka",
39 | license="MIT",
40 | description=DESCRIPTION,
41 | python_requires=">=3.8",
42 | long_description=readme(),
43 | long_description_content_type="text/markdown",
44 | packages=find_packages(),
45 | install_requires=[
46 | "torch>=1.3",
47 | "torchvision>=0.4",
48 | "tensorboard",
49 | "prettytable",
50 | "opencv-python",
51 | "moviepy",
52 | "ml_collections",
53 | "ipdb",
54 | ],
55 | extras_require={
56 | "dev": ["jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE],
57 | "test": TESTS_REQUIRE,
58 | },
59 | tests_require=TESTS_REQUIRE,
60 | url="https://github.com/kevinzakka/torchkit/",
61 | )
62 |
--------------------------------------------------------------------------------
/torchkit/utils/pdb_fallback.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Callable, TypeVar
3 |
4 | CallableType = TypeVar("CallableType", bound=Callable)
5 |
6 |
7 | # Reference: https://github.com/brentyi/fannypack/blob/master/fannypack/utils/_pdb_safety_net.py # noqa: E501
8 | def pdb_fallback(f: CallableType, use_ipdb: bool = False) -> CallableType:
9 | """Wraps a function in a pdb safety net for unexpected errors in a Python script.
10 |
11 | When called, pdb will be automatically opened when either (a) the user hits Ctrl+C
12 | or (b) we encounter an uncaught exception. Helpful for bypassing minor errors,
13 | diagnosing problems, and rescuing unsaved models.
14 |
15 | Example usage::
16 |
17 | from torchkit.utils import pdb_fallback
18 |
19 | @pdb_fallback
20 | def main():
21 | # A very interesting function that might fail because we did something
22 | # stupid.
23 | ...
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
28 | Args:
29 | f (CallableType): The function to wrap.
30 | use_ipdb (bool, optional): Whether to use ipdb instead of pdb. Defaults to
31 | False.
32 | """
33 |
34 | import signal
35 | import sys
36 | import traceback as tb
37 |
38 | if use_ipdb:
39 | import ipdb as pdb
40 | else:
41 | import pdb
42 |
43 | @functools.wraps(f)
44 | def inner_wrapper(*args, **kwargs):
45 | # Open pdb on Ctrl-C.
46 | def handler(sig, frame):
47 | pdb.set_trace()
48 |
49 | signal.signal(signal.SIGINT, handler)
50 |
51 | # Open pdb when we encounter an uncaught exception.
52 | def excepthook(type_, value, traceback):
53 | tb.print_exception(type_, value, traceback, limit=100)
54 | pdb.post_mortem(traceback)
55 |
56 | sys.excepthook = excepthook
57 |
58 | return f(*args, **kwargs)
59 |
60 | return inner_wrapper
61 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | """Configuration file for the Sphinx documentation builder."""
2 |
3 | import torchkit
4 |
5 | # -- Project information -----------------------------------------------------
6 |
7 | project = "torchkit"
8 | copyright = "2021, Kevin Zakka"
9 | author = "Kevin Zakka"
10 | release = torchkit.__version__
11 |
12 | # -- General configuration ---------------------------------------------------
13 |
14 | master_doc = "index"
15 |
16 | extensions = [
17 | "sphinx.ext.napoleon",
18 | "sphinx.ext.autodoc",
19 | "sphinx_autodoc_typehints",
20 | "sphinx.ext.autosummary",
21 | "sphinx.ext.mathjax",
22 | "sphinx.ext.viewcode",
23 | "sphinx_rtd_theme",
24 | ]
25 |
26 |
27 | autodoc_typehints = "description"
28 |
29 | napoleon_google_docstring = True
30 | napoleon_numpy_docstring = False
31 |
32 | # Add any paths that contain templates here, relative to this directory.
33 | templates_path = ["_templates"]
34 |
35 | # List of patterns, relative to source directory, that match files and
36 | # directories to ignore when looking for source files.
37 | # This pattern also affects html_static_path and html_extra_path.
38 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
39 |
40 | autodoc_default_options = {
41 | "members": True,
42 | "undoc-members": True,
43 | "special-members": "__init__",
44 | }
45 |
46 |
47 | # -- Options for HTML output -------------------------------------------------
48 |
49 | html_theme = "sphinx_rtd_theme"
50 | html_theme_options = {
51 | "logo_only": True,
52 | "style_nav_header_background": "#06203A",
53 | }
54 |
55 | # Add any paths that contain custom static files (such as style sheets) here,
56 | # relative to this directory. They are copied after the builtin static files,
57 | # so a file named "default.css" will overwrite the builtin "default.css".
58 | html_static_path = ["_static"]
59 |
60 | # ---------------------------------------------------------------------------
61 |
62 | sphinx_to_github = True
63 | sphinx_to_github_verbose = True
64 | sphinx_to_github_encoding = "utf-8"
65 |
--------------------------------------------------------------------------------
/torchkit/utils/module_stats.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | # Reference: https://stackoverflow.com/a/62508086
5 | def get_total_params(
6 | model: torch.nn.Module,
7 | trainable: bool = True,
8 | print_table: bool = False,
9 | ) -> int:
10 | """Get the total number of parameters in a PyTorch model.
11 |
12 | Example usage::
13 |
14 | class SimpleMLP(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | self.fc1 = nn.Linear(3, 16)
19 | self.fc2 = nn.Linear(16, 2)
20 |
21 | def forward(self, x: torch.Tensor) -> torch.Tensor:
22 | out = F.relu(self.fc1(x))
23 | return self.fc2(out)
24 |
25 | net = SimpleMLP()
26 | num_params = torch_utils.get_total_params(net, print_table=True)
27 |
28 | # prints the following:
29 | +------------+------------+
30 | | Modules | Parameters |
31 | +------------+------------+
32 | | fc1.weight | 48 |
33 | | fc1.bias | 16 |
34 | | fc2.weight | 32 |
35 | | fc2.bias | 2 |
36 | +------------+------------+
37 | Total Trainable Params: 98
38 |
39 | Args:
40 | model (torch.nn.Module): The pytorch model.
41 | trainable (bool, optional): Only consider trainable parameters. Defaults to
42 | True.
43 | print_table (bool, optional): Print the parameters in a pretty table. Defaults
44 | to False.
45 |
46 | Returns:
47 | int: Either all model parameters or only the trainable ones.
48 | """
49 | from prettytable import PrettyTable
50 |
51 | table = PrettyTable(["Modules", "Parameters"])
52 |
53 | total_params = 0
54 | for name, parameter in model.named_parameters():
55 | if not parameter.requires_grad and trainable:
56 | continue
57 | param = parameter.numel()
58 | table.add_row([name, param])
59 | total_params += param
60 |
61 | if print_table:
62 | print(table)
63 | print("Total Trainable Params: {:,}".format(total_params))
64 |
65 | return total_params
66 |
--------------------------------------------------------------------------------
/torchkit/experiment.py:
--------------------------------------------------------------------------------
1 | """Methods useful for running experiments."""
2 |
3 | import os
4 | import uuid
5 | from typing import Any
6 |
7 | from ml_collections import config_dict
8 |
9 | from torchkit.utils import dump_config, git_revision_hash, load_config
10 |
11 | ConfigDict = config_dict.ConfigDict
12 |
13 |
14 | def string_from_kwargs(**kwargs: Any) -> str:
15 | """Concatenate kwargs into an underscore-separated string.
16 |
17 | Used to generate an experiment name based on supplied config kwargs.
18 | """
19 | return "_".join([f"{k}={v}" for k, v in kwargs.items()])
20 |
21 |
22 | def unique_id() -> str:
23 | """Generate a unique ID as specified in RFC 4122."""
24 | # See https://docs.python.org/3/library/uuid.html
25 | return str(uuid.uuid4())
26 |
27 |
28 | def setup_experiment(
29 | exp_dir: str,
30 | config: ConfigDict,
31 | resume: bool = False,
32 | ) -> None:
33 | """Initializes an experiment.
34 |
35 | If the experiment directory doesn't exist yet, creates it and dumps the config
36 | dict as a yaml file and git hash as a text file.
37 | If it exists already, raises a ValueError to prevent overwriting unless resume is
38 | set to True.
39 | If it exists already and resume is set to True, inplace updates the config with the
40 | values in the saved yaml file.
41 |
42 | Args:
43 | exp_dir (str): Path to the experiment directory.
44 | config (ConfigDict): The config for the experiment.
45 | resume (bool, optional): Whether to resume from a previously created experiment.
46 | Defaults to False.
47 |
48 | Raises:
49 | ValueError: If the experiment directory exists already and resume is not set to
50 | True.
51 | """
52 | if os.path.exists(exp_dir):
53 | if not resume:
54 | raise ValueError(
55 | "Experiment already exists. Run with --resume to continue."
56 | )
57 | # Inplace-update the config using the values in the saved yaml file.
58 | load_config(exp_dir, config)
59 | else:
60 | # Dump config as a yaml file.
61 | dump_config(exp_dir, config)
62 |
63 | # Dump git hash as a text file.
64 | with open(os.path.join(exp_dir, "git_hash.txt"), "w") as fp:
65 | fp.write(git_revision_hash())
66 |
--------------------------------------------------------------------------------
/tests/test_layers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.testing import assert_allclose
4 |
5 | from torchkit import layers
6 |
7 |
8 | class TestLayers:
9 | @pytest.mark.parametrize("kernel_size", [1, 3, 5])
10 | def test_conv2d_same_shape(self, kernel_size):
11 | b, c, h, w = 32, 64, 16, 16
12 | x = torch.randn(b, c, h, w)
13 | out = layers.conv2d(c, c * 2, kernel_size=kernel_size)(x)
14 | assert out.shape[2:] == x.shape[2:]
15 |
16 | def test_spatial_soft_argmax(self):
17 | b, c, h, w = 32, 64, 16, 16
18 | x = torch.zeros(b, c, h, w)
19 | true_max = torch.randint(0, 10, size=(b, c, 2))
20 | for i in range(b):
21 | for j in range(c):
22 | x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000
23 | soft_max = layers.SpatialSoftArgmax(normalize=False)(x).reshape(b, c, 2)
24 | assert_allclose(true_max.float(), soft_max)
25 |
26 | def test_global_max_pool_1d(self):
27 | b, c, t = 4, 3, 16
28 | x = torch.randn(b, c, t)
29 | x_np = x.numpy()
30 | actual = layers.GlobalMaxPool1d()(x)
31 | expected = x_np.max(axis=(-1))
32 | assert_allclose(actual, expected)
33 |
34 | def test_global_max_pool_2d(self):
35 | b, c, h, w = 4, 3, 16, 16
36 | x = torch.randn(b, c, h, w)
37 | x_np = x.numpy()
38 | actual = layers.GlobalMaxPool2d()(x)
39 | expected = x_np.max(axis=(-1, -2))
40 | assert_allclose(actual, expected)
41 |
42 | def test_global_max_pool_3d(self):
43 | b, c, t, h, w = 4, 16, 5, 16, 16
44 | x = torch.randn(b, c, t, h, w)
45 | x_np = x.numpy()
46 | actual = layers.GlobalMaxPool3d()(x)
47 | expected = x_np.max(axis=(-1, -2, -3))
48 | assert_allclose(actual, expected)
49 |
50 | def test_global_average_pool_1d(self):
51 | b, c, t = 4, 3, 16
52 | x = torch.randn(b, c, t)
53 | x_np = x.numpy()
54 | actual = layers.GlobalAvgPool1d()(x)
55 | expected = x_np.mean(axis=(-1))
56 | assert_allclose(actual, expected)
57 |
58 | def test_global_average_pool_2d(self):
59 | b, c, h, w = 4, 3, 16, 16
60 | x = torch.randn(b, c, h, w)
61 | x_np = x.numpy()
62 | actual = layers.GlobalAvgPool2d()(x)
63 | expected = x_np.mean(axis=(-1, -2))
64 | assert_allclose(actual, expected)
65 |
66 | def test_global_average_pool_3d(self):
67 | b, c, t, h, w = 4, 16, 5, 16, 16
68 | x = torch.randn(b, c, t, h, w)
69 | x_np = x.numpy()
70 | actual = layers.GlobalAvgPool3d()(x)
71 | expected = x_np.mean(axis=(-1, -2, -3))
72 | assert_allclose(actual, expected)
73 |
--------------------------------------------------------------------------------
/tests/test_checkpoint.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchkit.checkpoint import Checkpoint, CheckpointManager
7 |
8 |
9 | @pytest.fixture
10 | def init_model_and_optimizer():
11 | class SimpleMLP(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 | self.fc1 = nn.Linear(3, 16)
15 | self.fc2 = nn.Linear(16, 2)
16 |
17 | def forward(self, x):
18 | out = F.relu(self.fc1(x))
19 | return self.fc2(out)
20 |
21 | model = SimpleMLP()
22 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
23 | return model, optimizer
24 |
25 |
26 | class TestCheckpoint:
27 | def test_checkpoint_save_restore(self, tmp_path, init_model_and_optimizer):
28 | model, optimizer = init_model_and_optimizer
29 | checkpoint = Checkpoint(model=model, optimizer=optimizer)
30 | checkpoint_dir = tmp_path / "ckpts"
31 | checkpoint_dir.mkdir()
32 | checkpoint_path = checkpoint_dir / "test.ckpt"
33 | checkpoint.save(checkpoint_path)
34 | assert checkpoint.restore(checkpoint_path)
35 |
36 | def test_checkpoint_save_partial_restore(self, tmp_path, init_model_and_optimizer):
37 | model, optimizer = init_model_and_optimizer
38 | checkpoint = Checkpoint(model=model, optimizer=optimizer)
39 | checkpoint_dir = tmp_path / "ckpts"
40 | checkpoint_dir.mkdir()
41 | checkpoint_path = checkpoint_dir / "test.ckpt"
42 | checkpoint.save(checkpoint_path)
43 | assert Checkpoint(model=model).restore(checkpoint_path)
44 | assert Checkpoint(optimizer=optimizer).restore(checkpoint_path)
45 |
46 | def test_checkpoint_save_faulty_restore(self, tmp_path, init_model_and_optimizer):
47 | model, optimizer = init_model_and_optimizer
48 | checkpoint = Checkpoint(model=model, optimizer=optimizer)
49 | checkpoint_dir = tmp_path / "ckpts"
50 | checkpoint_dir.mkdir()
51 | checkpoint_path = checkpoint_dir / "test.ckpt"
52 | checkpoint.save(checkpoint_path)
53 | model.fc2 = nn.Linear(2, 2) # Purposely modify model.
54 | assert not checkpoint.restore(checkpoint_path)
55 |
56 | def test_checkpoint_manager(self, tmp_path, init_model_and_optimizer):
57 | model, optimizer = init_model_and_optimizer
58 | ckpt_dir = tmp_path / "ckpts"
59 | checkpoint_manager = CheckpointManager(
60 | ckpt_dir,
61 | max_to_keep=5,
62 | model=model,
63 | optimizer=optimizer,
64 | )
65 | global_step = checkpoint_manager.restore_or_initialize()
66 | assert global_step == 0
67 | for i in range(10):
68 | checkpoint_manager.save(i)
69 | available_ckpts = checkpoint_manager.list_checkpoints(ckpt_dir)
70 | assert len(available_ckpts) == 5
71 | ckpts = [int(d.stem) for d in available_ckpts]
72 | expected = list(range(5, 10))
73 | assert all([a == b for a, b in zip(ckpts, expected)])
74 | global_step = checkpoint_manager.restore_or_initialize()
75 | assert global_step == 9
76 | assert int(checkpoint_manager.latest_checkpoint.stem) == 9
77 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torchkit
2 |
3 | [](https://kevinzakka.github.io/torchkit/)
4 | 
5 | 
6 |
7 | **`torchkit`** is a *lightweight* library containing PyTorch utilities useful for day-to-day research. Its main goal is to abstract away a lot of the redundant boilerplate associated with research projects like experimental configurations, logging and model checkpointing. It consists of:
8 |
9 |
10 |
11 |
12 | torchkit.Logger |
13 |
14 | A wrapper around Tensorboard's SummaryWriter for safe
15 | logging of scalars, images, videos and learning rates. Supports both numpy arrays and torch Tensors.
16 | |
17 |
18 |
19 | torchkit.CheckpointManager |
20 |
21 | A port of Tensorflow's checkpoint manager that automatically manages multiple checkpoints in an experimental run.
22 | |
23 |
24 |
25 | torchkit.experiment |
26 |
27 | A collection of methods for setting up experiment directories.
28 | |
29 |
30 |
31 | torchkit.layers |
32 |
33 | A set of commonly used layers in research papers not available in vanilla PyTorch like "same" and "causal" convolution and SpatialSoftArgmax.
34 | |
35 |
36 |
37 | torchkit.losses |
38 |
39 | Some useful loss functions also unavailable in vanilla PyTorch like cross entropy with label smoothing and Huber loss.
40 | |
41 |
42 |
43 | torchkit.utils |
44 |
45 | A bunch of helper functions for config manipulation, I/O, timing, debugging, etc.
46 | |
47 |
48 |
49 |
50 |
51 | For more details about each module, see the [documentation](https://kevinzakka.github.io/torchkit/).
52 |
53 | ### Installation
54 |
55 | To install the latest release, run:
56 |
57 | ```bash
58 | pip install git+https://github.com/kevinzakka/torchkit.git
59 | ```
60 |
61 | ### Contributing
62 |
63 | For development, clone the source code and create a virtual environment for this project:
64 |
65 | ```bash
66 | git clone https://github.com/kevinzakka/torchkit.git
67 | cd torchkit
68 | pip install -e .[dev]
69 | ```
70 |
71 | ### Acknowledgments
72 |
73 | * Thanks to Karan Desai's [VirTex](https://github.com/kdexd/virtex) which I used to figure out documentation-related setup for torchkit and for just being an excellent example of stellar open-source research release.
74 | * Thanks to [seals](https://github.com/HumanCompatibleAI/seals) for the excellent software development
75 | practices that I've tried to emulate in this repo.
76 | * Thanks to Brent Yi for encouraging me to use type hinting and for letting me use his awesome [Bayesian filtering library](https://github.com/stanford-iprl-lab/torchfilter)'s README as a template.
77 |
--------------------------------------------------------------------------------
/torchkit/viz.py:
--------------------------------------------------------------------------------
1 | """Tools for visualizing resnet feature maps and ViT attention weights."""
2 |
3 | from typing import Optional, Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torchvision
9 | from PIL import Image
10 | from torchvision import transforms as T
11 |
12 |
13 | def _load_model(
14 | model_name: str,
15 | device: torch.device,
16 | ):
17 | if "dino" in model_name:
18 | assert model_name in [
19 | "dino_vits16",
20 | "dino_vits8",
21 | "dino_vitb16",
22 | "dino_vitb8",
23 | "dino_resnet50",
24 | ], f"{model_name} is not a valid DINO pretrained model."
25 | model = torch.hub.load("facebookresearch/dino:main", model_name)
26 | normalize = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
27 | elif "resnet" in model_name:
28 | model = getattr(torchvision.models, model_name)(pretrained=True)
29 | normalize = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
30 | for p in model.parameters():
31 | p.requires_grad = False
32 | model.eval()
33 | model.to(device)
34 | return model, normalize
35 |
36 |
37 | def visualize_attention(
38 | model_name: str,
39 | image: Union[str, np.ndarray],
40 | image_size: Tuple[int, int] = (480, 480),
41 | resnet_layer_idx: Optional[int] = -2,
42 | ) -> np.ndarray:
43 | # Load the model.
44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
45 | model, normalize = _load_model(model_name, device)
46 |
47 | # Create pre-processing pipeline.
48 | preprocess = T.Compose(
49 | [
50 | T.Resize(image_size),
51 | T.ToTensor(),
52 | normalize,
53 | ]
54 | )
55 |
56 | if isinstance(image, str):
57 | img = Image.open(image)
58 | else:
59 | img = Image.fromarray(image)
60 | img_tensor = preprocess(img)
61 |
62 | if "dino" in model_name and "resnet" not in model_name:
63 | patch_size = model.patch_embed.patch_size
64 | w = img_tensor.shape[1] - img_tensor.shape[1] % patch_size
65 | h = img_tensor.shape[2] - img_tensor.shape[2] % patch_size
66 | img_tensor = img_tensor[:, :w, :h].unsqueeze(0).to(device)
67 | w_featmap = img_tensor.shape[-2] // patch_size
68 | h_featmap = img_tensor.shape[-1] // patch_size
69 |
70 | attentions = model.get_last_selfattention(img_tensor)
71 | nh = attentions.shape[1]
72 |
73 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
74 | attentions = attentions.reshape(nh, w_featmap, h_featmap)
75 |
76 | attentions = (
77 | nn.functional.interpolate(
78 | attentions.unsqueeze(0),
79 | scale_factor=patch_size,
80 | mode="nearest",
81 | )[0]
82 | .cpu()
83 | .numpy()
84 | )
85 | else:
86 | layers = list(model.children())[:resnet_layer_idx]
87 | model = nn.Sequential(*layers)
88 |
89 | img_tensor = img_tensor.unsqueeze(0).to(device)
90 | img_h, img_w = img_tensor.shape[-2], img_tensor.shape[-1]
91 |
92 | with torch.no_grad():
93 | attentions = model(img_tensor)
94 |
95 | # Average over all feature maps.
96 | attentions = attentions.mean(dim=1)
97 |
98 | scale_factor_h = img_h / attentions.shape[-2]
99 | scale_factor_w = img_w / attentions.shape[-1]
100 | attentions = (
101 | nn.functional.interpolate(
102 | attentions.unsqueeze(0),
103 | scale_factor=(scale_factor_h, scale_factor_w),
104 | mode="nearest",
105 | )[0]
106 | .cpu()
107 | .numpy()
108 | )
109 |
110 | return attentions
111 |
--------------------------------------------------------------------------------
/tests/test_losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from torch.testing import assert_allclose
5 |
6 | from torchkit import losses
7 |
8 |
9 | def log_softmax(x):
10 | """A numerically stable log softmax implementation."""
11 | x = x - np.max(x, axis=-1, keepdims=True)
12 | return x - np.log(np.exp(x).sum(axis=-1, keepdims=True))
13 |
14 |
15 | class TestLayers:
16 | @pytest.mark.parametrize("smooth_eps", [0, 0.5, 1])
17 | @pytest.mark.parametrize("K", [2, 100])
18 | def test_one_hot(self, smooth_eps, K):
19 | batch_size = 32
20 | y = torch.randint(K, (batch_size,))
21 | y_np = y.numpy()
22 | actual = losses.one_hot(y, K, smooth_eps)
23 | y_np_one_hot = np.eye(K)[y_np]
24 | expected = y_np_one_hot * (1 - smooth_eps) + (smooth_eps / (K - 1))
25 | assert_allclose(actual, expected)
26 |
27 | def test_one_hot_not_rank_one(self):
28 | with pytest.raises(AssertionError):
29 | losses.one_hot(torch.randint(5, (2, 2)), 5, 0)
30 |
31 | @pytest.mark.parametrize("smooth_eps", [-1, 2])
32 | def test_one_hot_eps_out_of_bounds(self, smooth_eps):
33 | with pytest.raises(AssertionError):
34 | losses.one_hot(torch.randint(5, (2, 2)), 5, smooth_eps)
35 |
36 | @pytest.mark.parametrize("smooth_eps", [0, 0.5, 1])
37 | @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
38 | def test_cross_entropy(self, smooth_eps, reduction):
39 | batch_size = 200
40 | K = 50
41 | labels = torch.randint(K, (batch_size,))
42 | logits = torch.randn(batch_size, K)
43 | actual = losses.cross_entropy(logits, labels, smooth_eps, reduction)
44 | logits_np = logits.numpy()
45 | labels_np = labels.numpy()
46 | labels_np_one_hot = np.eye(K)[labels_np] * (1 - smooth_eps) + (
47 | smooth_eps / (K - 1)
48 | )
49 | # Compute log softmax of logits.
50 | log_probs_np = log_softmax(logits_np)
51 | loss_np = (-labels_np_one_hot * log_probs_np).sum(axis=-1)
52 | if reduction == "mean":
53 | expected = loss_np.mean()
54 | elif reduction == "sum":
55 | expected = loss_np.sum(axis=-1)
56 | else: # none
57 | expected = loss_np
58 | assert_allclose(actual, expected)
59 |
60 | def test_cross_entropy_unsupported_reduction(self):
61 | K, batch_size = 50, 200
62 | with pytest.raises(AssertionError):
63 | labels = torch.randint(K, (batch_size,))
64 | logits = torch.randn(batch_size, K)
65 | losses.cross_entropy(logits, labels, reduction="average")
66 |
67 | def test_cross_entropy_labels_dim(self):
68 | K, batch_size = 50, 200
69 | with pytest.raises(AssertionError):
70 | labels = torch.randint(K, (batch_size, 2))
71 | logits = torch.randn(batch_size, K)
72 | losses.cross_entropy(logits, labels, reduction="average")
73 |
74 | @pytest.mark.parametrize("delta", [1.0, 2.0, 10.0])
75 | @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
76 | def test_huber_loss(self, delta, reduction):
77 | batch_size = 200
78 | num_dims = 50
79 | target = torch.randn(batch_size, num_dims)
80 | input = torch.randn(batch_size, num_dims)
81 | target_np = target.numpy()
82 | input_np = input.numpy()
83 | actual = losses.huber_loss(input, target, delta, reduction)
84 | diff_np = target_np - input_np
85 | diff_abs_np = np.abs(diff_np)
86 | cond = diff_abs_np <= delta
87 | out = np.where(
88 | cond, 0.5 * diff_np ** 2, (delta * diff_abs_np) - (0.5 * delta ** 2)
89 | )
90 | if reduction == "mean":
91 | expected = out.mean()
92 | elif reduction == "sum":
93 | expected = out.sum(axis=-1)
94 | else: # none
95 | expected = out
96 | assert_allclose(actual, expected)
97 |
--------------------------------------------------------------------------------
/tests/test_logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from torchkit.logger import Logger
6 |
7 |
8 | @pytest.fixture
9 | def init_logger(tmp_path):
10 | log_dir = tmp_path / "logs"
11 | logger = Logger(log_dir, force_write=False)
12 | return logger
13 |
14 |
15 | class TestLogger:
16 | def test_error_on_init_existing(self, tmp_path, init_logger):
17 | _ = init_logger
18 | with pytest.raises(ValueError):
19 | _ = Logger(tmp_path / "logs")
20 |
21 | @pytest.mark.parametrize(
22 | "scalar", [torch.FloatTensor([5.0]), torch.randn([2, 2]).mean(), 5.0]
23 | )
24 | def test_log_scalar(self, init_logger, scalar):
25 | logger = init_logger
26 | logger.log_scalar(scalar, 0, "loss", "training")
27 |
28 | def test_log_scalar_notscalar(self, init_logger):
29 | logger = init_logger
30 | scalar = torch.FloatTensor([5.0, 5.0])
31 | with pytest.raises(ValueError):
32 | logger.log_scalar(scalar, 0, "loss", "training")
33 |
34 | @pytest.mark.parametrize(
35 | "image",
36 | [
37 | np.random.randint(0, 256, size=(224, 224, 3)),
38 | np.random.randint(0, 256, size=(2, 224, 224, 3)),
39 | torch.randint(0, 256, size=(3, 224, 224)),
40 | torch.randint(0, 256, size=(2, 3, 224, 224)),
41 | ],
42 | )
43 | def test_log_image(self, init_logger, image):
44 | logger = init_logger
45 | logger.log_image(image, 0, "image", "validation")
46 |
47 | @pytest.mark.parametrize(
48 | "image",
49 | [
50 | np.random.randint(0, 256, size=(3, 224, 224)),
51 | np.random.randint(0, 256, size=(2, 3, 224, 224)),
52 | torch.randint(0, 256, size=(224, 224, 3)),
53 | torch.randint(0, 256, size=(2, 224, 224, 3)),
54 | ],
55 | )
56 | def test_log_image_wrong_format(self, init_logger, image):
57 | logger = init_logger
58 | with pytest.raises(TypeError):
59 | logger.log_image(image, 0, "image", "validation")
60 |
61 | @pytest.mark.parametrize(
62 | "video",
63 | [
64 | np.random.randint(0, 256, size=(5, 224, 224, 3)),
65 | np.random.randint(0, 256, size=(4, 5, 224, 224, 3)),
66 | torch.randint(0, 256, size=(5, 3, 224, 224)),
67 | torch.randint(0, 256, size=(4, 5, 3, 224, 224)),
68 | ],
69 | )
70 | def test_log_video(self, init_logger, video):
71 | logger = init_logger
72 | logger.log_video(video, 0, "video", "training")
73 |
74 | def test_log_video_wrongdim(self, init_logger):
75 | logger = init_logger
76 | image = np.random.randint(0, 256, (224, 224, 3))
77 | with pytest.raises(ValueError):
78 | logger.log_video(image, 0, "video", "training")
79 |
80 | @pytest.mark.parametrize(
81 | "video",
82 | [
83 | np.random.randint(0, 256, size=(5, 3, 224, 224)),
84 | np.random.randint(0, 256, size=(4, 5, 3, 224, 224)),
85 | torch.randint(0, 256, size=(5, 224, 224, 3)),
86 | torch.randint(0, 256, size=(4, 5, 224, 224, 3)),
87 | ],
88 | )
89 | def test_log_video_wrongformat(self, init_logger, video):
90 | logger = init_logger
91 | with pytest.raises(TypeError):
92 | logger.log_video(video, 0, "video", "training")
93 |
94 | def test_learning_rate(self, init_logger):
95 | logger = init_logger
96 | param = torch.randn((32, 3), requires_grad=True)
97 | optim = torch.optim.Adam([param], lr=1e-3)
98 | logger.log_learning_rate(optim, 0, "training")
99 |
100 | def test_learning_rate_notoptim(self, init_logger):
101 | logger = init_logger
102 | param = torch.randn((32, 3), requires_grad=True)
103 | with pytest.raises(TypeError):
104 | logger.log_learning_rate(param, 0, "training")
105 |
--------------------------------------------------------------------------------
/torchkit/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Union
3 |
4 | import yaml
5 | from ml_collections import config_dict
6 |
7 | ConfigDict = config_dict.ConfigDict
8 | FrozenConfigDict = config_dict.FrozenConfigDict
9 |
10 |
11 | # Reference: https://github.com/deepmind/jaxline/blob/master/jaxline/base_config.py
12 | def validate_config(
13 | config: ConfigDict,
14 | base_config: ConfigDict,
15 | base_filename: str,
16 | ) -> None:
17 | """Ensures a config inherits from a base config.
18 |
19 | Args:
20 | config (ConfigDict): The child config.
21 | base_config (ConfigDict): The base config.
22 | base_filename (str): Path to python file containing base config definition.
23 |
24 | Raises:
25 | ValueError: if the base config contains keys that are not present in config.
26 | """
27 | for key in base_config.keys():
28 | if key not in config:
29 | raise ValueError(
30 | f"Key {key} missing from config. This config is required to have "
31 | f"keys: {list(base_config.keys())}. See {base_filename} for more "
32 | "details."
33 | )
34 | if isinstance(base_config[key], ConfigDict) and config[key] is not None:
35 | validate_config(config[key], base_config[key], base_filename)
36 |
37 |
38 | def dump_config(exp_dir: str, config: Union[ConfigDict, FrozenConfigDict]) -> None:
39 | """Dump a config to disk.
40 |
41 | Args:
42 | exp_dir (str): Path to the experiment directory.
43 | config (Union[ConfigDict, FrozenConfigDict]): The config to dump.
44 | """
45 | if not os.path.exists(exp_dir):
46 | os.makedirs(exp_dir)
47 |
48 | # Note: No need to explicitly delete the previous config file as "w" will overwrite
49 | # the file if it already exists.
50 | with open(os.path.join(exp_dir, "config.yaml"), "w") as fp:
51 | yaml.dump(config.to_dict(), fp)
52 |
53 |
54 | def load_config(
55 | exp_dir: str,
56 | config: Optional[ConfigDict] = None,
57 | freeze: bool = False,
58 | ) -> Optional[Union[ConfigDict, FrozenConfigDict]]:
59 | """Load a config from an experiment directory.
60 |
61 | Args:
62 | exp_dir (str): Path to the experiment directory.
63 | config (Optional[ConfigDict], optional): An optional config object to inplace
64 | update. If one isn't provided, a new config object is returned. Defaults to
65 | None.
66 | freeze (bool, optional): Whether to freeze the config. Defaults to False.
67 |
68 | Returns:
69 | Optional[Union[ConfigDict, FrozenConfigDict]]: The config file that was stored
70 | in the experiment directory.
71 | """
72 | with open(os.path.join(exp_dir, "config.yaml"), "r") as fp:
73 | cfg = yaml.load(fp, Loader=yaml.FullLoader)
74 | # Inplace update the config if one is provided.
75 | if config is not None:
76 | config.update(cfg)
77 | return
78 | if freeze:
79 | return FrozenConfigDict(cfg)
80 | return ConfigDict(cfg)
81 |
82 |
83 | def copy_config_and_replace(
84 | config: ConfigDict,
85 | update_dict: Optional[ConfigDict] = None,
86 | freeze: bool = False,
87 | ) -> Union[ConfigDict, FrozenConfigDict]:
88 | """Makes a copy of a config and optionally updates its values.
89 |
90 | Args:
91 | config (ConfigDict): The config to copy.
92 | update_dict (Optional[ConfigDict], optional): A config that will optionally
93 | update the copy. Defaults to None.
94 | freeze (bool, optional): Whether to freeze the config after the copy. Defaults
95 | to False.
96 |
97 | Returns:
98 | Union[ConfigDict, FrozenConfigDict]: A copy of the config.
99 | """
100 | # Using the ConfigDict constructor leaves the `FieldReferences` untouched unlike
101 | # `ConfigDict.copy_and_resolve_references`.
102 | new_config = ConfigDict(config)
103 | if update_dict is not None:
104 | new_config.update(update_dict)
105 | if freeze:
106 | return FrozenConfigDict(new_config)
107 | return new_config
108 |
--------------------------------------------------------------------------------
/torchkit/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | Tensor = torch.Tensor
5 |
6 |
7 | def one_hot(
8 | y: Tensor,
9 | K: int,
10 | smooth_eps: float = 0,
11 | ) -> Tensor:
12 | """One-hot encodes a tensor, with optional label smoothing.
13 |
14 | Args:
15 | y (Tensor): A tensor containing the ground-truth labels of shape `(N,)`, i.e.
16 | one label for each element in the batch.
17 | K (int): The number of classes.
18 | smooth_eps (float, optional): Label smoothing factor in `[0, 1]` range. Defaults
19 | to 0, which corresponds to no label smoothing.
20 |
21 | Returns:
22 | Tensor: The one-hot encoded tensor.
23 | """
24 | assert 0 <= smooth_eps <= 1
25 | assert y.ndim == 1, "Label tensor must be rank 1."
26 | y_hot = torch.eye(K)[y] * (1 - smooth_eps) + (smooth_eps / (K - 1))
27 | return y_hot.to(y.device)
28 |
29 |
30 | def cross_entropy(
31 | logits: Tensor,
32 | labels: Tensor,
33 | smooth_eps: float = 0,
34 | reduction: str = "mean",
35 | ) -> Tensor:
36 | """Cross-entropy loss with support for label smoothing.
37 |
38 | Args:
39 | logits (Tensor): A `FloatTensor` containing the raw logits, i.e. no softmax has
40 | been applied to the model output. The tensor should be of shape
41 | `(N, K)` where K is the number of classes.
42 | labels (Tensor): A rank-1 `LongTensor` containing the ground truth labels.
43 | smooth_eps (float, optional): The label smoothing factor in `[0, 1]` range.
44 | Defaults to 0.
45 | reduction (str, optional): The reduction strategy on the final loss tensor.
46 | Defaults to "mean".
47 |
48 | Returns:
49 | If reduction is `none`, a 2D Tensor.
50 | If reduction is `sum`, a 1D Tensor.
51 | If reduction is `mean`, a scalar 1D Tensor.
52 | """
53 | assert isinstance(logits, (torch.FloatTensor, torch.cuda.FloatTensor))
54 | assert isinstance(labels, (torch.LongTensor, torch.cuda.LongTensor))
55 | assert reduction in ["none", "mean", "sum"], "reduction method is not supported"
56 |
57 | # Ensure logits are not 1-hot encoded.
58 | assert labels.ndim == 1, "[!] Labels are NOT expected to be 1-hot encoded."
59 |
60 | if smooth_eps == 0:
61 | return F.cross_entropy(logits, labels, reduction=reduction)
62 |
63 | # One-hot encode targets.
64 | labels = one_hot(labels, logits.shape[1], smooth_eps)
65 |
66 | # Convert logits to log probabilities.
67 | log_probs = F.log_softmax(logits, dim=-1)
68 |
69 | loss = (-labels * log_probs).sum(dim=-1)
70 |
71 | if reduction == "none":
72 | return loss
73 | elif reduction == "mean":
74 | return loss.mean()
75 | return loss.sum(dim=-1)
76 |
77 |
78 | def huber_loss(
79 | input: Tensor,
80 | target: Tensor,
81 | delta: float,
82 | reduction: str = "mean",
83 | ) -> Tensor:
84 | """Huber loss with tunable margin, as defined in `1`_.
85 |
86 | Args:
87 | input (Tensor): A FloatTensor representing the model output.
88 | target (Tensor): A FloatTensor representing the target values.
89 | delta (float): Given the tensor difference `diff`, delta is the value at which
90 | we incur a quadratic penalty if `diff` is at least delta and a
91 | linear penalty otherwise.
92 | reduction (str, optional): The reduction strategy on the final loss tensor.
93 | Defaults to "mean".
94 |
95 | Returns:
96 | If reduction is `none`, a 2D Tensor.
97 | If reduction is `sum`, a 1D Tensor.
98 | If reduction is `mean`, a scalar 1D Tensor.
99 |
100 | .. _1: https://en.wikipedia.org/wiki/Huber_loss
101 | """
102 | assert isinstance(input, (torch.FloatTensor, torch.cuda.FloatTensor))
103 | assert isinstance(target, (torch.FloatTensor, torch.cuda.FloatTensor))
104 | assert reduction in ["none", "mean", "sum"], "reduction method is not supported"
105 |
106 | diff = target - input
107 | diff_abs = torch.abs(diff)
108 | cond = diff_abs <= delta
109 | loss = torch.where(cond, 0.5 * diff ** 2, (delta * diff_abs) - (0.5 * delta ** 2))
110 | if reduction == "none":
111 | return loss
112 | elif reduction == "mean":
113 | return loss.mean()
114 | return loss.sum(dim=-1)
115 |
--------------------------------------------------------------------------------
/torchkit/logger.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from typing import Type, Union, cast
3 |
4 | import numpy as np
5 | import torch
6 | import torchvision
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | Tensor = torch.Tensor
10 | ImageType = Union[Tensor, np.ndarray]
11 |
12 |
13 | class Logger:
14 | """A Tensorboard-based logger."""
15 |
16 | def __init__(self, log_dir: str, force_write: bool = False) -> None:
17 | """Constructor.
18 |
19 | Args:
20 | log_dir: The directory in which to store Tensorboard logs.
21 | force_write: Whether to force write to an already existing log dir.
22 | Set to `True` if resuming training.
23 | """
24 | # Setup the summary writer.
25 | if osp.exists(log_dir) and not force_write:
26 | raise ValueError(
27 | "You might be overwriting a directory that already "
28 | "has train_logs. Please provide a new experiment name "
29 | "or set --resume to True when launching train script."
30 | )
31 | self._writer = SummaryWriter(log_dir)
32 |
33 | def close(self) -> None:
34 | self._writer.close()
35 |
36 | def flush(self) -> None:
37 | self._writer.flush()
38 |
39 | def log_scalar(
40 | self,
41 | scalar: Union[Tensor, float],
42 | global_step: int,
43 | name: str,
44 | prefix: str = "",
45 | ) -> None:
46 | """Log a scalar value.
47 |
48 | Args:
49 | scalar: A scalar `torch.Tensor` or float.
50 | global_step: The training iteration step.
51 | name: The name of the logged scalar.
52 | prefix: A prefix to prepend to the logged scalar.
53 | """
54 | if isinstance(scalar, torch.Tensor):
55 | if cast(torch.Tensor, scalar).ndim > 1:
56 | raise ValueError("Tensor must be scalar-valued.")
57 | if cast(torch.Tensor, scalar).ndim == 1:
58 | if cast(torch.Tensor, scalar).shape != torch.Size([1]):
59 | raise ValueError("Tensor must be scalar-valued.")
60 | scalar = cast(torch.Tensor, scalar).item()
61 | assert np.isscalar(scalar), "Not a scalar."
62 | msg = "/".join([prefix, name]) if prefix else name
63 | self._writer.add_scalar(msg, scalar, global_step)
64 |
65 | def log_image(
66 | self,
67 | image: ImageType,
68 | global_step: int,
69 | name: str,
70 | prefix: str = "",
71 | nrow: int = 5,
72 | ) -> None:
73 | """Log an image or batch of images.
74 |
75 | Args:
76 | image: A numpy ndarray or a torch Tensor. If the image is 4D (i.e.
77 | batched), it will be converted to a 3D image using make_grid.
78 | The numpy array should be in channel-last format while the torch
79 | Tensor should be in channel-first format.
80 | global_step: The training iteration step.
81 | name: The name of the logged image(s).
82 | prefix: A prefix to prepend to the logged image(s).
83 | nrow: The number of images displayed in each row of the grid if the
84 | input image is 4D.
85 | """
86 | msg = "/".join([prefix, name]) if prefix else name
87 | assert image.ndim in [3, 4], "Must be an image or batch of images."
88 | if image.ndim == 4:
89 | if isinstance(image, np.ndarray):
90 | image = torch.from_numpy(image).permute(0, 3, 1, 2)
91 | image = torchvision.utils.make_grid(image, nrow=nrow)
92 | else:
93 | if isinstance(image, np.ndarray):
94 | image = torch.from_numpy(image).permute(2, 0, 1)
95 | self._writer.add_image(msg, image, global_step, dataformats="CHW")
96 |
97 | def log_video(
98 | self,
99 | video,
100 | global_step: int,
101 | name: str,
102 | prefix: str = "",
103 | fps: int = 4,
104 | ) -> None:
105 | """Log a sequence of images or a batch of sequence of images.
106 |
107 | Args:
108 | video: A torch Tensor or numpy ndarray. The numpy array should be in
109 | channel-last format while the torch Tensor should be in
110 | channel-first format. Should be either a single sequence of
111 | images of shape (T, CHW/HWC) or a batch of sequences of shape
112 | (B, T, CHW/HWC). The batch of sequences will get converted to
113 | one grid sequence of images.
114 | global_step: The training iteration step.
115 | name: The name of the logged video(s).
116 | prefix: A prefix to prepend to the logged video(s).
117 | fps: The frames per second.
118 | """
119 | msg = f"{prefix}/image/{name}"
120 | if video.ndim not in [4, 5]:
121 | raise ValueError("Must be a video or batch of videos.")
122 | if video.ndim == 4:
123 | if isinstance(video, np.ndarray):
124 | if video.shape[-1] != 3:
125 | raise TypeError("Numpy array should have THWC format.")
126 | # (T, H, W, C) -> (T, C, H, W).
127 | video = torch.from_numpy(video).permute(0, 3, 1, 2)
128 | elif isinstance(video, torch.Tensor):
129 | if video.shape[1] != 3:
130 | raise TypeError("Torch tensor should have TCHW format.")
131 | video = video.unsqueeze(0) # (T, C, H, W) -> (1, T, C, H, W).
132 | else:
133 | if isinstance(video, np.ndarray):
134 | if video.shape[-1] != 3:
135 | raise TypeError("Numpy array should have BTHWC format.")
136 | # (B, T, H, W, C) -> (B, T, C, H, W).
137 | video = torch.from_numpy(video).permute(0, 1, 4, 2, 3)
138 | elif isinstance(video, torch.Tensor):
139 | if video.shape[2] != 3:
140 | raise TypeError("Torch tensor should have BTCHW format.")
141 | self._writer.add_video(msg, video, global_step, fps=fps)
142 |
143 | def log_learning_rate(
144 | self,
145 | optimizer: Type[torch.optim.Optimizer],
146 | global_step: int,
147 | prefix: str = "",
148 | ) -> None:
149 | """Log the learning rate.
150 |
151 | Args:
152 | optimizer: An optimizer.
153 | global_step: The training iteration step.
154 | """
155 | if not isinstance(optimizer, torch.optim.Optimizer):
156 | raise TypeError("Optimizer must be an instance of torch.optim.Optimizer.")
157 | for param_group in optimizer.param_groups:
158 | lr = param_group["lr"]
159 | self.log_scalar(lr, global_step, "learning_rate", prefix)
160 |
--------------------------------------------------------------------------------
/torchkit/checkpoint.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import signal
4 | import tempfile
5 | from pathlib import Path
6 | from typing import Any, List, Optional, Union
7 |
8 | import torch
9 |
10 | from .experiment import unique_id
11 |
12 |
13 | def get_files(
14 | d: Path,
15 | pattern: str,
16 | sort_lexicographical: bool = False,
17 | sort_numerical: bool = False,
18 | ) -> List[Path]:
19 | """Return a list of files in a given directory.
20 |
21 | Args:
22 | d: The path to the directory.
23 | pattern: The wildcard to filter files with.
24 | sort_lexicographical: Lexicographical sort.
25 | sort_numerical: Numerical sort.
26 | """
27 | files = d.glob(pattern)
28 | if sort_lexicographical:
29 | return sorted(files, key=lambda x: x.stem)
30 | if sort_numerical:
31 | return sorted(files, key=lambda x: int(x.stem))
32 | return list(files)
33 |
34 |
35 | class Checkpoint:
36 | """Save and restore PyTorch objects implementing a `state_dict` method."""
37 |
38 | def __init__(self, **kwargs) -> None:
39 | """Constructor.
40 |
41 | Accepts keyword arguments whose values are objects that contain a
42 | `state_dict` attribute and thus can be serialized to disk.
43 |
44 | Args:
45 | kwargs: Keyword arguments are set as attributes of this object,
46 | and are saved with the checkpoint. Values must have a
47 | `state_dict` attribute.
48 |
49 | Raises:
50 | ValueError: If objects in `kwargs` do not have a `state_dict`
51 | attribute.
52 | """
53 | for k, v in sorted(kwargs.items()):
54 | if not getattr(v, "state_dict"):
55 | raise ValueError(f"{k} does not have a state_dict attribute.")
56 | setattr(self, k, v)
57 |
58 | def save(self, save_path: Path) -> None:
59 | """Save a state to disk.
60 |
61 | Modified from brentyi/fannypack.
62 |
63 | Args:
64 | save_path: The name of the checkpoint to save.
65 | """
66 | # Ignore ctrl+c while saving.
67 | try:
68 | orig_handler = signal.getsignal(signal.SIGINT)
69 | signal.signal(signal.SIGINT, lambda _sig, _frame: None)
70 | except ValueError:
71 | # Signal throws a ValueError if we're not in the main thread.
72 | orig_handler = None
73 |
74 | # Create a snapshot of the current state.
75 | save_dict = dict()
76 | for k, v in self.__dict__.items():
77 | save_dict[k] = v.state_dict()
78 |
79 | with tempfile.TemporaryDirectory() as tmp_dir:
80 | tmp_path = Path(tmp_dir) / "tmp.ckpt"
81 | torch.save(save_dict, tmp_path)
82 | # `rename` is POSIX-compliant and thus, is an atomic operation.
83 | # Ref: https://docs.python.org/3/library/os.html#os.rename
84 | os.rename(tmp_path, save_path)
85 |
86 | tmp_path = save_path.parent / f"tmp-{unique_id()}.ckpt"
87 | torch.save(save_dict, tmp_path)
88 | os.rename(tmp_path, save_path)
89 |
90 | # Restore SIGINT handler.
91 | if orig_handler is not None:
92 | signal.signal(signal.SIGINT, orig_handler)
93 |
94 | def restore(self, save_path: Union[str, Path]) -> bool:
95 | """Restore a state from a saved checkpoint.
96 |
97 | Args:
98 | save_path: The filepath to the saved checkpoint.
99 |
100 | Returns:
101 | True if restoring was successful or partially (not all
102 | checkpointables could be restored) successful and False otherwise.
103 | """
104 | try:
105 | state = torch.load(Path(save_path), map_location="cpu")
106 | for name, state_dict in state.items():
107 | if not hasattr(self, name):
108 | logging.warning(
109 | f"{name} in saved checkpoint not in checkpoint to "
110 | "reload. Skipping it."
111 | )
112 | continue
113 | getattr(self, name).load_state_dict(state_dict)
114 | return True
115 | except Exception as e:
116 | print(e)
117 | return False
118 |
119 |
120 | # TODO(kevin): Add saving of best checkpoint based on specified metric.
121 | class CheckpointManager:
122 | """
123 | Periodically save PyTorch checkpointables (any object that implements a
124 | `state_dict` method) to disk and restore them to resume training.
125 |
126 | Note: This is a re-implementation of `2`_.
127 |
128 | Example usage::
129 |
130 | from torchkit.checkpoint import CheckpointManager
131 |
132 | # Create a checkpoint manager instance.
133 | checkpoint_manager = checkpoint.CheckpointManager(
134 | checkpoint_dir,
135 | device,
136 | model=model,
137 | optimizer=optimizer,
138 | )
139 |
140 | # Restore last checkpoint if it exists.
141 | global_step = checkpoint_manager.restore_or_initialize()
142 | for global_step in range(1000):
143 | # forward pass + loss computation
144 |
145 | # Save a checkpoint every N iters.
146 | if not global_step % N:
147 | checkpoint_manager.save(global_step)
148 |
149 | .. _2: https://www.tensorflow.org/api_docs/python/tf/train/CheckpointManager/
150 | """
151 |
152 | def __init__(
153 | self,
154 | directory: str,
155 | max_to_keep: int = 10,
156 | **checkpointables: Any,
157 | ) -> None:
158 | """Constructor.
159 |
160 | Args:
161 | directory: The directory in which checkpoints will be saved.
162 | max_to_keep: The maximum number of checkpoints to keep.
163 | Amongst all saved checkpoints, checkpoints will be deleted
164 | oldest first, until `max_to_keep` remain.
165 | checkpointables: Keyword args with checkpointable PyTorch objects.
166 | """
167 | assert max_to_keep > 0, "max_to_keep should be a positive integer."
168 |
169 | self.directory = Path(directory).absolute()
170 | self.max_to_keep = max_to_keep
171 | self.checkpoint = Checkpoint(**checkpointables)
172 |
173 | # Create checkpoint directory if it doesn't already exist.
174 | self.directory.mkdir(parents=True, exist_ok=True)
175 |
176 | def restore_or_initialize(self) -> int:
177 | """Restore items in checkpoint from the latest checkpoint file.
178 |
179 | Returns:
180 | The global iteration step. This is parsed from the latest checkpoint
181 | file if one is found, else 0 is returned.
182 | """
183 | ckpts = CheckpointManager.list_checkpoints(self.directory)
184 | if not ckpts:
185 | return 0
186 | last_ckpt = ckpts[-1]
187 | status = self.checkpoint.restore(last_ckpt)
188 | if not status:
189 | logging.info("Could not restore latest checkpoint file.")
190 | return 0
191 | return int(last_ckpt.stem)
192 |
193 | def save(self, global_step: int) -> None:
194 | """Create a new checkpoint.
195 |
196 | Args:
197 | global_step: The iteration number which will be used to name the
198 | checkpoint.
199 | """
200 | save_path = self.directory / f"{global_step}.ckpt"
201 | self.checkpoint.save(save_path)
202 | self._trim_checkpoints()
203 |
204 | def _trim_checkpoints(self) -> None:
205 | """Trim older checkpoints until `max_to_keep` remain."""
206 | # Get a list of checkpoints in reverse global_step order.
207 | ckpts = CheckpointManager.list_checkpoints(self.directory)[::-1]
208 | # Remove until `max_to_keep` remain.
209 | while len(ckpts) - self.max_to_keep > 0:
210 | ckpts.pop().unlink()
211 |
212 | def load_latest_checkpoint(self) -> None:
213 | """Load the last saved checkpoint."""
214 | self.checkpoint.restore(self.latest_checkpoint)
215 |
216 | def load_checkpoint_at(self, global_step: int) -> None:
217 | """Load a checkpoint at a given global step."""
218 | ckpt_name = self.directory / f"{global_step}.ckpt"
219 | if ckpt_name not in CheckpointManager.list_checkpoints(self.directory):
220 | raise ValueError(f"No checkpoint found at step {global_step}.")
221 | self.checkpoint.restore(ckpt_name)
222 |
223 | @property
224 | def latest_checkpoint(self) -> Optional[Path]:
225 | """Get the last saved checkpoint."""
226 | ckpts = CheckpointManager.list_checkpoints(self.directory)
227 | if not ckpts:
228 | raise ValueError(f"No checkpoints found in {self.directory}.")
229 | return ckpts[-1]
230 |
231 | @staticmethod
232 | def list_checkpoints(directory: Union[Path, str]) -> List[Path]:
233 | """List all checkpoints in a checkpoint directory."""
234 | return get_files(Path(directory), "*.ckpt", sort_numerical=True)
235 |
--------------------------------------------------------------------------------
/torchkit/layers.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | ConvType = Union[torch.nn.modules.conv.Conv2d, torch.nn.modules.conv.Conv3d]
8 | Tensor = torch.Tensor
9 |
10 |
11 | def _conv(
12 | dim: int,
13 | in_channels: int,
14 | out_channels: int,
15 | kernel_size: int = 3,
16 | stride: int = 1,
17 | dilation: int = 1,
18 | bias: bool = True,
19 | ) -> ConvType:
20 | """`same` convolution, i.e. output shape equals input shape.
21 |
22 | Args:
23 | dim: The dimension of the convolution: 2 is conv2d, 3 is conv3d.
24 | in_planes: The number of input feature maps.
25 | out_planes: The number of output feature maps.
26 | kernel_size: The filter size.
27 | stride: The filter stride.
28 | dilation: The filter dilation factor.
29 | bias: Whether to add a bias.
30 | """
31 | assert dim in [2, 3], "[!] Only 2D and 3D convolution supported."
32 | conv = nn.Conv2d if dim == 2 else nn.Conv3d
33 |
34 | # Compute new filter size after dilation and necessary padding for `same`
35 | # output size.
36 | dilated_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
37 | same_padding = (dilated_kernel_size - 1) // 2
38 |
39 | return conv(
40 | in_channels,
41 | out_channels,
42 | kernel_size=kernel_size,
43 | stride=stride,
44 | padding=same_padding,
45 | dilation=dilation,
46 | bias=bias,
47 | )
48 |
49 |
50 | def conv2d(*args, **kwargs) -> torch.nn.modules.conv.Conv2d:
51 | """`same` 2D convolution, i.e. output shape equals input shape.
52 |
53 | Args:
54 | in_planes: The number of input feature maps.
55 | out_planes: The number of output feature maps.
56 | kernel_size: The filter size.
57 | stride: The filter stride.
58 | dilation: The filter dilation factor.
59 | bias: Whether to add a bias.
60 | """
61 | return _conv(2, *args, **kwargs)
62 |
63 |
64 | def conv3d(*args, **kwargs) -> torch.nn.modules.conv.Conv3d:
65 | """`same` 3D convolution, i.e. output shape equals input shape.
66 |
67 | Args:
68 | in_planes: The number of input feature maps.
69 | out_planes: The number of output feature maps.
70 | kernel_size: The filter size.
71 | stride: The filter stride.
72 | dilation: The filter dilation factor.
73 | bias: Whether to add a bias.
74 | """
75 | return _conv(3, *args, **kwargs)
76 |
77 |
78 | class Flatten(nn.Module):
79 | """Flattens convolutional feature maps for fully-connected layers.
80 |
81 | This is a convenience module meant to be plugged into a
82 | `torch.nn.Sequential` model.
83 |
84 | Example usage::
85 |
86 | import torch.nn as nn
87 | from torchkit import layers
88 |
89 | # Assume an input of shape (3, 28, 28).
90 | net = nn.Sequential(
91 | layers.conv2d(3, 8, kernel_size=3),
92 | nn.ReLU(),
93 | layers.conv2d(8, 16, kernel_size=3),
94 | nn.ReLU(),
95 | layers.Flatten(),
96 | nn.Linear(28*28*16, 256),
97 | nn.ReLU(),
98 | nn.Linear(256, 2),
99 | )
100 | """
101 |
102 | def __init__(self):
103 | super().__init__()
104 |
105 | def forward(self, x: Tensor) -> Tensor:
106 | return x.view(x.shape[0], -1)
107 |
108 |
109 | class SpatialSoftArgmax(nn.Module):
110 | """Spatial softmax as defined in `1`_.
111 |
112 | Concretely, the spatial softmax of each feature map is used to compute a
113 | weighted mean of the pixel locations, effectively performing a soft arg-max
114 | over the feature dimension.
115 |
116 | .. _1: https://arxiv.org/abs/1504.00702
117 | """
118 |
119 | def __init__(self, normalize: bool = False):
120 | """Constructor.
121 |
122 | Args:
123 | normalize: Whether to use normalized image coordinates, i.e.
124 | coordinates in the range `[-1, 1]`.
125 | """
126 | super().__init__()
127 |
128 | self.normalize = normalize
129 |
130 | def _coord_grid(
131 | self,
132 | h: int,
133 | w: int,
134 | device: torch.device,
135 | ) -> Tensor:
136 | if self.normalize:
137 | return torch.stack(
138 | torch.meshgrid(
139 | torch.linspace(-1, 1, w, device=device),
140 | torch.linspace(-1, 1, h, device=device),
141 | )
142 | )
143 | return torch.stack(
144 | torch.meshgrid(
145 | torch.arange(0, w, device=device),
146 | torch.arange(0, h, device=device),
147 | )
148 | )
149 |
150 | def forward(self, x: Tensor) -> Tensor:
151 | assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)."
152 |
153 | # Compute a spatial softmax over the input:
154 | # Given an input of shape (B, C, H, W), reshape it to (B*C, H*W) then
155 | # apply the softmax operator over the last dimension.
156 | b, c, h, w = x.shape
157 | softmax = F.softmax(x.view(-1, h * w), dim=-1)
158 |
159 | # Create a meshgrid of normalized pixel coordinates.
160 | xc, yc = self._coord_grid(h, w, x.device)
161 |
162 | # Element-wise multiply the x and y coordinates with the softmax, then
163 | # sum over the h*w dimension. This effectively computes the weighted
164 | # mean x and y locations.
165 | x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True)
166 | y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True)
167 |
168 | # Concatenate and reshape the result to (B, C*2) where for every feature
169 | # we have the expected x and y pixel locations.
170 | return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2)
171 |
172 |
173 | class _GlobalMaxPool(nn.Module):
174 | """Global max pooling layer."""
175 |
176 | def __init__(self, dim):
177 | super().__init__()
178 |
179 | if dim == 1:
180 | self._pool = F.max_pool1d
181 | elif dim == 2:
182 | self._pool = F.max_pool2d
183 | elif dim == 3:
184 | self._pool = F.max_pool3d
185 | else:
186 | raise ValueError("{}D is not supported.")
187 |
188 | def forward(self, x: Tensor) -> Tensor:
189 | out = self._pool(x, kernel_size=x.size()[2:])
190 | for _ in range(len(out.shape[2:])):
191 | out.squeeze_(dim=-1)
192 | return out
193 |
194 |
195 | class GlobalMaxPool1d(_GlobalMaxPool):
196 | """Global max pooling operation for temporal or 1D data."""
197 |
198 | def __init__(self):
199 | super().__init__(dim=1)
200 |
201 |
202 | class GlobalMaxPool2d(_GlobalMaxPool):
203 | """Global max pooling operation for spatial or 2D data."""
204 |
205 | def __init__(self):
206 | super().__init__(dim=2)
207 |
208 |
209 | class GlobalMaxPool3d(_GlobalMaxPool):
210 | """Global max pooling operation for 3D data."""
211 |
212 | def __init__(self):
213 | super().__init__(dim=3)
214 |
215 |
216 | class _GlobalAvgPool(nn.Module):
217 | """Global average pooling layer."""
218 |
219 | def __init__(self, dim):
220 | super().__init__()
221 |
222 | if dim == 1:
223 | self._pool = F.avg_pool1d
224 | elif dim == 2:
225 | self._pool = F.avg_pool2d
226 | elif dim == 3:
227 | self._pool = F.avg_pool3d
228 | else:
229 | raise ValueError("{}D is not supported.")
230 |
231 | def forward(self, x: Tensor) -> Tensor:
232 | out = self._pool(x, kernel_size=x.size()[2:])
233 | for _ in range(len(out.shape[2:])):
234 | out.squeeze_(dim=-1)
235 | return out
236 |
237 |
238 | class GlobalAvgPool1d(_GlobalAvgPool):
239 | """Global average pooling operation for temporal or 1D data."""
240 |
241 | def __init__(self):
242 | super().__init__(dim=1)
243 |
244 |
245 | class GlobalAvgPool2d(_GlobalAvgPool):
246 | """Global average pooling operation for spatial or 2D data."""
247 |
248 | def __init__(self):
249 | super().__init__(dim=2)
250 |
251 |
252 | class GlobalAvgPool3d(_GlobalAvgPool):
253 | """Global average pooling operation for 3D data."""
254 |
255 | def __init__(self):
256 | super().__init__(dim=3)
257 |
258 |
259 | class CausalConv1d(nn.Conv1d):
260 | """A causal a.k.a. masked 1D convolution."""
261 |
262 | def __init__(
263 | self,
264 | in_channels: int,
265 | out_channels: int,
266 | kernel_size: int,
267 | stride: int = 1,
268 | dilation: int = 1,
269 | bias: bool = True,
270 | ):
271 | """Constructor.
272 |
273 | Args:
274 | in_channels: The number of input channels.
275 | out_channels: The number of output channels.
276 | kernel_size: The filter size.
277 | stride: The filter stride.
278 | dilation: The filter dilation factor.
279 | bias: Whether to add the bias term or not.
280 | """
281 | self.__padding = (kernel_size - 1) * dilation
282 |
283 | super().__init__(
284 | in_channels,
285 | out_channels,
286 | kernel_size=kernel_size,
287 | stride=stride,
288 | padding=self.__padding,
289 | dilation=dilation,
290 | bias=bias,
291 | )
292 |
293 | def forward(self, x: Tensor) -> Tensor:
294 | res = super().forward(x)
295 | if self.__padding != 0:
296 | return res[:, :, : -self.__padding]
297 | return res
298 |
--------------------------------------------------------------------------------