├── .github └── workflows │ ├── build_doc.yml │ ├── deploy_pypi.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.rst ├── doc ├── Makefile ├── _static │ ├── FLOP.pdf │ ├── TRL.png │ ├── cp-conv.png │ ├── deep_tensor_nets_pros_circle.png │ ├── domain-adaptation.png │ ├── favicon │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-256x256.png │ │ ├── apple-touch-icon.png │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ ├── favicon.ico │ │ ├── safari-pinned-tab.svg │ │ └── site.webmanifest │ ├── logos │ │ ├── logo_caltech.png │ │ ├── logo_nvidia.png │ │ └── tensorly-torch-logo.png │ ├── mobilenet-conv.pdf │ ├── mobilenet-v2-conv.pdf │ ├── tensorly-torch-pyramid.png │ ├── transduction.png │ └── tucker-conv.pdf ├── _templates │ ├── class.rst │ └── function.rst ├── about.rst ├── conf.py ├── dev_guide │ ├── api.rst │ ├── contributing.rst │ └── index.rst ├── index.rst ├── install.rst ├── make.bat ├── minify.py ├── modules │ └── api.rst ├── requirements_doc.txt └── user_guide │ ├── factorized_conv.rst │ ├── factorized_embeddings.rst │ ├── factorized_tensors.rst │ ├── index.rst │ ├── tensor_hooks.rst │ ├── tensorized_linear.rst │ └── trl.rst ├── requirements.txt ├── setup.py └── tltorch ├── __init__.py ├── factorized_layers ├── __init__.py ├── factorized_convolution.py ├── factorized_embedding.py ├── factorized_linear.py ├── tensor_contraction_layers.py ├── tensor_regression_layers.py └── tests │ ├── __init__.py │ ├── test_factorized_convolution.py │ ├── test_factorized_embedding.py │ ├── test_factorized_linear.py │ ├── test_tensor_contraction_layers.py │ └── test_trl.py ├── factorized_tensors ├── __init__.py ├── complex_factorized_tensors.py ├── complex_tensorized_matrices.py ├── core.py ├── factorized_tensors.py ├── init.py ├── tensorized_matrices.py └── tests │ ├── __init__.py │ └── test_factorizations.py ├── functional ├── __init__.py ├── convolution.py ├── factorized_linear.py ├── factorized_tensordot.py ├── linear.py ├── tensor_regression.py └── tests │ ├── __init__.py │ └── test_factorized_linear.py ├── tensor_hooks ├── __init__.py ├── _tensor_dropout.py ├── _tensor_lasso.py └── tests │ ├── __init__.py │ ├── test_tensor_dropout.py │ └── test_tensor_lasso.py └── utils ├── __init__.py ├── parameter_list.py └── tensorize_shape.py /.github/workflows/build_doc.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | - name: Install Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.9 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | python -m pip install -r requirements.txt 24 | python -m pip install -r doc/requirements_doc.txt 25 | python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu 26 | - name: Install package 27 | run: | 28 | python -m pip install -e . 29 | - name: Make doc 30 | run: | 31 | cd doc 32 | python minify.py 33 | make html 34 | cd .. 35 | - name: Push docs 36 | run: | 37 | # See https://github.community/t/github-actions-bot-email-address/17204/5 38 | git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" 39 | git config --global user.name "github-actions" 40 | git fetch origin gh-pages 41 | git checkout gh-pages 42 | git rm -r dev/* 43 | cp -r doc/_build/html/* dev 44 | git add dev 45 | # If the doc is up to date, the script shouldn't fail, hence --allow-empty 46 | # Might be a cleaner way to check 47 | git commit --allow-empty -m "Deployed to GitHub Pages" 48 | git push --force origin gh-pages 49 | -------------------------------------------------------------------------------- /.github/workflows/deploy_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to Pypi 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | - name: Install Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.9 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | python -m pip install -r requirements.txt 24 | python -m pip install -r doc/requirements_doc.txt 25 | python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu 26 | - name: Install package 27 | run: | 28 | python -m pip install -e . 29 | pip install setuptools wheel 30 | - name: Build a binary wheel and a source tarball 31 | run: | 32 | python setup.py sdist bdist_wheel 33 | - name: Publish package to TestPyPI 34 | uses: pypa/gh-action-pypi-publish@master 35 | with: 36 | user: __token__ 37 | password: ${{ secrets.TEST_PYPI_PASSWORD }} 38 | repository_url: https://test.pypi.org/legacy/ 39 | - name: Publish package to PyPI 40 | uses: pypa/gh-action-pypi-publish@master 41 | with: 42 | user: __token__ 43 | password: ${{ secrets.PYPI_PASSWORD }} 44 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test TensorLy-Torch 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 3.12 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: 3.12 16 | - name: Install dependencies 17 | run: | 18 | echo "* Updating pip" 19 | python -m pip install --upgrade pip 20 | echo "* Installing requirements" 21 | python -m pip install -r requirements.txt 22 | echo "* Installing documentation requirements" 23 | python -m pip install -r doc/requirements_doc.txt 24 | echo "* Installing PyTorch" 25 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 26 | - name: Install TensorLy dev 27 | run: | 28 | CWD=`pwd` 29 | echo 'Cloning TensorLy in ${CWD}' 30 | mkdir git_repos 31 | cd git_repos 32 | git clone https://github.com/tensorly/tensorly 33 | cd tensorly 34 | python -m pip install -e . 35 | cd .. 36 | - name: Install package 37 | run: | 38 | python -m pip install -e . 39 | - name: Test with pytest 40 | run: | 41 | pytest -vvv tltorch 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.DS_Store 6 | *.vscode 7 | 8 | # C extensions 9 | *.so 10 | *.py~ 11 | 12 | # Pycharm 13 | .idea 14 | 15 | # vim temp files 16 | *.swp 17 | 18 | # Sphinx doc 19 | doc/_build/ 20 | doc/auto_examples/ 21 | doc/modules/generated/ 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | .pytest_cache/ 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *,cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | #Ipython Notebook 77 | .ipynb_checkpoints 78 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, The TensorLy-Torch Developers 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://badge.fury.io/py/tensorly-torch.svg 2 | :target: https://badge.fury.io/py/tensorly-torch 3 | 4 | 5 | ============== 6 | TensorLy-Torch 7 | ============== 8 | 9 | TensorLy-Torch is a Python library for deep tensor networks that 10 | builds on top of `TensorLy `_ 11 | and `PyTorch `_. 12 | It allows to easily leverage tensor methods in a deep learning setting and comes with all batteries included. 13 | 14 | - **Website:** http://tensorly.org/torch/ 15 | - **Source-code:** https://github.com/tensorly/torch 16 | 17 | 18 | With TensorLy-Torch, you can easily: 19 | 20 | - **Tensor Factorizations**: decomposing, manipulating and initializing tensor decompositions can be tricky. We take care of it all, in a convenient, unified API. 21 | - **Leverage structure in your data**: with tensor layers, you can easily leverage the structure in your data, through Tensor Regression Layers, Factorized Convolutions, etc 22 | - **Built-in tensor layers**: all you have to do is import tensorly torch and include the layers we provide directly within your PyTorch models! 23 | - **Tensor hooks**: you can easily augment your architectures with our built-in Tensor Hooks. Robustify your network with Tensor Dropout and automatically select the rank end-to-end with L1 Regularization! 24 | - **All the methods available**: we are always adding more methods to make it easy to compare between the performance of various deep tensor based methods! 25 | 26 | Deep Tensorized Learning 27 | ======================== 28 | 29 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors. 30 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible. 31 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks. 32 | 33 | .. image:: ./doc/_static/tensorly-torch-pyramid.png 34 | 35 | TensorLy is a Python library that aims at making tensor learning simple and accessible. 36 | It provides a high-level API for tensor methods, including core tensor operations, tensor decomposition and regression. 37 | It has a flexible backend that allows running operations seamlessly using NumPy, PyTorch, TensorFlow, JAX, MXNet and CuPy. 38 | 39 | **TensorLy-Torch** is a PyTorch only library that builds on top of TensorLy and provides out-of-the-box tensor layers. 40 | 41 | Improve your neural networks with tensor methods 42 | ------------------------------------------------ 43 | 44 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors. 45 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible. 46 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks. 47 | 48 | In TensorLy-Torch, we provide convenient layers that do all the heavy lifting for you 49 | and provide the benefits tensor based layers wrapped in a nice, well documented and tested API. 50 | 51 | For instance, convolution layers of any order (2D, 3D or more), can be efficiently parametrized 52 | using tensor decomposition. Using a CP decomposition results in a separable convolution 53 | and you can replace your original convolution with a series of small efficient ones: 54 | 55 | .. image:: ./doc/_static/cp-conv.png 56 | 57 | These can be easily perform with FactorizedConv in TensorLy-Torch. 58 | We also have Tucker convolutions and new tensor-train convolutions! 59 | We also implement various other methods such as tensor regression and contraction layers, 60 | tensorized linear layers, tensor dropout and more! 61 | 62 | 63 | Installing TensorLy-Torch 64 | ========================= 65 | 66 | Through pip 67 | ----------- 68 | 69 | .. code:: 70 | 71 | pip install tensorly-torch 72 | 73 | 74 | From source 75 | ----------- 76 | 77 | .. code:: 78 | 79 | git clone https://github.com/tensorly/torch 80 | cd torch 81 | pip install -e . 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | .PHONY: debug 23 | debug: 24 | @$(SPHINXBUILD) -M "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -vv 25 | 26 | -------------------------------------------------------------------------------- /doc/_static/FLOP.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/FLOP.pdf -------------------------------------------------------------------------------- /doc/_static/TRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/TRL.png -------------------------------------------------------------------------------- /doc/_static/cp-conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/cp-conv.png -------------------------------------------------------------------------------- /doc/_static/deep_tensor_nets_pros_circle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/deep_tensor_nets_pros_circle.png -------------------------------------------------------------------------------- /doc/_static/domain-adaptation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/domain-adaptation.png -------------------------------------------------------------------------------- /doc/_static/favicon/android-chrome-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/android-chrome-192x192.png -------------------------------------------------------------------------------- /doc/_static/favicon/android-chrome-256x256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/android-chrome-256x256.png -------------------------------------------------------------------------------- /doc/_static/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /doc/_static/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /doc/_static/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /doc/_static/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/favicon/favicon.ico -------------------------------------------------------------------------------- /doc/_static/favicon/safari-pinned-tab.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 8 | Created by potrace 1.11, written by Peter Selinger 2001-2013 9 | 10 | 12 | 17 | 20 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /doc/_static/favicon/site.webmanifest: -------------------------------------------------------------------------------- 1 | { 2 | "name": "", 3 | "short_name": "", 4 | "icons": [ 5 | { 6 | "src": "/android-chrome-192x192.png", 7 | "sizes": "192x192", 8 | "type": "image/png" 9 | }, 10 | { 11 | "src": "/android-chrome-256x256.png", 12 | "sizes": "256x256", 13 | "type": "image/png" 14 | } 15 | ], 16 | "theme_color": "#ffffff", 17 | "background_color": "#ffffff", 18 | "display": "standalone" 19 | } 20 | -------------------------------------------------------------------------------- /doc/_static/logos/logo_caltech.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/logos/logo_caltech.png -------------------------------------------------------------------------------- /doc/_static/logos/logo_nvidia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/logos/logo_nvidia.png -------------------------------------------------------------------------------- /doc/_static/logos/tensorly-torch-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/logos/tensorly-torch-logo.png -------------------------------------------------------------------------------- /doc/_static/mobilenet-conv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/mobilenet-conv.pdf -------------------------------------------------------------------------------- /doc/_static/mobilenet-v2-conv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/mobilenet-v2-conv.pdf -------------------------------------------------------------------------------- /doc/_static/tensorly-torch-pyramid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/tensorly-torch-pyramid.png -------------------------------------------------------------------------------- /doc/_static/transduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/transduction.png -------------------------------------------------------------------------------- /doc/_static/tucker-conv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/doc/_static/tucker-conv.pdf -------------------------------------------------------------------------------- /doc/_templates/class.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | :members: 8 | 9 | .. raw:: html 10 | 11 |
12 | -------------------------------------------------------------------------------- /doc/_templates/function.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | 8 | .. raw:: html 9 | 10 |
11 | -------------------------------------------------------------------------------- /doc/about.rst: -------------------------------------------------------------------------------- 1 | .. _about_us: 2 | 3 | About Us 4 | ======== 5 | 6 | TensorLy-Torch is an open-source effort, led primarily at NVIDIA Research by `Jean Kossaifi`_. 7 | It is part of the TensorLy project and builds on top of the core TensorLy in order to provide out-of-the box PyTorch tensor layers for deep learning. 8 | 9 | 10 | Core team 11 | --------- 12 | 13 | * `Anima Anandkumar`_ 14 | * `Wonmin Byeon`_ 15 | * `Jean Kossaifi`_ 16 | * `Saurav Muralidharan`_ 17 | 18 | Supporters 19 | ---------- 20 | 21 | The TensorLy project is supported by: 22 | 23 | .. image:: _static/logos/logo_nvidia.png 24 | :width: 150pt 25 | :align: center 26 | :target: https://www.nvidia.com 27 | :alt: NVIDIA 28 | 29 | ........ 30 | 31 | .. image:: _static/logos/logo_caltech.png 32 | :width: 150pt 33 | :align: center 34 | :target: https://www.caltech.edu 35 | :alt: California Institute of Technology 36 | 37 | 38 | .. _Jean Kossaifi: http://jeankossaifi.com/ 39 | .. _Anima Anandkumar: http://tensorlab.cms.caltech.edu/users/anima/ 40 | .. _Wonmin Byeon: https://wonmin-byeon.github.io/ 41 | .. _Saurav Muralidharan: https://www.sauravm.com -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | # sys.path.insert(0, os.path.abspath('../examples')) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'TensorLy-Torch' 21 | from datetime import datetime 22 | year = datetime.now().year 23 | copyright = f'{year}, Jean Kossaifi' 24 | author = 'Jean Kossaifi' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | import tltorch 28 | release = tltorch.__version__ 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.autosummary', 39 | 'sphinx.ext.todo', 40 | 'sphinx.ext.viewcode', 41 | 'sphinx.ext.githubpages', 42 | # "nbsphinx", 43 | # "myst_nb", 44 | # "sphinx_nbexamples", 45 | # 'jupyter_sphinx', 46 | # 'matplotlib.sphinxext.plot_directive', 47 | 'sphinx.ext.mathjax', #'sphinx.ext.imgmath', 48 | 'numpydoc.numpydoc', 49 | ] 50 | 51 | # # # Sphinx-nbexamples 52 | # process_examples = False 53 | # example_gallery_config = dict(pattern='+/+.ipynb') 54 | 55 | # Remove the permalinks ("¶" symbols) 56 | html_add_permalinks = "" 57 | 58 | # NumPy 59 | numpydoc_class_members_toctree = False 60 | numpydoc_show_class_members = True 61 | numpydoc_show_inherited_class_members = False 62 | 63 | # generate autosummary even if no references 64 | autosummary_generate = True 65 | autodoc_member_order = 'bysource' 66 | autodoc_default_flags = ['members'] 67 | 68 | # Napoleon 69 | napoleon_google_docstring = False 70 | napoleon_use_rtype = False 71 | 72 | # imgmath/mathjax 73 | imgmath_image_format = 'svg' 74 | 75 | # The suffix(es) of source filenames. 76 | # You can specify multiple suffix as a list of string: 77 | # source_suffix = ['.rst', '.md'] 78 | source_suffix = '.rst' 79 | 80 | 81 | # The master toctree document. 82 | master_doc = 'index' 83 | 84 | # Add any paths that contain templates here, relative to this directory. 85 | templates_path = ['_templates'] 86 | 87 | # List of patterns, relative to source directory, that match files and 88 | # directories to ignore when looking for source files. 89 | # This pattern also affects html_static_path and html_extra_path. 90 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 91 | 92 | 93 | primary_domain = 'py' 94 | 95 | # -- Options for HTML output ------------------------------------------------- 96 | 97 | # The theme to use for HTML and HTML Help pages. See the documentation for 98 | # a list of builtin themes. 99 | html_theme = 'tensorly_sphinx_theme' 100 | html_logo = '_static/logos/tensorly-torch-logo.png' 101 | 102 | # Add any paths that contain custom static files (such as style sheets) here, 103 | # relative to this directory. They are copied after the builtin static files, 104 | # so a file named "default.css" will overwrite the builtin "default.css". 105 | html_static_path = ['_static'] 106 | 107 | html_theme_options = { 108 | 'github_url': 'https://github.com/tensorly/torch', 109 | 'google_analytics' : 'G-QSPLEF75VT', 110 | 'searchbar_text': 'Search TensorLy-Torch', 111 | 'nav_links' : [('Install', 'install'), 112 | ('User Guide', 'user_guide/index'), 113 | ('API', 'modules/api'), 114 | ('About Us', 'about')], 115 | 'nav_dropdowns' : [('Ecosystem', 116 | [('TensorLy', 'http://tensorly.org/dev'), 117 | ('TensorLy-Viz', 'http://tensorly.org/viz'), 118 | ('TensorLy-Quantum', 'http://tensorly.org/quantum'), 119 | ]), 120 | ], 121 | } 122 | 123 | # -- Options for LaTeX output --------------------------------------------- 124 | 125 | # Grouping the document tree into LaTeX files. List of tuples 126 | # (source start file, target name, title, 127 | # author, documentclass [howto, manual, or own class]). 128 | latex_documents = [ 129 | (master_doc, 'tensorly.tex', 'Tensor operations in Python', 130 | 'Jean Kossaifi', 'manual'), 131 | ] 132 | 133 | latex_preamble = r""" 134 | \usepackage{amsmath}\usepackage{amsfonts} 135 | \setcounter{MaxMatrixCols}{20} 136 | """ 137 | 138 | imgmath_latex_preamble = latex_preamble 139 | 140 | latex_elements = { 141 | 'classoptions': ',oneside', 142 | 'babel': '\\usepackage[english]{babel}', 143 | # Get completely rid of index 144 | 'printindex': '', 145 | 'preamble': latex_preamble, 146 | } 147 | -------------------------------------------------------------------------------- /doc/dev_guide/api.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Style and API 3 | ============= 4 | 5 | In TensorLy-Torch (and more generally in the TensorLy project), 6 | we try to maintain a simple and consistent API. 7 | 8 | Here are some elements to consider *Coming Soon* -------------------------------------------------------------------------------- /doc/dev_guide/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | We actively welcome new contributions! 5 | If you know of a tensor method that should be in TensorLy-Torch 6 | or if you spot any mistake, typo, missing documentation etc, 7 | please report it, and even better, open a Pull-Request! 8 | 9 | How-to 10 | ------ 11 | 12 | To make sure the contribution is relevant and is not already worked on, 13 | you can `open an issue `_! 14 | 15 | To add code of fix issues in TensorLy, 16 | you will want to open a `Pull-Request `_ 17 | on the Github repository of the project. 18 | 19 | Guidelines 20 | ---------- 21 | 22 | For each function or class, we expect helpful docstrings in the NumPy format, 23 | as well as unit-tests to make sure it is working as expected 24 | (especially helpful for future refactoring to make sure no exising code is broken!) 25 | 26 | Check the existing code for examples, 27 | and don't hesitate to contact us if you are unsure or have any question! 28 | 29 | -------------------------------------------------------------------------------- /doc/dev_guide/index.rst: -------------------------------------------------------------------------------- 1 | .. _dev_guide: 2 | 3 | ================= 4 | Development guide 5 | ================= 6 | 7 | .. toctree:: 8 | 9 | contributing.rst 10 | api.rst 11 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | :no-toc: 2 | :no-localtoc: 3 | :no-pagination: 4 | 5 | .. TensorLy-Torch documentation 6 | 7 | .. only:: html 8 | 9 | .. raw:: html 10 | 11 |

12 | 13 | .. image:: _static/logos/tensorly-torch-logo.png 14 | :align: center 15 | :width: 500 16 | 17 | 18 | .. only:: html 19 | 20 | .. raw:: html 21 | 22 |
23 |

Tensor Batteries Included

24 |
25 |

26 | 27 | 28 | **TensorLy-Torch** is a PyTorch only library that builds on top of `TensorLy `_ and provides out-of-the-box tensor layers. 29 | It comes with all batteries included and tries to make it as easy as possible to use tensor methods within your deep networks. 30 | 31 | - **Leverage structure in your data**: with tensor layers, you can easily leverage the structure in your data, through :ref:`TRL `, :ref:`TCL `, :ref:`Factorized convolutions ` and more! 32 | - **Factorized tensors** as first class citizens: you can transparently directly create, manipulate and index factorized tensors and regular (dense) pytorch tensors alike! 33 | - **Built-in tensor layers**: all you have to do is import tensorly torch and include the layers we provide directly within your PyTorch models! 34 | - **Initialization**: initializing tensor decompositions can be tricky. We take care of it all, whether you want to initialize randomly using our :ref:`init_ref` module or from a pretrained layer. 35 | - **Tensor hooks**: you can easily augment your architectures with our built-in :mod:`Tensor Hooks `. Robustify your network with Tensor Dropout and automatically select the rank end-to-end with L1 Regularization! 36 | - **All the methods available**: we are always adding more methods to make it easy to compare between the performance of various deep tensor based methods! 37 | 38 | Deep Tensorized Learning 39 | ======================== 40 | 41 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors. 42 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible. 43 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks. 44 | 45 | .. image:: _static/tensorly-torch-pyramid.png 46 | :align: center 47 | :width: 800 48 | 49 | TensorLy is a Python library that aims at making tensor learning simple and accessible. 50 | It provides a high-level API for tensor methods, including core tensor operations, tensor decomposition and regression. 51 | It has a flexible backend that allows running operations seamlessly using NumPy, PyTorch, TensorFlow, JAX, MXNet and CuPy. 52 | 53 | **TensorLy-Torch** is a PyTorch only library that builds on top of TensorLy and provides out-of-the-box tensor layers. 54 | 55 | Improve your neural networks with tensor methods 56 | ------------------------------------------------ 57 | 58 | Tensor methods generalize matrix algebraic operations to higher-orders. Deep neural networks typically map between higher-order tensors. 59 | In fact, it is the ability of deep convolutional neural networks to preserve and leverage local structure that, along with large datasets and efficient hardware, made the current levels of performance possible. 60 | Tensor methods allow to further leverage and preserve that structure, for individual layers or whole networks. 61 | 62 | .. image:: _static/deep_tensor_nets_pros_circle.png 63 | :align: center 64 | :width: 350 65 | 66 | 67 | In TensorLy-Torch, we provide convenient layers that do all the heavy lifting for you 68 | and provide the benefits tensor based layers wrapped in a nice, well documented and tested API. 69 | 70 | For instance, convolution layers of any order (2D, 3D or more), can be efficiently parametrized 71 | using tensor decomposition. Using a CP decomposition results in a separable convolution 72 | and you can replace your original convolution with a series of small efficient ones: 73 | 74 | .. image:: _static/cp-conv.png 75 | :width: 500 76 | :align: center 77 | 78 | These can be easily perform with the :ref:`factorized_conv_ref` module in TensorLy-Torch. 79 | We also have Tucker convolutions and new tensor-train convolutions! 80 | We also implement various other methods such as tensor regression and contraction layers, 81 | tensorized linear layers, tensor dropout and more! 82 | 83 | 84 | .. toctree:: 85 | :maxdepth: 1 86 | :hidden: 87 | 88 | install 89 | modules/api 90 | user_guide/index 91 | dev_guide/index 92 | about 93 | 94 | 95 | .. only:: html 96 | 97 | .. raw:: html 98 | 99 |

100 |
101 | 102 | 107 | 108 | -------------------------------------------------------------------------------- /doc/install.rst: -------------------------------------------------------------------------------- 1 | ========================= 2 | Installing tensorly-Torch 3 | ========================= 4 | 5 | 6 | Pre-requisite 7 | ============= 8 | 9 | You will need to have Python 3 installed, as well as NumPy, Scipy and `TensorLy `_. 10 | If you are starting with Python or generally want a pain-free experience, I recommend you install the `Anaconda distribiution `_. It comes with all you need shipped-in and ready to use! 11 | 12 | 13 | Installing with pip (recommended) 14 | ================================= 15 | 16 | 17 | Simply run, in your terminal:: 18 | 19 | pip install -U tensorly-torch 20 | 21 | (the `-U` is optional, use it if you want to update the package). 22 | 23 | 24 | Cloning the github repository 25 | ============================= 26 | 27 | Clone the repository and cd there:: 28 | 29 | git clone https://github.com/tensorly/torch 30 | cd torch 31 | 32 | Then install the package (here in editable mode with `-e` or equivalently `--editable`:: 33 | 34 | pip install -e . 35 | 36 | Running the tests 37 | ================= 38 | 39 | Uni-testing is an vital part of this package. 40 | You can run all the tests using `pytest`:: 41 | 42 | pip install pytest 43 | pytest tltorch 44 | 45 | Building the documentation 46 | ========================== 47 | 48 | You will need to install slimit and minify:: 49 | 50 | pip install slimit rcssmin 51 | 52 | You are now ready to build the doc (here in html):: 53 | 54 | make html 55 | 56 | The results will be in `_build/html` 57 | 58 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/minify.py: -------------------------------------------------------------------------------- 1 | from jsmin import jsmin 2 | from rcssmin import cssmin 3 | 4 | from pathlib import Path 5 | asset_path = Path('./themes/tensorly/static') 6 | 7 | for path in asset_path.glob('*.js'): 8 | # Ignore already minified files 9 | if '.min.' in str(path): 10 | continue 11 | target_path = path.with_suffix('.min.js') 12 | with open(path.as_posix(), 'r') as f: 13 | text = f.read() 14 | minified = jsmin(text, mangle=True, mangle_toplevel=True) 15 | with open(target_path.as_posix(), 'w') as f: 16 | f.write(minified) 17 | 18 | for path in asset_path.glob('*.css'): 19 | # Ignore already minified files 20 | if '.min.' in str(path): 21 | continue 22 | target_path = path.with_suffix('.min.css') 23 | with open(path.as_posix(), 'r') as f: 24 | text = f.read() 25 | minified = cssmin(text) 26 | with open(target_path.as_posix(), 'w') as f: 27 | f.write(minified) 28 | 29 | -------------------------------------------------------------------------------- /doc/modules/api.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | API reference 3 | ============= 4 | 5 | :mod:`tltorch`: Tensorized Deep Neural Networks 6 | 7 | .. automodule:: tltorch 8 | :no-members: 9 | :no-inherited-members: 10 | 11 | .. _factorized_tensor_ref: 12 | 13 | Factorized Tensors 14 | ================== 15 | 16 | TensorLy-Torch builds on top of TensorLy and provides out of the box PyTorch layers for tensor based operations. 17 | The core of this is the concept of factorized tensors, which factorize our layers, instead of regular, dense PyTorch tensors. 18 | 19 | You can create any factorized tensor through the main class using: 20 | 21 | .. autosummary:: 22 | :toctree: generated 23 | :template: class.rst 24 | 25 | FactorizedTensor 26 | 27 | You can create a tensor of any form using ``FactorizedTensor.new(shape, rank, factorization)``, where factorization can be `Dense`, `CP`, `Tucker` or `TT`. 28 | Note that if you use ``factorization = 'dense'`` you are just creating a regular, unfactorized tensor. 29 | This allows to manipulate any tensor, factorized or not, with a simple, unified interface. 30 | 31 | Alternatively, you can also directly create a specific subclass: 32 | 33 | .. autosummary:: 34 | :toctree: generated 35 | :template: class.rst 36 | 37 | DenseTensor 38 | CPTensor 39 | TuckerTensor 40 | TTTensor 41 | 42 | .. _factorized_matrix_ref: 43 | 44 | Tensorized Matrices 45 | =================== 46 | 47 | In TensorLy-Torch , you can also represent matrices in *tensorized* form, as low-rank tensors. 48 | Just as for factorized tensor, you can create a tensorized matrix through the main class using: 49 | 50 | .. autosummary:: 51 | :toctree: generated 52 | :template: class.rst 53 | 54 | TensorizedTensor 55 | 56 | You can create a tensor of any form using ``TensorizedTensor.new(tensorized_shape, rank, factorization)``, where factorization can be `Dense`, `CP`, `Tucker` or `BlockTT`. 57 | 58 | You can also explicitly create the type of tensor you want using the following classes: 59 | 60 | .. autosummary:: 61 | :toctree: generated 62 | :template: class.rst 63 | 64 | DenseTensorized 65 | TensorizedTensor 66 | CPTensorized 67 | BlockTT 68 | 69 | .. _complex_ref: 70 | 71 | Complex Tensors 72 | =============== 73 | 74 | In theory, you can simply specify ``dtype=torch.cfloat`` in the creation of any of the tensors of tensorized matrices above, to automatically create a complex valued tensor. 75 | However, in practice, there are many issues in complex support. Distributed Data Parallelism in particular, is not supported. 76 | 77 | In TensorLy-Torch, we propose a convenient and transparent way around this: simply use ``ComplexTensor`` instead. 78 | This will store the factors of the decomposition in real form (by explicitly storing the real and imaginary parts) 79 | but will transparently return you a complex valued tensor or reconstruction. 80 | 81 | .. autosummary:: 82 | :toctree: generated 83 | :template: class.rst 84 | 85 | ComplexDenseTensor 86 | ComplexCPTensor 87 | ComplexTuckerTensor 88 | ComplexTTTensor 89 | 90 | 91 | ComplexDenseTensorized 92 | ComplexTuckerTensorized 93 | ComplexCPTensorized 94 | ComplexBlockTT 95 | 96 | 97 | You can also transparently instanciate any of these using directly the main classes, ``TensorizedTensor`` or ``FactorizedTensor`` and specifying 98 | ``factorization="ComplexCP"`` or in general ``ComplexFactorization`` with `Factorization` any of the supported decompositions. 99 | 100 | 101 | .. _init_ref: 102 | 103 | Initialization 104 | ============== 105 | 106 | .. automodule:: tltorch.factorized_tensors 107 | :no-members: 108 | :no-inherited-members: 109 | 110 | Initialization is particularly important in the context of deep learning. 111 | We provide convenient functions to directly initialize factorized tensor (i.e. their factors) 112 | such that their reconstruction follows approximately a centered Gaussian distribution. 113 | 114 | .. currentmodule:: tltorch.factorized_tensors.init 115 | 116 | .. autosummary:: 117 | :toctree: generated 118 | :template: function.rst 119 | 120 | tensor_init 121 | cp_init 122 | tucker_init 123 | tt_init 124 | block_tt_init 125 | 126 | .. _trl_ref: 127 | 128 | Tensor Regression Layers 129 | ======================== 130 | 131 | .. automodule:: tltorch.factorized_layers 132 | :no-members: 133 | :no-inherited-members: 134 | 135 | .. currentmodule:: tltorch.factorized_layers 136 | 137 | .. autosummary:: 138 | :toctree: generated 139 | :template: class.rst 140 | 141 | TRL 142 | 143 | .. _tcl_ref: 144 | 145 | Tensor Contraction Layers 146 | ========================= 147 | 148 | .. autosummary:: 149 | :toctree: generated 150 | :template: class.rst 151 | 152 | TCL 153 | 154 | .. _factorized_linear_ref: 155 | 156 | Factorized Linear Layers 157 | ======================== 158 | 159 | .. autosummary:: 160 | :toctree: generated 161 | :template: class.rst 162 | 163 | FactorizedLinear 164 | 165 | .. _factorized_conv_ref: 166 | 167 | Factorized Convolutions 168 | ======================= 169 | 170 | General N-Dimensional convolutions in Factorized forms 171 | 172 | .. autosummary:: 173 | :toctree: generated 174 | :template: class.rst 175 | 176 | FactorizedConv 177 | 178 | .. _tensor_dropout_ref: 179 | 180 | Factorized Embeddings 181 | ===================== 182 | 183 | A drop-in replacement for PyTorch's embeddings but using an efficient tensor parametrization that never reconstructs the full table. 184 | 185 | .. autosummary:: 186 | :toctree: generated 187 | :template: class.rst 188 | 189 | FactorizedEmbedding 190 | 191 | .. _tensor_dropout_ref: 192 | 193 | Tensor Dropout 194 | ============== 195 | 196 | .. currentmodule:: tltorch.tensor_hooks 197 | 198 | .. automodule:: tltorch.tensor_hooks 199 | :no-members: 200 | :no-inherited-members: 201 | 202 | These functions allow you to easily add or remove tensor dropout from tensor layers. 203 | 204 | 205 | .. autosummary:: 206 | :toctree: generated 207 | :template: function.rst 208 | 209 | tensor_dropout 210 | remove_tensor_dropout 211 | 212 | 213 | You can also use the class API below but unless you have a particular use for the classes, you should use the convenient functions provided instead. 214 | 215 | .. autosummary:: 216 | :toctree: generated 217 | :template: class.rst 218 | 219 | TensorDropout 220 | 221 | .. _tensor_lasso_ref: 222 | 223 | L1 Regularization 224 | ================= 225 | 226 | L1 Regularization on tensor modules. 227 | 228 | .. currentmodule:: tltorch.tensor_hooks 229 | 230 | .. autosummary:: 231 | :toctree: generated 232 | :template: function.rst 233 | 234 | tensor_lasso 235 | remove_tensor_lasso 236 | 237 | Utilities 238 | ========= 239 | 240 | Utility functions 241 | 242 | .. currentmodule:: tltorch.utils 243 | 244 | .. autosummary:: 245 | :toctree: generated 246 | :template: function.rst 247 | 248 | get_tensorized_shape -------------------------------------------------------------------------------- /doc/requirements_doc.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | jsmin 3 | rcssmin 4 | numpydoc 5 | sphinx-gallery 6 | myst-nb 7 | tensorly_sphinx_theme 8 | -------------------------------------------------------------------------------- /doc/user_guide/factorized_conv.rst: -------------------------------------------------------------------------------- 1 | Factorized Convolutional Layers 2 | =============================== 3 | 4 | It is possible to apply low-rank tensor factorization to convolution 5 | kernels to compress the network and reduce the number of parameters. 6 | 7 | In TensorLy-Torch, you can easily try factorized convolutions: first, let’s 8 | import the library: 9 | 10 | .. code:: python 11 | 12 | import tltorch 13 | import torch 14 | 15 | Let’s now create some random data to try our modules: we can choose the 16 | size of the convolutions. 17 | 18 | .. code:: python 19 | 20 | device='cpu' 21 | input_channels = 16 22 | output_channels = 32 23 | kernel_size = 3 24 | batch_size = 2 25 | size = 24 26 | order = 2 27 | 28 | input_shape = (batch_size, input_channels) + (size, )*order 29 | kernel_shape = (output_channels, input_channels) + (kernel_size, )*order 30 | 31 | We can create some random input data: 32 | 33 | .. code:: python 34 | 35 | data = torch.randn(input_shape, dtype=torch.float32, device=device) 36 | 37 | Creating Factorized Convolutions 38 | -------------------------------- 39 | 40 | From Random 41 | ~~~~~~~~~~~ 42 | 43 | In PyTorch, you would create a convolution as follows: 44 | 45 | .. code:: python 46 | 47 | conv = torch.nn.Conv2d(input_channels, output_channels, kernel_size) 48 | 49 | In TensorLy Torch, it is exactly the same except that factorized 50 | convolutions are by default of any order: either you specify the kernel 51 | size or your specify the order 52 | 53 | .. code:: python 54 | 55 | conv = tltorch.FactorizedConv(input_channels, output_channels, kernel_size, order=2, rank='same', factorization='cp') 56 | 57 | .. code:: python 58 | 59 | conv = torch.nn.Conv2d(input_channels, output_channels, kernel_size=3) 60 | 61 | In TensorLy-Torch, factorized convolutions can be of any order, so you 62 | have to specify the order at creation (in Pytorch, you specify it 63 | through the class name, e.g. Conv2d or Conv3d): 64 | 65 | .. code:: python 66 | 67 | fact_conv = tltorch.FactorizedConv(input_channels, output_channels, kernel_size=3, order=2, rank='same') 68 | 69 | 70 | Or, you can specify the order directly by passing a tuple as kernel_size 71 | (in which case, ``order = len(kernel_size)`` is used). 72 | 73 | .. code:: python 74 | 75 | fact_conv = tltorch.FactorizedConv(input_channels, output_channels, kernel_size=(3, 3), rank='same') 76 | 77 | From an existing Convolution 78 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 79 | 80 | You can create a Factorized convolution from an existing (PyTorch) 81 | convolution: 82 | 83 | .. code:: python 84 | 85 | fact_conv = tltorch.FactorizedConv.from_conv(conv, rank=0.5, decompose_weights=True, factorization='tucker') 86 | 87 | Efficient Convolutional Blocks 88 | ------------------------------ 89 | 90 | If you compress a convolutional kernel, you can get efficient 91 | convolutional blocks by applying tensor factorization. For instance, if 92 | you apply CP decomposition, you can get a MobileNet-v2 block: 93 | 94 | .. code:: python 95 | 96 | fact_conv = tltorch.FactorizedConv.from_conv(conv, rank=0.5, factorization='cp', implementation='mobilenet') 97 | 98 | 99 | Similarly, if you apply Tucker decomposition, you can get a ResNet 100 | BottleNeck block: 101 | 102 | .. code:: python 103 | 104 | fact_conv = tltorch.FactorizedConv.from_conv(conv, rank=0.5, factorization='tucker', implementation='factorized') 105 | -------------------------------------------------------------------------------- /doc/user_guide/factorized_embeddings.rst: -------------------------------------------------------------------------------- 1 | Factorized embedding layers 2 | =========================== 3 | 4 | In TensorLy-Torch, we also provide out-of-the-box tensorized embedding layers. 5 | 6 | Just as for the case of factorized linear, you can either create a factorized embedding from scratch, here automatically determine the 7 | input and output tensorized shapes, to have 3 dimensions each: 8 | 9 | .. code-block:: python 10 | 11 | import tltorch 12 | import torch 13 | 14 | from_embedding = tltorch.FactorizedEmbedding(num_embeddings, embedding_dim, auto_reshape=True, d=3, rank=0.4) 15 | 16 | 17 | Or, you can create it by decomposing an existing embedding layer: 18 | 19 | .. code-block:: python 20 | 21 | from_embedding = tltorch.FactorizedEmbedding.from_embedding(embedding_layer, auto_reshape=True, 22 | factorization='blocktt', n_tensorized_modes=3, rank=0.4) -------------------------------------------------------------------------------- /doc/user_guide/factorized_tensors.rst: -------------------------------------------------------------------------------- 1 | Factorized tensors 2 | ================== 3 | 4 | The core concept in TensorLy-Torch is that of *factorized tensors*. 5 | We provide a :class:`~tltorch.FactorizedTensor` class that can be used just like any `PyTorch.Tensor` but 6 | provides all tensor factorization through one, simple API. 7 | 8 | 9 | Creating factorized tensors 10 | --------------------------- 11 | 12 | You can create a new factorized tensor easily: 13 | 14 | The signature is: 15 | 16 | .. code-block:: python 17 | 18 | factorized_tensor = FactorizedTensor.new(shape, rank, factorization) 19 | 20 | For instance, to create a tensor in Tucker form, that has half the parameters of a dense (non-factorized) tensor of the same shape, you would simply write: 21 | 22 | .. code-block:: python 23 | 24 | tucker_tensor = FactorizedTensor.new(shape, rank=0.5, factorization='tucker') 25 | 26 | Since TensorLy-Torch builds on top of TensorLy, it also comes with tensor decomposition out-of-the-box. 27 | To initialize a factorized tensor in CP (Canonical-Polyadic) form, also known as Parafac, or Kruskal tensor, 28 | with 1/10th of the parameters, you can simply write: 29 | 30 | .. code-block:: python 31 | 32 | cp_tensor = FactorizedTensor.new(dense_tensor, rank=0.1, factorization='CP') 33 | 34 | 35 | Manipulating factorized tensors 36 | ------------------------------- 37 | 38 | The first thing you want to do, if you created a new tensor from scratch (by using the ``new`` method), is to initialize it, 39 | e.g. so that the element of the reconstruction approximately follow a Gaussian distribution: 40 | 41 | .. code-block:: python 42 | 43 | cp_tensor.normal_(mean=0, std=0.02) 44 | 45 | You can even use PyTorch's functions! This works: 46 | 47 | .. code-block:: python 48 | 49 | from torch.nn import init 50 | 51 | init.kaiming_normal(cp_tensor) 52 | 53 | Finally, you can index tensors directly in factorized form, which will return another factorized tensor, whenever possible! 54 | 55 | >>> cp_tensor[:2, :2] 56 | CPTensor(shape=(2, 2, 2), rank=2) 57 | 58 | If not possible, a dense tensor will be returned: 59 | 60 | 61 | >>> cp_tensor[2, 3, 1] 62 | tensor(0.0250, grad_fn=) 63 | 64 | 65 | Note how, above, indexing tracks gradients as well! 66 | 67 | Tensorized tensors 68 | ================== 69 | 70 | In addition to tensor in factorized forms, TensorLy-Torch provides out-of-the-box for **Tensorized** tensors. 71 | The most common case is that of tensorized matrices, where a matrix is first *tensorized*, i.e. reshaped into 72 | a higher-order tensor which is then decomposed and stored in factorized form. 73 | 74 | A commonly used tensorized tensor is the tensor-train matrix (also known as Matrix-Product Operator in quantum physics), 75 | or, in general, Block-TT. 76 | 77 | Creation 78 | -------- 79 | 80 | You can create one in TensorLy-Torch, from a matrix, just as easily as a regular tensor, using the :class:`tltorch.TensorizedTensor` class, 81 | with the following signature: 82 | 83 | .. code-block:: python 84 | 85 | TensorizedTensor.from_matrix(matrix, tensorized_row_shape, tensorized_column_shape, rank) 86 | 87 | where tensorized_row_shape and tensorized_column_shape indicate the shape to which to tensorize the row and column size of the given matrix. 88 | For instance, if you have a matrix of size 16x21, you could use tensorized_row_shape=(4, 4) and tensorized_column_shape=(3, 7). 89 | 90 | 91 | In general, you can tensorize any tensor, not just matrices, even with batched modes (dimensions)! 92 | 93 | .. code-block:: python 94 | 95 | tensorized_tensor = TensorizedTensor.new(tensorized_shape, rank, factorization) 96 | 97 | 98 | ``tensorized_shape`` is a nested tuple, in which an int represents a batched mode, and a tuple a tensorized mode. 99 | 100 | For instance, a batch of 5 matrices of size 16x21 could be tensorized into 101 | a batch of 5 tensorized matrices of size (4x4)x(3x7), in the BlockTT form. In code, you would do this using 102 | 103 | .. code-block:: python 104 | 105 | tensorized_tensor = TensorizedTensor.from_tensor(tensor, (5, (4, 4), (3, 7)), rank=0.7, factorization='BlockTT') 106 | 107 | You can of course tensorize any size tensors, e.g. a batch of 5 matrices of size 8x27 can be tensorized into: 108 | 109 | >>> ftt = tltorch.TensorizedTensor.new((5, (2, 2, 2), (3, 3, 3)), rank=0.5, factorization='BlockTT') 110 | 111 | This returns a tensorized tensor, stored in decomposed form: 112 | >>> ftt 113 | BlockTT(shape=[5, 8, 27], tensorized_shape=(5, (2, 2, 2), (3, 3, 3)), rank=[1, 20, 20, 1]) 114 | 115 | Manipulation 116 | ------------- 117 | 118 | As for factorized tensors, you can directly index them: 119 | 120 | >>> ftt[2] 121 | BlockTT(shape=[8, 27], tensorized_shape=[(2, 2, 2), (3, 3, 3)], rank=[1, 20, 20, 1]) 122 | 123 | >>> ftt[0, :2, :2] 124 | tensor([[-0.0009, 0.0004], 125 | [ 0.0007, 0.0003]], grad_fn=) 126 | 127 | Again, notice that gradients are tracked and all operations on factorized and tensorized tensors are back-propagatable! 128 | -------------------------------------------------------------------------------- /doc/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | .. _user_guide: 2 | 3 | 4 | User guide 5 | ========== 6 | 7 | TensorLy-Torch was written to provide out-of-the-box Tensor layers that can be readily used within any PyTorch network or code. 8 | It builds on top of `TensorLy `_ and enables anyone to use tensor methods within deep networks, even without a deep knowledge of tensor algebra. 9 | 10 | .. toctree:: 11 | 12 | factorized_tensors 13 | trl 14 | factorized_conv 15 | tensorized_linear 16 | factorized_embeddings 17 | tensor_hooks 18 | -------------------------------------------------------------------------------- /doc/user_guide/tensor_hooks.rst: -------------------------------------------------------------------------------- 1 | Tensor Hooks 2 | ============ 3 | 4 | TensorLy-Torch also makes it very easy to manipulate the tensor decomposition parametrizing a tensor module. 5 | 6 | Tensor dropout 7 | -------------- 8 | For instance, you can apply very easily tensor dropout to any tensor factorization: let's first create a simple TRL layer. 9 | 10 | .. code:: python 11 | 12 | import tltorch 13 | trl = tltorch.TRL((10, 10), (10, ), rank='same') 14 | 15 | To add tensor dropout, simply apply the helper function: 16 | 17 | .. code:: python 18 | 19 | trl = tltorch.tucker_dropout(trl.weight, p=0.5) 20 | 21 | 22 | Similarly, to remove tensor dropout: 23 | 24 | .. code:: python 25 | 26 | tltorch.remove_tucker_dropout(trl.weight) 27 | 28 | 29 | Lasso rank regularization 30 | ------------------------- 31 | 32 | Rank selection is a hard problem. One way to choose the rank while training is to apply 33 | an l1 penalty (Lasso) on the rank. 34 | 35 | This was used previously for CP decomposition, and we extended it in TensorLy-Torch to Tucker and Tensor-Train, 36 | by introducing new weights in the decomposition. 37 | 38 | To use is, you can define a regularizer object that will take care of everything. 39 | Using our previously defined TRL: 40 | 41 | .. code:: python 42 | 43 | l1_reg = tltorch.TuckerL1Regularizer(penalty=0.01) 44 | l1_reg.apply(trl.weight) 45 | x = trl(x) 46 | loss = my_loss(x) + l1_reg.loss 47 | l1_reg.res 48 | 49 | After each iteration, don't forget to reset the loss so you don't keep accumulating: 50 | 51 | .. code:: python 52 | 53 | l1_reg.reset() 54 | 55 | Initializing tensor decomposition 56 | --------------------------------- 57 | 58 | Another issue is that of initializing tensor decompositions: 59 | if you simply initialize randomly each component without care, 60 | the reconstructed (full) tensor can have arbitrarily large or small values 61 | potentially leading to gradient vanishing or exploding during training. 62 | 63 | In TensorLy-Torch, we provide a module for initialization that will 64 | properly initialize the factors of the decomposition 65 | so that the reconstruction has zero mean and the specified standard deviation! 66 | 67 | For any tensor factorization ``fact_tensor``: 68 | 69 | .. code:: python 70 | 71 | fact_tensor.normal_(0, 0.02) 72 | 73 | 74 | -------------------------------------------------------------------------------- /doc/user_guide/tensorized_linear.rst: -------------------------------------------------------------------------------- 1 | Tensorized Linear Layers 2 | ======================== 3 | 4 | Linear layers are parametrized by matrices. However, it is possible to 5 | *tensorize* them, i.e. reshape them into higher-order tensors in order 6 | to compress them. 7 | 8 | You can do this easily in TensorLy-Torch: 9 | 10 | .. code-block:: python 11 | 12 | import tltorch 13 | import torch 14 | 15 | Let’s create a batch of 4 data points of size 16 each: 16 | 17 | .. code-block:: python 18 | 19 | data = torch.randn((4, 16), dtype=torch.float32) 20 | 21 | Now, imagine you already have a linear layer: 22 | 23 | .. code-block:: python 24 | 25 | linear = torch.nn.Linear(in_features=16, 10) 26 | 27 | You can easily compress it into a tensorized linear layer: here we specify the shape to which to tensorize the weights, 28 | and use `rank=0.5`, which means automatically determine the rank so that the factorization uses approximately half the 29 | number of parameters. 30 | 31 | .. code-block:: python 32 | 33 | fact_linear = tltorch.FactorizedLinear.from_linear(linear, auto_tensorize=False, 34 | in_tensorized_features=(4, 4), out_tensorized_features=(2, 5), rank=0.5) 35 | 36 | 37 | The tensorized weights will have the following shape: 38 | 39 | .. parsed-literal:: 40 | 41 | torch.Size([4, 4, 2, 5]) 42 | 43 | 44 | Note that you can also let TensorLy-Torch automatically determine the tensorization shape. In this case we just instruct it to 45 | find ``in_tensorized_features`` and ``out_tensorized_features`` to have length `2`: 46 | 47 | .. code-block:: python 48 | 49 | fact_linear = tltorch.FactorizedLinear.from_linear(linear, auto_tensorize=True, n_tensorized_modes=2, rank=0.5) 50 | 51 | 52 | You can also create tensorized layers from scratch: 53 | 54 | .. code-block:: python 55 | 56 | fact_linear = tltorch.FactorizedLinear(in_tensorized_features=(4, 4), 57 | out_tensorized_features=(2, 5), 58 | factorization='tucker', rank=0.5) 59 | 60 | Finally, during the forward pass, you can reconstruct the full weights (``implementation='reconstructed'``) and perform a regular linear layer forward pass. 61 | ALternatively, you can let TensorLy-Torch automatically direction contract the input tensor with the *factors of the decomposition* (``implementation='factorized'``), 62 | which can be faster, particularly if you have a very small rank, e.g. very small factorization factors. -------------------------------------------------------------------------------- /doc/user_guide/trl.rst: -------------------------------------------------------------------------------- 1 | Tensor Regression Layers 2 | ======================== 3 | 4 | In deep neural networks, while convolutional layers map between 5 | high-order activation tensors, the output is still obtained through 6 | linear regression: first the activation is flattened before being passed 7 | through linear layers. 8 | 9 | This approach has several drawbacks: 10 | 11 | * Linear regression discards topological (e.g. spatial) information. 12 | * Very large number of parameters 13 | (product of the dimensions 14 | of the input tensor times the size of the output) 15 | * Lack of robustness 16 | 17 | A Tensor Regression Layer (TRL) generalizes the concept of linear 18 | regression to higher-order but alleviates the above issues. It allows to 19 | preserve and leverage multi-linear structure while being parsimonious in 20 | terms of number of parameters. The low-rank constraints also acts as an 21 | implicit reguralization on the model, typically leading to better sample 22 | efficiency and robustness. 23 | 24 | .. image:: /_static/TRL.png 25 | :align: center 26 | :width: 800 27 | 28 | 29 | TRL in TensorLy-Torch 30 | --------------------- 31 | 32 | Now, let’s see how to do this in code with TensorLy-Torch 33 | 34 | Random TRL 35 | ---------- 36 | 37 | Let’s first see how to create and train a TRL from scratch 38 | 39 | .. code:: python 40 | 41 | import tltorch 42 | import torch 43 | from torch import nn 44 | import numpy as np 45 | 46 | .. code:: python 47 | 48 | input_shape = (4, 5) 49 | output_shape = (6, 2) 50 | batch_size = 2 51 | 52 | device = 'cpu' 53 | 54 | x = torch.randn((batch_size,) + input_shape, 55 | dtype=torch.float32, device=device) 56 | 57 | .. code:: python 58 | 59 | trl = tltorch.TRL(input_shape, output_shape, rank='same') 60 | 61 | 62 | .. code:: python 63 | 64 | result = trl(x) 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pytest 4 | pytest-cov 5 | tensorly 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup, find_packages 3 | except ImportError: 4 | from distutils.core import setup, find_packages 5 | 6 | import re 7 | from pathlib import Path 8 | 9 | def version(root_path): 10 | """Returns the version taken from __init__.py 11 | 12 | Parameters 13 | ---------- 14 | root_path : pathlib.Path 15 | path to the root of the package 16 | 17 | Reference 18 | --------- 19 | https://packaging.python.org/guides/single-sourcing-package-version/ 20 | """ 21 | version_path = root_path.joinpath('tltorch', '__init__.py') 22 | with version_path.open() as f: 23 | version_file = f.read() 24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 25 | version_file, re.M) 26 | if version_match: 27 | return version_match.group(1) 28 | raise RuntimeError("Unable to find version string.") 29 | 30 | 31 | def readme(root_path): 32 | """Returns the text content of the README.rst of the package 33 | 34 | Parameters 35 | ---------- 36 | root_path : pathlib.Path 37 | path to the root of the package 38 | """ 39 | with root_path.joinpath('README.rst').open(encoding='UTF-8') as f: 40 | return f.read() 41 | 42 | 43 | root_path = Path(__file__).parent 44 | README = readme(root_path) 45 | VERSION = version(root_path) 46 | 47 | 48 | config = { 49 | 'name': 'tensorly-torch', 50 | 'packages': find_packages(), 51 | 'description': 'Deep Learning with Tensors in Python, using PyTorch and TensorLy.', 52 | 'long_description': README, 53 | 'long_description_content_type' : 'text/x-rst', 54 | 'author': 'Jean Kossaifi', 55 | 'author_email': 'jean.kossaifi@gmail.com', 56 | 'version': VERSION, 57 | 'url': 'https://github.com/tensorly/tensorly-torch', 58 | 'download_url': 'https://github.com/tensorly/tensorly-torch/tarball/' + VERSION, 59 | 'install_requires': ['numpy', 'scipy'], 60 | 'license': 'Modified BSD', 61 | 'scripts': [], 62 | 'classifiers': [ 63 | 'Topic :: Scientific/Engineering', 64 | 'License :: OSI Approved :: BSD License', 65 | 'Programming Language :: Python :: 3' 66 | ], 67 | } 68 | 69 | setup(**config) 70 | -------------------------------------------------------------------------------- /tltorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.5.0' 2 | 3 | from . import utils 4 | from . import factorized_tensors 5 | from .factorized_tensors import init 6 | from . import functional 7 | from . import factorized_layers 8 | 9 | from .factorized_layers import FactorizedLinear, FactorizedConv, TRL, TCL, FactorizedEmbedding 10 | from .factorized_tensors import FactorizedTensor, DenseTensor, CPTensor, TTTensor, TuckerTensor, tensor_init 11 | from .factorized_tensors import (TensorizedTensor, CPTensorized, BlockTT, 12 | DenseTensorized, TuckerTensorized) 13 | from .factorized_tensors import (ComplexCPTensor, ComplexTuckerTensor, 14 | ComplexTTTensor, ComplexDenseTensor) 15 | from .factorized_tensors import (ComplexCPTensorized, ComplexBlockTT, 16 | ComplexDenseTensorized, ComplexTuckerTensorized) 17 | from .tensor_hooks import (tensor_lasso, remove_tensor_lasso, 18 | tensor_dropout, remove_tensor_dropout) 19 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .factorized_convolution import FactorizedConv 2 | from .tensor_regression_layers import TRL 3 | from .tensor_contraction_layers import TCL 4 | from .factorized_linear import FactorizedLinear 5 | from .factorized_embedding import FactorizedEmbedding 6 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/factorized_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from tltorch.factorized_tensors import TensorizedTensor,tensor_init 5 | from tltorch.utils import get_tensorized_shape 6 | 7 | # Authors: Cole Hawkins 8 | # Jean Kossaifi 9 | 10 | class FactorizedEmbedding(nn.Module): 11 | """ 12 | Tensorized Embedding Layers For Efficient Model Compression 13 | Tensorized drop-in replacement for `torch.nn.Embedding` 14 | 15 | Parameters 16 | ---------- 17 | num_embeddings : int 18 | number of entries in the lookup table 19 | embedding_dim : int 20 | number of dimensions per entry 21 | auto_tensorize : bool 22 | whether to use automatic reshaping for the embedding dimensions 23 | n_tensorized_modes : int or int tuple 24 | number of reshape dimensions for both embedding table dimension 25 | tensorized_num_embeddings : int tuple 26 | tensorized shape of the first embedding table dimension 27 | tensorized_embedding_dim : int tuple 28 | tensorized shape of the second embedding table dimension 29 | factorization : str 30 | tensor type 31 | rank : int tuple or str 32 | rank of the tensor factorization 33 | """ 34 | def __init__(self, 35 | num_embeddings, 36 | embedding_dim, 37 | auto_tensorize=True, 38 | n_tensorized_modes=3, 39 | tensorized_num_embeddings=None, 40 | tensorized_embedding_dim=None, 41 | factorization='blocktt', 42 | rank=8, 43 | n_layers=1, 44 | device=None, 45 | dtype=None): 46 | super().__init__() 47 | 48 | if auto_tensorize: 49 | 50 | if tensorized_num_embeddings is not None and tensorized_embedding_dim is not None: 51 | raise ValueError( 52 | "Either use auto_tensorize or specify tensorized_num_embeddings and tensorized_embedding_dim." 53 | ) 54 | 55 | tensorized_num_embeddings, tensorized_embedding_dim = get_tensorized_shape(in_features=num_embeddings, out_features=embedding_dim, order=n_tensorized_modes, min_dim=2, verbose=False) 56 | 57 | else: 58 | #check that dimensions match factorization 59 | computed_num_embeddings = np.prod(tensorized_num_embeddings) 60 | computed_embedding_dim = np.prod(tensorized_embedding_dim) 61 | 62 | if computed_num_embeddings!=num_embeddings: 63 | raise ValueError("Tensorized embeddding number {} does not match num_embeddings argument {}".format(computed_num_embeddings,num_embeddings)) 64 | if computed_embedding_dim!=embedding_dim: 65 | raise ValueError("Tensorized embeddding dimension {} does not match embedding_dim argument {}".format(computed_embedding_dim,embedding_dim)) 66 | 67 | self.num_embeddings = num_embeddings 68 | self.embedding_dim = embedding_dim 69 | self.tensor_shape = (tensorized_num_embeddings, 70 | tensorized_embedding_dim) 71 | self.weight_shape = (self.num_embeddings, self.embedding_dim) 72 | 73 | self.n_layers = n_layers 74 | if n_layers > 1: 75 | self.tensor_shape = (n_layers, ) + self.tensor_shape 76 | self.weight_shape = (n_layers, ) + self.weight_shape 77 | 78 | self.factorization = factorization 79 | 80 | self.weight = TensorizedTensor.new(self.tensor_shape, 81 | rank=rank, 82 | factorization=self.factorization, 83 | device=device, 84 | dtype=dtype) 85 | self.reset_parameters() 86 | 87 | self.rank = self.weight.rank 88 | 89 | def reset_parameters(self): 90 | #Parameter initialization from Yin et al. 91 | #TT-Rec: Tensor Train Compression for Deep Learning Recommendation Model Embeddings 92 | target_stddev = 1 / np.sqrt(3 * self.num_embeddings) 93 | with torch.no_grad(): 94 | tensor_init(self.weight,std=target_stddev) 95 | 96 | def forward(self, input, indices=0): 97 | #to handle case where input is not 1-D 98 | output_shape = (*input.shape, self.embedding_dim) 99 | 100 | flattened_input = input.reshape(-1) 101 | 102 | if self.n_layers == 1: 103 | if indices == 0: 104 | embeddings = self.weight[flattened_input, :] 105 | else: 106 | embeddings = self.weight[indices, flattened_input, :] 107 | 108 | #CPTensorized returns CPTensorized when indexing 109 | if self.factorization.lower() == 'cp': 110 | embeddings = embeddings.to_matrix() 111 | 112 | #TuckerTensorized returns tensor not matrix, 113 | # and requires reshape not view for contiguous 114 | elif self.factorization.lower() == 'tucker': 115 | embeddings = embeddings.reshape(input.shape[0], -1) 116 | 117 | return embeddings.view(output_shape) 118 | 119 | @classmethod 120 | def from_embedding(cls, 121 | embedding_layer, 122 | rank=8, 123 | factorization='blocktt', 124 | n_tensorized_modes=2, 125 | decompose_weights=True, 126 | auto_tensorize=True, 127 | decomposition_kwargs=dict(), 128 | **kwargs): 129 | """ 130 | Create a tensorized embedding layer from a regular embedding layer 131 | 132 | Parameters 133 | ---------- 134 | embedding_layer : torch.nn.Embedding 135 | rank : int tuple or str 136 | rank of the tensor decomposition 137 | factorization : str 138 | tensor type 139 | decompose_weights: bool 140 | whether to decompose weights and use for initialization 141 | auto_tensorize: bool 142 | if True, automatically reshape dimensions for TensorizedTensor 143 | decomposition_kwargs: dict 144 | specify kwargs for the decomposition 145 | """ 146 | num_embeddings, embedding_dim = embedding_layer.weight.shape 147 | 148 | instance = cls(num_embeddings, 149 | embedding_dim, 150 | auto_tensorize=auto_tensorize, 151 | factorization=factorization, 152 | n_tensorized_modes=n_tensorized_modes, 153 | rank=rank, 154 | **kwargs) 155 | 156 | if decompose_weights: 157 | with torch.no_grad(): 158 | instance.weight.init_from_matrix(embedding_layer.weight.data, 159 | **decomposition_kwargs) 160 | 161 | else: 162 | instance.reset_parameters() 163 | 164 | return instance 165 | 166 | @classmethod 167 | def from_embedding_list(cls, 168 | embedding_layer_list, 169 | rank=8, 170 | factorization='blocktt', 171 | n_tensorized_modes=2, 172 | decompose_weights=True, 173 | auto_tensorize=True, 174 | decomposition_kwargs=dict(), 175 | **kwargs): 176 | """ 177 | Create a tensorized embedding layer from a regular embedding layer 178 | 179 | Parameters 180 | ---------- 181 | embedding_layer : torch.nn.Embedding 182 | rank : int tuple or str 183 | tensor rank 184 | factorization : str 185 | tensor decomposition to use 186 | decompose_weights: bool 187 | decompose weights and use for initialization 188 | auto_tensorize: bool 189 | automatically reshape dimensions for TensorizedTensor 190 | decomposition_kwargs: dict 191 | specify kwargs for the decomposition 192 | """ 193 | n_layers = len(embedding_layer_list) 194 | num_embeddings, embedding_dim = embedding_layer_list[0].weight.shape 195 | 196 | for i, layer in enumerate(embedding_layer_list[1:]): 197 | # Just some checks on the size of the embeddings 198 | # They need to have the same size so they can be jointly factorized 199 | new_num_embeddings, new_embedding_dim = layer.weight.shape 200 | if num_embeddings != new_num_embeddings: 201 | msg = 'All embedding layers must have the same num_embeddings.' 202 | msg += f'Yet, got embedding_layer_list[0] with num_embeddings={num_embeddings} ' 203 | msg += f' and embedding_layer_list[{i+1}] with num_embeddings={new_num_embeddings}.' 204 | raise ValueError(msg) 205 | if embedding_dim != new_embedding_dim: 206 | msg = 'All embedding layers must have the same embedding_dim.' 207 | msg += f'Yet, got embedding_layer_list[0] with embedding_dim={embedding_dim} ' 208 | msg += f' and embedding_layer_list[{i+1}] with embedding_dim={new_embedding_dim}.' 209 | raise ValueError(msg) 210 | 211 | instance = cls(num_embeddings, 212 | embedding_dim, 213 | n_tensorized_modes=n_tensorized_modes, 214 | auto_tensorize=auto_tensorize, 215 | factorization=factorization, 216 | rank=rank, 217 | n_layers=n_layers, 218 | **kwargs) 219 | 220 | if decompose_weights: 221 | weight_tensor = torch.stack([layer.weight.data for layer in embedding_layer_list]) 222 | with torch.no_grad(): 223 | instance.weight.init_from_matrix(weight_tensor, 224 | **decomposition_kwargs) 225 | 226 | else: 227 | instance.reset_parameters() 228 | 229 | return instance 230 | 231 | 232 | def get_embedding(self, indices): 233 | if self.n_layers == 1: 234 | raise ValueError('A single linear is parametrized, directly use the main class.') 235 | 236 | return SubFactorizedEmbedding(self, indices) 237 | 238 | 239 | class SubFactorizedEmbedding(nn.Module): 240 | """Class representing one of the embeddings from the mother joint factorized embedding layer 241 | 242 | Parameters 243 | ---------- 244 | 245 | Notes 246 | ----- 247 | This relies on the fact that nn.Parameters are not duplicated: 248 | if the same nn.Parameter is assigned to multiple modules, they all point to the same data, 249 | which is shared. 250 | """ 251 | def __init__(self, main_layer, indices): 252 | super().__init__() 253 | self.main_layer = main_layer 254 | self.indices = indices 255 | 256 | def forward(self, x): 257 | return self.main_layer(x, self.indices) 258 | 259 | def extra_repr(self): 260 | return '' 261 | 262 | def __repr__(self): 263 | msg = f' {self.__class__.__name__} {self.indices} from main factorized layer.' 264 | msg += f'\n{self.__class__.__name__}(' 265 | msg += self.extra_repr() 266 | msg += ')' 267 | return msg -------------------------------------------------------------------------------- /tltorch/factorized_layers/factorized_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from torch import nn 4 | import torch 5 | from torch.utils import checkpoint 6 | 7 | from ..functional import factorized_linear 8 | from ..factorized_tensors import TensorizedTensor 9 | from tltorch.utils import get_tensorized_shape 10 | 11 | # Author: Jean Kossaifi 12 | # License: BSD 3 clause 13 | 14 | 15 | class FactorizedLinear(nn.Module): 16 | """Tensorized Fully-Connected Layers 17 | 18 | The weight matrice is tensorized to a tensor of size `(*in_tensorized_features, *out_tensorized_features)`. 19 | That tensor is expressed as a low-rank tensor. 20 | 21 | During inference, the full tensor is reconstructed, and unfolded back into a matrix, 22 | used for the forward pass in a regular linear layer. 23 | 24 | Parameters 25 | ---------- 26 | in_tensorized_features : int tuple 27 | shape to which the input_features dimension is tensorized to 28 | e.g. if in_features is 8 in_tensorized_features could be (2, 2, 2) 29 | should verify prod(in_tensorized_features) = in_features 30 | out_tensorized_features : int tuple 31 | shape to which the input_features dimension is tensorized to. 32 | factorization : str, default is 'cp' 33 | rank : int tuple or str 34 | implementation : {'factorized', 'reconstructed'}, default is 'factorized' 35 | which implementation to use for forward function: 36 | - if 'factorized', will directly contract the input with the factors of the decomposition 37 | - if 'reconstructed', the full weight matrix is reconstructed from the factorized version and used for a regular linear layer forward pass. 38 | n_layers : int, default is 1 39 | number of linear layers to be parametrized with a single factorized tensor 40 | bias : bool, default is True 41 | checkpointing : bool 42 | whether to enable gradient checkpointing to save memory during training-mode forward, default is False 43 | device : PyTorch device to use, default is None 44 | dtype : PyTorch dtype, default is None 45 | """ 46 | def __init__(self, in_tensorized_features, out_tensorized_features, bias=True, 47 | factorization='cp', rank='same', implementation='factorized', n_layers=1, 48 | checkpointing=False, device=None, dtype=None): 49 | super().__init__() 50 | if factorization == 'TTM' and n_layers != 1: 51 | raise ValueError(f'TTM factorization only support single factorized layers but got n_layers={n_layers}.') 52 | 53 | self.in_features = np.prod(in_tensorized_features) 54 | self.out_features = np.prod(out_tensorized_features) 55 | self.in_tensorized_features = in_tensorized_features 56 | self.out_tensorized_features = out_tensorized_features 57 | self.tensorized_shape = out_tensorized_features + in_tensorized_features 58 | self.weight_shape = (self.out_features, self.in_features) 59 | self.input_rank = rank 60 | self.implementation = implementation 61 | self.checkpointing = checkpointing 62 | 63 | if bias: 64 | if n_layers == 1: 65 | self.bias = nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) 66 | self.has_bias = True 67 | else: 68 | self.bias = nn.Parameter(torch.empty((n_layers, self.out_features), device=device, dtype=dtype)) 69 | self.has_bias = np.zeros(n_layers) 70 | else: 71 | self.register_parameter('bias', None) 72 | 73 | self.rank = rank 74 | self.n_layers = n_layers 75 | if n_layers > 1: 76 | tensor_shape = (n_layers, out_tensorized_features, in_tensorized_features) 77 | else: 78 | tensor_shape = (out_tensorized_features, in_tensorized_features) 79 | 80 | if isinstance(factorization, TensorizedTensor): 81 | self.weight = factorization.to(device).to(dtype) 82 | else: 83 | self.weight = TensorizedTensor.new(tensor_shape, rank=rank, factorization=factorization, device=device, dtype=dtype) 84 | self.reset_parameters() 85 | 86 | self.rank = self.weight.rank 87 | 88 | def reset_parameters(self): 89 | with torch.no_grad(): 90 | self.weight.normal_(0, math.sqrt(5)/math.sqrt(self.in_features)) 91 | if self.bias is not None: 92 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 93 | bound = 1 / math.sqrt(fan_in) 94 | torch.nn.init.uniform_(self.bias, -bound, bound) 95 | 96 | def forward(self, x, indices=0): 97 | if self.n_layers == 1: 98 | if indices == 0: 99 | weight, bias = self.weight(), self.bias 100 | else: 101 | raise ValueError(f'Only one convolution was parametrized (n_layers=1) but tried to access {indices}.') 102 | 103 | elif isinstance(self.n_layers, int): 104 | if not isinstance(indices, int): 105 | raise ValueError(f'Expected indices to be in int but got indices={indices}' 106 | f', but this conv was created with n_layers={self.n_layers}.') 107 | weight = self.weight(indices) 108 | bias = self.bias[indices] if self.bias is not None else None 109 | elif len(indices) != len(self.n_layers): 110 | raise ValueError(f'Got indices={indices}, but this conv was created with n_layers={self.n_layers}.') 111 | else: 112 | weight = self.weight(indices) 113 | bias = self.bias[indices] if self.bias is not None else None 114 | 115 | def _inner_forward(x): # move weight() out to avoid register_hooks from being executed twice during recomputation 116 | return factorized_linear(x, weight, bias=bias, in_features=self.in_features, 117 | implementation=self.implementation) 118 | 119 | if self.checkpointing and x.requires_grad: 120 | x = checkpoint.checkpoint(_inner_forward, x) 121 | else: 122 | x = _inner_forward(x) 123 | return x 124 | 125 | def get_linear(self, indices): 126 | if self.n_layers == 1: 127 | raise ValueError('A single linear is parametrized, directly use the main class.') 128 | 129 | return SubFactorizedLinear(self, indices) 130 | 131 | def __getitem__(self, indices): 132 | return self.get_linear(indices) 133 | 134 | @classmethod 135 | def from_linear(cls, linear, rank='same', auto_tensorize=True, n_tensorized_modes=3, 136 | in_tensorized_features=None, out_tensorized_features=None, 137 | bias=True, factorization='CP', implementation="reconstructed", 138 | checkpointing=False, decomposition_kwargs=dict(), verbose=False): 139 | """Class method to create an instance from an existing linear layer 140 | 141 | Parameters 142 | ---------- 143 | linear : torch.nn.Linear 144 | layer to tensorize 145 | auto_tensorize : bool, default is True 146 | if True, automatically find values for the tensorized_shapes 147 | n_tensorized_modes : int, default is 3 148 | Order (number of dims) of the tensorized weights if auto_tensorize is True 149 | in_tensorized_features, out_tensorized_features : tuple 150 | shape to tensorized the factorized_weight matrix to. 151 | Must verify np.prod(tensorized_shape) == np.prod(linear.factorized_weight.shape) 152 | factorization : str, default is 'cp' 153 | implementation : str 154 | which implementation to use for forward function. support 'factorized' and 'reconstructed', default is 'factorized' 155 | checkpointing : bool 156 | whether to enable gradient checkpointing to save memory during training-mode forward, default is False 157 | rank : {rank of the decomposition, 'same', float} 158 | if float, percentage of parameters of the original factorized_weights to use 159 | if 'same' use the same number of parameters 160 | bias : bool, default is True 161 | verbose : bool, default is False 162 | """ 163 | out_features, in_features = linear.weight.shape 164 | 165 | if auto_tensorize: 166 | 167 | if out_tensorized_features is not None and in_tensorized_features is not None: 168 | raise ValueError( 169 | "Either use auto_reshape or specify out_tensorized_features and in_tensorized_features." 170 | ) 171 | 172 | in_tensorized_features, out_tensorized_features = get_tensorized_shape( 173 | in_features=in_features, out_features=out_features, order=n_tensorized_modes, min_dim=2, verbose=verbose) 174 | else: 175 | assert(out_features == np.prod(out_tensorized_features)) 176 | assert(in_features == np.prod(in_tensorized_features)) 177 | 178 | instance = cls(in_tensorized_features, out_tensorized_features, bias=bias, 179 | factorization=factorization, rank=rank, implementation=implementation, 180 | n_layers=1, checkpointing=checkpointing, device=linear.weight.device, dtype=linear.weight.dtype) 181 | 182 | instance.weight.init_from_matrix(linear.weight.data, **decomposition_kwargs) 183 | 184 | if bias and linear.bias is not None: 185 | instance.bias.data = linear.bias.data 186 | 187 | return instance 188 | 189 | @classmethod 190 | def from_linear_list(cls, linear_list, in_tensorized_features, out_tensorized_features, rank, bias=True, 191 | factorization='CP', implementation="reconstructed", checkpointing=False, decomposition_kwargs=dict(init='random')): 192 | """Class method to create an instance from an existing linear layer 193 | 194 | Parameters 195 | ---------- 196 | linear : torch.nn.Linear 197 | layer to tensorize 198 | tensorized_shape : tuple 199 | shape to tensorized the weight matrix to. 200 | Must verify np.prod(tensorized_shape) == np.prod(linear.weight.shape) 201 | factorization : str, default is 'cp' 202 | implementation : str 203 | which implementation to use for forward function. support 'factorized' and 'reconstructed', default is 'factorized' 204 | checkpointing : bool 205 | whether to enable gradient checkpointing to save memory during training-mode forward, default is False 206 | rank : {rank of the decomposition, 'same', float} 207 | if float, percentage of parameters of the original weights to use 208 | if 'same' use the same number of parameters 209 | bias : bool, default is True 210 | """ 211 | if factorization == 'TTM' and len(linear_list) > 1: 212 | raise ValueError(f'TTM factorization only support single factorized layers but got {len(linear_list)} layers.') 213 | 214 | for linear in linear_list: 215 | out_features, in_features = linear.weight.shape 216 | assert(out_features == np.prod(out_tensorized_features)) 217 | assert(in_features == np.prod(in_tensorized_features)) 218 | 219 | instance = cls(in_tensorized_features, out_tensorized_features, bias=bias, 220 | factorization=factorization, rank=rank, implementation=implementation, 221 | n_layers=len(linear_list), checkpointing=checkpointing, device=linear.weight.device, dtype=linear.weight.dtype) 222 | weight_tensor = torch.stack([layer.weight.data for layer in linear_list]) 223 | instance.weight.init_from_matrix(weight_tensor, **decomposition_kwargs) 224 | 225 | if bias: 226 | for i, layer in enumerate(linear_list): 227 | if layer.bias is not None: 228 | instance.bias.data[i] = layer.bias.data 229 | instance.has_bias[i] = 1 230 | 231 | return instance 232 | 233 | def __repr__(self): 234 | msg = (f'{self.__class__.__name__}(in_features={self.in_features}, out_features={self.out_features},' 235 | f' weight of size ({self.out_features}, {self.in_features}) tensorized to ({self.out_tensorized_features}, {self.in_tensorized_features}),' 236 | f'factorization={self.weight._name}, rank={self.rank}, implementation={self.implementation}') 237 | if self.bias is None: 238 | msg += f', bias=False' 239 | 240 | if self.n_layers == 1: 241 | msg += ', with a single layer parametrized, ' 242 | return msg 243 | 244 | msg += f' with {self.n_layers} layers jointly parametrized.' 245 | 246 | return msg 247 | 248 | 249 | class SubFactorizedLinear(nn.Module): 250 | """Class representing one of the convolutions from the mother joint factorized convolution 251 | 252 | Parameters 253 | ---------- 254 | 255 | Notes 256 | ----- 257 | This relies on the fact that nn.Parameters are not duplicated: 258 | if the same nn.Parameter is assigned to multiple modules, they all point to the same data, 259 | which is shared. 260 | """ 261 | def __init__(self, main_linear, indices): 262 | super().__init__() 263 | self.main_linear = main_linear 264 | self.indices = indices 265 | 266 | def forward(self, x): 267 | return self.main_linear(x, self.indices) 268 | 269 | def extra_repr(self): 270 | msg = f'in_features={self.main_linear.in_features}, out_features={self.main_linear.out_features}' 271 | if self.main_linear.has_bias[self.indices]: 272 | msg += ', bias=True' 273 | return msg 274 | 275 | def __repr__(self): 276 | msg = f' {self.__class__.__name__} {self.indices} from main factorized layer.' 277 | msg += f'\n{self.__class__.__name__}(' 278 | msg += self.extra_repr() 279 | msg += ')' 280 | return msg 281 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/tensor_contraction_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tensor Contraction Layers 3 | """ 4 | 5 | # Author: Jean Kossaifi 6 | # License: BSD 3 clause 7 | 8 | from tensorly import tenalg 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import init 12 | 13 | import math 14 | 15 | import tensorly as tl 16 | tl.set_backend('pytorch') 17 | 18 | 19 | class TCL(nn.Module): 20 | """Tensor Contraction Layer [1]_ 21 | 22 | Parameters 23 | ---------- 24 | input_size : int iterable 25 | shape of the input, excluding batch size 26 | rank : int list or int 27 | rank of the TCL, will also be the output-shape (excluding batch-size) 28 | if int, the same rank will be used for all dimensions 29 | verbose : int, default is 1 30 | level of verbosity 31 | 32 | References 33 | ---------- 34 | .. [1] J. Kossaifi, A. Khanna, Z. Lipton, T. Furlanello and A. Anandkumar, 35 | "Tensor Contraction Layers for Parsimonious Deep Nets," 2017 IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 36 | Honolulu, HI, 2017, pp. 1940-1946, doi: 10.1109/CVPRW.2017.243. 37 | """ 38 | def __init__(self, input_shape, rank, verbose=0, bias=False, 39 | device=None, dtype=None, **kwargs): 40 | super().__init__(**kwargs) 41 | self.verbose = verbose 42 | 43 | if isinstance(input_shape, int): 44 | self.input_shape = (input_shape, ) 45 | else: 46 | self.input_shape = tuple(input_shape) 47 | 48 | self.order = len(input_shape) 49 | 50 | if isinstance(rank, int): 51 | self.rank = (rank, )*self.order 52 | else: 53 | self.rank = tuple(rank) 54 | 55 | # Start at 1 as the batch-size is not projected 56 | self.contraction_modes = list(range(1, self.order + 1)) 57 | for i, (s, r) in enumerate(zip(self.input_shape, self.rank)): 58 | self.register_parameter(f'factor_{i}', nn.Parameter(torch.empty((r, s), device=device, dtype=dtype))) 59 | 60 | # self.factors = ParameterList(parameters=factors) 61 | if bias: 62 | self.bias = nn.Parameter( 63 | torch.empty(self.output_shape, device=device, dtype=dtype), requires_grad=True) 64 | else: 65 | self.register_parameter('bias', None) 66 | 67 | self.reset_parameters() 68 | 69 | @property 70 | def factors(self): 71 | return [getattr(self, f'factor_{i}') for i in range(self.order)] 72 | 73 | def forward(self, x): 74 | """Performs a forward pass""" 75 | x = tenalg.multi_mode_dot( 76 | x, self.factors, modes=self.contraction_modes) 77 | 78 | if self.bias is not None: 79 | return x + self.bias 80 | else: 81 | return x 82 | 83 | def reset_parameters(self): 84 | """Sets the parameters' values randomly 85 | 86 | Todo 87 | ---- 88 | This may be renamed to init_from_random for consistency with TensorModules 89 | """ 90 | for i in range(self.order): 91 | init.kaiming_uniform_(getattr(self, f'factor_{i}'), a=math.sqrt(5)) 92 | if self.bias is not None: 93 | bound = 1 / math.sqrt(self.input_shape[0]) 94 | init.uniform_(self.bias, -bound, bound) 95 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/tensor_regression_layers.py: -------------------------------------------------------------------------------- 1 | """Tensor Regression Layers 2 | """ 3 | 4 | # Author: Jean Kossaifi 5 | # License: BSD 3 clause 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import tensorly as tl 11 | tl.set_backend('pytorch') 12 | from ..functional.tensor_regression import trl 13 | 14 | from ..factorized_tensors import FactorizedTensor 15 | 16 | class TRL(nn.Module): 17 | """Tensor Regression Layers 18 | 19 | Parameters 20 | ---------- 21 | input_shape : int iterable 22 | shape of the input, excluding batch size 23 | output_shape : int iterable 24 | shape of the output, excluding batch size 25 | verbose : int, default is 0 26 | level of verbosity 27 | 28 | References 29 | ---------- 30 | .. [1] Tensor Regression Networks, Jean Kossaifi, Zachary C. Lipton, Arinbjorn Kolbeinsson, 31 | Aran Khanna, Tommaso Furlanello, Anima Anandkumar, JMLR, 2020. 32 | """ 33 | def __init__(self, input_shape, output_shape, bias=False, verbose=0, 34 | factorization='cp', rank='same', n_layers=1, 35 | device=None, dtype=None, **kwargs): 36 | super().__init__(**kwargs) 37 | self.verbose = verbose 38 | 39 | if isinstance(input_shape, int): 40 | self.input_shape = (input_shape, ) 41 | else: 42 | self.input_shape = tuple(input_shape) 43 | 44 | if isinstance(output_shape, int): 45 | self.output_shape = (output_shape, ) 46 | else: 47 | self.output_shape = tuple(output_shape) 48 | 49 | self.n_input = len(self.input_shape) 50 | self.n_output = len(self.output_shape) 51 | self.weight_shape = self.input_shape + self.output_shape 52 | self.order = len(self.weight_shape) 53 | 54 | if bias: 55 | self.bias = nn.Parameter(torch.empty(self.output_shape, device=device, dtype=dtype)) 56 | else: 57 | self.bias = None 58 | 59 | if n_layers == 1: 60 | factorization_shape = self.weight_shape 61 | elif isinstance(n_layers, int): 62 | factorization_shape = (n_layers, ) + self.weight_shape 63 | elif isinstance(n_layers, tuple): 64 | factorization_shape = n_layers + self.weight_shape 65 | 66 | if isinstance(factorization, FactorizedTensor): 67 | self.weight = factorization.to(device).to(dtype) 68 | else: 69 | self.weight = FactorizedTensor.new(factorization_shape, rank=rank, factorization=factorization, 70 | device=device, dtype=dtype) 71 | self.init_from_random() 72 | 73 | self.factorization = self.weight.name 74 | 75 | def forward(self, x): 76 | """Performs a forward pass""" 77 | return trl(x, self.weight, bias=self.bias) 78 | 79 | def init_from_random(self, decompose_full_weight=False): 80 | """Initialize the module randomly 81 | 82 | Parameters 83 | ---------- 84 | decompose_full_weight : bool, default is False 85 | if True, constructs a full weight tensor and decomposes it to initialize the factors 86 | otherwise, the factors are directly initialized randomlys 87 | """ 88 | with torch.no_grad(): 89 | if decompose_full_weight: 90 | full_weight = torch.normal(0.0, 0.02, size=self.weight_shape) 91 | self.weight.init_from_tensor(full_weight) 92 | else: 93 | self.weight.normal_() 94 | if self.bias is not None: 95 | self.bias.uniform_(-1, 1) 96 | 97 | def init_from_linear(self, linear, unsqueezed_modes=None, **kwargs): 98 | """Initialise the TRL from the weights of a fully connected layer 99 | 100 | Parameters 101 | ---------- 102 | linear : torch.nn.Linear 103 | unsqueezed_modes : int list or None 104 | For Tucker factorization, this allows to replace pooling layers and instead 105 | learn the average pooling for the specified modes ("unsqueezed_modes"). 106 | **for factorization='Tucker' only** 107 | """ 108 | if unsqueezed_modes is not None: 109 | if self.factorization != 'Tucker': 110 | raise ValueError(f'unsqueezed_modes is only supported for factorization="tucker" but factorization is {self.factorization}.') 111 | 112 | unsqueezed_modes = sorted(unsqueezed_modes) 113 | weight_shape = list(self.weight_shape) 114 | for mode in unsqueezed_modes[::-1]: 115 | if mode == 0: 116 | raise ValueError(f'Cannot learn pooling for mode-0 (channels).') 117 | if mode > self.n_input: 118 | msg = 'Can only learn pooling for the input tensor. ' 119 | msg += f'The input has only {self.n_input} modes, yet got a unsqueezed_mode for mode {mode}.' 120 | raise ValueError(msg) 121 | 122 | weight_shape.pop(mode) 123 | kwargs['unsqueezed_modes'] = unsqueezed_modes 124 | else: 125 | weight_shape = self.weight_shape 126 | 127 | with torch.no_grad(): 128 | weight = torch.t(linear.weight).contiguous().view(weight_shape) 129 | 130 | self.weight.init_from_tensor(weight, **kwargs) 131 | if self.bias is not None: 132 | self.bias.data = linear.bias.data 133 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/factorized_layers/tests/__init__.py -------------------------------------------------------------------------------- /tltorch/factorized_layers/tests/test_factorized_convolution.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | import tensorly as tl 6 | tl.set_backend('pytorch') 7 | from ... import FactorizedTensor 8 | from tltorch.factorized_layers import FactorizedConv 9 | from tensorly.testing import assert_array_almost_equal 10 | 11 | from ..factorized_convolution import (FactorizedConv, kernel_shape_to_factorization_shape, tensor_to_kernel) 12 | 13 | @pytest.mark.parametrize('factorization, implementation', 14 | [('CP', 'factorized'), ('CP', 'reconstructed'), ('CP', 'mobilenet'), 15 | ('Tucker', 'factorized'), ('Tucker', 'reconstructed'), 16 | ('TT', 'factorized'), ('TT', 'reconstructed')]) 17 | def test_single_conv(factorization, implementation, 18 | order=2, rank=0.5, rng=None, input_channels=4, output_channels=5, 19 | kernel_size=3, batch_size=1, activation_size=(8, 7), device='cpu'): 20 | rng = tl.check_random_state(rng) 21 | input_shape = (batch_size, input_channels) + activation_size 22 | kernel_shape = (output_channels, input_channels) + (kernel_size, )*order 23 | 24 | if rank is None: 25 | rank = max(kernel_shape) 26 | 27 | if order == 1: 28 | FullConv = nn.Conv1d 29 | elif order == 2: 30 | FullConv = nn.Conv2d 31 | elif order == 3: 32 | FullConv = nn.Conv3d 33 | 34 | # Factorized input tensor 35 | factorization_shape = kernel_shape_to_factorization_shape(factorization, kernel_shape) 36 | decomposed_weights = FactorizedTensor.new(shape=factorization_shape, rank=rank, factorization=factorization).normal_(0, 1) 37 | full_weights = tensor_to_kernel(factorization, decomposed_weights.to_tensor().to(device)) 38 | data = torch.tensor(rng.random_sample(input_shape), dtype=torch.float32).to(device) 39 | 40 | # PyTorch regular Conv 41 | conv = FullConv(input_channels, output_channels, kernel_size, bias=True, padding=1) 42 | true_bias = conv.bias.data 43 | conv.weight.data = full_weights 44 | 45 | # Factorized conv 46 | fact_conv = FactorizedConv.from_factorization(decomposed_weights, implementation=implementation, 47 | bias=true_bias, padding=1) 48 | 49 | # First check it has the correct implementation 50 | msg = f'Created implementation={implementation} but {fact_conv.implementation} was created.' 51 | assert fact_conv.implementation == implementation, msg 52 | 53 | # Check that it gives the same result as the full conv 54 | true_res = conv(data) 55 | res = fact_conv(data) 56 | msg = f'{fact_conv.__class__.__name__} does not give same result as {FullConv.__class__.__name__}.' 57 | assert_array_almost_equal(true_res, res, decimal=4, err_msg=msg) 58 | 59 | # Check that the parameters of the decomposition are transposed back correctly 60 | decomposed_weights, bias = fact_conv.weight, fact_conv.bias 61 | rec = tensor_to_kernel(factorization, decomposed_weights.to_tensor()) 62 | msg = msg = f'{fact_conv.__class__.__name__} does not return the decomposition it was constructed with.' 63 | assert_array_almost_equal(rec, full_weights, err_msg=msg) 64 | msg = msg = f'{fact_conv.__class__.__name__} does not return the bias it was constructed with.' 65 | assert_array_almost_equal(bias, true_bias, err_msg=msg) 66 | 67 | conv = FullConv(input_channels, output_channels, kernel_size, bias=True, padding=1) 68 | conv.weight.data.uniform_(-1, 1) 69 | fact_conv = FactorizedConv.from_conv(conv, rank=30, factorization=factorization)#, decomposition_kwargs=dict(init='svd', l2_reg=1e-5)) 70 | true_res = conv(data) 71 | res = fact_conv(data) 72 | msg = f'{fact_conv.__class__.__name__} does not give same result as {FullConv.__class__.__name__}.' 73 | assert_array_almost_equal(true_res, res, decimal=2, err_msg=msg) 74 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/tests/test_factorized_embedding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from ..factorized_embedding import FactorizedEmbedding 5 | 6 | import tensorly as tl 7 | tl.set_backend('pytorch') 8 | from tensorly import testing 9 | 10 | #@pytest.mark.parametrize('factorization', ['BlockTT']) 11 | @pytest.mark.parametrize('factorization', ['CP','Tucker', 'BlockTT']) 12 | @pytest.mark.parametrize('dims', [(256,16), (1000,32)]) 13 | def test_FactorizedEmbedding(factorization,dims): 14 | NUM_EMBEDDINGS, EMBEDDING_DIM = dims 15 | 16 | #create factorized embedding 17 | factorized_embedding = FactorizedEmbedding(NUM_EMBEDDINGS, EMBEDDING_DIM, factorization=factorization) 18 | 19 | #make test embedding of same shape and same weight 20 | test_embedding = torch.nn.Embedding(factorized_embedding.weight.shape[0], factorized_embedding.weight.shape[1]) 21 | test_embedding.weight.data.copy_(factorized_embedding.weight.to_matrix().detach()) 22 | 23 | #create batch and test using all entries (shuffled since entries may not be sorted) 24 | batch = torch.randperm(NUM_EMBEDDINGS)#.view(-1,1) 25 | normal_embed = test_embedding(batch) 26 | factorized_embed = factorized_embedding(batch) 27 | testing.assert_array_almost_equal(normal_embed,factorized_embed,decimal=2) 28 | 29 | #split batch into tensor with first dimension 3 30 | batch = torch.randperm(NUM_EMBEDDINGS) 31 | split_size = NUM_EMBEDDINGS//5 32 | 33 | split_batch = [batch[:1*split_size],batch[1*split_size:2*split_size],batch[3*split_size:4*split_size]] 34 | 35 | split_batch = torch.stack(split_batch,0) 36 | 37 | normal_embed = test_embedding(split_batch) 38 | factorized_embed = factorized_embedding(split_batch) 39 | testing.assert_array_almost_equal(normal_embed,factorized_embed,decimal=2) 40 | 41 | #BlockTT has no init_from_matrix, so skip that test 42 | if factorization=='BlockTT': 43 | return 44 | 45 | del factorized_embedding 46 | 47 | #init from test layer which is low rank 48 | factorized_embedding = FactorizedEmbedding.from_embedding(test_embedding,factorization=factorization,rank=8) 49 | 50 | #test using same batch as before, only test that shapes match 51 | normal_embed = test_embedding(batch) 52 | factorized_embed = factorized_embedding(batch) 53 | testing.assert_array_almost_equal(normal_embed.shape,factorized_embed.shape,decimal=2) 54 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/tests/test_factorized_linear.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import nn 3 | from ..factorized_linear import FactorizedLinear 4 | from ... import TensorizedTensor 5 | 6 | import tensorly as tl 7 | tl.set_backend('pytorch') 8 | from tensorly import testing 9 | 10 | @pytest.mark.parametrize('factorization', ['CP', 'Tucker', 'BlockTT']) 11 | def test_FactorizedLinear(factorization): 12 | random_state = 12345 13 | rng = tl.check_random_state(random_state) 14 | batch_size = 2 15 | in_features = 9 16 | in_shape = (3, 3) 17 | out_features = 16 18 | out_shape = (4, 4) 19 | data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32) 20 | 21 | # Creat from a tensor factorization 22 | tensor = TensorizedTensor.new((out_shape, in_shape), rank='same', factorization=factorization) 23 | tensor.normal_() 24 | fc = nn.Linear(in_features, out_features, bias=True) 25 | fc.weight.data = tensor.to_matrix() 26 | tfc = FactorizedLinear(in_shape, out_shape, rank='same', factorization=tensor, bias=True) 27 | tfc.bias.data = fc.bias 28 | res_fc = fc(data) 29 | res_tfc = tfc(data) 30 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2) 31 | 32 | # Decompose an existing layer 33 | fc = nn.Linear(in_features, out_features, bias=True) 34 | tfc = FactorizedLinear.from_linear(fc, auto_tensorize=False, 35 | in_tensorized_features=(3, 3), out_tensorized_features=(4, 4), rank=34, bias=True) 36 | res_fc = fc(data) 37 | res_tfc = tfc(data) 38 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2) 39 | 40 | # Decompose an existing layer, automatically determine tensorization shape 41 | fc = nn.Linear(in_features, out_features, bias=True) 42 | tfc = FactorizedLinear.from_linear(fc, auto_tensorize=True, n_tensorized_modes=2, rank=34, bias=True) 43 | res_fc = fc(data) 44 | res_tfc = tfc(data) 45 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2) 46 | 47 | # Multi-layer factorization 48 | fc1 = nn.Linear(in_features, out_features, bias=True) 49 | fc2 = nn.Linear(in_features, out_features, bias=True) 50 | tfc = FactorizedLinear.from_linear_list([fc1, fc2], in_tensorized_features=in_shape, out_tensorized_features=out_shape, rank=38, bias=True) 51 | ## Test first parametrized conv 52 | res_fc = fc1(data) 53 | res_tfc = tfc[0](data) 54 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2) 55 | ## Test second parametrized conv 56 | res_fc = fc2(data) 57 | res_tfc = tfc[1](data) 58 | testing.assert_array_almost_equal(res_fc, res_tfc, decimal=2) -------------------------------------------------------------------------------- /tltorch/factorized_layers/tests/test_tensor_contraction_layers.py: -------------------------------------------------------------------------------- 1 | from ..tensor_contraction_layers import TCL 2 | import tensorly as tl 3 | from tensorly import testing 4 | tl.set_backend('pytorch') 5 | 6 | 7 | def test_tcl(): 8 | random_state = 12345 9 | rng = tl.check_random_state(random_state) 10 | batch_size = 2 11 | in_shape = (4, 5, 6) 12 | out_shape = (2, 3, 5) 13 | data = tl.tensor(rng.random_sample((batch_size, ) + in_shape), dtype=tl.float32) 14 | 15 | expected_shape = (batch_size, ) + out_shape 16 | tcl = TCL(input_shape=in_shape, rank=out_shape, bias=False) 17 | res = tcl(data) 18 | testing.assert_(res.shape==expected_shape, 19 | msg=f'Wrong output size of TCL, expected {expected_shape} but got {res.shape}') 20 | -------------------------------------------------------------------------------- /tltorch/factorized_layers/tests/test_trl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torch import nn 4 | 5 | from ..tensor_regression_layers import TRL 6 | from ... import FactorizedTensor 7 | 8 | import tensorly as tl 9 | tl.set_backend('pytorch') 10 | from tensorly import tenalg 11 | from tensorly import random 12 | from tensorly import testing 13 | 14 | import pytest 15 | 16 | 17 | def optimize_trl(trl, loader, lr=0.005, n_epoch=200, verbose=False): 18 | """Function that takes as input a TRL, dataset and optimizes the TRL 19 | 20 | Parameters 21 | ---------- 22 | trl : tltorch.TRL 23 | loader : Pytorch dataset, returning batches (batch, labels) 24 | lr : float, default is 0.1 25 | learning rate 26 | n_epoch : int, default is 100 27 | verbose : bool, default is False 28 | level of verbosity 29 | 30 | Returns 31 | ------- 32 | (trl, objective function loss) 33 | """ 34 | optimizer = torch.optim.Adam(trl.parameters(), lr=lr, weight_decay=1e-7) 35 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, cooldown=20, factor=0.5, verbose=verbose) 36 | 37 | for epoch in range(n_epoch): 38 | for i, (sample_batch, label_batch) in enumerate(loader): 39 | 40 | # Important: do not forget to reset the gradients 41 | optimizer.zero_grad() 42 | 43 | # Reconstruct the tensor from the decomposed form 44 | pred = trl(sample_batch) 45 | 46 | # squared l2 loss 47 | loss = tl.norm(pred - label_batch, 2) 48 | 49 | loss.backward() 50 | optimizer.step() 51 | scheduler.step(loss) 52 | 53 | if not i or epoch % 10 == 0: 54 | if verbose: 55 | print(f"Epoch {epoch},. loss: {loss.item()}") 56 | 57 | return trl, loss.item() 58 | 59 | 60 | @pytest.mark.parametrize('factorization, true_rank, rank', 61 | [('Tucker', 0.5, 0.6), #(5, 5, 5, 4, 1, 3)), 62 | ('CP', 0.5, 0.6), 63 | ('TT', 0.5, 0.6)]) 64 | def test_trl(factorization, true_rank, rank): 65 | """Test for the TRL 66 | """ 67 | # Parameter of the experiment 68 | input_shape = (3, 4) 69 | output_shape = (2, 1) 70 | batch_size = 500 71 | 72 | # fix the random seed for reproducibility 73 | random_state = 12345 74 | 75 | rng = tl.check_random_state(random_state) 76 | tol = 0.08 77 | 78 | # Generate a random tensor 79 | samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1), dtype=tl.float32) 80 | true_bias = tl.tensor(rng.uniform(size=output_shape), dtype=tl.float32) 81 | 82 | with torch.no_grad(): 83 | true_weight = FactorizedTensor.new(shape=input_shape+output_shape, 84 | rank=true_rank, factorization=factorization) 85 | true_weight = true_weight.normal_(0, 0.1).to_tensor() 86 | labels = tenalg.inner(samples, true_weight, n_modes=len(input_shape)) + true_bias 87 | 88 | dataset = data.TensorDataset(samples, labels) 89 | loader = data.DataLoader(dataset, batch_size=32) 90 | 91 | trl = TRL(input_shape=input_shape, output_shape=output_shape, factorization=factorization, rank=rank, bias=True) 92 | trl.weight.normal_(0, 0.1) # TODO: do this through reset_parameters 93 | with torch.no_grad(): 94 | trl.bias.data.uniform_(-0.01, 0.01) 95 | 96 | print(f'Testing {trl.__class__.__name__}.') 97 | #trl.init_from_random(decompose_full_weight=True) 98 | trl, _ = optimize_trl(trl, loader, verbose=False) 99 | 100 | with torch.no_grad(): 101 | rec_weights = trl.weight.to_tensor() 102 | rec_loss = tl.norm(rec_weights - true_weight)/tl.norm(true_weight) 103 | 104 | with torch.no_grad(): 105 | bias_rec_loss = tl.norm(trl.bias - true_bias)/tl.norm(true_bias) 106 | 107 | testing.assert_(rec_loss <= tol, msg=f'Rec_loss of the weights={rec_loss} higher than tolerance={tol}') 108 | testing.assert_(bias_rec_loss <= tol, msg=f'Rec_loss of the bias={bias_rec_loss} higher than tolerance={tol}') 109 | 110 | 111 | @pytest.mark.parametrize('order', [2, 3]) 112 | @pytest.mark.parametrize('project_input', [False, True]) 113 | @pytest.mark.parametrize('learn_pool', [True, False]) 114 | def test_TuckerTRL(order, project_input, learn_pool): 115 | """Test for Tucker TRL 116 | 117 | Here, we test specifically for init from fully-connected layer 118 | (both when learning the pooling and not). 119 | 120 | We also test that projecting the input or not doesn't change the results 121 | """ 122 | in_features = 10 123 | out_features = 12 124 | batch_size = 2 125 | spatial_size = 4 126 | in_rank = 10 127 | out_rank = 12 128 | order= 2 129 | 130 | # fix the random seed for reproducibility and create random input 131 | random_state = 12345 132 | rng = tl.check_random_state(random_state) 133 | data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order), dtype=tl.float32) 134 | 135 | # Build a simple net with avg-pool, flatten + fully-connected 136 | if order == 2: 137 | pool = nn.AdaptiveAvgPool2d((1, 1)) 138 | else: 139 | pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 140 | fc = nn.Linear(in_features, out_features, bias=False) 141 | 142 | def net(data): 143 | x = pool(data) 144 | x = x.squeeze() 145 | x = fc(x) 146 | return x 147 | 148 | res_fc = net(tl.copy(data)) 149 | 150 | # A replacement TRL 151 | out_shape = (out_features, ) 152 | 153 | if learn_pool: 154 | # Learn the average pool as part of the TRL 155 | in_shape = (in_features, ) + (spatial_size, )*order 156 | rank = (in_rank, ) + (1, )*order + (out_rank, ) 157 | unsqueezed_modes = list(range(1, order+1)) 158 | else: 159 | in_shape = (in_features, ) 160 | rank = (in_rank, out_rank) 161 | unsqueezed_modes = None 162 | data = pool(data).squeeze() 163 | 164 | trl = TRL(in_shape, out_shape, rank=rank, factorization='tucker') 165 | trl.init_from_linear(fc, unsqueezed_modes=unsqueezed_modes) 166 | res_trl = trl(data) 167 | 168 | testing.assert_array_almost_equal(res_fc, res_trl) 169 | 170 | 171 | @pytest.mark.parametrize('factorization', ['CP', 'TT']) 172 | @pytest.mark.parametrize('bias', [True, False]) 173 | def test_TRL_from_linear(factorization, bias): 174 | """Test for CP and TT TRL 175 | 176 | Here, we test specifically for init from fully-connected layer 177 | """ 178 | in_features = 10 179 | out_features = 12 180 | batch_size = 2 181 | 182 | # fix the random seed for reproducibility and create random input 183 | random_state = 12345 184 | rng = tl.check_random_state(random_state) 185 | data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32) 186 | fc = nn.Linear(in_features, out_features, bias=bias) 187 | res_fc = fc(tl.copy(data)) 188 | trl = TRL((in_features, ), (out_features, ), rank=10, bias=bias, factorization=factorization) 189 | trl.init_from_linear(fc) 190 | res_trl = trl(data) 191 | 192 | testing.assert_array_almost_equal(res_fc, res_trl, decimal=2) 193 | 194 | -------------------------------------------------------------------------------- /tltorch/factorized_tensors/__init__.py: -------------------------------------------------------------------------------- 1 | from .factorized_tensors import (CPTensor, TuckerTensor, TTTensor, 2 | DenseTensor, FactorizedTensor) 3 | from .tensorized_matrices import (TensorizedTensor, CPTensorized, BlockTT, 4 | DenseTensorized, TuckerTensorized) 5 | from .complex_factorized_tensors import (ComplexCPTensor, ComplexTuckerTensor, 6 | ComplexTTTensor, ComplexDenseTensor) 7 | from .complex_tensorized_matrices import (ComplexCPTensorized, ComplexBlockTT, 8 | ComplexDenseTensorized, ComplexTuckerTensorized) 9 | from .init import tensor_init, cp_init, tucker_init, tt_init, block_tt_init 10 | -------------------------------------------------------------------------------- /tltorch/factorized_tensors/complex_factorized_tensors.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | import tensorly as tl 6 | tl.set_backend('pytorch') 7 | from tltorch.factorized_tensors.factorized_tensors import TuckerTensor, CPTensor, TTTensor, DenseTensor 8 | from tltorch.utils.parameter_list import FactorList, ComplexFactorList 9 | 10 | 11 | # Author: Jean Kossaifi 12 | # License: BSD 3 clause 13 | 14 | class ComplexHandler(): 15 | def __setattr__(self, key, value): 16 | if isinstance(value, (FactorList)): 17 | value = ComplexFactorList(value) 18 | super().__setattr__(key, value) 19 | 20 | elif isinstance(value, nn.Parameter): 21 | self.register_parameter(key, value) 22 | elif torch.is_tensor(value): 23 | self.register_buffer(key, value) 24 | else: 25 | super().__setattr__(key, value) 26 | 27 | def __getattr__(self, key): 28 | value = super().__getattr__(key) 29 | if torch.is_tensor(value): 30 | value = torch.view_as_complex(value) 31 | return value 32 | 33 | def register_parameter(self, key, value): 34 | value = nn.Parameter(torch.view_as_real(value)) 35 | super().register_parameter(key, value) 36 | 37 | def register_buffer(self, key, value): 38 | value = torch.view_as_real(value) 39 | super().register_buffer(key, value) 40 | 41 | 42 | class ComplexDenseTensor(ComplexHandler, DenseTensor, name='ComplexDense'): 43 | """Complex Dense Factorization 44 | """ 45 | @classmethod 46 | def new(cls, shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): 47 | return super().new(shape, rank, device=device, dtype=dtype, **kwargs) 48 | 49 | class ComplexTuckerTensor(ComplexHandler, TuckerTensor, name='ComplexTucker'): 50 | """Complex Tucker Factorization 51 | """ 52 | @classmethod 53 | def new(cls, shape, rank='same', fixed_rank_modes=None, 54 | device=None, dtype=torch.cfloat, **kwargs): 55 | return super().new(shape, rank, fixed_rank_modes=fixed_rank_modes, 56 | device=device, dtype=dtype, **kwargs) 57 | 58 | class ComplexTTTensor(ComplexHandler, TTTensor, name='ComplexTT'): 59 | """Complex TT Factorization 60 | """ 61 | @classmethod 62 | def new(cls, shape, rank='same', fixed_rank_modes=None, 63 | device=None, dtype=torch.cfloat, **kwargs): 64 | return super().new(shape, rank, 65 | device=device, dtype=dtype, **kwargs) 66 | 67 | class ComplexCPTensor(ComplexHandler, CPTensor, name='ComplexCP'): 68 | """Complex CP Factorization 69 | """ 70 | @classmethod 71 | def new(cls, shape, rank='same', fixed_rank_modes=None, 72 | device=None, dtype=torch.cfloat, **kwargs): 73 | return super().new(shape, rank, 74 | device=device, dtype=dtype, **kwargs) 75 | -------------------------------------------------------------------------------- /tltorch/factorized_tensors/complex_tensorized_matrices.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorly as tl 3 | tl.set_backend('pytorch') 4 | from tltorch.factorized_tensors.tensorized_matrices import TuckerTensorized, DenseTensorized, CPTensorized, BlockTT 5 | from .complex_factorized_tensors import ComplexHandler 6 | 7 | # Author: Jean Kossaifi 8 | # License: BSD 3 clause 9 | 10 | 11 | class ComplexDenseTensorized(ComplexHandler, DenseTensorized, name='ComplexDense'): 12 | """Complex DenseTensorized Factorization 13 | """ 14 | _complex_params = ['tensor'] 15 | 16 | @classmethod 17 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): 18 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs) 19 | 20 | class ComplexTuckerTensorized(ComplexHandler, TuckerTensorized, name='ComplexTucker'): 21 | """Complex TuckerTensorized Factorization 22 | """ 23 | _complex_params = ['core', 'factors'] 24 | 25 | @classmethod 26 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): 27 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs) 28 | 29 | class ComplexBlockTT(ComplexHandler, BlockTT, name='ComplexTT'): 30 | """Complex BlockTT Factorization 31 | """ 32 | _complex_params = ['factors'] 33 | 34 | @classmethod 35 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): 36 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs) 37 | 38 | class ComplexCPTensorized(ComplexHandler, CPTensorized, name='ComplexCP'): 39 | """Complex Tensorized CP Factorization 40 | """ 41 | _complex_params = ['weights', 'factors'] 42 | 43 | @classmethod 44 | def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): 45 | return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs) 46 | -------------------------------------------------------------------------------- /tltorch/factorized_tensors/factorized_tensors.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | import tensorly as tl 8 | tl.set_backend('pytorch') 9 | from tensorly import tenalg 10 | from tensorly.decomposition import parafac, tucker, tensor_train 11 | 12 | from .core import FactorizedTensor 13 | from ..utils import FactorList 14 | 15 | 16 | # Author: Jean Kossaifi 17 | # License: BSD 3 clause 18 | 19 | 20 | class DenseTensor(FactorizedTensor, name='Dense'): 21 | """Dense tensor 22 | """ 23 | def __init__(self, tensor, shape=None, rank=None): 24 | super().__init__() 25 | if shape is not None and rank is not None: 26 | self.shape, self.rank = shape, rank 27 | else: 28 | self.shape = tensor.shape 29 | self.rank = None 30 | self.order = len(self.shape) 31 | 32 | if isinstance(tensor, nn.Parameter): 33 | self.register_parameter('tensor', tensor) 34 | else: 35 | self.register_buffer('tensor', tensor) 36 | 37 | @classmethod 38 | def new(cls, shape, rank=None, device=None, dtype=None, **kwargs): 39 | # Register the parameters 40 | tensor = nn.Parameter(torch.empty(shape, device=device, dtype=dtype)) 41 | 42 | return cls(tensor) 43 | 44 | @classmethod 45 | def from_tensor(cls, tensor, rank='same', **kwargs): 46 | return cls(nn.Parameter(tl.copy(tensor))) 47 | 48 | def init_from_tensor(self, tensor, l2_reg=1e-5, **kwargs): 49 | with torch.no_grad(): 50 | self.tensor = nn.Parameter(tl.copy(tensor)) 51 | return self 52 | 53 | @property 54 | def decomposition(self): 55 | return self.tensor 56 | 57 | def to_tensor(self): 58 | return self.tensor 59 | 60 | def normal_(self, mean=0, std=1): 61 | with torch.no_grad(): 62 | self.tensor.data.normal_(mean, std) 63 | return self 64 | 65 | def __getitem__(self, indices): 66 | return self.__class__(self.tensor[indices]) 67 | 68 | 69 | class CPTensor(FactorizedTensor, name='CP'): 70 | """CP Factorization 71 | 72 | Parameters 73 | ---------- 74 | weights 75 | factors 76 | shape 77 | rank 78 | """ 79 | def __init__(self, weights, factors, shape=None, rank=None): 80 | super().__init__() 81 | if shape is not None and rank is not None: 82 | self.shape, self.rank = shape, rank 83 | else: 84 | self.shape, self.rank = tl.cp_tensor._validate_cp_tensor((weights, factors)) 85 | self.order = len(self.shape) 86 | 87 | # self.weights = weights 88 | if isinstance(weights, nn.Parameter): 89 | self.register_parameter('weights', weights) 90 | else: 91 | self.register_buffer('weights', weights) 92 | 93 | self.factors = FactorList(factors) 94 | 95 | @classmethod 96 | def new(cls, shape, rank, device=None, dtype=None, **kwargs): 97 | rank = tl.cp_tensor.validate_cp_rank(shape, rank) 98 | 99 | # Register the parameters 100 | weights = nn.Parameter(torch.empty(rank, device=device, dtype=dtype)) 101 | # Avoid the issues with ParameterList 102 | factors = [nn.Parameter(torch.empty((s, rank), device=device, dtype=dtype)) for s in shape] 103 | 104 | return cls(weights, factors) 105 | 106 | @classmethod 107 | def from_tensor(cls, tensor, rank='same', **kwargs): 108 | shape = tensor.shape 109 | rank = tl.cp_tensor.validate_cp_rank(shape, rank) 110 | dtype = tensor.dtype 111 | 112 | with torch.no_grad(): 113 | weights, factors = parafac(tensor.to(torch.float64), rank, **kwargs) 114 | 115 | return cls(nn.Parameter(weights.to(dtype).contiguous()), [nn.Parameter(f.to(dtype).contiguous()) for f in factors]) 116 | 117 | def init_from_tensor(self, tensor, l2_reg=1e-5, **kwargs): 118 | with torch.no_grad(): 119 | weights, factors = parafac(tensor, self.rank, l2_reg=l2_reg, **kwargs) 120 | 121 | self.weights = nn.Parameter(weights.contiguous()) 122 | self.factors = FactorList([nn.Parameter(f.contiguous()) for f in factors]) 123 | return self 124 | 125 | @property 126 | def decomposition(self): 127 | return self.weights, self.factors 128 | 129 | def to_tensor(self): 130 | return tl.cp_to_tensor(self.decomposition) 131 | 132 | def normal_(self, mean=0, std=1): 133 | super().normal_(mean, std) 134 | std_factors = (std/math.sqrt(self.rank))**(1/self.order) 135 | 136 | with torch.no_grad(): 137 | self.weights.fill_(1) 138 | for factor in self.factors: 139 | factor.data.normal_(0, std_factors) 140 | return self 141 | 142 | def __getitem__(self, indices): 143 | if isinstance(indices, int): 144 | # Select one dimension of one mode 145 | mixing_factor, *factors = self.factors 146 | weights = self.weights*mixing_factor[indices, :] 147 | return self.__class__(weights, factors) 148 | 149 | elif isinstance(indices, slice): 150 | # Index part of a factor 151 | mixing_factor, *factors = self.factors 152 | factors = [mixing_factor[indices, :], *factors] 153 | weights = self.weights 154 | return self.__class__(weights, factors) 155 | 156 | else: 157 | # Index multiple dimensions 158 | factors = self.factors 159 | index_factors = [] 160 | weights = self.weights 161 | for index in indices: 162 | if index is Ellipsis: 163 | raise ValueError(f'Ellipsis is not yet supported, yet got indices={indices} which contains one.') 164 | 165 | mixing_factor, *factors = factors 166 | if isinstance(index, (np.integer, int)): 167 | if factors or index_factors: 168 | weights = weights*mixing_factor[index, :] 169 | else: 170 | # No factors left 171 | return tl.sum(weights*mixing_factor[index, :]) 172 | else: 173 | index_factors.append(mixing_factor[index, :]) 174 | 175 | return self.__class__(weights, index_factors + factors) 176 | # return self.__class__(*tl.cp_indexing(self.weights, self.factors, indices)) 177 | 178 | def transduct(self, new_dim, mode=0, new_factor=None): 179 | """Transduction adds a new dimension to the existing factorization 180 | 181 | Parameters 182 | ---------- 183 | new_dim : int 184 | dimension of the new mode to add 185 | mode : where to insert the new dimension, after the channels, default is 0 186 | by default, insert the new dimensions before the existing ones 187 | (e.g. add time before height and width) 188 | 189 | Returns 190 | ------- 191 | self 192 | """ 193 | factors = self.factors 194 | # Important: don't increment the order before accessing factors which uses order! 195 | self.order += 1 196 | self.shape = self.shape[:mode] + (new_dim,) + self.shape[mode:] 197 | 198 | if new_factor is None: 199 | new_factor = torch.ones(new_dim, self.rank)#/new_dim 200 | 201 | factors.insert(mode, nn.Parameter(new_factor.to(factors[0].device).contiguous())) 202 | self.factors = FactorList(factors) 203 | 204 | return self 205 | 206 | 207 | class TuckerTensor(FactorizedTensor, name='Tucker'): 208 | """Tucker Factorization 209 | 210 | Parameters 211 | ---------- 212 | core 213 | factors 214 | shape 215 | rank 216 | """ 217 | def __init__(self, core, factors, shape=None, rank=None): 218 | super().__init__() 219 | if shape is not None and rank is not None: 220 | self.shape, self.rank = shape, rank 221 | else: 222 | self.shape, self.rank = tl.tucker_tensor._validate_tucker_tensor((core, factors)) 223 | 224 | self.order = len(self.shape) 225 | # self.core = core 226 | if isinstance(core, nn.Parameter): 227 | self.register_parameter('core', core) 228 | else: 229 | self.register_buffer('core', core) 230 | 231 | self.factors = FactorList(factors) 232 | 233 | @classmethod 234 | def new(cls, shape, rank, fixed_rank_modes=None, 235 | device=None, dtype=None, **kwargs): 236 | rank = tl.tucker_tensor.validate_tucker_rank(shape, rank, fixed_modes=fixed_rank_modes) 237 | 238 | # Register the parameters 239 | core = nn.Parameter(torch.empty(rank, device=device, dtype=dtype)) 240 | # Avoid the issues with ParameterList 241 | factors = [nn.Parameter(torch.empty((s, r), device=device, dtype=dtype)) for (s, r) in zip(shape, rank)] 242 | 243 | return cls(core, factors) 244 | 245 | @classmethod 246 | def from_tensor(cls, tensor, rank='same', fixed_rank_modes=None, **kwargs): 247 | shape = tensor.shape 248 | rank = tl.tucker_tensor.validate_tucker_rank(shape, rank, fixed_modes=fixed_rank_modes) 249 | 250 | with torch.no_grad(): 251 | core, factors = tucker(tensor, rank, **kwargs) 252 | 253 | return cls(nn.Parameter(core.contiguous()), [nn.Parameter(f.contiguous()) for f in factors]) 254 | 255 | def init_from_tensor(self, tensor, unsqueezed_modes=None, unsqueezed_init='average', **kwargs): 256 | """Initialize the tensor factorization from a tensor 257 | 258 | Parameters 259 | ---------- 260 | tensor : torch.Tensor 261 | full tensor to decompose 262 | unsqueezed_modes : int list 263 | list of modes for which the rank is 1 that don't correspond to a mode in the full tensor 264 | essentially we are adding a new dimension for which the core has dim 1, 265 | and that is not initialized through decomposition. 266 | Instead first `tensor` is decomposed into the other factors. 267 | The `unsqueezed factors` are then added and initialized e.g. with 1/dim[i] 268 | unsqueezed_init : 'average' or float 269 | if unsqueezed_modes, this is how the added "unsqueezed" factors will be initialized 270 | if 'average', then unsqueezed_factor[i] will have value 1/tensor.shape[i] 271 | """ 272 | if unsqueezed_modes is not None: 273 | unsqueezed_modes = sorted(unsqueezed_modes) 274 | for mode in unsqueezed_modes[::-1]: 275 | if self.rank[mode] != 1: 276 | msg = 'It is only possible to initialize by averagig over mode for which rank=1.' 277 | msg += f'However, got unsqueezed_modes={unsqueezed_modes} but rank[{mode}]={self.rank[mode]} != 1.' 278 | raise ValueError(msg) 279 | 280 | rank = tuple(r for (i, r) in enumerate(self.rank) if i not in unsqueezed_modes) 281 | else: 282 | rank = self.rank 283 | 284 | with torch.no_grad(): 285 | core, factors = tucker(tensor, rank, **kwargs) 286 | 287 | if unsqueezed_modes is not None: 288 | # Initialise with 1/shape[mode] or given value 289 | for mode in unsqueezed_modes: 290 | size = self.shape[mode] 291 | factor = torch.ones(size, 1) 292 | if unsqueezed_init == 'average': 293 | factor /= size 294 | else: 295 | factor *= unsqueezed_init 296 | factors.insert(mode, factor) 297 | core = core.unsqueeze(mode) 298 | 299 | self.core = nn.Parameter(core.contiguous()) 300 | self.factors = FactorList([nn.Parameter(f.contiguous()) for f in factors]) 301 | return self 302 | 303 | @property 304 | def decomposition(self): 305 | return self.core, self.factors 306 | 307 | def to_tensor(self): 308 | return tl.tucker_to_tensor(self.decomposition) 309 | 310 | def normal_(self, mean=0, std=1): 311 | if mean != 0: 312 | raise ValueError(f'Currently only mean=0 is supported, but got mean={mean}') 313 | 314 | r = np.prod([math.sqrt(r) for r in self.rank]) 315 | std_factors = (std/r)**(1/(self.order+1)) 316 | 317 | with torch.no_grad(): 318 | self.core.data.normal_(0, std_factors) 319 | for factor in self.factors: 320 | factor.data.normal_(0, std_factors) 321 | return self 322 | 323 | def __getitem__(self, indices): 324 | if isinstance(indices, int): 325 | # Select one dimension of one mode 326 | mixing_factor, *factors = self.factors 327 | core = tenalg.mode_dot(self.core, mixing_factor[indices, :], 0) 328 | return self.__class__(core, factors) 329 | 330 | elif isinstance(indices, slice): 331 | mixing_factor, *factors = self.factors 332 | factors = [mixing_factor[indices, :], *factors] 333 | return self.__class__(self.core, factors) 334 | 335 | else: 336 | # Index multiple dimensions 337 | modes = [] 338 | factors = [] 339 | factors_contract = [] 340 | for i, (index, factor) in enumerate(zip(indices, self.factors)): 341 | if index is Ellipsis: 342 | raise ValueError(f'Ellipsis is not yet supported, yet got indices={indices}, indices[{i}]={index}.') 343 | if isinstance(index, int): 344 | modes.append(i) 345 | factors_contract.append(factor[index, :]) 346 | else: 347 | factors.append(factor[index, :]) 348 | 349 | if modes: 350 | core = tenalg.multi_mode_dot(self.core, factors_contract, modes=modes) 351 | else: 352 | core = self.core 353 | factors = factors + self.factors[i+1:] 354 | 355 | if factors: 356 | return self.__class__(core, factors) 357 | 358 | # Fully contracted tensor 359 | return core 360 | 361 | 362 | class TTTensor(FactorizedTensor, name='TT'): 363 | """Tensor-Train (Matrix-Product-State) Factorization 364 | 365 | Parameters 366 | ---------- 367 | factors 368 | shape 369 | rank 370 | """ 371 | def __init__(self, factors, shape=None, rank=None): 372 | super().__init__() 373 | if shape is None or rank is None: 374 | self.shape, self.rank = tl.tt_tensor._validate_tt_tensor(factors) 375 | else: 376 | self.shape, self.rank = shape, rank 377 | 378 | self.order = len(self.shape) 379 | self.factors = FactorList(factors) 380 | 381 | @classmethod 382 | def new(cls, shape, rank, device=None, dtype=None, **kwargs): 383 | rank = tl.tt_tensor.validate_tt_rank(shape, rank) 384 | 385 | # Avoid the issues with ParameterList 386 | factors = [nn.Parameter(torch.empty((rank[i], s, rank[i+1]), device=device, dtype=dtype)) for i, s in enumerate(shape)] 387 | 388 | return cls(factors) 389 | 390 | @classmethod 391 | def from_tensor(cls, tensor, rank='same', **kwargs): 392 | shape = tensor.shape 393 | rank = tl.tt_tensor.validate_tt_rank(shape, rank) 394 | 395 | with torch.no_grad(): 396 | # TODO: deal properly with wrong kwargs 397 | factors = tensor_train(tensor, rank) 398 | 399 | return cls([nn.Parameter(f.contiguous()) for f in factors]) 400 | 401 | def init_from_tensor(self, tensor, **kwargs): 402 | with torch.no_grad(): 403 | # TODO: deal properly with wrong kwargs 404 | factors = tensor_train(tensor, self.rank) 405 | 406 | self.factors = FactorList([nn.Parameter(f.contiguous()) for f in factors]) 407 | self.rank = tuple([f.shape[0] for f in factors] + [1]) 408 | return self 409 | 410 | @property 411 | def decomposition(self): 412 | return self.factors 413 | 414 | def to_tensor(self): 415 | return tl.tt_to_tensor(self.decomposition) 416 | 417 | def normal_(self, mean=0, std=1): 418 | if mean != 0: 419 | raise ValueError(f'Currently only mean=0 is supported, but got mean={mean}') 420 | 421 | r = np.prod(self.rank) 422 | std_factors = (std/r)**(1/self.order) 423 | with torch.no_grad(): 424 | for factor in self.factors: 425 | factor.data.normal_(0, std_factors) 426 | return self 427 | 428 | def __getitem__(self, indices): 429 | if isinstance(indices, int): 430 | # Select one dimension of one mode 431 | factor, next_factor, *factors = self.factors 432 | next_factor = tenalg.mode_dot(next_factor, factor[:, indices, :].squeeze(1), 0) 433 | return self.__class__([next_factor, *factors]) 434 | 435 | elif isinstance(indices, slice): 436 | mixing_factor, *factors = self.factors 437 | factors = [mixing_factor[:, indices], *factors] 438 | return self.__class__(factors) 439 | 440 | else: 441 | factors = [] 442 | all_contracted = True 443 | for i, index in enumerate(indices): 444 | if index is Ellipsis: 445 | raise ValueError(f'Ellipsis is not yet supported, yet got indices={indices}, indices[{i}]={index}.') 446 | if isinstance(index, int): 447 | if i: 448 | factor = tenalg.mode_dot(factor, self.factors[i][:, index, :].T, -1) 449 | else: 450 | factor = self.factors[i][:, index, :] 451 | else: 452 | if i: 453 | if all_contracted: 454 | factor = tenalg.mode_dot(self.factors[i][:, index, :], factor, 0) 455 | else: 456 | factors.append(factor) 457 | factor = self.factors[i][:, index, :] 458 | else: 459 | factor = self.factors[i][:, index, :] 460 | all_contracted = False 461 | 462 | if factor.ndim == 2: # We have contracted all cores, so have a 2D matrix 463 | if self.order == (i+1): 464 | # No factors left 465 | return factor.squeeze() 466 | else: 467 | next_factor, *factors = self.factors[i+1:] 468 | factor = tenalg.mode_dot(next_factor, factor, 0) 469 | return self.__class__([factor, *factors]) 470 | else: 471 | return self.__class__([*factors, factor, *self.factors[i+1:]]) 472 | 473 | def transduct(self, new_dim, mode=0, new_factor=None): 474 | """Transduction adds a new dimension to the existing factorization 475 | 476 | Parameters 477 | ---------- 478 | new_dim : int 479 | dimension of the new mode to add 480 | mode : where to insert the new dimension, after the channels, default is 0 481 | by default, insert the new dimensions before the existing ones 482 | (e.g. add time before height and width) 483 | 484 | Returns 485 | ------- 486 | self 487 | """ 488 | factors = self.factors 489 | 490 | # Important: don't increment the order before accessing factors which uses order! 491 | self.order += 1 492 | new_rank = self.rank[mode] 493 | self.rank = self.rank[:mode] + (new_rank, ) + self.rank[mode:] 494 | self.shape = self.shape[:mode] + (new_dim, ) + self.shape[mode:] 495 | 496 | # Init so the reconstruction is equivalent to concatenating the previous self new_dim times 497 | if new_factor is None: 498 | new_factor = torch.zeros(new_rank, new_dim, new_rank) 499 | for i in range(new_dim): 500 | new_factor[:, i, :] = torch.eye(new_rank)#/new_dim 501 | # Below: <=> static prediciton 502 | # new_factor[:, new_dim//2, :] = torch.eye(new_rank) 503 | 504 | factors.insert(mode, nn.Parameter(new_factor.to(factors[0].device).contiguous())) 505 | self.factors = FactorList(factors) 506 | 507 | return self 508 | -------------------------------------------------------------------------------- /tltorch/factorized_tensors/init.py: -------------------------------------------------------------------------------- 1 | """Module for initializing tensor decompositions 2 | """ 3 | 4 | # Author: Jean Kossaifi 5 | # License: BSD 3 clause 6 | 7 | import torch 8 | import math 9 | import numpy as np 10 | 11 | import tensorly as tl 12 | tl.set_backend('pytorch') 13 | 14 | def tensor_init(tensor, std=0.02): 15 | """Initializes directly the parameters of a factorized tensor so the reconstruction has the specified standard deviation and 0 mean 16 | 17 | Parameters 18 | ---------- 19 | tensor : torch.Tensor or FactorizedTensor 20 | std : float, default is 0.02 21 | the desired standard deviation of the full (reconstructed) tensor 22 | """ 23 | from .factorized_tensors import FactorizedTensor 24 | 25 | if isinstance(tensor, FactorizedTensor): 26 | tensor.normal_(0, std) 27 | elif torch.is_tensor(tensor): 28 | tensor.normal_(0, std) 29 | else: 30 | raise ValueError(f'Got tensor of class {tensor.__class__.__name__} but expected torch.Tensor or FactorizedWeight.') 31 | 32 | 33 | def cp_init(cp_tensor, std=0.02): 34 | """Initializes directly the weights and factors of a CP decomposition so the reconstruction has the specified std and 0 mean 35 | 36 | Parameters 37 | ---------- 38 | cp_tensor : CPTensor 39 | std : float, default is 0.02 40 | the desired standard deviation of the full (reconstructed) tensor 41 | 42 | Notes 43 | ----- 44 | We assume the given (weights, factors) form a correct CP decomposition, no checks are done here. 45 | """ 46 | rank = cp_tensor.rank # We assume we are given a valid CP 47 | order = cp_tensor.orders 48 | std_factors = (std/math.sqrt(rank))**(1/order) 49 | 50 | with torch.no_grad(): 51 | cp_tensor.weights.fill_(1) 52 | for factor in cp_tensor.factors: 53 | factor.normal_(0, std_factors) 54 | return cp_tensor 55 | 56 | def tucker_init(tucker_tensor, std=0.02): 57 | """Initializes directly the weights and factors of a Tucker decomposition so the reconstruction has the specified std and 0 mean 58 | 59 | Parameters 60 | ---------- 61 | tucker_tensor : TuckerTensor 62 | std : float, default is 0.02 63 | the desired standard deviation of the full (reconstructed) tensor 64 | 65 | Notes 66 | ----- 67 | We assume the given (core, factors) form a correct Tucker decomposition, no checks are done here. 68 | """ 69 | order = tucker_tensor.order 70 | rank = tucker_tensor.rank 71 | r = np.prod([math.sqrt(r) for r in rank]) 72 | std_factors = (std/r)**(1/(order+1)) 73 | with torch.no_grad(): 74 | tucker_tensor.core.normal_(0, std_factors) 75 | for factor in tucker_tensor.factors: 76 | factor.normal_(0, std_factors) 77 | return tucker_tensor 78 | 79 | def tt_init(tt_tensor, std=0.02): 80 | """Initializes directly the weights and factors of a TT decomposition so the reconstruction has the specified std and 0 mean 81 | 82 | Parameters 83 | ---------- 84 | tt_tensor : TTTensor 85 | std : float, default is 0.02 86 | the desired standard deviation of the full (reconstructed) tensor 87 | 88 | Notes 89 | ----- 90 | We assume the given factors form a correct TT decomposition, no checks are done here. 91 | """ 92 | order = tt_tensor.order 93 | r = np.prod(tt_tensor.rank) 94 | std_factors = (std/r)**(1/order) 95 | with torch.no_grad(): 96 | for factor in tt_tensor.factors: 97 | factor.normal_(0, std_factors) 98 | return tt_tensor 99 | 100 | 101 | def block_tt_init(block_tt, std=0.02): 102 | """Initializes directly the weights and factors of a BlockTT decomposition so the reconstruction has the specified std and 0 mean 103 | 104 | Parameters 105 | ---------- 106 | block_tt : Matrix in the tensor-train format 107 | std : float, default is 0.02 108 | the desired standard deviation of the full (reconstructed) tensor 109 | 110 | Notes 111 | ----- 112 | We assume the given factors form a correct Block-TT decomposition, no checks are done here. 113 | """ 114 | return tt_init(block_tt, std=std) -------------------------------------------------------------------------------- /tltorch/factorized_tensors/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/factorized_tensors/tests/__init__.py -------------------------------------------------------------------------------- /tltorch/factorized_tensors/tests/test_factorizations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import math 4 | import torch 5 | from torch import testing 6 | 7 | from tltorch.factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT 8 | from tltorch.factorized_tensors.core import TensorizedTensor 9 | 10 | from ..factorized_tensors import FactorizedTensor, CPTensor, TuckerTensor, TTTensor 11 | 12 | 13 | @pytest.mark.parametrize('factorization', ['CP', 'Tucker', 'TT']) 14 | def test_FactorizedTensor(factorization): 15 | """Test for FactorizedTensor""" 16 | shape = (4, 3, 2, 5) 17 | fact_tensor = FactorizedTensor.new(shape=shape, rank='same', factorization=factorization) 18 | fact_tensor.normal_() 19 | 20 | # Check that the correct type of factorized tensor is created 21 | assert fact_tensor._name.lower() == factorization.lower() 22 | mapping = dict(CP=CPTensor, Tucker=TuckerTensor, TT=TTTensor) 23 | assert isinstance(fact_tensor, mapping[factorization]) 24 | 25 | # Check the shape of the factorized tensor and reconstruction 26 | reconstruction = fact_tensor.to_tensor() 27 | assert fact_tensor.shape == reconstruction.shape == shape 28 | 29 | # Check that indexing the factorized tensor returns the correct result 30 | # np_s converts intuitive array indexing to proper tuples 31 | indices = [ 32 | np.s_[:, 2, :], # = (slice(None), 2, slice(None)) 33 | np.s_[2:, :2, :, 0], 34 | np.s_[0, 2, 1, 3], 35 | np.s_[-1, :, :, ::2] 36 | ] 37 | for idx in indices: 38 | assert reconstruction[idx].shape == fact_tensor[idx].shape 39 | res = fact_tensor[idx] 40 | if not torch.is_tensor(res): 41 | res = res.to_tensor() 42 | testing.assert_close(reconstruction[idx], res) 43 | 44 | 45 | @pytest.mark.parametrize('factorization', ['BlockTT', 'CP']) #['CP', 'Tucker', 'BlockTT']) 46 | @pytest.mark.parametrize('batch_size', [(), (4,)]) 47 | def test_TensorizedMatrix(factorization, batch_size): 48 | """Test for TensorizedMatrix""" 49 | row_tensor_shape = (4, 3, 2) 50 | column_tensor_shape = (5, 3, 2) 51 | row_shape = math.prod(row_tensor_shape) 52 | column_shape = math.prod(column_tensor_shape) 53 | tensor_shape = batch_size + (row_tensor_shape, column_tensor_shape) 54 | 55 | fact_tensor = TensorizedTensor.new(tensor_shape, 56 | rank=0.5, factorization=factorization) 57 | fact_tensor.normal_() 58 | 59 | # Check that the correct type of factorized tensor is created 60 | assert fact_tensor._name.lower() == factorization.lower() 61 | mapping = dict(CP=CPTensorized, Tucker=TuckerTensorized, BlockTT=BlockTT) 62 | assert isinstance(fact_tensor, mapping[factorization]) 63 | 64 | # Check that the matrix has the right shape 65 | reconstruction = fact_tensor.to_matrix() 66 | if batch_size: 67 | assert fact_tensor.shape[1] == row_shape == reconstruction.shape[1] 68 | assert fact_tensor.shape[2] == column_shape == reconstruction.shape[2] 69 | assert fact_tensor.ndim == 3 70 | else: 71 | assert fact_tensor.shape[0] == row_shape == reconstruction.shape[0] 72 | assert fact_tensor.shape[1] == column_shape == reconstruction.shape[1] 73 | assert fact_tensor.ndim == 2 74 | 75 | # Check that indexing the factorized tensor returns the correct result 76 | # np_s converts intuitive array indexing to proper tuples 77 | indices = [ 78 | np.s_[:, :], # = (slice(None), slice(None)) 79 | np.s_[:, 2], 80 | np.s_[2, 3], 81 | np.s_[1, :] 82 | ] 83 | for idx in indices: 84 | assert tuple(reconstruction[idx].shape) == tuple(fact_tensor[idx].shape) 85 | res = fact_tensor[idx] 86 | if not torch.is_tensor(res): 87 | res = res.to_matrix() 88 | testing.assert_close(reconstruction[idx], res) 89 | 90 | 91 | @pytest.mark.parametrize('factorization', ['CP', 'TT']) 92 | def test_transduction(factorization): 93 | """Test for transduction""" 94 | shape = (3, 4, 5) 95 | new_dim = 2 96 | for mode in range(3): 97 | fact_tensor = FactorizedTensor.new(shape=shape, rank=6, factorization=factorization) 98 | fact_tensor.normal_() 99 | original_rec = fact_tensor.to_tensor() 100 | fact_tensor = fact_tensor.transduct(new_dim, mode=mode) 101 | rec = fact_tensor.to_tensor() 102 | true_shape = list(shape); true_shape.insert(mode, new_dim) 103 | assert tuple(fact_tensor.shape) == tuple(rec.shape) == tuple(true_shape) 104 | 105 | indices = [slice(None)]*mode 106 | for i in range(new_dim): 107 | testing.assert_close(original_rec, rec[tuple(indices + [i])]) 108 | 109 | @pytest.mark.parametrize('unsqueezed_init', ['average', 1.2]) 110 | def test_tucker_init_unsqueezed_modes(unsqueezed_init): 111 | """Test for Tucker Factorization init from tensor with unsqueezed_modes 112 | """ 113 | tensor = FactorizedTensor.new((4, 4, 4), rank=(4, 1, 4), factorization='tucker') 114 | mat = torch.randn((4, 4)) 115 | 116 | tensor.init_from_tensor(mat, unsqueezed_modes=[1], unsqueezed_init=unsqueezed_init) 117 | rec = tensor.to_tensor() 118 | 119 | if unsqueezed_init == 'average': 120 | coef = 1/4 121 | else: 122 | coef = unsqueezed_init 123 | 124 | for i in range(4): 125 | testing.assert_close(rec[:, i], mat*coef) 126 | 127 | 128 | @pytest.mark.parametrize('factorization', ['ComplexCP', 'ComplexTucker', 'ComplexTT', 'ComplexDense']) 129 | def test_ComplexFactorizedTensor(factorization): 130 | """Test for ComplexFactorizedTensor""" 131 | shape = (4, 3, 2, 5) 132 | fact_tensor = FactorizedTensor.new(shape=shape, rank='same', factorization=factorization) 133 | fact_tensor.normal_() 134 | 135 | assert fact_tensor.to_tensor().shape == shape 136 | assert fact_tensor.to_tensor().dtype == torch.cfloat 137 | for param in fact_tensor.parameters(): 138 | assert param.dtype == torch.float32 139 | -------------------------------------------------------------------------------- /tltorch/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .convolution import convolve, tucker_conv 2 | from .linear import factorized_linear -------------------------------------------------------------------------------- /tltorch/functional/convolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import tensorly as tl 5 | tl.set_backend('pytorch') 6 | from tensorly import tenalg 7 | 8 | from ..factorized_tensors import CPTensor, TTTensor, TuckerTensor, DenseTensor 9 | 10 | # Author: Jean Kossaifi 11 | # License: BSD 3 clause 12 | 13 | 14 | _CONVOLUTION = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} 15 | 16 | def convolve(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 17 | """Convolution of any specified order, wrapper on torch's F.convNd 18 | 19 | Parameters 20 | ---------- 21 | x : torch.Tensor or FactorizedTensor 22 | input tensor 23 | weight : torch.Tensor 24 | convolutional weights 25 | bias : bool, optional 26 | by default None 27 | stride : int, optional 28 | by default 1 29 | padding : int, optional 30 | by default 0 31 | dilation : int, optional 32 | by default 1 33 | groups : int, optional 34 | by default 1 35 | 36 | Returns 37 | ------- 38 | torch.Tensor 39 | `x` convolved with `weight` 40 | """ 41 | try: 42 | if torch.is_tensor(weight): 43 | return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding, 44 | dilation=dilation, groups=groups) 45 | else: 46 | if isinstance(weight, TTTensor): 47 | weight = tl.moveaxis(weight.to_tensor(), -1, 0) 48 | else: 49 | weight = weight.to_tensor() 50 | return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding, 51 | dilation=dilation, groups=groups) 52 | except KeyError: 53 | raise ValueError(f'Got tensor of order={weight.ndim} but pytorch only supports up to 3rd order (3D) Convs.') 54 | 55 | 56 | def general_conv1d_(x, kernel, mode, bias=None, stride=1, padding=0, groups=1, dilation=1, verbose=False): 57 | """General 1D convolution along the mode-th dimension 58 | 59 | Parameters 60 | ---------- 61 | x : batch-dize, in_channels, K1, ..., KN 62 | kernel : out_channels, in_channels/groups, K{mode} 63 | mode : int 64 | weight along which to perform the decomposition 65 | stride : int 66 | padding : int 67 | groups : 1 68 | typically would be equal to thhe number of input-channels 69 | at least for CP convolutions 70 | 71 | Returns 72 | ------- 73 | x convolved with the given kernel, along dimension `mode` 74 | """ 75 | if verbose: 76 | print(f'Convolving {x.shape} with {kernel.shape} along mode {mode}, ' 77 | f'stride={stride}, padding={padding}, groups={groups}') 78 | 79 | in_channels = tl.shape(x)[1] 80 | n_dim = tl.ndim(x) 81 | permutation = list(range(n_dim)) 82 | spatial_dim = permutation.pop(mode) 83 | channels_dim = permutation.pop(1) 84 | permutation += [channels_dim, spatial_dim] 85 | x = tl.transpose(x, permutation) 86 | x_shape = list(x.shape) 87 | x = tl.reshape(x, (-1, in_channels, x_shape[-1])) 88 | x = F.conv1d(x.contiguous(), kernel, bias=bias, stride=stride, dilation=dilation, padding=padding, groups=groups) 89 | x_shape[-2:] = x.shape[-2:] 90 | x = tl.reshape(x, x_shape) 91 | permutation = list(range(n_dim))[:-2] 92 | permutation.insert(1, n_dim - 2) 93 | permutation.insert(mode, n_dim - 1) 94 | x = tl.transpose(x, permutation) 95 | 96 | return x 97 | 98 | 99 | def general_conv1d(x, kernel, mode, bias=None, stride=1, padding=0, groups=1, dilation=1, verbose=False): 100 | """General 1D convolution along the mode-th dimension 101 | 102 | Uses an ND convolution under the hood 103 | 104 | Parameters 105 | ---------- 106 | x : batch-dize, in_channels, K1, ..., KN 107 | kernel : out_channels, in_channels/groups, K{mode} 108 | mode : int 109 | weight along which to perform the decomposition 110 | stride : int 111 | padding : int 112 | groups : 1 113 | typically would be equal to the number of input-channels 114 | at least for CP convolutions 115 | 116 | Returns 117 | ------- 118 | x convolved with the given kernel, along dimension `mode` 119 | """ 120 | if verbose: 121 | print(f'Convolving {x.shape} with {kernel.shape} along mode {mode}, ' 122 | f'stride={stride}, padding={padding}, groups={groups}') 123 | 124 | def _pad_value(value, mode, order, padding=1): 125 | return tuple([value if i == (mode - 2) else padding for i in range(order)]) 126 | 127 | ndim = tl.ndim(x) 128 | order = ndim - 2 129 | for i in range(2, ndim): 130 | if i != mode: 131 | kernel = kernel.unsqueeze(i) 132 | 133 | return _CONVOLUTION[order](x, kernel, bias=bias, 134 | stride=_pad_value(stride, mode, order), 135 | padding=_pad_value(padding, mode, order, padding=0), 136 | dilation=_pad_value(dilation, mode, order), 137 | groups=groups) 138 | 139 | 140 | def tucker_conv(x, tucker_tensor, bias=None, stride=1, padding=0, dilation=1): 141 | # Extract the rank from the actual decomposition in case it was changed by, e.g. dropout 142 | rank = tucker_tensor.rank 143 | 144 | batch_size = x.shape[0] 145 | n_dim = tl.ndim(x) 146 | 147 | # Change the number of channels to the rank 148 | x_shape = list(x.shape) 149 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous() 150 | 151 | # This can be done with a tensor contraction 152 | # First conv == tensor contraction 153 | # from (in_channels, rank) to (rank == out_channels, in_channels, 1) 154 | x = F.conv1d(x, tl.transpose(tucker_tensor.factors[1]).unsqueeze(2)) 155 | 156 | x_shape[1] = rank[1] 157 | x = x.reshape(x_shape) 158 | 159 | modes = list(range(2, n_dim+1)) 160 | weight = tl.tenalg.multi_mode_dot(tucker_tensor.core, tucker_tensor.factors[2:], modes=modes) 161 | x = convolve(x, weight, bias=None, stride=stride, padding=padding, dilation=dilation) 162 | 163 | # Revert back number of channels from rank to output_channels 164 | x_shape = list(x.shape) 165 | x = x.reshape((batch_size, x_shape[1], -1)) 166 | # Last conv == tensor contraction 167 | # From (out_channels, rank) to (out_channels, in_channels == rank, 1) 168 | x = F.conv1d(x, tucker_tensor.factors[0].unsqueeze(2), bias=bias) 169 | 170 | x_shape[1] = x.shape[1] 171 | x = x.reshape(x_shape) 172 | 173 | return x 174 | 175 | 176 | def tt_conv(x, tt_tensor, bias=None, stride=1, padding=0, dilation=1): 177 | """Perform a factorized tt convolution 178 | 179 | Parameters 180 | ---------- 181 | x : torch.tensor 182 | tensor of shape (batch_size, C, I_2, I_3, ..., I_N) 183 | 184 | Returns 185 | ------- 186 | NDConv(x) with an tt kernel 187 | """ 188 | shape = tt_tensor.shape 189 | rank = tt_tensor.rank 190 | 191 | batch_size = x.shape[0] 192 | order = len(shape) - 2 193 | 194 | if isinstance(padding, int): 195 | padding = (padding, )*order 196 | if isinstance(stride, int): 197 | stride = (stride, )*order 198 | if isinstance(dilation, int): 199 | dilation = (dilation, )*order 200 | 201 | # Change the number of channels to the rank 202 | x_shape = list(x.shape) 203 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous() 204 | 205 | # First conv == tensor contraction 206 | # from (1, in_channels, rank) to (rank == out_channels, in_channels, 1) 207 | x = F.conv1d(x, tl.transpose(tt_tensor.factors[0], [2, 1, 0])) 208 | 209 | x_shape[1] = x.shape[1]#rank[1] 210 | x = x.reshape(x_shape) 211 | 212 | # convolve over non-channels 213 | for i in range(order): 214 | # From (in_rank, kernel_size, out_rank) to (out_rank, in_rank, kernel_size) 215 | kernel = tl.transpose(tt_tensor.factors[i+1], [2, 0, 1]) 216 | x = general_conv1d(x.contiguous(), kernel, i+2, stride=stride[i], padding=padding[i], dilation=dilation[i]) 217 | 218 | # Revert back number of channels from rank to output_channels 219 | x_shape = list(x.shape) 220 | x = x.reshape((batch_size, x_shape[1], -1)) 221 | # Last conv == tensor contraction 222 | # From (rank, out_channels, 1) to (out_channels, in_channels == rank, 1) 223 | x = F.conv1d(x, tl.transpose(tt_tensor.factors[-1], [1, 0, 2]), bias=bias) 224 | 225 | x_shape[1] = x.shape[1] 226 | x = x.reshape(x_shape) 227 | 228 | return x 229 | 230 | 231 | def cp_conv(x, cp_tensor, bias=None, stride=1, padding=0, dilation=1): 232 | """Perform a factorized CP convolution 233 | 234 | Parameters 235 | ---------- 236 | x : torch.tensor 237 | tensor of shape (batch_size, C, I_2, I_3, ..., I_N) 238 | 239 | Returns 240 | ------- 241 | NDConv(x) with an CP kernel 242 | """ 243 | shape = cp_tensor.shape 244 | rank = cp_tensor.rank 245 | 246 | batch_size = x.shape[0] 247 | order = len(shape) - 2 248 | 249 | if isinstance(padding, int): 250 | padding = (padding, )*order 251 | if isinstance(stride, int): 252 | stride = (stride, )*order 253 | if isinstance(dilation, int): 254 | dilation = (dilation, )*order 255 | 256 | # Change the number of channels to the rank 257 | x_shape = list(x.shape) 258 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous() 259 | 260 | # First conv == tensor contraction 261 | # from (in_channels, rank) to (rank == out_channels, in_channels, 1) 262 | x = F.conv1d(x, tl.transpose(cp_tensor.factors[1]).unsqueeze(2)) 263 | 264 | x_shape[1] = rank 265 | x = x.reshape(x_shape) 266 | 267 | # convolve over non-channels 268 | for i in range(order): 269 | # From (kernel_size, rank) to (rank, 1, kernel_size) 270 | kernel = tl.transpose(cp_tensor.factors[i+2]).unsqueeze(1) 271 | x = general_conv1d(x.contiguous(), kernel, i+2, stride=stride[i], padding=padding[i], dilation=dilation[i], groups=rank) 272 | 273 | # Revert back number of channels from rank to output_channels 274 | x_shape = list(x.shape) 275 | x = x.reshape((batch_size, x_shape[1], -1)) 276 | # Last conv == tensor contraction 277 | # From (out_channels, rank) to (out_channels, in_channels == rank, 1) 278 | x = F.conv1d(x*cp_tensor.weights.unsqueeze(1).unsqueeze(0), cp_tensor.factors[0].unsqueeze(2), bias=bias) 279 | 280 | x_shape[1] = x.shape[1] # = out_channels 281 | x = x.reshape(x_shape) 282 | 283 | return x 284 | 285 | 286 | def cp_conv_mobilenet(x, cp_tensor, bias=None, stride=1, padding=0, dilation=1): 287 | """Perform a factorized CP convolution 288 | 289 | Parameters 290 | ---------- 291 | x : torch.tensor 292 | tensor of shape (batch_size, C, I_2, I_3, ..., I_N) 293 | 294 | Returns 295 | ------- 296 | NDConv(x) with an CP kernel 297 | """ 298 | factors = cp_tensor.factors 299 | shape = cp_tensor.shape 300 | rank = cp_tensor.rank 301 | 302 | batch_size = x.shape[0] 303 | order = len(shape) - 2 304 | 305 | # Change the number of channels to the rank 306 | x_shape = list(x.shape) 307 | x = x.reshape((batch_size, x_shape[1], -1)).contiguous() 308 | 309 | # First conv == tensor contraction 310 | # from (in_channels, rank) to (rank == out_channels, in_channels, 1) 311 | x = F.conv1d(x, tl.transpose(factors[1]).unsqueeze(2)) 312 | 313 | x_shape[1] = rank 314 | x = x.reshape(x_shape) 315 | 316 | # convolve over merged actual dimensions 317 | # Spatial convs 318 | # From (kernel_size, rank) to (out_rank, 1, kernel_size) 319 | if order == 1: 320 | weight = tl.transpose(factors[2]).unsqueeze(1) 321 | x = F.conv1d(x.contiguous(), weight, stride=stride, padding=padding, dilation=dilation, groups=rank) 322 | elif order == 2: 323 | weight = tenalg.tensordot(tl.transpose(factors[2]), 324 | tl.transpose(factors[3]), modes=(), batched_modes=0 325 | ).unsqueeze(1) 326 | x = F.conv2d(x.contiguous(), weight, stride=stride, padding=padding, dilation=dilation, groups=rank) 327 | elif order == 3: 328 | weight = tenalg.tensordot(tl.transpose(factors[2]), 329 | tenalg.tensordot(tl.transpose(factors[3]), tl.transpose(factors[4]), modes=(), batched_modes=0), 330 | modes=(), batched_modes=0 331 | ).unsqueeze(1) 332 | x = F.conv3d(x.contiguous(), weight, stride=stride, padding=padding, dilation=dilation, groups=rank) 333 | 334 | # Revert back number of channels from rank to output_channels 335 | x_shape = list(x.shape) 336 | x = x.reshape((batch_size, x_shape[1], -1)) 337 | 338 | # Last conv == tensor contraction 339 | # From (out_channels, rank) to (out_channels, in_channels == rank, 1) 340 | x = F.conv1d(x*cp_tensor.weights.unsqueeze(1).unsqueeze(0), factors[0].unsqueeze(2), bias=bias) 341 | 342 | x_shape[1] = x.shape[1] # = out_channels 343 | x = x.reshape(x_shape) 344 | 345 | return x 346 | 347 | 348 | def _get_factorized_conv(factorization, implementation='factorized'): 349 | if implementation == 'reconstructed' or factorization == 'Dense': 350 | return convolve 351 | if isinstance(factorization, CPTensor): 352 | if implementation == 'factorized': 353 | return cp_conv 354 | elif implementation == 'mobilenet': 355 | return cp_conv_mobilenet 356 | elif isinstance(factorization, TuckerTensor): 357 | return tucker_conv 358 | elif isinstance(factorization, TTTensor): 359 | return tt_conv 360 | raise ValueError(f'Got unknown type {factorization}') 361 | 362 | 363 | def convNd(x, weight, bias=None, stride=1, padding=0, dilation=1, implementation='factorized'): 364 | if implementation=='reconstructed': 365 | weight = weight.to_tensor() 366 | 367 | if isinstance(weight, DenseTensor): 368 | return convolve(x, weight.tensor, bias=bias, stride=stride, padding=padding, dilation=dilation) 369 | 370 | if torch.is_tensor(weight): 371 | return convolve(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation) 372 | 373 | if isinstance(weight, CPTensor): 374 | if implementation == 'factorized': 375 | return cp_conv(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation) 376 | elif implementation == 'mobilenet': 377 | return cp_conv_mobilenet(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation) 378 | elif isinstance(weight, TuckerTensor): 379 | return tucker_conv(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation) 380 | elif isinstance(weight, TTTensor): 381 | return tt_conv(x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation) 382 | -------------------------------------------------------------------------------- /tltorch/functional/factorized_linear.py: -------------------------------------------------------------------------------- 1 | from .factorized_tensordot import tensor_dot_tucker, tensor_dot_cp 2 | import tensorly as tl 3 | tl.set_backend('pytorch') 4 | 5 | # Author: Jean Kossaifi 6 | 7 | def linear_tucker(tensor, tucker_matrix, transpose=True, channels_first=True): 8 | if transpose: 9 | contraction_axis = 1 10 | else: 11 | contraction_axis = 0 12 | n_rows = len(tucker_matrix.tensorized_shape[contraction_axis]) 13 | tensor = tensor.reshape(-1, *tucker_matrix.tensorized_shape[contraction_axis]) 14 | 15 | modes_tensor = list(range(tensor.ndim - n_rows, tensor.ndim)) 16 | if transpose: 17 | modes_tucker = list(range(n_rows, tucker_matrix.order)) 18 | else: 19 | modes_tucker = list(range(n_rows)) 20 | 21 | return tensor_dot_tucker(tensor, tucker_matrix, (modes_tensor, modes_tucker)) 22 | 23 | def linear_cp(tensor, cp_matrix, transpose=True): 24 | if transpose: 25 | out_features, in_features = len(cp_matrix.tensorized_shape[0]), len(cp_matrix.tensorized_shape[1]) 26 | in_shape = cp_matrix.tensorized_shape[1] 27 | modes_cp = list(range(out_features, cp_matrix.order)) 28 | else: 29 | in_features, out_features = len(cp_matrix.tensorized_shape[0]), len(cp_matrix.tensorized_shape[1]) 30 | in_shape = cp_matrix.tensorized_shape[0] 31 | modes_cp = list(range(in_features)) 32 | tensor = tensor.reshape(-1, *in_shape) 33 | 34 | modes_tensor = list(range(1, tensor.ndim)) 35 | 36 | return tensor_dot_cp(tensor, cp_matrix, (modes_tensor, modes_cp)) 37 | 38 | 39 | def linear_blocktt(tensor, tt_matrix, transpose=True): 40 | if transpose: 41 | contraction_axis = 1 42 | else: 43 | contraction_axis = 0 44 | ndim = len(tt_matrix.tensorized_shape[contraction_axis]) 45 | tensor = tensor.reshape(-1, *tt_matrix.tensorized_shape[contraction_axis]) 46 | 47 | bs = 'a' 48 | start = ord(bs) + 1 49 | in_idx = bs + ''.join(chr(i) for i in [start+i for i in range(ndim)]) 50 | factors_idx = [] 51 | for i in range(ndim): 52 | if transpose: 53 | idx = [start+ndim*2+i, start+ndim+i, start+i, start+ndim*2+i+1] 54 | else: 55 | idx = [start+ndim*2+i, start+i, start+ndim+i, start+ndim*2+i+1] 56 | factors_idx.append(''.join(chr(j) for j in idx)) 57 | out_idx = bs + ''.join(chr(i) for i in [start + ndim + i for i in range(ndim)]) 58 | eq = in_idx + ',' + ','.join(i for i in factors_idx) + '->' + out_idx 59 | res = tl.einsum(eq, tensor, *tt_matrix.factors) 60 | return tl.reshape(res, (tl.shape(res)[0], -1)) 61 | 62 | -------------------------------------------------------------------------------- /tltorch/functional/factorized_tensordot.py: -------------------------------------------------------------------------------- 1 | # Author: Jean Kossaifi 2 | 3 | import tensorly as tl 4 | from tensorly.tenalg.tenalg_utils import _validate_contraction_modes 5 | tl.set_backend('pytorch') 6 | 7 | einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 8 | 9 | 10 | def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()): 11 | """Batched tensor contraction between a dense tensor and a Tucker tensor on specified modes 12 | 13 | Parameters 14 | ---------- 15 | tensor : DenseTensor 16 | tucker : TuckerTensor 17 | modes : int list or int 18 | modes on which to contract tensor1 and tensor2 19 | batched_modes : int or tuple[int] 20 | 21 | Returns 22 | ------- 23 | contraction : tensor contracted with cp on the specified modes 24 | """ 25 | modes_tensor, modes_tucker = _validate_contraction_modes( 26 | tl.shape(tensor), tucker.tensor_shape, modes) 27 | input_order = tensor.ndim 28 | weight_order = tucker.order 29 | 30 | batched_modes_tensor, batched_modes_tucker = _validate_contraction_modes( 31 | tl.shape(tensor), tucker.tensor_shape, batched_modes) 32 | 33 | sorted_modes_tucker = sorted(modes_tucker+batched_modes_tucker, reverse=True) 34 | sorted_modes_tensor = sorted(modes_tensor+batched_modes_tensor, reverse=True) 35 | 36 | # Symbol for dimensionality of the core 37 | rank_sym = [einsum_symbols[i] for i in range(weight_order)] 38 | 39 | # Symbols for tucker weight size 40 | tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)] 41 | 42 | # Symbols for input tensor 43 | tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)] 44 | 45 | # Output: input + weights symbols after removing contraction symbols 46 | output_sym = tensor_sym + tucker_sym 47 | for m in sorted_modes_tucker: 48 | if m in modes_tucker: #not batched 49 | output_sym.pop(m+input_order) 50 | for m in sorted_modes_tensor: 51 | # It's batched, always remove 52 | output_sym.pop(m) 53 | 54 | # print(tensor_sym, tucker_sym, modes_tensor, batched_modes_tensor) 55 | for i, e in enumerate(modes_tensor): 56 | tensor_sym[e] = tucker_sym[modes_tucker[i]] 57 | for i, e in enumerate(batched_modes_tensor): 58 | tensor_sym[e] = tucker_sym[batched_modes_tucker[i]] 59 | 60 | # Form the actual equation: tensor, core, factors -> output 61 | eq = ''.join(tensor_sym) 62 | eq += ',' + ''.join(rank_sym) 63 | eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym)) 64 | eq += '->' + ''.join(output_sym) 65 | 66 | return tl.einsum(eq, tensor, tucker.core, *tucker.factors) 67 | 68 | 69 | 70 | def tensor_dot_cp(tensor, cp, modes): 71 | """Contracts a to CP tensors in factorized form 72 | 73 | Returns 74 | ------- 75 | tensor = tensor x cp_matrix.to_matrix().T 76 | """ 77 | try: 78 | cp_shape = cp.tensor_shape 79 | except AttributeError: 80 | cp_shape = cp.shape 81 | modes_tensor, modes_cp = _validate_contraction_modes(tl.shape(tensor), cp_shape, modes) 82 | 83 | tensor_order = tl.ndim(tensor) 84 | # CP rank = 'a', start at b 85 | start = ord('b') 86 | eq_in = ''.join(f'{chr(start+index)}' for index in range(tensor_order)) 87 | eq_factors = [] 88 | eq_res = ''.join(eq_in[i] if i not in modes_tensor else '' for i in range(tensor_order)) 89 | counter_joint = 0 # contraction modes, shared indices between tensor and CP 90 | counter_free = 0 # new uncontracted modes from the CP 91 | for i in range(len(cp.factors)): 92 | if i in modes_cp: 93 | eq_factors.append(f'{eq_in[modes_tensor[counter_joint]]}a') 94 | counter_joint += 1 95 | else: 96 | eq_factors.append(f'{chr(start+tensor_order+counter_free)}a') 97 | eq_res += f'{chr(start+tensor_order+counter_free)}' 98 | counter_free += 1 99 | 100 | eq_factors = ','.join(f for f in eq_factors) 101 | eq = eq_in + ',a,' + eq_factors + '->' + eq_res 102 | res = tl.einsum(eq, tensor, cp.weights, *cp.factors) 103 | 104 | return res -------------------------------------------------------------------------------- /tltorch/functional/linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from ..factorized_tensors import TensorizedTensor 5 | from ..factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT 6 | from .factorized_linear import linear_blocktt, linear_cp, linear_tucker 7 | 8 | import tensorly as tl 9 | tl.set_backend('pytorch') 10 | 11 | # Author: Jean Kossaifi 12 | # License: BSD 3 clause 13 | 14 | 15 | def factorized_linear(x, weight, bias=None, in_features=None, implementation="factorized"): 16 | """Linear layer with a dense input x and factorized weight 17 | """ 18 | assert implementation in {"factorized", "reconstructed"}, f"Expect implementation from [factorized, reconstructed], but got {implementation}" 19 | 20 | if in_features is None: 21 | in_features = np.prod(x.shape[-1]) 22 | 23 | if not torch.is_tensor(weight): 24 | # Weights are in the form (out_features, in_features) 25 | # PyTorch's linear returns dot(x, weight.T)! 26 | if isinstance(weight, TensorizedTensor): 27 | if implementation == "factorized": 28 | x_shape = x.shape[:-1] + weight.tensorized_shape[1] 29 | out_shape = x.shape[:-1] + (-1, ) 30 | if isinstance(weight, CPTensorized): 31 | x = linear_cp(x.reshape(x_shape), weight).reshape(out_shape) 32 | if bias is not None: 33 | x = x + bias 34 | return x 35 | elif isinstance(weight, TuckerTensorized): 36 | x = linear_tucker(x.reshape(x_shape), weight).reshape(out_shape) 37 | if bias is not None: 38 | x = x + bias 39 | return x 40 | elif isinstance(weight, BlockTT): 41 | x = linear_blocktt(x.reshape(x_shape), weight).reshape(out_shape) 42 | if bias is not None: 43 | x = x + bias 44 | return x 45 | # if no efficient implementation available or force to use reconstructed mode: use reconstruction 46 | weight = weight.to_matrix() 47 | else: 48 | weight = weight.to_tensor() 49 | 50 | return F.linear(x, torch.reshape(weight, (-1, in_features)), bias=bias) 51 | -------------------------------------------------------------------------------- /tltorch/functional/tensor_regression.py: -------------------------------------------------------------------------------- 1 | from ..factorized_tensors import FactorizedTensor, TuckerTensor 2 | 3 | import tensorly as tl 4 | tl.set_backend('pytorch') 5 | from tensorly import tenalg 6 | 7 | # Author: Jean Kossaifi 8 | # License: BSD 3 clause 9 | 10 | 11 | def trl(x, weight, bias=None, **kwargs): 12 | """Tensor Regression Layer 13 | 14 | Parameters 15 | ---------- 16 | x : torch.tensor 17 | batch of inputs 18 | weight : FactorizedTensor 19 | factorized weights of the TRL 20 | bias : torch.Tensor, optional 21 | 1D tensor, by default None 22 | 23 | Returns 24 | ------- 25 | result 26 | input x contracted with regression weights 27 | """ 28 | if isinstance(weight, TuckerTensor): 29 | return tucker_trl(x, weight, bias=bias, **kwargs) 30 | else: 31 | if bias is None: 32 | return tenalg.inner(x, weight.to_tensor(), n_modes=tl.ndim(x)-1) 33 | else: 34 | return tenalg.inner(x, weight.to_tensor(), n_modes=tl.ndim(x)-1) + bias 35 | 36 | 37 | def tucker_trl(x, weight, project_input=False, bias=None): 38 | n_input = tl.ndim(x) - 1 39 | if project_input: 40 | x = tenalg.multi_mode_dot(x, weight.factors[:n_input], modes=range(1, n_input+1), transpose=True) 41 | regression_weights = tenalg.multi_mode_dot(weight.core, weight.factors[n_input:], 42 | modes=range(n_input, weight.order)) 43 | else: 44 | regression_weights = weight.to_tensor() 45 | 46 | if bias is None: 47 | return tenalg.inner(x, regression_weights, n_modes=tl.ndim(x)-1) 48 | else: 49 | return tenalg.inner(x, regression_weights, n_modes=tl.ndim(x)-1) + bias 50 | 51 | -------------------------------------------------------------------------------- /tltorch/functional/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/functional/tests/__init__.py -------------------------------------------------------------------------------- /tltorch/functional/tests/test_factorized_linear.py: -------------------------------------------------------------------------------- 1 | from ...factorized_tensors import TensorizedTensor 2 | from ..factorized_linear import linear_tucker, linear_blocktt, linear_cp 3 | import torch 4 | 5 | import tensorly as tl 6 | tl.set_backend('pytorch') 7 | from tensorly import testing 8 | from math import prod 9 | 10 | import pytest 11 | 12 | # Author: Jean Kossaifi 13 | 14 | 15 | @pytest.mark.parametrize('factorization, factorized_linear', 16 | [('tucker', linear_tucker), ('blocktt', linear_blocktt), ('cp', linear_cp)]) 17 | def test_linear_tensor_dot_tucker(factorization, factorized_linear): 18 | in_shape = (4, 5) 19 | in_dim = prod(in_shape) 20 | out_shape = (6, 2) 21 | rank = 3 22 | batch_size = 2 23 | 24 | tensor = tl.randn((batch_size, in_dim), dtype=tl.float32) 25 | fact_weight = TensorizedTensor.new((out_shape, in_shape), rank=rank, 26 | factorization=factorization) 27 | fact_weight.normal_() 28 | full_weight = fact_weight.to_matrix() 29 | true_res = torch.matmul(tensor, full_weight.T) 30 | res = factorized_linear(tensor, fact_weight, transpose=True) 31 | res = res.reshape(batch_size, -1) 32 | testing.assert_array_almost_equal(true_res, res, decimal=5) 33 | 34 | -------------------------------------------------------------------------------- /tltorch/tensor_hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from ._tensor_lasso import tensor_lasso, remove_tensor_lasso, TensorLasso 2 | from ._tensor_dropout import tensor_dropout, remove_tensor_dropout, TensorDropout -------------------------------------------------------------------------------- /tltorch/tensor_hooks/_tensor_dropout.py: -------------------------------------------------------------------------------- 1 | """Tensor Dropout for TensorModules""" 2 | 3 | # Author: Jean Kossaifi 4 | # License: BSD 3 clause 5 | 6 | import tensorly as tl 7 | tl.set_backend('pytorch') 8 | import torch 9 | from ..factorized_tensors import TuckerTensor, CPTensor, TTTensor 10 | 11 | 12 | class TensorDropout(): 13 | """Decomposition Hook for Tensor Dropout on FactorizedTensor 14 | 15 | Parameters 16 | ---------- 17 | name : FactorizedTensor parameter on which to apply the dropout 18 | proba : float, probability of dropout 19 | min_dim : int 20 | Minimum dimension size for which to apply dropout. 21 | For instance, if a tensor if of shape (32, 32, 3, 3) and min_dim = 4 22 | then dropout will *not* be applied to the last two modes. 23 | """ 24 | _factorizations = dict() 25 | 26 | def __init_subclass__(cls, factorization, **kwargs): 27 | """When a subclass is created, register it in _factorizations""" 28 | cls._factorizations[factorization.__name__] = cls 29 | 30 | def __init__(self, proba, min_dim=1, min_values=1, drop_test=False): 31 | assert 0 <= proba < 1, f'Got prob={proba} but tensor dropout is defined for 0 <= proba < 1.' 32 | self.proba = proba 33 | self.min_dim = min_dim 34 | self.min_values = min_values 35 | self.drop_test = drop_test 36 | 37 | def __call__(self, module, input, factorized_tensor): 38 | return self._apply_tensor_dropout(factorized_tensor, training=module.training) 39 | 40 | def _apply_tensor_dropout(self, factorized_tensor, training=True): 41 | raise NotImplementedError() 42 | 43 | @classmethod 44 | def apply(cls, module, proba, min_dim=3, min_values=1, drop_test=False): 45 | cls = cls._factorizations[module.__class__.__name__] 46 | for k, hook in module._forward_hooks.items(): 47 | if isinstance(hook, cls): 48 | raise RuntimeError("Cannot register two weight_norm hooks on " 49 | "the same parameter") 50 | 51 | dropout = cls(proba, min_dim=min_dim, min_values=min_values, drop_test=drop_test) 52 | handle = module.register_forward_hook(dropout) 53 | return handle 54 | 55 | 56 | class TuckerDropout(TensorDropout, factorization=TuckerTensor): 57 | def _apply_tensor_dropout(self, tucker_tensor, training=True): 58 | if (not self.proba) or ((not training) and (not self.drop_test)): 59 | return tucker_tensor 60 | 61 | core, factors = tucker_tensor.core, tucker_tensor.factors 62 | tucker_rank = tucker_tensor.rank 63 | 64 | sampled_indices = [] 65 | for rank in tucker_rank: 66 | idx = tl.arange(rank, device=core.device, dtype=torch.int64) 67 | if rank > self.min_dim: 68 | idx = idx[torch.bernoulli(torch.ones(rank, device=core.device)*(1 - self.proba), 69 | out=torch.empty(rank, device=core.device, dtype=torch.bool))] 70 | if len(idx) == 0: 71 | idx = torch.randint(0, rank, size=(self.min_values, ), device=core.device, dtype=torch.int64) 72 | 73 | sampled_indices.append(idx) 74 | 75 | if training: 76 | core = core[torch.meshgrid(*sampled_indices)]*(1/((1 - self.proba)**core.ndim)) 77 | else: 78 | core = core[torch.meshgrid(*sampled_indices)] 79 | 80 | factors = [factor[:, idx] for (factor, idx) in zip(factors, sampled_indices)] 81 | 82 | return TuckerTensor(core, factors) 83 | 84 | 85 | class CPDropout(TensorDropout, factorization=CPTensor): 86 | def _apply_tensor_dropout(self, cp_tensor, training=True): 87 | if (not self.proba) or ((not training) and (not self.drop_test)): 88 | return cp_tensor 89 | 90 | rank = cp_tensor.rank 91 | device = cp_tensor.factors[0].device 92 | 93 | if rank > self.min_dim: 94 | sampled_indices = tl.arange(rank, device=device, dtype=torch.int64) 95 | sampled_indices = sampled_indices[torch.bernoulli(torch.ones(rank, device=device)*(1 - self.proba), 96 | out=torch.empty(rank, device=device, dtype=torch.bool))] 97 | if len(sampled_indices) == 0: 98 | sampled_indices = torch.randint(0, rank, size=(self.min_values, ), device=device, dtype=torch.int64) 99 | 100 | factors = [factor[:, sampled_indices] for factor in cp_tensor.factors] 101 | if training: 102 | weights = cp_tensor.weights[sampled_indices]*(1/(1 - self.proba)) 103 | else: 104 | weights = cp_tensor.weights[sampled_indices] 105 | 106 | return CPTensor(weights, factors) 107 | 108 | 109 | class TTDropout(TensorDropout, factorization=TTTensor): 110 | def _apply_tensor_dropout(self, tt_tensor, training=True): 111 | if (not self.proba) or ((not training) and (not self.drop_test)): 112 | return tt_tensor 113 | 114 | device = tt_tensor.factors[0].device 115 | 116 | sampled_indices = [] 117 | for i, rank in enumerate(tt_tensor.rank[1:]): 118 | if rank > self.min_dim: 119 | idx = tl.arange(rank, device=device, dtype=torch.int64) 120 | idx = idx[torch.bernoulli(torch.ones(rank, device=device)*(1 - self.proba), 121 | out=torch.empty(rank, device=device, dtype=torch.bool))] 122 | if len(idx) == 0: 123 | idx = torch.randint(0, rank, size=(self.min_values, ), device=device, dtype=torch.int64) 124 | else: 125 | idx = tl.arange(rank, device=device, dtype=torch.int64).tolist() 126 | 127 | sampled_indices.append(idx) 128 | 129 | sampled_factors = [] 130 | if training: 131 | scaling = 1/(1 - self.proba) 132 | else: 133 | scaling = 1 134 | for i, f in enumerate(tt_tensor.factors): 135 | if i == 0: 136 | sampled_factors.append(f[..., sampled_indices[i]]*scaling) 137 | elif i == (tt_tensor.order - 1): 138 | sampled_factors.append(f[sampled_indices[i-1], ...]) 139 | else: 140 | sampled_factors.append(f[sampled_indices[i-1], ...][..., sampled_indices[i]]*scaling) 141 | 142 | return TTTensor(sampled_factors) 143 | 144 | 145 | def tensor_dropout(factorized_tensor, p=0, min_dim=3, min_values=1, drop_test=False): 146 | """Tensor Dropout 147 | 148 | Parameters 149 | ---------- 150 | factorized_tensor : FactorizedTensor 151 | the tensor module parametrized by the tensor decomposition to which to apply tensor dropout 152 | p : float 153 | dropout probability 154 | if 0, no dropout is applied 155 | if 1, all the components but 1 are dropped in the latent space 156 | min_dim : int, default is 3 157 | only apply dropout to modes with dimension larger than `min_dim` 158 | min_values : int, default is 1 159 | minimum number of components to select 160 | 161 | Returns 162 | ------- 163 | FactorizedTensor 164 | the module to which tensor dropout has been attached 165 | 166 | Examples 167 | -------- 168 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_() 169 | >>> tensor = tensor_dropout(tensor, p=0.5) 170 | >>> remove_tensor_dropout(tensor) 171 | """ 172 | TensorDropout.apply(factorized_tensor, p, min_dim=min_dim, min_values=min_values, drop_test=drop_test) 173 | 174 | return factorized_tensor 175 | 176 | 177 | def remove_tensor_dropout(factorized_tensor): 178 | """Removes the tensor dropout from a TensorModule 179 | 180 | Parameters 181 | ---------- 182 | factorized_tensor : tltorch.FactorizedTensor 183 | the tensor module parametrized by the tensor decomposition to which to apply tensor dropout 184 | 185 | Examples 186 | -------- 187 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_() 188 | >>> tensor = tensor_dropout(tensor, p=0.5) 189 | >>> remove_tensor_dropout(tensor) 190 | """ 191 | for key, hook in factorized_tensor._forward_hooks.items(): 192 | if isinstance(hook, TensorDropout): 193 | del factorized_tensor._forward_hooks[key] 194 | return factorized_tensor 195 | 196 | raise ValueError(f'TensorLasso not found in factorized tensor {factorized_tensor}') 197 | -------------------------------------------------------------------------------- /tltorch/tensor_hooks/_tensor_lasso.py: -------------------------------------------------------------------------------- 1 | import tensorly as tl 2 | tl.set_backend('pytorch') 3 | 4 | import warnings 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from ..factorized_tensors import TuckerTensor, TTTensor, CPTensor 10 | from ..utils import ParameterList 11 | 12 | # Author: Jean Kossaifi 13 | # License: BSD 3 clause 14 | 15 | class TensorLasso: 16 | """Generalized Tensor Lasso on factorized tensors 17 | 18 | Applies a generalized Lasso (l1 regularization) on a factorized tensor. 19 | 20 | 21 | Parameters 22 | ---------- 23 | penalty : float, default is 0.01 24 | scaling factor for the loss 25 | 26 | clamp_weights : bool, default is True 27 | if True, the lasso weights are clamp between -1 and 1 28 | 29 | threshold : float, default is 1e-6 30 | if a lasso weight is lower than the set threshold, it is set to 0 31 | 32 | normalize_loss : bool, default is True 33 | If True, the loss will be between 0 and 1. 34 | Otherwise, the raw sum of absolute weights will be returned. 35 | 36 | Examples 37 | -------- 38 | 39 | First you need to create an instance of the regularizer: 40 | 41 | >>> regularizer = tensor_lasso(factorization='cp') 42 | 43 | You can apply the regularizer to one or several layers: 44 | 45 | >>> trl = TRL((5, 5), (5, 5), rank='same') 46 | >>> trl2 = TRL((5, 5), (2, ), rank='same') 47 | >>> regularizer.apply(trl.weight) 48 | >>> regularizer.apply(trl2.weight) 49 | 50 | The lasso is automatically applied: 51 | 52 | >>> x = trl(x) 53 | >>> pred = trl2(x) 54 | >>> loss = your_loss_function(pred) 55 | 56 | Add the Lasso loss: 57 | 58 | >>> loss = loss + regularizer.loss 59 | 60 | You can now backpropagate through your loss as usual: 61 | 62 | >>> loss.backwards() 63 | 64 | After you finish updating the weights, don't forget to reset the regularizer, 65 | otherwise it will keep accumulating values! 66 | 67 | >>> loss.reset() 68 | 69 | You can also remove the regularizer with `regularizer.remove(trl)`. 70 | """ 71 | _factorizations = dict() 72 | 73 | def __init_subclass__(cls, factorization, **kwargs): 74 | """When a subclass is created, register it in _factorizations""" 75 | cls._factorizations[factorization.__name__] = cls 76 | 77 | def __init__(self, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True): 78 | self.penalty = penalty 79 | self.clamp_weights = clamp_weights 80 | self.threshold = threshold 81 | self.normalize_loss = normalize_loss 82 | 83 | # Initialize the counters 84 | self.reset() 85 | 86 | def reset(self): 87 | """Reset the loss, should be called at the end of each iteration. 88 | """ 89 | self._loss = 0 90 | self.n_element = 0 91 | 92 | @property 93 | def loss(self): 94 | """Returns the current Lasso (l1) loss for the layers that have been called so far. 95 | 96 | Returns 97 | ------- 98 | float 99 | l1 regularization on the tensor layers the regularization has been applied to. 100 | """ 101 | if self.n_element == 0: 102 | warnings.warn('The L1Regularization was not applied to any weights.') 103 | return 0 104 | elif self.normalize_loss: 105 | return self.penalty*self._loss/self.n_element 106 | else: 107 | return self.penalty*self._loss 108 | 109 | def __call__(self, module, input, tucker_tensor): 110 | raise NotImplementedError 111 | 112 | def apply_lasso(self, tucker_tensor, lasso_weights): 113 | """Applies the lasso to a decomposed tensor 114 | """ 115 | raise NotImplementedError 116 | 117 | @classmethod 118 | def from_factorization(cls, factorization, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True): 119 | return cls.from_factorization_name(factorization.__class__.__name__, penalty=penalty, 120 | clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss) 121 | 122 | @classmethod 123 | def from_factorization_name(cls, factorization_name, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True): 124 | cls = cls._factorizations[factorization_name] 125 | lasso = cls(penalty=penalty, clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss) 126 | return lasso 127 | 128 | def remove(self, module): 129 | raise NotImplementedError 130 | 131 | 132 | class CPLasso(TensorLasso, factorization=CPTensor): 133 | """Decomposition Hook for Tensor Lasso on CP tensors 134 | 135 | Parameters 136 | ---------- 137 | penalty : float, default is 0.01 138 | scaling factor for the loss 139 | 140 | clamp_weights : bool, default is True 141 | if True, the lasso weights are clamp between -1 and 1 142 | 143 | threshold : float, default is 1e-6 144 | if a lasso weight is lower than the set threshold, it is set to 0 145 | 146 | normalize_loss : bool, default is True 147 | If True, the loss will be between 0 and 1. 148 | Otherwise, the raw sum of absolute weights will be returned. 149 | """ 150 | def __call__(self, module, input, cp_tensor): 151 | """CP already includes weights, we'll just take their l1 norm 152 | """ 153 | weights = getattr(module, 'lasso_weights') 154 | 155 | with torch.no_grad(): 156 | if self.clamp_weights: 157 | weights.data = torch.clamp(weights.data, -1, 1) 158 | setattr(module, 'lasso_weights', weights) 159 | 160 | if self.threshold: 161 | weights.data = F.threshold(weights.data, threshold=self.threshold, value=0, inplace=True) 162 | setattr(module, 'lasso_weights', weights) 163 | 164 | self.n_element += weights.numel() 165 | self._loss = self._loss + self.penalty*torch.norm(weights, 1) 166 | return cp_tensor 167 | 168 | def apply(self, module): 169 | """Apply an instance of the L1Regularizer to a tensor module 170 | 171 | Parameters 172 | ---------- 173 | module : TensorModule 174 | module on which to add the regularization 175 | 176 | Returns 177 | ------- 178 | TensorModule (with Regularization hook) 179 | """ 180 | context = tl.context(module.factors[0]) 181 | lasso_weights = nn.Parameter(torch.ones(module.rank, **context)) 182 | setattr(module, 'lasso_weights', lasso_weights) 183 | 184 | module.register_forward_hook(self) 185 | return module 186 | 187 | def remove(self, module): 188 | delattr(module, 'lasso_weights') 189 | 190 | def set_weights(self, module, value): 191 | with torch.no_grad(): 192 | module.lasso_weights.data.fill_(value) 193 | 194 | 195 | class TuckerLasso(TensorLasso, factorization=TuckerTensor): 196 | """Decomposition Hook for Tensor Lasso on Tucker tensors 197 | 198 | Applies a generalized Lasso (l1 regularization) on the tensor layers the regularization it is applied to. 199 | 200 | 201 | Parameters 202 | ---------- 203 | penalty : float, default is 0.01 204 | scaling factor for the loss 205 | 206 | clamp_weights : bool, default is True 207 | if True, the lasso weights are clamp between -1 and 1 208 | 209 | threshold : float, default is 1e-6 210 | if a lasso weight is lower than the set threshold, it is set to 0 211 | 212 | normalize_loss : bool, default is True 213 | If True, the loss will be between 0 and 1. 214 | Otherwise, the raw sum of absolute weights will be returned. 215 | """ 216 | _log = [] 217 | 218 | def __call__(self, module, input, tucker_tensor): 219 | lasso_weights = getattr(module, 'lasso_weights') 220 | order = len(lasso_weights) 221 | 222 | with torch.no_grad(): 223 | for i in range(order): 224 | if self.clamp_weights: 225 | lasso_weights[i].data = torch.clamp(lasso_weights[i].data, -1, 1) 226 | 227 | if self.threshold: 228 | lasso_weights[i] = F.threshold(lasso_weights[i], threshold=self.threshold, value=0, inplace=True) 229 | 230 | setattr(module, 'lasso_weights', lasso_weights) 231 | 232 | for weight in lasso_weights: 233 | self.n_element += weight.numel() 234 | self._loss = self._loss + torch.sum(torch.abs(weight)) 235 | 236 | return self.apply_lasso(tucker_tensor, lasso_weights) 237 | 238 | def apply_lasso(self, tucker_tensor, lasso_weights): 239 | """Applies the lasso to a decomposed tensor 240 | """ 241 | factors = tucker_tensor.factors 242 | factors = [factor*w for (factor, w) in zip(factors, lasso_weights)] 243 | return TuckerTensor(tucker_tensor.core, factors) 244 | 245 | def apply(self, module): 246 | """Apply an instance of the L1Regularizer to a tensor module 247 | 248 | Parameters 249 | ---------- 250 | module : TensorModule 251 | module on which to add the regularization 252 | 253 | Returns 254 | ------- 255 | TensorModule (with Regularization hook) 256 | """ 257 | rank = module.rank 258 | context = tl.context(module.core) 259 | lasso_weights = ParameterList([nn.Parameter(torch.ones(r, **context)) for r in rank]) 260 | setattr(module, 'lasso_weights', lasso_weights) 261 | module.register_forward_hook(self) 262 | 263 | return module 264 | 265 | def remove(self, module): 266 | delattr(module, 'lasso_weights') 267 | 268 | def set_weights(self, module, value): 269 | with torch.no_grad(): 270 | for weight in module.lasso_weights: 271 | weight.data.fill_(value) 272 | 273 | 274 | class TTLasso(TensorLasso, factorization=TTTensor): 275 | """Decomposition Hook for Tensor Lasso on TT tensors 276 | 277 | Parameters 278 | ---------- 279 | penalty : float, default is 0.01 280 | scaling factor for the loss 281 | 282 | clamp_weights : bool, default is True 283 | if True, the lasso weights are clamp between -1 and 1 284 | 285 | threshold : float, default is 1e-6 286 | if a lasso weight is lower than the set threshold, it is set to 0 287 | 288 | normalize_loss : bool, default is True 289 | If True, the loss will be between 0 and 1. 290 | Otherwise, the raw sum of absolute weights will be returned. 291 | """ 292 | def __call__(self, module, input, tt_tensor): 293 | lasso_weights = getattr(module, 'lasso_weights') 294 | order = len(lasso_weights) 295 | 296 | with torch.no_grad(): 297 | for i in range(order): 298 | if self.clamp_weights: 299 | lasso_weights[i].data = torch.clamp(lasso_weights[i].data, -1, 1) 300 | 301 | if self.threshold: 302 | lasso_weights[i] = F.threshold(lasso_weights[i], threshold=self.threshold, value=0, inplace=True) 303 | 304 | setattr(module, 'lasso_weights', lasso_weights) 305 | 306 | for weight in lasso_weights: 307 | self.n_element += weight.numel() 308 | self._loss = self._loss + torch.sum(torch.abs(weight)) 309 | 310 | return self.apply_lasso(tt_tensor, lasso_weights) 311 | 312 | def apply_lasso(self, tt_tensor, lasso_weights): 313 | """Applies the lasso to a decomposed tensor 314 | """ 315 | factors = tt_tensor.factors 316 | factors = [factor*w for (factor, w) in zip(factors, lasso_weights)] + [factors[-1]] 317 | return TTTensor(factors) 318 | 319 | def apply(self, module): 320 | """Apply an instance of the L1Regularizer to a tensor module 321 | 322 | Parameters 323 | ---------- 324 | module : TensorModule 325 | module on which to add the regularization 326 | 327 | Returns 328 | ------- 329 | TensorModule (with Regularization hook) 330 | """ 331 | rank = module.rank[1:-1] 332 | lasso_weights = ParameterList([nn.Parameter(torch.ones(1, 1, r)) for r in rank]) 333 | setattr(module, 'lasso_weights', lasso_weights) 334 | handle = module.register_forward_hook(self) 335 | return module 336 | 337 | def remove(self, module): 338 | """Remove the Regularization from a module. 339 | """ 340 | delattr(module, 'lasso_weights') 341 | 342 | def set_weights(self, module, value): 343 | with torch.no_grad(): 344 | for weight in module.lasso_weights: 345 | weight.data.fill_(value) 346 | 347 | 348 | def tensor_lasso(factorization='CP', penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True): 349 | """Generalized Tensor Lasso from a factorized tensors 350 | 351 | Applies a generalized Lasso (l1 regularization) on a factorized tensor. 352 | 353 | 354 | Parameters 355 | ---------- 356 | factorization : str 357 | 358 | penalty : float, default is 0.01 359 | scaling factor for the loss 360 | 361 | clamp_weights : bool, default is True 362 | if True, the lasso weights are clamp between -1 and 1 363 | 364 | threshold : float, default is 1e-6 365 | if a lasso weight is lower than the set threshold, it is set to 0 366 | 367 | normalize_loss : bool, default is True 368 | If True, the loss will be between 0 and 1. 369 | Otherwise, the raw sum of absolute weights will be returned. 370 | 371 | Examples 372 | -------- 373 | 374 | Let's say you have a set of factorized (here, CP) tensors: 375 | 376 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank='same', factorization='CP').normal_() 377 | >>> tensor2 = FactorizedTensor.new((5, 6, 7), rank=0.5, factorization='CP').normal_() 378 | 379 | First you need to create an instance of the regularizer: 380 | 381 | >>> regularizer = TensorLasso(factorization='cp', penalty=penalty) 382 | 383 | You can apply the regularizer to one or several layers: 384 | 385 | >>> regularizer.apply(tensor) 386 | >>> regularizer.apply(tensor2) 387 | 388 | The lasso is automatically applied: 389 | 390 | >>> sum = torch.sum(tensor() + tensor2()) 391 | 392 | You can access the Lasso loss from your instance: 393 | 394 | >>> l1_loss = regularizer.loss 395 | 396 | You can optimize and backpropagate through your loss as usual. 397 | 398 | After you finish updating the weights, don't forget to reset the regularizer, 399 | otherwise it will keep accumulating values! 400 | 401 | >>> regularizer.reset() 402 | 403 | You can also remove the regularizer with `regularizer.remove(tensor)`, 404 | or `remove_tensor_lasso(tensor)`. 405 | """ 406 | factorization = factorization.lower() 407 | mapping = dict(cp='CPTensor', tucker='TuckerTensor', tt='TTTensor') 408 | return TensorLasso.from_factorization_name(mapping[factorization], penalty=penalty, clamp_weights=clamp_weights, 409 | threshold=threshold, normalize_loss=normalize_loss) 410 | 411 | def remove_tensor_lasso(factorized_tensor): 412 | """Removes the tensor lasso from a TensorModule 413 | 414 | Parameters 415 | ---------- 416 | factorized_tensor : tltorch.FactorizedTensor 417 | the tensor module parametrized by the tensor decomposition to which to apply tensor dropout 418 | 419 | Examples 420 | -------- 421 | >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_() 422 | >>> tensor = tensor_lasso(tensor, p=0.5) 423 | >>> remove_tensor_lasso(tensor) 424 | """ 425 | for key, hook in factorized_tensor._forward_hooks.items(): 426 | if isinstance(hook, TensorLasso): 427 | hook.remove(factorized_tensor) 428 | del factorized_tensor._forward_hooks[key] 429 | return factorized_tensor 430 | 431 | raise ValueError(f'TensorLasso not found in factorized tensor {factorized_tensor}') 432 | -------------------------------------------------------------------------------- /tltorch/tensor_hooks/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorly/torch/e602edf6f25127f9e1fa0657b0069e5dfadeca95/tltorch/tensor_hooks/tests/__init__.py -------------------------------------------------------------------------------- /tltorch/tensor_hooks/tests/test_tensor_dropout.py: -------------------------------------------------------------------------------- 1 | from ... import FactorizedTensor 2 | from ...factorized_layers import TRL 3 | from .._tensor_dropout import tensor_dropout, remove_tensor_dropout 4 | 5 | import tensorly as tl 6 | tl.set_backend('pytorch') 7 | 8 | def test_tucker_dropout(): 9 | """Test for Tucker Dropout""" 10 | shape = (10, 11, 12) 11 | rank = (7, 8, 9) 12 | tensor = FactorizedTensor.new(shape, rank=rank, factorization='Tucker') 13 | tensor = tensor_dropout(tensor, 0.999) 14 | core = tensor().core 15 | assert (tl.shape(core) == (1, 1, 1)) 16 | 17 | remove_tensor_dropout(tensor) 18 | assert (not tensor._forward_hooks) 19 | 20 | tensor = tensor_dropout(tensor, 0) 21 | core = tensor().core 22 | assert (tl.shape(core) == rank) 23 | 24 | def test_cp_dropout(): 25 | """Test for CP Dropout""" 26 | shape = (10, 11, 12) 27 | rank = 8 28 | tensor = FactorizedTensor.new(shape, rank=rank, factorization='CP') 29 | tensor = tensor_dropout(tensor, 0.999) 30 | weights = tensor().weights 31 | assert (len(weights) == (1)) 32 | 33 | remove_tensor_dropout(tensor) 34 | assert (not tensor._forward_hooks) 35 | 36 | tensor = tensor_dropout(tensor, 0) 37 | weights = tensor().weights 38 | assert (len(weights) == rank) 39 | 40 | 41 | def test_tt_dropout(): 42 | """Test for TT Dropout""" 43 | shape = (10, 11, 12) 44 | # Use the same rank for all factors 45 | rank = 4 46 | tensor = FactorizedTensor.new(shape, rank=rank, factorization='TT') 47 | tensor = tensor_dropout(tensor, 0.999) 48 | factors = tensor().factors 49 | for f in factors: 50 | assert (f.shape[0] == f.shape[-1] == 1) 51 | 52 | remove_tensor_dropout(tensor) 53 | assert (not tensor._forward_hooks) 54 | 55 | tensor = tensor_dropout(tensor, 0) 56 | factors = tensor().factors 57 | for i, f in enumerate(factors): 58 | if i: 59 | assert (f.shape[0] == rank) 60 | else: # boundary conditions: first and last rank are equal to 1 61 | assert (f.shape[-1] == rank) 62 | -------------------------------------------------------------------------------- /tltorch/tensor_hooks/tests/test_tensor_lasso.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .._tensor_lasso import TuckerLasso, TTLasso, CPLasso, remove_tensor_lasso, tensor_lasso 4 | from ... import FactorizedTensor 5 | 6 | import tensorly as tl 7 | tl.set_backend('pytorch') 8 | from tensorly import testing 9 | import numpy as np 10 | import torch 11 | 12 | @pytest.mark.parametrize('factorization', ['cp', 'tucker', 'tt']) 13 | def test_tensor_lasso(factorization): 14 | shape = (5, 5, 6) 15 | rank = 3 16 | tensor1 = FactorizedTensor.new(shape, rank, factorization=factorization).normal_() 17 | tensor2 = FactorizedTensor.new(shape, rank, factorization=factorization).normal_() 18 | 19 | lasso = tensor_lasso(factorization, penalty=1, clamp_weights=False, normalize_loss=False) 20 | 21 | lasso.apply(tensor1) 22 | lasso.apply(tensor2) 23 | 24 | # Sum of weights all equal to a given value 25 | value = 1.5 26 | lasso.set_weights(tensor1, value) 27 | lasso.set_weights(tensor2, value) 28 | 29 | data = torch.sum(tensor1()) 30 | l1 = lasso.loss 31 | data = torch.sum(tensor2()) 32 | l2 = lasso.loss 33 | 34 | # The result should be n-param * value 35 | # First tensor 36 | if factorization == 'tt': 37 | sum_rank = sum(tensor1.rank[1:-1]) 38 | elif factorization == 'tucker': 39 | sum_rank = sum(tensor1.rank) 40 | elif factorization == 'cp': 41 | sum_rank = tensor1.rank 42 | testing.assert_(l1 == sum_rank*value) 43 | # Second tensor lasso 44 | if factorization == 'tt': 45 | sum_rank += sum(tensor2.rank[1:-1]) 46 | elif factorization == 'tucker': 47 | sum_rank += sum(tensor2.rank) 48 | elif factorization == 'cp': 49 | sum_rank += tensor2.rank 50 | testing.assert_(l2 == sum_rank*value) 51 | 52 | testing.assert_(tensor1._forward_hooks) 53 | testing.assert_(tensor2._forward_hooks) 54 | 55 | ### Test when all weights are 0 56 | lasso.reset() 57 | lasso.set_weights(tensor1, 0) 58 | lasso.set_weights(tensor2, 0) 59 | torch.sum(tensor1()) + torch.sum(tensor2()) 60 | testing.assert_(lasso.loss == 0) 61 | 62 | # Check the Lasso correctly removed 63 | remove_tensor_lasso(tensor1) 64 | testing.assert_(not tensor1._forward_hooks) 65 | testing.assert_(tensor2._forward_hooks) 66 | remove_tensor_lasso(tensor2) 67 | testing.assert_(not tensor1._forward_hooks) 68 | 69 | # Check normalization between 0 and 1 70 | lasso = tensor_lasso(factorization, penalty=1, normalize_loss=True, clamp_weights=False) 71 | lasso.apply(tensor1) 72 | tensor1() 73 | l1 = lasso.loss 74 | assert(abs(l1 - 1) < 1e-5) 75 | remove_tensor_lasso(tensor1) -------------------------------------------------------------------------------- /tltorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .parameter_list import ParameterList, FactorList, ComplexFactorList 2 | from .tensorize_shape import get_tensorized_shape -------------------------------------------------------------------------------- /tltorch/utils/parameter_list.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class FactorList(nn.Module): 6 | def __init__(self, parameters=None): 7 | super().__init__() 8 | self.keys = [] 9 | self.counter = 0 10 | if parameters is not None: 11 | self.extend(parameters) 12 | 13 | def _unique_key(self): 14 | """Creates a new unique key""" 15 | key = f'factor_{self.counter}' 16 | self.counter += 1 17 | return key 18 | 19 | def append(self, element): 20 | key = self._unique_key() 21 | if torch.is_tensor(element): 22 | if isinstance(element, nn.Parameter): 23 | self.register_parameter(key, element) 24 | else: 25 | self.register_buffer(key, element) 26 | else: 27 | setattr(self, key, self.__class__(element)) 28 | self.keys.append(key) 29 | 30 | def insert(self, index, element): 31 | key = self._unique_key() 32 | setattr(self ,key, element) 33 | self.keys.insert(index, key) 34 | 35 | def pop(self, index=-1): 36 | item = self[index] 37 | self.__delitem__(index) 38 | return item 39 | 40 | def __getitem__(self, index): 41 | keys = self.keys[index] 42 | if isinstance(keys, list): 43 | return self.__class__([getattr(self, key) for key in keys]) 44 | return getattr(self, keys) 45 | 46 | def __setitem__(self, index, value): 47 | setattr(self, self.keys[index], value) 48 | 49 | def __delitem__(self, index): 50 | delattr(self, self.keys[index]) 51 | self.keys.__delitem__(index) 52 | 53 | def __len__(self): 54 | return len(self.keys) 55 | 56 | def extend(self, parameters): 57 | for param in parameters: 58 | self.append(param) 59 | 60 | def __iadd__(self, parameters): 61 | return self.extend(parameters) 62 | 63 | def __add__(self, parameters): 64 | instance = self.__class__(self) 65 | instance.extend(parameters) 66 | return instance 67 | 68 | def __radd__(self, parameters): 69 | instance = self.__class__(parameters) 70 | instance.extend(self) 71 | return instance 72 | 73 | def extra_repr(self) -> str: 74 | child_lines = [] 75 | for k, p in self._parameters.items(): 76 | size_str = 'x'.join(str(size) for size in p.size()) 77 | device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) 78 | parastr = 'Parameter containing: [{} of size {}{}]'.format( 79 | torch.typename(p), size_str, device_str) 80 | child_lines.append(' (' + str(k) + '): ' + parastr) 81 | tmpstr = '\n'.join(child_lines) 82 | return tmpstr 83 | 84 | 85 | class ComplexFactorList(FactorList): 86 | def __getitem__(self, index): 87 | if isinstance(index, int): 88 | value = getattr(self, self.keys[index]) 89 | if torch.is_tensor(value): 90 | value = torch.view_as_complex(value) 91 | return value 92 | else: 93 | keys = self.keys[index] 94 | return self.__class__([torch.view_as_complex(getattr(self, key)) for key in keys]) 95 | 96 | def __setitem__(self, index, value): 97 | if torch.is_tensor(value): 98 | value = torch.view_as_real(value) 99 | setattr(self, self.keys[index], value) 100 | 101 | def register_parameter(self, key, value): 102 | value = nn.Parameter(torch.view_as_real(value)) 103 | super().register_parameter(key, value) 104 | 105 | def register_buffer(self, key, value): 106 | value = torch.view_as_real(value) 107 | super().register_buffer(key, value) 108 | 109 | 110 | class ParameterList(nn.Module): 111 | def __init__(self, parameters=None): 112 | super().__init__() 113 | self.keys = [] 114 | self.counter = 0 115 | if parameters is not None: 116 | self.extend(parameters) 117 | 118 | def _unique_key(self): 119 | """Creates a new unique key""" 120 | key = f'param_{self.counter}' 121 | self.counter += 1 122 | return key 123 | 124 | def append(self, element): 125 | # p = nn.Parameter(element) 126 | key = self._unique_key() 127 | self.register_parameter(key, element) 128 | self.keys.append(key) 129 | 130 | def insert(self, index, element): 131 | # p = nn.Parameter(element) 132 | key = self._unique_key() 133 | self.register_parameter(key, element) 134 | self.keys.insert(index, key) 135 | 136 | def pop(self, index=-1): 137 | item = self[index] 138 | self.__delitem__(index) 139 | return item 140 | 141 | def __getitem__(self, index): 142 | keys = self.keys[index] 143 | if isinstance(keys, list): 144 | return self.__class__([getattr(self, key) for key in keys]) 145 | return getattr(self, keys) 146 | 147 | def __setitem__(self, index, value): 148 | self.register_parameter(self.keys[index], value) 149 | 150 | def __delitem__(self, index): 151 | delattr(self, self.keys[index]) 152 | self.keys.__delitem__(index) 153 | 154 | def __len__(self): 155 | return len(self.keys) 156 | 157 | def extend(self, parameters): 158 | for param in parameters: 159 | self.append(param) 160 | 161 | def __iadd__(self, parameters): 162 | return self.extend(parameters) 163 | 164 | def extra_repr(self) -> str: 165 | child_lines = [] 166 | for k, p in self._parameters.items(): 167 | size_str = 'x'.join(str(size) for size in p.size()) 168 | device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) 169 | parastr = 'Parameter containing: [{} of size {}{}]'.format( 170 | torch.typename(p), size_str, device_str) 171 | child_lines.append(' (' + str(k) + '): ' + parastr) 172 | tmpstr = '\n'.join(child_lines) 173 | return tmpstr 174 | -------------------------------------------------------------------------------- /tltorch/utils/tensorize_shape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from bisect import insort_left 4 | 5 | # Author : Jean Kossaifi 6 | 7 | def factorize(value, min_value=2, remaining=-1): 8 | """Factorize an integer input value into it's smallest divisors 9 | 10 | Parameters 11 | ---------- 12 | value : int 13 | integer to factorize 14 | min_value : int, default is 2 15 | smallest divisors to use 16 | remaining : int, default is -1 17 | DO NOT SPECIFY THIS VALUE, IT IS USED FOR TAIL RECURSION 18 | 19 | Returns 20 | ------- 21 | factorization : int tuple 22 | ints such that prod(factorization) == value 23 | """ 24 | if value <= min_value or remaining == 0: 25 | return (value, ) 26 | lim = math.isqrt(value) 27 | for i in range(min_value, lim+1): 28 | if value == i: 29 | return (i, ) 30 | if not (value % i): 31 | return (i, *factorize(value//i, min_value=min_value, remaining=remaining-1)) 32 | return (value, ) 33 | 34 | def merge_ints(values, size): 35 | """Utility function to merge the smallest values in a given tuple until it's length is the given size 36 | 37 | Parameters 38 | ---------- 39 | values : int list 40 | list of values to merge 41 | size : int 42 | target len of the list 43 | stop merging when len(values) <= size 44 | 45 | Returns 46 | ------- 47 | merge_values : list of size ``size`` 48 | """ 49 | if len(values) <= 1: 50 | return values 51 | 52 | values = sorted(list(values)) 53 | while (len(values) > size): 54 | a, b, *values = values 55 | insort_left(values, a*b) 56 | 57 | return tuple(values) 58 | 59 | def get_tensorized_shape(in_features, out_features, order=None, min_dim=2, verbose=True): 60 | """ Factorizes in_features and out_features such that: 61 | * they both are factorized into the same number of integers 62 | * they should both be factorized into `order` integers 63 | * each of the factors should be at least min_dim 64 | 65 | This is used to tensorize a matrix of size (in_features, out_features) into a higher order tensor 66 | 67 | Parameters 68 | ---------- 69 | in_features, out_features : int 70 | order : int 71 | the number of integers that each input should be factorized into 72 | min_dim : int 73 | smallest acceptable integer value for the factors 74 | 75 | Returns 76 | ------- 77 | in_tensorized, out_tensorized : tuple[int] 78 | tuples of ints used to tensorize each dimension 79 | 80 | Notes 81 | ----- 82 | This is a bruteforce solution but is enough for the dimensions we encounter in DNNs 83 | """ 84 | in_ten = factorize(in_features, min_value=min_dim) 85 | out_ten = factorize(out_features, min_value=min_dim, remaining=len(in_ten)) 86 | if order is not None: 87 | merge_size = min(order, len(in_ten), len(out_ten)) 88 | else: 89 | merge_size = min(len(in_ten), len(out_ten)) 90 | 91 | if len(in_ten) > merge_size: 92 | in_ten = merge_ints(in_ten, size=merge_size) 93 | if len(out_ten) > merge_size: 94 | out_ten = merge_ints(out_ten, size=merge_size) 95 | 96 | if verbose: 97 | print(f'Tensorizing (in, out)=({in_features, out_features}) -> ({in_ten, out_ten})') 98 | return in_ten, out_ten 99 | --------------------------------------------------------------------------------