├── .coveragerc ├── .gitignore ├── LICENSE-MIT ├── MANIFEST.in ├── README.md ├── nalu ├── __init__.py ├── core │ ├── __init__.py │ ├── nac_cell.py │ └── nalu_cell.py └── layers │ ├── __init__.py │ └── nalu_layer.py ├── pyproject.toml ├── requirements.txt ├── resources └── NALU.png ├── setup.py ├── tests └── __init__.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = 3 | nalu 4 | tests 5 | branch = True 6 | omit = 7 | nalu/cli.py 8 | 9 | [report] 10 | exclude_lines = 11 | no cov 12 | no qa 13 | noqa 14 | pragma: no cover 15 | if __name__ == .__main__.: 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | #nohup 108 | nohup.out 109 | 110 | #others 111 | notebooks/* 112 | #LICENSE-MIT 113 | #pyproject.toml 114 | #tox.ini 115 | #tests/* 116 | #.coveragerc 117 | #MANIFEST.in 118 | #setup.py -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Bharath G.S 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE-MIT 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Arithmetic Logic Units (NALU) 2 | 3 | [![Downloads](https://pepy.tech/badge/nalu)](https://pepy.tech/project/nalu) 4 | 5 | 6 | 7 | ![GitHub](https://img.shields.io/github/license/mashape/apistatus.svg) 8 | ![Hackage-Deps](https://img.shields.io/hackage-deps/v/lens.svg) 9 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/Django.svg) 10 | 11 | ![](resources/NALU.png) 12 | 13 | Basic pytorch implementation of NAC/NALU from [Neural Arithmetic Logic Units](https://arxiv.org/pdf/1808.00508.pdf) by trask et.al 14 | 15 | ## Installation 16 | 17 | ```python 18 | pip install NALU 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```python 24 | from nalu.core import NaluCell, NacCell 25 | from nalu.layers import NaluLayer 26 | ``` 27 | -------------------------------------------------------------------------------- /nalu/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.3' 2 | 3 | from .core import * 4 | from .layers import * 5 | -------------------------------------------------------------------------------- /nalu/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .nac_cell import NacCell 2 | from .nalu_cell import NaluCell 3 | -------------------------------------------------------------------------------- /nalu/core/nac_cell.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from torch.nn.parameter import Parameter 3 | from torch.nn.init import xavier_uniform_ 4 | from torch.nn.functional import linear 5 | from torch import sigmoid, tanh 6 | 7 | 8 | class NacCell(nn.Module): 9 | """Basic NAC unit implementation 10 | from https://arxiv.org/pdf/1808.00508.pdf 11 | """ 12 | 13 | def __init__(self, in_shape, out_shape): 14 | """ 15 | in_shape: input sample dimension 16 | out_shape: output sample dimension 17 | """ 18 | super().__init__() 19 | self.in_shape = in_shape 20 | self.out_shape = out_shape 21 | self.W_ = Parameter(Tensor(out_shape, in_shape)) 22 | self.M_ = Parameter(Tensor(out_shape, in_shape)) 23 | xavier_uniform_(self.W_), xavier_uniform_(self.M_) 24 | self.register_parameter('bias', None) 25 | 26 | def forward(self, input): 27 | W = tanh(self.W_) * sigmoid(self.M_) 28 | return linear(input, W, self.bias) 29 | -------------------------------------------------------------------------------- /nalu/core/nalu_cell.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, exp, log, nn 2 | from torch.nn.parameter import Parameter 3 | from torch.nn.init import xavier_uniform_ 4 | from torch.nn.functional import linear 5 | from torch import sigmoid 6 | from .nac_cell import NacCell 7 | 8 | 9 | class NaluCell(nn.Module): 10 | """Basic NALU unit implementation 11 | from https://arxiv.org/pdf/1808.00508.pdf 12 | """ 13 | 14 | def __init__(self, in_shape, out_shape): 15 | """ 16 | in_shape: input sample dimension 17 | out_shape: output sample dimension 18 | """ 19 | super().__init__() 20 | self.in_shape = in_shape 21 | self.out_shape = out_shape 22 | self.G = Parameter(Tensor(out_shape, in_shape)) 23 | self.nac = NacCell(out_shape, in_shape) 24 | xavier_uniform_(self.G) 25 | self.eps = 1e-5 26 | self.register_parameter('bias', None) 27 | 28 | def forward(self, input): 29 | a = self.nac(input) 30 | g = sigmoid(linear(input, self.G, self.bias)) 31 | ag = g * a 32 | log_in = log(abs(input) + self.eps) 33 | m = exp(self.nac(log_in)) 34 | md = (1 - g) * m 35 | return ag + md 36 | -------------------------------------------------------------------------------- /nalu/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .nalu_layer import NaluLayer 2 | -------------------------------------------------------------------------------- /nalu/layers/nalu_layer.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Sequential 2 | from torch import nn 3 | from nalu.core.nalu_cell import NaluCell 4 | 5 | 6 | class NaluLayer(nn.Module): 7 | def __init__(self, input_shape, output_shape, n_layers, hidden_shape): 8 | super().__init__() 9 | self.input_shape = input_shape 10 | self.output_shape = output_shape 11 | self.n_layers = n_layers 12 | self.hidden_shape = hidden_shape 13 | layers = [NaluCell(hidden_shape if n > 0 else input_shape, 14 | hidden_shape if n < n_layers - 1 else output_shape) for n in range(n_layers)] 15 | self.model = Sequential(*layers) 16 | 17 | def forward(self, data): 18 | return self.model(data) 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = 'NALU' 3 | version = '0.0.1' 4 | description = 'basic implementation of Neural arithmetic and logic units as described in arxiv.org/pdf/1808.00508.pdf' 5 | author = 'Bharath G.S' 6 | author_email = 'royalkingpin@gmail.com' 7 | license = 'MIT' 8 | url = 'https://github.com/bharathgs/NALU' 9 | 10 | [requires] 11 | python_version = ['3.6'] 12 | 13 | [build-system] 14 | requires = ['setuptools', 'wheel'] 15 | 16 | [tool.hatch.commands] 17 | prerelease = 'hatch build' 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | 4 | -------------------------------------------------------------------------------- /resources/NALU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bharathgs/NALU/5d52cc270786563b67837a3856841baafba20e60/resources/NALU.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from io import open 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open('nalu/__init__.py', 'r') as f: 6 | for line in f: 7 | if line.startswith('__version__'): 8 | version = line.strip().split('=')[1].strip(' \'"') 9 | break 10 | else: 11 | version = '0.0.1' 12 | 13 | with open('README.md', 'r', encoding='utf-8') as f: 14 | readme = f.read() 15 | 16 | REQUIRES = ['torch', 'numpy'] 17 | 18 | setup( 19 | name='NALU', 20 | version=version, 21 | description='basic implementation of Neural arithmetic and logic units as described in arxiv.org/pdf/1808.00508.pdf', 22 | long_description=readme, 23 | author='Bharath G.S', 24 | author_email='royalkingpin@gmail.com', 25 | maintainer='Bharath G.S', 26 | maintainer_email='royalkingpin@gmail.com', 27 | url='https://github.com/bharathgs/NALU', 28 | license='MIT', 29 | 30 | keywords=[ 31 | 'NALU', 'ALU', 'neural', 'neural-networks', 'pytorch', 'NAC', 'torch', 'machine-learning' 32 | ], 33 | 34 | classifiers=[ 35 | 'Development Status :: 4 - Beta', 36 | 'Intended Audience :: Developers', 37 | 'License :: OSI Approved :: MIT License', 38 | 'Natural Language :: English', 39 | 'Operating System :: OS Independent', 40 | 'Programming Language :: Python :: 3.6', 41 | 'Programming Language :: Python :: Implementation :: CPython', 42 | ], 43 | 44 | install_requires=REQUIRES, 45 | tests_require=['coverage', 'pytest'], 46 | 47 | packages=find_packages(), 48 | ) 49 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bharathgs/NALU/5d52cc270786563b67837a3856841baafba20e60/tests/__init__.py -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | py36, 4 | 5 | [testenv] 6 | passenv = * 7 | deps = 8 | coverage 9 | pytest 10 | commands = 11 | python setup.py --quiet clean develop 12 | coverage run --parallel-mode -m pytest 13 | coverage combine --append 14 | coverage report -m 15 | --------------------------------------------------------------------------------