├── .formatter.exs ├── .gitignore ├── LICENSE ├── README.md ├── lib ├── mockingjay.ex └── mockingjay │ ├── strategies │ ├── gemm.ex │ ├── perfect_tree_traversal.ex │ └── tree_traversal.ex │ ├── strategy.ex │ └── tree.ex ├── mix.exs ├── mix.lock └── test ├── mockingjay ├── gemm_test.exs └── tree_test.exs ├── mockingjay_test.exs ├── support └── model.ex └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | mockingjay-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mockingjay 2 | 3 | [![Documentation](https://img.shields.io/badge/-Documentation-blueviolet)](https://hexdocs.pm/mockingjay) 4 | 5 | Implementation of Microsoft's [Hummingbird](https://github.com/microsoft/hummingbird) library for converting trained Decision Tree 6 | models into tensor computations. 7 | 8 | ## How to Use 9 | 10 | Implement the `DecisionTree` protocol for any data source you would like to compile. Then you can use `Mockingjay.convert/1` 11 | to generate an `Nx.Defn` prediction function that makes inferences. The output of `convert` will be a function with the signature 12 | `fn x -> predict(x)`. The three strategies are GEMM, TreeTraversal, and PerfectTree traversal. You can specify the strategy using the 13 | `:strategy` option in `convert` or use a heuristic strategy by default. The heuristic used is generally: 14 | 15 | * GEMM: Shallow Trees (<=3) 16 | * PerfectTreeTraversal: Tall trees where depth <= 10 17 | * TreeTraversal: Tall trees unfit for PTT (depth > 10) 18 | 19 | ## Installation 20 | 21 | ```elixir 22 | def deps do 23 | [ 24 | {:mockingjay, "~> 0.1"} 25 | ] 26 | end 27 | ``` 28 | -------------------------------------------------------------------------------- /lib/mockingjay.ex: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay do 2 | @moduledoc """ 3 | Mockingjay is a library for compiling trained decision trees to `Nx` `defn` functions. 4 | 5 | It is based on the paper [Taming Model Serving Complexity, Performance and Cost: 6 | A Compilation to Tensor Computations Approach](https://scnakandala.github.io/papers/TR_2020_Hummingbird.pdf) 7 | and the accompanying [Hummingbird library](https://github.com/microsoft/hummingbird) from Microsoft. 8 | 9 | ## Protocol 10 | 11 | Mockingjay can be used with any model that implements the `Mockingjay.DecisionTree` protocol. For an example implementation, 12 | this protocol is implemented by `EXGBoost` in its `EXGBoost.Compile` module. This protocol is used to extract the trees from the model 13 | and to get the number of classes and features in the model. 14 | 15 | ## Strategies 16 | 17 | Mockingjay supports three strategies for compiling decision trees: `:gemm`, `:tree_traversal`, and `:perfect_tree_traversal`, 18 | or `:auto` to select using heuristics. The `:auto` strategy will select the best strategy based on the depth of the tree 19 | according to the following rules: 20 | 21 | * GEMM: Shallow Trees (<=3) 22 | 23 | * PerfectTreeTraversal: Tall trees where depth <= 10 24 | 25 | * TreeTraversal: Tall trees unfit for PerfectTreeTraversal (depth > 10) 26 | 27 | """ 28 | 29 | @doc """ 30 | Compiles a model that implements the `Mockingjay.DecisionTree` protocol to a `defn` function. 31 | 32 | ## Options 33 | 34 | * `:reorder_trees` - whether to reorder the trees in the model to optimize inference accuracy. Defaults to `true`. This assumes 35 | that trees are ordered such that they classify classes in order 0..n then repeat (e.g. a cyclic class prediction). If this is not 36 | the case, set this to `false` and implement custom ordering in the DecisionTree protocol implementation. 37 | 38 | * `:post_transform` - the post transform to use. Must be one of :none, :softmax, :sigmoid, :log_softmax, :log_sigmoid or :linear, 39 | or a custom function that receives the aggregation results. Defaults to sigmoid if n_classes <= 2, otherwise softmax. 40 | """ 41 | def convert(data, opts \\ []) do 42 | {strategy, opts} = Keyword.pop(opts, :strategy, :auto) 43 | 44 | strategy = 45 | case strategy do 46 | :gemm -> 47 | Mockingjay.Strategies.GEMM 48 | 49 | :tree_traversal -> 50 | Mockingjay.Strategies.TreeTraversal 51 | 52 | :perfect_tree_traversal -> 53 | Mockingjay.Strategies.PerfectTreeTraversal 54 | 55 | :auto -> 56 | Mockingjay.Strategy.get_strategy(data, opts) 57 | 58 | _ -> 59 | raise ArgumentError, 60 | "strategy must be one of :gemm, :tree_traversal, :perfect_tree_traversal, or :auto" 61 | end 62 | 63 | {post_transform, opts} = Keyword.pop(opts, :post_transform, nil) 64 | state = strategy.init(data, opts) 65 | 66 | fn data -> 67 | result = strategy.forward(data, state) 68 | {_, n_trees, n_classes} = Nx.shape(result) 69 | 70 | result 71 | |> aggregate(n_trees, n_classes) 72 | |> post_transform(post_transform, n_classes) 73 | end 74 | end 75 | 76 | defp aggregate(x, n_trees, n_classes) do 77 | cond do 78 | n_classes > 1 and n_trees > 1 -> 79 | n_gbdt_classes = if n_classes > 2, do: n_classes, else: 1 80 | n_trees_per_class = trunc(n_trees / n_gbdt_classes) 81 | 82 | x 83 | |> Nx.reshape({:auto, n_gbdt_classes, n_trees_per_class}) 84 | |> Nx.sum(axes: [2]) 85 | 86 | n_classes > 1 and n_trees == 1 -> 87 | Nx.squeeze(x, axes: [1]) 88 | 89 | true -> 90 | raise "unknown output type from strategy" 91 | end 92 | end 93 | 94 | defp post_transform(x, post_transform, n_classes) do 95 | fun = post_transform_to_fun(post_transform || infer_post_transform(n_classes)) 96 | fun.(x) 97 | end 98 | 99 | defp infer_post_transform(n_classes) when n_classes <= 2, do: :sigmoid 100 | defp infer_post_transform(_), do: :softmax 101 | 102 | defp post_transform_to_fun(:none) do 103 | &Function.identity/1 104 | end 105 | 106 | defp post_transform_to_fun(post_transform) 107 | when post_transform in [:softmax, :linear, :sigmoid, :log_softmax, :log_sigmoid] do 108 | &apply(Axon.Activations, post_transform, [&1]) 109 | end 110 | 111 | defp post_transform_to_fun(post_transform) when is_function(post_transform, 1) do 112 | post_transform 113 | end 114 | 115 | defp post_transform_to_fun(post_transform) do 116 | raise ArgumentError, 117 | "invalid post_transform: #{inspect(post_transform)} -- must be one of :none, :softmax, :sigmoid, :log_softmax, :log_sigmoid or :linear -- or a custom function of arity 1" 118 | end 119 | end 120 | -------------------------------------------------------------------------------- /lib/mockingjay/strategies/gemm.ex: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.Strategies.GEMM do 2 | @moduledoc false 3 | import Nx.Defn 4 | 5 | alias Mockingjay.Tree 6 | alias Mockingjay.DecisionTree 7 | 8 | @behaviour Mockingjay.Strategy 9 | 10 | @impl true 11 | def init(ensemble, opts \\ []) do 12 | opts = Keyword.validate!(opts, reorder_trees: true) 13 | trees = DecisionTree.trees(ensemble) 14 | 15 | num_features = DecisionTree.num_features(ensemble) 16 | condition = DecisionTree.condition(ensemble) 17 | 18 | # Overall number of classes for classification, 1 for regression 19 | n_classes = DecisionTree.num_classes(ensemble) 20 | 21 | trees = 22 | if opts[:reorder_trees] do 23 | for j <- 0..(n_classes - 1), 24 | i <- 0..(Integer.floor_div(length(trees), n_classes) - 1), 25 | do: Enum.at(trees, i * n_classes + j) 26 | else 27 | trees 28 | end 29 | 30 | # Number of classes each weak learner can predict 31 | # We infer from the shape of a leaf's :value key 32 | n_weak_learner_classes = 33 | trees 34 | |> hd() 35 | |> Tree.get_decision_values() 36 | |> hd() 37 | |> case do 38 | value when is_list(value) -> 39 | length(value) 40 | 41 | _value -> 42 | 1 43 | end 44 | 45 | {max_decision_nodes, max_leaf_nodes} = 46 | Enum.reduce(trees, {0, 0}, fn tree, {h1, h2} -> 47 | {max(h1, length(Tree.get_decision_nodes(tree))), 48 | max(h2, length(Tree.get_leaf_nodes(tree)))} 49 | end) 50 | 51 | n_trees = length(trees) 52 | 53 | {mat_A, mat_B} = generate_matrices_AB(trees, num_features, max_decision_nodes) 54 | mat_C = generate_matrix_C(trees, max_decision_nodes, max_leaf_nodes) 55 | {mat_D, mat_E} = generate_matrices_DE(trees, max_leaf_nodes, n_weak_learner_classes) 56 | 57 | arg = %{ 58 | mat_A: mat_A, 59 | mat_B: mat_B, 60 | mat_C: mat_C, 61 | mat_D: mat_D, 62 | mat_E: mat_E 63 | } 64 | 65 | opts = [ 66 | condition: Mockingjay.Strategy.cond_to_fun(condition), 67 | n_trees: n_trees, 68 | max_decision_nodes: max_decision_nodes, 69 | max_leaf_nodes: max_leaf_nodes, 70 | n_weak_learner_classes: n_weak_learner_classes, 71 | n_classes: n_classes 72 | ] 73 | 74 | {arg, opts} 75 | end 76 | 77 | @impl true 78 | deftransform forward(x, {arg, opts}) do 79 | opts = 80 | Keyword.validate!(opts, [ 81 | :condition, 82 | :n_trees, 83 | :n_classes, 84 | :max_decision_nodes, 85 | :max_leaf_nodes, 86 | :n_weak_learner_classes, 87 | :custom_forward 88 | ]) 89 | 90 | _forward(x, arg, opts) 91 | end 92 | 93 | defnp _forward(x, arg, opts \\ []) do 94 | %{mat_A: mat_A, mat_B: mat_B, mat_C: mat_C, mat_D: mat_D, mat_E: mat_E} = arg 95 | 96 | condition = opts[:condition] 97 | n_trees = opts[:n_trees] 98 | n_classes = opts[:n_classes] 99 | max_decision_nodes = opts[:max_decision_nodes] 100 | max_leaf_nodes = opts[:max_leaf_nodes] 101 | n_weak_learner_classes = opts[:n_weak_learner_classes] 102 | 103 | mat_A 104 | |> Nx.dot([1], x, [1]) 105 | |> condition.(mat_B) 106 | |> Nx.reshape({n_trees, max_decision_nodes, :auto}) 107 | |> then(&Nx.dot(mat_C, [2], [0], &1, [1], [0])) 108 | |> Nx.reshape({n_trees * max_leaf_nodes, :auto}) 109 | |> Nx.equal(mat_D) 110 | |> Nx.reshape({n_trees, max_leaf_nodes, :auto}) 111 | |> then(&Nx.dot(mat_E, [2], [0], &1, [1], [0])) 112 | |> Nx.reshape({n_trees, n_weak_learner_classes, :auto}) 113 | |> Nx.transpose() 114 | |> Nx.reshape({:auto, n_trees, n_classes}) 115 | end 116 | 117 | # Leaves are ordered as DFS rather than BFS that internal nodes are 118 | # TO-DO: make TCOptimizable 119 | defp get_leaf_left_depths(root) do 120 | _get_leaf_left_depths(root, 0) 121 | end 122 | 123 | defp _get_leaf_left_depths(root, depth) do 124 | case root do 125 | %{left: nil, right: nil} -> 126 | [depth] 127 | 128 | %{left: left, right: right} -> 129 | _get_leaf_left_depths(left, depth + 1) ++ _get_leaf_left_depths(right, depth) 130 | end 131 | end 132 | 133 | defp generate_matrices_AB(trees, num_features, max_decision_nodes) do 134 | n_trees = length(trees) 135 | 136 | {indices_list, updates_list} = 137 | trees 138 | |> Enum.with_index() 139 | |> Enum.flat_map(fn {tree, tree_index} -> 140 | Enum.with_index(Tree.get_decision_values(tree), fn value, node_index -> 141 | {[tree_index, node_index, value.feature], value.threshold} 142 | end) 143 | end) 144 | |> Enum.unzip() 145 | 146 | a_indices = Nx.tensor(indices_list) 147 | b_indices = a_indices[[.., 0..1]] 148 | 149 | a_updates = Nx.broadcast(1, {Nx.axis_size(a_indices, 0)}) 150 | b_updates = Nx.tensor(updates_list) 151 | 152 | a_zeros = Nx.broadcast(0, {n_trees, max_decision_nodes, num_features}) 153 | 154 | b_zeros = Nx.slice_along_axis(a_zeros, 0, 1, axis: -1) |> Nx.squeeze(axes: [-1]) 155 | 156 | a = Nx.indexed_put(a_zeros, a_indices, a_updates) 157 | 158 | b = Nx.indexed_put(b_zeros, b_indices, b_updates) 159 | 160 | num_rows = n_trees * max_decision_nodes 161 | {Nx.reshape(a, {num_rows, num_features}), Nx.reshape(b, {num_rows, 1})} 162 | end 163 | 164 | defp generate_matrix_C(trees, max_decision_nodes, max_leaf_nodes) do 165 | n_trees = length(trees) 166 | 167 | child_matrix = 168 | Enum.flat_map(Enum.with_index(trees), fn {tree, tree_index} -> 169 | Enum.flat_map(Enum.with_index(Tree.get_decision_nodes(tree)), fn {internal_node, 170 | internal_index} -> 171 | Enum.with_index(Tree.get_leaf_nodes(tree), fn leaf_node, leaf_index -> 172 | truth_value = 173 | cond do 174 | Tree.child?(internal_node.left, leaf_node.id) -> 1 175 | Tree.child?(internal_node.right, leaf_node.id) -> -1 176 | true -> 0 177 | end 178 | 179 | [tree_index, leaf_index, internal_index, truth_value] 180 | end) 181 | end) 182 | end) 183 | |> Nx.tensor() 184 | 185 | # Gets the tensor of 'truth values' 186 | axis_size = Nx.axis_size(child_matrix, -1) 187 | updates = Nx.transpose(child_matrix)[-1] 188 | indices = Nx.slice_along_axis(child_matrix, 0, axis_size - 1, axis: -1) 189 | 190 | Nx.indexed_put( 191 | Nx.broadcast(0, {n_trees, max_leaf_nodes, max_decision_nodes}), 192 | indices, 193 | updates 194 | ) 195 | end 196 | 197 | defp generate_matrices_DE(trees, max_leaf_nodes, n_weak_learner_classes) do 198 | n_trees = length(trees) 199 | 200 | {indices_list, updates_list} = 201 | trees 202 | |> Enum.with_index() 203 | |> Enum.flat_map(fn {tree, index} -> 204 | Enum.with_index(Tree.get_leaf_nodes(tree), fn node, node_index -> 205 | if n_weak_learner_classes == 1 do 206 | {[index, 0, node_index], node.value} 207 | else 208 | {[index, trunc(node.value), node_index], 1} 209 | end 210 | end) 211 | end) 212 | |> Enum.unzip() 213 | 214 | e_indices = Nx.tensor(indices_list) 215 | 216 | d_indices = Nx.take(e_indices, Nx.tensor([0, 2]), axis: 1) 217 | 218 | d_updates = trees |> Enum.flat_map(&get_leaf_left_depths/1) |> Nx.tensor() 219 | d_zero = Nx.broadcast(0, {n_trees, max_leaf_nodes}) 220 | 221 | d = Nx.indexed_put(d_zero, d_indices, d_updates) 222 | 223 | e_updates = Nx.tensor(updates_list) 224 | 225 | e_zero = Nx.broadcast(0, {n_trees, n_weak_learner_classes, max_leaf_nodes}) 226 | 227 | e = Nx.indexed_put(e_zero, e_indices, e_updates) 228 | 229 | {Nx.reshape(d, {:auto, 1}), e} 230 | end 231 | end 232 | -------------------------------------------------------------------------------- /lib/mockingjay/strategies/perfect_tree_traversal.ex: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.Strategies.PerfectTreeTraversal do 2 | @moduledoc false 3 | alias Mockingjay.Tree 4 | alias Mockingjay.DecisionTree 5 | import Nx.Defn 6 | @behaviour Mockingjay.Strategy 7 | 8 | # Derived from Binary Tree structure 9 | @factor 2 10 | 11 | @impl true 12 | def init(data, opts \\ []) do 13 | opts = Keyword.validate!(opts, reorder_trees: true) 14 | trees = DecisionTree.trees(data) 15 | condition = DecisionTree.condition(data) 16 | n_classes = DecisionTree.num_classes(data) 17 | num_trees = length(trees) 18 | 19 | trees = 20 | if opts[:reorder_trees] do 21 | for j <- 0..(n_classes - 1), 22 | i <- 0..(Integer.floor_div(length(trees), n_classes) - 1), 23 | do: Enum.at(trees, i * n_classes + j) 24 | else 25 | trees 26 | end 27 | 28 | # Number of classes each weak learner can predict 29 | # We infer from the shape of a leaf's :value key 30 | n_weak_learner_classes = 31 | trees 32 | |> hd() 33 | |> Tree.get_decision_values() 34 | |> hd() 35 | |> case do 36 | value when is_list(value) -> 37 | length(value) 38 | 39 | _value -> 40 | 1 41 | end 42 | 43 | [h | t] = trees 44 | 45 | max_tree_depth = 46 | Enum.reduce(t, Tree.depth(h), fn tree, acc -> 47 | max(acc, Tree.depth(tree)) 48 | end) 49 | 50 | perfect_trees = Enum.map(trees, &make_tree_perfect(&1, 0, max_tree_depth)) 51 | 52 | {features, tt_tv} = 53 | Enum.map(perfect_trees, fn tree -> 54 | {tf, tt, tv} = 55 | Enum.reduce(Tree.bfs(tree), {[], [], []}, fn 56 | node, {tree_features, tree_thresholds, tree_values} -> 57 | case node do 58 | %Tree{left: nil, right: nil} -> 59 | tree_values = [[node.value] | tree_values] 60 | {tree_features, tree_thresholds, tree_values} 61 | 62 | %Tree{left: _left, right: _right} -> 63 | tree_features = [node.value.feature | tree_features] 64 | tree_thresholds = [node.value.threshold | tree_thresholds] 65 | {tree_features, tree_thresholds, tree_values} 66 | end 67 | end) 68 | 69 | tf = Enum.reverse(tf) 70 | tt = Enum.reverse(tt) 71 | tv = Enum.reverse(tv) 72 | 73 | {tf, {tt, tv}} 74 | end) 75 | |> Enum.unzip() 76 | 77 | {thresholds, values} = Enum.unzip(tt_tv) 78 | 79 | # shape of {num_trees, 2 ** max_tree_depth - 1} 80 | features = 81 | features 82 | |> Nx.tensor() 83 | |> Nx.reshape({num_trees, @factor ** max_tree_depth - 1}) 84 | 85 | # shape of {num_trees, 2 ** max_tree_depth - 1} 86 | thresholds = 87 | thresholds 88 | |> Nx.tensor() 89 | |> Nx.reshape({num_trees, @factor ** max_tree_depth - 1}) 90 | 91 | # shape of {num_trees, 2 ** max_tree_depth} 92 | values = 93 | values 94 | |> Nx.tensor() 95 | |> Nx.reshape({:auto, n_weak_learner_classes}) 96 | 97 | root_features = Nx.flatten(features[[.., 0]]) 98 | 99 | root_thresholds = Nx.flatten(thresholds[[.., 0]]) 100 | 101 | {features, thresholds} = 102 | Enum.reduce(1..(max_tree_depth - 1), {[], []}, fn depth, {all_nodes, all_biases} -> 103 | start = @factor ** depth - 1 104 | stop = @factor ** (depth + 1) - 2 105 | 106 | n = Nx.flatten(features[[.., start..stop]]) 107 | 108 | b = Nx.flatten(thresholds[[.., start..stop]]) 109 | 110 | {[n | all_nodes], [b | all_biases]} 111 | end) 112 | 113 | features = Enum.reverse(features) |> List.to_tuple() 114 | thresholds = Enum.reverse(thresholds) |> List.to_tuple() 115 | 116 | nt = @factor * num_trees 117 | 118 | indices = 0..(nt - 1)//2 |> Enum.into([]) |> Nx.tensor(type: :s64) 119 | 120 | [ 121 | indices: indices, 122 | num_trees: num_trees, 123 | max_tree_depth: max_tree_depth, 124 | features: features, 125 | thresholds: thresholds, 126 | root_features: root_features, 127 | root_thresholds: root_thresholds, 128 | values: values, 129 | condition: Mockingjay.Strategy.cond_to_fun(condition), 130 | n_classes: n_classes 131 | ] 132 | end 133 | 134 | @impl true 135 | def forward(x, opts \\ []) do 136 | opts = 137 | Keyword.validate!(opts, [ 138 | :custom_forward, 139 | :root_features, 140 | :root_thresholds, 141 | :condition, 142 | :indices, 143 | :num_trees, 144 | :n_classes, 145 | :thresholds, 146 | :features, 147 | :max_tree_depth, 148 | :values 149 | ]) 150 | 151 | _forward( 152 | x, 153 | opts[:root_features], 154 | opts[:root_thresholds], 155 | opts[:features], 156 | opts[:thresholds], 157 | opts[:values], 158 | opts[:indices], 159 | Keyword.take(opts, [:num_trees, :condition, :n_classes]) 160 | ) 161 | end 162 | 163 | defnp _forward( 164 | x, 165 | root_features, 166 | root_thresholds, 167 | features, 168 | thresholds, 169 | values, 170 | indices, 171 | opts \\ [] 172 | ) do 173 | prev_indices = 174 | x 175 | |> Nx.take(root_features, axis: 1) 176 | |> opts[:condition].(root_thresholds) 177 | |> Nx.add(indices) 178 | |> Nx.reshape({:auto}) 179 | |> forward_reduce_features(x, features, thresholds, opts) 180 | 181 | Nx.take(values, prev_indices) 182 | |> Nx.reshape({:auto, opts[:num_trees], opts[:n_classes]}) 183 | end 184 | 185 | deftransformp forward_reduce_features(prev_indices, x, features, thresholds, opts \\ []) do 186 | Enum.zip_reduce( 187 | Tuple.to_list(features), 188 | Tuple.to_list(thresholds), 189 | prev_indices, 190 | fn nodes, biases, acc -> 191 | gather_indices = nodes |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]}) 192 | features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto}) 193 | 194 | acc 195 | |> Nx.multiply(@factor) 196 | |> Nx.add(opts[:condition].(features, Nx.take(biases, acc))) 197 | end 198 | ) 199 | end 200 | 201 | defp make_tree_perfect(tree, current_depth, max_depth) do 202 | case tree do 203 | %Tree{left: nil, right: nil} -> 204 | if current_depth < max_depth do 205 | %Tree{ 206 | id: make_ref(), 207 | # This can be anything since either path results in the same leaf 208 | value: %{feature: 0, threshold: 0}, 209 | left: make_tree_perfect(tree, current_depth + 1, max_depth), 210 | right: make_tree_perfect(tree, current_depth + 1, max_depth) 211 | } 212 | else 213 | tree 214 | end 215 | 216 | %Tree{left: left, right: right} -> 217 | struct(tree, 218 | left: make_tree_perfect(left, current_depth + 1, max_depth), 219 | right: make_tree_perfect(right, current_depth + 1, max_depth) 220 | ) 221 | end 222 | end 223 | end 224 | -------------------------------------------------------------------------------- /lib/mockingjay/strategies/tree_traversal.ex: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.Strategies.TreeTraversal do 2 | @moduledoc false 3 | import Nx.Defn 4 | alias Mockingjay.Tree 5 | alias Mockingjay.DecisionTree 6 | @behaviour Mockingjay.Strategy 7 | 8 | @impl true 9 | def init(ensemble, opts \\ []) do 10 | opts = Keyword.validate!(opts, reorder_trees: true) 11 | 12 | trees = DecisionTree.trees(ensemble) 13 | 14 | condition = DecisionTree.condition(ensemble) 15 | n_classes = DecisionTree.num_classes(ensemble) 16 | num_trees = length(trees) 17 | 18 | trees = 19 | if opts[:reorder_trees] do 20 | for j <- 0..(n_classes - 1), 21 | i <- 0..(Integer.floor_div(length(trees), n_classes) - 1), 22 | do: Enum.at(trees, i * n_classes + j) 23 | else 24 | trees 25 | end 26 | 27 | # Number of classes each weak learner can predict 28 | # We infer from the shape of a leaf's :value key 29 | n_weak_learner_classes = 30 | case trees |> hd |> Tree.get_decision_values() |> hd do 31 | value when is_list(value) -> 32 | length(value) 33 | 34 | _value -> 35 | 1 36 | end 37 | 38 | num_nodes = 39 | Enum.reduce(trees, 0, fn tree, acc -> 40 | max(acc, length(Tree.bfs(tree))) 41 | end) 42 | 43 | max_tree_depth = 44 | Enum.reduce(trees, 0, fn tree, acc -> 45 | max(acc, Tree.depth(tree)) 46 | end) 47 | 48 | {lefts, rights, features, thresholds, values} = 49 | trees 50 | |> Enum.reduce({[], [], [], [], []}, fn tree, 51 | {all_lefts, all_rights, all_features, 52 | all_thresholds, all_values} -> 53 | dfs_tree = Tree.dfs(tree) 54 | 55 | id_to_index = 56 | Enum.reduce(Enum.with_index(dfs_tree), %{}, fn {node, index}, acc -> 57 | Map.put_new(acc, node.id, index) 58 | end) 59 | 60 | {tl, tr, tf, tt, tv} = 61 | Enum.reduce(dfs_tree, {[], [], [], [], []}, fn node, 62 | {tree_lefts, tree_rights, tree_features, 63 | tree_thresholds, tree_values} -> 64 | case node do 65 | %Tree{left: nil, right: nil} -> 66 | tree_lefts = tree_lefts ++ [id_to_index[node.id]] 67 | tree_rights = tree_rights ++ [id_to_index[node.id]] 68 | tree_features = tree_features ++ [0] 69 | tree_thresholds = tree_thresholds ++ [0] 70 | current_value = if is_list(node.value), do: node.value, else: [node.value] 71 | tree_values = tree_values ++ [current_value] 72 | {tree_lefts, tree_rights, tree_features, tree_thresholds, tree_values} 73 | 74 | %Tree{left: left, right: right} -> 75 | tree_lefts = tree_lefts ++ [id_to_index[left.id]] 76 | tree_rights = tree_rights ++ [id_to_index[right.id]] 77 | tree_features = tree_features ++ [node.value.feature] 78 | tree_thresholds = tree_thresholds ++ [node.value.threshold] 79 | tree_values = tree_values ++ [[-1]] 80 | {tree_lefts, tree_rights, tree_features, tree_thresholds, tree_values} 81 | end 82 | end) 83 | 84 | tl = Nx.tensor(tl) |> Nx.pad(0, [{0, num_nodes - length(tl), 0}]) 85 | tr = Nx.tensor(tr) |> Nx.pad(0, [{0, num_nodes - length(tr), 0}]) 86 | tf = Nx.tensor(tf) |> Nx.pad(0, [{0, num_nodes - length(tf), 0}]) 87 | tt = Nx.tensor(tt) |> Nx.pad(0, [{0, num_nodes - length(tt), 0}]) 88 | tv = Nx.tensor(tv) |> Nx.pad(0, [{0, num_nodes - length(tv), 0}, {0, 0, 0}]) 89 | 90 | {[tl | all_lefts], [tr | all_rights], [tf | all_features], [tt | all_thresholds], 91 | [tv | all_values]} 92 | end) 93 | 94 | lefts = 95 | Nx.stack(Enum.reverse(lefts)) 96 | |> Nx.reshape({:auto}) 97 | 98 | rights = 99 | Nx.stack(Enum.reverse(rights)) 100 | |> Nx.reshape({:auto}) 101 | 102 | features = 103 | Nx.stack(Enum.reverse(features)) 104 | |> Nx.reshape({:auto}) 105 | 106 | thresholds = 107 | Nx.stack(Enum.reverse(thresholds)) 108 | |> Nx.reshape({:auto}) 109 | 110 | values = 111 | Nx.stack(Enum.reverse(values)) 112 | |> Nx.reshape({:auto, n_weak_learner_classes}) 113 | 114 | nodes_offset = 115 | Nx.iota({1, num_trees}, type: :s64) 116 | |> Nx.multiply(num_nodes) 117 | 118 | [ 119 | nodes_offset: nodes_offset, 120 | num_trees: num_trees, 121 | max_tree_depth: max_tree_depth, 122 | lefts: lefts, 123 | rights: rights, 124 | features: features, 125 | thresholds: thresholds, 126 | values: values, 127 | condition: Mockingjay.Strategy.cond_to_fun(condition), 128 | n_classes: n_classes 129 | ] 130 | end 131 | 132 | @impl true 133 | deftransform forward(x, opts \\ []) do 134 | opts = 135 | Keyword.validate!(opts, [ 136 | :custom_forward, 137 | :max_tree_depth, 138 | :num_trees, 139 | :n_classes, 140 | :nodes_offset, 141 | :lefts, 142 | :rights, 143 | :features, 144 | :thresholds, 145 | :values, 146 | :condition, 147 | unroll: false 148 | ]) 149 | 150 | _forward( 151 | x, 152 | opts[:features], 153 | opts[:lefts], 154 | opts[:rights], 155 | opts[:thresholds], 156 | opts[:nodes_offset], 157 | opts[:values], 158 | opts 159 | ) 160 | end 161 | 162 | defn _forward(x, features, lefts, rights, thresholds, nodes_offset, values, opts \\ []) do 163 | max_tree_depth = opts[:max_tree_depth] 164 | num_trees = opts[:num_trees] 165 | n_classes = opts[:n_classes] 166 | condition = opts[:condition] 167 | unroll = opts[:unroll] 168 | 169 | batch_size = Nx.axis_size(x, 0) 170 | 171 | indices = 172 | nodes_offset 173 | |> Nx.broadcast({batch_size, num_trees}) 174 | |> Nx.reshape({:auto}) 175 | 176 | {indices, _} = 177 | while {tree_nodes = indices, {features, lefts, rights, thresholds, nodes_offset, x}}, 178 | _ <- 1..max_tree_depth, 179 | unroll: unroll do 180 | feature_nodes = Nx.take(features, tree_nodes) |> Nx.reshape({:auto, num_trees}) 181 | feature_values = Nx.take_along_axis(x, feature_nodes, axis: 1) 182 | local_thresholds = Nx.take(thresholds, tree_nodes) |> Nx.reshape({:auto, num_trees}) 183 | local_lefts = Nx.take(lefts, tree_nodes) |> Nx.reshape({:auto, num_trees}) 184 | local_rights = Nx.take(rights, tree_nodes) |> Nx.reshape({:auto, num_trees}) 185 | 186 | result = 187 | Nx.select( 188 | condition.(feature_values, local_thresholds), 189 | local_lefts, 190 | local_rights 191 | ) 192 | |> Nx.add(nodes_offset) 193 | |> Nx.reshape({:auto}) 194 | 195 | {result, {features, lefts, rights, thresholds, nodes_offset, x}} 196 | end 197 | 198 | values 199 | |> Nx.take(indices) 200 | |> Nx.reshape({:auto, num_trees, n_classes}) 201 | end 202 | end 203 | -------------------------------------------------------------------------------- /lib/mockingjay/strategy.ex: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.Strategy do 2 | @moduledoc false 3 | @type t :: Nx.Container.t() 4 | 5 | @callback init(data :: any(), opts :: Keyword.t()) :: term() 6 | @callback forward(x :: Nx.Container.t(), term()) :: Nx.Tensor.t() 7 | 8 | def cond_to_fun(condition) 9 | when condition in [:greater, :less, :greater_equal, :less_equal, :equal, :not_equal] do 10 | &apply(Nx, condition, [&1, &2]) 11 | end 12 | 13 | def cond_to_fun(condition) when is_function(condition, 2) do 14 | condition 15 | end 16 | 17 | def cond_to_fun(condition), 18 | do: 19 | raise( 20 | ArgumentError, 21 | "Invalid condition: #{inspect(condition)} -- must be one of :greater, :less, :greater_equal, :less_equal, :equal, :not_equal -- or a custom function of arity 2" 22 | ) 23 | 24 | def get_strategy(ensemble, opts \\ []) do 25 | opts = Keyword.validate!(opts, high: 10, low: 3) 26 | # The current heuristic is such that GEMM <= low < PerfTreeTrav <= high < TreeTrav 27 | max_tree_depth = 28 | Enum.reduce(Mockingjay.DecisionTree.trees(ensemble), 0, fn tree, acc -> 29 | max(acc, Mockingjay.Tree.depth(tree)) 30 | end) 31 | 32 | cond do 33 | max_tree_depth <= opts[:low] -> 34 | Mockingjay.Strategies.GEMM 35 | 36 | max_tree_depth <= opts[:high] -> 37 | Mockingjay.Strategies.PerfectTreeTraversal 38 | 39 | true -> 40 | Mockingjay.Strategies.TreeTraversal 41 | end 42 | end 43 | end 44 | -------------------------------------------------------------------------------- /lib/mockingjay/tree.ex: -------------------------------------------------------------------------------- 1 | defprotocol Mockingjay.DecisionTree do 2 | @moduledoc """ 3 | A protocol for extracting decision trees from a model and getting information about the model. 4 | 5 | This protocol can be used for any model that uses decision trees as its base model. 6 | As such, this can be used for both ensemble and single decision tree models. This protocol 7 | requires that the model implement the `Mockingjay.Tree` struct for representing decision trees. 8 | This protocol also requires that all decsision split conditions in the model are the same condition. 9 | The model does not need to be a perfect binary tree, but it must be a binary tree. 10 | """ 11 | 12 | @doc """ 13 | Returns a list of `Mockingjay.Tree` struct representing the decision tree. 14 | """ 15 | @spec trees(data :: any) :: [Tree.t()] 16 | def trees(data) 17 | 18 | @doc """ 19 | Returns the number of classes in the model. 20 | """ 21 | @spec num_classes(data :: any) :: pos_integer() 22 | def num_classes(data) 23 | 24 | @doc """ 25 | Returns the number of features in the model. 26 | """ 27 | @spec num_features(data :: any) :: pos_integer() 28 | def num_features(data) 29 | 30 | @doc """ 31 | Returns the condition used to split the data. 32 | """ 33 | @spec condition(data :: any) :: :greater | :less | :greater_equal | :less_equal 34 | def condition(data) 35 | end 36 | 37 | defmodule Mockingjay.Tree do 38 | @moduledoc """ 39 | A struct containing a convenient in-memory representation of a decision tree. 40 | 41 | Each "node" in the tree is a `Tree` struct. A "decision" or "inner" node is any node whose 42 | `:left` and `:right` values are a valid `Tree`. A "leaf" or "outer" node is any node whose `:left` 43 | and `:right` values are `nil`. `:left` and `:right` must either both be `nil` or both be a `Tree`. 44 | 45 | * `:id` - The id of the node. This is a unique reference. This is generated automatically when using the `from_map` function. 46 | * `:left` - The left child of the node. This is `nil` for leaf nodes. This is a `Tree` for decision nodes. 47 | * `:right` - The right child of the node. This is `nil` for leaf nodes. This is a `Tree` for decision nodes. 48 | * `:value` - The value of the node: 49 | * For leaf nodes, this is the value as a number. 50 | * For non-leaf nodes, this is a map containing the following keys: 51 | * `:feature` - The feature used to split the data (if it is not a leaf). 52 | * `:threshold` - The threshold used to split the data (if it is not a leaf). 53 | """ 54 | 55 | @enforce_keys [:id, :value] 56 | defstruct [:id, :left, :right, :value] 57 | 58 | @doc """ 59 | Returns a `Tree` struct from a map. 60 | The map must have the appropriate required keys for the `Tree` struct. Any extra keys are ignored. 61 | """ 62 | def from_map(%__MODULE__{} = t), do: t 63 | 64 | def from_map(%{} = map) do 65 | case map do 66 | %{left: nil, right: nil, value: value} when is_number(value) -> 67 | %__MODULE__{ 68 | id: make_ref(), 69 | left: nil, 70 | right: nil, 71 | value: value 72 | } 73 | 74 | %{value: value} when is_number(value) -> 75 | %__MODULE__{ 76 | id: make_ref(), 77 | left: nil, 78 | right: nil, 79 | value: value 80 | } 81 | 82 | %{left: nil, right: nil, value: value} -> 83 | raise ArgumentError, "Leaf nodes must have a numeric value. Got: #{inspect(value)}" 84 | 85 | %{left: left, right: right, value: %{threshold: threshold, feature: feature}} 86 | when is_number(threshold) and is_number(feature) -> 87 | %__MODULE__{ 88 | id: make_ref(), 89 | left: from_map(left), 90 | right: from_map(right), 91 | value: %{threshold: threshold, feature: feature} 92 | } 93 | 94 | %{left: _left, right: _right, value: %{threshold: _threshold, feature: _feature}} -> 95 | raise ArgumentError, 96 | "Non-leaf nodes must have a numeric threshold and feature. Got: #{inspect(map)}" 97 | 98 | %{value: value} -> 99 | raise ArgumentError, "Leaf nodes must have a numeric value. Got: #{inspect(value)}" 100 | 101 | _ -> 102 | raise ArgumentError, "Invalid tree map: #{inspect(map)}" 103 | end 104 | end 105 | 106 | @typedoc "A simple binary tree implementation." 107 | @type t() :: %__MODULE__{ 108 | id: reference(), 109 | left: t() | nil, 110 | right: t() | nil, 111 | value: number() | %{feature: pos_integer(), threshold: number()} 112 | } 113 | 114 | # TO-DO: make TCOptimizable 115 | @doc """ 116 | Returns tree nodes as a list in DFS order. 117 | """ 118 | def dfs(root) do 119 | _dfs(root, []) 120 | end 121 | 122 | defp _dfs(root, acc) do 123 | case root do 124 | %{left: nil, right: nil} -> 125 | acc ++ [root] 126 | 127 | %{left: left, right: right} -> 128 | [root] ++ _dfs(left, acc) ++ _dfs(right, acc) 129 | end 130 | end 131 | 132 | # Credit to this SO answer: https://stackoverflow.com/questions/55327307/flatten-a-binary-tree-to-list-ordered 133 | @doc """ 134 | Returns a list of nodes in BFS order. 135 | For the uses in Mockingjay, BFS is tree-level order from right to left on each level. 136 | The nodes include their children nodes. 137 | """ 138 | def bfs(root) do 139 | root 140 | |> reduce_tree([], fn val, acc -> 141 | [val | acc] 142 | end) 143 | |> Enum.reverse() 144 | end 145 | 146 | @doc """ 147 | Traverse the tree in BFS order, applying a reducer function to each node. 148 | """ 149 | def reduce_tree(root, acc, reducer) do 150 | :queue.new() 151 | |> :queue.snoc(root) 152 | |> process_queue(acc, reducer) 153 | end 154 | 155 | defp process_queue(queue, acc, reducer) do 156 | case :queue.out(queue) do 157 | {{:value, %{left: nil, right: nil} = node}, popped} -> 158 | new_acc = reducer.(node, acc) 159 | process_queue(popped, new_acc, reducer) 160 | 161 | {{:value, %{left: left, right: right} = node}, popped} -> 162 | new_acc = reducer.(node, acc) 163 | 164 | popped 165 | |> :queue.snoc(right) 166 | |> :queue.snoc(left) 167 | |> process_queue(new_acc, reducer) 168 | 169 | _other -> 170 | acc 171 | end 172 | end 173 | 174 | @doc """ 175 | Returns the depth of the tree. The root node is at depth 0. 176 | """ 177 | def depth(tree) do 178 | depth_from_level(tree, 0) 179 | end 180 | 181 | @doc """ 182 | Returns the depth of the tree starting from the given level. 183 | """ 184 | def depth_from_level(%{} = tree, current_depth) do 185 | case tree.value do 186 | %{} -> 187 | max( 188 | depth_from_level(tree.left, current_depth + 1), 189 | depth_from_level(tree.right, current_depth + 1) 190 | ) 191 | 192 | _ -> 193 | current_depth 194 | end 195 | end 196 | 197 | @doc """ 198 | Returns a list of the decision nodes in the tree in BFS order. 199 | """ 200 | def get_decision_nodes(tree) do 201 | tree 202 | |> bfs() 203 | |> Enum.filter(fn node -> 204 | is_map(node.value) 205 | end) 206 | end 207 | 208 | @doc """ 209 | Returns a list of the leaf nodes in the tree in BFS order. 210 | """ 211 | def get_leaf_nodes(tree) do 212 | tree 213 | |> do_get_leaf_nodes([]) 214 | |> Enum.reverse() 215 | end 216 | 217 | defp do_get_leaf_nodes(tree, nodes) do 218 | # Unlilke get_decision_nodes, this function returns the leaf nodes in DFS order. 219 | case tree do 220 | %__MODULE__{left: nil, right: nil} -> 221 | [tree | nodes] 222 | 223 | %__MODULE__{left: left, right: right} -> 224 | left 225 | |> do_get_leaf_nodes(nodes) 226 | |> then(&do_get_leaf_nodes(right, &1)) 227 | end 228 | end 229 | 230 | @doc """ 231 | Returns a list of the values of the leaf nodes in the tree in BFS order. 232 | """ 233 | def get_leaf_values(tree) do 234 | tree 235 | |> get_leaf_nodes() 236 | |> Enum.map(& &1.value) 237 | end 238 | 239 | @doc """ 240 | Returns a list of the values of the decision nodes in the tree in BFS order. 241 | """ 242 | def get_decision_values(tree) do 243 | tree 244 | |> bfs() 245 | |> Enum.filter(fn node -> 246 | is_map(node.value) 247 | end) 248 | |> Enum.map(& &1.value) 249 | end 250 | 251 | @doc """ 252 | Checks is the given child_id exists in the tree. 253 | """ 254 | def child?(tree, child_id) do 255 | case tree do 256 | nil -> 257 | false 258 | 259 | %__MODULE__{id: id, left: left, right: right} -> 260 | id == child_id or child?(left, child_id) or child?(right, child_id) 261 | end 262 | end 263 | end 264 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.MixProject do 2 | use Mix.Project 3 | 4 | def project do 5 | [ 6 | app: :mockingjay, 7 | version: "0.1.0", 8 | elixir: "~> 1.14", 9 | start_permanent: Mix.env() == :prod, 10 | elixirc_paths: elixirc_paths(Mix.env()), 11 | deps: deps(), 12 | package: package(), 13 | docs: docs(), 14 | preferred_cli_env: [ 15 | docs: :docs, 16 | "hex.publish": :docs 17 | ], 18 | name: "Mockingjay", 19 | description: 20 | "A library to convert trained decision tree models into [Nx](https://hexdocs.pm/nx/Nx.html) tensor operations." 21 | ] 22 | end 23 | 24 | defp elixirc_paths(:test), do: ["lib", "test/support"] 25 | defp elixirc_paths(_), do: ["lib"] 26 | 27 | def application do 28 | [ 29 | extra_applications: [:logger] 30 | ] 31 | end 32 | 33 | defp deps do 34 | [ 35 | {:nx, "~> 0.5"}, 36 | {:axon, "~> 0.5"}, 37 | {:ex_doc, "~> 0.29", only: :docs} 38 | ] 39 | end 40 | 41 | defp package do 42 | [ 43 | maintainers: ["Andres Alejos"], 44 | licenses: ["Apache-2.0"], 45 | links: %{"GitHub" => "https://github.com/acalejos/mockingjay"} 46 | ] 47 | end 48 | 49 | defp docs do 50 | [ 51 | main: "Mockingjay" 52 | ] 53 | end 54 | end 55 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "axon": {:hex, :axon, "0.5.1", "1ae3a2193df45e51fca912158320b2ca87cb7fba4df242bd3ebe245504d0ea1a", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "d36f2a11c34c6c2b458f54df5c71ffdb7ed91c6a9ccd908faba909c84cc6a38e"}, 3 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 4 | "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, 5 | "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, 6 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 7 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, 8 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, 9 | "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, 10 | "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, 11 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 12 | } 13 | -------------------------------------------------------------------------------- /test/mockingjay/gemm_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.GEMMTest do 2 | use ExUnit.Case, async: true 3 | 4 | alias Mockingjay.Tree 5 | alias Mockingjay.Strategies.GEMM 6 | 7 | # Test tree and matrix outputs from the Hummingbird paper 8 | setup do 9 | tree = %Tree{ 10 | id: 1, 11 | value: %{feature: 2, condition: :lt, threshold: 0.5}, 12 | left: %Tree{ 13 | id: 2, 14 | value: %{feature: 1, condition: :lt, threshold: 2.0}, 15 | left: %Tree{id: 3, value: 0}, 16 | right: %Tree{id: 4, value: 1} 17 | }, 18 | right: %Tree{ 19 | id: 5, 20 | value: %{feature: 4, condition: :lt, threshold: 5.5}, 21 | left: %Tree{ 22 | id: 6, 23 | value: %{feature: 2, condition: :lt, threshold: 2.4}, 24 | left: %Tree{id: 7, value: 1}, 25 | right: %Tree{id: 8, value: 0} 26 | }, 27 | right: %Tree{id: 9, value: 0} 28 | } 29 | } 30 | 31 | trees = [tree] 32 | 33 | {hidden_one_size, hidden_two_size} = 34 | Enum.reduce(trees, {0, 0}, fn tree, {h1, h2} -> 35 | {max(h1, length(Tree.get_decision_nodes(tree))), 36 | max(h2, length(Tree.get_leaf_nodes(tree)))} 37 | end) 38 | 39 | %{ 40 | tree: tree, 41 | trees: trees, 42 | hidden_one_size: hidden_one_size, 43 | hidden_two_size: hidden_two_size, 44 | num_features: 5, 45 | num_classes: 2 46 | } 47 | end 48 | 49 | test "convert", context do 50 | model = %Mockingjay.Model{ 51 | trees: context.trees, 52 | num_classes: 1, 53 | num_features: 5, 54 | output_type: :classification, 55 | condition: :less 56 | } 57 | 58 | f = Mockingjay.convert(model) 59 | assert is_function(f, 1) 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /test/mockingjay/tree_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.TreeTest do 2 | use ExUnit.Case, async: true 3 | 4 | alias Mockingjay.Tree 5 | 6 | setup do 7 | %{ 8 | tree: %Tree{ 9 | id: 1, 10 | value: %{feature: 2, threshold: 3}, 11 | left: %Tree{ 12 | id: 2, 13 | value: %{feature: 1, threshold: 5}, 14 | left: %Tree{id: 3, value: 10}, 15 | right: %Tree{ 16 | id: 4, 17 | value: %{feature: 3, threshold: 3}, 18 | left: %Tree{id: 5, value: 40}, 19 | right: %Tree{id: 6, value: 50} 20 | } 21 | }, 22 | right: %Tree{ 23 | id: 7, 24 | value: %{feature: 0, threshold: 1.2}, 25 | left: %Tree{id: 8, value: 30}, 26 | right: %Tree{id: 9, value: 20} 27 | } 28 | } 29 | } 30 | end 31 | 32 | test "from_map", context do 33 | new_tree = 34 | Tree.from_map(%{ 35 | value: %{feature: 2, threshold: 3}, 36 | left: %{ 37 | id: 2, 38 | value: %{feature: 1, threshold: 5}, 39 | left: %{id: 3, value: 10}, 40 | right: %{ 41 | id: 4, 42 | value: %{feature: 3, threshold: 3}, 43 | left: %{id: 5, value: 40}, 44 | right: %{id: 6, value: 50} 45 | } 46 | }, 47 | right: %{ 48 | id: 7, 49 | value: %{feature: 0, threshold: 1.2}, 50 | left: %{id: 8, value: 30}, 51 | right: %{id: 9, value: 20} 52 | } 53 | }) 54 | 55 | Enum.zip(Tree.bfs(context.tree), Tree.bfs(new_tree)) 56 | |> Enum.each(fn {a, b} -> 57 | assert a.value == b.value 58 | end) 59 | end 60 | 61 | test "bfs", context do 62 | assert Tree.bfs(context.tree) |> Enum.map(& &1.value) == [ 63 | %{feature: 2, threshold: 3}, 64 | %{feature: 0, threshold: 1.2}, 65 | %{feature: 1, threshold: 5}, 66 | 20, 67 | 30, 68 | %{feature: 3, threshold: 3}, 69 | 10, 70 | 50, 71 | 40 72 | ] 73 | end 74 | 75 | test "depth", context do 76 | assert Tree.depth(context.tree) == 3 77 | end 78 | 79 | test "get decision values", context do 80 | assert Tree.get_decision_values(context.tree) == [ 81 | %{feature: 2, threshold: 3}, 82 | %{feature: 0, threshold: 1.2}, 83 | %{feature: 1, threshold: 5}, 84 | %{feature: 3, threshold: 3} 85 | ] 86 | end 87 | 88 | test "get leaf values", context do 89 | assert Tree.get_leaf_values(context.tree) == [10, 40, 50, 30, 20] 90 | end 91 | end 92 | -------------------------------------------------------------------------------- /test/mockingjay_test.exs: -------------------------------------------------------------------------------- 1 | defmodule MockingjayTest do 2 | end 3 | -------------------------------------------------------------------------------- /test/support/model.ex: -------------------------------------------------------------------------------- 1 | defmodule Mockingjay.Model do 2 | defstruct [:trees, :num_classes, :num_features, :output_type, :condition] 3 | 4 | defimpl Mockingjay.DecisionTree do 5 | def trees(data) do 6 | Enum.map(data.trees, &Mockingjay.Tree.from_map/1) 7 | end 8 | 9 | def num_classes(data), do: data.num_classes 10 | def num_features(data), do: data.num_features 11 | def output_type(data), do: data.output_type 12 | def condition(data), do: data.condition 13 | end 14 | end 15 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | --------------------------------------------------------------------------------