├── .gitignore ├── LICENSE ├── README.md └── fitted_learning.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yann Henon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fitted-learning 2 | This is a small Keras implementation of Fitted Learning: Models with Awareness of their Limits. 3 | It shows how to train a simple network that will over-generalize less than a more standard network. 4 | 5 | The dataset consists of point coordinates (x,y) from circles of different radius,as show in the image below 6 | ![dataset](http://i.imgur.com/kMsvO6m.png) 7 | 8 | The parameter DOO can be changed to tune the degree of generalization of the network, with smaller DOO meaning more generalization. 9 | A DOO of 1 is equivalent to a standard NN with a softmax layer and crossentropy loss. The following images show how the space is classified by the NN for different values of DOO. 10 | The dataset is shown in white rather than the original colors for visibility. 11 | 12 | Class 1 is blue, class 2 is green, class 3 is red. The black space shows regions where the probability is low for all classes. 13 | 14 | DOO=1 (standard NN) 15 | 16 | ![1](http://i.imgur.com/yI17Euk.png) 17 | 18 | 19 | DOO=2 20 | 21 | ![2](http://i.imgur.com/RRYHup6.png) 22 | 23 | 24 | DOO=6 25 | 26 | ![6](http://i.imgur.com/OmH56Vm.png) 27 | 28 | 29 | DOO=24 30 | 31 | ![24](http://i.imgur.com/kroBjfe.png) 32 | 33 | 34 | DOO=48 35 | 36 | ![48](http://i.imgur.com/Uasd2wP.png) 37 | -------------------------------------------------------------------------------- /fitted_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.models import Model 3 | from keras.layers import Dense, Activation, Flatten, Input 4 | import math 5 | import cv2 6 | 7 | 8 | def build_label(class_idx, n_classes, DOO): 9 | # returns the target for a training instance 10 | label = np.zeros((n_classes * DOO, )) 11 | for ii in range(DOO): 12 | label[ii * n_classes + class_idx] = 1.0 / DOO 13 | return label 14 | 15 | 16 | def infer(probs, DOO, n_classes): 17 | # infer from a test instance 18 | out = np.ones((n_classes,)) 19 | for ii in range(DOO): 20 | for jj in range(n_classes): 21 | out[jj] = out[jj] * probs[jj + ii * n_classes] * DOO 22 | return out 23 | 24 | batch_size = 128 25 | nb_classes = 3 26 | nb_epoch = 20 27 | 28 | DOO = 16 29 | 30 | input_layer = Input(shape=(2,)) 31 | x = Dense(64, activation='relu')(input_layer) 32 | x = Dense(128, activation='relu')(x) 33 | x = Dense(128, activation='relu')(x) 34 | x = Dense(128, activation='relu')(x) 35 | out = Dense(DOO * nb_classes, activation='softmax')(x) 36 | 37 | model = Model(inputs=input_layer, outputs=out) 38 | 39 | model.compile(loss='categorical_crossentropy', 40 | optimizer='adam') 41 | 42 | 43 | # Create the training data 44 | X = [] 45 | Y = [] 46 | 47 | rads = [1.0, 0.5, 0.25] 48 | 49 | for i in np.linspace(0, 2*math.pi, 1000): 50 | for ix, rad in enumerate(rads): 51 | x = rad * math.cos(i) 52 | y = rad * math.sin(i) 53 | X.append([x, y]) 54 | Y.append(build_label(ix, nb_classes, DOO)) 55 | 56 | X_train = np.array(X) 57 | Y_train = np.array(Y) 58 | 59 | model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, verbose=1) 60 | 61 | # plot the results 62 | x_min, x_max = -2, 2 63 | y_min, y_max = x_min, x_max 64 | h = 0.01 65 | num_px = int((x_max - x_min) / h) 66 | 67 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), 68 | np.arange(y_min, y_max, h)) 69 | 70 | 71 | Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) 72 | P = np.zeros((num_px*num_px, 3)) 73 | 74 | for i in range(Z.shape[0]): 75 | P[i, :] = infer(Z[i, :], DOO, nb_classes) 76 | 77 | P = np.array(P) 78 | P = np.reshape(P, xx.shape + (nb_classes,)) 79 | 80 | img = (255*P).astype(np.uint8) 81 | 82 | scale = 0.25 * (x_max - x_min) / h 83 | for rad in rads: 84 | cv2.circle(img, (num_px/2, num_px/2), int(rad * scale), (255,255,255), thickness=2) 85 | 86 | cv2.imshow('preds', img) 87 | cv2.waitKey(0) 88 | --------------------------------------------------------------------------------