├── examples ├── __init__.py ├── pixelcnn.py ├── conditional.py ├── gan.py └── cifar.py ├── .gitignore ├── requirements-full.txt ├── logo.png ├── torchelie ├── callbacks │ ├── __init__.py │ └── avg.py ├── loss │ ├── gan │ │ ├── __init__.py │ │ ├── ls.py │ │ ├── hinge.py │ │ └── standard.py │ ├── focal.py │ ├── __init__.py │ ├── deepdreamloss.py │ ├── perceptualloss.py │ ├── functional │ │ └── __init__.py │ ├── face_rec.py │ └── neuralstyleloss.py ├── __init__.py ├── models │ ├── __init__.py │ ├── registry.py │ ├── vit.py │ ├── perceptualnet.py │ ├── convnext.py │ ├── alexnet.py │ ├── pix2pix.py │ ├── vgg.py │ ├── hourglass.py │ ├── autogan.py │ ├── mlpmixer.py │ ├── unet.py │ ├── efficient.py │ ├── patchgan.py │ └── attention.py ├── nn │ ├── imagenetinputnorm.py │ ├── __init__.py │ ├── pixelnorm.py │ ├── noise.py │ ├── functional │ │ ├── transformer.py │ │ ├── vq.py │ │ └── __init__.py │ ├── reshape.py │ ├── condseq.py │ ├── debug.py │ ├── interpolate.py │ ├── withsavedactivations.py │ ├── graph.py │ ├── maskedconv.py │ ├── llm.py │ ├── adain.py │ └── transformer.py ├── serving_utils.py ├── recipes │ ├── __init__.py │ ├── trainandcall.py │ ├── algorithm.py │ ├── trainandtest.py │ ├── gan.py │ └── deepdream.py ├── datasets │ ├── ms1m.py │ ├── concat.py │ └── debug.py └── distributions.py ├── tests ├── style.jpg ├── dream_me.jpg ├── test_sched.py ├── test_utils.py ├── test_hyper.py ├── test_optims.py ├── test_tranforms.py ├── test_datasets.py ├── test_models.py ├── test_datalearning.py ├── test_loss.py ├── test_tensorboard_callback.py ├── test_nn.py └── test_recipes.py ├── docs ├── _static │ ├── vis_example.jpg │ ├── dream_example.jpg │ ├── style_example.png │ └── css │ │ └── custom.css ├── hyper.rst ├── utils.rst ├── data_learning.rst ├── distributions.rst ├── _templates │ └── klass.rst ├── optimizers.rst ├── Makefile ├── make.bat ├── datasets.rst ├── transforms.rst ├── index.rst ├── loss.rst ├── recipes.rst ├── callbacks.rst ├── conf.py └── nn.rst ├── requirements.txt ├── .style.yapf ├── requirements-docs.txt ├── pyproject.toml ├── setup.py ├── .flake8 ├── .readthedocs.yml ├── .github └── workflows │ ├── lint.yml │ └── tests.yml ├── run_tests.sh ├── LICENSE └── scripts └── stylevgg.py /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /requirements-full.txt: -------------------------------------------------------------------------------- 1 | scipy>=1.5 2 | scikit-learn>=0.22 3 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vermeille/Torchelie/HEAD/logo.png -------------------------------------------------------------------------------- /torchelie/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from torchelie.callbacks.callbacks import * 2 | -------------------------------------------------------------------------------- /tests/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vermeille/Torchelie/HEAD/tests/style.jpg -------------------------------------------------------------------------------- /tests/dream_me.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vermeille/Torchelie/HEAD/tests/dream_me.jpg -------------------------------------------------------------------------------- /docs/_static/vis_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vermeille/Torchelie/HEAD/docs/_static/vis_example.jpg -------------------------------------------------------------------------------- /docs/hyper.rst: -------------------------------------------------------------------------------- 1 | torchelie.hyper 2 | =============== 3 | 4 | .. automodule:: torchelie.hyper 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/_static/dream_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vermeille/Torchelie/HEAD/docs/_static/dream_example.jpg -------------------------------------------------------------------------------- /docs/_static/style_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vermeille/Torchelie/HEAD/docs/_static/style_example.png -------------------------------------------------------------------------------- /docs/utils.rst: -------------------------------------------------------------------------------- 1 | torchelie.utils 2 | =============== 3 | 4 | .. automodule:: torchelie.utils 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-fid==0.2.0 2 | crayons>=0.4.0 3 | numpy>=1.22.0 4 | torch>=2.0.1 5 | torchvision>=1.13 6 | visdom>=0.1.8 7 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | split_before_logical_operator = true 4 | split_before_arithmetic_operator = true 5 | -------------------------------------------------------------------------------- /docs/data_learning.rst: -------------------------------------------------------------------------------- 1 | torchelie.data_learning 2 | ======================= 3 | 4 | .. automodule:: torchelie.data_learning 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.9.1 3 | sphinx==3.5.4 4 | visdom==0.1.8.8 5 | crayons==0.2.0 6 | Pillow==10.0.1 7 | sphinx-rtd-theme 8 | -------------------------------------------------------------------------------- /docs/distributions.rst: -------------------------------------------------------------------------------- 1 | torchelie.distributions 2 | ======================= 3 | 4 | .. automodule:: torchelie.distributions 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /torchelie/loss/gan/__init__.py: -------------------------------------------------------------------------------- 1 | import torchelie.loss.gan.hinge 2 | import torchelie.loss.gan.standard 3 | import torchelie.loss.gan.penalty 4 | import torchelie.loss.gan.ls 5 | -------------------------------------------------------------------------------- /docs/_templates/klass.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "torchelie" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Guillaume Sanchez "] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">3.8" 9 | torch = ">2" 10 | 11 | [tool.poetry.dev-dependencies] 12 | pytest = "^3.0" 13 | 14 | [build-system] 15 | requires = ["setuptools", "poetry>=0.12"] 16 | build-backend = "poetry.masonry.api" 17 | -------------------------------------------------------------------------------- /torchelie/__init__.py: -------------------------------------------------------------------------------- 1 | import torchelie.nn 2 | import torchelie.utils 3 | import torchelie.loss 4 | import torchelie.transforms 5 | import torchelie.models 6 | import torchelie.callbacks 7 | import torchelie.datasets 8 | import torchelie.data_learning 9 | import torchelie.optim 10 | import torchelie.distributions 11 | import torchelie.lr_scheduler 12 | import torchelie.hyper 13 | import torchelie.recipes 14 | 15 | __version__ = '0.1.0' 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='Torchelie', 4 | version='0.1dev', 5 | packages=find_packages(), 6 | classifiers=[ 7 | "License :: OSI Approved :: MIT License", 8 | ], 9 | install_requires=[ 10 | 'visdom>=0.1.8', 11 | 'crayons>=0.2', 12 | 'torchvision>=0.13', 13 | 'torch>=2', 14 | 'numpy>=1.16', 15 | 'Pillow>=6', 16 | ]) 17 | -------------------------------------------------------------------------------- /docs/optimizers.rst: -------------------------------------------------------------------------------- 1 | torchelie.optim 2 | =============== 3 | 4 | .. autoclass:: torchelie.optim.Lookahead 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: torchelie.optim.RAdamW 9 | :members: 10 | 11 | .. autoclass:: torchelie.optim.DeepDreamOptim 12 | :members: 13 | 14 | .. autoclass:: torchelie.optim.AddSign 15 | :members: 16 | 17 | torchelie.lr_scheduler 18 | ====================== 19 | 20 | .. automodule:: torchelie.lr_scheduler 21 | :members: 22 | :undoc-members: 23 | 24 | -------------------------------------------------------------------------------- /tests/test_sched.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import SGD 3 | from torchelie.lr_scheduler import * 4 | 5 | 6 | def test_curriculum(): 7 | x = torch.randn(3, requires_grad=True) 8 | opt = SGD([x], 1) 9 | sched = CurriculumScheduler(opt, [[0, 1, 1], [10, 0, 0]]) 10 | opt.step() 11 | sched.step() 12 | 13 | 14 | def test_onecycle(): 15 | x = torch.randn(3, requires_grad=True) 16 | opt = SGD([x], 1) 17 | sched = OneCycle(opt, (1e-3, 1e-2), 10) 18 | opt.step() 19 | sched.step() 20 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Newlines (\a) and spaces (\20) before each parameter */ 2 | .sig-param::before { 3 | content: "\a\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20"; 4 | white-space: pre; 5 | } 6 | 7 | /* Newline after the last parameter (so the closing bracket is on a new line) */ 8 | dt em.sig-param:last-of-type::after { 9 | content: "\a"; 10 | white-space: pre; 11 | } 12 | 13 | /* To have blue background of width of the block (instead of width of content) */ 14 | dl.class > dt:first-of-type { 15 | display: block !important; 16 | } 17 | -------------------------------------------------------------------------------- /torchelie/loss/focal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .functional import focal_loss 4 | 5 | 6 | class FocalLoss(nn.Module): 7 | """ 8 | The focal loss 9 | 10 | https://arxiv.org/abs/1708.02002 11 | 12 | See :func:`torchelie.loss.focal_loss` for details. 13 | """ 14 | def __init__(self, gamma: float = 0): 15 | super(FocalLoss, self).__init__() 16 | self.gamma = gamma 17 | 18 | def forward(self, input: torch.Tensor, 19 | target: torch.Tensor) -> torch.Tensor: 20 | return focal_loss(input, target, self.gamma) 21 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 80 3 | 4 | extend-ignore = 5 | # W503 - because it's outdated and replaced by W504 while conflicting with it 6 | W503, 7 | # F401 - This is temporary, see https://github.com/Vermeille/Torchelie/issues/47 8 | F401, 9 | # E722 - This is temporary, see https://github.com/Vermeille/Torchelie/issues/48 10 | E722, 11 | # F405 - This is temporary, see https://github.com/Vermeille/Torchelie/issues/49 12 | F405, 13 | # F403 - This is temporary, see https://github.com/Vermeille/Torchelie/issues/49 14 | F403, 15 | # E501 - This is temporary, see https://github.com/Vermeille/Torchelie/issues/50 16 | E501, 17 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | # Build documentation with MkDocs 13 | #mkdocs: 14 | # configuration: mkdocs.yml 15 | 16 | # Optionally build your docs in additional formats such as PDF and ePub 17 | # formats: all 18 | 19 | # Optionally set the version of Python and requirements required to build your docs 20 | python: 21 | version: 3.7 22 | install: 23 | - requirements: ./requirements-docs.txt 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Lint Code Base 3 | on: [push] 4 | 5 | jobs: 6 | lint: 7 | name: Lint Code Base 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout Code 11 | uses: actions/checkout@v2 12 | with: 13 | fetch-depth: 0 14 | - name: Lint with GitHub Super Linter 15 | uses: github/super-linter@v3 16 | env: 17 | LINTER_RULES_PATH: . # Avoid having to create .github/linters/ dir 18 | VALIDATE_PYTHON_FLAKE8: true # Code lint 19 | # GITHUB_TOKEN allows the GitHub Super Linter to mark the status of 20 | # each individual linter runs in the Checks sections of a pull request 21 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | -------------------------------------------------------------------------------- /torchelie/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import ClassificationHead, ProjectionDiscr 2 | from .vgg import * 3 | from .resnet import * 4 | from .patchgan import * 5 | from .pixcnn import * 6 | from .perceptualnet import * 7 | from .unet import * 8 | from .pix2pix import * 9 | from .pix2pixhd import * 10 | from .autogan import AutoGAN, autogan_32, autogan_64, autogan_128 11 | from .snres_discr import * 12 | from .hourglass import Hourglass 13 | from .attention import Attention56Bone, attention56 14 | from .stylegan2 import StyleGAN2Generator, StyleGAN2Discriminator 15 | from .efficient import EfficientNet 16 | from .registry import * 17 | from .alexnet import * 18 | from .mlpmixer import * 19 | from .convnext import * 20 | from .vit import ViTTrunk 21 | # from .poolformer import * 22 | -------------------------------------------------------------------------------- /torchelie/nn/imagenetinputnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ImageNetInputNorm(nn.Module): 6 | """ 7 | Normalize images channels as torchvision models expects, in a 8 | differentiable way 9 | """ 10 | 11 | def __init__(self): 12 | super(ImageNetInputNorm, self).__init__() 13 | self.register_buffer('norm_mean', 14 | torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) 15 | 16 | self.register_buffer('norm_std', 17 | torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)) 18 | 19 | def forward(self, input): 20 | return (input - self.norm_mean) / self.norm_std 21 | 22 | def inverse(self, input): 23 | return input * self.norm_std + self.norm_mean 24 | -------------------------------------------------------------------------------- /torchelie/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .reshape import Reshape, Lambda, Permute 2 | from .conv import * 3 | from .debug import Debug, Dummy 4 | from .noise import Noise 5 | from .vq import VQ, MultiVQ, RVQ 6 | from .imagenetinputnorm import ImageNetInputNorm 7 | from .withsavedactivations import WithSavedActivations 8 | from .maskedconv import MaskedConv2d, TopLeftConv2d 9 | from .batchnorm import * 10 | from .adain import AdaIN2d, FiLM2d, FiLM 11 | from .pixelnorm import PixelNorm, ChannelNorm 12 | from .blocks import * 13 | from .layers import * 14 | from .condseq import CondSeq 15 | from .graph import ModuleGraph 16 | from .encdec import * 17 | from .interpolate import * 18 | from .resblock import * 19 | import torchelie.nn.utils 20 | from .transformer import LocalSelfAttention2d 21 | from .llm import * 22 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | if false; then 6 | RUN="--device cuda --iters 1000 --visdom-env torch-test" 7 | else 8 | RUN="--device cpu --iters 1" 9 | fi 10 | 11 | pytest tests 12 | python3 -m torchelie.recipes.deepdream \ 13 | --input tests/dream_me.jpg \ 14 | --out tests/dreamed.png \ 15 | $RUN 16 | 17 | python3 -m torchelie.recipes.neural_style \ 18 | --content tests/dream_me.jpg \ 19 | --style tests/style.jpg \ 20 | --out tests/styled.png \ 21 | --size 512 \ 22 | --ratio 300 \ 23 | $RUN 24 | 25 | python3 -m torchelie.recipes.feature_vis \ 26 | --model resnet \ 27 | --layer layer4 \ 28 | --neuron 0 \ 29 | $RUN 30 | 31 | python3 examples/mnist.py 32 | python3 examples/conditional.py 33 | python3 examples/gan.py 34 | python3 examples/pixelcnn.py 35 | 36 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchelie.utils import * 4 | 5 | 6 | def test_net(): 7 | m = torch.nn.Linear(10, 4) 8 | freeze(m) 9 | unfreeze(m) 10 | kaiming(m) 11 | xavier(m) 12 | normal_init(m, 0.02) 13 | nb_parameters(m) 14 | assert m is layer_by_name(torch.nn.Sequential(m), '0') 15 | assert layer_by_name(torch.nn.Sequential(m), 'test') is None 16 | send_to_device([{'a': [m]}], 'cpu') 17 | 18 | fm = FrozenModule(m) 19 | fm.train() 20 | assert not fm.weight.requires_grad 21 | fm.weight 22 | 23 | fm = DetachedModule(m) 24 | fm.weight 25 | 26 | 27 | def test_utils(): 28 | entropy(torch.randn(1, 10)) 29 | gram(torch.randn(4, 10)) 30 | bgram(torch.randn(3, 4, 10)) 31 | assert dict_by_key({'a': [{'b': 42}]}, 'a.0.b') == 42 32 | assert lerp(0, 2, 0.5) == 1 33 | assert ilerp(0, 2, 1) == 0.5 34 | -------------------------------------------------------------------------------- /tests/test_hyper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import suppress 3 | 4 | from torchelie.hyper import HyperparamSearch, UniformSampler 5 | 6 | 7 | def beale(x, y): 8 | return (1.5 - x + x * y)**2 + (2.25 - x + x * y**2)**2 + (2.625 - x + 9 | x * y**3)**2 10 | 11 | 12 | def sphere(x, y): 13 | return x**2 + y**2 14 | 15 | 16 | def rosen(x, y): 17 | return 100 * (y - x**2)**2 + (1 - x)**2 18 | 19 | 20 | hpsearch = HyperparamSearch(x=UniformSampler(-4.5, 4.5), 21 | y=UniformSampler(-4.5, 4.5)) 22 | 23 | with suppress(FileNotFoundError): 24 | os.remove('hpsearch.json') 25 | 26 | print(beale(3, 0.5)) 27 | for _ in range(30): 28 | hps = hpsearch.sample(algorithm='gp', target='out') 29 | out = -beale(**hps.params) 30 | print(hps, '\t', out) 31 | hpsearch.log_result(hps, {'out': out}) 32 | -------------------------------------------------------------------------------- /tests/test_optims.py: -------------------------------------------------------------------------------- 1 | from torchelie.optim import * 2 | 3 | 4 | def test_deepdream(): 5 | a = torch.randn(5, requires_grad=True) 6 | opt = DeepDreamOptim([a]) 7 | a.mean().backward() 8 | opt.step() 9 | 10 | 11 | def test_addsign(): 12 | a = torch.randn(5, requires_grad=True) 13 | opt = AddSign([a]) 14 | a.mean().backward() 15 | opt.step() 16 | 17 | 18 | def test_radamw(): 19 | a = torch.randn(5, requires_grad=True) 20 | opt = RAdamW([a]) 21 | a.mean().backward() 22 | opt.step() 23 | 24 | 25 | def test_adabelief(): 26 | a = torch.randn(5, requires_grad=True) 27 | opt = AdaBelief([a]) 28 | a.mean().backward() 29 | opt.step() 30 | 31 | 32 | def test_lookahead(): 33 | a = torch.randn(5, requires_grad=True) 34 | b = torch.randn(5, requires_grad=True) 35 | opt = Lookahead(AdaBelief([a, b])) 36 | for _ in range(10): 37 | a.mean().backward() 38 | opt.step() 39 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /torchelie/serving_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | SUPER ALPHA 3 | """ 4 | 5 | import random 6 | import os 7 | 8 | from typing import List 9 | import PIL 10 | 11 | 12 | class ImageLogger: 13 | """ 14 | Logs images and decisions up to a capacity per class 15 | """ 16 | def __init__(self, root: str, classes: List[str], capacity: int = 200): 17 | self.root = root 18 | self.classes = classes 19 | self.capacity = capacity 20 | for kls in classes: 21 | os.makedirs(root + '/' + kls, exist_ok=True) 22 | 23 | def __call__(self, images, klass: List[int]) -> None: 24 | for i, k in zip(images, klass): 25 | k_name = self.classes[k] 26 | n = random.randint(-int(self.capacity * 0.25), self.capacity) 27 | if n < 0: 28 | continue 29 | path = f'{self.root}/{k_name}/{n}.jpg' 30 | if isinstance(i, PIL.Image.Image): 31 | i.save(path) 32 | -------------------------------------------------------------------------------- /torchelie/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Recipes are a way to provide off the shelf algorithms that ce be either used 3 | directly from the command line or easily imported in a python script if more 4 | flexibility is needed. 5 | 6 | A recipe should, as much a possible, be agnostic of the data and the underlying 7 | model so that it can be used as a way to quickly try an algorithm on new data 8 | or be easily experimented on by changing the model 9 | """ 10 | from torchelie.recipes.recipebase import Recipe 11 | from torchelie.recipes.classification import Classification 12 | from torchelie.recipes.classification import CrossEntropyClassification 13 | from torchelie.recipes.classification import MixupClassification 14 | from torchelie.recipes.deepdream import DeepDream 15 | from torchelie.recipes.feature_vis import FeatureVis 16 | from torchelie.recipes.neural_style import NeuralStyle 17 | from torchelie.recipes.trainandtest import TrainAndTest 18 | from torchelie.recipes.trainandcall import TrainAndCall 19 | from torchelie.recipes.gan import GANRecipe 20 | -------------------------------------------------------------------------------- /docs/datasets.rst: -------------------------------------------------------------------------------- 1 | torchelie.datasets 2 | ================== 3 | 4 | .. currentmodule:: torchelie.datasets 5 | 6 | Debug datasets 7 | ~~~~~~~~~~~~~~ 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :template: klass.rst 12 | :nosignatures: 13 | 14 | ColoredColumns 15 | ColoredRows 16 | MS1M 17 | Pix2PixDataset 18 | Imagenette 19 | Imagewoof 20 | 21 | Loaders 22 | ~~~~~~~ 23 | 24 | .. autosummary:: 25 | :toctree: generated 26 | :template: klass.rst 27 | :nosignatures: 28 | 29 | FastImageFolder 30 | UnlabeledImages 31 | ImagesPaths 32 | SideBySideImagePairsDataset 33 | 34 | Datasets wrappers 35 | ~~~~~~~~~~~~~~~~~ 36 | 37 | .. autosummary:: 38 | :toctree: generated 39 | :template: klass.rst 40 | :nosignatures: 41 | 42 | PairedDataset 43 | RandomPairsDataset 44 | MixUpDataset 45 | Subset 46 | NoexceptDataset 47 | WithIndexDataset 48 | CachedDataset 49 | HorizontalConcatDataset 50 | MergedDataset 51 | 52 | Functions 53 | ~~~~~~~~~ 54 | 55 | .. autofunction:: torchelie.datasets.mixup 56 | -------------------------------------------------------------------------------- /docs/transforms.rst: -------------------------------------------------------------------------------- 1 | torchelie.transforms 2 | ==================== 3 | 4 | .. autoclass:: torchelie.transforms.ResizeNoCrop 5 | :members: 6 | 7 | .. autoclass:: torchelie.transforms.ResizedCrop 8 | :members: 9 | 10 | .. autoclass:: torchelie.transforms.AdaptPad 11 | :members: 12 | 13 | .. autoclass:: torchelie.transforms.MultiBranch 14 | :members: 15 | 16 | .. autoclass:: torchelie.transforms.Canny 17 | :members: 18 | 19 | .. autoclass:: torchelie.transforms.RandAugment 20 | :members: 21 | 22 | .. autoclass:: torchelie.transforms.Posterize 23 | :members: 24 | 25 | .. autoclass:: torchelie.transforms.Solarize 26 | :members: 27 | 28 | .. autoclass:: torchelie.transforms.Cutout 29 | :members: 30 | 31 | .. autoclass:: torchelie.transforms.Identity 32 | :members: 33 | 34 | .. autoclass:: torchelie.transforms.Subsample 35 | :members: 36 | 37 | .. autoclass:: torchelie.transforms.JPEGArtifacts 38 | :members: 39 | 40 | 41 | 42 | torchelie.transforms.differentiable 43 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 44 | 45 | .. automodule:: torchelie.transforms.differentiable 46 | :members: 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Guillaume "Vermeille" Sanchez 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchelie/nn/pixelnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PixelNorm(torch.nn.Module): 5 | """ 6 | PixelNorm from ProgressiveGAN 7 | """ 8 | 9 | def forward(self, x): 10 | return x / (x.pow(2).mean(dim=1, keepdim=True).sqrt() + 1e-8) 11 | 12 | 13 | class ChannelNorm(torch.nn.Module): 14 | 15 | def __init__(self, dim=1, affine=True, channels=1): 16 | super().__init__() 17 | self.affine = affine 18 | if affine: 19 | self.weight = torch.nn.Parameter(torch.ones(channels)) 20 | #self.bias = torch.nn.Parameter(torch.zeros(channels)) 21 | 22 | if isinstance(dim, int): 23 | dim = [dim] 24 | self.dim = dim 25 | 26 | if isinstance(channels, int): 27 | channels = [channels] 28 | self.channels = channels 29 | 30 | def forward(self, x): 31 | var = x.var(dim=self.dim, keepdim=True, unbiased=False) 32 | if not self.affine: 33 | return (x) * var.rsqrt() 34 | expand = [(self.channels[self.dim.index(i)] if i in self.dim else 1) 35 | for i in range(x.dim())] 36 | 37 | w = self.weight.view(expand) 38 | return x * var.rsqrt() * w 39 | -------------------------------------------------------------------------------- /torchelie/nn/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional 4 | 5 | 6 | class Noise(nn.Module): 7 | """ 8 | Add gaussian noise to the input, with a per channel or global learnable std. 9 | 10 | Args: 11 | ch (int): number of input channels for a different std on each channel, 12 | or 1 13 | """ 14 | 15 | def __init__(self, ch: int, inplace: bool = False, bias: bool = False): 16 | super(Noise, self).__init__() 17 | self.a = nn.Parameter(torch.zeros(ch, 1, 1)) 18 | self.inplace = inplace 19 | self.bias = nn.Parameter(torch.zeros_like(self.a)) if bias else None 20 | 21 | def forward(self, 22 | x: torch.Tensor, 23 | z: Optional[torch.Tensor] = None) -> torch.Tensor: 24 | N, C, H, W = x.shape 25 | if z is None: 26 | z = torch.randn(N, 1, H, W, device=x.device, dtype=x.dtype) 27 | else: 28 | assert z.shape == [N, 1, H, W] 29 | z = z * self.a 30 | if self.bias is not None: 31 | z.add_(self.bias) 32 | 33 | if self.inplace: 34 | return x.add_(z) 35 | else: 36 | return x + z 37 | -------------------------------------------------------------------------------- /torchelie/nn/functional/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def local_attention_2d(x: Tensor, conv_kqv: nn.Conv2d, posenc: Tensor, 10 | num_heads: int, patch_size: int) -> Tensor: 11 | B, inC, fullH, fullW = x.shape 12 | N = num_heads 13 | P = patch_size 14 | H, W = fullH // P, fullW // P 15 | 16 | x = x.view(B, inC, H, P, W, P).permute(0, 2, 4, 1, 3, 17 | 5).reshape(B * H * W, inC, P, P) 18 | k, q, v = torch.chunk(F.conv2d(x, conv_kqv.weight / math.sqrt(inC // N), 19 | conv_kqv.bias), 20 | 3, 21 | dim=1) 22 | hidC = k.shape[1] // N 23 | k = k.view(B, H * W, N, hidC, P * P) 24 | q = q.view(B, H * W, N, hidC, P * P) 25 | 26 | kq = torch.softmax(torch.matmul(q.transpose(-1, -2), k) + posenc, dim=-1) 27 | v = v.view(B, H * W, N, hidC, P * P) 28 | kqv = torch.matmul(v, kq.transpose(-1, -2)).view(B, H, W, N, hidC, P, P) 29 | kqv = kqv.permute(0, 3, 4, 1, 5, 2, 6).reshape(B, hidC * N, fullH, fullW) 30 | return kqv, kq 31 | -------------------------------------------------------------------------------- /torchelie/nn/reshape.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Lambda(nn.Module): 5 | """ 6 | Applies a lambda function on forward() 7 | 8 | Args: 9 | lamb (fn): the lambda function 10 | """ 11 | lam: nn.Module 12 | 13 | def __init__(self, lam: nn.Module): 14 | super(Lambda, self).__init__() 15 | self.lam = lam 16 | 17 | def forward(self, x): 18 | return self.lam(x) 19 | 20 | 21 | class Reshape(nn.Module): 22 | """ 23 | Reshape the input volume 24 | 25 | Args: 26 | *shape (ints): new shape, WITHOUT specifying batch size as first 27 | dimension, as it will remain unchanged. 28 | """ 29 | 30 | def __init__(self, *shape): 31 | super(Reshape, self).__init__() 32 | self.shape = shape 33 | 34 | def forward(self, x): 35 | return x.view(x.shape[0], *self.shape) 36 | 37 | 38 | class Permute(nn.Module): 39 | """ 40 | Permute the dimensions of the input tensor 41 | 42 | Args: 43 | *dims (ints): new permutation of the dimensions 44 | """ 45 | 46 | def __init__(self, *dims): 47 | super(Permute, self).__init__() 48 | self.dims = dims 49 | 50 | def forward(self, x): 51 | return x.permute(*self.dims) 52 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: "Torchélie tests" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | install-dependencies: 7 | name: "Dependencies + Tests" 8 | runs-on: ubuntu-20.04 9 | steps: 10 | - name: "Checkout code" 11 | uses: actions/checkout@v2 12 | - name: "Cache pip" 13 | uses: actions/cache@v2 14 | with: 15 | # This path is specific to Ubuntu 16 | path: ~/.cache/pip 17 | # Look to see if there is a cache hit for the corresponding requirements file 18 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 19 | restore-keys: | 20 | ${{ runner.os }}-pip- 21 | ${{ runner.os }}- 22 | - name: "Set up Python" 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: 3.9 26 | - name: "Install Python tooling" 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pytest 30 | - name: "Install dependencies" 31 | run: | 32 | pip install -r requirements.txt 33 | pip install -r requirements-full.txt 34 | - name: "Run tests with pytest (without OpenCV)" 35 | run: | 36 | python -m pytest -m "not require_opencv and not require_tensorboard" 37 | -------------------------------------------------------------------------------- /torchelie/nn/condseq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Any, Optional, Callable, cast 3 | 4 | 5 | class CondSeq(nn.Sequential): 6 | """ 7 | An extension to torch's Sequential that allows conditioning either as a 8 | second forward argument or `condition()` 9 | """ 10 | 11 | def condition(self, z: Any) -> None: 12 | """ 13 | Conditions all the layers on z 14 | 15 | Args: 16 | z: conditioning 17 | """ 18 | for m in self: 19 | if hasattr(m, 'condition') and m is not self: 20 | cast(Callable, m.condition)(z) 21 | 22 | def forward(self, x: Any, z: Optional[Any] = None) -> Any: 23 | """ 24 | Forward pass 25 | 26 | Args: 27 | x: input 28 | z (optional): conditioning. condition() must be called first if 29 | left None 30 | """ 31 | for nm, m in self.named_children(): 32 | try: 33 | if hasattr(m, 'condition') and z is not None: 34 | x = m(x, z) 35 | else: 36 | x = m(x) 37 | except Exception as e: 38 | raise Exception( 39 | f'Exception during forward pass of {nm}') from e 40 | return x 41 | -------------------------------------------------------------------------------- /torchelie/loss/gan/ls.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Least Square GAN 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def real(x: torch.Tensor, reduce: str = 'mean') -> torch.Tensor: 10 | out = F.mse_loss(x, torch.ones_like(x), reduce='none') 11 | if reduce == 'none': 12 | return out 13 | if reduce == 'mean': 14 | return out.mean() 15 | if reduce == 'batch_mean': 16 | return out.view(out.shape[0], -1).sum(1).mean() 17 | assert False, f'reduction {reduce} invalid' 18 | 19 | 20 | def fake(x: torch.Tensor, reduce: str = 'mean') -> torch.Tensor: 21 | out = F.mse_loss(x, -torch.ones_like(x), reduce='none') 22 | if reduce == 'none': 23 | return out 24 | if reduce == 'mean': 25 | return out.mean() 26 | if reduce == 'batch_mean': 27 | return out.view(out.shape[0], -1).sum(1).mean() 28 | assert False, f'reduction {reduce} invalid' 29 | 30 | 31 | def generated(x: torch.Tensor, reduce: str = 'mean') -> torch.Tensor: 32 | out = F.mse_loss(x, torch.ones_like(x), reduce='none') 33 | if reduce == 'none': 34 | return out 35 | if reduce == 'mean': 36 | return out.mean() 37 | if reduce == 'batch_mean': 38 | return out.view(out.shape[0], -1).sum(1).mean() 39 | assert False, f'reduction {reduce} invalid' 40 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Torchélie's documentation! 2 | ===================================== 3 | 4 | .. toctree:: 5 | :caption: Pytorch Utils 6 | :maxdepth: 1 7 | 8 | nn 9 | optimizers 10 | utils 11 | data_learning 12 | loss 13 | models 14 | distributions 15 | datasets 16 | transforms 17 | recipes 18 | 19 | .. toctree:: 20 | :caption: Algorithms and training 21 | :maxdepth: 1 22 | 23 | recipe_tuto 24 | callbacks 25 | hyper 26 | 27 | Torchélie aims to provide new utilities, layers, tools to pytorch users, 28 | as an extension library. It aims for minimal codependence between components, 29 | so that you can use only what you need with a strong pytorch flavour, without 30 | learning a new tool. 31 | 32 | It provides: 33 | 34 | - layers 35 | - models (untrained) 36 | - new optimizers and schedulers 37 | - utility functions 38 | - datasets and dataset utilities 39 | - losses 40 | - and some more specific things 41 | 42 | Torchélie also tries to meet the likes of Ignite and Pytorch-Lightning by 43 | providing training loops with automatic logging, averaging, checkpointing, and 44 | visualisation. Unlike those however, instead of providing a "one size fits all" 45 | training loop, Torchélie aims to make writing them easy. 46 | 47 | * :ref:`modindex` 48 | * :ref:`search` 49 | 50 | 51 | -------------------------------------------------------------------------------- /torchelie/nn/debug.py: -------------------------------------------------------------------------------- 1 | import crayons 2 | 3 | import torch.nn as nn 4 | from torchelie.utils import experimental 5 | 6 | 7 | class Debug(nn.Module): 8 | """ 9 | An pass-through layer that prints some debug info during forward pass. 10 | It prints its name, the input's shape, mean of channels means, mean, 11 | mean of channels std, and std. 12 | 13 | Args: 14 | name (str): this layer's name 15 | """ 16 | @experimental 17 | def __init__(self, name): 18 | super(Debug, self).__init__() 19 | self.name = name 20 | 21 | def forward(self, x): 22 | print(crayons.yellow(self.name)) 23 | print(crayons.yellow('----')) 24 | print('Shape {}'.format(x.shape)) 25 | if x.ndim == 2: 26 | print("Stats mean {:.2f} var {:.2f}".format(x.mean().item(), 27 | x.std().item())) 28 | if x.ndim == 4: 29 | print("Stats mean {:.2f} {:.2f} var s{:.2f} {:.2f}".format( 30 | x.mean(dim=[0, 2, 3]).mean().item(), 31 | x.mean().item(), 32 | x.std(dim=[0, 2, 3]).mean().item(), 33 | x.std().item())) 34 | print() 35 | return x 36 | 37 | 38 | class Dummy(nn.Module): 39 | """ 40 | A pure pass-through layer 41 | """ 42 | def forward(self, x): 43 | return x 44 | -------------------------------------------------------------------------------- /torchelie/loss/gan/hinge.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Hinge loss from Spectral Normalization GAN. 3 | 4 | https://arxiv.org/abs/1802.05957 5 | 6 | :math:`L_D(x_r, x_f) = \text{max}(0, 1 - D(x_r)) + \text{max}(0, 1 + D(x_f))` 7 | 8 | :math:`L_G(x_f) = -D(x_f)` 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def real(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: 16 | out = F.relu(1 - x) 17 | if reduction == 'none': 18 | return out 19 | if reduction == 'mean': 20 | return out.mean() 21 | if reduction == 'sum': 22 | return out.sum() 23 | assert False, f'{reduction} is not a valid reduction method' 24 | 25 | 26 | def fake(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: 27 | out = F.relu(1 + x) 28 | if reduction == 'none': 29 | return out 30 | if reduction == 'mean': 31 | return out.mean() 32 | if reduction == 'sum': 33 | return out.sum() 34 | assert False, f'{reduction} is not a valid reduction method' 35 | 36 | 37 | def generated(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: 38 | out = -x 39 | if reduction == 'none': 40 | return out 41 | if reduction == 'mean': 42 | return out.mean() 43 | if reduction == 'sum': 44 | return out.sum() 45 | assert False, f'{reduction} is not a valid reduction method' 46 | -------------------------------------------------------------------------------- /docs/loss.rst: -------------------------------------------------------------------------------- 1 | torchelie.loss 2 | ============== 3 | 4 | Functions 5 | ~~~~~~~~~ 6 | 7 | .. autofunction:: torchelie.loss.tempered_cross_entropy 8 | .. autofunction:: torchelie.loss.tempered_nll_loss 9 | .. autofunction:: torchelie.loss.tempered_softmax 10 | .. autofunction:: torchelie.loss.tempered_log_softmax 11 | .. autofunction:: torchelie.loss.ortho 12 | .. autofunction:: torchelie.loss.total_variation 13 | .. autofunction:: torchelie.loss.continuous_cross_entropy 14 | .. autofunction:: torchelie.loss.focal_loss 15 | 16 | Modules 17 | ~~~~~~~ 18 | 19 | .. autoclass:: torchelie.loss.TemperedCrossEntropyLoss 20 | :members: 21 | :undoc-members: 22 | .. autoclass:: torchelie.loss.OrthoLoss 23 | :members: 24 | :undoc-members: 25 | .. autoclass:: torchelie.loss.TotalVariationLoss 26 | :members: 27 | :undoc-members: 28 | .. autoclass:: torchelie.loss.ContinuousCEWithLogits 29 | :members: 30 | :undoc-members: 31 | .. autoclass:: torchelie.loss.FocalLoss 32 | :members: 33 | :undoc-members: 34 | .. autoclass:: torchelie.loss.PerceptualLoss 35 | :members: 36 | :undoc-members: 37 | .. autoclass:: torchelie.loss.NeuralStyleLoss 38 | :members: 39 | :undoc-members: 40 | .. autoclass:: torchelie.loss.DeepDreamLoss 41 | :members: 42 | :undoc-members: 43 | 44 | GAN losses 45 | ~~~~~~~~~~ 46 | 47 | .. automodule:: torchelie.loss.gan.hinge 48 | :members: 49 | :undoc-members: 50 | 51 | .. automodule:: torchelie.loss.gan.standard 52 | :members: 53 | :undoc-members: 54 | -------------------------------------------------------------------------------- /torchelie/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torchelie.loss.gan 5 | from .functional import ortho, total_variation, continuous_cross_entropy 6 | from .functional import focal_loss 7 | 8 | from .perceptualloss import PerceptualLoss 9 | from .neuralstyleloss import NeuralStyleLoss 10 | from .deepdreamloss import DeepDreamLoss 11 | from .focal import FocalLoss 12 | from .bitempered import tempered_cross_entropy, TemperedCrossEntropyLoss 13 | from .bitempered import tempered_softmax, tempered_nll_loss 14 | from .bitempered import tempered_log_softmax 15 | 16 | class OrthoLoss(nn.Module): 17 | """ 18 | Orthogonal loss 19 | 20 | See :func:`torchelie.loss.ortho` for details. 21 | """ 22 | def forward(self, w): 23 | return ortho(w) 24 | 25 | 26 | class TotalVariationLoss(nn.Module): 27 | """ 28 | Total Variation loss 29 | 30 | See :func:`torchelie.loss.total_variation` for details. 31 | """ 32 | def forward(self, x): 33 | return total_variation(x) 34 | 35 | 36 | class ContinuousCEWithLogits(nn.Module): 37 | """ 38 | Cross Entropy loss accepting continuous target values 39 | 40 | See :func:`torchelie.loss.continuous_cross_entropy` for details. 41 | """ 42 | def forward(self, pred, soft_targets): 43 | return continuous_cross_entropy(pred, soft_targets) 44 | 45 | 46 | def binary_hinge(x, y): 47 | p = y * torch.clamp(1 - x, min=0) + (1 - y) * torch.clamp(1 + x, min=0) 48 | return p.mean() 49 | -------------------------------------------------------------------------------- /tests/test_tranforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from torchvision.transforms import ToPILImage 4 | from torchelie.transforms import * 5 | import torchelie.transforms.differentiable as dtf 6 | 7 | 8 | def test_resizenocrop(): 9 | img = ToPILImage()(torch.clamp(torch.randn(3, 32, 16) + 1, min=0, max=1)) 10 | tf = ResizeNoCrop(16) 11 | assert tf(img).width == 8 12 | assert tf(img).height == 16 13 | 14 | 15 | def test_adaptpad(): 16 | img = ToPILImage()(torch.clamp(torch.randn(3, 30, 16) + 1, min=0, max=1)) 17 | tf = AdaptPad((32, 32)) 18 | assert tf(img).width == 32 19 | assert tf(img).height == 32 20 | 21 | 22 | def test_multibranch(): 23 | tf = MultiBranch([ 24 | lambda x: x + 1, 25 | lambda x: x * 3, 26 | lambda x: x - 1, 27 | ]) 28 | assert tf(1) == (2, 3, 0) 29 | 30 | 31 | @pytest.mark.require_opencv 32 | def test_canny(): 33 | img = ToPILImage()(torch.clamp(torch.randn(3, 30, 16) + 1, min=0, max=1)) 34 | tf = Canny() 35 | tf(img) 36 | 37 | 38 | def test_resizedcrop(): 39 | img = ToPILImage()(torch.clamp(torch.randn(3, 30, 16) + 1, min=0, max=1)) 40 | tf = ResizedCrop(48) 41 | tf(img) 42 | 43 | 44 | def test_diff(): 45 | dtf.roll(torch.randn(3, 16, 16), 3, 3) 46 | dtf.roll(torch.randn(1, 3, 16, 16), 3, 3) 47 | dtf.center_crop(torch.randn(3, 16, 16), (4, 4)) 48 | dtf.crop(torch.randn(1, 3, 16, 16)) 49 | dtf.gblur(torch.randn(1, 3, 16, 16)) 50 | dtf.mblur(torch.randn(1, 3, 16, 16)) 51 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | from torchelie.datasets import * 2 | 3 | 4 | class StupidDataset: 5 | def __init__(self): 6 | self.classes = [0, 1] 7 | self.imgs = [(i, 0 if i < 5 else 1) for i in range(10)] 8 | self.samples = self.imgs 9 | 10 | def __len__(self): 11 | return 10 12 | 13 | def __getitem__(self, i): 14 | return torch.FloatTensor([i]), torch.LongTensor([self.imgs[i][1]]) 15 | 16 | 17 | def test_colored(): 18 | cc = ColoredColumns(64, 64) 19 | cr = ColoredRows(64, 64) 20 | 21 | cc[0] 22 | cr[0] 23 | 24 | 25 | def test_paired(): 26 | ds = StupidDataset() 27 | ds2 = StupidDataset() 28 | 29 | pd = PairedDataset(ds, ds2) 30 | assert len(pd) == 100 31 | assert pd[0] == list(zip(ds[0], ds2[0])) 32 | 33 | 34 | def test_cat(): 35 | ds = StupidDataset() 36 | ds2 = StupidDataset() 37 | 38 | cated = HorizontalConcatDataset([ds, ds2]) 39 | assert len(cated) == len(ds) * 2 40 | print([cated[i][1] for i in range(20)]) 41 | 42 | for i in range(len(ds)): 43 | assert cated[i][1] == (0 if i < 5 else 1) 44 | assert cated[len(ds) + i][1] == (2 if i < 5 else 3) 45 | 46 | 47 | def test_mixup(): 48 | ds = StupidDataset() 49 | md = MixUpDataset(ds) 50 | md[0] 51 | 52 | 53 | def test_cached(): 54 | ds = StupidDataset() 55 | md = CachedDataset(ds) 56 | md[0] 57 | 58 | 59 | def test_withindex(): 60 | ds = StupidDataset() 61 | md = WithIndexDataset(ds) 62 | md[0] 63 | -------------------------------------------------------------------------------- /docs/recipes.rst: -------------------------------------------------------------------------------- 1 | torchelie.recipes 2 | ================= 3 | 4 | .. automodule:: torchelie.recipes 5 | 6 | .. autoclass:: torchelie.recipes.recipebase.Recipe 7 | :members: 8 | :inherited-members: 9 | 10 | .. autofunction:: torchelie.recipes.TrainAndCall 11 | 12 | .. autofunction:: torchelie.recipes.TrainAndTest 13 | 14 | .. autofunction:: torchelie.recipes.Classification 15 | 16 | .. autofunction:: torchelie.recipes.CrossEntropyClassification 17 | 18 | .. autofunction:: torchelie.recipes.gan.GANRecipe 19 | 20 | Model Training 21 | ~~~~~~~~~~~~~~ 22 | 23 | .. autofunction:: torchelie.recipes.trainandtest.TrainAndTest 24 | 25 | .. autofunction:: torchelie.recipes.classification.Classification 26 | 27 | .. autofunction:: torchelie.recipes.classification.CrossEntropyClassification 28 | 29 | Deep Dream 30 | ~~~~~~~~~~ 31 | 32 | .. image:: _static/dream_example.jpg 33 | 34 | .. automodule:: torchelie.recipes.deepdream 35 | 36 | .. autoclass:: torchelie.recipes.deepdream.DeepDream 37 | :members: 38 | :special-members: __call__ 39 | 40 | Feature visualization 41 | ~~~~~~~~~~~~~~~~~~~~~ 42 | 43 | .. image:: _static/vis_example.jpg 44 | 45 | .. automodule:: torchelie.recipes.feature_vis 46 | 47 | .. autoclass:: torchelie.recipes.feature_vis.FeatureVis 48 | :members: 49 | :special-members: __call__ 50 | 51 | Neural Style 52 | ~~~~~~~~~~~~ 53 | 54 | .. image:: _static/style_example.png 55 | 56 | .. automodule:: torchelie.recipes.neural_style 57 | 58 | .. autoclass:: torchelie.recipes.neural_style.NeuralStyle 59 | :members: 60 | :special-members: __call__ 61 | 62 | Deep Image Prior 63 | ~~~~~~~~~~~~~~~~ 64 | 65 | .. automodule:: torchelie.recipes.image_prior 66 | -------------------------------------------------------------------------------- /torchelie/loss/deepdreamloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | import torchelie.nn as tnn 6 | 7 | 8 | class DeepDreamLoss(nn.Module): 9 | """ 10 | The Deep Dream loss 11 | 12 | Args: 13 | model (nn.Module): a pretrained network on which to compute the 14 | activations 15 | dream_layer (str): the name of the layer on which the activations are 16 | to be maximized 17 | max_reduction (int): the maximum factor of reduction of the image, for 18 | multiscale generation 19 | """ 20 | def __init__(self, 21 | model: nn.Module, 22 | dream_layer: str, 23 | max_reduction: int = 3) -> None: 24 | super(DeepDreamLoss, self).__init__() 25 | self.dream_layer = dream_layer 26 | self.octaves = max_reduction 27 | model = model.eval() 28 | self.net = tnn.WithSavedActivations(model, names=[self.dream_layer]) 29 | self.i = 0 30 | 31 | def get_acts_(self, img: torch.Tensor, detach: bool) -> torch.Tensor: 32 | octave = (self.i % (self.octaves * 2)) / 2 + 1 33 | this_sz_img = F.interpolate(img, scale_factor=1 / octave) 34 | _, activations = self.net(this_sz_img, detach=detach) 35 | return activations[self.dream_layer] 36 | 37 | def forward(self, input_img: torch.Tensor) -> torch.Tensor: 38 | """ 39 | Compute the Deep Dream loss on `input_img` 40 | """ 41 | dream = self.get_acts_(input_img, detach=False) 42 | self.i += 1 43 | 44 | dream_loss = -dream.pow(2).sum() 45 | 46 | return dream_loss 47 | -------------------------------------------------------------------------------- /docs/callbacks.rst: -------------------------------------------------------------------------------- 1 | torchelie.callbacks 2 | =================== 3 | 4 | Loggers 5 | ------- 6 | 7 | .. autoclass:: torchelie.callbacks.StdoutLogger 8 | :members: 9 | 10 | .. autoclass:: torchelie.callbacks.VisdomLogger 11 | :members: 12 | 13 | Logging 14 | ------- 15 | 16 | .. autoclass:: torchelie.callbacks.TopkAccAvg 17 | :members: 18 | 19 | .. autoclass:: torchelie.callbacks.AccAvg 20 | :members: 21 | 22 | .. autoclass:: torchelie.callbacks.EpochMetricAvg 23 | :members: 24 | 25 | .. autoclass:: torchelie.callbacks.WindowedMetricAvg 26 | :members: 27 | 28 | .. autoclass:: torchelie.callbacks.Log 29 | :members: 30 | 31 | .. autoclass:: torchelie.callbacks.Counter 32 | :members: 33 | 34 | .. autoclass:: torchelie.callbacks.Throughput 35 | :members: 36 | 37 | Visualization 38 | ------------- 39 | 40 | .. autoclass:: torchelie.callbacks.ConfusionMatrix 41 | :members: 42 | 43 | .. autoclass:: torchelie.callbacks.MetricsTable 44 | :members: 45 | 46 | .. autoclass:: torchelie.callbacks.ClassificationInspector 47 | :members: 48 | 49 | .. autoclass:: torchelie.callbacks.ImageGradientVis 50 | :members: 51 | 52 | .. autoclass:: torchelie.callbacks.GANMetrics 53 | :members: 54 | 55 | .. autoclass:: torchelie.callbacks.SegmentationInspector 56 | :members: 57 | 58 | Training 59 | -------- 60 | 61 | .. autoclass:: torchelie.callbacks.Optimizer 62 | :members: 63 | 64 | .. autoclass:: torchelie.callbacks.LRSched 65 | :members: 66 | 67 | Model 68 | ----- 69 | 70 | .. autoclass:: torchelie.callbacks.Polyak 71 | :members: 72 | 73 | .. autoclass:: torchelie.callbacks.Checkpoint 74 | :members: 75 | 76 | Misc 77 | ---- 78 | 79 | .. autoclass:: torchelie.callbacks.CallRecipe 80 | :members: 81 | 82 | -------------------------------------------------------------------------------- /torchelie/loss/gan/standard.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Standard, non saturating, GAN loss from the original GAN paper 3 | 4 | https://arxiv.org/abs/1406.2661 5 | 6 | :math:`L_D(x_r, x_f) = - \log(1 - D(x_f)) - \log D(x_r)` 7 | 8 | :math:`L_G(x_f) = -\log D(x_f)` 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def real(x: torch.Tensor, reduce: str = 'mean') -> torch.Tensor: 16 | if isinstance(x, (tuple, list)): 17 | return sum(real(xx, reduce) for xx in x) / (len(x) if reduce == 'mean' else 1) 18 | out = F.softplus(-x) 19 | if reduce == 'none': 20 | return out 21 | if reduce == 'mean': 22 | return out.mean() 23 | if reduce == 'batch_mean': 24 | return out.view(out.shape[0], -1).sum(1).mean() 25 | assert False, f'reduction {reduce} invalid' 26 | 27 | 28 | def fake(x: torch.Tensor, reduce: str = 'mean') -> torch.Tensor: 29 | if isinstance(x, (tuple, list)): 30 | return sum(fake(xx, reduce) for xx in x) / (len(x) if reduce == 'mean' else 1) 31 | 32 | out = F.softplus(x) 33 | if reduce == 'none': 34 | return out 35 | if reduce == 'mean': 36 | return out.mean() 37 | if reduce == 'batch_mean': 38 | return out.view(out.shape[0], -1).sum(1).mean() 39 | assert False, f'reduction {reduce} invalid' 40 | 41 | 42 | def generated(x: torch.Tensor, reduce: str = 'mean') -> torch.Tensor: 43 | if isinstance(x, (tuple, list)): 44 | return sum(generated(xx, reduce) for xx in x) / (len(x) if reduce == 'mean' else 1) 45 | 46 | out = F.softplus(-x) 47 | if reduce == 'none': 48 | return out 49 | if reduce == 'mean': 50 | return out.mean() 51 | if reduce == 'batch_mean': 52 | return out.view(out.shape[0], -1).sum(1).mean() 53 | assert False, f'reduction {reduce} invalid' 54 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchelie.models import * 3 | 4 | 5 | def test_patchgan(): 6 | for M in [patch286, patch70, patch34, patch16]: 7 | m = M() 8 | m(torch.randn(1, 3, 128, 128)) 9 | 10 | 11 | def test_pnet(): 12 | pnet = PerceptualNet(['conv5_2']) 13 | pnet(torch.randn(1, 3, 128, 128), detach=True) 14 | 15 | 16 | def test_factored_predictor(): 17 | fp = PixelPredictor(10) 18 | fp(torch.randn(5, 10), torch.randn(5, 3)) 19 | fp.sample(torch.randn(3, 10), 1) 20 | 21 | 22 | def test_pixelcnn(): 23 | pc = PixelCNN(10, (8, 8), 1) 24 | pc(torch.randn(1, 1, 8, 8)) 25 | pc.sample(1, 1) 26 | 27 | pc = PixelCNN(10, (8, 8), 3) 28 | pc(torch.randn(1, 3, 8, 8)) 29 | pc.sample(1, 1) 30 | 31 | 32 | def test_resnet(): 33 | m = ResidualDiscriminator([2, 'D', 3]) 34 | m(torch.randn(1, 3, 8, 8)) 35 | 36 | m = ResidualDiscriminator([2, 'D', 3]) 37 | m.to_projection_discr(3) 38 | m(torch.randn(1, 3, 8, 8), torch.LongTensor([1])) 39 | 40 | def run(M): 41 | m = M(4) 42 | out = m(torch.randn(2, 3, 32, 32)) 43 | out.mean().backward() 44 | 45 | run(resnet18) 46 | run(preact_resnet18) 47 | run(resnet50) 48 | run(preact_resnet50) 49 | run(resnext50_32x4d) 50 | run(preact_resnext50_32x4d) 51 | 52 | 53 | def test_unet(): 54 | m = UNet([3, 6, 12], 1) 55 | m(torch.randn(1, 3, 128, 128)) 56 | 57 | 58 | def test_vgg(): 59 | m = vgg11(2) 60 | m(torch.randn(1, 3, 32, 32)) 61 | 62 | 63 | def test_attention(): 64 | m = attention56(2) 65 | m(torch.randn(2, 3, 32, 32)) 66 | 67 | 68 | def test_hourglass(): 69 | m = Hourglass() 70 | out = m(torch.randn(2, 32, 128, 128)) 71 | assert out.shape == (2, 3, 128, 128) 72 | 73 | 74 | def test_autogan(): 75 | m = AutoGAN([3, 4, 5], in_noise=4) 76 | m(torch.randn(16, 4)) 77 | -------------------------------------------------------------------------------- /torchelie/models/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torchelie.utils as tu 4 | 5 | 6 | class Registry: 7 | 8 | def __init__(self): 9 | self.sources = ['https://s3.eu-west-3.amazonaws.com/torchelie.models'] 10 | self.known_models = {} 11 | 12 | def from_source(self, src: str, model: str) -> dict: 13 | uri = f'{src}/{model}' 14 | if uri.lower().startswith('http'): 15 | return torch.hub.load_state_dict_from_url(uri, 16 | map_location='cpu', 17 | file_name=model.replace( 18 | '/', '.')) 19 | else: 20 | return torch.load(uri, map_location='cpu') 21 | 22 | def fetch(self, model: str) -> dict: 23 | for source in reversed(self.sources): 24 | try: 25 | return self.from_source(source, model) 26 | except Exception as e: 27 | print(f'{model} not found in source {source}, next', str(e)) 28 | raise Exception(f'No source contains pretrained model {model}') 29 | 30 | def register_decorator(self, f): 31 | 32 | def _f(*args, pretrained: Optional[str] = None, **kwargs): 33 | model = f(*args, **kwargs) 34 | if pretrained: 35 | ckpt = self.fetch(f'{pretrained}/{f.__name__}.pth') 36 | tu.load_state_dict_forgiving(model, ckpt) 37 | return model 38 | 39 | self.known_models[f.__name__] = _f 40 | 41 | return _f 42 | 43 | def get_model(self, name, *args, **kwargs): 44 | return self.known_models[name](*args, **kwargs) 45 | 46 | 47 | registry = Registry() 48 | register = registry.register_decorator 49 | get_model = registry.get_model 50 | 51 | __all__ = ['Registry', 'register', 'get_model'] 52 | -------------------------------------------------------------------------------- /torchelie/nn/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Optional 5 | 6 | 7 | class Interpolate2d(nn.Module): 8 | """ 9 | A wrapper around :func:`pytorch.nn.functional.interpolate` 10 | """ 11 | def __init__(self, 12 | mode: str, 13 | size: Optional[List[int]] = None, 14 | scale_factor: Optional[float] = None) -> None: 15 | super().__init__() 16 | self.size = size 17 | self.scale_factor = scale_factor 18 | self.mode = mode 19 | 20 | def forward(self, 21 | x: torch.Tensor, 22 | size: Optional[List[int]] = None) -> torch.Tensor: 23 | rsf = True if self.scale_factor is not None else None 24 | align = False if self.mode != 'nearest' else None 25 | if not size: 26 | return F.interpolate(x, 27 | mode=self.mode, 28 | size=self.size, 29 | scale_factor=self.scale_factor, 30 | recompute_scale_factor=rsf, 31 | align_corners=align) 32 | else: 33 | return F.interpolate(x, 34 | mode=self.mode, 35 | size=size, 36 | recompute_scale_factor=rsf, 37 | align_corners=align) 38 | 39 | def extra_repr(self) -> str: 40 | return f'scale_factor={self.scale_factor} size={self.size}' 41 | 42 | 43 | class InterpolateBilinear2d(Interpolate2d): 44 | """ 45 | A wrapper around :func:`pytorch.nn.functional.interpolate` with bilinear 46 | mode. 47 | """ 48 | def __init__( 49 | self, 50 | size: Optional[List[int]] = None, 51 | scale_factor: Optional[float] = None, 52 | ) -> None: 53 | super().__init__(size=size, scale_factor=scale_factor, mode='bilinear') 54 | -------------------------------------------------------------------------------- /examples/pixelcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchelie.models as tmodels 4 | from torchelie.metrics.avg import WindowAvg 5 | 6 | import torchvision.transforms as TF 7 | from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, SVHN 8 | from torch.utils.data import DataLoader 9 | from torchelie.optim import RAdamW 10 | from torchelie.recipes import TrainAndCall 11 | import torchelie.metrics as tcb 12 | 13 | 14 | def train(model, loader): 15 | def train_step(batch): 16 | x = batch[0] 17 | x = x.expand(-1, 3, -1, -1) 18 | 19 | x2 = model(x * 2 - 1) 20 | loss = F.cross_entropy(x2, (x * 255).long()) 21 | loss.backward() 22 | reconstruction = x2.argmax(dim=1).float() / 255.0 23 | return {'loss': loss, 'reconstruction': reconstruction} 24 | 25 | def after_train(): 26 | imgs = model.sample(1, 4).expand(-1, 3, -1, -1) 27 | return {'imgs': imgs} 28 | 29 | opt = RAdamW(model.parameters(), lr=3e-3) 30 | trainer = TrainAndCall(model, 31 | train_step, 32 | after_train, 33 | dl, 34 | test_every=500, 35 | visdom_env='pixelcnn') 36 | trainer.callbacks.add_callbacks([ 37 | tcb.WindowedMetricAvg('loss'), 38 | tcb.Log('reconstruction', 'reconstruction'), 39 | tcb.Optimizer(opt, log_lr=True), 40 | tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt)) 41 | ]) 42 | trainer.test_loop.callbacks.add_callbacks([ 43 | tcb.Log('imgs', 'imgs'), 44 | ]) 45 | 46 | trainer.to('cuda') 47 | trainer.run(10) 48 | 49 | 50 | tfms = TF.Compose([ 51 | TF.Resize(32), 52 | TF.ToTensor(), 53 | ]) 54 | 55 | dl = DataLoader(FashionMNIST('~/.cache/torch/fashionmnist', 56 | transform=tfms, 57 | download=True), 58 | batch_size=32, 59 | shuffle=True, 60 | num_workers=4) 61 | 62 | model = tmodels.PixelCNN(64, (32, 32), channels=3) 63 | train(model, dl) 64 | -------------------------------------------------------------------------------- /torchelie/models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..nn.transformer import ViTBlock 4 | 5 | 6 | class ViTTrunk(nn.Module): 7 | """ 8 | Vision Transformer (ViT) trunk that processes a sequence of patch embeddings with positional encoding 9 | and optional learnable registers, using a stack of ViTBlock layers. 10 | 11 | Args: 12 | seq_len (int): Length of the input sequence (number of patches). 13 | d_model (int): Dimension of the model. 14 | num_layers (int): Number of transformer blocks. 15 | num_heads (int): Number of attention heads. 16 | num_registers (int, optional): Number of learnable registers to prepend to the sequence. Default: 10. 17 | 18 | Forward Args: 19 | x (Tensor): Input tensor of shape [B, C, H/P, W/P], where P is the patch size. 20 | 21 | Returns: 22 | Tensor: Output tensor of shape [B, C, H/P, W/P]. 23 | """ 24 | 25 | def __init__(self, seq_len, d_model, num_layers, num_heads, num_registers=10): 26 | super().__init__() 27 | self.trunk = nn.ModuleList( 28 | [ViTBlock(d_model, num_heads) for _ in range(num_layers)] 29 | ) 30 | self.pos_enc = nn.Parameter(torch.zeros(seq_len, d_model)) 31 | self.registers = nn.Parameter( 32 | torch.randn(num_registers, d_model) / (d_model**0.5) 33 | ) 34 | 35 | def forward(self, x): 36 | """ 37 | Forward pass for the ViTTrunk. 38 | 39 | Args: 40 | x (Tensor): Input tensor of shape [B, C, H/P, W/P]. 41 | 42 | Returns: 43 | Tensor: Output tensor of shape [B, C, H/P, W/P]. 44 | """ 45 | # x: [B,C,H/P,W/P] 46 | B, C, Hp, Wp = x.shape 47 | x = x.view(B, C, Hp * Wp).permute(0, 2, 1) 48 | x = x + self.pos_enc 49 | # x: [B, L, C] 50 | x = torch.cat([self.registers.unsqueeze(0).expand(B, -1, -1), x], dim=1) 51 | for block in self.trunk: 52 | x = block(x) 53 | 54 | x = x[:, len(self.registers) :, :] 55 | # x = F.gelu(x) 56 | x = x.permute(0, 2, 1).reshape(B, C, Hp, Wp) 57 | # x: [B,C,H/P,W/P] 58 | return x 59 | -------------------------------------------------------------------------------- /torchelie/nn/withsavedactivations.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | from torchelie.utils import layer_by_name 5 | 6 | 7 | class WithSavedActivations(nn.Module): 8 | """ 9 | Hook :code:`model` in order to get intermediate activations. The 10 | activations to save can be either specified by module type or layer name. 11 | """ 12 | def __init__(self, model, types=(nn.Conv2d, nn.Linear), names=None): 13 | super(WithSavedActivations, self).__init__() 14 | self.model = model 15 | self.activations = {} 16 | self.detach = True 17 | self.handles = [] 18 | 19 | self.set_keep_layers(types, names) 20 | 21 | def set_keep_layers(self, types=(nn.Conv2d, nn.Linear), names=None): 22 | for h in self.handles: 23 | h.remove() 24 | 25 | if names is None: 26 | for name, layer in self.model.named_modules(): 27 | if isinstance(layer, types): 28 | h = layer.register_forward_hook(functools.partial( 29 | self._save, name)) 30 | self.handles.append(h) 31 | else: 32 | for name in names: 33 | layer = layer_by_name(self.model, name) 34 | h = layer.register_forward_hook(functools.partial( 35 | self._save, name)) 36 | self.handles.append(h) 37 | 38 | def _save(self, name, module, input, output): 39 | if self.detach: 40 | self.activations[name] = output.detach().clone() 41 | else: 42 | self.activations[name] = output.clone() 43 | 44 | def forward(self, input, detach: bool): 45 | """ 46 | Call :code:`self.model(input)`. 47 | 48 | Args: 49 | input: input to the model 50 | detach (bool): if True, intermediate activations will be 51 | :code:`.detach()`d. 52 | 53 | Returns 54 | model output, a name => activation dict with saved intermediate 55 | activations. 56 | """ 57 | self.detach = detach 58 | self.activations = {} 59 | out = self.model(input) 60 | acts = self.activations 61 | self.activations = {} 62 | return out, acts 63 | -------------------------------------------------------------------------------- /torchelie/nn/functional/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class VectorQuantization(Function): 6 | 7 | @staticmethod 8 | def compute_indices(inputs_orig, codebook): 9 | bi = [] 10 | SZ = 10000 11 | for i in range(0, inputs_orig.size(0), SZ): 12 | inputs = inputs_orig[i:i + SZ] 13 | # NxK 14 | distances_matrix = torch.cdist(inputs, codebook) 15 | # Nx1 16 | indic = torch.min(distances_matrix, dim=-1)[1].unsqueeze(1) 17 | bi.append(indic) 18 | return torch.cat(bi, dim=0) 19 | 20 | @staticmethod 21 | def flatten(x): 22 | code_dim = x.size(-1) 23 | return x.view(-1, code_dim) 24 | 25 | @staticmethod 26 | def restore_shapes(codes, indices, target_shape): 27 | idx_shape = list(target_shape) 28 | idx_shape[-1] = 1 29 | return codes.view(*target_shape), indices.view(*idx_shape) 30 | 31 | @staticmethod 32 | def forward(ctx, inputs, codebook, commitment=0.25, dim=1): 33 | inputs_flat = VectorQuantization.flatten(inputs) 34 | indices = VectorQuantization.compute_indices(inputs_flat, codebook) 35 | codes = codebook[indices.view(-1), :] 36 | codes, indices = VectorQuantization.restore_shapes( 37 | codes, indices, inputs.shape) 38 | 39 | ctx.save_for_backward(codes, inputs, torch.tensor([float(commitment)]), 40 | codebook, indices) 41 | ctx.mark_non_differentiable(indices) 42 | return codes, indices 43 | 44 | @staticmethod 45 | def backward(ctx, straight_through, unused_indices): 46 | codes, inputs, beta, codebook, indices = ctx.saved_tensors 47 | 48 | # TODO: figure out proper vq loss reduction 49 | # vq_loss = F.mse_loss(inputs, codes).detach() 50 | 51 | # gradient of vq_loss 52 | diff = 2 * (inputs - codes) / inputs.numel() 53 | 54 | commitment = beta.item() * diff 55 | 56 | code_disp = VectorQuantization.flatten(-diff) 57 | indices = VectorQuantization.flatten(indices) 58 | code_disp = (torch.zeros_like(codebook).index_add_( 59 | 0, indices.view(-1), code_disp)) 60 | return straight_through + commitment, code_disp, None, None 61 | 62 | 63 | quantize = VectorQuantization.apply 64 | -------------------------------------------------------------------------------- /tests/test_datalearning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchelie.data_learning import * 3 | 4 | 5 | def test_pixel_image(): 6 | pi = PixelImage((1, 3, 128, 128), 0.01) 7 | pi() 8 | 9 | start = torch.randn(3, 128, 128) 10 | pi = PixelImage((1, 3, 128, 128), init_img=start) 11 | 12 | assert start.allclose(pi() + 0.5, atol=1e-7) 13 | 14 | 15 | def test_spectral_image(): 16 | pi = SpectralImage((1, 3, 128, 128), 0.01) 17 | pi() 18 | 19 | start = torch.randn(1, 3, 128, 128) 20 | pi = SpectralImage((1, 3, 128, 128), init_img=start) 21 | 22 | 23 | def test_correlate_colors(): 24 | corr = CorrelateColors() 25 | start = torch.randn(1, 3, 64, 64) 26 | assert start.allclose(corr.invert(corr(start)), atol=1e-5) 27 | 28 | 29 | def test_parameterized_img(): 30 | start = torch.clamp(torch.randn(1, 3, 128, 128) + 0.5, min=0, max=1) 31 | 32 | ParameterizedImg(1, 3, 128, 128, space='spectral', colors='uncorr')() 33 | ParameterizedImg(1, 3, 34 | 128, 35 | 128, 36 | space='spectral', 37 | colors='uncorr', 38 | init_img=start)() 39 | 40 | ParameterizedImg(1, 3, 128, 128, space='spectral', colors='uncorr')() 41 | 42 | start = torch.clamp(torch.randn(1, 3, 128, 129) + 0.5, min=0, max=1) 43 | ParameterizedImg(1, 3, 44 | 128, 45 | 129, 46 | space='spectral', 47 | colors='uncorr', 48 | init_img=start)() 49 | start = torch.clamp(torch.randn(1, 3, 128, 128) + 0.5, min=0, max=1) 50 | ParameterizedImg(1, 3, 128, 128, space='pixel', colors='uncorr')() 51 | ParameterizedImg(1, 3, 52 | 128, 53 | 128, 54 | space='pixel', 55 | colors='uncorr', 56 | init_img=start)() 57 | 58 | ParameterizedImg(1, 3, 128, 128, space='spectral', colors='corr')() 59 | ParameterizedImg(1, 3, 60 | 128, 61 | 128, 62 | space='spectral', 63 | colors='corr', 64 | init_img=start)() 65 | 66 | ParameterizedImg(1, 3, 128, 128, space='pixel', colors='corr')() 67 | ParameterizedImg(1, 3, 128, 128, space='pixel', colors='corr', 68 | init_img=start)() 69 | -------------------------------------------------------------------------------- /torchelie/models/perceptualnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | from .vgg import vgg19 5 | from torchelie.nn.utils import edit_model 6 | from typing import List 7 | from torchelie.nn.withsavedactivations import WithSavedActivations 8 | 9 | class PerceptualNet(WithSavedActivations): 10 | """ 11 | Make a VGG16 with appropriately named layers that records intermediate 12 | activations. 13 | 14 | Args: 15 | layers (list of str): the names of the layers for which to save the 16 | activations. 17 | use_avg_pool (bool): Whether to replace max pooling with averange 18 | pooling (default: True) 19 | remove_unused_layers (bool): whether to remove layers past the last one 20 | used (default: True) 21 | """ 22 | def __init__(self, 23 | layers: List[str], 24 | use_avg_pool: bool = True, 25 | remove_unused_layers: bool = True) -> None: 26 | # yapf: disable 27 | layer_names = [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'maxpool1', 29 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'maxpool2', 30 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 31 | 'conv3_4', 'relu3_4', 'maxpool3', # noqa: E131 32 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 33 | 'conv4_4', 'relu4_4', 'maxpool4', # noqa: E131 34 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 35 | 'conv5_4', 'relu5_4', # 'maxpool5' 36 | ] 37 | # yapf: enable 38 | 39 | m = vgg19(1, pretrained='perceptual/imagenet').features 40 | flat_vgg = [ 41 | layer for layer in m.modules() 42 | if isinstance(layer, (nn.Conv2d, nn.ReLU, nn.MaxPool2d)) 43 | ] 44 | m = nn.Sequential( 45 | OrderedDict([(n, mod) for n, mod in zip(layer_names, flat_vgg)])) 46 | for nm, mod in m.named_modules(): 47 | if 'relu' in nm: 48 | setattr(m, nm, nn.ReLU(False)) 49 | elif 'pool' in nm and use_avg_pool: 50 | setattr(m, nm, nn.AvgPool2d(2, 2)) 51 | 52 | if remove_unused_layers: 53 | m = m[:max([layer_names.index(layer) for layer in layers]) + 1] 54 | 55 | super().__init__(m, names=layers) 56 | -------------------------------------------------------------------------------- /torchelie/recipes/trainandcall.py: -------------------------------------------------------------------------------- 1 | from .trainandtest import TrainAndTest 2 | 3 | 4 | def TrainAndCall(model, 5 | train_fun, 6 | test_fun, 7 | train_loader, 8 | test_every=100, 9 | visdom_env='main', 10 | checkpoint='model', 11 | log_every=10, 12 | key_best=None): 13 | """ 14 | Train a model and evaluate it with a custom function. The model is 15 | automatically registered and checkpointed as :code:`checkpoint['model']`, 16 | and put in eval mode when testing. 17 | 18 | Training callbacks: 19 | 20 | - Counter for counting iterations, connected to the testing loop as well 21 | - VisdomLogger 22 | - StdoutLogger 23 | 24 | Testing: 25 | 26 | Testing loop is in :code:`.test_loop`. 27 | 28 | Testing callbacks: 29 | 30 | - VisdomLogger 31 | - StdoutLogger 32 | - Checkpoint 33 | 34 | Args: 35 | model (nn.Model): a model 36 | train_fun (Callabble): a function that takes a batch as a single 37 | argument, performs a training step and return a dict of values to 38 | populate the recipe's state. 39 | test_fun (Callable): a function taking no argument that performs 40 | something to evaluate your model and returns a dict to populate the 41 | state. 42 | train_loader (DataLoader): Training set dataloader 43 | test_every (int): testing frequency, in number of iterations (default: 44 | 100) 45 | visdom_env (str): name of the visdom environment to use, or None for 46 | not using Visdom (default: None) 47 | checkpoint (str): checkpointing path or None for no checkpointing 48 | log_every (int): logging frequency, in number of iterations (default: 49 | 10) 50 | key_best (function or None): a key function for comparing states. 51 | Checkpointing the greatest. 52 | 53 | Returns: 54 | a configured Recipe 55 | """ 56 | 57 | def test_fun_wrap(_): 58 | return test_fun() 59 | 60 | return TrainAndTest(model, 61 | train_fun, 62 | test_fun_wrap, 63 | train_loader=train_loader, 64 | test_loader=range(1), 65 | test_every=test_every, 66 | visdom_env=visdom_env, 67 | checkpoint=checkpoint, 68 | log_every=log_every, 69 | key_best=key_best) 70 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | import torch # For some reasons, this is absolutely mandatory 16 | sys.path.insert(0, os.path.abspath(__file__) + '/..') 17 | sys.path.insert(0, os.path.abspath(__file__)) 18 | sys.path.insert(0, os.path.abspath('.')) 19 | sys.path.insert(0, os.path.abspath('..')) 20 | autodoc_mock_imports = ["crayons", 'torchvision', 'numpy', 'caffe2'] 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'Torchélie' 25 | copyright = '2019, Guillaume "Vermeille" Sanchez' 26 | author = 'Guillaume "Vermeille" Sanchez' 27 | 28 | # please ReadTheDocs 29 | master_doc = 'index' 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | 'sphinx.ext.autodoc', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 38 | 'sphinx_rtd_theme', 'sphinx.ext.autosummary' 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 48 | 49 | # -- Options for HTML output ------------------------------------------------- 50 | 51 | # The theme to use for HTML and HTML Help pages. See the documentation for 52 | # a list of builtin themes. 53 | # 54 | html_theme = "sphinx_rtd_theme" 55 | autodoc_inherit_docstrings = False 56 | autodoc_member_order = 'groupwise' 57 | 58 | # Add any paths that contain custom static files (such as style sheets) here, 59 | # relative to this directory. They are copied after the builtin static files, 60 | # so a file named "default.css" will overwrite the builtin "default.css". 61 | html_static_path = ['_static'] 62 | html_css_files = [ 63 | 'css/custom.css', 64 | ] 65 | 66 | autosummary_generate = True 67 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchelie.loss import * 4 | import torchelie.loss.gan as gan 5 | import torchelie.loss.functional as tlf 6 | 7 | 8 | def test_bitempered(): 9 | x = torch.randn(3, 5) 10 | y = torch.arange(3) 11 | 12 | tsm = tempered_softmax(x, 1) 13 | sm = torch.nn.functional.softmax(x, dim=1) 14 | assert sm.allclose(tsm) 15 | 16 | tsm = tempered_log_softmax(x, 1) 17 | sm = torch.nn.functional.log_softmax(x, dim=1) 18 | assert sm.allclose(tsm) 19 | 20 | tnll = tempered_nll_loss(sm, y, 1, 1) 21 | nll = torch.nn.functional.nll_loss(sm, y) 22 | assert nll.allclose(tnll) 23 | 24 | temp_loss = tempered_cross_entropy(x, y, 1, 1) 25 | ce_loss = torch.nn.functional.cross_entropy(x, y) 26 | assert ce_loss.allclose(temp_loss) 27 | 28 | tce = TemperedCrossEntropyLoss(1, 1) 29 | assert ce_loss.allclose(tce(x, y)) 30 | 31 | 32 | def test_deepdream(): 33 | m = nn.Sequential(nn.Conv2d(1, 1, 3)) 34 | dd = DeepDreamLoss(m, '0', max_reduction=1) 35 | dd(m(torch.randn(1, 1, 10, 10))) 36 | 37 | 38 | def test_focal(): 39 | y = (torch.randn(10, 1) < 0).float() 40 | x = torch.randn(10, 1) 41 | 42 | foc = FocalLoss() 43 | fl = foc(x, y) 44 | loss = torch.nn.functional.binary_cross_entropy_with_logits(x, y) 45 | assert torch.allclose(fl, loss) 46 | 47 | y = torch.randint(4, (10,)) 48 | x = torch.randn(10, 5) 49 | 50 | fl = foc(x, y) 51 | loss = torch.nn.functional.cross_entropy(x, y) 52 | assert torch.allclose(fl, loss) 53 | 54 | focal_loss(torch.randn(10, 5), torch.randint(4, (10,))) 55 | 56 | 57 | def test_funcs(): 58 | f = OrthoLoss() 59 | f(torch.randn(10, 10)) 60 | ortho(torch.randn(10, 10)) 61 | 62 | f = TotalVariationLoss() 63 | f(torch.randn(1, 1, 10, 10)) 64 | total_variation(torch.randn(1, 1, 10, 10)) 65 | 66 | f = ContinuousCEWithLogits() 67 | continuous_cross_entropy(torch.randn(10, 5), 68 | torch.nn.functional.softmax(torch.randn(10, 5), 1)) 69 | f(torch.randn(10, 5), torch.nn.functional.softmax(torch.randn(10, 5), 1)) 70 | 71 | 72 | def test_neural_style(): 73 | ns = NeuralStyleLoss() 74 | ns.set_content(torch.randn(1, 3, 128, 128)) 75 | ns.set_style(torch.randn(1, 3, 128, 128), 1) 76 | ns(torch.randn(1, 3, 128, 128)) 77 | 78 | 79 | def test_perceptual(): 80 | pl = PerceptualLoss(['conv1_1'], rescale=True) 81 | pl(torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64)) 82 | 83 | 84 | def test_gan(): 85 | x = torch.randn(5, 5) 86 | 87 | gan.hinge.real(x) 88 | gan.hinge.fake(x) 89 | gan.hinge.generated(x) 90 | 91 | gan.standard.real(x) 92 | gan.standard.fake(x) 93 | gan.standard.generated(x) 94 | -------------------------------------------------------------------------------- /torchelie/callbacks/avg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Those classes are different ways of averaging metrics. 3 | """ 4 | 5 | import torchelie.utils as tu 6 | from typing import List, Optional 7 | 8 | 9 | class RunningAvg(tu.AutoStateDict): 10 | """ 11 | Average by keeping the whole sum and number of elements of the data logged. 12 | Useful when the metrics come per batch and an accurate number for the whole 13 | epoch is needed. 14 | """ 15 | def __init__(self) -> None: 16 | super(RunningAvg, self).__init__() 17 | self.count = 0.0 18 | self.val = 0.0 19 | 20 | def log(self, x: float, total: int = 1): 21 | """ 22 | Log metric 23 | 24 | Args: 25 | x: the metric 26 | total: how many element does this represent 27 | """ 28 | self.count += total 29 | self.val += x 30 | 31 | def get(self) -> float: 32 | """ 33 | Get the average so far 34 | """ 35 | if self.count == 0: 36 | return float('nan') 37 | return self.val / self.count 38 | 39 | 40 | class WindowAvg(tu.AutoStateDict): 41 | """ 42 | Average a window containing the `k` previous logged values 43 | 44 | Args: 45 | k (int): the window's length 46 | """ 47 | def __init__(self, k: int = 100) -> None: 48 | super(WindowAvg, self).__init__() 49 | self.vals: List[float] = [] 50 | self.k = k 51 | 52 | def log(self, x: float) -> None: 53 | """ 54 | Log `x` 55 | """ 56 | if len(self.vals) == self.k: 57 | self.vals.pop(0) 58 | self.vals.append(x) 59 | 60 | def get(self) -> float: 61 | """ 62 | Return the value averaged over the window 63 | """ 64 | if len(self.vals) == 0: 65 | return float("nan") 66 | return sum(self.vals) / len(self.vals) 67 | 68 | 69 | class ExponentialAvg(tu.AutoStateDict): 70 | r""" 71 | Keep an exponentially decaying average of the values according to 72 | 73 | :math:`y := \beta y + (1 - \beta) x` 74 | 75 | Args: 76 | beta (float): the decay rate 77 | """ 78 | def __init__(self, beta: float = 0.6): 79 | super(ExponentialAvg, self).__init__(['beta']) 80 | self.beta = beta 81 | self.val: Optional[float] = None 82 | 83 | def log(self, x: float) -> None: 84 | """Log `x`""" 85 | if self.val is None: 86 | self.val = x 87 | else: 88 | self.val = self.beta * self.val + (1 - self.beta) * x 89 | 90 | def get(self) -> float: 91 | """Return the exponential average at this time step""" 92 | assert self.val is not None, 'no value yet' 93 | return self.val 94 | -------------------------------------------------------------------------------- /torchelie/models/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchelie.nn as tnn 4 | import torchelie.utils as tu 5 | from .registry import register 6 | from .classifier import ClassificationHead 7 | 8 | 9 | class LayerScale(nn.Module): 10 | def __init__(self, num_features): 11 | super().__init__() 12 | self.scales = nn.Parameter(torch.empty(num_features)) 13 | nn.init.normal_(self.scales, 0, 1e-5) 14 | 15 | def forward(self, w): 16 | s = self.scales.view(-1, *([1] * (w.ndim - 1))) 17 | return s * w 18 | 19 | 20 | class ConvNeXtBlock(nn.Module): 21 | def __init__(self, ch): 22 | super().__init__() 23 | e = 4 24 | self.branch = nn.Sequential( 25 | nn.Conv2d(ch, ch, kernel_size=7, padding=3, groups=ch, bias=False), 26 | nn.GroupNorm(1, ch), tu.kaiming(tnn.Conv1x1(ch, ch * e)), 27 | nn.SiLU(True), tu.constant_init(tnn.Conv1x1(ch * e, ch), 0)) 28 | nn.utils.parametrize.register_parametrization(self.branch[-1], 29 | 'weight', LayerScale(ch)) 30 | 31 | def forward(self, x): 32 | return self.branch(x).add_(x) 33 | 34 | 35 | class ConvNeXt(nn.Sequential): 36 | def __init__(self, num_classes, arch): 37 | super().__init__() 38 | self.add_module('input', tu.xavier(nn.Conv2d(3, arch[0], 4, 1, 0))) 39 | 40 | prev_ch = arch[0] 41 | ch = arch[0] 42 | self.add_module(f'norm0', nn.GroupNorm(1, ch)) 43 | self.add_module(f'act0', nn.SiLU(True)) 44 | for i in range(len(arch)): 45 | if isinstance(arch[i], int): 46 | ch = arch[i] 47 | self.add_module(f'layer{i}', ConvNeXtBlock(ch)) 48 | prev_ch = ch 49 | else: 50 | assert arch[i] == 'D' 51 | self.add_module(f'act{i}', nn.SiLU(True)) 52 | self.add_module(f'norm{i}', nn.GroupNorm(1, ch)) 53 | self.add_module( 54 | f'layer{i}', tu.kaiming(nn.Conv2d(ch, arch[i + 1], 2, 2, 55 | 0))) 56 | self.add_module(f'norm{i}', nn.GroupNorm(1, ch)) 57 | 58 | self.add_module('classifier', 59 | ClassificationHead(arch[-1], num_classes)) 60 | 61 | 62 | @register 63 | def convnext_xxt(num_classes): 64 | return ConvNeXt(num_classes, [64, 'D', 128, 'D', 256, 256, 256, 'D', 512]) 65 | 66 | 67 | @register 68 | def convnext_xt(num_classes): 69 | return ConvNeXt(num_classes, [96, 'D', 192, 'D', 384, 384, 384, 'D', 768]) 70 | 71 | 72 | @register 73 | def convnext_t(num_classes): 74 | return ConvNeXt(num_classes, [96] * 3 + ['D'] + [192] * 3 + ['D'] + 75 | [384] * 9 + ['D'] + [768] * 3) 76 | -------------------------------------------------------------------------------- /torchelie/loss/perceptualloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Callable, cast, Union, Tuple 5 | 6 | from torchelie.nn.imagenetinputnorm import ImageNetInputNorm 7 | from torchelie.models.perceptualnet import PerceptualNet 8 | 9 | 10 | class PerceptualLoss(nn.Module): 11 | r""" 12 | Perceptual loss: the distance between a two images deep representation 13 | 14 | :math:`\text{Percept}(\text{input}, \text{target})=\sum_l^{layers} 15 | \text{loss_fn}(\text{Vgg}(\text{input})_l, \text{Vgg}(\text{target})_l)` 16 | 17 | Args: 18 | l (list of str): the layers on which to compare the representations 19 | rescale (bool): whether to scale images smaller side to 224 as 20 | expected by the underlying vgg net 21 | loss_fn (distance function): a distance function to compare the 22 | representations, like mse_loss or l1_loss 23 | """ 24 | def __init__(self, 25 | layers: Union[List[str], List[Tuple[str, float]]], 26 | rescale: bool = False, 27 | loss_fn: Callable[[torch.Tensor, torch.Tensor], 28 | torch.Tensor] = F.mse_loss, 29 | use_avg_pool: bool = True, 30 | remove_unused_layers: bool = True): 31 | super(PerceptualLoss, self).__init__() 32 | 33 | def key(l): 34 | if isinstance(l, (tuple, list)): 35 | return l[0] 36 | else: 37 | return l 38 | 39 | def weight(l): 40 | if isinstance(l, (tuple, list)): 41 | return l[1] 42 | else: 43 | return 1 44 | 45 | self.weight = {key(l): weight(l) for l in layers} 46 | self.m = PerceptualNet([key(l) for l in layers], 47 | use_avg_pool=use_avg_pool, 48 | remove_unused_layers=remove_unused_layers) 49 | self.norm = ImageNetInputNorm() 50 | self.rescale = rescale 51 | self.loss_fn = loss_fn 52 | 53 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 54 | """ 55 | Return the perceptual loss between batch of images `x` and `y` 56 | """ 57 | if self.rescale: 58 | s = 224 / min(y.shape[-2:]) 59 | y = F.interpolate(y, scale_factor=s, mode='area') 60 | s = 224 / min(x.shape[-2:]) 61 | x = F.interpolate(x, scale_factor=s, mode='area') 62 | 63 | _, ref = self.m(self.norm(y), detach=True) 64 | _, acts = self.m(self.norm(x), detach=False) 65 | loss = cast( 66 | torch.Tensor, 67 | sum(self.weight[k] * self.loss_fn(acts[k], ref[k]) 68 | for k in acts.keys())) 69 | return loss / len(acts) 70 | -------------------------------------------------------------------------------- /examples/conditional.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import crayons 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torchvision.datasets import MNIST, CIFAR10 11 | import torchvision.transforms as TF 12 | 13 | import torchelie.nn as tnn 14 | import torchelie.models 15 | import torchelie as tch 16 | from torchelie.models import ClassCondResNetDebug 17 | from torchelie.utils import nb_parameters 18 | from torchelie.recipes.classification import Classification 19 | from torchelie.optim import RAdamW 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--cpu', action='store_true') 23 | parser.add_argument('--dataset', 24 | type=str, 25 | choices=['mnist', 'cifar10'], 26 | default='mnist') 27 | parser.add_argument('--models', default='all') 28 | opts = parser.parse_args() 29 | 30 | device = 'cpu' if opts.cpu else 'cuda' 31 | 32 | 33 | class TrueOrFakeLabelDataset: 34 | def __init__(self, dataset): 35 | self.dataset = dataset 36 | self.classes = ['Fake', 'True'] 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | 41 | def __getitem__(self, i): 42 | x, y = self.dataset[i] 43 | if torch.randn(1).item() < 0: 44 | return x, 1, y 45 | return x, 0, torch.randint(0, 10, (1, )).item() 46 | 47 | 48 | tfms = TF.Compose([TF.Resize(32), TF.ToTensor()]) 49 | ds = TrueOrFakeLabelDataset( 50 | MNIST('~/.cache/torch/mnist', download=True, transform=tfms)) 51 | dt = TrueOrFakeLabelDataset( 52 | MNIST('~/.cache/torch/mnist', download=True, transform=tfms, train=False)) 53 | dl = torch.utils.data.DataLoader(ds, 54 | num_workers=4, 55 | batch_size=32, 56 | shuffle=True) 57 | dlt = torch.utils.data.DataLoader(dt, 58 | num_workers=4, 59 | batch_size=32, 60 | shuffle=True) 61 | 62 | 63 | def train_net(): 64 | model = ClassCondResNetDebug(2, 10, in_ch=1) 65 | 66 | def train_step(batch): 67 | x, y, z = batch 68 | 69 | out = model(x, z) 70 | loss = F.cross_entropy(out, y) 71 | loss.backward() 72 | 73 | return {'loss': loss, 'pred': out} 74 | 75 | def validation_step(batch): 76 | x, y, z = batch 77 | 78 | out = model(x, z) 79 | loss = F.cross_entropy(out, y) 80 | return {'loss': loss, 'pred': out} 81 | 82 | clf = Classification(model, train_step, validation_step, dl, dlt, 83 | ds.classes).to(device) 84 | clf.callbacks.add_callbacks( 85 | [tch.callbacks.Optimizer(tch.optim.RAdamW(model.parameters()))]) 86 | clf.run(2) 87 | 88 | 89 | train_net() 90 | -------------------------------------------------------------------------------- /torchelie/recipes/algorithm.py: -------------------------------------------------------------------------------- 1 | import torchelie.utils as tu 2 | from collections import OrderedDict 3 | 4 | 5 | @tu.experimental 6 | class Algorithm: 7 | """ 8 | Define a customizable sequence of code blocks. 9 | """ 10 | 11 | def __init__(self) -> None: 12 | self.passes = OrderedDict() 13 | 14 | def add_step(self, name: str, f=None): 15 | if name in self.passes: 16 | raise KeyError(f'{name} is already in the algorithm') 17 | 18 | if f is not None: 19 | self.passes[name] = f 20 | return 21 | 22 | def _f(func): 23 | self.passes[name] = func 24 | return func 25 | 26 | return _f 27 | 28 | def override_step(self, name: str, f=None): 29 | if name not in self.passes: 30 | raise KeyError(f'{name} was not present in the algorithm') 31 | 32 | if f is not None: 33 | self.passes[name] = f 34 | return 35 | 36 | def _f(func): 37 | self.passes[name] = func 38 | return func 39 | 40 | return _f 41 | 42 | def __call__(self, *args, **kwargs): 43 | env = {} 44 | output = {} 45 | for pass_name, pass_ in self.passes.items(): 46 | try: 47 | out = pass_(env, *args, **kwargs) 48 | except Exception as e: 49 | print('Error during pass', pass_name) 50 | raise e 51 | output.update(out) 52 | return output 53 | 54 | def remove_step(self, name: str): 55 | if name in self.passes: 56 | del self.passes[name] 57 | 58 | def insert_before(self, key: str, name: str, func=None): 59 | 60 | def _f(f): 61 | funs = list(self.passes.items()) 62 | idx = [i for i, (k, v) in enumerate(funs) if k == key][0] 63 | funs[idx:idx] = [(name, f)] 64 | self.passes = OrderedDict(funs) 65 | return f 66 | 67 | if func is None: 68 | return _f 69 | else: 70 | _f(func) 71 | 72 | def insert_after(self, key: str, name: str, func=None): 73 | 74 | def _f(f): 75 | funs = list(self.passes.items()) 76 | idx = [i for i, (k, v) in enumerate(funs) if k == key][0] 77 | funs[idx + 1:idx + 1] = [(name, f)] 78 | self.passes = OrderedDict(funs) 79 | return f 80 | 81 | if func is None: 82 | return _f 83 | else: 84 | _f(func) 85 | 86 | def __getitem__(self, name: str): 87 | return self.passes[name] 88 | 89 | def __setitem__(self, name: str, value): 90 | self.passes[name] = value 91 | 92 | def __repr__(self) -> str: 93 | return (self.__class__.__name__ + '\n' + 94 | tu.indent('\n'.join(list(self.passes.keys()))) + "\n") 95 | -------------------------------------------------------------------------------- /tests/test_tensorboard_callback.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torchelie.callbacks as tcb 4 | import matplotlib.pyplot as plt 5 | from torch.utils.data import DataLoader 6 | from torchvision.datasets import FashionMNIST 7 | from torchvision.transforms import PILToTensor 8 | 9 | 10 | @pytest.mark.require_tensorboard 11 | def test_tesorboard(): 12 | from torchelie.recipes import Recipe 13 | 14 | batch_size = 4 15 | 16 | class Dataset: 17 | def __init__(self, batch_size): 18 | self.batch_size = batch_size 19 | self.mnist = FashionMNIST('.', download=True, transform=PILToTensor()) 20 | self.classes = self.mnist.classes 21 | self.num_classes = len(self.mnist.class_to_idx) 22 | self.target_by_classes = [[idx for idx in range(len(self.mnist)) if self.mnist.targets[idx] == i] 23 | for i in range(self.num_classes)] 24 | 25 | def __len__(self): 26 | return self.batch_size * self.num_classes 27 | 28 | def __getitem__(self, item): 29 | idx = self.target_by_classes[item//self.batch_size][item] 30 | x, y = self.mnist[idx] 31 | x = torch.stack(3*[x]).squeeze() 32 | x[2] = 0 33 | return x, y 34 | 35 | dst = Dataset(batch_size) 36 | 37 | def train(b): 38 | x, y = b 39 | fig = plt.figure() 40 | plt.plot([0, int(y[0])]) 41 | return {'letter_number_int': int(y[0]), 42 | 'letter_number_tensor': y[0], 43 | 'letter_text': dst.classes[int(y[0])], 44 | 'test_html': 'test HTML', 45 | 'letter_gray_img_HW': x[0, 0], 46 | 'letter_gray_img_CHW': x[0, :1], 47 | 'letter_gray_imgs_NCHW': x[:, :1], 48 | 'letter_color_img_CHW': x[0], 49 | 'letter_color_imgs_NCHW': x, 50 | 'test_matplotlib': fig} 51 | 52 | r = Recipe(train, DataLoader(dst, batch_size)) 53 | r.callbacks.add_callbacks([ 54 | tcb.Counter(), 55 | tcb.TensorboardLogger(log_every=1), 56 | tcb.Log('letter_number_int', 'letter_number_int'), 57 | tcb.Log('letter_number_tensor', 'letter_number_tensor'), 58 | tcb.Log('letter_text', 'letter_text'), 59 | tcb.Log('test_html', 'test_html'), 60 | tcb.Log('letter_gray_img_HW', 'letter_gray_img_HW'), 61 | tcb.Log('letter_gray_img_CHW', 'letter_gray_img_CHW'), 62 | tcb.Log('letter_gray_imgs_NCHW', 'letter_gray_imgs_NCHW'), 63 | tcb.Log('letter_color_img_CHW', 'letter_color_img_CHW'), 64 | tcb.Log('letter_color_imgs_NCHW', 'letter_color_imgs_NCHW'), 65 | tcb.Log('test_matplotlib', 'test_matplotlib'), 66 | ]) 67 | r.run(1) 68 | 69 | 70 | if __name__ == '__main__': 71 | test_tesorboard() 72 | -------------------------------------------------------------------------------- /torchelie/datasets/ms1m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from io import BytesIO 4 | import struct 5 | from typing import List, Tuple, Any 6 | 7 | 8 | class MS1M: 9 | def __init__(self, rec_file: str, idx_file: str, transform=None) -> None: 10 | self.rec_file = rec_file 11 | self.idx_file = idx_file 12 | 13 | offsets = self.read_idx(idx_file) 14 | max_id = int(self.read_metadata(offsets[-1])[1][0]) - 1 15 | self.transform = transform 16 | try: 17 | self.samples = torch.load('ms1m_cache.pth') 18 | except: 19 | self.samples = [(str(off), int(self.read_metadata(off)[1])) 20 | for off, i in zip(offsets, range(max_id))] 21 | torch.save(self.samples, 'ms1m_cache.pth') 22 | self.classes = [str(i) for i in range(85742)] 23 | self.class_to_idx = {k: i for i, k in enumerate(self.classes)} 24 | self.imgs = self.samples 25 | 26 | def __len__(self) -> int: 27 | return len(self.samples) 28 | 29 | @staticmethod 30 | def read_idx(idx_file: str) -> List[int]: 31 | indices = [] 32 | with open(idx_file, 'r') as f: 33 | for line in f.readlines(): 34 | n, offset = line.strip().split('\t') 35 | indices.append(int(offset)) 36 | return indices 37 | 38 | def read_metadata(self, offset: int) -> Tuple[bytes, Any]: 39 | with open(self.rec_file, 'rb') as rec_handle: 40 | rec_handle.seek(offset) 41 | magic, lrec = struct.unpack('> 29 44 | assert cflag == 0 45 | length = lrec & ~(3 << 29) 46 | header_sz = struct.calcsize('IfQQ') 47 | flag, label, id1, id2 = struct.unpack('IfQQ', 48 | rec_handle.read(header_sz)) 49 | if flag > 0: 50 | label = struct.unpack('f' * flag, rec_handle.read(4 * flag)) 51 | header_sz -= 4 * flag 52 | img_bytes = rec_handle.read(length - header_sz) 53 | return img_bytes, label 54 | 55 | def __getitem__(self, i: int) -> Tuple[Any, int]: 56 | offset, label = self.samples[i] 57 | img_bytes, _ = self.read_metadata(int(offset)) 58 | assert img_bytes[:3] == b'\xff\xd8\xff' 59 | 60 | with BytesIO(img_bytes) as dat: 61 | img = Image.open(dat).convert('RGB') 62 | 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | 66 | return img, int(label) 67 | 68 | def __repr__(self) -> str: 69 | return (f"MS1M Dataset:\n" 70 | f" n_samples: {len(self.samples)}\n" 71 | f" n_classes: {len(self.classes)}\n") 72 | -------------------------------------------------------------------------------- /examples/gan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | 4 | import torch 5 | 6 | from torchvision.datasets import MNIST, CIFAR10 7 | import torchvision.transforms as TF 8 | 9 | import torchelie as tch 10 | import torchelie.loss.gan.hinge as gan_loss 11 | from torchelie.recipes.gan import GANRecipe 12 | import torchelie.callbacks as tcb 13 | from torchelie.recipes import Recipe 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--cpu', action='store_true') 17 | opts = parser.parse_args() 18 | 19 | device = 'cpu' if opts.cpu else 'cuda' 20 | BS = 32 21 | 22 | tfms = TF.Compose([ 23 | TF.Resize(64), 24 | tch.transforms.AdaptPad((64, 64)), 25 | TF.RandomHorizontalFlip(), 26 | TF.ToTensor()]) 27 | ds = CIFAR10('~/.cache/torch/cifar10', download=True, transform=tfms) 28 | dl = torch.utils.data.DataLoader(ds, 29 | num_workers=4, 30 | batch_size=BS, 31 | shuffle=True) 32 | 33 | 34 | def train_net(Gen, Discr): 35 | G = Gen(in_noise=128, out_ch=3) 36 | G_polyak = copy.deepcopy(G).eval() 37 | D = Discr() 38 | print(G) 39 | print(D) 40 | 41 | def G_fun(batch): 42 | z = torch.randn(BS, 128, device=device) 43 | fake = G(z) 44 | preds = D(fake * 2 - 1).squeeze() 45 | loss = gan_loss.generated(preds) 46 | loss.backward() 47 | return {'loss': loss.item(), 'imgs': fake.detach()} 48 | 49 | def G_polyak_fun(batch): 50 | z = torch.randn(BS, 128, device=device) 51 | fake = G_polyak(z) 52 | return {'imgs': fake.detach()} 53 | 54 | def D_fun(batch): 55 | z = torch.randn(BS, 128, device=device) 56 | fake = G(z) 57 | fake_loss = gan_loss.fake(D(fake * 2 - 1)) 58 | fake_loss.backward() 59 | 60 | x = batch[0] 61 | 62 | real_loss = gan_loss.real(D(x * 2 - 1)) 63 | real_loss.backward() 64 | 65 | loss = real_loss.item() + fake_loss.item() 66 | return {'loss': loss, 'real_loss': real_loss.item(), 'fake_loss': 67 | fake_loss.item()} 68 | 69 | loop = GANRecipe(G, D, G_fun, D_fun, G_polyak_fun, dl, log_every=100).to(device) 70 | loop.register('polyak', G_polyak) 71 | loop.G_loop.callbacks.add_callbacks([ 72 | tcb.Optimizer(tch.optim.RAdamW(G.parameters(), lr=1e-4, betas=(0., 0.99))), 73 | tcb.Polyak(G, G_polyak), 74 | ]) 75 | loop.register('G_polyak', G_polyak) 76 | loop.callbacks.add_callbacks([ 77 | tcb.Log('batch.0', 'x'), 78 | tcb.WindowedMetricAvg('real_loss'), 79 | tcb.WindowedMetricAvg('fake_loss'), 80 | tcb.Optimizer(tch.optim.RAdamW(D.parameters(), lr=4e-4, betas=(0., 0.99))), 81 | ]) 82 | loop.test_loop.callbacks.add_callbacks([ 83 | tcb.Log('imgs', 'polyak_imgs'), 84 | tcb.VisdomLogger('main', prefix='test') 85 | ]) 86 | loop.to(device).run(100) 87 | 88 | 89 | train_net(tch.models.autogan_64, tch.models.snres_discr_4l) 90 | -------------------------------------------------------------------------------- /torchelie/models/alexnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torchelie.nn as tnn 3 | import torch.nn as nn 4 | 5 | from .registry import register 6 | from .classifier import ClassificationHead 7 | 8 | __all__ = ['AlexNet', 'alexnet', 'alexnet_bn', 'ZFNet', 'zfnet', 'zfnet_bn'] 9 | 10 | 11 | class AlexNet(tnn.CondSeq): 12 | 13 | def __init__(self, num_classes): 14 | super().__init__() 15 | self.features = tnn.CondSeq( 16 | OrderedDict([ 17 | # 224 -> 112 -> 56 18 | ('conv1', tnn.ConvBlock(3, 64, kernel_size=11, stride=4)), 19 | # 56 -> 28 20 | ('pool1', nn.MaxPool2d(3, 2, 1)), 21 | ('conv2', tnn.ConvBlock(64, 192, kernel_size=5)), 22 | # 28 -> 14 23 | ('pool2', nn.MaxPool2d(3, 2, 1)), 24 | ('conv3', tnn.ConvBlock(192, 384, 3)), 25 | ('conv4', tnn.ConvBlock(384, 256, 3)), 26 | ('conv5', tnn.ConvBlock(256, 256, 3)), 27 | # 28 -> 14 28 | ('pool3', nn.MaxPool2d(3, 2, 1)), 29 | ])) 30 | self.classifier = ClassificationHead(256, num_classes) 31 | self.classifier = self.classifier.to_two_layers(4096).set_pool_size(7) 32 | 33 | def remove_batchnorm(self): 34 | for m in self.features: 35 | if isinstance(m, tnn.ConvBlock): 36 | m.remove_batchnorm() 37 | return self 38 | 39 | 40 | @register 41 | def alexnet(num_classes): 42 | return AlexNet(num_classes).remove_batchnorm() 43 | 44 | 45 | @register 46 | def alexnet_bn(num_classes): 47 | return AlexNet(num_classes) 48 | 49 | 50 | class ZFNet(tnn.CondSeq): 51 | 52 | def __init__(self, num_classes): 53 | super().__init__() 54 | self.features = tnn.CondSeq( 55 | OrderedDict([ 56 | # 224 -> 112 57 | ('conv1', tnn.ConvBlock(3, 96, kernel_size=7, stride=2)), 58 | # 112 -> 56 59 | ('pool1', nn.MaxPool2d(3, 2, 1)), 60 | # 56 -> 28 61 | ('conv2', tnn.ConvBlock(96, 256, kernel_size=5, stride=2)), 62 | # 28 -> 14 63 | ('pool2', nn.MaxPool2d(3, 2, 1)), 64 | ('conv3', tnn.ConvBlock(256, 384, 3)), 65 | ('conv4', tnn.ConvBlock(384, 256, 3)), 66 | ('conv5', tnn.ConvBlock(256, 256, 3)), 67 | # 28 -> 14 68 | ('pool3', nn.MaxPool2d(3, 2, 1)), 69 | ])) 70 | self.classifier = ClassificationHead(256, num_classes) 71 | self.classifier = self.classifier.to_two_layers(4096).set_pool_size(7) 72 | 73 | def remove_batchnorm(self): 74 | for m in self.features: 75 | if isinstance(m, tnn.ConvBlock): 76 | m.remove_batchnorm() 77 | return self 78 | 79 | 80 | @register 81 | def zfnet(num_classes): 82 | return ZFNet(num_classes).remove_batchnorm() 83 | 84 | 85 | @register 86 | def zfnet_bn(num_classes): 87 | return ZFNet(num_classes) 88 | -------------------------------------------------------------------------------- /docs/nn.rst: -------------------------------------------------------------------------------- 1 | torchelie.nn 2 | ============ 3 | 4 | Convolutions 5 | ~~~~~~~~~~~~ 6 | 7 | .. currentmodule:: torchelie.nn 8 | .. autosummary:: 9 | :toctree: generated 10 | :template: klass.rst 11 | :nosignatures: 12 | 13 | Conv2d 14 | Conv3x3 15 | Conv1x1 16 | MaskedConv2d 17 | TopLeftConv2d 18 | 19 | Normalization 20 | ~~~~~~~~~~~~~ 21 | 22 | .. autosummary:: 23 | :toctree: generated 24 | :template: klass.rst 25 | :nosignatures: 26 | 27 | AdaIN2d 28 | FiLM2d 29 | PixelNorm 30 | ImageNetInputNorm 31 | ConditionalBN2d 32 | Spade2d 33 | AttenNorm2d 34 | GhostBatchNorm2d 35 | 36 | Misc 37 | ~~~~ 38 | 39 | .. autosummary:: 40 | :toctree: generated 41 | :template: klass.rst 42 | :nosignatures: 43 | 44 | VQ 45 | MultiVQ 46 | Noise 47 | Debug 48 | Dummy 49 | Lambda 50 | Reshape 51 | Interpolate2d 52 | InterpolateBilinear2d 53 | AdaptiveConcatPool2d 54 | ModulatedConv 55 | SelfAttention2d 56 | GaussianPriorFunc 57 | UnitGaussianPrior 58 | InformationBottleneck 59 | Const 60 | SinePositionEncoding2d 61 | MinibatchStddev 62 | 63 | Blocks 64 | ~~~~~~ 65 | 66 | .. autosummary:: 67 | :toctree: generated 68 | :template: klass.rst 69 | :nosignatures: 70 | 71 | ConvBlock 72 | MConvNormReLU 73 | MConvBNReLU 74 | SpadeResBlock 75 | AutoGANGenBlock 76 | ResidualDiscrBlock 77 | StyleGAN2Block 78 | SEBlock 79 | PreactResBlock 80 | PreactResBlockBottleneck 81 | ResBlock 82 | ResBlockBottleneck 83 | ConvDeconvBlock 84 | UBlock 85 | 86 | 87 | Sequential 88 | ~~~~~~~~~~ 89 | 90 | .. autosummary:: 91 | :toctree: generated 92 | :template: klass.rst 93 | :nosignatures: 94 | 95 | WithSavedActivations 96 | CondSeq 97 | ModuleGraph 98 | 99 | Activations 100 | ~~~~~~~~~~~ 101 | 102 | .. autosummary:: 103 | :toctree: generated 104 | :template: klass.rst 105 | :nosignatures: 106 | 107 | HardSigmoid 108 | HardSwish 109 | 110 | torchelie.nn.utils 111 | ================== 112 | 113 | .. currentmodule:: torchelie.nn.utils 114 | 115 | .. autosummary:: 116 | :toctree: generated 117 | :template: klass.rst 118 | :nosignatures: 119 | 120 | receptive_field_for 121 | 122 | Model edition 123 | ~~~~~~~~~~~~~ 124 | 125 | .. autosummary:: 126 | :toctree: generated 127 | :template: klass.rst 128 | :nosignatures: 129 | 130 | edit_model 131 | insert_after 132 | insert_before 133 | make_leaky 134 | remove_batchnorm 135 | 136 | Lambda 137 | ~~~~~~ 138 | 139 | .. autosummary:: 140 | :toctree: generated 141 | :template: klass.rst 142 | :nosignatures: 143 | 144 | WeightLambda 145 | weight_lambda 146 | remove_weight_lambda 147 | 148 | Weight normalization / equalized learning rate 149 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 150 | 151 | .. autosummary:: 152 | :toctree: generated 153 | :template: klass.rst 154 | :nosignatures: 155 | 156 | weight_norm_and_equal_lr 157 | remove_weight_norm_and_equal_lr 158 | remove_weight_scale 159 | weight_scale 160 | net_to_equal_lr 161 | net_remove_weight_norm_and_equal_lr 162 | 163 | -------------------------------------------------------------------------------- /torchelie/nn/graph.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Union, Tuple, List 3 | from torchelie.utils import experimental 4 | 5 | 6 | def tup(x): 7 | if isinstance(x, (tuple, list)): 8 | return list(x) 9 | return [x] 10 | 11 | 12 | ArgNames = Union[str, List[str]] 13 | NamedModule = Tuple[str, nn.Module] 14 | 15 | 16 | class ModuleGraph(nn.Sequential): 17 | """ 18 | Allows description of networks as computation graphs. The graph is 19 | constructed by labelling inputs and outputs of each node. Each node will be 20 | ran in declaration order, fetching its input values from a pool of named 21 | values populated from previous node's output values and keyword arguments 22 | in forward. 23 | 24 | Simple example: 25 | 26 | >>> m = tnn.ModuleGraph(outputs='y') 27 | >>> m.add_operation( 28 | inputs=['x'], 29 | operation=nn.Linear(10, 20), 30 | name='linear', 31 | outputs=['y']) 32 | >>> m(x=torch.randn(1, 10)) 33 | 34 | 35 | Multiple inputs example: 36 | 37 | If a layer takes more than 1 input, labels can be a tuple or a list of 38 | labels instead. The same applies if a module returns more than 1 output 39 | values. 40 | 41 | >>> m = tnn.ModuleGraph(outputs=['x1', 'y']) 42 | >>> m.add_operation( 43 | inputs=['x0'], 44 | operation=nn.Linear(10, 20) 45 | name='linear', 46 | outputs=['x1']) 47 | >>> m.add_operation( 48 | inputs=['x1', 'z'], 49 | operation=nn.AdaIN2d(20, 3) 50 | name='adain', 51 | outputs=['y']) 52 | >>> m(x0=torch.randn(1, 10), z=torch.randn(1, 3))['y'] 53 | 54 | """ 55 | def __init__(self, outputs: Union[str, List[str]]) -> None: 56 | super().__init__() 57 | self.ins: List[List[str]] = [] 58 | self.outs: List[List[str]] = [] 59 | 60 | self.outputs = outputs 61 | 62 | def add_operation(self, inputs: List[str], outputs: List[str], name: str, 63 | operation: nn.Module) -> 'ModuleGraph': 64 | self.ins.append(inputs) 65 | self.outs.append(outputs) 66 | self.add_module(name, operation) 67 | return self 68 | 69 | def forward(self, **args): 70 | variables = dict(args) 71 | 72 | for i_names, f, o_names in zip(self.ins, self._modules.values(), 73 | self.outs): 74 | ins = [variables[k] for k in i_names] 75 | outs = tup(f(*ins)) 76 | for o, k in zip(outs, o_names): 77 | variables[k] = o 78 | 79 | if isinstance(self.outputs, str): 80 | return variables[self.outputs] 81 | return {k: variables[k] for k in self.outputs} 82 | 83 | @experimental 84 | def to_dot(self) -> str: 85 | txt = '' 86 | for i_names, f_nm, o_names in zip(self.ins, self._modules.keys(), 87 | self.outs): 88 | for k in i_names: 89 | txt += f'{k} -> {f_nm};\n' 90 | for k in o_names: 91 | txt += f'{f_nm} -> {k};\n' 92 | txt += f'{f_nm} [shape=square];\n' 93 | return txt 94 | -------------------------------------------------------------------------------- /torchelie/models/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, cast 3 | 4 | import torchelie.nn as tnn 5 | import torchelie.utils as tu 6 | import torch.nn as nn 7 | from .unet import UNet 8 | 9 | 10 | class Pix2PixGenerator(UNet): 11 | """ 12 | UNet generator from Pix2Pix. Dropout layers have been substitued with Noise 13 | injections from StyleGAN2. 14 | 15 | Args: 16 | arch (List[int]): the number of channel for each depth level of the 17 | UNet. 18 | """ 19 | 20 | def __init__(self, arch: List[int]) -> None: 21 | super().__init__(arch, 3) 22 | self.remove_first_batchnorm() 23 | 24 | self.features.input = tnn.ConvBlock(3, int(arch[0]), 3) 25 | self.features.input.remove_batchnorm() 26 | 27 | encdec = cast(nn.Module, self.features.encoder_decoder) 28 | for m in encdec.modules(): 29 | if isinstance(m, tnn.UBlock): 30 | m.to_bilinear_sampling() 31 | m.set_encoder_num_layers(1) 32 | m.set_decoder_num_layers(1) 33 | 34 | tnn.utils.make_leaky(self) 35 | self.classifier.relu = nn.Sigmoid() 36 | 37 | self._add_noise() 38 | 39 | def _add_noise(self): 40 | layers = [self.features.encoder_decoder] 41 | while hasattr(layers[-1].inner, 'inner'): 42 | layers.append(layers[-1].inner) 43 | tnn.utils.insert_after(layers[-1].inner, 'norm', tnn.Noise(1, True), 44 | 'noise') 45 | 46 | for m in layers: 47 | tnn.utils.insert_after( 48 | m.out_conv.conv_0, 'norm', 49 | tnn.Noise(m.out_conv.conv_0.out_channels, True), 'noise') 50 | 51 | def to_equal_lr(self) -> 'Pix2PixGenerator': 52 | return tnn.utils.net_to_equal_lr(self) 53 | 54 | def set_padding_mode(self, mode: str) -> 'Pix2PixGenerator': 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | m.padding_mode = mode 58 | return self 59 | 60 | @torch.no_grad() 61 | def to_instance_norm(self, affine: bool = True) -> 'Pix2PixGenerator': 62 | """ 63 | Pix2Pix sometimes uses batch size 1, similar to instance norm. 64 | """ 65 | 66 | def to_instancenorm(m): 67 | if isinstance(m, nn.BatchNorm2d): 68 | return nn.InstanceNorm2d(m.num_features, affine=affine) 69 | return m 70 | 71 | tnn.utils.edit_model(self, to_instancenorm) 72 | 73 | return self 74 | 75 | 76 | def pix2pix_256() -> Pix2PixGenerator: 77 | """ 78 | The architecture used in `Pix2Pix `_, 79 | able to train on 256x256 or 512x512 images. 80 | """ 81 | return Pix2PixGenerator([32, 64, 128, 256, 512, 512, 512, 512]) 82 | 83 | 84 | def pix2pix_128() -> Pix2PixGenerator: 85 | """ 86 | The architecture used in `Pix2Pix `_, 87 | able to train on 128x128 or 512x512 images. 88 | """ 89 | return Pix2PixGenerator([32, 64, 128, 256, 512, 512, 512]) 90 | 91 | 92 | def pix2pix_dev() -> Pix2PixGenerator: 93 | """ 94 | A version of pix2pix_256 with less filter to use less memory and compute. 95 | """ 96 | return Pix2PixGenerator([32, 64, 128, 128, 256, 256, 512, 512]) 97 | -------------------------------------------------------------------------------- /tests/test_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchelie.nn import * 4 | import torchelie.nn.functional as tnnf 5 | 6 | 7 | def test_adain(): 8 | m = torch.jit.script(AdaIN2d(16, 8)) 9 | m(torch.randn(5, 16, 8, 8), torch.randn(5, 8)) 10 | 11 | m = torch.jit.script(AdaIN2d(16, 8)) 12 | m.condition(torch.randn(5, 8)) 13 | m(torch.randn(5, 16, 8, 8)) 14 | 15 | 16 | def test_film(): 17 | m = torch.jit.script(FiLM2d(16, 8)) 18 | m(torch.randn(5, 16, 8, 8), torch.randn(5, 8)) 19 | 20 | 21 | def test_bn(): 22 | for M in [NoAffineBN2d, NoAffineMABN2d, BatchNorm2d, MovingAverageBN2d]: 23 | m = torch.jit.script(M(16)) 24 | m(torch.randn(5, 16, 8, 8)) 25 | 26 | for M in [ConditionalBN2d, ConditionalMABN2d]: 27 | m = M(16, 8) 28 | m(torch.randn(5, 16, 8, 8), torch.randn(5, 8)) 29 | 30 | m = torch.jit.script(PixelNorm()) 31 | m(torch.randn(5, 16, 8, 8)) 32 | 33 | m = Lambda(lambda x: x + 1) 34 | m(torch.zeros(1)) 35 | 36 | 37 | def test_spade(): 38 | for M in [Spade2d, SpadeMA2d]: 39 | m = M(16, 8, 4) 40 | m(torch.randn(5, 16, 8, 8), torch.randn(5, 8, 8, 8)) 41 | 42 | 43 | def test_attnnorm(): 44 | m = AttenNorm2d(16, 8) 45 | m(torch.randn(1, 16, 8, 8)) 46 | 47 | 48 | def test_blocks(): 49 | m = MConvBNReLU(4, 8, 3) 50 | m(torch.randn(1, 4, 8, 8)) 51 | 52 | m = ConvBlock(4, 8, 3) 53 | m(torch.randn(1, 4, 8, 8)) 54 | 55 | m = ConvBlock(4, 8, 3).remove_batchnorm().leaky() 56 | m(torch.randn(1, 4, 8, 8)) 57 | 58 | m = ResBlock(4, 8, 1) 59 | m(torch.randn(1, 4, 8, 8)) 60 | 61 | m = PreactResBlock(4, 8, 1) 62 | m(torch.randn(1, 4, 8, 8)) 63 | 64 | m = SpadeResBlock(4, 4, 3, 1) 65 | m(torch.randn(1, 4, 8, 8), torch.randn(1, 3, 8, 8)) 66 | 67 | m = SpadeResBlock(4, 8, 3, 1) 68 | m(torch.randn(1, 4, 8, 8), torch.randn(1, 3, 8, 8)) 69 | 70 | m = AutoGANGenBlock(6, 3, []) 71 | m(torch.randn(1, 6, 8, 8)) 72 | 73 | m = AutoGANGenBlock(3, 3, []) 74 | m(torch.randn(1, 3, 8, 8)) 75 | 76 | m = AutoGANGenBlock(3, 3, [5]) 77 | m(torch.randn(1, 3, 8, 8), [torch.randn(1, 5, 4, 4)]) 78 | 79 | m = ResidualDiscrBlock(6, 3) 80 | m(torch.randn(1, 6, 8, 8)) 81 | 82 | 83 | def test_vq(): 84 | m = VQ(8, 16) 85 | m(torch.randn(10, 8)) 86 | 87 | m = VQ(8, 16) 88 | m(torch.randn(10, 8)) 89 | 90 | m = VQ(8, 16, init_mode='first') 91 | m(torch.randn(10, 8)) 92 | 93 | m = VQ(8, 16, init_mode='first') 94 | m(torch.randn(10, 8)) 95 | 96 | 97 | def test_tfms(): 98 | m = torch.jit.script(ImageNetInputNorm()) 99 | m(torch.randn(1, 3, 8, 8)) 100 | 101 | 102 | def test_maskedconv(): 103 | m = MaskedConv2d(3, 8, 3, center=True) 104 | m(torch.randn(1, 3, 8, 8)) 105 | m = TopLeftConv2d(3, 8, 3, center=True) 106 | m(torch.randn(1, 3, 8, 8)) 107 | 108 | 109 | def test_misc(): 110 | m = torch.jit.script(Noise(1)) 111 | m(torch.randn(1, 3, 8, 8)) 112 | 113 | m = Debug('test') 114 | m(torch.randn(1, 3, 8, 8)) 115 | 116 | m = torch.jit.script(Reshape(16)) 117 | m(torch.randn(1, 4, 4)) 118 | 119 | 120 | def test_laplacian(): 121 | x = torch.randn(5, 3, 32, 32) 122 | tnnf.combine_laplacians(tnnf.laplacian(x)) 123 | -------------------------------------------------------------------------------- /tests/test_recipes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchelie as tch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from torchvision.transforms import ToPILImage 6 | 7 | from torchelie.recipes import CrossEntropyClassification 8 | from torchelie.recipes import DeepDream 9 | from torchelie.recipes import FeatureVis 10 | from torchelie.recipes import NeuralStyle 11 | from torchelie.recipes import TrainAndCall 12 | import torchelie.callbacks as tcb 13 | 14 | 15 | class FakeData: 16 | 17 | def __len__(self): 18 | return 10 19 | 20 | def __getitem__(self, i): 21 | cls = 0 if i < 5 else 1 22 | return torch.randn(10) + cls * 3, cls 23 | 24 | 25 | class FakeImg: 26 | 27 | def __len__(self): 28 | return 10 29 | 30 | def __getitem__(self, i): 31 | cls = 0 if i < 5 else 1 32 | return torch.randn(1, 4, 4) + cls * 3, cls 33 | 34 | 35 | def test_classification(): 36 | trainloader = DataLoader(FakeImg(), 4, shuffle=True) 37 | testloader = DataLoader(FakeImg(), 4, shuffle=True) 38 | 39 | model = nn.Sequential(tch.nn.Reshape(-1), nn.Linear(16, 2)) 40 | 41 | clf_recipe = CrossEntropyClassification(model, trainloader, testloader, 42 | ['foo', 'bar']) 43 | clf_recipe.run(1) 44 | 45 | 46 | def test_deepdream(): 47 | model = nn.Sequential(nn.Conv2d(3, 6, 3)) 48 | dd = DeepDream(model, '0') 49 | dd.fit(ToPILImage()(torch.randn(3, 128, 128)), 1) 50 | 51 | 52 | def test_featurevis(): 53 | model = nn.Sequential(nn.Conv2d(3, 6, 3)) 54 | dd = FeatureVis(model, '0', 229, lr=1) 55 | dd.fit(1, 0) 56 | 57 | 58 | def test_neuralstyle(): 59 | stylizer = NeuralStyle() 60 | 61 | content = ToPILImage()(torch.randn(3, 64, 64)) 62 | style_img = ToPILImage()(torch.randn(3, 64, 64)) 63 | 64 | stylizer.fit(1, 65 | content, 66 | style_img, 67 | 1, 68 | second_scale_ratio=1, 69 | content_layers=['conv1_1']) 70 | 71 | 72 | def test_trainandcall(): 73 | model = nn.Linear(10, 2) 74 | 75 | def train_step(batch): 76 | x, y = batch 77 | out = model(x) 78 | loss = torch.nn.functional.cross_entropy(out, y) 79 | loss.backward() 80 | return {'loss': loss} 81 | 82 | def after_train(): 83 | print('Yup.') 84 | return {} 85 | 86 | trainloader = DataLoader(FakeData(), 4, shuffle=True) 87 | trainer = TrainAndCall(model, train_step, after_train, trainloader) 88 | trainer.callbacks.add_callbacks( 89 | [tcb.Optimizer(torch.optim.Adam(model.parameters(), lr=1e-3))]) 90 | 91 | trainer.run(1) 92 | 93 | 94 | def test_callbacks(): 95 | from torchelie.recipes import Recipe 96 | 97 | def train(b): 98 | x, y = b 99 | return {'pred': torch.randn(y.shape[0])} 100 | 101 | m = nn.Linear(2, 2) 102 | 103 | r = Recipe(train, DataLoader(FakeImg(), 4)) 104 | r.callbacks.add_callbacks([ 105 | tcb.Counter(), 106 | tcb.AccAvg(), 107 | tcb.Checkpoint('/tmp/m', m), 108 | tcb.ClassificationInspector(1, ['1', '2']), 109 | tcb.ConfusionMatrix(['one', 'two']), 110 | tcb.ImageGradientVis(), 111 | tcb.MetricsTable(), 112 | ]) 113 | r.run(1) 114 | -------------------------------------------------------------------------------- /scripts/stylevgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | this script equalizes the mean and std of the convolutions of a VGG network. 3 | It is used to make the VGG network used in the Gatys et al. style transfer 4 | paper to be compatible with the VGG network used in the Johnson et al. paper. 5 | It equalizes the contribution of each layer to the loss. 6 | 7 | ./stylevgg 8 | 9 | must be a constructor in torchelie.models 10 | """ 11 | import torch 12 | from torchelie.datasets import FastImageFolder 13 | import torchvision.transforms as TF 14 | import torchelie.models as tchm 15 | import torchelie as tch 16 | import sys 17 | 18 | torch.autograd.set_grad_enabled(False) 19 | 20 | imagenet_path = sys.argv[2] 21 | 22 | model = sys.argv[1] 23 | m = tchm.__dict__[model](1000, pretrained='classification/imagenet') 24 | del m.classifier 25 | m.cuda() 26 | m.eval() 27 | 28 | ds = FastImageFolder(imagenet_path, 29 | transform=TF.Compose([ 30 | TF.Resize(256), 31 | TF.CenterCrop(224), 32 | TF.ToTensor(), 33 | tch.nn.ImageNetInputNorm() 34 | ])) 35 | 36 | batches = [ 37 | b[0] for _, b in zip( 38 | range(200), torch.utils.data.DataLoader( 39 | ds, batch_size=320, shuffle=True)) 40 | ] 41 | 42 | batch = batches[0].cuda() 43 | 44 | 45 | def flatvgg(): 46 | layers = [] 47 | 48 | def _rec(m): 49 | if len(list(m.children())) == 0: 50 | layers.append(m) 51 | else: 52 | for mm in m.children(): 53 | _rec(mm) 54 | 55 | _rec(m.features) 56 | return torch.nn.Sequential(*layers) 57 | 58 | 59 | idxs = [ 60 | i for i, nm in enumerate(dict(m.features.named_children()).keys()) 61 | if 'conv' in nm 62 | ] 63 | flat = flatvgg() 64 | 65 | flatidxs = [i for i, l in enumerate(flat) if isinstance(l, torch.nn.Conv2d)] 66 | print(flatidxs) 67 | #flatidxs.append(len(flat)) 68 | print(dict(m.features.named_children()).keys()) 69 | 70 | print('before') 71 | for i in idxs: 72 | with torch.cuda.amp.autocast(): 73 | out = m.features[:i + 1](batch) 74 | mean = out.cpu().float().mean(dim=(0, 2, 3)) 75 | del out 76 | print(mean.mean(), mean.std()) 77 | 78 | prev_mean = torch.tensor([1, 1, 1]).cuda() 79 | for i in range(len(flatidxs)): 80 | print('computing', i) 81 | ms = [] 82 | for b in batches: 83 | with torch.cuda.amp.autocast(): 84 | out = flat[:flatidxs[i] + 2](b.cuda()) 85 | mean = out.cpu().float().mean(dim=(0, 2, 3)) 86 | del out 87 | ms.append(mean) 88 | mean = torch.stack(ms, dim=0).mean(0).cuda() 89 | flat[flatidxs[i]].weight.data *= (prev_mean[None, :, None, None] 90 | / mean[:, None, None, None]) 91 | flat[flatidxs[i]].bias.data /= mean 92 | prev_mean = mean 93 | 94 | print('after') 95 | for i in idxs: 96 | with torch.cuda.amp.autocast(): 97 | out = m.features[:i + 1](batch) 98 | mean = out.cpu().float().mean(dim=(0, 2, 3)) 99 | del out 100 | print(mean.mean(), mean.std()) 101 | 102 | ref = tchm.__dict__[model]( 103 | 1000, pretrained='classification/imagenet').features.cuda()(batch[:128]) 104 | print((m.features(batch[:128]) 105 | - ref / prev_mean[None, :, None, None]).abs().mean().item()) 106 | torch.save(m.state_dict(), f'{model}.pth') 107 | -------------------------------------------------------------------------------- /torchelie/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchelie.nn as tnn 4 | from torchelie.utils import kaiming 5 | from typing import cast 6 | from .classifier import ClassificationHead 7 | from .registry import register 8 | 9 | 10 | class VGG(tnn.CondSeq): 11 | """ 12 | Construct a VGG-like model. The architecture is composed of either the 13 | number of channels or 'M' for a maxpool operation. 14 | 15 | This creates a standard VGG11 with 10 classes. 16 | 17 | .. 18 | VGG([64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 19 | 10) 20 | """ 21 | 22 | def __init__(self, arch: list, num_classes: int) -> None: 23 | super().__init__() 24 | self.arch = arch 25 | in_ch = 3 26 | self.in_channels = in_ch 27 | 28 | feats = tnn.CondSeq() 29 | block_num = 1 30 | conv_num = 1 31 | for layer in arch: 32 | if layer == 'M': 33 | feats.add_module(f'pool_{block_num}', nn.MaxPool2d(2, 2)) 34 | block_num += 1 35 | conv_num = 1 36 | else: 37 | ch = cast(int, layer) 38 | feats.add_module(f'conv_{block_num}_{conv_num}', 39 | tnn.ConvBlock(in_ch, ch, 3).remove_batchnorm()) 40 | in_ch = ch 41 | conv_num += 1 42 | self.out_channels = ch 43 | 44 | self.features = feats 45 | self.classifier = ClassificationHead(self.out_channels, num_classes) 46 | self.classifier.to_vgg_style(4096) 47 | 48 | def add_batchnorm(self, remove_first=False) -> 'VGG': 49 | for m in self.features: 50 | if isinstance(m, tnn.ConvBlock): 51 | m.restore_batchnorm() 52 | 53 | if remove_first: 54 | self.features.conv_1_1.remove_batchnorm() 55 | 56 | return self 57 | 58 | def set_input_specs(self, in_channels: int) -> 'VGG': 59 | c1 = self.features.conv_1_1 60 | assert isinstance(c1, tnn.ConvBlock) 61 | c1.conv = kaiming(tnn.Conv3x3(in_channels, c1.conv.out_channels)) 62 | return self 63 | 64 | 65 | @register 66 | def vgg11(num_classes: int) -> 'VGG': 67 | return VGG([64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 68 | num_classes) 69 | 70 | 71 | @register 72 | def vgg13(num_classes: int) -> 'VGG': 73 | return VGG([ 74 | 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' 75 | ], num_classes) 76 | 77 | 78 | @register 79 | def vgg16(num_classes: int) -> 'VGG': 80 | return VGG([ 81 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 82 | 512, 512, 'M' 83 | ], num_classes) 84 | 85 | 86 | @register 87 | def vgg19(num_classes: int) -> 'VGG': 88 | return VGG([ 89 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 90 | 'M', 512, 512, 512, 512, 'M' 91 | ], num_classes) 92 | 93 | 94 | @register 95 | def vgg11_bn(num_classes: int) -> 'VGG': 96 | return vgg11(num_classes).add_batchnorm() 97 | 98 | 99 | @register 100 | def vgg13_bn(num_classes: int) -> 'VGG': 101 | return vgg13(num_classes).add_batchnorm() 102 | 103 | 104 | @register 105 | def vgg16_bn(num_classes: int) -> 'VGG': 106 | return vgg16(num_classes).add_batchnorm() 107 | 108 | 109 | @register 110 | def vgg19_bn(num_classes: int) -> 'VGG': 111 | return vgg19(num_classes).add_batchnorm() 112 | -------------------------------------------------------------------------------- /torchelie/models/hourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchelie.utils as tu 4 | import torchelie.nn as tnn 5 | 6 | 7 | @tu.experimental 8 | class Hourglass(nn.Module): 9 | """ 10 | Hourglass model from Deep Image Prior. 11 | """ 12 | 13 | def __init__(self, 14 | noise_dim=32, 15 | down_channels=[128, 128, 128, 128, 128], 16 | skip_channels=4, 17 | down_kernel=[3, 3, 3, 3, 3], 18 | up_kernel=[3, 3, 3, 3, 3], 19 | upsampling='bilinear') -> None: 20 | super().__init__() 21 | 22 | assert (len(down_channels) == len(down_kernel)), (len(down_channels), 23 | len(down_kernel)) 24 | assert (len(down_channels) == len(up_kernel)), (len(down_channels), 25 | len(up_kernel)) 26 | 27 | self.upsampling = upsampling 28 | self.downs = nn.ModuleList( 29 | [self.down(noise_dim, down_channels[0], down_kernel[0])] + [ 30 | self.down(d1, d2, down_kernel[0]) for d1, d2, k in zip( 31 | down_channels[:-1], down_channels[1:], down_kernel[1:]) 32 | ]) 33 | 34 | self.ups = nn.ModuleList([ 35 | self.up(down_channels[-1] 36 | + skip_channels, down_channels[-1], up_kernel[0]) 37 | ] + [ 38 | self.up(d1 + skip_channels, d2, k) for d1, d2, k in zip( 39 | down_channels[:0:-1], down_channels[-2::-1], up_kernel[1:]) 40 | ]) 41 | 42 | if skip_channels != 0: 43 | self.skips = nn.ModuleList( 44 | [self.skip(d, skip_channels) for d in down_channels]) 45 | 46 | self.to_rgb = tnn.ConvBlock(down_channels[0], 3, up_kernel[-1]) 47 | self.to_rgb.no_relu() 48 | 49 | tnn.utils.make_leaky(self) 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | m.padding_mode = 'reflect' 53 | 54 | def down(self, in_ch, out_ch, ks) -> nn.Sequential: 55 | return nn.Sequential( 56 | tnn.ConvBlock(in_ch, out_ch, ks, stride=2), 57 | tnn.ConvBlock(out_ch, out_ch, ks, stride=1), 58 | ) 59 | 60 | def up(self, in_ch, out_ch, ks) -> nn.Sequential: 61 | conv = tnn.ConvBlock(in_ch, out_ch, ks) 62 | tnn.utils.insert_before(conv, 'conv', nn.BatchNorm2d(in_ch), 'pre_bn') 63 | return conv 64 | 65 | def skip(self, in_ch, out_ch) -> nn.Sequential: 66 | return tnn.ConvBlock(in_ch, out_ch, 1) 67 | 68 | def forward(self, x) -> torch.Tensor: 69 | acts = [x] 70 | for d in self.downs: 71 | acts.append(d(acts[-1])) 72 | acts = acts[1:] 73 | 74 | if hasattr(self, 'skips'): 75 | skips = [s(a) for s, a in zip(self.skips, acts)] 76 | 77 | x = acts[-1] 78 | for u, s in zip(self.ups, reversed(skips)): 79 | x = nn.functional.interpolate(x, 80 | size=s.shape[2:], 81 | mode=self.upsampling) 82 | x = u(torch.cat([x, s], dim=1)) 83 | else: 84 | x = acts[-1] 85 | for u, x2 in zip(self.ups, reversed(acts)): 86 | x = nn.functional.interpolate(x, 87 | size=x2.shape[2:], 88 | mode=self.upsampling) 89 | x = u(x) 90 | 91 | x = nn.functional.interpolate(x, scale_factor=2, mode=self.upsampling) 92 | return torch.sigmoid(self.to_rgb(x)) 93 | -------------------------------------------------------------------------------- /torchelie/recipes/trainandtest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchelie.callbacks as tcb 4 | from torchelie.recipes import Recipe 5 | 6 | 7 | def TrainAndTest(model, 8 | train_fun, 9 | test_fun, 10 | train_loader, 11 | test_loader, 12 | *, 13 | test_every=100, 14 | visdom_env='main', 15 | log_every=10, 16 | checkpoint='model', 17 | key_best=None): 18 | """ 19 | Two nested loops, usually one for training and one for testing, but can 20 | serve other purposes. The model is automatically registered and 21 | checkpointed as :code:`checkpoint['model']`, and put in eval mode when 22 | testing. 23 | 24 | Training callbacks: 25 | 26 | - Counter for counting iterations, connected to the testing loop as well 27 | - VisdomLogger 28 | - StdoutLogger 29 | - SeedDistributedSampler 30 | 31 | Testing: 32 | 33 | Testing loop is in :code:`.test_loop`. 34 | 35 | Testing callbacks: 36 | 37 | - VisdomLogger 38 | - StdoutLogger 39 | - Checkpoint 40 | 41 | Args: 42 | model (nn.Model): a model 43 | train_fun (Callabble): a function that takes a batch as a single 44 | argument, performs a training step and return a dict of values to 45 | populate the recipe's state. 46 | test_fun (Callable): a function taking a batch as a single argument 47 | then performs something to evaluate your model and returns a dict 48 | to populate the state. 49 | train_loader (DataLoader): Training set dataloader 50 | test_loader (DataLoader): Testing set dataloader 51 | test_every (int): testing frequency, in number of iterations (default: 52 | 100) 53 | visdom_env (str): name of the visdom environment to use, or None for 54 | not using Visdom (default: None) 55 | log_every (int): logging frequency, in number of iterations (default: 56 | 100) 57 | checkpoint (str): checkpointing path or None for no checkpointing 58 | key_best (function or None): a key function for comparing states. 59 | Checkpointing the greatest. 60 | """ 61 | 62 | def eval_call(batch): 63 | model.eval() 64 | with torch.no_grad(): 65 | out = test_fun(batch) 66 | model.train() 67 | return out 68 | 69 | train_loop = Recipe(train_fun, train_loader) 70 | train_loop.register('model', model) 71 | 72 | test_loop = Recipe(eval_call, test_loader) 73 | train_loop.test_loop = test_loop 74 | train_loop.register('test_loop', test_loop) 75 | 76 | train_loop.callbacks.add_prologues([tcb.Counter()]) 77 | train_loop.callbacks.add_epilogues([ 78 | tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), 79 | tcb.StdoutLogger(log_every=log_every), 80 | tcb.CallRecipe(test_loop, test_every), 81 | tcb.SeedDistributedSampler(), 82 | ]) 83 | 84 | test_loop.callbacks.add_epilogues([ 85 | tcb.VisdomLogger(visdom_env=visdom_env, 86 | log_every=-1, 87 | prefix='test_', 88 | post_epoch_ends=True), 89 | tcb.StdoutLogger(log_every=-1, prefix='Test'), 90 | ]) 91 | 92 | if checkpoint is not None: 93 | test_loop.callbacks.add_epilogues([ 94 | tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', 95 | train_loop, 96 | key_best=key_best) 97 | ]) 98 | 99 | return train_loop 100 | -------------------------------------------------------------------------------- /torchelie/nn/maskedconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchelie.utils import kaiming, experimental 6 | 7 | 8 | class MaskedConv2d(nn.Conv2d): 9 | """ 10 | A masked 2D convolution for PixelCNN 11 | 12 | Args: 13 | in_chan (int): number of input channels 14 | out_chan (int): number of output channels 15 | ks (int): kernel size 16 | center (bool): whereas central pixel is masked or not 17 | stride (int): stride, defaults to 1 18 | bias (2-tuple of ints): A spatial bias. Either the spatial dimensions 19 | of the input for a different bias at each location, or (1, 1) for 20 | the same bias everywhere (default) 21 | """ 22 | 23 | def __init__(self, in_chan, out_chan, ks, center, stride=1, bias=(1, 1)): 24 | super(MaskedConv2d, self).__init__(in_chan, 25 | out_chan, (ks // 2 + 1, ks), 26 | padding=0, 27 | stride=stride, 28 | bias=False) 29 | self.register_buffer('mask', torch.ones(ks // 2 + 1, ks)) 30 | self.mask[-1, ks // 2 + (1 if center else 0):] = 0 31 | 32 | self.spatial_bias = None 33 | if bias is not None: 34 | self.spatial_bias = nn.Parameter(torch.zeros(out_chan, *bias)) 35 | 36 | nn.init.kaiming_uniform_(self.weight) 37 | 38 | def forward(self, x): 39 | self.weight_orig = self.weight 40 | del self.weight 41 | self.weight = self.weight_orig * self.mask 42 | ks = self.weight.shape[-1] 43 | 44 | x = F.pad(x, (ks // 2, ks // 2, ks // 2, 0)) 45 | res = super(MaskedConv2d, self).forward(x) 46 | 47 | self.weight = self.weight_orig 48 | del self.weight_orig 49 | if self.spatial_bias is not None: 50 | return res + self.spatial_bias 51 | else: 52 | return res 53 | 54 | 55 | class TopLeftConv2d(nn.Module): 56 | """ 57 | A 2D convolution for PixelCNN made of a convolution above the current pixel 58 | and another on the left. 59 | 60 | Args: 61 | in_chan (int): number of input channels 62 | out_chan (int): number of output channels 63 | ks (int): kernel size 64 | center (bool): whereas central pixel is masked or not 65 | stride (int): stride, defaults to 1 66 | bias (2-tuple of ints): A spatial bias. Either the spatial dimensions 67 | of the input for a different bias at each location, or (1, 1) for 68 | the same bias everywhere (default) 69 | """ 70 | 71 | @experimental 72 | def __init__(self, in_chan, out_chan, ks, center, stride=1, bias=(1, 1)): 73 | super(TopLeftConv2d, self).__init__() 74 | self.top = kaiming( 75 | nn.Conv2d(in_chan, 76 | out_chan, (ks // 2, ks), 77 | bias=False, 78 | stride=stride)) 79 | self.left = kaiming( 80 | nn.Conv2d(in_chan, 81 | out_chan, (1, ks // 2 + (1 if center else 0)), 82 | stride=stride, 83 | bias=False)) 84 | self.ks = ks 85 | self.center = center 86 | self.bias = nn.Parameter(torch.zeros(out_chan, *bias)) 87 | 88 | def forward(self, x): 89 | top = self.top( 90 | F.pad(x[:, :, :-1, :], 91 | (self.ks // 2, self.ks // 2, self.ks // 2, 0))) 92 | if not self.center: 93 | left = self.left(F.pad(x[:, :, :, :-1], (self.ks // 2, 0, 0, 0))) 94 | else: 95 | left = self.left(F.pad(x, (self.ks // 2, 0, 0, 0))) 96 | return top + left + self.bias 97 | -------------------------------------------------------------------------------- /torchelie/loss/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchelie.utils import experimental 6 | 7 | 8 | def ortho(w: torch.Tensor) -> torch.Tensor: 9 | r""" 10 | Returns the orthogonal loss for weight matrix `m`, from Big GAN. 11 | 12 | https://arxiv.org/abs/1809.11096 13 | 14 | :math:`R_{\beta}(W)= ||W^T W \odot (1 - I)||_F^2` 15 | """ 16 | cosine = torch.einsum('ij,ji->ij', w, w) 17 | no_diag = (1 - torch.eye(w.shape[0], device=w.device)) 18 | return (cosine * no_diag).pow(2).sum(dim=1).mean() 19 | 20 | 21 | def total_variation(i: torch.Tensor) -> torch.Tensor: 22 | """ 23 | Returns the total variation loss for batch of images `i` 24 | """ 25 | v = F.l1_loss(i[:, :, 1:, :], i[:, :, :-1, :]) 26 | h = F.l1_loss(i[:, :, :, 1:], i[:, :, :, :-1]) 27 | return v + h 28 | 29 | 30 | @experimental 31 | def focal_loss(input: torch.Tensor, 32 | target: torch.Tensor, 33 | gamma: float = 0, 34 | weight: Optional[torch.Tensor] = None, 35 | reduction: str = 'none') -> torch.Tensor: 36 | r""" 37 | Returns the focal loss between `target` and `input` 38 | 39 | :math:`\text{FL}(p_t)=-(1-p_t)^\gamma\log(p_t)` 40 | """ 41 | if input.shape[1] == 1: 42 | mlogp = nn.functional.binary_cross_entropy_with_logits( 43 | input, target, reduction='none') 44 | else: 45 | mlogp = nn.functional.cross_entropy(input, 46 | target, 47 | weight=weight, 48 | reduction='none') 49 | p = torch.exp(-mlogp).clamp(min=1e-8, max=1 - 1e-8) 50 | loss = (1 - p)**gamma * mlogp 51 | if reduction == 'mean': 52 | return loss.mean() 53 | if reduction == 'sum': 54 | return loss.sum() 55 | if reduction == 'none': 56 | return loss 57 | assert False, f'{reduction} not a valid reduction method' 58 | 59 | 60 | def continuous_cross_entropy(pred: torch.Tensor, 61 | soft_targets: torch.Tensor, 62 | weights: Optional[torch.Tensor] = None, 63 | reduction: str = 'mean') -> torch.Tensor: 64 | r""" 65 | Compute the cross entropy between the logits `pred` and a normalized 66 | distribution `soft_targets`. If `soft_targets` is a one-hot vector, this is 67 | equivalent to `nn.functional.cross_entropy` with a label 68 | """ 69 | if weights is None: 70 | ce = torch.sum(-soft_targets * F.log_softmax(pred, 1), 1) 71 | else: 72 | ce = torch.sum(-weights * soft_targets * F.log_softmax(pred, 1), 1) 73 | 74 | if reduction == 'mean': 75 | return ce.mean() 76 | if reduction == 'sum': 77 | return ce.sum() 78 | if reduction == 'none': 79 | return ce 80 | assert False, f'{reduction} not a valid reduction method' 81 | 82 | 83 | @experimental 84 | def smoothed_cross_entropy(pred: torch.Tensor, 85 | targets: torch.tensor, 86 | smoothing: float = 0.9): 87 | """ 88 | Cross entropy with label smoothing 89 | 90 | Args: 91 | pred (FloatTensor): a 2D logits prediction 92 | targets (LongTensor): 1D indices 93 | smoothing (float): target probability value for the correct class 94 | """ 95 | prob = F.log_softmax(pred) 96 | n_classes = pred.shape[1] 97 | wrong_prob = (1 - smoothing) / (n_classes - 1) 98 | 99 | wrong_sum = prob.sum(1) * wrong_prob 100 | good = pred.gather(0, targets.unsqueeze(0)).squeeze(0) 101 | good *= (smoothing - wrong_prob) 102 | return -torch.mean(wrong_sum + good) 103 | -------------------------------------------------------------------------------- /torchelie/models/autogan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchelie.utils import experimental 6 | import torchelie.nn as tnn 7 | from torchelie.utils import xavier 8 | from typing import List 9 | 10 | 11 | @experimental 12 | class AutoGAN(nn.Module): 13 | """ 14 | Generator discovered in AutoGAN: Neural Architecture Search for Generative 15 | Adversarial Networks. 16 | 17 | Args: 18 | arch (list): architecture specification: a list of output channel for 19 | each block. Each block doubles the resolution of the generated 20 | image. Example: `[512, 256, 128, 64, 32]`. 21 | n_skip_max (int): how many blocks far back will be used for the skip 22 | connections maximum. 23 | in_noise (int): dimension of the input noise vector 24 | out_ch (int): number of channels on the image 25 | batchnorm_in_output (bool): whether to have a batchnorm just before 26 | projecting to RGB. I have found it better on False, but the 27 | official AutoGAN repo has it. 28 | """ 29 | n_skip_max: int 30 | make_noise: nn.Module 31 | blocks: nn.ModuleList 32 | to_rgb: nn.Sequential 33 | 34 | def __init__(self, 35 | arch: List[int], 36 | n_skip_max: int = 2, 37 | in_noise: int = 256, 38 | out_ch: int = 3, 39 | batchnorm_in_output: bool = False) -> None: 40 | super().__init__() 41 | self.n_skip_max = n_skip_max 42 | self.make_noise = xavier(nn.Linear(in_noise, 4 * 4 * arch[0])) 43 | 44 | in_ch = arch[0] 45 | blocks = [] 46 | lasts: List[int] = [] 47 | for i, out in enumerate(arch[1:]): 48 | mode = 'nearest' if i % 2 == 0 else 'bilinear' 49 | blocks.append(tnn.AutoGANGenBlock(in_ch, out, lasts, mode=mode)) 50 | lasts = ([out] + lasts)[:n_skip_max] 51 | in_ch = out 52 | self.blocks = nn.ModuleList(blocks) 53 | if batchnorm_in_output: 54 | self.to_rgb = nn.Sequential(nn.BatchNorm2d(arch[-1]), 55 | nn.ReLU(True), 56 | xavier(tnn.Conv3x3(arch[-1], out_ch))) 57 | else: 58 | self.to_rgb = nn.Sequential(nn.ReLU(True), 59 | xavier(tnn.Conv3x3(arch[-1], out_ch))) 60 | 61 | def forward(self, z: torch.Tensor) -> torch.Tensor: 62 | """ 63 | Forward pass 64 | 65 | Args: 66 | z (tensor): A batch of noise vectors 67 | 68 | Returns: 69 | generated batch of images 70 | """ 71 | x = self.make_noise(z) 72 | x = x.view(x.shape[0], -1, 4, 4) 73 | 74 | skips: List[torch.Tensor] = [] 75 | for b in self.blocks: 76 | x, sk = b(x, skips) 77 | skips = ([sk] + skips)[:self.n_skip_max] 78 | return torch.sigmoid(self.to_rgb(F.leaky_relu(x, 0.2))) 79 | 80 | 81 | @experimental 82 | def autogan_128(in_noise: int, out_ch: int = 3) -> AutoGAN: 83 | return AutoGAN(arch=[512, 512, 256, 128, 64, 32], 84 | n_skip_max=3, 85 | in_noise=in_noise, 86 | out_ch=out_ch) 87 | 88 | 89 | @experimental 90 | def autogan_64(in_noise: int, out_ch: int = 3) -> AutoGAN: 91 | return AutoGAN(arch=[512, 256, 128, 64, 32], 92 | n_skip_max=3, 93 | in_noise=in_noise, 94 | out_ch=out_ch) 95 | 96 | 97 | @experimental 98 | def autogan_32(in_noise: int, out_ch: int = 3) -> AutoGAN: 99 | return AutoGAN(arch=[256, 128, 64, 32], 100 | n_skip_max=3, 101 | in_noise=in_noise, 102 | out_ch=out_ch) 103 | -------------------------------------------------------------------------------- /examples/cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example demonstrates how to learn CIFAR-10 with Torchelie. It can be 3 | trivially modified to fit another dataset or model. 4 | 5 | Better than that, make sure to check the Classification Recipe's builtin 6 | command line interface that allows to fit a model to an image dataset without 7 | writing a single line of code. It is good to quickly estimate how hard a 8 | dataset is to learn by fitting a default model with default transforms and 9 | hyperparameters. 10 | """ 11 | import argparse 12 | from contextlib import suppress 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from torchvision.datasets import CIFAR10 19 | import torchvision.transforms as TF 20 | 21 | import torchelie as tch 22 | from torchelie.recipes.classification import CrossEntropyClassification 23 | 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--device', choices=['cpu', 'cuda'], default='cpu') 28 | parser.add_argument('--epochs', default=200, type=int) 29 | parser.add_argument('--lr', type=float, default=0.1) 30 | parser.add_argument('--wd', type=float, default=5e-4) 31 | parser.add_argument('--network', required=True) 32 | return parser.parse_args() 33 | 34 | 35 | def build_transforms(): 36 | tfms = TF.Compose([ 37 | TF.Resize(32), 38 | TF.ToTensor(), 39 | TF.Normalize([0.5] * 3, [0.5] * 3, True), 40 | ]) 41 | train_tfms = TF.Compose([ 42 | TF.RandomCrop(32, padding=4), 43 | TF.RandomHorizontalFlip(), 44 | TF.ToTensor(), 45 | TF.Normalize([0.5] * 3, [0.5] * 3, True), 46 | ]) 47 | return tfms, train_tfms 48 | 49 | 50 | def get_datasets(): 51 | tfms, train_tfms = build_transforms() 52 | ds = CIFAR10('~/.cache/torch/cifar10', download=True, transform=train_tfms) 53 | dst = CIFAR10('~/.cache/torch/cifar10', 54 | transform=tfms, 55 | download=True, 56 | train=False) 57 | return ds, dst 58 | 59 | 60 | def train(): 61 | opts = get_args() 62 | ds, dst = get_datasets() 63 | m = opts.network 64 | model = tch.models.get_model(m, num_classes=10) 65 | 66 | with suppress(AttributeError): 67 | # Use a stem suitable for cifar10 68 | model.features.input.set_input_specs(32) 69 | 70 | dl = torch.utils.data.DataLoader(ds, 71 | num_workers=4, 72 | batch_size=128, 73 | pin_memory=True, 74 | persistent_workers=True, 75 | shuffle=True) 76 | dlt = torch.utils.data.DataLoader(dst, 77 | num_workers=4, 78 | batch_size=256, 79 | persistent_workers=True, 80 | pin_memory=True) 81 | recipe = CrossEntropyClassification(model, 82 | dl, 83 | dlt, 84 | ds.classes, 85 | optimizer='adamw', 86 | lr=opts.lr, 87 | wd=opts.wd, 88 | beta1=0.9, 89 | log_every=100, 90 | test_every=len(dl) * opts.epochs // 10, 91 | visdom_env='cifar_' + m, 92 | n_iters=len(dl) * opts.epochs) 93 | 94 | recipe.to(opts.device) 95 | print(recipe) 96 | recipe.run(opts.epochs + 1) 97 | 98 | 99 | if __name__ == '__main__': 100 | train() 101 | -------------------------------------------------------------------------------- /torchelie/recipes/gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchelie.utils as tu 4 | import torchelie.callbacks as tcb 5 | from torchelie.recipes.recipebase import Recipe 6 | from typing import Optional, Iterable, Any 7 | 8 | 9 | @tu.experimental 10 | def GANRecipe(G: nn.Module, 11 | D: nn.Module, 12 | G_fun, 13 | D_fun, 14 | test_fun, 15 | loader: Iterable[Any], 16 | *, 17 | test_loader: Optional[Iterable[Any]] = None, 18 | visdom_env: Optional[str] = 'main', 19 | checkpoint: Optional[str] = 'model', 20 | test_every: int = 1000, 21 | log_every: int = 10, 22 | g_every: int = 1) -> Recipe: 23 | 24 | def D_wrap(batch): 25 | tu.freeze(G) 26 | tu.unfreeze(D) 27 | 28 | return D_fun(batch) 29 | 30 | def G_wrap(batch): 31 | tu.freeze(D) 32 | tu.unfreeze(G) 33 | 34 | return G_fun(batch) 35 | 36 | def test_wrap(batch): 37 | tu.freeze(G) 38 | tu.freeze(D) 39 | D.eval() 40 | G.eval() 41 | 42 | with torch.no_grad(): 43 | out = test_fun(batch) 44 | 45 | D.train() 46 | G.train() 47 | return out 48 | 49 | class NoLim: 50 | 51 | def __init__(self, data): 52 | self.data = data 53 | self.i = iter(data) 54 | self.did_send = False 55 | 56 | def __iter__(self): 57 | return self 58 | 59 | def __next__(self): 60 | if self.did_send: 61 | self.did_send = False 62 | raise StopIteration 63 | self.did_send = True 64 | try: 65 | return next(self.i) 66 | except: 67 | self.i = iter(self.data) 68 | return next(self.i) 69 | 70 | D_loop = Recipe(D_wrap, loader) 71 | D_loop.register('G', G) 72 | D_loop.register('D', D) 73 | G_loop = Recipe(G_wrap, NoLim(loader)) 74 | D_loop.G_loop = G_loop 75 | D_loop.register('G_loop', G_loop) 76 | 77 | test_loop = Recipe(test_wrap, 78 | NoLim(loader) if test_loader is None else test_loader) 79 | D_loop.test_loop = test_loop 80 | D_loop.register('test_loop', test_loop) 81 | 82 | def G_test(state): 83 | G_loop.callbacks.update_state({ 84 | 'epoch': state['epoch'], 85 | 'iters': state['iters'], 86 | 'epoch_batch': state['epoch_batch'] 87 | }) 88 | 89 | def prepare_test(state): 90 | test_loop.callbacks.update_state({ 91 | 'epoch': state['epoch'], 92 | 'iters': state['iters'], 93 | 'epoch_batch': state['epoch_batch'] 94 | }) 95 | 96 | D_loop.callbacks.add_prologues([tcb.Counter()]) 97 | 98 | D_loop.callbacks.add_epilogues([ 99 | tcb.Log('imgs', 'G_imgs'), 100 | tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'), 101 | tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), 102 | tcb.StdoutLogger(log_every=log_every), 103 | tcb.CallRecipe(test_loop, 104 | test_every, 105 | init_fun=prepare_test, 106 | prefix='Test'), 107 | ]) 108 | 109 | G_loop.callbacks.add_epilogues([ 110 | tcb.WindowedMetricAvg('G_loss'), 111 | tcb.VisdomLogger(visdom_env=visdom_env, 112 | log_every=log_every, 113 | post_epoch_ends=False) 114 | ]) 115 | 116 | if checkpoint is not None: 117 | test_loop.callbacks.add_epilogues([ 118 | tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop), 119 | tcb.VisdomLogger(visdom_env=visdom_env), 120 | ]) 121 | 122 | return D_loop 123 | -------------------------------------------------------------------------------- /torchelie/models/mlpmixer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchelie.utils as tu 4 | import torchelie.nn as tnn 5 | from .classifier import ClassificationHead 6 | from .registry import register 7 | 8 | __all__ = [ 9 | 'MlpBlockBase', 'ChannelMlpBlock', 'SpatialMlpBlock', 'MixerBlock', 10 | 'MlpMixer', 'mlpmixer_s16', 'mlpmixer_s32', 'mlpmixer_vs16', 'mlpmixer_vs32' 11 | ] 12 | 13 | 14 | class MlpBlockBase(tnn.CondSeq): 15 | 16 | def __init__(self, outer_features, inner_features, block): 17 | super().__init__() 18 | self.inner_features = inner_features 19 | self.outer_features = outer_features 20 | self.linear1 = tu.kaiming(block(outer_features, inner_features)) 21 | self.gelu = nn.GELU() 22 | self.linear2 = tu.normal_init(block(inner_features, outer_features)) 23 | 24 | 25 | class ChannelMlpBlock(MlpBlockBase): 26 | 27 | def __init__(self, outer_features, inner_features): 28 | super().__init__(outer_features, inner_features, nn.Linear) 29 | # (B, L, Cin) -> (B, L, Cout) 30 | 31 | 32 | class SpatialMlpBlock(MlpBlockBase): 33 | 34 | def __init__(self, outer_features, inner_features): 35 | super().__init__(outer_features, inner_features, 36 | lambda i, o: nn.Conv1d(i, o, 1)) 37 | # (B, Lin, C) -> (B, Lout, C) 38 | 39 | 40 | class MixerBlock(nn.Module): 41 | 42 | def __init__(self, seq_len, in_features, hidden_token_mix, 43 | hidden_channel_mix): 44 | super().__init__() 45 | self.norm1 = nn.LayerNorm((seq_len, in_features)) 46 | self.tokens_mlp = SpatialMlpBlock(seq_len, hidden_token_mix) 47 | self.norm2 = nn.LayerNorm((seq_len, in_features)) 48 | self.channels_mlp = ChannelMlpBlock(in_features, hidden_channel_mix) 49 | 50 | def forward(self, x): 51 | x = x + self.tokens_mlp(self.norm1(x)) 52 | x = x + self.channels_mlp(self.norm2(x)) 53 | return x 54 | 55 | 56 | class MlpMixer(tnn.CondSeq): 57 | 58 | def __init__(self, im_size, num_classes, patch_size, num_blocks, hidden_dim, 59 | tokens_mlp_dim, channels_mlp_dim): 60 | super().__init__() 61 | self.patch_size = patch_size 62 | self.im_size = im_size 63 | features = tnn.CondSeq() 64 | features.input = tu.kaiming( 65 | nn.Conv2d(3, hidden_dim, patch_size, stride=patch_size)) 66 | seq_len = im_size // patch_size 67 | seq_len2 = seq_len * seq_len 68 | features.im_size = im_size 69 | features.reshape = tnn.Reshape(hidden_dim, seq_len2) 70 | features.permute = tnn.Lambda(lambda x: x.permute(0, 2, 1)) 71 | 72 | for i in range(num_blocks): 73 | mlp = MixerBlock(seq_len2, hidden_dim, tokens_mlp_dim, 74 | channels_mlp_dim) 75 | setattr(features, f'mlp_{i}', mlp) 76 | setattr(features, f'norm_{i}', nn.LayerNorm(hidden_dim)) 77 | features.unpermute = tnn.Lambda(lambda x: x.permute(0, 2, 1)) 78 | features.unshape = tnn.Reshape(hidden_dim, seq_len, seq_len) 79 | self.features = features 80 | self.classifier = ClassificationHead(hidden_dim, num_classes) 81 | 82 | def forward(self, x): 83 | assert (x.shape[2] == self.im_size and x.shape[3] == self.im_size), ( 84 | f"input image of {self.__class__.__name__} must be of size " 85 | f"{self.im_size}x{self.im_size}") 86 | return super().forward(x) 87 | 88 | 89 | @register 90 | def mlpmixer_vs32(num_classes, im_size=224): 91 | return MlpMixer(im_size, num_classes, 16, 4, 512, 64, 512) 92 | 93 | 94 | @register 95 | def mlpmixer_vs16(num_classes, im_size=224): 96 | return MlpMixer(im_size, num_classes, 32, 4, 512, 64, 512) 97 | 98 | 99 | @register 100 | def mlpmixer_s32(num_classes, im_size=224): 101 | return MlpMixer(im_size, num_classes, 32, 8, 512, 256, 2048) 102 | 103 | 104 | @register 105 | def mlpmixer_s16(num_classes, im_size=224): 106 | return MlpMixer(im_size, num_classes, 16, 8, 512, 256, 2048) 107 | -------------------------------------------------------------------------------- /torchelie/models/unet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, cast 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torchelie.nn as tnn 7 | import torchelie.utils as tu 8 | 9 | 10 | @tu.experimental 11 | class UNet(tnn.CondSeq): 12 | """ 13 | U-Net from `U-Net: Convolutional Networks for Biomedical Image Segmentation 14 | `_. This net has architectural changes 15 | operations for further customization. 16 | 17 | Args: 18 | arch (List[int]): a list of channels from the outermost to innermost 19 | layers 20 | num_classes (int): number of output channels 21 | """ 22 | 23 | def __init__(self, arch: List[int], num_classes: int) -> None: 24 | super().__init__() 25 | self.arch = arch 26 | self.in_channels = 3 27 | self.out_channels = arch[-1] 28 | 29 | feats = tnn.CondSeq() 30 | feats.input = tnn.ConvBlock(3, arch[0], 3) 31 | 32 | encdec: nn.Module = tnn.ConvBlock(arch[-1], arch[-1] * 2, 3) 33 | for outer, inner in zip(arch[-2::-1], arch[:0:-1]): 34 | encdec = tnn.UBlock(outer, inner, encdec) 35 | feats.encoder_decoder = encdec 36 | self.features = feats 37 | assert isinstance(encdec.out_channels, int) 38 | self.classifier = tnn.ConvBlock(encdec.out_channels, num_classes, 39 | 3).remove_batchnorm().no_relu() 40 | 41 | @torch.no_grad() 42 | def to_instance_norm(self, affine: bool = True) -> 'UNet': 43 | """ 44 | Replace BatchNorm with InstanceNorm. 45 | """ 46 | 47 | def to_instancenorm(m): 48 | if isinstance(m, nn.BatchNorm2d): 49 | return nn.InstanceNorm2d(m.num_features, affine=affine) 50 | return m 51 | 52 | tnn.utils.edit_model(self, to_instancenorm) 53 | 54 | return self 55 | 56 | def leaky(self, leak: float = 0.2) -> 'UNet': 57 | for m in self.modules(): 58 | if isinstance(m, tnn.ConvBlock): 59 | m.leaky(leak) 60 | return self 61 | 62 | def set_padding_mode(self, mode: str) -> 'UNet': 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | m.padding_mode = mode 66 | return self 67 | 68 | def set_input_specs(self, in_channels: int) -> 'UNet': 69 | assert isinstance(self.features.input, tnn.ConvBlock) 70 | c = self.features.input.conv 71 | self.features.input.conv = tu.kaiming( 72 | nn.Conv2d(in_channels, 73 | c.out_channels, 74 | cast(Tuple[int, int], c.kernel_size), 75 | bias=c.bias is not None, 76 | padding=cast(Tuple[int, int], c.padding))) 77 | return self 78 | 79 | def set_encoder_num_layers(self, num: int) -> 'UNet': 80 | for m in self.modules(): 81 | if isinstance(m, tnn.UBlock): 82 | m.set_encoder_num_layers(num) 83 | return self 84 | 85 | def remove_upsampling_conv(self, num: int) -> 'UNet': 86 | for m in self.modules(): 87 | if isinstance(m, tnn.UBlock): 88 | m.remove_upsampling_conv() 89 | return self 90 | 91 | def set_decoder_num_layers(self, num: int) -> 'UNet': 92 | for m in self.modules(): 93 | if isinstance(m, tnn.UBlock): 94 | m.set_decoder_num_layers(num) 95 | return self 96 | 97 | def to_bilinear_sampling(self) -> 'UNet': 98 | for m in self.modules(): 99 | if isinstance(m, tnn.UBlock): 100 | m.to_bilinear_sampling() 101 | return self 102 | 103 | def remove_first_batchnorm(self) -> 'UNet': 104 | assert isinstance(self.features.input, tnn.ConvBlock) 105 | self.features.input.remove_batchnorm() 106 | return self 107 | 108 | def remove_batchnorm(self) -> 'UNet': 109 | for m in self.modules(): 110 | if isinstance(m, tnn.ConvBlock): 111 | m.remove_batchnorm() 112 | return self 113 | -------------------------------------------------------------------------------- /torchelie/models/efficient.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torchelie.nn as tnn 4 | import torchelie.utils as tu 5 | 6 | 7 | @tu.experimental 8 | class MBConv(nn.Module): 9 | 10 | def __init__(self, 11 | in_ch: int, 12 | out_ch: int, 13 | ks: int, 14 | stride: int = 1, 15 | mul_factor: int = 6): 16 | super(MBConv, self).__init__() 17 | self.in_ch = in_ch 18 | self.out_ch = out_ch 19 | self.ks = ks 20 | self.stride = stride 21 | self.factor = mul_factor 22 | 23 | hid = in_ch * mul_factor 24 | self.branch = tnn.CondSeq( 25 | tu.xavier(tnn.Conv1x1(in_ch, hid, bias=False)), nn.BatchNorm2d(hid), 26 | tnn.HardSwish(), 27 | tu.xavier( 28 | nn.Conv2d(hid, 29 | hid, 30 | ks, 31 | stride=stride, 32 | padding=ks // 2, 33 | groups=hid, 34 | bias=False)), nn.BatchNorm2d(hid), tnn.HardSwish(), 35 | tnn.SEBlock(hid, reduction=4), tu.xavier(tnn.Conv1x1(hid, out_ch)), 36 | nn.BatchNorm2d(out_ch)) 37 | 38 | self.shortcut = tnn.CondSeq() 39 | 40 | if stride != 1: 41 | self.shortcut.add_module( 42 | 'pool', nn.AvgPool2d(stride, stride, ceil_mode=True)) 43 | 44 | if in_ch != out_ch: 45 | self.shortcut.add_module('conv', 46 | tnn.Conv1x1(in_ch, out_ch, bias=False)) 47 | self.shortcut.add_module('bn', nn.BatchNorm2d(out_ch)) 48 | 49 | def __repr__(self): 50 | return "MBConv({}, {}, factor={}, k{}x{}s{}))".format( 51 | self.in_ch, self.out_ch, self.factor, self.ks, self.ks, self.stride) 52 | 53 | def forward(self, x): 54 | return self.branch(x).add_(self.shortcut(x)) 55 | 56 | 57 | class EfficientNet(tnn.CondSeq): 58 | 59 | @tu.experimental 60 | def __init__(self, in_ch, num_classes, B=0): 61 | 62 | def ch(ch): 63 | return int(ch * 1.1**B) // 8 * 8 64 | 65 | def n_layers(d): 66 | return int(math.ceil(d * 1.2**B)) 67 | 68 | def r(): 69 | return int(224 * 1.15**B) 70 | 71 | super(EfficientNet, self).__init__( 72 | # Stage 1 73 | # nn.UpsamplingBilinear2d(size=(r(), r())), 74 | tu.kaiming(tnn.Conv3x3(in_ch, ch(32), stride=2, bias=False)), 75 | nn.BatchNorm2d(ch(32)), 76 | tnn.HardSwish(), 77 | 78 | # Stage 2 79 | MBConv(ch(32), ch(16), 3, mul_factor=1), 80 | *[ 81 | MBConv(ch(16), ch(16), 3, mul_factor=1) 82 | for _ in range(n_layers(1) - 1) 83 | ], 84 | 85 | # Stage 3 86 | MBConv(ch(16), ch(24), 3, stride=2), 87 | *[MBConv(ch(24), ch(24), 3) for _ in range(n_layers(2) - 1)], 88 | 89 | # Stage 4 90 | MBConv(ch(24), ch(40), 5, stride=2), 91 | *[MBConv(ch(40), ch(40), 5) for _ in range(n_layers(2) - 1)], 92 | 93 | # Stage 5 94 | MBConv(ch(40), ch(80), 3, stride=2), 95 | *[MBConv(ch(80), ch(80), 3) for _ in range(n_layers(3) - 1)], 96 | 97 | # Stage 6 98 | MBConv(ch(80), ch(112), 5), 99 | *[MBConv(ch(112), ch(112), 5) for _ in range(n_layers(3) - 1)], 100 | 101 | # Stage 7 102 | MBConv(ch(112), ch(192), 5, stride=2), 103 | *[MBConv(ch(192), ch(192), 5) for _ in range(n_layers(4) - 1)], 104 | 105 | # Stage 8 106 | MBConv(ch(192), ch(320), 3), 107 | *[MBConv(ch(320), ch(320), 3) for _ in range(n_layers(1) - 1)], 108 | tu.kaiming(tnn.Conv1x1(ch(320), ch(1280), bias=False)), 109 | nn.BatchNorm2d(ch(1280)), 110 | tnn.HardSwish(), 111 | nn.AdaptiveAvgPool2d(1), 112 | tnn.Reshape(-1), 113 | tu.xavier(nn.Linear(ch(1280), num_classes))) 114 | -------------------------------------------------------------------------------- /torchelie/loss/face_rec.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class L2Constraint(nn.Module): 10 | """ 11 | From `Ranjan 2017`_ , L2-constrained Softmax Loss for Discriminative Face 12 | Verification. 13 | 14 | Args: 15 | dim (int): number of channels of the feature vector 16 | num_classes (int): number of identities 17 | fixed (bool): whether to use the fixed or dynamic version of AdaCos 18 | (default: False) 19 | 20 | :: _Ranjan 2017: https://arxiv.org/abs/1703.09507 21 | """ 22 | def __init__(self, dim: int, num_classes: int, s: float = 30.): 23 | super(L2Constraint, self).__init__() 24 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, dim)) 25 | nn.init.orthogonal_(self.weight) 26 | self.num_classes = num_classes 27 | self.s = s 28 | 29 | def forward(self, input: torch.Tensor, 30 | label: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 31 | """ 32 | Forward pass 33 | 34 | Args: 35 | input (tensor): feature vectors 36 | label (tensor): labels 37 | 38 | Returns: 39 | scaled cosine logits, cosine logits 40 | """ 41 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 42 | output = cosine * self.s 43 | 44 | return output, cosine 45 | 46 | def __repr__(self): 47 | return "L2CrossEntropy(s={})".format(self.s) 48 | 49 | 50 | class AdaCos(nn.Module): 51 | """ 52 | From AdaCos: [Adaptively Scaling Cosine Logits for Effectively Learning 53 | Deep Face Representations](https://arxiv.org/abs/1905.00292) 54 | 55 | Args: 56 | dim (int): number of channels of the feature vector 57 | num_classes (int): number of identities 58 | fixed (bool): whether to use the fixed or dynamic version of AdaCos 59 | (default: False) 60 | estimate_B (bool): is using dynamic AdaCos, B is estimated from the 61 | real angles of the cosine similarity. However I found that this 62 | method was not numerically stable and experimented with the 63 | approximation :code:`B = num_classes - 1` that was more satisfying. 64 | """ 65 | s: torch.Tensor 66 | 67 | def __init__(self, 68 | dim: int, 69 | num_classes: int, 70 | fixed: bool = False, 71 | estimate_B: bool = False): 72 | super(AdaCos, self).__init__() 73 | self.fixed = fixed 74 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, dim)) 75 | nn.init.xavier_normal_(self.weight) 76 | self.num_classes = num_classes 77 | self.register_buffer( 78 | 's', torch.tensor(math.sqrt(2) * math.log(num_classes - 1))) 79 | self.register_buffer('B', torch.tensor(num_classes - 1.)) 80 | self.estimate_B = estimate_B 81 | 82 | def forward( 83 | self, input: torch.Tensor, label: torch.Tensor 84 | ) -> Tuple[torch.Tensor, torch.Tensor]: 85 | """ 86 | Forward pass 87 | 88 | Args: 89 | input (tensor): feature vectors 90 | label (tensor): labels 91 | 92 | Returns: 93 | scaled cosine logits, cosine logits 94 | """ 95 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 96 | 97 | if self.fixed: 98 | return cosine * self.s, cosine 99 | 100 | if self.training: 101 | with torch.no_grad(): 102 | correct_cos = torch.gather(cosine, 1, label.unsqueeze(1)) 103 | theta_med = torch.acos(correct_cos).median() 104 | 105 | if self.estimate_B: 106 | output = cosine * self.s 107 | expout = output.exp() 108 | correct_expout = torch.gather(expout, 1, 109 | label.unsqueeze(1)) 110 | correct_expout.squeeze_() 111 | self.B = torch.mean(expout.sum(1) - correct_expout, dim=0) 112 | 113 | self.s = torch.log(self.B) / math.cos( 114 | min(math.pi / 4, theta_med.item())) 115 | return cosine * self.s, cosine 116 | 117 | def __repr__(self) -> str: 118 | return "FixedAdaCos(fixed={})".format(self.fixed) 119 | -------------------------------------------------------------------------------- /torchelie/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import TransformedDistribution 6 | 7 | from .utils import experimental 8 | 9 | 10 | class Logistic(TransformedDistribution): 11 | """ 12 | Logistic distribution 13 | 14 | Args: 15 | loc (tensor): mean of the distribution 16 | scale (tensor): scale of the distribution 17 | """ 18 | 19 | def __init__(self, loc: torch.Tensor, scale: torch.Tensor) -> None: 20 | td = torch.distributions 21 | base_distribution = td.Uniform(torch.zeros_like(loc), 22 | torch.ones_like(loc)) 23 | transforms = [ 24 | td.SigmoidTransform().inv, 25 | td.AffineTransform(loc=loc, scale=scale) 26 | ] 27 | super(Logistic, self).__init__(base_distribution, transforms) 28 | 29 | 30 | class LogisticMixture: 31 | """ 32 | Mixture of Logistic distributions. Each tensor contains an additional 33 | dimension with `number of distributions` elements. 34 | 35 | Args: 36 | weights (tensor): un-normalized weights of distributions 37 | loc (tensor): mean of the distributions 38 | scale (tensor): scale of the distributions 39 | dim (int): dimension reprenseting the various distributions, that will 40 | weighted and averaged on. 41 | """ 42 | 43 | def __init__(self, weights, locs, scales, dim) -> None: 44 | self.weights = weights 45 | self.logistics = Logistic(locs, scales) 46 | self.locs = locs 47 | self.dim = dim - len(locs.shape) if dim >= 0 else dim 48 | 49 | @property 50 | def mean(self) -> torch.Tensor: 51 | w = nn.functional.softmax(self.weights, dim=self.dim) 52 | return torch.sum(self.locs * w, dim=self.dim) 53 | 54 | def log_prob(self, x: torch.Tensor) -> torch.Tensor: 55 | log_pis = nn.functional.log_softmax(self.weights, dim=self.dim) 56 | return torch.logsumexp(self.logistics.log_prob(x.unsqueeze(self.dim)) + 57 | log_pis, 58 | dim=self.dim) 59 | 60 | 61 | class GaussianMixture: 62 | """ 63 | Mixture of gaussian distributions. Each tensor contains an additional 64 | dimension with `number of distributions` elements. 65 | 66 | Args: 67 | weights (tensor): un-normalized weights of distributions 68 | loc (tensor): mean of the distributions 69 | scale (tensor): scale of the distributions 70 | dim (int): dimension reprenseting the various distributions, that will 71 | weighted and averaged on. 72 | """ 73 | 74 | def __init__(self, weights: torch.Tensor, locs: torch.Tensor, 75 | scales: torch.Tensor) -> None: 76 | self.weights = weights 77 | self.logistics = torch.distributions.Normal(locs, scales) 78 | self.locs = locs 79 | 80 | @property 81 | def mean(self) -> torch.Tensor: 82 | w = nn.functional.softmax(self.weights, dim=1) 83 | return torch.sum(self.locs * w, dim=1) 84 | 85 | def log_prob(self, x: torch.Tensor) -> torch.Tensor: 86 | log_pis = nn.functional.log_softmax(self.weights, dim=1) 87 | return torch.logsumexp(self.logistics.log_prob(x.unsqueeze(1)) + 88 | log_pis, 89 | dim=1) 90 | 91 | 92 | @experimental 93 | def parameterized_truncated_normal(uniform: torch.Tensor, mu: float, 94 | sigma: float, a: float, 95 | b: float) -> torch.Tensor: 96 | normal = torch.distributions.normal.Normal(0, 1) 97 | 98 | alpha = torch.tensor((a - mu) / sigma) 99 | beta = torch.tensor((b - mu) / sigma) 100 | 101 | alpha_normal_cdf = normal.cdf(alpha) 102 | p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform 103 | 104 | one = torch.tensor(1, dtype=p.dtype) 105 | epsilon = 1e-8 106 | v = torch.clamp(2 * p - 1, -one + epsilon, one - epsilon) 107 | x = mu + sigma * math.sqrt(2) * torch.erfinv(v) 108 | x = torch.clamp(x, a, b) 109 | 110 | return x 111 | 112 | 113 | @experimental 114 | def truncated_normal(uniform: torch.Tensor, a: float, b: float) -> torch.Tensor: 115 | return parameterized_truncated_normal(uniform, mu=0.0, sigma=1.0, a=a, b=b) 116 | 117 | 118 | @experimental 119 | def sample_truncated_normal(*shape, cutoff: float = 2): 120 | return truncated_normal(torch.rand(shape), a=-cutoff, b=cutoff) 121 | -------------------------------------------------------------------------------- /torchelie/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Function 4 | from .vq import quantize 5 | from .transformer import local_attention_2d 6 | 7 | 8 | def laplacian(images, n_down=4): 9 | """ 10 | Decompose a 4D images tensor into a laplacian pyramid. 11 | 12 | Args: 13 | images (4D Tensor): source images 14 | n_down (int): how many times to downscale the input 15 | 16 | Returns: 17 | A list of tensor: laplacian pyramid 18 | """ 19 | lapls = [] 20 | 21 | for i in range(n_down): 22 | n = F.interpolate(images, 23 | scale_factor=0.5, 24 | mode='bilinear', 25 | align_corners=True) 26 | lapls.append(images - F.interpolate( 27 | n, size=images.shape[-2:], mode='bilinear', align_corners=True)) 28 | images = n 29 | 30 | lapls.append(images) 31 | return lapls 32 | 33 | 34 | def combine_laplacians(laplacians): 35 | """ 36 | Recombine a list of of tensor as returned by :code:`laplacian()` into an 37 | image batch 38 | 39 | Args: 40 | laplacians (list of tensors: laplacian pyramid 41 | 42 | Returns: 43 | a tensor 44 | """ 45 | biggest = laplacians[0] 46 | 47 | rescaled = [biggest] 48 | for im in laplacians[1:]: 49 | rescaled.append( 50 | F.interpolate(im, 51 | size=biggest.shape[-2:], 52 | mode='bilinear', 53 | align_corners=True)) 54 | 55 | mixed = torch.stack(rescaled, dim=-1) 56 | return mixed.sum(dim=-1) 57 | 58 | 59 | class InformationBottleneckFunc(Function): 60 | 61 | @staticmethod 62 | def forward(ctx, mu, sigma, strength=1): 63 | z = torch.randn_like(mu) 64 | x = mu + z * sigma 65 | s = torch.tensor(strength, device=x.device) 66 | ctx.save_for_backward(mu, sigma, z, s) 67 | return x 68 | 69 | @staticmethod 70 | def backward(ctx, d_out): 71 | mu, sigma, z, strength = ctx.saved_tensors 72 | 73 | # kl = -0.5 * (1 + log(sigma^2) - mu^2 - sigma^2) 74 | # dkl/dmu = -0.5 * (-2*mu) 75 | # = mu 76 | # dkl/dsig = -0.5 * (d 2*log(sigma) - d sigma^2) 77 | # = -0.5 * (2/sigma - 2*sigma) 78 | # = -1/sigma + sigma 79 | d_mu = d_out + strength * mu 80 | d_sigma = z * d_out 81 | d_sigma += -strength / sigma + strength * sigma 82 | return d_mu, d_sigma, None 83 | 84 | 85 | information_bottleneck = InformationBottleneckFunc.apply 86 | unit_gaussian_prior = InformationBottleneckFunc.apply 87 | 88 | 89 | def drop_path(x, drop_prob=0.0, training=False): 90 | if drop_prob == 0. or not training: 91 | return x 92 | keep_prob = 1 - drop_prob 93 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 94 | random_tensor = keep_prob + torch.rand( 95 | shape, dtype=x.dtype, device=x.device) 96 | random_tensor.floor_() 97 | output = x * random_tensor.div(keep_prob) 98 | return output 99 | 100 | 101 | class SquaReLUFunc(torch.autograd.Function): 102 | 103 | @staticmethod 104 | def forward(ctx, input, inplace=False): 105 | x = F.relu(input, inplace=inplace) 106 | ctx.save_for_backward(x) 107 | return x * x 108 | 109 | @staticmethod 110 | def backward(ctx, grad_output): 111 | x, = ctx.saved_tensors 112 | return grad_output * 2 * x, None 113 | 114 | 115 | squarelu = SquaReLUFunc.apply 116 | 117 | 118 | class StaReLUFunc(torch.autograd.Function): 119 | 120 | @staticmethod 121 | def forward(ctx, input, weight, bias, inplace=False): 122 | x = F.relu(input) 123 | 124 | x = x * x * weight 125 | if bias is not None: 126 | x += bias 127 | 128 | if inplace: 129 | input.copy_(x) 130 | ctx.save_for_backward(input, weight, bias) 131 | return input 132 | ctx.save_for_backward(x, weight, bias) 133 | return x 134 | 135 | @staticmethod 136 | def backward(ctx, grad_output): 137 | out, w, b = ctx.saved_tensors 138 | if b is None: 139 | d_bias = None 140 | else: 141 | d_bias = torch.sum(grad_output) 142 | out -= b 143 | 144 | out /= w 145 | d_weight = torch.sum(grad_output * out) 146 | d_input = grad_output * 2 * out.sqrt() * w 147 | return d_input, d_weight, d_bias, None 148 | 149 | 150 | starelu = StaReLUFunc.apply 151 | -------------------------------------------------------------------------------- /torchelie/models/patchgan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchelie.nn as tnn 3 | from torchelie.utils import kaiming 4 | from typing import List 5 | from torchelie.transforms.differentiable import BinomialFilter2d 6 | 7 | 8 | class PatchDiscriminator(tnn.CondSeq): 9 | 10 | def __init__(self, arch: List[int]) -> None: 11 | super().__init__() 12 | layers: List[nn.Module] = [ 13 | tnn.ConvBlock(3, arch[0], kernel_size=4, 14 | stride=2).remove_batchnorm().leaky() 15 | ] 16 | 17 | in_ch = arch[0] 18 | self.in_channels = in_ch 19 | for next_ch in arch[1:]: 20 | layers.append( 21 | tnn.ConvBlock(in_ch, next_ch, kernel_size=4, stride=2).leaky()) 22 | in_ch = next_ch 23 | assert isinstance(layers[-1], tnn.ConvBlock) 24 | layers[-1].conv.stride = (1, 1) 25 | 26 | self.features = tnn.CondSeq(*layers) 27 | self.classifier = tnn.Conv2d(in_ch, 1, 4) 28 | 29 | def to_equal_lr(self, leak=0.2) -> 'PatchDiscriminator': 30 | for m in self.modules(): 31 | if isinstance(m, (nn.Linear, nn.Conv2d)): 32 | kaiming(m, dynamic=True, a=leak) 33 | 34 | return self 35 | 36 | def remove_batchnorm(self) -> 'PatchDiscriminator': 37 | for m in self.features.modules(): 38 | if isinstance(m, tnn.ConvBlock): 39 | m.remove_batchnorm() 40 | return self 41 | 42 | def to_instance_norm(self, affine: bool = True) -> 'PatchDiscriminator': 43 | """ 44 | Pix2PixHD uses instancenorm rather than batchnorm 45 | """ 46 | 47 | def to_instancenorm(m): 48 | if isinstance(m, nn.BatchNorm2d): 49 | return nn.InstanceNorm2d(m.num_features, affine=affine) 50 | return m 51 | 52 | return tnn.utils.edit_model(self, to_instancenorm) 53 | 54 | def to_binomial_downsampling(self) -> 'PatchDiscriminator': 55 | for m in self.features.modules(): 56 | if isinstance(m, tnn.ConvBlock): 57 | if m.conv.stride[0] != 2: 58 | continue 59 | tnn.utils.insert_before(m, 'conv', BinomialFilter2d(2), 'pool') 60 | m.conv.stride = (1, 1) 61 | return self 62 | 63 | def to_avg_pool(self) -> 'PatchDiscriminator': 64 | for m in self.features.modules(): 65 | if isinstance(m, tnn.ConvBlock): 66 | if m.conv.stride[0] != 2: 67 | continue 68 | tnn.utils.insert_before(m, 'conv', nn.AvgPool2d(2), 'pool') 69 | m.conv.stride = (1, 1) 70 | return self 71 | 72 | def set_input_specs(self, in_channels: int) -> 'PatchDiscriminator': 73 | c = self.features[0].conv 74 | assert isinstance(c, nn.Conv2d) 75 | self.features[0].conv = kaiming(nn.Conv2d(in_channels, 76 | c.out_channels, 77 | 4, 78 | stride=2, 79 | padding=c.padding, 80 | bias=c.bias is not None), 81 | a=0.2) 82 | return self 83 | 84 | def set_kernel_size(self, kernel_size: int) -> 'PatchDiscriminator': 85 | 86 | def change_ks(m): 87 | if isinstance(m, nn.Conv2d) and m.kernel_size[0] != 1: 88 | return kaiming( 89 | nn.Conv2d(m.in_channels, 90 | m.out_channels, 91 | kernel_size, 92 | m.stride, 93 | padding=kernel_size // 2)) 94 | return m 95 | 96 | tnn.utils.edit_model(self.features, change_ks) 97 | return self 98 | 99 | 100 | def patch286() -> PatchDiscriminator: 101 | """ 102 | Patch Discriminator from pix2pix 103 | """ 104 | return PatchDiscriminator([64, 128, 256, 512, 512, 512]) 105 | 106 | 107 | def patch70() -> PatchDiscriminator: 108 | """ 109 | Patch Discriminator from pix2pix 110 | """ 111 | return PatchDiscriminator([64, 128, 256, 512]) 112 | 113 | 114 | def patch34() -> PatchDiscriminator: 115 | """ 116 | Patch Discriminator from pix2pix 117 | """ 118 | return PatchDiscriminator([64, 128, 256]) 119 | 120 | 121 | def patch16() -> PatchDiscriminator: 122 | """ 123 | Patch Discriminator from pix2pix 124 | """ 125 | return PatchDiscriminator([64, 128]) 126 | -------------------------------------------------------------------------------- /torchelie/loss/neuralstyleloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import random 5 | from typing import Dict, Optional, List, cast, Tuple 6 | 7 | from torchelie.utils import bgram 8 | 9 | import torchelie as tch 10 | import torchelie.utils as tu 11 | from torchelie.nn import ImageNetInputNorm, WithSavedActivations 12 | from torchelie.models import PerceptualNet 13 | 14 | 15 | class NeuralStyleLoss(nn.Module): 16 | """ 17 | Style Transfer loss by Leon Gatys 18 | 19 | https://arxiv.org/abs/1508.06576 20 | 21 | set the style and content before performing a forward pass. 22 | """ 23 | net: PerceptualNet 24 | 25 | def __init__(self) -> None: 26 | super(NeuralStyleLoss, self).__init__() 27 | self.style_layers = [ 28 | 'conv1_1', 29 | 'conv2_1', 30 | 'conv3_1', 31 | 'conv4_1', 32 | 'conv5_1', 33 | ] 34 | self.content_layers = ['conv3_2'] 35 | self.content = {} 36 | self.style_maps = {} 37 | self.net = PerceptualNet(self.style_layers + self.content_layers, 38 | remove_unused_layers=False) 39 | self.norm = ImageNetInputNorm() 40 | tu.freeze(self.net) 41 | 42 | def get_style_content_(self, img: torch.Tensor, 43 | detach: bool) -> Dict[str, Dict[str, torch.Tensor]]: 44 | activations: Dict[str, torch.Tensor] 45 | 46 | _, activations = self.net(self.norm(img), detach=detach) 47 | 48 | # this ain't a bug. This normalization is freakin *everything*. 49 | activations = { 50 | k: F.instance_norm(a.float()) for k, a in activations.items() 51 | } 52 | 53 | return activations 54 | 55 | def set_style(self, 56 | style_img: torch.Tensor, 57 | style_ratio: float, 58 | style_layers: Optional[List[str]] = None) -> None: 59 | """ 60 | Set the style. 61 | 62 | Args: 63 | style_img (3xHxW tensor): an image tensor 64 | style_ratio (float): a multiplier for the style loss to make it 65 | greater or smaller than the content loss 66 | style_layer (list of str, optional): the layers on which to compute 67 | the style, or `None` to keep them unchanged 68 | """ 69 | if style_layers is not None: 70 | self.style_layers = style_layers 71 | self.net.set_keep_layers(names=self.style_layers + 72 | self.content_layers) 73 | 74 | self.ratio = torch.tensor(style_ratio) 75 | 76 | with torch.no_grad(): 77 | out = self.get_style_content_(style_img, detach=True) 78 | self.style_maps = {k: bgram(out[k]) for k in self.style_layers} 79 | 80 | def set_content(self, 81 | content_img: torch.Tensor, 82 | content_layers: Optional[List[str]] = None) -> None: 83 | """ 84 | Set the content. 85 | 86 | Args: 87 | content_img (3xHxW tensor): an image tensor 88 | content_layer (str, optional): the layer on which to compute the 89 | content representation, or `None` to keep it unchanged 90 | """ 91 | if content_layers is not None: 92 | self.content_layers = content_layers 93 | self.net.set_keep_layers(names=self.style_layers + 94 | self.content_layers) 95 | 96 | with torch.no_grad(): 97 | out = self.get_style_content_(content_img, detach=True) 98 | 99 | self.content = {a: out[a] for a in self.content_layers} 100 | 101 | def forward( 102 | self, 103 | input_img: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]: 104 | """ 105 | Actually compute the loss 106 | """ 107 | out = self.get_style_content_(input_img, detach=False) 108 | 109 | c_ratio = 1. - self.ratio.squeeze() 110 | s_ratio = self.ratio.squeeze() 111 | 112 | style_loss = sum( 113 | F.l1_loss(self.style_maps[a], bgram(out[a])) 114 | for a in self.style_layers) / len(self.style_maps) 115 | 116 | content_loss = sum( 117 | F.mse_loss(self.content[a], out[a]) 118 | for a in self.content_layers) / len(self.content_layers) 119 | 120 | loss = c_ratio * content_loss + s_ratio * style_loss 121 | 122 | return loss, { 123 | 'style': style_loss.item(), 124 | 'content': content_loss.item() 125 | } 126 | -------------------------------------------------------------------------------- /torchelie/recipes/deepdream.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Dream recipe. 3 | 4 | Performs the algorithm described in 5 | https://ai.googleblog.com/2015/06/inceptionism-going-deeper-into-neural.html 6 | 7 | This implementation differs from the original one: the image is optimized in 8 | Fourier space, for greater details and colors, the model and layers are 9 | customizable. 10 | 11 | A commandline interface is provided through `python3 -m 12 | torchelie.recipes.deepdream`, and a DeepDreamRecipe is provided. 13 | """ 14 | import random 15 | 16 | import torch 17 | import torchvision.transforms as TF 18 | 19 | import torchelie.nn as tnn 20 | from torchelie.data_learning import ParameterizedImg 21 | from torchelie.loss.deepdreamloss import DeepDreamLoss 22 | from torchelie.optim import DeepDreamOptim 23 | from torchelie.recipes.recipebase import Recipe 24 | import torchelie.callbacks as tcb 25 | 26 | from PIL import Image 27 | 28 | 29 | class DeepDream(torch.nn.Module): 30 | """ 31 | Deep Dream recipe 32 | 33 | First instantiate the recipe then call `recipe(n_iter, img)` 34 | 35 | Args: 36 | model (nn.Module): the trained model to use 37 | dream_layer (str): the layer to use on which activations will be 38 | maximized 39 | """ 40 | 41 | def __init__(self, model, dream_layer): 42 | super(DeepDream, self).__init__() 43 | self.loss = DeepDreamLoss(model, dream_layer) 44 | self.norm = tnn.ImageNetInputNorm() 45 | 46 | def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'): 47 | """ 48 | Args: 49 | lr (float, optional): the learning rate 50 | visdom_env (str or None): the name of the visdom env to use, or None 51 | to disable Visdom 52 | """ 53 | ref_tensor = TF.ToTensor()(ref).unsqueeze(0) 54 | canvas = ParameterizedImg(1, 3, 55 | ref_tensor.shape[2], 56 | ref_tensor.shape[3], 57 | init_img=ref_tensor, 58 | space='spectral', 59 | colors='uncorr') 60 | 61 | def forward(_): 62 | img = canvas() 63 | rnd = random.randint(0, 10) 64 | loss = self.loss(self.norm(img[:, :, rnd:, rnd:])) 65 | loss.backward() 66 | return {'loss': loss, 'img': img} 67 | 68 | loop = Recipe(forward, range(iters)) 69 | loop.register('model', self) 70 | loop.register('canvas', canvas) 71 | loop.callbacks.add_callbacks([ 72 | tcb.Counter(), 73 | tcb.Log('loss', 'loss'), 74 | tcb.Log('img', 'img'), 75 | tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)), 76 | tcb.VisdomLogger(visdom_env=visdom_env, log_every=10), 77 | tcb.StdoutLogger(log_every=10) 78 | ]) 79 | loop.to(device) 80 | loop.run(1) 81 | return canvas.render().cpu() 82 | 83 | 84 | if __name__ == '__main__': 85 | import argparse 86 | import torchvision.models as tvmodels 87 | 88 | models = { 89 | 'vgg': { 90 | 'ctor': tvmodels.vgg19, 91 | 'layer': 'features.28' 92 | }, 93 | 'inception': { 94 | 'ctor': tvmodels.inception_v3, 95 | 'layer': 'Mixed_6c' 96 | }, 97 | 'googlenet': { 98 | 'ctor': tvmodels.googlenet, 99 | 'layer': 'inception4c' 100 | }, 101 | 'resnet': { 102 | 'ctor': tvmodels.resnet18, 103 | 'layer': 'layer3' 104 | } 105 | } 106 | parser = argparse.ArgumentParser(description="DeepDream") 107 | parser.add_argument('--input', required=True) 108 | parser.add_argument('--out', required=True) 109 | parser.add_argument('--device', default='cuda') 110 | parser.add_argument('--model', default='googlenet', choices=models.keys()) 111 | parser.add_argument('--lr', default=3e-4, type=float) 112 | parser.add_argument('--iters', default=4000, type=int) 113 | parser.add_argument('--dream-layer') 114 | parser.add_argument('--visdom-env') 115 | args = parser.parse_args() 116 | 117 | model = models[args.model]['ctor'](pretrained=True) 118 | 119 | print(model) 120 | img = Image.open(args.input) 121 | dd = DeepDream(model, args.dream_layer or models[args.model]['layer']) 122 | out = dd.fit(img, 123 | args.iters, 124 | lr=args.lr, 125 | device=args.device, 126 | visdom_env=args.visdom_env) 127 | 128 | TF.ToPILImage()(out).save(args.out) 129 | -------------------------------------------------------------------------------- /torchelie/nn/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torchelie.utils as tu 7 | 8 | 9 | class Rotary(torch.nn.Module): 10 | _cache = {} 11 | 12 | def __init__(self, dim, base=10000): 13 | super().__init__() 14 | self.dim = dim 15 | self.base = base 16 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 17 | self.inv_freq = nn.Buffer(inv_freq) 18 | 19 | def _cache_key(self, device): 20 | return (self.dim, self.base, device) 21 | 22 | def forward(self, q, k, v, seq_dim=-2): 23 | seq_len = q.shape[seq_dim] 24 | device = q.device 25 | key = self._cache_key(device) 26 | cos_cached, sin_cached, cached_len = self._cache.get(key, (None, None, 0)) 27 | needed_len = seq_len 28 | if cached_len < needed_len or cos_cached is None: 29 | t = torch.arange(needed_len, device=device).type_as(self.inv_freq) 30 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 31 | emb = torch.cat((freqs, freqs), dim=-1).to(device) 32 | cos_cached = emb.cos() 33 | sin_cached = emb.sin() 34 | cached_len = needed_len 35 | self._cache[key] = (cos_cached, sin_cached, cached_len) 36 | cos, sin = cos_cached[:seq_len], sin_cached[:seq_len] 37 | ndim = q.ndim 38 | seq_dim = seq_dim % ndim 39 | # default layout already matches (seq_len, dim) 40 | if seq_dim != ndim - 2: 41 | shape = [1] * ndim 42 | shape[seq_dim] = seq_len 43 | shape[-1] = cos.shape[-1] 44 | cos = cos.reshape(shape) 45 | sin = sin.reshape(shape) 46 | 47 | return self.apply_rotary_pos_emb(q, k, v, cos, sin) 48 | 49 | # rotary pos emb helpers: 50 | 51 | def rotate_half(self, x: torch.Tensor) -> torch.Tensor: 52 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 53 | return torch.cat( 54 | (-x2, x1), dim=x1.ndim - 1 55 | ) # dim=-1 triggers a bug in torch < 1.8.0 56 | 57 | def apply_rotary_pos_emb( 58 | self, 59 | q: torch.Tensor, 60 | k: torch.Tensor, 61 | v: torch.Tensor, 62 | cos: torch.Tensor, 63 | sin: torch.Tensor, 64 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 65 | return ( 66 | (q * cos) + (self.rotate_half(q) * sin), 67 | (k * cos) + (self.rotate_half(k) * sin), 68 | v, 69 | ) 70 | 71 | 72 | class SelfAttention(nn.Module): 73 | """ 74 | Self-attention layer 75 | Assumes input of shape (b, l, hidden_size). Uses scaled dot-product 76 | attention and rotary positional embeddings. 77 | 78 | Args: 79 | hidden_size (int): size of the hidden dimension 80 | num_heads (int): number of heads 81 | head_size (int): size of each head 82 | causal (bool, optional): whether to apply causal masking. Defaults to True. 83 | rotary (bool, optional): whether to apply RoPE. Defaults to True. 84 | """ 85 | 86 | def __init__(self, hidden_size, num_heads, head_size, causal=True, rotary=True): 87 | super().__init__() 88 | self.num_heads = num_heads 89 | self.head_size = head_size 90 | self.qkv = tu.kaiming( 91 | nn.Linear(hidden_size, head_size * num_heads * 3, bias=False) 92 | ) 93 | self.g = tu.kaiming(nn.Linear(hidden_size, num_heads)) 94 | self.fc = tu.xavier(nn.Linear(head_size * num_heads, hidden_size, bias=False)) 95 | self.rotary = Rotary(head_size) if rotary else None 96 | self.causal = causal 97 | 98 | def forward(self, x): 99 | b, l, h, d = x.shape[0], x.shape[1], self.num_heads, self.head_size 100 | # bld -> (q/k/v)bhld 101 | qkv = self.qkv(x).reshape(b, l, 3, h, d).permute(2, 0, 3, 1, 4) 102 | q, k, v = qkv[0], qkv[1], qkv[2] 103 | 104 | with torch.autocast(x.device.type, dtype=torch.float): 105 | if self.rotary is not None: 106 | q, k, v = self.rotary(q.float(), k.float(), v.float()) 107 | 108 | g = self.g(x).permute(0, 2, 1).view(b, h, l, 1) # blh -> bhl1 109 | att = nn.functional.scaled_dot_product_attention( 110 | q, k, v, is_causal=self.causal 111 | ) * torch.sigmoid(g) 112 | att = att.to(x.dtype) 113 | # bhld -> blhd 114 | att = att.permute(0, 2, 1, 3).contiguous().reshape(b, l, h * d) 115 | return self.fc(att) 116 | 117 | def extra_repr(self): 118 | return f"hidden_size={self.qkv.in_features}, num_heads={self.num_heads}, head_size={self.head_size}, causal={self.causal}" 119 | -------------------------------------------------------------------------------- /torchelie/datasets/concat.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchelie.utils import indent 3 | from typing import List, Tuple, Any, Generic, TypeVar, Sequence 4 | from typing import overload 5 | 6 | T = TypeVar('T') 7 | U = TypeVar('U') 8 | 9 | 10 | class CatedSamples(Generic[T]): 11 | 12 | def __init__(self, samples: List[List[Tuple[T, int]]]) -> None: 13 | self.samples = samples 14 | self.n_classes = [ 15 | len(set(samp[1] for samp in sample)) for sample in samples 16 | ] 17 | 18 | def __len__(self) -> int: 19 | return sum(len(ds) for ds in self.samples) 20 | 21 | def __getitem__(self, i: int) -> Tuple[T, int]: 22 | class_offset = 0 23 | for samp, n_class in zip(self.samples, self.n_classes): 24 | if i < len(samp): 25 | return samp[i][0], samp[i][1] + class_offset 26 | i -= len(samp) 27 | class_offset += n_class 28 | raise IndexError 29 | 30 | 31 | class CatedLists(Sequence[T]): 32 | 33 | def __init__(self, ls: List[List[T]]) -> None: 34 | self.ls = ls 35 | 36 | def __len__(self) -> int: 37 | return sum([len(ds) for ds in self.ls]) 38 | 39 | @overload 40 | def __getitem__(self, i: int) -> T: 41 | ... 42 | 43 | @overload 44 | def __getitem__(self, i: slice) -> Sequence[T]: 45 | ... 46 | 47 | def __getitem__(self, i): 48 | if isinstance(i, slice): 49 | return [self[ii] for ii in range(*i.indices(len(self)))] 50 | 51 | for catedlist in self.ls: 52 | if i < len(catedlist): 53 | return catedlist[i] 54 | i -= len(catedlist) 55 | raise IndexError 56 | 57 | 58 | class HorizontalConcatDataset(Dataset): 59 | """ 60 | Concatenates multiple datasets. However, while torchvision's ConcatDataset 61 | just concatenates samples, torchelie's also relabels classes. While a 62 | vertical concat like torchvision's is useful to add more examples per 63 | class, an horizontal concat merges datasets to more classes. 64 | 65 | Args: 66 | datasets (list of Dataset): the datasets to concatenate 67 | """ 68 | classes: CatedLists[str] 69 | 70 | def __init__(self, datasets: List) -> None: 71 | self.datasets = datasets 72 | 73 | self.classes = CatedLists([ds.classes for ds in datasets]) 74 | self.samples = CatedSamples([ds.samples for ds in datasets]) 75 | self.class_to_idx = {nm: i for i, nm in enumerate(self.classes)} 76 | 77 | def __len__(self) -> int: 78 | return len(self.samples) 79 | 80 | def __getitem__(self, i: int) -> Tuple[Any, int]: 81 | class_offset = 0 82 | for ds in self.datasets: 83 | if i < len(ds): 84 | x, t = ds[i] 85 | return x, t + class_offset 86 | i -= len(ds) 87 | class_offset += len(ds.classes) 88 | raise IndexError 89 | 90 | def __repr__(self) -> str: 91 | return "DatasetConcat:\n" + '\n--\n'.join( 92 | [indent(repr(d)) for d in self.datasets]) 93 | 94 | 95 | class MergedSamples: 96 | 97 | def __init__(self, ds) -> None: 98 | self.ds = ds 99 | 100 | def __len__(self) -> int: 101 | return sum(len(d) for d in self.ds.datasets) 102 | 103 | def __getitem__(self, i: int): 104 | for ds in self.ds.datasets: 105 | if i < len(ds): 106 | x, y, *ys = ds.samples[i] 107 | return [x, self.ds.class_to_idx[ds.classes[y]]] + ys 108 | i -= len(ds) 109 | raise IndexError 110 | 111 | 112 | class MergedDataset(Dataset): 113 | 114 | def __init__(self, datasets, transform=None): 115 | self.datasets = datasets 116 | self.classes = list(set(c for d in datasets for c in d.classes)) 117 | self.classes.sort() 118 | self.class_to_idx = {c: i for i, c in enumerate(self.classes)} 119 | 120 | self.samples = MergedSamples(self) 121 | self.transform = transform or (lambda x: x) 122 | 123 | def __len__(self): 124 | return sum(len(d) for d in self.datasets) 125 | 126 | def __getitem__(self, i): 127 | for ds in self.datasets: 128 | if i < len(ds): 129 | x, y, *ys = ds[i] 130 | return [self.transform(x), self.class_to_idx[ds.classes[y]] 131 | ] + ys 132 | i -= len(ds) 133 | raise IndexError 134 | 135 | def __repr__(self): 136 | return "MergedDatasets: \n" + ( 137 | f" num_samples: {len(self)}\n" 138 | f" num_classes: {len(self.classes)}\n") + "\n".join( 139 | [indent(repr(ds)) for ds in self.datasets]) 140 | -------------------------------------------------------------------------------- /torchelie/nn/adain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchelie.utils as tu 4 | 5 | from typing import Optional 6 | 7 | 8 | class AdaIN2d(nn.Module): 9 | """ 10 | Adaptive InstanceNormalization from `*Arbitrary Style Transfer in Real-time 11 | with Adaptive Instance Normalization* (Huang et al, 2017) 12 | `_ 13 | 14 | Args: 15 | channels (int): number of input channels 16 | cond_channels (int): number of conditioning channels from which bias 17 | and scale will be derived 18 | """ 19 | 20 | weight: Optional[torch.Tensor] 21 | bias: Optional[torch.Tensor] 22 | 23 | def __init__(self, channels: int, cond_channels: int) -> None: 24 | super(AdaIN2d, self).__init__() 25 | self.make_weight = nn.Linear(cond_channels, channels) 26 | self.make_bias = nn.Linear(cond_channels, channels) 27 | self.weight = None 28 | self.bias = None 29 | 30 | def forward( 31 | self, x: torch.Tensor, z: Optional[torch.Tensor] = None 32 | ) -> torch.Tensor: 33 | """ 34 | Forward pass 35 | 36 | Args: 37 | x (4D tensor): input tensor 38 | z (2D tensor, optional): conditioning vector. If not present, 39 | :code:`condition(z)` must be called first 40 | 41 | Returns: 42 | x, renormalized 43 | """ 44 | if z is not None: 45 | self.condition(z) 46 | 47 | m = x.mean(dim=(2, 3), keepdim=True) 48 | s = torch.sqrt(x.var(dim=(2, 3), keepdim=True) + 1e-8) 49 | 50 | z_w = self.weight 51 | z_b = self.bias 52 | assert ( 53 | z_w is not None and z_b is not None 54 | ), "AdaIN did not receive a conditioning vector yet" 55 | weight = z_w / (s + 1e-5) 56 | bias = -m * weight + z_b 57 | out = weight * x + bias 58 | return out 59 | 60 | def condition(self, z: torch.Tensor) -> None: 61 | """ 62 | Conditions the layer before the forward pass if z will not be present 63 | when calling forward 64 | 65 | Args: 66 | z (2D tensor, optional): conditioning vector 67 | """ 68 | self.weight = self.make_weight(z)[:, :, None, None] + 1 69 | self.bias = self.make_bias(z)[:, :, None, None] 70 | 71 | 72 | class FiLM2d(nn.Module): 73 | """ 74 | Feature-wise Linear Modulation from 75 | https://distill.pub/2018/feature-wise-transformations/ 76 | The difference with AdaIN is that FiLM does not uses the input's mean and 77 | std in its calculations 78 | 79 | Args: 80 | channels (int): number of input channels 81 | cond_channels (int): number of conditioning channels from which bias 82 | and scale will be derived 83 | """ 84 | 85 | weight: Optional[torch.Tensor] 86 | bias: Optional[torch.Tensor] 87 | 88 | def __init__(self, channels: int, cond_channels: int): 89 | super(FiLM2d, self).__init__() 90 | self.make_weight = nn.Sequential( 91 | tu.constant_init(nn.Linear(cond_channels, channels), 0.02) 92 | ) 93 | self.make_weight[-1].bias.data.fill_(1.0) 94 | 95 | self.make_bias = nn.Sequential( 96 | tu.constant_init(nn.Linear(cond_channels, channels, bias=False), 0.02) 97 | ) 98 | 99 | self.weight = None 100 | self.bias = None 101 | 102 | def forward(self, x, z: Optional[torch.Tensor] = None) -> torch.Tensor: 103 | """ 104 | Forward pass 105 | 106 | Args: 107 | x (4D tensor): input tensor 108 | z (2D tensor, optional): conditioning vector. If not present, 109 | :code:`condition(z)` must be called first 110 | 111 | Returns: 112 | x, conditioned 113 | """ 114 | if z is not None: 115 | self.condition(z) 116 | 117 | w = self.weight 118 | assert w is not None 119 | x = x * w 120 | 121 | b = self.bias 122 | if b is not None: 123 | x = x + b 124 | return x 125 | 126 | def condition(self, z: torch.Tensor) -> None: 127 | """ 128 | Conditions the layer before the forward pass if z will not be present 129 | when calling forward 130 | 131 | Args: 132 | z (2D tensor, optional): conditioning vector 133 | """ 134 | self.weight = self.make_weight(z)[:, :, None, None] 135 | self.bias = self.make_bias(z)[:, :, None, None] 136 | 137 | 138 | class FiLM(FiLM2d): 139 | def condition(self, z: torch.Tensor) -> None: 140 | """ 141 | Conditions the layer before the forward pass if z will not be present 142 | when calling forward 143 | 144 | Args: 145 | z (2D tensor, optional): conditioning vector 146 | """ 147 | self.weight = self.make_weight(z) 148 | self.bias = self.make_bias(z) 149 | -------------------------------------------------------------------------------- /torchelie/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchelie.utils as tu 3 | import torch.nn as nn 4 | import torchelie.nn as tnn 5 | from torchelie.models import ClassificationHead 6 | from typing import Optional 7 | from collections import OrderedDict 8 | 9 | Block = tnn.PreactResBlockBottleneck 10 | 11 | 12 | class UBlock(nn.Module): 13 | inner: Optional[nn.Module] 14 | skip: Optional[nn.Module] 15 | encode: nn.Module 16 | 17 | @tu.experimental 18 | def __init__(self, 19 | ch: int, 20 | inner: Optional[nn.Module], 21 | with_skip: bool = True) -> None: 22 | super(UBlock, self).__init__() 23 | self.inner = inner 24 | if with_skip and inner is not None: 25 | self.skip = Block(ch, ch) 26 | else: 27 | self.skip = None 28 | self.encode = tnn.CondSeq(nn.MaxPool2d(3, 1, 1), 29 | nn.UpsamplingBilinear2d(scale_factor=0.5), 30 | Block(ch, ch)) 31 | self.decode = tnn.CondSeq(Block(ch, ch), 32 | nn.UpsamplingBilinear2d(scale_factor=2)) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | e = self.encode(x) 36 | if self.inner is not None: 37 | e2 = self.inner(e) 38 | else: 39 | e2 = e 40 | 41 | if self.skip is not None: 42 | e2 += self.skip(e) 43 | 44 | return self.decode(e2) 45 | 46 | 47 | class UBlock1(nn.Module): 48 | @tu.experimental 49 | def __init__(self, ch): 50 | super(UBlock1, self).__init__() 51 | self.inner = tnn.CondSeq(nn.MaxPool2d(3, 1, 1), 52 | nn.UpsamplingBilinear2d(scale_factor=0.5), 53 | Block(ch, ch), 54 | nn.UpsamplingBilinear2d(scale_factor=2)) 55 | 56 | def forward(self, x): 57 | return self.inner(x) 58 | 59 | 60 | class AttentionBlock(nn.Module): 61 | mask: Optional[tnn.CondSeq] 62 | 63 | def __init__(self, 64 | ch: int, 65 | n_down: int, 66 | n_trunk: int = 2, 67 | n_post: int = 1, 68 | n_pre: int = 1, 69 | n_att_conv: int = 2, 70 | with_skips: bool = True) -> None: 71 | super(AttentionBlock, self).__init__() 72 | self.pre = tnn.CondSeq(*[Block(ch, ch) for _ in range(n_pre)]) 73 | self.post = tnn.CondSeq(*[Block(ch, ch) for _ in range(n_post)]) 74 | self.trunk = tnn.CondSeq(*[Block(ch, ch) for _ in range(n_trunk)]) 75 | 76 | soft: nn.Module = UBlock1(ch) 77 | for _ in range(n_down - 1): 78 | soft = UBlock(ch, soft, with_skip=with_skips) 79 | 80 | if n_down >= 0: 81 | conv1 = [soft] 82 | for i in range(n_att_conv): 83 | conv1 += [ 84 | nn.BatchNorm2d(ch), 85 | nn.ReLU(True), 86 | tu.kaiming(tnn.Conv1x1(ch, ch, bias=(i != n_att_conv - 1))) 87 | ] 88 | conv1.append(nn.Sigmoid()) 89 | 90 | self.mask = tnn.CondSeq(*conv1) 91 | else: 92 | self.mask = None 93 | 94 | def forward(self, x: torch.Tensor) -> torch.Tensor: 95 | x = self.pre(x) 96 | t = self.trunk(x) 97 | if self.mask is not None: 98 | t = t * (self.mask(x) + 1) 99 | return self.post(t) 100 | 101 | 102 | class Attention56Bone(tnn.CondSeq): 103 | """ 104 | Attention56 bone 105 | 106 | Args: 107 | in_ch (int): number of channels in the images 108 | """ 109 | @tu.experimental 110 | def __init__(self, num_classes: int) -> None: 111 | super(Attention56Bone, self).__init__( 112 | OrderedDict([ 113 | ('head', 114 | tnn.CondSeq(tu.kaiming(tnn.Conv2d(3, 64, 7, stride=2)), 115 | nn.ReLU(True), nn.MaxPool2d(3, 2, 1))), 116 | ('pre1', Block(64, 256)), ('attn1', AttentionBlock(256, 3)), 117 | ('pre2', Block(256, 512, stride=2)), 118 | ('attn2', AttentionBlock(512, 2)), 119 | ('pre3', Block(512, 1024, stride=2)), 120 | ('attn3', AttentionBlock(1024, 1)), 121 | ('pre4', 122 | tnn.CondSeq( 123 | Block(1024, 2048, stride=2), 124 | Block(2048, 2048), 125 | Block(2048, 2048), 126 | )), 127 | ('classifier', ClassificationHead(2048, num_classes)) 128 | ])) 129 | 130 | 131 | @tu.experimental 132 | def attention56(num_classes): 133 | """ 134 | Build a attention56 network 135 | 136 | Args: 137 | num_classes (int): number of classes 138 | in_ch (int): number of channels in the images 139 | """ 140 | return Attention56Bone(num_classes) 141 | -------------------------------------------------------------------------------- /torchelie/datasets/debug.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Callable 3 | from typing_extensions import Literal 4 | import torch 5 | import torchvision.transforms as TF 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets import ImageFolder 8 | 9 | from torchvision.datasets.utils import download_and_extract_archive 10 | 11 | __all__ = ['ColoredColumns', 'ColoredRows', 'Imagenette', 'Imagewoof'] 12 | 13 | 14 | class ColoredColumns(Dataset): 15 | """ 16 | A dataset of precedurally generated images of columns randomly colorized. 17 | 18 | Args: 19 | *size (int): size of images 20 | transform (transforms or None): the image transforms to apply to the 21 | generated pictures 22 | """ 23 | 24 | def __init__(self, *size, transform=None) -> None: 25 | super(ColoredColumns, self).__init__() 26 | self.size = size 27 | self.transform = transform if transform is not None else (lambda x: x) 28 | 29 | def __len__(self): 30 | return 10000 31 | 32 | def __getitem__(self, i): 33 | cols = torch.randint(0, 255, (3, 1, self.size[1])) 34 | expanded = cols.expand(3, *self.size).float() 35 | img = TF.ToPILImage()(expanded / 255) 36 | return self.transform(img), 0 37 | 38 | 39 | class ColoredRows(Dataset): 40 | """ 41 | A dataset of precedurally generated images of rows randomly colorized. 42 | 43 | Args: 44 | *size (int): size of images 45 | transform (transforms or None): the image transforms to apply to the 46 | generated pictures 47 | """ 48 | 49 | def __init__(self, *size, transform=None) -> None: 50 | super(ColoredRows, self).__init__() 51 | self.size = size 52 | self.transform = transform if transform is not None else (lambda x: x) 53 | 54 | def __len__(self): 55 | return 10000 56 | 57 | def __getitem__(self, i): 58 | rows = torch.randint(0, 255, (3, self.size[0], 1)) 59 | expanded = rows.expand(3, *self.size).float() 60 | img = TF.ToPILImage()(expanded / 255) 61 | return self.transform(img), 0 62 | 63 | 64 | class Imagenette(ImageFolder): 65 | """ 66 | Imagenette by Jeremy Howards ( https://github.com/fastai/imagenette ). 67 | 68 | Args: 69 | root (str): root directory 70 | split (bool): if False, use validation split 71 | transform (Callable): image transforms 72 | download (bool): if True and root empty, download the dataset 73 | version (str): which resolution to download ('full', '32Opx', '160px') 74 | 75 | """ 76 | 77 | def __init__(self, 78 | root: str, 79 | train: bool, 80 | transform: Optional[Callable] = None, 81 | download: bool = False, 82 | version: Literal['full', '320px', '160px'] = '320px'): 83 | size = ({ 84 | 'full': 'imagenette2', 85 | '320px': 'imagenette2-320', 86 | '160px': 'imagenette2-160' 87 | })[version] 88 | 89 | split = 'train' if train else 'val' 90 | if not self._check_integrity(f'{root}/{size}') and download: 91 | download_and_extract_archive( 92 | f'https://s3.amazonaws.com/fast-ai-imageclas/{size}.tgz', 93 | root, 94 | remove_finished=True) 95 | 96 | super().__init__(f'{root}/{size}/{split}', transform=transform) 97 | 98 | def _check_integrity(self, path): 99 | return os.path.exists(os.path.expanduser(path)) 100 | 101 | 102 | class Imagewoof(ImageFolder): 103 | """ 104 | Imagewoof by Jeremy Howards ( https://github.com/fastai/imagenette ). 105 | 106 | Args: 107 | root (str): root directory 108 | split (bool): if False, use validation split 109 | transform (Callable): image transforms 110 | download (bool): if True and root empty, download the dataset 111 | version (str): which resolution to download ('full', '32Opx', '160px') 112 | 113 | """ 114 | 115 | def __init__(self, 116 | root: str, 117 | train: bool, 118 | transform: Optional[Callable] = None, 119 | download: bool = False, 120 | version: Literal['full', '320px', '160px'] = '320px'): 121 | size = ({ 122 | 'full': 'imagewoof2', 123 | '320px': 'imagewoof2-320', 124 | '160px': 'imagewoof2-160' 125 | })[version] 126 | 127 | split = 'train' if train else 'val' 128 | if not self._check_integrity(f'{root}/{size}') and download: 129 | download_and_extract_archive( 130 | f'https://s3.amazonaws.com/fast-ai-imageclas/{size}.tgz', 131 | root, 132 | remove_finished=True) 133 | 134 | super().__init__(f'{root}/{size}/{split}', transform=transform) 135 | 136 | def _check_integrity(self, path): 137 | return os.path.exists(os.path.expanduser(path)) 138 | -------------------------------------------------------------------------------- /torchelie/nn/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torchelie.utils as tu 5 | from .conv import Conv1x1 6 | from ..nn.condseq import CondSeq 7 | from ..nn.llm import SelfAttention 8 | from .functional.transformer import local_attention_2d 9 | 10 | from typing import Optional 11 | from typing_extensions import Literal 12 | 13 | 14 | class LocalSelfAttentionHook(nn.Module): 15 | 16 | def forward(self, x, attn, pad): 17 | return x, attn, pad 18 | 19 | 20 | class LocalSelfAttention2d(nn.Module): 21 | 22 | def __init__(self, 23 | in_channels: int, 24 | num_heads: int, 25 | kernel_size: int, 26 | hidden_channels: Optional[int] = None, 27 | padding_mode: Literal['none', 'auto'] = 'none'): 28 | """ 29 | Args: 30 | in_channels (int): number of input channels 31 | num_heads (int): how many self attention heads 32 | kernel_size (int): the self attention window size. Must divide 33 | input size if padding_mode is 'none'. 34 | hidden_channels (int): how many channels *per head*. 35 | padding_mode (str): if 'none', no padding is used and kernel_size 36 | must divide the input spatial size. If 'auto', zero padding 37 | will be used to center the input feature map, and make it 38 | a multiple of kernel_size 39 | """ 40 | super().__init__() 41 | hidden_channels = hidden_channels or (in_channels // num_heads) 42 | self.hidden_channels = hidden_channels 43 | self.proj = tu.kaiming( 44 | Conv1x1(in_channels, hidden_channels * num_heads * 3, bias=False)) 45 | self.position = nn.Parameter( 46 | torch.zeros(num_heads, 2 * kernel_size, 2 * kernel_size)) 47 | self.out = tu.kaiming(Conv1x1(hidden_channels * num_heads, in_channels)) 48 | self.kernel_size = kernel_size 49 | self.num_heads = num_heads 50 | self.attn_hook = LocalSelfAttentionHook() 51 | self.padding_mode = padding_mode 52 | 53 | def unfolded_posenc(self): 54 | w = self.kernel_size 55 | pos = torch.tensor([[x, y] for x in range(w) for y in range(w)]) 56 | pos = pos[None, :, :] - pos[:, None, :] 57 | pos += w 58 | return self.position[:, pos[:, :, 0], pos[:, :, 1]] 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | H, W, P = x.shape[2], x.shape[3], self.kernel_size 62 | pad = (-(W % -P), -(H % -P)) 63 | assert self.padding_mode == 'auto' or pad == (0, 0), ( 64 | f'kernel_size {self.kernel_size} does not divide size {W}x{H} ' 65 | 'and padding_mode="none" specified') 66 | pad = (pad[0] // 2, pad[0] - pad[0] // 2, pad[1] // 2, 67 | pad[1] - pad[1] // 2) 68 | if pad != (0, 0, 0, 0): 69 | x = F.pad(x, pad) 70 | 71 | x, attn = local_attention_2d(x, self.proj, self.unfolded_posenc(), 72 | self.num_heads, self.kernel_size) 73 | self.attn_hook(x, attn, pad) 74 | 75 | if pad != (0, 0, 0, 0): 76 | x = x[:, :, pad[2]:H + pad[2], pad[0]:W + pad[0]] 77 | 78 | x = self.out(x) 79 | return x 80 | 81 | 82 | class ViTBlock(nn.Module): 83 | """ 84 | Vision Transformer (ViT) block consisting of a self-attention layer and a feed-forward MLP, 85 | each followed by RMS normalization and gated residual connections. 86 | 87 | Args: 88 | d_model (int): Dimension of the model. 89 | num_heads (int): Number of attention heads. 90 | 91 | Forward Args: 92 | x (Tensor): Input tensor of shape [B, L, d_model]. 93 | z (Any): Optional conditioning input for CondSeq modules. 94 | 95 | Returns: 96 | Tensor: Output tensor of shape [B, L, d_model]. 97 | """ 98 | 99 | def __init__(self, d_model, num_heads): 100 | super().__init__() 101 | self.sa = CondSeq( 102 | nn.RMSNorm(d_model), 103 | SelfAttention( 104 | d_model, 105 | num_heads, 106 | head_size=d_model // num_heads, 107 | causal=False, 108 | rotary=True, 109 | ), 110 | ) 111 | self.mlp = CondSeq( 112 | nn.RMSNorm(d_model), 113 | tu.kaiming(nn.Linear(d_model, 4 * d_model)), 114 | nn.GELU(), 115 | tu.kaiming(nn.Linear(4 * d_model, d_model)), 116 | ) 117 | self.g1 = nn.Parameter(torch.zeros(d_model)) 118 | self.g2 = nn.Parameter(torch.zeros(d_model)) 119 | 120 | def forward(self, x): 121 | """ 122 | Forward pass for the ViTBlock. 123 | 124 | Args: 125 | x (Tensor): Input tensor of shape [B, L, d_model]. 126 | 127 | Returns: 128 | Tensor: Output tensor of shape [B, L, d_model]. 129 | """ 130 | x = self.sa(x) * self.g1 + x 131 | x = self.mlp(x) * self.g2 + x 132 | return x 133 | --------------------------------------------------------------------------------