├── tests ├── __init__.py └── fake_test.py ├── experiments ├── results │ ├── performance_30epochs_combined.json │ ├── MNIST │ │ ├── accuracy-mnist.png │ │ ├── accuracy-30epoch-mnist.png │ │ ├── performance-30epochs-mnist.png │ │ ├── performance_combined.json │ │ └── performance_30epochs_combined.json │ └── Fashion-MNIST │ │ ├── accuracy-fashion-mnist.png │ │ ├── accuracy-30epoch-fashion-mnist.png │ │ ├── performance-30epochs-fashion.png │ │ ├── performance_combined.json │ │ └── performance_30epochs_combined.json ├── combine_results.ipynb ├── 30epochresults.ipynb └── deep_learning_paper.ipynb ├── keras_svm ├── __init__.py └── model_svm_wrapper.py ├── .gitignore ├── MANIFEST.in ├── send.bat ├── setup.cfg ├── tox.ini ├── LICENSE ├── setup.py └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/results/performance_30epochs_combined.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/fake_test.py: -------------------------------------------------------------------------------- 1 | def test_success(): 2 | assert True 3 | -------------------------------------------------------------------------------- /keras_svm/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_svm_wrapper import ModelSVMWrapper 2 | 3 | __all__ = ["ModelSVMWrapper"] 4 | -------------------------------------------------------------------------------- /experiments/results/MNIST/accuracy-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/MNIST/accuracy-mnist.png -------------------------------------------------------------------------------- /experiments/results/MNIST/accuracy-30epoch-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/MNIST/accuracy-30epoch-mnist.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | env/ 3 | .tox/ 4 | *.egg-info/ 5 | dist/ 6 | build/ 7 | .pytest_cache/ 8 | **/__pycache__/ 9 | *.pyc 10 | **/.ipynb_checkpoints/ 11 | -------------------------------------------------------------------------------- /experiments/results/MNIST/performance-30epochs-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/MNIST/performance-30epochs-mnist.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include the README 2 | include *.md 3 | 4 | # Include the data files 5 | recursive-include data * 6 | 7 | include *.py 8 | include *.txt 9 | include LICENSE -------------------------------------------------------------------------------- /experiments/results/Fashion-MNIST/accuracy-fashion-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/Fashion-MNIST/accuracy-fashion-mnist.png -------------------------------------------------------------------------------- /experiments/results/Fashion-MNIST/accuracy-30epoch-fashion-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/Fashion-MNIST/accuracy-30epoch-fashion-mnist.png -------------------------------------------------------------------------------- /experiments/results/Fashion-MNIST/performance-30epochs-fashion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/Fashion-MNIST/performance-30epochs-fashion.png -------------------------------------------------------------------------------- /send.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | SET /P variable="Did you remember to run tox?" 3 | 4 | if %variable%==Y set variable=y 5 | if %variable%==y ( 6 | rm -r dist/* 7 | python setup.py bdist_wheel --universal 8 | python setup.py sdist 9 | twine upload dist/* 10 | ) else ( 11 | echo upload cancelled 12 | ) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file in the wheel. 3 | license_file = LICENSE 4 | [bdist_wheel] 5 | # This flag says to generate wheels that support both Python 2 and Python 6 | # 3. If your code will not run unchanged on both Python 2 and 3, you will 7 | # need to generate separate wheels for each Python version that you 8 | # support. 9 | universal=1 -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Use: 2 | # 3 | # - check-manifest 4 | # confirm items checked into vcs are in your sdist 5 | # - python setup.py check (using the readme_renderer extension) 6 | # confirms your long_description will render correctly on pypi 7 | # 8 | # and also to help confirm pull requests to this project. 9 | 10 | [tox] 11 | envlist = py{27,36} 12 | 13 | [testenv] 14 | basepython = 15 | py27: python2.7 16 | py34: python3.4 17 | py35: python3.5 18 | py36: python3.6 19 | deps = 20 | docutils 21 | check-manifest 22 | ; readme_renderer 23 | flake8 24 | pytest 25 | commands = 26 | check-manifest --ignore tox.ini,tests*,*.bat 27 | python setup.py check -m -r -s 28 | flake8 . 29 | py.test tests 30 | [flake8] 31 | 32 | ignore = E501 33 | exclude = .tox,*.egg,build,data,env,*.bat 34 | select = E,W,F -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Daan Luttik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='keras_svm', 5 | version='1.0.0b10', 6 | description='A model to use keras models with Support Vector Machines', 7 | url='https://github.com/Luttik/keras_svm/tree/master', # Optional 8 | author='Daan Luttik', # Optional 9 | author_email='d.t.luttik@gmail.com', # Optional 10 | license='MIT', 11 | classifiers=[ 12 | 'Development Status :: 4 - Beta', 13 | 'Intended Audience :: Developers', 14 | 'Intended Audience :: Education', 15 | 'Intended Audience :: Science/Research', 16 | 'License :: OSI Approved :: MIT License', 17 | 'Programming Language :: Python :: 2', 18 | 'Programming Language :: Python :: 2.7', 19 | 'Programming Language :: Python :: 3', 20 | 'Programming Language :: Python :: 3.4', 21 | 'Programming Language :: Python :: 3.5', 22 | 'Programming Language :: Python :: 3.6', 23 | 'Topic :: Software Development :: Libraries', 24 | 'Topic :: Software Development :: Libraries :: Python Modules' 25 | ], 26 | keywords='keras sklearn svm ml', # Optional 27 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required 28 | install_requires=['keras', 'scikit-learn'], # Optional 29 | ) 30 | -------------------------------------------------------------------------------- /experiments/results/MNIST/performance_combined.json: -------------------------------------------------------------------------------- 1 | [{"with_svm@-2": [0.9898, 0.9912, 0.9929, 0.9926, 0.9925, 0.9922, 0.9924, 0.9928, 0.9919, 0.9926], "with_svm@-3": [0.9912, 0.9916, 0.9926, 0.9921, 0.9925, 0.9931, 0.9931, 0.9931, 0.9935, 0.9927], "without_svm": [0.9873, 0.9897, 0.9905, 0.99, 0.9925, 0.9932, 0.9901, 0.9912, 0.99, 0.9907]}, 2 | {"with_svm@-2": [0.9891, 0.9902, 0.9918, 0.9927, 0.9926, 0.9935, 0.993, 0.9939, 0.9927, 0.9933], "with_svm@-3": [0.9912, 0.9917, 0.9918, 0.9935, 0.9926, 0.9933, 0.9936, 0.9932, 0.9935, 0.9932], "without_svm": [0.9836, 0.9847, 0.9909, 0.9908, 0.9906, 0.9898, 0.9922, 0.9905, 0.9925, 0.9906]}, 3 | {"with_svm@-2": [0.9893, 0.9905, 0.9921, 0.9929, 0.9932, 0.9935, 0.9935, 0.9935, 0.9921, 0.9928], "with_svm@-3": [0.9906, 0.9909, 0.992, 0.9934, 0.9929, 0.9938, 0.9927, 0.9935, 0.9934, 0.9937], "without_svm": [0.9823, 0.9894, 0.9916, 0.9908, 0.9906, 0.9929, 0.9934, 0.9927, 0.9919, 0.9921]}, 4 | {"with_svm@-2": [0.9897, 0.9912, 0.9919, 0.9923, 0.9913, 0.9923, 0.992, 0.9921, 0.9931, 0.9927], "with_svm@-3": [0.9907, 0.992, 0.9929, 0.9922, 0.9926, 0.993, 0.9927, 0.9931, 0.993, 0.9931], "without_svm": [0.9849, 0.9895, 0.9866, 0.9896, 0.9916, 0.9915, 0.991, 0.9928, 0.9924, 0.9917]}, 5 | {"with_svm@-2": [0.9891, 0.9905, 0.9916, 0.9924, 0.992, 0.9927, 0.9923, 0.9935, 0.9927, 0.9935], "with_svm@-3": [0.9903, 0.9917, 0.9927, 0.993, 0.9933, 0.9933, 0.9923, 0.9931, 0.9928, 0.9939], "without_svm": [0.9869, 0.9879, 0.9918, 0.991, 0.9902, 0.9891, 0.9921, 0.9924, 0.9924, 0.9921]}] -------------------------------------------------------------------------------- /experiments/results/Fashion-MNIST/performance_combined.json: -------------------------------------------------------------------------------- 1 | [{"with_svm@-2": [0.8779, 0.8947, 0.9037, 0.9061, 0.911, 0.9137, 0.9087, 0.9128, 0.9109, 0.9116], "with_svm@-3": [0.8973, 0.9043, 0.905, 0.911, 0.9104, 0.9107, 0.9112, 0.9121, 0.9117, 0.9118], "without_svm": [0.8581, 0.8506, 0.8879, 0.8987, 0.9048, 0.9031, 0.9064, 0.9076, 0.8957, 0.9061]}, 2 | {"with_svm@-2": [0.8799, 0.8943, 0.9013, 0.907, 0.9101, 0.9114, 0.9078, 0.9137, 0.9132, 0.9146], "with_svm@-3": [0.8957, 0.9004, 0.9042, 0.9052, 0.9103, 0.9109, 0.9109, 0.9112, 0.9114, 0.9128], "without_svm": [0.8444, 0.8788, 0.8809, 0.8927, 0.9043, 0.9091, 0.9024, 0.9059, 0.9057, 0.9053]}, 3 | {"with_svm@-2": [0.8743, 0.8937, 0.902, 0.9039, 0.9108, 0.9113, 0.9075, 0.9107, 0.9103, 0.912], "with_svm@-3": [0.8929, 0.9007, 0.9035, 0.908, 0.9056, 0.9078, 0.9083, 0.9052, 0.909, 0.9072], "without_svm": [0.8517, 0.8701, 0.8958, 0.8759, 0.901, 0.9015, 0.9015, 0.8979, 0.9033, 0.9049]}, 4 | {"with_svm@-2": [0.8794, 0.8946, 0.9012, 0.9044, 0.9095, 0.9138, 0.9124, 0.9127, 0.9126, 0.9128], "with_svm@-3": [0.8922, 0.901, 0.9026, 0.9046, 0.9059, 0.9102, 0.9081, 0.908, 0.9114, 0.9119], "without_svm": [0.8536, 0.8766, 0.8935, 0.8707, 0.8994, 0.9079, 0.9052, 0.9005, 0.909, 0.9099]}, 5 | {"with_svm@-2": [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145, 0.9122, 0.9123], "with_svm@-3": [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912, 0.9121, 0.9109], "without_svm": [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062, 0.8882, 0.9062]}] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras SVM 2 | 3 | [![PyPI - Status](https://img.shields.io/pypi/status/keras-svm.svg)](https://pypi.org/project/keras-svm/) 4 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/keras-svm.svg)](https://pypi.org/project/keras-svm/) 5 | [![PyPI - License](https://img.shields.io/pypi/l/keras-svm.svg)](https://github.com/Luttik/keras_svm/blob/master/LICENSE) 6 | [![PyPI](https://img.shields.io/pypi/v/keras-svm.svg)](https://pypi.org/project/keras-svm/) 7 | 8 | ## Purpose 9 | Provides a wrapper class that effectively replaces the softmax of your Keras model with a SVM. 10 | 11 | The SVM has no impact on the training of the Neural Network, but replacing softmax with an SVM has been shown to perform better on unseen data. 12 | 13 | ## Code examples 14 | ### Example construction 15 | ``` 16 | # Build a classical model 17 | def build_model(): 18 | model = models.Sequential() 19 | model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) 20 | model.add(layers.MaxPooling2D((2, 2))) 21 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 22 | model.add(layers.MaxPooling2D((2, 2))) 23 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 24 | model.add(layers.Flatten(name="intermediate_output")) 25 | model.add(layers.Dense(64, activation='relu')) 26 | model.add(layers.Dense(10, activation='softmax')) 27 | 28 | # The extra metric is important for the evaluate function 29 | model.compile(optimizer='rmsprop', 30 | loss='categorical_crossentropy', 31 | metrics=['accuracy']) 32 | return model 33 | 34 | # Wrap it in the ModelSVMWrapper 35 | wrapper = ModelSVMWrapper(build_model()) 36 | ``` 37 | 38 | ### Training while maintaining an accuracy score 39 | ``` 40 | accuracy = { 41 | "with_svm": [], 42 | "without_svm": [] 43 | } 44 | 45 | epochs = 10 46 | for i in range(epochs): 47 | print('Starting run: {}'.format(i)) 48 | wrapper.fit(train_images, train_labels, epochs=1, batch_size=64) 49 | accuracy["with_svm"].append(wrapper.evaluate(test_images, test_labels)) 50 | accuracy["without_svm"].append( 51 | wrapper.model.evaluate(test_images, to_categorical(test_labels))[1]) 52 | ``` 53 | -------------------------------------------------------------------------------- /experiments/combine_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import plotly\n", 10 | "import plotly.plotly as py\n", 11 | "import plotly.graph_objs as go\n", 12 | "import json\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 10, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "results_json = json.load(open('results/Fashion-MNIST/performance_combined.json'))" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 11, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "{'with_svm@-2': [[0.8779, 0.8799, 0.8743, 0.8794, 0.8876],\n", 34 | " [0.8947, 0.8943, 0.8937, 0.8946, 0.8992],\n", 35 | " [0.9037, 0.9013, 0.902, 0.9012, 0.9077],\n", 36 | " [0.9061, 0.907, 0.9039, 0.9044, 0.912],\n", 37 | " [0.911, 0.9101, 0.9108, 0.9095, 0.9115],\n", 38 | " [0.9137, 0.9114, 0.9113, 0.9138, 0.9109],\n", 39 | " [0.9087, 0.9078, 0.9075, 0.9124, 0.9122],\n", 40 | " [0.9128, 0.9137, 0.9107, 0.9127, 0.9145],\n", 41 | " [0.9109, 0.9132, 0.9103, 0.9126, 0.9122],\n", 42 | " [0.9116, 0.9146, 0.912, 0.9128, 0.9123]],\n", 43 | " 'with_svm@-3': [[0.8973, 0.8957, 0.8929, 0.8922, 0.8999],\n", 44 | " [0.9043, 0.9004, 0.9007, 0.901, 0.9046],\n", 45 | " [0.905, 0.9042, 0.9035, 0.9026, 0.9097],\n", 46 | " [0.911, 0.9052, 0.908, 0.9046, 0.9095],\n", 47 | " [0.9104, 0.9103, 0.9056, 0.9059, 0.9111],\n", 48 | " [0.9107, 0.9109, 0.9078, 0.9102, 0.9115],\n", 49 | " [0.9112, 0.9109, 0.9083, 0.9081, 0.9116],\n", 50 | " [0.9121, 0.9112, 0.9052, 0.908, 0.912],\n", 51 | " [0.9117, 0.9114, 0.909, 0.9114, 0.9121],\n", 52 | " [0.9118, 0.9128, 0.9072, 0.9119, 0.9109]],\n", 53 | " 'without_svm': [[0.8581, 0.8444, 0.8517, 0.8536, 0.8667],\n", 54 | " [0.8506, 0.8788, 0.8701, 0.8766, 0.8758],\n", 55 | " [0.8879, 0.8809, 0.8958, 0.8935, 0.8939],\n", 56 | " [0.8987, 0.8927, 0.8759, 0.8707, 0.9006],\n", 57 | " [0.9048, 0.9043, 0.901, 0.8994, 0.903],\n", 58 | " [0.9031, 0.9091, 0.9015, 0.9079, 0.8985],\n", 59 | " [0.9064, 0.9024, 0.9015, 0.9052, 0.9041],\n", 60 | " [0.9076, 0.9059, 0.8979, 0.9005, 0.9062],\n", 61 | " [0.8957, 0.9057, 0.9033, 0.909, 0.8882],\n", 62 | " [0.9061, 0.9053, 0.9049, 0.9099, 0.9062]]}" 63 | ] 64 | }, 65 | "execution_count": 11, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "results = {key:[[results_json[j][key][i] for j in range(len(results_json))] for i in range(len(results_json[0][key]))]\n", 72 | " for key in results_json[0]}\n", 73 | "results" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 12, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/html": [ 84 | "" 85 | ], 86 | "text/plain": [ 87 | "" 88 | ] 89 | }, 90 | "execution_count": 12, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "plotly.tools.set_credentials_file(username='luttik', api_key='6p0E4ba6sVIgV575o3uX')\n", 97 | "\n", 98 | "data = [\n", 99 | " go.Scatter(\n", 100 | " x=list(range(10)),\n", 101 | " y=[np.mean(results[key][epoch]) for epoch in range(len(results[key]))],\n", 102 | " name=key,\n", 103 | " error_y=dict(\n", 104 | " type='data',\n", 105 | " array=[np.std(results[key][epoch]) for epoch in range(len(results[key]))],\n", 106 | " visible=True\n", 107 | " )\n", 108 | " )\n", 109 | " for key in results\n", 110 | "]\n", 111 | "\n", 112 | "layout = dict(title = 'Accuracy on the Fashion-MNIST dataset',\n", 113 | " xaxis = dict(title = 'Epoch'),\n", 114 | " yaxis = dict(title = 'Accuracy'),\n", 115 | " )\n", 116 | "\n", 117 | "py.iplot(data, filename='results-svm-replacement-fashion', layout=layout)" 118 | ] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "Python 3", 124 | "language": "python", 125 | "name": "python3" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.6.4" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /experiments/30epochresults.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import plotly\n", 10 | "import plotly.plotly as py\n", 11 | "import plotly.graph_objs as go\n", 12 | "import json\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "data = json.load(open('results/MNIST/performance_30epochs_combined.json'))\n", 23 | "history = [x['history'] for x in data]\n", 24 | "performances = [x['performance'] for x in data]" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 4, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/plain": [ 35 | "(array([[0.94743333, 0.94448333, 0.94555 , 0.9476 , 0.9471 ],\n", 36 | " [0.9853 , 0.98511667, 0.98516667, 0.98528333, 0.98583333],\n", 37 | " [0.98936667, 0.99013333, 0.99 , 0.99033333, 0.98983333],\n", 38 | " [0.99188333, 0.99286667, 0.9922 , 0.9926 , 0.99235 ],\n", 39 | " [0.99345 , 0.99396667, 0.9939 , 0.99435 , 0.9941 ],\n", 40 | " [0.99508333, 0.99548333, 0.99503333, 0.99541667, 0.99521667],\n", 41 | " [0.99598333, 0.99593333, 0.9957 , 0.99635 , 0.9963 ],\n", 42 | " [0.99655 , 0.99675 , 0.9967 , 0.99688333, 0.99675 ],\n", 43 | " [0.99718333, 0.99715 , 0.99726667, 0.99751667, 0.99725 ],\n", 44 | " [0.99716667, 0.99738333, 0.99766667, 0.99785 , 0.99778333],\n", 45 | " [0.99795 , 0.99775 , 0.9979 , 0.99805 , 0.99793333],\n", 46 | " [0.99818333, 0.99835 , 0.99823333, 0.99825 , 0.99841667],\n", 47 | " [0.99816667, 0.99846667, 0.99841667, 0.99846667, 0.99838333],\n", 48 | " [0.99826667, 0.99858333, 0.9986 , 0.9988 , 0.99871667],\n", 49 | " [0.99866667, 0.99891667, 0.99863333, 0.99873333, 0.99888333],\n", 50 | " [0.99856667, 0.99886667, 0.99886667, 0.999 , 0.99865 ],\n", 51 | " [0.99885 , 0.999 , 0.99878333, 0.99905 , 0.99881667],\n", 52 | " [0.99891667, 0.99881667, 0.99883333, 0.99918333, 0.99911667],\n", 53 | " [0.99903333, 0.99916667, 0.99895 , 0.99906667, 0.99918333],\n", 54 | " [0.99908333, 0.9994 , 0.99903333, 0.99928333, 0.9991 ],\n", 55 | " [0.99911667, 0.99896667, 0.99918333, 0.99933333, 0.9992 ],\n", 56 | " [0.99925 , 0.99911667, 0.9992 , 0.99875 , 0.9992 ],\n", 57 | " [0.99931667, 0.99916667, 0.99928333, 0.99915 , 0.99938333],\n", 58 | " [0.99928333, 0.99938333, 0.99925 , 0.99956667, 0.99926667],\n", 59 | " [0.99933333, 0.99953333, 0.99923333, 0.99921667, 0.99926667],\n", 60 | " [0.99916667, 0.99946667, 0.99908333, 0.99938333, 0.99946667],\n", 61 | " [0.99931667, 0.99936667, 0.99941667, 0.99941667, 0.99928333],\n", 62 | " [0.99935 , 0.99965 , 0.99958333, 0.99906667, 0.99935 ],\n", 63 | " [0.99945 , 0.9994 , 0.99953333, 0.99926667, 0.99935 ],\n", 64 | " [0.99953333, 0.99936667, 0.99935 , 0.99955 , 0.99938333]]),\n", 65 | " array([[0.981 , 0.9844, 0.9848, 0.9743, 0.9843],\n", 66 | " [0.9902, 0.9881, 0.9878, 0.9881, 0.9892],\n", 67 | " [0.9916, 0.9916, 0.9886, 0.9894, 0.9915],\n", 68 | " [0.9847, 0.9893, 0.9912, 0.9909, 0.9915],\n", 69 | " [0.9933, 0.993 , 0.9903, 0.9913, 0.9916],\n", 70 | " [0.9897, 0.9919, 0.9857, 0.9933, 0.9917],\n", 71 | " [0.9917, 0.9919, 0.9917, 0.9915, 0.9923],\n", 72 | " [0.9891, 0.992 , 0.9911, 0.9914, 0.9932],\n", 73 | " [0.9929, 0.9911, 0.9904, 0.9912, 0.9911],\n", 74 | " [0.9925, 0.9935, 0.9927, 0.9932, 0.9921],\n", 75 | " [0.9906, 0.9923, 0.9917, 0.9917, 0.9892],\n", 76 | " [0.9919, 0.9938, 0.9919, 0.994 , 0.9917],\n", 77 | " [0.993 , 0.9933, 0.9917, 0.9936, 0.9928],\n", 78 | " [0.9891, 0.9931, 0.991 , 0.9922, 0.9922],\n", 79 | " [0.993 , 0.9928, 0.9908, 0.9924, 0.9918],\n", 80 | " [0.993 , 0.9922, 0.9943, 0.9943, 0.9921],\n", 81 | " [0.9933, 0.9934, 0.992 , 0.9935, 0.9919],\n", 82 | " [0.9916, 0.9924, 0.9926, 0.9925, 0.9933],\n", 83 | " [0.9923, 0.9929, 0.9918, 0.9935, 0.9918],\n", 84 | " [0.9923, 0.9931, 0.993 , 0.9935, 0.9916],\n", 85 | " [0.9922, 0.9925, 0.9914, 0.9896, 0.9923],\n", 86 | " [0.9919, 0.9922, 0.9927, 0.993 , 0.9925],\n", 87 | " [0.9915, 0.9928, 0.9924, 0.9933, 0.9914],\n", 88 | " [0.993 , 0.9935, 0.9931, 0.9928, 0.9922],\n", 89 | " [0.9917, 0.9935, 0.9926, 0.9921, 0.9916],\n", 90 | " [0.9913, 0.9937, 0.993 , 0.9922, 0.9921],\n", 91 | " [0.9933, 0.9932, 0.9908, 0.993 , 0.9916],\n", 92 | " [0.9935, 0.9922, 0.9917, 0.9924, 0.9918],\n", 93 | " [0.9933, 0.9918, 0.9937, 0.9916, 0.9934],\n", 94 | " [0.9923, 0.9914, 0.9926, 0.9921, 0.9923]]))" 95 | ] 96 | }, 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "train_accuracies = np.swapaxes(np.array([x['acc'] for x in history], dtype=float), 0, 1)\n", 104 | "test_accuracy = np.swapaxes(np.array([x['val_acc'] for x in history], dtype=float), 0, 1)\n", 105 | "train_accuracies, test_accuracy" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 7, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/html": [ 116 | "" 117 | ], 118 | "text/plain": [ 119 | "" 120 | ] 121 | }, 122 | "execution_count": 7, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "results = {\n", 129 | " \"train_accuracies\": train_accuracies,\n", 130 | " \"test_accuracy\": test_accuracy\n", 131 | "}\n", 132 | "\n", 133 | "data = [\n", 134 | " go.Scatter(\n", 135 | " x=list(range(30)),\n", 136 | " y=[np.mean(results[key][epoch]) for epoch in range(len(results[key]))],\n", 137 | " name=key,\n", 138 | " error_y=dict(\n", 139 | " type='data',\n", 140 | " array=[np.std(results[key][epoch]) for epoch in range(len(results[key]))],\n", 141 | " visible=True\n", 142 | " )\n", 143 | " )\n", 144 | " for key in results\n", 145 | "]\n", 146 | "\n", 147 | "layout = dict(title = 'Accuracy on the Fashion-MNIST dataset 30-epochs',\n", 148 | " xaxis = dict(title = 'Epoch'),\n", 149 | " yaxis = dict(title = 'Accuracy'),\n", 150 | " )\n", 151 | "\n", 152 | "py.iplot(data, filename='results-30epochs-mnist', layout=layout)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "performances" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "text/html": [ 172 | "" 173 | ], 174 | "text/plain": [ 175 | "" 176 | ] 177 | }, 178 | "execution_count": 6, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "plot_data = {\n", 185 | " 'Without SVM': test_accuracy[-1],\n", 186 | " 'SVM_@-2': [x['with_svm@-2'] for x in performances],\n", 187 | " 'SVM_@-3': [x['with_svm@-3'] for x in performances]\n", 188 | "}\n", 189 | "\n", 190 | "data = [\n", 191 | " go.Box(\n", 192 | " y=plot_data[key],\n", 193 | " name=key,\n", 194 | " ) for key in plot_data\n", 195 | "]\n", 196 | "\n", 197 | "layout = dict(title = 'Accuracy on the Fashion-MNIST dataset 30-epochs',\n", 198 | " xaxis = dict(title = 'Epoch'),\n", 199 | " yaxis = dict(title = 'Accuracy'),\n", 200 | " )\n", 201 | "\n", 202 | "py.iplot(data, filename='performances-30epochs-mnist', layout=layout)" 203 | ] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.6.4" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /experiments/results/Fashion-MNIST/performance_30epochs_combined.json: -------------------------------------------------------------------------------- 1 | [{"history": {"val_loss": [0.39961640694141387, 0.3240031172275543, 0.36486296498775483, 0.2999567676544189, 0.26941082603931427, 0.28105862131118775, 0.2583421004772186, 0.27426248898506167, 0.26412350838184356, 0.2587658328771591, 0.3263972431182861, 0.3037034754753113, 0.2993787395477295, 0.3444756777048111, 0.34303277168273927, 0.3952072314620018, 0.323935725069046, 0.43344608159065245, 0.43881502020359037, 0.40885772070884707, 0.48054315140247345, 0.40082721934318544, 0.4527245432853699, 0.47377985467910766, 0.4480116040706634, 0.4445895487546921, 0.5960354033231735, 0.5005660709023476, 0.4928376198232174, 0.5406574306488037], "val_acc": [0.8512, 0.8836, 0.8636, 0.8917, 0.9024, 0.9024, 0.9076, 0.9061, 0.9144, 0.9106, 0.8941, 0.9075, 0.9111, 0.9024, 0.9087, 0.8956, 0.9104, 0.9007, 0.9093, 0.9069, 0.9029, 0.9001, 0.9099, 0.9088, 0.9077, 0.9047, 0.9076, 0.9067, 0.9097, 0.9044], "loss": [0.5369842763264974, 0.33226358030637104, 0.28292571523189547, 0.2514223607858022, 0.22796711163520814, 0.20951138263543448, 0.19243775055408477, 0.17820718478361766, 0.1669310154000918, 0.15484501781662305, 0.14333154555956523, 0.13416243276198705, 0.12566612465679647, 0.119671809245646, 0.11197880250414212, 0.10515761943757534, 0.09846485815346241, 0.09648188420385122, 0.09266547986765702, 0.0884354511961341, 0.0859529619326194, 0.08248054644614458, 0.08047926588232318, 0.0785362726101031, 0.07778030724997322, 0.07500176610127092, 0.07231801753346809, 0.07249056691670169, 0.07015723809991031, 0.07295623635863885], "acc": [0.80485, 0.8794666666666666, 0.8975666666666666, 0.9077833333333334, 0.9161833333333333, 0.9232333333333334, 0.9284, 0.934, 0.9381, 0.9432333333333334, 0.9473166666666667, 0.9504166666666667, 0.9542833333333334, 0.9563833333333334, 0.9586333333333333, 0.9607166666666667, 0.9641, 0.9638333333333333, 0.96575, 0.9675166666666667, 0.9686333333333333, 0.96995, 0.9702833333333334, 0.97225, 0.9716666666666667, 0.9734, 0.9745166666666667, 0.97495, 0.9749333333333333, 0.9749166666666667]}, "performance": {"with_svm@-3": 0.9058, "with_svm@-2": 0.9093}},{"history": {"val_loss": [0.4150441631793976, 0.3261751090049744, 0.3454334979057312, 0.3094462700843811, 0.28191435914039614, 0.2875005770921707, 0.2872392510890961, 0.26201132280826567, 0.2847232746839523, 0.2788535724401474, 0.33057111032009123, 0.2983163303613663, 0.30958148255348206, 0.34721010646820066, 0.3631596899986267, 0.3546803478240967, 0.4272496466159821, 0.3690841608285904, 0.3891567182540894, 0.4191183884859085, 0.4348145247459412, 0.4770137900710106, 0.4338083899974823, 0.5618760014295578, 0.5968949678063392, 0.5492106608390808, 0.47301609473228456, 0.5954591445505619, 0.5699526505947113, 0.6091273111760617], "val_acc": [0.8442, 0.8812, 0.8708, 0.8908, 0.9007, 0.8984, 0.8951, 0.9066, 0.9122, 0.9021, 0.9064, 0.9103, 0.9012, 0.9002, 0.9087, 0.9068, 0.8991, 0.9074, 0.9076, 0.9053, 0.9058, 0.9052, 0.9051, 0.8962, 0.9071, 0.8982, 0.9004, 0.908, 0.9008, 0.9042], "loss": [0.5372069796085358, 0.33147892653942107, 0.2819155935049057, 0.2501675777196884, 0.22795790247519812, 0.20883161357243857, 0.19210073427359264, 0.17625245781143506, 0.16437627596060436, 0.15312131391763686, 0.14188234887719153, 0.13113620449503263, 0.1231539293517669, 0.11442032988071442, 0.10762197708835204, 0.10159016521821419, 0.09620393821696441, 0.09138783883477251, 0.08485465567509333, 0.08136600005974372, 0.07779915048554539, 0.07684610707126557, 0.07310853577925203, 0.07031081450295945, 0.06876748004704714, 0.06674423448170225, 0.06708652312122286, 0.0629670261045297, 0.06223219020180404, 0.060887808101888125], "acc": [0.8011833333333334, 0.8782, 0.896, 0.9074166666666666, 0.91605, 0.92305, 0.9288833333333333, 0.9350166666666667, 0.9386666666666666, 0.9424, 0.94645, 0.9512333333333334, 0.9545833333333333, 0.9567666666666667, 0.9592333333333334, 0.9621166666666666, 0.96345, 0.9660166666666666, 0.9680666666666666, 0.9688666666666667, 0.97135, 0.9718, 0.9725833333333334, 0.97455, 0.9753833333333334, 0.9762166666666666, 0.9763666666666667, 0.977, 0.9783833333333334, 0.9784833333333334]}, "performance": {"with_svm@-3": 0.9023, "with_svm@-2": 0.9039}},{"history": {"val_loss": [0.4043342845916748, 0.4140978363990784, 0.2880860612392426, 0.2902471554279327, 0.26663297791481017, 0.27250645632743836, 0.27851846265792846, 0.2674907832622528, 0.2509124885082245, 0.2766284414768219, 0.30450538427829743, 0.3396841156244278, 0.33489206714630126, 0.32968835673332214, 0.35049580488204957, 0.40639078683853147, 0.4156203591585159, 0.44790499334335326, 0.39392427277565, 0.4508971324920654, 0.45144741268157956, 0.4733809838980436, 0.43593108036518097, 0.47289209780693053, 0.5302549986958504, 0.5181801124095917, 0.4857735851764679, 0.43312914094924926, 0.589274645756185, 0.5597955271720886], "val_acc": [0.8577, 0.8536, 0.8938, 0.8919, 0.9056, 0.9018, 0.8997, 0.9096, 0.9146, 0.9096, 0.9093, 0.9034, 0.8991, 0.9053, 0.9027, 0.9061, 0.9061, 0.9006, 0.9028, 0.9073, 0.9044, 0.9051, 0.9003, 0.907, 0.9047, 0.9059, 0.8963, 0.9007, 0.9079, 0.9054], "loss": [0.5449977326552073, 0.331432382162412, 0.27721960909366605, 0.24537183661460876, 0.2236899551153183, 0.2034526499390602, 0.18604228529532751, 0.1735770888586839, 0.1608558345834414, 0.14848035674492518, 0.13826039456129074, 0.13101752441922823, 0.12131975610057513, 0.11398900530735652, 0.10702471297780673, 0.10125248184750478, 0.09447303883483013, 0.09300156107619405, 0.08732006960138679, 0.08322347532113393, 0.08042012545143565, 0.07930759765207768, 0.07606005681777994, 0.07313436928779507, 0.06995374247866372, 0.0693866363959387, 0.06551957191874584, 0.06589129398725926, 0.06280896504372359, 0.0626244676458494], "acc": [0.7985, 0.8795, 0.8988833333333334, 0.9090833333333334, 0.9173666666666667, 0.92485, 0.93225, 0.9366833333333333, 0.9411166666666667, 0.94515, 0.94975, 0.9526833333333333, 0.9558, 0.95695, 0.9602666666666667, 0.9631166666666666, 0.9650166666666666, 0.9656833333333333, 0.9686166666666667, 0.9690833333333333, 0.9704666666666667, 0.9719, 0.9723333333333334, 0.97415, 0.97455, 0.9749666666666666, 0.9765166666666667, 0.97695, 0.9780333333333333, 0.9779666666666667]}, "performance": {"with_svm@-3": 0.9026, "with_svm@-2": 0.9083}},{"history": {"val_loss": [0.38714364495277404, 0.3567921408176422, 0.3250357707738876, 0.2858080851316452, 0.28071209461688995, 0.2722078594684601, 0.2801001211643219, 0.27960221977233884, 0.2834079221725464, 0.27649084842205046, 0.2873888593196869, 0.3608335973739624, 0.3075114233016968, 0.33885010154247286, 0.36965319867134094, 0.3924022182226181, 0.3339719518661499, 0.4083260585546494, 0.40416117042303085, 0.40474343745708463, 0.42218335597515105, 0.4105245063304901, 0.45178317244052885, 0.44313831994533537, 0.4897380232691765, 0.47371806797981264, 0.5233340344429016, 0.5715650954663754, 0.5431698856592179, 0.6270964526116848], "val_acc": [0.864, 0.8749, 0.8814, 0.8912, 0.8966, 0.9025, 0.9056, 0.9048, 0.9077, 0.9085, 0.909, 0.9108, 0.9093, 0.9108, 0.906, 0.9056, 0.9076, 0.9087, 0.9049, 0.907, 0.906, 0.8994, 0.9052, 0.9018, 0.9004, 0.9033, 0.9063, 0.9024, 0.8953, 0.9066], "loss": [0.5441179068406423, 0.3305619689186414, 0.28244102187951403, 0.24943221237659455, 0.22933951362371444, 0.2084388442079226, 0.19123481418987115, 0.17787826199332873, 0.16468209545016288, 0.15482280835111936, 0.14399771384596824, 0.13411573451360068, 0.12475449653317532, 0.11885996306836605, 0.11200869298875332, 0.1069638219366471, 0.10219919406970342, 0.09681515558858712, 0.0917619632239143, 0.08723219532544414, 0.08620026492799322, 0.08276339541425308, 0.08095510593801737, 0.07946293067646523, 0.07687996114244064, 0.07559641222606103, 0.07310451997717221, 0.07273915510618438, 0.072224404023091, 0.0677077879728439], "acc": [0.80055, 0.8808333333333334, 0.8976666666666666, 0.9079333333333334, 0.91625, 0.9231666666666667, 0.93035, 0.9345333333333333, 0.9392, 0.9429833333333333, 0.9468, 0.9507666666666666, 0.9537833333333333, 0.9559333333333333, 0.9582666666666667, 0.9615, 0.9625, 0.96465, 0.96645, 0.9681333333333333, 0.9686333333333333, 0.9702333333333333, 0.9708166666666667, 0.9718, 0.97205, 0.9737, 0.9737833333333333, 0.9746333333333334, 0.9753833333333334, 0.97585]}, "performance": {"with_svm@-3": 0.9048, "with_svm@-2": 0.9104}},{"history": {"val_loss": [0.46697981705665587, 0.3832242013454437, 0.303250941491127, 0.28147033443450925, 0.2645062158644199, 0.30668587293624877, 0.28016208398342135, 0.2843065684996545, 0.27898327768445014, 0.2912071517586708, 0.31837791128754617, 0.3626342387139797, 0.31595948444008826, 0.36499525430202484, 0.39556382931470874, 0.3669290903389454, 0.3701426920939237, 0.3836698907200247, 0.42279923573583367, 0.48294994099140165, 0.4073905217766762, 0.5648117405116558, 0.5509298311325721, 0.52655683157444, 0.5822678367150947, 0.5299802899837494, 0.48815737073849885, 0.5766397311203182, 0.5554989251375199, 0.5038782718300819], "val_acc": [0.8383, 0.8641, 0.8908, 0.8942, 0.9051, 0.8855, 0.9032, 0.9099, 0.9059, 0.9062, 0.9071, 0.9045, 0.9099, 0.9064, 0.9019, 0.9069, 0.9125, 0.9094, 0.9087, 0.9053, 0.8955, 0.9033, 0.9056, 0.9079, 0.9121, 0.9115, 0.9071, 0.9091, 0.9092, 0.8997], "loss": [0.5585080780824025, 0.3369152142286301, 0.28359521205425264, 0.25012954001426696, 0.22671767033338547, 0.20606048639615376, 0.18930058318376541, 0.17374658589363098, 0.1618980585793654, 0.14980819566051165, 0.13848809727629025, 0.12958743358751137, 0.1217845716059208, 0.11529610383758943, 0.10612352059980233, 0.10063879103461901, 0.09551240612169107, 0.09058071209937334, 0.0849770180746913, 0.0819070641172429, 0.08037483161017299, 0.07584481724972526, 0.07303781550178925, 0.07187192963063717, 0.07019264071881771, 0.06639131543164452, 0.06384243606943638, 0.06306189757228518, 0.06210263610463589, 0.06092731468016282], "acc": [0.79055, 0.8768666666666667, 0.8955166666666666, 0.9082, 0.9159333333333334, 0.9237166666666666, 0.9308, 0.9363333333333334, 0.9404166666666667, 0.9446333333333333, 0.94825, 0.9519666666666666, 0.95445, 0.9571833333333334, 0.9613333333333334, 0.9628833333333333, 0.9644166666666667, 0.9666166666666667, 0.9682, 0.9696333333333333, 0.9715333333333334, 0.97235, 0.9737333333333333, 0.9744333333333334, 0.9748833333333333, 0.9766, 0.9770166666666666, 0.9782, 0.9781666666666666, 0.9792666666666666]}, "performance": {"with_svm@-3": 0.8975, "with_svm@-2": 0.9038}}] -------------------------------------------------------------------------------- /experiments/results/MNIST/performance_30epochs_combined.json: -------------------------------------------------------------------------------- 1 | [{"history": {"val_loss": [0.058012286397721616, 0.031465802982915195, 0.02752863751942932, 0.04448511460040463, 0.024455885998660234, 0.03770310556006152, 0.0302407726525661, 0.040125972307659685, 0.02950342787828249, 0.034729696841456874, 0.055798425468014376, 0.04136685403668739, 0.03716646059967136, 0.06317648815578528, 0.04977548953918097, 0.04911087775120434, 0.044314033737478914, 0.05617084128602239, 0.05534849855535915, 0.06092766595187452, 0.06630350116573029, 0.06982053399169147, 0.07034379402384695, 0.05951280631003713, 0.08072586988518408, 0.07859032812471571, 0.06860585440432163, 0.060573564682116304, 0.06489135026952284, 0.07252224225466455], "val_acc": [0.981, 0.9902, 0.9916, 0.9847, 0.9933, 0.9897, 0.9917, 0.9891, 0.9929, 0.9925, 0.9906, 0.9919, 0.993, 0.9891, 0.993, 0.993, 0.9933, 0.9916, 0.9923, 0.9923, 0.9922, 0.9919, 0.9915, 0.993, 0.9917, 0.9913, 0.9933, 0.9935, 0.9933, 0.9923], "loss": [0.17301000547260045, 0.04833274687674517, 0.03296006406117231, 0.026375785614883837, 0.020759016964376983, 0.016109659135122394, 0.013145960034359087, 0.01125916710313201, 0.009907794772488463, 0.00860547799981235, 0.006869320600131293, 0.006863135235392868, 0.006310418321777403, 0.005953495336345638, 0.004055292068486461, 0.004956505267661699, 0.00472531110749945, 0.004115091456871733, 0.0036294369941880632, 0.004000084901630392, 0.0034543862089507003, 0.0031846033174160616, 0.002197130242962362, 0.003921002310771739, 0.00240644144151604, 0.0034369908087196715, 0.003028492050913519, 0.002963888820605924, 0.0024517073052008377, 0.002116798436639715], "acc": [0.9474333333333333, 0.9853, 0.9893666666666666, 0.9918833333333333, 0.99345, 0.9950833333333333, 0.9959833333333333, 0.99655, 0.9971833333333333, 0.9971666666666666, 0.99795, 0.9981833333333333, 0.9981666666666666, 0.9982666666666666, 0.9986666666666667, 0.9985666666666667, 0.99885, 0.9989166666666667, 0.9990333333333333, 0.9990833333333333, 0.9991166666666667, 0.99925, 0.9993166666666666, 0.9992833333333333, 0.9993333333333333, 0.9991666666666666, 0.9993166666666666, 0.99935, 0.99945, 0.9995333333333334]}, "performance": {"with_svm@-3": 0.9942, "with_svm@-2": 0.9938}}, 2 | {"history": {"val_loss": [0.04501374350236729, 0.03333432343536988, 0.02289775217374554, 0.031247786466928665, 0.024255429341574198, 0.02684926290004114, 0.03141559111404458, 0.028666918717404406, 0.04446157635115014, 0.028694591341076922, 0.03957730199245957, 0.04202645219727212, 0.040510537447208415, 0.0429101706364025, 0.04585446695649855, 0.0456064468267775, 0.044591218002707614, 0.058358724957093475, 0.05272812839602186, 0.05444539611888879, 0.05924669900309011, 0.05979339384530924, 0.06361616533229401, 0.053325072051109715, 0.05930827118715099, 0.05690898257817064, 0.06264420522789153, 0.07012136972692072, 0.06826655435531541, 0.07514357501242923], "val_acc": [0.9844, 0.9881, 0.9916, 0.9893, 0.993, 0.9919, 0.9919, 0.992, 0.9911, 0.9935, 0.9923, 0.9938, 0.9933, 0.9931, 0.9928, 0.9922, 0.9934, 0.9924, 0.9929, 0.9931, 0.9925, 0.9922, 0.9928, 0.9935, 0.9935, 0.9937, 0.9932, 0.9922, 0.9918, 0.9914], "loss": [0.17729154649289947, 0.04786290728921692, 0.03310972272466558, 0.024000238360309352, 0.019628610597027, 0.015570734798592943, 0.012905735168078535, 0.01112630117221476, 0.009569634471618307, 0.008479778113915093, 0.007240269398996497, 0.005718392407475357, 0.005387276505376697, 0.005427486746183755, 0.00405911881381312, 0.004863275600015732, 0.0036052914443208163, 0.004098738280737507, 0.0034470039245908233, 0.002539891782689559, 0.003724600748225195, 0.003396888139204172, 0.0033139408396933127, 0.0022997476812821713, 0.002101743097013415, 0.0023254889946794643, 0.0026494044639331226, 0.001954368177306363, 0.0031213346675570846, 0.002741553527614049], "acc": [0.9444833333333333, 0.9851166666666666, 0.9901333333333333, 0.9928666666666667, 0.9939666666666667, 0.9954833333333334, 0.9959333333333333, 0.99675, 0.99715, 0.9973833333333333, 0.99775, 0.99835, 0.9984666666666666, 0.9985833333333334, 0.9989166666666667, 0.9988666666666667, 0.999, 0.9988166666666667, 0.9991666666666666, 0.9994, 0.9989666666666667, 0.9991166666666667, 0.9991666666666666, 0.9993833333333333, 0.9995333333333334, 0.9994666666666666, 0.9993666666666666, 0.99965, 0.9994, 0.9993666666666666]}, "performance": {"with_svm@-3": 0.9943, "with_svm@-2": 0.9932}}, 3 | {"history": {"val_loss": [0.04693662054911256, 0.040643049051705744, 0.03836872058830922, 0.025879924827218202, 0.03365807937731297, 0.05606620579953305, 0.03101422648470616, 0.033968908638895845, 0.03749096577653436, 0.03477393601403474, 0.03989164399835893, 0.04330464110312171, 0.052758552326085666, 0.046145228933404746, 0.05790804634753588, 0.039930098790110356, 0.05093961141671082, 0.05108419878352672, 0.06115888950322489, 0.05455096907358804, 0.0614969135381117, 0.05646109203540775, 0.06690327637175754, 0.05403370705581249, 0.05814710223793893, 0.06744184541873186, 0.07847057491659128, 0.07086941969375286, 0.06166929620088142, 0.06629630808312378], "val_acc": [0.9848, 0.9878, 0.9886, 0.9912, 0.9903, 0.9857, 0.9917, 0.9911, 0.9904, 0.9927, 0.9917, 0.9919, 0.9917, 0.991, 0.9908, 0.9943, 0.992, 0.9926, 0.9918, 0.993, 0.9914, 0.9927, 0.9924, 0.9931, 0.9926, 0.993, 0.9908, 0.9917, 0.9937, 0.9926], "loss": [0.17716153866623838, 0.04742282863749812, 0.03301722237803042, 0.025223383712318415, 0.02021720212032475, 0.01563415939128075, 0.013349200006359025, 0.011451028558874896, 0.009647474916346982, 0.007851272325570154, 0.00654441191043555, 0.006293635521197999, 0.00570135567472183, 0.004649824481886086, 0.0051237779484377445, 0.0044031188274873175, 0.0044619473003985905, 0.004122088544306773, 0.0037621076239143275, 0.003428344476895821, 0.003696741009939372, 0.0028606461014387377, 0.0029217197914563407, 0.00375240520837721, 0.0035898833562101116, 0.003996992303055117, 0.0024696943116979354, 0.0015655488069841491, 0.002664865174772005, 0.003095264735105055], "acc": [0.94555, 0.9851666666666666, 0.99, 0.9922, 0.9939, 0.9950333333333333, 0.9957, 0.9967, 0.9972666666666666, 0.9976666666666667, 0.9979, 0.9982333333333333, 0.9984166666666666, 0.9986, 0.9986333333333334, 0.9988666666666667, 0.9987833333333334, 0.9988333333333334, 0.99895, 0.9990333333333333, 0.9991833333333333, 0.9992, 0.9992833333333333, 0.99925, 0.9992333333333333, 0.9990833333333333, 0.9994166666666666, 0.9995833333333334, 0.9995333333333334, 0.99935]}, "performance": {"with_svm@-3": 0.994, "with_svm@-2": 0.9934}}, 4 | {"history": {"val_loss": [0.08197933812029659, 0.03546777663987596, 0.03341330313124054, 0.030251626034558284, 0.02884039005855011, 0.019837814035794873, 0.0317348777288682, 0.03626116072119066, 0.03655646794341883, 0.028666779191147088, 0.04297479794507308, 0.039568193257491224, 0.03202172394558131, 0.046409738939450426, 0.046686414847980924, 0.03982525034669272, 0.05124423405043328, 0.05239241906636428, 0.04661062939958031, 0.055866467449014294, 0.08729090600506854, 0.06305750775740036, 0.054992045577612614, 0.06343066945878484, 0.06362634397384556, 0.06779158965707477, 0.057760000991913536, 0.06725425281132523, 0.06973312734179865, 0.07662225833157338], "val_acc": [0.9743, 0.9881, 0.9894, 0.9909, 0.9913, 0.9933, 0.9915, 0.9914, 0.9912, 0.9932, 0.9917, 0.994, 0.9936, 0.9922, 0.9924, 0.9943, 0.9935, 0.9925, 0.9935, 0.9935, 0.9896, 0.993, 0.9933, 0.9928, 0.9921, 0.9922, 0.993, 0.9924, 0.9916, 0.9921], "loss": [0.16861701912457744, 0.046583191910013554, 0.03097322490364313, 0.023580448029361045, 0.018423834095839024, 0.015051568441031123, 0.012091321246521208, 0.010600623359851655, 0.009061368080696427, 0.007479306346113405, 0.00671511196024926, 0.005556289483133651, 0.0056452080929059855, 0.004374737399550767, 0.00513432675843607, 0.0039756081365345, 0.003821543846153865, 0.003121435883364787, 0.004367182645120101, 0.0028328010872671863, 0.002666306109979875, 0.0049319935316656256, 0.003590493196360785, 0.002394134152485617, 0.0033287287435602064, 0.002699324307992515, 0.0029116885613174265, 0.004292509903183471, 0.00416508137487399, 0.0020753812079482865], "acc": [0.9476, 0.9852833333333333, 0.9903333333333333, 0.9926, 0.99435, 0.9954166666666666, 0.99635, 0.9968833333333333, 0.9975166666666667, 0.99785, 0.99805, 0.99825, 0.9984666666666666, 0.9988, 0.9987333333333334, 0.999, 0.99905, 0.9991833333333333, 0.9990666666666667, 0.9992833333333333, 0.9993333333333333, 0.99875, 0.99915, 0.9995666666666667, 0.9992166666666666, 0.9993833333333333, 0.9994166666666666, 0.9990666666666667, 0.9992666666666666, 0.99955]}, "performance": {"with_svm@-3": 0.9942, "with_svm@-2": 0.9939}}, 5 | {"history": {"val_loss": [0.04706196300601587, 0.031698632218583954, 0.03227365807255264, 0.02973660909422033, 0.03172424109021194, 0.029596417490038582, 0.024765829553188814, 0.03213707902151118, 0.04085539187522227, 0.03789611699151364, 0.056589442667745536, 0.05191770297432388, 0.042685455790007584, 0.04466887094765375, 0.059763871695658895, 0.05341054688095057, 0.04867746317094354, 0.046260142465698234, 0.05594641119233786, 0.05880422080311575, 0.05863813357189838, 0.0646812159224218, 0.07454232090098196, 0.06272837779143121, 0.0716926159349325, 0.060035764309630875, 0.07311587181051857, 0.07722717706860649, 0.06542106488027886, 0.07567073885974498], "val_acc": [0.9843, 0.9892, 0.9915, 0.9915, 0.9916, 0.9917, 0.9923, 0.9932, 0.9911, 0.9921, 0.9892, 0.9917, 0.9928, 0.9922, 0.9918, 0.9921, 0.9919, 0.9933, 0.9918, 0.9916, 0.9923, 0.9925, 0.9914, 0.9922, 0.9916, 0.9921, 0.9916, 0.9918, 0.9934, 0.9923], "loss": [0.17061692195584377, 0.046234829489390054, 0.03182135101857905, 0.02431265958369865, 0.018748356542420515, 0.015698805837876473, 0.012273000637336372, 0.010425754852034274, 0.008624888020616497, 0.007014024585414942, 0.007216674821420172, 0.005889621360516246, 0.005983667603303639, 0.004918486540910453, 0.00394124438706978, 0.0043690883281159885, 0.0042438452046275115, 0.003430537353717273, 0.0030983400962485198, 0.003630599020303695, 0.0027655376554707573, 0.00308147610446437, 0.0030024490120792962, 0.003198153611999343, 0.0034311861089857, 0.0029338489132527, 0.0032705091191530073, 0.0022210830026674635, 0.0030188847804404834, 0.0033922651689746847], "acc": [0.9471, 0.9858333333333333, 0.9898333333333333, 0.99235, 0.9941, 0.9952166666666666, 0.9963, 0.99675, 0.99725, 0.9977833333333334, 0.9979333333333333, 0.9984166666666666, 0.9983833333333333, 0.9987166666666667, 0.9988833333333333, 0.99865, 0.9988166666666667, 0.9991166666666667, 0.9991833333333333, 0.9991, 0.9992, 0.9992, 0.9993833333333333, 0.9992666666666666, 0.9992666666666666, 0.9994666666666666, 0.9992833333333333, 0.99935, 0.99935, 0.9993833333333333]}, "performance": {"with_svm@-3": 0.9933, "with_svm@-2": 0.9933}}] -------------------------------------------------------------------------------- /keras_svm/model_svm_wrapper.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from sklearn.svm import SVC 3 | from keras.utils import to_categorical 4 | 5 | 6 | class ModelSVMWrapper: 7 | """ 8 | Linear stack of layers with the option to replace the end of the stack with a Support Vector Machine 9 | # Arguments 10 | layers: list of layers to add to the model. 11 | svm: The Support Vector Machine to use. 12 | """ 13 | def __init__(self, model, svm=None): 14 | super().__init__() 15 | 16 | self.model = model 17 | self.intermediate_model = None # type: Model 18 | self.svm = svm 19 | 20 | if svm is None: 21 | self.svm = SVC(kernel='linear') 22 | 23 | def add(self, layer): 24 | return self.model.add(layer) 25 | 26 | def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0., 27 | validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, 28 | steps_per_epoch=None, validation_steps=None, **kwargs): 29 | """ 30 | Trains the model for a fixed number of epochs (iterations on a dataset). 31 | 32 | # Arguments 33 | x: Numpy array of training data. 34 | If the input layer in the model is named, you can also pass a 35 | dictionary mapping the input name to a Numpy array. 36 | `x` can be `None` (default) if feeding from 37 | framework-native tensors (e.g. TensorFlow data tensors). 38 | y: Numpy array of target (label) data. 39 | If the output layer in the model is named, you can also pass a 40 | dictionary mapping the output name to a Numpy array. 41 | `y` can be `None` (default) if feeding from 42 | framework-native tensors (e.g. TensorFlow data tensors). 43 | batch_size: Integer or `None`. 44 | Number of samples per gradient update. 45 | If unspecified, it will default to 32. 46 | epochs: Integer. Number of epochs to train the model. 47 | An epoch is an iteration over the entire `x` and `y` 48 | data provided. 49 | Note that in conjunction with `initial_epoch`, 50 | `epochs` is to be understood as "final epoch". 51 | The model is not trained for a number of iterations 52 | given by `epochs`, but merely until the epoch 53 | of index `epochs` is reached. 54 | verbose: 0, 1, or 2. Verbosity mode. 55 | 0 = silent, 1 = progress bar, 2 = one line per epoch. 56 | callbacks: List of `keras.callbacks.Callback` instances. 57 | List of callbacks to apply during training. 58 | See [callbacks](/callbacks). 59 | validation_split: Float between 0 and 1. 60 | Fraction of the training data to be used as validation data. 61 | The model will set apart this fraction of the training data, 62 | will not train on it, and will evaluate 63 | the loss and any model metrics 64 | on this data at the end of each epoch. 65 | The validation data is selected from the last samples 66 | in the `x` and `y` data provided, before shuffling. 67 | validation_data: tuple `(x_val, y_val)` or tuple 68 | `(x_val, y_val, val_sample_weights)` on which to evaluate 69 | the loss and any model metrics at the end of each epoch. 70 | The model will not be trained on this data. 71 | This will override `validation_split`. 72 | shuffle: Boolean (whether to shuffle the training data 73 | before each epoch) or str (for 'batch'). 74 | 'batch' is a special option for dealing with the 75 | limitations of HDF5 data; it shuffles in batch-sized chunks. 76 | Has no effect when `steps_per_epoch` is not `None`. 77 | class_weight: Optional dictionary mapping class indices (integers) 78 | to a weight (float) value, used for weighting the loss function 79 | (during training only). 80 | This can be useful to tell the model to 81 | "pay more attention" to samples from 82 | an under-represented class. 83 | sample_weight: Optional Numpy array of weights for 84 | the training samples, used for weighting the loss function 85 | (during training only). You can either pass a flat (1D) 86 | Numpy array with the same length as the input samples 87 | (1:1 mapping between weights and samples), 88 | or in the case of temporal data, 89 | you can pass a 2D array with shape 90 | `(samples, sequence_length)`, 91 | to apply a different weight to every timestep of every sample. 92 | In this case you should make sure to specify 93 | `sample_weight_mode="temporal"` in `compile()`. 94 | initial_epoch: Epoch at which to start training 95 | (useful for resuming a previous training run). 96 | steps_per_epoch: Total number of steps (batches of samples) 97 | before declaring one epoch finished and starting the 98 | next epoch. When training with input tensors such as 99 | TensorFlow data tensors, the default `None` is equal to 100 | the number of samples in your dataset divided by 101 | the batch size, or 1 if that cannot be determined. 102 | validation_steps: Only relevant if `steps_per_epoch` 103 | is specified. Total number of steps (batches of samples) 104 | to validate before stopping. 105 | 106 | # Returns 107 | A `History` object. Its `History.history` attribute is 108 | a record of training loss values and metrics values 109 | at successive epochs, as well as validation loss values 110 | and validation metrics values (if applicable). 111 | 112 | # Raises 113 | RuntimeError: If the model was never compiled. 114 | ValueError: In case of mismatch between the provided input data 115 | and what the model expects. 116 | """ 117 | fit = self.model.fit(x, to_categorical(y), batch_size, epochs, verbose, callbacks, validation_split, 118 | validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, 119 | validation_steps, **kwargs) 120 | 121 | self.fit_svm(x, y, self.__get_split_layer()) 122 | 123 | return fit 124 | 125 | def fit_svm(self, x, y, split_layer): 126 | """ 127 | Fits the SVM on the data without changing the under laying neural network 128 | # Arguments 129 | x: Numpy array of training data. 130 | If the input layer in the model is named, you can also pass a 131 | dictionary mapping the input name to a Numpy array. 132 | `x` can be `None` (default) if feeding from 133 | framework-native tensors (e.g. TensorFlow data tensors). 134 | y: Numpy array of target (label) data. 135 | If the output layer in the model is named, you can also pass a 136 | dictionary mapping the output name to a Numpy array. 137 | `y` can be `None` (default) if feeding from 138 | framework-native tensors (e.g. TensorFlow data tensors). 139 | split_layer: 140 | The layer to split on 141 | """ 142 | # Store intermediate model 143 | self.intermediate_model = Model(inputs=self.model.input, 144 | outputs=split_layer.output) 145 | # Use output of intermediate model to train SVM 146 | intermediate_output = self.intermediate_model.predict(x) 147 | self.svm.fit(intermediate_output, y) 148 | 149 | def evaluate(self, x=None, y=None, batch_size=None, verbose=1, steps=None): 150 | """ 151 | Computes the accuracy on some input data, batch by batch. 152 | 153 | # Arguments 154 | x: input data, as a Numpy array or list of Numpy arrays 155 | (if the model has multiple inputs). 156 | `x` can be `None` (default) if feeding from 157 | framework-native tensors (e.g. TensorFlow data tensors). 158 | y: labels, as a Numpy array. 159 | `y` can be `None` (default) if feeding from 160 | framework-native tensors (e.g. TensorFlow data tensors). 161 | batch_size: Integer. If unspecified, it will default to 32. 162 | verbose: verbosity mode, 0 or 1. 163 | sample_weight: sample weights, as a Numpy array. 164 | steps: Integer or `None`. 165 | Total number of steps (batches of samples) 166 | before declaring the evaluation round finished. 167 | Ignored with the default value of `None`. 168 | 169 | # Returns 170 | Accuracy of the model mappin x to y 171 | 172 | # Raises 173 | RuntimeError: if the model was never compiled. 174 | """ 175 | 176 | if self.intermediate_model is None: 177 | raise Exception("A model must be fit before running evaluate") 178 | output = self.predict(x, batch_size, verbose, steps) 179 | correct = [output[i] == y[i] 180 | for i in range(len(output))] 181 | 182 | accuracy = sum(correct) / len(correct) 183 | 184 | return accuracy 185 | 186 | def predict(self, x, batch_size=None, verbose=0, steps=None): 187 | """ 188 | Computes the loss on some input data, batch by batch. 189 | 190 | # Arguments 191 | x: input data, as a Numpy array or list of Numpy arrays 192 | (if the model has multiple inputs). 193 | `x` can be `None` (default) if feeding from 194 | framework-native tensors (e.g. TensorFlow data tensors). 195 | y: labels, as a Numpy array. 196 | `y` can be `None` (default) if feeding from 197 | framework-native tensors (e.g. TensorFlow data tensors). 198 | batch_size: Integer. If unspecified, it will default to 32. 199 | verbose: verbosity mode, 0 or 1. 200 | sample_weight: sample weights, as a Numpy array. 201 | steps: Integer or `None`. 202 | Total number of steps (batches of samples) 203 | before declaring the evaluation round finished. 204 | Ignored with the default value of `None`. 205 | 206 | # Returns 207 | Scalar test loss (if the model has no metrics) 208 | or list of scalars (if the model computes other metrics). 209 | The attribute `model.metrics_names` will give you 210 | the display labels for the scalar outputs. 211 | 212 | # Raises 213 | RuntimeError: if the model was never compiled. 214 | """ 215 | intermediate_prediction = self.intermediate_model.predict(x, batch_size, verbose, steps) 216 | output = self.svm.predict(intermediate_prediction) 217 | 218 | return output 219 | 220 | def __get_split_layer(self): 221 | """ 222 | Gets the layer to split on either "split_layer" or the second to last layer. 223 | 224 | :return: The layer to split on: from where the svm must replace the existing model. 225 | :raises ValueError: If not enough layers exist for a good split (at least two required) 226 | """ 227 | if len(self.model.layers) < 3: 228 | raise ValueError('self.layers to small for a relevant split') 229 | 230 | for layer in self.model.layers: 231 | if layer.name == "split_layer": 232 | return layer 233 | 234 | # if no specific cut of point is specified we can assume we need to remove only the last (softmax) layer 235 | return self.model.layers[-3] 236 | -------------------------------------------------------------------------------- /experiments/deep_learning_paper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "MvPDrkNDsyO9" 8 | }, 9 | "source": [] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 17, 14 | "metadata": { 15 | "colab": { 16 | "autoexec": { 17 | "startup": false, 18 | "wait_interval": 0 19 | }, 20 | "base_uri": "https://localhost:8080/", 21 | "height": 170 22 | }, 23 | "colab_type": "code", 24 | "executionInfo": { 25 | "elapsed": 3158, 26 | "status": "ok", 27 | "timestamp": 1525036626944, 28 | "user": { 29 | "displayName": "Daan Luttik", 30 | "photoUrl": "//lh3.googleusercontent.com/-CPv5nanSWKo/AAAAAAAAAAI/AAAAAAAAE3I/gkG30GZ_TJs/s50-c-k-no/photo.jpg", 31 | "userId": "106328974539959585216" 32 | }, 33 | "user_tz": -120 34 | }, 35 | "id": "GaBXH7PZiJsK", 36 | "outputId": "4c5b7299-ab38-4082-fa49-527be31d49a0" 37 | }, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "Looking in indexes: https://pypi.org/simple, https://legacy.pypi.org/simple\n", 44 | "Requirement already up-to-date: keras-svm in /usr/local/lib/python3.6/dist-packages (1.0.0b10)\n", 45 | "Requirement not upgraded as not directly required: scikit-learn in /usr/local/lib/python3.6/dist-packages (from keras-svm) (0.19.1)\n", 46 | "Requirement not upgraded as not directly required: keras in /usr/local/lib/python3.6/dist-packages (from keras-svm) (2.1.6)\n", 47 | "Requirement not upgraded as not directly required: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (3.12)\n", 48 | "Requirement not upgraded as not directly required: h5py in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (2.7.1)\n", 49 | "Requirement not upgraded as not directly required: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (1.11.0)\n", 50 | "Requirement not upgraded as not directly required: numpy>=1.9.1 in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (1.14.2)\n", 51 | "Requirement not upgraded as not directly required: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (0.19.1)\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "# I've created a pip package containing a wrapper for the model.\n", 57 | "!pip install --upgrade keras-svm" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "colab": { 65 | "autoexec": { 66 | "startup": false, 67 | "wait_interval": 0 68 | } 69 | }, 70 | "colab_type": "code", 71 | "id": "AyMkuF7nsips" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import keras\n", 76 | "from keras_svm.model_svm_wrapper import ModelSVMWrapper\n", 77 | "from keras import layers, models, backend\n", 78 | "from keras.datasets import mnist, fashion_mnist\n", 79 | "from keras.utils import to_categorical\n", 80 | "from keras.models import Model\n", 81 | "from keras.engine.topology import Layer\n", 82 | "import matplotlib.pyplot as plt\n", 83 | "from google.colab import files\n", 84 | "import pickle\n", 85 | "import json" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "colab": { 93 | "autoexec": { 94 | "startup": false, 95 | "wait_interval": 0 96 | } 97 | }, 98 | "colab_type": "code", 99 | "id": "LGqkiYtkCPnY" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "(train_images, train_labels), (test_images, test_labels) = \\\n", 104 | " fashion_mnist.load_data() \n", 105 | " # or mnist.load_data()\n", 106 | "\n", 107 | "train_images = train_images.reshape((60000, 28, 28, 1))\n", 108 | "train_images = train_images.astype('float32') / 255\n", 109 | "\n", 110 | "test_images = test_images.reshape((10000, 28, 28, 1))\n", 111 | "test_images = test_images.astype('float32') / 255" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": { 117 | "colab_type": "text", 118 | "id": "V7oJ5coGo7Oe" 119 | }, 120 | "source": [ 121 | "# Build a generic CNN\n", 122 | "\n", 123 | "----\n", 124 | "(based on https://github.com/fchollet/deep-learning-with-python-notebooks)\n", 125 | "\n", 126 | "Creating a simple CNN with three convolutional layers." 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "colab": { 134 | "autoexec": { 135 | "startup": false, 136 | "wait_interval": 0 137 | } 138 | }, 139 | "colab_type": "code", 140 | "id": "0wmrpcg2o7Oe" 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "def build_model():\n", 145 | " model = models.Sequential()\n", 146 | " model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))\n", 147 | " model.add(layers.MaxPooling2D((2, 2)))\n", 148 | " model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", 149 | " model.add(layers.MaxPooling2D((2, 2)))\n", 150 | " model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", 151 | " model.add(layers.Flatten())\n", 152 | " \n", 153 | " model.add(layers.Dense(64, activation='relu'))\n", 154 | " model.add(layers.Dense(10, activation='softmax'))\n", 155 | " model.compile(optimizer='rmsprop',\n", 156 | " loss='categorical_crossentropy',\n", 157 | " metrics=['accuracy'])\n", 158 | " return model\n", 159 | "\n", 160 | "wrapper = ModelSVMWrapper(build_model())" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": { 166 | "colab_type": "text", 167 | "id": "d6LQYxH4o7Oh" 168 | }, 169 | "source": [ 170 | "Let's display the architecture of our convnet so far:" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 21, 176 | "metadata": { 177 | "colab": { 178 | "autoexec": { 179 | "startup": false, 180 | "wait_interval": 0 181 | }, 182 | "base_uri": "https://localhost:8080/", 183 | "height": 408 184 | }, 185 | "colab_type": "code", 186 | "executionInfo": { 187 | "elapsed": 417, 188 | "status": "ok", 189 | "timestamp": 1525036629639, 190 | "user": { 191 | "displayName": "Daan Luttik", 192 | "photoUrl": "//lh3.googleusercontent.com/-CPv5nanSWKo/AAAAAAAAAAI/AAAAAAAAE3I/gkG30GZ_TJs/s50-c-k-no/photo.jpg", 193 | "userId": "106328974539959585216" 194 | }, 195 | "user_tz": -120 196 | }, 197 | "id": "7UaZvzRyzUIo", 198 | "outputId": "b2f985ce-274b-4235-fed2-9ba088be73de" 199 | }, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "_________________________________________________________________\n", 206 | "Layer (type) Output Shape Param # \n", 207 | "=================================================================\n", 208 | "conv2d_19 (Conv2D) (None, 26, 26, 32) 320 \n", 209 | "_________________________________________________________________\n", 210 | "max_pooling2d_13 (MaxPooling (None, 13, 13, 32) 0 \n", 211 | "_________________________________________________________________\n", 212 | "conv2d_20 (Conv2D) (None, 11, 11, 64) 18496 \n", 213 | "_________________________________________________________________\n", 214 | "max_pooling2d_14 (MaxPooling (None, 5, 5, 64) 0 \n", 215 | "_________________________________________________________________\n", 216 | "conv2d_21 (Conv2D) (None, 3, 3, 64) 36928 \n", 217 | "_________________________________________________________________\n", 218 | "flatten_7 (Flatten) (None, 576) 0 \n", 219 | "_________________________________________________________________\n", 220 | "dense_13 (Dense) (None, 64) 36928 \n", 221 | "_________________________________________________________________\n", 222 | "dense_14 (Dense) (None, 10) 650 \n", 223 | "=================================================================\n", 224 | "Total params: 93,322\n", 225 | "Trainable params: 93,322\n", 226 | "Non-trainable params: 0\n", 227 | "_________________________________________________________________\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "wrapper.model.summary()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": { 238 | "colab_type": "text", 239 | "id": "XRy9LUEio7Os" 240 | }, 241 | "source": [ 242 | "Prepare the data for the CNN" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": { 248 | "colab_type": "text", 249 | "id": "9F7NHlli9nq6" 250 | }, 251 | "source": [ 252 | "Train model and store intermediate test results" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 22, 258 | "metadata": { 259 | "colab": { 260 | "autoexec": { 261 | "startup": false, 262 | "wait_interval": 0 263 | }, 264 | "base_uri": "https://localhost:8080/", 265 | "height": 1244 266 | }, 267 | "colab_type": "code", 268 | "executionInfo": { 269 | "elapsed": 3310064, 270 | "status": "ok", 271 | "timestamp": 1525039939835, 272 | "user": { 273 | "displayName": "Daan Luttik", 274 | "photoUrl": "//lh3.googleusercontent.com/-CPv5nanSWKo/AAAAAAAAAAI/AAAAAAAAE3I/gkG30GZ_TJs/s50-c-k-no/photo.jpg", 275 | "userId": "106328974539959585216" 276 | }, 277 | "user_tz": -120 278 | }, 279 | "id": "UyLwib_Jo7Ow", 280 | "outputId": "4ccc51a5-f194-474a-cad9-be204e8982d6" 281 | }, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Starting epoch: 1\n", 288 | "Epoch 1/1\n", 289 | "60000/60000 [==============================] - 12s 195us/step - loss: 0.5336 - acc: 0.8041\n", 290 | "10000/10000 [==============================] - 1s 88us/step\n", 291 | "10000/10000 [==============================] - 1s 131us/step\n", 292 | "10000/10000 [==============================] - 1s 94us/step\n", 293 | "{'with_svm@-2': [0.8876], 'with_svm@-3': [0.8999], 'without_svm': [0.8667]}\n", 294 | "Starting epoch: 2\n", 295 | "Epoch 1/1\n", 296 | "38912/60000 [==================>...........] - ETA: 3s - loss: 0.3295 - acc: 0.880160000/60000 [==============================] - 11s 187us/step - loss: 0.3198 - acc: 0.8834\n", 297 | "10000/10000 [==============================] - 1s 91us/step\n", 298 | "10000/10000 [==============================] - 1s 123us/step\n", 299 | "10000/10000 [==============================] - 1s 91us/step\n", 300 | "{'with_svm@-2': [0.8876, 0.8992], 'with_svm@-3': [0.8999, 0.9046], 'without_svm': [0.8667, 0.8758]}\n", 301 | "Starting epoch: 3\n", 302 | "Epoch 1/1\n", 303 | "51008/60000 [========================>.....] - ETA: 1s - loss: 0.2726 - acc: 0.900660000/60000 [==============================] - 11s 187us/step - loss: 0.2725 - acc: 0.9010\n", 304 | "10000/10000 [==============================] - 1s 89us/step\n", 305 | "10000/10000 [==============================] - 1s 125us/step\n", 306 | "10000/10000 [==============================] - 1s 93us/step\n", 307 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077], 'with_svm@-3': [0.8999, 0.9046, 0.9097], 'without_svm': [0.8667, 0.8758, 0.8939]}\n", 308 | "Starting epoch: 4\n", 309 | "Epoch 1/1\n", 310 | "50432/60000 [========================>.....] - ETA: 1s - loss: 0.2437 - acc: 0.911760000/60000 [==============================] - 11s 187us/step - loss: 0.2421 - acc: 0.9121\n", 311 | "10000/10000 [==============================] - 1s 95us/step\n", 312 | "10000/10000 [==============================] - 1s 118us/step\n", 313 | "10000/10000 [==============================] - 1s 94us/step\n", 314 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006]}\n", 315 | "Starting epoch: 5\n", 316 | "Epoch 1/1\n", 317 | "47744/60000 [======================>.......] - ETA: 2s - loss: 0.2189 - acc: 0.919460000/60000 [==============================] - 11s 188us/step - loss: 0.2188 - acc: 0.9189\n", 318 | "10000/10000 [==============================] - 1s 89us/step\n", 319 | "10000/10000 [==============================] - 1s 121us/step\n", 320 | "10000/10000 [==============================] - 1s 95us/step\n", 321 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903]}\n", 322 | "Starting epoch: 6\n", 323 | "Epoch 1/1\n", 324 | "44928/60000 [=====================>........] - ETA: 2s - loss: 0.2001 - acc: 0.926360000/60000 [==============================] - 11s 189us/step - loss: 0.2015 - acc: 0.9258\n", 325 | "10000/10000 [==============================] - 1s 90us/step\n", 326 | "10000/10000 [==============================] - 1s 123us/step\n", 327 | "10000/10000 [==============================] - 1s 96us/step\n", 328 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985]}\n", 329 | "Starting epoch: 7\n", 330 | "Epoch 1/1\n", 331 | "41600/60000 [===================>..........] - ETA: 3s - loss: 0.1851 - acc: 0.932860000/60000 [==============================] - 12s 195us/step - loss: 0.1851 - acc: 0.9325\n", 332 | "10000/10000 [==============================] - 1s 96us/step\n", 333 | "10000/10000 [==============================] - 1s 120us/step\n", 334 | "10000/10000 [==============================] - 1s 96us/step\n", 335 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041]}\n", 336 | "Starting epoch: 8\n", 337 | "Epoch 1/1\n", 338 | "40576/60000 [===================>..........] - ETA: 3s - loss: 0.1695 - acc: 0.938060000/60000 [==============================] - 11s 190us/step - loss: 0.1727 - acc: 0.9371\n", 339 | "10000/10000 [==============================] - 1s 91us/step\n", 340 | "10000/10000 [==============================] - 1s 123us/step\n", 341 | "10000/10000 [==============================] - 1s 97us/step\n", 342 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062]}\n", 343 | "Starting epoch: 9\n", 344 | "Epoch 1/1\n", 345 | "38912/60000 [==================>...........] - ETA: 3s - loss: 0.1588 - acc: 0.940760000/60000 [==============================] - 11s 188us/step - loss: 0.1580 - acc: 0.9412\n", 346 | "10000/10000 [==============================] - 1s 91us/step\n", 347 | "10000/10000 [==============================] - 1s 121us/step\n", 348 | "10000/10000 [==============================] - 1s 95us/step\n", 349 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145, 0.9122], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912, 0.9121], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062, 0.8882]}\n", 350 | "Starting epoch: 10\n", 351 | "Epoch 1/1\n", 352 | "37568/60000 [=================>............] - ETA: 4s - loss: 0.1453 - acc: 0.946360000/60000 [==============================] - 11s 187us/step - loss: 0.1467 - acc: 0.9462\n", 353 | "10000/10000 [==============================] - 1s 93us/step\n", 354 | "10000/10000 [==============================] - 1s 122us/step\n", 355 | "10000/10000 [==============================] - 1s 100us/step\n", 356 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145, 0.9122, 0.9123], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912, 0.9121, 0.9109], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062, 0.8882, 0.9062]}\n", 357 | "performance4.json\n" 358 | ] 359 | } 360 | ], 361 | "source": [ 362 | "for j in range(5): \n", 363 | " wrapper = ModelSVMWrapper(build_model())\n", 364 | "\n", 365 | " epochs = 10\n", 366 | " performance = {\n", 367 | " \"with_svm@-2\": [],\n", 368 | " \"with_svm@-3\": [],\n", 369 | " \"without_svm\": []\n", 370 | " }\n", 371 | " for i in range(epochs):\n", 372 | " print('Starting epoch: {}'.format(i + 1))\n", 373 | " wrapper.fit(train_images, train_labels, epochs=1, batch_size=64)\n", 374 | " performance[\"with_svm@-3\"].append(wrapper.evaluate(test_images, test_labels))\n", 375 | " performance[\"without_svm\"].append(\n", 376 | " wrapper.model.evaluate(test_images, to_categorical(test_labels))[1])\n", 377 | "\n", 378 | " # Try it for the different SVM\n", 379 | " wrapper.fit_svm(train_images, train_labels, wrapper.model.layers[-2])\n", 380 | " performance[\"with_svm@-2\"].append(wrapper.evaluate(test_images, test_labels))\n", 381 | " \n", 382 | " print(performance)\n", 383 | " filename = 'performance{}.json'.format(j)\n", 384 | " print(filename)\n", 385 | " with open(filename, 'w') as file:\n", 386 | " json.dump(performance, file)" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": { 393 | "colab": { 394 | "autoexec": { 395 | "startup": false, 396 | "wait_interval": 0 397 | } 398 | }, 399 | "colab_type": "code", 400 | "id": "GZFCd5D5dGPy" 401 | }, 402 | "outputs": [], 403 | "source": [ 404 | "for j in range(5): \n", 405 | " filename = 'performance{}.json'.format(j)\n", 406 | " files.download(filename)" 407 | ] 408 | } 409 | ], 410 | "metadata": { 411 | "accelerator": "GPU", 412 | "colab": { 413 | "collapsed_sections": [], 414 | "default_view": {}, 415 | "name": "Deep Learning Paper", 416 | "provenance": [], 417 | "toc_visible": true, 418 | "version": "0.3.2", 419 | "views": {} 420 | }, 421 | "kernelspec": { 422 | "display_name": "Python 3", 423 | "language": "python", 424 | "name": "python3" 425 | }, 426 | "language_info": { 427 | "codemirror_mode": { 428 | "name": "ipython", 429 | "version": 3 430 | }, 431 | "file_extension": ".py", 432 | "mimetype": "text/x-python", 433 | "name": "python", 434 | "nbconvert_exporter": "python", 435 | "pygments_lexer": "ipython3", 436 | "version": "3.6.4" 437 | } 438 | }, 439 | "nbformat": 4, 440 | "nbformat_minor": 2 441 | } 442 | --------------------------------------------------------------------------------