├── PdfPigMLNetBlockClassifier ├── sample.pdf ├── model │ ├── model.zip │ ├── modelV2.zip │ └── results.txt ├── PdfPigMLNetBlockClassifier.csproj └── Program.cs ├── PdfPigMLNetBlockClassifier.Data ├── PdfPigMLNetBlockClassifier.Data.csproj ├── v1 │ ├── FeatureHelper.cs │ └── DataGenerator.cs └── v2 │ ├── FeatureHelper.cs │ └── DataGenerator.cs ├── PdfPigMLNetBlockClassifier.LightGbm ├── BlockCategory.cs ├── PdfPigMLNetBlockClassifier.LightGbm.csproj ├── BlockFeatures.cs ├── LightGbmBlockClassifier.cs ├── README.md └── LightGbmModelBuilder.cs ├── PdfPigMLNetBlockClassifier.LightGbmV2 ├── BlockCategory.cs ├── PdfPigMLNetBlockClassifier.LightGbmV2.csproj ├── BlockFeatures.cs ├── LightGbmBlockClassifier.cs └── LightGbmModelBuilder.cs ├── .gitattributes ├── PdfPigMLNetBlockClassifier.sln ├── .gitignore └── README.md /PdfPigMLNetBlockClassifier/sample.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobLd/PdfPigMLNetBlockClassifier/HEAD/PdfPigMLNetBlockClassifier/sample.pdf -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier/model/model.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobLd/PdfPigMLNetBlockClassifier/HEAD/PdfPigMLNetBlockClassifier/model/model.zip -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier/model/modelV2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobLd/PdfPigMLNetBlockClassifier/HEAD/PdfPigMLNetBlockClassifier/model/modelV2.zip -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.Data/PdfPigMLNetBlockClassifier.Data.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | netstandard2.0 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbm/BlockCategory.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Data; 2 | using System; 3 | 4 | namespace PdfPigMLNetBlockClassifier.LightGbm 5 | { 6 | public class BlockCategory 7 | { 8 | // ColumnName attribute is used to change the column name from 9 | // its default value, which is the name of the field. 10 | [ColumnName("PredictedLabel")] 11 | public Single Prediction { get; set; } 12 | 13 | public float[] Score { get; set; } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbmV2/BlockCategory.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Data; 2 | using System; 3 | 4 | namespace PdfPigMLNetBlockClassifier.LightGbmV2 5 | { 6 | public class BlockCategory 7 | { 8 | // ColumnName attribute is used to change the column name from 9 | // its default value, which is the name of the field. 10 | [ColumnName("PredictedLabel")] 11 | public Single Prediction { get; set; } 12 | 13 | public float[] Score { get; set; } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbm/PdfPigMLNetBlockClassifier.LightGbm.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | netstandard2.0 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbmV2/PdfPigMLNetBlockClassifier.LightGbmV2.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | netstandard2.0 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier/PdfPigMLNetBlockClassifier.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Exe 5 | netcoreapp2.1 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | Always 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbm/BlockFeatures.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Data; 2 | 3 | namespace PdfPigMLNetBlockClassifier.LightGbm 4 | { 5 | public class BlockFeatures 6 | { 7 | [ColumnName("charsCount"), LoadColumn(0)] 8 | public float CharsCount { get; set; } 9 | 10 | 11 | [ColumnName("pctNumericChars"), LoadColumn(1)] 12 | public float PctNumericChars { get; set; } 13 | 14 | 15 | [ColumnName("pctAlphabeticalChars"), LoadColumn(2)] 16 | public float PctAlphabeticalChars { get; set; } 17 | 18 | 19 | [ColumnName("pctSymbolicChars"), LoadColumn(3)] 20 | public float PctSymbolicChars { get; set; } 21 | 22 | 23 | [ColumnName("pctBulletChars"), LoadColumn(4)] 24 | public float PctBulletChars { get; set; } 25 | 26 | 27 | [ColumnName("deltaToHeight"), LoadColumn(5)] 28 | public float DeltaToHeight { get; set; } 29 | 30 | 31 | [ColumnName("pathsCount"), LoadColumn(6)] 32 | public float PathsCount { get; set; } 33 | 34 | 35 | [ColumnName("pctBezierPaths"), LoadColumn(7)] 36 | public float PctBezierPaths { get; set; } 37 | 38 | 39 | [ColumnName("pctHorPaths"), LoadColumn(8)] 40 | public float PctHorPaths { get; set; } 41 | 42 | 43 | [ColumnName("pctVertPaths"), LoadColumn(9)] 44 | public float PctVertPaths { get; set; } 45 | 46 | 47 | [ColumnName("pctOblPaths"), LoadColumn(10)] 48 | public float PctOblPaths { get; set; } 49 | 50 | 51 | [ColumnName("imagesCount"), LoadColumn(11)] 52 | public float ImagesCount { get; set; } 53 | 54 | 55 | [ColumnName("imageAvgProportion"), LoadColumn(12)] 56 | public float ImageAvgProportion { get; set; } 57 | 58 | 59 | [ColumnName("label"), LoadColumn(13)] 60 | public float Label { get; set; } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbmV2/BlockFeatures.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Data; 2 | 3 | namespace PdfPigMLNetBlockClassifier.LightGbmV2 4 | { 5 | public class BlockFeatures 6 | { 7 | [ColumnName("blockAspectRatio"), LoadColumn(0)] 8 | public float BlockAspectRatio { get; set; } 9 | 10 | 11 | [ColumnName("charsCount"), LoadColumn(1)] 12 | public float CharsCount { get; set; } 13 | 14 | 15 | [ColumnName("wordsCount"), LoadColumn(2)] 16 | public float WordsCount { get; set; } 17 | 18 | 19 | [ColumnName("linesCount"), LoadColumn(3)] 20 | public float LinesCount { get; set; } 21 | 22 | 23 | [ColumnName("pctNumericChars"), LoadColumn(4)] 24 | public float PctNumericChars { get; set; } 25 | 26 | 27 | [ColumnName("pctAlphabeticalChars"), LoadColumn(5)] 28 | public float PctAlphabeticalChars { get; set; } 29 | 30 | 31 | [ColumnName("pctSymbolicChars"), LoadColumn(6)] 32 | public float PctSymbolicChars { get; set; } 33 | 34 | 35 | [ColumnName("pctBulletChars"), LoadColumn(7)] 36 | public float PctBulletChars { get; set; } 37 | 38 | 39 | [ColumnName("deltaToHeight"), LoadColumn(8)] 40 | public float DeltaToHeight { get; set; } 41 | 42 | 43 | [ColumnName("pathsCount"), LoadColumn(9)] 44 | public float PathsCount { get; set; } 45 | 46 | 47 | [ColumnName("pctBezierPaths"), LoadColumn(10)] 48 | public float PctBezierPaths { get; set; } 49 | 50 | 51 | [ColumnName("pctHorPaths"), LoadColumn(11)] 52 | public float PctHorPaths { get; set; } 53 | 54 | 55 | [ColumnName("pctVertPaths"), LoadColumn(12)] 56 | public float PctVertPaths { get; set; } 57 | 58 | 59 | [ColumnName("pctOblPaths"), LoadColumn(13)] 60 | public float PctOblPaths { get; set; } 61 | 62 | 63 | [ColumnName("imagesCount"), LoadColumn(14)] 64 | public float ImagesCount { get; set; } 65 | 66 | 67 | [ColumnName("imageAvgProportion"), LoadColumn(15)] 68 | public float ImageAvgProportion { get; set; } 69 | 70 | 71 | [ColumnName("bestNormEditDistance"), LoadColumn(16)] 72 | public float BestNormEditDistance { get; set; } 73 | 74 | 75 | [ColumnName("label"), LoadColumn(17)] 76 | public float Label { get; set; } 77 | 78 | 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbm/LightGbmBlockClassifier.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML; 2 | using PdfPigMLNetBlockClassifier.Data.v1; 3 | using System.Collections.Generic; 4 | using System.Linq; 5 | using UglyToad.PdfPig.Content; 6 | using UglyToad.PdfPig.DocumentLayoutAnalysis; 7 | using UglyToad.PdfPig.DocumentLayoutAnalysis.PageSegmenter; 8 | using UglyToad.PdfPig.Util; 9 | 10 | namespace PdfPigMLNetBlockClassifier.LightGbm 11 | { 12 | public class LightGbmBlockClassifier 13 | { 14 | private MLContext mlContext; 15 | private ITransformer mlModel; 16 | private PredictionEngine predEngine; 17 | 18 | public LightGbmBlockClassifier(string modelPath) 19 | { 20 | mlContext = new MLContext(); 21 | mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema); 22 | predEngine = mlContext.Model.CreatePredictionEngine(mlModel); 23 | } 24 | 25 | public IEnumerable<(string Prediction, float Score, TextBlock Block)> Classify(Page page, IWordExtractor wordExtractor, IPageSegmenter pageSegmenter) 26 | { 27 | var words = wordExtractor.GetWords(page.Letters); 28 | var blocks = pageSegmenter.GetBlocks(words); 29 | 30 | foreach (var block in blocks) 31 | { 32 | var letters = block.TextLines.SelectMany(li => li.Words).SelectMany(w => w.Letters); 33 | var paths = FeatureHelper.GetPathsInside(block.BoundingBox, page.ExperimentalAccess.Paths); 34 | var images = FeatureHelper.GetImagesInside(block.BoundingBox, page.GetImages()); 35 | var features = FeatureHelper.GetFeatures(page, block.BoundingBox, letters, paths, images); 36 | 37 | BlockFeatures blockFeatures = new BlockFeatures() 38 | { 39 | CharsCount = features[0], 40 | PctNumericChars = features[1], 41 | PctAlphabeticalChars = features[2], 42 | PctSymbolicChars = features[3], 43 | PctBulletChars = features[4], 44 | DeltaToHeight = features[5], 45 | PathsCount = features[6], 46 | PctBezierPaths = features[7], 47 | PctHorPaths = features[8], 48 | PctVertPaths = features[9], 49 | PctOblPaths = features[10], 50 | ImagesCount = features[11], 51 | ImageAvgProportion = features[12] 52 | }; 53 | 54 | var result = predEngine.Predict(blockFeatures); 55 | 56 | yield return (FeatureHelper.Categories[(int)result.Prediction], result.Score.Max(), block); 57 | } 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.28307.852 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "PdfPigMLNetBlockClassifier", "PdfPigMLNetBlockClassifier\PdfPigMLNetBlockClassifier.csproj", "{1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}" 7 | EndProject 8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "PdfPigMLNetBlockClassifier.LightGbm", "PdfPigMLNetBlockClassifier.LightGbm\PdfPigMLNetBlockClassifier.LightGbm.csproj", "{42436C8B-339F-4020-A52D-F6787C267483}" 9 | EndProject 10 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "PdfPigMLNetBlockClassifier.Data", "PdfPigMLNetBlockClassifier.Data\PdfPigMLNetBlockClassifier.Data.csproj", "{D4147008-F6C7-4528-9621-219F41BC12DB}" 11 | EndProject 12 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PdfPigMLNetBlockClassifier.LightGbmV2", "PdfPigMLNetBlockClassifier.LightGbmV2\PdfPigMLNetBlockClassifier.LightGbmV2.csproj", "{3267E733-7D48-4A11-AB73-1261ACBC8681}" 13 | EndProject 14 | Global 15 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 16 | Debug|Any CPU = Debug|Any CPU 17 | Release|Any CPU = Release|Any CPU 18 | EndGlobalSection 19 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 20 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 21 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Debug|Any CPU.Build.0 = Debug|Any CPU 22 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Release|Any CPU.ActiveCfg = Release|Any CPU 23 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Release|Any CPU.Build.0 = Release|Any CPU 24 | {42436C8B-339F-4020-A52D-F6787C267483}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 25 | {42436C8B-339F-4020-A52D-F6787C267483}.Debug|Any CPU.Build.0 = Debug|Any CPU 26 | {42436C8B-339F-4020-A52D-F6787C267483}.Release|Any CPU.ActiveCfg = Release|Any CPU 27 | {42436C8B-339F-4020-A52D-F6787C267483}.Release|Any CPU.Build.0 = Release|Any CPU 28 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 29 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Debug|Any CPU.Build.0 = Debug|Any CPU 30 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Release|Any CPU.ActiveCfg = Release|Any CPU 31 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Release|Any CPU.Build.0 = Release|Any CPU 32 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 33 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Debug|Any CPU.Build.0 = Debug|Any CPU 34 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Release|Any CPU.ActiveCfg = Release|Any CPU 35 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Release|Any CPU.Build.0 = Release|Any CPU 36 | EndGlobalSection 37 | GlobalSection(SolutionProperties) = preSolution 38 | HideSolutionNode = FALSE 39 | EndGlobalSection 40 | GlobalSection(ExtensibilityGlobals) = postSolution 41 | SolutionGuid = {91D33AE2-5546-4275-B7C7-3B029863A3FB} 42 | EndGlobalSection 43 | EndGlobal 44 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier/Program.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Data; 2 | using PdfPigMLNetBlockClassifier.Data.v2; 3 | using PdfPigMLNetBlockClassifier.LightGbmV2; 4 | using System; 5 | using System.Collections.Generic; 6 | using System.Linq; 7 | using UglyToad.PdfPig; 8 | using UglyToad.PdfPig.DocumentLayoutAnalysis.PageSegmenter; 9 | using UglyToad.PdfPig.DocumentLayoutAnalysis.WordExtractor; 10 | using UglyToad.PdfPig.Outline; 11 | 12 | namespace PdfPigMLNetBlockClassifier 13 | { 14 | class Program 15 | { 16 | static readonly string TRAIN_RAW_DATA_FILEPATH = @"D:\Datasets\Document Layout Analysis\PubLayNet\extracted\train"; 17 | static readonly string TEST_RAW_DATA_FILEPATH = @"D:\Datasets\Document Layout Analysis\PubLayNet\extracted\val"; 18 | 19 | static readonly string TRAIN_DATA_FILENAME = "features_train_v2.csv"; 20 | static readonly string TEST_DATA_FILENAME = "features_val_v2.csv"; 21 | 22 | static readonly string MODEL_NAME = "modelV2.zip"; 23 | 24 | static void Main(string[] args) 25 | { 26 | // 1. Convert pdf documents and their PAGE xml ground truth to csv files 27 | //DataGenerator.GetCsv(TEST_RAW_DATA_FILEPATH, 0, TEST_DATA_FILENAME); // testing 28 | //DataGenerator.GetCsv(TRAIN_RAW_DATA_FILEPATH, 0, TRAIN_DATA_FILENAME); // training 29 | 30 | // 2. Create the model 31 | //LightGbmModelBuilder.TrainModel(DataGenerator.GetDataPath(TRAIN_DATA_FILENAME), MODEL_NAME); 32 | 33 | // 3. Evaluate the model 34 | //LightGbmModelBuilder.Evaluate(MODEL_NAME, DataGenerator.GetDataPath(TEST_DATA_FILENAME)); 35 | 36 | // 4. Load the trained classifier 37 | LightGbmBlockClassifier lightGbmBlockClassifier = new LightGbmBlockClassifier(LightGbmModelBuilder.GetModelPath(MODEL_NAME)); 38 | 39 | var test = lightGbmBlockClassifier.OutputSchema["label"].HasSlotNames(); 40 | 41 | NearestNeighbourWordExtractor nearestNeighbourWordExtractor = new NearestNeighbourWordExtractor(); 42 | RecursiveXYCut recursiveXYCut = new RecursiveXYCut(); 43 | 44 | using (var document = PdfDocument.Open("sample.pdf")) 45 | { 46 | var hasBookmarks = document.TryGetBookmarks(out Bookmarks bookmarks); 47 | 48 | for (var i = 0; i < document.NumberOfPages; i++) 49 | { 50 | var page = document.GetPage(i + 1); 51 | 52 | List bookmarksNodes = bookmarks?.GetNodes() 53 | .Where(b => b is DocumentBookmarkNode) 54 | .Select(b => b as DocumentBookmarkNode) 55 | .Cast() 56 | .Where(b => b.PageNumber == page.Number).ToList(); 57 | 58 | var avgPageFontHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average(); 59 | 60 | var words = nearestNeighbourWordExtractor.GetWords(page.Letters); 61 | var blocks = recursiveXYCut.GetBlocks(words, page.Width / 3.0); 62 | 63 | foreach (var block in blocks) 64 | { 65 | var paths = FeatureHelper.GetPathsInside(block.BoundingBox, page.ExperimentalAccess.Paths); 66 | var images = FeatureHelper.GetImagesInside(block.BoundingBox, page.GetImages()); 67 | 68 | var pred = lightGbmBlockClassifier.Classify(block, paths, images, avgPageFontHeight, bookmarksNodes); 69 | 70 | Console.WriteLine(); 71 | Console.WriteLine(pred.Prediction + " [" + pred.Score.ToString("0.0%") + "]"); 72 | Console.WriteLine(block.Text.Normalize(normalizationForm: System.Text.NormalizationForm.FormKC)); // remove ligatures 73 | } 74 | } 75 | } 76 | 77 | Console.ReadKey(); 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier/model/results.txt: -------------------------------------------------------------------------------- 1 | =============== Cross-validating to get model's accuracy metrics =============== 2 | ************************************************************************************************************* 3 | * Metrics for Multi-class Classification model 4 | *------------------------------------------------------------------------------------------------------------ 5 | * Average MicroAccuracy: 0.931 - Standard deviation: (.001) - Confidence Interval 95%: (.001) 6 | * Average MacroAccuracy: 0.743 - Standard deviation: (.001) - Confidence Interval 95%: (.001) 7 | * Average LogLoss: .219 - Standard deviation: (.002) - Confidence Interval 95%: (.002) 8 | * Average LogLossReduction: .747 - Standard deviation: (.003) - Confidence Interval 95%: (.003) 9 | ************************************************************************************************************* 10 | =============== Training model =============== 11 | =============== End of training process =============== 12 | =============== Saving the model =============== 13 | The model is saved to ../../../model\model.zip 14 | =============== Evaluating to get model's accuracy metrics =============== 15 | ************************************************************ 16 | * Metrics for multi-class classification model 17 | *----------------------------------------------------------- 18 | MacroAccuracy = 0.7482, a value between 0 and 1, the closer to 1, the better 19 | MicroAccuracy = 0.9369, a value between 0 and 1, the closer to 1, the better 20 | LogLoss = 0.2092, the closer to 0, the better 21 | LogLoss for class 1 = 0.2156, the closer to 0, the better 22 | LogLoss for class 2 = 3.1245, the closer to 0, the better 23 | LogLoss for class 3 = 0.306, the closer to 0, the better 24 | LogLoss for class 4 = 0.1094, the closer to 0, the better 25 | LogLoss for class 5 = 0.2472, the closer to 0, the better 26 | 27 | Confusion table 28 | ||======================================== 29 | PREDICTED || 0 | 1 | 2 | 3 | 4 | Recall 30 | TRUTH ||======================================== 31 | 0 || 1,765 | 0 | 0 | 145 | 2 | 0.9231 32 | 1 || 1 | 3 | 2 | 273 | 0 | 0.0108 33 | 2 || 0 | 0 | 623 | 63 | 5 | 0.9016 34 | 3 || 242 | 0 | 13 | 8,709 | 1 | 0.9714 35 | 4 || 1 | 0 | 2 | 2 | 71 | 0.9342 36 | ||======================================== 37 | Precision ||0.8785 |1.0000 |0.9734 |0.9475 |0.8987 | 38 | 39 | F1 Score for class 1 = 0.9003, a value between 0 and 1, the closer to 1, the better 40 | F1 Score for class 2 = 0.0213, a value between 0 and 1, the closer to 1, the better 41 | F1 Score for class 3 = 0.9361, a value between 0 and 1, the closer to 1, the better 42 | F1 Score for class 4 = 0.9593, a value between 0 and 1, the closer to 1, the better 43 | F1 Score for class 5 = 0.9161, a value between 0 and 1, the closer to 1, the better 44 | ************************************************************ 45 | =============== Permutation Feature Importance =============== 46 | PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that 47 | feature across all the examples, so that each example now has a random value for the feature 48 | and the original values for all other features. The evaluation metric(e.g.micro - accuracy) is 49 | then calculated for this modified dataset, and the change in the evaluation metric from the 50 | original dataset is computed.The larger the change in the evaluation metric, the more 51 | important the feature is to the model. PFI works by performing this permutation analysis 52 | across all the features of a model, one after another. 53 | 54 | Feature Change in MicroAccuracy 95% Confidence in the Mean Change in MicroAccuracy 55 | charsCount -0.2192 0.0008443 56 | pctNumericChars -0.04996 0.0004363 57 | deltaToHeight -0.04155 0.000428 58 | pctBulletChars -0.01571 0.0004034 59 | pctAlphabeticalChars -0.012 0.0003245 60 | pctSymbolicChars -0.01187 0.0004204 61 | pathsCount -0.01089 0.0002144 62 | pctHorPaths -0.002695 0.0001318 63 | imageAvgProportion -0.001895 3.909E-05 64 | pctVertPaths -0.001403 0.0001069 65 | pctOblPaths -0.0005032 3.612E-05 66 | pctBezierPaths -7.828E-05 1.563E-05 67 | imagesCount -2.796E-05 2.768E-05 68 | 69 | 70 | Feature Change in MacroAccuracy 95% Confidence in the Mean Change in MacroAccuracy 71 | charsCount -0.1906 0.001906 72 | pathsCount -0.07355 0.001211 73 | pctNumericChars -0.06516 0.0008356 74 | deltaToHeight -0.05476 0.001333 75 | pctOblPaths -0.01097 0.0002006 76 | pctAlphabeticalChars -0.009334 0.001237 77 | imageAvgProportion -0.00874 0.0001771 78 | pctVertPaths -0.005001 0.0003568 79 | pctHorPaths -0.004718 0.0004244 80 | pctSymbolicChars -0.001522 0.0008552 81 | pctBulletChars 0.0009088 0.0007747 82 | pctBezierPaths -0.0003737 5.705E-05 83 | imagesCount -3.805E-05 3.434E-05 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | 4 | # User-specific files 5 | *.suo 6 | *.user 7 | *.userosscache 8 | *.sln.docstates 9 | *csv 10 | *txt 11 | 12 | # User-specific files (MonoDevelop/Xamarin Studio) 13 | *.userprefs 14 | 15 | # Build results 16 | [Dd]ebug/ 17 | [Dd]ebugPublic/ 18 | [Rr]elease/ 19 | [Rr]eleases/ 20 | x64/ 21 | x86/ 22 | bld/ 23 | [Bb]in/ 24 | [Oo]bj/ 25 | [Ll]og/ 26 | 27 | # Visual Studio 2015 cache/options directory 28 | .vs/ 29 | # Uncomment if you have tasks that create the project's static files in wwwroot 30 | #wwwroot/ 31 | 32 | # MSTest test Results 33 | [Tt]est[Rr]esult*/ 34 | [Bb]uild[Ll]og.* 35 | 36 | # NUNIT 37 | *.VisualState.xml 38 | TestResult.xml 39 | 40 | # Build Results of an ATL Project 41 | [Dd]ebugPS/ 42 | [Rr]eleasePS/ 43 | dlldata.c 44 | 45 | # DNX 46 | project.lock.json 47 | project.fragment.lock.json 48 | artifacts/ 49 | 50 | *_i.c 51 | *_p.c 52 | *_i.h 53 | *.ilk 54 | *.meta 55 | *.obj 56 | *.pch 57 | *.pdb 58 | *.pgc 59 | *.pgd 60 | *.rsp 61 | *.sbr 62 | *.tlb 63 | *.tli 64 | *.tlh 65 | *.tmp 66 | *.tmp_proj 67 | *.log 68 | *.vspscc 69 | *.vssscc 70 | .builds 71 | *.pidb 72 | *.svclog 73 | *.scc 74 | 75 | # Chutzpah Test files 76 | _Chutzpah* 77 | 78 | # Visual C++ cache files 79 | ipch/ 80 | *.aps 81 | *.ncb 82 | *.opendb 83 | *.opensdf 84 | *.sdf 85 | *.cachefile 86 | *.VC.db 87 | *.VC.VC.opendb 88 | 89 | # Visual Studio profiler 90 | *.psess 91 | *.vsp 92 | *.vspx 93 | *.sap 94 | 95 | # TFS 2012 Local Workspace 96 | $tf/ 97 | 98 | # Guidance Automation Toolkit 99 | *.gpState 100 | 101 | # ReSharper is a .NET coding add-in 102 | _ReSharper*/ 103 | *.[Rr]e[Ss]harper 104 | *.DotSettings.user 105 | 106 | # JustCode is a .NET coding add-in 107 | .JustCode 108 | 109 | # TeamCity is a build add-in 110 | _TeamCity* 111 | 112 | # DotCover is a Code Coverage Tool 113 | *.dotCover 114 | 115 | # NCrunch 116 | _NCrunch_* 117 | .*crunch*.local.xml 118 | nCrunchTemp_* 119 | 120 | # MightyMoose 121 | *.mm.* 122 | AutoTest.Net/ 123 | 124 | # Web workbench (sass) 125 | .sass-cache/ 126 | 127 | # Installshield output folder 128 | [Ee]xpress/ 129 | 130 | # DocProject is a documentation generator add-in 131 | DocProject/buildhelp/ 132 | DocProject/Help/*.HxT 133 | DocProject/Help/*.HxC 134 | DocProject/Help/*.hhc 135 | DocProject/Help/*.hhk 136 | DocProject/Help/*.hhp 137 | DocProject/Help/Html2 138 | DocProject/Help/html 139 | 140 | # Click-Once directory 141 | publish/ 142 | 143 | # Publish Web Output 144 | *.[Pp]ublish.xml 145 | *.azurePubxml 146 | # TODO: Comment the next line if you want to checkin your web deploy settings 147 | # but database connection strings (with potential passwords) will be unencrypted 148 | #*.pubxml 149 | *.publishproj 150 | 151 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 152 | # checkin your Azure Web App publish settings, but sensitive information contained 153 | # in these scripts will be unencrypted 154 | PublishScripts/ 155 | 156 | # NuGet Packages 157 | *.nupkg 158 | # The packages folder can be ignored because of Package Restore 159 | **/packages/* 160 | # except build/, which is used as an MSBuild target. 161 | !**/packages/build/ 162 | # Uncomment if necessary however generally it will be regenerated when needed 163 | #!**/packages/repositories.config 164 | # NuGet v3's project.json files produces more ignoreable files 165 | *.nuget.props 166 | *.nuget.targets 167 | 168 | # Microsoft Azure Build Output 169 | csx/ 170 | *.build.csdef 171 | 172 | # Microsoft Azure Emulator 173 | ecf/ 174 | rcf/ 175 | 176 | # Windows Store app package directories and files 177 | AppPackages/ 178 | BundleArtifacts/ 179 | Package.StoreAssociation.xml 180 | _pkginfo.txt 181 | 182 | # Visual Studio cache files 183 | # files ending in .cache can be ignored 184 | *.[Cc]ache 185 | # but keep track of directories ending in .cache 186 | !*.[Cc]ache/ 187 | 188 | # Others 189 | ClientBin/ 190 | ~$* 191 | *~ 192 | *.dbmdl 193 | *.dbproj.schemaview 194 | *.jfm 195 | *.pfx 196 | *.publishsettings 197 | node_modules/ 198 | orleans.codegen.cs 199 | 200 | # Since there are multiple workflows, uncomment next line to ignore bower_components 201 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 202 | #bower_components/ 203 | 204 | # RIA/Silverlight projects 205 | Generated_Code/ 206 | 207 | # Backup & report files from converting an old project file 208 | # to a newer Visual Studio version. Backup files are not needed, 209 | # because we have git ;-) 210 | _UpgradeReport_Files/ 211 | Backup*/ 212 | UpgradeLog*.XML 213 | UpgradeLog*.htm 214 | 215 | # SQL Server files 216 | *.mdf 217 | *.ldf 218 | 219 | # Business Intelligence projects 220 | *.rdl.data 221 | *.bim.layout 222 | *.bim_*.settings 223 | 224 | # Microsoft Fakes 225 | FakesAssemblies/ 226 | 227 | # GhostDoc plugin setting file 228 | *.GhostDoc.xml 229 | 230 | # Node.js Tools for Visual Studio 231 | .ntvs_analysis.dat 232 | 233 | # Visual Studio 6 build log 234 | *.plg 235 | 236 | # Visual Studio 6 workspace options file 237 | *.opt 238 | 239 | # Visual Studio LightSwitch build output 240 | **/*.HTMLClient/GeneratedArtifacts 241 | **/*.DesktopClient/GeneratedArtifacts 242 | **/*.DesktopClient/ModelManifest.xml 243 | **/*.Server/GeneratedArtifacts 244 | **/*.Server/ModelManifest.xml 245 | _Pvt_Extensions 246 | 247 | # Paket dependency manager 248 | .paket/paket.exe 249 | paket-files/ 250 | 251 | # FAKE - F# Make 252 | .fake/ 253 | 254 | # JetBrains Rider 255 | .idea/ 256 | *.sln.iml 257 | 258 | # CodeRush 259 | .cr/ 260 | 261 | # Python Tools for Visual Studio (PTVS) 262 | __pycache__/ 263 | *.pyc -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbmV2/LightGbmBlockClassifier.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML; 2 | using PdfPigMLNetBlockClassifier.Data.v2; 3 | using System.Collections.Generic; 4 | using System.Linq; 5 | using UglyToad.PdfPig.Content; 6 | using UglyToad.PdfPig.Core; 7 | using UglyToad.PdfPig.DocumentLayoutAnalysis; 8 | using UglyToad.PdfPig.DocumentLayoutAnalysis.PageSegmenter; 9 | using UglyToad.PdfPig.Outline; 10 | using UglyToad.PdfPig.Util; 11 | 12 | namespace PdfPigMLNetBlockClassifier.LightGbmV2 13 | { 14 | public class LightGbmBlockClassifier 15 | { 16 | private MLContext mlContext; 17 | private ITransformer mlModel; 18 | private PredictionEngine predEngine; 19 | 20 | public DataViewSchema OutputSchema => predEngine?.OutputSchema; 21 | 22 | private static readonly int[] MLNetCategories = new int[] { 2, 0, 3, 4, 1 }; 23 | 24 | public LightGbmBlockClassifier(string modelPath) 25 | { 26 | mlContext = new MLContext(); 27 | mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema); 28 | predEngine = mlContext.Model.CreatePredictionEngine(mlModel); 29 | } 30 | 31 | public (string Prediction, float Score) Classify(TextBlock textBlock, IEnumerable paths, IEnumerable images, 32 | double averagePageFontHeight, List pageBookmarksNodes) 33 | { 34 | double bboxArea = textBlock.BoundingBox.Area; 35 | 36 | var letters = textBlock.TextLines.SelectMany(li => li.Words).SelectMany(w => w.Letters); 37 | 38 | 39 | var features = FeatureHelper.GetFeatures( 40 | textBlock, paths, 41 | images, averagePageFontHeight, 42 | textBlock.BoundingBox.Area, 43 | pageBookmarksNodes); 44 | 45 | BlockFeatures blockFeatures = new BlockFeatures() 46 | { 47 | BlockAspectRatio = features[0], 48 | CharsCount = features[1], 49 | WordsCount = features[2], 50 | LinesCount = features[3], 51 | PctNumericChars = features[4], 52 | PctAlphabeticalChars = features[5], 53 | PctSymbolicChars = features[6], 54 | PctBulletChars = features[7], 55 | DeltaToHeight = features[8], 56 | PathsCount = features[9], 57 | PctBezierPaths = features[10], 58 | PctHorPaths = features[11], 59 | PctVertPaths = features[12], 60 | PctOblPaths = features[13], 61 | ImagesCount = features[14], 62 | ImageAvgProportion = features[15], 63 | BestNormEditDistance = features[16], 64 | }; 65 | var result = predEngine.Predict(blockFeatures); 66 | 67 | return (FeatureHelper.Categories[(int)result.Prediction], result.Score.Max()); 68 | } 69 | 70 | public IEnumerable<(string Prediction, float Score, TextBlock Block)> Classify(Page page, IWordExtractor wordExtractor, 71 | IPageSegmenter pageSegmenter, Bookmarks bookmarks = null) 72 | { 73 | 74 | List bookmarksNodes = bookmarks?.GetNodes() 75 | .Where(b => b is DocumentBookmarkNode) 76 | .Select(b => b as DocumentBookmarkNode) 77 | .Cast() 78 | .Where(b => b.PageNumber == page.Number).ToList(); 79 | 80 | var avgPageFontHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average(); 81 | 82 | var words = wordExtractor.GetWords(page.Letters); 83 | var blocks = pageSegmenter.GetBlocks(words); 84 | 85 | foreach (var block in blocks) 86 | { 87 | var letters = block.TextLines.SelectMany(li => li.Words).SelectMany(w => w.Letters); 88 | var paths = FeatureHelper.GetPathsInside(block.BoundingBox, page.ExperimentalAccess.Paths); 89 | var images = FeatureHelper.GetImagesInside(block.BoundingBox, page.GetImages()); 90 | 91 | var features = FeatureHelper.GetFeatures( 92 | block, paths, 93 | images, avgPageFontHeight, 94 | block.BoundingBox.Area, 95 | bookmarksNodes); 96 | 97 | BlockFeatures blockFeatures = new BlockFeatures() 98 | { 99 | BlockAspectRatio = features[0], 100 | CharsCount = features[1], 101 | WordsCount = features[2], 102 | LinesCount = features[3], 103 | PctNumericChars = features[4], 104 | PctAlphabeticalChars = features[5], 105 | PctSymbolicChars = features[6], 106 | PctBulletChars = features[7], 107 | DeltaToHeight = features[8], 108 | PathsCount = features[9], 109 | PctBezierPaths = features[10], 110 | PctHorPaths = features[11], 111 | PctVertPaths = features[12], 112 | PctOblPaths = features[13], 113 | ImagesCount = features[14], 114 | ImageAvgProportion = features[15], 115 | BestNormEditDistance = features[16], 116 | }; 117 | 118 | var result = predEngine.Predict(blockFeatures); 119 | 120 | yield return (FeatureHelper.Categories[(int)result.Prediction], result.Score.Max(), block); 121 | } 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.Data/v1/FeatureHelper.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using UglyToad.PdfPig.Content; 5 | using UglyToad.PdfPig.Core; 6 | using static UglyToad.PdfPig.Core.PdfPath; 7 | 8 | namespace PdfPigMLNetBlockClassifier.Data.v1 9 | { 10 | public static class FeatureHelper 11 | { 12 | public static readonly Dictionary Categories = new Dictionary() 13 | { 14 | { 0, "text" }, 15 | { 1, "title" }, 16 | { 2, "list" }, 17 | { 3, "table" }, 18 | { 4, "image" }, 19 | }; 20 | 21 | private static readonly char[] Bullets = new char[] 22 | { 23 | '•', 'o', '▪', '❖', '➢', '►', '✓', '➔', '⇨', '➪', 24 | '➨', '➫', '➬', '➭', '➮', '➯', '➱', '➲', '\u2023', 25 | '\u2043', '\u204C', '\u204D' 26 | }; 27 | 28 | public static float[] GetFeatures(Page page, PdfRectangle bbox, IEnumerable letters, IEnumerable paths, IEnumerable images) 29 | { 30 | // Letters features 31 | float charsCount = 0; 32 | float pctNumericChars = 0; 33 | float pctAlphabeticalChars = 0; 34 | float pctSymbolicChars = 0; 35 | float pctBulletChars = 0; 36 | float deltaToHeight = -1; // might be problematic 37 | 38 | if (letters != null && letters.Count() > 0) 39 | { 40 | var avgHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average(); 41 | 42 | char[] chars = letters.SelectMany(l => l.Value).ToArray(); 43 | 44 | charsCount = chars.Length; 45 | pctNumericChars = (float)Math.Round(chars.Count(c => char.IsNumber(c)) / charsCount, 5); 46 | pctAlphabeticalChars = (float)Math.Round(chars.Count(c => char.IsLetter(c)) / charsCount, 5); 47 | pctSymbolicChars = (float)Math.Round(chars.Count(c => !char.IsLetterOrDigit(c)) / charsCount, 5); 48 | pctBulletChars = (float)Math.Round(chars.Count(c => Bullets.Any(bullet => bullet == c)) / charsCount, 5); 49 | deltaToHeight = avgHeight != 0 ? (float)Math.Round(letters.Select(l => l.GlyphRectangle.Height).Average() / avgHeight, 5) : -1; 50 | } 51 | 52 | // Paths features 53 | float pathsCount = 0; 54 | float pctBezierPaths = 0; 55 | float pctHorPaths = 0; 56 | float pctVertPaths = 0; 57 | float pctOblPaths = 0; 58 | 59 | if (paths != null && paths.Count() > 0) 60 | { 61 | foreach (var path in paths) 62 | { 63 | foreach (var command in path.Commands) 64 | { 65 | if (command is BezierCurve bezierCurve) 66 | { 67 | pathsCount++; 68 | pctBezierPaths++; 69 | } 70 | else if (command is Line line) 71 | { 72 | pathsCount++; 73 | if (line.From.X == line.To.X) 74 | { 75 | pctVertPaths++; 76 | } 77 | else if (line.From.Y == line.To.Y) 78 | { 79 | pctHorPaths++; 80 | } 81 | else 82 | { 83 | pctOblPaths++; 84 | } 85 | } 86 | } 87 | } 88 | 89 | pctBezierPaths = (float)Math.Round(pctBezierPaths / pathsCount, 5); 90 | pctHorPaths = (float)Math.Round(pctHorPaths / pathsCount, 5); 91 | pctVertPaths = (float)Math.Round(pctVertPaths / pathsCount, 5); 92 | pctOblPaths = (float)Math.Round(pctOblPaths / pathsCount, 5); 93 | } 94 | 95 | // Images features 96 | float imagesCount = 0; 97 | float imageAvgProportion = 0; 98 | 99 | if (images != null && images.Count() > 0) 100 | { 101 | imagesCount = images.Count(); 102 | imageAvgProportion = (float)(images.Average(i => i.Bounds.Area) / bbox.Area); 103 | } 104 | 105 | return new float[] 106 | { 107 | charsCount, pctNumericChars, pctAlphabeticalChars, pctSymbolicChars, pctBulletChars, deltaToHeight, 108 | pathsCount, pctBezierPaths, pctHorPaths, pctVertPaths, pctOblPaths, 109 | imagesCount, imageAvgProportion 110 | }; 111 | } 112 | 113 | public static IEnumerable GetLettersInside(PdfRectangle bound, IEnumerable letters) 114 | { 115 | return letters.Where(l => l.GlyphRectangle.Left >= bound.Left && 116 | l.GlyphRectangle.Right <= bound.Right && 117 | l.GlyphRectangle.Bottom >= bound.Bottom && 118 | l.GlyphRectangle.Top <= bound.Top); 119 | } 120 | 121 | public static IEnumerable GetImagesInside(PdfRectangle bound, IEnumerable images) 122 | { 123 | return images.Where(b => b.Bounds.Left >= bound.Left && 124 | b.Bounds.Right <= bound.Right && 125 | b.Bounds.Bottom >= bound.Bottom && 126 | b.Bounds.Top <= bound.Top); 127 | } 128 | 129 | public static IEnumerable GetPathsInside(PdfRectangle bound, IEnumerable paths) 130 | { 131 | return paths.Where(b => b.GetBoundingRectangle().HasValue) 132 | .Where(b => b.GetBoundingRectangle().Value.Left >= bound.Left && 133 | b.GetBoundingRectangle().Value.Right <= bound.Right && 134 | b.GetBoundingRectangle().Value.Bottom >= bound.Bottom && 135 | b.GetBoundingRectangle().Value.Top <= bound.Top); 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbm/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | Results are based on PubLayNet's validation dataset, where the page segmentation is known. For real life use, a page segmenter will be needed (see PdfPig's [PageSegmenters](https://github.com/UglyToad/PdfPig/tree/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/PageSegmenter)). The quality of the page segmentation will impact the results. 3 | ## Metrics for multi-class classification model 4 | ``` 5 | MicroAccuracy = 0.9369, a value between 0 and 1, the closer to 1, the better 6 | MacroAccuracy = 0.7482, a value between 0 and 1, the closer to 1, the better 7 | LogLoss = 0.2092, the closer to 0, the better 8 | 9 | LogLoss for class 0 (title) = 0.2156, the closer to 0, the better 10 | LogLoss for class 1 (list) = 3.1245, the closer to 0, the better 11 | LogLoss for class 2 (table) = 0.3060, the closer to 0, the better 12 | LogLoss for class 3 (text) = 0.1094, the closer to 0, the better 13 | LogLoss for class 4 (image) = 0.2472, the closer to 0, the better 14 | 15 | F1 Score for class 0 (title) = 0.9003, a value between 0 and 1, the closer to 1, the better 16 | F1 Score for class 1 (list) = 0.0213, a value between 0 and 1, the closer to 1, the better 17 | F1 Score for class 2 (table) = 0.9361, a value between 0 and 1, the closer to 1, the better 18 | F1 Score for class 3 (text) = 0.9593, a value between 0 and 1, the closer to 1, the better 19 | F1 Score for class 4 (image) = 0.9161, a value between 0 and 1, the closer to 1, the better 20 | ``` 21 | 22 | ## Confusion table 23 | ``` 24 | ||======================================================= 25 | PREDICTED || 0 | 1 | 2 | 3 | 4 | Total | Recall 26 | TRUTH ||======================================================= 27 | (title) 0 || 1,765 | 0 | 0 | 145 | 2 | 1,912 | 0.9231 28 | (list) 1 || 1 | 3 | 2 | 273 | 0 | 279 | 0.0108 29 | (table) 2 || 0 | 0 | 623 | 63 | 5 | 691 | 0.9016 30 | (text) 3 || 242 | 0 | 13 | 8,709 | 1 | 8,965 | 0.9714 31 | (image) 4 || 1 | 0 | 2 | 2 | 71 | 76 | 0.9342 32 | ||======================================================= 33 | Precision ||0.8785 |1.0000 |0.9734 |0.9475 |0.8987 | 34 | ``` 35 | 36 | ## Permutation Feature Importance 37 | PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that 38 | feature across all the examples, so that each example now has a random value for the feature 39 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is 40 | then calculated for this modified dataset, and the change in the evaluation metric from the 41 | original dataset is computed. The larger the change in the evaluation metric, the more 42 | important the feature is to the model. PFI works by performing this permutation analysis 43 | across all the features of a model, one after another. - [Source]( https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_) 44 | 45 | ### Micro Accuracy 46 | 47 | |Feature | Description | Change in MicroAccuracy | 95% Confidence in the Mean Change in MicroAccuracy | 48 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:| 49 | |**charsCount** |Characters count|-0.2192 |0.0008443 | 50 | |**pctNumericChars** |% of numeric characters|-0.04996 |0.0004363 | 51 | |**deltaToHeight** |Average delta to average page glyph height|-0.04155 |0.000428| 52 | |**pctBulletChars** |% of bullet characters|-0.01571 |0.0004034| 53 | |**pctAlphabeticalChars** |% of alphabetical characters|-0.012 |0.0003245 | 54 | |**pctSymbolicChars** |% of symbolic characters|-0.01187 |0.0004204 | 55 | |**pathsCount** |Paths count|-0.01089 |0.0002144 | 56 | |**pctHorPaths** |% of horizontal paths|-0.002695 |0.0001318 | 57 | |**imageAvgProportion** |Average area covered by images|-0.001895 |0.00003909 | 58 | |**pctVertPaths** |% of vertical paths|-0.001403 |0.0001069 | 59 | |**pctOblPaths** |% of oblique paths|-0.0005032 |0.00003612 | 60 | |**pctBezierPaths** |% of Bezier curve paths|-0.00007828 |0.00001563 | 61 | |**imagesCount** |Images count|-0.00002796 |0.00002768 | 62 | 63 | ### Macro Accuracy 64 | 65 | |Feature | Description | Change in MacroAccuracy | 95% Confidence in the Mean Change in MacroAccuracy | 66 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:| 67 | |**charsCount** |Characters count|-0.1906 |0.001906| 68 | |**pathsCount** |Paths count|-0.07355 |0.001211| 69 | |**pctNumericChars** |% of numeric characters|-0.06516 |0.0008356| 70 | |**deltaToHeight** |Average delta to average page glyph height|-0.05476 |0.001333| 71 | |**pctOblPaths** |% of oblique paths|-0.01097 |0.0002006| 72 | |**pctAlphabeticalChars** |% of alphabetical characters|-0.009334 |0.001237| 73 | |**imageAvgProportion** |Average area covered by images|-0.00874 |0.0001771| 74 | |**pctVertPaths** |% of vertical paths|-0.005001 |0.0003568| 75 | |**pctHorPaths** |% of horizontal paths|-0.004718 |0.0004244| 76 | |**pctSymbolicChars** |% of symbolic characters|-0.001522 |0.0008552| 77 | |**pctBulletChars** |% of bullet characters|0.0009088 |0.0007747| 78 | |**pctBezierPaths** |% of Bezier curve paths|-0.0003737 |0.00005705| 79 | |**imagesCount** |Images count|-0.00003805 |0.00003434| 80 | 81 | # TO DO 82 | ## Features 83 | - Add a [decoration](https://github.com/UglyToad/PdfPig/blob/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/DecorationTextBlockClassifier.cs) score/flag 84 | - Add block's number of line, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf) 85 | - Add block's aspect ratio: the ratio between width and height of the bounding box, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf) 86 | - Add block's area ratio: the ratio between block area and the page area, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf) 87 | - Add block's font style: an enumerated type, with possible values: regular, bold, italic, underline, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf) 88 | - Use [bookmarks](https://github.com/UglyToad/PdfPig/blob/master/src/UglyToad.PdfPig/Outline/BookmarksProvider.cs) when available with minimum edit distance score 89 | - Add % sparse lines in a block, for better table recognition [cf.](https://clgiles.ist.psu.edu/pubs/CIKM2008-table-boundaries.pdf) 90 | - Font color distance from most common color 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PdfPig ML.Net Block Classifier v2 2 | Proof of concept of training a simple Region Classifier using [PdfPig](https://github.com/UglyToad/PdfPig) and [ML.NET](https://github.com/dotnet/machinelearning). 3 | The objective is to classify each text block in a __pdf document__ page as either __title__, __text__, __list__, __table__ and __image__. 4 | 5 | [AutoML](https://docs.microsoft.com/en-us/dotnet/machine-learning/automate-training-with-model-builder) model builder was used. The model was 6 | trained on a subset of the [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet#getting-data) dataset. See their license [here](https://cdla.io/permissive-1-0/). 7 | 8 | | Generation | MicroAccuracy | MacroAccuracy | 9 | |-----------:|:-------------:|:-------------:| 10 | | **v1** | 0.937 | 0.748 | 11 | | **v2** | 0.952 | 0.801 | 12 | 13 | For v1 model results, see [here](https://github.com/BobLd/PdfPigMLNetBlockClassifier/blob/master/PdfPigMLNetBlockClassifier.LightGbm/README.md). 14 | 15 | # Results 16 | Results are based on PubLayNet's validation dataset, where the page segmentation is known. For real life use, a page segmenter will be needed (see PdfPig's [PageSegmenters](https://github.com/UglyToad/PdfPig/tree/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/PageSegmenter)). The quality of the page segmentation will impact the results. 17 | ## Metrics for multi-class classification model 18 | ``` 19 | MicroAccuracy = 0.9523, a value between 0 and 1, the closer to 1, the better 20 | MacroAccuracy = 0.8009, a value between 0 and 1, the closer to 1, the better 21 | LogLoss = 3.5333, the closer to 0, the better 22 | 23 | LogLoss for class 0 (title) = 4.6289, the closer to 0, the better 24 | LogLoss for class 1 (image) = 1.849, the closer to 0, the better 25 | LogLoss for class 2 (text) = 2.5834, the closer to 0, the better 26 | LogLoss for class 3 (list) = 28.8412, the closer to 0, the better 27 | LogLoss for class 4 (table) = 2.8326, the closer to 0, the better 28 | 29 | F1 Score for class 0 (title) = 0.9305, a value between 0 and 1, the closer to 1, the better 30 | F1 Score for class 1 (image) = 0.9412, a value between 0 and 1, the closer to 1, the better 31 | F1 Score for class 2 (text) = 0.9691, a value between 0 and 1, the closer to 1, the better 32 | F1 Score for class 3 (list) = 0.2963, a value between 0 and 1, the closer to 1, the better 33 | F1 Score for class 4 (table) = 0.9611, a value between 0 and 1, the closer to 1, the better 34 | ``` 35 | 36 | ## Confusion table 37 | ``` 38 | ||======================================================= 39 | PREDICTED || 0 | 1 | 2 | 3 | 4 | Total | Recall 40 | TRUTH ||======================================================= 41 | (title) 0 || 1,848 | 0 | 89 | 1 | 0 | 1,938 | 0.9536 42 | (image) 1 || 0 | 72 | 3 | 0 | 1 | 76 | 0.9474 43 | (text) 2 || 185 | 1 | 8,837 | 16 | 9 | 9,048 | 0.9767 44 | (list) 3 || 1 | 0 | 225 | 52 | 2 | 280 | 0.1857 45 | (table) 4 || 0 | 4 | 35 | 2 | 654 | 695 | 0.9410 46 | ||======================================================= 47 | Precision ||0.9086 |0.9351 |0.9617 |0.7324 |0.9820 | 48 | ``` 49 | 50 | ## Permutation Feature Importance 51 | PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that 52 | feature across all the examples, so that each example now has a random value for the feature 53 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is 54 | then calculated for this modified dataset, and the change in the evaluation metric from the 55 | original dataset is computed. The larger the change in the evaluation metric, the more 56 | important the feature is to the model. PFI works by performing this permutation analysis 57 | across all the features of a model, one after another. - [Source]( https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_) 58 | 59 | ### Micro Accuracy 60 | 61 | |Feature | Description | Change in MicroAccuracy | 95% Confidence in the Mean Change in MicroAccuracy | 62 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:| 63 | |**wordsCount**| Words count |-0.1144 |0.0008237| 64 | |**charsCount**| Characters count |-0.09849|0.0009115| 65 | |**linesCount**| Lines count |-0.04255 |0.0005726| 66 | |**blockAspectRatio**|Ratio between the block's width and height|-0.04159|0.0005134| 67 | |**bestNormEditDistance**|Minimum edit distance between bookmark title and block text|-0.04016 | 0.0006307| 68 | |**deltaToHeight**|Average delta to average page glyph height|-0.03689 | 0.0003823| 69 | |**pctNumericChars**|% of numeric characters|-0.03375 | 0.0005174| 70 | |**pctSymbolicChars**|% of symbolic characters|-0.01035 | 0.000358| 71 | |**pctAlphabeticalChars**|% of alphabetical characters| -0.00859 | 0.0003075| 72 | |**pctBulletChars**|% of bullet characters|-0.004071 | 0.0002063| 73 | |**pathsCount**|Paths count|-0.003661 | 0.0001184| 74 | |**pctHorPaths**|% of horizontal paths|-0.003431 | 0.0001309| 75 | |**imageAvgProportion**|Average area covered by images|-0.001684 | 4.68E-05| 76 | |**pctVertPaths**|% of vertical paths|-0.001448 | 8.139E-05| 77 | |**pctOblPaths**|% of oblique paths|-0.0001883 | 1.734E-05| 78 | |**imagesCount**|Images count|-9.692E-05 | 1.127E-05| 79 | |**pctBezierPaths**|% of Bezier curve paths|6.212E-22 | 1.352E-05| 80 | 81 | ### Macro Accuracy 82 | 83 | |Feature | Description | Change in MacroAccuracy | 95% Confidence in the Mean Change in MacroAccuracy | 84 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:| 85 | |**charsCount**|Characters count| -0.1339 | 0.002263| 86 | |**linesCount**| Lines count | -0.08378 | 0.0009925| 87 | |**deltaToHeight**|Average delta to average page glyph height| -0.06461 | 0.001223| 88 | |**blockAspectRatio**|Ratio between the block's width and height| -0.05896 | 0.001163| 89 | |**bestNormEditDistance**|Minimum edit distance between bookmark title and block text| -0.05039 | 0.001419| 90 | |**pctNumericChars**|% of numeric characters| -0.0475 | 0.001258| 91 | |**pathsCount**|Paths count| -0.0252 | 0.000614| 92 | |**wordsCount**|Words count| -0.01651 | 0.001655| 93 | |**pctSymbolicChars**|% of symbolic characters| -0.01166 | 0.001505| 94 | |**pctHorPaths**|% of horizontal paths| -0.00984 | 0.0003186| 95 | |**pctVertPaths**|% of vertical paths| -0.008558 | 0.0002925| 96 | |**pctAlphabeticalChars**|% of alphabetical characters| -0.006882 | 0.0009913| 97 | |**imageAvgProportion**|Average area covered by images| -0.006148 | 0.0001146| 98 | |**pctBulletChars**|% of bullet characters| -0.005319 | 0.0008375| 99 | |**pctOblPaths**|% of oblique paths| -0.002747 | 6.657E-05| 100 | |**imagesCount**|Images count| -0.00268 | 3.903E-05| 101 | |**pctBezierPaths**|% of Bezier curve paths| -4.078E-05 | 5.263E-05| 102 | 103 | # TO DO 104 | ## Features 105 | - Add a [decoration](https://github.com/UglyToad/PdfPig/blob/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/DecorationTextBlockClassifier.cs) score/flag 106 | - Add block's area ratio: the ratio between block area and the page area, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf) 107 | - Add block's font style: an enumerated type, with possible values: regular, bold, italic, underline, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf) 108 | - Add % sparse lines in a block, for better table recognition [cf.](https://clgiles.ist.psu.edu/pubs/CIKM2008-table-boundaries.pdf) 109 | - Font color distance from most common color 110 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.Data/v2/FeatureHelper.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using UglyToad.PdfPig.Content; 5 | using UglyToad.PdfPig.Core; 6 | using UglyToad.PdfPig.DocumentLayoutAnalysis; 7 | using UglyToad.PdfPig.DocumentLayoutAnalysis.WordExtractor; 8 | using static UglyToad.PdfPig.Core.PdfPath; 9 | using UglyToad.PdfPig.Geometry; 10 | using UglyToad.PdfPig.Outline; 11 | using System.Text; 12 | 13 | namespace PdfPigMLNetBlockClassifier.Data.v2 14 | { 15 | public static class FeatureHelper 16 | { 17 | public static readonly string Header = "blockAspectRatio,charsCount,wordsCount,linesCount,pctNumericChars,pctAlphabeticalChars,pctSymbolicChars,pctBulletChars,deltaToHeight,pathsCount,pctBezierPaths,pctHorPaths,pctVertPaths,pctOblPaths,imagesCount,imageAvgProportion,bestNormEditDistance,label"; 18 | 19 | public static readonly Dictionary Categories = new Dictionary() 20 | { 21 | { 0, "text" }, 22 | { 1, "title" }, 23 | { 2, "list" }, 24 | { 3, "table" }, 25 | { 4, "image" }, 26 | }; 27 | 28 | private static readonly char[] Bullets = new char[] 29 | { 30 | '•', 'o', '▪', '❖', '➢', '►', '✓', '➔', '⇨', '➪', 31 | '➨', '➫', '➬', '➭', '➮', '➯', '➱', '➲', '\u2023', 32 | '\u2043', '\u204C', '\u204D' 33 | }; 34 | 35 | public static float[] GetFeatures(TextBlock textBlock, IEnumerable paths, IEnumerable images, 36 | double averagePageFontHeight, double bboxArea, List pageBookmarksNodes) 37 | { 38 | // text block features 39 | float blockAspectRatio = float.NaN; 40 | 41 | // Letters features 42 | float charsCount = 0; 43 | float pctNumericChars = 0; 44 | float pctAlphabeticalChars = 0; 45 | float pctSymbolicChars = 0; 46 | float pctBulletChars = 0; 47 | float deltaToHeight = float.NaN; // might be problematic 48 | 49 | float wordsCount = 0; 50 | float linesCount = 0; 51 | float bestNormEditDistance = float.NaN; 52 | 53 | if (textBlock?.TextLines != null && textBlock.TextLines.Any()) 54 | { 55 | blockAspectRatio = (float)Math.Round(textBlock.BoundingBox.Width / textBlock.BoundingBox.Height, 5); 56 | 57 | var avgHeight = averagePageFontHeight; 58 | 59 | var textLines = textBlock.TextLines; 60 | var words = textLines.SelectMany(tl => tl.Words).ToList(); 61 | var letters = words.SelectMany(w => w.Letters).ToList(); 62 | char[] chars = letters.SelectMany(l => l.Value).ToArray(); 63 | 64 | charsCount = chars.Length; 65 | pctNumericChars = (float)Math.Round(chars.Count(c => char.IsNumber(c)) / charsCount, 5); 66 | pctAlphabeticalChars = (float)Math.Round(chars.Count(c => char.IsLetter(c)) / charsCount, 5); 67 | pctSymbolicChars = (float)Math.Round(chars.Count(c => !char.IsLetterOrDigit(c)) / charsCount, 5); 68 | pctBulletChars = (float)Math.Round(chars.Count(c => Bullets.Any(bullet => bullet == c)) / charsCount, 5); 69 | if (avgHeight != 0) 70 | { 71 | deltaToHeight = (float)Math.Round(letters.Select(l => l.GlyphRectangle.Height).Average() / avgHeight, 5); 72 | } 73 | 74 | wordsCount = words.Count(); 75 | linesCount = textLines.Count(); 76 | 77 | if (pageBookmarksNodes != null) 78 | { 79 | // http://www.unicode.org/reports/tr15/ 80 | var textBlockNormalised = textBlock.Text.Normalize(NormalizationForm.FormKC).ToLower(); 81 | foreach (var bookmark in pageBookmarksNodes) 82 | { 83 | // need to normalise both text 84 | var bookmarkTextNormalised = bookmark.Title.Normalize(NormalizationForm.FormKC).ToLower(); 85 | var currentDist = Distances.MinimumEditDistanceNormalised(textBlockNormalised, bookmarkTextNormalised); 86 | if (float.IsNaN(bestNormEditDistance) || currentDist < bestNormEditDistance) 87 | { 88 | bestNormEditDistance = (float)Math.Round(currentDist, 5); 89 | } 90 | } 91 | } 92 | } 93 | 94 | // Paths features 95 | float pathsCount = 0; 96 | float pctBezierPaths = 0; 97 | float pctHorPaths = 0; 98 | float pctVertPaths = 0; 99 | float pctOblPaths = 0; 100 | 101 | if (paths != null && paths.Count() > 0) 102 | { 103 | foreach (var path in paths) 104 | { 105 | foreach (var command in path.Commands) 106 | { 107 | if (command is BezierCurve bezierCurve) 108 | { 109 | pathsCount++; 110 | pctBezierPaths++; 111 | } 112 | else if (command is Line line) 113 | { 114 | pathsCount++; 115 | if (line.From.X == line.To.X) 116 | { 117 | pctVertPaths++; 118 | } 119 | else if (line.From.Y == line.To.Y) 120 | { 121 | pctHorPaths++; 122 | } 123 | else 124 | { 125 | pctOblPaths++; 126 | } 127 | } 128 | } 129 | } 130 | 131 | pctBezierPaths = (float)Math.Round(pctBezierPaths / pathsCount, 5); 132 | pctHorPaths = (float)Math.Round(pctHorPaths / pathsCount, 5); 133 | pctVertPaths = (float)Math.Round(pctVertPaths / pathsCount, 5); 134 | pctOblPaths = (float)Math.Round(pctOblPaths / pathsCount, 5); 135 | } 136 | 137 | // Images features 138 | float imagesCount = 0; 139 | float imageAvgProportion = 0; 140 | 141 | if (images != null && images.Count() > 0) 142 | { 143 | imagesCount = images.Count(); 144 | imageAvgProportion = (float)(images.Average(i => i.Bounds.Area) / bboxArea); 145 | } 146 | 147 | return new float[] 148 | { 149 | blockAspectRatio, charsCount, wordsCount, linesCount, pctNumericChars, 150 | pctAlphabeticalChars, pctSymbolicChars, pctBulletChars, deltaToHeight, 151 | pathsCount, pctBezierPaths, pctHorPaths, pctVertPaths, pctOblPaths, 152 | imagesCount, imageAvgProportion, bestNormEditDistance 153 | }; 154 | } 155 | 156 | public static IEnumerable GetLettersInside(PdfRectangle bound, IEnumerable letters) 157 | { 158 | return letters.Where(l => bound.IntersectsWith(l.GlyphRectangle)); 159 | } 160 | 161 | static NearestNeighbourWordExtractor nearestNeighbourWordExtractor = NearestNeighbourWordExtractor.Instance; 162 | 163 | public static IReadOnlyList GetWords(IReadOnlyList letters) 164 | { 165 | return nearestNeighbourWordExtractor.GetWords(letters).ToList(); 166 | } 167 | 168 | public static IReadOnlyList GetLines(IReadOnlyList words) 169 | { 170 | return words.GroupBy(x => x.BoundingBox.Bottom).OrderByDescending(x => x.Key) 171 | .Select(x => new TextLine(x.ToList())).ToArray(); 172 | } 173 | 174 | public static IEnumerable GetImagesInside(PdfRectangle bound, IEnumerable images) 175 | { 176 | return images.Where(b => b.Bounds.Left >= bound.Left && 177 | b.Bounds.Right <= bound.Right && 178 | b.Bounds.Bottom >= bound.Bottom && 179 | b.Bounds.Top <= bound.Top); 180 | } 181 | 182 | public static IEnumerable GetPathsInside(PdfRectangle bound, IEnumerable paths) 183 | { 184 | return paths.Where(b => b.GetBoundingRectangle().HasValue) 185 | .Where(b => b.GetBoundingRectangle().Value.Left >= bound.Left && 186 | b.GetBoundingRectangle().Value.Right <= bound.Right && 187 | b.GetBoundingRectangle().Value.Bottom >= bound.Bottom && 188 | b.GetBoundingRectangle().Value.Top <= bound.Top); 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.Data/v1/DataGenerator.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Concurrent; 3 | using System.Collections.Generic; 4 | using System.IO; 5 | using System.Linq; 6 | using System.Threading.Tasks; 7 | using System.Xml; 8 | using System.Xml.Serialization; 9 | using UglyToad.PdfPig; 10 | using UglyToad.PdfPig.Core; 11 | using UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE; 12 | using static UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE.PageXmlDocument; 13 | 14 | namespace PdfPigMLNetBlockClassifier.Data.v1 15 | { 16 | public static class DataGenerator 17 | { 18 | public static readonly string OutputFolderPath = @"../../../data"; 19 | private static readonly string header = "charsCount,pctNumericChars,pctAlphabeticalChars,pctSymbolicChars,pctBulletChars,deltaToHeight,pathsCount,pctBezierPaths,pctHorPaths,pctVertPaths,pctOblPaths,imagesCount,imageAvgProportion,label"; 20 | 21 | /// 22 | /// Generate a csv file of features. You will need the pdf documents and the ground truths in PAGE xml format. 23 | /// 24 | /// The path to the data folder. It should contain both the pdf files 25 | /// and their corresponding ground truth xml files. 26 | /// Number of documents to concider. 27 | public static void GetCsv(string dataFolder, int numberOfPdfDocs, string outputFileName) 28 | { 29 | string outputFullPath = GetDataPath(outputFileName); 30 | string outputErrorFullPath = Path.Combine(OutputFolderPath, "invalide_pdfs_" + Path.ChangeExtension(outputFileName, "txt")); 31 | 32 | ConcurrentBag invalidPdfs = new ConcurrentBag(); 33 | ConcurrentBag data = new ConcurrentBag(); 34 | 35 | int done = 0; 36 | 37 | DirectoryInfo d = new DirectoryInfo(dataFolder); 38 | var pdfFileLinks = d.GetFiles("*.pdf", SearchOption.TopDirectoryOnly); 39 | var maxPageNumber = d.GetFiles("*.xml", SearchOption.TopDirectoryOnly).Select(f => ParseXmlFileName(f.Name)).Max() + 1; 40 | 41 | numberOfPdfDocs = Math.Min(pdfFileLinks.Count(), numberOfPdfDocs); 42 | numberOfPdfDocs = numberOfPdfDocs == 0 ? pdfFileLinks.Count() : numberOfPdfDocs; 43 | 44 | var indexesSelected = GenerateRandom(numberOfPdfDocs, 0, pdfFileLinks.Length); 45 | 46 | Parallel.ForEach(indexesSelected, index => 47 | { 48 | var pdfFile = pdfFileLinks[index]; 49 | string fileName = pdfFile.Name; 50 | string xmlFileNameTemplate = fileName.Replace(".pdf", "_"); 51 | 52 | var pageXmlLinksCandidates = Enumerable.Range(0, maxPageNumber).Select(i => 53 | Path.Combine(dataFolder, fileName.Replace(".pdf", "_" + string.Format("{0:00000}", i) + ".xml"))).ToArray(); 54 | var pageXmlLinks = pageXmlLinksCandidates.Where(l => File.Exists(l)).Select(l => new FileInfo(l)).ToArray(); 55 | 56 | if (pageXmlLinks.Length == 0) 57 | { 58 | Console.BackgroundColor = ConsoleColor.DarkRed; 59 | Console.WriteLine("Error for document '" + fileName + "': No PageXml files found"); 60 | Console.ResetColor(); 61 | return; 62 | } 63 | 64 | try 65 | { 66 | var pagesNumbers = pageXmlLinks.Select(l => ParseXmlFileName(l.Name)).ToList(); 67 | List localFeatures = new List(); 68 | List localCategories = new List(); 69 | bool isValidDocument = true; 70 | 71 | using (var doc = PdfDocument.Open(pdfFile.FullName)) 72 | { 73 | // Checks if this pdf document looks to be valid 74 | if ((pagesNumbers.Max() + 1) > doc.NumberOfPages) 75 | { 76 | // ignore this document as page number is not correct 77 | Console.BackgroundColor = ConsoleColor.Red; 78 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as page number is not correct"); 79 | Console.ResetColor(); 80 | isValidDocument = false; 81 | } 82 | 83 | foreach (var pageXmlLink in pageXmlLinks) 84 | { 85 | if (!isValidDocument) break; 86 | 87 | int pageNo = ParseXmlFileName(pageXmlLink.Name); 88 | var page = doc.GetPage(pageNo + 1); 89 | 90 | if (page.Rotation.Value != 0) 91 | { 92 | Console.BackgroundColor = ConsoleColor.Yellow; 93 | Console.ForegroundColor = ConsoleColor.Black; 94 | Console.WriteLine("Error for document '" + fileName + "': Ignoring page " + (pageNo + 1) + " because it is rotated"); 95 | Console.ResetColor(); 96 | continue; 97 | } 98 | 99 | var pageXml = Deserialize(pageXmlLink.FullName); 100 | 101 | var blocks = pageXml.Page.Items; 102 | 103 | foreach (var block in blocks) 104 | { 105 | int category = -1; 106 | PdfRectangle bbox = new PdfRectangle(); 107 | 108 | if (block is PageXmlTextRegion textBlock) 109 | { 110 | bbox = ParsePageXmlCoord(textBlock.Coords.Points, page.Height); 111 | switch (textBlock.Type) 112 | { 113 | case PageXmlTextSimpleType.Paragraph: 114 | category = 0; 115 | break; 116 | case PageXmlTextSimpleType.Heading: 117 | category = 1; 118 | break; 119 | case PageXmlTextSimpleType.LisLabel: 120 | category = 2; 121 | break; 122 | default: 123 | throw new ArgumentException("Unknown category"); 124 | } 125 | 126 | if (FeatureHelper.GetLettersInside(bbox, page.Letters).Count() == 0) 127 | { 128 | Console.BackgroundColor = ConsoleColor.Red; 129 | Console.ForegroundColor = ConsoleColor.Black; 130 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as an empty paragraph was found"); 131 | Console.ResetColor(); 132 | isValidDocument = false; 133 | break; 134 | } 135 | } 136 | else if (block is PageXmlTableRegion tableBlock) 137 | { 138 | bbox = ParsePageXmlCoord(tableBlock.Coords.Points, page.Height); 139 | category = 3; 140 | } 141 | else if (block is PageXmlImageRegion imageBlock) 142 | { 143 | bbox = ParsePageXmlCoord(imageBlock.Coords.Points, page.Height); 144 | category = 4; 145 | } 146 | else 147 | { 148 | throw new ArgumentException("Unknown region type"); 149 | } 150 | 151 | var letters = FeatureHelper.GetLettersInside(bbox, page.Letters).ToList(); 152 | var paths = FeatureHelper.GetPathsInside(bbox, page.ExperimentalAccess.Paths).ToList(); 153 | var images = FeatureHelper.GetImagesInside(bbox, page.GetImages()); 154 | var f = FeatureHelper.GetFeatures(page, bbox, letters, paths, images); 155 | 156 | if (category == -1) 157 | { 158 | throw new ArgumentException("Unknown category number."); 159 | } 160 | 161 | if (f != null) 162 | { 163 | localFeatures.Add(f); 164 | localCategories.Add(category); 165 | } 166 | } 167 | } 168 | } 169 | 170 | if (isValidDocument) 171 | { 172 | if (localFeatures.Count != localCategories.Count) 173 | { 174 | throw new ArgumentException("features and categories don't have the same size"); 175 | } 176 | 177 | foreach (var line in localFeatures.Zip(localCategories, (f, c) => string.Join(",", f) + "," + c)) 178 | { 179 | data.Add(line); 180 | } 181 | } 182 | else 183 | { 184 | invalidPdfs.Add(pdfFile.Name); 185 | } 186 | } 187 | catch (Exception ex) 188 | { 189 | Console.ForegroundColor = ConsoleColor.Red; 190 | Console.WriteLine("Error for document '" + fileName + "': " + ex.Message); 191 | Console.ResetColor(); 192 | } 193 | Console.WriteLine(done++); 194 | }); 195 | 196 | List csv = new List() { header }; 197 | csv.AddRange(data); 198 | File.WriteAllLines(outputFullPath, csv); 199 | File.WriteAllLines(outputErrorFullPath, invalidPdfs); 200 | 201 | Console.WriteLine("Done. Csv file saved in " + outputFullPath); 202 | } 203 | 204 | public static string GetDataPath(string fileName) 205 | { 206 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(fileName, "csv")); 207 | } 208 | 209 | private static PageXmlDocument Deserialize(string xmlPath) 210 | { 211 | XmlSerializer serializer = new XmlSerializer(typeof(PageXmlDocument)); 212 | 213 | using (var reader = XmlReader.Create(xmlPath)) 214 | { 215 | return (PageXmlDocument)serializer.Deserialize(reader); 216 | } 217 | } 218 | 219 | private static PdfRectangle ParsePageXmlCoord(string points, double height) 220 | { 221 | string[] pointsStr = points.Split(' '); 222 | 223 | List pdfPoints = new List(); 224 | 225 | foreach (var p in pointsStr) 226 | { 227 | string[] coord = p.Split(','); 228 | pdfPoints.Add(new PdfPoint(double.Parse(coord[0]), height - double.Parse(coord[1]))); 229 | } 230 | 231 | return new PdfRectangle(pdfPoints.Min(p => p.X), pdfPoints.Min(p => p.Y), pdfPoints.Max(p => p.X), pdfPoints.Max(p => p.Y)); 232 | } 233 | 234 | private static int ParseXmlFileName(string xmlFileName) 235 | { 236 | string split = xmlFileName.Split('_')[1].Replace(".xml", ""); 237 | if (int.TryParse(split, out int pageNo)) 238 | { 239 | return pageNo; 240 | } 241 | 242 | throw new ArgumentException("Cannot parse page number"); 243 | } 244 | 245 | /// 246 | /// https://codereview.stackexchange.com/questions/61338/generate-random-numbers-without-repetitions 247 | /// 248 | private static List GenerateRandom(int count, int min, int max) 249 | { 250 | Random random = new Random(42); 251 | 252 | // initialize set S to empty 253 | // for J := N-M + 1 to N do 254 | // T := RandInt(1, J) 255 | // if T is not in S then 256 | // insert T in S 257 | // else 258 | // insert J in S 259 | // 260 | // adapted for C# which does not have an inclusive Next(..) 261 | // and to make it from configurable range not just 1. 262 | 263 | if (max <= min || count < 0 || 264 | // max - min > 0 required to avoid overflow 265 | (count > max - min && max - min > 0)) 266 | { 267 | // need to use 64-bit to support big ranges (negative min, positive max) 268 | throw new ArgumentOutOfRangeException("Range " + min + " to " + max + 269 | " (" + ((Int64)max - (Int64)min) + " values), or count " + count + " is illegal"); 270 | } 271 | 272 | // generate count random values. 273 | HashSet candidates = new HashSet(); 274 | 275 | // start count values before max, and end at max 276 | for (int top = max - count; top < max; top++) 277 | { 278 | // May strike a duplicate. 279 | // Need to add +1 to make inclusive generator 280 | // +1 is safe even for MaxVal max value because top < max 281 | if (!candidates.Add(random.Next(min, top + 1))) 282 | { 283 | // collision, add inclusive max. 284 | // which could not possibly have been added before. 285 | candidates.Add(top); 286 | } 287 | } 288 | 289 | // load them in to a list, to sort 290 | List result = candidates.ToList(); 291 | 292 | // shuffle the results because HashSet has messed 293 | // with the order, and the algorithm does not produce 294 | // random-ordered results (e.g. max-1 will never be the first value) 295 | for (int i = result.Count - 1; i > 0; i--) 296 | { 297 | int k = random.Next(i + 1); 298 | int tmp = result[k]; 299 | result[k] = result[i]; 300 | result[i] = tmp; 301 | } 302 | return result; 303 | } 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.Data/v2/DataGenerator.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Concurrent; 3 | using System.Collections.Generic; 4 | using System.IO; 5 | using System.Linq; 6 | using System.Threading.Tasks; 7 | using System.Xml; 8 | using System.Xml.Serialization; 9 | using UglyToad.PdfPig; 10 | using UglyToad.PdfPig.Core; 11 | using UglyToad.PdfPig.DocumentLayoutAnalysis; 12 | using UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE; 13 | using UglyToad.PdfPig.Outline; 14 | using static UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE.PageXmlDocument; 15 | 16 | namespace PdfPigMLNetBlockClassifier.Data.v2 17 | { 18 | public static class DataGenerator 19 | { 20 | public static readonly string OutputFolderPath = @"../../../data"; 21 | 22 | /// 23 | /// Generate a csv file of features. You will need the pdf documents and the ground truths in PAGE xml format. 24 | /// 25 | /// The path to the data folder. It should contain both the pdf files 26 | /// and their corresponding ground truth xml files. 27 | /// Number of documents to concider. 28 | public static void GetCsv(string dataFolder, int numberOfPdfDocs, string outputFileName) 29 | { 30 | string outputFullPath = GetDataPath(outputFileName); 31 | string outputErrorFullPath = Path.Combine(OutputFolderPath, "invalide_pdfs_" + Path.ChangeExtension(outputFileName, "txt")); 32 | 33 | ConcurrentBag invalidPdfs = new ConcurrentBag(); 34 | ConcurrentBag data = new ConcurrentBag(); 35 | 36 | int done = 0; 37 | 38 | DirectoryInfo d = new DirectoryInfo(dataFolder); 39 | var pdfFileLinks = d.GetFiles("*.pdf", SearchOption.TopDirectoryOnly); 40 | var maxPageNumber = d.GetFiles("*.xml", SearchOption.TopDirectoryOnly).Select(f => ParseXmlFileName(f.Name)).Max() + 1; 41 | 42 | numberOfPdfDocs = Math.Min(pdfFileLinks.Count(), numberOfPdfDocs); 43 | numberOfPdfDocs = numberOfPdfDocs == 0 ? pdfFileLinks.Count() : numberOfPdfDocs; 44 | 45 | var indexesSelected = GenerateRandom(numberOfPdfDocs, 0, pdfFileLinks.Length); 46 | 47 | Parallel.ForEach(indexesSelected, index => 48 | //foreach (var index in indexesSelected) 49 | { 50 | var pdfFile = pdfFileLinks[index]; 51 | string fileName = pdfFile.Name; 52 | string xmlFileNameTemplate = fileName.Replace(".pdf", "_"); 53 | 54 | var pageXmlLinksCandidates = Enumerable.Range(0, maxPageNumber).Select(i => 55 | Path.Combine(dataFolder, fileName.Replace(".pdf", "_" + string.Format("{0:00000}", i) + ".xml"))).ToArray(); 56 | var pageXmlLinks = pageXmlLinksCandidates.Where(l => File.Exists(l)).Select(l => new FileInfo(l)).ToArray(); 57 | 58 | if (pageXmlLinks.Length == 0) 59 | { 60 | Console.BackgroundColor = ConsoleColor.DarkRed; 61 | Console.WriteLine("Error for document '" + fileName + "': No PageXml files found"); 62 | Console.ResetColor(); 63 | return; 64 | } 65 | 66 | try 67 | { 68 | var pagesNumbers = pageXmlLinks.Select(l => ParseXmlFileName(l.Name)).ToList(); 69 | List localFeatures = new List(); 70 | List localCategories = new List(); 71 | bool isValidDocument = true; 72 | 73 | using (var doc = PdfDocument.Open(pdfFile.FullName)) 74 | { 75 | var hasBookmarks = doc.TryGetBookmarks(out Bookmarks bookmarks); 76 | List bookmarksNodes = null; 77 | if (hasBookmarks) bookmarksNodes = bookmarks.GetNodes() 78 | .Where(b => b is DocumentBookmarkNode) 79 | .Select(b => b as DocumentBookmarkNode) 80 | .Cast().ToList(); 81 | 82 | // Checks if this pdf document looks to be valid 83 | if ((pagesNumbers.Max() + 1) > doc.NumberOfPages) 84 | { 85 | // ignore this document as page number is not correct 86 | Console.BackgroundColor = ConsoleColor.Red; 87 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as page number is not correct"); 88 | Console.ResetColor(); 89 | isValidDocument = false; 90 | } 91 | 92 | foreach (var pageXmlLink in pageXmlLinks) 93 | { 94 | if (!isValidDocument) break; 95 | 96 | int pageNo = ParseXmlFileName(pageXmlLink.Name); // base 0 97 | 98 | List pageBookmarksNodes = null; 99 | if (hasBookmarks) 100 | { 101 | pageBookmarksNodes = bookmarksNodes.Where(b => b.PageNumber == pageNo + 1).ToList(); 102 | } 103 | 104 | var page = doc.GetPage(pageNo + 1); 105 | 106 | var avgPageFontHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average(); 107 | 108 | if (page.Rotation.Value != 0) 109 | { 110 | Console.BackgroundColor = ConsoleColor.Yellow; 111 | Console.ForegroundColor = ConsoleColor.Black; 112 | Console.WriteLine("Error for document '" + fileName + "': Ignoring page " + (pageNo + 1) + " because it is rotated"); 113 | Console.ResetColor(); 114 | continue; 115 | } 116 | 117 | var pageXml = Deserialize(pageXmlLink.FullName); 118 | 119 | var blocks = pageXml.Page.Items; 120 | 121 | foreach (var block in blocks) 122 | { 123 | int category = -1; 124 | PdfRectangle bbox = new PdfRectangle(); 125 | 126 | if (block is PageXmlTextRegion pageTextRegion) 127 | { 128 | bbox = ParsePageXmlCoord(pageTextRegion.Coords.Points, page.Height); 129 | switch (pageTextRegion.Type) 130 | { 131 | case PageXmlTextSimpleType.Paragraph: 132 | category = 0; 133 | break; 134 | case PageXmlTextSimpleType.Heading: 135 | category = 1; 136 | break; 137 | case PageXmlTextSimpleType.LisLabel: 138 | category = 2; 139 | break; 140 | default: 141 | throw new ArgumentException("Unknown category"); 142 | } 143 | 144 | if (FeatureHelper.GetLettersInside(bbox, page.Letters).Count() == 0) 145 | { 146 | Console.BackgroundColor = ConsoleColor.Red; 147 | Console.ForegroundColor = ConsoleColor.Black; 148 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as an empty paragraph was found"); 149 | Console.ResetColor(); 150 | isValidDocument = false; 151 | break; 152 | } 153 | } 154 | else if (block is PageXmlTableRegion tableBlock) 155 | { 156 | bbox = ParsePageXmlCoord(tableBlock.Coords.Points, page.Height); 157 | category = 3; 158 | } 159 | else if (block is PageXmlImageRegion imageBlock) 160 | { 161 | bbox = ParsePageXmlCoord(imageBlock.Coords.Points, page.Height); 162 | category = 4; 163 | } 164 | else 165 | { 166 | throw new ArgumentException("Unknown region type"); 167 | } 168 | 169 | TextBlock textBlock = null; 170 | var letters = FeatureHelper.GetLettersInside(bbox, page.Letters).ToList(); 171 | if (letters.Any()) 172 | { 173 | var words = FeatureHelper.GetWords(letters); 174 | var lines = FeatureHelper.GetLines(words); 175 | if (lines != null && lines.Count > 0) 176 | { 177 | textBlock = new TextBlock(lines); 178 | } 179 | } 180 | 181 | var paths = FeatureHelper.GetPathsInside(bbox, page.ExperimentalAccess.Paths).ToList(); 182 | var images = FeatureHelper.GetImagesInside(bbox, page.GetImages()); 183 | var f = FeatureHelper.GetFeatures(textBlock, paths, images, 184 | avgPageFontHeight, bbox.Area, pageBookmarksNodes); 185 | 186 | if (category == -1) 187 | { 188 | throw new ArgumentException("Unknown category number."); 189 | } 190 | 191 | if (f != null) 192 | { 193 | localFeatures.Add(f); 194 | localCategories.Add(category); 195 | } 196 | } 197 | } 198 | } 199 | 200 | if (isValidDocument) 201 | { 202 | if (localFeatures.Count != localCategories.Count) 203 | { 204 | throw new ArgumentException("features and categories don't have the same size"); 205 | } 206 | 207 | foreach (var line in localFeatures.Zip(localCategories, 208 | (f, c) => string.Join(",", f).Replace(float.NaN.ToString(), "") + "," + c)) 209 | { 210 | data.Add(line); 211 | } 212 | } 213 | else 214 | { 215 | invalidPdfs.Add(pdfFile.Name); 216 | } 217 | } 218 | catch (Exception ex) 219 | { 220 | Console.ForegroundColor = ConsoleColor.Red; 221 | Console.WriteLine("Error for document '" + fileName + "': " + ex.Message); 222 | Console.ResetColor(); 223 | } 224 | Console.WriteLine(done++); 225 | }); 226 | 227 | List csv = new List() { FeatureHelper.Header }; 228 | csv.AddRange(data); 229 | File.WriteAllLines(outputFullPath, csv); 230 | File.WriteAllLines(outputErrorFullPath, invalidPdfs); 231 | 232 | Console.WriteLine("Done. Csv file saved in " + outputFullPath); 233 | } 234 | 235 | public static string GetDataPath(string fileName) 236 | { 237 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(fileName, "csv")); 238 | } 239 | 240 | private static PageXmlDocument Deserialize(string xmlPath) 241 | { 242 | XmlSerializer serializer = new XmlSerializer(typeof(PageXmlDocument)); 243 | 244 | using (var reader = XmlReader.Create(xmlPath)) 245 | { 246 | return (PageXmlDocument)serializer.Deserialize(reader); 247 | } 248 | } 249 | 250 | private static PdfRectangle ParsePageXmlCoord(string points, double height) 251 | { 252 | string[] pointsStr = points.Split(' '); 253 | 254 | List pdfPoints = new List(); 255 | 256 | foreach (var p in pointsStr) 257 | { 258 | string[] coord = p.Split(','); 259 | pdfPoints.Add(new PdfPoint(double.Parse(coord[0]), height - double.Parse(coord[1]))); 260 | } 261 | 262 | return new PdfRectangle(pdfPoints.Min(p => p.X), pdfPoints.Min(p => p.Y), pdfPoints.Max(p => p.X), pdfPoints.Max(p => p.Y)); 263 | } 264 | 265 | private static int ParseXmlFileName(string xmlFileName) 266 | { 267 | string split = xmlFileName.Split('_')[1].Replace(".xml", ""); 268 | if (int.TryParse(split, out int pageNo)) 269 | { 270 | return pageNo; 271 | } 272 | 273 | throw new ArgumentException("Cannot parse page number"); 274 | } 275 | 276 | /// 277 | /// https://codereview.stackexchange.com/questions/61338/generate-random-numbers-without-repetitions 278 | /// 279 | private static List GenerateRandom(int count, int min, int max) 280 | { 281 | Random random = new Random(42); 282 | 283 | // initialize set S to empty 284 | // for J := N-M + 1 to N do 285 | // T := RandInt(1, J) 286 | // if T is not in S then 287 | // insert T in S 288 | // else 289 | // insert J in S 290 | // 291 | // adapted for C# which does not have an inclusive Next(..) 292 | // and to make it from configurable range not just 1. 293 | 294 | if (max <= min || count < 0 || 295 | // max - min > 0 required to avoid overflow 296 | (count > max - min && max - min > 0)) 297 | { 298 | // need to use 64-bit to support big ranges (negative min, positive max) 299 | throw new ArgumentOutOfRangeException("Range " + min + " to " + max + 300 | " (" + ((Int64)max - (Int64)min) + " values), or count " + count + " is illegal"); 301 | } 302 | 303 | // generate count random values. 304 | HashSet candidates = new HashSet(); 305 | 306 | // start count values before max, and end at max 307 | for (int top = max - count; top < max; top++) 308 | { 309 | // May strike a duplicate. 310 | // Need to add +1 to make inclusive generator 311 | // +1 is safe even for MaxVal max value because top < max 312 | if (!candidates.Add(random.Next(min, top + 1))) 313 | { 314 | // collision, add inclusive max. 315 | // which could not possibly have been added before. 316 | candidates.Add(top); 317 | } 318 | } 319 | 320 | // load them in to a list, to sort 321 | List result = candidates.ToList(); 322 | 323 | // shuffle the results because HashSet has messed 324 | // with the order, and the algorithm does not produce 325 | // random-ordered results (e.g. max-1 will never be the first value) 326 | for (int i = result.Count - 1; i > 0; i--) 327 | { 328 | int k = random.Next(i + 1); 329 | int tmp = result[k]; 330 | result[k] = result[i]; 331 | result[i] = tmp; 332 | } 333 | return result; 334 | } 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbmV2/LightGbmModelBuilder.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML; 2 | using Microsoft.ML.Data; 3 | using Microsoft.ML.Trainers; 4 | using System; 5 | using System.Collections.Generic; 6 | using System.IO; 7 | using System.Linq; 8 | 9 | namespace PdfPigMLNetBlockClassifier.LightGbmV2 10 | { 11 | public static class LightGbmModelBuilder 12 | { 13 | public static readonly string OutputFolderPath = @"../../../model"; 14 | 15 | private static MLContext mlContext = new MLContext(seed: 1); 16 | 17 | public static void TrainModel(string trainDataFilePath, string outputModelName) 18 | { 19 | string outputFullPath = Path.Combine(OutputFolderPath, Path.ChangeExtension(outputModelName, "zip")); 20 | 21 | // Load Data 22 | IDataView trainingDataView = mlContext.Data.LoadFromTextFile( 23 | path: trainDataFilePath, 24 | hasHeader: true, 25 | separatorChar: ',', 26 | allowQuoting: true, 27 | allowSparse: false); 28 | 29 | // Build training pipeline 30 | IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); 31 | 32 | // Evaluate quality of Model 33 | CrossValidate(mlContext, trainingDataView, trainingPipeline); 34 | 35 | // Train Model 36 | ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); 37 | 38 | // Save model 39 | SaveModel(mlContext, mlModel, outputFullPath, trainingDataView.Schema); 40 | } 41 | 42 | public static void Evaluate(string modelName, string testDataFilePath) 43 | { 44 | string modelFullPath = GetModelPath(modelName); 45 | 46 | // Create new MLContext 47 | MLContext mlContext = new MLContext(); 48 | 49 | // Load model & create prediction engine 50 | ITransformer mlModel = mlContext.Model.Load(modelFullPath, out var modelInputSchema); 51 | var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); 52 | 53 | // Load Data 54 | IDataView testingDataView = mlContext.Data.LoadFromTextFile( 55 | path: testDataFilePath, 56 | hasHeader: true, 57 | separatorChar: ',', 58 | allowQuoting: true, 59 | allowSparse: false); 60 | 61 | IDataView transformedTestingDataView = mlModel.Transform(testingDataView); 62 | 63 | Evaluate(mlContext, transformedTestingDataView); 64 | 65 | // Permutation Feature Importance 66 | // https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_ 67 | 68 | Console.WriteLine("=============== Permutation Feature Importance ==============="); 69 | Console.WriteLine(@"PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that 70 | feature across all the examples, so that each example now has a random value for the feature 71 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is 72 | then calculated for this modified dataset, and the change in the evaluation metric from the 73 | original dataset is computed. The larger the change in the evaluation metric, the more 74 | important the feature is to the model. PFI works by performing this permutation analysis 75 | across all the features of a model, one after another.\n"); 76 | 77 | // Get the column name of input features. 78 | string[] featureColumns = testingDataView.Schema.Select(column => column.Name) 79 | .Where(columnName => columnName != "label").ToArray(); 80 | 81 | var predictor = ((mlModel as TransformerChain).LastTransformer as TransformerChain) 82 | .First() as MulticlassPredictionTransformer; 83 | 84 | var pfi = mlContext.MulticlassClassification.PermutationFeatureImportance( 85 | predictor, 86 | transformedTestingDataView, 87 | labelColumnName: "label", 88 | permutationCount: 30); 89 | 90 | // Now let's look at which features are most important to the model 91 | // overall. Get the feature indices sorted by their impact on 92 | // microaccuracy. 93 | var sortedIndicesMicro = pfi.Select((metrics, index) => new { index, metrics.MicroAccuracy }) 94 | .OrderByDescending(feature => Math.Abs(feature.MicroAccuracy.Mean)) 95 | .Select(feature => feature.index); 96 | 97 | Console.WriteLine("Feature\tChange in MicroAccuracy\t95% Confidence in the Mean Change in MicroAccuracy"); 98 | var microAccuracy = pfi.Select(x => x.MicroAccuracy).ToArray(); 99 | 100 | foreach (int i in sortedIndicesMicro) 101 | { 102 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}", 103 | featureColumns[i], 104 | microAccuracy[i].Mean, 105 | 1.96 * microAccuracy[i].StandardError); 106 | } 107 | 108 | Console.WriteLine(); 109 | 110 | // Now let's look at which features are most important to the model 111 | // overall. Get the feature indices sorted by their impact on 112 | // macroaccuracy. 113 | var sortedIndicesMacro = pfi.Select((metrics, index) => new { index, metrics.MacroAccuracy }) 114 | .OrderByDescending(feature => Math.Abs(feature.MacroAccuracy.Mean)) 115 | .Select(feature => feature.index); 116 | 117 | Console.WriteLine("Feature\tChange in MacroAccuracy\t95% Confidence in the Mean Change in MacroAccuracy"); 118 | var macroAccuracy = pfi.Select(x => x.MacroAccuracy).ToArray(); 119 | 120 | foreach (int i in sortedIndicesMacro) 121 | { 122 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}", 123 | featureColumns[i], 124 | macroAccuracy[i].Mean, 125 | 1.96 * macroAccuracy[i].StandardError); 126 | } 127 | } 128 | 129 | public static string GetModelPath(string modelName) 130 | { 131 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(modelName, "zip")); 132 | } 133 | 134 | private static IEstimator BuildTrainingPipeline(MLContext mlContext) 135 | { 136 | // Data process configuration with pipeline data transformations 137 | var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("label", "label") 138 | .Append(mlContext.Transforms.Concatenate("Features", 139 | new[] 140 | { 141 | "blockAspectRatio", "charsCount", 142 | "wordsCount", "linesCount", 143 | "pctNumericChars", "pctAlphabeticalChars", 144 | "pctSymbolicChars", "pctBulletChars", 145 | "deltaToHeight", "pathsCount", 146 | "pctBezierPaths", "pctHorPaths", 147 | "pctVertPaths", "pctOblPaths", 148 | "imagesCount", "imageAvgProportion", 149 | "bestNormEditDistance" 150 | })); 151 | 152 | // Set the training algorithm 153 | var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(labelColumnName: "label", featureColumnName: "Features") 154 | .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel")); 155 | var trainingPipeline = dataProcessPipeline.Append(trainer); 156 | 157 | return trainingPipeline; 158 | } 159 | 160 | private static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) 161 | { 162 | Console.WriteLine("=============== Training model ==============="); 163 | 164 | ITransformer model = trainingPipeline.Fit(trainingDataView); 165 | 166 | Console.WriteLine("=============== End of training process ==============="); 167 | return model; 168 | } 169 | 170 | private static void CrossValidate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) 171 | { 172 | // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate) 173 | // in order to evaluate and get the model's accuracy metrics 174 | Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ==============="); 175 | var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "label"); 176 | PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults); 177 | } 178 | 179 | private static void Evaluate(MLContext mlContext, IDataView testingDataView) 180 | { 181 | Console.WriteLine("=============== Evaluating to get model's accuracy metrics ==============="); 182 | var evaluationResults = mlContext.MulticlassClassification.Evaluate(testingDataView, labelColumnName: "label"); 183 | PrintMulticlassClassificationMetrics(evaluationResults); 184 | } 185 | 186 | private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) 187 | { 188 | // Save/persist the trained model to a .ZIP file 189 | Console.WriteLine($"=============== Saving the model ==============="); 190 | mlContext.Model.Save(mlModel, modelInputSchema, modelRelativePath); 191 | Console.WriteLine("The model is saved to {0}", modelRelativePath); 192 | } 193 | 194 | private static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics) 195 | { 196 | Console.WriteLine($"************************************************************"); 197 | Console.WriteLine($"* Metrics for multi-class classification model "); 198 | Console.WriteLine($"*-----------------------------------------------------------"); 199 | Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); 200 | Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); 201 | Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); 202 | for (int i = 0; i < metrics.PerClassLogLoss.Count; i++) 203 | { 204 | Console.WriteLine($" LogLoss for class {i} \t= {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better"); 205 | } 206 | 207 | Console.WriteLine(" " + metrics.ConfusionMatrix.GetFormattedConfusionTable()); 208 | 209 | for (int i = 0; i < metrics.ConfusionMatrix.PerClassPrecision.Count; i++) 210 | { 211 | var precision = metrics.ConfusionMatrix.PerClassPrecision[i]; 212 | var recall = metrics.ConfusionMatrix.PerClassRecall[i]; 213 | var f1Score = 2 * (precision * recall) / (precision + recall); 214 | Console.WriteLine($" F1 Score for class {i} \t= {f1Score:0.####}, a value between 0 and 1, the closer to 1, the better"); 215 | } 216 | 217 | Console.WriteLine($"************************************************************"); 218 | } 219 | 220 | private static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults) 221 | { 222 | var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); 223 | 224 | var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy); 225 | var microAccuracyAverage = microAccuracyValues.Average(); 226 | var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues); 227 | var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues); 228 | 229 | var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy); 230 | var macroAccuracyAverage = macroAccuracyValues.Average(); 231 | var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues); 232 | var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues); 233 | 234 | var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss); 235 | var logLossAverage = logLossValues.Average(); 236 | var logLossStdDeviation = CalculateStandardDeviation(logLossValues); 237 | var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues); 238 | 239 | var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction); 240 | var logLossReductionAverage = logLossReductionValues.Average(); 241 | var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues); 242 | var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues); 243 | 244 | Console.WriteLine($"*************************************************************************************************************"); 245 | Console.WriteLine($"* Metrics for Multi-class Classification model "); 246 | Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); 247 | Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})"); 248 | Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})"); 249 | Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})"); 250 | Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})"); 251 | Console.WriteLine($"*************************************************************************************************************"); 252 | } 253 | 254 | private static double CalculateStandardDeviation(IEnumerable values) 255 | { 256 | double average = values.Average(); 257 | double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); 258 | double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); 259 | return standardDeviation; 260 | } 261 | 262 | private static double CalculateConfidenceInterval95(IEnumerable values) 263 | { 264 | double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); 265 | return confidenceInterval95; 266 | } 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /PdfPigMLNetBlockClassifier.LightGbm/LightGbmModelBuilder.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML; 2 | using Microsoft.ML.Data; 3 | using Microsoft.ML.Trainers; 4 | using Microsoft.ML.Trainers.LightGbm; 5 | using System; 6 | using System.Collections.Generic; 7 | using System.IO; 8 | using System.Linq; 9 | 10 | namespace PdfPigMLNetBlockClassifier.LightGbm 11 | { 12 | public static class LightGbmModelBuilder 13 | { 14 | public static readonly string OutputFolderPath = @"../../../model"; 15 | 16 | private static MLContext mlContext = new MLContext(seed: 1); 17 | 18 | public static void TrainModel(string trainDataFilePath, string outputModelName) 19 | { 20 | string outputFullPath = Path.Combine(OutputFolderPath, Path.ChangeExtension(outputModelName, "zip")); 21 | 22 | // Load Data 23 | IDataView trainingDataView = mlContext.Data.LoadFromTextFile( 24 | path: trainDataFilePath, 25 | hasHeader: true, 26 | separatorChar: ',', 27 | allowQuoting: true, 28 | allowSparse: false); 29 | 30 | // Build training pipeline 31 | IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); 32 | 33 | // Evaluate quality of Model 34 | CrossValidate(mlContext, trainingDataView, trainingPipeline); 35 | 36 | // Train Model 37 | ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); 38 | 39 | // Save model 40 | SaveModel(mlContext, mlModel, outputFullPath, trainingDataView.Schema); 41 | } 42 | 43 | public static void Evaluate(string modelName, string testDataFilePath) 44 | { 45 | string modelFullPath = GetModelPath(modelName); 46 | 47 | // Create new MLContext 48 | MLContext mlContext = new MLContext(); 49 | 50 | // Load model & create prediction engine 51 | ITransformer mlModel = mlContext.Model.Load(modelFullPath, out var modelInputSchema); 52 | var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); 53 | 54 | // Load Data 55 | IDataView testingDataView = mlContext.Data.LoadFromTextFile( 56 | path: testDataFilePath, 57 | hasHeader: true, 58 | separatorChar: ',', 59 | allowQuoting: true, 60 | allowSparse: false); 61 | 62 | IDataView transformedTestingDataView = mlModel.Transform(testingDataView); 63 | 64 | Evaluate(mlContext, transformedTestingDataView); 65 | 66 | // Permutation Feature Importance 67 | // https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_ 68 | 69 | Console.WriteLine("=============== Permutation Feature Importance ==============="); 70 | Console.WriteLine(@"PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that 71 | feature across all the examples, so that each example now has a random value for the feature 72 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is 73 | then calculated for this modified dataset, and the change in the evaluation metric from the 74 | original dataset is computed. The larger the change in the evaluation metric, the more 75 | important the feature is to the model. PFI works by performing this permutation analysis 76 | across all the features of a model, one after another.\n"); 77 | 78 | // Get the column name of input features. 79 | string[] featureColumns = testingDataView.Schema.Select(column => column.Name) 80 | .Where(columnName => columnName != "label").ToArray(); 81 | 82 | var predictor = ((mlModel as TransformerChain).LastTransformer as TransformerChain) 83 | .First() as MulticlassPredictionTransformer; 84 | 85 | var pfi = mlContext.MulticlassClassification.PermutationFeatureImportance( 86 | predictor, 87 | transformedTestingDataView, 88 | labelColumnName: "label", 89 | permutationCount: 30); 90 | 91 | // Now let's look at which features are most important to the model 92 | // overall. Get the feature indices sorted by their impact on 93 | // microaccuracy. 94 | var sortedIndicesMicro = pfi.Select((metrics, index) => new { index, metrics.MicroAccuracy }) 95 | .OrderByDescending(feature => Math.Abs(feature.MicroAccuracy.Mean)) 96 | .Select(feature => feature.index); 97 | 98 | Console.WriteLine("Feature\tChange in MicroAccuracy\t95% Confidence in the Mean Change in MicroAccuracy"); 99 | var microAccuracy = pfi.Select(x => x.MicroAccuracy).ToArray(); 100 | 101 | foreach (int i in sortedIndicesMicro) 102 | { 103 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}", 104 | featureColumns[i], 105 | microAccuracy[i].Mean, 106 | 1.96 * microAccuracy[i].StandardError); 107 | } 108 | 109 | Console.WriteLine(); 110 | 111 | // Now let's look at which features are most important to the model 112 | // overall. Get the feature indices sorted by their impact on 113 | // macroaccuracy. 114 | var sortedIndicesMacro = pfi.Select((metrics, index) => new { index, metrics.MacroAccuracy }) 115 | .OrderByDescending(feature => Math.Abs(feature.MacroAccuracy.Mean)) 116 | .Select(feature => feature.index); 117 | 118 | Console.WriteLine("Feature\tChange in MacroAccuracy\t95% Confidence in the Mean Change in MacroAccuracy"); 119 | var macroAccuracy = pfi.Select(x => x.MacroAccuracy).ToArray(); 120 | 121 | foreach (int i in sortedIndicesMacro) 122 | { 123 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}", 124 | featureColumns[i], 125 | macroAccuracy[i].Mean, 126 | 1.96 * macroAccuracy[i].StandardError); 127 | } 128 | } 129 | 130 | public static string GetModelPath(string modelName) 131 | { 132 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(modelName, "zip")); 133 | } 134 | 135 | private static IEstimator BuildTrainingPipeline(MLContext mlContext) 136 | { 137 | // Data process configuration with pipeline data transformations 138 | var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("label", "label") 139 | .Append(mlContext.Transforms.Concatenate("Features", new[] { "charsCount", "pctNumericChars", "pctAlphabeticalChars", 140 | "pctSymbolicChars", "pctBulletChars", "deltaToHeight", 141 | "pathsCount", "pctBezierPaths", "pctHorPaths", 142 | "pctVertPaths", "pctOblPaths", "imagesCount", 143 | "imageAvgProportion" })); 144 | 145 | // Set the training algorithm 146 | var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options() 147 | { 148 | NumberOfIterations = 150, 149 | LearningRate = 0.1158737f, 150 | NumberOfLeaves = 39, 151 | MinimumExampleCountPerLeaf = 50, 152 | UseCategoricalSplit = true, 153 | HandleMissingValue = false, 154 | MinimumExampleCountPerGroup = 50, 155 | MaximumCategoricalSplitPointCount = 32, 156 | CategoricalSmoothing = 10, 157 | L2CategoricalRegularization = 1, 158 | UseSoftmax = false, 159 | Booster = new GradientBooster.Options() 160 | { 161 | L2Regularization = 0, 162 | L1Regularization = 0 163 | }, 164 | LabelColumnName = "label", 165 | FeatureColumnName = "Features", 166 | //UnbalancedSets = true, // added by BobLd 167 | }).Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel")); 168 | 169 | var trainingPipeline = dataProcessPipeline.Append(trainer); 170 | 171 | return trainingPipeline; 172 | } 173 | 174 | private static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) 175 | { 176 | Console.WriteLine("=============== Training model ==============="); 177 | 178 | ITransformer model = trainingPipeline.Fit(trainingDataView); 179 | 180 | Console.WriteLine("=============== End of training process ==============="); 181 | return model; 182 | } 183 | 184 | private static void CrossValidate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) 185 | { 186 | // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate) 187 | // in order to evaluate and get the model's accuracy metrics 188 | Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ==============="); 189 | var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "label"); 190 | PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults); 191 | } 192 | 193 | private static void Evaluate(MLContext mlContext, IDataView testingDataView) 194 | { 195 | Console.WriteLine("=============== Evaluating to get model's accuracy metrics ==============="); 196 | var evaluationResults = mlContext.MulticlassClassification.Evaluate(testingDataView, labelColumnName: "label"); 197 | PrintMulticlassClassificationMetrics(evaluationResults); 198 | } 199 | 200 | private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) 201 | { 202 | // Save/persist the trained model to a .ZIP file 203 | Console.WriteLine($"=============== Saving the model ==============="); 204 | mlContext.Model.Save(mlModel, modelInputSchema, modelRelativePath); 205 | Console.WriteLine("The model is saved to {0}", modelRelativePath); 206 | } 207 | 208 | private static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics) 209 | { 210 | Console.WriteLine($"************************************************************"); 211 | Console.WriteLine($"* Metrics for multi-class classification model "); 212 | Console.WriteLine($"*-----------------------------------------------------------"); 213 | Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); 214 | Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); 215 | Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); 216 | for (int i = 0; i < metrics.PerClassLogLoss.Count; i++) 217 | { 218 | Console.WriteLine($" LogLoss for class {i} \t= {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better"); 219 | } 220 | 221 | Console.WriteLine(" " + metrics.ConfusionMatrix.GetFormattedConfusionTable()); 222 | 223 | for (int i = 0; i < metrics.ConfusionMatrix.PerClassPrecision.Count; i++) 224 | { 225 | var precision = metrics.ConfusionMatrix.PerClassPrecision[i]; 226 | var recall = metrics.ConfusionMatrix.PerClassRecall[i]; 227 | var f1Score = 2 * (precision * recall) / (precision + recall); 228 | Console.WriteLine($" F1 Score for class {i} \t= {f1Score:0.####}, a value between 0 and 1, the closer to 1, the better"); 229 | } 230 | 231 | Console.WriteLine($"************************************************************"); 232 | } 233 | 234 | private static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults) 235 | { 236 | var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); 237 | 238 | var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy); 239 | var microAccuracyAverage = microAccuracyValues.Average(); 240 | var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues); 241 | var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues); 242 | 243 | var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy); 244 | var macroAccuracyAverage = macroAccuracyValues.Average(); 245 | var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues); 246 | var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues); 247 | 248 | var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss); 249 | var logLossAverage = logLossValues.Average(); 250 | var logLossStdDeviation = CalculateStandardDeviation(logLossValues); 251 | var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues); 252 | 253 | var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction); 254 | var logLossReductionAverage = logLossReductionValues.Average(); 255 | var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues); 256 | var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues); 257 | 258 | Console.WriteLine($"*************************************************************************************************************"); 259 | Console.WriteLine($"* Metrics for Multi-class Classification model "); 260 | Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); 261 | Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})"); 262 | Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})"); 263 | Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})"); 264 | Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})"); 265 | Console.WriteLine($"*************************************************************************************************************"); 266 | } 267 | 268 | private static double CalculateStandardDeviation(IEnumerable values) 269 | { 270 | double average = values.Average(); 271 | double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); 272 | double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); 273 | return standardDeviation; 274 | } 275 | 276 | private static double CalculateConfidenceInterval95(IEnumerable values) 277 | { 278 | double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); 279 | return confidenceInterval95; 280 | } 281 | } 282 | } 283 | --------------------------------------------------------------------------------