├── README.md ├── models.py └── train_elm.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch ELM 2 | 3 | Pytorch implementation of Extreme Learning Machines. 4 | 5 | ## Prerequisites 6 | 7 | - Python 3.3+ 8 | - [Pytorch](https://pytorch.org/) 9 | 10 | ## Usage 11 | 12 | This code is implemented for MNIST dataset. 13 | 14 | Usage example: 15 | 16 | $ python train_elm.py --hsize 1000 17 | 18 | --hsize represents the number of neurons in hidden layer. 19 | 20 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ############### 5 | # ELM 6 | ############### 7 | class ELM(): 8 | def __init__(self, input_size, h_size, num_classes, device=None): 9 | self._input_size = input_size 10 | self._h_size = h_size 11 | self._output_size = num_classes 12 | self._device = device 13 | 14 | self._alpha = nn.init.uniform_(torch.empty(self._input_size, self._h_size, device=self._device), a=-1., b=1.) 15 | self._beta = nn.init.uniform_(torch.empty(self._h_size, self._output_size, device=self._device), a=-1., b=1.) 16 | 17 | self._bias = torch.zeros(self._h_size, device=self._device) 18 | 19 | self._activation = torch.sigmoid 20 | 21 | def predict(self, x): 22 | h = self._activation(torch.add(x.mm(self._alpha), self._bias)) 23 | out = h.mm(self._beta) 24 | 25 | return out 26 | 27 | def fit(self, x, t): 28 | temp = x.mm(self._alpha) 29 | H = self._activation(torch.add(temp, self._bias)) 30 | 31 | H_pinv = torch.pinverse(H) 32 | self._beta = H_pinv.mm(t) 33 | 34 | 35 | def evaluate(self, x, t): 36 | y_pred = self.predict(x) 37 | acc = torch.sum(torch.argmax(y_pred, dim=1) == torch.argmax(t, dim=1)).item() / len(t) 38 | return acc 39 | 40 | ##################### 41 | # Helper Functions 42 | ##################### 43 | def to_onehot(batch_size, num_classes, y, device): 44 | # One hot encoding buffer that you create out of the loop and just keep reusing 45 | y_onehot = torch.FloatTensor(batch_size, num_classes).to(device) 46 | #y = y.type(dtype=torch.long) 47 | y = torch.unsqueeze(y, dim=1) 48 | # In your for loop 49 | y_onehot.zero_() 50 | y_onehot.scatter_(1, y, 1) 51 | 52 | return y_onehot 53 | -------------------------------------------------------------------------------- /train_elm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision.datasets 4 | from torch.utils.data import DataLoader 5 | import torchvision.transforms as transforms 6 | 7 | from models import ELM, to_onehot 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Defensive GAN') 11 | parser.add_argument('--hsize', type=int, default=500, help='Number of neurons in hidden layer.') 12 | opt = parser.parse_args() 13 | 14 | 15 | ################# 16 | # Parameters 17 | ################# 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | image_size = 28*28 20 | hidden_size = opt.hsize 21 | num_classes = 10 22 | 23 | ################## 24 | # Datasets 25 | ################## 26 | transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | ]) 29 | dataset = torchvision.datasets.MNIST(root='~/AI/Datasets/mnist/data', train=True, transform=transform) 30 | test_dataset = torchvision.datasets.MNIST(root='~/AI/Datasets/mnist/data', train=False, transform=transform) 31 | 32 | def get_all_data(dataset, num_workers=30, shuffle=False): 33 | dataset_size = len(dataset) 34 | data_loader = DataLoader(dataset, batch_size=dataset_size, 35 | num_workers=num_workers, shuffle=shuffle) 36 | 37 | for i_batch, sample_batched in enumerate(data_loader): 38 | images, labels = sample_batched[0].view(len(dataset), -1).to(device), sample_batched[1].to(device) 39 | return images, labels 40 | 41 | train_images , train_labels = get_all_data(dataset, shuffle=True) 42 | train_labels = to_onehot(batch_size=len(dataset), num_classes=num_classes, y=train_labels, device=device) 43 | 44 | test_images , test_labels = get_all_data(dataset, shuffle=False) 45 | test_labels = to_onehot(batch_size=len(dataset), num_classes=num_classes, y=test_labels, device=device) 46 | 47 | 48 | 49 | ################# 50 | # Model 51 | ################# 52 | elm = ELM(input_size=image_size, h_size=hidden_size, num_classes=num_classes, device=device) 53 | elm.fit(train_images, train_labels) 54 | accuracy = elm.evaluate(test_images, test_labels) 55 | 56 | print('Accuracy: {}'.format(accuracy)) 57 | 58 | 59 | 60 | 61 | --------------------------------------------------------------------------------