├── raw └── .keep ├── test └── .keep ├── train └── .keep ├── checkpoints └── .keep ├── iris ├── __init__.py ├── log.py ├── dataset.py └── network.py ├── web ├── __init__.py ├── models.py ├── iris.py └── swagger.json ├── .gitignore ├── requirements.txt ├── iris-network-test.py ├── iris-network-predict.py ├── iris-network-download.py ├── iris-network-train.py └── README.md /raw/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iris/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /web/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *__pycache__ 4 | checkpoints/* 5 | test/*.csv 6 | train/*.csv 7 | raw/*.csv 8 | .DS_Store 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # By default I'm using python3 version of tensorflow. 2 | https://storage.googleapis.com/tensorflow/mac/tensorflow-0.6.0-py3-none-any.whl 3 | # https://storage.googleapis.com/tensorflow/mac/tensorflow-0.6.0-py2-none-any.whl 4 | 5 | # Used to download the iris dataset 6 | requests 7 | tqdm 8 | 9 | # Used by the REST API 10 | falcon 11 | gunicorn 12 | simplejson 13 | -------------------------------------------------------------------------------- /iris-network-test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from iris import network 4 | from iris import log 5 | 6 | 7 | _logger = log.get_logger() 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser( 12 | description="""Use the test dataset to check for the total accuracy of 13 | the highest iteration's checkpoint model.""") 14 | 15 | parser.add_argument( 16 | "--test-dir", 17 | default="./test", 18 | type=str, 19 | help="Directory containing CSV files used in testing.") 20 | parser.add_argument( 21 | "--checkpoint-dir", 22 | default="./checkpoints", 23 | type=str, 24 | help="Location to restore checkpoint files.") 25 | parser.add_argument("-v", "--verbosity", action="count", default=0) 26 | args = parser.parse_args() 27 | 28 | log.set_verbosity(args.verbosity) 29 | 30 | test_features, test_species = network.read_data_set( 31 | "{test_dir}/*.csv".format(test_dir=args.test_dir)) 32 | 33 | network.test( 34 | test_features, 35 | test_species, 36 | args.checkpoint_dir) 37 | -------------------------------------------------------------------------------- /iris/log.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import logging 4 | 5 | """ 6 | Base logging support which also edits the level of logging used in TensorFlow. 7 | """ 8 | 9 | # Default logger for keeping track of steps. 10 | _logger = logging.getLogger("iris") 11 | 12 | _handler = logging.StreamHandler() 13 | _handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT, None)) 14 | _logger.addHandler(_handler) 15 | 16 | 17 | def get_logger(): 18 | """ 19 | Allow access to the global logger used in the application. 20 | """ 21 | return _logger 22 | 23 | 24 | def set_verbosity(verbosity): 25 | """ 26 | Adjust this logger's verbosity level on a scale which is: 27 | 0 => Error only logging 28 | 1 => Info logging 29 | Anything else => Debug logging 30 | 31 | Parameters 32 | ---------- 33 | verbosity : int 34 | Level of messages to be reported. 35 | """ 36 | log_level = logging.DEBUG 37 | if verbosity == 0: 38 | log_level = logging.ERROR 39 | elif verbosity == 1: 40 | log_level = logging.INFO 41 | 42 | _logger.setLevel(log_level) 43 | tf.logging.set_verbosity(log_level) 44 | -------------------------------------------------------------------------------- /iris-network-predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from iris import network 4 | from iris import log 5 | 6 | 7 | _logger = log.get_logger() 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser( 12 | description="""Predict the species of an Iris based on its elements""") 13 | 14 | parser.add_argument( 15 | "--checkpoint-dir", 16 | default="./checkpoints", 17 | type=str, 18 | help="Location to restore checkpoint files.") 19 | parser.add_argument( 20 | "--feature", 21 | required=True, 22 | help=""" 23 | CSV list with 4 elements as sepal_length, sepal_width, 24 | petal_length and petal_width e.g. 25 | 1.5,2.3,4.5,6.7""", 26 | type=lambda a: [float(l) for l in a.split(",")]) 27 | parser.add_argument("-v", "--verbosity", action="count", default=0) 28 | args = parser.parse_args() 29 | 30 | log.set_verbosity(args.verbosity) 31 | 32 | features = [args.feature] 33 | y = network.predict(features, args.checkpoint_dir) 34 | 35 | for i in range(len(features)): 36 | feature = features[i] 37 | confidence = [round(p, 2) for p in y[0][i]] 38 | _logger.info( 39 | "Prediction for %s is %s with y of: %s", 40 | feature, 41 | y[1][i], 42 | confidence) 43 | -------------------------------------------------------------------------------- /iris-network-download.py: -------------------------------------------------------------------------------- 1 | from iris import dataset 2 | from iris import log 3 | 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser( 9 | description="""Download the Iris flower dataset from UCI to be used in 10 | training a feed forward neural network.""") 11 | parser.add_argument( 12 | "--raw-dir", 13 | default="./raw", 14 | type=str, 15 | help="Location to download the Iris flower dataset to.") 16 | parser.add_argument( 17 | "--test-dir", 18 | default="./test", 19 | type=str, 20 | help="Location to place the test dataset generated from raw data.") 21 | parser.add_argument( 22 | "--train-dir", 23 | default="./train", 24 | type=str, 25 | help="Location to place the train dataset generated from raw data.") 26 | parser.add_argument( 27 | "--use-backup-data-url", 28 | action="store_true", 29 | help="""Use an alternate location (other than UCI) to download 30 | IRIS data from.""") 31 | parser.add_argument("-v", "--verbosity", action="count", default=0) 32 | args = parser.parse_args() 33 | 34 | log.set_verbosity(args.verbosity) 35 | 36 | dataset.prepare( 37 | args.raw_dir, 38 | args.test_dir, 39 | args.train_dir, 40 | args.use_backup_data_url) 41 | -------------------------------------------------------------------------------- /iris-network-train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from iris import network 4 | from iris import log 5 | 6 | 7 | _logger = log.get_logger() 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser( 12 | description="""Train a neural network using the Iris dataset.""") 13 | 14 | parser.add_argument( 15 | "--train-dir", 16 | default="./train", 17 | type=str, 18 | help="Directory containing CSV files used in training.") 19 | parser.add_argument( 20 | "--checkpoint-dir", 21 | default="./checkpoints", 22 | type=str, 23 | help="Location to save checkpoint files.") 24 | parser.add_argument( 25 | "--checkpoint-save-every", 26 | type=int, 27 | help="Save a checkpoint every X iterations.") 28 | parser.add_argument( 29 | "--train-iterations", 30 | default=20000, 31 | type=int, 32 | help="Number of train iterations to run.") 33 | parser.add_argument("-v", "--verbosity", action="count", default=0) 34 | args = parser.parse_args() 35 | 36 | log.set_verbosity(args.verbosity) 37 | 38 | _logger.info( 39 | "Reading CSV files from %s.", args.train_dir) 40 | 41 | train_features, train_species = network.read_data_set( 42 | "{train_dir}/*.csv".format(train_dir=args.train_dir)) 43 | 44 | network.train( 45 | train_features, 46 | train_species, 47 | args.checkpoint_dir, 48 | args.train_iterations, 49 | args.checkpoint_save_every) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow: From CSV to API 2 | 3 | Code from [this tutorial](https://eerwitt.github.io/2016/01/14/tensorflow-from-csv-to-api/) about creating a REST API which uses a model trained using TensorFlow. The code goes through the entire process from the downloading of a CSV training file to hosting the model in an API. 4 | 5 | ## Installation 6 | 7 | Using Python 3 in a [virtual environment](https://virtualenv.readthedocs.org/en/latest/), install the required packages via [pip](https://pip.pypa.io/en/stable/). 8 | 9 | ```zsh 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | You might need to edit the version of TensorFlow in the requirements file. Details are found on [TensorFlow's installation guide](https://www.tensorflow.org/versions/0.6.0/get_started/os_setup.html). 14 | 15 | ## Running 16 | 17 | There are four commands which include a number of arguments. 18 | 19 | | Command | Explanation | 20 | | :--- | :--- | 21 | | `python iris-network-download.py -vv` | Download example Iris data from UCI and convert the CSV into a test and train dataset with the Iris species converted to a one-hot vector representation. | 22 | | `python iris-network-train.py -vv` | Train a Neural Network using feed forward learning and save checkpoint models to a directory. | 23 | | `python iris-network-test.py -vv` | Test the trained Neural Network to check for accuracy. | 24 | | `python iris-network-predict.py --feature 5.5,4.2,1.4,0.2 -vv` | Check the prediction for an Iris described by the features `5.5,4.2,1.4,0.2` which correspond to Sepal Length, Sepal Width, Petal Length and Petal Width. | 25 | | `gunicorn web.iris:api` | Start a `gunicorn` webserver to host the Falcon API. | 26 | 27 | ## Testing 28 | 29 | This code is not meant to be used in production and doesn't provide necessary tests to validate its functionality. 30 | 31 | If you're writing TensorFlow tests there [is a useful Python class](https://github.com/tensorflow/tensorflow/blob/f13d006e097c4e8010a4ad3ad2018a0369f5dc19/tensorflow/python/framework/test_util.py) which makes testing a graph fairly trivial. 32 | -------------------------------------------------------------------------------- /web/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Objects which describe data which the API responds with. This was developed 3 | based on the Swagger definitions found at: 4 | ./swagger.json or ./swagger.yml 5 | """ 6 | 7 | 8 | class Iris(object): 9 | """ 10 | See ./swagger.json#models for further information. 11 | """ 12 | def __init__(self, species=None, features=None): 13 | """ 14 | Create a new Iris model. 15 | 16 | Parameters 17 | ---------- 18 | species : str, optional 19 | The species of the Iris if it's known. When requesting a prediction 20 | for an unknown Iris species this value will be None. 21 | features : list(float[4]), optional 22 | Known features for this Iris to classify a species for. 23 | """ 24 | self.species = species 25 | 26 | if features is not None: 27 | self._sepal_length = features[0] 28 | self._sepal_width = features[1] 29 | self._petal_length = features[2] 30 | self._petal_width = features[3] 31 | else: 32 | self._sepal_length = None 33 | self._sepal_width = None 34 | self._petal_length = None 35 | self._petal_width = None 36 | 37 | def _asdict(self): 38 | return { 39 | 'species': self.species, 40 | 'sepal_length': self.sepal_length, 41 | 'sepal_width': self.sepal_width, 42 | 'petal_length': self.petal_length, 43 | 'petal_width': self.petal_width 44 | } 45 | 46 | @property 47 | def sepal_length(self): 48 | return self._sepal_length 49 | 50 | @property 51 | def sepal_width(self): 52 | return self._sepal_width 53 | 54 | @property 55 | def petal_length(self): 56 | return self._petal_length 57 | 58 | @property 59 | def petal_width(self): 60 | return self._petal_width 61 | 62 | 63 | class Prediction(object): 64 | """ 65 | See ./swagger.json#models for further information. 66 | """ 67 | def __init__(self, iris, y): 68 | self._iris = iris 69 | self._y = y 70 | 71 | def _asdict(self): 72 | return { 73 | 'iris': self.iris._asdict(), 74 | 'y': self.y 75 | } 76 | 77 | @property 78 | def iris(self): 79 | return self._iris 80 | 81 | @property 82 | def y(self): 83 | return self._y 84 | 85 | 86 | class PredictionRequest(object): 87 | """ 88 | See ./swagger.json#models for further information. 89 | """ 90 | def __init__(self, iris_features): 91 | self._uuid = None 92 | self._prediction = None 93 | self._status = "pending" 94 | self._iris_features = iris_features 95 | 96 | def _asdict(self): 97 | prediction = None 98 | if self.prediction: 99 | prediction = self.prediction._asdict() 100 | 101 | return { 102 | 'uuid': self.uuid, 103 | 'status': self.status, 104 | 'iris_features': self.iris_features._asdict(), 105 | 'prediction': prediction 106 | } 107 | 108 | @property 109 | def uuid(self): 110 | return self._uuid 111 | 112 | @uuid.setter 113 | def uuid(self, value): 114 | self._uuid = value 115 | 116 | @property 117 | def prediction(self): 118 | return self._prediction 119 | 120 | @prediction.setter 121 | def prediction(self, value): 122 | self._prediction = value 123 | self._status = "fulfilled" 124 | 125 | @property 126 | def status(self): 127 | return self._status 128 | 129 | @property 130 | def iris_features(self): 131 | return self._iris_features 132 | -------------------------------------------------------------------------------- /web/iris.py: -------------------------------------------------------------------------------- 1 | from iris import network 2 | from iris import dataset 3 | 4 | from web.models import Iris 5 | from web.models import Prediction 6 | from web.models import PredictionRequest 7 | 8 | import simplejson as json 9 | 10 | import falcon 11 | import uuid 12 | 13 | 14 | class PredictionRequestStorageEngine(object): 15 | """ 16 | To avoid extra dependencies, this is a fake storage engine based on the 17 | Falcon example docs. 18 | 19 | Notes 20 | ----- 21 | Never use this in production, it is in memory only! 22 | 23 | See Also 24 | -------- 25 | http://falcon.readthedocs.org/en/latest/user/quickstart.html 26 | """ 27 | def __init__(self): 28 | self._store = {} 29 | 30 | def get_prediction_request(self, prediction_request_uuid): 31 | """ 32 | Get a prediction request based on its UUID. 33 | 34 | Parameters 35 | ---------- 36 | prediction_request_uuid : str 37 | UUID for the requested PredictionRequest. 38 | 39 | Returns 40 | ------- 41 | prediction_request : PredictionRequest 42 | The PredictionRequest with the requested UUID. 43 | """ 44 | return self._store[prediction_request_uuid] 45 | 46 | def add_prediction_request(self, prediction_request): 47 | """ 48 | "Save" a prediction request into our in memory fake storage. 49 | 50 | Parameters 51 | ---------- 52 | prediction_request : PredictionRequest 53 | An Iris with features but no species to predict its species from. 54 | 55 | Returns 56 | ------- 57 | prediction_request : PredictionRequest 58 | The prediction request which was just "saved". 59 | """ 60 | prediction_request_uuid = str(uuid.uuid4()) 61 | prediction_request.uuid = prediction_request_uuid 62 | self._store[prediction_request_uuid] = prediction_request 63 | 64 | return prediction_request 65 | 66 | 67 | def allow_swagger_editor(req, res, resource): 68 | """ 69 | Allow cross origin requests from the Swagger Online editor at: 70 | http://editor.swagger.io/ 71 | 72 | Parameters 73 | ---------- 74 | req : FalconRequest 75 | Request being made through this Falcon middleware. 76 | res : FalconResponse 77 | Response to send back to clients from this Falcon middleware. 78 | resource : FalconResource 79 | Resource being requested. 80 | """ 81 | res.set_header( 82 | "Access-Control-Allow-Origin", "http://editor.swagger.io") 83 | res.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") 84 | res.set_header("Access-Control-Allow-Headers", "Content-Type") 85 | 86 | 87 | class PredictionRequestsResource(object): 88 | """ 89 | List resource for PredictionRequests. 90 | """ 91 | def __init__(self, db): 92 | self._db = db 93 | 94 | def on_post(self, req, res): 95 | """ 96 | Create a PredictionRequest based on the JSON body. 97 | 98 | See ./swagger.json for details. 99 | """ 100 | body = json.loads(req.stream.read().decode('utf8')) 101 | iris_test = Iris(features=[ 102 | body["sepal_length"], 103 | body["sepal_width"], 104 | body["petal_length"], 105 | body["petal_width"]]) 106 | prediction_request = PredictionRequest( 107 | iris_features=iris_test) 108 | 109 | self._db.add_prediction_request(prediction_request) 110 | 111 | res.status = falcon.HTTP_CREATED 112 | res.body = json.dumps( 113 | prediction_request._asdict(), 114 | use_decimal=True) 115 | 116 | 117 | class PredictionRequestResource(object): 118 | def __init__(self, db, onehot_species, net, sess): 119 | self._db = db 120 | self._onehot_species = onehot_species 121 | 122 | # Note, this session won't cleanup after itself without restarting the 123 | # webserver. 124 | self.net = net 125 | self.sess = sess 126 | 127 | def on_get(self, req, res, prediction_uuid): 128 | """ 129 | Get the results from a prediction based on a prediction request's 130 | features. 131 | 132 | See ./swagger.json for details. 133 | """ 134 | prediction_request = self._db.get_prediction_request(prediction_uuid) 135 | 136 | iris_test = prediction_request.iris_features 137 | 138 | y = network.predict_with_session( 139 | self.net, 140 | [[ 141 | iris_test.sepal_length, 142 | iris_test.sepal_width, 143 | iris_test.petal_length, 144 | iris_test.petal_width 145 | ]], 146 | self.sess) 147 | 148 | predicted_onehot = network.onehot_from_argmax(y[1]) 149 | predicted_species = self._onehot_species[predicted_onehot] 150 | 151 | iris_found = Iris(species=predicted_species) 152 | 153 | prediction_tensor = list(map(lambda p: round(float(p), 4), y[0][0])) 154 | prediction = Prediction( 155 | iris=iris_found, 156 | y=prediction_tensor) 157 | 158 | prediction_request.prediction = prediction 159 | 160 | res.status = falcon.HTTP_OK 161 | res.body = json.dumps( 162 | prediction_request._asdict(), 163 | use_decimal=True) 164 | 165 | 166 | # Global session for predicting with, if we reuse sessions then the checkpoint 167 | # restore will create duplicate variables and fail. 168 | sess, net = network.predict_init("./checkpoints") 169 | db = PredictionRequestStorageEngine() 170 | 171 | onehot_species = {} 172 | for species, onehot in dataset.read_species_onehot_csv("./raw"): 173 | onehot_species[onehot] = species 174 | 175 | api = falcon.API(after=[allow_swagger_editor]) 176 | api.add_route( 177 | '/predictionrequest/', 178 | PredictionRequestsResource(db)) 179 | 180 | api.add_route( 181 | '/predictionrequest/{prediction_uuid}', 182 | PredictionRequestResource(db, onehot_species, net, sess)) 183 | -------------------------------------------------------------------------------- /web/swagger.json: -------------------------------------------------------------------------------- 1 | { 2 | "swagger": "2.0", 3 | "host": "localhost:8000", 4 | "basePath": "/", 5 | "schemes": [ 6 | "http" 7 | ], 8 | "consumes": [ 9 | "application/json" 10 | ], 11 | "produces": [ 12 | "application/json" 13 | ], 14 | "externalDocs": { 15 | "description": "Tutorial this API was created based on.", 16 | "url": "https://eerwitt.github.io/2016/01/14/tensorflow-from-csv-to-api/" 17 | }, 18 | "info": { 19 | "version": "0.0.1a", 20 | "title": "Iris", 21 | "description": "Predict the species (class) of an Iris based on data provided by\n[Iris Flower Wiki](https://en.wikipedia.org/wiki/Iris_flower_data_set) and\n[UCI](https://archive.ics.uci.edu/ml/datasets/Iris).\n", 22 | "termsOfService": "None", 23 | "contact": { 24 | "name": "Erik", 25 | "url": "https://eerwitt.github.io" 26 | }, 27 | "license": { 28 | "name": "Beerware", 29 | "url": "https://en.wikipedia.org/wiki/Beerware" 30 | } 31 | }, 32 | "definitions": { 33 | "UnknownIris": { 34 | "$ref": "#/definitions/Iris" 35 | }, 36 | "Iris": { 37 | "type": "object", 38 | "description": "A type of flower.", 39 | "properties": { 40 | "species": { 41 | "description": "Species of Iris based on the Iris dataset.", 42 | "type": "string", 43 | "enum": [ 44 | "Iris Setosa", 45 | "Iris Versicolor", 46 | "Iris Virginica" 47 | ] 48 | }, 49 | "sepal_length": { 50 | "description": "Sepal Length of an Iris", 51 | "type": "number", 52 | "format": "float" 53 | }, 54 | "sepal_width": { 55 | "description": "Sepal Width of an Iris", 56 | "type": "number", 57 | "format": "float" 58 | }, 59 | "petal_length": { 60 | "description": "Petal Length of an Iris", 61 | "type": "number", 62 | "format": "float" 63 | }, 64 | "petal_width": { 65 | "description": "Petal Width of an Iris", 66 | "type": "number", 67 | "format": "float" 68 | } 69 | } 70 | }, 71 | "Prediction": { 72 | "type": "object", 73 | "description": "Results found from attempting to match a requested set of Iris features\nwith our trained model.\n", 74 | "properties": { 75 | "iris": { 76 | "description": "Class which most closely matched the requested features.", 77 | "$ref": "#/definitions/Iris" 78 | }, 79 | "y": { 80 | "type": "array", 81 | "description": "Y's rank 1 tensor of results in order Setsosa, Versicolor and\nVirginica.\n", 82 | "items": { 83 | "type": "number", 84 | "format": "float" 85 | } 86 | } 87 | } 88 | }, 89 | "PredictionRequest": { 90 | "type": "object", 91 | "description": "Request to predict the class of an Iris based on its attributes.\n", 92 | "properties": { 93 | "uuid": { 94 | "description": "UUID4 identifier for this PredictionRequest.", 95 | "type": "string", 96 | "format": "uuid" 97 | }, 98 | "prediction": { 99 | "$ref": "#/definitions/Prediction" 100 | }, 101 | "status": { 102 | "description": "Changes based on when this request is picked from a FIFO.", 103 | "type": "string", 104 | "enum": [ 105 | "fulfilled", 106 | "pending", 107 | "error" 108 | ] 109 | }, 110 | "iris_features": { 111 | "description": "An Iris with no class but petal and sepal information filled out.\n", 112 | "$ref": "#/definitions/Iris" 113 | } 114 | } 115 | } 116 | }, 117 | "paths": { 118 | "/predictionrequest/{uuid}": { 119 | "get": { 120 | "description": "`PredictionRequest`s are not returned immediately. They are processed\nin a queue which require continual checking until they're ready.\n\nAs a `PredictionRequest` is fulfilled, it will change the status and\ninclude a `Prediction` which is the result generated by testing the\nfeatures against the trained model `y`.\n", 121 | "parameters": [ 122 | { 123 | "name": "uuid", 124 | "in": "path", 125 | "type": "string", 126 | "description": "UUID of prediction request.", 127 | "required": true 128 | } 129 | ], 130 | "responses": { 131 | "200": { 132 | "description": "Successful response", 133 | "schema": { 134 | "$ref": "#/definitions/PredictionRequest" 135 | } 136 | } 137 | } 138 | } 139 | }, 140 | "/predictionrequest/": { 141 | "post": { 142 | "parameters": [ 143 | { 144 | "name": "UnknownIris", 145 | "in": "body", 146 | "description": "Request a prediction of what type of Iris this class relates to.\n", 147 | "schema": { 148 | "$ref": "#/definitions/UnknownIris" 149 | }, 150 | "required": true 151 | } 152 | ], 153 | "description": "Request Iris to make a prediction based on a set of features.\n", 154 | "responses": { 155 | "201": { 156 | "description": "Successfully requested a prediction of the given features.", 157 | "schema": { 158 | "$ref": "#/definitions/PredictionRequest" 159 | } 160 | } 161 | } 162 | } 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /iris/dataset.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os.path 3 | import random 4 | import math 5 | import csv 6 | 7 | from tqdm import tqdm 8 | 9 | from collections import OrderedDict 10 | 11 | from iris import log 12 | 13 | """ 14 | Work with downloading the original Iris dataset and convert the CSV into the 15 | fields our model expects to parse. 16 | """ 17 | 18 | _logger = log.get_logger() 19 | 20 | _MAIN_IRIS_DATA_URL = \ 21 | "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" 22 | 23 | # A backup site including the iris dataset which is similar to that found on 24 | # UCI's site, included because archive.ics.uci.edu was having connection issues 25 | # while creating this script. 26 | _BACKUP_IRIS_DATA_URL = \ 27 | "http://cs.joensuu.fi/sipu/datasets/iris.data.txt" 28 | 29 | _SEPAL_CSV_FIELDNAMES = [ 30 | "SepalLength", 31 | "SepalWidth", 32 | "PetalLength", 33 | "PetalWidth", 34 | "Species", 35 | ] 36 | 37 | _ONEHOT_CSV_FIELDNAMES = [ 38 | "Species", 39 | "OneHot", 40 | ] 41 | 42 | 43 | def parse_raw_iris_csv(raw_csv_filename): 44 | """ 45 | A raw CSV is a file downloaded in the format expected from the UCI archive. 46 | 47 | Parameters 48 | ---------- 49 | raw_csv_filename : str 50 | Relative filename of the CSV file downloaded from UCI. 51 | 52 | Yields 53 | ------ 54 | row : dict 55 | Each row of the CSV with keys which match _SEPAL_CSV_FIELDNAMES. 56 | """ 57 | _logger.info("Opening raw CSV file %s.", raw_csv_filename) 58 | with open(raw_csv_filename, "r") as raw_csv_file: 59 | reader = csv.DictReader( 60 | raw_csv_file, 61 | fieldnames=_SEPAL_CSV_FIELDNAMES) 62 | 63 | for row in reader: 64 | _logger.debug("Raw Row: %s", row) 65 | yield row 66 | 67 | 68 | def download_iris_data(output_location, use_backup_iris_data_url): 69 | """ 70 | Download the Iris dataset using the requests library and showing progress 71 | updates using tqdm. 72 | 73 | Parameters 74 | ---------- 75 | output_location : str 76 | Directory to save the downloaded CSV into. The filename will be 77 | overwritten with a name of iris-data-raw.csv. 78 | use_backup_iris_data_url : bool 79 | If UCI's archive is offline, there is a backup URL which can be used to 80 | download the information from. 81 | 82 | Returns 83 | ------- 84 | output_filename : str 85 | The filename of downloaded file or the existing file if it already 86 | exists. 87 | downloaded : bool 88 | If no raw CSV is found, the file will be downloaded and this will return 89 | True, otherwise it is False. 90 | 91 | Notes 92 | ----- 93 | It's important to notice that if the file exists already, it won't be 94 | downloaded a second time. 95 | """ 96 | output_filename = "{ol}/iris-data-raw.csv".format( 97 | ol=output_location) 98 | 99 | if os.path.exists(output_filename): 100 | _logger.info("Downloaded files already exist, skipping download.") 101 | return output_filename, False 102 | 103 | _logger.info("Downloading files to %s", output_filename) 104 | 105 | url = _MAIN_IRIS_DATA_URL 106 | if use_backup_iris_data_url: 107 | _logger.debug("Using backup URL.") 108 | url = _BACKUP_IRIS_DATA_URL 109 | 110 | iris_data_response = requests.get(url, stream=True) 111 | with open(output_filename, "wb") as iris_data_output: 112 | for block in tqdm(iris_data_response.iter_content()): 113 | iris_data_output.write(block) 114 | 115 | _logger.debug("Finished downloading new Iris data.") 116 | return output_filename, True 117 | 118 | 119 | def generate_onehot_from_species(raw_csv_filename): 120 | """ 121 | Parse the CSV file and convert each row's Species into a one-hot vector in 122 | order to be used by TensorFlow in our model. 123 | 124 | Parameters 125 | ---------- 126 | raw_csv_filename : str 127 | Relative path to the downloaded CSV file including information on Iris' 128 | from Fishker's dataset. 129 | 130 | Returns 131 | ------- 132 | unique_species : dict 133 | Mapping between Iris species names and their one-hot representation. 134 | """ 135 | unique_species = OrderedDict() 136 | for row in parse_raw_iris_csv(raw_csv_filename): 137 | species = row["Species"] 138 | if species not in unique_species: 139 | unique_species[species] = None 140 | 141 | unique_species_count = len(unique_species) 142 | 143 | i = 0 144 | for species_name in unique_species.keys(): 145 | unique_species[species_name] = \ 146 | "%0*d" % (unique_species_count, 10 ** i) 147 | i += 1 148 | 149 | return unique_species 150 | 151 | 152 | def write_sepal_csv(filename, rows): 153 | """ 154 | Writes a CSV in a format which we specify as: 155 | Sepal Lengh, Sepal Width, Petal Length, Petal Width, One-Hot Species 156 | 157 | Parameters 158 | ---------- 159 | filename : str 160 | Relative filename of output CSV. 161 | rows : dict 162 | Dictionary with keys matching _SEPAL_CSV_FIELDNAMES. 163 | 164 | Notes 165 | ----- 166 | See comments on using lineterminator, this is important when using tf 0.6. 167 | """ 168 | with open(filename, "w") as output_file: 169 | # Specifying the newline character is required with tf 0.6 otherwise 170 | # the CSV reader will fail to read our CSVs. 171 | writer = csv.DictWriter( 172 | output_file, 173 | lineterminator="\n", 174 | fieldnames=_SEPAL_CSV_FIELDNAMES) 175 | 176 | writer.writeheader() 177 | writer.writerows(rows) 178 | 179 | 180 | def write_species_onehot_csv(raw_dir, species_onehot): 181 | """ 182 | Writes a CSV in the format (Species, OneHot Representation) to match the 183 | one-hot version of the species class. 184 | 185 | Parameters 186 | ---------- 187 | raw_dir : str 188 | Relative location to store CSV which will be saved with the raw UCI 189 | download. 190 | species_onehot : OrderedDict 191 | An ordered dict mapping a species name to its one-hot representation as 192 | a string. 193 | 194 | Returns 195 | ------- 196 | filename : str 197 | The filename which the CSV was saved to. 198 | """ 199 | filename = "{dir}/species-onehot.csv".format(dir=raw_dir) 200 | with open(filename, "w") as output_file: 201 | writer = csv.DictWriter( 202 | output_file, 203 | lineterminator="\n", 204 | fieldnames=_ONEHOT_CSV_FIELDNAMES) 205 | 206 | writer.writeheader() 207 | for species, onehot in species_onehot.items(): 208 | writer.writerow({"Species": species, "OneHot": onehot}) 209 | 210 | return filename 211 | 212 | 213 | def read_species_onehot_csv(raw_dir): 214 | """ 215 | Reads the species one-hot vector from the CSV file which was saved while 216 | the raw UCI data was being prepared for training. 217 | 218 | Parameters 219 | ---------- 220 | raw_dir : str 221 | Directory containing the species one-hot CSV. 222 | 223 | Yields 224 | ------ 225 | species_onehot : tuple(species, onehot) 226 | Each row of the CSV translated into species and one-hot. 227 | """ 228 | filename = "{dir}/species-onehot.csv".format(dir=raw_dir) 229 | with open(filename, "r") as species_csv_file: 230 | reader = csv.DictReader( 231 | species_csv_file, 232 | fieldnames=_ONEHOT_CSV_FIELDNAMES) 233 | 234 | for row in reader: 235 | yield row["Species"], row["OneHot"] 236 | 237 | 238 | def split_test_train(raw_csv_filename, species_onehot, test_dir, train_dir): 239 | """ 240 | Split a CSV downloaded from UCI into a test and train dataset. We use 90% of 241 | the data for training and no cross validation. The reasoning is because 242 | there are only 150 examples in Fishker's dataset. 243 | 244 | Parameters 245 | ---------- 246 | raw_csv_filename : str 247 | Downloaded CSV from UCI. 248 | species_onehot : dict 249 | Dictionary with keys of Iris species and a value of their one-hot 250 | representation as a string. 251 | test_dir : str 252 | Relative directory to store our test dataset. 253 | train_dir : str 254 | Relative directory to store our train dataset. 255 | """ 256 | examples = [] 257 | for row in parse_raw_iris_csv(raw_csv_filename): 258 | current_species = row["Species"] 259 | iris_with_onehot = row 260 | # Convert the species to be the one-hot representation. 261 | iris_with_onehot["Species"] = species_onehot[current_species] 262 | 263 | examples.append(iris_with_onehot) 264 | 265 | random.shuffle(examples) 266 | number_of_examples = len(examples) 267 | 268 | # We're not doing a cross validation set because the Iris dataset is so 269 | # small. 270 | # Using 90% of the examples for training and the remaining for test. 271 | train_count = math.ceil(number_of_examples * 0.9) 272 | 273 | write_sepal_csv( 274 | "{train_dir}/iris-data-train.csv".format(train_dir=train_dir), 275 | examples[0:train_count]) 276 | write_sepal_csv( 277 | "{test_dir}/iris-data-test.csv".format(test_dir=test_dir), 278 | examples[train_count:]) 279 | 280 | 281 | def prepare(raw_dir, test_dir, train_dir, use_backup_iris_data_url=False): 282 | """ 283 | Prepare our test and train dataset by downloading raw Iris data, create 284 | one-hot representation of Iris species and split data into a test and train 285 | dataset. 286 | 287 | Parameters 288 | ---------- 289 | raw_dir : str 290 | Relative location to store CSV downloaded from UCI. 291 | test_dir : str 292 | Relative location to store test dataset. 293 | train_dir : str 294 | Relative location to store train dataset. 295 | use_backup_iris_data_url : bool, optional 296 | If UCI's Archive isn't responding, use a backup location which also 297 | hosts the same CSV. 298 | """ 299 | downloaded_filename, downloaded = download_iris_data( 300 | raw_dir, 301 | use_backup_iris_data_url) 302 | 303 | species_onehot = generate_onehot_from_species(downloaded_filename) 304 | 305 | # Saving the one-hot so other processes may use it. 306 | write_species_onehot_csv(raw_dir, species_onehot) 307 | 308 | split_test_train(downloaded_filename, species_onehot, test_dir, train_dir) 309 | -------------------------------------------------------------------------------- /iris/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from collections import namedtuple 4 | import glob 5 | 6 | from iris import log 7 | 8 | # Total guess based on the rule of thumb from this question: 9 | # http://stackoverflow.com/questions/10565868/multi-layer-perceptron-mlp-architecture-criteria-for-choosing-number-of-hidde 10 | NUM_HIDDEN = 21 11 | 12 | # Sepal Length, Sepal Width, Petal Length and Petal Width == 4 features. 13 | NUM_FEATURES = 4 14 | 15 | # Iris-setosa, Iris-versicolor and Iris-virginica == 3 16 | NUM_LABELS = 3 17 | 18 | _logger = log.get_logger() 19 | 20 | # Data structure used to keep track of important parameters used in training our 21 | # model. 22 | Topology = namedtuple( 23 | "Topology", 24 | "x y t w_hidden b_hidden hidden w_out b_out") 25 | 26 | 27 | def generate_weight(shape, name): 28 | """ 29 | TF variable filled with random values which follow a normal distribution. 30 | 31 | Parameters 32 | ---------- 33 | shape : 1-D Tensor or Array 34 | Corresponds to the shape parameter of tf.random_normal. 35 | name : str 36 | Variable name used in saving and restoring checkpoints. 37 | 38 | Returns 39 | ------- 40 | variable : tf.Variable 41 | TensorFlow variable which is filled with random values in the shape of 42 | the parameter "shape". 43 | 44 | Notes 45 | ----- 46 | See also: 47 | https://en.wikipedia.org/wiki/Normal_distribution 48 | 49 | Originally found as part of: 50 | https://github.com/nlintz/TensorFlow-Tutorials/blob/master/3_net.py#L7 51 | """ 52 | return tf.Variable( 53 | tf.random_normal( 54 | shape, 55 | stddev=0.01, 56 | dtype=tf.float32), name=name) 57 | 58 | 59 | def build(): 60 | """ 61 | Create the topology for a feed forward neural network, we use this to 62 | declare shared functionality between the train, test and predict functions. 63 | 64 | Returns 65 | ------- 66 | topology : Topology 67 | Set of variables which are required in training, testing and storing 68 | this model in checkpoint files. 69 | 70 | Notes 71 | ----- 72 | See also: 73 | https://www.tensorflow.org/versions/0.6.0/tutorials/mnist/pros/index.html#train-the-model 74 | """ 75 | num_hidden = NUM_HIDDEN 76 | num_features = NUM_FEATURES 77 | num_labels = NUM_LABELS 78 | 79 | x = tf.placeholder(tf.float32, shape=[None, num_features]) 80 | t = tf.placeholder(tf.float32, shape=[None, num_labels]) 81 | 82 | w_hidden = generate_weight([num_features, num_hidden], "w_hidden") 83 | b_hidden = generate_weight([1, num_hidden], "b_hidden") 84 | 85 | hidden = tf.nn.relu(tf.matmul(x, w_hidden) + b_hidden) 86 | 87 | w_out = generate_weight([num_hidden, num_labels], "w_out") 88 | b_out = generate_weight([1, num_labels], "b_out") 89 | 90 | y = tf.nn.softmax(tf.matmul(hidden, w_out) + b_out, name="y") 91 | 92 | return Topology( 93 | x=x, 94 | y=y, 95 | t=t, 96 | w_hidden=w_hidden, 97 | b_hidden=b_hidden, 98 | hidden=hidden, 99 | w_out=w_out, 100 | b_out=b_out) 101 | 102 | 103 | def _split_model_name_to_negative_numeric(model_name): 104 | """ 105 | Change a model file name into the integer value at the end of the filename. 106 | 107 | Parameters 108 | ---------- 109 | model_name : str 110 | A model filename from the checkpoints directory. 111 | 112 | Returns 113 | ------- 114 | model_number : int 115 | Negative number which relates to the number in the filename. 116 | 117 | Note 118 | ---- 119 | Returns a negative number since it's used in sorting and we want to sort 120 | from highest to lowest. 121 | 122 | Example 123 | ------- 124 | >>> model_name = "model-2000" 125 | >>> _split_model_name_to_negative_numeric(model_name) 126 | -2000 127 | """ 128 | return -int(model_name.split('-')[1]) 129 | 130 | 131 | def most_recent_checkpoint(checkpoint_directory): 132 | """ 133 | Sort the checkpoint files found under the checkpoints directory and return 134 | the highest numbered checkpoint file. 135 | 136 | Parameters 137 | ---------- 138 | checkpoint_directory : str 139 | Directory including checkpoint model files. 140 | 141 | Returns 142 | ------- 143 | checkpoint_filename : str 144 | Most recent checkpoint based on the numerical extension. 145 | """ 146 | checkpoints = glob.glob("{dir}/model-*".format(dir=checkpoint_directory)) 147 | 148 | return sorted(checkpoints, key=_split_model_name_to_negative_numeric)[0] 149 | 150 | 151 | def read_data_set(directory): 152 | """ 153 | Read CSV files from a directory using TensorFLow to generate the filenames 154 | and the queue of files to be processed. 155 | 156 | Parameters 157 | ---------- 158 | directory : str 159 | A directory containing CSV files which match the format which has been 160 | pulled from our original CSV file. 161 | (sepal length, sepal width, petal length, petal width, iris species) 162 | 163 | Returns 164 | ------- 165 | features, iris_species : tuple(array, scalar) 166 | The features pulled from the CSV and the target species. 167 | 168 | Notes 169 | ----- 170 | This fails on TF 0.6 if you pass in a CSV without \n line endings. 171 | """ 172 | filename_queue = tf.train.string_input_producer( 173 | tf.train.match_filenames_once(directory), 174 | shuffle=True) 175 | 176 | line_reader = tf.TextLineReader(skip_header_lines=1) 177 | 178 | _, csv_row = line_reader.read(filename_queue) 179 | 180 | record_defaults = [[0.0], [0.0], [0.0], [0.0], [""]] 181 | sepal_length, sepal_width, petal_length, petal_width, iris_species = \ 182 | tf.decode_csv(csv_row, record_defaults=record_defaults) 183 | 184 | features = tf.pack([ 185 | sepal_length, 186 | sepal_width, 187 | petal_length, 188 | petal_width]) 189 | 190 | return features, iris_species 191 | 192 | 193 | def predict_with_session(net, features, sess): 194 | """ 195 | This code was copied between a script which uses a scoped session and a 196 | webserver which uses a session which is kept open until explicitly closed. 197 | 198 | Parameters 199 | ---------- 200 | net : Topology 201 | Network to test against. 202 | features : Tensor 203 | Features used as the `x` while testing. 204 | sess : tf.Session 205 | An active session from TensorFlow, won't close after running the code. 206 | 207 | Returns 208 | ------- 209 | prediction : Tensor(Tensor, Scalar) 210 | Prediction with raw `y` results and the species chosen based on the 211 | features. The argmax will return a number between 0 and NUM_LABELS which 212 | corresponds to where the 1 will be in our one-hot vector. 213 | """ 214 | return sess.run( 215 | [net.y, tf.argmax(net.y, 1)], 216 | feed_dict={net.x: features}) 217 | 218 | 219 | def predict_init(checkpoint_dir): 220 | """ 221 | Encapsulating shared logic between the script to run predictions and 222 | the web API. The logic is mainly related to initializing the variable and 223 | checkpoint for a network. 224 | 225 | Parameters 226 | ---------- 227 | checkpoint_dir : str 228 | Directory which is used to store the checkpoint save files. 229 | 230 | Returns 231 | ------- 232 | Tuple(tf.Session, tf.Tensor) 233 | Active TensorFlow session and a feed forward network. 234 | """ 235 | net = build() 236 | 237 | checkpoint = tf.train.Saver([ 238 | net.w_hidden, net.b_hidden, net.w_out, net.b_out]) 239 | 240 | sess = tf.Session() 241 | checkpoint.restore(sess, most_recent_checkpoint(checkpoint_dir)) 242 | 243 | return sess, net 244 | 245 | 246 | def predict(features, checkpoint_dir): 247 | """ 248 | Predict the species of an iris based on its features and the most recent 249 | checkpoint model. 250 | 251 | Parameters 252 | ---------- 253 | features : tf.Tensor 254 | Meta data about an Iris to guess its species. 255 | checkpoint_dir : str 256 | Location of checkpoint model files. 257 | 258 | Returns 259 | ------- 260 | prediction : tf.Tensor 261 | See #predict_with_session for the full return shape. 262 | """ 263 | sess, net = predict_init(checkpoint_dir) 264 | 265 | prediction = predict_with_session(net, features, sess) 266 | 267 | sess.close() 268 | 269 | return prediction 270 | 271 | 272 | def test(features, iris_species, checkpoint_dir): 273 | """ 274 | Take a random test feature and attempt to predict its species. This is 275 | repeated 100x to get the accuracy of the trained model. 276 | 277 | Parameters 278 | ---------- 279 | features : tf.Tensor 280 | Iris features without their species. 281 | iris_species : tf.Tensor 282 | One-hot vector of the resulting species for each feature. 283 | checkpoint_dir : str 284 | Relative location of all the model checkpoints. 285 | """ 286 | net = build() 287 | 288 | correct_prediction = tf.equal( 289 | tf.argmax(net.y, 1), tf.argmax(net.t, 1)) 290 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 291 | 292 | checkpoint = tf.train.Saver([ 293 | net.w_hidden, net.b_hidden, net.w_out, net.b_out]) 294 | 295 | with tf.Session() as sess: 296 | # Start populating the filename queue. 297 | tf.initialize_all_variables().run() 298 | 299 | coord = tf.train.Coordinator() 300 | threads = tf.train.start_queue_runners(coord=coord) 301 | 302 | checkpoint.restore(sess, most_recent_checkpoint(checkpoint_dir)) 303 | 304 | total_accuracy = 0 305 | # We test 100 times 1-101 which pulls random elements from the test 306 | # dataset which has fewer than 100 items in it. 307 | for iteration in range(1, 101): 308 | example, label = sess.run([features, iris_species]) 309 | label = [int(l) for l in label] 310 | 311 | current_accuracy = accuracy.eval( 312 | feed_dict={ 313 | net.x: [example], 314 | net.t: [label]}) 315 | total_accuracy += current_accuracy 316 | _logger.debug( 317 | "[example: %s, label: %s, accuracy: %d]", 318 | example, 319 | label, 320 | current_accuracy) 321 | 322 | _logger.info("Total Accuracy: %0.2d", total_accuracy) 323 | coord.request_stop() 324 | coord.join(threads) 325 | 326 | 327 | def train(features, iris_species, checkpoint_dir, iterations, save_every): 328 | """ 329 | Train a model based on a training set. 330 | 331 | Parameters 332 | ---------- 333 | features : tf.Tensor 334 | Training features of Iris flowers found in the training set. 335 | iris_species : tf.Tensor 336 | Known species of Iris flowers described by features. 337 | checkpoint_dir : str 338 | Directory to save model checkpoint files to. 339 | iterations : int 340 | Maximum number of iterations (steps) to run before quitting training. 341 | Training continues to run until we tell it to stop... 342 | save_every : int 343 | Number of iterations between each save to a checkpoint model. Tweak this 344 | to stop from creating too many checkpoint files. 345 | 346 | Notes 347 | ----- 348 | Logic for training and the GradientDescentOptimizer are taken from the 349 | training steps found at: 350 | https://www.tensorflow.org/versions/0.6.0/tutorials/mnist/tf/index.html 351 | """ 352 | net = build() 353 | 354 | cross_entropy = -tf.reduce_sum(net.t * tf.log(net.y)) 355 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 356 | 357 | checkpoint = tf.train.Saver([ 358 | net.w_hidden, net.b_hidden, net.w_out, net.b_out]) 359 | with tf.Session() as sess: 360 | # Start populating the filename queue. 361 | tf.initialize_all_variables().run() 362 | 363 | coord = tf.train.Coordinator() 364 | threads = tf.train.start_queue_runners(coord=coord) 365 | 366 | if save_every is None: 367 | save_every = iterations / 10 368 | 369 | for iteration in range(1, iterations + 1): 370 | example, label = sess.run([features, iris_species]) 371 | label = [int(l) for l in label] 372 | 373 | train_step.run(feed_dict={ 374 | net.x: [example], net.t: [label]}) 375 | 376 | if iteration % save_every == 0: 377 | _logger.info("Saving iteration %i.", iteration) 378 | save_path = checkpoint.save( 379 | sess, 380 | "{cd}/model".format(cd=checkpoint_dir), 381 | global_step=iteration) 382 | _logger.debug("File saved to %s", save_path) 383 | 384 | coord.request_stop() 385 | coord.join(threads) 386 | 387 | 388 | def onehot_from_argmax(argmax_index): 389 | """ 390 | Convert the argmax output to be a one-hot representation. 391 | 392 | Parameters 393 | ---------- 394 | argmax_index : int{0..NUM_LABELS-1} 395 | Argmax output from finding the highest value in the Tensor from the 396 | prediction model. 397 | 398 | Returns 399 | ------- 400 | onehot : str 401 | One-hot vector based on the highest index found in the y Tensor. 402 | """ 403 | return "%0*d" % (NUM_LABELS, 10 ** (NUM_LABELS - argmax_index - 1)) 404 | --------------------------------------------------------------------------------