├── .gitignore ├── README.md ├── data └── MNIST_data │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-labels-idx1-ubyte.gz │ ├── train-images-idx3-ubyte.gz │ └── train-labels-idx1-ubyte.gz └── mnist_eager.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Eager-Examples 2 | 3 | Simple examples using TensorFlow Eager Execution. 4 | 5 | ## Install TensorFlow with Eager Execution Support 6 | 7 | CPU version: 8 | ``` 9 | pip install tf-nightly 10 | ``` 11 | 12 | GPU version: 13 | ``` 14 | pip install tf-nightly-gpu 15 | ``` 16 | 17 | ### Examples 18 | 19 | - mnist_eager.py : Rewrite the [official MNIST tutorial](https://www.tensorflow.org/get_started/mnist/beginners) with Eager mode. 20 | -------------------------------------------------------------------------------- /data/MNIST_data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/TensorFlow-Eager-Examples/e90dca042a1ebe0e862583d51689d1eeee3eeeff/data/MNIST_data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/MNIST_data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/TensorFlow-Eager-Examples/e90dca042a1ebe0e862583d51689d1eeee3eeeff/data/MNIST_data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/MNIST_data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/TensorFlow-Eager-Examples/e90dca042a1ebe0e862583d51689d1eeee3eeeff/data/MNIST_data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/MNIST_data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/TensorFlow-Eager-Examples/e90dca042a1ebe0e862583d51689d1eeee3eeeff/data/MNIST_data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /mnist_eager.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.eager as tfe 3 | from tensorflow.examples.tutorials.mnist import input_data 4 | tfe.enable_eager_execution() 5 | 6 | 7 | W = tf.get_variable(name="W", shape=(784, 10)) 8 | b = tf.get_variable(name="b", shape=(10, )) 9 | 10 | 11 | def softmax_model(image_batch): 12 | model_output = tf.nn.softmax(tf.matmul(image_batch, W) + b) 13 | return model_output 14 | 15 | 16 | def cross_entropy(model_output, label_batch): 17 | loss = tf.reduce_mean( 18 | -tf.reduce_sum(label_batch * tf.log(model_output), 19 | reduction_indices=[1])) 20 | return loss 21 | 22 | 23 | @tfe.implicit_value_and_gradients 24 | def cal_gradient(image_batch, label_batch): 25 | return cross_entropy(softmax_model(image_batch), label_batch) 26 | 27 | 28 | if __name__ == '__main__': 29 | data = input_data.read_data_sets("data/MNIST_data/", one_hot=True) 30 | train_ds = tf.data.Dataset.from_tensor_slices((data.train.images, data.train.labels))\ 31 | .map(lambda x, y: (x, tf.cast(y, tf.float32)))\ 32 | .shuffle(buffer_size=1000)\ 33 | .batch(100)\ 34 | 35 | optimizer = tf.train.GradientDescentOptimizer(0.5) 36 | 37 | for step, (image_batch, label_batch) in enumerate(tfe.Iterator(train_ds)): 38 | loss, grads_and_vars = cal_gradient(image_batch, label_batch) 39 | optimizer.apply_gradients(grads_and_vars) 40 | print("step: {} loss: {}".format(step, loss.numpy())) 41 | 42 | model_test_output = softmax_model(data.test.images) 43 | model_test_label = data.test.labels 44 | correct_prediction = tf.equal(tf.argmax(model_test_output, 1), tf.argmax(model_test_label, 1)) 45 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 46 | 47 | print("test accuracy = {}".format(accuracy.numpy())) 48 | --------------------------------------------------------------------------------