├── pytorch_fitmodule ├── __init__.py ├── utils.py └── fit_module.py ├── run_example.py ├── .gitignore └── README.md /pytorch_fitmodule/__init__.py: -------------------------------------------------------------------------------- 1 | from .fit_module import FitModule 2 | -------------------------------------------------------------------------------- /run_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from pytorch_fitmodule import FitModule 7 | from sklearn.datasets import make_multilabel_classification 8 | 9 | 10 | SEED = 1701 11 | 12 | 13 | def print_title(s): 14 | print("\n\n{0}\n{1}\n{0}".format("="*len(s), s)) 15 | 16 | 17 | ##### Generate training set ##### 18 | print_title("Generating data set") 19 | 20 | n_feats, n_classes = 200, 5 21 | X, y = make_multilabel_classification( 22 | n_samples=10000, n_features=n_feats, n_classes=n_classes, n_labels=0.01, 23 | length=50, allow_unlabeled=False, sparse=False, return_indicator='dense', 24 | return_distributions=False, random_state=SEED 25 | ) 26 | y = np.argmax(y, axis=1) 27 | X = torch.from_numpy(X).float() 28 | y = torch.from_numpy(y).long() 29 | 30 | 31 | ##### Define model ##### 32 | print_title("Building model") 33 | 34 | class MLP(FitModule): 35 | def __init__(self, n_feats, n_classes, hidden_size=50): 36 | super(MLP, self).__init__() 37 | self.fc1 = nn.Linear(n_feats, hidden_size) 38 | self.fc2 = nn.Linear(hidden_size, n_classes) 39 | def forward(self, x): 40 | return F.log_softmax(self.fc2(F.relu(self.fc1(x)))) 41 | 42 | f = MLP(n_feats, n_classes) 43 | 44 | 45 | ##### Train model ##### 46 | print_title("Training model") 47 | 48 | def accuracy(y_true, y_pred): 49 | return np.mean(y_true.numpy() == np.argmax(y_pred.numpy(), axis=1)) 50 | 51 | f.fit( 52 | X, y, epochs=10, validation_split=0.3, seed=SEED, metrics=[accuracy] 53 | ) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | 101 | -------------------------------------------------------------------------------- /pytorch_fitmodule/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import torch 4 | 5 | from functools import partial 6 | from torch.utils.data import DataLoader, TensorDataset 7 | 8 | 9 | ##### Data utils ##### 10 | 11 | def get_loader(X, y=None, batch_size=1, shuffle=False): 12 | """Convert X and y Tensors to a DataLoader 13 | 14 | If y is None, use a dummy Tensor 15 | """ 16 | if y is None: 17 | y = torch.Tensor(X.size()[0]) 18 | return DataLoader(TensorDataset(X, y), batch_size, shuffle) 19 | 20 | 21 | ##### Logging ##### 22 | 23 | def add_metrics_to_log(log, metrics, y_true, y_pred, prefix=''): 24 | for metric in metrics: 25 | q = metric(y_true, y_pred) 26 | log[prefix + metric.__name__] = q 27 | return log 28 | 29 | 30 | def log_to_message(log, precision=4): 31 | fmt = "{0}: {1:." + str(precision) + "f}" 32 | return " ".join(fmt.format(k, v) for k, v in log.items()) 33 | 34 | 35 | class ProgressBar(object): 36 | """Cheers @ajratner""" 37 | 38 | def __init__(self, n, length=40): 39 | # Protect against division by zero 40 | self.n = max(1, n) 41 | self.nf = float(n) 42 | self.length = length 43 | # Precalculate the i values that should trigger a write operation 44 | self.ticks = set([round(i/100.0 * n) for i in range(101)]) 45 | self.ticks.add(n-1) 46 | self.bar(0) 47 | 48 | def bar(self, i, message=""): 49 | """Assumes i ranges through [0, n-1]""" 50 | if i in self.ticks: 51 | b = int(np.ceil(((i+1) / self.nf) * self.length)) 52 | sys.stdout.write("\r[{0}{1}] {2}%\t{3}".format( 53 | "="*b, " "*(self.length-b), int(100*((i+1) / self.nf)), message 54 | )) 55 | sys.stdout.flush() 56 | 57 | def close(self, message=""): 58 | # Move the bar to 100% before closing 59 | self.bar(self.n-1) 60 | sys.stdout.write("{0}\n\n".format(message)) 61 | sys.stdout.flush() 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A super simple `fit` method for PyTorch `Module`s 2 | 3 | Ever wanted a pretty, Keras-like `fit` method for your PyTorch `Module`s? 4 | Here's one. It lacks some of the advanced functionality, but it's easy to use: 5 | 6 | ```python 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from pytorch_fitmodule import FitModule 12 | 13 | X, Y, n_classes = torch.get_me_some_data() 14 | 15 | class MLP(FitModule): 16 | def __init__(self, n_feats, n_classes, hidden_size=50): 17 | super(MLP, self).__init__() 18 | self.fc1 = nn.Linear(n_feats, hidden_size) 19 | self.fc2 = nn.Linear(hidden_size, n_classes) 20 | def forward(self, x): 21 | return F.log_softmax(self.fc2(F.relu(self.fc1(x)))) 22 | 23 | f = MLP(X.size()[1], n_classes) 24 | 25 | def n_correct(y_true, y_pred): 26 | return (y_true == torch.max(y_pred, 1)[1]).sum() 27 | 28 | f.fit(X, Y, epochs=5, validation_split=0.3, metrics=[n_correct]) 29 | ``` 30 | 31 | 32 | ## Installation 33 | 34 | Just clone this repo and add it to your Python path. You'll need 35 | * [PyTorch](http://pytorch.org) 36 | * [NumPy](http://numpy.org/) 37 | * [Scikit-Learn](http://scikit-learn.org/) (just for the example) 38 | 39 | all of which are available via [Anaconda](https://www.continuum.io/downloads). 40 | 41 | ## Example 42 | 43 | Try out a simple example with the included script: 44 | 45 | ```bash 46 | python run_example.py 47 | ``` 48 | 49 | ```bash 50 | Epoch 1 / 10 51 | [========================================] 100% loss: 1.3285 accuracy: 0.5676 val_loss: 1.0450 val_accuracy: 0.5693 52 | 53 | Epoch 2 / 10 54 | [========================================] 100% loss: 0.8004 accuracy: 0.8900 val_loss: 0.5804 val_accuracy: 0.8900 55 | 56 | Epoch 3 / 10 57 | [========================================] 100% loss: 0.4638 accuracy: 0.8981 val_loss: 0.3845 val_accuracy: 0.8983 58 | 59 | Epoch 4 / 10 60 | [========================================] 100% loss: 0.3357 accuracy: 0.9033 val_loss: 0.2998 val_accuracy: 0.9043 61 | 62 | Epoch 5 / 10 63 | [========================================] 100% loss: 0.2684 accuracy: 0.9196 val_loss: 0.2462 val_accuracy: 0.9213 64 | 65 | Epoch 6 / 10 66 | [========================================] 100% loss: 0.2215 accuracy: 0.9374 val_loss: 0.2061 val_accuracy: 0.9423 67 | 68 | Epoch 7 / 10 69 | [========================================] 100% loss: 0.1841 accuracy: 0.9586 val_loss: 0.1738 val_accuracy: 0.9590 70 | 71 | Epoch 8 / 10 72 | [========================================] 100% loss: 0.1543 accuracy: 0.9704 val_loss: 0.1478 val_accuracy: 0.9673 73 | 74 | Epoch 9 / 10 75 | [========================================] 100% loss: 0.1298 accuracy: 0.9806 val_loss: 0.1266 val_accuracy: 0.9747 76 | 77 | Epoch 10 / 10 78 | [========================================] 100% loss: 0.1099 accuracy: 0.9861 val_loss: 0.1094 val_accuracy: 0.9800 79 | ``` -------------------------------------------------------------------------------- /pytorch_fitmodule/fit_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from collections import OrderedDict 4 | from functools import partial 5 | from torch.autograd import Variable 6 | from torch.nn import CrossEntropyLoss, Module 7 | from torch.optim import SGD 8 | 9 | from .utils import add_metrics_to_log, get_loader, log_to_message, ProgressBar 10 | 11 | 12 | DEFAULT_LOSS = CrossEntropyLoss() 13 | DEFAULT_OPTIMIZER = partial(SGD, lr=0.001, momentum=0.9) 14 | 15 | 16 | class FitModule(Module): 17 | 18 | def fit(self, 19 | X, 20 | y, 21 | batch_size=32, 22 | epochs=1, 23 | verbose=1, 24 | validation_split=0., 25 | validation_data=None, 26 | shuffle=True, 27 | initial_epoch=0, 28 | seed=None, 29 | loss=DEFAULT_LOSS, 30 | optimizer=DEFAULT_OPTIMIZER, 31 | metrics=None): 32 | """Trains the model similar to Keras' .fit(...) method 33 | 34 | # Arguments 35 | X: training data Tensor. 36 | y: target data Tensor. 37 | batch_size: integer. Number of samples per gradient update. 38 | epochs: integer, the number of times to iterate 39 | over the training data arrays. 40 | verbose: 0, 1. Verbosity mode. 41 | 0 = silent, 1 = verbose. 42 | validation_split: float between 0 and 1: 43 | fraction of the training data to be used as validation data. 44 | The model will set apart this fraction of the training data, 45 | will not train on it, and will evaluate 46 | the loss and any model metrics 47 | on this data at the end of each epoch. 48 | validation_data: (x_val, y_val) tuple on which to evaluate 49 | the loss and any model metrics 50 | at the end of each epoch. The model will not 51 | be trained on this data. 52 | shuffle: boolean, whether to shuffle the training data 53 | before each epoch. 54 | initial_epoch: epoch at which to start training 55 | (useful for resuming a previous training run) 56 | seed: random seed. 57 | optimizer: training optimizer 58 | loss: training loss 59 | metrics: list of functions with signatures `metric(y_true, y_pred)` 60 | where y_true and y_pred are both Tensors 61 | 62 | # Returns 63 | list of OrderedDicts with training metrics 64 | """ 65 | if seed and seed >= 0: 66 | torch.manual_seed(seed) 67 | # Prepare validation data 68 | if validation_data: 69 | X_val, y_val = validation_data 70 | elif validation_split and 0. < validation_split < 1.: 71 | split = int(X.size()[0] * (1. - validation_split)) 72 | X, X_val = X[:split], X[split:] 73 | y, y_val = y[:split], y[split:] 74 | else: 75 | X_val, y_val = None, None 76 | # Build DataLoaders 77 | train_data = get_loader(X, y, batch_size, shuffle) 78 | # Compile optimizer 79 | opt = optimizer(self.parameters()) 80 | # Run training loop 81 | logs = [] 82 | self.train() 83 | for t in range(initial_epoch, epochs): 84 | if verbose: 85 | print("Epoch {0} / {1}".format(t+1, epochs)) 86 | # Setup logger 87 | if verbose: 88 | pb = ProgressBar(len(train_data)) 89 | log = OrderedDict() 90 | epoch_loss = 0.0 91 | # Run batches 92 | for batch_i, batch_data in enumerate(train_data): 93 | # Get batch data 94 | X_batch = Variable(batch_data[0]) 95 | y_batch = Variable(batch_data[1]) 96 | # Backprop 97 | opt.zero_grad() 98 | y_batch_pred = self(X_batch) 99 | batch_loss = loss(y_batch_pred, y_batch) 100 | batch_loss.backward() 101 | opt.step() 102 | # Update status 103 | epoch_loss += batch_loss.data[0] 104 | log['loss'] = float(epoch_loss) / (batch_i + 1) 105 | if verbose: 106 | pb.bar(batch_i, log_to_message(log)) 107 | # Run metrics 108 | if metrics: 109 | y_train_pred = self.predict(X, batch_size) 110 | add_metrics_to_log(log, metrics, y, y_train_pred) 111 | if X_val is not None and y_val is not None: 112 | y_val_pred = self.predict(X_val, batch_size) 113 | val_loss = loss(Variable(y_val_pred), Variable(y_val)) 114 | log['val_loss'] = val_loss.data[0] 115 | if metrics: 116 | add_metrics_to_log(log, metrics, y_val, y_val_pred, 'val_') 117 | logs.append(log) 118 | if verbose: 119 | pb.close(log_to_message(log)) 120 | return logs 121 | 122 | def predict(self, X, batch_size=32): 123 | """Generates output predictions for the input samples. 124 | 125 | Computation is done in batches. 126 | 127 | # Arguments 128 | X: input data Tensor. 129 | batch_size: integer. 130 | 131 | # Returns 132 | prediction Tensor. 133 | """ 134 | # Build DataLoader 135 | data = get_loader(X, batch_size=batch_size) 136 | # Batch prediction 137 | self.eval() 138 | r, n = 0, X.size()[0] 139 | for batch_data in data: 140 | # Predict on batch 141 | X_batch = Variable(batch_data[0]) 142 | y_batch_pred = self(X_batch).data 143 | # Infer prediction shape 144 | if r == 0: 145 | y_pred = torch.zeros((n,) + y_batch_pred.size()[1:]) 146 | # Add to prediction tensor 147 | y_pred[r : min(n, r + batch_size)] = y_batch_pred 148 | r += batch_size 149 | return y_pred 150 | --------------------------------------------------------------------------------