├── .gitignore ├── README.md ├── assets ├── speed_comp.mov ├── training_log_pytorch.png └── training_log_tf_eager.png ├── pytorch ├── __init__.py ├── app.py ├── model ├── model.py ├── test_2.png ├── test_4.png ├── test_6.png └── train.py ├── requirements.txt └── tensorflow_eager ├── __init__.py ├── app.py ├── checkpoint ├── ckpt-9380.data-00000-of-00001 ├── ckpt-9380.index ├── model.py ├── test_2.png ├── test_4.png ├── test_6.png └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # Others 104 | .DS_Store 105 | data/ 106 | *.swp 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST Interactive Examples in PyTorch & TensorFlow Eager Mode 2 | 3 | Modified from PyTorch MNIST official example. Recreated with TensorFlow under eager execution mode. With detailed comments & interactive interface. 4 | 5 | [深度學習新手村:PyTorch入門(中文)](https://pyliaorachel.github.io/blog/tech/deeplearning/2017/10/16/getting-started-with-deep-learning-with-pytorch.html) 6 | 7 | ## Structure 8 | 9 | ``` 10 | pytorch/ 11 | train.py # Train the model 12 | model.py # The defined model 13 | app.py # Interactive predictor 14 | model # Pretrained model, will be overriden when you start training 15 | test_n.png # Sample images for the use of interactive predictor 16 | 17 | tensorflow_eager/ 18 | train.py # Train the model 19 | model.py # The defined model 20 | app.py # Interactive predictor 21 | checkpoint, ckpt-* # Pretrained model, the number after prefix is the final training step 22 | test_n.png # Sample images for the use of interactive predictor 23 | ``` 24 | 25 | ## Usage 26 | 27 | ```bash 28 | # clone project 29 | $ git clone https://github.com/pyliaorachel/pytorch-mnist-interactive.git 30 | $ cd MNIST-pytorch-tensorflow-eager-interactive 31 | 32 | # install dependencies 33 | $ pip3 install -r requirements.txt 34 | 35 | # train & test model 36 | $ python3 -m pytorch.train 37 | # ...data will be fetched to ../data/ 38 | # ...trained model will be saved to ./pytorch/model 39 | # or 40 | $ python3 -m tensorflow_eager.train 41 | # ...data will be fetched to somewhere 42 | # ...trained model will be saved to ./tensorflow_eager/checkpoint & ./tensorflow_eager/ckpt-* 43 | 44 | # test model interactively 45 | $ python3 -m pytorch.app --image= 46 | # or 47 | $ python3 -m tensorflow_eager.app --image= 48 | ``` 49 | 50 | ## Experiments 51 | 52 | ###### Machine Settings 53 | 54 | |OS|CPU|Memory| 55 | |:-:|:-:|:-:| 56 | |MacOS 10.12.4|2 GHz Intel Core i5|16 GB 1867 MHz LPDDR3| 57 | 58 | ###### Results 59 | 60 | ||TensorFlow Eager|PyTorch| 61 | |:-:|:-|:-| 62 | |Time|`real 6m4.446s`
`user 13m42.909s`
`sys 1m54.327s`|`real 3m59.340s`
`user 3m28.285s`
`sys 0m57.395s`| 63 | |Avg. Loss (Test)|0.0610|0.0473| 64 | |Accuracy| 9856/10000 (99%) | 9845/10000 (98%) | 65 | 66 | Avg. Loss and Accuracy are expected to be more or less the same. PyTorch is half the time of TensorFlow's on CPU, while the code complexity is the same. 67 | -------------------------------------------------------------------------------- /assets/speed_comp.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/assets/speed_comp.mov -------------------------------------------------------------------------------- /assets/training_log_pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/assets/training_log_pytorch.png -------------------------------------------------------------------------------- /assets/training_log_tf_eager.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/assets/training_log_tf_eager.png -------------------------------------------------------------------------------- /pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/pytorch/__init__.py -------------------------------------------------------------------------------- /pytorch/app.py: -------------------------------------------------------------------------------- 1 | # 2 | # Handwritten number predictor 3 | # 4 | 5 | import os 6 | import argparse 7 | from PIL import Image 8 | import torch 9 | from torchvision import transforms 10 | 11 | from .model import Net 12 | 13 | 14 | """Settings""" 15 | 16 | package_dir = os.path.dirname(os.path.abspath(__file__)) 17 | default_img_path = os.path.join(package_dir,'test_2.png') 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch MNIST Predictor') 20 | parser.add_argument('--image', type=str, default=default_img_path, metavar='IMG', 21 | help='image for prediction (default: {})'.format(default_img_path)) 22 | args = parser.parse_args() 23 | 24 | 25 | """Make Prediction""" 26 | 27 | # Load model 28 | model_path = os.path.join(package_dir,'model') 29 | model = Net() 30 | model.load_state_dict(torch.load(model_path)) 31 | 32 | # Load & transform image 33 | ori_img = Image.open(args.image).convert('L') 34 | t = transforms.Compose([ 35 | transforms.Resize((28, 28)), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.1307,), (0.3081,)) 38 | ]) 39 | img = torch.autograd.Variable(t(ori_img).unsqueeze(0)) 40 | ori_img.close() 41 | 42 | # Predict 43 | model.eval() 44 | output = model(img) 45 | pred = output.data.max(1, keepdim=True)[1][0][0] 46 | print('Prediction: {}'.format(pred)) 47 | -------------------------------------------------------------------------------- /pytorch/model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/pytorch/model -------------------------------------------------------------------------------- /pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms # torchvision contains common utilities for computer vision 6 | 7 | 8 | class Net(nn.Module): # Inherit from `nn.Module`, define `__init__` & `forward` 9 | def __init__(self): 10 | # Always call the init function of the parent class `nn.Module` 11 | # so that magics can be set up. 12 | super(Net, self).__init__() 13 | 14 | # Define the parameters in your network. 15 | # This is achieved by defining the shapes of the multiple layers in the network. 16 | 17 | # Define two 2D convolutional layers (1 x 10, 10 x 20 each) 18 | # with convolution kernel of size (5 x 5). 19 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 20 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 21 | 22 | # Define a dropout layer 23 | self.conv2_drop = nn.Dropout2d() 24 | 25 | # Define a fully-connected layer (320 x 10) 26 | self.fc = nn.Linear(320, 10) 27 | 28 | def forward(self, x): 29 | # Define the network architecture. 30 | # This is achieved by defining how the network forward propagates your inputs 31 | 32 | # Input image size: 28 x 28, input channel: 1, batch size (training): 64 33 | 34 | # Input (64 x 1 x 28 x 28) -> Conv1 (64 x 10 x 24 x 24) -> Max Pooling (64 x 10 x 12 x 12) -> ReLU -> ... 35 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 36 | 37 | # ... -> Conv2 (64 x 20 x 8 x 8) -> Dropout -> Max Pooling (64 x 20 x 4 x 4) -> ReLU -> ... 38 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 39 | 40 | # ... -> Flatten (64 x 320) -> ... 41 | x = x.view(-1, 320) 42 | 43 | # ... -> FC (64 x 10) -> ... 44 | x = self.fc(x) 45 | 46 | # ... -> Log Softmax -> Output 47 | return F.log_softmax(x, dim=1) 48 | -------------------------------------------------------------------------------- /pytorch/test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/pytorch/test_2.png -------------------------------------------------------------------------------- /pytorch/test_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/pytorch/test_4.png -------------------------------------------------------------------------------- /pytorch/test_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/pytorch/test_6.png -------------------------------------------------------------------------------- /pytorch/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Modified from PyTorch examples: 3 | # https://github.com/pytorch/examples/blob/master/mnist/main.py 4 | # 5 | 6 | import os 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torchvision import datasets, transforms # torchvision contains common utilities for computer vision 13 | from torch.autograd import Variable 14 | 15 | from .model import Net 16 | 17 | 18 | def load_data(train_batch_size, test_batch_size): 19 | """Fetch MNIST dataset 20 | 21 | MNIST dataset has built-in utilities set up in the `torchvision` package, so we just use the `torchvision.datasets.MNIST` module (http://pytorch.org/docs/master/torchvision/datasets.html#torchvision.datasets.MNIST) to make our lives easier. 22 | """ 23 | 24 | kwargs = {} 25 | 26 | # Fetch training data 27 | train_loader = torch.utils.data.DataLoader( 28 | datasets.MNIST('../data', train=True, download=True, 29 | transform=transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.1307,), (0.3081,)) 32 | ])), 33 | batch_size=train_batch_size, shuffle=True, **kwargs) 34 | 35 | # Fetch test data 36 | test_loader = torch.utils.data.DataLoader( 37 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.1307,), (0.3081,)) 40 | ])), 41 | batch_size=test_batch_size, shuffle=True, **kwargs) 42 | 43 | return (train_loader, test_loader) 44 | 45 | def train(model, optimizer, epoch, train_loader, log_interval): 46 | # State that you are training the model 47 | model.train() 48 | 49 | # Iterate over batches of data 50 | for batch_idx, (data, target) in enumerate(train_loader): 51 | # Wrap the input and target output in the `Variable` wrapper 52 | data, target = Variable(data), Variable(target) 53 | 54 | # Clear the gradients, since PyTorch accumulates them 55 | optimizer.zero_grad() 56 | 57 | # Forward propagation 58 | output = model(data) 59 | 60 | # Calculate negative log likelihood loss 61 | loss = F.nll_loss(output, target) 62 | 63 | # Backward propagation 64 | loss.backward() 65 | 66 | # Update the gradients 67 | optimizer.step() 68 | 69 | # Output debug message 70 | if batch_idx % log_interval == 0: 71 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 72 | epoch, batch_idx * len(data), len(train_loader.dataset), 73 | 100. * batch_idx / len(train_loader), loss.data.item())) 74 | 75 | def test(model, test_loader): 76 | # State that you are testing the model; this prevents layers e.g. Dropout to take effect 77 | model.eval() 78 | 79 | # Init loss & correct prediction accumulators 80 | test_loss = 0 81 | correct = 0 82 | 83 | # Optimize the validation process with `torch.no_grad()` 84 | with torch.no_grad(): 85 | # Iterate over data 86 | for data, target in test_loader: # Under `torch.no_grad()`, no need to wrap data & target in `Variable` 87 | # Retrieve output 88 | output = model(data) 89 | 90 | # Calculate & accumulate loss 91 | test_loss += F.nll_loss(output, target, reduction='sum').data.item() 92 | 93 | # Get the index of the max log-probability (the predicted output label) 94 | pred = output.data.argmax(1) 95 | 96 | # If correct, increment correct prediction accumulator 97 | correct += pred.eq(target.data).sum() 98 | 99 | # Print out average test loss 100 | test_loss /= len(test_loader.dataset) 101 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 102 | test_loss, correct, len(test_loader.dataset), 103 | 100. * correct / len(test_loader.dataset))) 104 | 105 | 106 | if __name__ == '__main__': 107 | # Set up training settings from command line options, or use default 108 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 109 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 110 | help='input batch size for training (default: 64)') 111 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 112 | help='input batch size for testing (default: 1000)') 113 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 114 | help='number of epochs to train (default: 10)') 115 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 116 | help='learning rate (default: 0.01)') 117 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 118 | help='SGD momentum (default: 0.5)') 119 | parser.add_argument('--seed', type=int, default=1, metavar='S', 120 | help='random seed (default: 1)') 121 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 122 | help='how many batches to wait before logging training status') 123 | args = parser.parse_args() 124 | 125 | # Provide seed for the pseudorandom number generator s.t. the same results can be reproduced 126 | torch.manual_seed(args.seed) 127 | 128 | 129 | # Instantiate the model 130 | model = Net() 131 | 132 | # Choose SGD as the optimizer, initialize it with the parameters & settings 133 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 134 | 135 | # Load data 136 | train_loader, test_loader = load_data(args.batch_size, args.test_batch_size) 137 | 138 | # Train & test the model 139 | for epoch in range(1, args.epochs + 1): 140 | train(model, optimizer, epoch, train_loader, log_interval=args.log_interval) 141 | test(model, test_loader) 142 | 143 | 144 | # Save the model for future use 145 | package_dir = os.path.dirname(os.path.abspath(__file__)) 146 | model_path = os.path.join(package_dir,'model') 147 | torch.save(model.state_dict(), model_path) 148 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.13.3 2 | olefile==0.44 3 | Pillow==4.3.0 4 | PyYAML==3.12 5 | six==1.11.0 6 | torch==0.2.0.post3 7 | torchvision==0.1.9 8 | tensorflow==1.9.0 9 | -------------------------------------------------------------------------------- /tensorflow_eager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/tensorflow_eager/__init__.py -------------------------------------------------------------------------------- /tensorflow_eager/app.py: -------------------------------------------------------------------------------- 1 | # 2 | # Handwritten number predictor 3 | # 4 | 5 | import os 6 | import argparse 7 | from PIL import Image 8 | 9 | import tensorflow as tf 10 | import tensorflow.contrib.eager as tfe 11 | tf.enable_eager_execution() 12 | 13 | from .model import Net 14 | 15 | 16 | """Settings""" 17 | 18 | package_dir = os.path.dirname(os.path.abspath(__file__)) 19 | default_img_path = os.path.join(package_dir,'test_2.png') 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch MNIST Predictor') 22 | parser.add_argument('--image', type=str, default=default_img_path, metavar='IMG', 23 | help='image for prediction (default: {})'.format(default_img_path)) 24 | args = parser.parse_args() 25 | 26 | 27 | """Make Prediction""" 28 | 29 | # Load & transform image 30 | img = tf.image.decode_png(tf.read_file(args.image), channels=1) 31 | img = tf.image.resize_images(img, (28, 28)) 32 | img = ((img / 255) - 0.1307) / 0.3081 # Normalize 33 | img = tf.expand_dims(img, 0) # Squeeze in batch_size dim 34 | 35 | # Create model 36 | model = Net() 37 | 38 | # Load parameters; they will only be restored after the first run of the mode, in which variables in model are lazily created 39 | checkpoint_dir = os.path.dirname(os.path.abspath(__file__)) 40 | with tfe.restore_variables_on_create(tf.train.latest_checkpoint(checkpoint_dir)): 41 | global_step = tf.train.get_or_create_global_step() 42 | 43 | # Predict 44 | output = model(img, training=False) 45 | pred = tf.argmax(output, 1) 46 | print('Prediction: {}'.format(pred.numpy()[0])) 47 | -------------------------------------------------------------------------------- /tensorflow_eager/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/Users/liaopeiyu/Workspace/temp/pytorch-mnist-interactive/tensorflow_eager/ckpt-9380" 2 | all_model_checkpoint_paths: "/Users/liaopeiyu/Workspace/temp/pytorch-mnist-interactive/tensorflow_eager/ckpt-9380" 3 | -------------------------------------------------------------------------------- /tensorflow_eager/ckpt-9380.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/tensorflow_eager/ckpt-9380.data-00000-of-00001 -------------------------------------------------------------------------------- /tensorflow_eager/ckpt-9380.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/tensorflow_eager/ckpt-9380.index -------------------------------------------------------------------------------- /tensorflow_eager/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras, nn 3 | 4 | 5 | class Net(keras.Model): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | self.conv1 = keras.layers.Conv2D(10, kernel_size=5) 10 | self.conv2 = keras.layers.Conv2D(20, kernel_size=5) 11 | self.dense = keras.layers.Dense(10, activation='softmax') 12 | 13 | self.dropout = keras.layers.Dropout(0.5) 14 | self.max_pool = keras.layers.MaxPooling2D(2) 15 | 16 | def call(self, x, training=True): 17 | y = nn.relu(self.max_pool(self.conv1(x))) 18 | y = nn.relu(self.max_pool(self.dropout(self.conv2(y), training=training))) 19 | 20 | # Flatten feature matrix 21 | batch_size = y.shape[0] 22 | y = tf.reshape(y, (batch_size, -1)) 23 | 24 | y = self.dense(y) 25 | return y 26 | -------------------------------------------------------------------------------- /tensorflow_eager/test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/tensorflow_eager/test_2.png -------------------------------------------------------------------------------- /tensorflow_eager/test_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/tensorflow_eager/test_4.png -------------------------------------------------------------------------------- /tensorflow_eager/test_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyliaorachel/MNIST-pytorch-tensorflow-eager-interactive/1979746fa6a2c25cb3ff8874e7cbc2a28d279125/tensorflow_eager/test_6.png -------------------------------------------------------------------------------- /tensorflow_eager/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.eager as tfe 6 | tf.enable_eager_execution() 7 | 8 | from .model import Net 9 | 10 | 11 | def load_data(train_batch_size, test_batch_size): 12 | # Load data 13 | mnist = tf.keras.datasets.mnist 14 | 15 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 16 | x_train, x_test = ((x_train / 255.0) - 0.1307) / 0.3081, ((x_test / 255.0 - 0.1307)) / 0.3081 # Normalize 17 | x_train = tf.expand_dims(x_train, -1) # Append channel dim 18 | x_test = tf.expand_dims(x_test, -1) 19 | 20 | train_size = x_train.shape[0].value 21 | test_size = x_test.shape[0].value 22 | 23 | # Wrap in tf dataset; type casting so that tf.equal can work in cal_acc 24 | dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ 25 | .map(lambda x, y: (x, tf.cast(y, tf.int64)))\ 26 | .shuffle(1000)\ 27 | .batch(train_batch_size) 28 | dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))\ 29 | .map(lambda x, y: (x, tf.cast(y, tf.int64)))\ 30 | .shuffle(1000)\ 31 | .batch(train_batch_size) 32 | 33 | return (dataset_train, dataset_test, train_size, test_size) 34 | 35 | def cal_loss(model, data, target, training=True): 36 | # Forward propagation 37 | output = model(data, training=training) 38 | loss = tf.keras.losses.sparse_categorical_crossentropy(target, output) # The order is correct, as specified in Keras doc 39 | return (loss, output) 40 | 41 | def train(model, optimizer, epoch, train_loader, batch_size, train_size, log_interval): 42 | # Iterate over batches of data 43 | for batch_idx, (data, target) in enumerate(train_loader): 44 | # Calculate loss & gradient 45 | with tf.GradientTape() as tape: 46 | loss, output = cal_loss(model, data, target) 47 | loss_value = tf.reduce_mean(loss) 48 | gradients = tape.gradient(loss_value, model.variables) 49 | grads_and_vars = zip(gradients, model.variables) 50 | 51 | # Backward propagation 52 | optimizer.apply_gradients(grads_and_vars, global_step=tf.train.get_or_create_global_step()) 53 | 54 | # Output debug message 55 | if batch_idx % log_interval == 0: 56 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 57 | epoch, batch_idx * batch_size, train_size, 58 | 100. * batch_idx * batch_size / train_size, loss_value.numpy())) 59 | 60 | def test(model, test_size, test_loader): 61 | # Init loss & correct prediction accumulators 62 | test_loss = 0 63 | correct = 0 64 | 65 | # Iterate over data, accumulate total loss & correct predictions 66 | for data, target in test_loader: 67 | # Get loss and output from network 68 | loss, output = cal_loss(model, data, target, training=False) 69 | test_loss += tf.reduce_sum(loss).numpy() 70 | 71 | # Get correct number of predictions 72 | pred = tf.argmax(output, 1) 73 | correct += tf.reduce_sum(tf.cast(tf.equal(pred, target), tf.int32)) 74 | 75 | # Print out average test loss 76 | test_loss /= test_size 77 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 78 | test_loss, correct, test_size, 79 | 100. * correct / test_size)) 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 84 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 85 | help='input batch size for training (default: 64)') 86 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 87 | help='input batch size for testing (default: 1000)') 88 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 89 | help='number of epochs to train (default: 10)') 90 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 91 | help='learning rate (default: 0.01)') 92 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 93 | help='SGD momentum (default: 0.5)') 94 | parser.add_argument('--seed', type=int, default=1, metavar='S', 95 | help='random seed (default: 1)') 96 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 97 | help='how many batches to wait before logging training status') 98 | args = parser.parse_args() 99 | 100 | # Configure tf 101 | tf.set_random_seed(args.seed) 102 | 103 | # Load data 104 | train_loader, test_loader, train_size, test_size = load_data(args.batch_size, args.test_batch_size) 105 | 106 | # Build model 107 | model = Net() 108 | optimizer = tf.train.MomentumOptimizer(args.lr, args.momentum) 109 | 110 | # Train & test model 111 | for epoch in range(1, args.epochs + 1): 112 | train(model, optimizer, epoch, train_loader, args.batch_size, train_size, log_interval=args.log_interval) 113 | test(model, test_size, test_loader) 114 | 115 | # Save model for future use 116 | checkpoint_dir = os.path.dirname(os.path.abspath(__file__)) 117 | checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') 118 | 119 | global_step = tf.train.get_or_create_global_step() 120 | all_variables = (model.variables + [global_step]) 121 | tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) 122 | --------------------------------------------------------------------------------