├── .coveragerc
├── .gitignore
├── LICENSE-MIT
├── MANIFEST.in
├── README.md
├── nalu
├── __init__.py
├── core
│ ├── __init__.py
│ ├── nac_cell.py
│ └── nalu_cell.py
└── layers
│ ├── __init__.py
│ └── nalu_layer.py
├── pyproject.toml
├── requirements.txt
├── resources
└── NALU.png
├── setup.py
├── tests
└── __init__.py
└── tox.ini
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source =
3 | nalu
4 | tests
5 | branch = True
6 | omit =
7 | nalu/cli.py
8 |
9 | [report]
10 | exclude_lines =
11 | no cov
12 | no qa
13 | noqa
14 | pragma: no cover
15 | if __name__ == .__main__.:
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | #nohup
108 | nohup.out
109 |
110 | #others
111 | notebooks/*
112 | #LICENSE-MIT
113 | #pyproject.toml
114 | #tox.ini
115 | #tests/*
116 | #.coveragerc
117 | #MANIFEST.in
118 | #setup.py
--------------------------------------------------------------------------------
/LICENSE-MIT:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018 Bharath G.S
2 |
3 | MIT License
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE-MIT
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Arithmetic Logic Units (NALU)
2 |
3 | [](https://pepy.tech/project/nalu)
4 |
5 |
6 |
7 | 
8 | 
9 | 
10 |
11 | 
12 |
13 | Basic pytorch implementation of NAC/NALU from [Neural Arithmetic Logic Units](https://arxiv.org/pdf/1808.00508.pdf) by trask et.al
14 |
15 | ## Installation
16 |
17 | ```python
18 | pip install NALU
19 | ```
20 |
21 | ## Usage
22 |
23 | ```python
24 | from nalu.core import NaluCell, NacCell
25 | from nalu.layers import NaluLayer
26 | ```
27 |
--------------------------------------------------------------------------------
/nalu/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.3'
2 |
3 | from .core import *
4 | from .layers import *
5 |
--------------------------------------------------------------------------------
/nalu/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .nac_cell import NacCell
2 | from .nalu_cell import NaluCell
3 |
--------------------------------------------------------------------------------
/nalu/core/nac_cell.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, nn
2 | from torch.nn.parameter import Parameter
3 | from torch.nn.init import xavier_uniform_
4 | from torch.nn.functional import linear
5 | from torch import sigmoid, tanh
6 |
7 |
8 | class NacCell(nn.Module):
9 | """Basic NAC unit implementation
10 | from https://arxiv.org/pdf/1808.00508.pdf
11 | """
12 |
13 | def __init__(self, in_shape, out_shape):
14 | """
15 | in_shape: input sample dimension
16 | out_shape: output sample dimension
17 | """
18 | super().__init__()
19 | self.in_shape = in_shape
20 | self.out_shape = out_shape
21 | self.W_ = Parameter(Tensor(out_shape, in_shape))
22 | self.M_ = Parameter(Tensor(out_shape, in_shape))
23 | xavier_uniform_(self.W_), xavier_uniform_(self.M_)
24 | self.register_parameter('bias', None)
25 |
26 | def forward(self, input):
27 | W = tanh(self.W_) * sigmoid(self.M_)
28 | return linear(input, W, self.bias)
29 |
--------------------------------------------------------------------------------
/nalu/core/nalu_cell.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, exp, log, nn
2 | from torch.nn.parameter import Parameter
3 | from torch.nn.init import xavier_uniform_
4 | from torch.nn.functional import linear
5 | from torch import sigmoid
6 | from .nac_cell import NacCell
7 |
8 |
9 | class NaluCell(nn.Module):
10 | """Basic NALU unit implementation
11 | from https://arxiv.org/pdf/1808.00508.pdf
12 | """
13 |
14 | def __init__(self, in_shape, out_shape):
15 | """
16 | in_shape: input sample dimension
17 | out_shape: output sample dimension
18 | """
19 | super().__init__()
20 | self.in_shape = in_shape
21 | self.out_shape = out_shape
22 | self.G = Parameter(Tensor(out_shape, in_shape))
23 | self.nac = NacCell(out_shape, in_shape)
24 | xavier_uniform_(self.G)
25 | self.eps = 1e-5
26 | self.register_parameter('bias', None)
27 |
28 | def forward(self, input):
29 | a = self.nac(input)
30 | g = sigmoid(linear(input, self.G, self.bias))
31 | ag = g * a
32 | log_in = log(abs(input) + self.eps)
33 | m = exp(self.nac(log_in))
34 | md = (1 - g) * m
35 | return ag + md
36 |
--------------------------------------------------------------------------------
/nalu/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .nalu_layer import NaluLayer
2 |
--------------------------------------------------------------------------------
/nalu/layers/nalu_layer.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Sequential
2 | from torch import nn
3 | from nalu.core.nalu_cell import NaluCell
4 |
5 |
6 | class NaluLayer(nn.Module):
7 | def __init__(self, input_shape, output_shape, n_layers, hidden_shape):
8 | super().__init__()
9 | self.input_shape = input_shape
10 | self.output_shape = output_shape
11 | self.n_layers = n_layers
12 | self.hidden_shape = hidden_shape
13 | layers = [NaluCell(hidden_shape if n > 0 else input_shape,
14 | hidden_shape if n < n_layers - 1 else output_shape) for n in range(n_layers)]
15 | self.model = Sequential(*layers)
16 |
17 | def forward(self, data):
18 | return self.model(data)
19 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = 'NALU'
3 | version = '0.0.1'
4 | description = 'basic implementation of Neural arithmetic and logic units as described in arxiv.org/pdf/1808.00508.pdf'
5 | author = 'Bharath G.S'
6 | author_email = 'royalkingpin@gmail.com'
7 | license = 'MIT'
8 | url = 'https://github.com/bharathgs/NALU'
9 |
10 | [requires]
11 | python_version = ['3.6']
12 |
13 | [build-system]
14 | requires = ['setuptools', 'wheel']
15 |
16 | [tool.hatch.commands]
17 | prerelease = 'hatch build'
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 |
4 |
--------------------------------------------------------------------------------
/resources/NALU.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bharathgs/NALU/5d52cc270786563b67837a3856841baafba20e60/resources/NALU.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from io import open
2 |
3 | from setuptools import find_packages, setup
4 |
5 | with open('nalu/__init__.py', 'r') as f:
6 | for line in f:
7 | if line.startswith('__version__'):
8 | version = line.strip().split('=')[1].strip(' \'"')
9 | break
10 | else:
11 | version = '0.0.1'
12 |
13 | with open('README.md', 'r', encoding='utf-8') as f:
14 | readme = f.read()
15 |
16 | REQUIRES = ['torch', 'numpy']
17 |
18 | setup(
19 | name='NALU',
20 | version=version,
21 | description='basic implementation of Neural arithmetic and logic units as described in arxiv.org/pdf/1808.00508.pdf',
22 | long_description=readme,
23 | author='Bharath G.S',
24 | author_email='royalkingpin@gmail.com',
25 | maintainer='Bharath G.S',
26 | maintainer_email='royalkingpin@gmail.com',
27 | url='https://github.com/bharathgs/NALU',
28 | license='MIT',
29 |
30 | keywords=[
31 | 'NALU', 'ALU', 'neural', 'neural-networks', 'pytorch', 'NAC', 'torch', 'machine-learning'
32 | ],
33 |
34 | classifiers=[
35 | 'Development Status :: 4 - Beta',
36 | 'Intended Audience :: Developers',
37 | 'License :: OSI Approved :: MIT License',
38 | 'Natural Language :: English',
39 | 'Operating System :: OS Independent',
40 | 'Programming Language :: Python :: 3.6',
41 | 'Programming Language :: Python :: Implementation :: CPython',
42 | ],
43 |
44 | install_requires=REQUIRES,
45 | tests_require=['coverage', 'pytest'],
46 |
47 | packages=find_packages(),
48 | )
49 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bharathgs/NALU/5d52cc270786563b67837a3856841baafba20e60/tests/__init__.py
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist =
3 | py36,
4 |
5 | [testenv]
6 | passenv = *
7 | deps =
8 | coverage
9 | pytest
10 | commands =
11 | python setup.py --quiet clean develop
12 | coverage run --parallel-mode -m pytest
13 | coverage combine --append
14 | coverage report -m
15 |
--------------------------------------------------------------------------------