├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── nam_pytorch ├── __init__.py └── nam_pytorch.py ├── pic.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishabh Anand 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nam-pytorch 2 | Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al. [[`abs`](https://arxiv.org/abs/2004.13912), [`pdf`](https://arxiv.org/pdf/2004.13912.pdf)] 3 | 4 | 5 | 6 | --- 7 | 8 | ## Installation 9 | 10 | You can access `nam-pytorch` via `pip`: 11 | 12 | ```bash 13 | $ pip install nam-pytorch 14 | ``` 15 | 16 | ## Usage 17 | 18 | ```python 19 | import torch 20 | from nam_pytorch import NAM 21 | 22 | nam = NAM( 23 | num_features=784, 24 | link_func="sigmoid" 25 | ) 26 | 27 | images = torch.rand(32, 784) 28 | pred = nam(images) # [32, 1] 29 | ``` 30 | 31 | ## Contributing 32 | As always, if there are any issues with / suggestions for the code, feel free to raise an issue or submit a PR. 33 | 34 | ## License 35 | [MIT](https://github.com/rish-16/nam-pytorch/blob/main/LICENSE) 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nam_pytorch import NAM 3 | 4 | nam = NAM(784, "tanh") 5 | x = torch.rand(32, 784) 6 | 7 | y = nam(x) 8 | 9 | print (y.shape) -------------------------------------------------------------------------------- /nam_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from nam_pytorch.nam_pytorch import NAM -------------------------------------------------------------------------------- /nam_pytorch/nam_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FeatureNetwork(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.l1 = nn.Linear(1, 64) 9 | self.l2 = nn.Linear(64, 128) 10 | self.l3 = nn.Linear(128, 256) 11 | self.l4 = nn.Linear(256, 128) 12 | self.l5 = nn.Linear(128, 64) 13 | self.l6 = nn.Linear(64, 1) 14 | 15 | def forward(self, x): 16 | x = F.relu(self.l1(x)) 17 | x = F.relu(self.l2(x)) 18 | x = F.relu(self.l3(x)) 19 | x = F.relu(self.l4(x)) 20 | x = F.relu(self.l5(x)) 21 | out = self.l6(x) 22 | 23 | return out 24 | 25 | class NAM(nn.Module): 26 | def __init__(self, num_features, link_func="sigmoid"): 27 | super().__init__() 28 | self.networks = nn.ModuleList([ 29 | FeatureNetwork() for _ in range(num_features) 30 | ]) 31 | self.num_features = num_features 32 | self.bias = nn.Parameter(torch.rand(1)) # extra beta term 33 | self.link_func = link_func 34 | 35 | def forward(self, x): 36 | B, dim = x.shape 37 | outs = torch.Tensor(B, dim) 38 | 39 | for i in range(B): 40 | temp = torch.Tensor(dim) 41 | for j in range(self.num_features): # for all dim 42 | net = self.networks[j] 43 | xi = x[i, j].unsqueeze(dim=0) 44 | temp[j] = net(xi) 45 | outs[i] = temp 46 | 47 | summed = outs.sum(axis=1) + self.bias 48 | 49 | if self.link_func == "sigmoid": 50 | res = torch.sigmoid(summed).view(B, 1) 51 | elif self.link_func == "tanh": 52 | res = torch.tanh(summed).view(B, 1) 53 | 54 | return res -------------------------------------------------------------------------------- /pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rish-16/nam-pytorch/7c52537e6f31ee9532258c76c9cded501dcc7f5d/pic.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md') as readme_file: 4 | README = readme_file.read() 5 | 6 | setup( 7 | name = 'nam_pytorch', 8 | packages = find_packages(exclude=[]), 9 | version = '0.0.1', 10 | license='MIT', 11 | description = 'Neural Additive Models (NAM) - Pytorch', 12 | long_description_content_type="text/markdown", 13 | long_description=README, 14 | author = 'Rishabh Anand', 15 | author_email = 'mail.rishabh.anand@gmail.com', 16 | url = 'https://github.com/rish-16/nam-pytorch', 17 | keywords = [ 18 | 'artificial intelligence', 19 | 'deep learning', 20 | 'nam', 21 | 'neural additive models', 22 | 'generalized additive models' 23 | ], 24 | install_requires=[ 25 | 'torch>=1.6' 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Developers', 30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3.6', 33 | ], 34 | ) --------------------------------------------------------------------------------