├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt └── src ├── deepsets ├── __init__.py ├── datasets.py ├── experiments.py ├── networks.py └── settings.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # PyCharm 104 | .idea/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yasser Souri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deep-sets 2 | PyTorch implementation of parts of "Deep Sets" (NIPS 2017) 3 | 4 | ## Requirements 5 | 6 | * Python 3.6 7 | * PyTorch 0.3 8 | * torchvision 9 | * matplotlib 10 | * click 11 | * tqdm 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | matplotlib 4 | ipython 5 | jupyter 6 | tqdm 7 | click 8 | numpy 9 | tensorboardX 10 | -------------------------------------------------------------------------------- /src/deepsets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yassersouri/pytorch-deep-sets/5190e3eee8a438a0e6f599882786502d8fa0b09e/src/deepsets/__init__.py -------------------------------------------------------------------------------- /src/deepsets/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import FloatTensor 6 | from torch.utils.data.dataset import Dataset 7 | from torchvision.datasets import MNIST 8 | from torchvision.transforms import Compose, ToTensor, Normalize 9 | 10 | from .settings import DATA_ROOT 11 | 12 | MNIST_TRANSFORM = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) 13 | 14 | 15 | class MNISTSummation(Dataset): 16 | def __init__(self, min_len: int, max_len: int, dataset_len: int, train: bool = True, transform: Compose = None): 17 | self.min_len = min_len 18 | self.max_len = max_len 19 | self.dataset_len = dataset_len 20 | self.train = train 21 | self.transform = transform 22 | 23 | self.mnist = MNIST(DATA_ROOT, train=self.train, transform=self.transform, download=True) 24 | mnist_len = self.mnist.__len__() 25 | mnist_items_range = np.arange(0, mnist_len) 26 | 27 | items_len_range = np.arange(self.min_len, self.max_len + 1) 28 | items_len = np.random.choice(items_len_range, size=self.dataset_len, replace=True) 29 | self.mnist_items = [] 30 | for i in range(self.dataset_len): 31 | self.mnist_items.append(np.random.choice(mnist_items_range, size=items_len[i], replace=True)) 32 | 33 | def __len__(self) -> int: 34 | return self.dataset_len 35 | 36 | def __getitem__(self, item: int) -> Tuple[FloatTensor, FloatTensor]: 37 | mnist_items = self.mnist_items[item] 38 | 39 | the_sum = 0 40 | images = [] 41 | for mi in mnist_items: 42 | img, target = self.mnist.__getitem__(mi) 43 | the_sum += target 44 | images.append(img) 45 | 46 | return torch.stack(images, dim=0), torch.FloatTensor([the_sum]) 47 | -------------------------------------------------------------------------------- /src/deepsets/experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from tensorboardX import SummaryWriter 5 | from torch import optim 6 | from torch.autograd import Variable 7 | from tqdm import tqdm 8 | 9 | from .datasets import MNISTSummation, MNIST_TRANSFORM 10 | from .networks import InvariantModel, SmallMNISTCNNPhi, SmallRho 11 | 12 | 13 | class SumOfDigits(object): 14 | def __init__(self, lr=1e-3, wd=5e-3): 15 | self.lr = lr 16 | self.wd = wd 17 | self.train_db = MNISTSummation(min_len=2, max_len=10, dataset_len=100000, train=True, transform=MNIST_TRANSFORM) 18 | self.test_db = MNISTSummation(min_len=5, max_len=50, dataset_len=100000, train=False, transform=MNIST_TRANSFORM) 19 | 20 | self.the_phi = SmallMNISTCNNPhi() 21 | self.the_rho = SmallRho(input_size=10, output_size=1) 22 | 23 | self.model = InvariantModel(phi=self.the_phi, rho=self.the_rho) 24 | if torch.cuda.is_available(): 25 | self.model.cuda() 26 | 27 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd) 28 | 29 | self.summary_writer = SummaryWriter( 30 | log_dir='/home/souri/temp/deepsets/exp-lr:%1.5f-wd:%1.5f/' % (self.lr, self.wd)) 31 | 32 | def train_1_epoch(self, epoch_num: int = 0): 33 | self.model.train() 34 | for i in tqdm(range(len(self.train_db))): 35 | loss = self.train_1_item(i) 36 | self.summary_writer.add_scalar('train_loss', loss, i + len(self.train_db) * epoch_num) 37 | 38 | def train_1_item(self, item_number: int) -> float: 39 | x, target = self.train_db.__getitem__(item_number) 40 | if torch.cuda.is_available(): 41 | x, target = x.cuda(), target.cuda() 42 | 43 | x, target = Variable(x), Variable(target) 44 | 45 | self.optimizer.zero_grad() 46 | pred = self.model.forward(x) 47 | the_loss = F.mse_loss(pred, target) 48 | 49 | the_loss.backward() 50 | self.optimizer.step() 51 | 52 | the_loss_tensor = the_loss.data 53 | if torch.cuda.is_available(): 54 | the_loss_tensor = the_loss_tensor.cpu() 55 | 56 | the_loss_numpy = the_loss_tensor.numpy().flatten() 57 | the_loss_float = float(the_loss_numpy[0]) 58 | 59 | return the_loss_float 60 | 61 | def evaluate(self): 62 | self.model.eval() 63 | totals = [0] * 51 64 | corrects = [0] * 51 65 | 66 | for i in tqdm(range(len(self.test_db))): 67 | x, target = self.test_db.__getitem__(i) 68 | 69 | item_size = x.shape[0] 70 | 71 | if torch.cuda.is_available(): 72 | x = x.cuda() 73 | 74 | pred = self.model.forward(Variable(x)).data 75 | 76 | if torch.cuda.is_available(): 77 | pred = pred.cpu().numpy().flatten() 78 | 79 | pred = int(round(float(pred[0]))) 80 | target = int(round(float(target.numpy()[0]))) 81 | 82 | totals[item_size] += 1 83 | 84 | if pred == target: 85 | corrects[item_size] += 1 86 | 87 | totals = np.array(totals) 88 | corrects = np.array(corrects) 89 | 90 | print(corrects / totals) 91 | -------------------------------------------------------------------------------- /src/deepsets/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import FloatTensor 7 | from torch.autograd import Variable 8 | 9 | NetIO = Union[FloatTensor, Variable] 10 | 11 | 12 | class InvariantModel(nn.Module): 13 | def __init__(self, phi: nn.Module, rho: nn.Module): 14 | super().__init__() 15 | self.phi = phi 16 | self.rho = rho 17 | 18 | def forward(self, x: NetIO) -> NetIO: 19 | # compute the representation for each data point 20 | x = self.phi.forward(x) 21 | 22 | # sum up the representations 23 | # here I have assumed that x is 2D and the each row is representation of an input, so the following operation 24 | # will reduce the number of rows to 1, but it will keep the tensor as a 2D tensor. 25 | x = torch.sum(x, dim=0, keepdim=True) 26 | 27 | # compute the output 28 | out = self.rho.forward(x) 29 | 30 | return out 31 | 32 | 33 | class SmallMNISTCNNPhi(nn.Module): 34 | def __init__(self): 35 | super().__init__() 36 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 37 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 38 | self.conv2_drop = nn.Dropout2d() 39 | self.fc1 = nn.Linear(320, 50) 40 | self.fc1_drop = nn.Dropout2d() 41 | self.fc2 = nn.Linear(50, 10) 42 | 43 | def forward(self, x: NetIO) -> NetIO: 44 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 45 | x = self.conv2_drop(self.conv2(x)) 46 | x = F.relu(F.max_pool2d(x, 2)) 47 | x = x.view(-1, 320) 48 | x = F.relu(self.fc1(x)) 49 | x = self.fc1_drop(x) 50 | x = F.relu(self.fc2(x)) 51 | return x 52 | 53 | 54 | class SmallRho(nn.Module): 55 | def __init__(self, input_size: int, output_size: int = 1): 56 | super().__init__() 57 | self.input_size = input_size 58 | self.output_size = output_size 59 | 60 | self.fc1 = nn.Linear(self.input_size, 10) 61 | self.fc2 = nn.Linear(10, self.output_size) 62 | 63 | def forward(self, x: NetIO) -> NetIO: 64 | x = F.relu(self.fc1(x)) 65 | x = self.fc2(x) 66 | return x 67 | -------------------------------------------------------------------------------- /src/deepsets/settings.py: -------------------------------------------------------------------------------- 1 | RANDOM_SEED = 0 2 | DATA_ROOT = '~/datasets/' 3 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | import torch 4 | 5 | from deepsets.experiments import SumOfDigits 6 | from deepsets.settings import RANDOM_SEED 7 | 8 | 9 | @click.command() 10 | @click.option('--random-seed', envvar='SEED', default=RANDOM_SEED) 11 | def main(random_seed): 12 | np.random.seed(random_seed) 13 | torch.manual_seed(random_seed) 14 | torch.cuda.manual_seed_all(random_seed) 15 | 16 | the_experiment = SumOfDigits(lr=1e-3) 17 | 18 | for i in range(20): 19 | the_experiment.train_1_epoch(i) 20 | the_experiment.evaluate() 21 | 22 | 23 | if __name__ == '__main__': 24 | main() 25 | --------------------------------------------------------------------------------