├── .github
└── workflows
│ ├── build_doc.yml
│ ├── deploy_pypi.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.rst
├── doc
├── Makefile
├── _static
│ ├── FLOP.pdf
│ ├── TRL.png
│ ├── cp-conv.png
│ ├── deep_tensor_nets_pros_circle.png
│ ├── domain-adaptation.png
│ ├── favicon
│ │ ├── android-chrome-192x192.png
│ │ ├── android-chrome-256x256.png
│ │ ├── apple-touch-icon.png
│ │ ├── favicon-16x16.png
│ │ ├── favicon-32x32.png
│ │ ├── favicon.ico
│ │ ├── safari-pinned-tab.svg
│ │ └── site.webmanifest
│ ├── logos
│ │ ├── logo_caltech.png
│ │ ├── logo_nvidia.png
│ │ └── tensorly-torch-logo.png
│ ├── mobilenet-conv.pdf
│ ├── mobilenet-v2-conv.pdf
│ ├── tensorly-torch-pyramid.png
│ ├── transduction.png
│ └── tucker-conv.pdf
├── _templates
│ ├── class.rst
│ └── function.rst
├── about.rst
├── conf.py
├── dev_guide
│ ├── api.rst
│ ├── contributing.rst
│ └── index.rst
├── index.rst
├── install.rst
├── make.bat
├── minify.py
├── modules
│ └── api.rst
├── requirements_doc.txt
└── user_guide
│ ├── factorized_conv.rst
│ ├── factorized_embeddings.rst
│ ├── factorized_tensors.rst
│ ├── index.rst
│ ├── tensor_hooks.rst
│ ├── tensorized_linear.rst
│ └── trl.rst
├── requirements.txt
├── setup.py
└── tltorch
├── __init__.py
├── factorized_layers
├── __init__.py
├── factorized_convolution.py
├── factorized_embedding.py
├── factorized_linear.py
├── tensor_contraction_layers.py
├── tensor_regression_layers.py
└── tests
│ ├── __init__.py
│ ├── test_factorized_convolution.py
│ ├── test_factorized_embedding.py
│ ├── test_factorized_linear.py
│ ├── test_tensor_contraction_layers.py
│ └── test_trl.py
├── factorized_tensors
├── __init__.py
├── complex_factorized_tensors.py
├── complex_tensorized_matrices.py
├── core.py
├── factorized_tensors.py
├── init.py
├── tensorized_matrices.py
└── tests
│ ├── __init__.py
│ └── test_factorizations.py
├── functional
├── __init__.py
├── convolution.py
├── factorized_linear.py
├── factorized_tensordot.py
├── linear.py
├── tensor_regression.py
└── tests
│ ├── __init__.py
│ └── test_factorized_linear.py
├── tensor_hooks
├── __init__.py
├── _tensor_dropout.py
├── _tensor_lasso.py
└── tests
│ ├── __init__.py
│ ├── test_tensor_dropout.py
│ └── test_tensor_lasso.py
└── utils
├── __init__.py
├── parameter_list.py
└── tensorize_shape.py
/.github/workflows/build_doc.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 |
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: Checkout code
15 | uses: actions/checkout@v3
16 | - name: Install Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: 3.9
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | python -m pip install -r requirements.txt
24 | python -m pip install -r doc/requirements_doc.txt
25 | python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
26 | - name: Install package
27 | run: |
28 | python -m pip install -e .
29 | - name: Make doc
30 | run: |
31 | cd doc
32 | python minify.py
33 | make html
34 | cd ..
35 | - name: Push docs
36 | run: |
37 | # See https://github.community/t/github-actions-bot-email-address/17204/5
38 | git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
39 | git config --global user.name "github-actions"
40 | git fetch origin gh-pages
41 | git checkout gh-pages
42 | git rm -r dev/*
43 | cp -r doc/_build/html/* dev
44 | git add dev
45 | # If the doc is up to date, the script shouldn't fail, hence --allow-empty
46 | # Might be a cleaner way to check
47 | git commit --allow-empty -m "Deployed to GitHub Pages"
48 | git push --force origin gh-pages
49 |
--------------------------------------------------------------------------------
/.github/workflows/deploy_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Deploy to Pypi
2 |
3 | on:
4 | push:
5 | tags:
6 | - '*'
7 |
8 | jobs:
9 | build:
10 |
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: Checkout code
15 | uses: actions/checkout@v3
16 | - name: Install Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: 3.9
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | python -m pip install -r requirements.txt
24 | python -m pip install -r doc/requirements_doc.txt
25 | python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
26 | - name: Install package
27 | run: |
28 | python -m pip install -e .
29 | pip install setuptools wheel
30 | - name: Build a binary wheel and a source tarball
31 | run: |
32 | python setup.py sdist bdist_wheel
33 | - name: Publish package to TestPyPI
34 | uses: pypa/gh-action-pypi-publish@master
35 | with:
36 | user: __token__
37 | password: ${{ secrets.TEST_PYPI_PASSWORD }}
38 | repository_url: https://test.pypi.org/legacy/
39 | - name: Publish package to PyPI
40 | uses: pypa/gh-action-pypi-publish@master
41 | with:
42 | user: __token__
43 | password: ${{ secrets.PYPI_PASSWORD }}
44 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test TensorLy-Torch
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v3
12 | - name: Set up Python 3.12
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: 3.12
16 | - name: Install dependencies
17 | run: |
18 | echo "* Updating pip"
19 | python -m pip install --upgrade pip
20 | echo "* Installing requirements"
21 | python -m pip install -r requirements.txt
22 | echo "* Installing documentation requirements"
23 | python -m pip install -r doc/requirements_doc.txt
24 | echo "* Installing PyTorch"
25 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
26 | - name: Install TensorLy dev
27 | run: |
28 | CWD=`pwd`
29 | echo 'Cloning TensorLy in ${CWD}'
30 | mkdir git_repos
31 | cd git_repos
32 | git clone https://github.com/tensorly/tensorly
33 | cd tensorly
34 | python -m pip install -e .
35 | cd ..
36 | - name: Install package
37 | run: |
38 | python -m pip install -e .
39 | - name: Test with pytest
40 | run: |
41 | pytest -vvv tltorch
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.DS_Store
6 | *.vscode
7 |
8 | # C extensions
9 | *.so
10 | *.py~
11 |
12 | # Pycharm
13 | .idea
14 |
15 | # vim temp files
16 | *.swp
17 |
18 | # Sphinx doc
19 | doc/_build/
20 | doc/auto_examples/
21 | doc/modules/generated/
22 |
23 | # Distribution / packaging
24 | .Python
25 | env/
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | .pytest_cache/
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *,cover
61 | .hypothesis/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | #Ipython Notebook
77 | .ipynb_checkpoints
78 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, The TensorLy-Torch Developers
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | .. image:: https://badge.fury.io/py/tensorly-torch.svg
2 | :target: https://badge.fury.io/py/tensorly-torch
3 |
4 |
5 | ==============
6 | TensorLy-Torch
7 | ==============
8 |
9 | TensorLy-Torch is a Python library for deep tensor networks that
10 | builds on top of `TensorLy `_
11 | and `PyTorch `_.
12 | It allows to easily leverage tensor methods in a deep learning setting and comes with all batteries included.
13 |
14 | - **Website:** http://tensorly.org/torch/
15 | - **Source-code:** https://github.com/tensorly/torch
16 |
17 |
18 | With TensorLy-Torch, you can easily:
19 |
20 | - **Tensor Factorizations**: decomposing, manipulating and initializing tensor decompositions can be tricky. We take care of it all, in a convenient, unified API.
21 | - **Leverage structure in your data**: with tensor layers, you can easily leverage the structure in your data, through Tensor Regression Layers, Factorized Convolutions, etc
22 | - **Built-in tensor layers**: all you have to do is import tensorly torch and include the layers we provide directly within your PyTorch models!
23 | - **Tensor hooks**: you can easily augment your architectures with our built-in Tensor Hooks. Robustify your network with Tensor Dropout and automatically select the rank end-to-end with L1 Regularization!
24 | - **All the methods available**: we are always adding more methods to make it easy to compare between the performance of various deep tensor based methods!
25 |
26 | Deep Tensorized Learning
27 | ========================
28 |
29 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors.
30 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible.
31 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks.
32 |
33 | .. image:: ./doc/_static/tensorly-torch-pyramid.png
34 |
35 | TensorLy is a Python library that aims at making tensor learning simple and accessible.
36 | It provides a high-level API for tensor methods, including core tensor operations, tensor decomposition and regression.
37 | It has a flexible backend that allows running operations seamlessly using NumPy, PyTorch, TensorFlow, JAX, MXNet and CuPy.
38 |
39 | **TensorLy-Torch** is a PyTorch only library that builds on top of TensorLy and provides out-of-the-box tensor layers.
40 |
41 | Improve your neural networks with tensor methods
42 | ------------------------------------------------
43 |
44 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors.
45 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible.
46 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks.
47 |
48 | In TensorLy-Torch, we provide convenient layers that do all the heavy lifting for you
49 | and provide the benefits tensor based layers wrapped in a nice, well documented and tested API.
50 |
51 | For instance, convolution layers of any order (2D, 3D or more), can be efficiently parametrized
52 | using tensor decomposition. Using a CP decomposition results in a separable convolution
53 | and you can replace your original convolution with a series of small efficient ones:
54 |
55 | .. image:: ./doc/_static/cp-conv.png
56 |
57 | These can be easily perform with FactorizedConv in TensorLy-Torch.
58 | We also have Tucker convolutions and new tensor-train convolutions!
59 | We also implement various other methods such as tensor regression and contraction layers,
60 | tensorized linear layers, tensor dropout and more!
61 |
62 |
63 | Installing TensorLy-Torch
64 | =========================
65 |
66 | Through pip
67 | -----------
68 |
69 | .. code::
70 |
71 | pip install tensorly-torch
72 |
73 |
74 | From source
75 | -----------
76 |
77 | .. code::
78 |
79 | git clone https://github.com/tensorly/torch
80 | cd torch
81 | pip install -e .
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
22 | .PHONY: debug
23 | debug:
24 | @$(SPHINXBUILD) -M "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -vv
25 |
26 |
--------------------------------------------------------------------------------
/doc/_static/FLOP.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/FLOP.pdf
--------------------------------------------------------------------------------
/doc/_static/TRL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/TRL.png
--------------------------------------------------------------------------------
/doc/_static/cp-conv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/cp-conv.png
--------------------------------------------------------------------------------
/doc/_static/deep_tensor_nets_pros_circle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/deep_tensor_nets_pros_circle.png
--------------------------------------------------------------------------------
/doc/_static/domain-adaptation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/domain-adaptation.png
--------------------------------------------------------------------------------
/doc/_static/favicon/android-chrome-192x192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/android-chrome-192x192.png
--------------------------------------------------------------------------------
/doc/_static/favicon/android-chrome-256x256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/android-chrome-256x256.png
--------------------------------------------------------------------------------
/doc/_static/favicon/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/apple-touch-icon.png
--------------------------------------------------------------------------------
/doc/_static/favicon/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/favicon-16x16.png
--------------------------------------------------------------------------------
/doc/_static/favicon/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/favicon-32x32.png
--------------------------------------------------------------------------------
/doc/_static/favicon/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/favicon.ico
--------------------------------------------------------------------------------
/doc/_static/favicon/safari-pinned-tab.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
7 |
8 | Created by potrace 1.11, written by Peter Selinger 2001-2013
9 |
10 |
12 |
17 |
20 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/doc/_static/favicon/site.webmanifest:
--------------------------------------------------------------------------------
1 | {
2 | "name": "",
3 | "short_name": "",
4 | "icons": [
5 | {
6 | "src": "/android-chrome-192x192.png",
7 | "sizes": "192x192",
8 | "type": "image/png"
9 | },
10 | {
11 | "src": "/android-chrome-256x256.png",
12 | "sizes": "256x256",
13 | "type": "image/png"
14 | }
15 | ],
16 | "theme_color": "#ffffff",
17 | "background_color": "#ffffff",
18 | "display": "standalone"
19 | }
20 |
--------------------------------------------------------------------------------
/doc/_static/logos/logo_caltech.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/logos/logo_caltech.png
--------------------------------------------------------------------------------
/doc/_static/logos/logo_nvidia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/logos/logo_nvidia.png
--------------------------------------------------------------------------------
/doc/_static/logos/tensorly-torch-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/logos/tensorly-torch-logo.png
--------------------------------------------------------------------------------
/doc/_static/mobilenet-conv.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/mobilenet-conv.pdf
--------------------------------------------------------------------------------
/doc/_static/mobilenet-v2-conv.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/mobilenet-v2-conv.pdf
--------------------------------------------------------------------------------
/doc/_static/tensorly-torch-pyramid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/tensorly-torch-pyramid.png
--------------------------------------------------------------------------------
/doc/_static/transduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/transduction.png
--------------------------------------------------------------------------------
/doc/_static/tucker-conv.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/tucker-conv.pdf
--------------------------------------------------------------------------------
/doc/_templates/class.rst:
--------------------------------------------------------------------------------
1 | :mod:`{{module}}`.{{objname}}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 | :members:
8 |
9 | .. raw:: html
10 |
11 |
12 |
--------------------------------------------------------------------------------
/doc/_templates/function.rst:
--------------------------------------------------------------------------------
1 | :mod:`{{module}}`.{{objname}}
2 | {{ underline }}====================
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autofunction:: {{ objname }}
7 |
8 | .. raw:: html
9 |
10 |
11 |
--------------------------------------------------------------------------------
/doc/about.rst:
--------------------------------------------------------------------------------
1 | .. _about_us:
2 |
3 | About Us
4 | ========
5 |
6 | TensorLy-Torch is an open-source effort, led primarily at NVIDIA Research by `Jean Kossaifi`_.
7 | It is part of the TensorLy project and builds on top of the core TensorLy in order to provide out-of-the box PyTorch tensor layers for deep learning.
8 |
9 |
10 | Core team
11 | ---------
12 |
13 | * `Anima Anandkumar`_
14 | * `Wonmin Byeon`_
15 | * `Jean Kossaifi`_
16 | * `Saurav Muralidharan`_
17 |
18 | Supporters
19 | ----------
20 |
21 | The TensorLy project is supported by:
22 |
23 | .. image:: _static/logos/logo_nvidia.png
24 | :width: 150pt
25 | :align: center
26 | :target: https://www.nvidia.com
27 | :alt: NVIDIA
28 |
29 | ........
30 |
31 | .. image:: _static/logos/logo_caltech.png
32 | :width: 150pt
33 | :align: center
34 | :target: https://www.caltech.edu
35 | :alt: California Institute of Technology
36 |
37 |
38 | .. _Jean Kossaifi: http://jeankossaifi.com/
39 | .. _Anima Anandkumar: http://tensorlab.cms.caltech.edu/users/anima/
40 | .. _Wonmin Byeon: https://wonmin-byeon.github.io/
41 | .. _Saurav Muralidharan: https://www.sauravm.com
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 |
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('..'))
16 | # sys.path.insert(0, os.path.abspath('../examples'))
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'TensorLy-Torch'
21 | from datetime import datetime
22 | year = datetime.now().year
23 | copyright = f'{year}, Jean Kossaifi'
24 | author = 'Jean Kossaifi'
25 |
26 | # The full version, including alpha/beta/rc tags
27 | import tltorch
28 | release = tltorch.__version__
29 |
30 |
31 | # -- General configuration ---------------------------------------------------
32 |
33 | # Add any Sphinx extension module names here, as strings. They can be
34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
35 | # ones.
36 | extensions = [
37 | 'sphinx.ext.autodoc',
38 | 'sphinx.ext.autosummary',
39 | 'sphinx.ext.todo',
40 | 'sphinx.ext.viewcode',
41 | 'sphinx.ext.githubpages',
42 | # "nbsphinx",
43 | # "myst_nb",
44 | # "sphinx_nbexamples",
45 | # 'jupyter_sphinx',
46 | # 'matplotlib.sphinxext.plot_directive',
47 | 'sphinx.ext.mathjax', #'sphinx.ext.imgmath',
48 | 'numpydoc.numpydoc',
49 | ]
50 |
51 | # # # Sphinx-nbexamples
52 | # process_examples = False
53 | # example_gallery_config = dict(pattern='+/+.ipynb')
54 |
55 | # Remove the permalinks ("¶" symbols)
56 | html_add_permalinks = ""
57 |
58 | # NumPy
59 | numpydoc_class_members_toctree = False
60 | numpydoc_show_class_members = True
61 | numpydoc_show_inherited_class_members = False
62 |
63 | # generate autosummary even if no references
64 | autosummary_generate = True
65 | autodoc_member_order = 'bysource'
66 | autodoc_default_flags = ['members']
67 |
68 | # Napoleon
69 | napoleon_google_docstring = False
70 | napoleon_use_rtype = False
71 |
72 | # imgmath/mathjax
73 | imgmath_image_format = 'svg'
74 |
75 | # The suffix(es) of source filenames.
76 | # You can specify multiple suffix as a list of string:
77 | # source_suffix = ['.rst', '.md']
78 | source_suffix = '.rst'
79 |
80 |
81 | # The master toctree document.
82 | master_doc = 'index'
83 |
84 | # Add any paths that contain templates here, relative to this directory.
85 | templates_path = ['_templates']
86 |
87 | # List of patterns, relative to source directory, that match files and
88 | # directories to ignore when looking for source files.
89 | # This pattern also affects html_static_path and html_extra_path.
90 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
91 |
92 |
93 | primary_domain = 'py'
94 |
95 | # -- Options for HTML output -------------------------------------------------
96 |
97 | # The theme to use for HTML and HTML Help pages. See the documentation for
98 | # a list of builtin themes.
99 | html_theme = 'tensorly_sphinx_theme'
100 | html_logo = '_static/logos/tensorly-torch-logo.png'
101 |
102 | # Add any paths that contain custom static files (such as style sheets) here,
103 | # relative to this directory. They are copied after the builtin static files,
104 | # so a file named "default.css" will overwrite the builtin "default.css".
105 | html_static_path = ['_static']
106 |
107 | html_theme_options = {
108 | 'github_url': 'https://github.com/tensorly/torch',
109 | 'google_analytics' : 'G-QSPLEF75VT',
110 | 'searchbar_text': 'Search TensorLy-Torch',
111 | 'nav_links' : [('Install', 'install'),
112 | ('User Guide', 'user_guide/index'),
113 | ('API', 'modules/api'),
114 | ('About Us', 'about')],
115 | 'nav_dropdowns' : [('Ecosystem',
116 | [('TensorLy', 'http://tensorly.org/dev'),
117 | ('TensorLy-Viz', 'http://tensorly.org/viz'),
118 | ('TensorLy-Quantum', 'http://tensorly.org/quantum'),
119 | ]),
120 | ],
121 | }
122 |
123 | # -- Options for LaTeX output ---------------------------------------------
124 |
125 | # Grouping the document tree into LaTeX files. List of tuples
126 | # (source start file, target name, title,
127 | # author, documentclass [howto, manual, or own class]).
128 | latex_documents = [
129 | (master_doc, 'tensorly.tex', 'Tensor operations in Python',
130 | 'Jean Kossaifi', 'manual'),
131 | ]
132 |
133 | latex_preamble = r"""
134 | \usepackage{amsmath}\usepackage{amsfonts}
135 | \setcounter{MaxMatrixCols}{20}
136 | """
137 |
138 | imgmath_latex_preamble = latex_preamble
139 |
140 | latex_elements = {
141 | 'classoptions': ',oneside',
142 | 'babel': '\\usepackage[english]{babel}',
143 | # Get completely rid of index
144 | 'printindex': '',
145 | 'preamble': latex_preamble,
146 | }
147 |
--------------------------------------------------------------------------------
/doc/dev_guide/api.rst:
--------------------------------------------------------------------------------
1 | =============
2 | Style and API
3 | =============
4 |
5 | In TensorLy-Torch (and more generally in the TensorLy project),
6 | we try to maintain a simple and consistent API.
7 |
8 | Here are some elements to consider *Coming Soon*
--------------------------------------------------------------------------------
/doc/dev_guide/contributing.rst:
--------------------------------------------------------------------------------
1 | Contributing
2 | ============
3 |
4 | We actively welcome new contributions!
5 | If you know of a tensor method that should be in TensorLy-Torch
6 | or if you spot any mistake, typo, missing documentation etc,
7 | please report it, and even better, open a Pull-Request!
8 |
9 | How-to
10 | ------
11 |
12 | To make sure the contribution is relevant and is not already worked on,
13 | you can `open an issue `_!
14 |
15 | To add code of fix issues in TensorLy,
16 | you will want to open a `Pull-Request `_
17 | on the Github repository of the project.
18 |
19 | Guidelines
20 | ----------
21 |
22 | For each function or class, we expect helpful docstrings in the NumPy format,
23 | as well as unit-tests to make sure it is working as expected
24 | (especially helpful for future refactoring to make sure no exising code is broken!)
25 |
26 | Check the existing code for examples,
27 | and don't hesitate to contact us if you are unsure or have any question!
28 |
29 |
--------------------------------------------------------------------------------
/doc/dev_guide/index.rst:
--------------------------------------------------------------------------------
1 | .. _dev_guide:
2 |
3 | =================
4 | Development guide
5 | =================
6 |
7 | .. toctree::
8 |
9 | contributing.rst
10 | api.rst
11 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | :no-toc:
2 | :no-localtoc:
3 | :no-pagination:
4 |
5 | .. TensorLy-Torch documentation
6 |
7 | .. only:: html
8 |
9 | .. raw:: html
10 |
11 |
12 |
13 | .. image:: _static/logos/tensorly-torch-logo.png
14 | :align: center
15 | :width: 500
16 |
17 |
18 | .. only:: html
19 |
20 | .. raw:: html
21 |
22 |
23 |
Tensor Batteries Included
24 |
25 |
26 |
27 |
28 | **TensorLy-Torch** is a PyTorch only library that builds on top of `TensorLy `_ and provides out-of-the-box tensor layers.
29 | It comes with all batteries included and tries to make it as easy as possible to use tensor methods within your deep networks.
30 |
31 | - **Leverage structure in your data**: with tensor layers, you can easily leverage the structure in your data, through :ref:`TRL `, :ref:`TCL `, :ref:`Factorized convolutions ` and more!
32 | - **Factorized tensors** as first class citizens: you can transparently directly create, manipulate and index factorized tensors and regular (dense) pytorch tensors alike!
33 | - **Built-in tensor layers**: all you have to do is import tensorly torch and include the layers we provide directly within your PyTorch models!
34 | - **Initialization**: initializing tensor decompositions can be tricky. We take care of it all, whether you want to initialize randomly using our :ref:`init_ref` module or from a pretrained layer.
35 | - **Tensor hooks**: you can easily augment your architectures with our built-in :mod:`Tensor Hooks `. Robustify your network with Tensor Dropout and automatically select the rank end-to-end with L1 Regularization!
36 | - **All the methods available**: we are always adding more methods to make it easy to compare between the performance of various deep tensor based methods!
37 |
38 | Deep Tensorized Learning
39 | ========================
40 |
41 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors.
42 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible.
43 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks.
44 |
45 | .. image:: _static/tensorly-torch-pyramid.png
46 | :align: center
47 | :width: 800
48 |
49 | TensorLy is a Python library that aims at making tensor learning simple and accessible.
50 | It provides a high-level API for tensor methods, including core tensor operations, tensor decomposition and regression.
51 | It has a flexible backend that allows running operations seamlessly using NumPy, PyTorch, TensorFlow, JAX, MXNet and CuPy.
52 |
53 | **TensorLy-Torch** is a PyTorch only library that builds on top of TensorLy and provides out-of-the-box tensor layers.
54 |
55 | Improve your neural networks with tensor methods
56 | ------------------------------------------------
57 |
58 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors.
59 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible.
60 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks.
61 |
62 | .. image:: _static/deep_tensor_nets_pros_circle.png
63 | :align: center
64 | :width: 350
65 |
66 |
67 | In TensorLy-Torch, we provide convenient layers that do all the heavy lifting for you
68 | and provide the benefits tensor based layers wrapped in a nice, well documented and tested API.
69 |
70 | For instance, convolution layers of any order (2D, 3D or more), can be efficiently parametrized
71 | using tensor decomposition. Using a CP decomposition results in a separable convolution
72 | and you can replace your original convolution with a series of small efficient ones:
73 |
74 | .. image:: _static/cp-conv.png
75 | :width: 500
76 | :align: center
77 |
78 | These can be easily perform with the :ref:`factorized_conv_ref` module in TensorLy-Torch.
79 | We also have Tucker convolutions and new tensor-train convolutions!
80 | We also implement various other methods such as tensor regression and contraction layers,
81 | tensorized linear layers, tensor dropout and more!
82 |
83 |
84 | .. toctree::
85 | :maxdepth: 1
86 | :hidden:
87 |
88 | install
89 | modules/api
90 | user_guide/index
91 | dev_guide/index
92 | about
93 |
94 |
95 | .. only:: html
96 |
97 | .. raw:: html
98 |
99 |
100 |
101 |
102 |
107 |
108 |
--------------------------------------------------------------------------------
/doc/install.rst:
--------------------------------------------------------------------------------
1 | =========================
2 | Installing tensorly-Torch
3 | =========================
4 |
5 |
6 | Pre-requisite
7 | =============
8 |
9 | You will need to have Python 3 installed, as well as NumPy, Scipy and `TensorLy `_.
10 | If you are starting with Python or generally want a pain-free experience, I recommend you install the `Anaconda distribiution `_. It comes with all you need shipped-in and ready to use!
11 |
12 |
13 | Installing with pip (recommended)
14 | =================================
15 |
16 |
17 | Simply run, in your terminal::
18 |
19 | pip install -U tensorly-torch
20 |
21 | (the `-U` is optional, use it if you want to update the package).
22 |
23 |
24 | Cloning the github repository
25 | =============================
26 |
27 | Clone the repository and cd there::
28 |
29 | git clone https://github.com/tensorly/torch
30 | cd torch
31 |
32 | Then install the package (here in editable mode with `-e` or equivalently `--editable`::
33 |
34 | pip install -e .
35 |
36 | Running the tests
37 | =================
38 |
39 | Uni-testing is an vital part of this package.
40 | You can run all the tests using `pytest`::
41 |
42 | pip install pytest
43 | pytest tltorch
44 |
45 | Building the documentation
46 | ==========================
47 |
48 | You will need to install slimit and minify::
49 |
50 | pip install slimit rcssmin
51 |
52 | You are now ready to build the doc (here in html)::
53 |
54 | make html
55 |
56 | The results will be in `_build/html`
57 |
58 |
--------------------------------------------------------------------------------
/doc/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/doc/minify.py:
--------------------------------------------------------------------------------
1 | from jsmin import jsmin
2 | from rcssmin import cssmin
3 |
4 | from pathlib import Path
5 | asset_path = Path('./themes/tensorly/static')
6 |
7 | for path in asset_path.glob('*.js'):
8 | # Ignore already minified files
9 | if '.min.' in str(path):
10 | continue
11 | target_path = path.with_suffix('.min.js')
12 | with open(path.as_posix(), 'r') as f:
13 | text = f.read()
14 | minified = jsmin(text, mangle=True, mangle_toplevel=True)
15 | with open(target_path.as_posix(), 'w') as f:
16 | f.write(minified)
17 |
18 | for path in asset_path.glob('*.css'):
19 | # Ignore already minified files
20 | if '.min.' in str(path):
21 | continue
22 | target_path = path.with_suffix('.min.css')
23 | with open(path.as_posix(), 'r') as f:
24 | text = f.read()
25 | minified = cssmin(text)
26 | with open(target_path.as_posix(), 'w') as f:
27 | f.write(minified)
28 |
29 |
--------------------------------------------------------------------------------
/doc/modules/api.rst:
--------------------------------------------------------------------------------
1 | =============
2 | API reference
3 | =============
4 |
5 | :mod:`tltorch`: Tensorized Deep Neural Networks
6 |
7 | .. automodule:: tltorch
8 | :no-members:
9 | :no-inherited-members:
10 |
11 | .. _factorized_tensor_ref:
12 |
13 | Factorized Tensors
14 | ==================
15 |
16 | TensorLy-Torch builds on top of TensorLy and provides out of the box PyTorch layers for tensor based operations.
17 | The core of this is the concept of factorized tensors, which factorize our layers, instead of regular, dense PyTorch tensors.
18 |
19 | You can create any factorized tensor through the main class using:
20 |
21 | .. autosummary::
22 | :toctree: generated
23 | :template: class.rst
24 |
25 | FactorizedTensor
26 |
27 | You can create a tensor of any form using ``FactorizedTensor.new(shape, rank, factorization)``, where factorization can be `Dense`, `CP`, `Tucker` or `TT`.
28 | Note that if you use ``factorization = 'dense'`` you are just creating a regular, unfactorized tensor.
29 | This allows to manipulate any tensor, factorized or not, with a simple, unified interface.
30 |
31 | Alternatively, you can also directly create a specific subclass:
32 |
33 | .. autosummary::
34 | :toctree: generated
35 | :template: class.rst
36 |
37 | DenseTensor
38 | CPTensor
39 | TuckerTensor
40 | TTTensor
41 |
42 | .. _factorized_matrix_ref:
43 |
44 | Tensorized Matrices
45 | ===================
46 |
47 | In TensorLy-Torch , you can also represent matrices in *tensorized* form, as low-rank tensors.
48 | Just as for factorized tensor, you can create a tensorized matrix through the main class using:
49 |
50 | .. autosummary::
51 | :toctree: generated
52 | :template: class.rst
53 |
54 | TensorizedTensor
55 |
56 | You can create a tensor of any form using ``TensorizedTensor.new(tensorized_shape, rank, factorization)``, where factorization can be `Dense`, `CP`, `Tucker` or `BlockTT`.
57 |
58 | You can also explicitly create the type of tensor you want using the following classes:
59 |
60 | .. autosummary::
61 | :toctree: generated
62 | :template: class.rst
63 |
64 | DenseTensorized
65 | TensorizedTensor
66 | CPTensorized
67 | BlockTT
68 |
69 | .. _complex_ref:
70 |
71 | Complex Tensors
72 | ===============
73 |
74 | In theory, you can simply specify ``dtype=torch.cfloat`` in the creation of any of the tensors of tensorized matrices above, to automatically create a complex valued tensor.
75 | However, in practice, there are many issues in complex support. Distributed Data Parallelism in particular, is not supported.
76 |
77 | In TensorLy-Torch, we propose a convenient and transparent way around this: simply use ``ComplexTensor`` instead.
78 | This will store the factors of the decomposition in real form (by explicitly storing the real and imaginary parts)
79 | but will transparently return you a complex valued tensor or reconstruction.
80 |
81 | .. autosummary::
82 | :toctree: generated
83 | :template: class.rst
84 |
85 | ComplexDenseTensor
86 | ComplexCPTensor
87 | ComplexTuckerTensor
88 | ComplexTTTensor
89 |
90 |
91 | ComplexDenseTensorized
92 | ComplexTuckerTensorized
93 | ComplexCPTensorized
94 | ComplexBlockTT
95 |
96 |
97 | You can also transparently instanciate any of these using directly the main classes, ``TensorizedTensor`` or ``FactorizedTensor`` and specifying
98 | ``factorization="ComplexCP"`` or in general ``ComplexFactorization`` with `Factorization` any of the supported decompositions.
99 |
100 |
101 | .. _init_ref:
102 |
103 | Initialization
104 | ==============
105 |
106 | .. automodule:: tltorch.factorized_tensors
107 | :no-members:
108 | :no-inherited-members:
109 |
110 | Initialization is particularly important in the context of deep learning.
111 | We provide convenient functions to directly initialize factorized tensor (i.e. their factors)
112 | such that their reconstruction follows approximately a centered Gaussian distribution.
113 |
114 | .. currentmodule:: tltorch.factorized_tensors.init
115 |
116 | .. autosummary::
117 | :toctree: generated
118 | :template: function.rst
119 |
120 | tensor_init
121 | cp_init
122 | tucker_init
123 | tt_init
124 | block_tt_init
125 |
126 | .. _trl_ref:
127 |
128 | Tensor Regression Layers
129 | ========================
130 |
131 | .. automodule:: tltorch.factorized_layers
132 | :no-members:
133 | :no-inherited-members:
134 |
135 | .. currentmodule:: tltorch.factorized_layers
136 |
137 | .. autosummary::
138 | :toctree: generated
139 | :template: class.rst
140 |
141 | TRL
142 |
143 | .. _tcl_ref:
144 |
145 | Tensor Contraction Layers
146 | =========================
147 |
148 | .. autosummary::
149 | :toctree: generated
150 | :template: class.rst
151 |
152 | TCL
153 |
154 | .. _factorized_linear_ref:
155 |
156 | Factorized Linear Layers
157 | ========================
158 |
159 | .. autosummary::
160 | :toctree: generated
161 | :template: class.rst
162 |
163 | FactorizedLinear
164 |
165 | .. _factorized_conv_ref:
166 |
167 | Factorized Convolutions
168 | =======================
169 |
170 | General N-Dimensional convolutions in Factorized forms
171 |
172 | .. autosummary::
173 | :toctree: generated
174 | :template: class.rst
175 |
176 | FactorizedConv
177 |
178 | .. _tensor_dropout_ref:
179 |
180 | Factorized Embeddings
181 | =====================
182 |
183 | A drop-in replacement for PyTorch's embeddings but using an efficient tensor parametrization that never reconstructs the full table.
184 |
185 | .. autosummary::
186 | :toctree: generated
187 | :template: class.rst
188 |
189 | FactorizedEmbedding
190 |
191 | .. _tensor_dropout_ref:
192 |
193 | Tensor Dropout
194 | ==============
195 |
196 | .. currentmodule:: tltorch.tensor_hooks
197 |
198 | .. automodule:: tltorch.tensor_hooks
199 | :no-members:
200 | :no-inherited-members:
201 |
202 | These functions allow you to easily add or remove tensor dropout from tensor layers.
203 |
204 |
205 | .. autosummary::
206 | :toctree: generated
207 | :template: function.rst
208 |
209 | tensor_dropout
210 | remove_tensor_dropout
211 |
212 |
213 | You can also use the class API below but unless you have a particular use for the classes, you should use the convenient functions provided instead.
214 |
215 | .. autosummary::
216 | :toctree: generated
217 | :template: class.rst
218 |
219 | TensorDropout
220 |
221 | .. _tensor_lasso_ref:
222 |
223 | L1 Regularization
224 | =================
225 |
226 | L1 Regularization on tensor modules.
227 |
228 | .. currentmodule:: tltorch.tensor_hooks
229 |
230 | .. autosummary::
231 | :toctree: generated
232 | :template: function.rst
233 |
234 | tensor_lasso
235 | remove_tensor_lasso
236 |
237 | Utilities
238 | =========
239 |
240 | Utility functions
241 |
242 | .. currentmodule:: tltorch.utils
243 |
244 | .. autosummary::
245 | :toctree: generated
246 | :template: function.rst
247 |
248 | get_tensorized_shape
--------------------------------------------------------------------------------
/doc/requirements_doc.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | jsmin
3 | rcssmin
4 | numpydoc
5 | sphinx-gallery
6 | myst-nb
7 | tensorly_sphinx_theme
8 |
--------------------------------------------------------------------------------
/doc/user_guide/factorized_conv.rst:
--------------------------------------------------------------------------------
1 | Factorized Convolutional Layers
2 | ===============================
3 |
4 | It is possible to apply low-rank tensor factorization to convolution
5 | kernels to compress the network and reduce the number of parameters.
6 |
7 | In TensorLy-Torch, you can easily try factorized convolutions: first, let’s
8 | import the library:
9 |
10 | .. code:: python
11 |
12 | import tltorch
13 | import torch
14 |
15 | Let’s now create some random data to try our modules: we can choose the
16 | size of the convolutions.
17 |
18 | .. code:: python
19 |
20 | device='cpu'
21 | input_channels = 16
22 | output_channels = 32
23 | kernel_size = 3
24 | batch_size = 2
25 | size = 24
26 | order = 2
27 |
28 | input_shape = (batch_size, input_channels) + (size, )*order
29 | kernel_shape = (output_channels, input_channels) + (kernel_size, )*order
30 |
31 | We can create some random input data:
32 |
33 | .. code:: python
34 |
35 | data = torch.randn(input_shape, dtype=torch.float32, device=device)
36 |
37 | Creating Factorized Convolutions
38 | --------------------------------
39 |
40 | From Random
41 | ~~~~~~~~~~~
42 |
43 | In PyTorch, you would create a convolution as follows:
44 |
45 | .. code:: python
46 |
47 | conv = torch.nn.Conv2d(input_channels, output_channels, kernel_size)
48 |
49 | In TensorLy Torch, it is exactly the same except that factorized
50 | convolutions are by default of any order: either you specify the kernel
51 | size or your specify the order
52 |
53 | .. code:: python
54 |
55 | conv = tltorch.FactorizedConv(input_channels, output_channels, kernel_size, order=2, rank='same', factorization='cp')
56 |
57 | .. code:: python
58 |
59 | conv = torch.nn.Conv2d(input_channels, output_channels, kernel_size=3)
60 |
61 | In TensorLy-Torch, factorized convolutions can be of any order, so you
62 | have to specify the order at creation (in Pytorch, you specify it
63 | through the class name, e.g. Conv2d or Conv3d):
64 |
65 | .. code:: python
66 |
67 | fact_conv = tltorch.FactorizedConv(input_channels, output_channels, kernel_size=3, order=2, rank='same')
68 |
69 |
70 | Or, you can specify the order directly by passing a tuple as kernel_size
71 | (in which case, ``order = len(kernel_size)`` is used).
72 |
73 | .. code:: python
74 |
75 | fact_conv = tltorch.FactorizedConv(input_channels, output_channels, kernel_size=(3, 3), rank='same')
76 |
77 | From an existing Convolution
78 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79 |
80 | You can create a Factorized convolution from an existing (PyTorch)
81 | convolution:
82 |
83 | .. code:: python
84 |
85 | fact_conv = tltorch.FactorizedConv.from_conv(conv, rank=0.5, decompose_weights=True, factorization='tucker')
86 |
87 | Efficient Convolutional Blocks
88 | ------------------------------
89 |
90 | If you compress a convolutional kernel, you can get efficient
91 | convolutional blocks by applying tensor factorization. For instance, if
92 | you apply CP decomposition, you can get a MobileNet-v2 block:
93 |
94 | .. code:: python
95 |
96 | fact_conv = tltorch.FactorizedConv.from_conv(conv, rank=0.5, factorization='cp', implementation='mobilenet')
97 |
98 |
99 | Similarly, if you apply Tucker decomposition, you can get a ResNet
100 | BottleNeck block:
101 |
102 | .. code:: python
103 |
104 | fact_conv = tltorch.FactorizedConv.from_conv(conv, rank=0.5, factorization='tucker', implementation='factorized')
105 |
--------------------------------------------------------------------------------
/doc/user_guide/factorized_embeddings.rst:
--------------------------------------------------------------------------------
1 | Factorized embedding layers
2 | ===========================
3 |
4 | In TensorLy-Torch, we also provide out-of-the-box tensorized embedding layers.
5 |
6 | Just as for the case of factorized linear, you can either create a factorized embedding from scratch, here automatically determine the
7 | input and output tensorized shapes, to have 3 dimensions each:
8 |
9 | .. code-block:: python
10 |
11 | import tltorch
12 | import torch
13 |
14 | from_embedding = tltorch.FactorizedEmbedding(num_embeddings, embedding_dim, auto_reshape=True, d=3, rank=0.4)
15 |
16 |
17 | Or, you can create it by decomposing an existing embedding layer:
18 |
19 | .. code-block:: python
20 |
21 | from_embedding = tltorch.FactorizedEmbedding.from_embedding(embedding_layer, auto_reshape=True,
22 | factorization='blocktt', n_tensorized_modes=3, rank=0.4)
--------------------------------------------------------------------------------
/doc/user_guide/factorized_tensors.rst:
--------------------------------------------------------------------------------
1 | Factorized tensors
2 | ==================
3 |
4 | The core concept in TensorLy-Torch is that of *factorized tensors*.
5 | We provide a :class:`~tltorch.FactorizedTensor` class that can be used just like any `PyTorch.Tensor` but
6 | provides all tensor factorization through one, simple API.
7 |
8 |
9 | Creating factorized tensors
10 | ---------------------------
11 |
12 | You can create a new factorized tensor easily:
13 |
14 | The signature is:
15 |
16 | .. code-block:: python
17 |
18 | factorized_tensor = FactorizedTensor.new(shape, rank, factorization)
19 |
20 | For instance, to create a tensor in Tucker form, that has half the parameters of a dense (non-factorized) tensor of the same shape, you would simply write:
21 |
22 | .. code-block:: python
23 |
24 | tucker_tensor = FactorizedTensor.new(shape, rank=0.5, factorization='tucker')
25 |
26 | Since TensorLy-Torch builds on top of TensorLy, it also comes with tensor decomposition out-of-the-box.
27 | To initialize a factorized tensor in CP (Canonical-Polyadic) form, also known as Parafac, or Kruskal tensor,
28 | with 1/10th of the parameters, you can simply write:
29 |
30 | .. code-block:: python
31 |
32 | cp_tensor = FactorizedTensor.new(dense_tensor, rank=0.1, factorization='CP')
33 |
34 |
35 | Manipulating factorized tensors
36 | -------------------------------
37 |
38 | The first thing you want to do, if you created a new tensor from scratch (by using the ``new`` method), is to initialize it,
39 | e.g. so that the element of the reconstruction approximately follow a Gaussian distribution:
40 |
41 | .. code-block:: python
42 |
43 | cp_tensor.normal_(mean=0, std=0.02)
44 |
45 | You can even use PyTorch's functions! This works:
46 |
47 | .. code-block:: python
48 |
49 | from torch.nn import init
50 |
51 | init.kaiming_normal(cp_tensor)
52 |
53 | Finally, you can index tensors directly in factorized form, which will return another factorized tensor, whenever possible!
54 |
55 | >>> cp_tensor[:2, :2]
56 | CPTensor(shape=(2, 2, 2), rank=2)
57 |
58 | If not possible, a dense tensor will be returned:
59 |
60 |
61 | >>> cp_tensor[2, 3, 1]
62 | tensor(0.0250, grad_fn=)
63 |
64 |
65 | Note how, above, indexing tracks gradients as well!
66 |
67 | Tensorized tensors
68 | ==================
69 |
70 | In addition to tensor in factorized forms, TensorLy-Torch provides out-of-the-box for **Tensorized** tensors.
71 | The most common case is that of tensorized matrices, where a matrix is first *tensorized*, i.e. reshaped into
72 | a higher-order tensor which is then decomposed and stored in factorized form.
73 |
74 | A commonly used tensorized tensor is the tensor-train matrix (also known as Matrix-Product Operator in quantum physics),
75 | or, in general, Block-TT.
76 |
77 | Creation
78 | --------
79 |
80 | You can create one in TensorLy-Torch, from a matrix, just as easily as a regular tensor, using the :class:`tltorch.TensorizedTensor` class,
81 | with the following signature:
82 |
83 | .. code-block:: python
84 |
85 | TensorizedTensor.from_matrix(matrix, tensorized_row_shape, tensorized_column_shape, rank)
86 |
87 | where tensorized_row_shape and tensorized_column_shape indicate the shape to which to tensorize the row and column size of the given matrix.
88 | For instance, if you have a matrix of size 16x21, you could use tensorized_row_shape=(4, 4) and tensorized_column_shape=(3, 7).
89 |
90 |
91 | In general, you can tensorize any tensor, not just matrices, even with batched modes (dimensions)!
92 |
93 | .. code-block:: python
94 |
95 | tensorized_tensor = TensorizedTensor.new(tensorized_shape, rank, factorization)
96 |
97 |
98 | ``tensorized_shape`` is a nested tuple, in which an int represents a batched mode, and a tuple a tensorized mode.
99 |
100 | For instance, a batch of 5 matrices of size 16x21 could be tensorized into
101 | a batch of 5 tensorized matrices of size (4x4)x(3x7), in the BlockTT form. In code, you would do this using
102 |
103 | .. code-block:: python
104 |
105 | tensorized_tensor = TensorizedTensor.from_tensor(tensor, (5, (4, 4), (3, 7)), rank=0.7, factorization='BlockTT')
106 |
107 | You can of course tensorize any size tensors, e.g. a batch of 5 matrices of size 8x27 can be tensorized into:
108 |
109 | >>> ftt = tltorch.TensorizedTensor.new((5, (2, 2, 2), (3, 3, 3)), rank=0.5, factorization='BlockTT')
110 |
111 | This returns a tensorized tensor, stored in decomposed form:
112 | >>> ftt
113 | BlockTT(shape=[5, 8, 27], tensorized_shape=(5, (2, 2, 2), (3, 3, 3)), rank=[1, 20, 20, 1])
114 |
115 | Manipulation
116 | -------------
117 |
118 | As for factorized tensors, you can directly index them:
119 |
120 | >>> ftt[2]
121 | BlockTT(shape=[8, 27], tensorized_shape=[(2, 2, 2), (3, 3, 3)], rank=[1, 20, 20, 1])
122 |
123 | >>> ftt[0, :2, :2]
124 | tensor([[-0.0009, 0.0004],
125 | [ 0.0007, 0.0003]], grad_fn=)
126 |
127 | Again, notice that gradients are tracked and all operations on factorized and tensorized tensors are back-propagatable!
128 |
--------------------------------------------------------------------------------
/doc/user_guide/index.rst:
--------------------------------------------------------------------------------
1 | .. _user_guide:
2 |
3 |
4 | User guide
5 | ==========
6 |
7 | TensorLy-Torch was written to provide out-of-the-box Tensor layers that can be readily used within any PyTorch network or code.
8 | It builds on top of `TensorLy `_ and enables anyone to use tensor methods within deep networks, even without a deep knowledge of tensor algebra.
9 |
10 | .. toctree::
11 |
12 | factorized_tensors
13 | trl
14 | factorized_conv
15 | tensorized_linear
16 | factorized_embeddings
17 | tensor_hooks
18 |
--------------------------------------------------------------------------------
/doc/user_guide/tensor_hooks.rst:
--------------------------------------------------------------------------------
1 | Tensor Hooks
2 | ============
3 |
4 | TensorLy-Torch also makes it very easy to manipulate the tensor decomposition parametrizing a tensor module.
5 |
6 | Tensor dropout
7 | --------------
8 | For instance, you can apply very easily tensor dropout to any tensor factorization: let's first create a simple TRL layer.
9 |
10 | .. code:: python
11 |
12 | import tltorch
13 | trl = tltorch.TRL((10, 10), (10, ), rank='same')
14 |
15 | To add tensor dropout, simply apply the helper function:
16 |
17 | .. code:: python
18 |
19 | trl = tltorch.tucker_dropout(trl.weight, p=0.5)
20 |
21 |
22 | Similarly, to remove tensor dropout:
23 |
24 | .. code:: python
25 |
26 | tltorch.remove_tucker_dropout(trl.weight)
27 |
28 |
29 | Lasso rank regularization
30 | -------------------------
31 |
32 | Rank selection is a hard problem. One way to choose the rank while training is to apply
33 | an l1 penalty (Lasso) on the rank.
34 |
35 | This was used previously for CP decomposition, and we extended it in TensorLy-Torch to Tucker and Tensor-Train,
36 | by introducing new weights in the decomposition.
37 |
38 | To use is, you can define a regularizer object that will take care of everything.
39 | Using our previously defined TRL:
40 |
41 | .. code:: python
42 |
43 | l1_reg = tltorch.TuckerL1Regularizer(penalty=0.01)
44 | l1_reg.apply(trl.weight)
45 | x = trl(x)
46 | loss = my_loss(x) + l1_reg.loss
47 | l1_reg.res
48 |
49 | After each iteration, don't forget to reset the loss so you don't keep accumulating:
50 |
51 | .. code:: python
52 |
53 | l1_reg.reset()
54 |
55 | Initializing tensor decomposition
56 | ---------------------------------
57 |
58 | Another issue is that of initializing tensor decompositions:
59 | if you simply initialize randomly each component without care,
60 | the reconstructed (full) tensor can have arbitrarily large or small values
61 | potentially leading to gradient vanishing or exploding during training.
62 |
63 | In TensorLy-Torch, we provide a module for initialization that will
64 | properly initialize the factors of the decomposition
65 | so that the reconstruction has zero mean and the specified standard deviation!
66 |
67 | For any tensor factorization ``fact_tensor``:
68 |
69 | .. code:: python
70 |
71 | fact_tensor.normal_(0, 0.02)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/doc/user_guide/tensorized_linear.rst:
--------------------------------------------------------------------------------
1 | Tensorized Linear Layers
2 | ========================
3 |
4 | Linear layers are parametrized by matrices. However, it is possible to
5 | *tensorize* them, i.e. reshape them into higher-order tensors in order
6 | to compress them.
7 |
8 | You can do this easily in TensorLy-Torch:
9 |
10 | .. code-block:: python
11 |
12 | import tltorch
13 | import torch
14 |
15 | Let’s create a batch of 4 data points of size 16 each:
16 |
17 | .. code-block:: python
18 |
19 | data = torch.randn((4, 16), dtype=torch.float32)
20 |
21 | Now, imagine you already have a linear layer:
22 |
23 | .. code-block:: python
24 |
25 | linear = torch.nn.Linear(in_features=16, 10)
26 |
27 | You can easily compress it into a tensorized linear layer: here we specify the shape to which to tensorize the weights,
28 | and use `rank=0.5`, which means automatically determine the rank so that the factorization uses approximately half the
29 | number of parameters.
30 |
31 | .. code-block:: python
32 |
33 | fact_linear = tltorch.FactorizedLinear.from_linear(linear, auto_tensorize=False,
34 | in_tensorized_features=(4, 4), out_tensorized_features=(2, 5), rank=0.5)
35 |
36 |
37 | The tensorized weights will have the following shape:
38 |
39 | .. parsed-literal::
40 |
41 | torch.Size([4, 4, 2, 5])
42 |
43 |
44 | Note that you can also let TensorLy-Torch automatically determine the tensorization shape. In this case we just instruct it to
45 | find ``in_tensorized_features`` and ``out_tensorized_features`` to have length `2`:
46 |
47 | .. code-block:: python
48 |
49 | fact_linear = tltorch.FactorizedLinear.from_linear(linear, auto_tensorize=True, n_tensorized_modes=2, rank=0.5)
50 |
51 |
52 | You can also create tensorized layers from scratch:
53 |
54 | .. code-block:: python
55 |
56 | fact_linear = tltorch.FactorizedLinear(in_tensorized_features=(4, 4),
57 | out_tensorized_features=(2, 5),
58 | factorization='tucker', rank=0.5)
59 |
60 | Finally, during the forward pass, you can reconstruct the full weights (``implementation='reconstructed'``) and perform a regular linear layer forward pass.
61 | ALternatively, you can let TensorLy-Torch automatically direction contract the input tensor with the *factors of the decomposition* (``implementation='factorized'``),
62 | which can be faster, particularly if you have a very small rank, e.g. very small factorization factors.
--------------------------------------------------------------------------------
/doc/user_guide/trl.rst:
--------------------------------------------------------------------------------
1 | Tensor Regression Layers
2 | ========================
3 |
4 | In deep neural networks, while convolutional layers map between
5 | high-order activation tensors, the output is still obtained through
6 | linear regression: first the activation is flattened before being passed
7 | through linear layers.
8 |
9 | This approach has several drawbacks:
10 |
11 | * Linear regression discards topological (e.g. spatial) information.
12 | * Very large number of parameters
13 | (product of the dimensions
14 | of the input tensor times the size of the output)
15 | * Lack of robustness
16 |
17 | A Tensor Regression Layer (TRL) generalizes the concept of linear
18 | regression to higher-order but alleviates the above issues. It allows to
19 | preserve and leverage multi-linear structure while being parsimonious in
20 | terms of number of parameters. The low-rank constraints also acts as an
21 | implicit reguralization on the model, typically leading to better sample
22 | efficiency and robustness.
23 |
24 | .. image:: /_static/TRL.png
25 | :align: center
26 | :width: 800
27 |
28 |
29 | TRL in TensorLy-Torch
30 | ---------------------
31 |
32 | Now, let’s see how to do this in code with TensorLy-Torch
33 |
34 | Random TRL
35 | ----------
36 |
37 | Let’s first see how to create and train a TRL from scratch
38 |
39 | .. code:: python
40 |
41 | import tltorch
42 | import torch
43 | from torch import nn
44 | import numpy as np
45 |
46 | .. code:: python
47 |
48 | input_shape = (4, 5)
49 | output_shape = (6, 2)
50 | batch_size = 2
51 |
52 | device = 'cpu'
53 |
54 | x = torch.randn((batch_size,) + input_shape,
55 | dtype=torch.float32, device=device)
56 |
57 | .. code:: python
58 |
59 | trl = tltorch.TRL(input_shape, output_shape, rank='same')
60 |
61 |
62 | .. code:: python
63 |
64 | result = trl(x)
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | pytest
4 | pytest-cov
5 | tensorly
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | try:
2 | from setuptools import setup, find_packages
3 | except ImportError:
4 | from distutils.core import setup, find_packages
5 |
6 | import re
7 | from pathlib import Path
8 |
9 | def version(root_path):
10 | """Returns the version taken from __init__.py
11 |
12 | Parameters
13 | ----------
14 | root_path : pathlib.Path
15 | path to the root of the package
16 |
17 | Reference
18 | ---------
19 | https://packaging.python.org/guides/single-sourcing-package-version/
20 | """
21 | version_path = root_path.joinpath('tltorch', '__init__.py')
22 | with version_path.open() as f:
23 | version_file = f.read()
24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
25 | version_file, re.M)
26 | if version_match:
27 | return version_match.group(1)
28 | raise RuntimeError("Unable to find version string.")
29 |
30 |
31 | def readme(root_path):
32 | """Returns the text content of the README.rst of the package
33 |
34 | Parameters
35 | ----------
36 | root_path : pathlib.Path
37 | path to the root of the package
38 | """
39 | with root_path.joinpath('README.rst').open(encoding='UTF-8') as f:
40 | return f.read()
41 |
42 |
43 | root_path = Path(__file__).parent
44 | README = readme(root_path)
45 | VERSION = version(root_path)
46 |
47 |
48 | config = {
49 | 'name': 'tensorly-torch',
50 | 'packages': find_packages(),
51 | 'description': 'Deep Learning with Tensors in Python, using PyTorch and TensorLy.',
52 | 'long_description': README,
53 | 'long_description_content_type' : 'text/x-rst',
54 | 'author': 'Jean Kossaifi',
55 | 'author_email': 'jean.kossaifi@gmail.com',
56 | 'version': VERSION,
57 | 'url': 'https://github.com/tensorly/tensorly-torch',
58 | 'download_url': 'https://github.com/tensorly/tensorly-torch/tarball/' + VERSION,
59 | 'install_requires': ['numpy', 'scipy'],
60 | 'license': 'Modified BSD',
61 | 'scripts': [],
62 | 'classifiers': [
63 | 'Topic :: Scientific/Engineering',
64 | 'License :: OSI Approved :: BSD License',
65 | 'Programming Language :: Python :: 3'
66 | ],
67 | }
68 |
69 | setup(**config)
70 |
--------------------------------------------------------------------------------
/tltorch/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.5.0'
2 |
3 | from . import utils
4 | from . import factorized_tensors
5 | from .factorized_tensors import init
6 | from . import functional
7 | from . import factorized_layers
8 |
9 | from .factorized_layers import FactorizedLinear, FactorizedConv, TRL, TCL, FactorizedEmbedding
10 | from .factorized_tensors import FactorizedTensor, DenseTensor, CPTensor, TTTensor, TuckerTensor, tensor_init
11 | from .factorized_tensors import (TensorizedTensor, CPTensorized, BlockTT,
12 | DenseTensorized, TuckerTensorized)
13 | from .factorized_tensors import (ComplexCPTensor, ComplexTuckerTensor,
14 | ComplexTTTensor, ComplexDenseTensor)
15 | from .factorized_tensors import (ComplexCPTensorized, ComplexBlockTT,
16 | ComplexDenseTensorized, ComplexTuckerTensorized)
17 | from .tensor_hooks import (tensor_lasso, remove_tensor_lasso,
18 | tensor_dropout, remove_tensor_dropout)
19 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .factorized_convolution import FactorizedConv
2 | from .tensor_regression_layers import TRL
3 | from .tensor_contraction_layers import TCL
4 | from .factorized_linear import FactorizedLinear
5 | from .factorized_embedding import FactorizedEmbedding
6 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/factorized_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn
4 | from tltorch.factorized_tensors import TensorizedTensor,tensor_init
5 | from tltorch.utils import get_tensorized_shape
6 |
7 | # Authors: Cole Hawkins
8 | # Jean Kossaifi
9 |
10 | class FactorizedEmbedding(nn.Module):
11 | """
12 | Tensorized Embedding Layers For Efficient Model Compression
13 | Tensorized drop-in replacement for `torch.nn.Embedding`
14 |
15 | Parameters
16 | ----------
17 | num_embeddings : int
18 | number of entries in the lookup table
19 | embedding_dim : int
20 | number of dimensions per entry
21 | auto_tensorize : bool
22 | whether to use automatic reshaping for the embedding dimensions
23 | n_tensorized_modes : int or int tuple
24 | number of reshape dimensions for both embedding table dimension
25 | tensorized_num_embeddings : int tuple
26 | tensorized shape of the first embedding table dimension
27 | tensorized_embedding_dim : int tuple
28 | tensorized shape of the second embedding table dimension
29 | factorization : str
30 | tensor type
31 | rank : int tuple or str
32 | rank of the tensor factorization
33 | """
34 | def __init__(self,
35 | num_embeddings,
36 | embedding_dim,
37 | auto_tensorize=True,
38 | n_tensorized_modes=3,
39 | tensorized_num_embeddings=None,
40 | tensorized_embedding_dim=None,
41 | factorization='blocktt',
42 | rank=8,
43 | n_layers=1,
44 | device=None,
45 | dtype=None):
46 | super().__init__()
47 |
48 | if auto_tensorize:
49 |
50 | if tensorized_num_embeddings is not None and tensorized_embedding_dim is not None:
51 | raise ValueError(
52 | "Either use auto_tensorize or specify tensorized_num_embeddings and tensorized_embedding_dim."
53 | )
54 |
55 | tensorized_num_embeddings, tensorized_embedding_dim = get_tensorized_shape(in_features=num_embeddings, out_features=embedding_dim, order=n_tensorized_modes, min_dim=2, verbose=False)
56 |
57 | else:
58 | #check that dimensions match factorization
59 | computed_num_embeddings = np.prod(tensorized_num_embeddings)
60 | computed_embedding_dim = np.prod(tensorized_embedding_dim)
61 |
62 | if computed_num_embeddings!=num_embeddings:
63 | raise ValueError("Tensorized embeddding number {} does not match num_embeddings argument {}".format(computed_num_embeddings,num_embeddings))
64 | if computed_embedding_dim!=embedding_dim:
65 | raise ValueError("Tensorized embeddding dimension {} does not match embedding_dim argument {}".format(computed_embedding_dim,embedding_dim))
66 |
67 | self.num_embeddings = num_embeddings
68 | self.embedding_dim = embedding_dim
69 | self.tensor_shape = (tensorized_num_embeddings,
70 | tensorized_embedding_dim)
71 | self.weight_shape = (self.num_embeddings, self.embedding_dim)
72 |
73 | self.n_layers = n_layers
74 | if n_layers > 1:
75 | self.tensor_shape = (n_layers, ) + self.tensor_shape
76 | self.weight_shape = (n_layers, ) + self.weight_shape
77 |
78 | self.factorization = factorization
79 |
80 | self.weight = TensorizedTensor.new(self.tensor_shape,
81 | rank=rank,
82 | factorization=self.factorization,
83 | device=device,
84 | dtype=dtype)
85 | self.reset_parameters()
86 |
87 | self.rank = self.weight.rank
88 |
89 | def reset_parameters(self):
90 | #Parameter initialization from Yin et al.
91 | #TT-Rec: Tensor Train Compression for Deep Learning Recommendation Model Embeddings
92 | target_stddev = 1 / np.sqrt(3 * self.num_embeddings)
93 | with torch.no_grad():
94 | tensor_init(self.weight,std=target_stddev)
95 |
96 | def forward(self, input, indices=0):
97 | #to handle case where input is not 1-D
98 | output_shape = (*input.shape, self.embedding_dim)
99 |
100 | flattened_input = input.reshape(-1)
101 |
102 | if self.n_layers == 1:
103 | if indices == 0:
104 | embeddings = self.weight[flattened_input, :]
105 | else:
106 | embeddings = self.weight[indices, flattened_input, :]
107 |
108 | #CPTensorized returns CPTensorized when indexing
109 | if self.factorization.lower() == 'cp':
110 | embeddings = embeddings.to_matrix()
111 |
112 | #TuckerTensorized returns tensor not matrix,
113 | # and requires reshape not view for contiguous
114 | elif self.factorization.lower() == 'tucker':
115 | embeddings = embeddings.reshape(input.shape[0], -1)
116 |
117 | return embeddings.view(output_shape)
118 |
119 | @classmethod
120 | def from_embedding(cls,
121 | embedding_layer,
122 | rank=8,
123 | factorization='blocktt',
124 | n_tensorized_modes=2,
125 | decompose_weights=True,
126 | auto_tensorize=True,
127 | decomposition_kwargs=dict(),
128 | **kwargs):
129 | """
130 | Create a tensorized embedding layer from a regular embedding layer
131 |
132 | Parameters
133 | ----------
134 | embedding_layer : torch.nn.Embedding
135 | rank : int tuple or str
136 | rank of the tensor decomposition
137 | factorization : str
138 | tensor type
139 | decompose_weights: bool
140 | whether to decompose weights and use for initialization
141 | auto_tensorize: bool
142 | if True, automatically reshape dimensions for TensorizedTensor
143 | decomposition_kwargs: dict
144 | specify kwargs for the decomposition
145 | """
146 | num_embeddings, embedding_dim = embedding_layer.weight.shape
147 |
148 | instance = cls(num_embeddings,
149 | embedding_dim,
150 | auto_tensorize=auto_tensorize,
151 | factorization=factorization,
152 | n_tensorized_modes=n_tensorized_modes,
153 | rank=rank,
154 | **kwargs)
155 |
156 | if decompose_weights:
157 | with torch.no_grad():
158 | instance.weight.init_from_matrix(embedding_layer.weight.data,
159 | **decomposition_kwargs)
160 |
161 | else:
162 | instance.reset_parameters()
163 |
164 | return instance
165 |
166 | @classmethod
167 | def from_embedding_list(cls,
168 | embedding_layer_list,
169 | rank=8,
170 | factorization='blocktt',
171 | n_tensorized_modes=2,
172 | decompose_weights=True,
173 | auto_tensorize=True,
174 | decomposition_kwargs=dict(),
175 | **kwargs):
176 | """
177 | Create a tensorized embedding layer from a regular embedding layer
178 |
179 | Parameters
180 | ----------
181 | embedding_layer : torch.nn.Embedding
182 | rank : int tuple or str
183 | tensor rank
184 | factorization : str
185 | tensor decomposition to use
186 | decompose_weights: bool
187 | decompose weights and use for initialization
188 | auto_tensorize: bool
189 | automatically reshape dimensions for TensorizedTensor
190 | decomposition_kwargs: dict
191 | specify kwargs for the decomposition
192 | """
193 | n_layers = len(embedding_layer_list)
194 | num_embeddings, embedding_dim = embedding_layer_list[0].weight.shape
195 |
196 | for i, layer in enumerate(embedding_layer_list[1:]):
197 | # Just some checks on the size of the embeddings
198 | # They need to have the same size so they can be jointly factorized
199 | new_num_embeddings, new_embedding_dim = layer.weight.shape
200 | if num_embeddings != new_num_embeddings:
201 | msg = 'All embedding layers must have the same num_embeddings.'
202 | msg += f'Yet, got embedding_layer_list[0] with num_embeddings={num_embeddings} '
203 | msg += f' and embedding_layer_list[{i+1}] with num_embeddings={new_num_embeddings}.'
204 | raise ValueError(msg)
205 | if embedding_dim != new_embedding_dim:
206 | msg = 'All embedding layers must have the same embedding_dim.'
207 | msg += f'Yet, got embedding_layer_list[0] with embedding_dim={embedding_dim} '
208 | msg += f' and embedding_layer_list[{i+1}] with embedding_dim={new_embedding_dim}.'
209 | raise ValueError(msg)
210 |
211 | instance = cls(num_embeddings,
212 | embedding_dim,
213 | n_tensorized_modes=n_tensorized_modes,
214 | auto_tensorize=auto_tensorize,
215 | factorization=factorization,
216 | rank=rank,
217 | n_layers=n_layers,
218 | **kwargs)
219 |
220 | if decompose_weights:
221 | weight_tensor = torch.stack([layer.weight.data for layer in embedding_layer_list])
222 | with torch.no_grad():
223 | instance.weight.init_from_matrix(weight_tensor,
224 | **decomposition_kwargs)
225 |
226 | else:
227 | instance.reset_parameters()
228 |
229 | return instance
230 |
231 |
232 | def get_embedding(self, indices):
233 | if self.n_layers == 1:
234 | raise ValueError('A single linear is parametrized, directly use the main class.')
235 |
236 | return SubFactorizedEmbedding(self, indices)
237 |
238 |
239 | class SubFactorizedEmbedding(nn.Module):
240 | """Class representing one of the embeddings from the mother joint factorized embedding layer
241 |
242 | Parameters
243 | ----------
244 |
245 | Notes
246 | -----
247 | This relies on the fact that nn.Parameters are not duplicated:
248 | if the same nn.Parameter is assigned to multiple modules, they all point to the same data,
249 | which is shared.
250 | """
251 | def __init__(self, main_layer, indices):
252 | super().__init__()
253 | self.main_layer = main_layer
254 | self.indices = indices
255 |
256 | def forward(self, x):
257 | return self.main_layer(x, self.indices)
258 |
259 | def extra_repr(self):
260 | return ''
261 |
262 | def __repr__(self):
263 | msg = f' {self.__class__.__name__} {self.indices} from main factorized layer.'
264 | msg += f'\n{self.__class__.__name__}('
265 | msg += self.extra_repr()
266 | msg += ')'
267 | return msg
--------------------------------------------------------------------------------
/tltorch/factorized_layers/factorized_linear.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | from torch import nn
4 | import torch
5 | from torch.utils import checkpoint
6 |
7 | from ..functional import factorized_linear
8 | from ..factorized_tensors import TensorizedTensor
9 | from tltorch.utils import get_tensorized_shape
10 |
11 | # Author: Jean Kossaifi
12 | # License: BSD 3 clause
13 |
14 |
15 | class FactorizedLinear(nn.Module):
16 | """Tensorized Fully-Connected Layers
17 |
18 | The weight matrice is tensorized to a tensor of size `(*in_tensorized_features, *out_tensorized_features)`.
19 | That tensor is expressed as a low-rank tensor.
20 |
21 | During inference, the full tensor is reconstructed, and unfolded back into a matrix,
22 | used for the forward pass in a regular linear layer.
23 |
24 | Parameters
25 | ----------
26 | in_tensorized_features : int tuple
27 | shape to which the input_features dimension is tensorized to
28 | e.g. if in_features is 8 in_tensorized_features could be (2, 2, 2)
29 | should verify prod(in_tensorized_features) = in_features
30 | out_tensorized_features : int tuple
31 | shape to which the input_features dimension is tensorized to.
32 | factorization : str, default is 'cp'
33 | rank : int tuple or str
34 | implementation : {'factorized', 'reconstructed'}, default is 'factorized'
35 | which implementation to use for forward function:
36 | - if 'factorized', will directly contract the input with the factors of the decomposition
37 | - if 'reconstructed', the full weight matrix is reconstructed from the factorized version and used for a regular linear layer forward pass.
38 | n_layers : int, default is 1
39 | number of linear layers to be parametrized with a single factorized tensor
40 | bias : bool, default is True
41 | checkpointing : bool
42 | whether to enable gradient checkpointing to save memory during training-mode forward, default is False
43 | device : PyTorch device to use, default is None
44 | dtype : PyTorch dtype, default is None
45 | """
46 | def __init__(self, in_tensorized_features, out_tensorized_features, bias=True,
47 | factorization='cp', rank='same', implementation='factorized', n_layers=1,
48 | checkpointing=False, device=None, dtype=None):
49 | super().__init__()
50 | if factorization == 'TTM' and n_layers != 1:
51 | raise ValueError(f'TTM factorization only support single factorized layers but got n_layers={n_layers}.')
52 |
53 | self.in_features = np.prod(in_tensorized_features)
54 | self.out_features = np.prod(out_tensorized_features)
55 | self.in_tensorized_features = in_tensorized_features
56 | self.out_tensorized_features = out_tensorized_features
57 | self.tensorized_shape = out_tensorized_features + in_tensorized_features
58 | self.weight_shape = (self.out_features, self.in_features)
59 | self.input_rank = rank
60 | self.implementation = implementation
61 | self.checkpointing = checkpointing
62 |
63 | if bias:
64 | if n_layers == 1:
65 | self.bias = nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
66 | self.has_bias = True
67 | else:
68 | self.bias = nn.Parameter(torch.empty((n_layers, self.out_features), device=device, dtype=dtype))
69 | self.has_bias = np.zeros(n_layers)
70 | else:
71 | self.register_parameter('bias', None)
72 |
73 | self.rank = rank
74 | self.n_layers = n_layers
75 | if n_layers > 1:
76 | tensor_shape = (n_layers, out_tensorized_features, in_tensorized_features)
77 | else:
78 | tensor_shape = (out_tensorized_features, in_tensorized_features)
79 |
80 | if isinstance(factorization, TensorizedTensor):
81 | self.weight = factorization.to(device).to(dtype)
82 | else:
83 | self.weight = TensorizedTensor.new(tensor_shape, rank=rank, factorization=factorization, device=device, dtype=dtype)
84 | self.reset_parameters()
85 |
86 | self.rank = self.weight.rank
87 |
88 | def reset_parameters(self):
89 | with torch.no_grad():
90 | self.weight.normal_(0, math.sqrt(5)/math.sqrt(self.in_features))
91 | if self.bias is not None:
92 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
93 | bound = 1 / math.sqrt(fan_in)
94 | torch.nn.init.uniform_(self.bias, -bound, bound)
95 |
96 | def forward(self, x, indices=0):
97 | if self.n_layers == 1:
98 | if indices == 0:
99 | weight, bias = self.weight(), self.bias
100 | else:
101 | raise ValueError(f'Only one convolution was parametrized (n_layers=1) but tried to access {indices}.')
102 |
103 | elif isinstance(self.n_layers, int):
104 | if not isinstance(indices, int):
105 | raise ValueError(f'Expected indices to be in int but got indices={indices}'
106 | f', but this conv was created with n_layers={self.n_layers}.')
107 | weight = self.weight(indices)
108 | bias = self.bias[indices] if self.bias is not None else None
109 | elif len(indices) != len(self.n_layers):
110 | raise ValueError(f'Got indices={indices}, but this conv was created with n_layers={self.n_layers}.')
111 | else:
112 | weight = self.weight(indices)
113 | bias = self.bias[indices] if self.bias is not None else None
114 |
115 | def _inner_forward(x): # move weight() out to avoid register_hooks from being executed twice during recomputation
116 | return factorized_linear(x, weight, bias=bias, in_features=self.in_features,
117 | implementation=self.implementation)
118 |
119 | if self.checkpointing and x.requires_grad:
120 | x = checkpoint.checkpoint(_inner_forward, x)
121 | else:
122 | x = _inner_forward(x)
123 | return x
124 |
125 | def get_linear(self, indices):
126 | if self.n_layers == 1:
127 | raise ValueError('A single linear is parametrized, directly use the main class.')
128 |
129 | return SubFactorizedLinear(self, indices)
130 |
131 | def __getitem__(self, indices):
132 | return self.get_linear(indices)
133 |
134 | @classmethod
135 | def from_linear(cls, linear, rank='same', auto_tensorize=True, n_tensorized_modes=3,
136 | in_tensorized_features=None, out_tensorized_features=None,
137 | bias=True, factorization='CP', implementation="reconstructed",
138 | checkpointing=False, decomposition_kwargs=dict(), verbose=False):
139 | """Class method to create an instance from an existing linear layer
140 |
141 | Parameters
142 | ----------
143 | linear : torch.nn.Linear
144 | layer to tensorize
145 | auto_tensorize : bool, default is True
146 | if True, automatically find values for the tensorized_shapes
147 | n_tensorized_modes : int, default is 3
148 | Order (number of dims) of the tensorized weights if auto_tensorize is True
149 | in_tensorized_features, out_tensorized_features : tuple
150 | shape to tensorized the factorized_weight matrix to.
151 | Must verify np.prod(tensorized_shape) == np.prod(linear.factorized_weight.shape)
152 | factorization : str, default is 'cp'
153 | implementation : str
154 | which implementation to use for forward function. support 'factorized' and 'reconstructed', default is 'factorized'
155 | checkpointing : bool
156 | whether to enable gradient checkpointing to save memory during training-mode forward, default is False
157 | rank : {rank of the decomposition, 'same', float}
158 | if float, percentage of parameters of the original factorized_weights to use
159 | if 'same' use the same number of parameters
160 | bias : bool, default is True
161 | verbose : bool, default is False
162 | """
163 | out_features, in_features = linear.weight.shape
164 |
165 | if auto_tensorize:
166 |
167 | if out_tensorized_features is not None and in_tensorized_features is not None:
168 | raise ValueError(
169 | "Either use auto_reshape or specify out_tensorized_features and in_tensorized_features."
170 | )
171 |
172 | in_tensorized_features, out_tensorized_features = get_tensorized_shape(
173 | in_features=in_features, out_features=out_features, order=n_tensorized_modes, min_dim=2, verbose=verbose)
174 | else:
175 | assert(out_features == np.prod(out_tensorized_features))
176 | assert(in_features == np.prod(in_tensorized_features))
177 |
178 | instance = cls(in_tensorized_features, out_tensorized_features, bias=bias,
179 | factorization=factorization, rank=rank, implementation=implementation,
180 | n_layers=1, checkpointing=checkpointing, device=linear.weight.device, dtype=linear.weight.dtype)
181 |
182 | instance.weight.init_from_matrix(linear.weight.data, **decomposition_kwargs)
183 |
184 | if bias and linear.bias is not None:
185 | instance.bias.data = linear.bias.data
186 |
187 | return instance
188 |
189 | @classmethod
190 | def from_linear_list(cls, linear_list, in_tensorized_features, out_tensorized_features, rank, bias=True,
191 | factorization='CP', implementation="reconstructed", checkpointing=False, decomposition_kwargs=dict(init='random')):
192 | """Class method to create an instance from an existing linear layer
193 |
194 | Parameters
195 | ----------
196 | linear : torch.nn.Linear
197 | layer to tensorize
198 | tensorized_shape : tuple
199 | shape to tensorized the weight matrix to.
200 | Must verify np.prod(tensorized_shape) == np.prod(linear.weight.shape)
201 | factorization : str, default is 'cp'
202 | implementation : str
203 | which implementation to use for forward function. support 'factorized' and 'reconstructed', default is 'factorized'
204 | checkpointing : bool
205 | whether to enable gradient checkpointing to save memory during training-mode forward, default is False
206 | rank : {rank of the decomposition, 'same', float}
207 | if float, percentage of parameters of the original weights to use
208 | if 'same' use the same number of parameters
209 | bias : bool, default is True
210 | """
211 | if factorization == 'TTM' and len(linear_list) > 1:
212 | raise ValueError(f'TTM factorization only support single factorized layers but got {len(linear_list)} layers.')
213 |
214 | for linear in linear_list:
215 | out_features, in_features = linear.weight.shape
216 | assert(out_features == np.prod(out_tensorized_features))
217 | assert(in_features == np.prod(in_tensorized_features))
218 |
219 | instance = cls(in_tensorized_features, out_tensorized_features, bias=bias,
220 | factorization=factorization, rank=rank, implementation=implementation,
221 | n_layers=len(linear_list), checkpointing=checkpointing, device=linear.weight.device, dtype=linear.weight.dtype)
222 | weight_tensor = torch.stack([layer.weight.data for layer in linear_list])
223 | instance.weight.init_from_matrix(weight_tensor, **decomposition_kwargs)
224 |
225 | if bias:
226 | for i, layer in enumerate(linear_list):
227 | if layer.bias is not None:
228 | instance.bias.data[i] = layer.bias.data
229 | instance.has_bias[i] = 1
230 |
231 | return instance
232 |
233 | def __repr__(self):
234 | msg = (f'{self.__class__.__name__}(in_features={self.in_features}, out_features={self.out_features},'
235 | f' weight of size ({self.out_features}, {self.in_features}) tensorized to ({self.out_tensorized_features}, {self.in_tensorized_features}),'
236 | f'factorization={self.weight._name}, rank={self.rank}, implementation={self.implementation}')
237 | if self.bias is None:
238 | msg += f', bias=False'
239 |
240 | if self.n_layers == 1:
241 | msg += ', with a single layer parametrized, '
242 | return msg
243 |
244 | msg += f' with {self.n_layers} layers jointly parametrized.'
245 |
246 | return msg
247 |
248 |
249 | class SubFactorizedLinear(nn.Module):
250 | """Class representing one of the convolutions from the mother joint factorized convolution
251 |
252 | Parameters
253 | ----------
254 |
255 | Notes
256 | -----
257 | This relies on the fact that nn.Parameters are not duplicated:
258 | if the same nn.Parameter is assigned to multiple modules, they all point to the same data,
259 | which is shared.
260 | """
261 | def __init__(self, main_linear, indices):
262 | super().__init__()
263 | self.main_linear = main_linear
264 | self.indices = indices
265 |
266 | def forward(self, x):
267 | return self.main_linear(x, self.indices)
268 |
269 | def extra_repr(self):
270 | msg = f'in_features={self.main_linear.in_features}, out_features={self.main_linear.out_features}'
271 | if self.main_linear.has_bias[self.indices]:
272 | msg += ', bias=True'
273 | return msg
274 |
275 | def __repr__(self):
276 | msg = f' {self.__class__.__name__} {self.indices} from main factorized layer.'
277 | msg += f'\n{self.__class__.__name__}('
278 | msg += self.extra_repr()
279 | msg += ')'
280 | return msg
281 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tensor_contraction_layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Tensor Contraction Layers
3 | """
4 |
5 | # Author: Jean Kossaifi
6 | # License: BSD 3 clause
7 |
8 | from tensorly import tenalg
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import init
12 |
13 | import math
14 |
15 | import tensorly as tl
16 | tl.set_backend('pytorch')
17 |
18 |
19 | class TCL(nn.Module):
20 | """Tensor Contraction Layer [1]_
21 |
22 | Parameters
23 | ----------
24 | input_size : int iterable
25 | shape of the input, excluding batch size
26 | rank : int list or int
27 | rank of the TCL, will also be the output-shape (excluding batch-size)
28 | if int, the same rank will be used for all dimensions
29 | verbose : int, default is 1
30 | level of verbosity
31 |
32 | References
33 | ----------
34 | .. [1] J. Kossaifi, A. Khanna, Z. Lipton, T. Furlanello and A. Anandkumar,
35 | "Tensor Contraction Layers for Parsimonious Deep Nets," 2017 IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW),
36 | Honolulu, HI, 2017, pp. 1940-1946, doi: 10.1109/CVPRW.2017.243.
37 | """
38 | def __init__(self, input_shape, rank, verbose=0, bias=False,
39 | device=None, dtype=None, **kwargs):
40 | super().__init__(**kwargs)
41 | self.verbose = verbose
42 |
43 | if isinstance(input_shape, int):
44 | self.input_shape = (input_shape, )
45 | else:
46 | self.input_shape = tuple(input_shape)
47 |
48 | self.order = len(input_shape)
49 |
50 | if isinstance(rank, int):
51 | self.rank = (rank, )*self.order
52 | else:
53 | self.rank = tuple(rank)
54 |
55 | # Start at 1 as the batch-size is not projected
56 | self.contraction_modes = list(range(1, self.order + 1))
57 | for i, (s, r) in enumerate(zip(self.input_shape, self.rank)):
58 | self.register_parameter(f'factor_{i}', nn.Parameter(torch.empty((r, s), device=device, dtype=dtype)))
59 |
60 | # self.factors = ParameterList(parameters=factors)
61 | if bias:
62 | self.bias = nn.Parameter(
63 | torch.empty(self.output_shape, device=device, dtype=dtype), requires_grad=True)
64 | else:
65 | self.register_parameter('bias', None)
66 |
67 | self.reset_parameters()
68 |
69 | @property
70 | def factors(self):
71 | return [getattr(self, f'factor_{i}') for i in range(self.order)]
72 |
73 | def forward(self, x):
74 | """Performs a forward pass"""
75 | x = tenalg.multi_mode_dot(
76 | x, self.factors, modes=self.contraction_modes)
77 |
78 | if self.bias is not None:
79 | return x + self.bias
80 | else:
81 | return x
82 |
83 | def reset_parameters(self):
84 | """Sets the parameters' values randomly
85 |
86 | Todo
87 | ----
88 | This may be renamed to init_from_random for consistency with TensorModules
89 | """
90 | for i in range(self.order):
91 | init.kaiming_uniform_(getattr(self, f'factor_{i}'), a=math.sqrt(5))
92 | if self.bias is not None:
93 | bound = 1 / math.sqrt(self.input_shape[0])
94 | init.uniform_(self.bias, -bound, bound)
95 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tensor_regression_layers.py:
--------------------------------------------------------------------------------
1 | """Tensor Regression Layers
2 | """
3 |
4 | # Author: Jean Kossaifi
5 | # License: BSD 3 clause
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import tensorly as tl
11 | tl.set_backend('pytorch')
12 | from ..functional.tensor_regression import trl
13 |
14 | from ..factorized_tensors import FactorizedTensor
15 |
16 | class TRL(nn.Module):
17 | """Tensor Regression Layers
18 |
19 | Parameters
20 | ----------
21 | input_shape : int iterable
22 | shape of the input, excluding batch size
23 | output_shape : int iterable
24 | shape of the output, excluding batch size
25 | verbose : int, default is 0
26 | level of verbosity
27 |
28 | References
29 | ----------
30 | .. [1] Tensor Regression Networks, Jean Kossaifi, Zachary C. Lipton, Arinbjorn Kolbeinsson,
31 | Aran Khanna, Tommaso Furlanello, Anima Anandkumar, JMLR, 2020.
32 | """
33 | def __init__(self, input_shape, output_shape, bias=False, verbose=0,
34 | factorization='cp', rank='same', n_layers=1,
35 | device=None, dtype=None, **kwargs):
36 | super().__init__(**kwargs)
37 | self.verbose = verbose
38 |
39 | if isinstance(input_shape, int):
40 | self.input_shape = (input_shape, )
41 | else:
42 | self.input_shape = tuple(input_shape)
43 |
44 | if isinstance(output_shape, int):
45 | self.output_shape = (output_shape, )
46 | else:
47 | self.output_shape = tuple(output_shape)
48 |
49 | self.n_input = len(self.input_shape)
50 | self.n_output = len(self.output_shape)
51 | self.weight_shape = self.input_shape + self.output_shape
52 | self.order = len(self.weight_shape)
53 |
54 | if bias:
55 | self.bias = nn.Parameter(torch.empty(self.output_shape, device=device, dtype=dtype))
56 | else:
57 | self.bias = None
58 |
59 | if n_layers == 1:
60 | factorization_shape = self.weight_shape
61 | elif isinstance(n_layers, int):
62 | factorization_shape = (n_layers, ) + self.weight_shape
63 | elif isinstance(n_layers, tuple):
64 | factorization_shape = n_layers + self.weight_shape
65 |
66 | if isinstance(factorization, FactorizedTensor):
67 | self.weight = factorization.to(device).to(dtype)
68 | else:
69 | self.weight = FactorizedTensor.new(factorization_shape, rank=rank, factorization=factorization,
70 | device=device, dtype=dtype)
71 | self.init_from_random()
72 |
73 | self.factorization = self.weight.name
74 |
75 | def forward(self, x):
76 | """Performs a forward pass"""
77 | return trl(x, self.weight, bias=self.bias)
78 |
79 | def init_from_random(self, decompose_full_weight=False):
80 | """Initialize the module randomly
81 |
82 | Parameters
83 | ----------
84 | decompose_full_weight : bool, default is False
85 | if True, constructs a full weight tensor and decomposes it to initialize the factors
86 | otherwise, the factors are directly initialized randomlys
87 | """
88 | with torch.no_grad():
89 | if decompose_full_weight:
90 | full_weight = torch.normal(0.0, 0.02, size=self.weight_shape)
91 | self.weight.init_from_tensor(full_weight)
92 | else:
93 | self.weight.normal_()
94 | if self.bias is not None:
95 | self.bias.uniform_(-1, 1)
96 |
97 | def init_from_linear(self, linear, unsqueezed_modes=None, **kwargs):
98 | """Initialise the TRL from the weights of a fully connected layer
99 |
100 | Parameters
101 | ----------
102 | linear : torch.nn.Linear
103 | unsqueezed_modes : int list or None
104 | For Tucker factorization, this allows to replace pooling layers and instead
105 | learn the average pooling for the specified modes ("unsqueezed_modes").
106 | **for factorization='Tucker' only**
107 | """
108 | if unsqueezed_modes is not None:
109 | if self.factorization != 'Tucker':
110 | raise ValueError(f'unsqueezed_modes is only supported for factorization="tucker" but factorization is {self.factorization}.')
111 |
112 | unsqueezed_modes = sorted(unsqueezed_modes)
113 | weight_shape = list(self.weight_shape)
114 | for mode in unsqueezed_modes[::-1]:
115 | if mode == 0:
116 | raise ValueError(f'Cannot learn pooling for mode-0 (channels).')
117 | if mode > self.n_input:
118 | msg = 'Can only learn pooling for the input tensor. '
119 | msg += f'The input has only {self.n_input} modes, yet got a unsqueezed_mode for mode {mode}.'
120 | raise ValueError(msg)
121 |
122 | weight_shape.pop(mode)
123 | kwargs['unsqueezed_modes'] = unsqueezed_modes
124 | else:
125 | weight_shape = self.weight_shape
126 |
127 | with torch.no_grad():
128 | weight = torch.t(linear.weight).contiguous().view(weight_shape)
129 |
130 | self.weight.init_from_tensor(weight, **kwargs)
131 | if self.bias is not None:
132 | self.bias.data = linear.bias.data
133 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/factorized_layers/tests/__init__.py
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tests/test_factorized_convolution.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch import nn
4 |
5 | import tensorly as tl
6 | tl.set_backend('pytorch')
7 | from ... import FactorizedTensor
8 | from tltorch.factorized_layers import FactorizedConv
9 | from tensorly.testing import assert_array_almost_equal
10 |
11 | from ..factorized_convolution import (FactorizedConv, kernel_shape_to_factorization_shape, tensor_to_kernel)
12 |
13 | @pytest.mark.parametrize('factorization, implementation',
14 | [('CP', 'factorized'), ('CP', 'reconstructed'), ('CP', 'mobilenet'),
15 | ('Tucker', 'factorized'), ('Tucker', 'reconstructed'),
16 | ('TT', 'factorized'), ('TT', 'reconstructed')])
17 | def test_single_conv(factorization, implementation,
18 | order=2, rank=0.5, rng=None, input_channels=4, output_channels=5,
19 | kernel_size=3, batch_size=1, activation_size=(8, 7), device='cpu'):
20 | rng = tl.check_random_state(rng)
21 | input_shape = (batch_size, input_channels) + activation_size
22 | kernel_shape = (output_channels, input_channels) + (kernel_size, )*order
23 |
24 | if rank is None:
25 | rank = max(kernel_shape)
26 |
27 | if order == 1:
28 | FullConv = nn.Conv1d
29 | elif order == 2:
30 | FullConv = nn.Conv2d
31 | elif order == 3:
32 | FullConv = nn.Conv3d
33 |
34 | # Factorized input tensor
35 | factorization_shape = kernel_shape_to_factorization_shape(factorization, kernel_shape)
36 | decomposed_weights = FactorizedTensor.new(shape=factorization_shape, rank=rank, factorization=factorization).normal_(0, 1)
37 | full_weights = tensor_to_kernel(factorization, decomposed_weights.to_tensor().to(device))
38 | data = torch.tensor(rng.random_sample(input_shape), dtype=torch.float32).to(device)
39 |
40 | # PyTorch regular Conv
41 | conv = FullConv(input_channels, output_channels, kernel_size, bias=True, padding=1)
42 | true_bias = conv.bias.data
43 | conv.weight.data = full_weights
44 |
45 | # Factorized conv
46 | fact_conv = FactorizedConv.from_factorization(decomposed_weights, implementation=implementation,
47 | bias=true_bias, padding=1)
48 |
49 | # First check it has the correct implementation
50 | msg = f'Created implementation={implementation} but {fact_conv.implementation} was created.'
51 | assert fact_conv.implementation == implementation, msg
52 |
53 | # Check that it gives the same result as the full conv
54 | true_res = conv(data)
55 | res = fact_conv(data)
56 | msg = f'{fact_conv.__class__.__name__} does not give same result as {FullConv.__class__.__name__}.'
57 | assert_array_almost_equal(true_res, res, decimal=4, err_msg=msg)
58 |
59 | # Check that the parameters of the decomposition are transposed back correctly
60 | decomposed_weights, bias = fact_conv.weight, fact_conv.bias
61 | rec = tensor_to_kernel(factorization, decomposed_weights.to_tensor())
62 | msg = msg = f'{fact_conv.__class__.__name__} does not return the decomposition it was constructed with.'
63 | assert_array_almost_equal(rec, full_weights, err_msg=msg)
64 | msg = msg = f'{fact_conv.__class__.__name__} does not return the bias it was constructed with.'
65 | assert_array_almost_equal(bias, true_bias, err_msg=msg)
66 |
67 | conv = FullConv(input_channels, output_channels, kernel_size, bias=True, padding=1)
68 | conv.weight.data.uniform_(-1, 1)
69 | fact_conv = FactorizedConv.from_conv(conv, rank=30, factorization=factorization)#, decomposition_kwargs=dict(init='svd', l2_reg=1e-5))
70 | true_res = conv(data)
71 | res = fact_conv(data)
72 | msg = f'{fact_conv.__class__.__name__} does not give same result as {FullConv.__class__.__name__}.'
73 | assert_array_almost_equal(true_res, res, decimal=2, err_msg=msg)
74 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tests/test_factorized_embedding.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch import nn
4 | from ..factorized_embedding import FactorizedEmbedding
5 |
6 | import tensorly as tl
7 | tl.set_backend('pytorch')
8 | from tensorly import testing
9 |
10 | #@pytest.mark.parametrize('factorization', ['BlockTT'])
11 | @pytest.mark.parametrize('factorization', ['CP','Tucker', 'BlockTT'])
12 | @pytest.mark.parametrize('dims', [(256,16), (1000,32)])
13 | def test_FactorizedEmbedding(factorization,dims):
14 | NUM_EMBEDDINGS, EMBEDDING_DIM = dims
15 |
16 | #create factorized embedding
17 | factorized_embedding = FactorizedEmbedding(NUM_EMBEDDINGS, EMBEDDING_DIM, factorization=factorization)
18 |
19 | #make test embedding of same shape and same weight
20 | test_embedding = torch.nn.Embedding(factorized_embedding.weight.shape[0], factorized_embedding.weight.shape[1])
21 | test_embedding.weight.data.copy_(factorized_embedding.weight.to_matrix().detach())
22 |
23 | #create batch and test using all entries (shuffled since entries may not be sorted)
24 | batch = torch.randperm(NUM_EMBEDDINGS)#.view(-1,1)
25 | normal_embed = test_embedding(batch)
26 | factorized_embed = factorized_embedding(batch)
27 | testing.assert_array_almost_equal(normal_embed,factorized_embed,decimal=2)
28 |
29 | #split batch into tensor with first dimension 3
30 | batch = torch.randperm(NUM_EMBEDDINGS)
31 | split_size = NUM_EMBEDDINGS//5
32 |
33 | split_batch = [batch[:1*split_size],batch[1*split_size:2*split_size],batch[3*split_size:4*split_size]]
34 |
35 | split_batch = torch.stack(split_batch,0)
36 |
37 | normal_embed = test_embedding(split_batch)
38 | factorized_embed = factorized_embedding(split_batch)
39 | testing.assert_array_almost_equal(normal_embed,factorized_embed,decimal=2)
40 |
41 | #BlockTT has no init_from_matrix, so skip that test
42 | if factorization=='BlockTT':
43 | return
44 |
45 | del factorized_embedding
46 |
47 | #init from test layer which is low rank
48 | factorized_embedding = FactorizedEmbedding.from_embedding(test_embedding,factorization=factorization,rank=8)
49 |
50 | #test using same batch as before, only test that shapes match
51 | normal_embed = test_embedding(batch)
52 | factorized_embed = factorized_embedding(batch)
53 | testing.assert_array_almost_equal(normal_embed.shape,factorized_embed.shape,decimal=2)
54 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tests/test_factorized_linear.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from torch import nn
3 | from ..factorized_linear import FactorizedLinear
4 | from ... import TensorizedTensor
5 |
6 | import tensorly as tl
7 | tl.set_backend('pytorch')
8 | from tensorly import testing
9 |
10 | @pytest.mark.parametrize('factorization', ['CP', 'Tucker', 'BlockTT'])
11 | def test_FactorizedLinear(factorization):
12 | random_state = 12345
13 | rng = tl.check_random_state(random_state)
14 | batch_size = 2
15 | in_features = 9
16 | in_shape = (3, 3)
17 | out_features = 16
18 | out_shape = (4, 4)
19 | data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32)
20 |
21 | # Creat from a tensor factorization
22 | tensor = TensorizedTensor.new((out_shape, in_shape), rank='same', factorization=factorization)
23 | tensor.normal_()
24 | fc = nn.Linear(in_features, out_features, bias=True)
25 | fc.weight.data = tensor.to_matrix()
26 | tfc = FactorizedLinear(in_shape, out_shape, rank='same', factorization=tensor, bias=True)
27 | tfc.bias.data = fc.bias
28 | res_fc = fc(data)
29 | res_tfc = tfc(data)
30 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2)
31 |
32 | # Decompose an existing layer
33 | fc = nn.Linear(in_features, out_features, bias=True)
34 | tfc = FactorizedLinear.from_linear(fc, auto_tensorize=False,
35 | in_tensorized_features=(3, 3), out_tensorized_features=(4, 4), rank=34, bias=True)
36 | res_fc = fc(data)
37 | res_tfc = tfc(data)
38 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2)
39 |
40 | # Decompose an existing layer, automatically determine tensorization shape
41 | fc = nn.Linear(in_features, out_features, bias=True)
42 | tfc = FactorizedLinear.from_linear(fc, auto_tensorize=True, n_tensorized_modes=2, rank=34, bias=True)
43 | res_fc = fc(data)
44 | res_tfc = tfc(data)
45 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2)
46 |
47 | # Multi-layer factorization
48 | fc1 = nn.Linear(in_features, out_features, bias=True)
49 | fc2 = nn.Linear(in_features, out_features, bias=True)
50 | tfc = FactorizedLinear.from_linear_list([fc1, fc2], in_tensorized_features=in_shape, out_tensorized_features=out_shape, rank=38, bias=True)
51 | ## Test first parametrized conv
52 | res_fc = fc1(data)
53 | res_tfc = tfc[0](data)
54 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2)
55 | ## Test second parametrized conv
56 | res_fc = fc2(data)
57 | res_tfc = tfc[1](data)
58 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2)
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tests/test_tensor_contraction_layers.py:
--------------------------------------------------------------------------------
1 | from ..tensor_contraction_layers import TCL
2 | import tensorly as tl
3 | from tensorly import testing
4 | tl.set_backend('pytorch')
5 |
6 |
7 | def test_tcl():
8 | random_state = 12345
9 | rng = tl.check_random_state(random_state)
10 | batch_size = 2
11 | in_shape = (4, 5, 6)
12 | out_shape = (2, 3, 5)
13 | data = tl.tensor(rng.random_sample((batch_size, ) + in_shape), dtype=tl.float32)
14 |
15 | expected_shape = (batch_size, ) + out_shape
16 | tcl = TCL(input_shape=in_shape, rank=out_shape, bias=False)
17 | res = tcl(data)
18 | testing.assert_(res.shape==expected_shape,
19 | msg=f'Wrong output size of TCL, expected {expected_shape} but got {res.shape}')
20 |
--------------------------------------------------------------------------------
/tltorch/factorized_layers/tests/test_trl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | from torch import nn
4 |
5 | from ..tensor_regression_layers import TRL
6 | from ... import FactorizedTensor
7 |
8 | import tensorly as tl
9 | tl.set_backend('pytorch')
10 | from tensorly import tenalg
11 | from tensorly import random
12 | from tensorly import testing
13 |
14 | import pytest
15 |
16 |
17 | def optimize_trl(trl, loader, lr=0.005, n_epoch=200, verbose=False):
18 | """Function that takes as input a TRL, dataset and optimizes the TRL
19 |
20 | Parameters
21 | ----------
22 | trl : tltorch.TRL
23 | loader : Pytorch dataset, returning batches (batch, labels)
24 | lr : float, default is 0.1
25 | learning rate
26 | n_epoch : int, default is 100
27 | verbose : bool, default is False
28 | level of verbosity
29 |
30 | Returns
31 | -------
32 | (trl, objective function loss)
33 | """
34 | optimizer = torch.optim.Adam(trl.parameters(), lr=lr, weight_decay=1e-7)
35 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, cooldown=20, factor=0.5, verbose=verbose)
36 |
37 | for epoch in range(n_epoch):
38 | for i, (sample_batch, label_batch) in enumerate(loader):
39 |
40 | # Important: do not forget to reset the gradients
41 | optimizer.zero_grad()
42 |
43 | # Reconstruct the tensor from the decomposed form
44 | pred = trl(sample_batch)
45 |
46 | # squared l2 loss
47 | loss = tl.norm(pred - label_batch, 2)
48 |
49 | loss.backward()
50 | optimizer.step()
51 | scheduler.step(loss)
52 |
53 | if not i or epoch % 10 == 0:
54 | if verbose:
55 | print(f"Epoch {epoch},. loss: {loss.item()}")
56 |
57 | return trl, loss.item()
58 |
59 |
60 | @pytest.mark.parametrize('factorization, true_rank, rank',
61 | [('Tucker', 0.5, 0.6), #(5, 5, 5, 4, 1, 3)),
62 | ('CP', 0.5, 0.6),
63 | ('TT', 0.5, 0.6)])
64 | def test_trl(factorization, true_rank, rank):
65 | """Test for the TRL
66 | """
67 | # Parameter of the experiment
68 | input_shape = (3, 4)
69 | output_shape = (2, 1)
70 | batch_size = 500
71 |
72 | # fix the random seed for reproducibility
73 | random_state = 12345
74 |
75 | rng = tl.check_random_state(random_state)
76 | tol = 0.08
77 |
78 | # Generate a random tensor
79 | samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1), dtype=tl.float32)
80 | true_bias = tl.tensor(rng.uniform(size=output_shape), dtype=tl.float32)
81 |
82 | with torch.no_grad():
83 | true_weight = FactorizedTensor.new(shape=input_shape+output_shape,
84 | rank=true_rank, factorization=factorization)
85 | true_weight = true_weight.normal_(0, 0.1).to_tensor()
86 | labels = tenalg.inner(samples, true_weight, n_modes=len(input_shape)) + true_bias
87 |
88 | dataset = data.TensorDataset(samples, labels)
89 | loader = data.DataLoader(dataset, batch_size=32)
90 |
91 | trl = TRL(input_shape=input_shape, output_shape=output_shape, factorization=factorization, rank=rank, bias=True)
92 | trl.weight.normal_(0, 0.1) # TODO: do this through reset_parameters
93 | with torch.no_grad():
94 | trl.bias.data.uniform_(-0.01, 0.01)
95 |
96 | print(f'Testing {trl.__class__.__name__}.')
97 | #trl.init_from_random(decompose_full_weight=True)
98 | trl, _ = optimize_trl(trl, loader, verbose=False)
99 |
100 | with torch.no_grad():
101 | rec_weights = trl.weight.to_tensor()
102 | rec_loss = tl.norm(rec_weights - true_weight)/tl.norm(true_weight)
103 |
104 | with torch.no_grad():
105 | bias_rec_loss = tl.norm(trl.bias - true_bias)/tl.norm(true_bias)
106 |
107 | testing.assert_(rec_loss <= tol, msg=f'Rec_loss of the weights={rec_loss} higher than tolerance={tol}')
108 | testing.assert_(bias_rec_loss <= tol, msg=f'Rec_loss of the bias={bias_rec_loss} higher than tolerance={tol}')
109 |
110 |
111 | @pytest.mark.parametrize('order', [2, 3])
112 | @pytest.mark.parametrize('project_input', [False, True])
113 | @pytest.mark.parametrize('learn_pool', [True, False])
114 | def test_TuckerTRL(order, project_input, learn_pool):
115 | """Test for Tucker TRL
116 |
117 | Here, we test specifically for init from fully-connected layer
118 | (both when learning the pooling and not).
119 |
120 | We also test that projecting the input or not doesn't change the results
121 | """
122 | in_features = 10
123 | out_features = 12
124 | batch_size = 2
125 | spatial_size = 4
126 | in_rank = 10
127 | out_rank = 12
128 | order= 2
129 |
130 | # fix the random seed for reproducibility and create random input
131 | random_state = 12345
132 | rng = tl.check_random_state(random_state)
133 | data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order), dtype=tl.float32)
134 |
135 | # Build a simple net with avg-pool, flatten + fully-connected
136 | if order == 2:
137 | pool = nn.AdaptiveAvgPool2d((1, 1))
138 | else:
139 | pool = nn.AdaptiveAvgPool3d((1, 1, 1))
140 | fc = nn.Linear(in_features, out_features, bias=False)
141 |
142 | def net(data):
143 | x = pool(data)
144 | x = x.squeeze()
145 | x = fc(x)
146 | return x
147 |
148 | res_fc = net(tl.copy(data))
149 |
150 | # A replacement TRL
151 | out_shape = (out_features, )
152 |
153 | if learn_pool:
154 | # Learn the average pool as part of the TRL
155 | in_shape = (in_features, ) + (spatial_size, )*order
156 | rank = (in_rank, ) + (1, )*order + (out_rank, )
157 | unsqueezed_modes = list(range(1, order+1))
158 | else:
159 | in_shape = (in_features, )
160 | rank = (in_rank, out_rank)
161 | unsqueezed_modes = None
162 | data = pool(data).squeeze()
163 |
164 | trl = TRL(in_shape, out_shape, rank=rank, factorization='tucker')
165 | trl.init_from_linear(fc, unsqueezed_modes=unsqueezed_modes)
166 | res_trl = trl(data)
167 |
168 | testing.assert_array_almost_equal(res_fc, res_trl)
169 |
170 |
171 | @pytest.mark.parametrize('factorization', ['CP', 'TT'])
172 | @pytest.mark.parametrize('bias', [True, False])
173 | def test_TRL_from_linear(factorization, bias):
174 | """Test for CP and TT TRL
175 |
176 | Here, we test specifically for init from fully-connected layer
177 | """
178 | in_features = 10
179 | out_features = 12
180 | batch_size = 2
181 |
182 | # fix the random seed for reproducibility and create random input
183 | random_state = 12345
184 | rng = tl.check_random_state(random_state)
185 | data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32)
186 | fc = nn.Linear(in_features, out_features, bias=bias)
187 | res_fc = fc(tl.copy(data))
188 | trl = TRL((in_features, ), (out_features, ), rank=10, bias=bias, factorization=factorization)
189 | trl.init_from_linear(fc)
190 | res_trl = trl(data)
191 |
192 | testing.assert_array_almost_equal(res_fc, res_trl, decimal=2)
193 |
194 |
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/__init__.py:
--------------------------------------------------------------------------------
1 | from .factorized_tensors import (CPTensor, TuckerTensor, TTTensor,
2 | DenseTensor, FactorizedTensor)
3 | from .tensorized_matrices import (TensorizedTensor, CPTensorized, BlockTT,
4 | DenseTensorized, TuckerTensorized)
5 | from .complex_factorized_tensors import (ComplexCPTensor, ComplexTuckerTensor,
6 | ComplexTTTensor, ComplexDenseTensor)
7 | from .complex_tensorized_matrices import (ComplexCPTensorized, ComplexBlockTT,
8 | ComplexDenseTensorized, ComplexTuckerTensorized)
9 | from .init import tensor_init, cp_init, tucker_init, tt_init, block_tt_init
10 |
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/complex_factorized_tensors.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 |
5 | import tensorly as tl
6 | tl.set_backend('pytorch')
7 | from tltorch.factorized_tensors.factorized_tensors import TuckerTensor, CPTensor, TTTensor, DenseTensor
8 | from tltorch.utils.parameter_list import FactorList, ComplexFactorList
9 |
10 |
11 | # Author: Jean Kossaifi
12 | # License: BSD 3 clause
13 |
14 | class ComplexHandler():
15 | def __setattr__(self, key, value):
16 | if isinstance(value, (FactorList)):
17 | value = ComplexFactorList(value)
18 | super().__setattr__(key, value)
19 |
20 | elif isinstance(value, nn.Parameter):
21 | self.register_parameter(key, value)
22 | elif torch.is_tensor(value):
23 | self.register_buffer(key, value)
24 | else:
25 | super().__setattr__(key, value)
26 |
27 | def __getattr__(self, key):
28 | value = super().__getattr__(key)
29 | if torch.is_tensor(value):
30 | value = torch.view_as_complex(value)
31 | return value
32 |
33 | def register_parameter(self, key, value):
34 | value = nn.Parameter(torch.view_as_real(value))
35 | super().register_parameter(key, value)
36 |
37 | def register_buffer(self, key, value):
38 | value = torch.view_as_real(value)
39 | super().register_buffer(key, value)
40 |
41 |
42 | class ComplexDenseTensor(ComplexHandler, DenseTensor, name='ComplexDense'):
43 | """Complex Dense Factorization
44 | """
45 | @classmethod
46 | def new(cls, shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
47 | return super().new(shape, rank, device=device, dtype=dtype, **kwargs)
48 |
49 | class ComplexTuckerTensor(ComplexHandler, TuckerTensor, name='ComplexTucker'):
50 | """Complex Tucker Factorization
51 | """
52 | @classmethod
53 | def new(cls, shape, rank='same', fixed_rank_modes=None,
54 | device=None, dtype=torch.cfloat, **kwargs):
55 | return super().new(shape, rank, fixed_rank_modes=fixed_rank_modes,
56 | device=device, dtype=dtype, **kwargs)
57 |
58 | class ComplexTTTensor(ComplexHandler, TTTensor, name='ComplexTT'):
59 | """Complex TT Factorization
60 | """
61 | @classmethod
62 | def new(cls, shape, rank='same', fixed_rank_modes=None,
63 | device=None, dtype=torch.cfloat, **kwargs):
64 | return super().new(shape, rank,
65 | device=device, dtype=dtype, **kwargs)
66 |
67 | class ComplexCPTensor(ComplexHandler, CPTensor, name='ComplexCP'):
68 | """Complex CP Factorization
69 | """
70 | @classmethod
71 | def new(cls, shape, rank='same', fixed_rank_modes=None,
72 | device=None, dtype=torch.cfloat, **kwargs):
73 | return super().new(shape, rank,
74 | device=device, dtype=dtype, **kwargs)
75 |
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/complex_tensorized_matrices.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorly as tl
3 | tl.set_backend('pytorch')
4 | from tltorch.factorized_tensors.tensorized_matrices import TuckerTensorized, DenseTensorized, CPTensorized, BlockTT
5 | from .complex_factorized_tensors import ComplexHandler
6 |
7 | # Author: Jean Kossaifi
8 | # License: BSD 3 clause
9 |
10 |
11 | class ComplexDenseTensorized(ComplexHandler, DenseTensorized, name='ComplexDense'):
12 | """Complex DenseTensorized Factorization
13 | """
14 | _complex_params = ['tensor']
15 |
16 | @classmethod
17 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
18 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
19 |
20 | class ComplexTuckerTensorized(ComplexHandler, TuckerTensorized, name='ComplexTucker'):
21 | """Complex TuckerTensorized Factorization
22 | """
23 | _complex_params = ['core', 'factors']
24 |
25 | @classmethod
26 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
27 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
28 |
29 | class ComplexBlockTT(ComplexHandler, BlockTT, name='ComplexTT'):
30 | """Complex BlockTT Factorization
31 | """
32 | _complex_params = ['factors']
33 |
34 | @classmethod
35 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
36 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
37 |
38 | class ComplexCPTensorized(ComplexHandler, CPTensorized, name='ComplexCP'):
39 | """Complex Tensorized CP Factorization
40 | """
41 | _complex_params = ['weights', 'factors']
42 |
43 | @classmethod
44 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
45 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
46 |
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/factorized_tensors.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 |
7 | import tensorly as tl
8 | tl.set_backend('pytorch')
9 | from tensorly import tenalg
10 | from tensorly.decomposition import parafac, tucker, tensor_train
11 |
12 | from .core import FactorizedTensor
13 | from ..utils import FactorList
14 |
15 |
16 | # Author: Jean Kossaifi
17 | # License: BSD 3 clause
18 |
19 |
20 | class DenseTensor(FactorizedTensor, name='Dense'):
21 | """Dense tensor
22 | """
23 | def __init__(self, tensor, shape=None, rank=None):
24 | super().__init__()
25 | if shape is not None and rank is not None:
26 | self.shape, self.rank = shape, rank
27 | else:
28 | self.shape = tensor.shape
29 | self.rank = None
30 | self.order = len(self.shape)
31 |
32 | if isinstance(tensor, nn.Parameter):
33 | self.register_parameter('tensor', tensor)
34 | else:
35 | self.register_buffer('tensor', tensor)
36 |
37 | @classmethod
38 | def new(cls, shape, rank=None, device=None, dtype=None, **kwargs):
39 | # Register the parameters
40 | tensor = nn.Parameter(torch.empty(shape, device=device, dtype=dtype))
41 |
42 | return cls(tensor)
43 |
44 | @classmethod
45 | def from_tensor(cls, tensor, rank='same', **kwargs):
46 | return cls(nn.Parameter(tl.copy(tensor)))
47 |
48 | def init_from_tensor(self, tensor, l2_reg=1e-5, **kwargs):
49 | with torch.no_grad():
50 | self.tensor = nn.Parameter(tl.copy(tensor))
51 | return self
52 |
53 | @property
54 | def decomposition(self):
55 | return self.tensor
56 |
57 | def to_tensor(self):
58 | return self.tensor
59 |
60 | def normal_(self, mean=0, std=1):
61 | with torch.no_grad():
62 | self.tensor.data.normal_(mean, std)
63 | return self
64 |
65 | def __getitem__(self, indices):
66 | return self.__class__(self.tensor[indices])
67 |
68 |
69 | class CPTensor(FactorizedTensor, name='CP'):
70 | """CP Factorization
71 |
72 | Parameters
73 | ----------
74 | weights
75 | factors
76 | shape
77 | rank
78 | """
79 | def __init__(self, weights, factors, shape=None, rank=None):
80 | super().__init__()
81 | if shape is not None and rank is not None:
82 | self.shape, self.rank = shape, rank
83 | else:
84 | self.shape, self.rank = tl.cp_tensor._validate_cp_tensor((weights, factors))
85 | self.order = len(self.shape)
86 |
87 | # self.weights = weights
88 | if isinstance(weights, nn.Parameter):
89 | self.register_parameter('weights', weights)
90 | else:
91 | self.register_buffer('weights', weights)
92 |
93 | self.factors = FactorList(factors)
94 |
95 | @classmethod
96 | def new(cls, shape, rank, device=None, dtype=None, **kwargs):
97 | rank = tl.cp_tensor.validate_cp_rank(shape, rank)
98 |
99 | # Register the parameters
100 | weights = nn.Parameter(torch.empty(rank, device=device, dtype=dtype))
101 | # Avoid the issues with ParameterList
102 | factors = [nn.Parameter(torch.empty((s, rank), device=device, dtype=dtype)) for s in shape]
103 |
104 | return cls(weights, factors)
105 |
106 | @classmethod
107 | def from_tensor(cls, tensor, rank='same', **kwargs):
108 | shape = tensor.shape
109 | rank = tl.cp_tensor.validate_cp_rank(shape, rank)
110 | dtype = tensor.dtype
111 |
112 | with torch.no_grad():
113 | weights, factors = parafac(tensor.to(torch.float64), rank, **kwargs)
114 |
115 | return cls(nn.Parameter(weights.to(dtype).contiguous()), [nn.Parameter(f.to(dtype).contiguous()) for f in factors])
116 |
117 | def init_from_tensor(self, tensor, l2_reg=1e-5, **kwargs):
118 | with torch.no_grad():
119 | weights, factors = parafac(tensor, self.rank, l2_reg=l2_reg, **kwargs)
120 |
121 | self.weights = nn.Parameter(weights.contiguous())
122 | self.factors = FactorList([nn.Parameter(f.contiguous()) for f in factors])
123 | return self
124 |
125 | @property
126 | def decomposition(self):
127 | return self.weights, self.factors
128 |
129 | def to_tensor(self):
130 | return tl.cp_to_tensor(self.decomposition)
131 |
132 | def normal_(self, mean=0, std=1):
133 | super().normal_(mean, std)
134 | std_factors = (std/math.sqrt(self.rank))**(1/self.order)
135 |
136 | with torch.no_grad():
137 | self.weights.fill_(1)
138 | for factor in self.factors:
139 | factor.data.normal_(0, std_factors)
140 | return self
141 |
142 | def __getitem__(self, indices):
143 | if isinstance(indices, int):
144 | # Select one dimension of one mode
145 | mixing_factor, *factors = self.factors
146 | weights = self.weights*mixing_factor[indices, :]
147 | return self.__class__(weights, factors)
148 |
149 | elif isinstance(indices, slice):
150 | # Index part of a factor
151 | mixing_factor, *factors = self.factors
152 | factors = [mixing_factor[indices, :], *factors]
153 | weights = self.weights
154 | return self.__class__(weights, factors)
155 |
156 | else:
157 | # Index multiple dimensions
158 | factors = self.factors
159 | index_factors = []
160 | weights = self.weights
161 | for index in indices:
162 | if index is Ellipsis:
163 | raise ValueError(f'Ellipsis is not yet supported, yet got indices={indices} which contains one.')
164 |
165 | mixing_factor, *factors = factors
166 | if isinstance(index, (np.integer, int)):
167 | if factors or index_factors:
168 | weights = weights*mixing_factor[index, :]
169 | else:
170 | # No factors left
171 | return tl.sum(weights*mixing_factor[index, :])
172 | else:
173 | index_factors.append(mixing_factor[index, :])
174 |
175 | return self.__class__(weights, index_factors + factors)
176 | # return self.__class__(*tl.cp_indexing(self.weights, self.factors, indices))
177 |
178 | def transduct(self, new_dim, mode=0, new_factor=None):
179 | """Transduction adds a new dimension to the existing factorization
180 |
181 | Parameters
182 | ----------
183 | new_dim : int
184 | dimension of the new mode to add
185 | mode : where to insert the new dimension, after the channels, default is 0
186 | by default, insert the new dimensions before the existing ones
187 | (e.g. add time before height and width)
188 |
189 | Returns
190 | -------
191 | self
192 | """
193 | factors = self.factors
194 | # Important: don't increment the order before accessing factors which uses order!
195 | self.order += 1
196 | self.shape = self.shape[:mode] + (new_dim,) + self.shape[mode:]
197 |
198 | if new_factor is None:
199 | new_factor = torch.ones(new_dim, self.rank)#/new_dim
200 |
201 | factors.insert(mode, nn.Parameter(new_factor.to(factors[0].device).contiguous()))
202 | self.factors = FactorList(factors)
203 |
204 | return self
205 |
206 |
207 | class TuckerTensor(FactorizedTensor, name='Tucker'):
208 | """Tucker Factorization
209 |
210 | Parameters
211 | ----------
212 | core
213 | factors
214 | shape
215 | rank
216 | """
217 | def __init__(self, core, factors, shape=None, rank=None):
218 | super().__init__()
219 | if shape is not None and rank is not None:
220 | self.shape, self.rank = shape, rank
221 | else:
222 | self.shape, self.rank = tl.tucker_tensor._validate_tucker_tensor((core, factors))
223 |
224 | self.order = len(self.shape)
225 | # self.core = core
226 | if isinstance(core, nn.Parameter):
227 | self.register_parameter('core', core)
228 | else:
229 | self.register_buffer('core', core)
230 |
231 | self.factors = FactorList(factors)
232 |
233 | @classmethod
234 | def new(cls, shape, rank, fixed_rank_modes=None,
235 | device=None, dtype=None, **kwargs):
236 | rank = tl.tucker_tensor.validate_tucker_rank(shape, rank, fixed_modes=fixed_rank_modes)
237 |
238 | # Register the parameters
239 | core = nn.Parameter(torch.empty(rank, device=device, dtype=dtype))
240 | # Avoid the issues with ParameterList
241 | factors = [nn.Parameter(torch.empty((s, r), device=device, dtype=dtype)) for (s, r) in zip(shape, rank)]
242 |
243 | return cls(core, factors)
244 |
245 | @classmethod
246 | def from_tensor(cls, tensor, rank='same', fixed_rank_modes=None, **kwargs):
247 | shape = tensor.shape
248 | rank = tl.tucker_tensor.validate_tucker_rank(shape, rank, fixed_modes=fixed_rank_modes)
249 |
250 | with torch.no_grad():
251 | core, factors = tucker(tensor, rank, **kwargs)
252 |
253 | return cls(nn.Parameter(core.contiguous()), [nn.Parameter(f.contiguous()) for f in factors])
254 |
255 | def init_from_tensor(self, tensor, unsqueezed_modes=None, unsqueezed_init='average', **kwargs):
256 | """Initialize the tensor factorization from a tensor
257 |
258 | Parameters
259 | ----------
260 | tensor : torch.Tensor
261 | full tensor to decompose
262 | unsqueezed_modes : int list
263 | list of modes for which the rank is 1 that don't correspond to a mode in the full tensor
264 | essentially we are adding a new dimension for which the core has dim 1,
265 | and that is not initialized through decomposition.
266 | Instead first `tensor` is decomposed into the other factors.
267 | The `unsqueezed factors` are then added and initialized e.g. with 1/dim[i]
268 | unsqueezed_init : 'average' or float
269 | if unsqueezed_modes, this is how the added "unsqueezed" factors will be initialized
270 | if 'average', then unsqueezed_factor[i] will have value 1/tensor.shape[i]
271 | """
272 | if unsqueezed_modes is not None:
273 | unsqueezed_modes = sorted(unsqueezed_modes)
274 | for mode in unsqueezed_modes[::-1]:
275 | if self.rank[mode] != 1:
276 | msg = 'It is only possible to initialize by averagig over mode for which rank=1.'
277 | msg += f'However, got unsqueezed_modes={unsqueezed_modes} but rank[{mode}]={self.rank[mode]} != 1.'
278 | raise ValueError(msg)
279 |
280 | rank = tuple(r for (i, r) in enumerate(self.rank) if i not in unsqueezed_modes)
281 | else:
282 | rank = self.rank
283 |
284 | with torch.no_grad():
285 | core, factors = tucker(tensor, rank, **kwargs)
286 |
287 | if unsqueezed_modes is not None:
288 | # Initialise with 1/shape[mode] or given value
289 | for mode in unsqueezed_modes:
290 | size = self.shape[mode]
291 | factor = torch.ones(size, 1)
292 | if unsqueezed_init == 'average':
293 | factor /= size
294 | else:
295 | factor *= unsqueezed_init
296 | factors.insert(mode, factor)
297 | core = core.unsqueeze(mode)
298 |
299 | self.core = nn.Parameter(core.contiguous())
300 | self.factors = FactorList([nn.Parameter(f.contiguous()) for f in factors])
301 | return self
302 |
303 | @property
304 | def decomposition(self):
305 | return self.core, self.factors
306 |
307 | def to_tensor(self):
308 | return tl.tucker_to_tensor(self.decomposition)
309 |
310 | def normal_(self, mean=0, std=1):
311 | if mean != 0:
312 | raise ValueError(f'Currently only mean=0 is supported, but got mean={mean}')
313 |
314 | r = np.prod([math.sqrt(r) for r in self.rank])
315 | std_factors = (std/r)**(1/(self.order+1))
316 |
317 | with torch.no_grad():
318 | self.core.data.normal_(0, std_factors)
319 | for factor in self.factors:
320 | factor.data.normal_(0, std_factors)
321 | return self
322 |
323 | def __getitem__(self, indices):
324 | if isinstance(indices, int):
325 | # Select one dimension of one mode
326 | mixing_factor, *factors = self.factors
327 | core = tenalg.mode_dot(self.core, mixing_factor[indices, :], 0)
328 | return self.__class__(core, factors)
329 |
330 | elif isinstance(indices, slice):
331 | mixing_factor, *factors = self.factors
332 | factors = [mixing_factor[indices, :], *factors]
333 | return self.__class__(self.core, factors)
334 |
335 | else:
336 | # Index multiple dimensions
337 | modes = []
338 | factors = []
339 | factors_contract = []
340 | for i, (index, factor) in enumerate(zip(indices, self.factors)):
341 | if index is Ellipsis:
342 | raise ValueError(f'Ellipsis is not yet supported, yet got indices={indices}, indices[{i}]={index}.')
343 | if isinstance(index, int):
344 | modes.append(i)
345 | factors_contract.append(factor[index, :])
346 | else:
347 | factors.append(factor[index, :])
348 |
349 | if modes:
350 | core = tenalg.multi_mode_dot(self.core, factors_contract, modes=modes)
351 | else:
352 | core = self.core
353 | factors = factors + self.factors[i+1:]
354 |
355 | if factors:
356 | return self.__class__(core, factors)
357 |
358 | # Fully contracted tensor
359 | return core
360 |
361 |
362 | class TTTensor(FactorizedTensor, name='TT'):
363 | """Tensor-Train (Matrix-Product-State) Factorization
364 |
365 | Parameters
366 | ----------
367 | factors
368 | shape
369 | rank
370 | """
371 | def __init__(self, factors, shape=None, rank=None):
372 | super().__init__()
373 | if shape is None or rank is None:
374 | self.shape, self.rank = tl.tt_tensor._validate_tt_tensor(factors)
375 | else:
376 | self.shape, self.rank = shape, rank
377 |
378 | self.order = len(self.shape)
379 | self.factors = FactorList(factors)
380 |
381 | @classmethod
382 | def new(cls, shape, rank, device=None, dtype=None, **kwargs):
383 | rank = tl.tt_tensor.validate_tt_rank(shape, rank)
384 |
385 | # Avoid the issues with ParameterList
386 | factors = [nn.Parameter(torch.empty((rank[i], s, rank[i+1]), device=device, dtype=dtype)) for i, s in enumerate(shape)]
387 |
388 | return cls(factors)
389 |
390 | @classmethod
391 | def from_tensor(cls, tensor, rank='same', **kwargs):
392 | shape = tensor.shape
393 | rank = tl.tt_tensor.validate_tt_rank(shape, rank)
394 |
395 | with torch.no_grad():
396 | # TODO: deal properly with wrong kwargs
397 | factors = tensor_train(tensor, rank)
398 |
399 | return cls([nn.Parameter(f.contiguous()) for f in factors])
400 |
401 | def init_from_tensor(self, tensor, **kwargs):
402 | with torch.no_grad():
403 | # TODO: deal properly with wrong kwargs
404 | factors = tensor_train(tensor, self.rank)
405 |
406 | self.factors = FactorList([nn.Parameter(f.contiguous()) for f in factors])
407 | self.rank = tuple([f.shape[0] for f in factors] + [1])
408 | return self
409 |
410 | @property
411 | def decomposition(self):
412 | return self.factors
413 |
414 | def to_tensor(self):
415 | return tl.tt_to_tensor(self.decomposition)
416 |
417 | def normal_(self, mean=0, std=1):
418 | if mean != 0:
419 | raise ValueError(f'Currently only mean=0 is supported, but got mean={mean}')
420 |
421 | r = np.prod(self.rank)
422 | std_factors = (std/r)**(1/self.order)
423 | with torch.no_grad():
424 | for factor in self.factors:
425 | factor.data.normal_(0, std_factors)
426 | return self
427 |
428 | def __getitem__(self, indices):
429 | if isinstance(indices, int):
430 | # Select one dimension of one mode
431 | factor, next_factor, *factors = self.factors
432 | next_factor = tenalg.mode_dot(next_factor, factor[:, indices, :].squeeze(1), 0)
433 | return self.__class__([next_factor, *factors])
434 |
435 | elif isinstance(indices, slice):
436 | mixing_factor, *factors = self.factors
437 | factors = [mixing_factor[:, indices], *factors]
438 | return self.__class__(factors)
439 |
440 | else:
441 | factors = []
442 | all_contracted = True
443 | for i, index in enumerate(indices):
444 | if index is Ellipsis:
445 | raise ValueError(f'Ellipsis is not yet supported, yet got indices={indices}, indices[{i}]={index}.')
446 | if isinstance(index, int):
447 | if i:
448 | factor = tenalg.mode_dot(factor, self.factors[i][:, index, :].T, -1)
449 | else:
450 | factor = self.factors[i][:, index, :]
451 | else:
452 | if i:
453 | if all_contracted:
454 | factor = tenalg.mode_dot(self.factors[i][:, index, :], factor, 0)
455 | else:
456 | factors.append(factor)
457 | factor = self.factors[i][:, index, :]
458 | else:
459 | factor = self.factors[i][:, index, :]
460 | all_contracted = False
461 |
462 | if factor.ndim == 2: # We have contracted all cores, so have a 2D matrix
463 | if self.order == (i+1):
464 | # No factors left
465 | return factor.squeeze()
466 | else:
467 | next_factor, *factors = self.factors[i+1:]
468 | factor = tenalg.mode_dot(next_factor, factor, 0)
469 | return self.__class__([factor, *factors])
470 | else:
471 | return self.__class__([*factors, factor, *self.factors[i+1:]])
472 |
473 | def transduct(self, new_dim, mode=0, new_factor=None):
474 | """Transduction adds a new dimension to the existing factorization
475 |
476 | Parameters
477 | ----------
478 | new_dim : int
479 | dimension of the new mode to add
480 | mode : where to insert the new dimension, after the channels, default is 0
481 | by default, insert the new dimensions before the existing ones
482 | (e.g. add time before height and width)
483 |
484 | Returns
485 | -------
486 | self
487 | """
488 | factors = self.factors
489 |
490 | # Important: don't increment the order before accessing factors which uses order!
491 | self.order += 1
492 | new_rank = self.rank[mode]
493 | self.rank = self.rank[:mode] + (new_rank, ) + self.rank[mode:]
494 | self.shape = self.shape[:mode] + (new_dim, ) + self.shape[mode:]
495 |
496 | # Init so the reconstruction is equivalent to concatenating the previous self new_dim times
497 | if new_factor is None:
498 | new_factor = torch.zeros(new_rank, new_dim, new_rank)
499 | for i in range(new_dim):
500 | new_factor[:, i, :] = torch.eye(new_rank)#/new_dim
501 | # Below: <=> static prediciton
502 | # new_factor[:, new_dim//2, :] = torch.eye(new_rank)
503 |
504 | factors.insert(mode, nn.Parameter(new_factor.to(factors[0].device).contiguous()))
505 | self.factors = FactorList(factors)
506 |
507 | return self
508 |
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/init.py:
--------------------------------------------------------------------------------
1 | """Module for initializing tensor decompositions
2 | """
3 |
4 | # Author: Jean Kossaifi
5 | # License: BSD 3 clause
6 |
7 | import torch
8 | import math
9 | import numpy as np
10 |
11 | import tensorly as tl
12 | tl.set_backend('pytorch')
13 |
14 | def tensor_init(tensor, std=0.02):
15 | """Initializes directly the parameters of a factorized tensor so the reconstruction has the specified standard deviation and 0 mean
16 |
17 | Parameters
18 | ----------
19 | tensor : torch.Tensor or FactorizedTensor
20 | std : float, default is 0.02
21 | the desired standard deviation of the full (reconstructed) tensor
22 | """
23 | from .factorized_tensors import FactorizedTensor
24 |
25 | if isinstance(tensor, FactorizedTensor):
26 | tensor.normal_(0, std)
27 | elif torch.is_tensor(tensor):
28 | tensor.normal_(0, std)
29 | else:
30 | raise ValueError(f'Got tensor of class {tensor.__class__.__name__} but expected torch.Tensor or FactorizedWeight.')
31 |
32 |
33 | def cp_init(cp_tensor, std=0.02):
34 | """Initializes directly the weights and factors of a CP decomposition so the reconstruction has the specified std and 0 mean
35 |
36 | Parameters
37 | ----------
38 | cp_tensor : CPTensor
39 | std : float, default is 0.02
40 | the desired standard deviation of the full (reconstructed) tensor
41 |
42 | Notes
43 | -----
44 | We assume the given (weights, factors) form a correct CP decomposition, no checks are done here.
45 | """
46 | rank = cp_tensor.rank # We assume we are given a valid CP
47 | order = cp_tensor.orders
48 | std_factors = (std/math.sqrt(rank))**(1/order)
49 |
50 | with torch.no_grad():
51 | cp_tensor.weights.fill_(1)
52 | for factor in cp_tensor.factors:
53 | factor.normal_(0, std_factors)
54 | return cp_tensor
55 |
56 | def tucker_init(tucker_tensor, std=0.02):
57 | """Initializes directly the weights and factors of a Tucker decomposition so the reconstruction has the specified std and 0 mean
58 |
59 | Parameters
60 | ----------
61 | tucker_tensor : TuckerTensor
62 | std : float, default is 0.02
63 | the desired standard deviation of the full (reconstructed) tensor
64 |
65 | Notes
66 | -----
67 | We assume the given (core, factors) form a correct Tucker decomposition, no checks are done here.
68 | """
69 | order = tucker_tensor.order
70 | rank = tucker_tensor.rank
71 | r = np.prod([math.sqrt(r) for r in rank])
72 | std_factors = (std/r)**(1/(order+1))
73 | with torch.no_grad():
74 | tucker_tensor.core.normal_(0, std_factors)
75 | for factor in tucker_tensor.factors:
76 | factor.normal_(0, std_factors)
77 | return tucker_tensor
78 |
79 | def tt_init(tt_tensor, std=0.02):
80 | """Initializes directly the weights and factors of a TT decomposition so the reconstruction has the specified std and 0 mean
81 |
82 | Parameters
83 | ----------
84 | tt_tensor : TTTensor
85 | std : float, default is 0.02
86 | the desired standard deviation of the full (reconstructed) tensor
87 |
88 | Notes
89 | -----
90 | We assume the given factors form a correct TT decomposition, no checks are done here.
91 | """
92 | order = tt_tensor.order
93 | r = np.prod(tt_tensor.rank)
94 | std_factors = (std/r)**(1/order)
95 | with torch.no_grad():
96 | for factor in tt_tensor.factors:
97 | factor.normal_(0, std_factors)
98 | return tt_tensor
99 |
100 |
101 | def block_tt_init(block_tt, std=0.02):
102 | """Initializes directly the weights and factors of a BlockTT decomposition so the reconstruction has the specified std and 0 mean
103 |
104 | Parameters
105 | ----------
106 | block_tt : Matrix in the tensor-train format
107 | std : float, default is 0.02
108 | the desired standard deviation of the full (reconstructed) tensor
109 |
110 | Notes
111 | -----
112 | We assume the given factors form a correct Block-TT decomposition, no checks are done here.
113 | """
114 | return tt_init(block_tt, std=std)
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/factorized_tensors/tests/__init__.py
--------------------------------------------------------------------------------
/tltorch/factorized_tensors/tests/test_factorizations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import math
4 | import torch
5 | from torch import testing
6 |
7 | from tltorch.factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT
8 | from tltorch.factorized_tensors.core import TensorizedTensor
9 |
10 | from ..factorized_tensors import FactorizedTensor, CPTensor, TuckerTensor, TTTensor
11 |
12 |
13 | @pytest.mark.parametrize('factorization', ['CP', 'Tucker', 'TT'])
14 | def test_FactorizedTensor(factorization):
15 | """Test for FactorizedTensor"""
16 | shape = (4, 3, 2, 5)
17 | fact_tensor = FactorizedTensor.new(shape=shape, rank='same', factorization=factorization)
18 | fact_tensor.normal_()
19 |
20 | # Check that the correct type of factorized tensor is created
21 | assert fact_tensor._name.lower() == factorization.lower()
22 | mapping = dict(CP=CPTensor, Tucker=TuckerTensor, TT=TTTensor)
23 | assert isinstance(fact_tensor, mapping[factorization])
24 |
25 | # Check the shape of the factorized tensor and reconstruction
26 | reconstruction = fact_tensor.to_tensor()
27 | assert fact_tensor.shape == reconstruction.shape == shape
28 |
29 | # Check that indexing the factorized tensor returns the correct result
30 | # np_s converts intuitive array indexing to proper tuples
31 | indices = [
32 | np.s_[:, 2, :], # = (slice(None), 2, slice(None))
33 | np.s_[2:, :2, :, 0],
34 | np.s_[0, 2, 1, 3],
35 | np.s_[-1, :, :, ::2]
36 | ]
37 | for idx in indices:
38 | assert reconstruction[idx].shape == fact_tensor[idx].shape
39 | res = fact_tensor[idx]
40 | if not torch.is_tensor(res):
41 | res = res.to_tensor()
42 | testing.assert_close(reconstruction[idx], res)
43 |
44 |
45 | @pytest.mark.parametrize('factorization', ['BlockTT', 'CP']) #['CP', 'Tucker', 'BlockTT'])
46 | @pytest.mark.parametrize('batch_size', [(), (4,)])
47 | def test_TensorizedMatrix(factorization, batch_size):
48 | """Test for TensorizedMatrix"""
49 | row_tensor_shape = (4, 3, 2)
50 | column_tensor_shape = (5, 3, 2)
51 | row_shape = math.prod(row_tensor_shape)
52 | column_shape = math.prod(column_tensor_shape)
53 | tensor_shape = batch_size + (row_tensor_shape, column_tensor_shape)
54 |
55 | fact_tensor = TensorizedTensor.new(tensor_shape,
56 | rank=0.5, factorization=factorization)
57 | fact_tensor.normal_()
58 |
59 | # Check that the correct type of factorized tensor is created
60 | assert fact_tensor._name.lower() == factorization.lower()
61 | mapping = dict(CP=CPTensorized, Tucker=TuckerTensorized, BlockTT=BlockTT)
62 | assert isinstance(fact_tensor, mapping[factorization])
63 |
64 | # Check that the matrix has the right shape
65 | reconstruction = fact_tensor.to_matrix()
66 | if batch_size:
67 | assert fact_tensor.shape[1] == row_shape == reconstruction.shape[1]
68 | assert fact_tensor.shape[2] == column_shape == reconstruction.shape[2]
69 | assert fact_tensor.ndim == 3
70 | else:
71 | assert fact_tensor.shape[0] == row_shape == reconstruction.shape[0]
72 | assert fact_tensor.shape[1] == column_shape == reconstruction.shape[1]
73 | assert fact_tensor.ndim == 2
74 |
75 | # Check that indexing the factorized tensor returns the correct result
76 | # np_s converts intuitive array indexing to proper tuples
77 | indices = [
78 | np.s_[:, :], # = (slice(None), slice(None))
79 | np.s_[:, 2],
80 | np.s_[2, 3],
81 | np.s_[1, :]
82 | ]
83 | for idx in indices:
84 | assert tuple(reconstruction[idx].shape) == tuple(fact_tensor[idx].shape)
85 | res = fact_tensor[idx]
86 | if not torch.is_tensor(res):
87 | res = res.to_matrix()
88 | testing.assert_close(reconstruction[idx], res)
89 |
90 |
91 | @pytest.mark.parametrize('factorization', ['CP', 'TT'])
92 | def test_transduction(factorization):
93 | """Test for transduction"""
94 | shape = (3, 4, 5)
95 | new_dim = 2
96 | for mode in range(3):
97 | fact_tensor = FactorizedTensor.new(shape=shape, rank=6, factorization=factorization)
98 | fact_tensor.normal_()
99 | original_rec = fact_tensor.to_tensor()
100 | fact_tensor = fact_tensor.transduct(new_dim, mode=mode)
101 | rec = fact_tensor.to_tensor()
102 | true_shape = list(shape); true_shape.insert(mode, new_dim)
103 | assert tuple(fact_tensor.shape) == tuple(rec.shape) == tuple(true_shape)
104 |
105 | indices = [slice(None)]*mode
106 | for i in range(new_dim):
107 | testing.assert_close(original_rec, rec[tuple(indices + [i])])
108 |
109 | @pytest.mark.parametrize('unsqueezed_init', ['average', 1.2])
110 | def test_tucker_init_unsqueezed_modes(unsqueezed_init):
111 | """Test for Tucker Factorization init from tensor with unsqueezed_modes
112 | """
113 | tensor = FactorizedTensor.new((4, 4, 4), rank=(4, 1, 4), factorization='tucker')
114 | mat = torch.randn((4, 4))
115 |
116 | tensor.init_from_tensor(mat, unsqueezed_modes=[1], unsqueezed_init=unsqueezed_init)
117 | rec = tensor.to_tensor()
118 |
119 | if unsqueezed_init == 'average':
120 | coef = 1/4
121 | else:
122 | coef = unsqueezed_init
123 |
124 | for i in range(4):
125 | testing.assert_close(rec[:, i], mat*coef)
126 |
127 |
128 | @pytest.mark.parametrize('factorization', ['ComplexCP', 'ComplexTucker', 'ComplexTT', 'ComplexDense'])
129 | def test_ComplexFactorizedTensor(factorization):
130 | """Test for ComplexFactorizedTensor"""
131 | shape = (4, 3, 2, 5)
132 | fact_tensor = FactorizedTensor.new(shape=shape, rank='same', factorization=factorization)
133 | fact_tensor.normal_()
134 |
135 | assert fact_tensor.to_tensor().shape == shape
136 | assert fact_tensor.to_tensor().dtype == torch.cfloat
137 | for param in fact_tensor.parameters():
138 | assert param.dtype == torch.float32
139 |
--------------------------------------------------------------------------------
/tltorch/functional/__init__.py:
--------------------------------------------------------------------------------
1 | from .convolution import convolve, tucker_conv
2 | from .linear import factorized_linear
--------------------------------------------------------------------------------
/tltorch/functional/convolution.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import tensorly as tl
5 | tl.set_backend('pytorch')
6 | from tensorly import tenalg
7 |
8 | from ..factorized_tensors import CPTensor, TTTensor, TuckerTensor, DenseTensor
9 |
10 | # Author: Jean Kossaifi
11 | # License: BSD 3 clause
12 |
13 |
14 | _CONVOLUTION = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d}
15 |
16 | def convolve(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
17 | """Convolution of any specified order, wrapper on torch's F.convNd
18 |
19 | Parameters
20 | ----------
21 | x : torch.Tensor or FactorizedTensor
22 | input tensor
23 | weight : torch.Tensor
24 | convolutional weights
25 | bias : bool, optional
26 | by default None
27 | stride : int, optional
28 | by default 1
29 | padding : int, optional
30 | by default 0
31 | dilation : int, optional
32 | by default 1
33 | groups : int, optional
34 | by default 1
35 |
36 | Returns
37 | -------
38 | torch.Tensor
39 | `x` convolved with `weight`
40 | """
41 | try:
42 | if torch.is_tensor(weight):
43 | return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding,
44 | dilation=dilation, groups=groups)
45 | else:
46 | if isinstance(weight, TTTensor):
47 | weight = tl.moveaxis(weight.to_tensor(), -1, 0)
48 | else:
49 | weight = weight.to_tensor()
50 | return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding,
51 | dilation=dilation, groups=groups)
52 | except KeyError:
53 | raise ValueError(f'Got tensor of order={weight.ndim} but pytorch only supports up to 3rd order (3D) Convs.')
54 |
55 |
56 | def general_conv1d_(x, kernel, mode, bias=None, stride=1, padding=0, groups=1, dilation=1, verbose=False):
57 | """General 1D convolution along the mode-th dimension
58 |
59 | Parameters
60 | ----------
61 | x : batch-dize, in_channels, K1, ..., KN
62 | kernel : out_channels, in_channels/groups, K{mode}
63 | mode : int
64 | weight along which to perform the decomposition
65 | stride : int
66 | padding : int
67 | groups : 1
68 | typically would be equal to thhe number of input-channels
69 | at least for CP convolutions
70 |
71 | Returns
72 | -------
73 | x convolved with the given kernel, along dimension `mode`
74 | """
75 | if verbose:
76 | print(f'Convolving {x.shape} with {kernel.shape} along mode {mode}, '
77 | f'stride={stride}, padding={padding}, groups={groups}')
78 |
79 | in_channels = tl.shape(x)[1]
80 | n_dim = tl.ndim(x)
81 | permutation = list(range(n_dim))
82 | spatial_dim = permutation.pop(mode)
83 | channels_dim = permutation.pop(1)
84 | permutation += [channels_dim, spatial_dim]
85 | x = tl.transpose(x, permutation)
86 | x_shape = list(x.shape)
87 | x = tl.reshape(x, (-1, in_channels, x_shape[-1]))
88 | x = F.conv1d(x.contiguous(), kernel, bias=bias, stride=stride, dilation=dilation, padding=padding, groups=groups)
89 | x_shape[-2:] = x.shape[-2:]
90 | x = tl.reshape(x, x_shape)
91 | permutation = list(range(n_dim))[:-2]
92 | permutation.insert(1, n_dim - 2)
93 | permutation.insert(mode, n_dim - 1)
94 | x = tl.transpose(x, permutation)
95 |
96 | return x
97 |
98 |
99 | def general_conv1d(x, kernel, mode, bias=None, stride=1, padding=0, groups=1, dilation=1, verbose=False):
100 | """General 1D convolution along the mode-th dimension
101 |
102 | Uses an ND convolution under the hood
103 |
104 | Parameters
105 | ----------
106 | x : batch-dize, in_channels, K1, ..., KN
107 | kernel : out_channels, in_channels/groups, K{mode}
108 | mode : int
109 | weight along which to perform the decomposition
110 | stride : int
111 | padding : int
112 | groups : 1
113 | typically would be equal to the number of input-channels
114 | at least for CP convolutions
115 |
116 | Returns
117 | -------
118 | x convolved with the given kernel, along dimension `mode`
119 | """
120 | if verbose:
121 | print(f'Convolving {x.shape} with {kernel.shape} along mode {mode}, '
122 | f'stride={stride}, padding={padding}, groups={groups}')
123 |
124 | def _pad_value(value, mode, order, padding=1):
125 | return tuple([value if i == (mode - 2) else padding for i in range(order)])
126 |
127 | ndim = tl.ndim(x)
128 | order = ndim - 2
129 | for i in range(2, ndim):
130 | if i != mode:
131 | kernel = kernel.unsqueeze(i)
132 |
133 | return _CONVOLUTION[order](x, kernel, bias=bias,
134 | stride=_pad_value(stride, mode, order),
135 | padding=_pad_value(padding, mode, order, padding=0),
136 | dilation=_pad_value(dilation, mode, order),
137 | groups=groups)
138 |
139 |
140 | def tucker_conv(x, tucker_tensor, bias=None, stride=1, padding=0, dilation=1):
141 | # Extract the rank from the actual decomposition in case it was changed by, e.g. dropout
142 | rank = tucker_tensor.rank
143 |
144 | batch_size = x.shape[0]
145 | n_dim = tl.ndim(x)
146 |
147 | # Change the number of channels to the rank
148 | x_shape = list(x.shape)
149 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous()
150 |
151 | # This can be done with a tensor contraction
152 | # First conv == tensor contraction
153 | # from (in_channels, rank) to (rank == out_channels, in_channels, 1)
154 | x = F.conv1d(x, tl.transpose(tucker_tensor.factors[1]).unsqueeze(2))
155 |
156 | x_shape[1] = rank[1]
157 | x = x.reshape(x_shape)
158 |
159 | modes = list(range(2, n_dim+1))
160 | weight = tl.tenalg.multi_mode_dot(tucker_tensor.core, tucker_tensor.factors[2:], modes=modes)
161 | x = convolve(x, weight, bias=None, stride=stride, padding=padding, dilation=dilation)
162 |
163 | # Revert back number of channels from rank to output_channels
164 | x_shape = list(x.shape)
165 | x = x.reshape((batch_size, x_shape[1], -1))
166 | # Last conv == tensor contraction
167 | # From (out_channels, rank) to (out_channels, in_channels == rank, 1)
168 | x = F.conv1d(x, tucker_tensor.factors[0].unsqueeze(2), bias=bias)
169 |
170 | x_shape[1] = x.shape[1]
171 | x = x.reshape(x_shape)
172 |
173 | return x
174 |
175 |
176 | def tt_conv(x, tt_tensor, bias=None, stride=1, padding=0, dilation=1):
177 | """Perform a factorized tt convolution
178 |
179 | Parameters
180 | ----------
181 | x : torch.tensor
182 | tensor of shape (batch_size, C, I_2, I_3, ..., I_N)
183 |
184 | Returns
185 | -------
186 | NDConv(x) with an tt kernel
187 | """
188 | shape = tt_tensor.shape
189 | rank = tt_tensor.rank
190 |
191 | batch_size = x.shape[0]
192 | order = len(shape) - 2
193 |
194 | if isinstance(padding, int):
195 | padding = (padding, )*order
196 | if isinstance(stride, int):
197 | stride = (stride, )*order
198 | if isinstance(dilation, int):
199 | dilation = (dilation, )*order
200 |
201 | # Change the number of channels to the rank
202 | x_shape = list(x.shape)
203 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous()
204 |
205 | # First conv == tensor contraction
206 | # from (1, in_channels, rank) to (rank == out_channels, in_channels, 1)
207 | x = F.conv1d(x, tl.transpose(tt_tensor.factors[0], [2, 1, 0]))
208 |
209 | x_shape[1] = x.shape[1]#rank[1]
210 | x = x.reshape(x_shape)
211 |
212 | # convolve over non-channels
213 | for i in range(order):
214 | # From (in_rank, kernel_size, out_rank) to (out_rank, in_rank, kernel_size)
215 | kernel = tl.transpose(tt_tensor.factors[i+1], [2, 0, 1])
216 | x = general_conv1d(x.contiguous(), kernel, i+2, stride=stride[i], padding=padding[i], dilation=dilation[i])
217 |
218 | # Revert back number of channels from rank to output_channels
219 | x_shape = list(x.shape)
220 | x = x.reshape((batch_size, x_shape[1], -1))
221 | # Last conv == tensor contraction
222 | # From (rank, out_channels, 1) to (out_channels, in_channels == rank, 1)
223 | x = F.conv1d(x, tl.transpose(tt_tensor.factors[-1], [1, 0, 2]), bias=bias)
224 |
225 | x_shape[1] = x.shape[1]
226 | x = x.reshape(x_shape)
227 |
228 | return x
229 |
230 |
231 | def cp_conv(x, cp_tensor, bias=None, stride=1, padding=0, dilation=1):
232 | """Perform a factorized CP convolution
233 |
234 | Parameters
235 | ----------
236 | x : torch.tensor
237 | tensor of shape (batch_size, C, I_2, I_3, ..., I_N)
238 |
239 | Returns
240 | -------
241 | NDConv(x) with an CP kernel
242 | """
243 | shape = cp_tensor.shape
244 | rank = cp_tensor.rank
245 |
246 | batch_size = x.shape[0]
247 | order = len(shape) - 2
248 |
249 | if isinstance(padding, int):
250 | padding = (padding, )*order
251 | if isinstance(stride, int):
252 | stride = (stride, )*order
253 | if isinstance(dilation, int):
254 | dilation = (dilation, )*order
255 |
256 | # Change the number of channels to the rank
257 | x_shape = list(x.shape)
258 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous()
259 |
260 | # First conv == tensor contraction
261 | # from (in_channels, rank) to (rank == out_channels, in_channels, 1)
262 | x = F.conv1d(x, tl.transpose(cp_tensor.factors[1]).unsqueeze(2))
263 |
264 | x_shape[1] = rank
265 | x = x.reshape(x_shape)
266 |
267 | # convolve over non-channels
268 | for i in range(order):
269 | # From (kernel_size, rank) to (rank, 1, kernel_size)
270 | kernel = tl.transpose(cp_tensor.factors[i+2]).unsqueeze(1)
271 | x = general_conv1d(x.contiguous(), kernel, i+2, stride=stride[i], padding=padding[i], dilation=dilation[i], groups=rank)
272 |
273 | # Revert back number of channels from rank to output_channels
274 | x_shape = list(x.shape)
275 | x = x.reshape((batch_size, x_shape[1], -1))
276 | # Last conv == tensor contraction
277 | # From (out_channels, rank) to (out_channels, in_channels == rank, 1)
278 | x = F.conv1d(x*cp_tensor.weights.unsqueeze(1).unsqueeze(0), cp_tensor.factors[0].unsqueeze(2), bias=bias)
279 |
280 | x_shape[1] = x.shape[1] # = out_channels
281 | x = x.reshape(x_shape)
282 |
283 | return x
284 |
285 |
286 | def cp_conv_mobilenet(x, cp_tensor, bias=None, stride=1, padding=0, dilation=1):
287 | """Perform a factorized CP convolution
288 |
289 | Parameters
290 | ----------
291 | x : torch.tensor
292 | tensor of shape (batch_size, C, I_2, I_3, ..., I_N)
293 |
294 | Returns
295 | -------
296 | NDConv(x) with an CP kernel
297 | """
298 | factors = cp_tensor.factors
299 | shape = cp_tensor.shape
300 | rank = cp_tensor.rank
301 |
302 | batch_size = x.shape[0]
303 | order = len(shape) - 2
304 |
305 | # Change the number of channels to the rank
306 | x_shape = list(x.shape)
307 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous()
308 |
309 | # First conv == tensor contraction
310 | # from (in_channels, rank) to (rank == out_channels, in_channels, 1)
311 | x = F.conv1d(x, tl.transpose(factors[1]).unsqueeze(2))
312 |
313 | x_shape[1] = rank
314 | x = x.reshape(x_shape)
315 |
316 | # convolve over merged actual dimensions
317 | # Spatial convs
318 | # From (kernel_size, rank) to (out_rank, 1, kernel_size)
319 | if order == 1:
320 | weight = tl.transpose(factors[2]).unsqueeze(1)
321 | x = F.conv1d(x.contiguous(), weight, stride=stride, padding=padding, dilation=dilation, groups=rank)
322 | elif order == 2:
323 | weight = tenalg.tensordot(tl.transpose(factors[2]),
324 | tl.transpose(factors[3]), modes=(), batched_modes=0
325 | ).unsqueeze(1)
326 | x = F.conv2d(x.contiguous(), weight, stride=stride, padding=padding, dilation=dilation, groups=rank)
327 | elif order == 3:
328 | weight = tenalg.tensordot(tl.transpose(factors[2]),
329 | tenalg.tensordot(tl.transpose(factors[3]), tl.transpose(factors[4]), modes=(), batched_modes=0),
330 | modes=(), batched_modes=0
331 | ).unsqueeze(1)
332 | x = F.conv3d(x.contiguous(), weight, stride=stride, padding=padding, dilation=dilation, groups=rank)
333 |
334 | # Revert back number of channels from rank to output_channels
335 | x_shape = list(x.shape)
336 | x = x.reshape((batch_size, x_shape[1], -1))
337 |
338 | # Last conv == tensor contraction
339 | # From (out_channels, rank) to (out_channels, in_channels == rank, 1)
340 | x = F.conv1d(x*cp_tensor.weights.unsqueeze(1).unsqueeze(0), factors[0].unsqueeze(2), bias=bias)
341 |
342 | x_shape[1] = x.shape[1] # = out_channels
343 | x = x.reshape(x_shape)
344 |
345 | return x
346 |
347 |
348 | def _get_factorized_conv(factorization, implementation='factorized'):
349 | if implementation == 'reconstructed' or factorization == 'Dense':
350 | return convolve
351 | if isinstance(factorization, CPTensor):
352 | if implementation == 'factorized':
353 | return cp_conv
354 | elif implementation == 'mobilenet':
355 | return cp_conv_mobilenet
356 | elif isinstance(factorization, TuckerTensor):
357 | return tucker_conv
358 | elif isinstance(factorization, TTTensor):
359 | return tt_conv
360 | raise ValueError(f'Got unknown type {factorization}')
361 |
362 |
363 | def convNd(x, weight, bias=None, stride=1, padding=0, dilation=1, implementation='factorized'):
364 | if implementation=='reconstructed':
365 | weight = weight.to_tensor()
366 |
367 | if isinstance(weight, DenseTensor):
368 | return convolve(x, weight.tensor, bias=bias, stride=stride, padding=padding, dilation=dilation)
369 |
370 | if torch.is_tensor(weight):
371 | return convolve(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation)
372 |
373 | if isinstance(weight, CPTensor):
374 | if implementation == 'factorized':
375 | return cp_conv(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation)
376 | elif implementation == 'mobilenet':
377 | return cp_conv_mobilenet(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation)
378 | elif isinstance(weight, TuckerTensor):
379 | return tucker_conv(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation)
380 | elif isinstance(weight, TTTensor):
381 | return tt_conv(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation)
382 |
--------------------------------------------------------------------------------
/tltorch/functional/factorized_linear.py:
--------------------------------------------------------------------------------
1 | from .factorized_tensordot import tensor_dot_tucker, tensor_dot_cp
2 | import tensorly as tl
3 | tl.set_backend('pytorch')
4 |
5 | # Author: Jean Kossaifi
6 |
7 | def linear_tucker(tensor, tucker_matrix, transpose=True, channels_first=True):
8 | if transpose:
9 | contraction_axis = 1
10 | else:
11 | contraction_axis = 0
12 | n_rows = len(tucker_matrix.tensorized_shape[contraction_axis])
13 | tensor = tensor.reshape(-1, *tucker_matrix.tensorized_shape[contraction_axis])
14 |
15 | modes_tensor = list(range(tensor.ndim - n_rows, tensor.ndim))
16 | if transpose:
17 | modes_tucker = list(range(n_rows, tucker_matrix.order))
18 | else:
19 | modes_tucker = list(range(n_rows))
20 |
21 | return tensor_dot_tucker(tensor, tucker_matrix, (modes_tensor, modes_tucker))
22 |
23 | def linear_cp(tensor, cp_matrix, transpose=True):
24 | if transpose:
25 | out_features, in_features = len(cp_matrix.tensorized_shape[0]), len(cp_matrix.tensorized_shape[1])
26 | in_shape = cp_matrix.tensorized_shape[1]
27 | modes_cp = list(range(out_features, cp_matrix.order))
28 | else:
29 | in_features, out_features = len(cp_matrix.tensorized_shape[0]), len(cp_matrix.tensorized_shape[1])
30 | in_shape = cp_matrix.tensorized_shape[0]
31 | modes_cp = list(range(in_features))
32 | tensor = tensor.reshape(-1, *in_shape)
33 |
34 | modes_tensor = list(range(1, tensor.ndim))
35 |
36 | return tensor_dot_cp(tensor, cp_matrix, (modes_tensor, modes_cp))
37 |
38 |
39 | def linear_blocktt(tensor, tt_matrix, transpose=True):
40 | if transpose:
41 | contraction_axis = 1
42 | else:
43 | contraction_axis = 0
44 | ndim = len(tt_matrix.tensorized_shape[contraction_axis])
45 | tensor = tensor.reshape(-1, *tt_matrix.tensorized_shape[contraction_axis])
46 |
47 | bs = 'a'
48 | start = ord(bs) + 1
49 | in_idx = bs + ''.join(chr(i) for i in [start+i for i in range(ndim)])
50 | factors_idx = []
51 | for i in range(ndim):
52 | if transpose:
53 | idx = [start+ndim*2+i, start+ndim+i, start+i, start+ndim*2+i+1]
54 | else:
55 | idx = [start+ndim*2+i, start+i, start+ndim+i, start+ndim*2+i+1]
56 | factors_idx.append(''.join(chr(j) for j in idx))
57 | out_idx = bs + ''.join(chr(i) for i in [start + ndim + i for i in range(ndim)])
58 | eq = in_idx + ',' + ','.join(i for i in factors_idx) + '->' + out_idx
59 | res = tl.einsum(eq, tensor, *tt_matrix.factors)
60 | return tl.reshape(res, (tl.shape(res)[0], -1))
61 |
62 |
--------------------------------------------------------------------------------
/tltorch/functional/factorized_tensordot.py:
--------------------------------------------------------------------------------
1 | # Author: Jean Kossaifi
2 |
3 | import tensorly as tl
4 | from tensorly.tenalg.tenalg_utils import _validate_contraction_modes
5 | tl.set_backend('pytorch')
6 |
7 | einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
8 |
9 |
10 | def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()):
11 | """Batched tensor contraction between a dense tensor and a Tucker tensor on specified modes
12 |
13 | Parameters
14 | ----------
15 | tensor : DenseTensor
16 | tucker : TuckerTensor
17 | modes : int list or int
18 | modes on which to contract tensor1 and tensor2
19 | batched_modes : int or tuple[int]
20 |
21 | Returns
22 | -------
23 | contraction : tensor contracted with cp on the specified modes
24 | """
25 | modes_tensor, modes_tucker = _validate_contraction_modes(
26 | tl.shape(tensor), tucker.tensor_shape, modes)
27 | input_order = tensor.ndim
28 | weight_order = tucker.order
29 |
30 | batched_modes_tensor, batched_modes_tucker = _validate_contraction_modes(
31 | tl.shape(tensor), tucker.tensor_shape, batched_modes)
32 |
33 | sorted_modes_tucker = sorted(modes_tucker+batched_modes_tucker, reverse=True)
34 | sorted_modes_tensor = sorted(modes_tensor+batched_modes_tensor, reverse=True)
35 |
36 | # Symbol for dimensionality of the core
37 | rank_sym = [einsum_symbols[i] for i in range(weight_order)]
38 |
39 | # Symbols for tucker weight size
40 | tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]
41 |
42 | # Symbols for input tensor
43 | tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]
44 |
45 | # Output: input + weights symbols after removing contraction symbols
46 | output_sym = tensor_sym + tucker_sym
47 | for m in sorted_modes_tucker:
48 | if m in modes_tucker: #not batched
49 | output_sym.pop(m+input_order)
50 | for m in sorted_modes_tensor:
51 | # It's batched, always remove
52 | output_sym.pop(m)
53 |
54 | # print(tensor_sym, tucker_sym, modes_tensor, batched_modes_tensor)
55 | for i, e in enumerate(modes_tensor):
56 | tensor_sym[e] = tucker_sym[modes_tucker[i]]
57 | for i, e in enumerate(batched_modes_tensor):
58 | tensor_sym[e] = tucker_sym[batched_modes_tucker[i]]
59 |
60 | # Form the actual equation: tensor, core, factors -> output
61 | eq = ''.join(tensor_sym)
62 | eq += ',' + ''.join(rank_sym)
63 | eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym))
64 | eq += '->' + ''.join(output_sym)
65 |
66 | return tl.einsum(eq, tensor, tucker.core, *tucker.factors)
67 |
68 |
69 |
70 | def tensor_dot_cp(tensor, cp, modes):
71 | """Contracts a to CP tensors in factorized form
72 |
73 | Returns
74 | -------
75 | tensor = tensor x cp_matrix.to_matrix().T
76 | """
77 | try:
78 | cp_shape = cp.tensor_shape
79 | except AttributeError:
80 | cp_shape = cp.shape
81 | modes_tensor, modes_cp = _validate_contraction_modes(tl.shape(tensor), cp_shape, modes)
82 |
83 | tensor_order = tl.ndim(tensor)
84 | # CP rank = 'a', start at b
85 | start = ord('b')
86 | eq_in = ''.join(f'{chr(start+index)}' for index in range(tensor_order))
87 | eq_factors = []
88 | eq_res = ''.join(eq_in[i] if i not in modes_tensor else '' for i in range(tensor_order))
89 | counter_joint = 0 # contraction modes, shared indices between tensor and CP
90 | counter_free = 0 # new uncontracted modes from the CP
91 | for i in range(len(cp.factors)):
92 | if i in modes_cp:
93 | eq_factors.append(f'{eq_in[modes_tensor[counter_joint]]}a')
94 | counter_joint += 1
95 | else:
96 | eq_factors.append(f'{chr(start+tensor_order+counter_free)}a')
97 | eq_res += f'{chr(start+tensor_order+counter_free)}'
98 | counter_free += 1
99 |
100 | eq_factors = ','.join(f for f in eq_factors)
101 | eq = eq_in + ',a,' + eq_factors + '->' + eq_res
102 | res = tl.einsum(eq, tensor, cp.weights, *cp.factors)
103 |
104 | return res
--------------------------------------------------------------------------------
/tltorch/functional/linear.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from ..factorized_tensors import TensorizedTensor
5 | from ..factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT
6 | from .factorized_linear import linear_blocktt, linear_cp, linear_tucker
7 |
8 | import tensorly as tl
9 | tl.set_backend('pytorch')
10 |
11 | # Author: Jean Kossaifi
12 | # License: BSD 3 clause
13 |
14 |
15 | def factorized_linear(x, weight, bias=None, in_features=None, implementation="factorized"):
16 | """Linear layer with a dense input x and factorized weight
17 | """
18 | assert implementation in {"factorized", "reconstructed"}, f"Expect implementation from [factorized, reconstructed], but got {implementation}"
19 |
20 | if in_features is None:
21 | in_features = np.prod(x.shape[-1])
22 |
23 | if not torch.is_tensor(weight):
24 | # Weights are in the form (out_features, in_features)
25 | # PyTorch's linear returns dot(x, weight.T)!
26 | if isinstance(weight, TensorizedTensor):
27 | if implementation == "factorized":
28 | x_shape = x.shape[:-1] + weight.tensorized_shape[1]
29 | out_shape = x.shape[:-1] + (-1, )
30 | if isinstance(weight, CPTensorized):
31 | x = linear_cp(x.reshape(x_shape), weight).reshape(out_shape)
32 | if bias is not None:
33 | x = x + bias
34 | return x
35 | elif isinstance(weight, TuckerTensorized):
36 | x = linear_tucker(x.reshape(x_shape), weight).reshape(out_shape)
37 | if bias is not None:
38 | x = x + bias
39 | return x
40 | elif isinstance(weight, BlockTT):
41 | x = linear_blocktt(x.reshape(x_shape), weight).reshape(out_shape)
42 | if bias is not None:
43 | x = x + bias
44 | return x
45 | # if no efficient implementation available or force to use reconstructed mode: use reconstruction
46 | weight = weight.to_matrix()
47 | else:
48 | weight = weight.to_tensor()
49 |
50 | return F.linear(x, torch.reshape(weight, (-1, in_features)), bias=bias)
51 |
--------------------------------------------------------------------------------
/tltorch/functional/tensor_regression.py:
--------------------------------------------------------------------------------
1 | from ..factorized_tensors import FactorizedTensor, TuckerTensor
2 |
3 | import tensorly as tl
4 | tl.set_backend('pytorch')
5 | from tensorly import tenalg
6 |
7 | # Author: Jean Kossaifi
8 | # License: BSD 3 clause
9 |
10 |
11 | def trl(x, weight, bias=None, **kwargs):
12 | """Tensor Regression Layer
13 |
14 | Parameters
15 | ----------
16 | x : torch.tensor
17 | batch of inputs
18 | weight : FactorizedTensor
19 | factorized weights of the TRL
20 | bias : torch.Tensor, optional
21 | 1D tensor, by default None
22 |
23 | Returns
24 | -------
25 | result
26 | input x contracted with regression weights
27 | """
28 | if isinstance(weight, TuckerTensor):
29 | return tucker_trl(x, weight, bias=bias, **kwargs)
30 | else:
31 | if bias is None:
32 | return tenalg.inner(x, weight.to_tensor(), n_modes=tl.ndim(x)-1)
33 | else:
34 | return tenalg.inner(x, weight.to_tensor(), n_modes=tl.ndim(x)-1) + bias
35 |
36 |
37 | def tucker_trl(x, weight, project_input=False, bias=None):
38 | n_input = tl.ndim(x) - 1
39 | if project_input:
40 | x = tenalg.multi_mode_dot(x, weight.factors[:n_input], modes=range(1, n_input+1), transpose=True)
41 | regression_weights = tenalg.multi_mode_dot(weight.core, weight.factors[n_input:],
42 | modes=range(n_input, weight.order))
43 | else:
44 | regression_weights = weight.to_tensor()
45 |
46 | if bias is None:
47 | return tenalg.inner(x, regression_weights, n_modes=tl.ndim(x)-1)
48 | else:
49 | return tenalg.inner(x, regression_weights, n_modes=tl.ndim(x)-1) + bias
50 |
51 |
--------------------------------------------------------------------------------
/tltorch/functional/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/functional/tests/__init__.py
--------------------------------------------------------------------------------
/tltorch/functional/tests/test_factorized_linear.py:
--------------------------------------------------------------------------------
1 | from ...factorized_tensors import TensorizedTensor
2 | from ..factorized_linear import linear_tucker, linear_blocktt, linear_cp
3 | import torch
4 |
5 | import tensorly as tl
6 | tl.set_backend('pytorch')
7 | from tensorly import testing
8 | from math import prod
9 |
10 | import pytest
11 |
12 | # Author: Jean Kossaifi
13 |
14 |
15 | @pytest.mark.parametrize('factorization, factorized_linear',
16 | [('tucker', linear_tucker), ('blocktt', linear_blocktt), ('cp', linear_cp)])
17 | def test_linear_tensor_dot_tucker(factorization, factorized_linear):
18 | in_shape = (4, 5)
19 | in_dim = prod(in_shape)
20 | out_shape = (6, 2)
21 | rank = 3
22 | batch_size = 2
23 |
24 | tensor = tl.randn((batch_size, in_dim), dtype=tl.float32)
25 | fact_weight = TensorizedTensor.new((out_shape, in_shape), rank=rank,
26 | factorization=factorization)
27 | fact_weight.normal_()
28 | full_weight = fact_weight.to_matrix()
29 | true_res = torch.matmul(tensor, full_weight.T)
30 | res = factorized_linear(tensor, fact_weight, transpose=True)
31 | res = res.reshape(batch_size, -1)
32 | testing.assert_array_almost_equal(true_res, res, decimal=5)
33 |
34 |
--------------------------------------------------------------------------------
/tltorch/tensor_hooks/__init__.py:
--------------------------------------------------------------------------------
1 | from ._tensor_lasso import tensor_lasso, remove_tensor_lasso, TensorLasso
2 | from ._tensor_dropout import tensor_dropout, remove_tensor_dropout, TensorDropout
--------------------------------------------------------------------------------
/tltorch/tensor_hooks/_tensor_dropout.py:
--------------------------------------------------------------------------------
1 | """Tensor Dropout for TensorModules"""
2 |
3 | # Author: Jean Kossaifi
4 | # License: BSD 3 clause
5 |
6 | import tensorly as tl
7 | tl.set_backend('pytorch')
8 | import torch
9 | from ..factorized_tensors import TuckerTensor, CPTensor, TTTensor
10 |
11 |
12 | class TensorDropout():
13 | """Decomposition Hook for Tensor Dropout on FactorizedTensor
14 |
15 | Parameters
16 | ----------
17 | name : FactorizedTensor parameter on which to apply the dropout
18 | proba : float, probability of dropout
19 | min_dim : int
20 | Minimum dimension size for which to apply dropout.
21 | For instance, if a tensor if of shape (32, 32, 3, 3) and min_dim = 4
22 | then dropout will *not* be applied to the last two modes.
23 | """
24 | _factorizations = dict()
25 |
26 | def __init_subclass__(cls, factorization, **kwargs):
27 | """When a subclass is created, register it in _factorizations"""
28 | cls._factorizations[factorization.__name__] = cls
29 |
30 | def __init__(self, proba, min_dim=1, min_values=1, drop_test=False):
31 | assert 0 <= proba < 1, f'Got prob={proba} but tensor dropout is defined for 0 <= proba < 1.'
32 | self.proba = proba
33 | self.min_dim = min_dim
34 | self.min_values = min_values
35 | self.drop_test = drop_test
36 |
37 | def __call__(self, module, input, factorized_tensor):
38 | return self._apply_tensor_dropout(factorized_tensor, training=module.training)
39 |
40 | def _apply_tensor_dropout(self, factorized_tensor, training=True):
41 | raise NotImplementedError()
42 |
43 | @classmethod
44 | def apply(cls, module, proba, min_dim=3, min_values=1, drop_test=False):
45 | cls = cls._factorizations[module.__class__.__name__]
46 | for k, hook in module._forward_hooks.items():
47 | if isinstance(hook, cls):
48 | raise RuntimeError("Cannot register two weight_norm hooks on "
49 | "the same parameter")
50 |
51 | dropout = cls(proba, min_dim=min_dim, min_values=min_values, drop_test=drop_test)
52 | handle = module.register_forward_hook(dropout)
53 | return handle
54 |
55 |
56 | class TuckerDropout(TensorDropout, factorization=TuckerTensor):
57 | def _apply_tensor_dropout(self, tucker_tensor, training=True):
58 | if (not self.proba) or ((not training) and (not self.drop_test)):
59 | return tucker_tensor
60 |
61 | core, factors = tucker_tensor.core, tucker_tensor.factors
62 | tucker_rank = tucker_tensor.rank
63 |
64 | sampled_indices = []
65 | for rank in tucker_rank:
66 | idx = tl.arange(rank, device=core.device, dtype=torch.int64)
67 | if rank > self.min_dim:
68 | idx = idx[torch.bernoulli(torch.ones(rank, device=core.device)*(1 - self.proba),
69 | out=torch.empty(rank, device=core.device, dtype=torch.bool))]
70 | if len(idx) == 0:
71 | idx = torch.randint(0, rank, size=(self.min_values, ), device=core.device, dtype=torch.int64)
72 |
73 | sampled_indices.append(idx)
74 |
75 | if training:
76 | core = core[torch.meshgrid(*sampled_indices)]*(1/((1 - self.proba)**core.ndim))
77 | else:
78 | core = core[torch.meshgrid(*sampled_indices)]
79 |
80 | factors = [factor[:, idx] for (factor, idx) in zip(factors, sampled_indices)]
81 |
82 | return TuckerTensor(core, factors)
83 |
84 |
85 | class CPDropout(TensorDropout, factorization=CPTensor):
86 | def _apply_tensor_dropout(self, cp_tensor, training=True):
87 | if (not self.proba) or ((not training) and (not self.drop_test)):
88 | return cp_tensor
89 |
90 | rank = cp_tensor.rank
91 | device = cp_tensor.factors[0].device
92 |
93 | if rank > self.min_dim:
94 | sampled_indices = tl.arange(rank, device=device, dtype=torch.int64)
95 | sampled_indices = sampled_indices[torch.bernoulli(torch.ones(rank, device=device)*(1 - self.proba),
96 | out=torch.empty(rank, device=device, dtype=torch.bool))]
97 | if len(sampled_indices) == 0:
98 | sampled_indices = torch.randint(0, rank, size=(self.min_values, ), device=device, dtype=torch.int64)
99 |
100 | factors = [factor[:, sampled_indices] for factor in cp_tensor.factors]
101 | if training:
102 | weights = cp_tensor.weights[sampled_indices]*(1/(1 - self.proba))
103 | else:
104 | weights = cp_tensor.weights[sampled_indices]
105 |
106 | return CPTensor(weights, factors)
107 |
108 |
109 | class TTDropout(TensorDropout, factorization=TTTensor):
110 | def _apply_tensor_dropout(self, tt_tensor, training=True):
111 | if (not self.proba) or ((not training) and (not self.drop_test)):
112 | return tt_tensor
113 |
114 | device = tt_tensor.factors[0].device
115 |
116 | sampled_indices = []
117 | for i, rank in enumerate(tt_tensor.rank[1:]):
118 | if rank > self.min_dim:
119 | idx = tl.arange(rank, device=device, dtype=torch.int64)
120 | idx = idx[torch.bernoulli(torch.ones(rank, device=device)*(1 - self.proba),
121 | out=torch.empty(rank, device=device, dtype=torch.bool))]
122 | if len(idx) == 0:
123 | idx = torch.randint(0, rank, size=(self.min_values, ), device=device, dtype=torch.int64)
124 | else:
125 | idx = tl.arange(rank, device=device, dtype=torch.int64).tolist()
126 |
127 | sampled_indices.append(idx)
128 |
129 | sampled_factors = []
130 | if training:
131 | scaling = 1/(1 - self.proba)
132 | else:
133 | scaling = 1
134 | for i, f in enumerate(tt_tensor.factors):
135 | if i == 0:
136 | sampled_factors.append(f[..., sampled_indices[i]]*scaling)
137 | elif i == (tt_tensor.order - 1):
138 | sampled_factors.append(f[sampled_indices[i-1], ...])
139 | else:
140 | sampled_factors.append(f[sampled_indices[i-1], ...][..., sampled_indices[i]]*scaling)
141 |
142 | return TTTensor(sampled_factors)
143 |
144 |
145 | def tensor_dropout(factorized_tensor, p=0, min_dim=3, min_values=1, drop_test=False):
146 | """Tensor Dropout
147 |
148 | Parameters
149 | ----------
150 | factorized_tensor : FactorizedTensor
151 | the tensor module parametrized by the tensor decomposition to which to apply tensor dropout
152 | p : float
153 | dropout probability
154 | if 0, no dropout is applied
155 | if 1, all the components but 1 are dropped in the latent space
156 | min_dim : int, default is 3
157 | only apply dropout to modes with dimension larger than `min_dim`
158 | min_values : int, default is 1
159 | minimum number of components to select
160 |
161 | Returns
162 | -------
163 | FactorizedTensor
164 | the module to which tensor dropout has been attached
165 |
166 | Examples
167 | --------
168 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_()
169 | >>> tensor = tensor_dropout(tensor, p=0.5)
170 | >>> remove_tensor_dropout(tensor)
171 | """
172 | TensorDropout.apply(factorized_tensor, p, min_dim=min_dim, min_values=min_values, drop_test=drop_test)
173 |
174 | return factorized_tensor
175 |
176 |
177 | def remove_tensor_dropout(factorized_tensor):
178 | """Removes the tensor dropout from a TensorModule
179 |
180 | Parameters
181 | ----------
182 | factorized_tensor : tltorch.FactorizedTensor
183 | the tensor module parametrized by the tensor decomposition to which to apply tensor dropout
184 |
185 | Examples
186 | --------
187 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_()
188 | >>> tensor = tensor_dropout(tensor, p=0.5)
189 | >>> remove_tensor_dropout(tensor)
190 | """
191 | for key, hook in factorized_tensor._forward_hooks.items():
192 | if isinstance(hook, TensorDropout):
193 | del factorized_tensor._forward_hooks[key]
194 | return factorized_tensor
195 |
196 | raise ValueError(f'TensorLasso not found in factorized tensor {factorized_tensor}')
197 |
--------------------------------------------------------------------------------
/tltorch/tensor_hooks/_tensor_lasso.py:
--------------------------------------------------------------------------------
1 | import tensorly as tl
2 | tl.set_backend('pytorch')
3 |
4 | import warnings
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from ..factorized_tensors import TuckerTensor, TTTensor, CPTensor
10 | from ..utils import ParameterList
11 |
12 | # Author: Jean Kossaifi
13 | # License: BSD 3 clause
14 |
15 | class TensorLasso:
16 | """Generalized Tensor Lasso on factorized tensors
17 |
18 | Applies a generalized Lasso (l1 regularization) on a factorized tensor.
19 |
20 |
21 | Parameters
22 | ----------
23 | penalty : float, default is 0.01
24 | scaling factor for the loss
25 |
26 | clamp_weights : bool, default is True
27 | if True, the lasso weights are clamp between -1 and 1
28 |
29 | threshold : float, default is 1e-6
30 | if a lasso weight is lower than the set threshold, it is set to 0
31 |
32 | normalize_loss : bool, default is True
33 | If True, the loss will be between 0 and 1.
34 | Otherwise, the raw sum of absolute weights will be returned.
35 |
36 | Examples
37 | --------
38 |
39 | First you need to create an instance of the regularizer:
40 |
41 | >>> regularizer = tensor_lasso(factorization='cp')
42 |
43 | You can apply the regularizer to one or several layers:
44 |
45 | >>> trl = TRL((5, 5), (5, 5), rank='same')
46 | >>> trl2 = TRL((5, 5), (2, ), rank='same')
47 | >>> regularizer.apply(trl.weight)
48 | >>> regularizer.apply(trl2.weight)
49 |
50 | The lasso is automatically applied:
51 |
52 | >>> x = trl(x)
53 | >>> pred = trl2(x)
54 | >>> loss = your_loss_function(pred)
55 |
56 | Add the Lasso loss:
57 |
58 | >>> loss = loss + regularizer.loss
59 |
60 | You can now backpropagate through your loss as usual:
61 |
62 | >>> loss.backwards()
63 |
64 | After you finish updating the weights, don't forget to reset the regularizer,
65 | otherwise it will keep accumulating values!
66 |
67 | >>> loss.reset()
68 |
69 | You can also remove the regularizer with `regularizer.remove(trl)`.
70 | """
71 | _factorizations = dict()
72 |
73 | def __init_subclass__(cls, factorization, **kwargs):
74 | """When a subclass is created, register it in _factorizations"""
75 | cls._factorizations[factorization.__name__] = cls
76 |
77 | def __init__(self, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
78 | self.penalty = penalty
79 | self.clamp_weights = clamp_weights
80 | self.threshold = threshold
81 | self.normalize_loss = normalize_loss
82 |
83 | # Initialize the counters
84 | self.reset()
85 |
86 | def reset(self):
87 | """Reset the loss, should be called at the end of each iteration.
88 | """
89 | self._loss = 0
90 | self.n_element = 0
91 |
92 | @property
93 | def loss(self):
94 | """Returns the current Lasso (l1) loss for the layers that have been called so far.
95 |
96 | Returns
97 | -------
98 | float
99 | l1 regularization on the tensor layers the regularization has been applied to.
100 | """
101 | if self.n_element == 0:
102 | warnings.warn('The L1Regularization was not applied to any weights.')
103 | return 0
104 | elif self.normalize_loss:
105 | return self.penalty*self._loss/self.n_element
106 | else:
107 | return self.penalty*self._loss
108 |
109 | def __call__(self, module, input, tucker_tensor):
110 | raise NotImplementedError
111 |
112 | def apply_lasso(self, tucker_tensor, lasso_weights):
113 | """Applies the lasso to a decomposed tensor
114 | """
115 | raise NotImplementedError
116 |
117 | @classmethod
118 | def from_factorization(cls, factorization, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
119 | return cls.from_factorization_name(factorization.__class__.__name__, penalty=penalty,
120 | clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss)
121 |
122 | @classmethod
123 | def from_factorization_name(cls, factorization_name, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
124 | cls = cls._factorizations[factorization_name]
125 | lasso = cls(penalty=penalty, clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss)
126 | return lasso
127 |
128 | def remove(self, module):
129 | raise NotImplementedError
130 |
131 |
132 | class CPLasso(TensorLasso, factorization=CPTensor):
133 | """Decomposition Hook for Tensor Lasso on CP tensors
134 |
135 | Parameters
136 | ----------
137 | penalty : float, default is 0.01
138 | scaling factor for the loss
139 |
140 | clamp_weights : bool, default is True
141 | if True, the lasso weights are clamp between -1 and 1
142 |
143 | threshold : float, default is 1e-6
144 | if a lasso weight is lower than the set threshold, it is set to 0
145 |
146 | normalize_loss : bool, default is True
147 | If True, the loss will be between 0 and 1.
148 | Otherwise, the raw sum of absolute weights will be returned.
149 | """
150 | def __call__(self, module, input, cp_tensor):
151 | """CP already includes weights, we'll just take their l1 norm
152 | """
153 | weights = getattr(module, 'lasso_weights')
154 |
155 | with torch.no_grad():
156 | if self.clamp_weights:
157 | weights.data = torch.clamp(weights.data, -1, 1)
158 | setattr(module, 'lasso_weights', weights)
159 |
160 | if self.threshold:
161 | weights.data = F.threshold(weights.data, threshold=self.threshold, value=0, inplace=True)
162 | setattr(module, 'lasso_weights', weights)
163 |
164 | self.n_element += weights.numel()
165 | self._loss = self._loss + self.penalty*torch.norm(weights, 1)
166 | return cp_tensor
167 |
168 | def apply(self, module):
169 | """Apply an instance of the L1Regularizer to a tensor module
170 |
171 | Parameters
172 | ----------
173 | module : TensorModule
174 | module on which to add the regularization
175 |
176 | Returns
177 | -------
178 | TensorModule (with Regularization hook)
179 | """
180 | context = tl.context(module.factors[0])
181 | lasso_weights = nn.Parameter(torch.ones(module.rank, **context))
182 | setattr(module, 'lasso_weights', lasso_weights)
183 |
184 | module.register_forward_hook(self)
185 | return module
186 |
187 | def remove(self, module):
188 | delattr(module, 'lasso_weights')
189 |
190 | def set_weights(self, module, value):
191 | with torch.no_grad():
192 | module.lasso_weights.data.fill_(value)
193 |
194 |
195 | class TuckerLasso(TensorLasso, factorization=TuckerTensor):
196 | """Decomposition Hook for Tensor Lasso on Tucker tensors
197 |
198 | Applies a generalized Lasso (l1 regularization) on the tensor layers the regularization it is applied to.
199 |
200 |
201 | Parameters
202 | ----------
203 | penalty : float, default is 0.01
204 | scaling factor for the loss
205 |
206 | clamp_weights : bool, default is True
207 | if True, the lasso weights are clamp between -1 and 1
208 |
209 | threshold : float, default is 1e-6
210 | if a lasso weight is lower than the set threshold, it is set to 0
211 |
212 | normalize_loss : bool, default is True
213 | If True, the loss will be between 0 and 1.
214 | Otherwise, the raw sum of absolute weights will be returned.
215 | """
216 | _log = []
217 |
218 | def __call__(self, module, input, tucker_tensor):
219 | lasso_weights = getattr(module, 'lasso_weights')
220 | order = len(lasso_weights)
221 |
222 | with torch.no_grad():
223 | for i in range(order):
224 | if self.clamp_weights:
225 | lasso_weights[i].data = torch.clamp(lasso_weights[i].data, -1, 1)
226 |
227 | if self.threshold:
228 | lasso_weights[i] = F.threshold(lasso_weights[i], threshold=self.threshold, value=0, inplace=True)
229 |
230 | setattr(module, 'lasso_weights', lasso_weights)
231 |
232 | for weight in lasso_weights:
233 | self.n_element += weight.numel()
234 | self._loss = self._loss + torch.sum(torch.abs(weight))
235 |
236 | return self.apply_lasso(tucker_tensor, lasso_weights)
237 |
238 | def apply_lasso(self, tucker_tensor, lasso_weights):
239 | """Applies the lasso to a decomposed tensor
240 | """
241 | factors = tucker_tensor.factors
242 | factors = [factor*w for (factor, w) in zip(factors, lasso_weights)]
243 | return TuckerTensor(tucker_tensor.core, factors)
244 |
245 | def apply(self, module):
246 | """Apply an instance of the L1Regularizer to a tensor module
247 |
248 | Parameters
249 | ----------
250 | module : TensorModule
251 | module on which to add the regularization
252 |
253 | Returns
254 | -------
255 | TensorModule (with Regularization hook)
256 | """
257 | rank = module.rank
258 | context = tl.context(module.core)
259 | lasso_weights = ParameterList([nn.Parameter(torch.ones(r, **context)) for r in rank])
260 | setattr(module, 'lasso_weights', lasso_weights)
261 | module.register_forward_hook(self)
262 |
263 | return module
264 |
265 | def remove(self, module):
266 | delattr(module, 'lasso_weights')
267 |
268 | def set_weights(self, module, value):
269 | with torch.no_grad():
270 | for weight in module.lasso_weights:
271 | weight.data.fill_(value)
272 |
273 |
274 | class TTLasso(TensorLasso, factorization=TTTensor):
275 | """Decomposition Hook for Tensor Lasso on TT tensors
276 |
277 | Parameters
278 | ----------
279 | penalty : float, default is 0.01
280 | scaling factor for the loss
281 |
282 | clamp_weights : bool, default is True
283 | if True, the lasso weights are clamp between -1 and 1
284 |
285 | threshold : float, default is 1e-6
286 | if a lasso weight is lower than the set threshold, it is set to 0
287 |
288 | normalize_loss : bool, default is True
289 | If True, the loss will be between 0 and 1.
290 | Otherwise, the raw sum of absolute weights will be returned.
291 | """
292 | def __call__(self, module, input, tt_tensor):
293 | lasso_weights = getattr(module, 'lasso_weights')
294 | order = len(lasso_weights)
295 |
296 | with torch.no_grad():
297 | for i in range(order):
298 | if self.clamp_weights:
299 | lasso_weights[i].data = torch.clamp(lasso_weights[i].data, -1, 1)
300 |
301 | if self.threshold:
302 | lasso_weights[i] = F.threshold(lasso_weights[i], threshold=self.threshold, value=0, inplace=True)
303 |
304 | setattr(module, 'lasso_weights', lasso_weights)
305 |
306 | for weight in lasso_weights:
307 | self.n_element += weight.numel()
308 | self._loss = self._loss + torch.sum(torch.abs(weight))
309 |
310 | return self.apply_lasso(tt_tensor, lasso_weights)
311 |
312 | def apply_lasso(self, tt_tensor, lasso_weights):
313 | """Applies the lasso to a decomposed tensor
314 | """
315 | factors = tt_tensor.factors
316 | factors = [factor*w for (factor, w) in zip(factors, lasso_weights)] + [factors[-1]]
317 | return TTTensor(factors)
318 |
319 | def apply(self, module):
320 | """Apply an instance of the L1Regularizer to a tensor module
321 |
322 | Parameters
323 | ----------
324 | module : TensorModule
325 | module on which to add the regularization
326 |
327 | Returns
328 | -------
329 | TensorModule (with Regularization hook)
330 | """
331 | rank = module.rank[1:-1]
332 | lasso_weights = ParameterList([nn.Parameter(torch.ones(1, 1, r)) for r in rank])
333 | setattr(module, 'lasso_weights', lasso_weights)
334 | handle = module.register_forward_hook(self)
335 | return module
336 |
337 | def remove(self, module):
338 | """Remove the Regularization from a module.
339 | """
340 | delattr(module, 'lasso_weights')
341 |
342 | def set_weights(self, module, value):
343 | with torch.no_grad():
344 | for weight in module.lasso_weights:
345 | weight.data.fill_(value)
346 |
347 |
348 | def tensor_lasso(factorization='CP', penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
349 | """Generalized Tensor Lasso from a factorized tensors
350 |
351 | Applies a generalized Lasso (l1 regularization) on a factorized tensor.
352 |
353 |
354 | Parameters
355 | ----------
356 | factorization : str
357 |
358 | penalty : float, default is 0.01
359 | scaling factor for the loss
360 |
361 | clamp_weights : bool, default is True
362 | if True, the lasso weights are clamp between -1 and 1
363 |
364 | threshold : float, default is 1e-6
365 | if a lasso weight is lower than the set threshold, it is set to 0
366 |
367 | normalize_loss : bool, default is True
368 | If True, the loss will be between 0 and 1.
369 | Otherwise, the raw sum of absolute weights will be returned.
370 |
371 | Examples
372 | --------
373 |
374 | Let's say you have a set of factorized (here, CP) tensors:
375 |
376 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank='same', factorization='CP').normal_()
377 | >>> tensor2 = FactorizedTensor.new((5, 6, 7), rank=0.5, factorization='CP').normal_()
378 |
379 | First you need to create an instance of the regularizer:
380 |
381 | >>> regularizer = TensorLasso(factorization='cp', penalty=penalty)
382 |
383 | You can apply the regularizer to one or several layers:
384 |
385 | >>> regularizer.apply(tensor)
386 | >>> regularizer.apply(tensor2)
387 |
388 | The lasso is automatically applied:
389 |
390 | >>> sum = torch.sum(tensor() + tensor2())
391 |
392 | You can access the Lasso loss from your instance:
393 |
394 | >>> l1_loss = regularizer.loss
395 |
396 | You can optimize and backpropagate through your loss as usual.
397 |
398 | After you finish updating the weights, don't forget to reset the regularizer,
399 | otherwise it will keep accumulating values!
400 |
401 | >>> regularizer.reset()
402 |
403 | You can also remove the regularizer with `regularizer.remove(tensor)`,
404 | or `remove_tensor_lasso(tensor)`.
405 | """
406 | factorization = factorization.lower()
407 | mapping = dict(cp='CPTensor', tucker='TuckerTensor', tt='TTTensor')
408 | return TensorLasso.from_factorization_name(mapping[factorization], penalty=penalty, clamp_weights=clamp_weights,
409 | threshold=threshold, normalize_loss=normalize_loss)
410 |
411 | def remove_tensor_lasso(factorized_tensor):
412 | """Removes the tensor lasso from a TensorModule
413 |
414 | Parameters
415 | ----------
416 | factorized_tensor : tltorch.FactorizedTensor
417 | the tensor module parametrized by the tensor decomposition to which to apply tensor dropout
418 |
419 | Examples
420 | --------
421 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_()
422 | >>> tensor = tensor_lasso(tensor, p=0.5)
423 | >>> remove_tensor_lasso(tensor)
424 | """
425 | for key, hook in factorized_tensor._forward_hooks.items():
426 | if isinstance(hook, TensorLasso):
427 | hook.remove(factorized_tensor)
428 | del factorized_tensor._forward_hooks[key]
429 | return factorized_tensor
430 |
431 | raise ValueError(f'TensorLasso not found in factorized tensor {factorized_tensor}')
432 |
--------------------------------------------------------------------------------
/tltorch/tensor_hooks/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/tensor_hooks/tests/__init__.py
--------------------------------------------------------------------------------
/tltorch/tensor_hooks/tests/test_tensor_dropout.py:
--------------------------------------------------------------------------------
1 | from ... import FactorizedTensor
2 | from ...factorized_layers import TRL
3 | from .._tensor_dropout import tensor_dropout, remove_tensor_dropout
4 |
5 | import tensorly as tl
6 | tl.set_backend('pytorch')
7 |
8 | def test_tucker_dropout():
9 | """Test for Tucker Dropout"""
10 | shape = (10, 11, 12)
11 | rank = (7, 8, 9)
12 | tensor = FactorizedTensor.new(shape, rank=rank, factorization='Tucker')
13 | tensor = tensor_dropout(tensor, 0.999)
14 | core = tensor().core
15 | assert (tl.shape(core) == (1, 1, 1))
16 |
17 | remove_tensor_dropout(tensor)
18 | assert (not tensor._forward_hooks)
19 |
20 | tensor = tensor_dropout(tensor, 0)
21 | core = tensor().core
22 | assert (tl.shape(core) == rank)
23 |
24 | def test_cp_dropout():
25 | """Test for CP Dropout"""
26 | shape = (10, 11, 12)
27 | rank = 8
28 | tensor = FactorizedTensor.new(shape, rank=rank, factorization='CP')
29 | tensor = tensor_dropout(tensor, 0.999)
30 | weights = tensor().weights
31 | assert (len(weights) == (1))
32 |
33 | remove_tensor_dropout(tensor)
34 | assert (not tensor._forward_hooks)
35 |
36 | tensor = tensor_dropout(tensor, 0)
37 | weights = tensor().weights
38 | assert (len(weights) == rank)
39 |
40 |
41 | def test_tt_dropout():
42 | """Test for TT Dropout"""
43 | shape = (10, 11, 12)
44 | # Use the same rank for all factors
45 | rank = 4
46 | tensor = FactorizedTensor.new(shape, rank=rank, factorization='TT')
47 | tensor = tensor_dropout(tensor, 0.999)
48 | factors = tensor().factors
49 | for f in factors:
50 | assert (f.shape[0] == f.shape[-1] == 1)
51 |
52 | remove_tensor_dropout(tensor)
53 | assert (not tensor._forward_hooks)
54 |
55 | tensor = tensor_dropout(tensor, 0)
56 | factors = tensor().factors
57 | for i, f in enumerate(factors):
58 | if i:
59 | assert (f.shape[0] == rank)
60 | else: # boundary conditions: first and last rank are equal to 1
61 | assert (f.shape[-1] == rank)
62 |
--------------------------------------------------------------------------------
/tltorch/tensor_hooks/tests/test_tensor_lasso.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from .._tensor_lasso import TuckerLasso, TTLasso, CPLasso, remove_tensor_lasso, tensor_lasso
4 | from ... import FactorizedTensor
5 |
6 | import tensorly as tl
7 | tl.set_backend('pytorch')
8 | from tensorly import testing
9 | import numpy as np
10 | import torch
11 |
12 | @pytest.mark.parametrize('factorization', ['cp', 'tucker', 'tt'])
13 | def test_tensor_lasso(factorization):
14 | shape = (5, 5, 6)
15 | rank = 3
16 | tensor1 = FactorizedTensor.new(shape, rank, factorization=factorization).normal_()
17 | tensor2 = FactorizedTensor.new(shape, rank, factorization=factorization).normal_()
18 |
19 | lasso = tensor_lasso(factorization, penalty=1, clamp_weights=False, normalize_loss=False)
20 |
21 | lasso.apply(tensor1)
22 | lasso.apply(tensor2)
23 |
24 | # Sum of weights all equal to a given value
25 | value = 1.5
26 | lasso.set_weights(tensor1, value)
27 | lasso.set_weights(tensor2, value)
28 |
29 | data = torch.sum(tensor1())
30 | l1 = lasso.loss
31 | data = torch.sum(tensor2())
32 | l2 = lasso.loss
33 |
34 | # The result should be n-param * value
35 | # First tensor
36 | if factorization == 'tt':
37 | sum_rank = sum(tensor1.rank[1:-1])
38 | elif factorization == 'tucker':
39 | sum_rank = sum(tensor1.rank)
40 | elif factorization == 'cp':
41 | sum_rank = tensor1.rank
42 | testing.assert_(l1 == sum_rank*value)
43 | # Second tensor lasso
44 | if factorization == 'tt':
45 | sum_rank += sum(tensor2.rank[1:-1])
46 | elif factorization == 'tucker':
47 | sum_rank += sum(tensor2.rank)
48 | elif factorization == 'cp':
49 | sum_rank += tensor2.rank
50 | testing.assert_(l2 == sum_rank*value)
51 |
52 | testing.assert_(tensor1._forward_hooks)
53 | testing.assert_(tensor2._forward_hooks)
54 |
55 | ### Test when all weights are 0
56 | lasso.reset()
57 | lasso.set_weights(tensor1, 0)
58 | lasso.set_weights(tensor2, 0)
59 | torch.sum(tensor1()) + torch.sum(tensor2())
60 | testing.assert_(lasso.loss == 0)
61 |
62 | # Check the Lasso correctly removed
63 | remove_tensor_lasso(tensor1)
64 | testing.assert_(not tensor1._forward_hooks)
65 | testing.assert_(tensor2._forward_hooks)
66 | remove_tensor_lasso(tensor2)
67 | testing.assert_(not tensor1._forward_hooks)
68 |
69 | # Check normalization between 0 and 1
70 | lasso = tensor_lasso(factorization, penalty=1, normalize_loss=True, clamp_weights=False)
71 | lasso.apply(tensor1)
72 | tensor1()
73 | l1 = lasso.loss
74 | assert(abs(l1 - 1) < 1e-5)
75 | remove_tensor_lasso(tensor1)
--------------------------------------------------------------------------------
/tltorch/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .parameter_list import ParameterList, FactorList, ComplexFactorList
2 | from .tensorize_shape import get_tensorized_shape
--------------------------------------------------------------------------------
/tltorch/utils/parameter_list.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 |
5 | class FactorList(nn.Module):
6 | def __init__(self, parameters=None):
7 | super().__init__()
8 | self.keys = []
9 | self.counter = 0
10 | if parameters is not None:
11 | self.extend(parameters)
12 |
13 | def _unique_key(self):
14 | """Creates a new unique key"""
15 | key = f'factor_{self.counter}'
16 | self.counter += 1
17 | return key
18 |
19 | def append(self, element):
20 | key = self._unique_key()
21 | if torch.is_tensor(element):
22 | if isinstance(element, nn.Parameter):
23 | self.register_parameter(key, element)
24 | else:
25 | self.register_buffer(key, element)
26 | else:
27 | setattr(self, key, self.__class__(element))
28 | self.keys.append(key)
29 |
30 | def insert(self, index, element):
31 | key = self._unique_key()
32 | setattr(self ,key, element)
33 | self.keys.insert(index, key)
34 |
35 | def pop(self, index=-1):
36 | item = self[index]
37 | self.__delitem__(index)
38 | return item
39 |
40 | def __getitem__(self, index):
41 | keys = self.keys[index]
42 | if isinstance(keys, list):
43 | return self.__class__([getattr(self, key) for key in keys])
44 | return getattr(self, keys)
45 |
46 | def __setitem__(self, index, value):
47 | setattr(self, self.keys[index], value)
48 |
49 | def __delitem__(self, index):
50 | delattr(self, self.keys[index])
51 | self.keys.__delitem__(index)
52 |
53 | def __len__(self):
54 | return len(self.keys)
55 |
56 | def extend(self, parameters):
57 | for param in parameters:
58 | self.append(param)
59 |
60 | def __iadd__(self, parameters):
61 | return self.extend(parameters)
62 |
63 | def __add__(self, parameters):
64 | instance = self.__class__(self)
65 | instance.extend(parameters)
66 | return instance
67 |
68 | def __radd__(self, parameters):
69 | instance = self.__class__(parameters)
70 | instance.extend(self)
71 | return instance
72 |
73 | def extra_repr(self) -> str:
74 | child_lines = []
75 | for k, p in self._parameters.items():
76 | size_str = 'x'.join(str(size) for size in p.size())
77 | device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
78 | parastr = 'Parameter containing: [{} of size {}{}]'.format(
79 | torch.typename(p), size_str, device_str)
80 | child_lines.append(' (' + str(k) + '): ' + parastr)
81 | tmpstr = '\n'.join(child_lines)
82 | return tmpstr
83 |
84 |
85 | class ComplexFactorList(FactorList):
86 | def __getitem__(self, index):
87 | if isinstance(index, int):
88 | value = getattr(self, self.keys[index])
89 | if torch.is_tensor(value):
90 | value = torch.view_as_complex(value)
91 | return value
92 | else:
93 | keys = self.keys[index]
94 | return self.__class__([torch.view_as_complex(getattr(self, key)) for key in keys])
95 |
96 | def __setitem__(self, index, value):
97 | if torch.is_tensor(value):
98 | value = torch.view_as_real(value)
99 | setattr(self, self.keys[index], value)
100 |
101 | def register_parameter(self, key, value):
102 | value = nn.Parameter(torch.view_as_real(value))
103 | super().register_parameter(key, value)
104 |
105 | def register_buffer(self, key, value):
106 | value = torch.view_as_real(value)
107 | super().register_buffer(key, value)
108 |
109 |
110 | class ParameterList(nn.Module):
111 | def __init__(self, parameters=None):
112 | super().__init__()
113 | self.keys = []
114 | self.counter = 0
115 | if parameters is not None:
116 | self.extend(parameters)
117 |
118 | def _unique_key(self):
119 | """Creates a new unique key"""
120 | key = f'param_{self.counter}'
121 | self.counter += 1
122 | return key
123 |
124 | def append(self, element):
125 | # p = nn.Parameter(element)
126 | key = self._unique_key()
127 | self.register_parameter(key, element)
128 | self.keys.append(key)
129 |
130 | def insert(self, index, element):
131 | # p = nn.Parameter(element)
132 | key = self._unique_key()
133 | self.register_parameter(key, element)
134 | self.keys.insert(index, key)
135 |
136 | def pop(self, index=-1):
137 | item = self[index]
138 | self.__delitem__(index)
139 | return item
140 |
141 | def __getitem__(self, index):
142 | keys = self.keys[index]
143 | if isinstance(keys, list):
144 | return self.__class__([getattr(self, key) for key in keys])
145 | return getattr(self, keys)
146 |
147 | def __setitem__(self, index, value):
148 | self.register_parameter(self.keys[index], value)
149 |
150 | def __delitem__(self, index):
151 | delattr(self, self.keys[index])
152 | self.keys.__delitem__(index)
153 |
154 | def __len__(self):
155 | return len(self.keys)
156 |
157 | def extend(self, parameters):
158 | for param in parameters:
159 | self.append(param)
160 |
161 | def __iadd__(self, parameters):
162 | return self.extend(parameters)
163 |
164 | def extra_repr(self) -> str:
165 | child_lines = []
166 | for k, p in self._parameters.items():
167 | size_str = 'x'.join(str(size) for size in p.size())
168 | device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
169 | parastr = 'Parameter containing: [{} of size {}{}]'.format(
170 | torch.typename(p), size_str, device_str)
171 | child_lines.append(' (' + str(k) + '): ' + parastr)
172 | tmpstr = '\n'.join(child_lines)
173 | return tmpstr
174 |
--------------------------------------------------------------------------------
/tltorch/utils/tensorize_shape.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | from bisect import insort_left
4 |
5 | # Author : Jean Kossaifi
6 |
7 | def factorize(value, min_value=2, remaining=-1):
8 | """Factorize an integer input value into it's smallest divisors
9 |
10 | Parameters
11 | ----------
12 | value : int
13 | integer to factorize
14 | min_value : int, default is 2
15 | smallest divisors to use
16 | remaining : int, default is -1
17 | DO NOT SPECIFY THIS VALUE, IT IS USED FOR TAIL RECURSION
18 |
19 | Returns
20 | -------
21 | factorization : int tuple
22 | ints such that prod(factorization) == value
23 | """
24 | if value <= min_value or remaining == 0:
25 | return (value, )
26 | lim = math.isqrt(value)
27 | for i in range(min_value, lim+1):
28 | if value == i:
29 | return (i, )
30 | if not (value % i):
31 | return (i, *factorize(value//i, min_value=min_value, remaining=remaining-1))
32 | return (value, )
33 |
34 | def merge_ints(values, size):
35 | """Utility function to merge the smallest values in a given tuple until it's length is the given size
36 |
37 | Parameters
38 | ----------
39 | values : int list
40 | list of values to merge
41 | size : int
42 | target len of the list
43 | stop merging when len(values) <= size
44 |
45 | Returns
46 | -------
47 | merge_values : list of size ``size``
48 | """
49 | if len(values) <= 1:
50 | return values
51 |
52 | values = sorted(list(values))
53 | while (len(values) > size):
54 | a, b, *values = values
55 | insort_left(values, a*b)
56 |
57 | return tuple(values)
58 |
59 | def get_tensorized_shape(in_features, out_features, order=None, min_dim=2, verbose=True):
60 | """ Factorizes in_features and out_features such that:
61 | * they both are factorized into the same number of integers
62 | * they should both be factorized into `order` integers
63 | * each of the factors should be at least min_dim
64 |
65 | This is used to tensorize a matrix of size (in_features, out_features) into a higher order tensor
66 |
67 | Parameters
68 | ----------
69 | in_features, out_features : int
70 | order : int
71 | the number of integers that each input should be factorized into
72 | min_dim : int
73 | smallest acceptable integer value for the factors
74 |
75 | Returns
76 | -------
77 | in_tensorized, out_tensorized : tuple[int]
78 | tuples of ints used to tensorize each dimension
79 |
80 | Notes
81 | -----
82 | This is a bruteforce solution but is enough for the dimensions we encounter in DNNs
83 | """
84 | in_ten = factorize(in_features, min_value=min_dim)
85 | out_ten = factorize(out_features, min_value=min_dim, remaining=len(in_ten))
86 | if order is not None:
87 | merge_size = min(order, len(in_ten), len(out_ten))
88 | else:
89 | merge_size = min(len(in_ten), len(out_ten))
90 |
91 | if len(in_ten) > merge_size:
92 | in_ten = merge_ints(in_ten, size=merge_size)
93 | if len(out_ten) > merge_size:
94 | out_ten = merge_ints(out_ten, size=merge_size)
95 |
96 | if verbose:
97 | print(f'Tensorizing (in, out)=({in_features, out_features}) -> ({in_ten, out_ten})')
98 | return in_ten, out_ten
99 |
--------------------------------------------------------------------------------