├── tests
├── __init__.py
└── fake_test.py
├── experiments
├── results
│ ├── performance_30epochs_combined.json
│ ├── MNIST
│ │ ├── accuracy-mnist.png
│ │ ├── accuracy-30epoch-mnist.png
│ │ ├── performance-30epochs-mnist.png
│ │ ├── performance_combined.json
│ │ └── performance_30epochs_combined.json
│ └── Fashion-MNIST
│ │ ├── accuracy-fashion-mnist.png
│ │ ├── accuracy-30epoch-fashion-mnist.png
│ │ ├── performance-30epochs-fashion.png
│ │ ├── performance_combined.json
│ │ └── performance_30epochs_combined.json
├── combine_results.ipynb
├── 30epochresults.ipynb
└── deep_learning_paper.ipynb
├── keras_svm
├── __init__.py
└── model_svm_wrapper.py
├── .gitignore
├── MANIFEST.in
├── send.bat
├── setup.cfg
├── tox.ini
├── LICENSE
├── setup.py
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/results/performance_30epochs_combined.json:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/fake_test.py:
--------------------------------------------------------------------------------
1 | def test_success():
2 | assert True
3 |
--------------------------------------------------------------------------------
/keras_svm/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_svm_wrapper import ModelSVMWrapper
2 |
3 | __all__ = ["ModelSVMWrapper"]
4 |
--------------------------------------------------------------------------------
/experiments/results/MNIST/accuracy-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/MNIST/accuracy-mnist.png
--------------------------------------------------------------------------------
/experiments/results/MNIST/accuracy-30epoch-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/MNIST/accuracy-30epoch-mnist.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | env/
3 | .tox/
4 | *.egg-info/
5 | dist/
6 | build/
7 | .pytest_cache/
8 | **/__pycache__/
9 | *.pyc
10 | **/.ipynb_checkpoints/
11 |
--------------------------------------------------------------------------------
/experiments/results/MNIST/performance-30epochs-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/MNIST/performance-30epochs-mnist.png
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Include the README
2 | include *.md
3 |
4 | # Include the data files
5 | recursive-include data *
6 |
7 | include *.py
8 | include *.txt
9 | include LICENSE
--------------------------------------------------------------------------------
/experiments/results/Fashion-MNIST/accuracy-fashion-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/Fashion-MNIST/accuracy-fashion-mnist.png
--------------------------------------------------------------------------------
/experiments/results/Fashion-MNIST/accuracy-30epoch-fashion-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/Fashion-MNIST/accuracy-30epoch-fashion-mnist.png
--------------------------------------------------------------------------------
/experiments/results/Fashion-MNIST/performance-30epochs-fashion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luttik/keras_svm/HEAD/experiments/results/Fashion-MNIST/performance-30epochs-fashion.png
--------------------------------------------------------------------------------
/send.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | SET /P variable="Did you remember to run tox?"
3 |
4 | if %variable%==Y set variable=y
5 | if %variable%==y (
6 | rm -r dist/*
7 | python setup.py bdist_wheel --universal
8 | python setup.py sdist
9 | twine upload dist/*
10 | ) else (
11 | echo upload cancelled
12 | )
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | # This includes the license file in the wheel.
3 | license_file = LICENSE
4 | [bdist_wheel]
5 | # This flag says to generate wheels that support both Python 2 and Python
6 | # 3. If your code will not run unchanged on both Python 2 and 3, you will
7 | # need to generate separate wheels for each Python version that you
8 | # support.
9 | universal=1
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # Use:
2 | #
3 | # - check-manifest
4 | # confirm items checked into vcs are in your sdist
5 | # - python setup.py check (using the readme_renderer extension)
6 | # confirms your long_description will render correctly on pypi
7 | #
8 | # and also to help confirm pull requests to this project.
9 |
10 | [tox]
11 | envlist = py{27,36}
12 |
13 | [testenv]
14 | basepython =
15 | py27: python2.7
16 | py34: python3.4
17 | py35: python3.5
18 | py36: python3.6
19 | deps =
20 | docutils
21 | check-manifest
22 | ; readme_renderer
23 | flake8
24 | pytest
25 | commands =
26 | check-manifest --ignore tox.ini,tests*,*.bat
27 | python setup.py check -m -r -s
28 | flake8 .
29 | py.test tests
30 | [flake8]
31 |
32 | ignore = E501
33 | exclude = .tox,*.egg,build,data,env,*.bat
34 | select = E,W,F
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Daan Luttik
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='keras_svm',
5 | version='1.0.0b10',
6 | description='A model to use keras models with Support Vector Machines',
7 | url='https://github.com/Luttik/keras_svm/tree/master', # Optional
8 | author='Daan Luttik', # Optional
9 | author_email='d.t.luttik@gmail.com', # Optional
10 | license='MIT',
11 | classifiers=[
12 | 'Development Status :: 4 - Beta',
13 | 'Intended Audience :: Developers',
14 | 'Intended Audience :: Education',
15 | 'Intended Audience :: Science/Research',
16 | 'License :: OSI Approved :: MIT License',
17 | 'Programming Language :: Python :: 2',
18 | 'Programming Language :: Python :: 2.7',
19 | 'Programming Language :: Python :: 3',
20 | 'Programming Language :: Python :: 3.4',
21 | 'Programming Language :: Python :: 3.5',
22 | 'Programming Language :: Python :: 3.6',
23 | 'Topic :: Software Development :: Libraries',
24 | 'Topic :: Software Development :: Libraries :: Python Modules'
25 | ],
26 | keywords='keras sklearn svm ml', # Optional
27 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required
28 | install_requires=['keras', 'scikit-learn'], # Optional
29 | )
30 |
--------------------------------------------------------------------------------
/experiments/results/MNIST/performance_combined.json:
--------------------------------------------------------------------------------
1 | [{"with_svm@-2": [0.9898, 0.9912, 0.9929, 0.9926, 0.9925, 0.9922, 0.9924, 0.9928, 0.9919, 0.9926], "with_svm@-3": [0.9912, 0.9916, 0.9926, 0.9921, 0.9925, 0.9931, 0.9931, 0.9931, 0.9935, 0.9927], "without_svm": [0.9873, 0.9897, 0.9905, 0.99, 0.9925, 0.9932, 0.9901, 0.9912, 0.99, 0.9907]},
2 | {"with_svm@-2": [0.9891, 0.9902, 0.9918, 0.9927, 0.9926, 0.9935, 0.993, 0.9939, 0.9927, 0.9933], "with_svm@-3": [0.9912, 0.9917, 0.9918, 0.9935, 0.9926, 0.9933, 0.9936, 0.9932, 0.9935, 0.9932], "without_svm": [0.9836, 0.9847, 0.9909, 0.9908, 0.9906, 0.9898, 0.9922, 0.9905, 0.9925, 0.9906]},
3 | {"with_svm@-2": [0.9893, 0.9905, 0.9921, 0.9929, 0.9932, 0.9935, 0.9935, 0.9935, 0.9921, 0.9928], "with_svm@-3": [0.9906, 0.9909, 0.992, 0.9934, 0.9929, 0.9938, 0.9927, 0.9935, 0.9934, 0.9937], "without_svm": [0.9823, 0.9894, 0.9916, 0.9908, 0.9906, 0.9929, 0.9934, 0.9927, 0.9919, 0.9921]},
4 | {"with_svm@-2": [0.9897, 0.9912, 0.9919, 0.9923, 0.9913, 0.9923, 0.992, 0.9921, 0.9931, 0.9927], "with_svm@-3": [0.9907, 0.992, 0.9929, 0.9922, 0.9926, 0.993, 0.9927, 0.9931, 0.993, 0.9931], "without_svm": [0.9849, 0.9895, 0.9866, 0.9896, 0.9916, 0.9915, 0.991, 0.9928, 0.9924, 0.9917]},
5 | {"with_svm@-2": [0.9891, 0.9905, 0.9916, 0.9924, 0.992, 0.9927, 0.9923, 0.9935, 0.9927, 0.9935], "with_svm@-3": [0.9903, 0.9917, 0.9927, 0.993, 0.9933, 0.9933, 0.9923, 0.9931, 0.9928, 0.9939], "without_svm": [0.9869, 0.9879, 0.9918, 0.991, 0.9902, 0.9891, 0.9921, 0.9924, 0.9924, 0.9921]}]
--------------------------------------------------------------------------------
/experiments/results/Fashion-MNIST/performance_combined.json:
--------------------------------------------------------------------------------
1 | [{"with_svm@-2": [0.8779, 0.8947, 0.9037, 0.9061, 0.911, 0.9137, 0.9087, 0.9128, 0.9109, 0.9116], "with_svm@-3": [0.8973, 0.9043, 0.905, 0.911, 0.9104, 0.9107, 0.9112, 0.9121, 0.9117, 0.9118], "without_svm": [0.8581, 0.8506, 0.8879, 0.8987, 0.9048, 0.9031, 0.9064, 0.9076, 0.8957, 0.9061]},
2 | {"with_svm@-2": [0.8799, 0.8943, 0.9013, 0.907, 0.9101, 0.9114, 0.9078, 0.9137, 0.9132, 0.9146], "with_svm@-3": [0.8957, 0.9004, 0.9042, 0.9052, 0.9103, 0.9109, 0.9109, 0.9112, 0.9114, 0.9128], "without_svm": [0.8444, 0.8788, 0.8809, 0.8927, 0.9043, 0.9091, 0.9024, 0.9059, 0.9057, 0.9053]},
3 | {"with_svm@-2": [0.8743, 0.8937, 0.902, 0.9039, 0.9108, 0.9113, 0.9075, 0.9107, 0.9103, 0.912], "with_svm@-3": [0.8929, 0.9007, 0.9035, 0.908, 0.9056, 0.9078, 0.9083, 0.9052, 0.909, 0.9072], "without_svm": [0.8517, 0.8701, 0.8958, 0.8759, 0.901, 0.9015, 0.9015, 0.8979, 0.9033, 0.9049]},
4 | {"with_svm@-2": [0.8794, 0.8946, 0.9012, 0.9044, 0.9095, 0.9138, 0.9124, 0.9127, 0.9126, 0.9128], "with_svm@-3": [0.8922, 0.901, 0.9026, 0.9046, 0.9059, 0.9102, 0.9081, 0.908, 0.9114, 0.9119], "without_svm": [0.8536, 0.8766, 0.8935, 0.8707, 0.8994, 0.9079, 0.9052, 0.9005, 0.909, 0.9099]},
5 | {"with_svm@-2": [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145, 0.9122, 0.9123], "with_svm@-3": [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912, 0.9121, 0.9109], "without_svm": [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062, 0.8882, 0.9062]}]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Keras SVM
2 |
3 | [](https://pypi.org/project/keras-svm/)
4 | [](https://pypi.org/project/keras-svm/)
5 | [](https://github.com/Luttik/keras_svm/blob/master/LICENSE)
6 | [](https://pypi.org/project/keras-svm/)
7 |
8 | ## Purpose
9 | Provides a wrapper class that effectively replaces the softmax of your Keras model with a SVM.
10 |
11 | The SVM has no impact on the training of the Neural Network, but replacing softmax with an SVM has been shown to perform better on unseen data.
12 |
13 | ## Code examples
14 | ### Example construction
15 | ```
16 | # Build a classical model
17 | def build_model():
18 | model = models.Sequential()
19 | model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
20 | model.add(layers.MaxPooling2D((2, 2)))
21 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
22 | model.add(layers.MaxPooling2D((2, 2)))
23 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
24 | model.add(layers.Flatten(name="intermediate_output"))
25 | model.add(layers.Dense(64, activation='relu'))
26 | model.add(layers.Dense(10, activation='softmax'))
27 |
28 | # The extra metric is important for the evaluate function
29 | model.compile(optimizer='rmsprop',
30 | loss='categorical_crossentropy',
31 | metrics=['accuracy'])
32 | return model
33 |
34 | # Wrap it in the ModelSVMWrapper
35 | wrapper = ModelSVMWrapper(build_model())
36 | ```
37 |
38 | ### Training while maintaining an accuracy score
39 | ```
40 | accuracy = {
41 | "with_svm": [],
42 | "without_svm": []
43 | }
44 |
45 | epochs = 10
46 | for i in range(epochs):
47 | print('Starting run: {}'.format(i))
48 | wrapper.fit(train_images, train_labels, epochs=1, batch_size=64)
49 | accuracy["with_svm"].append(wrapper.evaluate(test_images, test_labels))
50 | accuracy["without_svm"].append(
51 | wrapper.model.evaluate(test_images, to_categorical(test_labels))[1])
52 | ```
53 |
--------------------------------------------------------------------------------
/experiments/combine_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 9,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import plotly\n",
10 | "import plotly.plotly as py\n",
11 | "import plotly.graph_objs as go\n",
12 | "import json\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 10,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "results_json = json.load(open('results/Fashion-MNIST/performance_combined.json'))"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 11,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "data": {
32 | "text/plain": [
33 | "{'with_svm@-2': [[0.8779, 0.8799, 0.8743, 0.8794, 0.8876],\n",
34 | " [0.8947, 0.8943, 0.8937, 0.8946, 0.8992],\n",
35 | " [0.9037, 0.9013, 0.902, 0.9012, 0.9077],\n",
36 | " [0.9061, 0.907, 0.9039, 0.9044, 0.912],\n",
37 | " [0.911, 0.9101, 0.9108, 0.9095, 0.9115],\n",
38 | " [0.9137, 0.9114, 0.9113, 0.9138, 0.9109],\n",
39 | " [0.9087, 0.9078, 0.9075, 0.9124, 0.9122],\n",
40 | " [0.9128, 0.9137, 0.9107, 0.9127, 0.9145],\n",
41 | " [0.9109, 0.9132, 0.9103, 0.9126, 0.9122],\n",
42 | " [0.9116, 0.9146, 0.912, 0.9128, 0.9123]],\n",
43 | " 'with_svm@-3': [[0.8973, 0.8957, 0.8929, 0.8922, 0.8999],\n",
44 | " [0.9043, 0.9004, 0.9007, 0.901, 0.9046],\n",
45 | " [0.905, 0.9042, 0.9035, 0.9026, 0.9097],\n",
46 | " [0.911, 0.9052, 0.908, 0.9046, 0.9095],\n",
47 | " [0.9104, 0.9103, 0.9056, 0.9059, 0.9111],\n",
48 | " [0.9107, 0.9109, 0.9078, 0.9102, 0.9115],\n",
49 | " [0.9112, 0.9109, 0.9083, 0.9081, 0.9116],\n",
50 | " [0.9121, 0.9112, 0.9052, 0.908, 0.912],\n",
51 | " [0.9117, 0.9114, 0.909, 0.9114, 0.9121],\n",
52 | " [0.9118, 0.9128, 0.9072, 0.9119, 0.9109]],\n",
53 | " 'without_svm': [[0.8581, 0.8444, 0.8517, 0.8536, 0.8667],\n",
54 | " [0.8506, 0.8788, 0.8701, 0.8766, 0.8758],\n",
55 | " [0.8879, 0.8809, 0.8958, 0.8935, 0.8939],\n",
56 | " [0.8987, 0.8927, 0.8759, 0.8707, 0.9006],\n",
57 | " [0.9048, 0.9043, 0.901, 0.8994, 0.903],\n",
58 | " [0.9031, 0.9091, 0.9015, 0.9079, 0.8985],\n",
59 | " [0.9064, 0.9024, 0.9015, 0.9052, 0.9041],\n",
60 | " [0.9076, 0.9059, 0.8979, 0.9005, 0.9062],\n",
61 | " [0.8957, 0.9057, 0.9033, 0.909, 0.8882],\n",
62 | " [0.9061, 0.9053, 0.9049, 0.9099, 0.9062]]}"
63 | ]
64 | },
65 | "execution_count": 11,
66 | "metadata": {},
67 | "output_type": "execute_result"
68 | }
69 | ],
70 | "source": [
71 | "results = {key:[[results_json[j][key][i] for j in range(len(results_json))] for i in range(len(results_json[0][key]))]\n",
72 | " for key in results_json[0]}\n",
73 | "results"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 12,
79 | "metadata": {},
80 | "outputs": [
81 | {
82 | "data": {
83 | "text/html": [
84 | ""
85 | ],
86 | "text/plain": [
87 | ""
88 | ]
89 | },
90 | "execution_count": 12,
91 | "metadata": {},
92 | "output_type": "execute_result"
93 | }
94 | ],
95 | "source": [
96 | "plotly.tools.set_credentials_file(username='luttik', api_key='6p0E4ba6sVIgV575o3uX')\n",
97 | "\n",
98 | "data = [\n",
99 | " go.Scatter(\n",
100 | " x=list(range(10)),\n",
101 | " y=[np.mean(results[key][epoch]) for epoch in range(len(results[key]))],\n",
102 | " name=key,\n",
103 | " error_y=dict(\n",
104 | " type='data',\n",
105 | " array=[np.std(results[key][epoch]) for epoch in range(len(results[key]))],\n",
106 | " visible=True\n",
107 | " )\n",
108 | " )\n",
109 | " for key in results\n",
110 | "]\n",
111 | "\n",
112 | "layout = dict(title = 'Accuracy on the Fashion-MNIST dataset',\n",
113 | " xaxis = dict(title = 'Epoch'),\n",
114 | " yaxis = dict(title = 'Accuracy'),\n",
115 | " )\n",
116 | "\n",
117 | "py.iplot(data, filename='results-svm-replacement-fashion', layout=layout)"
118 | ]
119 | }
120 | ],
121 | "metadata": {
122 | "kernelspec": {
123 | "display_name": "Python 3",
124 | "language": "python",
125 | "name": "python3"
126 | },
127 | "language_info": {
128 | "codemirror_mode": {
129 | "name": "ipython",
130 | "version": 3
131 | },
132 | "file_extension": ".py",
133 | "mimetype": "text/x-python",
134 | "name": "python",
135 | "nbconvert_exporter": "python",
136 | "pygments_lexer": "ipython3",
137 | "version": "3.6.4"
138 | }
139 | },
140 | "nbformat": 4,
141 | "nbformat_minor": 2
142 | }
143 |
--------------------------------------------------------------------------------
/experiments/30epochresults.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import plotly\n",
10 | "import plotly.plotly as py\n",
11 | "import plotly.graph_objs as go\n",
12 | "import json\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "data = json.load(open('results/MNIST/performance_30epochs_combined.json'))\n",
23 | "history = [x['history'] for x in data]\n",
24 | "performances = [x['performance'] for x in data]"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 4,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "data": {
34 | "text/plain": [
35 | "(array([[0.94743333, 0.94448333, 0.94555 , 0.9476 , 0.9471 ],\n",
36 | " [0.9853 , 0.98511667, 0.98516667, 0.98528333, 0.98583333],\n",
37 | " [0.98936667, 0.99013333, 0.99 , 0.99033333, 0.98983333],\n",
38 | " [0.99188333, 0.99286667, 0.9922 , 0.9926 , 0.99235 ],\n",
39 | " [0.99345 , 0.99396667, 0.9939 , 0.99435 , 0.9941 ],\n",
40 | " [0.99508333, 0.99548333, 0.99503333, 0.99541667, 0.99521667],\n",
41 | " [0.99598333, 0.99593333, 0.9957 , 0.99635 , 0.9963 ],\n",
42 | " [0.99655 , 0.99675 , 0.9967 , 0.99688333, 0.99675 ],\n",
43 | " [0.99718333, 0.99715 , 0.99726667, 0.99751667, 0.99725 ],\n",
44 | " [0.99716667, 0.99738333, 0.99766667, 0.99785 , 0.99778333],\n",
45 | " [0.99795 , 0.99775 , 0.9979 , 0.99805 , 0.99793333],\n",
46 | " [0.99818333, 0.99835 , 0.99823333, 0.99825 , 0.99841667],\n",
47 | " [0.99816667, 0.99846667, 0.99841667, 0.99846667, 0.99838333],\n",
48 | " [0.99826667, 0.99858333, 0.9986 , 0.9988 , 0.99871667],\n",
49 | " [0.99866667, 0.99891667, 0.99863333, 0.99873333, 0.99888333],\n",
50 | " [0.99856667, 0.99886667, 0.99886667, 0.999 , 0.99865 ],\n",
51 | " [0.99885 , 0.999 , 0.99878333, 0.99905 , 0.99881667],\n",
52 | " [0.99891667, 0.99881667, 0.99883333, 0.99918333, 0.99911667],\n",
53 | " [0.99903333, 0.99916667, 0.99895 , 0.99906667, 0.99918333],\n",
54 | " [0.99908333, 0.9994 , 0.99903333, 0.99928333, 0.9991 ],\n",
55 | " [0.99911667, 0.99896667, 0.99918333, 0.99933333, 0.9992 ],\n",
56 | " [0.99925 , 0.99911667, 0.9992 , 0.99875 , 0.9992 ],\n",
57 | " [0.99931667, 0.99916667, 0.99928333, 0.99915 , 0.99938333],\n",
58 | " [0.99928333, 0.99938333, 0.99925 , 0.99956667, 0.99926667],\n",
59 | " [0.99933333, 0.99953333, 0.99923333, 0.99921667, 0.99926667],\n",
60 | " [0.99916667, 0.99946667, 0.99908333, 0.99938333, 0.99946667],\n",
61 | " [0.99931667, 0.99936667, 0.99941667, 0.99941667, 0.99928333],\n",
62 | " [0.99935 , 0.99965 , 0.99958333, 0.99906667, 0.99935 ],\n",
63 | " [0.99945 , 0.9994 , 0.99953333, 0.99926667, 0.99935 ],\n",
64 | " [0.99953333, 0.99936667, 0.99935 , 0.99955 , 0.99938333]]),\n",
65 | " array([[0.981 , 0.9844, 0.9848, 0.9743, 0.9843],\n",
66 | " [0.9902, 0.9881, 0.9878, 0.9881, 0.9892],\n",
67 | " [0.9916, 0.9916, 0.9886, 0.9894, 0.9915],\n",
68 | " [0.9847, 0.9893, 0.9912, 0.9909, 0.9915],\n",
69 | " [0.9933, 0.993 , 0.9903, 0.9913, 0.9916],\n",
70 | " [0.9897, 0.9919, 0.9857, 0.9933, 0.9917],\n",
71 | " [0.9917, 0.9919, 0.9917, 0.9915, 0.9923],\n",
72 | " [0.9891, 0.992 , 0.9911, 0.9914, 0.9932],\n",
73 | " [0.9929, 0.9911, 0.9904, 0.9912, 0.9911],\n",
74 | " [0.9925, 0.9935, 0.9927, 0.9932, 0.9921],\n",
75 | " [0.9906, 0.9923, 0.9917, 0.9917, 0.9892],\n",
76 | " [0.9919, 0.9938, 0.9919, 0.994 , 0.9917],\n",
77 | " [0.993 , 0.9933, 0.9917, 0.9936, 0.9928],\n",
78 | " [0.9891, 0.9931, 0.991 , 0.9922, 0.9922],\n",
79 | " [0.993 , 0.9928, 0.9908, 0.9924, 0.9918],\n",
80 | " [0.993 , 0.9922, 0.9943, 0.9943, 0.9921],\n",
81 | " [0.9933, 0.9934, 0.992 , 0.9935, 0.9919],\n",
82 | " [0.9916, 0.9924, 0.9926, 0.9925, 0.9933],\n",
83 | " [0.9923, 0.9929, 0.9918, 0.9935, 0.9918],\n",
84 | " [0.9923, 0.9931, 0.993 , 0.9935, 0.9916],\n",
85 | " [0.9922, 0.9925, 0.9914, 0.9896, 0.9923],\n",
86 | " [0.9919, 0.9922, 0.9927, 0.993 , 0.9925],\n",
87 | " [0.9915, 0.9928, 0.9924, 0.9933, 0.9914],\n",
88 | " [0.993 , 0.9935, 0.9931, 0.9928, 0.9922],\n",
89 | " [0.9917, 0.9935, 0.9926, 0.9921, 0.9916],\n",
90 | " [0.9913, 0.9937, 0.993 , 0.9922, 0.9921],\n",
91 | " [0.9933, 0.9932, 0.9908, 0.993 , 0.9916],\n",
92 | " [0.9935, 0.9922, 0.9917, 0.9924, 0.9918],\n",
93 | " [0.9933, 0.9918, 0.9937, 0.9916, 0.9934],\n",
94 | " [0.9923, 0.9914, 0.9926, 0.9921, 0.9923]]))"
95 | ]
96 | },
97 | "execution_count": 4,
98 | "metadata": {},
99 | "output_type": "execute_result"
100 | }
101 | ],
102 | "source": [
103 | "train_accuracies = np.swapaxes(np.array([x['acc'] for x in history], dtype=float), 0, 1)\n",
104 | "test_accuracy = np.swapaxes(np.array([x['val_acc'] for x in history], dtype=float), 0, 1)\n",
105 | "train_accuracies, test_accuracy"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 7,
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "data": {
115 | "text/html": [
116 | ""
117 | ],
118 | "text/plain": [
119 | ""
120 | ]
121 | },
122 | "execution_count": 7,
123 | "metadata": {},
124 | "output_type": "execute_result"
125 | }
126 | ],
127 | "source": [
128 | "results = {\n",
129 | " \"train_accuracies\": train_accuracies,\n",
130 | " \"test_accuracy\": test_accuracy\n",
131 | "}\n",
132 | "\n",
133 | "data = [\n",
134 | " go.Scatter(\n",
135 | " x=list(range(30)),\n",
136 | " y=[np.mean(results[key][epoch]) for epoch in range(len(results[key]))],\n",
137 | " name=key,\n",
138 | " error_y=dict(\n",
139 | " type='data',\n",
140 | " array=[np.std(results[key][epoch]) for epoch in range(len(results[key]))],\n",
141 | " visible=True\n",
142 | " )\n",
143 | " )\n",
144 | " for key in results\n",
145 | "]\n",
146 | "\n",
147 | "layout = dict(title = 'Accuracy on the Fashion-MNIST dataset 30-epochs',\n",
148 | " xaxis = dict(title = 'Epoch'),\n",
149 | " yaxis = dict(title = 'Accuracy'),\n",
150 | " )\n",
151 | "\n",
152 | "py.iplot(data, filename='results-30epochs-mnist', layout=layout)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "performances"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 6,
167 | "metadata": {},
168 | "outputs": [
169 | {
170 | "data": {
171 | "text/html": [
172 | ""
173 | ],
174 | "text/plain": [
175 | ""
176 | ]
177 | },
178 | "execution_count": 6,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "plot_data = {\n",
185 | " 'Without SVM': test_accuracy[-1],\n",
186 | " 'SVM_@-2': [x['with_svm@-2'] for x in performances],\n",
187 | " 'SVM_@-3': [x['with_svm@-3'] for x in performances]\n",
188 | "}\n",
189 | "\n",
190 | "data = [\n",
191 | " go.Box(\n",
192 | " y=plot_data[key],\n",
193 | " name=key,\n",
194 | " ) for key in plot_data\n",
195 | "]\n",
196 | "\n",
197 | "layout = dict(title = 'Accuracy on the Fashion-MNIST dataset 30-epochs',\n",
198 | " xaxis = dict(title = 'Epoch'),\n",
199 | " yaxis = dict(title = 'Accuracy'),\n",
200 | " )\n",
201 | "\n",
202 | "py.iplot(data, filename='performances-30epochs-mnist', layout=layout)"
203 | ]
204 | }
205 | ],
206 | "metadata": {
207 | "kernelspec": {
208 | "display_name": "Python 3",
209 | "language": "python",
210 | "name": "python3"
211 | },
212 | "language_info": {
213 | "codemirror_mode": {
214 | "name": "ipython",
215 | "version": 3
216 | },
217 | "file_extension": ".py",
218 | "mimetype": "text/x-python",
219 | "name": "python",
220 | "nbconvert_exporter": "python",
221 | "pygments_lexer": "ipython3",
222 | "version": "3.6.4"
223 | }
224 | },
225 | "nbformat": 4,
226 | "nbformat_minor": 2
227 | }
228 |
--------------------------------------------------------------------------------
/experiments/results/Fashion-MNIST/performance_30epochs_combined.json:
--------------------------------------------------------------------------------
1 | [{"history": {"val_loss": [0.39961640694141387, 0.3240031172275543, 0.36486296498775483, 0.2999567676544189, 0.26941082603931427, 0.28105862131118775, 0.2583421004772186, 0.27426248898506167, 0.26412350838184356, 0.2587658328771591, 0.3263972431182861, 0.3037034754753113, 0.2993787395477295, 0.3444756777048111, 0.34303277168273927, 0.3952072314620018, 0.323935725069046, 0.43344608159065245, 0.43881502020359037, 0.40885772070884707, 0.48054315140247345, 0.40082721934318544, 0.4527245432853699, 0.47377985467910766, 0.4480116040706634, 0.4445895487546921, 0.5960354033231735, 0.5005660709023476, 0.4928376198232174, 0.5406574306488037], "val_acc": [0.8512, 0.8836, 0.8636, 0.8917, 0.9024, 0.9024, 0.9076, 0.9061, 0.9144, 0.9106, 0.8941, 0.9075, 0.9111, 0.9024, 0.9087, 0.8956, 0.9104, 0.9007, 0.9093, 0.9069, 0.9029, 0.9001, 0.9099, 0.9088, 0.9077, 0.9047, 0.9076, 0.9067, 0.9097, 0.9044], "loss": [0.5369842763264974, 0.33226358030637104, 0.28292571523189547, 0.2514223607858022, 0.22796711163520814, 0.20951138263543448, 0.19243775055408477, 0.17820718478361766, 0.1669310154000918, 0.15484501781662305, 0.14333154555956523, 0.13416243276198705, 0.12566612465679647, 0.119671809245646, 0.11197880250414212, 0.10515761943757534, 0.09846485815346241, 0.09648188420385122, 0.09266547986765702, 0.0884354511961341, 0.0859529619326194, 0.08248054644614458, 0.08047926588232318, 0.0785362726101031, 0.07778030724997322, 0.07500176610127092, 0.07231801753346809, 0.07249056691670169, 0.07015723809991031, 0.07295623635863885], "acc": [0.80485, 0.8794666666666666, 0.8975666666666666, 0.9077833333333334, 0.9161833333333333, 0.9232333333333334, 0.9284, 0.934, 0.9381, 0.9432333333333334, 0.9473166666666667, 0.9504166666666667, 0.9542833333333334, 0.9563833333333334, 0.9586333333333333, 0.9607166666666667, 0.9641, 0.9638333333333333, 0.96575, 0.9675166666666667, 0.9686333333333333, 0.96995, 0.9702833333333334, 0.97225, 0.9716666666666667, 0.9734, 0.9745166666666667, 0.97495, 0.9749333333333333, 0.9749166666666667]}, "performance": {"with_svm@-3": 0.9058, "with_svm@-2": 0.9093}},{"history": {"val_loss": [0.4150441631793976, 0.3261751090049744, 0.3454334979057312, 0.3094462700843811, 0.28191435914039614, 0.2875005770921707, 0.2872392510890961, 0.26201132280826567, 0.2847232746839523, 0.2788535724401474, 0.33057111032009123, 0.2983163303613663, 0.30958148255348206, 0.34721010646820066, 0.3631596899986267, 0.3546803478240967, 0.4272496466159821, 0.3690841608285904, 0.3891567182540894, 0.4191183884859085, 0.4348145247459412, 0.4770137900710106, 0.4338083899974823, 0.5618760014295578, 0.5968949678063392, 0.5492106608390808, 0.47301609473228456, 0.5954591445505619, 0.5699526505947113, 0.6091273111760617], "val_acc": [0.8442, 0.8812, 0.8708, 0.8908, 0.9007, 0.8984, 0.8951, 0.9066, 0.9122, 0.9021, 0.9064, 0.9103, 0.9012, 0.9002, 0.9087, 0.9068, 0.8991, 0.9074, 0.9076, 0.9053, 0.9058, 0.9052, 0.9051, 0.8962, 0.9071, 0.8982, 0.9004, 0.908, 0.9008, 0.9042], "loss": [0.5372069796085358, 0.33147892653942107, 0.2819155935049057, 0.2501675777196884, 0.22795790247519812, 0.20883161357243857, 0.19210073427359264, 0.17625245781143506, 0.16437627596060436, 0.15312131391763686, 0.14188234887719153, 0.13113620449503263, 0.1231539293517669, 0.11442032988071442, 0.10762197708835204, 0.10159016521821419, 0.09620393821696441, 0.09138783883477251, 0.08485465567509333, 0.08136600005974372, 0.07779915048554539, 0.07684610707126557, 0.07310853577925203, 0.07031081450295945, 0.06876748004704714, 0.06674423448170225, 0.06708652312122286, 0.0629670261045297, 0.06223219020180404, 0.060887808101888125], "acc": [0.8011833333333334, 0.8782, 0.896, 0.9074166666666666, 0.91605, 0.92305, 0.9288833333333333, 0.9350166666666667, 0.9386666666666666, 0.9424, 0.94645, 0.9512333333333334, 0.9545833333333333, 0.9567666666666667, 0.9592333333333334, 0.9621166666666666, 0.96345, 0.9660166666666666, 0.9680666666666666, 0.9688666666666667, 0.97135, 0.9718, 0.9725833333333334, 0.97455, 0.9753833333333334, 0.9762166666666666, 0.9763666666666667, 0.977, 0.9783833333333334, 0.9784833333333334]}, "performance": {"with_svm@-3": 0.9023, "with_svm@-2": 0.9039}},{"history": {"val_loss": [0.4043342845916748, 0.4140978363990784, 0.2880860612392426, 0.2902471554279327, 0.26663297791481017, 0.27250645632743836, 0.27851846265792846, 0.2674907832622528, 0.2509124885082245, 0.2766284414768219, 0.30450538427829743, 0.3396841156244278, 0.33489206714630126, 0.32968835673332214, 0.35049580488204957, 0.40639078683853147, 0.4156203591585159, 0.44790499334335326, 0.39392427277565, 0.4508971324920654, 0.45144741268157956, 0.4733809838980436, 0.43593108036518097, 0.47289209780693053, 0.5302549986958504, 0.5181801124095917, 0.4857735851764679, 0.43312914094924926, 0.589274645756185, 0.5597955271720886], "val_acc": [0.8577, 0.8536, 0.8938, 0.8919, 0.9056, 0.9018, 0.8997, 0.9096, 0.9146, 0.9096, 0.9093, 0.9034, 0.8991, 0.9053, 0.9027, 0.9061, 0.9061, 0.9006, 0.9028, 0.9073, 0.9044, 0.9051, 0.9003, 0.907, 0.9047, 0.9059, 0.8963, 0.9007, 0.9079, 0.9054], "loss": [0.5449977326552073, 0.331432382162412, 0.27721960909366605, 0.24537183661460876, 0.2236899551153183, 0.2034526499390602, 0.18604228529532751, 0.1735770888586839, 0.1608558345834414, 0.14848035674492518, 0.13826039456129074, 0.13101752441922823, 0.12131975610057513, 0.11398900530735652, 0.10702471297780673, 0.10125248184750478, 0.09447303883483013, 0.09300156107619405, 0.08732006960138679, 0.08322347532113393, 0.08042012545143565, 0.07930759765207768, 0.07606005681777994, 0.07313436928779507, 0.06995374247866372, 0.0693866363959387, 0.06551957191874584, 0.06589129398725926, 0.06280896504372359, 0.0626244676458494], "acc": [0.7985, 0.8795, 0.8988833333333334, 0.9090833333333334, 0.9173666666666667, 0.92485, 0.93225, 0.9366833333333333, 0.9411166666666667, 0.94515, 0.94975, 0.9526833333333333, 0.9558, 0.95695, 0.9602666666666667, 0.9631166666666666, 0.9650166666666666, 0.9656833333333333, 0.9686166666666667, 0.9690833333333333, 0.9704666666666667, 0.9719, 0.9723333333333334, 0.97415, 0.97455, 0.9749666666666666, 0.9765166666666667, 0.97695, 0.9780333333333333, 0.9779666666666667]}, "performance": {"with_svm@-3": 0.9026, "with_svm@-2": 0.9083}},{"history": {"val_loss": [0.38714364495277404, 0.3567921408176422, 0.3250357707738876, 0.2858080851316452, 0.28071209461688995, 0.2722078594684601, 0.2801001211643219, 0.27960221977233884, 0.2834079221725464, 0.27649084842205046, 0.2873888593196869, 0.3608335973739624, 0.3075114233016968, 0.33885010154247286, 0.36965319867134094, 0.3924022182226181, 0.3339719518661499, 0.4083260585546494, 0.40416117042303085, 0.40474343745708463, 0.42218335597515105, 0.4105245063304901, 0.45178317244052885, 0.44313831994533537, 0.4897380232691765, 0.47371806797981264, 0.5233340344429016, 0.5715650954663754, 0.5431698856592179, 0.6270964526116848], "val_acc": [0.864, 0.8749, 0.8814, 0.8912, 0.8966, 0.9025, 0.9056, 0.9048, 0.9077, 0.9085, 0.909, 0.9108, 0.9093, 0.9108, 0.906, 0.9056, 0.9076, 0.9087, 0.9049, 0.907, 0.906, 0.8994, 0.9052, 0.9018, 0.9004, 0.9033, 0.9063, 0.9024, 0.8953, 0.9066], "loss": [0.5441179068406423, 0.3305619689186414, 0.28244102187951403, 0.24943221237659455, 0.22933951362371444, 0.2084388442079226, 0.19123481418987115, 0.17787826199332873, 0.16468209545016288, 0.15482280835111936, 0.14399771384596824, 0.13411573451360068, 0.12475449653317532, 0.11885996306836605, 0.11200869298875332, 0.1069638219366471, 0.10219919406970342, 0.09681515558858712, 0.0917619632239143, 0.08723219532544414, 0.08620026492799322, 0.08276339541425308, 0.08095510593801737, 0.07946293067646523, 0.07687996114244064, 0.07559641222606103, 0.07310451997717221, 0.07273915510618438, 0.072224404023091, 0.0677077879728439], "acc": [0.80055, 0.8808333333333334, 0.8976666666666666, 0.9079333333333334, 0.91625, 0.9231666666666667, 0.93035, 0.9345333333333333, 0.9392, 0.9429833333333333, 0.9468, 0.9507666666666666, 0.9537833333333333, 0.9559333333333333, 0.9582666666666667, 0.9615, 0.9625, 0.96465, 0.96645, 0.9681333333333333, 0.9686333333333333, 0.9702333333333333, 0.9708166666666667, 0.9718, 0.97205, 0.9737, 0.9737833333333333, 0.9746333333333334, 0.9753833333333334, 0.97585]}, "performance": {"with_svm@-3": 0.9048, "with_svm@-2": 0.9104}},{"history": {"val_loss": [0.46697981705665587, 0.3832242013454437, 0.303250941491127, 0.28147033443450925, 0.2645062158644199, 0.30668587293624877, 0.28016208398342135, 0.2843065684996545, 0.27898327768445014, 0.2912071517586708, 0.31837791128754617, 0.3626342387139797, 0.31595948444008826, 0.36499525430202484, 0.39556382931470874, 0.3669290903389454, 0.3701426920939237, 0.3836698907200247, 0.42279923573583367, 0.48294994099140165, 0.4073905217766762, 0.5648117405116558, 0.5509298311325721, 0.52655683157444, 0.5822678367150947, 0.5299802899837494, 0.48815737073849885, 0.5766397311203182, 0.5554989251375199, 0.5038782718300819], "val_acc": [0.8383, 0.8641, 0.8908, 0.8942, 0.9051, 0.8855, 0.9032, 0.9099, 0.9059, 0.9062, 0.9071, 0.9045, 0.9099, 0.9064, 0.9019, 0.9069, 0.9125, 0.9094, 0.9087, 0.9053, 0.8955, 0.9033, 0.9056, 0.9079, 0.9121, 0.9115, 0.9071, 0.9091, 0.9092, 0.8997], "loss": [0.5585080780824025, 0.3369152142286301, 0.28359521205425264, 0.25012954001426696, 0.22671767033338547, 0.20606048639615376, 0.18930058318376541, 0.17374658589363098, 0.1618980585793654, 0.14980819566051165, 0.13848809727629025, 0.12958743358751137, 0.1217845716059208, 0.11529610383758943, 0.10612352059980233, 0.10063879103461901, 0.09551240612169107, 0.09058071209937334, 0.0849770180746913, 0.0819070641172429, 0.08037483161017299, 0.07584481724972526, 0.07303781550178925, 0.07187192963063717, 0.07019264071881771, 0.06639131543164452, 0.06384243606943638, 0.06306189757228518, 0.06210263610463589, 0.06092731468016282], "acc": [0.79055, 0.8768666666666667, 0.8955166666666666, 0.9082, 0.9159333333333334, 0.9237166666666666, 0.9308, 0.9363333333333334, 0.9404166666666667, 0.9446333333333333, 0.94825, 0.9519666666666666, 0.95445, 0.9571833333333334, 0.9613333333333334, 0.9628833333333333, 0.9644166666666667, 0.9666166666666667, 0.9682, 0.9696333333333333, 0.9715333333333334, 0.97235, 0.9737333333333333, 0.9744333333333334, 0.9748833333333333, 0.9766, 0.9770166666666666, 0.9782, 0.9781666666666666, 0.9792666666666666]}, "performance": {"with_svm@-3": 0.8975, "with_svm@-2": 0.9038}}]
--------------------------------------------------------------------------------
/experiments/results/MNIST/performance_30epochs_combined.json:
--------------------------------------------------------------------------------
1 | [{"history": {"val_loss": [0.058012286397721616, 0.031465802982915195, 0.02752863751942932, 0.04448511460040463, 0.024455885998660234, 0.03770310556006152, 0.0302407726525661, 0.040125972307659685, 0.02950342787828249, 0.034729696841456874, 0.055798425468014376, 0.04136685403668739, 0.03716646059967136, 0.06317648815578528, 0.04977548953918097, 0.04911087775120434, 0.044314033737478914, 0.05617084128602239, 0.05534849855535915, 0.06092766595187452, 0.06630350116573029, 0.06982053399169147, 0.07034379402384695, 0.05951280631003713, 0.08072586988518408, 0.07859032812471571, 0.06860585440432163, 0.060573564682116304, 0.06489135026952284, 0.07252224225466455], "val_acc": [0.981, 0.9902, 0.9916, 0.9847, 0.9933, 0.9897, 0.9917, 0.9891, 0.9929, 0.9925, 0.9906, 0.9919, 0.993, 0.9891, 0.993, 0.993, 0.9933, 0.9916, 0.9923, 0.9923, 0.9922, 0.9919, 0.9915, 0.993, 0.9917, 0.9913, 0.9933, 0.9935, 0.9933, 0.9923], "loss": [0.17301000547260045, 0.04833274687674517, 0.03296006406117231, 0.026375785614883837, 0.020759016964376983, 0.016109659135122394, 0.013145960034359087, 0.01125916710313201, 0.009907794772488463, 0.00860547799981235, 0.006869320600131293, 0.006863135235392868, 0.006310418321777403, 0.005953495336345638, 0.004055292068486461, 0.004956505267661699, 0.00472531110749945, 0.004115091456871733, 0.0036294369941880632, 0.004000084901630392, 0.0034543862089507003, 0.0031846033174160616, 0.002197130242962362, 0.003921002310771739, 0.00240644144151604, 0.0034369908087196715, 0.003028492050913519, 0.002963888820605924, 0.0024517073052008377, 0.002116798436639715], "acc": [0.9474333333333333, 0.9853, 0.9893666666666666, 0.9918833333333333, 0.99345, 0.9950833333333333, 0.9959833333333333, 0.99655, 0.9971833333333333, 0.9971666666666666, 0.99795, 0.9981833333333333, 0.9981666666666666, 0.9982666666666666, 0.9986666666666667, 0.9985666666666667, 0.99885, 0.9989166666666667, 0.9990333333333333, 0.9990833333333333, 0.9991166666666667, 0.99925, 0.9993166666666666, 0.9992833333333333, 0.9993333333333333, 0.9991666666666666, 0.9993166666666666, 0.99935, 0.99945, 0.9995333333333334]}, "performance": {"with_svm@-3": 0.9942, "with_svm@-2": 0.9938}},
2 | {"history": {"val_loss": [0.04501374350236729, 0.03333432343536988, 0.02289775217374554, 0.031247786466928665, 0.024255429341574198, 0.02684926290004114, 0.03141559111404458, 0.028666918717404406, 0.04446157635115014, 0.028694591341076922, 0.03957730199245957, 0.04202645219727212, 0.040510537447208415, 0.0429101706364025, 0.04585446695649855, 0.0456064468267775, 0.044591218002707614, 0.058358724957093475, 0.05272812839602186, 0.05444539611888879, 0.05924669900309011, 0.05979339384530924, 0.06361616533229401, 0.053325072051109715, 0.05930827118715099, 0.05690898257817064, 0.06264420522789153, 0.07012136972692072, 0.06826655435531541, 0.07514357501242923], "val_acc": [0.9844, 0.9881, 0.9916, 0.9893, 0.993, 0.9919, 0.9919, 0.992, 0.9911, 0.9935, 0.9923, 0.9938, 0.9933, 0.9931, 0.9928, 0.9922, 0.9934, 0.9924, 0.9929, 0.9931, 0.9925, 0.9922, 0.9928, 0.9935, 0.9935, 0.9937, 0.9932, 0.9922, 0.9918, 0.9914], "loss": [0.17729154649289947, 0.04786290728921692, 0.03310972272466558, 0.024000238360309352, 0.019628610597027, 0.015570734798592943, 0.012905735168078535, 0.01112630117221476, 0.009569634471618307, 0.008479778113915093, 0.007240269398996497, 0.005718392407475357, 0.005387276505376697, 0.005427486746183755, 0.00405911881381312, 0.004863275600015732, 0.0036052914443208163, 0.004098738280737507, 0.0034470039245908233, 0.002539891782689559, 0.003724600748225195, 0.003396888139204172, 0.0033139408396933127, 0.0022997476812821713, 0.002101743097013415, 0.0023254889946794643, 0.0026494044639331226, 0.001954368177306363, 0.0031213346675570846, 0.002741553527614049], "acc": [0.9444833333333333, 0.9851166666666666, 0.9901333333333333, 0.9928666666666667, 0.9939666666666667, 0.9954833333333334, 0.9959333333333333, 0.99675, 0.99715, 0.9973833333333333, 0.99775, 0.99835, 0.9984666666666666, 0.9985833333333334, 0.9989166666666667, 0.9988666666666667, 0.999, 0.9988166666666667, 0.9991666666666666, 0.9994, 0.9989666666666667, 0.9991166666666667, 0.9991666666666666, 0.9993833333333333, 0.9995333333333334, 0.9994666666666666, 0.9993666666666666, 0.99965, 0.9994, 0.9993666666666666]}, "performance": {"with_svm@-3": 0.9943, "with_svm@-2": 0.9932}},
3 | {"history": {"val_loss": [0.04693662054911256, 0.040643049051705744, 0.03836872058830922, 0.025879924827218202, 0.03365807937731297, 0.05606620579953305, 0.03101422648470616, 0.033968908638895845, 0.03749096577653436, 0.03477393601403474, 0.03989164399835893, 0.04330464110312171, 0.052758552326085666, 0.046145228933404746, 0.05790804634753588, 0.039930098790110356, 0.05093961141671082, 0.05108419878352672, 0.06115888950322489, 0.05455096907358804, 0.0614969135381117, 0.05646109203540775, 0.06690327637175754, 0.05403370705581249, 0.05814710223793893, 0.06744184541873186, 0.07847057491659128, 0.07086941969375286, 0.06166929620088142, 0.06629630808312378], "val_acc": [0.9848, 0.9878, 0.9886, 0.9912, 0.9903, 0.9857, 0.9917, 0.9911, 0.9904, 0.9927, 0.9917, 0.9919, 0.9917, 0.991, 0.9908, 0.9943, 0.992, 0.9926, 0.9918, 0.993, 0.9914, 0.9927, 0.9924, 0.9931, 0.9926, 0.993, 0.9908, 0.9917, 0.9937, 0.9926], "loss": [0.17716153866623838, 0.04742282863749812, 0.03301722237803042, 0.025223383712318415, 0.02021720212032475, 0.01563415939128075, 0.013349200006359025, 0.011451028558874896, 0.009647474916346982, 0.007851272325570154, 0.00654441191043555, 0.006293635521197999, 0.00570135567472183, 0.004649824481886086, 0.0051237779484377445, 0.0044031188274873175, 0.0044619473003985905, 0.004122088544306773, 0.0037621076239143275, 0.003428344476895821, 0.003696741009939372, 0.0028606461014387377, 0.0029217197914563407, 0.00375240520837721, 0.0035898833562101116, 0.003996992303055117, 0.0024696943116979354, 0.0015655488069841491, 0.002664865174772005, 0.003095264735105055], "acc": [0.94555, 0.9851666666666666, 0.99, 0.9922, 0.9939, 0.9950333333333333, 0.9957, 0.9967, 0.9972666666666666, 0.9976666666666667, 0.9979, 0.9982333333333333, 0.9984166666666666, 0.9986, 0.9986333333333334, 0.9988666666666667, 0.9987833333333334, 0.9988333333333334, 0.99895, 0.9990333333333333, 0.9991833333333333, 0.9992, 0.9992833333333333, 0.99925, 0.9992333333333333, 0.9990833333333333, 0.9994166666666666, 0.9995833333333334, 0.9995333333333334, 0.99935]}, "performance": {"with_svm@-3": 0.994, "with_svm@-2": 0.9934}},
4 | {"history": {"val_loss": [0.08197933812029659, 0.03546777663987596, 0.03341330313124054, 0.030251626034558284, 0.02884039005855011, 0.019837814035794873, 0.0317348777288682, 0.03626116072119066, 0.03655646794341883, 0.028666779191147088, 0.04297479794507308, 0.039568193257491224, 0.03202172394558131, 0.046409738939450426, 0.046686414847980924, 0.03982525034669272, 0.05124423405043328, 0.05239241906636428, 0.04661062939958031, 0.055866467449014294, 0.08729090600506854, 0.06305750775740036, 0.054992045577612614, 0.06343066945878484, 0.06362634397384556, 0.06779158965707477, 0.057760000991913536, 0.06725425281132523, 0.06973312734179865, 0.07662225833157338], "val_acc": [0.9743, 0.9881, 0.9894, 0.9909, 0.9913, 0.9933, 0.9915, 0.9914, 0.9912, 0.9932, 0.9917, 0.994, 0.9936, 0.9922, 0.9924, 0.9943, 0.9935, 0.9925, 0.9935, 0.9935, 0.9896, 0.993, 0.9933, 0.9928, 0.9921, 0.9922, 0.993, 0.9924, 0.9916, 0.9921], "loss": [0.16861701912457744, 0.046583191910013554, 0.03097322490364313, 0.023580448029361045, 0.018423834095839024, 0.015051568441031123, 0.012091321246521208, 0.010600623359851655, 0.009061368080696427, 0.007479306346113405, 0.00671511196024926, 0.005556289483133651, 0.0056452080929059855, 0.004374737399550767, 0.00513432675843607, 0.0039756081365345, 0.003821543846153865, 0.003121435883364787, 0.004367182645120101, 0.0028328010872671863, 0.002666306109979875, 0.0049319935316656256, 0.003590493196360785, 0.002394134152485617, 0.0033287287435602064, 0.002699324307992515, 0.0029116885613174265, 0.004292509903183471, 0.00416508137487399, 0.0020753812079482865], "acc": [0.9476, 0.9852833333333333, 0.9903333333333333, 0.9926, 0.99435, 0.9954166666666666, 0.99635, 0.9968833333333333, 0.9975166666666667, 0.99785, 0.99805, 0.99825, 0.9984666666666666, 0.9988, 0.9987333333333334, 0.999, 0.99905, 0.9991833333333333, 0.9990666666666667, 0.9992833333333333, 0.9993333333333333, 0.99875, 0.99915, 0.9995666666666667, 0.9992166666666666, 0.9993833333333333, 0.9994166666666666, 0.9990666666666667, 0.9992666666666666, 0.99955]}, "performance": {"with_svm@-3": 0.9942, "with_svm@-2": 0.9939}},
5 | {"history": {"val_loss": [0.04706196300601587, 0.031698632218583954, 0.03227365807255264, 0.02973660909422033, 0.03172424109021194, 0.029596417490038582, 0.024765829553188814, 0.03213707902151118, 0.04085539187522227, 0.03789611699151364, 0.056589442667745536, 0.05191770297432388, 0.042685455790007584, 0.04466887094765375, 0.059763871695658895, 0.05341054688095057, 0.04867746317094354, 0.046260142465698234, 0.05594641119233786, 0.05880422080311575, 0.05863813357189838, 0.0646812159224218, 0.07454232090098196, 0.06272837779143121, 0.0716926159349325, 0.060035764309630875, 0.07311587181051857, 0.07722717706860649, 0.06542106488027886, 0.07567073885974498], "val_acc": [0.9843, 0.9892, 0.9915, 0.9915, 0.9916, 0.9917, 0.9923, 0.9932, 0.9911, 0.9921, 0.9892, 0.9917, 0.9928, 0.9922, 0.9918, 0.9921, 0.9919, 0.9933, 0.9918, 0.9916, 0.9923, 0.9925, 0.9914, 0.9922, 0.9916, 0.9921, 0.9916, 0.9918, 0.9934, 0.9923], "loss": [0.17061692195584377, 0.046234829489390054, 0.03182135101857905, 0.02431265958369865, 0.018748356542420515, 0.015698805837876473, 0.012273000637336372, 0.010425754852034274, 0.008624888020616497, 0.007014024585414942, 0.007216674821420172, 0.005889621360516246, 0.005983667603303639, 0.004918486540910453, 0.00394124438706978, 0.0043690883281159885, 0.0042438452046275115, 0.003430537353717273, 0.0030983400962485198, 0.003630599020303695, 0.0027655376554707573, 0.00308147610446437, 0.0030024490120792962, 0.003198153611999343, 0.0034311861089857, 0.0029338489132527, 0.0032705091191530073, 0.0022210830026674635, 0.0030188847804404834, 0.0033922651689746847], "acc": [0.9471, 0.9858333333333333, 0.9898333333333333, 0.99235, 0.9941, 0.9952166666666666, 0.9963, 0.99675, 0.99725, 0.9977833333333334, 0.9979333333333333, 0.9984166666666666, 0.9983833333333333, 0.9987166666666667, 0.9988833333333333, 0.99865, 0.9988166666666667, 0.9991166666666667, 0.9991833333333333, 0.9991, 0.9992, 0.9992, 0.9993833333333333, 0.9992666666666666, 0.9992666666666666, 0.9994666666666666, 0.9992833333333333, 0.99935, 0.99935, 0.9993833333333333]}, "performance": {"with_svm@-3": 0.9933, "with_svm@-2": 0.9933}}]
--------------------------------------------------------------------------------
/keras_svm/model_svm_wrapper.py:
--------------------------------------------------------------------------------
1 | from keras.models import Model
2 | from sklearn.svm import SVC
3 | from keras.utils import to_categorical
4 |
5 |
6 | class ModelSVMWrapper:
7 | """
8 | Linear stack of layers with the option to replace the end of the stack with a Support Vector Machine
9 | # Arguments
10 | layers: list of layers to add to the model.
11 | svm: The Support Vector Machine to use.
12 | """
13 | def __init__(self, model, svm=None):
14 | super().__init__()
15 |
16 | self.model = model
17 | self.intermediate_model = None # type: Model
18 | self.svm = svm
19 |
20 | if svm is None:
21 | self.svm = SVC(kernel='linear')
22 |
23 | def add(self, layer):
24 | return self.model.add(layer)
25 |
26 | def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.,
27 | validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0,
28 | steps_per_epoch=None, validation_steps=None, **kwargs):
29 | """
30 | Trains the model for a fixed number of epochs (iterations on a dataset).
31 |
32 | # Arguments
33 | x: Numpy array of training data.
34 | If the input layer in the model is named, you can also pass a
35 | dictionary mapping the input name to a Numpy array.
36 | `x` can be `None` (default) if feeding from
37 | framework-native tensors (e.g. TensorFlow data tensors).
38 | y: Numpy array of target (label) data.
39 | If the output layer in the model is named, you can also pass a
40 | dictionary mapping the output name to a Numpy array.
41 | `y` can be `None` (default) if feeding from
42 | framework-native tensors (e.g. TensorFlow data tensors).
43 | batch_size: Integer or `None`.
44 | Number of samples per gradient update.
45 | If unspecified, it will default to 32.
46 | epochs: Integer. Number of epochs to train the model.
47 | An epoch is an iteration over the entire `x` and `y`
48 | data provided.
49 | Note that in conjunction with `initial_epoch`,
50 | `epochs` is to be understood as "final epoch".
51 | The model is not trained for a number of iterations
52 | given by `epochs`, but merely until the epoch
53 | of index `epochs` is reached.
54 | verbose: 0, 1, or 2. Verbosity mode.
55 | 0 = silent, 1 = progress bar, 2 = one line per epoch.
56 | callbacks: List of `keras.callbacks.Callback` instances.
57 | List of callbacks to apply during training.
58 | See [callbacks](/callbacks).
59 | validation_split: Float between 0 and 1.
60 | Fraction of the training data to be used as validation data.
61 | The model will set apart this fraction of the training data,
62 | will not train on it, and will evaluate
63 | the loss and any model metrics
64 | on this data at the end of each epoch.
65 | The validation data is selected from the last samples
66 | in the `x` and `y` data provided, before shuffling.
67 | validation_data: tuple `(x_val, y_val)` or tuple
68 | `(x_val, y_val, val_sample_weights)` on which to evaluate
69 | the loss and any model metrics at the end of each epoch.
70 | The model will not be trained on this data.
71 | This will override `validation_split`.
72 | shuffle: Boolean (whether to shuffle the training data
73 | before each epoch) or str (for 'batch').
74 | 'batch' is a special option for dealing with the
75 | limitations of HDF5 data; it shuffles in batch-sized chunks.
76 | Has no effect when `steps_per_epoch` is not `None`.
77 | class_weight: Optional dictionary mapping class indices (integers)
78 | to a weight (float) value, used for weighting the loss function
79 | (during training only).
80 | This can be useful to tell the model to
81 | "pay more attention" to samples from
82 | an under-represented class.
83 | sample_weight: Optional Numpy array of weights for
84 | the training samples, used for weighting the loss function
85 | (during training only). You can either pass a flat (1D)
86 | Numpy array with the same length as the input samples
87 | (1:1 mapping between weights and samples),
88 | or in the case of temporal data,
89 | you can pass a 2D array with shape
90 | `(samples, sequence_length)`,
91 | to apply a different weight to every timestep of every sample.
92 | In this case you should make sure to specify
93 | `sample_weight_mode="temporal"` in `compile()`.
94 | initial_epoch: Epoch at which to start training
95 | (useful for resuming a previous training run).
96 | steps_per_epoch: Total number of steps (batches of samples)
97 | before declaring one epoch finished and starting the
98 | next epoch. When training with input tensors such as
99 | TensorFlow data tensors, the default `None` is equal to
100 | the number of samples in your dataset divided by
101 | the batch size, or 1 if that cannot be determined.
102 | validation_steps: Only relevant if `steps_per_epoch`
103 | is specified. Total number of steps (batches of samples)
104 | to validate before stopping.
105 |
106 | # Returns
107 | A `History` object. Its `History.history` attribute is
108 | a record of training loss values and metrics values
109 | at successive epochs, as well as validation loss values
110 | and validation metrics values (if applicable).
111 |
112 | # Raises
113 | RuntimeError: If the model was never compiled.
114 | ValueError: In case of mismatch between the provided input data
115 | and what the model expects.
116 | """
117 | fit = self.model.fit(x, to_categorical(y), batch_size, epochs, verbose, callbacks, validation_split,
118 | validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch,
119 | validation_steps, **kwargs)
120 |
121 | self.fit_svm(x, y, self.__get_split_layer())
122 |
123 | return fit
124 |
125 | def fit_svm(self, x, y, split_layer):
126 | """
127 | Fits the SVM on the data without changing the under laying neural network
128 | # Arguments
129 | x: Numpy array of training data.
130 | If the input layer in the model is named, you can also pass a
131 | dictionary mapping the input name to a Numpy array.
132 | `x` can be `None` (default) if feeding from
133 | framework-native tensors (e.g. TensorFlow data tensors).
134 | y: Numpy array of target (label) data.
135 | If the output layer in the model is named, you can also pass a
136 | dictionary mapping the output name to a Numpy array.
137 | `y` can be `None` (default) if feeding from
138 | framework-native tensors (e.g. TensorFlow data tensors).
139 | split_layer:
140 | The layer to split on
141 | """
142 | # Store intermediate model
143 | self.intermediate_model = Model(inputs=self.model.input,
144 | outputs=split_layer.output)
145 | # Use output of intermediate model to train SVM
146 | intermediate_output = self.intermediate_model.predict(x)
147 | self.svm.fit(intermediate_output, y)
148 |
149 | def evaluate(self, x=None, y=None, batch_size=None, verbose=1, steps=None):
150 | """
151 | Computes the accuracy on some input data, batch by batch.
152 |
153 | # Arguments
154 | x: input data, as a Numpy array or list of Numpy arrays
155 | (if the model has multiple inputs).
156 | `x` can be `None` (default) if feeding from
157 | framework-native tensors (e.g. TensorFlow data tensors).
158 | y: labels, as a Numpy array.
159 | `y` can be `None` (default) if feeding from
160 | framework-native tensors (e.g. TensorFlow data tensors).
161 | batch_size: Integer. If unspecified, it will default to 32.
162 | verbose: verbosity mode, 0 or 1.
163 | sample_weight: sample weights, as a Numpy array.
164 | steps: Integer or `None`.
165 | Total number of steps (batches of samples)
166 | before declaring the evaluation round finished.
167 | Ignored with the default value of `None`.
168 |
169 | # Returns
170 | Accuracy of the model mappin x to y
171 |
172 | # Raises
173 | RuntimeError: if the model was never compiled.
174 | """
175 |
176 | if self.intermediate_model is None:
177 | raise Exception("A model must be fit before running evaluate")
178 | output = self.predict(x, batch_size, verbose, steps)
179 | correct = [output[i] == y[i]
180 | for i in range(len(output))]
181 |
182 | accuracy = sum(correct) / len(correct)
183 |
184 | return accuracy
185 |
186 | def predict(self, x, batch_size=None, verbose=0, steps=None):
187 | """
188 | Computes the loss on some input data, batch by batch.
189 |
190 | # Arguments
191 | x: input data, as a Numpy array or list of Numpy arrays
192 | (if the model has multiple inputs).
193 | `x` can be `None` (default) if feeding from
194 | framework-native tensors (e.g. TensorFlow data tensors).
195 | y: labels, as a Numpy array.
196 | `y` can be `None` (default) if feeding from
197 | framework-native tensors (e.g. TensorFlow data tensors).
198 | batch_size: Integer. If unspecified, it will default to 32.
199 | verbose: verbosity mode, 0 or 1.
200 | sample_weight: sample weights, as a Numpy array.
201 | steps: Integer or `None`.
202 | Total number of steps (batches of samples)
203 | before declaring the evaluation round finished.
204 | Ignored with the default value of `None`.
205 |
206 | # Returns
207 | Scalar test loss (if the model has no metrics)
208 | or list of scalars (if the model computes other metrics).
209 | The attribute `model.metrics_names` will give you
210 | the display labels for the scalar outputs.
211 |
212 | # Raises
213 | RuntimeError: if the model was never compiled.
214 | """
215 | intermediate_prediction = self.intermediate_model.predict(x, batch_size, verbose, steps)
216 | output = self.svm.predict(intermediate_prediction)
217 |
218 | return output
219 |
220 | def __get_split_layer(self):
221 | """
222 | Gets the layer to split on either "split_layer" or the second to last layer.
223 |
224 | :return: The layer to split on: from where the svm must replace the existing model.
225 | :raises ValueError: If not enough layers exist for a good split (at least two required)
226 | """
227 | if len(self.model.layers) < 3:
228 | raise ValueError('self.layers to small for a relevant split')
229 |
230 | for layer in self.model.layers:
231 | if layer.name == "split_layer":
232 | return layer
233 |
234 | # if no specific cut of point is specified we can assume we need to remove only the last (softmax) layer
235 | return self.model.layers[-3]
236 |
--------------------------------------------------------------------------------
/experiments/deep_learning_paper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "MvPDrkNDsyO9"
8 | },
9 | "source": []
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 17,
14 | "metadata": {
15 | "colab": {
16 | "autoexec": {
17 | "startup": false,
18 | "wait_interval": 0
19 | },
20 | "base_uri": "https://localhost:8080/",
21 | "height": 170
22 | },
23 | "colab_type": "code",
24 | "executionInfo": {
25 | "elapsed": 3158,
26 | "status": "ok",
27 | "timestamp": 1525036626944,
28 | "user": {
29 | "displayName": "Daan Luttik",
30 | "photoUrl": "//lh3.googleusercontent.com/-CPv5nanSWKo/AAAAAAAAAAI/AAAAAAAAE3I/gkG30GZ_TJs/s50-c-k-no/photo.jpg",
31 | "userId": "106328974539959585216"
32 | },
33 | "user_tz": -120
34 | },
35 | "id": "GaBXH7PZiJsK",
36 | "outputId": "4c5b7299-ab38-4082-fa49-527be31d49a0"
37 | },
38 | "outputs": [
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "Looking in indexes: https://pypi.org/simple, https://legacy.pypi.org/simple\n",
44 | "Requirement already up-to-date: keras-svm in /usr/local/lib/python3.6/dist-packages (1.0.0b10)\n",
45 | "Requirement not upgraded as not directly required: scikit-learn in /usr/local/lib/python3.6/dist-packages (from keras-svm) (0.19.1)\n",
46 | "Requirement not upgraded as not directly required: keras in /usr/local/lib/python3.6/dist-packages (from keras-svm) (2.1.6)\n",
47 | "Requirement not upgraded as not directly required: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (3.12)\n",
48 | "Requirement not upgraded as not directly required: h5py in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (2.7.1)\n",
49 | "Requirement not upgraded as not directly required: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (1.11.0)\n",
50 | "Requirement not upgraded as not directly required: numpy>=1.9.1 in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (1.14.2)\n",
51 | "Requirement not upgraded as not directly required: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from keras->keras-svm) (0.19.1)\n"
52 | ]
53 | }
54 | ],
55 | "source": [
56 | "# I've created a pip package containing a wrapper for the model.\n",
57 | "!pip install --upgrade keras-svm"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "colab": {
65 | "autoexec": {
66 | "startup": false,
67 | "wait_interval": 0
68 | }
69 | },
70 | "colab_type": "code",
71 | "id": "AyMkuF7nsips"
72 | },
73 | "outputs": [],
74 | "source": [
75 | "import keras\n",
76 | "from keras_svm.model_svm_wrapper import ModelSVMWrapper\n",
77 | "from keras import layers, models, backend\n",
78 | "from keras.datasets import mnist, fashion_mnist\n",
79 | "from keras.utils import to_categorical\n",
80 | "from keras.models import Model\n",
81 | "from keras.engine.topology import Layer\n",
82 | "import matplotlib.pyplot as plt\n",
83 | "from google.colab import files\n",
84 | "import pickle\n",
85 | "import json"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {
92 | "colab": {
93 | "autoexec": {
94 | "startup": false,
95 | "wait_interval": 0
96 | }
97 | },
98 | "colab_type": "code",
99 | "id": "LGqkiYtkCPnY"
100 | },
101 | "outputs": [],
102 | "source": [
103 | "(train_images, train_labels), (test_images, test_labels) = \\\n",
104 | " fashion_mnist.load_data() \n",
105 | " # or mnist.load_data()\n",
106 | "\n",
107 | "train_images = train_images.reshape((60000, 28, 28, 1))\n",
108 | "train_images = train_images.astype('float32') / 255\n",
109 | "\n",
110 | "test_images = test_images.reshape((10000, 28, 28, 1))\n",
111 | "test_images = test_images.astype('float32') / 255"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {
117 | "colab_type": "text",
118 | "id": "V7oJ5coGo7Oe"
119 | },
120 | "source": [
121 | "# Build a generic CNN\n",
122 | "\n",
123 | "----\n",
124 | "(based on https://github.com/fchollet/deep-learning-with-python-notebooks)\n",
125 | "\n",
126 | "Creating a simple CNN with three convolutional layers."
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {
133 | "colab": {
134 | "autoexec": {
135 | "startup": false,
136 | "wait_interval": 0
137 | }
138 | },
139 | "colab_type": "code",
140 | "id": "0wmrpcg2o7Oe"
141 | },
142 | "outputs": [],
143 | "source": [
144 | "def build_model():\n",
145 | " model = models.Sequential()\n",
146 | " model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))\n",
147 | " model.add(layers.MaxPooling2D((2, 2)))\n",
148 | " model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n",
149 | " model.add(layers.MaxPooling2D((2, 2)))\n",
150 | " model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n",
151 | " model.add(layers.Flatten())\n",
152 | " \n",
153 | " model.add(layers.Dense(64, activation='relu'))\n",
154 | " model.add(layers.Dense(10, activation='softmax'))\n",
155 | " model.compile(optimizer='rmsprop',\n",
156 | " loss='categorical_crossentropy',\n",
157 | " metrics=['accuracy'])\n",
158 | " return model\n",
159 | "\n",
160 | "wrapper = ModelSVMWrapper(build_model())"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {
166 | "colab_type": "text",
167 | "id": "d6LQYxH4o7Oh"
168 | },
169 | "source": [
170 | "Let's display the architecture of our convnet so far:"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 21,
176 | "metadata": {
177 | "colab": {
178 | "autoexec": {
179 | "startup": false,
180 | "wait_interval": 0
181 | },
182 | "base_uri": "https://localhost:8080/",
183 | "height": 408
184 | },
185 | "colab_type": "code",
186 | "executionInfo": {
187 | "elapsed": 417,
188 | "status": "ok",
189 | "timestamp": 1525036629639,
190 | "user": {
191 | "displayName": "Daan Luttik",
192 | "photoUrl": "//lh3.googleusercontent.com/-CPv5nanSWKo/AAAAAAAAAAI/AAAAAAAAE3I/gkG30GZ_TJs/s50-c-k-no/photo.jpg",
193 | "userId": "106328974539959585216"
194 | },
195 | "user_tz": -120
196 | },
197 | "id": "7UaZvzRyzUIo",
198 | "outputId": "b2f985ce-274b-4235-fed2-9ba088be73de"
199 | },
200 | "outputs": [
201 | {
202 | "name": "stdout",
203 | "output_type": "stream",
204 | "text": [
205 | "_________________________________________________________________\n",
206 | "Layer (type) Output Shape Param # \n",
207 | "=================================================================\n",
208 | "conv2d_19 (Conv2D) (None, 26, 26, 32) 320 \n",
209 | "_________________________________________________________________\n",
210 | "max_pooling2d_13 (MaxPooling (None, 13, 13, 32) 0 \n",
211 | "_________________________________________________________________\n",
212 | "conv2d_20 (Conv2D) (None, 11, 11, 64) 18496 \n",
213 | "_________________________________________________________________\n",
214 | "max_pooling2d_14 (MaxPooling (None, 5, 5, 64) 0 \n",
215 | "_________________________________________________________________\n",
216 | "conv2d_21 (Conv2D) (None, 3, 3, 64) 36928 \n",
217 | "_________________________________________________________________\n",
218 | "flatten_7 (Flatten) (None, 576) 0 \n",
219 | "_________________________________________________________________\n",
220 | "dense_13 (Dense) (None, 64) 36928 \n",
221 | "_________________________________________________________________\n",
222 | "dense_14 (Dense) (None, 10) 650 \n",
223 | "=================================================================\n",
224 | "Total params: 93,322\n",
225 | "Trainable params: 93,322\n",
226 | "Non-trainable params: 0\n",
227 | "_________________________________________________________________\n"
228 | ]
229 | }
230 | ],
231 | "source": [
232 | "wrapper.model.summary()"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "metadata": {
238 | "colab_type": "text",
239 | "id": "XRy9LUEio7Os"
240 | },
241 | "source": [
242 | "Prepare the data for the CNN"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "metadata": {
248 | "colab_type": "text",
249 | "id": "9F7NHlli9nq6"
250 | },
251 | "source": [
252 | "Train model and store intermediate test results"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 22,
258 | "metadata": {
259 | "colab": {
260 | "autoexec": {
261 | "startup": false,
262 | "wait_interval": 0
263 | },
264 | "base_uri": "https://localhost:8080/",
265 | "height": 1244
266 | },
267 | "colab_type": "code",
268 | "executionInfo": {
269 | "elapsed": 3310064,
270 | "status": "ok",
271 | "timestamp": 1525039939835,
272 | "user": {
273 | "displayName": "Daan Luttik",
274 | "photoUrl": "//lh3.googleusercontent.com/-CPv5nanSWKo/AAAAAAAAAAI/AAAAAAAAE3I/gkG30GZ_TJs/s50-c-k-no/photo.jpg",
275 | "userId": "106328974539959585216"
276 | },
277 | "user_tz": -120
278 | },
279 | "id": "UyLwib_Jo7Ow",
280 | "outputId": "4ccc51a5-f194-474a-cad9-be204e8982d6"
281 | },
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | "Starting epoch: 1\n",
288 | "Epoch 1/1\n",
289 | "60000/60000 [==============================] - 12s 195us/step - loss: 0.5336 - acc: 0.8041\n",
290 | "10000/10000 [==============================] - 1s 88us/step\n",
291 | "10000/10000 [==============================] - 1s 131us/step\n",
292 | "10000/10000 [==============================] - 1s 94us/step\n",
293 | "{'with_svm@-2': [0.8876], 'with_svm@-3': [0.8999], 'without_svm': [0.8667]}\n",
294 | "Starting epoch: 2\n",
295 | "Epoch 1/1\n",
296 | "38912/60000 [==================>...........] - ETA: 3s - loss: 0.3295 - acc: 0.880160000/60000 [==============================] - 11s 187us/step - loss: 0.3198 - acc: 0.8834\n",
297 | "10000/10000 [==============================] - 1s 91us/step\n",
298 | "10000/10000 [==============================] - 1s 123us/step\n",
299 | "10000/10000 [==============================] - 1s 91us/step\n",
300 | "{'with_svm@-2': [0.8876, 0.8992], 'with_svm@-3': [0.8999, 0.9046], 'without_svm': [0.8667, 0.8758]}\n",
301 | "Starting epoch: 3\n",
302 | "Epoch 1/1\n",
303 | "51008/60000 [========================>.....] - ETA: 1s - loss: 0.2726 - acc: 0.900660000/60000 [==============================] - 11s 187us/step - loss: 0.2725 - acc: 0.9010\n",
304 | "10000/10000 [==============================] - 1s 89us/step\n",
305 | "10000/10000 [==============================] - 1s 125us/step\n",
306 | "10000/10000 [==============================] - 1s 93us/step\n",
307 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077], 'with_svm@-3': [0.8999, 0.9046, 0.9097], 'without_svm': [0.8667, 0.8758, 0.8939]}\n",
308 | "Starting epoch: 4\n",
309 | "Epoch 1/1\n",
310 | "50432/60000 [========================>.....] - ETA: 1s - loss: 0.2437 - acc: 0.911760000/60000 [==============================] - 11s 187us/step - loss: 0.2421 - acc: 0.9121\n",
311 | "10000/10000 [==============================] - 1s 95us/step\n",
312 | "10000/10000 [==============================] - 1s 118us/step\n",
313 | "10000/10000 [==============================] - 1s 94us/step\n",
314 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006]}\n",
315 | "Starting epoch: 5\n",
316 | "Epoch 1/1\n",
317 | "47744/60000 [======================>.......] - ETA: 2s - loss: 0.2189 - acc: 0.919460000/60000 [==============================] - 11s 188us/step - loss: 0.2188 - acc: 0.9189\n",
318 | "10000/10000 [==============================] - 1s 89us/step\n",
319 | "10000/10000 [==============================] - 1s 121us/step\n",
320 | "10000/10000 [==============================] - 1s 95us/step\n",
321 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903]}\n",
322 | "Starting epoch: 6\n",
323 | "Epoch 1/1\n",
324 | "44928/60000 [=====================>........] - ETA: 2s - loss: 0.2001 - acc: 0.926360000/60000 [==============================] - 11s 189us/step - loss: 0.2015 - acc: 0.9258\n",
325 | "10000/10000 [==============================] - 1s 90us/step\n",
326 | "10000/10000 [==============================] - 1s 123us/step\n",
327 | "10000/10000 [==============================] - 1s 96us/step\n",
328 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985]}\n",
329 | "Starting epoch: 7\n",
330 | "Epoch 1/1\n",
331 | "41600/60000 [===================>..........] - ETA: 3s - loss: 0.1851 - acc: 0.932860000/60000 [==============================] - 12s 195us/step - loss: 0.1851 - acc: 0.9325\n",
332 | "10000/10000 [==============================] - 1s 96us/step\n",
333 | "10000/10000 [==============================] - 1s 120us/step\n",
334 | "10000/10000 [==============================] - 1s 96us/step\n",
335 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041]}\n",
336 | "Starting epoch: 8\n",
337 | "Epoch 1/1\n",
338 | "40576/60000 [===================>..........] - ETA: 3s - loss: 0.1695 - acc: 0.938060000/60000 [==============================] - 11s 190us/step - loss: 0.1727 - acc: 0.9371\n",
339 | "10000/10000 [==============================] - 1s 91us/step\n",
340 | "10000/10000 [==============================] - 1s 123us/step\n",
341 | "10000/10000 [==============================] - 1s 97us/step\n",
342 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062]}\n",
343 | "Starting epoch: 9\n",
344 | "Epoch 1/1\n",
345 | "38912/60000 [==================>...........] - ETA: 3s - loss: 0.1588 - acc: 0.940760000/60000 [==============================] - 11s 188us/step - loss: 0.1580 - acc: 0.9412\n",
346 | "10000/10000 [==============================] - 1s 91us/step\n",
347 | "10000/10000 [==============================] - 1s 121us/step\n",
348 | "10000/10000 [==============================] - 1s 95us/step\n",
349 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145, 0.9122], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912, 0.9121], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062, 0.8882]}\n",
350 | "Starting epoch: 10\n",
351 | "Epoch 1/1\n",
352 | "37568/60000 [=================>............] - ETA: 4s - loss: 0.1453 - acc: 0.946360000/60000 [==============================] - 11s 187us/step - loss: 0.1467 - acc: 0.9462\n",
353 | "10000/10000 [==============================] - 1s 93us/step\n",
354 | "10000/10000 [==============================] - 1s 122us/step\n",
355 | "10000/10000 [==============================] - 1s 100us/step\n",
356 | "{'with_svm@-2': [0.8876, 0.8992, 0.9077, 0.912, 0.9115, 0.9109, 0.9122, 0.9145, 0.9122, 0.9123], 'with_svm@-3': [0.8999, 0.9046, 0.9097, 0.9095, 0.9111, 0.9115, 0.9116, 0.912, 0.9121, 0.9109], 'without_svm': [0.8667, 0.8758, 0.8939, 0.9006, 0.903, 0.8985, 0.9041, 0.9062, 0.8882, 0.9062]}\n",
357 | "performance4.json\n"
358 | ]
359 | }
360 | ],
361 | "source": [
362 | "for j in range(5): \n",
363 | " wrapper = ModelSVMWrapper(build_model())\n",
364 | "\n",
365 | " epochs = 10\n",
366 | " performance = {\n",
367 | " \"with_svm@-2\": [],\n",
368 | " \"with_svm@-3\": [],\n",
369 | " \"without_svm\": []\n",
370 | " }\n",
371 | " for i in range(epochs):\n",
372 | " print('Starting epoch: {}'.format(i + 1))\n",
373 | " wrapper.fit(train_images, train_labels, epochs=1, batch_size=64)\n",
374 | " performance[\"with_svm@-3\"].append(wrapper.evaluate(test_images, test_labels))\n",
375 | " performance[\"without_svm\"].append(\n",
376 | " wrapper.model.evaluate(test_images, to_categorical(test_labels))[1])\n",
377 | "\n",
378 | " # Try it for the different SVM\n",
379 | " wrapper.fit_svm(train_images, train_labels, wrapper.model.layers[-2])\n",
380 | " performance[\"with_svm@-2\"].append(wrapper.evaluate(test_images, test_labels))\n",
381 | " \n",
382 | " print(performance)\n",
383 | " filename = 'performance{}.json'.format(j)\n",
384 | " print(filename)\n",
385 | " with open(filename, 'w') as file:\n",
386 | " json.dump(performance, file)"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": null,
392 | "metadata": {
393 | "colab": {
394 | "autoexec": {
395 | "startup": false,
396 | "wait_interval": 0
397 | }
398 | },
399 | "colab_type": "code",
400 | "id": "GZFCd5D5dGPy"
401 | },
402 | "outputs": [],
403 | "source": [
404 | "for j in range(5): \n",
405 | " filename = 'performance{}.json'.format(j)\n",
406 | " files.download(filename)"
407 | ]
408 | }
409 | ],
410 | "metadata": {
411 | "accelerator": "GPU",
412 | "colab": {
413 | "collapsed_sections": [],
414 | "default_view": {},
415 | "name": "Deep Learning Paper",
416 | "provenance": [],
417 | "toc_visible": true,
418 | "version": "0.3.2",
419 | "views": {}
420 | },
421 | "kernelspec": {
422 | "display_name": "Python 3",
423 | "language": "python",
424 | "name": "python3"
425 | },
426 | "language_info": {
427 | "codemirror_mode": {
428 | "name": "ipython",
429 | "version": 3
430 | },
431 | "file_extension": ".py",
432 | "mimetype": "text/x-python",
433 | "name": "python",
434 | "nbconvert_exporter": "python",
435 | "pygments_lexer": "ipython3",
436 | "version": "3.6.4"
437 | }
438 | },
439 | "nbformat": 4,
440 | "nbformat_minor": 2
441 | }
442 |
--------------------------------------------------------------------------------