├── .gitattributes ├── .gitignore ├── LICENSE.md ├── README.md ├── XgbFeatureInteractions.sln ├── XgbFeatureInteractions ├── App.config ├── FIScoreComparer.cs ├── FeatureInteraction.cs ├── FeatureInteractions.cs ├── FeatureScoreComparer.cs ├── GlobalSettings.cs ├── GlobalStats.cs ├── Program.cs ├── Properties │ ├── AssemblyInfo.cs │ ├── Settings.Designer.cs │ └── Settings.settings ├── SplitValueHistogram.cs ├── XgbFeatureInteractions.csproj ├── XgbModel.cs ├── XgbModelParser.cs ├── XgbTree.cs ├── XgbTreeNode.cs └── packages.config ├── bin ├── XgbFeatureInteractions.exe ├── XgbFeatureInteractions.exe.config └── lib │ ├── EPPlus.dll │ └── NGenerics.dll ├── doc ├── ScoresExample.png └── ScoresExample_small.png └── xgbfi_cc ├── xgbfi.cc └── xgbfi.h /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ################# 2 | ## Eclipse 3 | ################# 4 | 5 | *.pydevproject 6 | .project 7 | .metadata 8 | bin/ 9 | tmp/ 10 | *.tmp 11 | *.bak 12 | *.swp 13 | *~.nib 14 | local.properties 15 | .classpath 16 | .settings/ 17 | .loadpath 18 | 19 | # External tool builders 20 | .externalToolBuilders/ 21 | 22 | # Locally stored "Eclipse launch configurations" 23 | *.launch 24 | 25 | # CDT-specific 26 | .cproject 27 | 28 | # PDT-specific 29 | .buildpath 30 | 31 | 32 | ################# 33 | ## Visual Studio 34 | ################# 35 | 36 | ## Ignore Visual Studio temporary files, build results, and 37 | ## files generated by popular Visual Studio add-ons. 38 | 39 | # User-specific files 40 | *.suo 41 | *.user 42 | *.sln.docstates 43 | 44 | # Build results 45 | 46 | [Dd]ebug/ 47 | [Rr]elease/ 48 | x64/ 49 | build/ 50 | [Bb]in/ 51 | [Oo]bj/ 52 | 53 | # MSTest test Results 54 | [Tt]est[Rr]esult*/ 55 | [Bb]uild[Ll]og.* 56 | 57 | *_i.c 58 | *_p.c 59 | *.ilk 60 | *.meta 61 | *.obj 62 | *.pch 63 | *.pdb 64 | *.pgc 65 | *.pgd 66 | *.rsp 67 | *.sbr 68 | *.tlb 69 | *.tli 70 | *.tlh 71 | *.tmp 72 | *.tmp_proj 73 | *.log 74 | *.vspscc 75 | *.vssscc 76 | .builds 77 | *.pidb 78 | *.log 79 | *.scc 80 | 81 | # Visual C++ cache files 82 | ipch/ 83 | *.aps 84 | *.ncb 85 | *.opensdf 86 | *.sdf 87 | *.cachefile 88 | 89 | # Visual Studio profiler 90 | *.psess 91 | *.vsp 92 | *.vspx 93 | 94 | # Guidance Automation Toolkit 95 | *.gpState 96 | 97 | # ReSharper is a .NET coding add-in 98 | _ReSharper*/ 99 | *.[Rr]e[Ss]harper 100 | 101 | # TeamCity is a build add-in 102 | _TeamCity* 103 | 104 | # DotCover is a Code Coverage Tool 105 | *.dotCover 106 | 107 | # NCrunch 108 | *.ncrunch* 109 | .*crunch*.local.xml 110 | 111 | # Installshield output folder 112 | [Ee]xpress/ 113 | 114 | # DocProject is a documentation generator add-in 115 | DocProject/buildhelp/ 116 | DocProject/Help/*.HxT 117 | DocProject/Help/*.HxC 118 | DocProject/Help/*.hhc 119 | DocProject/Help/*.hhk 120 | DocProject/Help/*.hhp 121 | DocProject/Help/Html2 122 | DocProject/Help/html 123 | 124 | # Click-Once directory 125 | publish/ 126 | 127 | # Publish Web Output 128 | *.Publish.xml 129 | *.pubxml 130 | *.publishproj 131 | 132 | # NuGet Packages Directory 133 | ## TODO: If you have NuGet Package Restore enabled, uncomment the next line 134 | #packages/ 135 | 136 | # Windows Azure Build Output 137 | csx 138 | *.build.csdef 139 | 140 | # Windows Store app package directory 141 | AppPackages/ 142 | 143 | # Others 144 | sql/ 145 | *.Cache 146 | ClientBin/ 147 | [Ss]tyle[Cc]op.* 148 | ~$* 149 | *~ 150 | *.dbmdl 151 | *.[Pp]ublish.xml 152 | *.pfx 153 | *.publishsettings 154 | 155 | # RIA/Silverlight projects 156 | Generated_Code/ 157 | 158 | # Backup & report files from converting an old project file to a newer 159 | # Visual Studio version. Backup files are not needed, because we have git ;-) 160 | _UpgradeReport_Files/ 161 | Backup*/ 162 | UpgradeLog*.XML 163 | UpgradeLog*.htm 164 | 165 | # SQL Server files 166 | App_Data/*.mdf 167 | App_Data/*.ldf 168 | 169 | ############# 170 | ## Windows detritus 171 | ############# 172 | 173 | # Windows image file caches 174 | Thumbs.db 175 | ehthumbs.db 176 | 177 | # Folder config file 178 | Desktop.ini 179 | 180 | # Recycle Bin used on file shares 181 | $RECYCLE.BIN/ 182 | 183 | # Mac crap 184 | .DS_Store 185 | 186 | 187 | ############# 188 | ## Python 189 | ############# 190 | 191 | *.py[cod] 192 | 193 | # Packages 194 | *.egg 195 | *.egg-info 196 | dist/ 197 | build/ 198 | eggs/ 199 | parts/ 200 | var/ 201 | sdist/ 202 | develop-eggs/ 203 | .installed.cfg 204 | 205 | # Installer logs 206 | pip-log.txt 207 | 208 | # Unit test / coverage reports 209 | .coverage 210 | .tox 211 | 212 | #Translations 213 | *.mo 214 | 215 | #Mr Developer 216 | .mr.developer.cfg 217 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Xgbfi 2 | XGBoost Feature Interactions & Importance 3 | 4 | ### What is Xgbfi? 5 | Xgbfi is a [XGBoost](https://github.com/dmlc/xgboost) model dump parser, which ranks features as well as feature interactions by different metrics. 6 | 7 | ### Siblings 8 | [Xgbfir](https://github.com/limexp/xgbfir) - Python porting 9 | 10 | ### The Metrics 11 | * **Gain**: Total gain of each feature or feature interaction 12 | * **FScore**: Amount of possible splits taken on a feature or feature interaction 13 | * **wFScore**: Amount of possible splits taken on a feature or feature interaction weighted by the probability of the splits to take place 14 | * **Average wFScore**: *wFScore* divided by *FScore* 15 | * **Average Gain**: *Gain* divided by *FScore* 16 | * **Expected Gain**: Total gain of each feature or feature interaction weighted by the probability to gather the gain 17 | * **Average Tree Index** 18 | * **Average Tree Depth** 19 | 20 | ### Additional Features 21 | * **Leaf Statistics** 22 | * **Split Value Histograms** 23 | 24 | **Example:** 25 | 26 | ![](https://raw.githubusercontent.com/Far0n/xgbfi/master/doc/ScoresExample_small.png) 27 | 28 | ### Usage 29 | *[mono] XgbFeatureInteractions.exe [-help|options]* 30 | 31 | ### Quick Guide 32 | a) Creating a feature map (fmap) 33 | ```python 34 | def create_feature_map(fmap_filename, features): 35 | """ 36 | features: enumerable of feature names 37 | """ 38 | outfile = open(fmap_filename, 'w') 39 | for i, feat in enumerate(features): 40 | outfile.write('{0}\t{1}\tq\n'.format(i, feat)) 41 | outfile.close() 42 | 43 | create_feature_map('xgb.fmap', features) 44 | ``` 45 | 46 | b) Dumping a [XGBoost](https://github.com/dmlc/xgboost) model 47 | ```python 48 | gbdt.dump_model('xgb.dump',fmap='xgb.fmap', with_stats=True) 49 | ``` 50 | 51 | c) Editing Parameters in *XgbFeatureInteractions.exe.config* 52 | ```xml 53 | 54 | xgb.dump 55 | 56 | ``` 57 | 58 | d) Running *[mono] XgbFeatureInteractions.exe* without cmd line parameters 59 | -------------------------------------------------------------------------------- /XgbFeatureInteractions.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.23107.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "XgbFeatureInteractions", "XgbFeatureInteractions\XgbFeatureInteractions.csproj", "{09B47150-EEB5-4416-8D9A-7258CB0C717B}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {09B47150-EEB5-4416-8D9A-7258CB0C717B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {09B47150-EEB5-4416-8D9A-7258CB0C717B}.Debug|Any CPU.Build.0 = Debug|Any CPU 16 | {09B47150-EEB5-4416-8D9A-7258CB0C717B}.Release|Any CPU.ActiveCfg = Release|Any CPU 17 | {09B47150-EEB5-4416-8D9A-7258CB0C717B}.Release|Any CPU.Build.0 = Release|Any CPU 18 | EndGlobalSection 19 | GlobalSection(SolutionProperties) = preSolution 20 | HideSolutionNode = FALSE 21 | EndGlobalSection 22 | EndGlobal 23 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/App.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |
6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | xgb.model 20 | 21 | 22 | XgbFeatureInteractions.xlsx 23 | 24 | 25 | 100 26 | 27 | 28 | -1 29 | 30 | 31 | Gain 32 | 33 | 34 | -1 35 | 36 | 37 | 2 38 | 39 | 40 | 10 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/FIScoreComparer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | 7 | namespace XgbFeatureInteractions 8 | { 9 | 10 | public enum SortingMetric 11 | { 12 | Gain, 13 | FScore, 14 | FScoreWeighted, 15 | AverageFScoreWeighted, 16 | AverageGain, 17 | ExpectedGain 18 | } 19 | 20 | static class FIScoreComparer 21 | { 22 | private static Func _comparer { set; get; } 23 | public static SortingMetric SortBy { private set; get; } 24 | 25 | public static void SetComparer(string sortBy) 26 | { 27 | sortBy = sortBy.ToLower().Replace(" ", ""); 28 | if(sortBy == "gain") 29 | { 30 | SetComparer(SortingMetric.Gain); 31 | return; 32 | } 33 | 34 | if (sortBy == "fscore") 35 | { 36 | SetComparer(SortingMetric.FScore); 37 | return; 38 | } 39 | 40 | if (sortBy == "wfscore") 41 | { 42 | SetComparer(SortingMetric.FScoreWeighted); 43 | return; 44 | } 45 | 46 | if (sortBy == "avgwfscore") 47 | { 48 | SetComparer(SortingMetric.AverageFScoreWeighted); 49 | return; 50 | } 51 | 52 | if (sortBy == "avggain") 53 | { 54 | SetComparer(SortingMetric.AverageGain); 55 | return; 56 | } 57 | 58 | if (sortBy == "expgain") 59 | { 60 | SetComparer(SortingMetric.ExpectedGain); 61 | return; 62 | } 63 | } 64 | 65 | public static void SetComparer(SortingMetric sortingMetric) 66 | { 67 | switch(sortingMetric) 68 | { 69 | 70 | case SortingMetric.Gain: 71 | default: 72 | SortBy = SortingMetric.Gain; 73 | _comparer = (a, b) => 74 | { 75 | return -a.Gain.CompareTo(b.Gain); 76 | }; 77 | break; 78 | case SortingMetric.FScore: 79 | SortBy = SortingMetric.FScore; 80 | _comparer = (a, b) => 81 | { 82 | return -a.FScore.CompareTo(b.FScore); 83 | }; 84 | break; 85 | case SortingMetric.FScoreWeighted: 86 | SortBy = SortingMetric.FScoreWeighted; 87 | _comparer = (a, b) => 88 | { 89 | return -a.FScoreWeighted.CompareTo(b.FScoreWeighted); 90 | }; 91 | break; 92 | case SortingMetric.AverageFScoreWeighted: 93 | SortBy = SortingMetric.AverageFScoreWeighted; 94 | _comparer = (a, b) => 95 | { 96 | return -a.AverageFScoreWeighted.CompareTo(b.AverageFScoreWeighted); 97 | }; 98 | break; 99 | case SortingMetric.AverageGain: 100 | SortBy = SortingMetric.AverageGain; 101 | _comparer = (a, b) => 102 | { 103 | return -a.AverageGain.CompareTo(b.AverageGain); 104 | }; 105 | break; 106 | case SortingMetric.ExpectedGain: 107 | SortBy = SortingMetric.ExpectedGain; 108 | _comparer = (a, b) => 109 | { 110 | return -a.ExpectedGain.CompareTo(b.ExpectedGain); 111 | }; 112 | break; 113 | 114 | } 115 | } 116 | 117 | static FIScoreComparer() 118 | { 119 | SetComparer(SortingMetric.Gain); 120 | } 121 | 122 | public static int Compare(FeatureInteraction a, FeatureInteraction b) 123 | { 124 | return _comparer(a, b); 125 | } 126 | 127 | 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/FeatureInteraction.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | 7 | namespace XgbFeatureInteractions 8 | { 9 | public class FeatureInteraction : IComparable 10 | { 11 | public string Name { get; set; } 12 | public int Depth { get; set; } 13 | public double Gain { get; set; } 14 | public double Cover { get; set; } 15 | public double FScore { get; set; } 16 | public double FScoreWeighted { get; set; } 17 | public double AverageFScoreWeighted { get; set; } 18 | public double AverageGain { get; set; } 19 | public double ExpectedGain { get; set; } 20 | public double TreeIndex { get; set; } 21 | public double AverageTreeIndex { get; set; } 22 | public double TreeDepth { get; set; } 23 | public double AverageTreeDepth { get; set; } 24 | public SplitValueHistogram SplitValueHistogram { get; set; } 25 | 26 | public bool HasLeafStatistics { get; set; } 27 | public double SumLeafValuesLeft { get; set; } 28 | public double SumLeafCoversLeft { get; set; } 29 | public double SumLeafValuesRight { get; set; } 30 | public double SumLeafCoversRight { get; set; } 31 | 32 | public FeatureInteraction(HashSet interaction, double gain, double cover, double pathProbability, double depth, double treeIndex, double fScore = 1) 33 | { 34 | SplitValueHistogram = new SplitValueHistogram(); 35 | List features = interaction.OrderBy(x => x.Feature).Select(y => y.Feature).ToList(); 36 | 37 | Name = string.Join("|", features); 38 | Depth = interaction.Count - 1; 39 | Gain = gain; 40 | Cover = cover; 41 | FScore = fScore; 42 | FScoreWeighted = pathProbability; 43 | AverageFScoreWeighted = FScoreWeighted / FScore; 44 | AverageGain = Gain / FScore; 45 | ExpectedGain = Gain * pathProbability; 46 | TreeIndex = treeIndex; 47 | TreeDepth = depth; 48 | AverageTreeIndex = TreeIndex / FScore; 49 | AverageTreeDepth = TreeDepth / FScore; 50 | HasLeafStatistics = false; 51 | 52 | if (Depth == 0) 53 | { 54 | SplitValueHistogram.AddValue(interaction.First().SplitValue); 55 | } 56 | } 57 | 58 | public int CompareTo(object obj) 59 | { 60 | var featInteraction = obj as FeatureInteraction; 61 | return FIScoreComparer.Compare(this, featInteraction); 62 | } 63 | 64 | public override string ToString() 65 | { 66 | return String.Format("{0}:{1}", Name, Gain); 67 | } 68 | 69 | public override bool Equals(object obj) 70 | { 71 | var featInteraction = obj as FeatureInteraction; 72 | return this.Name.Equals(featInteraction.Name); 73 | } 74 | 75 | public override int GetHashCode() 76 | { 77 | return Name.GetHashCode(); 78 | } 79 | 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/FeatureInteractions.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.IO; 4 | using System.Linq; 5 | using System.Text; 6 | using System.Threading.Tasks; 7 | using OfficeOpenXml; 8 | using XgbFeatureInteractions.Properties; 9 | 10 | namespace XgbFeatureInteractions 11 | { 12 | public class FeatureInteractions : Dictionary 13 | { 14 | public int MaxDepth 15 | { 16 | get 17 | { 18 | return this.Max(x => x.Value.Depth); 19 | } 20 | } 21 | 22 | public double TotalGain 23 | { 24 | get 25 | { 26 | return this.Sum(x => x.Value.Gain); 27 | } 28 | } 29 | 30 | public double TotalCover 31 | { 32 | get 33 | { 34 | return this.Sum(x => x.Value.Cover); 35 | } 36 | } 37 | 38 | public double TotalFScore 39 | { 40 | get 41 | { 42 | return this.Sum(x => x.Value.FScore); 43 | } 44 | } 45 | 46 | public FeatureInteractions() : base() 47 | { 48 | 49 | } 50 | 51 | public void Merge(FeatureInteractions interactions) 52 | { 53 | foreach(KeyValuePair fi in interactions) 54 | { 55 | if (!ContainsKey(fi.Key)) 56 | { 57 | Add(fi.Key, fi.Value); 58 | } 59 | else 60 | { 61 | this[fi.Key].Gain += fi.Value.Gain; 62 | this[fi.Key].Cover += fi.Value.Cover; 63 | this[fi.Key].FScore += fi.Value.FScore; 64 | this[fi.Key].FScoreWeighted += fi.Value.FScoreWeighted; 65 | this[fi.Key].AverageFScoreWeighted = this[fi.Key].FScoreWeighted / this[fi.Key].FScore; 66 | this[fi.Key].AverageGain = this[fi.Key].Gain / this[fi.Key].FScore; 67 | this[fi.Key].ExpectedGain += fi.Value.ExpectedGain; 68 | this[fi.Key].SumLeafCoversLeft += fi.Value.SumLeafCoversLeft; 69 | this[fi.Key].SumLeafCoversRight += fi.Value.SumLeafCoversRight; 70 | this[fi.Key].SumLeafValuesLeft += fi.Value.SumLeafValuesLeft; 71 | this[fi.Key].SumLeafValuesRight += fi.Value.SumLeafValuesRight; 72 | this[fi.Key].TreeIndex += fi.Value.TreeIndex; 73 | this[fi.Key].AverageTreeIndex = this[fi.Key].TreeIndex / this[fi.Key].FScore; 74 | this[fi.Key].TreeDepth += fi.Value.TreeDepth; 75 | this[fi.Key].AverageTreeDepth = this[fi.Key].TreeDepth / this[fi.Key].FScore; 76 | 77 | //if (fi.Value.Depth == 0) 78 | //{ 79 | // this[fi.Key].SplitValueHistogram.Merge(fi.Value.SplitValueHistogram); 80 | 81 | //} 82 | this[fi.Key].SplitValueHistogram.Merge(fi.Value.SplitValueHistogram); 83 | } 84 | } 85 | } 86 | public FeatureInteractions(IEnumerable> featureInteractions) : base() 87 | { 88 | foreach (var fi in featureInteractions) 89 | { 90 | Add(fi.Key, fi.Value); 91 | } 92 | } 93 | 94 | public FeatureInteractions GetFeatureInteractionsOfDepth(int depth) 95 | { 96 | return new FeatureInteractions(new FeatureInteractions(this.Where(x => x.Value.Depth == depth))); 97 | } 98 | 99 | public FeatureInteractions GetFeatureInteractionsWithLeafStatistics() 100 | { 101 | return new FeatureInteractions(new FeatureInteractions(this.Where(x => x.Value.HasLeafStatistics == true))); 102 | } 103 | 104 | public bool WriteToXlsx(string fileName, int topK) 105 | { 106 | try 107 | { 108 | if (File.Exists(fileName)) 109 | { 110 | File.Delete(fileName); 111 | } 112 | 113 | FileInfo newFile = new FileInfo(fileName); 114 | ExcelPackage pck = new ExcelPackage(newFile); 115 | 116 | Console.ResetColor(); 117 | Console.WriteLine("Writing {0}", fileName); 118 | 119 | List interactions = null; 120 | 121 | for (int depth=0; depth <= MaxDepth; depth++) 122 | { 123 | Console.ForegroundColor = ConsoleColor.DarkGreen; 124 | Console.WriteLine(String.Format("Writing feature interactions with depth {0} ", depth)); 125 | Console.ResetColor(); 126 | 127 | interactions = GetFeatureInteractionsOfDepth(depth).Values.ToList(); 128 | interactions.Sort(); 129 | 130 | double KTotalGain = interactions.Sum(x => x.Gain); 131 | double TotalCover = interactions.Sum(x => x.Cover); 132 | double TotalFScore = interactions.Sum(x => x.FScore); 133 | double TotalFScoreWeighted = interactions.Sum(x => x.FScoreWeighted); 134 | double TotalFScoreWeightedAverage = interactions.Sum(x => x.AverageFScoreWeighted); 135 | 136 | if (topK > 0) 137 | { 138 | interactions = interactions.Take(topK).ToList(); 139 | } 140 | 141 | 142 | if (interactions.Count == 0) 143 | { 144 | break; 145 | } 146 | 147 | var ws = pck.Workbook.Worksheets.Add(String.Format("Interaction Depth {0}",depth)); 148 | 149 | ws.Row(1).Height = 20; 150 | ws.Row(1).Style.Font.Bold = true; 151 | ws.Row(1).Style.VerticalAlignment = OfficeOpenXml.Style.ExcelVerticalAlignment.Center; 152 | ws.Row(1).Style.HorizontalAlignment = OfficeOpenXml.Style.ExcelHorizontalAlignment.Center; 153 | ws.Column(1).Style.VerticalAlignment = OfficeOpenXml.Style.ExcelVerticalAlignment.Center; 154 | ws.Column(1).Style.HorizontalAlignment = OfficeOpenXml.Style.ExcelHorizontalAlignment.Center; 155 | ws.Column(1).Width = interactions.Max(x => x.Name.Length) + 10; 156 | ws.Column(2).Width = 17; 157 | ws.Column(3).Width = 17; 158 | ws.Column(4).Width = 17; 159 | ws.Column(5).Width = 17; 160 | ws.Column(6).Width = 17; 161 | ws.Column(7).Width = 17; 162 | ws.Column(8).Width = 17; 163 | ws.Column(9).Width = 17; 164 | ws.Column(10).Width = 17; 165 | ws.Column(11).Width = 18; 166 | ws.Column(12).Width = 18; 167 | ws.Column(13).Width = 19; 168 | ws.Column(14).Width = 17; 169 | ws.Column(15).Width = 19; 170 | ws.Column(16).Width = 19; 171 | 172 | ws.Cells[1, 1].Value = "Interaction"; 173 | ws.Cells[1, 2].Value = "Gain"; 174 | ws.Cells[1, 3].Value = "FScore"; 175 | ws.Cells[1, 4].Value = "wFScore"; 176 | ws.Cells[1, 5].Value = "Average wFScore"; 177 | ws.Cells[1, 6].Value = "Average Gain"; 178 | ws.Cells[1, 7].Value = "Expected Gain"; 179 | ws.Cells[1, 8].Value = "Gain Rank"; 180 | ws.Cells[1, 9].Value = "FScore Rank"; 181 | ws.Cells[1, 10].Value = "wFScore Rank"; 182 | ws.Cells[1, 11].Value = "Avg wFScore Rank"; 183 | ws.Cells[1, 12].Value = "Avg Gain Rank"; 184 | ws.Cells[1, 13].Value = "Expected Gain Rank"; 185 | ws.Cells[1, 14].Value = "Average Rank"; 186 | ws.Cells[1, 15].Value = "Average Tree Index"; 187 | ws.Cells[1, 16].Value = "Average Tree Depth"; 188 | 189 | var gainSorted = interactions.OrderBy(x => -x.Gain).ToList(); 190 | var fScoreSorted = interactions.OrderBy(x => -x.FScore).ToList(); 191 | var fScoreWeightedSorted = interactions.OrderBy(x => -x.FScoreWeighted).ToList(); 192 | var averagefScoreWeightedSorted = interactions.OrderBy(x => -x.AverageFScoreWeighted).ToList(); 193 | var averageGainSorted = interactions.OrderBy(x => -x.AverageGain).ToList(); 194 | var expectedGainSorted = interactions.OrderBy(x => -x.ExpectedGain).ToList(); 195 | 196 | List excelData = new List(); 197 | 198 | Func _formatNumber = (x) => 199 | { 200 | return Double.Parse(String.Format("{0:0.00}", x)); 201 | }; 202 | 203 | foreach (FeatureInteraction fi in interactions) 204 | { 205 | List rowValues = new List(); 206 | 207 | rowValues.Add(fi.Name); //1 208 | rowValues.Add(fi.Gain); //2 209 | rowValues.Add(fi.FScore); //3 210 | rowValues.Add(fi.FScoreWeighted); //4 211 | rowValues.Add(fi.AverageFScoreWeighted); //5 212 | rowValues.Add(fi.AverageGain); //6 213 | rowValues.Add(fi.ExpectedGain); //7 214 | rowValues.Add(gainSorted.FindIndex(x => x.Name == fi.Name) + 1); //8 215 | rowValues.Add(fScoreSorted.FindIndex(x => x.Name == fi.Name) + 1); //9 216 | rowValues.Add(fScoreWeightedSorted.FindIndex(x => x.Name == fi.Name) + 1); //10 217 | rowValues.Add(averagefScoreWeightedSorted.FindIndex(x => x.Name == fi.Name) + 1); //11 218 | rowValues.Add(averageGainSorted.FindIndex(x => x.Name == fi.Name) + 1); //12 219 | rowValues.Add(expectedGainSorted.FindIndex(x => x.Name == fi.Name) + 1); //13 220 | rowValues.Add(_formatNumber(rowValues.Skip(7).Average(x => Double.Parse(x.ToString())))); //14 221 | rowValues.Add(_formatNumber(fi.AverageTreeIndex)); //15 222 | rowValues.Add(_formatNumber(fi.AverageTreeDepth)); //16 223 | 224 | excelData.Add(rowValues.ToArray()); 225 | 226 | } 227 | ws.Cells["A2"].LoadFromArrays(excelData); 228 | } 229 | 230 | interactions = GetFeatureInteractionsWithLeafStatistics().Values.ToList(); 231 | if(interactions.Count > 0) 232 | { 233 | Console.ForegroundColor = ConsoleColor.DarkGreen; 234 | Console.WriteLine(String.Format("Writing leaf statistics")); 235 | Console.ResetColor(); 236 | 237 | interactions.Sort(); 238 | 239 | var ws = pck.Workbook.Worksheets.Add(String.Format("Leaf Statistics")); 240 | 241 | ws.Row(1).Height = 20; 242 | ws.Row(1).Style.Font.Bold = true; 243 | ws.Row(1).Style.VerticalAlignment = OfficeOpenXml.Style.ExcelVerticalAlignment.Center; 244 | ws.Row(1).Style.HorizontalAlignment = OfficeOpenXml.Style.ExcelHorizontalAlignment.Center; 245 | ws.Column(1).Style.VerticalAlignment = OfficeOpenXml.Style.ExcelVerticalAlignment.Center; 246 | ws.Column(1).Style.HorizontalAlignment = OfficeOpenXml.Style.ExcelHorizontalAlignment.Center; 247 | ws.Column(1).Width = interactions.Max(x => x.Name.Length) + 10; 248 | ws.Column(2).Width = 20; 249 | ws.Column(3).Width = 20; 250 | ws.Column(4).Width = 20; 251 | ws.Column(5).Width = 20; 252 | 253 | ws.Cells[1, 1].Value = "Interaction"; 254 | ws.Cells[1, 2].Value = "Sum Leaf Values Left"; 255 | ws.Cells[1, 3].Value = "Sum Leaf Values Right"; 256 | ws.Cells[1, 4].Value = "Sum Leaf Covers Left"; 257 | ws.Cells[1, 5].Value = "Sum Leaf Covers Right"; 258 | 259 | List excelData = new List(); 260 | 261 | foreach (FeatureInteraction fi in interactions) 262 | { 263 | List rowValues = new List(); 264 | 265 | rowValues.Add(fi.Name); //1 266 | rowValues.Add(fi.SumLeafValuesLeft); //2 267 | rowValues.Add(fi.SumLeafValuesRight); //3 268 | rowValues.Add(fi.SumLeafCoversLeft); //4 269 | rowValues.Add(fi.SumLeafCoversRight); //5 270 | 271 | excelData.Add(rowValues.ToArray()); 272 | 273 | } 274 | ws.Cells["A2"].LoadFromArrays(excelData); 275 | 276 | } 277 | 278 | interactions = GetFeatureInteractionsOfDepth(0).Values.ToList(); 279 | if (interactions.Count > 0) 280 | { 281 | Console.ForegroundColor = ConsoleColor.DarkGreen; 282 | Console.WriteLine(String.Format("Writing split value histograms")); 283 | Console.ResetColor(); 284 | 285 | interactions.Sort(); 286 | 287 | var ws = pck.Workbook.Worksheets.Add(String.Format("Split Value Histograms")); 288 | 289 | ws.Row(1).Height = 20; 290 | ws.Row(1).Style.Font.Bold = true; 291 | ws.Row(1).Style.VerticalAlignment = OfficeOpenXml.Style.ExcelVerticalAlignment.Center; 292 | ws.Row(1).Style.HorizontalAlignment = OfficeOpenXml.Style.ExcelHorizontalAlignment.Center; 293 | 294 | for(int i = 0; i < interactions.Count; i++) 295 | { 296 | if(i == GlobalSettings.MaxHistograms) 297 | { 298 | break; 299 | } 300 | var fi = interactions[i]; 301 | int c1 = i * 2 + 1; 302 | int c2 = c1 + 1; 303 | ws.Cells[1, c1, 1, c2].Merge = true; 304 | ws.Cells[1, c1 , 1, c2].Value = fi.Name; 305 | ws.Column(c1).Width = Math.Max(10, (fi.Name.Length + 4) / 2); 306 | ws.Column(c2).Width = Math.Max(10, (fi.Name.Length + 4) / 2); 307 | 308 | int row = 2; 309 | foreach(KeyValuePair kvp in fi.SplitValueHistogram) 310 | { 311 | ws.Cells[row, c1].Value = kvp.Key; 312 | ws.Cells[row, c2].Value = kvp.Value; 313 | row++; 314 | } 315 | } 316 | 317 | } 318 | 319 | 320 | pck.Save(); 321 | Console.ForegroundColor = ConsoleColor.Green; 322 | Console.WriteLine("{0} has been written.", fileName); 323 | Console.ResetColor(); 324 | return true; 325 | } 326 | catch (Exception e) 327 | { 328 | Console.ForegroundColor = ConsoleColor.Yellow; 329 | Console.WriteLine("ERROR: {0}", e.Message); 330 | Console.ResetColor(); 331 | return false; 332 | } 333 | 334 | } 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/FeatureScoreComparer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | 7 | namespace XgbFeatureInteractions 8 | { 9 | 10 | public enum SortingMetric 11 | { 12 | Gain, 13 | FScore, 14 | FScoreWeighted, 15 | FScoreWeightedAverage, 16 | AverageGain, 17 | ExpectedGain 18 | } 19 | 20 | static class FeatureScoreComparer 21 | { 22 | private static Func _comparer { set; get; } 23 | 24 | public static void SetComparer(string sortingMetric) 25 | { 26 | 27 | } 28 | 29 | public static void SetComparer(SortingMetric sortingMetric) 30 | { 31 | switch(sortingMetric) 32 | { 33 | 34 | case SortingMetric.Gain: 35 | default: 36 | _comparer = (a, b) => 37 | { 38 | return -a.Gain.CompareTo(b.Gain); 39 | }; 40 | break; 41 | case SortingMetric.FScore: 42 | _comparer = (a, b) => 43 | { 44 | return -a.FScore.CompareTo(b.FScore); 45 | }; 46 | break; 47 | case SortingMetric.FScoreWeighted: 48 | _comparer = (a, b) => 49 | { 50 | return -a.FScoreWeighted.CompareTo(b.FScoreWeighted); 51 | }; 52 | break; 53 | case SortingMetric.FScoreWeightedAverage: 54 | _comparer = (a, b) => 55 | { 56 | return -a.FScoreWeightedAverage.CompareTo(b.FScoreWeightedAverage); 57 | }; 58 | break; 59 | case SortingMetric.AverageGain: 60 | _comparer = (a, b) => 61 | { 62 | return -a.AverageGain.CompareTo(b.AverageGain); 63 | }; 64 | break; 65 | case SortingMetric.ExpectedGain: 66 | _comparer = (a, b) => 67 | { 68 | return -a.ExpectedGain.CompareTo(b.ExpectedGain); 69 | }; 70 | break; 71 | 72 | } 73 | } 74 | 75 | static FeatureScoreComparer() 76 | { 77 | SetComparer(SortingMetric.Gain); 78 | } 79 | 80 | public static int Compare(FeatureInteraction a, FeatureInteraction b) 81 | { 82 | return _comparer(a, b); 83 | } 84 | 85 | 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/GlobalSettings.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | using XgbFeatureInteractions.Properties; 7 | 8 | namespace XgbFeatureInteractions 9 | { 10 | public static class GlobalSettings 11 | { 12 | public static string XgbModelFile { get; set; } 13 | public static string OutputXlsxFile { get; set; } 14 | public static int MaxInteractionDepth { get; set; } 15 | public static int TopK { get; set; } 16 | public static int MaxTrees { get; set; } 17 | public static int MaxDeepening { get; set; } 18 | public static string SortBy { get; set; } 19 | public static int MaxHistograms { get; set; } 20 | 21 | static GlobalSettings() { 22 | XgbModelFile = Settings.Default.XgbModelFile.Replace("\"", ""); 23 | OutputXlsxFile = Settings.Default.OutputXlsxFile.Replace("\"", ""); 24 | MaxInteractionDepth = Settings.Default.MaxInteractionDepth; 25 | TopK = Settings.Default.TopK; 26 | MaxTrees = Settings.Default.MaxTrees; 27 | MaxDeepening = Settings.Default.MaxDeepening; 28 | SortBy = Settings.Default.SortBy; 29 | MaxHistograms = Settings.Default.MaxHistograms; 30 | } 31 | 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/GlobalStats.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | 7 | namespace XgbFeatureInteractions 8 | { 9 | public static class GlobalStats 10 | { 11 | public static int ParsedTrees { get; set; } 12 | public static long CollectedFeatureInteractions { get; set; } 13 | public static TimeSpan ElapsedTime { get; set; } 14 | public static long ModelFileSize { get; set; } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/Program.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Globalization; 4 | using System.Linq; 5 | using System.Text; 6 | using System.Text.RegularExpressions; 7 | using System.Threading; 8 | using System.Threading.Tasks; 9 | using XgbFeatureInteractions.Properties; 10 | 11 | namespace XgbFeatureInteractions 12 | { 13 | class Program 14 | { 15 | 16 | static void Main(string[] args) 17 | { 18 | Console.ForegroundColor = ConsoleColor.Green; 19 | Console.WriteLine("********************************"); 20 | Console.WriteLine("* XGBOOST Feature Interactions *"); 21 | Console.WriteLine("********************************"); 22 | Console.ResetColor(); 23 | 24 | ParseArgs(args); 25 | FIScoreComparer.SetComparer(GlobalSettings.SortBy); 26 | 27 | Console.WriteLine(""); 28 | 29 | Console.WriteLine("Settings:"); 30 | Console.WriteLine("========="); 31 | Console.WriteLine(String.Format("XgbModelFile: {0}", GlobalSettings.XgbModelFile)); 32 | Console.WriteLine(String.Format("OutputXlsxFile: {0}", GlobalSettings.OutputXlsxFile)); 33 | Console.WriteLine(String.Format("MaxInteractionDepth: {0}", GlobalSettings.MaxInteractionDepth)); 34 | Console.WriteLine(String.Format("MaxDeepening: {0}", GlobalSettings.MaxDeepening)); 35 | Console.WriteLine(String.Format("MaxTrees: {0}", GlobalSettings.MaxTrees)); 36 | Console.WriteLine(String.Format("TopK: {0}", GlobalSettings.TopK)); 37 | Console.WriteLine(String.Format("SortBy: {0}", FIScoreComparer.SortBy)); 38 | Console.WriteLine(String.Format("MaxHistograms: {0}", GlobalSettings.MaxHistograms)); 39 | Console.WriteLine(); 40 | 41 | if(args.Length == 0) 42 | { 43 | Thread t1 = new Thread(() => 44 | { 45 | for (int c = 3; c >= 0; c--) 46 | { 47 | Console.Write("\rStarting in {0:00} seconds with the settings above. Press any key to abort ...", c); 48 | if (c == 0) break; 49 | for (int i = 0; i < 20; i++) 50 | { 51 | if (Console.KeyAvailable) 52 | { 53 | Console.ReadKey(true); 54 | Console.WriteLine(); 55 | Environment.Exit(0); 56 | } 57 | Thread.Sleep(50); 58 | } 59 | } 60 | 61 | return; 62 | }); 63 | 64 | t1.Start(); 65 | t1.Join(); 66 | Console.WriteLine("\n"); 67 | } 68 | 69 | var start = DateTime.Now; 70 | XgbModel xgbModel = XgbModelParser.GetXgbModelFromFile(GlobalSettings.XgbModelFile, GlobalSettings.MaxTrees); 71 | 72 | if(xgbModel == null) 73 | { 74 | Environment.Exit(-1); 75 | } 76 | 77 | var featureInteractions = xgbModel.GetFeatureInteractions(GlobalSettings.MaxInteractionDepth, GlobalSettings.MaxDeepening); 78 | 79 | var end = DateTime.Now; 80 | GlobalStats.ElapsedTime = (end - start); 81 | GlobalStats.ParsedTrees = xgbModel.NumTrees; 82 | GlobalStats.CollectedFeatureInteractions = featureInteractions.Count; 83 | 84 | featureInteractions.WriteToXlsx(GlobalSettings.OutputXlsxFile, GlobalSettings.TopK); 85 | 86 | end = DateTime.Now; 87 | Console.WriteLine(String.Format("Elapsed Time: {0}", (end - start))); 88 | 89 | } 90 | 91 | static void ParseArgs(string[] args) 92 | { 93 | if (args.Length == 0) 94 | { 95 | PrintHelp(); 96 | return; 97 | } 98 | 99 | string cmds = string.Join(" ", args); 100 | Match m = null; 101 | 102 | if(cmds.Contains("-help")) 103 | { 104 | PrintHelp(); 105 | Environment.Exit(0); 106 | } 107 | 108 | m = Regex.Match(cmds, @"-m\s([^\s]*)"); 109 | if (m.Success) 110 | { 111 | var model_file = m.Groups[1].Value; 112 | GlobalSettings.XgbModelFile = model_file; 113 | } 114 | m = Regex.Match(cmds, @"-o\s([^\s]*)"); 115 | if(m.Success) { 116 | var output_file = m.Groups[1].Value; 117 | if(!output_file.EndsWith(".xslx")) 118 | { 119 | output_file = output_file + ".xlsx"; 120 | } 121 | GlobalSettings.OutputXlsxFile = output_file; 122 | } 123 | 124 | m = Regex.Match(cmds, @"-d\s([^\s]*)"); 125 | if (m.Success) 126 | { 127 | var max_depth = m.Groups[1].Value; 128 | var tmp = 0; 129 | int.TryParse(max_depth, out tmp); 130 | GlobalSettings.MaxInteractionDepth = tmp; 131 | } 132 | m = Regex.Match(cmds, @"-g\s([^\s]*)"); 133 | if (m.Success) 134 | { 135 | var max_deepening = m.Groups[1].Value; 136 | var tmp = 0; 137 | int.TryParse(max_deepening, out tmp); 138 | GlobalSettings.MaxDeepening = tmp; 139 | } 140 | m = Regex.Match(cmds, @"-t\s([^\s]*)"); 141 | if (m.Success) 142 | { 143 | var ntrees = m.Groups[1].Value; 144 | var tmp = 0; 145 | int.TryParse(ntrees, out tmp); 146 | GlobalSettings.MaxTrees = tmp; 147 | } 148 | m = Regex.Match(cmds, @"-k\s([^\s]*)"); 149 | if (m.Success) 150 | { 151 | var k = m.Groups[1].Value; 152 | var tmp = 0; 153 | int.TryParse(k, out tmp); 154 | GlobalSettings.TopK = tmp; 155 | } 156 | m = Regex.Match(cmds, @"-s\s([^\s]*)"); 157 | if (m.Success) 158 | { 159 | GlobalSettings.SortBy = m.Groups[1].Value; 160 | } 161 | m = Regex.Match(cmds, @"-h\s([^\s]*)"); 162 | if (m.Success) 163 | { 164 | var h = m.Groups[1].Value; 165 | var tmp = 0; 166 | int.TryParse(h, out tmp); 167 | GlobalSettings.MaxHistograms = tmp; 168 | } 169 | 170 | return; 171 | } 172 | 173 | static void PrintHelp() 174 | { 175 | Console.ResetColor(); 176 | Console.WriteLine("Usage: XgbFeatureInteractions.exe [Options]"); 177 | Console.WriteLine("\nOptions:"); 178 | Console.WriteLine("\t-m Xgboost model dump (dumped w/ 'with_stats=True')"); 179 | Console.WriteLine("\t-d Upper bound for extracted feature interactions depth"); 180 | Console.WriteLine("\t-g Upper bound for interaction start deepening (zero deepening => interactions starting @root only)"); 181 | Console.WriteLine("\t-t Upper bound for trees to be parsed"); 182 | Console.WriteLine("\t-k Upper bound for exportet feature interactions per depth level"); 183 | Console.WriteLine("\t-s Score metric to sort by (Gain, FScore, wFScore, AvgwFScore, AvgGain, ExpGain)"); 184 | Console.WriteLine("\t-o Xlsx file to be written"); 185 | Console.WriteLine("\t-h Amounts of split value histograms"); 186 | } 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | using System.Runtime.InteropServices; 4 | 5 | // General Information about an assembly is controlled through the following 6 | // set of attributes. Change these attribute values to modify the information 7 | // associated with an assembly. 8 | [assembly: AssemblyTitle("XgbFeatureInteractions")] 9 | [assembly: AssemblyDescription("XgbFeatureInteractions by Faron")] 10 | [assembly: AssemblyConfiguration("")] 11 | [assembly: AssemblyCompany("")] 12 | [assembly: AssemblyProduct("XgbFeatureInteractions")] 13 | [assembly: AssemblyCopyright("Copyright © 2015")] 14 | [assembly: AssemblyTrademark("")] 15 | [assembly: AssemblyCulture("")] 16 | 17 | // Setting ComVisible to false makes the types in this assembly not visible 18 | // to COM components. If you need to access a type in this assembly from 19 | // COM, set the ComVisible attribute to true on that type. 20 | [assembly: ComVisible(false)] 21 | 22 | // The following GUID is for the ID of the typelib if this project is exposed to COM 23 | [assembly: Guid("09b47150-eeb5-4416-8d9a-7258cb0c717b")] 24 | 25 | // Version information for an assembly consists of the following four values: 26 | // 27 | // Major Version 28 | // Minor Version 29 | // Build Number 30 | // Revision 31 | // 32 | // You can specify all the values or you can default the Build and Revision Numbers 33 | // by using the '*' as shown below: 34 | // [assembly: AssemblyVersion("1.0.*")] 35 | [assembly: AssemblyVersion("1.0.0.0")] 36 | [assembly: AssemblyFileVersion("1.0.0.0")] 37 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/Properties/Settings.Designer.cs: -------------------------------------------------------------------------------- 1 | //------------------------------------------------------------------------------ 2 | // 3 | // This code was generated by a tool. 4 | // Runtime Version:4.0.30319.42000 5 | // 6 | // Changes to this file may cause incorrect behavior and will be lost if 7 | // the code is regenerated. 8 | // 9 | //------------------------------------------------------------------------------ 10 | 11 | namespace XgbFeatureInteractions.Properties { 12 | 13 | 14 | [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] 15 | [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.Editors.SettingsDesigner.SettingsSingleFileGenerator", "14.0.0.0")] 16 | internal sealed partial class Settings : global::System.Configuration.ApplicationSettingsBase { 17 | 18 | private static Settings defaultInstance = ((Settings)(global::System.Configuration.ApplicationSettingsBase.Synchronized(new Settings()))); 19 | 20 | public static Settings Default { 21 | get { 22 | return defaultInstance; 23 | } 24 | } 25 | 26 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 27 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 28 | [global::System.Configuration.DefaultSettingValueAttribute("xgb.model")] 29 | public string XgbModelFile { 30 | get { 31 | return ((string)(this["XgbModelFile"])); 32 | } 33 | } 34 | 35 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 36 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 37 | [global::System.Configuration.DefaultSettingValueAttribute("XgbFeatureInteractions.xlsx")] 38 | public string OutputXlsxFile { 39 | get { 40 | return ((string)(this["OutputXlsxFile"])); 41 | } 42 | } 43 | 44 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 45 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 46 | [global::System.Configuration.DefaultSettingValueAttribute("100")] 47 | public int TopK { 48 | get { 49 | return ((int)(this["TopK"])); 50 | } 51 | } 52 | 53 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 54 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 55 | [global::System.Configuration.DefaultSettingValueAttribute("-1")] 56 | public int MaxDeepening { 57 | get { 58 | return ((int)(this["MaxDeepening"])); 59 | } 60 | } 61 | 62 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 63 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 64 | [global::System.Configuration.DefaultSettingValueAttribute("Gain")] 65 | public string SortBy { 66 | get { 67 | return ((string)(this["SortBy"])); 68 | } 69 | } 70 | 71 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 72 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 73 | [global::System.Configuration.DefaultSettingValueAttribute("-1")] 74 | public int MaxTrees { 75 | get { 76 | return ((int)(this["MaxTrees"])); 77 | } 78 | } 79 | 80 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 81 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 82 | [global::System.Configuration.DefaultSettingValueAttribute("2")] 83 | public int MaxInteractionDepth { 84 | get { 85 | return ((int)(this["MaxInteractionDepth"])); 86 | } 87 | } 88 | 89 | [global::System.Configuration.ApplicationScopedSettingAttribute()] 90 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 91 | [global::System.Configuration.DefaultSettingValueAttribute("10")] 92 | public int MaxHistograms { 93 | get { 94 | return ((int)(this["MaxHistograms"])); 95 | } 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/Properties/Settings.settings: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | xgb.model 7 | 8 | 9 | XgbFeatureInteractions.xlsx 10 | 11 | 12 | 100 13 | 14 | 15 | -1 16 | 17 | 18 | Gain 19 | 20 | 21 | -1 22 | 23 | 24 | 2 25 | 26 | 27 | 10 28 | 29 | 30 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/SplitValueHistogram.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | 6 | namespace XgbFeatureInteractions 7 | { 8 | public class SplitValueHistogram : SortedDictionary 9 | { 10 | public SplitValueHistogram() : base() 11 | { 12 | 13 | } 14 | 15 | public void AddValue(double splitValue, double count=1) 16 | { 17 | if(!this.ContainsKey(splitValue)){ 18 | this.Add(splitValue, 0); 19 | } 20 | this[splitValue] += count; 21 | } 22 | 23 | public void Merge(SortedDictionary histogram) 24 | { 25 | foreach(var kvp in histogram) 26 | { 27 | this.AddValue(kvp.Key, kvp.Value); 28 | } 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/XgbFeatureInteractions.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | AnyCPU 7 | {09B47150-EEB5-4416-8D9A-7258CB0C717B} 8 | Exe 9 | Properties 10 | XgbFeatureInteractions 11 | XgbFeatureInteractions 12 | v4.5.2 13 | 512 14 | true 15 | 16 | 17 | AnyCPU 18 | true 19 | full 20 | false 21 | bin\Debug\ 22 | DEBUG;TRACE 23 | prompt 24 | 4 25 | 26 | 27 | x64 28 | pdbonly 29 | true 30 | bin\Release\ 31 | TRACE 32 | prompt 33 | 4 34 | 35 | 36 | 37 | ..\packages\EPPlus.4.0.4\lib\net20\EPPlus.dll 38 | True 39 | 40 | 41 | ..\packages\NGenerics.1.4.1.0\lib\net35\NGenerics.dll 42 | True 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | True 62 | True 63 | Settings.settings 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | SettingsSingleFileGenerator 77 | Settings.Designer.cs 78 | 79 | 80 | 81 | 82 | rd /S /Q md $(TargetDir)lib 83 | 84 | 85 | md $(TargetDir)lib 86 | move /Y "$(TargetDir)*.dll" "$(TargetDir)lib" 87 | 88 | 95 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/XgbModel.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.IO; 4 | using System.Linq; 5 | using System.Text; 6 | using System.Threading.Tasks; 7 | using XgbFeatureInteractions.Properties; 8 | 9 | namespace XgbFeatureInteractions 10 | { 11 | public class XgbModel 12 | { 13 | public List XgbTrees { get; set; } 14 | private FeatureInteractions _treeFeatureInteractions { get; set; } 15 | private int _maxInteractionDepth { get; set; } 16 | private HashSet _pathMemo { get; set; } 17 | private double _treeIndex { get; set; } 18 | private int _maxDeepening { get; set; } 19 | public int NumTrees 20 | { 21 | get { return XgbTrees.Count; } 22 | } 23 | public XgbModel() 24 | { 25 | XgbTrees = new List(); 26 | } 27 | public FeatureInteractions GetFeatureInteractions(int maxInteractionDepth = -1, int maxDeepening = -1) 28 | { 29 | FeatureInteractions xgbFeatureInteractions = new FeatureInteractions(); 30 | _maxInteractionDepth = maxInteractionDepth; 31 | _maxDeepening = maxDeepening; 32 | 33 | Console.ResetColor(); 34 | if(_maxInteractionDepth == -1) 35 | Console.WriteLine(String.Format("Collectiong feature interactions")); 36 | else 37 | Console.WriteLine(String.Format("Collectiong feature interactions up to depth {0}", _maxInteractionDepth)); 38 | 39 | for (int i = 0; i < NumTrees; i++) 40 | { 41 | Console.ForegroundColor = ConsoleColor.DarkGreen; 42 | Console.Write(String.Format("Collectiong feature interactions within tree #{0} ", i+1)); 43 | 44 | _treeFeatureInteractions = new FeatureInteractions(); 45 | _pathMemo = new HashSet(); 46 | _treeIndex = i; 47 | CollectFeatureInteractions(XgbTrees[i], new HashSet() , currentGain: 0, currentCover: 0, pathProbability: 1, depth: 0, deepening: 0); 48 | 49 | //double treeGain = _treeFeatureInteractions.GetFeatureInteractionsOfDepth(0).Sum(x => x.Value.Gain); 50 | //foreach (KeyValuePair fi in _treeFeatureInteractions) 51 | //{ 52 | // fi.Value.Gain /= treeGain; 53 | //} 54 | 55 | Console.WriteLine(String.Format("=> number of interactions: {0}", _treeFeatureInteractions.Count)); 56 | Console.ResetColor(); 57 | xgbFeatureInteractions.Merge(_treeFeatureInteractions); 58 | } 59 | 60 | Console.ForegroundColor = ConsoleColor.Green; 61 | Console.WriteLine(String.Format("{0} feature interactions has been collected.\n", xgbFeatureInteractions.Count)); 62 | Console.ResetColor(); 63 | 64 | return xgbFeatureInteractions; 65 | } 66 | private void CollectFeatureInteractions(XgbTree tree, HashSet currentInteraction, double currentGain, double currentCover, double pathProbability, int depth, int deepening) 67 | { 68 | if (tree.IsLeafNode) 69 | { 70 | return; 71 | } 72 | 73 | currentInteraction.Add(tree.Data); 74 | currentGain += tree.Data.Gain; 75 | currentCover += tree.Data.Cover; 76 | 77 | var pathProbabilityLeft = pathProbability * (((XgbTree)tree.Left).Data.Cover / tree.Data.Cover); 78 | var pathProbabilityRight = pathProbability * (((XgbTree)tree.Right).Data.Cover / tree.Data.Cover); 79 | 80 | var fi = new FeatureInteraction(currentInteraction, currentGain, currentCover, pathProbability, depth, _treeIndex, 1); 81 | 82 | if (depth < _maxDeepening || _maxDeepening < 0) 83 | { 84 | var newInteractionLeft = new HashSet() { }; 85 | var newInteractionRight = new HashSet() { }; 86 | 87 | CollectFeatureInteractions((XgbTree)tree.Left, newInteractionLeft, 0, 0, pathProbabilityLeft, depth + 1, deepening + 1); 88 | CollectFeatureInteractions((XgbTree)tree.Right, newInteractionRight, 0, 0, pathProbabilityRight, depth + 1, deepening + 1); 89 | } 90 | 91 | var path = string.Join("-", currentInteraction.Select(x => x.Number)); 92 | 93 | if (!_treeFeatureInteractions.ContainsKey(fi.Name)) 94 | { 95 | _treeFeatureInteractions.Add(fi.Name, fi); 96 | _pathMemo.Add(path); 97 | } 98 | else 99 | { 100 | if(_pathMemo.Contains(path)) 101 | { 102 | return; 103 | } 104 | 105 | _pathMemo.Add(path); 106 | var tfi = _treeFeatureInteractions[fi.Name]; 107 | tfi.Gain += currentGain; 108 | tfi.Cover += currentCover; 109 | tfi.FScore += 1; 110 | tfi.FScoreWeighted += pathProbability; 111 | tfi.AverageFScoreWeighted = tfi.FScoreWeighted / tfi.FScore; 112 | tfi.AverageGain = tfi.Gain / tfi.FScore; 113 | tfi.ExpectedGain += currentGain * pathProbability; 114 | tfi.TreeDepth += depth; 115 | tfi.AverageTreeDepth = tfi.TreeDepth / tfi.FScore; 116 | tfi.TreeIndex += _treeIndex; 117 | tfi.AverageTreeIndex = tfi.TreeIndex / tfi.FScore; 118 | tfi.SplitValueHistogram.Merge(fi.SplitValueHistogram); 119 | } 120 | 121 | if (currentInteraction.Count - 1 == _maxInteractionDepth) 122 | return; 123 | 124 | 125 | var currentInteractionLeft = new HashSet(currentInteraction); 126 | var currentInteractionRight = new HashSet(currentInteraction); 127 | 128 | var leftTree = (XgbTree)(tree.Left); 129 | var rightTree = (XgbTree)(tree.Right); 130 | 131 | if (leftTree.IsLeafNode && deepening == 0) 132 | { 133 | var tfi = _treeFeatureInteractions[fi.Name]; 134 | tfi.SumLeafValuesLeft += leftTree.Data.LeafValue; 135 | tfi.SumLeafCoversLeft += leftTree.Data.Cover; 136 | tfi.HasLeafStatistics = true; 137 | } 138 | 139 | if (rightTree.IsLeafNode && deepening == 0) 140 | { 141 | var tfi = _treeFeatureInteractions[fi.Name]; 142 | tfi.SumLeafValuesRight += rightTree.Data.LeafValue; 143 | tfi.SumLeafCoversRight += rightTree.Data.Cover; 144 | tfi.HasLeafStatistics = true; 145 | } 146 | 147 | CollectFeatureInteractions((XgbTree)tree.Left, currentInteractionLeft, currentGain, currentCover, pathProbabilityLeft, depth + 1, deepening); 148 | CollectFeatureInteractions((XgbTree)tree.Right, currentInteractionRight, currentGain, currentCover, pathProbabilityRight, depth + 1, deepening); 149 | 150 | } 151 | 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/XgbModelParser.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Globalization; 4 | using System.IO; 5 | using System.Linq; 6 | using System.Text; 7 | using System.Text.RegularExpressions; 8 | using System.Threading.Tasks; 9 | using NGenerics.DataStructures.Trees; 10 | using XgbFeatureInteractions.Properties; 11 | 12 | namespace XgbFeatureInteractions 13 | { 14 | public static class XgbModelParser 15 | { 16 | private static Dictionary xgbNodeList = new Dictionary(); 17 | private static Regex nodeRegex = new Regex(@"(\d+):\[(.*)<(.+)\]\syes=(.*),no=(.*),missing=.*,gain=(.*),cover=(.*)", RegexOptions.Compiled); 18 | private static Regex leafRegex = new Regex(@"(\d+):leaf=(.*),cover=(.*)", RegexOptions.Compiled); 19 | 20 | public static XgbModel GetXgbModelFromFile(string fileName, int maxTrees) 21 | { 22 | XgbModel xgbModel = new XgbModel(); 23 | 24 | if (!File.Exists(fileName)) 25 | { 26 | Console.ForegroundColor = ConsoleColor.Yellow; 27 | Console.WriteLine(String.Format("Error: File {0} does not exist.",fileName)); 28 | Console.ResetColor(); 29 | return null; 30 | } 31 | 32 | 33 | Console.ResetColor(); 34 | Console.WriteLine(String.Format("Parsing {0}", fileName)); 35 | 36 | int numTree = 0; 37 | 38 | var fileInfo = new FileInfo(fileName); 39 | GlobalStats.ModelFileSize = fileInfo.Length; 40 | 41 | using (StreamReader sr = new StreamReader(fileName)) 42 | { 43 | while(!sr.EndOfStream) { 44 | 45 | 46 | var line = sr.ReadLine().Trim(); 47 | if (line.StartsWith("booster") || line == String.Empty) 48 | { 49 | if (xgbNodeList.Count > 0) 50 | { 51 | numTree++; 52 | Console.ForegroundColor = ConsoleColor.DarkGreen; 53 | Console.Write(String.Format("Constructing tree #{0} ", numTree)); 54 | 55 | 56 | XgbTree tree = new XgbTree(xgbNodeList[0]); 57 | ConstructXgbTree(tree); 58 | 59 | Console.WriteLine(String.Format("=> depth: {0} ({1} nodes)", tree.Height, xgbNodeList.Count)); 60 | Console.ResetColor(); 61 | 62 | xgbModel.XgbTrees.Add(tree); 63 | xgbNodeList.Clear(); 64 | if (numTree == maxTrees) break; 65 | } 66 | } 67 | else 68 | { 69 | var node = ParseXgbTreeNode(line); 70 | if (node == null) 71 | { 72 | return null; 73 | } 74 | xgbNodeList.Add(node.Number, node); 75 | } 76 | } 77 | } 78 | if (xgbNodeList.Count > 0 && (maxTrees < 0 || numTree < maxTrees)) 79 | { 80 | numTree++; 81 | Console.ForegroundColor = ConsoleColor.DarkGreen; 82 | Console.Write(String.Format("Constructing tree #{0} ", numTree)); 83 | 84 | 85 | XgbTree tree = new XgbTree(xgbNodeList[0]); 86 | ConstructXgbTree(tree); 87 | 88 | Console.WriteLine(String.Format("=> depth: {0} ({1} nodes)", tree.Height, xgbNodeList.Count)); 89 | Console.ResetColor(); 90 | 91 | xgbModel.XgbTrees.Add(tree); 92 | xgbNodeList.Clear(); 93 | } 94 | Console.ForegroundColor = ConsoleColor.Green; 95 | Console.WriteLine(String.Format("{0} trees has been constructed.\n", xgbModel.NumTrees)); 96 | Console.ResetColor(); 97 | return xgbModel; 98 | } 99 | 100 | private static XgbTreeNode ParseXgbTreeNode(string line) 101 | { 102 | var node = new XgbTreeNode(); 103 | try { 104 | if (line.Contains("leaf")) 105 | { 106 | Match m = leafRegex.Match(line); 107 | node.Number = Int32.Parse(m.Groups[1].Value); 108 | node.LeafValue = Double.Parse(m.Groups[2].Value, CultureInfo.InvariantCulture); 109 | node.Cover = Double.Parse(m.Groups[3].Value, CultureInfo.InvariantCulture); 110 | node.IsLeaf = true; 111 | } 112 | else 113 | { 114 | Match m = nodeRegex.Match(line); 115 | node.Number = Int32.Parse(m.Groups[1].Value); 116 | node.Feature = m.Groups[2].Value; 117 | node.SplitValue = Double.Parse(m.Groups[3].Value, CultureInfo.InvariantCulture); 118 | node.LeftChild = Int32.Parse(m.Groups[4].Value); 119 | node.RightChild = Int32.Parse(m.Groups[5].Value); 120 | node.Gain = Double.Parse(m.Groups[6].Value, CultureInfo.InvariantCulture); 121 | node.Cover = Double.Parse(m.Groups[7].Value, CultureInfo.InvariantCulture); 122 | node.IsLeaf = false; 123 | } 124 | } catch(Exception e) 125 | { 126 | Console.ForegroundColor = ConsoleColor.Yellow; 127 | Console.WriteLine(String.Format("Error: Invalid model file. Did you dump the model w/ with_stats=True?")); 128 | Console.WriteLine(String.Format("Unable to parse line '{0}'", line)); 129 | Console.WriteLine(e.Message); 130 | Console.ResetColor(); 131 | return null; 132 | } 133 | 134 | return node; 135 | } 136 | 137 | private static void ConstructXgbTree(XgbTree tree) 138 | { 139 | if (tree.Data.LeftChild != null) 140 | { 141 | tree.Add(new XgbTree(xgbNodeList[(int)tree.Data.LeftChild])); 142 | ConstructXgbTree((XgbTree)tree.Left); 143 | } 144 | 145 | 146 | if (tree.Data.RightChild != null) 147 | { 148 | tree.Add(new XgbTree(xgbNodeList[(int)tree.Data.RightChild])); 149 | ConstructXgbTree((XgbTree)tree.Right); 150 | } 151 | 152 | } 153 | 154 | 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/XgbTree.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | using NGenerics.DataStructures.Trees; 7 | 8 | namespace XgbFeatureInteractions 9 | { 10 | public class XgbTree : BinaryTree 11 | { 12 | public XgbTree(XgbTreeNode root) : base(root) 13 | { 14 | 15 | } 16 | 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/XgbTreeNode.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | 7 | namespace XgbFeatureInteractions 8 | { 9 | public class XgbTreeNode 10 | { 11 | public int Number { get; set; } 12 | public string Feature { get; set; } 13 | public double Gain { get; set; } 14 | public double Cover { get; set; } 15 | public int? LeftChild { get; set; } 16 | public int? RightChild { get; set; } 17 | public bool IsLeaf { get; set; } 18 | public double SplitValue { get; set; } 19 | public double LeafValue { get; set; } 20 | 21 | 22 | public XgbTreeNode() 23 | { 24 | Feature = String.Empty; 25 | Gain = 0; 26 | Cover = 0; 27 | Number = -1; 28 | LeftChild = null; 29 | RightChild = null; 30 | LeafValue = 0; 31 | SplitValue = 0; 32 | IsLeaf = false; 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /XgbFeatureInteractions/packages.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /bin/XgbFeatureInteractions.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/bin/XgbFeatureInteractions.exe -------------------------------------------------------------------------------- /bin/XgbFeatureInteractions.exe.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |
6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | xgb.dump 20 | 21 | 22 | XgbFeatureInteractions.xlsx 23 | 24 | 25 | 100 26 | 27 | 28 | Gain 29 | 30 | 31 | -1 32 | 33 | 34 | 100 35 | 36 | 37 | 2 38 | 39 | 40 | 10 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /bin/lib/EPPlus.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/bin/lib/EPPlus.dll -------------------------------------------------------------------------------- /bin/lib/NGenerics.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/bin/lib/NGenerics.dll -------------------------------------------------------------------------------- /doc/ScoresExample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/doc/ScoresExample.png -------------------------------------------------------------------------------- /doc/ScoresExample_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/doc/ScoresExample_small.png -------------------------------------------------------------------------------- /xgbfi_cc/xgbfi.cc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/xgbfi_cc/xgbfi.cc -------------------------------------------------------------------------------- /xgbfi_cc/xgbfi.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Far0n/xgbfi/f779f7c7042557bac2289e2f40ea7e9475dc8473/xgbfi_cc/xgbfi.h --------------------------------------------------------------------------------