├── mogrifier
├── __init__.py
└── mogrifier.py
├── mogrifier.png
├── transmogrifier.jpg
├── setup.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── README.md
└── .gitignore
/mogrifier/__init__.py:
--------------------------------------------------------------------------------
1 | from mogrifier.mogrifier import Mogrifier
2 |
--------------------------------------------------------------------------------
/mogrifier.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/mogrifier/HEAD/mogrifier.png
--------------------------------------------------------------------------------
/transmogrifier.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/mogrifier/HEAD/transmogrifier.jpg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'mogrifier',
5 | packages = find_packages(),
6 | version = '0.0.5',
7 | license='MIT',
8 | description = 'Implementation of Mogrifier circuit from Deepmind',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/mogrifier',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'natural language processing',
16 | 'improved conditioning'
17 | ],
18 | install_requires=[
19 | 'einops>=0.8',
20 | 'torch'
21 | ],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Programming Language :: Python :: 3.6',
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](https://badge.fury.io/py/mogrifier)
4 |
5 | ## Mogrifier
6 |
7 | A complete implementation of Mogrifier, a circuit for enhancing LSTMs and potentially other networks. It allows two vectors to modulate each other by having each gate the other in an interleaved, iterative fashion.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install mogrifier
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from mogrifier import Mogrifier
20 |
21 | mogrify = Mogrifier(
22 | dim = 512,
23 | dim_hidden = 256,
24 | iters = 5, # number of iterations, defaults to 5 as paper recommended for LSTM
25 | factorize_k = 16 # factorize weight matrices into (dim x k) and (k x dim), if specified
26 | )
27 |
28 | x = torch.randn(1, 16, 512)
29 | h = torch.randn(1, 16, 256)
30 |
31 | out, hidden_out = mogrify(x, h) # (1, 16, 512), (1, 16, 256)
32 |
33 | assert out.shape == x.shape
34 | assert hidden_out.shape == h.shape
35 | ```
36 |
37 | ## Citation
38 |
39 | ```bibtex
40 | @inproceedings{Melis2020Mogrifier,
41 | title = {Mogrifier LSTM},
42 | author = {Gábor Melis and Tomáš Kočiský and Phil Blunsom},
43 | booktitle = {International Conference on Learning Representations},
44 | year = {2020},
45 | url = {https://openreview.net/forum?id=SJe5P6EYvS}
46 | }
47 | ```
48 |
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/mogrifier/mogrifier.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from torch import nn, Tensor
5 | from torch.nn import Module
6 |
7 | from einops import repeat, pack, unpack
8 |
9 | # constants
10 |
11 | Linear = nn.Linear
12 |
13 | def exists(v):
14 | return v is not None
15 |
16 | def default(v, d):
17 | return v if exists(v) else d
18 |
19 | # maybe factorized projection
20 |
21 | def weight(
22 | dim_in,
23 | dim_out,
24 | k: int | None = None
25 | ):
26 | if not exists(k):
27 | return Linear(dim_in, dim_out)
28 |
29 | assert k < dim_in and k < dim_out, 'k must be of relative lower rank'
30 |
31 | return nn.Sequential(
32 | Linear(dim_in, k),
33 | Linear(k, dim_out)
34 | )
35 |
36 | # main class
37 |
38 | class Mogrifier(Module):
39 | def __init__(
40 | self,
41 | dim: int,
42 | iters = 5,
43 | factorize_k: int | None = None,
44 | dim_hidden: int | None = None,
45 | hidden_factorize_k: int | None = None,
46 | ):
47 | super().__init__()
48 | assert iters > 1
49 |
50 | self.dim = dim
51 |
52 | dim_hidden = default(dim_hidden, dim)
53 | self.dim_hidden = dim_hidden
54 |
55 | self.iters = iters
56 |
57 | self.Q = nn.Sequential(
58 | weight(dim_hidden, dim, factorize_k),
59 | nn.Sigmoid()
60 | )
61 |
62 | factorize_k = default(hidden_factorize_k, factorize_k)
63 |
64 | self.R = nn.Sequential(
65 | weight(dim, dim_hidden, factorize_k),
66 | nn.Sigmoid()
67 | )
68 |
69 | def forward(
70 | self,
71 | inputs: Tensor,
72 | hiddens: Tensor,
73 | iters: int | None = None
74 | ):
75 | iters = default(iters, self.iters)
76 |
77 | if inputs.ndim == 3 and hiddens.ndim == 2:
78 | hiddens = repeat(hiddens, 'b d -> b n d', n = inputs.shape[-2])
79 |
80 | assert inputs.shape[-1] == self.dim
81 | assert hiddens.shape[-1] == self.dim_hidden
82 | assert inputs.shape[:-2] == hiddens.shape[:-2]
83 |
84 | (inputs, packed_shape), (hiddens, _) = tuple(pack([t], '* d') for t in (inputs, hiddens))
85 |
86 | for ind in range(self.iters):
87 | is_even = (ind % 2) == 0
88 |
89 | if is_even:
90 | inputs = 2 * self.Q(hiddens) * inputs
91 | else:
92 | hiddens = 2 * self.R(inputs) * hiddens
93 |
94 | inputs, hiddens = tuple(unpack(t, packed_shape, '* d')[0] for t in (inputs, hiddens))
95 | return inputs, hiddens
96 |
--------------------------------------------------------------------------------