├── .gitignore ├── .travis.yml ├── LICENSE.md ├── README.md ├── bonspy ├── __init__.py ├── bonsai.py ├── features.py ├── graph_builder.py ├── logistic.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── data │ │ └── test.csv.gz │ ├── test_bonsai.py │ ├── test_features.py │ └── test_graph_builder.py └── utils.py ├── requirements.txt ├── requirements_test.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | *.env 81 | 82 | # virtualenv 83 | .venv/ 84 | venv/ 85 | ENV/ 86 | src/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | 91 | # Rope project settings 92 | .ropeproject 93 | 94 | # Pycharm stuff 95 | .idea 96 | 97 | # Data stuff 98 | *.csv 99 | *.json 100 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | 3 | notifications: 4 | - email: false 5 | 6 | language: python 7 | 8 | os: 9 | - linux 10 | 11 | python: 12 | - "3.3" 13 | - "3.4" 14 | - "3.5" 15 | - "3.6" 16 | - "nightly" 17 | 18 | install: 19 | - pip install -r requirements.txt 20 | - pip install -r requirements_test.txt 21 | - pip install pytest-cov coveralls 22 | 23 | before_script: 24 | - flake8 --show-source . 25 | 26 | script: 27 | - py.test --cov=bonspy . 28 | 29 | after_success: 30 | - coveralls 31 | - bash <(curl -s https://codecov.io/bash) 32 | 33 | deploy: 34 | provider: pypi 35 | user: waltherg 36 | password: 37 | secure: TzR/rGNwLAfV6r6WrGqblCWXpmik43slBgOnL6MbKd6DILHJKgxHjmqQmSwbC1ce1AQWmEKW4oJtraAbDEaNMQ3YBx9EZ0oOKnaZgevDcMULNy/Ax5BUM3p/oTbdHHAE4pqnXlWubdovueXxH3Vq87aSIIXHH+Y5AM1Y/Bh7Qr4wJcWfbAREPVQ6Tao6FtmBZLUskQ2AR2g6X2H8yGnPiRvfk0qxml+etqhS3ARRTN7jm3mWLzGBukV3Louw/9Cuch6y0YlsDKhTbuXJyk8r4KVH5STVOXrYoAvbXjHAdui+o0pcV4s2aM6EXT7upCRDJZwsWZMpW0apX8L4DGAuE84NlQi1kL6vvYF14b/pg6FCmwaY9HY2/iAg+oCjz9wZlL0BLQjQKzysEg1NQTW4pzXRup8mzOVNEtO6C4mFzdYd71EEWV0pRm6r1IOnHYws7Bt6QumbXiH+vGqnNVnlD3cXmS4yn3kMZmj8o3l7GQ52LD6+J4whtJYRHzqTaUDucw6/7LSz4jZ8RXacsCDBQgdLSy7MDTC/n8kZfa8edyblNt1MRHxd7H0f2eQGHAOvpCk1xYsEN3JPWos3V0eDeLpDIJQJyuyYqKbC0bO5Iz0T0FxehkylTRbE/OlJjOlZxTm2cXbh/Rt3ZJ/ZpnhjpC8ET+TCXF4GrP9VrBrHsmQ= 38 | on: 39 | tags: true 40 | branch: master 41 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | License 2 | ------- 3 | 4 | Copyright (c) 2015 Rocket Internet SE. 5 | All rights reserved. 6 | 7 | Authors: 8 | - Alexander Volkmann (Rocket Internet SE) 9 | - Georg Walther (Rocket Internet SE) 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 3. Neither the name of the copyright holder nor the names of its contributors 20 | may be used to endorse or promote products derived from this software 21 | without specific prior written permission. 22 | 23 | 24 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 25 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 26 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 28 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 29 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 30 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 31 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 32 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 33 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bonspy 2 | 3 | [![PyPI version](https://badge.fury.io/py/bonspy.svg)](https://badge.fury.io/py/bonspy) 4 | [![Build Status](https://travis-ci.org/markovianhq/bonspy.svg)](https://travis-ci.org/markovianhq/bonspy) 5 | [![codecov](https://codecov.io/gh/markovianhq/bonspy/branch/master/graph/badge.svg)](https://codecov.io/gh/markovianhq/bonspy) 6 | [![Join the chat at https://gitter.im/markovianhq/Lobby](https://badges.gitter.im/markovianhq/Lobby.svg)] 7 | (https://gitter.im/markovianhq/Lobby) 8 | 9 | Bonspy converts bidding trees from various input formats to the 10 | [Bonsai bidding language of AppNexus](http://developers.appnexus.com/introduction-to-the-bonsai-decision-tree-language/). 11 | 12 | As intermediate format bonspy constructs a [NetworkX](https://networkx.github.io/) graph from which it produces the 13 | Bonsai language output. 14 | Bidding trees may also be constructed directly in this NetworkX format (see first example below). 15 | 16 | At present bonspy provides a converter from trained [sklearn](http://scikit-learn.org/stable/) logistic regression 17 | classifiers with categorical, one-hot encoded features to the intermediate NetworkX format (see second example below). 18 | 19 | In combination with our AppNexus API wrapper [`nexusadspy`](https://github.com/markovianhq/nexusadspy) it is also 20 | straightforward to check your bidding tree for syntactical errors and upload it for real-time bidding (third example below). 21 | 22 | This package was developed and tested on Python 3.5. 23 | However, the examples below have been tested successfully in Python 2.7. 24 | 25 | Versions 0.9.2 and higher will no longer support Python 2.7. 26 | 27 | ## Installation 28 | 29 | ### Installation as regular library 30 | 31 | Install the latest release from PyPI: 32 | 33 | $ pip install bonspy 34 | 35 | To install the latest `master` branch commit of bonspy: 36 | 37 | $ pip install -e git+git@github.com:markovianhq/bonspy.git@master#egg=bonspy 38 | 39 | To install a specific commit, e.g. `97c41e9`: 40 | 41 | $ pip install -e git+git@github.com:markovianhq/bonspy.git@97c41e9#egg=bonspy 42 | 43 | ### Installation for development 44 | 45 | To install bonspy for local development you may want to create a virtual environment. 46 | Assuming you use [Continuum Anaconda](https://www.continuum.io/downloads), create 47 | a new virtual environment as follows: 48 | 49 | $ conda create --name bonspy python=3 -y 50 | 51 | Activate the environment: 52 | 53 | $ source activate bonspy 54 | 55 | Install the requirements: 56 | 57 | $ pip install -r requirements.txt 58 | 59 | Now install bonspy in development mode: 60 | 61 | $ python setup.py develop 62 | 63 | To run the tests, install these additional packages: 64 | 65 | $ pip install -r requirements_test.txt 66 | 67 | Now run the tests: 68 | 69 | $ py.test bonspy --flake8 70 | 71 | ## Example: NetworkX tree to Bonsai output 72 | 73 | import networkx as nx 74 | 75 | from bonspy import BonsaiTree 76 | 77 | 78 | g = nx.DiGraph() 79 | 80 | g.add_node(0, split='segment', state={}) 81 | g.add_node(1, split='age', state={'segment': 12345}) 82 | g.add_node(2, split='age', state={'segment': 67890}) 83 | g.add_node(3, split='country', state={'segment': 12345, 'age': (None, 10.)}) 84 | g.add_node(4, split='country', state={'segment': 12345, 'age': (10., None)}) 85 | g.add_node(5, split='country', state={'segment': 67890, 'age': (None, 10.)}) 86 | g.add_node(6, split='country', state={'segment': 67890, 'age': (10., None)}) 87 | g.add_node(7, is_leaf=True, output=0.10, state={'segment': 12345, 'age': (None, 10.), 'country': ('GB', 'DE')}) 88 | g.add_node(8, is_leaf=True, output=0.20, state={'segment': 12345, 'age': (None, 10.), 'country': ('US', 'BR')}) 89 | g.add_node(9, is_leaf=True, output=0.10, state={'segment': 12345, 'age': (10., None), 'country': ('GB', 'DE')}) 90 | g.add_node(10, is_leaf=True, output=0.20, state={'segment': 12345, 'age': (10., None), 'country': ('US', 'BR')}) 91 | g.add_node(11, is_leaf=True, output=0.10, state={'segment': 67890, 'age': (None, 10.), 'country': ('GB', 'DE')}) 92 | g.add_node(12, is_leaf=True, output=0.20, state={'segment': 67890, 'age': (None, 10.), 'country': ('US', 'BR')}) 93 | g.add_node(13, is_leaf=True, output=0.10, state={'segment': 67890, 'age': (10., None), 'country': ('GB', 'DE')}) 94 | g.add_node(14, is_leaf=True, output=0.20, state={'segment': 67890, 'age': (10., None), 'country': ('US', 'BR')}) 95 | g.add_node(15, is_default_leaf=True, output=0.05, state={}) 96 | g.add_node(16, is_default_leaf=True, output=0.05, state={'segment': 12345}) 97 | g.add_node(17, is_default_leaf=True, output=0.05, state={'segment': 67890}) 98 | g.add_node(18, is_default_leaf=True, output=0.05, state={'segment': 12345, 'age': (None, 10.)}) 99 | g.add_node(19, is_default_leaf=True, output=0.05, state={'segment': 12345, 'age': (10., None)}) 100 | g.add_node(20, is_default_leaf=True, output=0.05, state={'segment': 67890, 'age': (None, 10.)}) 101 | g.add_node(21, is_default_leaf=True, output=0.05, state={'segment': 67890, 'age': (10., None)}) 102 | 103 | g.add_edge(0, 1, value=12345, type='assignment') 104 | g.add_edge(0, 2, value=67890, type='assignment') 105 | g.add_edge(1, 3, value=(None, 10.), type='range') 106 | g.add_edge(1, 4, value=(10., None), type='range') 107 | g.add_edge(2, 5, value=(None, 10.), type='range') 108 | g.add_edge(2, 6, value=(10., None), type='range') 109 | g.add_edge(3, 7, value=('GB', 'DE'), type='membership') 110 | g.add_edge(3, 8, value=('US', 'BR'), type='membership') 111 | g.add_edge(4, 9, value=('GB', 'DE'), type='membership') 112 | g.add_edge(4, 10, value=('US', 'BR'), type='membership') 113 | g.add_edge(5, 11, value=('GB', 'DE'), type='membership') 114 | g.add_edge(5, 12, value=('US', 'BR'), type='membership') 115 | g.add_edge(6, 13, value=('GB', 'DE'), type='membership') 116 | g.add_edge(6, 14, value=('US', 'BR'), type='membership') 117 | g.add_edge(0, 15) 118 | g.add_edge(1, 16) 119 | g.add_edge(2, 17) 120 | g.add_edge(3, 18) 121 | g.add_edge(4, 19) 122 | g.add_edge(5, 20) 123 | g.add_edge(6, 21) 124 | 125 | tree = BonsaiTree(g) 126 | 127 | This `tree` looks as follows (note the image below is old: `geo` has been replaced with `country`, 128 | and `UK` with `GB`): 129 | 130 | ![tree_example](https://cloud.githubusercontent.com/assets/3273502/10993831/4cf94712-8472-11e5-8256-4f736814d7eb.png) 131 | 132 | Note that non-leaf nodes track the next user variable to be split on in their `split` attribute while 133 | the current choice of user features is tracked in their `state` attribute. 134 | Leaves designate their output (the bid) in their `output` attribute. 135 | 136 | The Bonsai text representation of the above `tree` is stored in its `.bonsai` attribute: 137 | 138 | print(tree.bonsai) 139 | 140 | prints out 141 | 142 | if segment[12345]: 143 | switch segment[12345].age: 144 | case (.. 10): 145 | if country in ("GB","DE"): 146 | 0.1000 147 | elif country in ("US","BR"): 148 | 0.2000 149 | else: 150 | 0.0500 151 | case (11 ..): 152 | if country in ("GB","DE"): 153 | 0.1000 154 | elif country in ("US","BR"): 155 | 0.2000 156 | else: 157 | 0.0500 158 | default: 159 | 0.0500 160 | elif segment[67890]: 161 | switch segment[67890].age: 162 | case (.. 10): 163 | if country in ("GB","DE"): 164 | 0.1000 165 | elif country in ("US","BR"): 166 | 0.2000 167 | else: 168 | 0.0500 169 | case (11 ..): 170 | if country in ("GB","DE"): 171 | 0.1000 172 | elif country in ("US","BR"): 173 | 0.2000 174 | else: 175 | 0.0500 176 | default: 177 | 0.0500 178 | else: 179 | 0.0500 180 | 181 | ## Example: Sklearn logistic regression classifier to Bonsai output 182 | 183 | **This example is old and has not been tested lately!** 184 | 185 | from bonspy import LogisticConverter 186 | from bonspy import BonsaiTree 187 | 188 | features = ['segment', 'age', 'geo'] 189 | 190 | vocabulary = { 191 | 'segment=12345': 0, 192 | 'segment=67890': 1, 193 | 'age=0': 2, 194 | 'age=1': 3, 195 | 'geo=UK': 4, 196 | 'geo=DE': 5, 197 | 'geo=US': 6, 198 | 'geo=BR': 7 199 | } 200 | 201 | weights = [.1, .2, .15, .25, .1, .1, .2, .2] 202 | intercept = .4 203 | 204 | buckets = { 205 | 'age': { 206 | '0': (None, 10), 207 | '1': (10, None) 208 | } 209 | } 210 | 211 | types = { 212 | 'segment': 'assignment', 213 | 'age': 'range', 214 | 'geo': 'assignment' 215 | } 216 | 217 | conv = LogisticConverter(features=features, vocabulary=vocabulary, 218 | weights=weights, intercept=intercept, 219 | types=types, base_bid=2., buckets=buckets) 220 | 221 | tree = BonsaiTree(conv.graph) 222 | 223 | print(tree.bonsai) 224 | 225 | Prints out 226 | 227 | if segment 67890: 228 | if segment 67890 age > 10: 229 | if geo="US": 230 | 1.4815 231 | elif geo="UK": 232 | 1.4422 233 | elif geo="BR": 234 | 1.4815 235 | elif geo="DE": 236 | 1.4422 237 | else: 238 | 1.4011 239 | elif segment 67890 age <= 10: 240 | if geo="US": 241 | 1.4422 242 | elif geo="UK": 243 | 1.4011 244 | elif geo="BR": 245 | 1.4422 246 | elif geo="DE": 247 | 1.4011 248 | else: 249 | 1.3584 250 | else: 251 | 1.2913 252 | elif segment 12345: 253 | if segment 12345 age > 10: 254 | if geo="US": 255 | 1.4422 256 | elif geo="DE": 257 | 1.4011 258 | elif geo="UK": 259 | 1.4011 260 | elif geo="BR": 261 | 1.4422 262 | else: 263 | 1.3584 264 | elif segment 12345 age <= 10: 265 | if geo="US": 266 | 1.4011 267 | elif geo="DE": 268 | 1.3584 269 | elif geo="UK": 270 | 1.3584 271 | elif geo="BR": 272 | 1.4011 273 | else: 274 | 1.3140 275 | else: 276 | 1.2449 277 | else: 278 | 1.1974 279 | 280 | ## Example: Uploading the Bonsai output to AppNexus 281 | 282 | Use our [`nexusadspy` library](https://github.com/markovianhq/nexusadspy) to 283 | send the encoded `tree` to the AppNexus parser and check 284 | for any syntactical errors: 285 | 286 | from nexusadspy import AppnexusClient 287 | 288 | check_tree = { 289 | "custom-model-parser": { 290 | "model_text": tree.bonsai_encoded 291 | } 292 | } 293 | 294 | with AppnexusClient('.appnexus_auth.json') as client: 295 | r = client.request('custom-model-parser', 'POST', data=check_tree) 296 | 297 | If the AppNexus API does not return any errors for our `tree` we can now 298 | upload it as follows: 299 | 300 | custom_model = { 301 | "custom_model": { 302 | "name": "Insert tree name (visible in the AppNexus advertiser UI)", 303 | "member_id": # add your integer member ID, 304 | "advertiser_id": # add your integer advertiser ID, 305 | "custom_model_structure": "decision_tree", 306 | "model_output": "bid", 307 | "model_text": encoded 308 | } 309 | } 310 | 311 | r = client.request('custom-model', 'POST', data=custom_model) 312 | 313 | Check the response `r` for the integer identifier assigned to your bidding tree by AppNexus. 314 | You will use this identifier to set the uploaded tree as bidder for your advertising 315 | campaigns in the AppNexus advertiser UI. 316 | 317 | For more details see https://wiki.appnexus.com/display/console/AppNexus+Programmable+Bidder. 318 | -------------------------------------------------------------------------------- /bonspy/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | from bonspy.bonsai import BonsaiTree 9 | from bonspy.logistic import LogisticConverter 10 | -------------------------------------------------------------------------------- /bonspy/bonsai.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | import base64 9 | 10 | from collections import deque, OrderedDict 11 | from functools import cmp_to_key 12 | 13 | import networkx as nx 14 | 15 | from bonspy.features import compound_features, get_validated, objects 16 | from bonspy.utils import compare_vectors, is_absent_value 17 | 18 | try: 19 | basestring 20 | except NameError: 21 | basestring = str 22 | 23 | RANGE_EPSILON = 1 24 | 25 | 26 | class BonsaiTree(nx.DiGraph): 27 | """ 28 | A NetworkX DiGraph (directed graph) subclass that knows how to print 29 | itself out in the AppNexus Bonsai bidding tree language. 30 | 31 | See the readme for the expected graph structure: 32 | 33 | https://github.com/markovianhq/bonspy 34 | 35 | The Bonsai text representation of this tree is stored in its `bonsai` attribute. 36 | 37 | :param graph: (optional) NetworkX graph to be exported to Bonsai. 38 | :param feature_order: (optional), iterable required when a parent node is split on more than one feature. 39 | Splitting the parent node on more than one feature is indicated through its `split` attribute 40 | set to an OrderedDict object [(child id, feature the parent node is split on]). 41 | The list `feature_order` then provides the order these different features appear in the 42 | Bonsai language output. 43 | :param feature_value_order: (optional), Similar to `feature_order` but a dictionary of lists 44 | of the form {feature: [feature value 1, feature value 2, ...]}. 45 | :param absence_values: (optional), Dictionary feature name -> iterable of values whose communal absence 46 | signals absence of the respective feature. 47 | :param slice_features: (optional) iterable, features to be used for slicing. The private _slice_graph method slices 48 | out the part of the graph where the "slice_features" have a value that is equal to the value of the 49 | "slice_feature_values" dict. 50 | Moreover, it splices out the levels where the "splice_features" are split. 51 | The "slice" method assumes that a node never splits on the "slice_features" together with another feature. 52 | :param slice_feature_values: (optional) dict, slice_feature -> feature values to not be sliced off the graph. 53 | """ 54 | 55 | def __init__(self, graph=None, feature_order=(), feature_value_order={}, absence_values=None, 56 | slice_features=None, slice_feature_values=(), **kwargs): 57 | if graph is not None: 58 | super(BonsaiTree, self).__init__(graph) 59 | self.feature_order = self._convert_to_dict(feature_order) 60 | self.feature_value_order = self._get_feature_value_order(feature_value_order) 61 | self.absence_values = absence_values or {} 62 | self.slice_features = slice_features or () 63 | self.slice_feature_values = slice_feature_values or {} 64 | for key, value in kwargs.items(): 65 | setattr(self, key, value) 66 | self._transform_splits() 67 | self._slice_graph() 68 | self._replace_absent_values() 69 | self._remove_missing_compound_features() 70 | self._validate_feature_values() 71 | self._assign_indent() 72 | self._assign_condition() 73 | self._handle_switch_statements() 74 | self.bonsai = ''.join(self._tree_to_bonsai()) 75 | else: 76 | super(BonsaiTree, self).__init__(**kwargs) 77 | 78 | @staticmethod 79 | def _convert_to_dict(feature_order): 80 | for index, f in enumerate(feature_order): 81 | if isinstance(f, list): 82 | feature_order[index] = tuple(f) 83 | feature_order = {f: index for index, f in enumerate(feature_order)} 84 | return feature_order 85 | 86 | def _get_feature_value_order(self, feature_value_order): 87 | return {feature: self._convert_to_dict(list_) for feature, list_ in feature_value_order.items()} 88 | 89 | @property 90 | def bonsai_encoded(self): 91 | return base64.b64encode(self.bonsai.encode('ascii')).decode() 92 | 93 | def _transform_splits(self): 94 | root_id = self._get_root() 95 | 96 | for node_id in self.bfs_nodes(root_id): 97 | try: 98 | split = self.node[node_id]['split'] 99 | except KeyError: 100 | continue 101 | 102 | if not isinstance(split, dict): 103 | self.node[node_id]['split'] = OrderedDict() 104 | 105 | for child_id in self.successors_iter(node_id): 106 | if not self.node[child_id].get('is_default_leaf', self.node[child_id].get('is_default_node')): 107 | self.node[node_id]['split'][child_id] = split 108 | 109 | def _slice_graph(self): 110 | for slice_feature in self.slice_features: 111 | self._slice_feature_out_of_graph(slice_feature) 112 | 113 | def _slice_feature_out_of_graph(self, slice_feature): 114 | root_id = self._get_root() 115 | 116 | queue = deque([root_id]) 117 | while queue: 118 | node_id = queue.popleft() 119 | if self.node[node_id].get('is_default_leaf'): 120 | continue 121 | split_contains_slice_feature = self._split_contains_slice_feature(node_id, slice_feature) 122 | 123 | if not split_contains_slice_feature: 124 | next_nodes = self.successors(node_id) 125 | queue.extend(next_nodes) 126 | else: 127 | queue = self._update_sub_graph(node_id, slice_feature, queue) 128 | 129 | def _split_contains_slice_feature(self, node_id, slice_feature): 130 | try: 131 | split = self.node[node_id]['split'] 132 | return slice_feature in split.values() 133 | except KeyError: # default leaf or leaf 134 | return False 135 | 136 | def _update_sub_graph(self, node_id, slice_feature, queue): 137 | self._prune_unwanted_children(node_id, slice_feature) 138 | 139 | default_child = next((n for n in self.successors_iter(node_id) if self.node[n].get('is_default_leaf'))) 140 | 141 | try: 142 | normal_child = self._get_normal_child(node_id, slice_feature) 143 | other_children = [n for n in self.successors_iter(node_id) if n not in {normal_child, default_child}] 144 | queue.extend(other_children) 145 | 146 | if self.node[normal_child].get('is_leaf'): 147 | self._remove_leaves_and_update_parent_default( 148 | node_id, slice_feature, normal_child, default_child, other_children 149 | ) 150 | else: 151 | self._splice_out_node(normal_child, slice_feature, slicing=True) 152 | 153 | except StopIteration: # slice feature value not present in subtree 154 | other_children = [n for n in self.successors_iter(node_id) if n != default_child] 155 | if other_children: 156 | queue.extend(other_children) 157 | else: 158 | self._cut_single_default_child(node_id, default_child) 159 | 160 | return queue 161 | 162 | def _prune_unwanted_children(self, node_id, slice_feature): 163 | prunable_children = [ 164 | n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf') and 165 | slice_feature in self.node[n]['state'] and 166 | self.node[n]['state'].get(slice_feature) != self.slice_feature_values[slice_feature] 167 | ] 168 | for prunable_child in prunable_children: 169 | if self.node[node_id].get('split'): 170 | del self.node[node_id]['split'][prunable_child] 171 | self._remove_sub_graph(prunable_child) 172 | 173 | def _remove_sub_graph(self, node): 174 | queue = deque([node]) 175 | while queue: 176 | current_node = queue.popleft() 177 | next_nodes = self.successors(current_node) 178 | self.remove_node(current_node) 179 | queue.extend(next_nodes) 180 | 181 | def _get_normal_child(self, node_id, slice_feature): 182 | return next(( 183 | n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf') and 184 | slice_feature in self.node[n]['state'] 185 | )) 186 | 187 | def _remove_leaves_and_update_parent_default(self, node_id, slice_feature, normal_child, 188 | default_child, other_children): 189 | if not other_children: 190 | del self.node[node_id]['split'] 191 | self._remove_feature_from_state(node_id, slice_feature) 192 | self.node[node_id] = self.node[normal_child].copy() 193 | 194 | self.remove_edge(node_id, default_child) 195 | self.remove_node(default_child) 196 | else: 197 | del self.node[node_id]['split'][normal_child] 198 | self._remove_feature_from_state(node_id, slice_feature) 199 | self.node[default_child] = self.node[normal_child].copy() 200 | del self.node[default_child]['is_leaf'] 201 | self.node[default_child]['is_default_leaf'] = True 202 | 203 | self.remove_edge(node_id, normal_child) 204 | self.remove_node(normal_child) 205 | 206 | def _remove_feature_from_state(self, source, feature): 207 | for node_id in self.bfs_nodes(source): 208 | try: 209 | del self.node[node_id]['state'][feature] 210 | except KeyError: # node_id is default leaf 211 | pass 212 | 213 | def _splice_out_node(self, source, feature, slicing=False): 214 | self._remove_feature_from_state(source, feature) 215 | self._skip_node(source, slicing) 216 | 217 | def _skip_node(self, node_id, slicing): 218 | parent_id = next(iter(self.predecessors_iter(node_id))) 219 | 220 | if slicing: 221 | self._skip_node_slicing(node_id, parent_id) 222 | else: 223 | self._skip_node_non_slicing(node_id, parent_id) 224 | 225 | self.remove_edge(parent_id, node_id) 226 | self.remove_node(node_id) 227 | 228 | def _cut_single_default_child(self, parent_id, default_child): 229 | if not self.node[parent_id].get('is_default_node'): 230 | self.node[parent_id] = self.node[default_child] 231 | del self.node[parent_id]['is_default_leaf'] 232 | self.node[parent_id]['is_leaf'] = True 233 | else: 234 | self.node[parent_id] = self.node[default_child] 235 | self.remove_node(default_child) 236 | 237 | def _replace_absent_values(self): 238 | root_id = self._get_root() 239 | 240 | for parent_id, child_id in nx.bfs_edges(self, root_id): 241 | try: 242 | feature = next(reversed(self.node[child_id]['state'])) 243 | except StopIteration: 244 | continue # node_id is root_id 245 | 246 | value = self.node[child_id]['state'][feature] 247 | 248 | if self.absence_values.get(feature) and value is None: 249 | self._replace_absent_value_split(parent_id, child_id, feature) 250 | self._replace_absent_value_edge(parent_id, child_id, feature) 251 | self._replace_absent_value_state(child_id, feature) 252 | 253 | def _replace_absent_value_split(self, parent_id, child_id, feature): 254 | values = self.absence_values[feature] 255 | self.node[parent_id]['split'][child_id] = tuple(feature for value in values) 256 | 257 | def _replace_absent_value_edge(self, parent_id, child_id, feature): 258 | values = self.absence_values[feature] 259 | 260 | self.edge[parent_id][child_id]['value'] = values 261 | self.edge[parent_id][child_id]['type'] = ['assignment' for value in values] 262 | self.edge[parent_id][child_id]['is_negated'] = [True for value in values] 263 | 264 | def _replace_absent_value_state(self, source, feature): 265 | absent_values = self.absence_values[feature] 266 | 267 | for node_id in self.bfs_nodes(source): 268 | state = self.node[node_id]['state'] 269 | absent_feature = tuple(feature for value in absent_values) 270 | 271 | state = OrderedDict( 272 | [(k, v) if k != feature else (absent_feature, absent_values) for k, v in state.items()] 273 | ) 274 | 275 | self.node[node_id]['state'] = state 276 | 277 | def _remove_missing_compound_features(self): 278 | root_id = self._get_root() 279 | 280 | for node_id in self.bfs_nodes(root_id): 281 | try: 282 | feature = next(reversed(self.node[node_id]['state'])) 283 | except StopIteration: 284 | continue # node_id is root_id 285 | 286 | value = self.node[node_id]['state'][feature] 287 | 288 | is_compound_attribute = self._is_compound_attribute(feature) 289 | 290 | if is_compound_attribute and value is None: 291 | if self.node[node_id].get('is_leaf'): 292 | self.remove_node(node_id) 293 | else: 294 | self._splice_out_node(node_id, feature) 295 | 296 | self._remove_disconnected_nodes() 297 | self._prune_redundant_default_leaves() 298 | 299 | def bfs_nodes(self, source): 300 | queue = deque([source]) 301 | 302 | while queue: 303 | node_id = queue.popleft() 304 | child_ids = self.successors_iter(node_id) 305 | queue.extend(child_ids) 306 | 307 | yield node_id 308 | 309 | def _is_compound_attribute(self, feature): 310 | if '.' in feature: 311 | return True 312 | else: 313 | return False 314 | 315 | def _skip_node_non_slicing(self, node_id, parent_id): 316 | for _, child_id, edge_data in self.edges(nbunch=(node_id,), data=True): 317 | if self.node[child_id].get('is_default_leaf'): 318 | continue 319 | else: 320 | self.add_edge(parent_id, child_id, attr_dict=edge_data) 321 | self.remove_edge(node_id, child_id) 322 | del self.node[parent_id]['split'][node_id] 323 | self._update_split(parent_id, node_id) 324 | 325 | def _skip_node_slicing(self, node_id, parent_id): 326 | for _, child_id, edge_data in self.out_edges(nbunch=(node_id,), data=True): 327 | if self.node[child_id].get('is_default_leaf'): 328 | self._update_parent_default_leaf(parent_id, child_id) 329 | del self.node[child_id] 330 | else: 331 | self.add_edge(parent_id, child_id, attr_dict=edge_data) 332 | self._update_split(parent_id, node_id, child_id=child_id) 333 | self.remove_edge(node_id, child_id) 334 | del self.node[parent_id]['split'][node_id] 335 | 336 | def _update_parent_default_leaf(self, parent_id, new_default): 337 | current_parent_default = next(iter( 338 | [n for n in self.successors(parent_id) if self.node[n].get('is_default_leaf')] 339 | )) 340 | self.node[current_parent_default] = self.node[new_default].copy() 341 | 342 | def _update_split(self, parent_id, node_id, child_id=None): 343 | node_split = self.node[node_id]['split'] 344 | if child_id: 345 | self.node[parent_id]['split'][child_id] = node_split[child_id] 346 | else: 347 | self.node[parent_id]['split'].update(node_split) 348 | 349 | def _remove_disconnected_nodes(self): 350 | node_ids = self._get_disconnected_nodes() 351 | 352 | while node_ids: 353 | self.remove_nodes_from(node_ids) 354 | 355 | node_ids = self._get_disconnected_nodes() 356 | 357 | def _get_disconnected_nodes(self): 358 | root = self._get_root() 359 | node_ids = [n for n in self.nodes_iter() if not self.successors(n) and not self.predecessors(n) and n != root] 360 | return node_ids 361 | 362 | def _prune_redundant_default_leaves(self): 363 | only_child_default_leaves = self._get_only_child_default_leaves() 364 | queue = deque(only_child_default_leaves) 365 | 366 | while queue: 367 | node_id = queue.popleft() 368 | parent_id = next(iter(self.predecessors_iter(node_id))) 369 | 370 | if not self.node[parent_id].get('is_default_node'): 371 | self.node[parent_id] = self.node[node_id] 372 | del self.node[parent_id]['is_default_leaf'] 373 | self.node[parent_id]['is_leaf'] = True 374 | else: 375 | self.node[parent_id] = self.node[node_id] 376 | queue.extend(parent_id) 377 | 378 | self.remove_node(node_id) 379 | 380 | def _get_only_child_default_leaves(self): 381 | default_edges = ((p, c) for (p, c) in self.edges_iter() if self.node[c].get('is_default_leaf')) 382 | only_child_default_leaves = (c for (p, c) in default_edges if self._has_only_one_child(p)) 383 | return only_child_default_leaves 384 | 385 | def _has_only_one_child(self, parent_id): 386 | return len(self.successors(parent_id)) == 1 387 | 388 | def _validate_feature_values(self): 389 | self._validate_node_states() 390 | self._validate_edge_values() 391 | 392 | def _validate_node_states(self): 393 | for node, data in self.nodes_iter(data=True): 394 | for feature, value in data.get('state', {}).items(): 395 | self.node[node]['state'][feature] = get_validated(feature, value) 396 | 397 | def _validate_edge_values(self): 398 | for parent, child, data in self.edges_iter(data=True): 399 | feature = self.node[parent]['split'] 400 | if isinstance(feature, dict): 401 | feature = feature.get(child) 402 | try: 403 | value = data['value'] 404 | self.edge[parent][child]['value'] = get_validated(feature, value) 405 | except KeyError: 406 | pass # edge has no value attribute, nothing to validate 407 | 408 | def _get_root(self): 409 | for node in self.nodes(): 410 | if len(self.predecessors(node)) == 0: 411 | return node 412 | 413 | def _assign_indent(self): 414 | root = self._get_root() 415 | queue = deque([root]) 416 | 417 | self.node[root]['indent'] = '' 418 | 419 | while queue: 420 | node = queue.popleft() 421 | indent = self.node[node]['indent'] 422 | 423 | next_nodes = self.successors(node) 424 | for node in next_nodes: 425 | self.node[node]['indent'] = indent + '\t' 426 | 427 | next_nodes = sorted(next_nodes, key=self._sort_key) 428 | 429 | queue.extend(next_nodes) 430 | 431 | @property 432 | def _sort_key(self): 433 | comparison_function = self._get_comparison_function() 434 | return cmp_to_key(comparison_function) 435 | 436 | def _get_comparison_function(self): 437 | _get_default_extended_vector = self._get_default_extended_vector 438 | 439 | def compare_nodes(x, y): 440 | x_ = _get_default_extended_vector(x) 441 | y_ = _get_default_extended_vector(y) 442 | 443 | return compare_vectors(x_, y_) 444 | 445 | return compare_nodes 446 | 447 | def _get_default_extended_vector(self, x): 448 | vec = [self.node[x].get('is_default_leaf', False), self.node[x].get('is_default_node', False)] 449 | vec += self._get_sorted_values(x) 450 | 451 | return vec 452 | 453 | def _get_sorted_values(self, x): 454 | values = [] 455 | 456 | for feature, value in self.node[x]['state'].items(): 457 | feature_key = self._get_feature_order_key(feature) 458 | value_key = self._get_value_order_key(feature, value) 459 | values.append(feature_key) 460 | values.append(value_key) 461 | 462 | return values 463 | 464 | def _get_feature_order_key(self, feature): 465 | feature_order = self.feature_order 466 | feature_order_key = self._get_order_key(dict_=feature_order, key=feature) 467 | return feature_order_key 468 | 469 | def _get_value_order_key(self, feature, value): 470 | value_order = self.feature_value_order.get(feature, {}) 471 | value_order_key = self._get_order_key(dict_=value_order, key=value) 472 | return value_order_key 473 | 474 | @staticmethod 475 | def _get_order_key(dict_, key): 476 | order_key = 0 477 | if not dict_ == {}: 478 | try: 479 | order_key = dict_[key] 480 | except KeyError: 481 | order_key = max(dict_.values()) + 1 482 | 483 | return order_key 484 | 485 | def _assign_condition(self): 486 | root = self._get_root() 487 | queue = deque([root]) 488 | 489 | while queue: 490 | node = queue.popleft() 491 | 492 | next_nodes = self.successors(node) 493 | next_nodes = sorted(next_nodes, key=self._sort_key) 494 | 495 | for n_i, n in enumerate(next_nodes): 496 | if n_i == 0: 497 | condition = 'if' 498 | elif n_i == len(next_nodes) - 1: 499 | condition = 'else' 500 | else: 501 | condition = 'elif' 502 | 503 | self.node[n]['condition'] = condition 504 | 505 | queue.extend(next_nodes) 506 | 507 | def _handle_switch_statements(self): 508 | self._assign_switch_headers() 509 | self._adapt_switch_indentation() 510 | self._adapt_switch_header_indentation() 511 | 512 | def _assign_switch_headers(self): 513 | root = self._get_root() 514 | stack = deque(self._get_sorted_out_edges(root)) 515 | 516 | while stack: 517 | parent, child = stack.popleft() 518 | 519 | next_edges = self._get_sorted_out_edges(child) 520 | stack.extendleft(next_edges[::-1]) # extendleft reverses order! 521 | 522 | type_ = self.edge[parent][child].get('type') 523 | 524 | if type_ == 'range' and len(set(self.node[parent]['split'].values())) == 1: 525 | feature = self._get_feature(parent, child, state_node=parent) 526 | 527 | header = 'switch {}:'.format(feature) # appropriate indentation added later 528 | 529 | self.node[parent]['switch_header'] = header 530 | 531 | def _adapt_switch_indentation(self): 532 | switch_header_nodes = [n for n, d in self.nodes_iter(data=True) if d.get('switch_header')] 533 | stack = deque(switch_header_nodes) 534 | 535 | while stack: 536 | node = stack.popleft() 537 | next_nodes = self.successors(node) 538 | stack.extendleft(next_nodes[::-1]) # extendleft reverses order! 539 | 540 | self.node[node]['indent'] += '\t' 541 | 542 | def _adapt_switch_header_indentation(self): 543 | for node, data in self.nodes_iter(data=True): 544 | if data.get('switch_header'): 545 | try: 546 | parent = self.predecessors(node)[0] 547 | except IndexError: # node is root 548 | continue 549 | parent_indent = self.node[parent]['indent'] 550 | switch_header = self.node[node]['switch_header'] 551 | self.node[node]['switch_header'] = parent_indent + '\t' + switch_header 552 | 553 | def _get_sorted_out_edges(self, node): 554 | edges = self.out_edges_iter(node) 555 | edges = sorted(edges, key=lambda x: self._sort_key(x[1])) 556 | return edges 557 | 558 | def _get_output_text(self, node): 559 | out_text = '' 560 | if self.node[node].get('is_leaf') or self.node[node].get('is_default_leaf'): 561 | if not self.node[node].get('is_smart'): 562 | out_text = self._get_leaf_output(node) 563 | else: 564 | name_line = self._get_name_line(node) 565 | value_line = self._get_value_line(node) 566 | out_text = name_line + value_line 567 | 568 | return out_text 569 | 570 | def _get_leaf_output(self, node): 571 | out_indent = self.node[node]['indent'] 572 | out_value = self.node[node]['output'] 573 | out_text = '{indent}{value:.4f}\n'.format(indent=out_indent, value=out_value) 574 | 575 | return out_text 576 | 577 | def _get_name_line(self, node): 578 | try: 579 | out_indent = self.node[node]['indent'] 580 | out_name = self.node[node]['leaf_name'] 581 | name_line = '{indent}leaf_name: "{name}"\n'.format(indent=out_indent, name=out_name) 582 | except KeyError: 583 | name_line = '' # leaf_name is optional 584 | 585 | return name_line 586 | 587 | def _get_value_line(self, node): 588 | out_indent = self.node[node]['indent'] 589 | out_value = self._get_smart_leaf_output_value(node) 590 | value_line = '{indent}{value}\n'.format(indent=out_indent, value=out_value) 591 | 592 | return value_line 593 | 594 | def _get_smart_leaf_output_value(self, node): 595 | if isinstance(self.node[node].get('value'), (int, float)): 596 | out_value = self._get_smart_leaf_output_bid_syntax(node) 597 | else: 598 | out_value = self._get_smart_leaf_output_compute_syntax(node) 599 | 600 | return out_value 601 | 602 | def _get_smart_leaf_output_bid_syntax(self, node): 603 | bid_value = self.node[node]['value'] 604 | if round(bid_value, 4) <= 0: 605 | out_value = 'value: no_bid' 606 | else: 607 | out_value = 'value: {bid_value:.4f}'.format(bid_value=bid_value) 608 | return out_value 609 | 610 | def _get_smart_leaf_output_compute_syntax(self, node): 611 | input_field = self.node[node]['input_field'] 612 | multiplier = self._get_compute_input(node, 'multiplier') 613 | offset = self._get_compute_input(node, 'offset') 614 | min_value = self._get_compute_input(node, 'min_value') 615 | max_value = self._get_compute_input(node, 'max_value') 616 | 617 | return 'value: compute({input_field}, {multiplier}, {offset}, {min_value}, {max_value})'.format( 618 | input_field=input_field, 619 | multiplier=multiplier, 620 | offset=offset, 621 | min_value=min_value, 622 | max_value=max_value 623 | ) 624 | 625 | def _get_compute_input(self, node, parameter): 626 | node_dict = self.node[node] 627 | try: 628 | value = round(node_dict[parameter], 4) 629 | except KeyError: 630 | value = '_' 631 | return value 632 | 633 | def _get_conditional_text(self, parent, child): 634 | pre_out = self._get_pre_out_statement(parent, child) 635 | out = self._get_out_statement(parent, child) 636 | 637 | return pre_out + out 638 | 639 | def _get_pre_out_statement(self, parent, child): 640 | type_ = self.edge[parent][child].get('type') 641 | conditional = self.node[child]['condition'] 642 | 643 | pre_out = '' 644 | 645 | if type_ == 'range' and conditional == 'if' and len(set(self.node[parent]['split'].values())) == 1: 646 | pre_out = self.node[parent]['switch_header'] + '\n' 647 | 648 | return pre_out 649 | 650 | def _get_out_statement(self, parent, child): 651 | indent = self.node[parent]['indent'] 652 | value = self.edge[parent][child].get('value') 653 | type_ = self.edge[parent][child].get('type') 654 | conditional = self.node[child]['condition'] 655 | feature = self._get_feature(parent, child, state_node=child) 656 | switch_header = self.node[parent].get('switch_header') 657 | join_statement = self.edge[parent][child].get('join_statement', 'every') 658 | is_negated = self._get_is_negated(parent, child, feature) 659 | 660 | if switch_header and type_ == 'range': 661 | out = self._get_switch_header_range_statement(indent, value) 662 | else: 663 | out = '{indent}{conditional}' 664 | if type_ is not None and all(isinstance(x, (list, tuple)) for x in (feature, type_)): 665 | out += ' ' + join_statement + ' ' + ', '.join( 666 | self._get_if_conditional(v, t, f, i, join_statement=join_statement) for v, t, f, i 667 | in zip(value, type_, feature, is_negated) 668 | ) 669 | elif type_ is not None and not any(isinstance(x, (list, tuple)) for x in (feature, type_)): 670 | out += ' ' + self._get_if_conditional(value, type_, feature, is_negated) 671 | elif type_ is None: 672 | out += '' 673 | else: 674 | raise ValueError( 675 | 'Unable to deduce if-conditional ' 676 | 'for feature "{}" and type "{}".'.format( 677 | feature, type_ 678 | ) 679 | ) 680 | out += ':\n' 681 | 682 | out = out.format(indent=indent, conditional=conditional) 683 | 684 | return out 685 | 686 | def _get_feature(self, parent, child, state_node): 687 | feature = self.node[parent].get('split') 688 | if isinstance(feature, dict): 689 | try: 690 | feature = feature[child] 691 | except KeyError: 692 | assert self.node[child].get('is_default_leaf', self.node[child].get('is_default_node', False)) 693 | if isinstance(feature, (list, tuple)): 694 | return self._get_formatted_multidimensional_compound_feature(feature, state_node) 695 | elif '.' in feature: 696 | return self._get_formatted_compound_feature(feature, state_node) 697 | else: 698 | return feature 699 | 700 | def _get_is_negated(self, parent, child, feature): 701 | try: 702 | return self.edge[parent][child]['is_negated'] 703 | except KeyError: 704 | if isinstance(feature, (list, tuple)): 705 | return len(feature) * (False,) 706 | else: 707 | return False 708 | 709 | def _get_formatted_multidimensional_compound_feature(self, feature, state_node): 710 | attribute_indices = self._get_attribute_indices(feature) 711 | feature = list(feature) 712 | for i in attribute_indices: 713 | feature[i] = self._get_formatted_compound_feature(feature[i], state_node) 714 | 715 | return tuple(feature) 716 | 717 | @staticmethod 718 | def _get_attribute_indices(feature): 719 | return [feature.index(f) for f in feature if '.' in f and f.split('.')[0] in feature] 720 | 721 | def _get_formatted_compound_feature(self, feature, state_node): 722 | object_, attribute = feature.split('.') 723 | try: 724 | attribute, value = attribute.split('__') 725 | except ValueError: 726 | try: 727 | value = self.node[state_node]['state'][object_] 728 | except KeyError: 729 | value = self.__getattribute__(object_) 730 | 731 | feature = '{feature}[{value}].{attribute}'.format( 732 | feature=object_, 733 | value=value, 734 | attribute=attribute 735 | ) 736 | 737 | return feature 738 | 739 | @staticmethod 740 | def _get_switch_header_range_statement(indent, value): 741 | if value is None: 742 | return '' 743 | 744 | left_bound, right_bound = value 745 | try: 746 | left_bound = round(left_bound, 4) 747 | _ = int(left_bound) # NOQA 748 | except (TypeError, OverflowError): 749 | left_bound = '' 750 | try: 751 | right_bound = round(right_bound, 4) 752 | _ = int(right_bound) # NOQA 753 | except (TypeError, OverflowError): 754 | right_bound = '' 755 | 756 | if left_bound == right_bound == '': 757 | raise ValueError( 758 | 'Value "{}" not reasonable as value of a range feature.'.format( 759 | value 760 | ) 761 | ) 762 | 763 | out = '{indent}case ({left_bound} .. {right_bound}):\n'.format( 764 | indent=indent, 765 | left_bound=left_bound, 766 | right_bound=right_bound 767 | ) 768 | 769 | return out 770 | 771 | def _get_if_conditional(self, value, type_, feature, is_negated, join_statement=None): 772 | 773 | if type_ not in {'range', 'membership', 'assignment'}: 774 | raise ValueError( 775 | 'Unable to deduce conditional statement for type "{}".'.format(type_) 776 | ) 777 | 778 | if is_absent_value(value): 779 | out = self._get_if_conditional_missing_value(type_, feature) 780 | else: 781 | out = self._get_if_conditional_present_value(value, type_, feature, join_statement=join_statement) 782 | 783 | if is_negated: 784 | out = 'not {}'.format(out) 785 | 786 | return out 787 | 788 | def _get_if_conditional_missing_value(self, type_, feature): 789 | out = '{feature} absent'.format(feature=feature) 790 | 791 | return out 792 | 793 | def _get_if_conditional_present_value(self, value, type_, feature, join_statement=None): 794 | if type_ == 'range': 795 | out = self._get_range_statement(value, feature, join_statement=join_statement) 796 | elif type_ == 'membership': 797 | value = tuple(value) 798 | if isinstance(value[0], basestring): 799 | value = '(\"{}\")'.format('\",\"'.join(value)) 800 | out = '{feature} in {value}'.format( 801 | feature=feature, 802 | value=value 803 | ) 804 | elif type_ == 'assignment': 805 | comparison = '=' 806 | value = '"{}"'.format(value) if not self._is_numerical(value) else value 807 | 808 | if feature.split('.')[0] not in compound_features: 809 | out = '{feature}{comparison}{value}'.format( 810 | feature=feature, 811 | comparison=comparison, 812 | value=value 813 | ) 814 | elif feature in compound_features: 815 | out = '{feature}[{value}]'.format( 816 | feature=feature, 817 | value=value 818 | ) 819 | else: 820 | object_, attribute = feature.split('.') 821 | out = '{feature}[{value}].{attribute}'.format( 822 | feature=object_, 823 | value=value, 824 | attribute=attribute 825 | ) 826 | 827 | return out 828 | 829 | def _get_range_statement(self, value, feature, join_statement=None): 830 | left_bound, right_bound = value 831 | 832 | if self._is_finite(left_bound) and self._is_finite(right_bound): 833 | left_bound = round(left_bound, 4) 834 | right_bound = round(right_bound, 4) 835 | out = self._get_range_output_for_finite_boundary_points( 836 | left_bound=left_bound, right_bound=right_bound, feature=feature, join_statement=join_statement 837 | ) 838 | elif not self._is_finite(left_bound) and self._is_finite(right_bound): 839 | right_bound = round(right_bound, 4) 840 | out = '{feature} <= {right_bound}'.format(feature=feature, right_bound=right_bound) 841 | elif self._is_finite(left_bound) and not self._is_finite(right_bound): 842 | left_bound = round(left_bound, 4) 843 | out = '{feature} >= {left_bound}'.format(feature=feature, left_bound=left_bound) 844 | else: 845 | raise ValueError( 846 | 'Value "{}" not reasonable as value of a range feature.'.format( 847 | value 848 | ) 849 | ) 850 | 851 | return out 852 | 853 | def _get_range_output_for_finite_boundary_points(self, left_bound, right_bound, feature, join_statement=None): 854 | if left_bound < right_bound and all([obj not in feature for obj in objects]): 855 | out = '{feature} range ({left_bound}, {right_bound})'.format( 856 | feature=feature, 857 | left_bound=left_bound, 858 | right_bound=right_bound 859 | ) 860 | elif left_bound < right_bound and any([obj in feature for obj in objects]): 861 | join = self._get_join(join_statement) 862 | out = '{join}{feature} >= {left_bound}, {feature} <= {right_bound}'.format( 863 | join=join, 864 | feature=feature, 865 | left_bound=left_bound, 866 | right_bound=right_bound 867 | ) 868 | else: 869 | out = '{feature} = {left_bound}'.format( 870 | feature=feature, 871 | left_bound=left_bound 872 | ) 873 | return out 874 | 875 | @staticmethod 876 | def _get_join(join_statement): 877 | if join_statement == 'any': 878 | raise ValueError( 879 | 'Cannot combine object feature "range" with "any" join_statement.' 880 | 'Object features are: {}.'.format(objects) 881 | ) 882 | join = '' if join_statement else 'every ' 883 | return join 884 | 885 | def _get_default_conditional_text(self, parent, child): 886 | type_ = self._get_sibling_type(parent, child) 887 | indent = self.node[parent]['indent'] 888 | 889 | conditional = 'default' if type_ == 'range' and len(set(self.node[parent]['split'].values())) == 1 else 'else' 890 | 891 | return '{indent}{conditional}:\n'.format(indent=indent, conditional=conditional) 892 | 893 | def _get_edge_siblings(self, parent, child): 894 | this_edge = (parent, child) 895 | sibling_edges = [edge for edge in self.out_edges(parent) if edge != this_edge] 896 | 897 | return sibling_edges 898 | 899 | def _get_sibling_type(self, parent, child): 900 | sibling_edges = self._get_edge_siblings(parent, child) 901 | sibling_types = [self.edge[sibling_parent][sibling_child]['type'] 902 | for sibling_parent, sibling_child in sibling_edges] 903 | 904 | return sibling_types[0] 905 | 906 | def _tree_to_bonsai(self): 907 | root = self._get_root() 908 | stack = deque(self._get_sorted_out_edges(root)) 909 | 910 | while stack: 911 | parent, child = stack.popleft() 912 | 913 | next_edges = self._get_sorted_out_edges(child) 914 | stack.extendleft(next_edges[::-1]) # extendleft reverses order! 915 | 916 | if not self.node[child].get('is_default_leaf', False): 917 | conditional_text = self._get_conditional_text(parent, child) 918 | elif self.node[child].get('is_default_leaf', False): 919 | conditional_text = self._get_default_conditional_text(parent, child) 920 | 921 | out_text = self._get_output_text(child) 922 | 923 | yield conditional_text + out_text 924 | 925 | @staticmethod 926 | def _is_numerical(x): 927 | try: 928 | int(x) 929 | float(x) 930 | return True 931 | except ValueError: 932 | return False 933 | 934 | @staticmethod 935 | def _is_finite(x): 936 | try: 937 | is_finite = abs(x) < float('inf') 938 | return is_finite 939 | except TypeError: 940 | return False 941 | -------------------------------------------------------------------------------- /bonspy/features.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | from itertools import product 9 | 10 | objects = ['advertiser', 'line_item', 'campaign'] 11 | attributes = ['recency', 'day_frequency', 'lifetime_frequency'] 12 | compound_features = objects + ['segment'] 13 | 14 | FLOORS = { 15 | 'segment.age': 0, 16 | 'user_hour': 0, 17 | 'segment.value': 1 18 | } 19 | 20 | CEILINGS = { 21 | 'user_hour': 23 22 | } 23 | 24 | OPERATIONS = { 25 | 'domain': [lambda value: str.lstrip(value, 'www.')] 26 | } 27 | 28 | TYPES = { 29 | 'segment': int, 30 | 'segment.age': int, 31 | 'segment.value': int, 32 | 'user_hour': int 33 | } 34 | 35 | TYPES.update({'{object}.{attribute}'.format(object=k, attribute=v): int for (k, v) in product(objects, attributes)}) 36 | 37 | 38 | def get_validated(feature, value): 39 | """ 40 | Returns the passed feature value or collection of feature values 41 | clamped to expected ceilings and floors. 42 | Further casts feature values to their expected data types. 43 | 44 | :param feature: str, Name of the feature 45 | :param value: Either one feature value or a tuple / list of feature values. 46 | :return: The return value has the same dimensionality as the input `value`. 47 | """ 48 | 49 | if isinstance(value, (list, tuple)): 50 | orig_type = type(value) 51 | value = list(value) 52 | for index, a_value in enumerate(value): 53 | value[index] = _get_valid_value(feature, a_value) 54 | value = orig_type(value) 55 | return value 56 | else: 57 | return _get_valid_value(feature, value) 58 | 59 | 60 | def _get_valid_value(feature, value): 61 | value = _get_ceiling(feature, value) 62 | value = _get_floor(feature, value) 63 | value = _type_cast(feature, value) 64 | value = _apply_operations(feature, value) 65 | 66 | return value 67 | 68 | 69 | def _get_ceiling(feature, value): 70 | if value is None: 71 | return value 72 | 73 | try: 74 | return min(CEILINGS[feature], value) 75 | except KeyError: 76 | return value 77 | 78 | 79 | def _get_floor(feature, value): 80 | if value is None: 81 | return value 82 | 83 | try: 84 | return max(FLOORS[feature], value) 85 | except KeyError: 86 | return value 87 | 88 | 89 | def _type_cast(feature, value): 90 | try: 91 | return TYPES[feature](value) 92 | except KeyError: 93 | return value 94 | except TypeError: 95 | # `value` is None 96 | return value 97 | except OverflowError: 98 | # `value` is -inf or inf 99 | return value 100 | 101 | 102 | def _apply_operations(feature, value): 103 | try: 104 | operations = OPERATIONS[feature] 105 | except KeyError: 106 | return value 107 | 108 | for operation in operations: 109 | try: 110 | value = operation(value) 111 | except TypeError: # `value` is None 112 | pass 113 | 114 | return value 115 | -------------------------------------------------------------------------------- /bonspy/graph_builder.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from collections import OrderedDict, defaultdict 3 | from csv import DictReader 4 | import gzip 5 | from glob import glob 6 | 7 | import networkx as nx 8 | 9 | 10 | class GraphBuilder: 11 | 12 | def __init__(self, input_, features, lazy_formatters=(), types_dict={}, functions=()): 13 | """ 14 | :param input_: str or list of str, path to gzipped csv input 15 | :param features: iterable, ordered features to build the tree with 16 | :param lazy_formatters: tuple of tuples, e.g. (('os', str), (user_day, int)) or dict 17 | :param types_dict: dict, types to be used for split, defaults to "assignment" 18 | :param functions: iterable, functions that return node_dict and take node_dict and row as arguments 19 | """ 20 | self.input_ = glob(input_) if isinstance(input_, str) else input_ 21 | self.features = features 22 | self.types_iterable = self._get_types_iterable(types_dict) 23 | self.lazy_formatters = self._get_lazy_formatter(lazy_formatters) 24 | self.functions = functions 25 | 26 | def _get_types_iterable(self, types_dict): 27 | return tuple(types_dict.get(f, 'assignment') for f in self.features) 28 | 29 | @staticmethod 30 | def _get_lazy_formatter(formatters): 31 | if not formatters: 32 | return defaultdict(lambda: (lambda x: x)) 33 | else: 34 | lazy_formatters = defaultdict(lambda: str) 35 | lazy_formatters.update(formatters) 36 | return lazy_formatters 37 | 38 | def get_data(self): 39 | for file in self.input_: 40 | data = DictReader(gzip.open(file, 'rt', encoding='utf-8')) 41 | yield from data 42 | 43 | def get_graph(self, graph=None): 44 | graph, node_index = self._seed_graph(graph) 45 | 46 | data = self.get_data() 47 | for row in data: 48 | graph, node_index = self._add_branch(graph, row, node_index) 49 | 50 | return graph 51 | 52 | @staticmethod 53 | def _seed_graph(graph): 54 | if not graph: 55 | graph = nx.DiGraph() 56 | root = 0 57 | graph.add_node(root, state=OrderedDict()) 58 | node_index = 1 + max((n for n in graph.nodes_iter())) 59 | return graph, node_index 60 | 61 | def _add_branch(self, graph, row, node_index): 62 | parent = 0 63 | graph.node[parent] = self._apply_functions(graph.node[parent], row) 64 | 65 | for feature in self.features: 66 | feature_value = row[feature] 67 | child = self._get_child(graph, parent, feature, feature_value) 68 | if child is None: 69 | 70 | childless = self._check_if_childless(graph, parent) 71 | if childless: 72 | default_leaf = node_index 73 | state = self._get_state(graph, parent) 74 | graph.add_node(default_leaf, state=state, is_default_leaf=True) 75 | graph.add_edge(parent, default_leaf) 76 | node_index += 1 77 | 78 | child = node_index 79 | state = self._get_state(graph, parent, new_feature=(feature, feature_value)) 80 | graph.add_node(child, state=state) 81 | graph = self._connect_node_to_parent(graph, parent, child, feature, feature_value) 82 | graph = self._update_parent_split(graph, parent, feature) 83 | node_index += 1 84 | 85 | graph.node[child] = self._apply_functions(graph.node[child], row) 86 | parent = child 87 | else: 88 | graph.node[child]['is_leaf'] = True 89 | 90 | return graph, node_index 91 | 92 | @staticmethod 93 | def _check_if_childless(graph, parent): 94 | edges = graph.edges_iter(parent) 95 | try: 96 | _ = next(edges) # NOQA 97 | return False 98 | except StopIteration: 99 | return True 100 | 101 | def _get_state(self, graph, parent, new_feature=None): 102 | state = graph.node[parent]['state'].copy() 103 | new_state = self._add_new_feature(state, new_feature) if new_feature else state 104 | return new_state 105 | 106 | def _get_child(self, graph, parent, feature, feature_value): 107 | edges = graph.edges_iter(parent, data=True) 108 | children = ((child, data) for _, child, data in edges if data) # filter out default leaves 109 | formatter = self._get_formatter(self.lazy_formatters[feature]) 110 | try: 111 | child = next(child for child, data in children if data.get('value') == formatter(feature_value)) 112 | except StopIteration: 113 | child = None 114 | return child 115 | 116 | def _connect_node_to_parent(self, graph, parent, new_node, feature, feature_value): 117 | feature_index = self.features.index(feature) 118 | type_ = self.types_iterable[feature_index] 119 | formatter = self._get_formatter(self.lazy_formatters[feature]) 120 | graph.add_edge(parent, new_node, type=type_, value=formatter(feature_value)) 121 | return graph 122 | 123 | @staticmethod 124 | def _update_parent_split(graph, parent, feature): 125 | graph.node[parent]['split'] = feature 126 | return graph 127 | 128 | def _apply_functions(self, node_dict, row): 129 | for function_ in self.functions: 130 | node_dict = function_(node_dict, row) 131 | return node_dict 132 | 133 | def _add_new_feature(self, state, new_feature): 134 | feature, value = new_feature 135 | formatter = self._get_formatter(self.lazy_formatters[feature]) 136 | state[feature] = formatter(value) 137 | return state 138 | 139 | @staticmethod 140 | def _get_formatter(formatter): 141 | return lambda x: formatter(x) if len(x) > 0 else None 142 | 143 | 144 | class Bidder(metaclass=ABCMeta): 145 | 146 | def compute_bids(self, graph): 147 | leaves = self.get_leaves(graph) 148 | for leaf in leaves: 149 | output_dict = self.get_bid(graph=graph, leaf=leaf) 150 | for key, value in output_dict.items(): 151 | graph.node[leaf][key] = value 152 | return graph 153 | 154 | @abstractmethod 155 | def get_bid(self, *args, **kwargs): 156 | pass 157 | 158 | @staticmethod 159 | def get_leaves(graph): 160 | leaves = ( 161 | n for n in graph.nodes_iter() if graph.node[n].get('is_leaf', graph.node[n].get('is_default_leaf', False)) 162 | ) 163 | return leaves 164 | 165 | 166 | class ConstantBidder(Bidder): 167 | 168 | def __init__(self, bid=1., **kwargs): 169 | self.bid = bid 170 | for key, value in kwargs.items(): 171 | setattr(self, key, value) 172 | 173 | def get_bid(self, *args, **kwargs): 174 | return {'output': self.bid} 175 | 176 | 177 | class EstimatorBidder(Bidder): 178 | 179 | def __init__(self, base_bid=1., estimators=(), **kwargs): 180 | self.base_bid = base_bid 181 | self.estimators = estimators 182 | for key, value in kwargs.items(): 183 | setattr(self, key, value) 184 | 185 | def get_bid(self, *args, **kwargs): 186 | graph = kwargs['graph'] 187 | leaf = kwargs['leaf'] 188 | state = graph.node[leaf]['state'] 189 | bid = self.base_bid 190 | for estimator in self.estimators: 191 | x = estimator.dict_vectorizer(state, **self.__dict__) 192 | try: 193 | bid *= estimator.predict(x)[0] 194 | except TypeError: 195 | bid *= estimator.predict(x) 196 | return {'output': bid} 197 | -------------------------------------------------------------------------------- /bonspy/logistic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | from collections import defaultdict 9 | import math 10 | 11 | import networkx as nx 12 | 13 | 14 | class LogisticConverter: 15 | """ 16 | Converter that translates a trained sklearn logistic regression classifier 17 | with one-hot-encoded, categorical features to a NetworkX graph that can 18 | be output to Bonsai with the `bonspy.BonsaiTree` converter. 19 | 20 | Attributes: 21 | features (list): List of feature names. 22 | vocabulary (dict): `vocabulary_` attribute of your trained `DictVectorizer` 23 | (http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.DictVectorizer.html) 24 | weights (list): `coef_` attribute of your trained `SGDClassifier(loss='log', ...)` 25 | (http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html) 26 | intercept (float): `intercept_` attribute of your trained `SGDClassifier(loss='log', ...)` 27 | types (dict): Variable assignment type definitions: 'assignment', 'range', or membership. 28 | base_bid (float): Constant value that the output of the trained classifier 29 | is multiplied with to produce the output (bid). 30 | buckets (dict): Optional. Map for range features from bucket ID's to their bounds. 31 | """ 32 | 33 | def __init__(self, features, vocabulary, weights, intercept, types, base_bid, 34 | buckets=None): 35 | 36 | self.features = features 37 | self.vocabulary = vocabulary 38 | self.weights = weights 39 | self.intercept = intercept 40 | self.types = types 41 | self.base_bid = base_bid 42 | self.buckets = buckets or {} 43 | 44 | self.feature_map = self._get_feature_map() 45 | 46 | self.graph = self._create_graph() 47 | 48 | def _get_feature_map(self): 49 | buckets = self.buckets 50 | map_ = defaultdict(dict) 51 | for key, index in self.vocabulary.items(): 52 | feature, value = key.split('=') 53 | range_ = buckets.get(feature, {}).get(value) 54 | 55 | if range_ is None: 56 | map_[feature][value] = index 57 | else: 58 | map_[feature][range_] = index 59 | 60 | return map_ 61 | 62 | def _create_graph(self): 63 | g = self._create_graph_skeleton() 64 | g = self._populate_nodes(g) 65 | g = self._populate_edges(g) 66 | 67 | return g 68 | 69 | def _create_graph_skeleton(self): 70 | g = nx.DiGraph() 71 | 72 | features = [tuple()] + self.features 73 | queue = [tuple()] 74 | g.add_node(tuple(), weight=self.intercept) 75 | 76 | while len(queue) > 0: 77 | parent = queue.pop(0) 78 | index = len(parent) 79 | 80 | try: 81 | next_feature = features[index + 1] 82 | except IndexError: 83 | continue 84 | 85 | for value, weight_index in self.feature_map[next_feature].items(): 86 | child = tuple(list(parent) + [value]) 87 | g.add_edge(parent, child) 88 | g.node[child]['weight'] = self.weights[weight_index] 89 | queue.append(child) 90 | 91 | # add default leaf / else node: 92 | value = None 93 | child = tuple(list(parent) + [value]) 94 | g.add_edge(parent, child) 95 | g.node[child]['weight'] = 0. 96 | 97 | return g 98 | 99 | def _populate_nodes(self, g): 100 | g = self._add_state(g) 101 | g = self._add_split(g) 102 | g = self._sum_weights(g) 103 | g = self._add_leaf_output(g) 104 | g = self._add_default_leaf_output(g) 105 | 106 | return g 107 | 108 | def _populate_edges(self, g): 109 | g = self._add_value(g) 110 | g = self._add_type(g) 111 | 112 | return g 113 | 114 | def _add_state(self, g): 115 | for node in nx.dfs_preorder_nodes(g, tuple()): 116 | if node == tuple(): 117 | state = {} 118 | elif node[-1] is None: 119 | parent = g.predecessors(node)[0] 120 | state = g.node[parent]['state'] 121 | else: 122 | state = {feat: value for feat, value in zip(self.features, node)} 123 | 124 | g.node[node]['state'] = state 125 | 126 | return g 127 | 128 | def _add_split(self, g): 129 | for node in g.nodes(): 130 | if node != tuple() and node[-1] is None: 131 | continue # skip default leaf 132 | 133 | index = len(node) 134 | try: 135 | split = self.features[index] 136 | g.node[node]['split'] = split 137 | except IndexError: 138 | continue 139 | 140 | return g 141 | 142 | def _sum_weights(self, g): 143 | queue = [tuple()] 144 | g.node[tuple()]['sum'] = g.node[tuple()]['weight'] 145 | 146 | while len(queue) > 0: 147 | parent = queue.pop(0) 148 | parent_sum = g.node[parent]['sum'] 149 | 150 | children = g.successors(parent) 151 | queue += children 152 | 153 | for child in children: 154 | g.node[child]['sum'] = parent_sum + g.node[child]['weight'] 155 | 156 | return g 157 | 158 | def _add_leaf_output(self, g): 159 | for node in g.nodes(): 160 | if len(g.successors(node)) > 0 or node[-1] is None: 161 | continue 162 | 163 | g.node[node]['is_leaf'] = True 164 | g.node[node]['output'] = self._sigmoid(g.node[node]['sum']) * self.base_bid 165 | 166 | return g 167 | 168 | def _add_default_leaf_output(self, g): 169 | for node in g.nodes(): 170 | if len(g.successors(node)) > 0 or node[-1] is not None: 171 | continue 172 | 173 | g.node[node]['is_default_leaf'] = True 174 | g.node[node]['output'] = self._sigmoid(g.node[node]['sum']) * self.base_bid 175 | 176 | return g 177 | 178 | def _add_value(self, g): 179 | for parent, child in g.edges(): 180 | value = child[-1] 181 | if value is None: 182 | continue 183 | 184 | g.edge[parent][child]['value'] = value 185 | 186 | return g 187 | 188 | def _add_type(self, g): 189 | for parent, child in g.edges(): 190 | if child[-1] is None: 191 | continue 192 | 193 | feature = g.node[parent]['split'] 194 | g.edge[parent][child]['type'] = self.types[feature] 195 | 196 | return g 197 | 198 | @staticmethod 199 | def _sigmoid(x): 200 | return 1. / (1. + math.exp(-x)) 201 | -------------------------------------------------------------------------------- /bonspy/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | -------------------------------------------------------------------------------- /bonspy/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | from collections import OrderedDict 9 | import gzip 10 | import os 11 | 12 | import networkx as nx 13 | 14 | import pytest 15 | 16 | 17 | @pytest.fixture 18 | def graph(): 19 | g = nx.DiGraph() 20 | 21 | g.add_node(0, split='segment', state=OrderedDict()) 22 | g.add_node(1, split=OrderedDict([(4, 'segment.age'), (26, 'browser'), (5, 'language')]), 23 | state=OrderedDict([('segment', 12345)])) 24 | g.add_node(2, split='segment.age', 25 | state=OrderedDict([('segment', 67890)])) 26 | g.add_node(3, split='geo', 27 | state=OrderedDict([('segment', 13579)])) 28 | g.add_node(4, split='geo', 29 | state=OrderedDict([('segment', 12345), ('segment.age', (-float('inf'), 10.))])) 30 | g.add_node(5, split='geo', 31 | state=OrderedDict([('segment', 12345), ('language', 'english')])) 32 | g.add_node(6, split=OrderedDict([(14, 'os'), (15, 'geo')]), 33 | state=OrderedDict([('segment', 67890), ('segment.age', (-float('inf'), 20.))])) 34 | g.add_node(7, split='geo', 35 | state=OrderedDict([('segment', 67890), ('segment.age', (20., 40.))])) 36 | g.add_node(26, is_leaf=True, output=1.1, 37 | state=OrderedDict([('segment', 12345), ('browser', 'safari')])) 38 | g.add_node(8, is_leaf=True, output=0.13, 39 | state=OrderedDict([('segment', 13579), 40 | ('geo', ('UK', 'DE', 'US'))])) 41 | g.add_node(9, is_leaf=True, output=1.2, 42 | state=OrderedDict([('segment', 13579), 43 | ('geo', ('BR',))])) 44 | g.add_node(10, is_leaf=True, output=0.10, 45 | state=OrderedDict([('segment', 12345), ('segment.age', (0, 10.)), 46 | ('geo', ('UK', 'DE'))])) 47 | g.add_node(11, is_leaf=True, output=0.20, 48 | state=OrderedDict([('segment', 12345), ('segment.age', (0, 10.)), 49 | ('geo', None)])) 50 | g.add_node(12, is_leaf=True, output=0.10, 51 | state=OrderedDict([('segment', 12345), ('segment.age', (10., 20.)), 52 | ('geo', ('UK', 'DE'))])) 53 | g.add_node(13, is_leaf=True, output=0.20, 54 | state=OrderedDict([('segment', 12345), ('segment.age', (10., 20.)), 55 | ('geo', ('US', 'BR'))])) 56 | g.add_node(14, is_leaf=True, is_smart=True, value=0., 57 | state=OrderedDict([('segment', 67890), ('segment.age', (0., 20.)), 58 | ('os', 'windows')])) 59 | g.add_node(15, is_leaf=True, output=0.20, 60 | state=OrderedDict([('segment', 67890), ('segment.age', (0., 20.)), 61 | ('geo', ('US', 'BR'))])) 62 | g.add_node(16, is_leaf=True, is_smart=True, value=0.10, 63 | state=OrderedDict([('segment', 67890), ('segment.age', (20., float('inf'))), 64 | ('geo', ('UK', 'DE'))])) 65 | g.add_node(17, is_leaf=True, is_smart=True, 66 | input_field='uniform', offset=0.4, max_value=1., 67 | state=OrderedDict([('segment', 67890), ('segment.age', (20., 40.)), 68 | ('geo', None)])) 69 | g.add_node(18, is_default_leaf=True, output=0.05, state=OrderedDict()) 70 | g.add_node(19, is_default_leaf=True, is_smart=True, 71 | input_field='uniform', multiplier=1.2, min_value=1., 72 | state=OrderedDict([('segment', 12345)])) 73 | g.add_node(20, is_default_leaf=True, is_smart=True, leaf_name='default_17', 74 | input_field='uniform', multiplier=.32, max_value=3.1, 75 | state=OrderedDict([('segment', 67890)])) 76 | g.add_node(21, is_default_leaf=True, output=0.09, 77 | state=OrderedDict([('segment', 13579)])) 78 | g.add_node(22, is_default_leaf=True, output=0.05, 79 | state=OrderedDict([('segment', 12345), ('segment.age', (0., 10.))])) 80 | g.add_node(23, is_default_leaf=True, output=0.05, 81 | state=OrderedDict([('segment', 12345), ('segment.age', (10., 20.))])) 82 | g.add_node(24, is_default_leaf=True, output=0.05, 83 | state=OrderedDict([('segment', 67890), ('segment.age', (0., 20.))])) 84 | g.add_node(25, is_default_leaf=True, output=0.05, 85 | state=OrderedDict([('segment', 67890), ('segment.age', (20., 40.))])) 86 | 87 | g.add_edge(0, 1, value=12345, type='assignment') 88 | g.add_edge(0, 2, value=67890, type='assignment') 89 | g.add_edge(0, 3, value=13579, type='assignment') 90 | g.add_edge(1, 4, value=(0., 10.), type='range') 91 | g.add_edge(1, 5, value='english', type='assignment') 92 | g.add_edge(1, 26, value='safari', type='assignment') 93 | g.add_edge(2, 6, value=(0., 20.), type='range') 94 | g.add_edge(2, 7, value=(20., 40.), type='range') 95 | g.add_edge(3, 8, value=('UK', 'DE', 'US'), type='membership') 96 | g.add_edge(3, 9, value=('BR',), type='membership') 97 | g.add_edge(4, 10, value=('UK', 'DE'), type='membership') 98 | g.add_edge(4, 11, value=None, type='membership') 99 | g.add_edge(5, 12, value=('UK', 'DE'), type='membership') 100 | g.add_edge(5, 13, value=('US', 'BR'), type='membership') 101 | g.add_edge(6, 14, value='windows', type='assignment') 102 | g.add_edge(6, 15, value=('US', 'BR'), type='membership') 103 | g.add_edge(7, 16, value=('UK', 'DE'), type='membership') 104 | g.add_edge(7, 17, value=None, type='membership') 105 | g.add_edge(0, 18) 106 | g.add_edge(1, 19) 107 | g.add_edge(2, 20) 108 | g.add_edge(3, 21) 109 | g.add_edge(4, 22) 110 | g.add_edge(5, 23) 111 | g.add_edge(6, 24) 112 | g.add_edge(7, 25) 113 | 114 | return g 115 | 116 | 117 | @pytest.fixture 118 | def graph_two_range_features(): 119 | g = nx.DiGraph() 120 | 121 | g.add_node(0, split=('segment', 'segment.age'), state=OrderedDict()) 122 | g.add_node(1, split='user_hour', state=OrderedDict([('segment', 12345), ('segment.age', (0., 10.))])) 123 | g.add_node(2, split='user_hour', state=OrderedDict([('segment', 12345), ('segment.age', (10., 20.))])) 124 | g.add_node(3, split='user_hour', state=OrderedDict([('segment', 67890), ('segment.age', (0., 20.))])) 125 | g.add_node(4, split='user_hour', state=OrderedDict([('segment', 67890), ('segment.age', (20., 40.))])) 126 | g.add_node(5, is_leaf=True, output=0.10, 127 | state=OrderedDict([ 128 | ('segment', 12345), 129 | ('segment.age', (0, 10.)), 130 | ('user_hour', (0., 12.)) 131 | ])) 132 | g.add_node(6, is_leaf=True, output=0.20, 133 | state=OrderedDict([ 134 | ('segment', 12345), 135 | ('segment.age', (0, 10.)), 136 | ('user_hour', (12., 100.)) 137 | ])) 138 | g.add_node(7, is_leaf=True, output=0.10, 139 | state=OrderedDict([ 140 | ('segment', 12345), 141 | ('segment.age', (10., 20.)), 142 | ('user_hour', (0., 12.)) 143 | ])) 144 | g.add_node(8, is_leaf=True, output=0.20, 145 | state=OrderedDict([ 146 | ('segment', 12345), 147 | ('segment.age', (10., 20.)), 148 | ('user_hour', (12., 100.)) 149 | ])) 150 | g.add_node(9, is_leaf=True, output=0.10, 151 | state=OrderedDict([ 152 | ('segment', 67890), 153 | ('segment.age', (0., 20.)), 154 | ('user_hour', (0., 12.)) 155 | ])) 156 | g.add_node(10, is_leaf=True, output=0.20, 157 | state=OrderedDict([ 158 | ('segment', 67890), 159 | ('segment.age', (0., 20.)), 160 | ('user_hour', (12., 100.)) 161 | ])) 162 | g.add_node(11, is_leaf=True, output=0.10, 163 | state=OrderedDict([ 164 | ('segment', 67890), 165 | ('segment.age', (20., 40.)), 166 | ('user_hour', (0., 12.)) 167 | ])) 168 | g.add_node(12, is_leaf=True, output=0.20, 169 | state=OrderedDict([ 170 | ('segment', 67890), 171 | ('segment.age', (20., 40.)), 172 | ('user_hour', (12., 100.)) 173 | ])) 174 | g.add_node(13, is_default_leaf=True, output=0.05, state=OrderedDict()) 175 | g.add_node(14, is_default_leaf=True, output=0.05, state=OrderedDict([('segment', 12345)])) 176 | g.add_node(15, is_default_leaf=True, output=0.05, state=OrderedDict([('segment', 67890)])) 177 | g.add_node(16, is_default_leaf=True, output=0.05, 178 | state=OrderedDict([('segment', 12345), ('segment.age', (0., 10.))])) 179 | g.add_node(17, is_default_leaf=True, output=0.05, 180 | state=OrderedDict([('segment', 12345), ('segment.age', (10., 20.))])) 181 | g.add_node(18, is_default_leaf=True, output=0.05, 182 | state=OrderedDict([('segment', 67890), ('segment.age', (0., 20.))])) 183 | g.add_node(19, is_default_leaf=True, output=0.05, 184 | state=OrderedDict([('segment', 67890), ('segment.age', (20., 40.))])) 185 | 186 | g.add_edge(0, 1, value=(12345, (0., 10.)), type=('assignment', 'range')) 187 | g.add_edge(0, 2, value=(12345, (10., 20.)), type=('assignment', 'range')) 188 | g.add_edge(0, 3, value=(67890, (0., 20.)), type=('assignment', 'range')) 189 | g.add_edge(0, 4, value=(67890, (20., 40.)), type=('assignment', 'range')) 190 | g.add_edge(1, 5, value=(0., 12.), type='range') 191 | g.add_edge(1, 6, value=(12., 100.), type='range') 192 | g.add_edge(2, 7, value=(0., 12.), type='range') 193 | g.add_edge(2, 8, value=(12., 100.), type='range') 194 | g.add_edge(3, 9, value=(0., 12.), type='range') 195 | g.add_edge(3, 10, value=(12., 100.), type='range') 196 | g.add_edge(4, 11, value=(0., 12.), type='range') 197 | g.add_edge(4, 12, value=(12., 100.), type='range') 198 | g.add_edge(0, 13) 199 | g.add_edge(1, 16) 200 | g.add_edge(2, 17) 201 | g.add_edge(3, 18) 202 | g.add_edge(4, 19) 203 | 204 | return g 205 | 206 | 207 | @pytest.fixture 208 | def graph_compound_feature(): 209 | g = nx.DiGraph() 210 | 211 | g.add_node(0, split='geo', state=OrderedDict()) 212 | g.add_node(1, split=('site_id', 'placement_id'), state=OrderedDict([('geo', 'DE')])) 213 | g.add_node(2, split=('site_id', 'placement_id'), state=OrderedDict([('geo', 'UK')])) 214 | g.add_node( 215 | 3, state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a')]), 216 | split='os' 217 | ) 218 | g.add_node( 219 | 4, is_leaf=True, output=.4, 220 | state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'b')]) 221 | ) 222 | g.add_node( 223 | 5, state=OrderedDict([('geo', 'UK'), ('site_id', 1), ('placement_id', 'a')]), 224 | split='os' 225 | ) 226 | g.add_node( 227 | 6, is_leaf=True, output=.6, 228 | state=OrderedDict([('geo', 'UK'), ('site_id', 1), ('placement_id', 'b')]) 229 | ) 230 | g.add_node( 231 | 7, is_leaf=True, output=.9, 232 | state=OrderedDict([('geo', 'UK'), ('site_id', 2), ('placement_id', 'a')]) 233 | ) 234 | g.add_node( 235 | 8, is_leaf=True, output=.2, 236 | state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a'), ('os', 'linux')]) 237 | ) 238 | g.add_node( 239 | 15, is_leaf=True, output=.1, 240 | state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a'), ('os', 'windows')]) 241 | ) 242 | g.add_node( 243 | 9, is_leaf=True, output=.3, 244 | state=OrderedDict([('geo', 'UK'), ('site_id', 2), ('placement_id', 'a'), ('os', 'windows')]) 245 | ) 246 | g.add_node( 247 | 10, is_default_leaf=True, output=.1, state=OrderedDict() 248 | ) 249 | g.add_node( 250 | 11, is_default_leaf=True, output=.5, state=OrderedDict([('geo', 'DE')]) 251 | ) 252 | g.add_node( 253 | 12, is_default_leaf=True, output=.05, state=OrderedDict([('geo', 'UK')]) 254 | ) 255 | g.add_node( 256 | 13, is_default_leaf=True, output=.2, state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a')]) 257 | ) 258 | g.add_node( 259 | 14, is_default_leaf=True, output=.3, state=OrderedDict([('geo', 'UK'), ('site_id', 1), ('placement_id', 'a')]) 260 | ) 261 | 262 | g.add_edge(0, 1, value='DE', type='assignment') 263 | g.add_edge(0, 2, value='UK', type='assignment') 264 | g.add_edge(1, 3, value=(1, 'a'), type=('assignment', 'assignment')) 265 | g.add_edge(1, 4, value=(1, 'b'), type=('assignment', 'assignment')) 266 | g.add_edge(2, 5, value=(1, 'a'), type=('assignment', 'assignment')) 267 | g.add_edge(2, 6, value=(1, 'b'), type=('assignment', 'assignment')) 268 | g.add_edge(2, 7, value=(2, 'a'), type=('assignment', 'assignment')) 269 | g.add_edge(3, 8, value='linux', type='assignment') 270 | g.add_edge(3, 15, value='windows', type='assignment') 271 | g.add_edge(5, 9, value='windows', type='assignment') 272 | g.add_edge(0, 10) 273 | g.add_edge(1, 11) 274 | g.add_edge(2, 12) 275 | g.add_edge(3, 13) 276 | g.add_edge(5, 14) 277 | 278 | return g 279 | 280 | 281 | @pytest.fixture 282 | def graph_with_default_node(): 283 | g = nx.DiGraph() 284 | 285 | g.add_node(0, split='geo', state=OrderedDict()) 286 | g.add_node(1, split=('site_id', 'placement_id'), state=OrderedDict([('geo', 'DE')])) 287 | g.add_node(2, is_default_node=True, split=('site_id', 'placement_id'), state=OrderedDict()) 288 | g.add_node( 289 | 3, state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a')]), 290 | split='os' 291 | ) 292 | g.add_node( 293 | 4, is_leaf=True, output=.4, 294 | state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'b')]) 295 | ) 296 | g.add_node( 297 | 5, state=OrderedDict([('site_id', 1), ('placement_id', 'a')]), 298 | split='os' 299 | ) 300 | g.add_node( 301 | 6, is_leaf=True, output=.6, 302 | state=OrderedDict([('site_id', 1), ('placement_id', 'b')]) 303 | ) 304 | g.add_node( 305 | 7, is_leaf=True, output=.9, 306 | state=OrderedDict([('site_id', 2), ('placement_id', 'a')]) 307 | ) 308 | g.add_node( 309 | 8, is_leaf=True, output=.2, 310 | state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a'), ('os', 'linux')]) 311 | ) 312 | g.add_node( 313 | 9, is_leaf=True, output=.3, 314 | state=OrderedDict([('site_id', 1), ('placement_id', 'a'), ('os', 'windows')]) 315 | ) 316 | g.add_node( 317 | 10, is_default_leaf=True, output=.5, state=OrderedDict([('geo', 'DE')]) 318 | ) 319 | g.add_node( 320 | 11, is_default_leaf=True, output=.05, state=OrderedDict([('geo', 'UK')]) 321 | ) 322 | g.add_node( 323 | 12, is_default_leaf=True, output=.2, state=OrderedDict([('geo', 'DE'), ('site_id', 1), ('placement_id', 'a')]) 324 | ) 325 | g.add_node( 326 | 13, is_default_leaf=True, output=.3, state=OrderedDict([('site_id', 1), ('placement_id', 'a')]) 327 | ) 328 | 329 | g.add_edge(0, 1, value='DE', type='assignment') 330 | g.add_edge(0, 2) 331 | g.add_edge(1, 3, value=(1, 'a'), type=('assignment', 'assignment')) 332 | g.add_edge(1, 4, value=(1, 'b'), type=('assignment', 'assignment')) 333 | g.add_edge(2, 5, value=(1, 'a'), type=('assignment', 'assignment')) 334 | g.add_edge(2, 6, value=(1, 'b'), type=('assignment', 'assignment')) 335 | g.add_edge(2, 7, value=(2, 'a'), type=('assignment', 'assignment')) 336 | g.add_edge(3, 8, value='linux', type='assignment') 337 | g.add_edge(5, 9, value='windows', type='assignment') 338 | g.add_edge(1, 10) 339 | g.add_edge(2, 11) 340 | g.add_edge(3, 12) 341 | g.add_edge(5, 13) 342 | 343 | return g 344 | 345 | 346 | @pytest.fixture 347 | def small_graph(): 348 | g = nx.DiGraph() 349 | 350 | g.add_node(0, split='user_hour', state=OrderedDict()) 351 | g.add_node(1, split='user_day', 352 | state=OrderedDict([('user_hour', (None, 10))])) 353 | 354 | g.add_node(2, split=OrderedDict([(5, 'user_day'), (6, 'os')]), 355 | state=OrderedDict([('user_hour', (11.3, 15))])) 356 | 357 | g.add_node(3, is_leaf=True, output=1.3, 358 | state=OrderedDict([('user_hour', (None, 10)), ('user_day', (1, 4))])) 359 | 360 | g.add_node(4, is_leaf=True, output=1.4, 361 | state=OrderedDict([('user_hour', (None, 10)), ('user_day', (5, 6))])) 362 | 363 | g.add_node(5, is_leaf=True, output=1.5, 364 | state=OrderedDict([('user_hour', (11.3, 15)), ('user_day', (3, 6))])) 365 | 366 | g.add_node(6, is_leaf=True, output=1.6, 367 | state=OrderedDict([('user_hour', (11.3, 15)), ('os', 'linux')])) 368 | 369 | g.add_node(7, is_default_leaf=True, output=0.7, state=OrderedDict()) 370 | g.add_node(8, is_default_leaf=True, output=0.8, state=OrderedDict([('user_hour', (None, 10))])) 371 | g.add_node(9, is_default_leaf=True, output=0.9, state=OrderedDict([('user_hour', (11.3, 15))])) 372 | 373 | g.add_edge(0, 1, value=(None, 10), type='range') 374 | g.add_edge(0, 2, value=(11.3, 15), type='range') 375 | g.add_edge(1, 3, value=(1, 4), type='range') 376 | g.add_edge(1, 4, value=(5, 6), type='range') 377 | g.add_edge(2, 5, value=(3, 6), type='range') 378 | g.add_edge(2, 6, value='linux', type='assignment') 379 | g.add_edge(0, 7) 380 | g.add_edge(1, 8) 381 | g.add_edge(2, 9) 382 | 383 | return g 384 | 385 | 386 | @pytest.fixture(params=['graph', 'graph_two_range_features', 'graph_compound_feature', 387 | 'graph_with_default_node', 'small_graph']) 388 | def parameterized_graph(request): 389 | return request.getfuncargvalue(request.param) 390 | 391 | 392 | @pytest.fixture 393 | def missing_values_graph(): 394 | g = nx.DiGraph() 395 | 396 | g.add_node('root', split='segment', state=OrderedDict()) 397 | g.add_node( 398 | 'root_default', 399 | is_default_leaf=True, 400 | state=OrderedDict(), 401 | output=.1 402 | ) 403 | g.add_node('segment_1', split='segment.age', state=OrderedDict([('segment', 1)])) 404 | g.add_node( 405 | 'segment_1_default', 406 | is_default_leaf=True, 407 | state=OrderedDict([('segment', 1)]), 408 | output=.1 409 | ) 410 | g.add_node('segment_2', split='os', state=OrderedDict([('segment', 2)])) 411 | g.add_node( 412 | 'segment_2_default', 413 | is_default_leaf=True, 414 | state=OrderedDict([('segment', 2)]), 415 | output=.1 416 | ) 417 | g.add_node('segment_missing', split='segment.age', state=OrderedDict([('segment', None)])) 418 | g.add_node( 419 | 'segment_missing_default', 420 | is_default_leaf=True, 421 | state=OrderedDict([('segment', None)]), 422 | output=.1 423 | ) 424 | g.add_node( 425 | 'segment_1_age_lower', 426 | is_leaf=True, 427 | state=OrderedDict([('segment', 1), ('segment.age', (-float('inf'), 10.))]), 428 | output=.1 429 | ) 430 | g.add_node( 431 | 'segment_1_age_upper', 432 | is_leaf=True, 433 | state=OrderedDict([('segment', 1), ('segment.age', (10., float('inf')))]), 434 | output=.1 435 | ) 436 | g.add_node( 437 | 'segment_2_os_known', 438 | is_leaf=True, 439 | state=OrderedDict([('segment', 2), ('os', ('linux', 'osx'))]), 440 | output=.1 441 | ) 442 | g.add_node( 443 | 'segment_2_os_unknown', 444 | is_leaf=True, 445 | state=OrderedDict([('segment', 2), ('os', None)]), 446 | output=.1 447 | ) 448 | g.add_node( 449 | 'segment_missing_age_missing', 450 | split='os', 451 | state=OrderedDict([('segment', None), ('segment.age', None)]) 452 | ) 453 | g.add_node( 454 | 'segment_missing_age_missing_default', 455 | is_default_leaf=True, 456 | state=OrderedDict([('segment', None), ('segment.age', None)]), 457 | output=.1 458 | ) 459 | g.add_node( 460 | 'segment_missing_age_missing_os_known', 461 | is_leaf=True, 462 | state=OrderedDict([('segment', None), ('segment.age', None), ('os', ('linux',))]), 463 | output=.1 464 | ) 465 | 466 | g.add_edge('root', 'segment_1', value=1, type='assignment') 467 | g.add_edge('root', 'segment_2', value=2, type='assignment') 468 | g.add_edge('root', 'segment_missing', value=None, type='assignment') 469 | g.add_edge('root', 'root_default') 470 | 471 | g.add_edge( 472 | 'segment_1', 473 | 'segment_1_age_lower', 474 | value=(-float('inf'), 10.), 475 | type='range' 476 | ) 477 | g.add_edge( 478 | 'segment_1', 479 | 'segment_1_age_upper', 480 | value=(10., float('inf')), 481 | type='range' 482 | ) 483 | g.add_edge( 484 | 'segment_1', 485 | 'segment_1_default' 486 | ) 487 | 488 | g.add_edge( 489 | 'segment_2', 490 | 'segment_2_os_known', 491 | value=('linux', 'osx'), 492 | type='membership' 493 | ) 494 | g.add_edge( 495 | 'segment_2', 496 | 'segment_2_os_unknown', 497 | value=None, 498 | type='membership' 499 | ) 500 | g.add_edge( 501 | 'segment_2', 502 | 'segment_2_default' 503 | ) 504 | 505 | g.add_edge( 506 | 'segment_missing', 507 | 'segment_missing_age_missing', 508 | value=None, 509 | type='range' 510 | ) 511 | g.add_edge( 512 | 'segment_missing', 513 | 'segment_missing_default' 514 | ) 515 | 516 | g.add_edge( 517 | 'segment_missing_age_missing', 518 | 'segment_missing_age_missing_os_known', 519 | value=('linux',), 520 | type='membership' 521 | ) 522 | 523 | g.add_edge( 524 | 'segment_missing_age_missing', 525 | 'segment_missing_age_missing_default' 526 | ) 527 | 528 | return g 529 | 530 | 531 | @pytest.fixture 532 | def missing_values_graph_short(): 533 | g = nx.DiGraph() 534 | 535 | g.add_node('root', split='segment', state=OrderedDict()) 536 | g.add_node( 537 | 'root_default', 538 | is_default_leaf=True, 539 | state=OrderedDict(), 540 | output=.1 541 | ) 542 | g.add_node('segment_1', split='segment.age', state=OrderedDict([('segment', 1)])) 543 | g.add_node( 544 | 'segment_1_default', 545 | is_default_leaf=True, 546 | state=OrderedDict([('segment', 1)]), 547 | output=.1 548 | ) 549 | g.add_node('segment_2', is_leaf=True, state=OrderedDict([('segment', 2)]), output=0.1) 550 | g.add_node('segment_missing', split='segment.age', state=OrderedDict([('segment', None)])) 551 | g.add_node( 552 | 'segment_missing_default', 553 | is_default_leaf=True, 554 | state=OrderedDict([('segment', None)]), 555 | output=.1 556 | ) 557 | g.add_node( 558 | 'segment_1_age_lower', 559 | is_leaf=True, 560 | state=OrderedDict([('segment', 1), ('segment.age', (-float('inf'), 10.))]), 561 | output=.1 562 | ) 563 | g.add_node( 564 | 'segment_1_age_missing', 565 | is_leaf=True, 566 | state=OrderedDict([('segment', 1), ('segment.age', None)]), 567 | output=.1 568 | ) 569 | g.add_node( 570 | 'segment_missing_age_missing', 571 | is_leaf=True, 572 | state=OrderedDict([('segment', None), ('segment.age', None)]), 573 | output=.1 574 | ) 575 | 576 | g.add_edge('root', 'segment_1', value=1, type='assignment') 577 | g.add_edge('root', 'segment_2', value=2, type='assignment') 578 | g.add_edge('root', 'segment_missing', value=None, type='assignment') 579 | g.add_edge('root', 'root_default') 580 | 581 | g.add_edge( 582 | 'segment_1', 583 | 'segment_1_age_lower', 584 | value=(-float('inf'), 10.), 585 | type='range' 586 | ) 587 | g.add_edge( 588 | 'segment_1', 589 | 'segment_1_age_missing', 590 | value=None, 591 | type='range' 592 | ) 593 | g.add_edge( 594 | 'segment_1', 595 | 'segment_1_default' 596 | ) 597 | 598 | g.add_edge( 599 | 'segment_missing', 600 | 'segment_missing_age_missing', 601 | value=None, 602 | type='range' 603 | ) 604 | g.add_edge( 605 | 'segment_missing', 606 | 'segment_missing_default' 607 | ) 608 | 609 | return g 610 | 611 | 612 | @pytest.fixture 613 | def negated_values_graph(): 614 | g = nx.DiGraph() 615 | 616 | g.add_node('root', split=OrderedDict([ 617 | ('every_segment', ('segment', 'segment')), 618 | ('negated_every_segment', ('segment', 'segment', 'segment')), 619 | ('any_segment', ('segment', 'segment')), 620 | ('negated_any_segment', ('segment', 'segment')) 621 | ]), state=OrderedDict()) 622 | 623 | g.add_node( 624 | 'every_segment', 625 | is_leaf=True, 626 | output=0.1, 627 | state=OrderedDict([(('segment', 'segment'), (1, 2))]) 628 | ) 629 | 630 | g.add_node( 631 | 'negated_every_segment', 632 | is_leaf=True, 633 | output=0.1, 634 | state=OrderedDict([(('segment', 'segment', 'segment'), (1, 2, 3))]) 635 | ) 636 | 637 | g.add_node( 638 | 'any_segment', 639 | is_leaf=True, 640 | output=0.1, 641 | state=OrderedDict([(('segment', 'segment'), (1, 2))]) 642 | ) 643 | 644 | g.add_node( 645 | 'negated_any_segment', 646 | is_leaf=True, 647 | output=0.1, 648 | state=OrderedDict([(('segment', 'segment'), (1, 10))]) 649 | ) 650 | 651 | g.add_edge( 652 | 'root', 653 | 'every_segment', 654 | value=(1, 2), 655 | type=('assignment', 'assignment') 656 | ) 657 | 658 | g.add_edge( 659 | 'root', 660 | 'negated_every_segment', 661 | value=(1, 2, 3), 662 | type=('assignment', 'assignment', 'assignment'), 663 | is_negated=(False, True, False) 664 | ) 665 | 666 | g.add_edge( 667 | 'root', 668 | 'any_segment', 669 | value=(1, 2), 670 | type=('assignment', 'assignment'), 671 | join_statement='any' 672 | ) 673 | 674 | g.add_edge( 675 | 'root', 676 | 'negated_any_segment', 677 | value=(1, 10), 678 | type=('assignment', 'assignment'), 679 | join_statement='any', 680 | is_negated=(False, True) 681 | ) 682 | 683 | return g 684 | 685 | 686 | @pytest.fixture 687 | def multiple_compound_features_graph(): 688 | g = nx.DiGraph() 689 | 690 | g.add_node( 691 | 'root', 692 | split=OrderedDict( 693 | [ 694 | ('segment_1', 'segment'), 695 | ('segment_2', 'segment') 696 | ] 697 | ), 698 | state=OrderedDict([]) 699 | ) 700 | 701 | g.add_node( 702 | 'segment_1', 703 | split=OrderedDict( 704 | [ 705 | ('segment_1_age_1', 'segment.age'), 706 | ('segment_1_age_2', 'segment.age') 707 | ] 708 | ), 709 | state=OrderedDict( 710 | [ 711 | ('segment', 1) 712 | ] 713 | ) 714 | ) 715 | 716 | g.add_node( 717 | 'segment_2', 718 | split=OrderedDict( 719 | [ 720 | ('segment_2_age_1', 'segment.age') 721 | ] 722 | ), 723 | state=OrderedDict( 724 | [ 725 | ('segment', 2) 726 | ] 727 | ) 728 | ) 729 | 730 | g.add_node( 731 | 'segment_1_age_1', 732 | split=OrderedDict( 733 | [ 734 | ('segment_1_age_1_freq_1', 'advertiser.lifetime_frequency'), 735 | ('segment_1_age_1_freq_2', 'advertiser.lifetime_frequency') 736 | ] 737 | ), 738 | state=OrderedDict( 739 | [ 740 | ('segment', 1), 741 | ('segment.age', (0, 10)) 742 | ] 743 | ) 744 | ) 745 | 746 | g.add_node( 747 | 'segment_1_age_2', 748 | is_leaf=True, 749 | output=0.1, 750 | state=OrderedDict( 751 | [ 752 | ('segment', 1), 753 | ('segment.age', (10, 20)) 754 | ] 755 | ) 756 | ) 757 | 758 | g.add_node( 759 | 'segment_2_age_1', 760 | split=OrderedDict( 761 | [ 762 | ('segment_2_age_1_freq_2', 'advertiser.lifetime_frequency'), 763 | ('segment_2_age_1_freq_3', 'advertiser.lifetime_frequency'), 764 | ('segment_2_age_1_freq_4', 'advertiser.lifetime_frequency'), 765 | ('segment_2_age_1_user_day', 'user_day') 766 | ] 767 | ), 768 | state=OrderedDict( 769 | [ 770 | ('segment', 2), 771 | ('segment.age', (0, 10)) 772 | ] 773 | ) 774 | ) 775 | 776 | g.add_node( 777 | 'segment_1_age_1_freq_1', 778 | is_leaf=True, 779 | output=0.5, 780 | state=OrderedDict( 781 | [ 782 | ('segment', 1), 783 | ('segment.age', (0, 10)), 784 | ('advertiser.lifetime_frequency', (5, 10)) 785 | ] 786 | ) 787 | ) 788 | 789 | g.add_node( 790 | 'segment_1_age_1_freq_2', 791 | is_leaf=True, 792 | output=0.3, 793 | state=OrderedDict( 794 | [ 795 | ('segment', 1), 796 | ('segment.age', (0, 10)), 797 | ('advertiser.lifetime_frequency', (None, 4)) 798 | ] 799 | ) 800 | ) 801 | 802 | g.add_node( 803 | 'segment_2_age_1_freq_2', 804 | is_leaf=True, 805 | output=0.6, 806 | state=OrderedDict( 807 | [ 808 | ('segment', 2), 809 | ('segment.age', (0, 10)), 810 | ('advertiser.lifetime_frequency', (11, 11)) 811 | ] 812 | ) 813 | ) 814 | 815 | g.add_node( 816 | 'segment_2_age_1_freq_3', 817 | is_leaf=True, 818 | output=0.7, 819 | state=OrderedDict( 820 | [ 821 | ('segment', 2), 822 | ('segment.age', (0, 10)), 823 | ('advertiser.lifetime_frequency', (12, None)) 824 | ] 825 | ) 826 | ) 827 | 828 | g.add_node( 829 | 'segment_2_age_1_freq_4', 830 | is_leaf=True, 831 | output=0.9, 832 | state=OrderedDict( 833 | [ 834 | ('segment', 2), 835 | ('segment.age', (0, 10)), 836 | ('advertiser.lifetime_frequency', (0, 10)) 837 | ] 838 | ) 839 | ) 840 | 841 | g.add_node( 842 | 'segment_2_age_1_user_day', 843 | is_leaf=True, 844 | output=1., 845 | state=OrderedDict( 846 | [ 847 | ('segment', 2), 848 | ('segment.age', (0, 10)), 849 | ('user_day', (0, 3)) 850 | ] 851 | ) 852 | ) 853 | 854 | g.add_edge( 855 | 'root', 856 | 'segment_1', 857 | value=1, 858 | type='assignment' 859 | ) 860 | 861 | g.add_edge( 862 | 'root', 863 | 'segment_2', 864 | value=2, 865 | type='assignment' 866 | ) 867 | 868 | g.add_edge( 869 | 'segment_1', 870 | 'segment_1_age_1', 871 | value=(0, 10), 872 | type='range' 873 | ) 874 | 875 | g.add_edge( 876 | 'segment_1', 877 | 'segment_1_age_2', 878 | value=(10, 20), 879 | type='range' 880 | ) 881 | 882 | g.add_edge( 883 | 'segment_2', 884 | 'segment_2_age_1', 885 | value=(0, 10), 886 | type='range' 887 | ) 888 | 889 | g.add_edge( 890 | 'segment_1_age_1', 891 | 'segment_1_age_1_freq_1', 892 | value=(5, 10), 893 | type='range' 894 | ) 895 | 896 | g.add_edge( 897 | 'segment_1_age_1', 898 | 'segment_1_age_1_freq_2', 899 | value=(None, 4), 900 | type='range' 901 | ) 902 | 903 | g.add_edge( 904 | 'segment_2_age_1', 905 | 'segment_2_age_1_freq_2', 906 | value=(11, 11), 907 | type='range' 908 | ) 909 | 910 | g.add_edge( 911 | 'segment_2_age_1', 912 | 'segment_2_age_1_freq_3', 913 | value=(12, None), 914 | type='range' 915 | ) 916 | 917 | g.add_edge( 918 | 'segment_2_age_1', 919 | 'segment_2_age_1_freq_4', 920 | value=(0, 10), 921 | type='range' 922 | ) 923 | 924 | g.add_edge( 925 | 'segment_2_age_1', 926 | 'segment_2_age_1_user_day', 927 | value=(0, 3), 928 | type='range' 929 | ) 930 | 931 | return g 932 | 933 | 934 | @pytest.fixture 935 | def data_features_and_file(): 936 | base_path = os.path.dirname(__file__) 937 | path = os.path.join(base_path, 'data/test.csv.gz') 938 | 939 | features = ['country', 940 | 'region', 941 | 'city', 942 | 'user_day', 943 | 'user_hour', 944 | 'os_extended', 945 | 'browser', 946 | 'language' 947 | ] 948 | 949 | return features, path 950 | 951 | 952 | @pytest.fixture 953 | def small_data_features_and_file(tmpdir): 954 | features = b'os,city\n' 955 | data = [ 956 | b'iOS,Berlin\n', 957 | b'Android,Berlin\n', 958 | b'iOS,Hamburg\n', 959 | b'iOS,Hamburg\n', 960 | b'iOS,\n', 961 | b',Berlin\n', 962 | b',Bremen\n' 963 | ] 964 | 965 | path = str(tmpdir.join('test_data.csv.gz')) 966 | 967 | with gzip.open(path, 'wb') as file: 968 | file.write(features) 969 | for line in data: 970 | file.write(line) 971 | 972 | return ['os', 'city'], path 973 | 974 | 975 | @pytest.fixture 976 | def small_data_features_and_file_numeric(tmpdir): 977 | features = b'city,user_day\n' 978 | data = [ 979 | b'Berlin,0\n', 980 | b'Berlin,1\n', 981 | b'Hamburg,\n', 982 | b'Hamburg,6\n', 983 | b',2\n', 984 | b'Berlin,\n', 985 | b'Bremen,2\n' 986 | ] 987 | path = str(tmpdir.join('test_data_numeric.csv.gz')) 988 | 989 | with gzip.open(path, 'wb') as file: 990 | file.write(features) 991 | for line in data: 992 | file.write(line) 993 | 994 | return ['city', 'user_day'], path 995 | 996 | 997 | @pytest.fixture 998 | def unsliced_graph(): 999 | g = nx.DiGraph() 1000 | # root 1001 | g.add_node(0, split='some_feature', state=OrderedDict()) 1002 | 1003 | # level one 1004 | g.add_node(1, state=OrderedDict([('some_feature', 'value_one')]), is_leaf=True, output=1.) 1005 | g.add_node(2, state=OrderedDict([('some_feature', 'value_two')]), split='slice_feature') 1006 | g.add_node('default_one', is_default_leaf=True, state=OrderedDict(), output=0.1) 1007 | 1008 | # connect root with level one 1009 | g.add_edge(0, 1, value='value_one', type='assignment') 1010 | g.add_edge(0, 2, value='value_two', type='assignment') 1011 | g.add_edge(0, 'default_one') 1012 | 1013 | # level two 1014 | g.add_node(3, state=OrderedDict([('some_feature', 'value_two'), ('slice_feature', 'bad')]), is_leaf=True, 1015 | output=1.) 1016 | g.add_node(4, state=OrderedDict([('some_feature', 'value_two'), ('slice_feature', 'good')]), split='other_feature') 1017 | g.add_node('default_two', is_default_leaf=True, state=OrderedDict([('some_feature', 'value_two')]), output=.1) 1018 | 1019 | # connect level one with level two 1020 | g.add_edge(2, 3, value='bad', type='assignment') 1021 | g.add_edge(2, 4, value='good', type='assignment') 1022 | g.add_edge(2, 'default_two') 1023 | 1024 | # level three 1025 | g.add_node(5, state=OrderedDict( 1026 | [('some_feature', 'value_two'), ('slice_feature', 'good'), ('other_feature', 'blah')] 1027 | ), is_leaf=True, output=1.) 1028 | g.add_node(6, state=OrderedDict( 1029 | [('some_feature', 'value_two'), ('slice_feature', 'good'), ('other_feature', 'blub')] 1030 | ), is_leaf=True, output=1.) 1031 | g.add_node('default_three', is_default_leaf=True, 1032 | state=OrderedDict([('some_feature', 'value_two'), ('slice_feature', 'good')]), output=.1) 1033 | 1034 | # connect level two with level three 1035 | g.add_edge(4, 5, value='blah', type='assignment') 1036 | g.add_edge(4, 6, value='blub', type='assignment') 1037 | g.add_edge(4, 'default_three') 1038 | 1039 | return g 1040 | 1041 | 1042 | @pytest.fixture 1043 | def small_unsliced_graph(): 1044 | g = nx.DiGraph() 1045 | # root 1046 | g.add_node(0, split='slice_feature', state=OrderedDict()) 1047 | 1048 | # level one 1049 | g.add_node(1, state=OrderedDict([('slice_feature', 'good')]), is_leaf=True, output=5.) 1050 | g.add_node(2, state=OrderedDict([('slice_feature', 'bad')]), is_leaf=True, output=1.) 1051 | g.add_node('default_one', is_default_leaf=True, state=OrderedDict(), output=1.) 1052 | 1053 | # connect root with level one 1054 | g.add_edge(0, 1, value='good', type='assignment') 1055 | g.add_edge(0, 2, value='bad', type='assignment') 1056 | g.add_edge(0, 'default_one') 1057 | 1058 | return g 1059 | 1060 | 1061 | @pytest.fixture 1062 | def small_unsliced_graph_single_slice_feature_value(): 1063 | g = nx.DiGraph() 1064 | # root 1065 | g.add_node(0, split='slice_feature', state=OrderedDict()) 1066 | 1067 | # level one 1068 | g.add_node(1, state=OrderedDict([('slice_feature', 'value')]), is_leaf=True, output=5.) 1069 | g.add_node('default_one', is_default_leaf=True, state=OrderedDict(), output=1.) 1070 | 1071 | # connect root with level one 1072 | g.add_edge(0, 1, value='value', type='assignment') 1073 | g.add_edge(0, 'default_one') 1074 | 1075 | return g 1076 | 1077 | 1078 | @pytest.fixture 1079 | def small_unsliced_graph_mixed_split(): 1080 | g = nx.DiGraph() 1081 | # root 1082 | g.add_node( 1083 | 0, split=OrderedDict([(1, 'slice_feature'), (2, 'slice_feature'), (3, 'other_feature')]), state=OrderedDict() 1084 | ) 1085 | 1086 | # level one 1087 | g.add_node(1, state=OrderedDict([('slice_feature', 'good')]), is_leaf=True, output=5.) 1088 | g.add_node(2, state=OrderedDict([('slice_feature', 'bad')]), is_leaf=True, output=1.) 1089 | g.add_node(3, state=OrderedDict([('other_feature', 'value')]), is_leaf=True, output=3.) 1090 | g.add_node('default_one', is_default_leaf=True, state=OrderedDict(), output=1.) 1091 | 1092 | # connect root with level one 1093 | g.add_edge(0, 1, value='good', type='assignment') 1094 | g.add_edge(0, 2, value='bad', type='assignment') 1095 | g.add_edge(0, 3, value='other_value', type='assignment') 1096 | g.add_edge(0, 'default_one') 1097 | 1098 | return g 1099 | -------------------------------------------------------------------------------- /bonspy/tests/data/test.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovianhq/bonspy/f13b6d0fbf91e8cafa19f685c500be7fa04563d5/bonspy/tests/data/test.csv.gz -------------------------------------------------------------------------------- /bonspy/tests/test_bonsai.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import ( 4 | print_function, division, generators, 5 | absolute_import, unicode_literals 6 | ) 7 | 8 | from collections import deque 9 | import networkx as nx 10 | import pytest 11 | import re 12 | 13 | from bonspy import BonsaiTree 14 | 15 | 16 | def test_switch_header(graph): 17 | tree = BonsaiTree(graph) 18 | text = tree.bonsai.replace('\t', '').split('\n') 19 | switch_header_nodes = [d for _, d in tree.nodes_iter(data=True) if d.get('split') is not None and 20 | set(d.get('split').values()) == {'segment.age'}] 21 | 22 | assert len(switch_header_nodes) == 1 23 | assert all([d.get('switch_header') is not None for d in switch_header_nodes]) 24 | 25 | for row in text: 26 | if '.age' in row and 'segment[67890]' in row: 27 | assert row in {'switch segment[67890].age:'} 28 | elif '.age' in row and 'segment[12345]' in row: 29 | assert row not in {'switch segment[12345].age:'} 30 | assert 'elif' in row or 'if' in row 31 | 32 | 33 | def test_switch_indent(graph): 34 | tree = BonsaiTree(graph) 35 | 36 | switch_header_nodes = [n for n, d in tree.nodes_iter(data=True) if d.get('split') == 'segment.age'] 37 | 38 | for node in switch_header_nodes: 39 | node_indent = tree.node[node]['indent'].count('\t') 40 | header_indent = tree.node[node]['switch_header'].count('\t') 41 | 42 | children_indent = [tree.node[c]['indent'].count('\t') for c in tree.successors_iter(node)] 43 | 44 | assert node_indent - 1 == header_indent 45 | assert all([header_indent + 2 == child_indent for child_indent in children_indent]) 46 | 47 | 48 | def test_compound_feature_presence(graph): 49 | tree = BonsaiTree(graph) 50 | 51 | text = tree.bonsai.replace('\t', '').split('\n') 52 | 53 | for row in text: 54 | if 'segment' in row: 55 | assert any(['segment[{id}]'.format(id=i) in row for i in [12345, 67890, 13579]]) 56 | 57 | 58 | def test_multiple_compound_features(multiple_compound_features_graph): 59 | feature_value_order = { 60 | 'segment': [1, 2], 61 | 'segment.age': [(0, 10), (10, 20)], 62 | 'advertiser.lifetime_frequency': [(5, 10), (None, 4), (11, 11), (12, None), (0, 10)] 63 | } 64 | 65 | feature_order = ['segment', 'segment.age', 'advertiser.lifetime_frequency', 'user_day'] 66 | 67 | tree = BonsaiTree( 68 | multiple_compound_features_graph, 69 | feature_order=feature_order, 70 | feature_value_order=feature_value_order, 71 | advertiser=1 72 | ) 73 | 74 | expected_tree = ''' 75 | if segment[1]: 76 | \tswitch segment[1].age: 77 | \t\tcase (0 .. 10): 78 | \t\t\tswitch advertiser[1].lifetime_frequency: 79 | \t\t\t\tcase (5 .. 10): 80 | \t\t\t\t\t0.5000 81 | \t\t\t\tcase ( .. 4): 82 | \t\t\t\t\t0.3000 83 | \t\tcase (10 .. 20): 84 | \t\t\t0.1000 85 | else segment[2]: 86 | \tswitch segment[2].age: 87 | \t\tcase (0 .. 10): 88 | \t\t\tif advertiser[1].lifetime_frequency = 11: 89 | \t\t\t\t0.6000 90 | \t\t\telif advertiser[1].lifetime_frequency >= 12: 91 | \t\t\t\t0.7000 92 | \t\t\telif every advertiser[1].lifetime_frequency >= 0, advertiser[1].lifetime_frequency <= 10: 93 | \t\t\t\t0.9000 94 | \t\t\telse user_day range (0, 3): 95 | \t\t\t\t1.0000 96 | '''.replace(8 * ' ', '').strip().lstrip('\n') + '\n' 97 | 98 | assert tree.bonsai == expected_tree 99 | 100 | 101 | def test_get_range_output_for_finite_boundary_points(graph): 102 | some_graph = nx.DiGraph(graph) 103 | tree = BonsaiTree(some_graph) 104 | 105 | for join in ['any', 'every', None]: 106 | out = tree._get_range_output_for_finite_boundary_points(0, 1, 'user_hour', join_statement=join) 107 | assert out == 'user_hour range (0, 1)' 108 | 109 | for join in ['any', 'every', None]: 110 | out = tree._get_range_output_for_finite_boundary_points(1, 1, 'user_hour', join_statement=join) 111 | assert out == 'user_hour = 1' 112 | 113 | for join in ['any', 'every', None]: 114 | out = tree._get_range_output_for_finite_boundary_points(1, 1, 'advertiser[123].recency', join_statement=join) 115 | assert out == 'advertiser[123].recency = 1' 116 | 117 | for join in ['every', None]: 118 | out = tree._get_range_output_for_finite_boundary_points(1, 2, 'advertiser[123].recency', join_statement=join) 119 | assert out == (join is None) * 'every ' + 'advertiser[123].recency >= 1, advertiser[123].recency <= 2' 120 | 121 | join = 'any' 122 | with pytest.raises(ValueError): 123 | tree._get_range_output_for_finite_boundary_points(1, 2, 'advertiser[123].recency', join_statement=join) 124 | 125 | 126 | def test_two_range_features(graph_two_range_features): 127 | tree = BonsaiTree(graph_two_range_features) 128 | 129 | switch_nodes = [n for n, d in tree.nodes_iter(data=True) if d.get('switch_header')] 130 | 131 | for node in switch_nodes: 132 | parent = tree.predecessors(node)[0] 133 | 134 | header_indent = tree.node[node]['switch_header'].count('\t') 135 | parent_indent = tree.node[parent]['indent'].count('\t') 136 | 137 | assert header_indent - 1 == parent_indent 138 | 139 | 140 | def test_feature_validation(graph_two_range_features): 141 | tree = BonsaiTree(graph_two_range_features) 142 | 143 | for node, data in tree.nodes_iter(data=True): 144 | try: 145 | lower, upper = data['state']['segment.age'] 146 | assert lower >= 0 147 | assert isinstance(lower, int) 148 | assert isinstance(upper, int) 149 | except KeyError: 150 | pass 151 | 152 | for node, data in tree.nodes_iter(data=True): 153 | try: 154 | lower, upper = data['state']['user_hour'] 155 | 156 | assert lower >= 0 157 | assert upper <= 23 158 | assert isinstance(lower, int) 159 | assert isinstance(upper, int) 160 | except KeyError: 161 | pass 162 | 163 | for parent, _, data in tree.edges_iter(data=True): 164 | if tree.node[parent]['split'] == 'segment.age': 165 | try: 166 | lower, upper = data['value'] 167 | 168 | assert lower >= 0 169 | assert isinstance(lower, int) 170 | assert isinstance(upper, int) 171 | except KeyError: 172 | pass 173 | 174 | for parent, _, data in tree.edges_iter(data=True): 175 | if tree.node[parent]['split'] == 'user_hour': 176 | try: 177 | lower, upper = data['value'] 178 | 179 | assert lower >= 0 180 | assert upper <= 23 181 | assert isinstance(lower, int) 182 | assert isinstance(upper, int) 183 | except KeyError: 184 | pass 185 | 186 | 187 | def test_compound_feature(graph_compound_feature): 188 | tree = BonsaiTree(graph_compound_feature) 189 | 190 | assert 'every site_id=1, placement_id="a":' in tree.bonsai 191 | assert 'every site_id=1, placement_id="b":' in tree.bonsai 192 | 193 | 194 | def test_if_elif_else_switch_default(parameterized_graph): 195 | tree = BonsaiTree(parameterized_graph) 196 | 197 | line_list = tree.bonsai.split('\n') 198 | line_list = line_list[:-1] if line_list[-1] == '' else line_list 199 | line_list = [line for line in line_list if 'leaf_name: ' not in line] 200 | indent_dict = {line: len(line.split('\t')) - 1 for line in line_list} 201 | 202 | indent_list = [indent_dict[i] for i in line_list] 203 | assert all(indent != next_indent for indent, next_indent in zip(indent_list, indent_list[1:])) 204 | 205 | queue = deque([line_list]) 206 | 207 | while len(queue) > 0: 208 | sub_list = queue.pop() 209 | indent_list = [indent_dict[i] for i in sub_list] 210 | 211 | outermost_level = min(indent_list) 212 | indices = [i for i, v in enumerate(indent_list) if v == outermost_level] 213 | 214 | first = sub_list[indices[0]] 215 | last = sub_list[indices[-1]] 216 | but_last = [sub_list[i] for i in indices[0:-1]] 217 | middle = [sub_list[i] for i in indices[1:-1]] 218 | 219 | if_elif_else_level = 'if' in first and 'else:' in last and all('elif' in line for line in middle) 220 | switch_level = 'switch' in first and len(indices) == 1 221 | case_level = all('case' in line for line in but_last) and 'default' in last 222 | 223 | assert if_elif_else_level or switch_level or case_level 224 | 225 | for i, j in zip(indices, indices[1:] + [None]): 226 | new_sublist = sub_list[i + 1:j] 227 | if len(new_sublist) == 0: 228 | continue 229 | elif len(new_sublist) == 1: 230 | try: 231 | assert float(new_sublist[0].strip()) 232 | except ValueError: 233 | assert re.match( 234 | r"value: (no_bid|\d+\.\d*|compute\(\w+, (\d+\.\d*|_), " 235 | r"(\d+\.\d*|_), (\d+\.\d*|_), (\d+\.\d*|_)\))", 236 | new_sublist[0].strip() 237 | ) 238 | else: 239 | queue.append(new_sublist) 240 | 241 | 242 | def test_segment_order(graph): 243 | tree = BonsaiTree(graph) 244 | 245 | assert 'if segment[12345]' in tree.bonsai 246 | assert 'elif segment[67890]' in tree.bonsai 247 | 248 | 249 | def test_segment_order_mapping(graph): 250 | tree = BonsaiTree( 251 | graph, 252 | feature_value_order={ 253 | 'segment': [67890, 12345] 254 | } 255 | ) 256 | 257 | assert 'if segment[67890]' in tree.bonsai 258 | assert 'elif segment[12345]' in tree.bonsai 259 | 260 | 261 | def test_language_order_mapping(graph_compound_feature): 262 | tree = BonsaiTree( 263 | graph_compound_feature, 264 | feature_value_order={ 265 | 'os': ['windows', 'linux'] 266 | } 267 | ) 268 | 269 | bonsai = tree.bonsai.replace('\n', '').replace('\t', '') 270 | 271 | assert 'if os="windows":0.1000elif os="linux":0.2000' in bonsai 272 | 273 | 274 | def test_language_order_mapping_one_value(graph_compound_feature): 275 | tree = BonsaiTree( 276 | graph_compound_feature, 277 | feature_value_order={ 278 | 'os': ['windows'] 279 | } 280 | ) 281 | 282 | bonsai = tree.bonsai.replace('\n', '').replace('\t', '') 283 | 284 | assert 'if os="windows":0.1000elif os="linux":0.2000' in bonsai 285 | 286 | 287 | def test_language_segment_age_order(graph): 288 | tree = BonsaiTree(graph) 289 | 290 | assert 'if segment[12345].age' in tree.bonsai 291 | assert 'elif language' in tree.bonsai 292 | 293 | 294 | def test_feature_order_mapping(graph): 295 | tree = BonsaiTree( 296 | graph, 297 | feature_order=['language', 'segment.age'] 298 | ) 299 | 300 | # language comes first 301 | assert 'if language' in tree.bonsai 302 | # segment.age comes before browser 303 | assert 'elif segment[12345].age' in tree.bonsai.split('elif browser="safari"')[0] 304 | 305 | 306 | def test_no_bid_present_in_output(graph): 307 | tree = BonsaiTree(graph) 308 | text = tree.bonsai 309 | 310 | assert 'value: no_bid' in text 311 | 312 | 313 | def test_switch_non_switch_range(small_graph): 314 | graph = small_graph 315 | tree = BonsaiTree(graph) 316 | 317 | assert 'switch user_hour' in tree.bonsai 318 | assert 'switch user_day' in tree.bonsai 319 | assert 'case ( .. 10)' in tree.bonsai 320 | assert 'case (11 .. 15)' in tree.bonsai 321 | assert 'if user_day range (3, 6)' in tree.bonsai 322 | 323 | 324 | def test_get_range_statement(): 325 | bonsai = BonsaiTree() 326 | get_range_statement = bonsai._get_range_statement 327 | values_dict = {1: (None, 1), 328 | 2: (1.54389, None), 329 | 3: (0.9, 2.1), 330 | 4: (None, None), 331 | 5: (0.99999, 1.4), 332 | 6: (-float('inf'), 1), 333 | 7: (1.3, float('inf')) 334 | } 335 | feature = 'some_feature' 336 | 337 | assert get_range_statement(values_dict[1], feature) == 'some_feature <= 1' 338 | assert get_range_statement(values_dict[2], feature) == 'some_feature >= 1.5439' 339 | assert get_range_statement(values_dict[3], feature) == 'some_feature range (0.9, 2.1)' 340 | with pytest.raises(ValueError): 341 | get_range_statement(values_dict[4], feature) 342 | assert get_range_statement(values_dict[5], feature) == 'some_feature range (1.0, 1.4)' 343 | assert get_range_statement(values_dict[6], feature) == 'some_feature <= 1' 344 | assert get_range_statement(values_dict[7], feature) == 'some_feature >= 1.3' 345 | 346 | 347 | def test_missing_values(missing_values_graph): 348 | graph = missing_values_graph 349 | 350 | feature_value_order = { 351 | 'segment': [1, 2], 352 | 'os': [("linux", "osx"), ("linux",)], 353 | 'segment.age': [(0, 10), (10, float('inf'))] 354 | } 355 | 356 | feature_order = ['segment', 'os'] 357 | 358 | tree = BonsaiTree( 359 | graph, 360 | feature_order=feature_order, 361 | feature_value_order=feature_value_order, 362 | absence_values={'segment': (1, 2)} 363 | ) 364 | 365 | expected_tree = ''' 366 | if segment[1]: 367 | \tswitch segment[1].age: 368 | \t\tcase (0 .. 10): 369 | \t\t\t0.1000 370 | \t\tcase (10 .. ): 371 | \t\t\t0.1000 372 | \t\tdefault: 373 | \t\t\t0.1000 374 | elif segment[2]: 375 | \tif os in ("linux","osx"): 376 | \t\t0.1000 377 | \telif os absent: 378 | \t\t0.1000 379 | \telse: 380 | \t\t0.1000 381 | elif every not segment[1], not segment[2]: 382 | \tif os in ("linux"): 383 | \t\t0.1000 384 | \telse: 385 | \t\t0.1000 386 | else: 387 | \t0.1000 388 | '''.replace(8 * ' ', '').strip().lstrip('\n') + '\n' 389 | 390 | assert tree.bonsai == expected_tree 391 | 392 | 393 | def test_missing_values_short(missing_values_graph_short): 394 | graph = missing_values_graph_short 395 | 396 | feature_value_order = { 397 | 'segment': [1, 2], 398 | 'segment.age': [(0, 10)] 399 | } 400 | 401 | feature_order = ['segment'] 402 | 403 | tree = BonsaiTree( 404 | graph, 405 | feature_order=feature_order, 406 | feature_value_order=feature_value_order, 407 | absence_values={'segment': (1, 2)} 408 | ) 409 | 410 | expected_tree = ''' 411 | if segment[1]: 412 | \tswitch segment[1].age: 413 | \t\tcase (0 .. 10): 414 | \t\t\t0.1000 415 | \t\tdefault: 416 | \t\t\t0.1000 417 | elif segment[2]: 418 | \t0.1000 419 | elif every not segment[1], not segment[2]: 420 | \t0.1000 421 | else: 422 | \t0.1000 423 | '''.replace(8 * ' ', '').strip().lstrip('\n') + '\n' 424 | 425 | assert tree.bonsai == expected_tree 426 | 427 | 428 | def test_negated_values(negated_values_graph): 429 | graph = negated_values_graph 430 | 431 | tree = BonsaiTree( 432 | graph, 433 | feature_order=[('segment', 'segment'), ('segment', 'segment', 'segment')], 434 | feature_value_order={ 435 | ('segment', 'segment'): [(1, 10), (1, 2)] 436 | } 437 | ) 438 | 439 | expected_conditions = [ 440 | 'every segment[1], segment[2]', 441 | 'every segment[1], not segment[2], segment[3]', 442 | 'any segment[1], segment[2]', 443 | 'any segment[1], not segment[10]' 444 | ] 445 | 446 | assert all([e in tree.bonsai for e in expected_conditions]) 447 | 448 | indexes = [tree.bonsai.index(e) for e in expected_conditions] 449 | assert tree.bonsai.index('any segment[1], not segment[10]') == min(indexes) 450 | 451 | 452 | def test_feature_slicer(unsliced_graph, small_unsliced_graph): 453 | tree = BonsaiTree( 454 | unsliced_graph, 455 | slice_features=('slice_feature',), 456 | slice_feature_values={'slice_feature': 'good'} 457 | ) 458 | 459 | assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) 460 | assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node]) 461 | 462 | tree = BonsaiTree( 463 | small_unsliced_graph, 464 | slice_features=('slice_feature',), 465 | slice_feature_values={'slice_feature': 'good'} 466 | ) 467 | 468 | assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) 469 | assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node]) 470 | assert tree.node[0]['output'] == 5. 471 | 472 | 473 | def test_feature_slicer_single_wrong_slice_feature_value(small_unsliced_graph_single_slice_feature_value): 474 | tree = BonsaiTree( 475 | small_unsliced_graph_single_slice_feature_value, 476 | slice_features=('slice_feature',), 477 | slice_feature_values={'slice_feature': 'value'} 478 | ) 479 | 480 | assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) 481 | assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node]) 482 | assert tree.node[0]['output'] == 5. 483 | 484 | 485 | def test_feature_slicer_single_correct_slice_feature_value(small_unsliced_graph_single_slice_feature_value): 486 | tree = BonsaiTree( 487 | small_unsliced_graph_single_slice_feature_value, 488 | slice_features=('slice_feature',), 489 | slice_feature_values={'slice_feature': 'other_value'} 490 | ) 491 | 492 | assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) 493 | assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node]) 494 | assert tree.node[0]['output'] == 1. 495 | 496 | 497 | def test_feature_slicer_mixed_split(small_unsliced_graph_mixed_split): 498 | tree = BonsaiTree( 499 | small_unsliced_graph_mixed_split, 500 | slice_features=('slice_feature',), 501 | slice_feature_values={'slice_feature': 'good'} 502 | ) 503 | 504 | assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) 505 | assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node]) 506 | assert 'output' not in tree.node[0] 507 | assert tree.node['default_one']['output'] == 5. 508 | -------------------------------------------------------------------------------- /bonspy/tests/test_features.py: -------------------------------------------------------------------------------- 1 | from bonspy.features import _apply_operations 2 | 3 | 4 | def test_apply_operations_domain(): 5 | value = _apply_operations('domain', 'www.test.com') 6 | 7 | assert value == 'test.com' 8 | 9 | 10 | def test_apply_operations_other_feature(): 11 | value = _apply_operations('other_feature', 'www.test.com') 12 | 13 | assert value == 'www.test.com' 14 | 15 | 16 | def test_apply_operations_segment(): 17 | value = _apply_operations('segment', 1) 18 | 19 | assert value == 1 20 | -------------------------------------------------------------------------------- /bonspy/tests/test_graph_builder.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from unittest.mock import Mock 3 | from random import random 4 | from bonspy.graph_builder import GraphBuilder, ConstantBidder, EstimatorBidder 5 | 6 | 7 | def test_graph_builder_small(small_data_features_and_file): 8 | features, path = small_data_features_and_file 9 | builder = GraphBuilder(path, features, lazy_formatters=(('os', str), ('city', str))) 10 | graph = builder.get_graph() 11 | leaves = [n for n in graph.node if graph.out_degree(n) == 0] 12 | normal_leaves = [n for n in leaves if graph.node[n].get('is_leaf')] 13 | default_leaves = [n for n in leaves if graph.node[n].get('is_default_leaf')] 14 | 15 | assert len(normal_leaves) + len(default_leaves) == len(leaves) 16 | assert len(normal_leaves) == 6 17 | assert len(default_leaves) == 4 18 | assert len(graph.node) == len(leaves) + 4 19 | 20 | 21 | def test_graph_builder_lazy_formatters(small_data_features_and_file_numeric): 22 | features, path = small_data_features_and_file_numeric 23 | builder = GraphBuilder(path, features, lazy_formatters=(('city', str), ('user_day', int))) 24 | graph = builder.get_graph() 25 | 26 | for node in graph.node: 27 | state = graph.node[node]['state'] 28 | for feature, feature_value in state.items(): 29 | assert isinstance(feature_value, builder.lazy_formatters[feature]) or feature_value is None 30 | 31 | 32 | def test_graph_builder(data_features_and_file): 33 | features, path = data_features_and_file 34 | builder = GraphBuilder(path, features) 35 | graph = builder.get_graph() 36 | leaves = [n for n in graph.node if graph.out_degree(n) == 0] 37 | 38 | assert all([graph.node[n].get('is_leaf', graph.node[n].get('is_default_leaf', False)) for n in leaves]) 39 | 40 | 41 | def test_graph_builder_functions(data_features_and_file): 42 | features, path = data_features_and_file 43 | 44 | def events_counter(node_dict, *args): 45 | try: 46 | node_dict['events'] += 1 47 | except KeyError: 48 | node_dict['events'] = 1 49 | return node_dict 50 | 51 | builder = GraphBuilder(path, features, functions=(events_counter,)) 52 | graph = builder.get_graph() 53 | 54 | normal_leaves = [n for n in graph.node if graph.node[n].get('is_leaf')] 55 | file = gzip.GzipFile(path) 56 | headers = next(file) # NOQA 57 | events = sum([1 for _ in file]) 58 | 59 | assert sum([graph.node[n]['events'] for n in normal_leaves]) == graph.node[0]['events'] == events 60 | 61 | 62 | def test_constant_bidder(data_features_and_file): 63 | features, path = data_features_and_file 64 | builder = GraphBuilder(path, features) 65 | graph = builder.get_graph() 66 | 67 | bidder = ConstantBidder(bid=1.) 68 | graph = bidder.compute_bids(graph) 69 | leaves = [n for n in graph.node if graph.out_degree(n) == 0] 70 | 71 | assert all([graph.node[n]['output'] == 1. for n in leaves]) 72 | 73 | 74 | def test_estimator_bidder(data_features_and_file): 75 | features, path = data_features_and_file 76 | builder = GraphBuilder(path, features) 77 | graph = builder.get_graph() 78 | 79 | rate_estimator = Mock() 80 | rate_estimator.dict_vectorizer = lambda x, **kwargs: x 81 | rate_estimator.predict = lambda x: 0.5 * (1 + random()) 82 | 83 | bidder = EstimatorBidder(base_bid=5., estimators=(rate_estimator, )) 84 | graph = bidder.compute_bids(graph) 85 | leaves = [n for n in graph.node if graph.out_degree(n) == 0] 86 | 87 | assert all([2.5 <= graph.node[n]['output'] <= 5. for n in leaves]) 88 | -------------------------------------------------------------------------------- /bonspy/utils.py: -------------------------------------------------------------------------------- 1 | def compare_vectors(x, y): 2 | for x_i, y_i in zip(x, y): 3 | comparison = _compare(x_i, y_i) 4 | if comparison == 0: 5 | continue 6 | else: 7 | return comparison 8 | return 0 9 | 10 | 11 | def _compare(x, y): 12 | if x is not None and y is not None: 13 | return int(x > y) - int(x < y) 14 | elif x is not None and y is None: 15 | return -1 16 | elif x is None and y is not None: 17 | return 1 18 | else: 19 | return 0 20 | 21 | 22 | def is_absent_value(value): 23 | return value in (None, '', (), []) 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx==1.11 2 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | pytest>=3.0.1 2 | pytest-flake8>=0.6 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E226,E302,E41 3 | max-line-length = 120 4 | exclude = tests/*,__init__.py 5 | max-complexity = 20 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | try: 4 | from setuptools import setup 5 | except ImportError: 6 | from distutils.core import setup 7 | 8 | 9 | setup( 10 | name='bonspy', 11 | version='1.2.9', 12 | description='Library that converts bidding trees to the AppNexus Bonsai language.', 13 | author='Alexander Volkmann, Georg Walther', 14 | author_email='contact@markovian.com', 15 | packages=['bonspy'], 16 | package_dir={'bonspy': 'bonspy'}, 17 | package_data={'bonspy': ['tests/data/*.csv.gz']}, 18 | url='https://github.com/markovianhq/bonspy', 19 | download_url='https://github.com/markovianhq/bonspy/tarball/master', 20 | classifiers=[ 21 | 'Development Status :: 3 - Alpha', 22 | 'License :: OSI Approved :: BSD License', 23 | 'Programming Language :: Python :: 3' 24 | ] 25 | ) 26 | --------------------------------------------------------------------------------