├── .gitignore ├── datapoint_generator.py ├── img └── example.png ├── main.py ├── my_data.py ├── readme.md ├── softmax_regression.py └── trained_models └── model.pth /.gitignore: -------------------------------------------------------------------------------- 1 | *pyc 2 | .idea/ 3 | -------------------------------------------------------------------------------- /datapoint_generator.py: -------------------------------------------------------------------------------- 1 | from torch.distributions.multivariate_normal import MultivariateNormal 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class DataPoint2DGenerator: 7 | def __init__(self, means, cov, n_point_per_class=100): 8 | self.means = means 9 | self.cov = cov 10 | self.n_point_per_class = n_point_per_class 11 | self.n_class = len(means) 12 | self.total_points = self.n_class * self.n_point_per_class 13 | self.data = torch.empty(self.total_points, 2) 14 | self.labels = torch.empty(self.total_points, dtype=torch.long) 15 | 16 | def generate(self): 17 | point_idx = 0 18 | for i in range(self.n_class): 19 | torch_mean = torch.tensor(self.means[i]) 20 | torch_cov = torch.tensor(self.cov) 21 | distribution = MultivariateNormal(torch_mean, torch_cov) 22 | 23 | for _ in range(self.n_point_per_class): 24 | self.labels[point_idx] = i 25 | self.data[point_idx] = distribution.sample() 26 | point_idx += 1 27 | 28 | result = (self.data, self.labels) 29 | return result 30 | 31 | def display(self): 32 | plt.scatter(self.data[:, 0], self.data[:, 1]) 33 | -------------------------------------------------------------------------------- /img/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trqminh/softmax-regression/860b44a0a5d567f8300ae38af282898ce0cc00af/img/example.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from datapoint_generator import DataPoint2DGenerator 2 | from softmax_regression import SoftmaxRegression 3 | from my_data import MyDataset 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import numpy as np 7 | import argparse 8 | 9 | if __name__ == "__main__": 10 | np.random.seed(0), torch.manual_seed(0) 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--train', type=int, default=0, help='1 for training and 0 for testing') 13 | args = parser.parse_args() 14 | 15 | if bool(args.train): 16 | print("Trainning...") 17 | means = [[2., 2.], [-2., -2.], [-5., 6.]] 18 | cov = [[1., 0.], [0., 1.]] 19 | 20 | data_generator = DataPoint2DGenerator(means, cov) 21 | data = data_generator.generate() 22 | data_generator.display() 23 | 24 | dataset = MyDataset(data[0], data[1]) 25 | soft_reg = SoftmaxRegression(dataset, data_generator.n_class) 26 | soft_reg.train() 27 | 28 | # accuracy on train set 29 | soft_reg.visualize() 30 | print('Accuracy on train set: ', soft_reg.accuracy_on_train_set()) 31 | 32 | # show all plot 33 | plt.show() 34 | else: 35 | print("Visualization: ") 36 | means = [[2., 2.], [-2., -2.], [-5., 6.]] 37 | cov = [[1., 0.], [0., 1.]] 38 | 39 | data_generator = DataPoint2DGenerator(means, cov) 40 | data = data_generator.generate() 41 | data_generator.display() 42 | 43 | dataset = MyDataset(data[0], data[1]) 44 | soft_reg = SoftmaxRegression(dataset, data_generator.n_class) 45 | 46 | # accuracy on train set 47 | soft_reg.visualize(True) 48 | 49 | # show all plot 50 | plt.show() 51 | -------------------------------------------------------------------------------- /my_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | 3 | 4 | class MyDataset(Dataset): 5 | def __init__(self, inputs, labels): 6 | self.inputs = inputs 7 | self.labels = labels 8 | 9 | def __len__(self): 10 | return len(self.labels) 11 | 12 | def __getitem__(self, index): 13 | x = self.inputs[index] 14 | y = self.labels[index] 15 | 16 | return x, y 17 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # softmax-regression 2 | In this repository, I have written some code on understanding softmax-regression. In addition, I also tried to get familiar with Pytorch by programming softmax regression in the Pytorch flow as well as using `Dataset` and `DataLoader`. 3 | 4 | The source code, in my view, is easy to read. You can change the mean values of the data or the number of classes in the `main.py` and `softmax_regression.py`, respectively. However, the boundaries drawn after training are valid only if the data is 2-dimensional. 5 | 6 | Here is the example of classifying 3 classes: 7 | ![](img/example.png) 8 | 9 | ## References 10 | - [Softmax Regression](machinelearningcoban.com/2017/02/17/softmax/) 11 | - [Pytorch tutorials](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) 12 | -------------------------------------------------------------------------------- /softmax_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data.dataloader import DataLoader 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from itertools import islice 8 | 9 | 10 | class SoftmaxRegression: 11 | def __init__(self, dataset, n_class): 12 | self.data_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2) 13 | self.n_class = n_class 14 | self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 15 | self.weights = nn.Linear(dataset[0][0].shape[0], self.n_class, bias=True) 16 | self.model = nn.Sequential(self.weights, nn.Softmax()).to(self.device) 17 | self.criterion = nn.CrossEntropyLoss() 18 | self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) 19 | self.model_path = 'trained_models/model.pth' 20 | 21 | def train(self): 22 | loss = None 23 | for epoch in range(1000): 24 | for _, (inputs, labels) in enumerate(self.data_loader, 0): 25 | inputs = inputs.to(self.device) 26 | labels = labels.to(self.device) 27 | self.optimizer.zero_grad() 28 | outputs = self.model(inputs) 29 | 30 | loss = self.criterion(outputs, labels) 31 | loss.backward() 32 | self.optimizer.step() 33 | 34 | if epoch % 100 == 0: 35 | print('Loss at {} epoch: {}'.format(epoch, loss.item())) 36 | 37 | print('Loss at last epoch: ', loss.item()) 38 | print('Saving the model: ') 39 | torch.save(self.model.state_dict(), self.model_path) 40 | 41 | def accuracy_on_train_set(self): 42 | correct = 0 43 | total = 0 44 | with torch.no_grad(): 45 | for (inputs, labels) in self.data_loader: 46 | inputs, labels = inputs.to(self.device), labels.to(self.device) 47 | outputs = self.model(inputs) 48 | _, predict = torch.max(outputs, 1) 49 | 50 | correct += (predict == labels).sum().item() 51 | total += labels.shape[0] 52 | 53 | return correct/total 54 | 55 | @staticmethod 56 | def draw(w, b): 57 | x = np.linspace(-8, 5, 100) 58 | y = (w[0] * x + b) / -w[1] 59 | 60 | plt.plot(x, y) 61 | 62 | def visualize(self, load_from_trained_model=False): 63 | if load_from_trained_model: 64 | self.model.load_state_dict(torch.load(self.model_path, map_location='cpu')) 65 | 66 | with torch.no_grad(): 67 | weights, biases = None, None 68 | for i, p in enumerate(self.model.parameters()): 69 | if i ==0: 70 | weights = p.cpu().numpy() 71 | else: 72 | biases = p.cpu().numpy() 73 | 74 | for i in range(len(weights)): 75 | w = np.zeros(len(weights[i])) 76 | b = 2*biases[i] 77 | for j in range(len(weights[i])): 78 | w[j] = 2*weights[i][j] 79 | for k in range(len(weights)): 80 | if k != i: 81 | w[j] -= weights[k][j] 82 | for kb in range(len(weights)): 83 | if kb != i: 84 | b -= biases[kb] 85 | 86 | self.draw(w,b) 87 | -------------------------------------------------------------------------------- /trained_models/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trqminh/softmax-regression/860b44a0a5d567f8300ae38af282898ce0cc00af/trained_models/model.pth --------------------------------------------------------------------------------