├── torchkit ├── utils │ ├── module_freezing.py │ ├── dataset.py │ ├── git.py │ ├── timer.py │ ├── __init__.py │ ├── multithreading.py │ ├── seed.py │ ├── io.py │ ├── pdb_fallback.py │ ├── module_stats.py │ └── config.py ├── __init__.py ├── version.py ├── experiment.py ├── viz.py ├── losses.py ├── logger.py ├── checkpoint.py └── layers.py ├── docs ├── torchkit │ ├── utils │ │ ├── io.rst │ │ ├── git.rst │ │ ├── seed.rst │ │ ├── config.rst │ │ ├── timer.rst │ │ ├── dataset.rst │ │ ├── pdb_fallback.rst │ │ ├── module_stats.rst │ │ └── index.rst │ ├── layers.rst │ ├── losses.rst │ ├── experiment.rst │ ├── installation.rst │ ├── checkpoint.rst │ └── logger.rst ├── Makefile ├── index.rst ├── make.bat └── conf.py ├── pyproject.toml ├── setup.cfg ├── .github └── workflows │ ├── lint.yml │ ├── docs.yml │ └── build.yml ├── scripts └── lint.sh ├── .gitignore ├── LICENSE ├── setup.py ├── tests ├── test_layers.py ├── test_checkpoint.py ├── test_losses.py └── test_logger.py └── README.md /torchkit/utils/module_freezing.py: -------------------------------------------------------------------------------- 1 | # TODO(kevin): Implement freezing and unfreezing of `nn.Module`. 2 | # TODO(kevin): Implement batch norm freezing. 3 | -------------------------------------------------------------------------------- /docs/torchkit/utils/io.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.io 2 | ================= 3 | 4 | .. automodule:: torchkit.utils.io 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ["py38"] 4 | 5 | [tool.isort] 6 | profile = "black" 7 | multi_line_output = 3 8 | line_length = 88 -------------------------------------------------------------------------------- /docs/torchkit/utils/git.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.git 2 | ================== 3 | 4 | .. automodule:: torchkit.utils.git 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /docs/torchkit/utils/seed.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.seed 2 | =================== 3 | 4 | .. automodule:: torchkit.utils.seed 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /docs/torchkit/layers.rst: -------------------------------------------------------------------------------- 1 | torchkit.layers 2 | =============== 3 | 4 | .. automodule:: torchkit.layers 5 | :members: 6 | :member-order: bysource 7 | :exclude-members: forward 8 | -------------------------------------------------------------------------------- /docs/torchkit/losses.rst: -------------------------------------------------------------------------------- 1 | torchkit.losses 2 | =============== 3 | 4 | .. automodule:: torchkit.losses 5 | :members: 6 | :member-order: bysource 7 | :exclude-members: forward 8 | -------------------------------------------------------------------------------- /docs/torchkit/utils/config.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.config 2 | ===================== 3 | 4 | .. automodule:: torchkit.utils.config 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /docs/torchkit/utils/timer.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.timer 2 | ==================== 3 | 4 | .. automodule:: torchkit.utils.timer 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /docs/torchkit/utils/dataset.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.dataset 2 | ====================== 3 | 4 | .. automodule:: torchkit.utils.dataset 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | docstring-convention=google 3 | ignore = E203, W503 4 | max-line-length = 88 5 | 6 | [pytype] 7 | inputs = 8 | torchkit/ 9 | tests/ 10 | setup.py 11 | -------------------------------------------------------------------------------- /docs/torchkit/utils/pdb_fallback.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.pdb_fallback 2 | =========================== 3 | 4 | .. automodule:: torchkit.utils.pdb_fallback 5 | :members: 6 | :member-order: alphabetical 7 | -------------------------------------------------------------------------------- /docs/torchkit/experiment.rst: -------------------------------------------------------------------------------- 1 | torchkit.experiment 2 | =================== 3 | 4 | .. automodule:: torchkit.experiment 5 | :members: 6 | :member-order: alphabetical 7 | :exclude-members: __init__ 8 | -------------------------------------------------------------------------------- /docs/torchkit/utils/module_stats.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils.module_stats 2 | =========================== 3 | 4 | .. automodule:: torchkit.utils.module_stats 5 | :members: 6 | :member-order: alphabetical 7 | 8 | -------------------------------------------------------------------------------- /docs/torchkit/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | **torchkit** requires Python 3.8+ or higher. You can install directly from Git master:: 5 | 6 | pip install git+https://github.com/kevinzakka/torchkit.git 7 | -------------------------------------------------------------------------------- /docs/torchkit/utils/index.rst: -------------------------------------------------------------------------------- 1 | torchkit.utils 2 | ============== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | config 8 | dataset 9 | git 10 | module_stats 11 | pdb_fallback 12 | io 13 | seed 14 | timer 15 | -------------------------------------------------------------------------------- /torchkit/__init__.py: -------------------------------------------------------------------------------- 1 | """A PyTorch toolkit for research.""" 2 | 3 | from torchkit.checkpoint import CheckpointManager # noqa: F401 4 | from torchkit.logger import Logger # noqa: F401 5 | from torchkit.version import __version__ # noqa: F401 6 | -------------------------------------------------------------------------------- /torchkit/version.py: -------------------------------------------------------------------------------- 1 | """Version file will be exec'd by setup.py and imported by __init__.py""" 2 | 3 | # Semantic versioning. 4 | _MAJOR_VERSION = "0" 5 | _MINOR_VERSION = "0" 6 | _PATCH_VERSION = "3" 7 | 8 | __version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) 9 | -------------------------------------------------------------------------------- /docs/torchkit/checkpoint.rst: -------------------------------------------------------------------------------- 1 | torchkit.checkpoint 2 | =================== 3 | 4 | .. autoclass:: torchkit.checkpoint.Checkpoint 5 | :members: 6 | :member-order: groupwise 7 | 8 | .. autoclass:: torchkit.checkpoint.CheckpointManager 9 | :members: 10 | :member-order: groupwise 11 | -------------------------------------------------------------------------------- /docs/torchkit/logger.rst: -------------------------------------------------------------------------------- 1 | torchkit.logger 2 | =============== 3 | 4 | 5 | The `Logger` class is a simple wrapper over Pytorch's ``SummaryWriter``. You can use it to log scalars, images and videos in numpy ndarray or torch Tensor format. 6 | 7 | .. autoclass:: torchkit.logger.Logger 8 | :members: 9 | :member-order: groupwise 10 | -------------------------------------------------------------------------------- /torchkit/utils/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Iterator 2 | 3 | 4 | # Reference: https://github.com/unixpickle/vq-voice-swap/blob/main/vq_voice_swap/util.py 5 | def infinite_dataset(data_loader: Iterable) -> Iterator: 6 | """Create an infinite loop over a `torch.utils.DataLoader`. 7 | 8 | Args: 9 | data_loader (Iterable): A `torch.utils.DataLoader` object. 10 | 11 | Yields: 12 | Iterator: An iterator over the dataloader that repeats ad infinitum. 13 | """ 14 | while True: 15 | yield from data_loader 16 | -------------------------------------------------------------------------------- /torchkit/utils/git.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | # Reference: https://stackoverflow.com/a/21901260 5 | def git_revision_hash() -> str: 6 | """Return the git commit hash of the current directory. 7 | 8 | Note: 9 | Will return a `fatal: not a git repository` string if the command fails. 10 | """ 11 | try: 12 | string = subprocess.check_output( 13 | ["git", "rev-parse", "HEAD"], stderr=subprocess.STDOUT 14 | ) 15 | except subprocess.CalledProcessError as err: 16 | string = err.output 17 | return string.decode("ascii").strip() 18 | 19 | 20 | # Alias. 21 | git_commit_hash = git_revision_hash 22 | -------------------------------------------------------------------------------- /torchkit/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Stopwatch: 5 | """A simple timer for measuring elapsed time. 6 | 7 | Example usage:: 8 | 9 | stopwatch = Stopwatch() 10 | some_func() 11 | print(f"some_func took: {stopwatch.elapsed()} seconds.") 12 | stopwatch.reset() 13 | """ 14 | 15 | def __init__(self) -> None: 16 | self.reset() 17 | 18 | def elapsed(self) -> float: 19 | """Return the elapsed time since the stopwatch was reset.""" 20 | return time.time() - self.time 21 | 22 | def reset(self) -> None: 23 | """Reset the stopwatch, i.e. start the timer.""" 24 | self.time = time.time() 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.8", "3.9"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip setuptools wheel 25 | pip install -e ".[dev]" 26 | - name: Run lint script 27 | run: | 28 | bash scripts/lint.sh 29 | -------------------------------------------------------------------------------- /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Modified from https://raw.githubusercontent.com/HumanCompatibleAI/seals/master/ci/code_checks.sh 4 | set -x 5 | set -e 6 | 7 | SRC_FILES=(torchkit/ tests/ docs/conf.py setup.py) 8 | 9 | if [ "$(uname)" == "Darwin" ]; then 10 | N_CPU=$(sysctl -n hw.ncpu) 11 | elif [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then 12 | N_CPU=$(grep -c ^processor /proc/cpuinfo) 13 | fi 14 | 15 | echo "Source format checking" 16 | flake8 ${SRC_FILES[@]} 17 | black --check ${SRC_FILES} 18 | 19 | if [ "$skipexpensive" != "true" ]; then 20 | echo "Building docs (validates docstrings)" 21 | pushd docs/ 22 | make clean 23 | make html 24 | popd 25 | 26 | echo "Type checking" 27 | pytype -n "${N_CPU}" ${SRC_FILES[@]} 28 | fi 29 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | 7 | jobs: 8 | docs: 9 | runs-on: ubuntu-latest 10 | container: 11 | image: python:3.8 12 | steps: 13 | 14 | # Check out source 15 | - uses: actions/checkout@v2 16 | 17 | # Build documentation 18 | - name: Building documentation 19 | run: | 20 | apt-get update 21 | apt-get -y install libgl1-mesa-glx 22 | pip install -e .[dev] 23 | sphinx-build docs/ docs/_build -b dirhtml 24 | 25 | # Deploy 26 | - name: Deploy to GitHub Pages 27 | uses: peaceiris/actions-gh-pages@v3 28 | with: 29 | github_token: ${{ secrets.GITHUB_TOKEN }} 30 | publish_dir: ./docs/_build 31 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.8", "3.9"] 15 | name: Python ${{ matrix.python-version }} 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip setuptools wheel 25 | pip install -e .[test] 26 | - name: Test with pytest 27 | run: | 28 | pytest tests 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | .pytype 3 | .vscode 4 | .ipynb_checkpoints 5 | playground.ipynb 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # Installer logs 30 | pip-log.txt 31 | pip-delete-this-directory.txt 32 | 33 | # Unit test / coverage reports 34 | .tox/ 35 | .coverage 36 | .cache 37 | nosetests.xml 38 | coverage.xml 39 | 40 | # Translations 41 | *.mo 42 | 43 | # Mr Developer 44 | .mr.developer.cfg 45 | .project 46 | .pydevproject 47 | 48 | # Rope 49 | .ropeproject 50 | 51 | # Django stuff: 52 | *.log 53 | *.pot 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | docs/_templates 58 | -------------------------------------------------------------------------------- /torchkit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import copy_config_and_replace, dump_config, load_config, validate_config 2 | from .dataset import infinite_dataset 3 | from .git import git_commit_hash, git_revision_hash 4 | from .io import load_pickle, save_pickle 5 | from .module_stats import get_total_params 6 | from .multithreading import threaded_func 7 | from .pdb_fallback import pdb_fallback 8 | from .seed import seed_rngs, set_cudnn 9 | from .timer import Stopwatch 10 | 11 | __all__ = [ 12 | "threaded_func", 13 | "Stopwatch", 14 | "pdb_fallback", 15 | "get_total_params", 16 | "git_revision_hash", 17 | "git_commit_hash", 18 | "seed_rngs", 19 | "set_cudnn", 20 | "validate_config", 21 | "dump_config", 22 | "load_config", 23 | "copy_config_and_replace", 24 | "infinite_dataset", 25 | "save_pickle", 26 | "load_pickle", 27 | ] 28 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. raw:: html 2 | 3 |

TorchKit: A PyTorch Toolkit for Research

4 |
5 | 6 | **torchkit** is a *lightweight* library containing PyTorch utilities useful for day-to-day research. 7 | Its main goal is to abstract away a lot of the redundant boilerplate associated with research projects 8 | like experimental configurations, logging and model checkpointing. 9 | 10 | .. toctree:: 11 | :maxdepth: 1 12 | :caption: User Guide 13 | 14 | torchkit/installation 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :caption: API Reference 19 | 20 | torchkit/logger 21 | torchkit/checkpoint 22 | torchkit/experiment 23 | torchkit/layers 24 | torchkit/losses 25 | torchkit/utils/index 26 | 27 | Indices and tables 28 | ================== 29 | 30 | * :ref:`genindex` 31 | * :ref:`modindex` 32 | * :ref:`search` 33 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /torchkit/utils/multithreading.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import Callable, Iterable, Tuple 3 | 4 | 5 | # Adapted from: https://github.com/facebookresearch/pytorchvideo/blob/master/pytorchvideo/data/utils.py#L99 # noqa: E501 6 | def threaded_func( 7 | func: Callable, 8 | args_iterable: Iterable[Tuple], 9 | multithreaded: bool, 10 | ) -> None: 11 | """Applies a func on a tuple of args with optional multithreading. 12 | 13 | Args: 14 | func: The func to execute. 15 | args_iterable: An iterable of arg tuples to feed to func. 16 | multithreaded: Whether to parallelize the func across threads. 17 | """ 18 | if multithreaded: 19 | threads = [] 20 | for args in args_iterable: 21 | thread = threading.Thread(target=func, args=args) 22 | thread.start() 23 | threads.append(thread) 24 | for t in threads: 25 | t.join() 26 | else: 27 | for args in args_iterable: 28 | func(*args) 29 | -------------------------------------------------------------------------------- /torchkit/utils/seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def seed_rngs(seed: int, pytorch: bool = True) -> None: 9 | """Seed system RNGs. 10 | 11 | Args: 12 | seed (int): The desired seed. 13 | pytorch (bool, optional): Whether to seed the `torch` RNG as well. Defaults to 14 | True. 15 | """ 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | if pytorch: 20 | torch.manual_seed(seed) 21 | 22 | 23 | def set_cudnn(deterministic: bool = False, benchmark: bool = True) -> None: 24 | """Set PyTorch-related CUDNN settings. 25 | 26 | Args: 27 | deterministic (bool, optional): Make CUDA algorithms deterministic. Defaults to 28 | False. 29 | benchmark (bool, optional): Make CUDA arlgorithm selection deterministic. 30 | Defaults to True. 31 | """ 32 | torch.backends.cudnn.deterministic = deterministic 33 | torch.backends.cudnn.benchmark = benchmark 34 | -------------------------------------------------------------------------------- /torchkit/utils/io.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | from typing import Any 5 | 6 | 7 | def save_pickle(obj: Any, path: str, name: str) -> None: 8 | """Save a python object as a pickle file. 9 | 10 | Args: 11 | obj (Any): The object to save. 12 | path (str): Directory wherein to save file. 13 | name (str): Name of the pickle file. 14 | """ 15 | filename = os.path.join(path, name) 16 | with open(filename, "wb") as fp: 17 | pickle.dump(obj, fp) 18 | logging.info(f"Successfully saved {filename}") 19 | 20 | 21 | def load_pickle(path: str, name: str) -> Any: 22 | """Load a pickled file. 23 | 24 | Args: 25 | path (str): The directory where the pickle file is stored. 26 | name (str): The name of the pickle file. 27 | 28 | Returns: 29 | Any: The object in the pickle file. 30 | """ 31 | filename = os.path.join(path, name) 32 | with open(filename, "rb") as fp: 33 | obj = pickle.load(fp) 34 | logging.info(f"Successfully loaded {filename}.") 35 | return obj 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kevin Zakka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | THIS_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | DESCRIPTION = "torchkit is a lightweight library containing PyTorch utilities useful for day-to-day research." # noqa: E501 7 | TESTS_REQUIRE = [ 8 | "pytest", 9 | "black", 10 | "isort", 11 | "pytype", 12 | "flake8", 13 | ] 14 | DOCS_REQUIRE = [ 15 | "sphinx", 16 | "sphinx-autodoc-typehints", 17 | "sphinx-rtd-theme", 18 | "docutils==0.16", 19 | ] 20 | 21 | 22 | def readme() -> str: 23 | """Load README for use as package's long description.""" 24 | with open(os.path.join(THIS_DIR, "README.md"), "r") as fp: 25 | return fp.read() 26 | 27 | 28 | def get_version() -> str: 29 | locals_dict = {} 30 | with open(os.path.join(THIS_DIR, "torchkit", "version.py"), "r") as fp: 31 | exec(fp.read(), globals(), locals_dict) 32 | return locals_dict["__version__"] # pytype: disable=invalid-directive 33 | 34 | 35 | setup( 36 | name="torchkit", 37 | version=get_version(), 38 | author="Kevin Zakka", 39 | license="MIT", 40 | description=DESCRIPTION, 41 | python_requires=">=3.8", 42 | long_description=readme(), 43 | long_description_content_type="text/markdown", 44 | packages=find_packages(), 45 | install_requires=[ 46 | "torch>=1.3", 47 | "torchvision>=0.4", 48 | "tensorboard", 49 | "prettytable", 50 | "opencv-python", 51 | "moviepy", 52 | "ml_collections", 53 | "ipdb", 54 | ], 55 | extras_require={ 56 | "dev": ["jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE], 57 | "test": TESTS_REQUIRE, 58 | }, 59 | tests_require=TESTS_REQUIRE, 60 | url="https://github.com/kevinzakka/torchkit/", 61 | ) 62 | -------------------------------------------------------------------------------- /torchkit/utils/pdb_fallback.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Callable, TypeVar 3 | 4 | CallableType = TypeVar("CallableType", bound=Callable) 5 | 6 | 7 | # Reference: https://github.com/brentyi/fannypack/blob/master/fannypack/utils/_pdb_safety_net.py # noqa: E501 8 | def pdb_fallback(f: CallableType, use_ipdb: bool = False) -> CallableType: 9 | """Wraps a function in a pdb safety net for unexpected errors in a Python script. 10 | 11 | When called, pdb will be automatically opened when either (a) the user hits Ctrl+C 12 | or (b) we encounter an uncaught exception. Helpful for bypassing minor errors, 13 | diagnosing problems, and rescuing unsaved models. 14 | 15 | Example usage:: 16 | 17 | from torchkit.utils import pdb_fallback 18 | 19 | @pdb_fallback 20 | def main(): 21 | # A very interesting function that might fail because we did something 22 | # stupid. 23 | ... 24 | 25 | if __name__ == "__main__": 26 | main() 27 | 28 | Args: 29 | f (CallableType): The function to wrap. 30 | use_ipdb (bool, optional): Whether to use ipdb instead of pdb. Defaults to 31 | False. 32 | """ 33 | 34 | import signal 35 | import sys 36 | import traceback as tb 37 | 38 | if use_ipdb: 39 | import ipdb as pdb 40 | else: 41 | import pdb 42 | 43 | @functools.wraps(f) 44 | def inner_wrapper(*args, **kwargs): 45 | # Open pdb on Ctrl-C. 46 | def handler(sig, frame): 47 | pdb.set_trace() 48 | 49 | signal.signal(signal.SIGINT, handler) 50 | 51 | # Open pdb when we encounter an uncaught exception. 52 | def excepthook(type_, value, traceback): 53 | tb.print_exception(type_, value, traceback, limit=100) 54 | pdb.post_mortem(traceback) 55 | 56 | sys.excepthook = excepthook 57 | 58 | return f(*args, **kwargs) 59 | 60 | return inner_wrapper 61 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder.""" 2 | 3 | import torchkit 4 | 5 | # -- Project information ----------------------------------------------------- 6 | 7 | project = "torchkit" 8 | copyright = "2021, Kevin Zakka" 9 | author = "Kevin Zakka" 10 | release = torchkit.__version__ 11 | 12 | # -- General configuration --------------------------------------------------- 13 | 14 | master_doc = "index" 15 | 16 | extensions = [ 17 | "sphinx.ext.napoleon", 18 | "sphinx.ext.autodoc", 19 | "sphinx_autodoc_typehints", 20 | "sphinx.ext.autosummary", 21 | "sphinx.ext.mathjax", 22 | "sphinx.ext.viewcode", 23 | "sphinx_rtd_theme", 24 | ] 25 | 26 | 27 | autodoc_typehints = "description" 28 | 29 | napoleon_google_docstring = True 30 | napoleon_numpy_docstring = False 31 | 32 | # Add any paths that contain templates here, relative to this directory. 33 | templates_path = ["_templates"] 34 | 35 | # List of patterns, relative to source directory, that match files and 36 | # directories to ignore when looking for source files. 37 | # This pattern also affects html_static_path and html_extra_path. 38 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 39 | 40 | autodoc_default_options = { 41 | "members": True, 42 | "undoc-members": True, 43 | "special-members": "__init__", 44 | } 45 | 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | html_theme = "sphinx_rtd_theme" 50 | html_theme_options = { 51 | "logo_only": True, 52 | "style_nav_header_background": "#06203A", 53 | } 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | html_static_path = ["_static"] 59 | 60 | # --------------------------------------------------------------------------- 61 | 62 | sphinx_to_github = True 63 | sphinx_to_github_verbose = True 64 | sphinx_to_github_encoding = "utf-8" 65 | -------------------------------------------------------------------------------- /torchkit/utils/module_stats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Reference: https://stackoverflow.com/a/62508086 5 | def get_total_params( 6 | model: torch.nn.Module, 7 | trainable: bool = True, 8 | print_table: bool = False, 9 | ) -> int: 10 | """Get the total number of parameters in a PyTorch model. 11 | 12 | Example usage:: 13 | 14 | class SimpleMLP(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | self.fc1 = nn.Linear(3, 16) 19 | self.fc2 = nn.Linear(16, 2) 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | out = F.relu(self.fc1(x)) 23 | return self.fc2(out) 24 | 25 | net = SimpleMLP() 26 | num_params = torch_utils.get_total_params(net, print_table=True) 27 | 28 | # prints the following: 29 | +------------+------------+ 30 | | Modules | Parameters | 31 | +------------+------------+ 32 | | fc1.weight | 48 | 33 | | fc1.bias | 16 | 34 | | fc2.weight | 32 | 35 | | fc2.bias | 2 | 36 | +------------+------------+ 37 | Total Trainable Params: 98 38 | 39 | Args: 40 | model (torch.nn.Module): The pytorch model. 41 | trainable (bool, optional): Only consider trainable parameters. Defaults to 42 | True. 43 | print_table (bool, optional): Print the parameters in a pretty table. Defaults 44 | to False. 45 | 46 | Returns: 47 | int: Either all model parameters or only the trainable ones. 48 | """ 49 | from prettytable import PrettyTable 50 | 51 | table = PrettyTable(["Modules", "Parameters"]) 52 | 53 | total_params = 0 54 | for name, parameter in model.named_parameters(): 55 | if not parameter.requires_grad and trainable: 56 | continue 57 | param = parameter.numel() 58 | table.add_row([name, param]) 59 | total_params += param 60 | 61 | if print_table: 62 | print(table) 63 | print("Total Trainable Params: {:,}".format(total_params)) 64 | 65 | return total_params 66 | -------------------------------------------------------------------------------- /torchkit/experiment.py: -------------------------------------------------------------------------------- 1 | """Methods useful for running experiments.""" 2 | 3 | import os 4 | import uuid 5 | from typing import Any 6 | 7 | from ml_collections import config_dict 8 | 9 | from torchkit.utils import dump_config, git_revision_hash, load_config 10 | 11 | ConfigDict = config_dict.ConfigDict 12 | 13 | 14 | def string_from_kwargs(**kwargs: Any) -> str: 15 | """Concatenate kwargs into an underscore-separated string. 16 | 17 | Used to generate an experiment name based on supplied config kwargs. 18 | """ 19 | return "_".join([f"{k}={v}" for k, v in kwargs.items()]) 20 | 21 | 22 | def unique_id() -> str: 23 | """Generate a unique ID as specified in RFC 4122.""" 24 | # See https://docs.python.org/3/library/uuid.html 25 | return str(uuid.uuid4()) 26 | 27 | 28 | def setup_experiment( 29 | exp_dir: str, 30 | config: ConfigDict, 31 | resume: bool = False, 32 | ) -> None: 33 | """Initializes an experiment. 34 | 35 | If the experiment directory doesn't exist yet, creates it and dumps the config 36 | dict as a yaml file and git hash as a text file. 37 | If it exists already, raises a ValueError to prevent overwriting unless resume is 38 | set to True. 39 | If it exists already and resume is set to True, inplace updates the config with the 40 | values in the saved yaml file. 41 | 42 | Args: 43 | exp_dir (str): Path to the experiment directory. 44 | config (ConfigDict): The config for the experiment. 45 | resume (bool, optional): Whether to resume from a previously created experiment. 46 | Defaults to False. 47 | 48 | Raises: 49 | ValueError: If the experiment directory exists already and resume is not set to 50 | True. 51 | """ 52 | if os.path.exists(exp_dir): 53 | if not resume: 54 | raise ValueError( 55 | "Experiment already exists. Run with --resume to continue." 56 | ) 57 | # Inplace-update the config using the values in the saved yaml file. 58 | load_config(exp_dir, config) 59 | else: 60 | # Dump config as a yaml file. 61 | dump_config(exp_dir, config) 62 | 63 | # Dump git hash as a text file. 64 | with open(os.path.join(exp_dir, "git_hash.txt"), "w") as fp: 65 | fp.write(git_revision_hash()) 66 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.testing import assert_allclose 4 | 5 | from torchkit import layers 6 | 7 | 8 | class TestLayers: 9 | @pytest.mark.parametrize("kernel_size", [1, 3, 5]) 10 | def test_conv2d_same_shape(self, kernel_size): 11 | b, c, h, w = 32, 64, 16, 16 12 | x = torch.randn(b, c, h, w) 13 | out = layers.conv2d(c, c * 2, kernel_size=kernel_size)(x) 14 | assert out.shape[2:] == x.shape[2:] 15 | 16 | def test_spatial_soft_argmax(self): 17 | b, c, h, w = 32, 64, 16, 16 18 | x = torch.zeros(b, c, h, w) 19 | true_max = torch.randint(0, 10, size=(b, c, 2)) 20 | for i in range(b): 21 | for j in range(c): 22 | x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000 23 | soft_max = layers.SpatialSoftArgmax(normalize=False)(x).reshape(b, c, 2) 24 | assert_allclose(true_max.float(), soft_max) 25 | 26 | def test_global_max_pool_1d(self): 27 | b, c, t = 4, 3, 16 28 | x = torch.randn(b, c, t) 29 | x_np = x.numpy() 30 | actual = layers.GlobalMaxPool1d()(x) 31 | expected = x_np.max(axis=(-1)) 32 | assert_allclose(actual, expected) 33 | 34 | def test_global_max_pool_2d(self): 35 | b, c, h, w = 4, 3, 16, 16 36 | x = torch.randn(b, c, h, w) 37 | x_np = x.numpy() 38 | actual = layers.GlobalMaxPool2d()(x) 39 | expected = x_np.max(axis=(-1, -2)) 40 | assert_allclose(actual, expected) 41 | 42 | def test_global_max_pool_3d(self): 43 | b, c, t, h, w = 4, 16, 5, 16, 16 44 | x = torch.randn(b, c, t, h, w) 45 | x_np = x.numpy() 46 | actual = layers.GlobalMaxPool3d()(x) 47 | expected = x_np.max(axis=(-1, -2, -3)) 48 | assert_allclose(actual, expected) 49 | 50 | def test_global_average_pool_1d(self): 51 | b, c, t = 4, 3, 16 52 | x = torch.randn(b, c, t) 53 | x_np = x.numpy() 54 | actual = layers.GlobalAvgPool1d()(x) 55 | expected = x_np.mean(axis=(-1)) 56 | assert_allclose(actual, expected) 57 | 58 | def test_global_average_pool_2d(self): 59 | b, c, h, w = 4, 3, 16, 16 60 | x = torch.randn(b, c, h, w) 61 | x_np = x.numpy() 62 | actual = layers.GlobalAvgPool2d()(x) 63 | expected = x_np.mean(axis=(-1, -2)) 64 | assert_allclose(actual, expected) 65 | 66 | def test_global_average_pool_3d(self): 67 | b, c, t, h, w = 4, 16, 5, 16, 16 68 | x = torch.randn(b, c, t, h, w) 69 | x_np = x.numpy() 70 | actual = layers.GlobalAvgPool3d()(x) 71 | expected = x_np.mean(axis=(-1, -2, -3)) 72 | assert_allclose(actual, expected) 73 | -------------------------------------------------------------------------------- /tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchkit.checkpoint import Checkpoint, CheckpointManager 7 | 8 | 9 | @pytest.fixture 10 | def init_model_and_optimizer(): 11 | class SimpleMLP(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | self.fc1 = nn.Linear(3, 16) 15 | self.fc2 = nn.Linear(16, 2) 16 | 17 | def forward(self, x): 18 | out = F.relu(self.fc1(x)) 19 | return self.fc2(out) 20 | 21 | model = SimpleMLP() 22 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 23 | return model, optimizer 24 | 25 | 26 | class TestCheckpoint: 27 | def test_checkpoint_save_restore(self, tmp_path, init_model_and_optimizer): 28 | model, optimizer = init_model_and_optimizer 29 | checkpoint = Checkpoint(model=model, optimizer=optimizer) 30 | checkpoint_dir = tmp_path / "ckpts" 31 | checkpoint_dir.mkdir() 32 | checkpoint_path = checkpoint_dir / "test.ckpt" 33 | checkpoint.save(checkpoint_path) 34 | assert checkpoint.restore(checkpoint_path) 35 | 36 | def test_checkpoint_save_partial_restore(self, tmp_path, init_model_and_optimizer): 37 | model, optimizer = init_model_and_optimizer 38 | checkpoint = Checkpoint(model=model, optimizer=optimizer) 39 | checkpoint_dir = tmp_path / "ckpts" 40 | checkpoint_dir.mkdir() 41 | checkpoint_path = checkpoint_dir / "test.ckpt" 42 | checkpoint.save(checkpoint_path) 43 | assert Checkpoint(model=model).restore(checkpoint_path) 44 | assert Checkpoint(optimizer=optimizer).restore(checkpoint_path) 45 | 46 | def test_checkpoint_save_faulty_restore(self, tmp_path, init_model_and_optimizer): 47 | model, optimizer = init_model_and_optimizer 48 | checkpoint = Checkpoint(model=model, optimizer=optimizer) 49 | checkpoint_dir = tmp_path / "ckpts" 50 | checkpoint_dir.mkdir() 51 | checkpoint_path = checkpoint_dir / "test.ckpt" 52 | checkpoint.save(checkpoint_path) 53 | model.fc2 = nn.Linear(2, 2) # Purposely modify model. 54 | assert not checkpoint.restore(checkpoint_path) 55 | 56 | def test_checkpoint_manager(self, tmp_path, init_model_and_optimizer): 57 | model, optimizer = init_model_and_optimizer 58 | ckpt_dir = tmp_path / "ckpts" 59 | checkpoint_manager = CheckpointManager( 60 | ckpt_dir, 61 | max_to_keep=5, 62 | model=model, 63 | optimizer=optimizer, 64 | ) 65 | global_step = checkpoint_manager.restore_or_initialize() 66 | assert global_step == 0 67 | for i in range(10): 68 | checkpoint_manager.save(i) 69 | available_ckpts = checkpoint_manager.list_checkpoints(ckpt_dir) 70 | assert len(available_ckpts) == 5 71 | ckpts = [int(d.stem) for d in available_ckpts] 72 | expected = list(range(5, 10)) 73 | assert all([a == b for a, b in zip(ckpts, expected)]) 74 | global_step = checkpoint_manager.restore_or_initialize() 75 | assert global_step == 9 76 | assert int(checkpoint_manager.latest_checkpoint.stem) == 9 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchkit 2 | 3 | [![documentation](https://github.com/kevinzakka/torchkit/workflows/docs/badge.svg)](https://kevinzakka.github.io/torchkit/) 4 | ![build](https://github.com/kevinzakka/torchkit/workflows/build/badge.svg) 5 | ![license](https://img.shields.io/github/license/kevinzakka/torchkit?color=blue) 6 | 7 | **`torchkit`** is a *lightweight* library containing PyTorch utilities useful for day-to-day research. Its main goal is to abstract away a lot of the redundant boilerplate associated with research projects like experimental configurations, logging and model checkpointing. It consists of: 8 | 9 | 10 | 11 | 12 | 13 | 17 | 18 | 19 | 20 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 | 43 | 44 | 47 | 48 | 49 |
torchkit.Logger 14 | A wrapper around Tensorboard's SummaryWriter for safe 15 | logging of scalars, images, videos and learning rates. Supports both numpy arrays and torch Tensors. 16 |
torchkit.CheckpointManager 21 | A port of Tensorflow's checkpoint manager that automatically manages multiple checkpoints in an experimental run. 22 |
torchkit.experiment 27 | A collection of methods for setting up experiment directories. 28 |
torchkit.layers 33 | A set of commonly used layers in research papers not available in vanilla PyTorch like "same" and "causal" convolution and SpatialSoftArgmax. 34 |
torchkit.losses 39 | Some useful loss functions also unavailable in vanilla PyTorch like cross entropy with label smoothing and Huber loss. 40 |
torchkit.utils 45 | A bunch of helper functions for config manipulation, I/O, timing, debugging, etc. 46 |
50 | 51 | For more details about each module, see the [documentation](https://kevinzakka.github.io/torchkit/). 52 | 53 | ### Installation 54 | 55 | To install the latest release, run: 56 | 57 | ```bash 58 | pip install git+https://github.com/kevinzakka/torchkit.git 59 | ``` 60 | 61 | ### Contributing 62 | 63 | For development, clone the source code and create a virtual environment for this project: 64 | 65 | ```bash 66 | git clone https://github.com/kevinzakka/torchkit.git 67 | cd torchkit 68 | pip install -e .[dev] 69 | ``` 70 | 71 | ### Acknowledgments 72 | 73 | * Thanks to Karan Desai's [VirTex](https://github.com/kdexd/virtex) which I used to figure out documentation-related setup for torchkit and for just being an excellent example of stellar open-source research release. 74 | * Thanks to [seals](https://github.com/HumanCompatibleAI/seals) for the excellent software development 75 | practices that I've tried to emulate in this repo. 76 | * Thanks to Brent Yi for encouraging me to use type hinting and for letting me use his awesome [Bayesian filtering library](https://github.com/stanford-iprl-lab/torchfilter)'s README as a template. 77 | -------------------------------------------------------------------------------- /torchkit/viz.py: -------------------------------------------------------------------------------- 1 | """Tools for visualizing resnet feature maps and ViT attention weights.""" 2 | 3 | from typing import Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from PIL import Image 10 | from torchvision import transforms as T 11 | 12 | 13 | def _load_model( 14 | model_name: str, 15 | device: torch.device, 16 | ): 17 | if "dino" in model_name: 18 | assert model_name in [ 19 | "dino_vits16", 20 | "dino_vits8", 21 | "dino_vitb16", 22 | "dino_vitb8", 23 | "dino_resnet50", 24 | ], f"{model_name} is not a valid DINO pretrained model." 25 | model = torch.hub.load("facebookresearch/dino:main", model_name) 26 | normalize = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 27 | elif "resnet" in model_name: 28 | model = getattr(torchvision.models, model_name)(pretrained=True) 29 | normalize = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 30 | for p in model.parameters(): 31 | p.requires_grad = False 32 | model.eval() 33 | model.to(device) 34 | return model, normalize 35 | 36 | 37 | def visualize_attention( 38 | model_name: str, 39 | image: Union[str, np.ndarray], 40 | image_size: Tuple[int, int] = (480, 480), 41 | resnet_layer_idx: Optional[int] = -2, 42 | ) -> np.ndarray: 43 | # Load the model. 44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 45 | model, normalize = _load_model(model_name, device) 46 | 47 | # Create pre-processing pipeline. 48 | preprocess = T.Compose( 49 | [ 50 | T.Resize(image_size), 51 | T.ToTensor(), 52 | normalize, 53 | ] 54 | ) 55 | 56 | if isinstance(image, str): 57 | img = Image.open(image) 58 | else: 59 | img = Image.fromarray(image) 60 | img_tensor = preprocess(img) 61 | 62 | if "dino" in model_name and "resnet" not in model_name: 63 | patch_size = model.patch_embed.patch_size 64 | w = img_tensor.shape[1] - img_tensor.shape[1] % patch_size 65 | h = img_tensor.shape[2] - img_tensor.shape[2] % patch_size 66 | img_tensor = img_tensor[:, :w, :h].unsqueeze(0).to(device) 67 | w_featmap = img_tensor.shape[-2] // patch_size 68 | h_featmap = img_tensor.shape[-1] // patch_size 69 | 70 | attentions = model.get_last_selfattention(img_tensor) 71 | nh = attentions.shape[1] 72 | 73 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 74 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 75 | 76 | attentions = ( 77 | nn.functional.interpolate( 78 | attentions.unsqueeze(0), 79 | scale_factor=patch_size, 80 | mode="nearest", 81 | )[0] 82 | .cpu() 83 | .numpy() 84 | ) 85 | else: 86 | layers = list(model.children())[:resnet_layer_idx] 87 | model = nn.Sequential(*layers) 88 | 89 | img_tensor = img_tensor.unsqueeze(0).to(device) 90 | img_h, img_w = img_tensor.shape[-2], img_tensor.shape[-1] 91 | 92 | with torch.no_grad(): 93 | attentions = model(img_tensor) 94 | 95 | # Average over all feature maps. 96 | attentions = attentions.mean(dim=1) 97 | 98 | scale_factor_h = img_h / attentions.shape[-2] 99 | scale_factor_w = img_w / attentions.shape[-1] 100 | attentions = ( 101 | nn.functional.interpolate( 102 | attentions.unsqueeze(0), 103 | scale_factor=(scale_factor_h, scale_factor_w), 104 | mode="nearest", 105 | )[0] 106 | .cpu() 107 | .numpy() 108 | ) 109 | 110 | return attentions 111 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from torch.testing import assert_allclose 5 | 6 | from torchkit import losses 7 | 8 | 9 | def log_softmax(x): 10 | """A numerically stable log softmax implementation.""" 11 | x = x - np.max(x, axis=-1, keepdims=True) 12 | return x - np.log(np.exp(x).sum(axis=-1, keepdims=True)) 13 | 14 | 15 | class TestLayers: 16 | @pytest.mark.parametrize("smooth_eps", [0, 0.5, 1]) 17 | @pytest.mark.parametrize("K", [2, 100]) 18 | def test_one_hot(self, smooth_eps, K): 19 | batch_size = 32 20 | y = torch.randint(K, (batch_size,)) 21 | y_np = y.numpy() 22 | actual = losses.one_hot(y, K, smooth_eps) 23 | y_np_one_hot = np.eye(K)[y_np] 24 | expected = y_np_one_hot * (1 - smooth_eps) + (smooth_eps / (K - 1)) 25 | assert_allclose(actual, expected) 26 | 27 | def test_one_hot_not_rank_one(self): 28 | with pytest.raises(AssertionError): 29 | losses.one_hot(torch.randint(5, (2, 2)), 5, 0) 30 | 31 | @pytest.mark.parametrize("smooth_eps", [-1, 2]) 32 | def test_one_hot_eps_out_of_bounds(self, smooth_eps): 33 | with pytest.raises(AssertionError): 34 | losses.one_hot(torch.randint(5, (2, 2)), 5, smooth_eps) 35 | 36 | @pytest.mark.parametrize("smooth_eps", [0, 0.5, 1]) 37 | @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) 38 | def test_cross_entropy(self, smooth_eps, reduction): 39 | batch_size = 200 40 | K = 50 41 | labels = torch.randint(K, (batch_size,)) 42 | logits = torch.randn(batch_size, K) 43 | actual = losses.cross_entropy(logits, labels, smooth_eps, reduction) 44 | logits_np = logits.numpy() 45 | labels_np = labels.numpy() 46 | labels_np_one_hot = np.eye(K)[labels_np] * (1 - smooth_eps) + ( 47 | smooth_eps / (K - 1) 48 | ) 49 | # Compute log softmax of logits. 50 | log_probs_np = log_softmax(logits_np) 51 | loss_np = (-labels_np_one_hot * log_probs_np).sum(axis=-1) 52 | if reduction == "mean": 53 | expected = loss_np.mean() 54 | elif reduction == "sum": 55 | expected = loss_np.sum(axis=-1) 56 | else: # none 57 | expected = loss_np 58 | assert_allclose(actual, expected) 59 | 60 | def test_cross_entropy_unsupported_reduction(self): 61 | K, batch_size = 50, 200 62 | with pytest.raises(AssertionError): 63 | labels = torch.randint(K, (batch_size,)) 64 | logits = torch.randn(batch_size, K) 65 | losses.cross_entropy(logits, labels, reduction="average") 66 | 67 | def test_cross_entropy_labels_dim(self): 68 | K, batch_size = 50, 200 69 | with pytest.raises(AssertionError): 70 | labels = torch.randint(K, (batch_size, 2)) 71 | logits = torch.randn(batch_size, K) 72 | losses.cross_entropy(logits, labels, reduction="average") 73 | 74 | @pytest.mark.parametrize("delta", [1.0, 2.0, 10.0]) 75 | @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) 76 | def test_huber_loss(self, delta, reduction): 77 | batch_size = 200 78 | num_dims = 50 79 | target = torch.randn(batch_size, num_dims) 80 | input = torch.randn(batch_size, num_dims) 81 | target_np = target.numpy() 82 | input_np = input.numpy() 83 | actual = losses.huber_loss(input, target, delta, reduction) 84 | diff_np = target_np - input_np 85 | diff_abs_np = np.abs(diff_np) 86 | cond = diff_abs_np <= delta 87 | out = np.where( 88 | cond, 0.5 * diff_np ** 2, (delta * diff_abs_np) - (0.5 * delta ** 2) 89 | ) 90 | if reduction == "mean": 91 | expected = out.mean() 92 | elif reduction == "sum": 93 | expected = out.sum(axis=-1) 94 | else: # none 95 | expected = out 96 | assert_allclose(actual, expected) 97 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from torchkit.logger import Logger 6 | 7 | 8 | @pytest.fixture 9 | def init_logger(tmp_path): 10 | log_dir = tmp_path / "logs" 11 | logger = Logger(log_dir, force_write=False) 12 | return logger 13 | 14 | 15 | class TestLogger: 16 | def test_error_on_init_existing(self, tmp_path, init_logger): 17 | _ = init_logger 18 | with pytest.raises(ValueError): 19 | _ = Logger(tmp_path / "logs") 20 | 21 | @pytest.mark.parametrize( 22 | "scalar", [torch.FloatTensor([5.0]), torch.randn([2, 2]).mean(), 5.0] 23 | ) 24 | def test_log_scalar(self, init_logger, scalar): 25 | logger = init_logger 26 | logger.log_scalar(scalar, 0, "loss", "training") 27 | 28 | def test_log_scalar_notscalar(self, init_logger): 29 | logger = init_logger 30 | scalar = torch.FloatTensor([5.0, 5.0]) 31 | with pytest.raises(ValueError): 32 | logger.log_scalar(scalar, 0, "loss", "training") 33 | 34 | @pytest.mark.parametrize( 35 | "image", 36 | [ 37 | np.random.randint(0, 256, size=(224, 224, 3)), 38 | np.random.randint(0, 256, size=(2, 224, 224, 3)), 39 | torch.randint(0, 256, size=(3, 224, 224)), 40 | torch.randint(0, 256, size=(2, 3, 224, 224)), 41 | ], 42 | ) 43 | def test_log_image(self, init_logger, image): 44 | logger = init_logger 45 | logger.log_image(image, 0, "image", "validation") 46 | 47 | @pytest.mark.parametrize( 48 | "image", 49 | [ 50 | np.random.randint(0, 256, size=(3, 224, 224)), 51 | np.random.randint(0, 256, size=(2, 3, 224, 224)), 52 | torch.randint(0, 256, size=(224, 224, 3)), 53 | torch.randint(0, 256, size=(2, 224, 224, 3)), 54 | ], 55 | ) 56 | def test_log_image_wrong_format(self, init_logger, image): 57 | logger = init_logger 58 | with pytest.raises(TypeError): 59 | logger.log_image(image, 0, "image", "validation") 60 | 61 | @pytest.mark.parametrize( 62 | "video", 63 | [ 64 | np.random.randint(0, 256, size=(5, 224, 224, 3)), 65 | np.random.randint(0, 256, size=(4, 5, 224, 224, 3)), 66 | torch.randint(0, 256, size=(5, 3, 224, 224)), 67 | torch.randint(0, 256, size=(4, 5, 3, 224, 224)), 68 | ], 69 | ) 70 | def test_log_video(self, init_logger, video): 71 | logger = init_logger 72 | logger.log_video(video, 0, "video", "training") 73 | 74 | def test_log_video_wrongdim(self, init_logger): 75 | logger = init_logger 76 | image = np.random.randint(0, 256, (224, 224, 3)) 77 | with pytest.raises(ValueError): 78 | logger.log_video(image, 0, "video", "training") 79 | 80 | @pytest.mark.parametrize( 81 | "video", 82 | [ 83 | np.random.randint(0, 256, size=(5, 3, 224, 224)), 84 | np.random.randint(0, 256, size=(4, 5, 3, 224, 224)), 85 | torch.randint(0, 256, size=(5, 224, 224, 3)), 86 | torch.randint(0, 256, size=(4, 5, 224, 224, 3)), 87 | ], 88 | ) 89 | def test_log_video_wrongformat(self, init_logger, video): 90 | logger = init_logger 91 | with pytest.raises(TypeError): 92 | logger.log_video(video, 0, "video", "training") 93 | 94 | def test_learning_rate(self, init_logger): 95 | logger = init_logger 96 | param = torch.randn((32, 3), requires_grad=True) 97 | optim = torch.optim.Adam([param], lr=1e-3) 98 | logger.log_learning_rate(optim, 0, "training") 99 | 100 | def test_learning_rate_notoptim(self, init_logger): 101 | logger = init_logger 102 | param = torch.randn((32, 3), requires_grad=True) 103 | with pytest.raises(TypeError): 104 | logger.log_learning_rate(param, 0, "training") 105 | -------------------------------------------------------------------------------- /torchkit/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Union 3 | 4 | import yaml 5 | from ml_collections import config_dict 6 | 7 | ConfigDict = config_dict.ConfigDict 8 | FrozenConfigDict = config_dict.FrozenConfigDict 9 | 10 | 11 | # Reference: https://github.com/deepmind/jaxline/blob/master/jaxline/base_config.py 12 | def validate_config( 13 | config: ConfigDict, 14 | base_config: ConfigDict, 15 | base_filename: str, 16 | ) -> None: 17 | """Ensures a config inherits from a base config. 18 | 19 | Args: 20 | config (ConfigDict): The child config. 21 | base_config (ConfigDict): The base config. 22 | base_filename (str): Path to python file containing base config definition. 23 | 24 | Raises: 25 | ValueError: if the base config contains keys that are not present in config. 26 | """ 27 | for key in base_config.keys(): 28 | if key not in config: 29 | raise ValueError( 30 | f"Key {key} missing from config. This config is required to have " 31 | f"keys: {list(base_config.keys())}. See {base_filename} for more " 32 | "details." 33 | ) 34 | if isinstance(base_config[key], ConfigDict) and config[key] is not None: 35 | validate_config(config[key], base_config[key], base_filename) 36 | 37 | 38 | def dump_config(exp_dir: str, config: Union[ConfigDict, FrozenConfigDict]) -> None: 39 | """Dump a config to disk. 40 | 41 | Args: 42 | exp_dir (str): Path to the experiment directory. 43 | config (Union[ConfigDict, FrozenConfigDict]): The config to dump. 44 | """ 45 | if not os.path.exists(exp_dir): 46 | os.makedirs(exp_dir) 47 | 48 | # Note: No need to explicitly delete the previous config file as "w" will overwrite 49 | # the file if it already exists. 50 | with open(os.path.join(exp_dir, "config.yaml"), "w") as fp: 51 | yaml.dump(config.to_dict(), fp) 52 | 53 | 54 | def load_config( 55 | exp_dir: str, 56 | config: Optional[ConfigDict] = None, 57 | freeze: bool = False, 58 | ) -> Optional[Union[ConfigDict, FrozenConfigDict]]: 59 | """Load a config from an experiment directory. 60 | 61 | Args: 62 | exp_dir (str): Path to the experiment directory. 63 | config (Optional[ConfigDict], optional): An optional config object to inplace 64 | update. If one isn't provided, a new config object is returned. Defaults to 65 | None. 66 | freeze (bool, optional): Whether to freeze the config. Defaults to False. 67 | 68 | Returns: 69 | Optional[Union[ConfigDict, FrozenConfigDict]]: The config file that was stored 70 | in the experiment directory. 71 | """ 72 | with open(os.path.join(exp_dir, "config.yaml"), "r") as fp: 73 | cfg = yaml.load(fp, Loader=yaml.FullLoader) 74 | # Inplace update the config if one is provided. 75 | if config is not None: 76 | config.update(cfg) 77 | return 78 | if freeze: 79 | return FrozenConfigDict(cfg) 80 | return ConfigDict(cfg) 81 | 82 | 83 | def copy_config_and_replace( 84 | config: ConfigDict, 85 | update_dict: Optional[ConfigDict] = None, 86 | freeze: bool = False, 87 | ) -> Union[ConfigDict, FrozenConfigDict]: 88 | """Makes a copy of a config and optionally updates its values. 89 | 90 | Args: 91 | config (ConfigDict): The config to copy. 92 | update_dict (Optional[ConfigDict], optional): A config that will optionally 93 | update the copy. Defaults to None. 94 | freeze (bool, optional): Whether to freeze the config after the copy. Defaults 95 | to False. 96 | 97 | Returns: 98 | Union[ConfigDict, FrozenConfigDict]: A copy of the config. 99 | """ 100 | # Using the ConfigDict constructor leaves the `FieldReferences` untouched unlike 101 | # `ConfigDict.copy_and_resolve_references`. 102 | new_config = ConfigDict(config) 103 | if update_dict is not None: 104 | new_config.update(update_dict) 105 | if freeze: 106 | return FrozenConfigDict(new_config) 107 | return new_config 108 | -------------------------------------------------------------------------------- /torchkit/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | Tensor = torch.Tensor 5 | 6 | 7 | def one_hot( 8 | y: Tensor, 9 | K: int, 10 | smooth_eps: float = 0, 11 | ) -> Tensor: 12 | """One-hot encodes a tensor, with optional label smoothing. 13 | 14 | Args: 15 | y (Tensor): A tensor containing the ground-truth labels of shape `(N,)`, i.e. 16 | one label for each element in the batch. 17 | K (int): The number of classes. 18 | smooth_eps (float, optional): Label smoothing factor in `[0, 1]` range. Defaults 19 | to 0, which corresponds to no label smoothing. 20 | 21 | Returns: 22 | Tensor: The one-hot encoded tensor. 23 | """ 24 | assert 0 <= smooth_eps <= 1 25 | assert y.ndim == 1, "Label tensor must be rank 1." 26 | y_hot = torch.eye(K)[y] * (1 - smooth_eps) + (smooth_eps / (K - 1)) 27 | return y_hot.to(y.device) 28 | 29 | 30 | def cross_entropy( 31 | logits: Tensor, 32 | labels: Tensor, 33 | smooth_eps: float = 0, 34 | reduction: str = "mean", 35 | ) -> Tensor: 36 | """Cross-entropy loss with support for label smoothing. 37 | 38 | Args: 39 | logits (Tensor): A `FloatTensor` containing the raw logits, i.e. no softmax has 40 | been applied to the model output. The tensor should be of shape 41 | `(N, K)` where K is the number of classes. 42 | labels (Tensor): A rank-1 `LongTensor` containing the ground truth labels. 43 | smooth_eps (float, optional): The label smoothing factor in `[0, 1]` range. 44 | Defaults to 0. 45 | reduction (str, optional): The reduction strategy on the final loss tensor. 46 | Defaults to "mean". 47 | 48 | Returns: 49 | If reduction is `none`, a 2D Tensor. 50 | If reduction is `sum`, a 1D Tensor. 51 | If reduction is `mean`, a scalar 1D Tensor. 52 | """ 53 | assert isinstance(logits, (torch.FloatTensor, torch.cuda.FloatTensor)) 54 | assert isinstance(labels, (torch.LongTensor, torch.cuda.LongTensor)) 55 | assert reduction in ["none", "mean", "sum"], "reduction method is not supported" 56 | 57 | # Ensure logits are not 1-hot encoded. 58 | assert labels.ndim == 1, "[!] Labels are NOT expected to be 1-hot encoded." 59 | 60 | if smooth_eps == 0: 61 | return F.cross_entropy(logits, labels, reduction=reduction) 62 | 63 | # One-hot encode targets. 64 | labels = one_hot(labels, logits.shape[1], smooth_eps) 65 | 66 | # Convert logits to log probabilities. 67 | log_probs = F.log_softmax(logits, dim=-1) 68 | 69 | loss = (-labels * log_probs).sum(dim=-1) 70 | 71 | if reduction == "none": 72 | return loss 73 | elif reduction == "mean": 74 | return loss.mean() 75 | return loss.sum(dim=-1) 76 | 77 | 78 | def huber_loss( 79 | input: Tensor, 80 | target: Tensor, 81 | delta: float, 82 | reduction: str = "mean", 83 | ) -> Tensor: 84 | """Huber loss with tunable margin, as defined in `1`_. 85 | 86 | Args: 87 | input (Tensor): A FloatTensor representing the model output. 88 | target (Tensor): A FloatTensor representing the target values. 89 | delta (float): Given the tensor difference `diff`, delta is the value at which 90 | we incur a quadratic penalty if `diff` is at least delta and a 91 | linear penalty otherwise. 92 | reduction (str, optional): The reduction strategy on the final loss tensor. 93 | Defaults to "mean". 94 | 95 | Returns: 96 | If reduction is `none`, a 2D Tensor. 97 | If reduction is `sum`, a 1D Tensor. 98 | If reduction is `mean`, a scalar 1D Tensor. 99 | 100 | .. _1: https://en.wikipedia.org/wiki/Huber_loss 101 | """ 102 | assert isinstance(input, (torch.FloatTensor, torch.cuda.FloatTensor)) 103 | assert isinstance(target, (torch.FloatTensor, torch.cuda.FloatTensor)) 104 | assert reduction in ["none", "mean", "sum"], "reduction method is not supported" 105 | 106 | diff = target - input 107 | diff_abs = torch.abs(diff) 108 | cond = diff_abs <= delta 109 | loss = torch.where(cond, 0.5 * diff ** 2, (delta * diff_abs) - (0.5 * delta ** 2)) 110 | if reduction == "none": 111 | return loss 112 | elif reduction == "mean": 113 | return loss.mean() 114 | return loss.sum(dim=-1) 115 | -------------------------------------------------------------------------------- /torchkit/logger.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Type, Union, cast 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | Tensor = torch.Tensor 10 | ImageType = Union[Tensor, np.ndarray] 11 | 12 | 13 | class Logger: 14 | """A Tensorboard-based logger.""" 15 | 16 | def __init__(self, log_dir: str, force_write: bool = False) -> None: 17 | """Constructor. 18 | 19 | Args: 20 | log_dir: The directory in which to store Tensorboard logs. 21 | force_write: Whether to force write to an already existing log dir. 22 | Set to `True` if resuming training. 23 | """ 24 | # Setup the summary writer. 25 | if osp.exists(log_dir) and not force_write: 26 | raise ValueError( 27 | "You might be overwriting a directory that already " 28 | "has train_logs. Please provide a new experiment name " 29 | "or set --resume to True when launching train script." 30 | ) 31 | self._writer = SummaryWriter(log_dir) 32 | 33 | def close(self) -> None: 34 | self._writer.close() 35 | 36 | def flush(self) -> None: 37 | self._writer.flush() 38 | 39 | def log_scalar( 40 | self, 41 | scalar: Union[Tensor, float], 42 | global_step: int, 43 | name: str, 44 | prefix: str = "", 45 | ) -> None: 46 | """Log a scalar value. 47 | 48 | Args: 49 | scalar: A scalar `torch.Tensor` or float. 50 | global_step: The training iteration step. 51 | name: The name of the logged scalar. 52 | prefix: A prefix to prepend to the logged scalar. 53 | """ 54 | if isinstance(scalar, torch.Tensor): 55 | if cast(torch.Tensor, scalar).ndim > 1: 56 | raise ValueError("Tensor must be scalar-valued.") 57 | if cast(torch.Tensor, scalar).ndim == 1: 58 | if cast(torch.Tensor, scalar).shape != torch.Size([1]): 59 | raise ValueError("Tensor must be scalar-valued.") 60 | scalar = cast(torch.Tensor, scalar).item() 61 | assert np.isscalar(scalar), "Not a scalar." 62 | msg = "/".join([prefix, name]) if prefix else name 63 | self._writer.add_scalar(msg, scalar, global_step) 64 | 65 | def log_image( 66 | self, 67 | image: ImageType, 68 | global_step: int, 69 | name: str, 70 | prefix: str = "", 71 | nrow: int = 5, 72 | ) -> None: 73 | """Log an image or batch of images. 74 | 75 | Args: 76 | image: A numpy ndarray or a torch Tensor. If the image is 4D (i.e. 77 | batched), it will be converted to a 3D image using make_grid. 78 | The numpy array should be in channel-last format while the torch 79 | Tensor should be in channel-first format. 80 | global_step: The training iteration step. 81 | name: The name of the logged image(s). 82 | prefix: A prefix to prepend to the logged image(s). 83 | nrow: The number of images displayed in each row of the grid if the 84 | input image is 4D. 85 | """ 86 | msg = "/".join([prefix, name]) if prefix else name 87 | assert image.ndim in [3, 4], "Must be an image or batch of images." 88 | if image.ndim == 4: 89 | if isinstance(image, np.ndarray): 90 | image = torch.from_numpy(image).permute(0, 3, 1, 2) 91 | image = torchvision.utils.make_grid(image, nrow=nrow) 92 | else: 93 | if isinstance(image, np.ndarray): 94 | image = torch.from_numpy(image).permute(2, 0, 1) 95 | self._writer.add_image(msg, image, global_step, dataformats="CHW") 96 | 97 | def log_video( 98 | self, 99 | video, 100 | global_step: int, 101 | name: str, 102 | prefix: str = "", 103 | fps: int = 4, 104 | ) -> None: 105 | """Log a sequence of images or a batch of sequence of images. 106 | 107 | Args: 108 | video: A torch Tensor or numpy ndarray. The numpy array should be in 109 | channel-last format while the torch Tensor should be in 110 | channel-first format. Should be either a single sequence of 111 | images of shape (T, CHW/HWC) or a batch of sequences of shape 112 | (B, T, CHW/HWC). The batch of sequences will get converted to 113 | one grid sequence of images. 114 | global_step: The training iteration step. 115 | name: The name of the logged video(s). 116 | prefix: A prefix to prepend to the logged video(s). 117 | fps: The frames per second. 118 | """ 119 | msg = f"{prefix}/image/{name}" 120 | if video.ndim not in [4, 5]: 121 | raise ValueError("Must be a video or batch of videos.") 122 | if video.ndim == 4: 123 | if isinstance(video, np.ndarray): 124 | if video.shape[-1] != 3: 125 | raise TypeError("Numpy array should have THWC format.") 126 | # (T, H, W, C) -> (T, C, H, W). 127 | video = torch.from_numpy(video).permute(0, 3, 1, 2) 128 | elif isinstance(video, torch.Tensor): 129 | if video.shape[1] != 3: 130 | raise TypeError("Torch tensor should have TCHW format.") 131 | video = video.unsqueeze(0) # (T, C, H, W) -> (1, T, C, H, W). 132 | else: 133 | if isinstance(video, np.ndarray): 134 | if video.shape[-1] != 3: 135 | raise TypeError("Numpy array should have BTHWC format.") 136 | # (B, T, H, W, C) -> (B, T, C, H, W). 137 | video = torch.from_numpy(video).permute(0, 1, 4, 2, 3) 138 | elif isinstance(video, torch.Tensor): 139 | if video.shape[2] != 3: 140 | raise TypeError("Torch tensor should have BTCHW format.") 141 | self._writer.add_video(msg, video, global_step, fps=fps) 142 | 143 | def log_learning_rate( 144 | self, 145 | optimizer: Type[torch.optim.Optimizer], 146 | global_step: int, 147 | prefix: str = "", 148 | ) -> None: 149 | """Log the learning rate. 150 | 151 | Args: 152 | optimizer: An optimizer. 153 | global_step: The training iteration step. 154 | """ 155 | if not isinstance(optimizer, torch.optim.Optimizer): 156 | raise TypeError("Optimizer must be an instance of torch.optim.Optimizer.") 157 | for param_group in optimizer.param_groups: 158 | lr = param_group["lr"] 159 | self.log_scalar(lr, global_step, "learning_rate", prefix) 160 | -------------------------------------------------------------------------------- /torchkit/checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import tempfile 5 | from pathlib import Path 6 | from typing import Any, List, Optional, Union 7 | 8 | import torch 9 | 10 | from .experiment import unique_id 11 | 12 | 13 | def get_files( 14 | d: Path, 15 | pattern: str, 16 | sort_lexicographical: bool = False, 17 | sort_numerical: bool = False, 18 | ) -> List[Path]: 19 | """Return a list of files in a given directory. 20 | 21 | Args: 22 | d: The path to the directory. 23 | pattern: The wildcard to filter files with. 24 | sort_lexicographical: Lexicographical sort. 25 | sort_numerical: Numerical sort. 26 | """ 27 | files = d.glob(pattern) 28 | if sort_lexicographical: 29 | return sorted(files, key=lambda x: x.stem) 30 | if sort_numerical: 31 | return sorted(files, key=lambda x: int(x.stem)) 32 | return list(files) 33 | 34 | 35 | class Checkpoint: 36 | """Save and restore PyTorch objects implementing a `state_dict` method.""" 37 | 38 | def __init__(self, **kwargs) -> None: 39 | """Constructor. 40 | 41 | Accepts keyword arguments whose values are objects that contain a 42 | `state_dict` attribute and thus can be serialized to disk. 43 | 44 | Args: 45 | kwargs: Keyword arguments are set as attributes of this object, 46 | and are saved with the checkpoint. Values must have a 47 | `state_dict` attribute. 48 | 49 | Raises: 50 | ValueError: If objects in `kwargs` do not have a `state_dict` 51 | attribute. 52 | """ 53 | for k, v in sorted(kwargs.items()): 54 | if not getattr(v, "state_dict"): 55 | raise ValueError(f"{k} does not have a state_dict attribute.") 56 | setattr(self, k, v) 57 | 58 | def save(self, save_path: Path) -> None: 59 | """Save a state to disk. 60 | 61 | Modified from brentyi/fannypack. 62 | 63 | Args: 64 | save_path: The name of the checkpoint to save. 65 | """ 66 | # Ignore ctrl+c while saving. 67 | try: 68 | orig_handler = signal.getsignal(signal.SIGINT) 69 | signal.signal(signal.SIGINT, lambda _sig, _frame: None) 70 | except ValueError: 71 | # Signal throws a ValueError if we're not in the main thread. 72 | orig_handler = None 73 | 74 | # Create a snapshot of the current state. 75 | save_dict = dict() 76 | for k, v in self.__dict__.items(): 77 | save_dict[k] = v.state_dict() 78 | 79 | with tempfile.TemporaryDirectory() as tmp_dir: 80 | tmp_path = Path(tmp_dir) / "tmp.ckpt" 81 | torch.save(save_dict, tmp_path) 82 | # `rename` is POSIX-compliant and thus, is an atomic operation. 83 | # Ref: https://docs.python.org/3/library/os.html#os.rename 84 | os.rename(tmp_path, save_path) 85 | 86 | tmp_path = save_path.parent / f"tmp-{unique_id()}.ckpt" 87 | torch.save(save_dict, tmp_path) 88 | os.rename(tmp_path, save_path) 89 | 90 | # Restore SIGINT handler. 91 | if orig_handler is not None: 92 | signal.signal(signal.SIGINT, orig_handler) 93 | 94 | def restore(self, save_path: Union[str, Path]) -> bool: 95 | """Restore a state from a saved checkpoint. 96 | 97 | Args: 98 | save_path: The filepath to the saved checkpoint. 99 | 100 | Returns: 101 | True if restoring was successful or partially (not all 102 | checkpointables could be restored) successful and False otherwise. 103 | """ 104 | try: 105 | state = torch.load(Path(save_path), map_location="cpu") 106 | for name, state_dict in state.items(): 107 | if not hasattr(self, name): 108 | logging.warning( 109 | f"{name} in saved checkpoint not in checkpoint to " 110 | "reload. Skipping it." 111 | ) 112 | continue 113 | getattr(self, name).load_state_dict(state_dict) 114 | return True 115 | except Exception as e: 116 | print(e) 117 | return False 118 | 119 | 120 | # TODO(kevin): Add saving of best checkpoint based on specified metric. 121 | class CheckpointManager: 122 | """ 123 | Periodically save PyTorch checkpointables (any object that implements a 124 | `state_dict` method) to disk and restore them to resume training. 125 | 126 | Note: This is a re-implementation of `2`_. 127 | 128 | Example usage:: 129 | 130 | from torchkit.checkpoint import CheckpointManager 131 | 132 | # Create a checkpoint manager instance. 133 | checkpoint_manager = checkpoint.CheckpointManager( 134 | checkpoint_dir, 135 | device, 136 | model=model, 137 | optimizer=optimizer, 138 | ) 139 | 140 | # Restore last checkpoint if it exists. 141 | global_step = checkpoint_manager.restore_or_initialize() 142 | for global_step in range(1000): 143 | # forward pass + loss computation 144 | 145 | # Save a checkpoint every N iters. 146 | if not global_step % N: 147 | checkpoint_manager.save(global_step) 148 | 149 | .. _2: https://www.tensorflow.org/api_docs/python/tf/train/CheckpointManager/ 150 | """ 151 | 152 | def __init__( 153 | self, 154 | directory: str, 155 | max_to_keep: int = 10, 156 | **checkpointables: Any, 157 | ) -> None: 158 | """Constructor. 159 | 160 | Args: 161 | directory: The directory in which checkpoints will be saved. 162 | max_to_keep: The maximum number of checkpoints to keep. 163 | Amongst all saved checkpoints, checkpoints will be deleted 164 | oldest first, until `max_to_keep` remain. 165 | checkpointables: Keyword args with checkpointable PyTorch objects. 166 | """ 167 | assert max_to_keep > 0, "max_to_keep should be a positive integer." 168 | 169 | self.directory = Path(directory).absolute() 170 | self.max_to_keep = max_to_keep 171 | self.checkpoint = Checkpoint(**checkpointables) 172 | 173 | # Create checkpoint directory if it doesn't already exist. 174 | self.directory.mkdir(parents=True, exist_ok=True) 175 | 176 | def restore_or_initialize(self) -> int: 177 | """Restore items in checkpoint from the latest checkpoint file. 178 | 179 | Returns: 180 | The global iteration step. This is parsed from the latest checkpoint 181 | file if one is found, else 0 is returned. 182 | """ 183 | ckpts = CheckpointManager.list_checkpoints(self.directory) 184 | if not ckpts: 185 | return 0 186 | last_ckpt = ckpts[-1] 187 | status = self.checkpoint.restore(last_ckpt) 188 | if not status: 189 | logging.info("Could not restore latest checkpoint file.") 190 | return 0 191 | return int(last_ckpt.stem) 192 | 193 | def save(self, global_step: int) -> None: 194 | """Create a new checkpoint. 195 | 196 | Args: 197 | global_step: The iteration number which will be used to name the 198 | checkpoint. 199 | """ 200 | save_path = self.directory / f"{global_step}.ckpt" 201 | self.checkpoint.save(save_path) 202 | self._trim_checkpoints() 203 | 204 | def _trim_checkpoints(self) -> None: 205 | """Trim older checkpoints until `max_to_keep` remain.""" 206 | # Get a list of checkpoints in reverse global_step order. 207 | ckpts = CheckpointManager.list_checkpoints(self.directory)[::-1] 208 | # Remove until `max_to_keep` remain. 209 | while len(ckpts) - self.max_to_keep > 0: 210 | ckpts.pop().unlink() 211 | 212 | def load_latest_checkpoint(self) -> None: 213 | """Load the last saved checkpoint.""" 214 | self.checkpoint.restore(self.latest_checkpoint) 215 | 216 | def load_checkpoint_at(self, global_step: int) -> None: 217 | """Load a checkpoint at a given global step.""" 218 | ckpt_name = self.directory / f"{global_step}.ckpt" 219 | if ckpt_name not in CheckpointManager.list_checkpoints(self.directory): 220 | raise ValueError(f"No checkpoint found at step {global_step}.") 221 | self.checkpoint.restore(ckpt_name) 222 | 223 | @property 224 | def latest_checkpoint(self) -> Optional[Path]: 225 | """Get the last saved checkpoint.""" 226 | ckpts = CheckpointManager.list_checkpoints(self.directory) 227 | if not ckpts: 228 | raise ValueError(f"No checkpoints found in {self.directory}.") 229 | return ckpts[-1] 230 | 231 | @staticmethod 232 | def list_checkpoints(directory: Union[Path, str]) -> List[Path]: 233 | """List all checkpoints in a checkpoint directory.""" 234 | return get_files(Path(directory), "*.ckpt", sort_numerical=True) 235 | -------------------------------------------------------------------------------- /torchkit/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | ConvType = Union[torch.nn.modules.conv.Conv2d, torch.nn.modules.conv.Conv3d] 8 | Tensor = torch.Tensor 9 | 10 | 11 | def _conv( 12 | dim: int, 13 | in_channels: int, 14 | out_channels: int, 15 | kernel_size: int = 3, 16 | stride: int = 1, 17 | dilation: int = 1, 18 | bias: bool = True, 19 | ) -> ConvType: 20 | """`same` convolution, i.e. output shape equals input shape. 21 | 22 | Args: 23 | dim: The dimension of the convolution: 2 is conv2d, 3 is conv3d. 24 | in_planes: The number of input feature maps. 25 | out_planes: The number of output feature maps. 26 | kernel_size: The filter size. 27 | stride: The filter stride. 28 | dilation: The filter dilation factor. 29 | bias: Whether to add a bias. 30 | """ 31 | assert dim in [2, 3], "[!] Only 2D and 3D convolution supported." 32 | conv = nn.Conv2d if dim == 2 else nn.Conv3d 33 | 34 | # Compute new filter size after dilation and necessary padding for `same` 35 | # output size. 36 | dilated_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size 37 | same_padding = (dilated_kernel_size - 1) // 2 38 | 39 | return conv( 40 | in_channels, 41 | out_channels, 42 | kernel_size=kernel_size, 43 | stride=stride, 44 | padding=same_padding, 45 | dilation=dilation, 46 | bias=bias, 47 | ) 48 | 49 | 50 | def conv2d(*args, **kwargs) -> torch.nn.modules.conv.Conv2d: 51 | """`same` 2D convolution, i.e. output shape equals input shape. 52 | 53 | Args: 54 | in_planes: The number of input feature maps. 55 | out_planes: The number of output feature maps. 56 | kernel_size: The filter size. 57 | stride: The filter stride. 58 | dilation: The filter dilation factor. 59 | bias: Whether to add a bias. 60 | """ 61 | return _conv(2, *args, **kwargs) 62 | 63 | 64 | def conv3d(*args, **kwargs) -> torch.nn.modules.conv.Conv3d: 65 | """`same` 3D convolution, i.e. output shape equals input shape. 66 | 67 | Args: 68 | in_planes: The number of input feature maps. 69 | out_planes: The number of output feature maps. 70 | kernel_size: The filter size. 71 | stride: The filter stride. 72 | dilation: The filter dilation factor. 73 | bias: Whether to add a bias. 74 | """ 75 | return _conv(3, *args, **kwargs) 76 | 77 | 78 | class Flatten(nn.Module): 79 | """Flattens convolutional feature maps for fully-connected layers. 80 | 81 | This is a convenience module meant to be plugged into a 82 | `torch.nn.Sequential` model. 83 | 84 | Example usage:: 85 | 86 | import torch.nn as nn 87 | from torchkit import layers 88 | 89 | # Assume an input of shape (3, 28, 28). 90 | net = nn.Sequential( 91 | layers.conv2d(3, 8, kernel_size=3), 92 | nn.ReLU(), 93 | layers.conv2d(8, 16, kernel_size=3), 94 | nn.ReLU(), 95 | layers.Flatten(), 96 | nn.Linear(28*28*16, 256), 97 | nn.ReLU(), 98 | nn.Linear(256, 2), 99 | ) 100 | """ 101 | 102 | def __init__(self): 103 | super().__init__() 104 | 105 | def forward(self, x: Tensor) -> Tensor: 106 | return x.view(x.shape[0], -1) 107 | 108 | 109 | class SpatialSoftArgmax(nn.Module): 110 | """Spatial softmax as defined in `1`_. 111 | 112 | Concretely, the spatial softmax of each feature map is used to compute a 113 | weighted mean of the pixel locations, effectively performing a soft arg-max 114 | over the feature dimension. 115 | 116 | .. _1: https://arxiv.org/abs/1504.00702 117 | """ 118 | 119 | def __init__(self, normalize: bool = False): 120 | """Constructor. 121 | 122 | Args: 123 | normalize: Whether to use normalized image coordinates, i.e. 124 | coordinates in the range `[-1, 1]`. 125 | """ 126 | super().__init__() 127 | 128 | self.normalize = normalize 129 | 130 | def _coord_grid( 131 | self, 132 | h: int, 133 | w: int, 134 | device: torch.device, 135 | ) -> Tensor: 136 | if self.normalize: 137 | return torch.stack( 138 | torch.meshgrid( 139 | torch.linspace(-1, 1, w, device=device), 140 | torch.linspace(-1, 1, h, device=device), 141 | ) 142 | ) 143 | return torch.stack( 144 | torch.meshgrid( 145 | torch.arange(0, w, device=device), 146 | torch.arange(0, h, device=device), 147 | ) 148 | ) 149 | 150 | def forward(self, x: Tensor) -> Tensor: 151 | assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)." 152 | 153 | # Compute a spatial softmax over the input: 154 | # Given an input of shape (B, C, H, W), reshape it to (B*C, H*W) then 155 | # apply the softmax operator over the last dimension. 156 | b, c, h, w = x.shape 157 | softmax = F.softmax(x.view(-1, h * w), dim=-1) 158 | 159 | # Create a meshgrid of normalized pixel coordinates. 160 | xc, yc = self._coord_grid(h, w, x.device) 161 | 162 | # Element-wise multiply the x and y coordinates with the softmax, then 163 | # sum over the h*w dimension. This effectively computes the weighted 164 | # mean x and y locations. 165 | x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True) 166 | y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True) 167 | 168 | # Concatenate and reshape the result to (B, C*2) where for every feature 169 | # we have the expected x and y pixel locations. 170 | return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2) 171 | 172 | 173 | class _GlobalMaxPool(nn.Module): 174 | """Global max pooling layer.""" 175 | 176 | def __init__(self, dim): 177 | super().__init__() 178 | 179 | if dim == 1: 180 | self._pool = F.max_pool1d 181 | elif dim == 2: 182 | self._pool = F.max_pool2d 183 | elif dim == 3: 184 | self._pool = F.max_pool3d 185 | else: 186 | raise ValueError("{}D is not supported.") 187 | 188 | def forward(self, x: Tensor) -> Tensor: 189 | out = self._pool(x, kernel_size=x.size()[2:]) 190 | for _ in range(len(out.shape[2:])): 191 | out.squeeze_(dim=-1) 192 | return out 193 | 194 | 195 | class GlobalMaxPool1d(_GlobalMaxPool): 196 | """Global max pooling operation for temporal or 1D data.""" 197 | 198 | def __init__(self): 199 | super().__init__(dim=1) 200 | 201 | 202 | class GlobalMaxPool2d(_GlobalMaxPool): 203 | """Global max pooling operation for spatial or 2D data.""" 204 | 205 | def __init__(self): 206 | super().__init__(dim=2) 207 | 208 | 209 | class GlobalMaxPool3d(_GlobalMaxPool): 210 | """Global max pooling operation for 3D data.""" 211 | 212 | def __init__(self): 213 | super().__init__(dim=3) 214 | 215 | 216 | class _GlobalAvgPool(nn.Module): 217 | """Global average pooling layer.""" 218 | 219 | def __init__(self, dim): 220 | super().__init__() 221 | 222 | if dim == 1: 223 | self._pool = F.avg_pool1d 224 | elif dim == 2: 225 | self._pool = F.avg_pool2d 226 | elif dim == 3: 227 | self._pool = F.avg_pool3d 228 | else: 229 | raise ValueError("{}D is not supported.") 230 | 231 | def forward(self, x: Tensor) -> Tensor: 232 | out = self._pool(x, kernel_size=x.size()[2:]) 233 | for _ in range(len(out.shape[2:])): 234 | out.squeeze_(dim=-1) 235 | return out 236 | 237 | 238 | class GlobalAvgPool1d(_GlobalAvgPool): 239 | """Global average pooling operation for temporal or 1D data.""" 240 | 241 | def __init__(self): 242 | super().__init__(dim=1) 243 | 244 | 245 | class GlobalAvgPool2d(_GlobalAvgPool): 246 | """Global average pooling operation for spatial or 2D data.""" 247 | 248 | def __init__(self): 249 | super().__init__(dim=2) 250 | 251 | 252 | class GlobalAvgPool3d(_GlobalAvgPool): 253 | """Global average pooling operation for 3D data.""" 254 | 255 | def __init__(self): 256 | super().__init__(dim=3) 257 | 258 | 259 | class CausalConv1d(nn.Conv1d): 260 | """A causal a.k.a. masked 1D convolution.""" 261 | 262 | def __init__( 263 | self, 264 | in_channels: int, 265 | out_channels: int, 266 | kernel_size: int, 267 | stride: int = 1, 268 | dilation: int = 1, 269 | bias: bool = True, 270 | ): 271 | """Constructor. 272 | 273 | Args: 274 | in_channels: The number of input channels. 275 | out_channels: The number of output channels. 276 | kernel_size: The filter size. 277 | stride: The filter stride. 278 | dilation: The filter dilation factor. 279 | bias: Whether to add the bias term or not. 280 | """ 281 | self.__padding = (kernel_size - 1) * dilation 282 | 283 | super().__init__( 284 | in_channels, 285 | out_channels, 286 | kernel_size=kernel_size, 287 | stride=stride, 288 | padding=self.__padding, 289 | dilation=dilation, 290 | bias=bias, 291 | ) 292 | 293 | def forward(self, x: Tensor) -> Tensor: 294 | res = super().forward(x) 295 | if self.__padding != 0: 296 | return res[:, :, : -self.__padding] 297 | return res 298 | --------------------------------------------------------------------------------