├── .gitignore ├── .travis.yml ├── CHANGELOG.rst ├── LICENSE.txt ├── README.rst ├── docs └── logo │ ├── strawberry-large.png │ ├── strawberry-medium.png │ └── strawberry-small.png ├── examples ├── iris.py ├── iris_data.csv └── trivial_model.py ├── pypi_upload ├── .pypirc ├── pypi_upload.sh └── setup.py ├── requirements.txt ├── smart_fruit ├── __init__.py ├── feature_class.py ├── feature_types │ ├── __init__.py │ ├── compound_types.py │ ├── feature_type_base.py │ └── simple_types.py ├── model.py ├── model_selection.py └── utils.py └── tests ├── __init__.py ├── test_compound_types.py ├── test_feature_serialization.py ├── test_label_types.py ├── test_model_training.py ├── test_simple_types.py ├── test_trivial_model.py └── test_utils ├── __init__.py ├── example_csv.csv └── test_csv_open.py /.gitignore: -------------------------------------------------------------------------------- 1 | # PyPI build files 2 | pypi_upload/.pypirc 3 | build 4 | dist 5 | smart_fruit.egg-info 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | cache: pip 3 | python: 4 | - "3.6" 5 | install: 6 | - pip install -r requirements.txt 7 | script: 8 | - python -m unittest 9 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 1.2.1 5 | ----- 6 | 7 | Bug fixes: 8 | 9 | - Fix broken build. 10 | 11 | 1.2.0 12 | ----- 13 | 14 | Features: 15 | 16 | - When reading CSV files, assume columns in same order as defined in class if not given in file. 17 | - Add the ``Integer``, ``Complex``, ``Vector``, and ``Tag`` feature types. 18 | - Add the ``Model.predict`` ``yield_inputs`` parameter. 19 | - Add the ``Model.train`` ``random_state`` parameter. 20 | 21 | Bug fixes: 22 | 23 | - Support unicode in CSV files. 24 | - Fix bug which raised an error when predicting a ``Number`` that wasn't the first feature in ``Output``. 25 | 26 | 1.1.0 27 | ----- 28 | 29 | Features: 30 | 31 | - Add basic feature type validation, and coercion, function. 32 | - Add test/train split parameters to ``Model.train``. 33 | 34 | Bug fixes: 35 | 36 | - Allow use of non-orderable types as ``Label`` labels. 37 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 madman-bob 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Smart Fruit 2 | =========== 3 | 4 | Purpose 5 | ------- 6 | 7 | A Python machine learning library, for creating quick and easy machine learning models. 8 | It is schema-based, and wraps `scikit-learn `_. 9 | 10 | Usage 11 | ----- 12 | 13 | Create and use a machine learning model in 3 steps: 14 | 15 | 1. Create a schema representing your input and output features. 16 | 2. Train a model from your data. 17 | 3. Make predictions from your model. 18 | 19 | Example 20 | ------- 21 | 22 | To get a feel for the library, consider the classic `Iris `_ dataset, 23 | where we predict the class of iris plant from measurements of the sepal, and petal. 24 | 25 | First, we create a schema describing our inputs and outputs. 26 | For our inputs, we have the length, and width, of both the sepal, and the petal. 27 | All of these input values happen to be numbers. 28 | For our output, we have just the class of iris, which may be one of the labels ``Iris-setosa``, ``Iris-versicolor``, or ``Iris-virginica``. 29 | 30 | We define this in code as follows: 31 | 32 | .. code:: python 33 | 34 | from smart_fruit import Model 35 | from smart_fruit.feature_types import Number, Label 36 | 37 | 38 | class Iris(Model): 39 | class Input: 40 | sepal_length_cm = Number() 41 | sepal_width_cm = Number() 42 | petal_length_cm = Number() 43 | petal_width_cm = Number() 44 | 45 | class Output: 46 | iris_class = Label(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']) 47 | 48 | Then, we train a model: 49 | 50 | .. code:: python 51 | 52 | model = Iris.train(Iris.features_from_csv('iris_data.csv')) 53 | 54 | with data file `iris_data.csv `_. 55 | 56 | :: 57 | 58 | sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm,iris_class 59 | 5.1,3.5,1.4,0.2,Iris-setosa 60 | ... 61 | 62 | Finally, we use our new model to make predictions: 63 | 64 | .. code:: python 65 | 66 | for prediction in model.predict([Iris.Input(5.1, 3.5, 1.4, 0.2)]): 67 | print(prediction.iris_class) 68 | 69 | Reference 70 | --------- 71 | 72 | Models 73 | ~~~~~~ 74 | 75 | - ``Model.Input`` - Schema for defining your input features. 76 | 77 | - ``Model.Output`` - Schema for defining your output features. 78 | 79 | Define ``Model.Input`` and ``Model.Output`` as classes with ``FeatureType`` attributes. 80 | 81 | eg. Consider the ``Iris`` class defined above. 82 | 83 | These classes can then be used to create objects representing the appropriate collections of features. 84 | 85 | eg. 86 | 87 | .. code:: python 88 | 89 | >>> iris_input = Iris.Input(5.1, 3.5, 1.4, 0.2) 90 | >>> iris_input 91 | Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2) 92 | >>> iris_input.sepal_length 93 | 5.1 94 | 95 | >>> Iris.Input.from_json({'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2}) 96 | Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2) 97 | 98 | - ``Model.features_from_list(lists)`` - Deserialize an iterable of lists into an iterable of input/output feature pairs. 99 | 100 | eg. 101 | 102 | .. code:: python 103 | 104 | >>> list(Iris.features_from_list([[5.1, 3.5, 1.4, 0.2, 'Iris-setosa']])) 105 | [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa'))] 106 | 107 | - ``Model.input_features_from_list(lists)`` - Deserialize an iterable of lists into an iterable of input features. 108 | 109 | eg. 110 | 111 | .. code:: python 112 | 113 | >>> list(Iris.input_features_from_list([[5.1, 3.5, 1.4, 0.2]])) 114 | [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)] 115 | 116 | - ``Model.features_from_json(json)`` - Deserialize an iterable of dictionaries into an iterable of input/output feature pairs. 117 | 118 | eg. 119 | 120 | .. code:: python 121 | 122 | >>> list(Iris.features_from_json([{'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2, 'iris_class': 'Iris-setosa'}])) 123 | [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa'))] 124 | 125 | - ``Model.input_features_from_json(json)`` - Deserialize an iterable of dictionaries into an iterable of input features. 126 | 127 | eg. 128 | 129 | .. code:: python 130 | 131 | >>> list(Iris.input_features_from_json([{'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2}])) 132 | [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)] 133 | 134 | - ``Model.features_from_csv(csv_path)`` - Take a path to a CSV file, or a file-like object, and deserialize it into an iterable of input/output feature pairs. 135 | 136 | If column headings are not given in the file, assume the input features are followed by the output features, in the order they are defined in their respective classes. 137 | 138 | eg. 139 | 140 | .. code:: python 141 | 142 | >>> list(Iris.features_from_csv('iris_data.csv')) 143 | [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa')), ...] 144 | 145 | - ``Model.input_features_from_csv(csv_path)`` - Take a path to a CSV file, or a file-like object, and deserialize it into an iterable of input features. 146 | 147 | If column headings are not given in the file, assume they are in the order they are defined in the ``Input`` class. 148 | 149 | eg. 150 | 151 | .. code:: python 152 | 153 | >>> list(Iris.input_features_from_csv('iris_data.csv')) 154 | [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), ...] 155 | 156 | - ``Model.model_class`` - How to model the relation between the input and output data. 157 | 158 | Default: ``sklearn.linear_model.LinearRegression`` 159 | 160 | This attribute accepts any class with ``fit``, ``predict``, and ``score`` methods defined as for ``scikit-learn`` multi-response regression models. 161 | In particular, this attribute accepts any ``scikit-learn`` multi-response regression models, 162 | ie. any ``scikit-learn`` regression model where the ``y`` parameter of ``fit`` accepts a numpy array of shape ``[n_samples, n_targets]``. 163 | 164 | - ``Model.train(features, train_test_split_ratio=None, test_sample_count=None, random_state=None)`` 165 | 166 | Train a new model on the given iterable of input/output pairs. 167 | 168 | Parameters: 169 | 170 | - ``features`` - An iterable of input/output pairs. 171 | 172 | - ``train_test_split_ratio`` - Proportion of data to use as cross-validation test data. 173 | 174 | - ``test_sample_count`` - Number of samples of data to use as cross-validation test data. 175 | 176 | If ``train_test_split_ratio`` or ``test_sample_count`` are provided, perform cross-validation of the given data. 177 | Return both the trained model, and the score of the test data on that model. 178 | 179 | - ``random_state`` - Either a ``numpy`` ``RandomState``, or the seed to use for the PRNG. 180 | 181 | Useful for getting consistent results, for example for automated tests. 182 | Do not use this parameter when generating models you plan to use in production settings. 183 | 184 | eg. 185 | 186 | .. code:: python 187 | 188 | >>> iris_model = Iris.train([(Iris.Input(5.1, 3.5, 1.4, 0.2), Iris.Output('Iris-setosa'))]) 189 | 190 | - ``model.predict(input_features, yield_inputs=False)`` - Predict the outputs for a given iterable of inputs. 191 | 192 | If ``yield_inputs`` is ``True`` then yield the prediction with the input used to generate it, as ``input``, ``output`` pairs. 193 | Otherwise, yield just the predictions, in the same order the inputs are given to the model. 194 | 195 | eg. 196 | 197 | .. code:: python 198 | 199 | >>> list(iris_model.predict([Iris.Input(5.1, 3.5, 1.4, 0.2)])) 200 | [Output(iris_class='Iris-setosa')] 201 | 202 | Feature Types 203 | ~~~~~~~~~~~~~ 204 | 205 | Smart Fruit recognizes the following data types for input and output features. 206 | Custom types may be made by extending the ``FeatureType`` class. 207 | 208 | - ``Number()`` - A real-valued feature. 209 | 210 | eg. ``0``, ``1``, ``3.141592``, ``-17``, ... 211 | 212 | - ``Integer()`` - A whole number feature. 213 | 214 | eg. ``0``, ``1``, ``3``, ``-17``, ... 215 | 216 | - ``Complex()`` - A complex-valued number feature. 217 | 218 | eg. ``0``, ``1``, ``3 + 4j``, ``-1 + 7j``, ... 219 | 220 | - ``Label(labels)`` - An enumerated feature, ie. one which may take one of a pre-defined list of available values. 221 | 222 | eg. For ``labels = ['red', 'green', 'blue']``, our label may take the value ``'red'``, but not ``'purple'``. 223 | 224 | - ``Vector(feature_types)`` - A feature made of other features. Useful for grouping conceptually related features. 225 | 226 | eg. For ``feature_types = [Number(), Label(['red', 'green', 'blue'])]``, we may take values such as ``(0, 'red')``, and ``(1, 'blue')``. 227 | 228 | - ``Tag()`` - A feature that is ignored when making predictions. Useful for keeping track of ID numbers. 229 | 230 | Accepts any Python value. 231 | 232 | Requirements 233 | ------------ 234 | 235 | Smart Fruit requires Python 3.6+, and uses 236 | `scikit-learn `_, 237 | `scipy `_, 238 | and `pandas `_. 239 | 240 | Installation 241 | ------------ 242 | 243 | Install and update using the standard Python package manager `pip `_: 244 | 245 | .. code:: text 246 | 247 | pip install smart-fruit 248 | 249 | Donate 250 | ------ 251 | 252 | To support the continued development of Smart Fruit, please 253 | `donate `_. 254 | -------------------------------------------------------------------------------- /docs/logo/strawberry-large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/docs/logo/strawberry-large.png -------------------------------------------------------------------------------- /docs/logo/strawberry-medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/docs/logo/strawberry-medium.png -------------------------------------------------------------------------------- /docs/logo/strawberry-small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/docs/logo/strawberry-small.png -------------------------------------------------------------------------------- /examples/iris.py: -------------------------------------------------------------------------------- 1 | from smart_fruit import Model 2 | from smart_fruit.feature_types import Number, Label 3 | 4 | 5 | class Iris(Model): 6 | """ 7 | Example class for the "Iris" data set: 8 | 9 | https://archive.ics.uci.edu/ml/datasets/Iris 10 | """ 11 | 12 | class Input: 13 | sepal_length_cm = Number() 14 | sepal_width_cm = Number() 15 | petal_length_cm = Number() 16 | petal_width_cm = Number() 17 | 18 | class Output: 19 | iris_class = Label(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']) 20 | 21 | 22 | def main(): 23 | features = list(Iris.features_from_csv('iris_data.csv')) 24 | 25 | model = Iris.train(features) 26 | 27 | input_features = [input_feature for input_feature, output_feature in features] 28 | 29 | for (input_, output), predicted_output in zip(features, model.predict(input_features)): 30 | print(list(input_), output.iris_class, predicted_output.iris_class) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /examples/iris_data.csv: -------------------------------------------------------------------------------- 1 | sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm,iris_class 2 | 5.1,3.5,1.4,0.2,Iris-setosa 3 | 4.9,3.0,1.4,0.2,Iris-setosa 4 | 4.7,3.2,1.3,0.2,Iris-setosa 5 | 4.6,3.1,1.5,0.2,Iris-setosa 6 | 5.0,3.6,1.4,0.2,Iris-setosa 7 | 5.4,3.9,1.7,0.4,Iris-setosa 8 | 4.6,3.4,1.4,0.3,Iris-setosa 9 | 5.0,3.4,1.5,0.2,Iris-setosa 10 | 4.4,2.9,1.4,0.2,Iris-setosa 11 | 4.9,3.1,1.5,0.1,Iris-setosa 12 | 5.4,3.7,1.5,0.2,Iris-setosa 13 | 4.8,3.4,1.6,0.2,Iris-setosa 14 | 4.8,3.0,1.4,0.1,Iris-setosa 15 | 4.3,3.0,1.1,0.1,Iris-setosa 16 | 5.8,4.0,1.2,0.2,Iris-setosa 17 | 5.7,4.4,1.5,0.4,Iris-setosa 18 | 5.4,3.9,1.3,0.4,Iris-setosa 19 | 5.1,3.5,1.4,0.3,Iris-setosa 20 | 5.7,3.8,1.7,0.3,Iris-setosa 21 | 5.1,3.8,1.5,0.3,Iris-setosa 22 | 5.4,3.4,1.7,0.2,Iris-setosa 23 | 5.1,3.7,1.5,0.4,Iris-setosa 24 | 4.6,3.6,1.0,0.2,Iris-setosa 25 | 5.1,3.3,1.7,0.5,Iris-setosa 26 | 4.8,3.4,1.9,0.2,Iris-setosa 27 | 5.0,3.0,1.6,0.2,Iris-setosa 28 | 5.0,3.4,1.6,0.4,Iris-setosa 29 | 5.2,3.5,1.5,0.2,Iris-setosa 30 | 5.2,3.4,1.4,0.2,Iris-setosa 31 | 4.7,3.2,1.6,0.2,Iris-setosa 32 | 4.8,3.1,1.6,0.2,Iris-setosa 33 | 5.4,3.4,1.5,0.4,Iris-setosa 34 | 5.2,4.1,1.5,0.1,Iris-setosa 35 | 5.5,4.2,1.4,0.2,Iris-setosa 36 | 4.9,3.1,1.5,0.1,Iris-setosa 37 | 5.0,3.2,1.2,0.2,Iris-setosa 38 | 5.5,3.5,1.3,0.2,Iris-setosa 39 | 4.9,3.1,1.5,0.1,Iris-setosa 40 | 4.4,3.0,1.3,0.2,Iris-setosa 41 | 5.1,3.4,1.5,0.2,Iris-setosa 42 | 5.0,3.5,1.3,0.3,Iris-setosa 43 | 4.5,2.3,1.3,0.3,Iris-setosa 44 | 4.4,3.2,1.3,0.2,Iris-setosa 45 | 5.0,3.5,1.6,0.6,Iris-setosa 46 | 5.1,3.8,1.9,0.4,Iris-setosa 47 | 4.8,3.0,1.4,0.3,Iris-setosa 48 | 5.1,3.8,1.6,0.2,Iris-setosa 49 | 4.6,3.2,1.4,0.2,Iris-setosa 50 | 5.3,3.7,1.5,0.2,Iris-setosa 51 | 5.0,3.3,1.4,0.2,Iris-setosa 52 | 7.0,3.2,4.7,1.4,Iris-versicolor 53 | 6.4,3.2,4.5,1.5,Iris-versicolor 54 | 6.9,3.1,4.9,1.5,Iris-versicolor 55 | 5.5,2.3,4.0,1.3,Iris-versicolor 56 | 6.5,2.8,4.6,1.5,Iris-versicolor 57 | 5.7,2.8,4.5,1.3,Iris-versicolor 58 | 6.3,3.3,4.7,1.6,Iris-versicolor 59 | 4.9,2.4,3.3,1.0,Iris-versicolor 60 | 6.6,2.9,4.6,1.3,Iris-versicolor 61 | 5.2,2.7,3.9,1.4,Iris-versicolor 62 | 5.0,2.0,3.5,1.0,Iris-versicolor 63 | 5.9,3.0,4.2,1.5,Iris-versicolor 64 | 6.0,2.2,4.0,1.0,Iris-versicolor 65 | 6.1,2.9,4.7,1.4,Iris-versicolor 66 | 5.6,2.9,3.6,1.3,Iris-versicolor 67 | 6.7,3.1,4.4,1.4,Iris-versicolor 68 | 5.6,3.0,4.5,1.5,Iris-versicolor 69 | 5.8,2.7,4.1,1.0,Iris-versicolor 70 | 6.2,2.2,4.5,1.5,Iris-versicolor 71 | 5.6,2.5,3.9,1.1,Iris-versicolor 72 | 5.9,3.2,4.8,1.8,Iris-versicolor 73 | 6.1,2.8,4.0,1.3,Iris-versicolor 74 | 6.3,2.5,4.9,1.5,Iris-versicolor 75 | 6.1,2.8,4.7,1.2,Iris-versicolor 76 | 6.4,2.9,4.3,1.3,Iris-versicolor 77 | 6.6,3.0,4.4,1.4,Iris-versicolor 78 | 6.8,2.8,4.8,1.4,Iris-versicolor 79 | 6.7,3.0,5.0,1.7,Iris-versicolor 80 | 6.0,2.9,4.5,1.5,Iris-versicolor 81 | 5.7,2.6,3.5,1.0,Iris-versicolor 82 | 5.5,2.4,3.8,1.1,Iris-versicolor 83 | 5.5,2.4,3.7,1.0,Iris-versicolor 84 | 5.8,2.7,3.9,1.2,Iris-versicolor 85 | 6.0,2.7,5.1,1.6,Iris-versicolor 86 | 5.4,3.0,4.5,1.5,Iris-versicolor 87 | 6.0,3.4,4.5,1.6,Iris-versicolor 88 | 6.7,3.1,4.7,1.5,Iris-versicolor 89 | 6.3,2.3,4.4,1.3,Iris-versicolor 90 | 5.6,3.0,4.1,1.3,Iris-versicolor 91 | 5.5,2.5,4.0,1.3,Iris-versicolor 92 | 5.5,2.6,4.4,1.2,Iris-versicolor 93 | 6.1,3.0,4.6,1.4,Iris-versicolor 94 | 5.8,2.6,4.0,1.2,Iris-versicolor 95 | 5.0,2.3,3.3,1.0,Iris-versicolor 96 | 5.6,2.7,4.2,1.3,Iris-versicolor 97 | 5.7,3.0,4.2,1.2,Iris-versicolor 98 | 5.7,2.9,4.2,1.3,Iris-versicolor 99 | 6.2,2.9,4.3,1.3,Iris-versicolor 100 | 5.1,2.5,3.0,1.1,Iris-versicolor 101 | 5.7,2.8,4.1,1.3,Iris-versicolor 102 | 6.3,3.3,6.0,2.5,Iris-virginica 103 | 5.8,2.7,5.1,1.9,Iris-virginica 104 | 7.1,3.0,5.9,2.1,Iris-virginica 105 | 6.3,2.9,5.6,1.8,Iris-virginica 106 | 6.5,3.0,5.8,2.2,Iris-virginica 107 | 7.6,3.0,6.6,2.1,Iris-virginica 108 | 4.9,2.5,4.5,1.7,Iris-virginica 109 | 7.3,2.9,6.3,1.8,Iris-virginica 110 | 6.7,2.5,5.8,1.8,Iris-virginica 111 | 7.2,3.6,6.1,2.5,Iris-virginica 112 | 6.5,3.2,5.1,2.0,Iris-virginica 113 | 6.4,2.7,5.3,1.9,Iris-virginica 114 | 6.8,3.0,5.5,2.1,Iris-virginica 115 | 5.7,2.5,5.0,2.0,Iris-virginica 116 | 5.8,2.8,5.1,2.4,Iris-virginica 117 | 6.4,3.2,5.3,2.3,Iris-virginica 118 | 6.5,3.0,5.5,1.8,Iris-virginica 119 | 7.7,3.8,6.7,2.2,Iris-virginica 120 | 7.7,2.6,6.9,2.3,Iris-virginica 121 | 6.0,2.2,5.0,1.5,Iris-virginica 122 | 6.9,3.2,5.7,2.3,Iris-virginica 123 | 5.6,2.8,4.9,2.0,Iris-virginica 124 | 7.7,2.8,6.7,2.0,Iris-virginica 125 | 6.3,2.7,4.9,1.8,Iris-virginica 126 | 6.7,3.3,5.7,2.1,Iris-virginica 127 | 7.2,3.2,6.0,1.8,Iris-virginica 128 | 6.2,2.8,4.8,1.8,Iris-virginica 129 | 6.1,3.0,4.9,1.8,Iris-virginica 130 | 6.4,2.8,5.6,2.1,Iris-virginica 131 | 7.2,3.0,5.8,1.6,Iris-virginica 132 | 7.4,2.8,6.1,1.9,Iris-virginica 133 | 7.9,3.8,6.4,2.0,Iris-virginica 134 | 6.4,2.8,5.6,2.2,Iris-virginica 135 | 6.3,2.8,5.1,1.5,Iris-virginica 136 | 6.1,2.6,5.6,1.4,Iris-virginica 137 | 7.7,3.0,6.1,2.3,Iris-virginica 138 | 6.3,3.4,5.6,2.4,Iris-virginica 139 | 6.4,3.1,5.5,1.8,Iris-virginica 140 | 6.0,3.0,4.8,1.8,Iris-virginica 141 | 6.9,3.1,5.4,2.1,Iris-virginica 142 | 6.7,3.1,5.6,2.4,Iris-virginica 143 | 6.9,3.1,5.1,2.3,Iris-virginica 144 | 5.8,2.7,5.1,1.9,Iris-virginica 145 | 6.8,3.2,5.9,2.3,Iris-virginica 146 | 6.7,3.3,5.7,2.5,Iris-virginica 147 | 6.7,3.0,5.2,2.3,Iris-virginica 148 | 6.3,2.5,5.0,1.9,Iris-virginica 149 | 6.5,3.0,5.2,2.0,Iris-virginica 150 | 6.2,3.4,5.4,2.3,Iris-virginica 151 | 5.9,3.0,5.1,1.8,Iris-virginica 152 | -------------------------------------------------------------------------------- /examples/trivial_model.py: -------------------------------------------------------------------------------- 1 | from smart_fruit import Model 2 | from smart_fruit.feature_types import Number 3 | 4 | 5 | class TrivialModel(Model): 6 | class Input: 7 | input_ = Number() 8 | 9 | class Output: 10 | output = Number() 11 | 12 | 13 | def main(): 14 | multiply_by_10 = TrivialModel.train([ 15 | (TrivialModel.Input(n), TrivialModel.Output(10 * n)) 16 | for n in [1, 2, 4, 5, 7, 8, 10] 17 | ]) 18 | 19 | inputs = [3, 6, 9] 20 | predictions = multiply_by_10.predict([TrivialModel.Input(n) for n in inputs], yield_inputs=True) 21 | for input_, predicted_output in predictions: 22 | print(input_, predicted_output.output) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /pypi_upload/.pypirc: -------------------------------------------------------------------------------- 1 | [distutils] 2 | index-servers = 3 | pypi 4 | testpypi 5 | 6 | [pypi] 7 | username: 8 | password: 9 | 10 | [testpypi] 11 | repository: https://test.pypi.org/legacy/ 12 | username: 13 | password: 14 | -------------------------------------------------------------------------------- /pypi_upload/pypi_upload.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Build distribution 4 | python3 pypi_upload/setup.py sdist bdist_wheel 5 | 6 | # Upload distribution to PyPI 7 | twine upload --config-file pypi_upload/.pypirc dist/* 8 | # For testing, add `--repository testpypi` 9 | 10 | # Tidy up 11 | rm -rf build 12 | rm -rf dist 13 | rm -rf smart_fruit.egg-info 14 | -------------------------------------------------------------------------------- /pypi_upload/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | 4 | import re 5 | 6 | project_root = path.join(path.abspath(path.dirname(__file__)), '..') 7 | 8 | 9 | def get_version(): 10 | with open(path.join(project_root, 'smart_fruit', '__init__.py'), encoding='utf-8') as init_file: 11 | return re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M).group(1) 12 | 13 | 14 | def get_long_description(): 15 | with open(path.join(project_root, 'README.rst'), encoding='utf-8') as readme_file: 16 | return readme_file.read() 17 | 18 | 19 | def get_requirements(): 20 | with open(path.join(project_root, 'requirements.txt'), encoding='utf-8') as requirements_file: 21 | return [requirement.strip() for requirement in requirements_file if requirement.strip()] 22 | 23 | 24 | setup( 25 | name='smart-fruit', 26 | version=get_version(), 27 | packages=find_packages(include=('smart_fruit', 'smart_fruit.*')), 28 | install_requires=get_requirements(), 29 | 30 | author='Robert Wright', 31 | author_email='madman.bob@hotmail.co.uk', 32 | 33 | description='A Python schema-based machine learning library', 34 | long_description=get_long_description(), 35 | long_description_content_type='text/x-rst', 36 | url='https://github.com/madman-bob/Smart-Fruit', 37 | license='MIT', 38 | classifiers=[ 39 | 'License :: OSI Approved :: MIT License', 40 | 'Programming Language :: Python :: 3', 41 | 'Programming Language :: Python :: 3.6', 42 | 'Programming Language :: Python :: 3.7', 43 | 'Programming Language :: Python :: 3 :: Only', 44 | ], 45 | python_requires='>=3.6' 46 | ) 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | scipy 3 | scikit-learn 4 | -------------------------------------------------------------------------------- /smart_fruit/__init__.py: -------------------------------------------------------------------------------- 1 | from smart_fruit.model import Model 2 | 3 | __version__ = '1.2.1' 4 | 5 | __all__ = ["Model"] 6 | -------------------------------------------------------------------------------- /smart_fruit/feature_class.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from smart_fruit.feature_types import FeatureType 4 | 5 | __all__ = ["FeatureClassMeta"] 6 | 7 | 8 | class FeatureClassMixin: 9 | def validate(self): 10 | return self.__class__(*( 11 | feature_type.validate(value) 12 | for feature_type, value in zip(self.__class__, self) 13 | )) 14 | 15 | @classmethod 16 | def from_json(cls, json): 17 | return cls(**{ 18 | key: value 19 | for key, value in json.items() 20 | if key in cls._fields 21 | }) 22 | 23 | def to_json(self): 24 | return self._asdict() 25 | 26 | 27 | class FeatureClassMeta(type): 28 | def __new__(cls, name, bases, namespace): 29 | base_feature_type = bases[0] 30 | features = tuple( 31 | key 32 | for key, value in base_feature_type.__dict__.items() 33 | if isinstance(value, FeatureType) 34 | ) 35 | 36 | return type.__new__( 37 | cls, 38 | name, 39 | tuple(bases) + (namedtuple(base_feature_type.__name__, features), FeatureClassMixin,), 40 | namespace 41 | ) 42 | 43 | def __iter__(self): 44 | for field_name in self._fields: 45 | yield getattr(self, field_name) 46 | 47 | def __len__(self): 48 | return len(self._fields) 49 | -------------------------------------------------------------------------------- /smart_fruit/feature_types/__init__.py: -------------------------------------------------------------------------------- 1 | from smart_fruit.feature_types.feature_type_base import FeatureType 2 | from smart_fruit.feature_types.simple_types import Number, Integer, Complex, Label, Tag 3 | from smart_fruit.feature_types.compound_types import Vector 4 | 5 | __all__ = ["FeatureType", "Number", "Integer", "Complex", "Label", "Vector", "Tag"] 6 | -------------------------------------------------------------------------------- /smart_fruit/feature_types/compound_types.py: -------------------------------------------------------------------------------- 1 | from pandas import concat 2 | 3 | from smart_fruit.feature_types.feature_type_base import FeatureType 4 | 5 | __all__ = ["Vector"] 6 | 7 | 8 | class Vector(FeatureType): 9 | def __init__(self, feature_types): 10 | self.feature_types = feature_types 11 | 12 | @property 13 | def feature_count(self): 14 | return sum(feature_type.feature_count for feature_type in self.feature_types) 15 | 16 | def validate(self, value): 17 | if len(value) != len(self.feature_types): 18 | raise ValueError( 19 | "Incorrect length vector (expected {}, got {!r})".format(len(self.feature_types), len(value)) 20 | ) 21 | 22 | return tuple( 23 | feature_type.validate(subvalue) 24 | for subvalue, feature_type in zip(value, self.feature_types) 25 | ) 26 | 27 | def to_series(self, value): 28 | if len(value) != len(self.feature_types): 29 | raise ValueError( 30 | "Incorrect length vector (expected {}, got {!r})".format(len(self.feature_types), len(value)) 31 | ) 32 | 33 | return concat([ 34 | feature_type.to_series(subvalue) 35 | for subvalue, feature_type in zip(value, self.feature_types) 36 | ], ignore_index=True) 37 | 38 | def from_series(self, features): 39 | return tuple( 40 | feature_type.from_series(chunk) 41 | for chunk, feature_type in self._chunk_series(features, self.feature_types) 42 | ) 43 | 44 | @staticmethod 45 | def _chunk_series(series, feature_types): 46 | start = 0 47 | for feature_type in feature_types: 48 | yield series.iloc[start:start + feature_type.feature_count].reset_index(drop=True), feature_type 49 | start += feature_type.feature_count 50 | -------------------------------------------------------------------------------- /smart_fruit/feature_types/feature_type_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | from pandas import Series 4 | 5 | __all__ = ["FeatureType"] 6 | 7 | 8 | class FeatureType(metaclass=ABCMeta): 9 | _index = None 10 | feature_count = 1 11 | 12 | def __get__(self, instance, owner): 13 | if instance is None: 14 | return self 15 | 16 | if self._index is None: 17 | self._index = next( 18 | i 19 | for i, name in enumerate(owner._fields) 20 | if getattr(owner, name) is self 21 | ) 22 | 23 | return instance[self._index] 24 | 25 | def validate(self, value): 26 | return value 27 | 28 | def to_series(self, value): 29 | return Series([value]) 30 | 31 | def from_series(self, features): 32 | return features.iloc[0] 33 | -------------------------------------------------------------------------------- /smart_fruit/feature_types/simple_types.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from numpy import isfinite 4 | from pandas import Series 5 | 6 | from smart_fruit.feature_types.feature_type_base import FeatureType 7 | 8 | __all__ = ["Number", "Integer", "Complex", "Label", "Tag"] 9 | 10 | 11 | class Number(FeatureType): 12 | def validate(self, value): 13 | value = float(value) 14 | 15 | if not isfinite(value): 16 | raise ValueError( 17 | "May not assign non-finite value {} to a {}".format( 18 | value, 19 | self.__class__.__name__ 20 | ) 21 | ) 22 | 23 | return value 24 | 25 | 26 | class Integer(Number): 27 | def validate(self, value): 28 | return int(round(super().validate(value))) 29 | 30 | def from_series(self, features): 31 | return int(round(super().from_series(features))) 32 | 33 | 34 | class Complex(FeatureType): 35 | feature_count = 2 36 | 37 | def validate(self, value): 38 | value = complex(value) 39 | 40 | if not isfinite(value): 41 | raise ValueError( 42 | "May not assign non-finite value {} to a {}".format( 43 | value, 44 | self.__class__.__name__ 45 | ) 46 | ) 47 | 48 | return value 49 | 50 | def to_series(self, value): 51 | return Series([value.real, value.imag]) 52 | 53 | def from_series(self, features): 54 | return complex(*features) 55 | 56 | 57 | class Label(FeatureType, namedtuple('Label', ['labels'])): 58 | @property 59 | def feature_count(self): 60 | return len(self.labels) 61 | 62 | def validate(self, value): 63 | if value not in self.labels: 64 | raise TypeError( 65 | "May not use non-existent label {!r} in a {!r}".format( 66 | value, 67 | self 68 | ) 69 | ) 70 | 71 | return value 72 | 73 | def to_series(self, value): 74 | return Series([int(value == label) for label in self.labels]) 75 | 76 | def from_series(self, features): 77 | return max(zip(features, enumerate(self.labels)))[1][1] 78 | 79 | 80 | class Tag(FeatureType): 81 | feature_count = 0 82 | 83 | def to_series(self, value): 84 | return Series() 85 | 86 | def from_series(self, features): 87 | raise TypeError( 88 | "May not predict a {}".format(self.__class__.__name__) 89 | ) 90 | -------------------------------------------------------------------------------- /smart_fruit/model.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame, concat 2 | from pandas.core.apply import frame_apply 3 | 4 | from sklearn import linear_model 5 | 6 | from smart_fruit.feature_class import FeatureClassMeta 7 | from smart_fruit.model_selection import train_test_split 8 | from smart_fruit.utils import csv_open 9 | 10 | __all__ = ["Model"] 11 | 12 | 13 | class ModelMeta(type): 14 | def __init__(cls, name, bases, namespace): 15 | super().__init__(name, bases, namespace) 16 | 17 | for feature_type in ['Input', 'Output']: 18 | feature_class = getattr(cls, feature_type) 19 | 20 | setattr(cls, feature_type, FeatureClassMeta(feature_class.__name__, (feature_class,), {})) 21 | 22 | 23 | class Model(metaclass=ModelMeta): 24 | model_class = linear_model.LinearRegression 25 | 26 | class Input: 27 | pass 28 | 29 | class Output: 30 | pass 31 | 32 | def __init__(self, *args, **kwargs): 33 | self.model = self.model_class(*args, **kwargs) 34 | 35 | @classmethod 36 | def input_features_from_list(cls, lists): 37 | for l in lists: 38 | yield cls.Input(*l).validate() 39 | 40 | @classmethod 41 | def input_features_from_json(cls, json): 42 | for feature in json: 43 | yield cls.Input.from_json(feature).validate() 44 | 45 | @classmethod 46 | def input_features_from_csv(cls, csv_path): 47 | yield from cls.input_features_from_json(csv_open(csv_path, cls.Input._fields)) 48 | 49 | @classmethod 50 | def features_from_list(cls, lists): 51 | for l in lists: 52 | yield cls.Input(*l[:len(cls.Input._fields)]).validate(), cls.Output(*l[len(cls.Input._fields):]).validate() 53 | 54 | @classmethod 55 | def features_from_json(cls, json): 56 | for feature in json: 57 | yield cls.Input.from_json(feature).validate(), cls.Output.from_json(feature).validate() 58 | 59 | @classmethod 60 | def features_from_csv(cls, csv_path): 61 | yield from cls.features_from_json(csv_open(csv_path, cls.Input._fields + cls.Output._fields)) 62 | 63 | @staticmethod 64 | def _to_raw_features(dataframe, feature_class): 65 | return concat([ 66 | column.apply(feature_type.to_series) 67 | for (i, column), feature_type in zip(dataframe.iteritems(), feature_class) 68 | ], axis=1) 69 | 70 | def _dataframes_from_features(self, features): 71 | dataframe = DataFrame(list(input_) + list(output) for input_, output in features) 72 | 73 | input_dataframe = self._to_raw_features(dataframe, self.Input) 74 | output_dataframe = self._to_raw_features(dataframe.loc[:, len(self.Input):], self.Output) 75 | 76 | return input_dataframe, output_dataframe 77 | 78 | @classmethod 79 | def train(cls, features, train_test_split_ratio=None, test_sample_count=None, random_state=None): 80 | if train_test_split_ratio is not None or test_sample_count is not None: 81 | train_features, test_features = train_test_split( 82 | features, 83 | train_test_split_ratio=train_test_split_ratio, 84 | test_sample_count=test_sample_count, 85 | random_state=random_state 86 | ) 87 | 88 | model = cls.train(train_features) 89 | 90 | return model, model.score(test_features) 91 | 92 | model = cls() 93 | 94 | model.model.fit(*model._dataframes_from_features(features)) 95 | 96 | return model 97 | 98 | def score(self, features): 99 | return self.model.score(*self._dataframes_from_features(features)) 100 | 101 | @staticmethod 102 | def _chunk_dataframe(dataframe, feature_types): 103 | start = 0 104 | for feature_type in feature_types: 105 | chunk = dataframe.loc[:, start:start + feature_type.feature_count - 1] 106 | chunk.columns = range(len(chunk.columns)) 107 | 108 | yield chunk, feature_type 109 | start += feature_type.feature_count 110 | 111 | def predict(self, input_features, yield_inputs=False): 112 | input_features_dataframe = DataFrame(input_features) 113 | 114 | raw_features = self._to_raw_features(input_features_dataframe, self.Input) 115 | 116 | raw_prediction_dataframe = DataFrame(self.model.predict(raw_features)) 117 | 118 | # Using frame_apply instead of chunk.apply as chunk.apply doesn't behave as expected for 0 columns 119 | prediction_dataframe = concat([ 120 | frame_apply(chunk, feature_type.from_series, axis=1).apply_standard() 121 | for chunk, feature_type in self._chunk_dataframe(raw_prediction_dataframe, self.Output) 122 | ], axis=1) 123 | 124 | if yield_inputs: 125 | for (_, input_series), (_, output_series) in zip( 126 | input_features_dataframe.iterrows(), 127 | prediction_dataframe.iterrows() 128 | ): 129 | yield self.Input(*input_series), self.Output(*output_series) 130 | else: 131 | for _, output_series in prediction_dataframe.iterrows(): 132 | yield self.Output(*output_series) 133 | -------------------------------------------------------------------------------- /smart_fruit/model_selection.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split as sk_train_test_split 2 | 3 | __all__ = ["train_test_split"] 4 | 5 | 6 | def train_test_split(features, train_test_split_ratio=None, test_sample_count=None, random_state=None): 7 | if (train_test_split_ratio is None) == (test_sample_count is None): 8 | raise ValueError( 9 | "Must provide exactly one of train_test_split_ratio or test_sample_count " 10 | "to perform train/test split" 11 | ) 12 | 13 | if train_test_split_ratio is not None: 14 | train_test_split_ratio = float(train_test_split_ratio) 15 | 16 | if train_test_split_ratio <= 0 or train_test_split_ratio >= 1: 17 | raise ValueError( 18 | "train_test_split_ratio must be strictly between 0 and 1 (given {})".format(train_test_split_ratio) 19 | ) 20 | 21 | if test_sample_count is not None: 22 | test_sample_count = round(test_sample_count) 23 | 24 | if test_sample_count <= 0: 25 | raise ValueError( 26 | "test_sample_count must be strictly positive (given {})".format(test_sample_count) 27 | ) 28 | 29 | return sk_train_test_split( 30 | list(features), 31 | test_size=train_test_split_ratio or test_sample_count, 32 | random_state=random_state 33 | ) 34 | -------------------------------------------------------------------------------- /smart_fruit/utils.py: -------------------------------------------------------------------------------- 1 | from csv import reader as csv_reader 2 | from itertools import chain 3 | 4 | __all__ = ["csv_open"] 5 | 6 | 7 | def csv_open(file, expected_columns): 8 | """ 9 | Yields rows of csv file as dictionaries 10 | 11 | Parameters: 12 | file - Path, or file-like object, of the CSV file to use 13 | expected_columns - Columns of the csv file 14 | If the first row of the CSV file are these labels, take the columns in that order 15 | Otherwise, take the columns in the order given by expected_columns 16 | """ 17 | 18 | if isinstance(file, str): 19 | with open(file, encoding='utf-8') as f: 20 | yield from csv_open(f, expected_columns=expected_columns) 21 | return 22 | 23 | expected_columns = tuple(expected_columns) 24 | 25 | csv_iter = csv_reader(file) 26 | 27 | first_row = next(csv_iter) 28 | 29 | if set(first_row) == set(expected_columns): 30 | columns = first_row 31 | else: 32 | columns = expected_columns 33 | csv_iter = chain([first_row], csv_iter) 34 | 35 | for row in csv_iter: 36 | if len(row) < len(columns): 37 | raise IndexError("Too few columns in row {!r}".format(row)) 38 | 39 | yield dict(zip(columns, row)) 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_compound_types.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from smart_fruit import Model 4 | from smart_fruit.feature_types import Number, Label, Vector 5 | 6 | 7 | class TestCompoundTypes(TestCase): 8 | def test_vector(self): 9 | class ExampleModel(Model): 10 | class Input: 11 | a = Number() 12 | 13 | class Output: 14 | b = Vector([ 15 | Number(), 16 | Label(['a', 'b']), 17 | Number() 18 | ]) 19 | 20 | samples = [ 21 | (0, (1, 'a', 2)), 22 | (1, (3, 'b', 4)) 23 | ] 24 | 25 | model = ExampleModel.train(ExampleModel.features_from_list(samples)) 26 | 27 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]])) 28 | 29 | for (sample_input, sample_output), prediction in zip(samples, predictions): 30 | with self.subTest(sample=sample_input): 31 | self.assertAlmostEqual(sample_output[0], prediction.b[0]) 32 | self.assertEqual(sample_output[1], prediction.b[1]) 33 | self.assertAlmostEqual(sample_output[2], prediction.b[2]) 34 | 35 | def test_nested_vectors(self): 36 | class ExampleModel(Model): 37 | class Input: 38 | a = Number() 39 | 40 | class Output: 41 | b = Vector([ 42 | Number(), 43 | Vector([ 44 | Number(), 45 | Label(['a', 'b']) 46 | ]), 47 | Number() 48 | ]) 49 | 50 | samples = [ 51 | (0, (1, (2, 'a'), 3)), 52 | (1, (4, (5, 'b'), 6)) 53 | ] 54 | 55 | model = ExampleModel.train(ExampleModel.features_from_list(samples)) 56 | 57 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]])) 58 | 59 | for (sample_input, sample_output), prediction in zip(samples, predictions): 60 | with self.subTest(sample=sample_input): 61 | self.assertAlmostEqual(sample_output[0], prediction.b[0]) 62 | self.assertAlmostEqual(sample_output[1][0], prediction.b[1][0]) 63 | self.assertEqual(sample_output[1][1], prediction.b[1][1]) 64 | self.assertAlmostEqual(sample_output[2], prediction.b[2]) 65 | 66 | def test_vector_validation(self): 67 | feature_type = Vector([ 68 | Number(), 69 | Label(['a', 'b']) 70 | ]) 71 | 72 | for a in ((1, 'a'), (3, 'b'), (17, 'a')): 73 | with self.subTest(a=a): 74 | self.assertEqual(feature_type.validate(a), a) 75 | 76 | for a in ((1,), (1, 'a', 2), ('a', 1), (1, 1)): 77 | with self.subTest(a=a), \ 78 | self.assertRaises((TypeError, ValueError)): 79 | feature_type.validate(a) 80 | -------------------------------------------------------------------------------- /tests/test_feature_serialization.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | from itertools import product 3 | from unittest import TestCase 4 | 5 | from smart_fruit import Model 6 | from smart_fruit.feature_types import Number, Label 7 | 8 | 9 | class TestFeatureSerialization(TestCase): 10 | class ExampleModel(Model): 11 | class Input: 12 | number = Number() 13 | label = Label(['a', 'b', 'c']) 14 | 15 | class Output: 16 | number_a = Number() 17 | number_b = Number() 18 | 19 | example_input = (1, 'a') 20 | example_output = (1, 2) 21 | 22 | example_json_input = {'number': 1, 'label': 'a'} 23 | example_json_output = {'number_a': 1, 'number_b': 2} 24 | 25 | valid_iterable_inputs = ( 26 | (1, 'a'), 27 | (1, 'b'), 28 | (2, 'a') 29 | ) 30 | 31 | invalid_iterable_inputs = ( 32 | ('a', 'a'), 33 | (float('nan'), 'a'), 34 | (float('inf'), 'a'), 35 | (- float('inf'), 'a'), 36 | 37 | (1, 1), 38 | (1, 'd') 39 | ) 40 | 41 | def test_iterable_deserialization(self): 42 | with self.subTest(feature_type=self.ExampleModel.Input): 43 | feature = self.ExampleModel.Input(*self.example_input) 44 | 45 | self.assertEqual(feature.number, 1) 46 | self.assertEqual(feature.label, 'a') 47 | 48 | with self.subTest(feature_type=self.ExampleModel.Output): 49 | feature = self.ExampleModel.Output(*self.example_output) 50 | 51 | self.assertEqual(feature.number_a, 1) 52 | self.assertEqual(feature.number_b, 2) 53 | 54 | with self.subTest(func=self.ExampleModel.input_features_from_list): 55 | features = list(self.ExampleModel.input_features_from_list([self.example_input])) 56 | 57 | self.assertEqual(features[0].number, 1) 58 | self.assertEqual(features[0].label, 'a') 59 | 60 | with self.subTest(func=self.ExampleModel.features_from_list): 61 | features = list(self.ExampleModel.features_from_list([self.example_input + self.example_output])) 62 | 63 | self.assertEqual(features[0][0].number, 1) 64 | self.assertEqual(features[0][0].label, 'a') 65 | 66 | self.assertEqual(features[0][1].number_a, 1) 67 | self.assertEqual(features[0][1].number_b, 2) 68 | 69 | def test_json_deserialization(self): 70 | with self.subTest(feature_type=self.ExampleModel.Input): 71 | feature = self.ExampleModel.Input.from_json(self.example_json_input) 72 | 73 | self.assertEqual(feature.number, 1) 74 | self.assertEqual(feature.label, 'a') 75 | 76 | with self.subTest(feature_type=self.ExampleModel.Output): 77 | feature = self.ExampleModel.Output.from_json(self.example_json_output) 78 | 79 | self.assertEqual(feature.number_a, 1) 80 | self.assertEqual(feature.number_b, 2) 81 | 82 | with self.subTest(func=self.ExampleModel.input_features_from_json): 83 | features = list(self.ExampleModel.input_features_from_json([self.example_json_input])) 84 | 85 | self.assertEqual(features[0].number, 1) 86 | self.assertEqual(features[0].label, 'a') 87 | 88 | with self.subTest(func=self.ExampleModel.features_from_json): 89 | features = list(self.ExampleModel.features_from_json([ 90 | {**self.example_json_input, **self.example_json_output} 91 | ])) 92 | 93 | self.assertEqual(features[0][0].number, 1) 94 | self.assertEqual(features[0][0].label, 'a') 95 | 96 | self.assertEqual(features[0][1].number_a, 1) 97 | self.assertEqual(features[0][1].number_b, 2) 98 | 99 | def test_csv_deserialization(self): 100 | with self.subTest(func=self.ExampleModel.input_features_from_csv): 101 | features = list(self.ExampleModel.input_features_from_csv( 102 | StringIO(",".join(map(str, self.example_input))) 103 | )) 104 | 105 | self.assertEqual(features[0].number, 1) 106 | self.assertEqual(features[0].label, 'a') 107 | 108 | with self.subTest(func=self.ExampleModel.features_from_csv): 109 | features = list(self.ExampleModel.features_from_csv( 110 | StringIO(",".join(map(str, self.example_input + self.example_output))) 111 | )) 112 | 113 | self.assertEqual(features[0][0].number, 1) 114 | self.assertEqual(features[0][0].label, 'a') 115 | 116 | self.assertEqual(features[0][1].number_a, 1) 117 | self.assertEqual(features[0][1].number_b, 2) 118 | 119 | def test_feature_equality(self): 120 | for a, b in product(self.valid_iterable_inputs, repeat=2): 121 | with self.subTest(a=a, b=b): 122 | self.assertEqual( 123 | self.ExampleModel.Input(*a) == self.ExampleModel.Input(*b), 124 | a == b 125 | ) 126 | 127 | def test_iterable_serialization(self): 128 | with self.subTest(feature_type=self.ExampleModel.Input): 129 | feature = self.ExampleModel.Input(*self.example_input) 130 | self.assertEqual(tuple(feature), self.example_input) 131 | 132 | with self.subTest(feature_type=self.ExampleModel.Output): 133 | feature = self.ExampleModel.Output(*self.example_output) 134 | self.assertEqual(tuple(feature), self.example_output) 135 | 136 | def test_json_serialization(self): 137 | with self.subTest(feature_type=self.ExampleModel.Input): 138 | feature = self.ExampleModel.Input(*self.example_input) 139 | self.assertEqual(dict(feature.to_json()), self.example_json_input) 140 | 141 | with self.subTest(feature_type=self.ExampleModel.Output): 142 | feature = self.ExampleModel.Output(*self.example_output) 143 | self.assertEqual(dict(feature.to_json()), self.example_json_output) 144 | 145 | def test_value_coercion(self): 146 | self.assertEqual( 147 | self.ExampleModel.Input('1', 'a').validate(), 148 | self.ExampleModel.Input(1, 'a') 149 | ) 150 | 151 | def test_invalid_deserialization(self): 152 | for invalid_input in self.invalid_iterable_inputs: 153 | with self.subTest(invalid_input=invalid_input), \ 154 | self.assertRaises((TypeError, ValueError)): 155 | self.ExampleModel.Input(*invalid_input).validate() 156 | -------------------------------------------------------------------------------- /tests/test_label_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from unittest import TestCase 3 | 4 | from pandas import Series 5 | 6 | from smart_fruit import Model 7 | from smart_fruit.feature_types import Label 8 | 9 | 10 | class TestLabelTypes(TestCase): 11 | def _test_labels(self, labels): 12 | class ExampleModel(Model): 13 | class Input: 14 | label = Label(labels) 15 | 16 | with self.subTest("Label to pandas Series"): 17 | for i, label in enumerate(labels): 18 | self.assertTrue( 19 | ExampleModel.Input.label.to_series(label).equals( 20 | self._basis_vector(len(labels), i) 21 | ) 22 | ) 23 | 24 | with self.subTest("Pandas Series to label"): 25 | for i, label in enumerate(labels): 26 | self.assertEqual( 27 | ExampleModel.Input.label.from_series(self._basis_vector(len(labels), i)), 28 | label 29 | ) 30 | 31 | @staticmethod 32 | def _basis_vector(length, index): 33 | vector = Series([0 for _ in range(length)]) 34 | vector[index] = 1 35 | return vector 36 | 37 | def test_string_labels(self): 38 | self._test_labels(['a', 'b', 'c']) 39 | 40 | def test_number_labels(self): 41 | self._test_labels([1, 17, 3.1415]) 42 | 43 | def test_enum_labels(self): 44 | class Colours(Enum): 45 | red = 0 46 | green = 1 47 | blue = 2 48 | 49 | self._test_labels(Colours) 50 | -------------------------------------------------------------------------------- /tests/test_model_training.py: -------------------------------------------------------------------------------- 1 | from random import Random 2 | from unittest import TestCase 3 | 4 | from numpy import median 5 | 6 | from sklearn import linear_model 7 | 8 | from examples.trivial_model import TrivialModel 9 | 10 | 11 | class TestModelTraining(TestCase): 12 | @staticmethod 13 | def _train_test_split(func, model_class=TrivialModel, **kwargs): 14 | model, score = model_class.train( 15 | model_class.features_from_list((n, func(n)) for n in range(20)), 16 | **kwargs 17 | ) 18 | 19 | return score 20 | 21 | def _median_train_test_split(self, *args, attempts=3, **kwargs): 22 | return median([ 23 | self._train_test_split(*args, **kwargs) 24 | for _ in range(attempts) 25 | ]) 26 | 27 | def test_train_test_split_perfect(self): 28 | self.assertEqual( 29 | self._train_test_split(lambda n: 10 * n, train_test_split_ratio=0.2), 30 | 1 31 | ) 32 | 33 | def test_train_test_split_noisy(self): 34 | random = Random(0).random 35 | 36 | for n in range(1, 10): 37 | with self.subTest(train_test_split_ratio=n / 10): 38 | score = self._median_train_test_split( 39 | lambda m: m + 5 * random(), 40 | train_test_split_ratio=n / 10, 41 | random_state=0, 42 | attempts=5 43 | ) 44 | 45 | self.assertGreater(score, 0) 46 | self.assertLess(score, 1) 47 | 48 | def test_train_test_split_pure_noise(self): 49 | random = Random(0).random 50 | 51 | score = self._median_train_test_split( 52 | lambda n: random(), 53 | train_test_split_ratio=0.2, 54 | random_state=0, 55 | attempts=5 56 | ) 57 | 58 | self.assertLess(score, 0) 59 | 60 | def test_train_test_split_errors(self): 61 | for train_test_split_ratio in (-1, -0.5, 0, 1, 1.5, 2): 62 | with self.subTest(train_test_split_ratio=train_test_split_ratio), \ 63 | self.assertRaises(ValueError): 64 | self._train_test_split(lambda n: 10 * n, train_test_split_ratio=train_test_split_ratio) 65 | 66 | for test_sample_count in (-1, 0): 67 | with self.subTest(test_sample_count=test_sample_count), \ 68 | self.assertRaises(ValueError): 69 | self._train_test_split(lambda n: 10 * n, test_sample_count=test_sample_count) 70 | 71 | def test_custom_model_class(self): 72 | class HuberModel(TrivialModel): 73 | model_class = linear_model.HuberRegressor 74 | 75 | almost_straight_line = lambda n: 10 * n if n else 10 76 | 77 | self.assertLess( 78 | self._median_train_test_split( 79 | almost_straight_line, 80 | model_class=TrivialModel, 81 | train_test_split_ratio=0.2 82 | ), 83 | self._median_train_test_split( 84 | almost_straight_line, 85 | model_class=HuberModel, 86 | train_test_split_ratio=0.2, 87 | ) 88 | ) 89 | -------------------------------------------------------------------------------- /tests/test_simple_types.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from smart_fruit import Model 4 | from smart_fruit.feature_types import Number, Integer, Complex, Label, Tag 5 | 6 | 7 | class TestSimpleTypes(TestCase): 8 | def test_number(self): 9 | class ExampleModel(Model): 10 | class Input: 11 | a = Number() 12 | 13 | class Output: 14 | b = Number() 15 | 16 | model = ExampleModel.train(ExampleModel.features_from_list([ 17 | (0, 0), 18 | (1, 10) 19 | ])) 20 | 21 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]])) 22 | 23 | for n, output in zip((0, 1), predictions): 24 | with self.subTest(n=n): 25 | self.assertAlmostEqual(output.b, 10 * n) 26 | 27 | def test_number_validation(self): 28 | feature_type = Number() 29 | 30 | for n in (0, 1, 3.141592, -17): 31 | with self.subTest(n=n): 32 | self.assertEqual(feature_type.validate(n), n) 33 | 34 | for a, b in (("1", 1),): 35 | with self.subTest(a=a): 36 | self.assertEqual(feature_type.validate(a), b) 37 | 38 | for n in (1j, float("nan"), float("inf"), -float("inf"), "a"): 39 | with self.subTest(n=n), \ 40 | self.assertRaises((TypeError, ValueError)): 41 | feature_type.validate(n) 42 | 43 | def test_integer(self): 44 | class ExampleModel(Model): 45 | class Input: 46 | a = Number() 47 | 48 | class Output: 49 | b = Integer() 50 | 51 | model = ExampleModel.train(ExampleModel.features_from_list([ 52 | (0, 0), 53 | (1, 10) 54 | ])) 55 | 56 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [0.01], [0.99], [1]])) 57 | 58 | for a, n, output in zip((0, 0.01, 0.99, 1), (0, 0, 10, 10), predictions): 59 | with self.subTest(a=a): 60 | self.assertEqual(output.b, n) 61 | 62 | def test_integer_validation(self): 63 | feature_type = Integer() 64 | 65 | for a, b in ((0, 0), (1, 1), (3.141592, 3), (-17, -17)): 66 | with self.subTest(a=a): 67 | self.assertEqual(feature_type.validate(a), b) 68 | 69 | for a in (1j, float("nan"), float("inf"), -float("inf"), "a"): 70 | with self.subTest(a=a), \ 71 | self.assertRaises((TypeError, ValueError)): 72 | feature_type.validate(a) 73 | 74 | def test_complex(self): 75 | class ExampleModel(Model): 76 | class Input: 77 | a = Number() 78 | 79 | class Output: 80 | b = Complex() 81 | 82 | model = ExampleModel.train(ExampleModel.features_from_list([ 83 | (0, 1), 84 | (1, 1j) 85 | ])) 86 | 87 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [0.5], [1]])) 88 | 89 | for a, b, output in zip((0, 0.5, 1), (1, 0.5 + 0.5j, 1j), predictions): 90 | with self.subTest(a=a): 91 | self.assertAlmostEqual(output.b, b) 92 | 93 | def test_complex_validation(self): 94 | feature_type = Complex() 95 | 96 | for a in (0, 1, 3 + 4j, -1 + 7j): 97 | with self.subTest(a=a): 98 | self.assertEqual(feature_type.validate(a), a) 99 | 100 | for a in (float("nan"), float("inf"), -float("inf"), "a"): 101 | with self.subTest(a=a), \ 102 | self.assertRaises((TypeError, ValueError)): 103 | feature_type.validate(a) 104 | 105 | def test_label(self): 106 | class ExampleModel(Model): 107 | class Input: 108 | a = Number() 109 | 110 | class Output: 111 | b = Label(['a', 'b']) 112 | 113 | model = ExampleModel.train(ExampleModel.features_from_list([ 114 | (0, 'a'), 115 | (1, 'b') 116 | ])) 117 | 118 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]])) 119 | 120 | for n, label, output in zip((0, 1), ('a', 'b'), predictions): 121 | with self.subTest(n=n, label=label): 122 | self.assertEqual(output.b, label) 123 | 124 | def test_label_validation(self): 125 | ob = object() 126 | bad_ob = object() 127 | 128 | feature_type = Label(['a', 'b', 0, ob]) 129 | 130 | for a in ('a', 'b', 0, ob): 131 | with self.subTest(a=a): 132 | self.assertEqual(feature_type.validate(a), a) 133 | 134 | for a in ('c', 1, bad_ob): 135 | with self.subTest(a=a), \ 136 | self.assertRaises(TypeError): 137 | feature_type.validate(a) 138 | 139 | def test_tag_input(self): 140 | class ExampleModel(Model): 141 | class Input: 142 | a = Number() 143 | b = Tag() 144 | 145 | class Output: 146 | c = Number() 147 | 148 | sample_inputs = [ 149 | (0, 0), 150 | (1, 1), 151 | (0, 1), 152 | (1, 0), 153 | (0, "a"), 154 | (0, object()) 155 | ] 156 | 157 | model = ExampleModel.train(ExampleModel.features_from_list([ 158 | (0, 0, 0), 159 | (1, 1, 1) 160 | ])) 161 | 162 | predictions = model.predict(ExampleModel.input_features_from_list(sample_inputs)) 163 | 164 | for sample_input, prediction in zip(sample_inputs, predictions): 165 | with self.subTest(sample=sample_input): 166 | self.assertAlmostEqual(prediction.c, sample_input[0]) 167 | 168 | def test_tag_output(self): 169 | class ExampleModel(Model): 170 | class Input: 171 | a = Number() 172 | 173 | class Output: 174 | b = Tag() 175 | c = Number() 176 | 177 | model = ExampleModel.train(ExampleModel.features_from_list([ 178 | (0, 0, 0), 179 | (1, 1, 1) 180 | ])) 181 | 182 | with self.assertRaisesRegex(TypeError, "May not predict a Tag"): 183 | next(model.predict([ExampleModel.Input(0)])) 184 | 185 | def test_multiple_types_in_single_model(self): 186 | class ExampleModel(Model): 187 | class Input: 188 | a = Number() 189 | b = Number() 190 | 191 | class Output: 192 | c = Number() 193 | d = Number() 194 | e = Label(['a', 'b']) 195 | f = Complex() 196 | g = Number() 197 | 198 | samples = [ 199 | (0, 0, 0, 0, 'a', 0, 0), 200 | (0, 1, 0, 3, 'b', 1j, 0), 201 | (1, 0, 2, 0, 'a', 1, 0), 202 | (1, 1, 2, 3, 'b', 1 + 1j, 0) 203 | ] 204 | 205 | model = ExampleModel.train(ExampleModel.features_from_list(samples)) 206 | 207 | predictions = model.predict(ExampleModel.input_features_from_list([ 208 | sample[:2] for sample in samples 209 | ])) 210 | 211 | for sample, prediction in zip(samples, predictions): 212 | with self.subTest(sample=sample): 213 | self.assertAlmostEqual(sample[2], prediction.c) 214 | self.assertAlmostEqual(sample[3], prediction.d) 215 | self.assertEqual(sample[4], prediction.e) 216 | self.assertAlmostEqual(sample[5], prediction.f) 217 | self.assertAlmostEqual(sample[6], prediction.g) 218 | -------------------------------------------------------------------------------- /tests/test_trivial_model.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from examples.trivial_model import TrivialModel 4 | 5 | 6 | class TestTrivialModel(TestCase): 7 | def _test_exact_relationship(self, func): 8 | model = TrivialModel.train(TrivialModel.features_from_list( 9 | (n, func(n)) 10 | for n in range(20) 11 | )) 12 | 13 | test_values = (-5, 3, 17, 30, 3.1415) 14 | predictions = model.predict(TrivialModel.Output(n) for n in test_values) 15 | 16 | for n, prediction in zip(test_values, predictions): 17 | with self.subTest(n=n): 18 | self.assertIsInstance(prediction, TrivialModel.Output) 19 | self.assertAlmostEqual(prediction.output, func(n)) 20 | 21 | def test_linear_relationship(self): 22 | self._test_exact_relationship(lambda n: 10 * n) 23 | 24 | def test_affine_relationship(self): 25 | self._test_exact_relationship(lambda n: 10 * n + 1) 26 | 27 | def test_quadratic_relationship(self): 28 | model = TrivialModel.train(TrivialModel.features_from_list( 29 | (n, n ** 2) 30 | for n in range(20) 31 | )) 32 | 33 | test_values = (0, 3, 10, 17, 20) 34 | predictions = model.predict(TrivialModel.Output(n) for n in test_values) 35 | 36 | for n, prediction in zip(test_values, predictions): 37 | with self.subTest(n=n): 38 | self.assertIsInstance(prediction, TrivialModel.Output) 39 | 40 | # Check that it's not exact, but still in the ballpark 41 | self.assertNotAlmostEqual(prediction.output, n ** 2, places=0) 42 | self.assertAlmostEqual(prediction.output, n ** 2, places=-3) 43 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/tests/test_utils/__init__.py -------------------------------------------------------------------------------- /tests/test_utils/example_csv.csv: -------------------------------------------------------------------------------- 1 | a,b,c 2 | 1,2,3 3 | 4,5,6 4 | α,β,γ 5 | -------------------------------------------------------------------------------- /tests/test_utils/test_csv_open.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | from unittest import TestCase 3 | 4 | from smart_fruit.utils import csv_open 5 | 6 | 7 | class TestCSVOpen(TestCase): 8 | test_csv_path = "tests/test_utils/example_csv.csv" 9 | test_csv_columns = ('a', 'b', 'c') 10 | test_csv_response = [ 11 | {'a': '1', 'b': '2', 'c': '3'}, 12 | {'a': '4', 'b': '5', 'c': '6'}, 13 | {'a': 'α', 'b': 'β', 'c': 'γ'} 14 | ] 15 | 16 | def test_opens_csv_paths(self): 17 | self.assertEqual( 18 | list(csv_open(self.test_csv_path, self.test_csv_columns)), 19 | self.test_csv_response 20 | ) 21 | 22 | def test_opens_csv_file_handles(self): 23 | with open(self.test_csv_path, encoding='utf-8') as csv_file: 24 | self.assertEqual( 25 | list(csv_open(csv_file, self.test_csv_columns)), 26 | self.test_csv_response 27 | ) 28 | 29 | def test_no_given_columns(self): 30 | self.assertEqual( 31 | list(csv_open(StringIO("1,2,3\n4,5,6\nα,β,γ"), self.test_csv_columns)), 32 | self.test_csv_response 33 | ) 34 | 35 | def test_different_column_order(self): 36 | self.assertEqual( 37 | list(csv_open(self.test_csv_path, ('b', 'a', 'c'))), 38 | self.test_csv_response 39 | ) 40 | 41 | def test_missing_columns(self): 42 | with self.assertRaises(IndexError): 43 | list(csv_open(StringIO("1,2"), self.test_csv_columns)) 44 | --------------------------------------------------------------------------------