├── .gitignore
├── .travis.yml
├── CHANGELOG.rst
├── LICENSE.txt
├── README.rst
├── docs
└── logo
│ ├── strawberry-large.png
│ ├── strawberry-medium.png
│ └── strawberry-small.png
├── examples
├── iris.py
├── iris_data.csv
└── trivial_model.py
├── pypi_upload
├── .pypirc
├── pypi_upload.sh
└── setup.py
├── requirements.txt
├── smart_fruit
├── __init__.py
├── feature_class.py
├── feature_types
│ ├── __init__.py
│ ├── compound_types.py
│ ├── feature_type_base.py
│ └── simple_types.py
├── model.py
├── model_selection.py
└── utils.py
└── tests
├── __init__.py
├── test_compound_types.py
├── test_feature_serialization.py
├── test_label_types.py
├── test_model_training.py
├── test_simple_types.py
├── test_trivial_model.py
└── test_utils
├── __init__.py
├── example_csv.csv
└── test_csv_open.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # PyPI build files
2 | pypi_upload/.pypirc
3 | build
4 | dist
5 | smart_fruit.egg-info
6 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | cache: pip
3 | python:
4 | - "3.6"
5 | install:
6 | - pip install -r requirements.txt
7 | script:
8 | - python -m unittest
9 |
--------------------------------------------------------------------------------
/CHANGELOG.rst:
--------------------------------------------------------------------------------
1 | Changelog
2 | =========
3 |
4 | 1.2.1
5 | -----
6 |
7 | Bug fixes:
8 |
9 | - Fix broken build.
10 |
11 | 1.2.0
12 | -----
13 |
14 | Features:
15 |
16 | - When reading CSV files, assume columns in same order as defined in class if not given in file.
17 | - Add the ``Integer``, ``Complex``, ``Vector``, and ``Tag`` feature types.
18 | - Add the ``Model.predict`` ``yield_inputs`` parameter.
19 | - Add the ``Model.train`` ``random_state`` parameter.
20 |
21 | Bug fixes:
22 |
23 | - Support unicode in CSV files.
24 | - Fix bug which raised an error when predicting a ``Number`` that wasn't the first feature in ``Output``.
25 |
26 | 1.1.0
27 | -----
28 |
29 | Features:
30 |
31 | - Add basic feature type validation, and coercion, function.
32 | - Add test/train split parameters to ``Model.train``.
33 |
34 | Bug fixes:
35 |
36 | - Allow use of non-orderable types as ``Label`` labels.
37 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 madman-bob
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | Smart Fruit
2 | ===========
3 |
4 | Purpose
5 | -------
6 |
7 | A Python machine learning library, for creating quick and easy machine learning models.
8 | It is schema-based, and wraps `scikit-learn `_.
9 |
10 | Usage
11 | -----
12 |
13 | Create and use a machine learning model in 3 steps:
14 |
15 | 1. Create a schema representing your input and output features.
16 | 2. Train a model from your data.
17 | 3. Make predictions from your model.
18 |
19 | Example
20 | -------
21 |
22 | To get a feel for the library, consider the classic `Iris `_ dataset,
23 | where we predict the class of iris plant from measurements of the sepal, and petal.
24 |
25 | First, we create a schema describing our inputs and outputs.
26 | For our inputs, we have the length, and width, of both the sepal, and the petal.
27 | All of these input values happen to be numbers.
28 | For our output, we have just the class of iris, which may be one of the labels ``Iris-setosa``, ``Iris-versicolor``, or ``Iris-virginica``.
29 |
30 | We define this in code as follows:
31 |
32 | .. code:: python
33 |
34 | from smart_fruit import Model
35 | from smart_fruit.feature_types import Number, Label
36 |
37 |
38 | class Iris(Model):
39 | class Input:
40 | sepal_length_cm = Number()
41 | sepal_width_cm = Number()
42 | petal_length_cm = Number()
43 | petal_width_cm = Number()
44 |
45 | class Output:
46 | iris_class = Label(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])
47 |
48 | Then, we train a model:
49 |
50 | .. code:: python
51 |
52 | model = Iris.train(Iris.features_from_csv('iris_data.csv'))
53 |
54 | with data file `iris_data.csv `_.
55 |
56 | ::
57 |
58 | sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm,iris_class
59 | 5.1,3.5,1.4,0.2,Iris-setosa
60 | ...
61 |
62 | Finally, we use our new model to make predictions:
63 |
64 | .. code:: python
65 |
66 | for prediction in model.predict([Iris.Input(5.1, 3.5, 1.4, 0.2)]):
67 | print(prediction.iris_class)
68 |
69 | Reference
70 | ---------
71 |
72 | Models
73 | ~~~~~~
74 |
75 | - ``Model.Input`` - Schema for defining your input features.
76 |
77 | - ``Model.Output`` - Schema for defining your output features.
78 |
79 | Define ``Model.Input`` and ``Model.Output`` as classes with ``FeatureType`` attributes.
80 |
81 | eg. Consider the ``Iris`` class defined above.
82 |
83 | These classes can then be used to create objects representing the appropriate collections of features.
84 |
85 | eg.
86 |
87 | .. code:: python
88 |
89 | >>> iris_input = Iris.Input(5.1, 3.5, 1.4, 0.2)
90 | >>> iris_input
91 | Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)
92 | >>> iris_input.sepal_length
93 | 5.1
94 |
95 | >>> Iris.Input.from_json({'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2})
96 | Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)
97 |
98 | - ``Model.features_from_list(lists)`` - Deserialize an iterable of lists into an iterable of input/output feature pairs.
99 |
100 | eg.
101 |
102 | .. code:: python
103 |
104 | >>> list(Iris.features_from_list([[5.1, 3.5, 1.4, 0.2, 'Iris-setosa']]))
105 | [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa'))]
106 |
107 | - ``Model.input_features_from_list(lists)`` - Deserialize an iterable of lists into an iterable of input features.
108 |
109 | eg.
110 |
111 | .. code:: python
112 |
113 | >>> list(Iris.input_features_from_list([[5.1, 3.5, 1.4, 0.2]]))
114 | [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)]
115 |
116 | - ``Model.features_from_json(json)`` - Deserialize an iterable of dictionaries into an iterable of input/output feature pairs.
117 |
118 | eg.
119 |
120 | .. code:: python
121 |
122 | >>> list(Iris.features_from_json([{'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2, 'iris_class': 'Iris-setosa'}]))
123 | [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa'))]
124 |
125 | - ``Model.input_features_from_json(json)`` - Deserialize an iterable of dictionaries into an iterable of input features.
126 |
127 | eg.
128 |
129 | .. code:: python
130 |
131 | >>> list(Iris.input_features_from_json([{'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2}]))
132 | [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)]
133 |
134 | - ``Model.features_from_csv(csv_path)`` - Take a path to a CSV file, or a file-like object, and deserialize it into an iterable of input/output feature pairs.
135 |
136 | If column headings are not given in the file, assume the input features are followed by the output features, in the order they are defined in their respective classes.
137 |
138 | eg.
139 |
140 | .. code:: python
141 |
142 | >>> list(Iris.features_from_csv('iris_data.csv'))
143 | [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa')), ...]
144 |
145 | - ``Model.input_features_from_csv(csv_path)`` - Take a path to a CSV file, or a file-like object, and deserialize it into an iterable of input features.
146 |
147 | If column headings are not given in the file, assume they are in the order they are defined in the ``Input`` class.
148 |
149 | eg.
150 |
151 | .. code:: python
152 |
153 | >>> list(Iris.input_features_from_csv('iris_data.csv'))
154 | [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), ...]
155 |
156 | - ``Model.model_class`` - How to model the relation between the input and output data.
157 |
158 | Default: ``sklearn.linear_model.LinearRegression``
159 |
160 | This attribute accepts any class with ``fit``, ``predict``, and ``score`` methods defined as for ``scikit-learn`` multi-response regression models.
161 | In particular, this attribute accepts any ``scikit-learn`` multi-response regression models,
162 | ie. any ``scikit-learn`` regression model where the ``y`` parameter of ``fit`` accepts a numpy array of shape ``[n_samples, n_targets]``.
163 |
164 | - ``Model.train(features, train_test_split_ratio=None, test_sample_count=None, random_state=None)``
165 |
166 | Train a new model on the given iterable of input/output pairs.
167 |
168 | Parameters:
169 |
170 | - ``features`` - An iterable of input/output pairs.
171 |
172 | - ``train_test_split_ratio`` - Proportion of data to use as cross-validation test data.
173 |
174 | - ``test_sample_count`` - Number of samples of data to use as cross-validation test data.
175 |
176 | If ``train_test_split_ratio`` or ``test_sample_count`` are provided, perform cross-validation of the given data.
177 | Return both the trained model, and the score of the test data on that model.
178 |
179 | - ``random_state`` - Either a ``numpy`` ``RandomState``, or the seed to use for the PRNG.
180 |
181 | Useful for getting consistent results, for example for automated tests.
182 | Do not use this parameter when generating models you plan to use in production settings.
183 |
184 | eg.
185 |
186 | .. code:: python
187 |
188 | >>> iris_model = Iris.train([(Iris.Input(5.1, 3.5, 1.4, 0.2), Iris.Output('Iris-setosa'))])
189 |
190 | - ``model.predict(input_features, yield_inputs=False)`` - Predict the outputs for a given iterable of inputs.
191 |
192 | If ``yield_inputs`` is ``True`` then yield the prediction with the input used to generate it, as ``input``, ``output`` pairs.
193 | Otherwise, yield just the predictions, in the same order the inputs are given to the model.
194 |
195 | eg.
196 |
197 | .. code:: python
198 |
199 | >>> list(iris_model.predict([Iris.Input(5.1, 3.5, 1.4, 0.2)]))
200 | [Output(iris_class='Iris-setosa')]
201 |
202 | Feature Types
203 | ~~~~~~~~~~~~~
204 |
205 | Smart Fruit recognizes the following data types for input and output features.
206 | Custom types may be made by extending the ``FeatureType`` class.
207 |
208 | - ``Number()`` - A real-valued feature.
209 |
210 | eg. ``0``, ``1``, ``3.141592``, ``-17``, ...
211 |
212 | - ``Integer()`` - A whole number feature.
213 |
214 | eg. ``0``, ``1``, ``3``, ``-17``, ...
215 |
216 | - ``Complex()`` - A complex-valued number feature.
217 |
218 | eg. ``0``, ``1``, ``3 + 4j``, ``-1 + 7j``, ...
219 |
220 | - ``Label(labels)`` - An enumerated feature, ie. one which may take one of a pre-defined list of available values.
221 |
222 | eg. For ``labels = ['red', 'green', 'blue']``, our label may take the value ``'red'``, but not ``'purple'``.
223 |
224 | - ``Vector(feature_types)`` - A feature made of other features. Useful for grouping conceptually related features.
225 |
226 | eg. For ``feature_types = [Number(), Label(['red', 'green', 'blue'])]``, we may take values such as ``(0, 'red')``, and ``(1, 'blue')``.
227 |
228 | - ``Tag()`` - A feature that is ignored when making predictions. Useful for keeping track of ID numbers.
229 |
230 | Accepts any Python value.
231 |
232 | Requirements
233 | ------------
234 |
235 | Smart Fruit requires Python 3.6+, and uses
236 | `scikit-learn `_,
237 | `scipy `_,
238 | and `pandas `_.
239 |
240 | Installation
241 | ------------
242 |
243 | Install and update using the standard Python package manager `pip `_:
244 |
245 | .. code:: text
246 |
247 | pip install smart-fruit
248 |
249 | Donate
250 | ------
251 |
252 | To support the continued development of Smart Fruit, please
253 | `donate `_.
254 |
--------------------------------------------------------------------------------
/docs/logo/strawberry-large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/docs/logo/strawberry-large.png
--------------------------------------------------------------------------------
/docs/logo/strawberry-medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/docs/logo/strawberry-medium.png
--------------------------------------------------------------------------------
/docs/logo/strawberry-small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/docs/logo/strawberry-small.png
--------------------------------------------------------------------------------
/examples/iris.py:
--------------------------------------------------------------------------------
1 | from smart_fruit import Model
2 | from smart_fruit.feature_types import Number, Label
3 |
4 |
5 | class Iris(Model):
6 | """
7 | Example class for the "Iris" data set:
8 |
9 | https://archive.ics.uci.edu/ml/datasets/Iris
10 | """
11 |
12 | class Input:
13 | sepal_length_cm = Number()
14 | sepal_width_cm = Number()
15 | petal_length_cm = Number()
16 | petal_width_cm = Number()
17 |
18 | class Output:
19 | iris_class = Label(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])
20 |
21 |
22 | def main():
23 | features = list(Iris.features_from_csv('iris_data.csv'))
24 |
25 | model = Iris.train(features)
26 |
27 | input_features = [input_feature for input_feature, output_feature in features]
28 |
29 | for (input_, output), predicted_output in zip(features, model.predict(input_features)):
30 | print(list(input_), output.iris_class, predicted_output.iris_class)
31 |
32 |
33 | if __name__ == "__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/examples/iris_data.csv:
--------------------------------------------------------------------------------
1 | sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm,iris_class
2 | 5.1,3.5,1.4,0.2,Iris-setosa
3 | 4.9,3.0,1.4,0.2,Iris-setosa
4 | 4.7,3.2,1.3,0.2,Iris-setosa
5 | 4.6,3.1,1.5,0.2,Iris-setosa
6 | 5.0,3.6,1.4,0.2,Iris-setosa
7 | 5.4,3.9,1.7,0.4,Iris-setosa
8 | 4.6,3.4,1.4,0.3,Iris-setosa
9 | 5.0,3.4,1.5,0.2,Iris-setosa
10 | 4.4,2.9,1.4,0.2,Iris-setosa
11 | 4.9,3.1,1.5,0.1,Iris-setosa
12 | 5.4,3.7,1.5,0.2,Iris-setosa
13 | 4.8,3.4,1.6,0.2,Iris-setosa
14 | 4.8,3.0,1.4,0.1,Iris-setosa
15 | 4.3,3.0,1.1,0.1,Iris-setosa
16 | 5.8,4.0,1.2,0.2,Iris-setosa
17 | 5.7,4.4,1.5,0.4,Iris-setosa
18 | 5.4,3.9,1.3,0.4,Iris-setosa
19 | 5.1,3.5,1.4,0.3,Iris-setosa
20 | 5.7,3.8,1.7,0.3,Iris-setosa
21 | 5.1,3.8,1.5,0.3,Iris-setosa
22 | 5.4,3.4,1.7,0.2,Iris-setosa
23 | 5.1,3.7,1.5,0.4,Iris-setosa
24 | 4.6,3.6,1.0,0.2,Iris-setosa
25 | 5.1,3.3,1.7,0.5,Iris-setosa
26 | 4.8,3.4,1.9,0.2,Iris-setosa
27 | 5.0,3.0,1.6,0.2,Iris-setosa
28 | 5.0,3.4,1.6,0.4,Iris-setosa
29 | 5.2,3.5,1.5,0.2,Iris-setosa
30 | 5.2,3.4,1.4,0.2,Iris-setosa
31 | 4.7,3.2,1.6,0.2,Iris-setosa
32 | 4.8,3.1,1.6,0.2,Iris-setosa
33 | 5.4,3.4,1.5,0.4,Iris-setosa
34 | 5.2,4.1,1.5,0.1,Iris-setosa
35 | 5.5,4.2,1.4,0.2,Iris-setosa
36 | 4.9,3.1,1.5,0.1,Iris-setosa
37 | 5.0,3.2,1.2,0.2,Iris-setosa
38 | 5.5,3.5,1.3,0.2,Iris-setosa
39 | 4.9,3.1,1.5,0.1,Iris-setosa
40 | 4.4,3.0,1.3,0.2,Iris-setosa
41 | 5.1,3.4,1.5,0.2,Iris-setosa
42 | 5.0,3.5,1.3,0.3,Iris-setosa
43 | 4.5,2.3,1.3,0.3,Iris-setosa
44 | 4.4,3.2,1.3,0.2,Iris-setosa
45 | 5.0,3.5,1.6,0.6,Iris-setosa
46 | 5.1,3.8,1.9,0.4,Iris-setosa
47 | 4.8,3.0,1.4,0.3,Iris-setosa
48 | 5.1,3.8,1.6,0.2,Iris-setosa
49 | 4.6,3.2,1.4,0.2,Iris-setosa
50 | 5.3,3.7,1.5,0.2,Iris-setosa
51 | 5.0,3.3,1.4,0.2,Iris-setosa
52 | 7.0,3.2,4.7,1.4,Iris-versicolor
53 | 6.4,3.2,4.5,1.5,Iris-versicolor
54 | 6.9,3.1,4.9,1.5,Iris-versicolor
55 | 5.5,2.3,4.0,1.3,Iris-versicolor
56 | 6.5,2.8,4.6,1.5,Iris-versicolor
57 | 5.7,2.8,4.5,1.3,Iris-versicolor
58 | 6.3,3.3,4.7,1.6,Iris-versicolor
59 | 4.9,2.4,3.3,1.0,Iris-versicolor
60 | 6.6,2.9,4.6,1.3,Iris-versicolor
61 | 5.2,2.7,3.9,1.4,Iris-versicolor
62 | 5.0,2.0,3.5,1.0,Iris-versicolor
63 | 5.9,3.0,4.2,1.5,Iris-versicolor
64 | 6.0,2.2,4.0,1.0,Iris-versicolor
65 | 6.1,2.9,4.7,1.4,Iris-versicolor
66 | 5.6,2.9,3.6,1.3,Iris-versicolor
67 | 6.7,3.1,4.4,1.4,Iris-versicolor
68 | 5.6,3.0,4.5,1.5,Iris-versicolor
69 | 5.8,2.7,4.1,1.0,Iris-versicolor
70 | 6.2,2.2,4.5,1.5,Iris-versicolor
71 | 5.6,2.5,3.9,1.1,Iris-versicolor
72 | 5.9,3.2,4.8,1.8,Iris-versicolor
73 | 6.1,2.8,4.0,1.3,Iris-versicolor
74 | 6.3,2.5,4.9,1.5,Iris-versicolor
75 | 6.1,2.8,4.7,1.2,Iris-versicolor
76 | 6.4,2.9,4.3,1.3,Iris-versicolor
77 | 6.6,3.0,4.4,1.4,Iris-versicolor
78 | 6.8,2.8,4.8,1.4,Iris-versicolor
79 | 6.7,3.0,5.0,1.7,Iris-versicolor
80 | 6.0,2.9,4.5,1.5,Iris-versicolor
81 | 5.7,2.6,3.5,1.0,Iris-versicolor
82 | 5.5,2.4,3.8,1.1,Iris-versicolor
83 | 5.5,2.4,3.7,1.0,Iris-versicolor
84 | 5.8,2.7,3.9,1.2,Iris-versicolor
85 | 6.0,2.7,5.1,1.6,Iris-versicolor
86 | 5.4,3.0,4.5,1.5,Iris-versicolor
87 | 6.0,3.4,4.5,1.6,Iris-versicolor
88 | 6.7,3.1,4.7,1.5,Iris-versicolor
89 | 6.3,2.3,4.4,1.3,Iris-versicolor
90 | 5.6,3.0,4.1,1.3,Iris-versicolor
91 | 5.5,2.5,4.0,1.3,Iris-versicolor
92 | 5.5,2.6,4.4,1.2,Iris-versicolor
93 | 6.1,3.0,4.6,1.4,Iris-versicolor
94 | 5.8,2.6,4.0,1.2,Iris-versicolor
95 | 5.0,2.3,3.3,1.0,Iris-versicolor
96 | 5.6,2.7,4.2,1.3,Iris-versicolor
97 | 5.7,3.0,4.2,1.2,Iris-versicolor
98 | 5.7,2.9,4.2,1.3,Iris-versicolor
99 | 6.2,2.9,4.3,1.3,Iris-versicolor
100 | 5.1,2.5,3.0,1.1,Iris-versicolor
101 | 5.7,2.8,4.1,1.3,Iris-versicolor
102 | 6.3,3.3,6.0,2.5,Iris-virginica
103 | 5.8,2.7,5.1,1.9,Iris-virginica
104 | 7.1,3.0,5.9,2.1,Iris-virginica
105 | 6.3,2.9,5.6,1.8,Iris-virginica
106 | 6.5,3.0,5.8,2.2,Iris-virginica
107 | 7.6,3.0,6.6,2.1,Iris-virginica
108 | 4.9,2.5,4.5,1.7,Iris-virginica
109 | 7.3,2.9,6.3,1.8,Iris-virginica
110 | 6.7,2.5,5.8,1.8,Iris-virginica
111 | 7.2,3.6,6.1,2.5,Iris-virginica
112 | 6.5,3.2,5.1,2.0,Iris-virginica
113 | 6.4,2.7,5.3,1.9,Iris-virginica
114 | 6.8,3.0,5.5,2.1,Iris-virginica
115 | 5.7,2.5,5.0,2.0,Iris-virginica
116 | 5.8,2.8,5.1,2.4,Iris-virginica
117 | 6.4,3.2,5.3,2.3,Iris-virginica
118 | 6.5,3.0,5.5,1.8,Iris-virginica
119 | 7.7,3.8,6.7,2.2,Iris-virginica
120 | 7.7,2.6,6.9,2.3,Iris-virginica
121 | 6.0,2.2,5.0,1.5,Iris-virginica
122 | 6.9,3.2,5.7,2.3,Iris-virginica
123 | 5.6,2.8,4.9,2.0,Iris-virginica
124 | 7.7,2.8,6.7,2.0,Iris-virginica
125 | 6.3,2.7,4.9,1.8,Iris-virginica
126 | 6.7,3.3,5.7,2.1,Iris-virginica
127 | 7.2,3.2,6.0,1.8,Iris-virginica
128 | 6.2,2.8,4.8,1.8,Iris-virginica
129 | 6.1,3.0,4.9,1.8,Iris-virginica
130 | 6.4,2.8,5.6,2.1,Iris-virginica
131 | 7.2,3.0,5.8,1.6,Iris-virginica
132 | 7.4,2.8,6.1,1.9,Iris-virginica
133 | 7.9,3.8,6.4,2.0,Iris-virginica
134 | 6.4,2.8,5.6,2.2,Iris-virginica
135 | 6.3,2.8,5.1,1.5,Iris-virginica
136 | 6.1,2.6,5.6,1.4,Iris-virginica
137 | 7.7,3.0,6.1,2.3,Iris-virginica
138 | 6.3,3.4,5.6,2.4,Iris-virginica
139 | 6.4,3.1,5.5,1.8,Iris-virginica
140 | 6.0,3.0,4.8,1.8,Iris-virginica
141 | 6.9,3.1,5.4,2.1,Iris-virginica
142 | 6.7,3.1,5.6,2.4,Iris-virginica
143 | 6.9,3.1,5.1,2.3,Iris-virginica
144 | 5.8,2.7,5.1,1.9,Iris-virginica
145 | 6.8,3.2,5.9,2.3,Iris-virginica
146 | 6.7,3.3,5.7,2.5,Iris-virginica
147 | 6.7,3.0,5.2,2.3,Iris-virginica
148 | 6.3,2.5,5.0,1.9,Iris-virginica
149 | 6.5,3.0,5.2,2.0,Iris-virginica
150 | 6.2,3.4,5.4,2.3,Iris-virginica
151 | 5.9,3.0,5.1,1.8,Iris-virginica
152 |
--------------------------------------------------------------------------------
/examples/trivial_model.py:
--------------------------------------------------------------------------------
1 | from smart_fruit import Model
2 | from smart_fruit.feature_types import Number
3 |
4 |
5 | class TrivialModel(Model):
6 | class Input:
7 | input_ = Number()
8 |
9 | class Output:
10 | output = Number()
11 |
12 |
13 | def main():
14 | multiply_by_10 = TrivialModel.train([
15 | (TrivialModel.Input(n), TrivialModel.Output(10 * n))
16 | for n in [1, 2, 4, 5, 7, 8, 10]
17 | ])
18 |
19 | inputs = [3, 6, 9]
20 | predictions = multiply_by_10.predict([TrivialModel.Input(n) for n in inputs], yield_inputs=True)
21 | for input_, predicted_output in predictions:
22 | print(input_, predicted_output.output)
23 |
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
--------------------------------------------------------------------------------
/pypi_upload/.pypirc:
--------------------------------------------------------------------------------
1 | [distutils]
2 | index-servers =
3 | pypi
4 | testpypi
5 |
6 | [pypi]
7 | username:
8 | password:
9 |
10 | [testpypi]
11 | repository: https://test.pypi.org/legacy/
12 | username:
13 | password:
14 |
--------------------------------------------------------------------------------
/pypi_upload/pypi_upload.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Build distribution
4 | python3 pypi_upload/setup.py sdist bdist_wheel
5 |
6 | # Upload distribution to PyPI
7 | twine upload --config-file pypi_upload/.pypirc dist/*
8 | # For testing, add `--repository testpypi`
9 |
10 | # Tidy up
11 | rm -rf build
12 | rm -rf dist
13 | rm -rf smart_fruit.egg-info
14 |
--------------------------------------------------------------------------------
/pypi_upload/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from os import path
3 |
4 | import re
5 |
6 | project_root = path.join(path.abspath(path.dirname(__file__)), '..')
7 |
8 |
9 | def get_version():
10 | with open(path.join(project_root, 'smart_fruit', '__init__.py'), encoding='utf-8') as init_file:
11 | return re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M).group(1)
12 |
13 |
14 | def get_long_description():
15 | with open(path.join(project_root, 'README.rst'), encoding='utf-8') as readme_file:
16 | return readme_file.read()
17 |
18 |
19 | def get_requirements():
20 | with open(path.join(project_root, 'requirements.txt'), encoding='utf-8') as requirements_file:
21 | return [requirement.strip() for requirement in requirements_file if requirement.strip()]
22 |
23 |
24 | setup(
25 | name='smart-fruit',
26 | version=get_version(),
27 | packages=find_packages(include=('smart_fruit', 'smart_fruit.*')),
28 | install_requires=get_requirements(),
29 |
30 | author='Robert Wright',
31 | author_email='madman.bob@hotmail.co.uk',
32 |
33 | description='A Python schema-based machine learning library',
34 | long_description=get_long_description(),
35 | long_description_content_type='text/x-rst',
36 | url='https://github.com/madman-bob/Smart-Fruit',
37 | license='MIT',
38 | classifiers=[
39 | 'License :: OSI Approved :: MIT License',
40 | 'Programming Language :: Python :: 3',
41 | 'Programming Language :: Python :: 3.6',
42 | 'Programming Language :: Python :: 3.7',
43 | 'Programming Language :: Python :: 3 :: Only',
44 | ],
45 | python_requires='>=3.6'
46 | )
47 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas
2 | scipy
3 | scikit-learn
4 |
--------------------------------------------------------------------------------
/smart_fruit/__init__.py:
--------------------------------------------------------------------------------
1 | from smart_fruit.model import Model
2 |
3 | __version__ = '1.2.1'
4 |
5 | __all__ = ["Model"]
6 |
--------------------------------------------------------------------------------
/smart_fruit/feature_class.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | from smart_fruit.feature_types import FeatureType
4 |
5 | __all__ = ["FeatureClassMeta"]
6 |
7 |
8 | class FeatureClassMixin:
9 | def validate(self):
10 | return self.__class__(*(
11 | feature_type.validate(value)
12 | for feature_type, value in zip(self.__class__, self)
13 | ))
14 |
15 | @classmethod
16 | def from_json(cls, json):
17 | return cls(**{
18 | key: value
19 | for key, value in json.items()
20 | if key in cls._fields
21 | })
22 |
23 | def to_json(self):
24 | return self._asdict()
25 |
26 |
27 | class FeatureClassMeta(type):
28 | def __new__(cls, name, bases, namespace):
29 | base_feature_type = bases[0]
30 | features = tuple(
31 | key
32 | for key, value in base_feature_type.__dict__.items()
33 | if isinstance(value, FeatureType)
34 | )
35 |
36 | return type.__new__(
37 | cls,
38 | name,
39 | tuple(bases) + (namedtuple(base_feature_type.__name__, features), FeatureClassMixin,),
40 | namespace
41 | )
42 |
43 | def __iter__(self):
44 | for field_name in self._fields:
45 | yield getattr(self, field_name)
46 |
47 | def __len__(self):
48 | return len(self._fields)
49 |
--------------------------------------------------------------------------------
/smart_fruit/feature_types/__init__.py:
--------------------------------------------------------------------------------
1 | from smart_fruit.feature_types.feature_type_base import FeatureType
2 | from smart_fruit.feature_types.simple_types import Number, Integer, Complex, Label, Tag
3 | from smart_fruit.feature_types.compound_types import Vector
4 |
5 | __all__ = ["FeatureType", "Number", "Integer", "Complex", "Label", "Vector", "Tag"]
6 |
--------------------------------------------------------------------------------
/smart_fruit/feature_types/compound_types.py:
--------------------------------------------------------------------------------
1 | from pandas import concat
2 |
3 | from smart_fruit.feature_types.feature_type_base import FeatureType
4 |
5 | __all__ = ["Vector"]
6 |
7 |
8 | class Vector(FeatureType):
9 | def __init__(self, feature_types):
10 | self.feature_types = feature_types
11 |
12 | @property
13 | def feature_count(self):
14 | return sum(feature_type.feature_count for feature_type in self.feature_types)
15 |
16 | def validate(self, value):
17 | if len(value) != len(self.feature_types):
18 | raise ValueError(
19 | "Incorrect length vector (expected {}, got {!r})".format(len(self.feature_types), len(value))
20 | )
21 |
22 | return tuple(
23 | feature_type.validate(subvalue)
24 | for subvalue, feature_type in zip(value, self.feature_types)
25 | )
26 |
27 | def to_series(self, value):
28 | if len(value) != len(self.feature_types):
29 | raise ValueError(
30 | "Incorrect length vector (expected {}, got {!r})".format(len(self.feature_types), len(value))
31 | )
32 |
33 | return concat([
34 | feature_type.to_series(subvalue)
35 | for subvalue, feature_type in zip(value, self.feature_types)
36 | ], ignore_index=True)
37 |
38 | def from_series(self, features):
39 | return tuple(
40 | feature_type.from_series(chunk)
41 | for chunk, feature_type in self._chunk_series(features, self.feature_types)
42 | )
43 |
44 | @staticmethod
45 | def _chunk_series(series, feature_types):
46 | start = 0
47 | for feature_type in feature_types:
48 | yield series.iloc[start:start + feature_type.feature_count].reset_index(drop=True), feature_type
49 | start += feature_type.feature_count
50 |
--------------------------------------------------------------------------------
/smart_fruit/feature_types/feature_type_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 |
3 | from pandas import Series
4 |
5 | __all__ = ["FeatureType"]
6 |
7 |
8 | class FeatureType(metaclass=ABCMeta):
9 | _index = None
10 | feature_count = 1
11 |
12 | def __get__(self, instance, owner):
13 | if instance is None:
14 | return self
15 |
16 | if self._index is None:
17 | self._index = next(
18 | i
19 | for i, name in enumerate(owner._fields)
20 | if getattr(owner, name) is self
21 | )
22 |
23 | return instance[self._index]
24 |
25 | def validate(self, value):
26 | return value
27 |
28 | def to_series(self, value):
29 | return Series([value])
30 |
31 | def from_series(self, features):
32 | return features.iloc[0]
33 |
--------------------------------------------------------------------------------
/smart_fruit/feature_types/simple_types.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | from numpy import isfinite
4 | from pandas import Series
5 |
6 | from smart_fruit.feature_types.feature_type_base import FeatureType
7 |
8 | __all__ = ["Number", "Integer", "Complex", "Label", "Tag"]
9 |
10 |
11 | class Number(FeatureType):
12 | def validate(self, value):
13 | value = float(value)
14 |
15 | if not isfinite(value):
16 | raise ValueError(
17 | "May not assign non-finite value {} to a {}".format(
18 | value,
19 | self.__class__.__name__
20 | )
21 | )
22 |
23 | return value
24 |
25 |
26 | class Integer(Number):
27 | def validate(self, value):
28 | return int(round(super().validate(value)))
29 |
30 | def from_series(self, features):
31 | return int(round(super().from_series(features)))
32 |
33 |
34 | class Complex(FeatureType):
35 | feature_count = 2
36 |
37 | def validate(self, value):
38 | value = complex(value)
39 |
40 | if not isfinite(value):
41 | raise ValueError(
42 | "May not assign non-finite value {} to a {}".format(
43 | value,
44 | self.__class__.__name__
45 | )
46 | )
47 |
48 | return value
49 |
50 | def to_series(self, value):
51 | return Series([value.real, value.imag])
52 |
53 | def from_series(self, features):
54 | return complex(*features)
55 |
56 |
57 | class Label(FeatureType, namedtuple('Label', ['labels'])):
58 | @property
59 | def feature_count(self):
60 | return len(self.labels)
61 |
62 | def validate(self, value):
63 | if value not in self.labels:
64 | raise TypeError(
65 | "May not use non-existent label {!r} in a {!r}".format(
66 | value,
67 | self
68 | )
69 | )
70 |
71 | return value
72 |
73 | def to_series(self, value):
74 | return Series([int(value == label) for label in self.labels])
75 |
76 | def from_series(self, features):
77 | return max(zip(features, enumerate(self.labels)))[1][1]
78 |
79 |
80 | class Tag(FeatureType):
81 | feature_count = 0
82 |
83 | def to_series(self, value):
84 | return Series()
85 |
86 | def from_series(self, features):
87 | raise TypeError(
88 | "May not predict a {}".format(self.__class__.__name__)
89 | )
90 |
--------------------------------------------------------------------------------
/smart_fruit/model.py:
--------------------------------------------------------------------------------
1 | from pandas import DataFrame, concat
2 | from pandas.core.apply import frame_apply
3 |
4 | from sklearn import linear_model
5 |
6 | from smart_fruit.feature_class import FeatureClassMeta
7 | from smart_fruit.model_selection import train_test_split
8 | from smart_fruit.utils import csv_open
9 |
10 | __all__ = ["Model"]
11 |
12 |
13 | class ModelMeta(type):
14 | def __init__(cls, name, bases, namespace):
15 | super().__init__(name, bases, namespace)
16 |
17 | for feature_type in ['Input', 'Output']:
18 | feature_class = getattr(cls, feature_type)
19 |
20 | setattr(cls, feature_type, FeatureClassMeta(feature_class.__name__, (feature_class,), {}))
21 |
22 |
23 | class Model(metaclass=ModelMeta):
24 | model_class = linear_model.LinearRegression
25 |
26 | class Input:
27 | pass
28 |
29 | class Output:
30 | pass
31 |
32 | def __init__(self, *args, **kwargs):
33 | self.model = self.model_class(*args, **kwargs)
34 |
35 | @classmethod
36 | def input_features_from_list(cls, lists):
37 | for l in lists:
38 | yield cls.Input(*l).validate()
39 |
40 | @classmethod
41 | def input_features_from_json(cls, json):
42 | for feature in json:
43 | yield cls.Input.from_json(feature).validate()
44 |
45 | @classmethod
46 | def input_features_from_csv(cls, csv_path):
47 | yield from cls.input_features_from_json(csv_open(csv_path, cls.Input._fields))
48 |
49 | @classmethod
50 | def features_from_list(cls, lists):
51 | for l in lists:
52 | yield cls.Input(*l[:len(cls.Input._fields)]).validate(), cls.Output(*l[len(cls.Input._fields):]).validate()
53 |
54 | @classmethod
55 | def features_from_json(cls, json):
56 | for feature in json:
57 | yield cls.Input.from_json(feature).validate(), cls.Output.from_json(feature).validate()
58 |
59 | @classmethod
60 | def features_from_csv(cls, csv_path):
61 | yield from cls.features_from_json(csv_open(csv_path, cls.Input._fields + cls.Output._fields))
62 |
63 | @staticmethod
64 | def _to_raw_features(dataframe, feature_class):
65 | return concat([
66 | column.apply(feature_type.to_series)
67 | for (i, column), feature_type in zip(dataframe.iteritems(), feature_class)
68 | ], axis=1)
69 |
70 | def _dataframes_from_features(self, features):
71 | dataframe = DataFrame(list(input_) + list(output) for input_, output in features)
72 |
73 | input_dataframe = self._to_raw_features(dataframe, self.Input)
74 | output_dataframe = self._to_raw_features(dataframe.loc[:, len(self.Input):], self.Output)
75 |
76 | return input_dataframe, output_dataframe
77 |
78 | @classmethod
79 | def train(cls, features, train_test_split_ratio=None, test_sample_count=None, random_state=None):
80 | if train_test_split_ratio is not None or test_sample_count is not None:
81 | train_features, test_features = train_test_split(
82 | features,
83 | train_test_split_ratio=train_test_split_ratio,
84 | test_sample_count=test_sample_count,
85 | random_state=random_state
86 | )
87 |
88 | model = cls.train(train_features)
89 |
90 | return model, model.score(test_features)
91 |
92 | model = cls()
93 |
94 | model.model.fit(*model._dataframes_from_features(features))
95 |
96 | return model
97 |
98 | def score(self, features):
99 | return self.model.score(*self._dataframes_from_features(features))
100 |
101 | @staticmethod
102 | def _chunk_dataframe(dataframe, feature_types):
103 | start = 0
104 | for feature_type in feature_types:
105 | chunk = dataframe.loc[:, start:start + feature_type.feature_count - 1]
106 | chunk.columns = range(len(chunk.columns))
107 |
108 | yield chunk, feature_type
109 | start += feature_type.feature_count
110 |
111 | def predict(self, input_features, yield_inputs=False):
112 | input_features_dataframe = DataFrame(input_features)
113 |
114 | raw_features = self._to_raw_features(input_features_dataframe, self.Input)
115 |
116 | raw_prediction_dataframe = DataFrame(self.model.predict(raw_features))
117 |
118 | # Using frame_apply instead of chunk.apply as chunk.apply doesn't behave as expected for 0 columns
119 | prediction_dataframe = concat([
120 | frame_apply(chunk, feature_type.from_series, axis=1).apply_standard()
121 | for chunk, feature_type in self._chunk_dataframe(raw_prediction_dataframe, self.Output)
122 | ], axis=1)
123 |
124 | if yield_inputs:
125 | for (_, input_series), (_, output_series) in zip(
126 | input_features_dataframe.iterrows(),
127 | prediction_dataframe.iterrows()
128 | ):
129 | yield self.Input(*input_series), self.Output(*output_series)
130 | else:
131 | for _, output_series in prediction_dataframe.iterrows():
132 | yield self.Output(*output_series)
133 |
--------------------------------------------------------------------------------
/smart_fruit/model_selection.py:
--------------------------------------------------------------------------------
1 | from sklearn.model_selection import train_test_split as sk_train_test_split
2 |
3 | __all__ = ["train_test_split"]
4 |
5 |
6 | def train_test_split(features, train_test_split_ratio=None, test_sample_count=None, random_state=None):
7 | if (train_test_split_ratio is None) == (test_sample_count is None):
8 | raise ValueError(
9 | "Must provide exactly one of train_test_split_ratio or test_sample_count "
10 | "to perform train/test split"
11 | )
12 |
13 | if train_test_split_ratio is not None:
14 | train_test_split_ratio = float(train_test_split_ratio)
15 |
16 | if train_test_split_ratio <= 0 or train_test_split_ratio >= 1:
17 | raise ValueError(
18 | "train_test_split_ratio must be strictly between 0 and 1 (given {})".format(train_test_split_ratio)
19 | )
20 |
21 | if test_sample_count is not None:
22 | test_sample_count = round(test_sample_count)
23 |
24 | if test_sample_count <= 0:
25 | raise ValueError(
26 | "test_sample_count must be strictly positive (given {})".format(test_sample_count)
27 | )
28 |
29 | return sk_train_test_split(
30 | list(features),
31 | test_size=train_test_split_ratio or test_sample_count,
32 | random_state=random_state
33 | )
34 |
--------------------------------------------------------------------------------
/smart_fruit/utils.py:
--------------------------------------------------------------------------------
1 | from csv import reader as csv_reader
2 | from itertools import chain
3 |
4 | __all__ = ["csv_open"]
5 |
6 |
7 | def csv_open(file, expected_columns):
8 | """
9 | Yields rows of csv file as dictionaries
10 |
11 | Parameters:
12 | file - Path, or file-like object, of the CSV file to use
13 | expected_columns - Columns of the csv file
14 | If the first row of the CSV file are these labels, take the columns in that order
15 | Otherwise, take the columns in the order given by expected_columns
16 | """
17 |
18 | if isinstance(file, str):
19 | with open(file, encoding='utf-8') as f:
20 | yield from csv_open(f, expected_columns=expected_columns)
21 | return
22 |
23 | expected_columns = tuple(expected_columns)
24 |
25 | csv_iter = csv_reader(file)
26 |
27 | first_row = next(csv_iter)
28 |
29 | if set(first_row) == set(expected_columns):
30 | columns = first_row
31 | else:
32 | columns = expected_columns
33 | csv_iter = chain([first_row], csv_iter)
34 |
35 | for row in csv_iter:
36 | if len(row) < len(columns):
37 | raise IndexError("Too few columns in row {!r}".format(row))
38 |
39 | yield dict(zip(columns, row))
40 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_compound_types.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from smart_fruit import Model
4 | from smart_fruit.feature_types import Number, Label, Vector
5 |
6 |
7 | class TestCompoundTypes(TestCase):
8 | def test_vector(self):
9 | class ExampleModel(Model):
10 | class Input:
11 | a = Number()
12 |
13 | class Output:
14 | b = Vector([
15 | Number(),
16 | Label(['a', 'b']),
17 | Number()
18 | ])
19 |
20 | samples = [
21 | (0, (1, 'a', 2)),
22 | (1, (3, 'b', 4))
23 | ]
24 |
25 | model = ExampleModel.train(ExampleModel.features_from_list(samples))
26 |
27 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]]))
28 |
29 | for (sample_input, sample_output), prediction in zip(samples, predictions):
30 | with self.subTest(sample=sample_input):
31 | self.assertAlmostEqual(sample_output[0], prediction.b[0])
32 | self.assertEqual(sample_output[1], prediction.b[1])
33 | self.assertAlmostEqual(sample_output[2], prediction.b[2])
34 |
35 | def test_nested_vectors(self):
36 | class ExampleModel(Model):
37 | class Input:
38 | a = Number()
39 |
40 | class Output:
41 | b = Vector([
42 | Number(),
43 | Vector([
44 | Number(),
45 | Label(['a', 'b'])
46 | ]),
47 | Number()
48 | ])
49 |
50 | samples = [
51 | (0, (1, (2, 'a'), 3)),
52 | (1, (4, (5, 'b'), 6))
53 | ]
54 |
55 | model = ExampleModel.train(ExampleModel.features_from_list(samples))
56 |
57 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]]))
58 |
59 | for (sample_input, sample_output), prediction in zip(samples, predictions):
60 | with self.subTest(sample=sample_input):
61 | self.assertAlmostEqual(sample_output[0], prediction.b[0])
62 | self.assertAlmostEqual(sample_output[1][0], prediction.b[1][0])
63 | self.assertEqual(sample_output[1][1], prediction.b[1][1])
64 | self.assertAlmostEqual(sample_output[2], prediction.b[2])
65 |
66 | def test_vector_validation(self):
67 | feature_type = Vector([
68 | Number(),
69 | Label(['a', 'b'])
70 | ])
71 |
72 | for a in ((1, 'a'), (3, 'b'), (17, 'a')):
73 | with self.subTest(a=a):
74 | self.assertEqual(feature_type.validate(a), a)
75 |
76 | for a in ((1,), (1, 'a', 2), ('a', 1), (1, 1)):
77 | with self.subTest(a=a), \
78 | self.assertRaises((TypeError, ValueError)):
79 | feature_type.validate(a)
80 |
--------------------------------------------------------------------------------
/tests/test_feature_serialization.py:
--------------------------------------------------------------------------------
1 | from io import StringIO
2 | from itertools import product
3 | from unittest import TestCase
4 |
5 | from smart_fruit import Model
6 | from smart_fruit.feature_types import Number, Label
7 |
8 |
9 | class TestFeatureSerialization(TestCase):
10 | class ExampleModel(Model):
11 | class Input:
12 | number = Number()
13 | label = Label(['a', 'b', 'c'])
14 |
15 | class Output:
16 | number_a = Number()
17 | number_b = Number()
18 |
19 | example_input = (1, 'a')
20 | example_output = (1, 2)
21 |
22 | example_json_input = {'number': 1, 'label': 'a'}
23 | example_json_output = {'number_a': 1, 'number_b': 2}
24 |
25 | valid_iterable_inputs = (
26 | (1, 'a'),
27 | (1, 'b'),
28 | (2, 'a')
29 | )
30 |
31 | invalid_iterable_inputs = (
32 | ('a', 'a'),
33 | (float('nan'), 'a'),
34 | (float('inf'), 'a'),
35 | (- float('inf'), 'a'),
36 |
37 | (1, 1),
38 | (1, 'd')
39 | )
40 |
41 | def test_iterable_deserialization(self):
42 | with self.subTest(feature_type=self.ExampleModel.Input):
43 | feature = self.ExampleModel.Input(*self.example_input)
44 |
45 | self.assertEqual(feature.number, 1)
46 | self.assertEqual(feature.label, 'a')
47 |
48 | with self.subTest(feature_type=self.ExampleModel.Output):
49 | feature = self.ExampleModel.Output(*self.example_output)
50 |
51 | self.assertEqual(feature.number_a, 1)
52 | self.assertEqual(feature.number_b, 2)
53 |
54 | with self.subTest(func=self.ExampleModel.input_features_from_list):
55 | features = list(self.ExampleModel.input_features_from_list([self.example_input]))
56 |
57 | self.assertEqual(features[0].number, 1)
58 | self.assertEqual(features[0].label, 'a')
59 |
60 | with self.subTest(func=self.ExampleModel.features_from_list):
61 | features = list(self.ExampleModel.features_from_list([self.example_input + self.example_output]))
62 |
63 | self.assertEqual(features[0][0].number, 1)
64 | self.assertEqual(features[0][0].label, 'a')
65 |
66 | self.assertEqual(features[0][1].number_a, 1)
67 | self.assertEqual(features[0][1].number_b, 2)
68 |
69 | def test_json_deserialization(self):
70 | with self.subTest(feature_type=self.ExampleModel.Input):
71 | feature = self.ExampleModel.Input.from_json(self.example_json_input)
72 |
73 | self.assertEqual(feature.number, 1)
74 | self.assertEqual(feature.label, 'a')
75 |
76 | with self.subTest(feature_type=self.ExampleModel.Output):
77 | feature = self.ExampleModel.Output.from_json(self.example_json_output)
78 |
79 | self.assertEqual(feature.number_a, 1)
80 | self.assertEqual(feature.number_b, 2)
81 |
82 | with self.subTest(func=self.ExampleModel.input_features_from_json):
83 | features = list(self.ExampleModel.input_features_from_json([self.example_json_input]))
84 |
85 | self.assertEqual(features[0].number, 1)
86 | self.assertEqual(features[0].label, 'a')
87 |
88 | with self.subTest(func=self.ExampleModel.features_from_json):
89 | features = list(self.ExampleModel.features_from_json([
90 | {**self.example_json_input, **self.example_json_output}
91 | ]))
92 |
93 | self.assertEqual(features[0][0].number, 1)
94 | self.assertEqual(features[0][0].label, 'a')
95 |
96 | self.assertEqual(features[0][1].number_a, 1)
97 | self.assertEqual(features[0][1].number_b, 2)
98 |
99 | def test_csv_deserialization(self):
100 | with self.subTest(func=self.ExampleModel.input_features_from_csv):
101 | features = list(self.ExampleModel.input_features_from_csv(
102 | StringIO(",".join(map(str, self.example_input)))
103 | ))
104 |
105 | self.assertEqual(features[0].number, 1)
106 | self.assertEqual(features[0].label, 'a')
107 |
108 | with self.subTest(func=self.ExampleModel.features_from_csv):
109 | features = list(self.ExampleModel.features_from_csv(
110 | StringIO(",".join(map(str, self.example_input + self.example_output)))
111 | ))
112 |
113 | self.assertEqual(features[0][0].number, 1)
114 | self.assertEqual(features[0][0].label, 'a')
115 |
116 | self.assertEqual(features[0][1].number_a, 1)
117 | self.assertEqual(features[0][1].number_b, 2)
118 |
119 | def test_feature_equality(self):
120 | for a, b in product(self.valid_iterable_inputs, repeat=2):
121 | with self.subTest(a=a, b=b):
122 | self.assertEqual(
123 | self.ExampleModel.Input(*a) == self.ExampleModel.Input(*b),
124 | a == b
125 | )
126 |
127 | def test_iterable_serialization(self):
128 | with self.subTest(feature_type=self.ExampleModel.Input):
129 | feature = self.ExampleModel.Input(*self.example_input)
130 | self.assertEqual(tuple(feature), self.example_input)
131 |
132 | with self.subTest(feature_type=self.ExampleModel.Output):
133 | feature = self.ExampleModel.Output(*self.example_output)
134 | self.assertEqual(tuple(feature), self.example_output)
135 |
136 | def test_json_serialization(self):
137 | with self.subTest(feature_type=self.ExampleModel.Input):
138 | feature = self.ExampleModel.Input(*self.example_input)
139 | self.assertEqual(dict(feature.to_json()), self.example_json_input)
140 |
141 | with self.subTest(feature_type=self.ExampleModel.Output):
142 | feature = self.ExampleModel.Output(*self.example_output)
143 | self.assertEqual(dict(feature.to_json()), self.example_json_output)
144 |
145 | def test_value_coercion(self):
146 | self.assertEqual(
147 | self.ExampleModel.Input('1', 'a').validate(),
148 | self.ExampleModel.Input(1, 'a')
149 | )
150 |
151 | def test_invalid_deserialization(self):
152 | for invalid_input in self.invalid_iterable_inputs:
153 | with self.subTest(invalid_input=invalid_input), \
154 | self.assertRaises((TypeError, ValueError)):
155 | self.ExampleModel.Input(*invalid_input).validate()
156 |
--------------------------------------------------------------------------------
/tests/test_label_types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from unittest import TestCase
3 |
4 | from pandas import Series
5 |
6 | from smart_fruit import Model
7 | from smart_fruit.feature_types import Label
8 |
9 |
10 | class TestLabelTypes(TestCase):
11 | def _test_labels(self, labels):
12 | class ExampleModel(Model):
13 | class Input:
14 | label = Label(labels)
15 |
16 | with self.subTest("Label to pandas Series"):
17 | for i, label in enumerate(labels):
18 | self.assertTrue(
19 | ExampleModel.Input.label.to_series(label).equals(
20 | self._basis_vector(len(labels), i)
21 | )
22 | )
23 |
24 | with self.subTest("Pandas Series to label"):
25 | for i, label in enumerate(labels):
26 | self.assertEqual(
27 | ExampleModel.Input.label.from_series(self._basis_vector(len(labels), i)),
28 | label
29 | )
30 |
31 | @staticmethod
32 | def _basis_vector(length, index):
33 | vector = Series([0 for _ in range(length)])
34 | vector[index] = 1
35 | return vector
36 |
37 | def test_string_labels(self):
38 | self._test_labels(['a', 'b', 'c'])
39 |
40 | def test_number_labels(self):
41 | self._test_labels([1, 17, 3.1415])
42 |
43 | def test_enum_labels(self):
44 | class Colours(Enum):
45 | red = 0
46 | green = 1
47 | blue = 2
48 |
49 | self._test_labels(Colours)
50 |
--------------------------------------------------------------------------------
/tests/test_model_training.py:
--------------------------------------------------------------------------------
1 | from random import Random
2 | from unittest import TestCase
3 |
4 | from numpy import median
5 |
6 | from sklearn import linear_model
7 |
8 | from examples.trivial_model import TrivialModel
9 |
10 |
11 | class TestModelTraining(TestCase):
12 | @staticmethod
13 | def _train_test_split(func, model_class=TrivialModel, **kwargs):
14 | model, score = model_class.train(
15 | model_class.features_from_list((n, func(n)) for n in range(20)),
16 | **kwargs
17 | )
18 |
19 | return score
20 |
21 | def _median_train_test_split(self, *args, attempts=3, **kwargs):
22 | return median([
23 | self._train_test_split(*args, **kwargs)
24 | for _ in range(attempts)
25 | ])
26 |
27 | def test_train_test_split_perfect(self):
28 | self.assertEqual(
29 | self._train_test_split(lambda n: 10 * n, train_test_split_ratio=0.2),
30 | 1
31 | )
32 |
33 | def test_train_test_split_noisy(self):
34 | random = Random(0).random
35 |
36 | for n in range(1, 10):
37 | with self.subTest(train_test_split_ratio=n / 10):
38 | score = self._median_train_test_split(
39 | lambda m: m + 5 * random(),
40 | train_test_split_ratio=n / 10,
41 | random_state=0,
42 | attempts=5
43 | )
44 |
45 | self.assertGreater(score, 0)
46 | self.assertLess(score, 1)
47 |
48 | def test_train_test_split_pure_noise(self):
49 | random = Random(0).random
50 |
51 | score = self._median_train_test_split(
52 | lambda n: random(),
53 | train_test_split_ratio=0.2,
54 | random_state=0,
55 | attempts=5
56 | )
57 |
58 | self.assertLess(score, 0)
59 |
60 | def test_train_test_split_errors(self):
61 | for train_test_split_ratio in (-1, -0.5, 0, 1, 1.5, 2):
62 | with self.subTest(train_test_split_ratio=train_test_split_ratio), \
63 | self.assertRaises(ValueError):
64 | self._train_test_split(lambda n: 10 * n, train_test_split_ratio=train_test_split_ratio)
65 |
66 | for test_sample_count in (-1, 0):
67 | with self.subTest(test_sample_count=test_sample_count), \
68 | self.assertRaises(ValueError):
69 | self._train_test_split(lambda n: 10 * n, test_sample_count=test_sample_count)
70 |
71 | def test_custom_model_class(self):
72 | class HuberModel(TrivialModel):
73 | model_class = linear_model.HuberRegressor
74 |
75 | almost_straight_line = lambda n: 10 * n if n else 10
76 |
77 | self.assertLess(
78 | self._median_train_test_split(
79 | almost_straight_line,
80 | model_class=TrivialModel,
81 | train_test_split_ratio=0.2
82 | ),
83 | self._median_train_test_split(
84 | almost_straight_line,
85 | model_class=HuberModel,
86 | train_test_split_ratio=0.2,
87 | )
88 | )
89 |
--------------------------------------------------------------------------------
/tests/test_simple_types.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from smart_fruit import Model
4 | from smart_fruit.feature_types import Number, Integer, Complex, Label, Tag
5 |
6 |
7 | class TestSimpleTypes(TestCase):
8 | def test_number(self):
9 | class ExampleModel(Model):
10 | class Input:
11 | a = Number()
12 |
13 | class Output:
14 | b = Number()
15 |
16 | model = ExampleModel.train(ExampleModel.features_from_list([
17 | (0, 0),
18 | (1, 10)
19 | ]))
20 |
21 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]]))
22 |
23 | for n, output in zip((0, 1), predictions):
24 | with self.subTest(n=n):
25 | self.assertAlmostEqual(output.b, 10 * n)
26 |
27 | def test_number_validation(self):
28 | feature_type = Number()
29 |
30 | for n in (0, 1, 3.141592, -17):
31 | with self.subTest(n=n):
32 | self.assertEqual(feature_type.validate(n), n)
33 |
34 | for a, b in (("1", 1),):
35 | with self.subTest(a=a):
36 | self.assertEqual(feature_type.validate(a), b)
37 |
38 | for n in (1j, float("nan"), float("inf"), -float("inf"), "a"):
39 | with self.subTest(n=n), \
40 | self.assertRaises((TypeError, ValueError)):
41 | feature_type.validate(n)
42 |
43 | def test_integer(self):
44 | class ExampleModel(Model):
45 | class Input:
46 | a = Number()
47 |
48 | class Output:
49 | b = Integer()
50 |
51 | model = ExampleModel.train(ExampleModel.features_from_list([
52 | (0, 0),
53 | (1, 10)
54 | ]))
55 |
56 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [0.01], [0.99], [1]]))
57 |
58 | for a, n, output in zip((0, 0.01, 0.99, 1), (0, 0, 10, 10), predictions):
59 | with self.subTest(a=a):
60 | self.assertEqual(output.b, n)
61 |
62 | def test_integer_validation(self):
63 | feature_type = Integer()
64 |
65 | for a, b in ((0, 0), (1, 1), (3.141592, 3), (-17, -17)):
66 | with self.subTest(a=a):
67 | self.assertEqual(feature_type.validate(a), b)
68 |
69 | for a in (1j, float("nan"), float("inf"), -float("inf"), "a"):
70 | with self.subTest(a=a), \
71 | self.assertRaises((TypeError, ValueError)):
72 | feature_type.validate(a)
73 |
74 | def test_complex(self):
75 | class ExampleModel(Model):
76 | class Input:
77 | a = Number()
78 |
79 | class Output:
80 | b = Complex()
81 |
82 | model = ExampleModel.train(ExampleModel.features_from_list([
83 | (0, 1),
84 | (1, 1j)
85 | ]))
86 |
87 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [0.5], [1]]))
88 |
89 | for a, b, output in zip((0, 0.5, 1), (1, 0.5 + 0.5j, 1j), predictions):
90 | with self.subTest(a=a):
91 | self.assertAlmostEqual(output.b, b)
92 |
93 | def test_complex_validation(self):
94 | feature_type = Complex()
95 |
96 | for a in (0, 1, 3 + 4j, -1 + 7j):
97 | with self.subTest(a=a):
98 | self.assertEqual(feature_type.validate(a), a)
99 |
100 | for a in (float("nan"), float("inf"), -float("inf"), "a"):
101 | with self.subTest(a=a), \
102 | self.assertRaises((TypeError, ValueError)):
103 | feature_type.validate(a)
104 |
105 | def test_label(self):
106 | class ExampleModel(Model):
107 | class Input:
108 | a = Number()
109 |
110 | class Output:
111 | b = Label(['a', 'b'])
112 |
113 | model = ExampleModel.train(ExampleModel.features_from_list([
114 | (0, 'a'),
115 | (1, 'b')
116 | ]))
117 |
118 | predictions = model.predict(ExampleModel.input_features_from_list([[0], [1]]))
119 |
120 | for n, label, output in zip((0, 1), ('a', 'b'), predictions):
121 | with self.subTest(n=n, label=label):
122 | self.assertEqual(output.b, label)
123 |
124 | def test_label_validation(self):
125 | ob = object()
126 | bad_ob = object()
127 |
128 | feature_type = Label(['a', 'b', 0, ob])
129 |
130 | for a in ('a', 'b', 0, ob):
131 | with self.subTest(a=a):
132 | self.assertEqual(feature_type.validate(a), a)
133 |
134 | for a in ('c', 1, bad_ob):
135 | with self.subTest(a=a), \
136 | self.assertRaises(TypeError):
137 | feature_type.validate(a)
138 |
139 | def test_tag_input(self):
140 | class ExampleModel(Model):
141 | class Input:
142 | a = Number()
143 | b = Tag()
144 |
145 | class Output:
146 | c = Number()
147 |
148 | sample_inputs = [
149 | (0, 0),
150 | (1, 1),
151 | (0, 1),
152 | (1, 0),
153 | (0, "a"),
154 | (0, object())
155 | ]
156 |
157 | model = ExampleModel.train(ExampleModel.features_from_list([
158 | (0, 0, 0),
159 | (1, 1, 1)
160 | ]))
161 |
162 | predictions = model.predict(ExampleModel.input_features_from_list(sample_inputs))
163 |
164 | for sample_input, prediction in zip(sample_inputs, predictions):
165 | with self.subTest(sample=sample_input):
166 | self.assertAlmostEqual(prediction.c, sample_input[0])
167 |
168 | def test_tag_output(self):
169 | class ExampleModel(Model):
170 | class Input:
171 | a = Number()
172 |
173 | class Output:
174 | b = Tag()
175 | c = Number()
176 |
177 | model = ExampleModel.train(ExampleModel.features_from_list([
178 | (0, 0, 0),
179 | (1, 1, 1)
180 | ]))
181 |
182 | with self.assertRaisesRegex(TypeError, "May not predict a Tag"):
183 | next(model.predict([ExampleModel.Input(0)]))
184 |
185 | def test_multiple_types_in_single_model(self):
186 | class ExampleModel(Model):
187 | class Input:
188 | a = Number()
189 | b = Number()
190 |
191 | class Output:
192 | c = Number()
193 | d = Number()
194 | e = Label(['a', 'b'])
195 | f = Complex()
196 | g = Number()
197 |
198 | samples = [
199 | (0, 0, 0, 0, 'a', 0, 0),
200 | (0, 1, 0, 3, 'b', 1j, 0),
201 | (1, 0, 2, 0, 'a', 1, 0),
202 | (1, 1, 2, 3, 'b', 1 + 1j, 0)
203 | ]
204 |
205 | model = ExampleModel.train(ExampleModel.features_from_list(samples))
206 |
207 | predictions = model.predict(ExampleModel.input_features_from_list([
208 | sample[:2] for sample in samples
209 | ]))
210 |
211 | for sample, prediction in zip(samples, predictions):
212 | with self.subTest(sample=sample):
213 | self.assertAlmostEqual(sample[2], prediction.c)
214 | self.assertAlmostEqual(sample[3], prediction.d)
215 | self.assertEqual(sample[4], prediction.e)
216 | self.assertAlmostEqual(sample[5], prediction.f)
217 | self.assertAlmostEqual(sample[6], prediction.g)
218 |
--------------------------------------------------------------------------------
/tests/test_trivial_model.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from examples.trivial_model import TrivialModel
4 |
5 |
6 | class TestTrivialModel(TestCase):
7 | def _test_exact_relationship(self, func):
8 | model = TrivialModel.train(TrivialModel.features_from_list(
9 | (n, func(n))
10 | for n in range(20)
11 | ))
12 |
13 | test_values = (-5, 3, 17, 30, 3.1415)
14 | predictions = model.predict(TrivialModel.Output(n) for n in test_values)
15 |
16 | for n, prediction in zip(test_values, predictions):
17 | with self.subTest(n=n):
18 | self.assertIsInstance(prediction, TrivialModel.Output)
19 | self.assertAlmostEqual(prediction.output, func(n))
20 |
21 | def test_linear_relationship(self):
22 | self._test_exact_relationship(lambda n: 10 * n)
23 |
24 | def test_affine_relationship(self):
25 | self._test_exact_relationship(lambda n: 10 * n + 1)
26 |
27 | def test_quadratic_relationship(self):
28 | model = TrivialModel.train(TrivialModel.features_from_list(
29 | (n, n ** 2)
30 | for n in range(20)
31 | ))
32 |
33 | test_values = (0, 3, 10, 17, 20)
34 | predictions = model.predict(TrivialModel.Output(n) for n in test_values)
35 |
36 | for n, prediction in zip(test_values, predictions):
37 | with self.subTest(n=n):
38 | self.assertIsInstance(prediction, TrivialModel.Output)
39 |
40 | # Check that it's not exact, but still in the ballpark
41 | self.assertNotAlmostEqual(prediction.output, n ** 2, places=0)
42 | self.assertAlmostEqual(prediction.output, n ** 2, places=-3)
43 |
--------------------------------------------------------------------------------
/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madman-bob/Smart-Fruit/8739874811334226c073489fd14b6ca0d0ff9a7b/tests/test_utils/__init__.py
--------------------------------------------------------------------------------
/tests/test_utils/example_csv.csv:
--------------------------------------------------------------------------------
1 | a,b,c
2 | 1,2,3
3 | 4,5,6
4 | α,β,γ
5 |
--------------------------------------------------------------------------------
/tests/test_utils/test_csv_open.py:
--------------------------------------------------------------------------------
1 | from io import StringIO
2 | from unittest import TestCase
3 |
4 | from smart_fruit.utils import csv_open
5 |
6 |
7 | class TestCSVOpen(TestCase):
8 | test_csv_path = "tests/test_utils/example_csv.csv"
9 | test_csv_columns = ('a', 'b', 'c')
10 | test_csv_response = [
11 | {'a': '1', 'b': '2', 'c': '3'},
12 | {'a': '4', 'b': '5', 'c': '6'},
13 | {'a': 'α', 'b': 'β', 'c': 'γ'}
14 | ]
15 |
16 | def test_opens_csv_paths(self):
17 | self.assertEqual(
18 | list(csv_open(self.test_csv_path, self.test_csv_columns)),
19 | self.test_csv_response
20 | )
21 |
22 | def test_opens_csv_file_handles(self):
23 | with open(self.test_csv_path, encoding='utf-8') as csv_file:
24 | self.assertEqual(
25 | list(csv_open(csv_file, self.test_csv_columns)),
26 | self.test_csv_response
27 | )
28 |
29 | def test_no_given_columns(self):
30 | self.assertEqual(
31 | list(csv_open(StringIO("1,2,3\n4,5,6\nα,β,γ"), self.test_csv_columns)),
32 | self.test_csv_response
33 | )
34 |
35 | def test_different_column_order(self):
36 | self.assertEqual(
37 | list(csv_open(self.test_csv_path, ('b', 'a', 'c'))),
38 | self.test_csv_response
39 | )
40 |
41 | def test_missing_columns(self):
42 | with self.assertRaises(IndexError):
43 | list(csv_open(StringIO("1,2"), self.test_csv_columns))
44 |
--------------------------------------------------------------------------------