├── PdfPigMLNetBlockClassifier
├── sample.pdf
├── model
│ ├── model.zip
│ ├── modelV2.zip
│ └── results.txt
├── PdfPigMLNetBlockClassifier.csproj
└── Program.cs
├── PdfPigMLNetBlockClassifier.Data
├── PdfPigMLNetBlockClassifier.Data.csproj
├── v1
│ ├── FeatureHelper.cs
│ └── DataGenerator.cs
└── v2
│ ├── FeatureHelper.cs
│ └── DataGenerator.cs
├── PdfPigMLNetBlockClassifier.LightGbm
├── BlockCategory.cs
├── PdfPigMLNetBlockClassifier.LightGbm.csproj
├── BlockFeatures.cs
├── LightGbmBlockClassifier.cs
├── README.md
└── LightGbmModelBuilder.cs
├── PdfPigMLNetBlockClassifier.LightGbmV2
├── BlockCategory.cs
├── PdfPigMLNetBlockClassifier.LightGbmV2.csproj
├── BlockFeatures.cs
├── LightGbmBlockClassifier.cs
└── LightGbmModelBuilder.cs
├── .gitattributes
├── PdfPigMLNetBlockClassifier.sln
├── .gitignore
└── README.md
/PdfPigMLNetBlockClassifier/sample.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobLd/PdfPigMLNetBlockClassifier/HEAD/PdfPigMLNetBlockClassifier/sample.pdf
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier/model/model.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobLd/PdfPigMLNetBlockClassifier/HEAD/PdfPigMLNetBlockClassifier/model/model.zip
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier/model/modelV2.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobLd/PdfPigMLNetBlockClassifier/HEAD/PdfPigMLNetBlockClassifier/model/modelV2.zip
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.Data/PdfPigMLNetBlockClassifier.Data.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | netstandard2.0
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbm/BlockCategory.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.Data;
2 | using System;
3 |
4 | namespace PdfPigMLNetBlockClassifier.LightGbm
5 | {
6 | public class BlockCategory
7 | {
8 | // ColumnName attribute is used to change the column name from
9 | // its default value, which is the name of the field.
10 | [ColumnName("PredictedLabel")]
11 | public Single Prediction { get; set; }
12 |
13 | public float[] Score { get; set; }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbmV2/BlockCategory.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.Data;
2 | using System;
3 |
4 | namespace PdfPigMLNetBlockClassifier.LightGbmV2
5 | {
6 | public class BlockCategory
7 | {
8 | // ColumnName attribute is used to change the column name from
9 | // its default value, which is the name of the field.
10 | [ColumnName("PredictedLabel")]
11 | public Single Prediction { get; set; }
12 |
13 | public float[] Score { get; set; }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbm/PdfPigMLNetBlockClassifier.LightGbm.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | netstandard2.0
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbmV2/PdfPigMLNetBlockClassifier.LightGbmV2.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | netstandard2.0
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier/PdfPigMLNetBlockClassifier.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | netcoreapp2.1
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | Always
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbm/BlockFeatures.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.Data;
2 |
3 | namespace PdfPigMLNetBlockClassifier.LightGbm
4 | {
5 | public class BlockFeatures
6 | {
7 | [ColumnName("charsCount"), LoadColumn(0)]
8 | public float CharsCount { get; set; }
9 |
10 |
11 | [ColumnName("pctNumericChars"), LoadColumn(1)]
12 | public float PctNumericChars { get; set; }
13 |
14 |
15 | [ColumnName("pctAlphabeticalChars"), LoadColumn(2)]
16 | public float PctAlphabeticalChars { get; set; }
17 |
18 |
19 | [ColumnName("pctSymbolicChars"), LoadColumn(3)]
20 | public float PctSymbolicChars { get; set; }
21 |
22 |
23 | [ColumnName("pctBulletChars"), LoadColumn(4)]
24 | public float PctBulletChars { get; set; }
25 |
26 |
27 | [ColumnName("deltaToHeight"), LoadColumn(5)]
28 | public float DeltaToHeight { get; set; }
29 |
30 |
31 | [ColumnName("pathsCount"), LoadColumn(6)]
32 | public float PathsCount { get; set; }
33 |
34 |
35 | [ColumnName("pctBezierPaths"), LoadColumn(7)]
36 | public float PctBezierPaths { get; set; }
37 |
38 |
39 | [ColumnName("pctHorPaths"), LoadColumn(8)]
40 | public float PctHorPaths { get; set; }
41 |
42 |
43 | [ColumnName("pctVertPaths"), LoadColumn(9)]
44 | public float PctVertPaths { get; set; }
45 |
46 |
47 | [ColumnName("pctOblPaths"), LoadColumn(10)]
48 | public float PctOblPaths { get; set; }
49 |
50 |
51 | [ColumnName("imagesCount"), LoadColumn(11)]
52 | public float ImagesCount { get; set; }
53 |
54 |
55 | [ColumnName("imageAvgProportion"), LoadColumn(12)]
56 | public float ImageAvgProportion { get; set; }
57 |
58 |
59 | [ColumnName("label"), LoadColumn(13)]
60 | public float Label { get; set; }
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbmV2/BlockFeatures.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.Data;
2 |
3 | namespace PdfPigMLNetBlockClassifier.LightGbmV2
4 | {
5 | public class BlockFeatures
6 | {
7 | [ColumnName("blockAspectRatio"), LoadColumn(0)]
8 | public float BlockAspectRatio { get; set; }
9 |
10 |
11 | [ColumnName("charsCount"), LoadColumn(1)]
12 | public float CharsCount { get; set; }
13 |
14 |
15 | [ColumnName("wordsCount"), LoadColumn(2)]
16 | public float WordsCount { get; set; }
17 |
18 |
19 | [ColumnName("linesCount"), LoadColumn(3)]
20 | public float LinesCount { get; set; }
21 |
22 |
23 | [ColumnName("pctNumericChars"), LoadColumn(4)]
24 | public float PctNumericChars { get; set; }
25 |
26 |
27 | [ColumnName("pctAlphabeticalChars"), LoadColumn(5)]
28 | public float PctAlphabeticalChars { get; set; }
29 |
30 |
31 | [ColumnName("pctSymbolicChars"), LoadColumn(6)]
32 | public float PctSymbolicChars { get; set; }
33 |
34 |
35 | [ColumnName("pctBulletChars"), LoadColumn(7)]
36 | public float PctBulletChars { get; set; }
37 |
38 |
39 | [ColumnName("deltaToHeight"), LoadColumn(8)]
40 | public float DeltaToHeight { get; set; }
41 |
42 |
43 | [ColumnName("pathsCount"), LoadColumn(9)]
44 | public float PathsCount { get; set; }
45 |
46 |
47 | [ColumnName("pctBezierPaths"), LoadColumn(10)]
48 | public float PctBezierPaths { get; set; }
49 |
50 |
51 | [ColumnName("pctHorPaths"), LoadColumn(11)]
52 | public float PctHorPaths { get; set; }
53 |
54 |
55 | [ColumnName("pctVertPaths"), LoadColumn(12)]
56 | public float PctVertPaths { get; set; }
57 |
58 |
59 | [ColumnName("pctOblPaths"), LoadColumn(13)]
60 | public float PctOblPaths { get; set; }
61 |
62 |
63 | [ColumnName("imagesCount"), LoadColumn(14)]
64 | public float ImagesCount { get; set; }
65 |
66 |
67 | [ColumnName("imageAvgProportion"), LoadColumn(15)]
68 | public float ImageAvgProportion { get; set; }
69 |
70 |
71 | [ColumnName("bestNormEditDistance"), LoadColumn(16)]
72 | public float BestNormEditDistance { get; set; }
73 |
74 |
75 | [ColumnName("label"), LoadColumn(17)]
76 | public float Label { get; set; }
77 |
78 |
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Set default behavior to automatically normalize line endings.
3 | ###############################################################################
4 | * text=auto
5 |
6 | ###############################################################################
7 | # Set default behavior for command prompt diff.
8 | #
9 | # This is need for earlier builds of msysgit that does not have it on by
10 | # default for csharp files.
11 | # Note: This is only used by command line
12 | ###############################################################################
13 | #*.cs diff=csharp
14 |
15 | ###############################################################################
16 | # Set the merge driver for project and solution files
17 | #
18 | # Merging from the command prompt will add diff markers to the files if there
19 | # are conflicts (Merging from VS is not affected by the settings below, in VS
20 | # the diff markers are never inserted). Diff markers may cause the following
21 | # file extensions to fail to load in VS. An alternative would be to treat
22 | # these files as binary and thus will always conflict and require user
23 | # intervention with every merge. To do so, just uncomment the entries below
24 | ###############################################################################
25 | #*.sln merge=binary
26 | #*.csproj merge=binary
27 | #*.vbproj merge=binary
28 | #*.vcxproj merge=binary
29 | #*.vcproj merge=binary
30 | #*.dbproj merge=binary
31 | #*.fsproj merge=binary
32 | #*.lsproj merge=binary
33 | #*.wixproj merge=binary
34 | #*.modelproj merge=binary
35 | #*.sqlproj merge=binary
36 | #*.wwaproj merge=binary
37 |
38 | ###############################################################################
39 | # behavior for image files
40 | #
41 | # image files are treated as binary by default.
42 | ###############################################################################
43 | #*.jpg binary
44 | #*.png binary
45 | #*.gif binary
46 |
47 | ###############################################################################
48 | # diff behavior for common document formats
49 | #
50 | # Convert binary document formats to text before diffing them. This feature
51 | # is only available from the command line. Turn it on by uncommenting the
52 | # entries below.
53 | ###############################################################################
54 | #*.doc diff=astextplain
55 | #*.DOC diff=astextplain
56 | #*.docx diff=astextplain
57 | #*.DOCX diff=astextplain
58 | #*.dot diff=astextplain
59 | #*.DOT diff=astextplain
60 | #*.pdf diff=astextplain
61 | #*.PDF diff=astextplain
62 | #*.rtf diff=astextplain
63 | #*.RTF diff=astextplain
64 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbm/LightGbmBlockClassifier.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML;
2 | using PdfPigMLNetBlockClassifier.Data.v1;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 | using UglyToad.PdfPig.Content;
6 | using UglyToad.PdfPig.DocumentLayoutAnalysis;
7 | using UglyToad.PdfPig.DocumentLayoutAnalysis.PageSegmenter;
8 | using UglyToad.PdfPig.Util;
9 |
10 | namespace PdfPigMLNetBlockClassifier.LightGbm
11 | {
12 | public class LightGbmBlockClassifier
13 | {
14 | private MLContext mlContext;
15 | private ITransformer mlModel;
16 | private PredictionEngine predEngine;
17 |
18 | public LightGbmBlockClassifier(string modelPath)
19 | {
20 | mlContext = new MLContext();
21 | mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
22 | predEngine = mlContext.Model.CreatePredictionEngine(mlModel);
23 | }
24 |
25 | public IEnumerable<(string Prediction, float Score, TextBlock Block)> Classify(Page page, IWordExtractor wordExtractor, IPageSegmenter pageSegmenter)
26 | {
27 | var words = wordExtractor.GetWords(page.Letters);
28 | var blocks = pageSegmenter.GetBlocks(words);
29 |
30 | foreach (var block in blocks)
31 | {
32 | var letters = block.TextLines.SelectMany(li => li.Words).SelectMany(w => w.Letters);
33 | var paths = FeatureHelper.GetPathsInside(block.BoundingBox, page.ExperimentalAccess.Paths);
34 | var images = FeatureHelper.GetImagesInside(block.BoundingBox, page.GetImages());
35 | var features = FeatureHelper.GetFeatures(page, block.BoundingBox, letters, paths, images);
36 |
37 | BlockFeatures blockFeatures = new BlockFeatures()
38 | {
39 | CharsCount = features[0],
40 | PctNumericChars = features[1],
41 | PctAlphabeticalChars = features[2],
42 | PctSymbolicChars = features[3],
43 | PctBulletChars = features[4],
44 | DeltaToHeight = features[5],
45 | PathsCount = features[6],
46 | PctBezierPaths = features[7],
47 | PctHorPaths = features[8],
48 | PctVertPaths = features[9],
49 | PctOblPaths = features[10],
50 | ImagesCount = features[11],
51 | ImageAvgProportion = features[12]
52 | };
53 |
54 | var result = predEngine.Predict(blockFeatures);
55 |
56 | yield return (FeatureHelper.Categories[(int)result.Prediction], result.Score.Max(), block);
57 | }
58 | }
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 15
4 | VisualStudioVersion = 15.0.28307.852
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "PdfPigMLNetBlockClassifier", "PdfPigMLNetBlockClassifier\PdfPigMLNetBlockClassifier.csproj", "{1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}"
7 | EndProject
8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "PdfPigMLNetBlockClassifier.LightGbm", "PdfPigMLNetBlockClassifier.LightGbm\PdfPigMLNetBlockClassifier.LightGbm.csproj", "{42436C8B-339F-4020-A52D-F6787C267483}"
9 | EndProject
10 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "PdfPigMLNetBlockClassifier.Data", "PdfPigMLNetBlockClassifier.Data\PdfPigMLNetBlockClassifier.Data.csproj", "{D4147008-F6C7-4528-9621-219F41BC12DB}"
11 | EndProject
12 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PdfPigMLNetBlockClassifier.LightGbmV2", "PdfPigMLNetBlockClassifier.LightGbmV2\PdfPigMLNetBlockClassifier.LightGbmV2.csproj", "{3267E733-7D48-4A11-AB73-1261ACBC8681}"
13 | EndProject
14 | Global
15 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
16 | Debug|Any CPU = Debug|Any CPU
17 | Release|Any CPU = Release|Any CPU
18 | EndGlobalSection
19 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
20 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
21 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Debug|Any CPU.Build.0 = Debug|Any CPU
22 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Release|Any CPU.ActiveCfg = Release|Any CPU
23 | {1B5DA09B-E217-40C6-80CB-3BBA28F5AF39}.Release|Any CPU.Build.0 = Release|Any CPU
24 | {42436C8B-339F-4020-A52D-F6787C267483}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
25 | {42436C8B-339F-4020-A52D-F6787C267483}.Debug|Any CPU.Build.0 = Debug|Any CPU
26 | {42436C8B-339F-4020-A52D-F6787C267483}.Release|Any CPU.ActiveCfg = Release|Any CPU
27 | {42436C8B-339F-4020-A52D-F6787C267483}.Release|Any CPU.Build.0 = Release|Any CPU
28 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
29 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Debug|Any CPU.Build.0 = Debug|Any CPU
30 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Release|Any CPU.ActiveCfg = Release|Any CPU
31 | {D4147008-F6C7-4528-9621-219F41BC12DB}.Release|Any CPU.Build.0 = Release|Any CPU
32 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
33 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Debug|Any CPU.Build.0 = Debug|Any CPU
34 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Release|Any CPU.ActiveCfg = Release|Any CPU
35 | {3267E733-7D48-4A11-AB73-1261ACBC8681}.Release|Any CPU.Build.0 = Release|Any CPU
36 | EndGlobalSection
37 | GlobalSection(SolutionProperties) = preSolution
38 | HideSolutionNode = FALSE
39 | EndGlobalSection
40 | GlobalSection(ExtensibilityGlobals) = postSolution
41 | SolutionGuid = {91D33AE2-5546-4275-B7C7-3B029863A3FB}
42 | EndGlobalSection
43 | EndGlobal
44 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier/Program.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.Data;
2 | using PdfPigMLNetBlockClassifier.Data.v2;
3 | using PdfPigMLNetBlockClassifier.LightGbmV2;
4 | using System;
5 | using System.Collections.Generic;
6 | using System.Linq;
7 | using UglyToad.PdfPig;
8 | using UglyToad.PdfPig.DocumentLayoutAnalysis.PageSegmenter;
9 | using UglyToad.PdfPig.DocumentLayoutAnalysis.WordExtractor;
10 | using UglyToad.PdfPig.Outline;
11 |
12 | namespace PdfPigMLNetBlockClassifier
13 | {
14 | class Program
15 | {
16 | static readonly string TRAIN_RAW_DATA_FILEPATH = @"D:\Datasets\Document Layout Analysis\PubLayNet\extracted\train";
17 | static readonly string TEST_RAW_DATA_FILEPATH = @"D:\Datasets\Document Layout Analysis\PubLayNet\extracted\val";
18 |
19 | static readonly string TRAIN_DATA_FILENAME = "features_train_v2.csv";
20 | static readonly string TEST_DATA_FILENAME = "features_val_v2.csv";
21 |
22 | static readonly string MODEL_NAME = "modelV2.zip";
23 |
24 | static void Main(string[] args)
25 | {
26 | // 1. Convert pdf documents and their PAGE xml ground truth to csv files
27 | //DataGenerator.GetCsv(TEST_RAW_DATA_FILEPATH, 0, TEST_DATA_FILENAME); // testing
28 | //DataGenerator.GetCsv(TRAIN_RAW_DATA_FILEPATH, 0, TRAIN_DATA_FILENAME); // training
29 |
30 | // 2. Create the model
31 | //LightGbmModelBuilder.TrainModel(DataGenerator.GetDataPath(TRAIN_DATA_FILENAME), MODEL_NAME);
32 |
33 | // 3. Evaluate the model
34 | //LightGbmModelBuilder.Evaluate(MODEL_NAME, DataGenerator.GetDataPath(TEST_DATA_FILENAME));
35 |
36 | // 4. Load the trained classifier
37 | LightGbmBlockClassifier lightGbmBlockClassifier = new LightGbmBlockClassifier(LightGbmModelBuilder.GetModelPath(MODEL_NAME));
38 |
39 | var test = lightGbmBlockClassifier.OutputSchema["label"].HasSlotNames();
40 |
41 | NearestNeighbourWordExtractor nearestNeighbourWordExtractor = new NearestNeighbourWordExtractor();
42 | RecursiveXYCut recursiveXYCut = new RecursiveXYCut();
43 |
44 | using (var document = PdfDocument.Open("sample.pdf"))
45 | {
46 | var hasBookmarks = document.TryGetBookmarks(out Bookmarks bookmarks);
47 |
48 | for (var i = 0; i < document.NumberOfPages; i++)
49 | {
50 | var page = document.GetPage(i + 1);
51 |
52 | List bookmarksNodes = bookmarks?.GetNodes()
53 | .Where(b => b is DocumentBookmarkNode)
54 | .Select(b => b as DocumentBookmarkNode)
55 | .Cast()
56 | .Where(b => b.PageNumber == page.Number).ToList();
57 |
58 | var avgPageFontHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average();
59 |
60 | var words = nearestNeighbourWordExtractor.GetWords(page.Letters);
61 | var blocks = recursiveXYCut.GetBlocks(words, page.Width / 3.0);
62 |
63 | foreach (var block in blocks)
64 | {
65 | var paths = FeatureHelper.GetPathsInside(block.BoundingBox, page.ExperimentalAccess.Paths);
66 | var images = FeatureHelper.GetImagesInside(block.BoundingBox, page.GetImages());
67 |
68 | var pred = lightGbmBlockClassifier.Classify(block, paths, images, avgPageFontHeight, bookmarksNodes);
69 |
70 | Console.WriteLine();
71 | Console.WriteLine(pred.Prediction + " [" + pred.Score.ToString("0.0%") + "]");
72 | Console.WriteLine(block.Text.Normalize(normalizationForm: System.Text.NormalizationForm.FormKC)); // remove ligatures
73 | }
74 | }
75 | }
76 |
77 | Console.ReadKey();
78 | }
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier/model/results.txt:
--------------------------------------------------------------------------------
1 | =============== Cross-validating to get model's accuracy metrics ===============
2 | *************************************************************************************************************
3 | * Metrics for Multi-class Classification model
4 | *------------------------------------------------------------------------------------------------------------
5 | * Average MicroAccuracy: 0.931 - Standard deviation: (.001) - Confidence Interval 95%: (.001)
6 | * Average MacroAccuracy: 0.743 - Standard deviation: (.001) - Confidence Interval 95%: (.001)
7 | * Average LogLoss: .219 - Standard deviation: (.002) - Confidence Interval 95%: (.002)
8 | * Average LogLossReduction: .747 - Standard deviation: (.003) - Confidence Interval 95%: (.003)
9 | *************************************************************************************************************
10 | =============== Training model ===============
11 | =============== End of training process ===============
12 | =============== Saving the model ===============
13 | The model is saved to ../../../model\model.zip
14 | =============== Evaluating to get model's accuracy metrics ===============
15 | ************************************************************
16 | * Metrics for multi-class classification model
17 | *-----------------------------------------------------------
18 | MacroAccuracy = 0.7482, a value between 0 and 1, the closer to 1, the better
19 | MicroAccuracy = 0.9369, a value between 0 and 1, the closer to 1, the better
20 | LogLoss = 0.2092, the closer to 0, the better
21 | LogLoss for class 1 = 0.2156, the closer to 0, the better
22 | LogLoss for class 2 = 3.1245, the closer to 0, the better
23 | LogLoss for class 3 = 0.306, the closer to 0, the better
24 | LogLoss for class 4 = 0.1094, the closer to 0, the better
25 | LogLoss for class 5 = 0.2472, the closer to 0, the better
26 |
27 | Confusion table
28 | ||========================================
29 | PREDICTED || 0 | 1 | 2 | 3 | 4 | Recall
30 | TRUTH ||========================================
31 | 0 || 1,765 | 0 | 0 | 145 | 2 | 0.9231
32 | 1 || 1 | 3 | 2 | 273 | 0 | 0.0108
33 | 2 || 0 | 0 | 623 | 63 | 5 | 0.9016
34 | 3 || 242 | 0 | 13 | 8,709 | 1 | 0.9714
35 | 4 || 1 | 0 | 2 | 2 | 71 | 0.9342
36 | ||========================================
37 | Precision ||0.8785 |1.0000 |0.9734 |0.9475 |0.8987 |
38 |
39 | F1 Score for class 1 = 0.9003, a value between 0 and 1, the closer to 1, the better
40 | F1 Score for class 2 = 0.0213, a value between 0 and 1, the closer to 1, the better
41 | F1 Score for class 3 = 0.9361, a value between 0 and 1, the closer to 1, the better
42 | F1 Score for class 4 = 0.9593, a value between 0 and 1, the closer to 1, the better
43 | F1 Score for class 5 = 0.9161, a value between 0 and 1, the closer to 1, the better
44 | ************************************************************
45 | =============== Permutation Feature Importance ===============
46 | PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that
47 | feature across all the examples, so that each example now has a random value for the feature
48 | and the original values for all other features. The evaluation metric(e.g.micro - accuracy) is
49 | then calculated for this modified dataset, and the change in the evaluation metric from the
50 | original dataset is computed.The larger the change in the evaluation metric, the more
51 | important the feature is to the model. PFI works by performing this permutation analysis
52 | across all the features of a model, one after another.
53 |
54 | Feature Change in MicroAccuracy 95% Confidence in the Mean Change in MicroAccuracy
55 | charsCount -0.2192 0.0008443
56 | pctNumericChars -0.04996 0.0004363
57 | deltaToHeight -0.04155 0.000428
58 | pctBulletChars -0.01571 0.0004034
59 | pctAlphabeticalChars -0.012 0.0003245
60 | pctSymbolicChars -0.01187 0.0004204
61 | pathsCount -0.01089 0.0002144
62 | pctHorPaths -0.002695 0.0001318
63 | imageAvgProportion -0.001895 3.909E-05
64 | pctVertPaths -0.001403 0.0001069
65 | pctOblPaths -0.0005032 3.612E-05
66 | pctBezierPaths -7.828E-05 1.563E-05
67 | imagesCount -2.796E-05 2.768E-05
68 |
69 |
70 | Feature Change in MacroAccuracy 95% Confidence in the Mean Change in MacroAccuracy
71 | charsCount -0.1906 0.001906
72 | pathsCount -0.07355 0.001211
73 | pctNumericChars -0.06516 0.0008356
74 | deltaToHeight -0.05476 0.001333
75 | pctOblPaths -0.01097 0.0002006
76 | pctAlphabeticalChars -0.009334 0.001237
77 | imageAvgProportion -0.00874 0.0001771
78 | pctVertPaths -0.005001 0.0003568
79 | pctHorPaths -0.004718 0.0004244
80 | pctSymbolicChars -0.001522 0.0008552
81 | pctBulletChars 0.0009088 0.0007747
82 | pctBezierPaths -0.0003737 5.705E-05
83 | imagesCount -3.805E-05 3.434E-05
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 |
4 | # User-specific files
5 | *.suo
6 | *.user
7 | *.userosscache
8 | *.sln.docstates
9 | *csv
10 | *txt
11 |
12 | # User-specific files (MonoDevelop/Xamarin Studio)
13 | *.userprefs
14 |
15 | # Build results
16 | [Dd]ebug/
17 | [Dd]ebugPublic/
18 | [Rr]elease/
19 | [Rr]eleases/
20 | x64/
21 | x86/
22 | bld/
23 | [Bb]in/
24 | [Oo]bj/
25 | [Ll]og/
26 |
27 | # Visual Studio 2015 cache/options directory
28 | .vs/
29 | # Uncomment if you have tasks that create the project's static files in wwwroot
30 | #wwwroot/
31 |
32 | # MSTest test Results
33 | [Tt]est[Rr]esult*/
34 | [Bb]uild[Ll]og.*
35 |
36 | # NUNIT
37 | *.VisualState.xml
38 | TestResult.xml
39 |
40 | # Build Results of an ATL Project
41 | [Dd]ebugPS/
42 | [Rr]eleasePS/
43 | dlldata.c
44 |
45 | # DNX
46 | project.lock.json
47 | project.fragment.lock.json
48 | artifacts/
49 |
50 | *_i.c
51 | *_p.c
52 | *_i.h
53 | *.ilk
54 | *.meta
55 | *.obj
56 | *.pch
57 | *.pdb
58 | *.pgc
59 | *.pgd
60 | *.rsp
61 | *.sbr
62 | *.tlb
63 | *.tli
64 | *.tlh
65 | *.tmp
66 | *.tmp_proj
67 | *.log
68 | *.vspscc
69 | *.vssscc
70 | .builds
71 | *.pidb
72 | *.svclog
73 | *.scc
74 |
75 | # Chutzpah Test files
76 | _Chutzpah*
77 |
78 | # Visual C++ cache files
79 | ipch/
80 | *.aps
81 | *.ncb
82 | *.opendb
83 | *.opensdf
84 | *.sdf
85 | *.cachefile
86 | *.VC.db
87 | *.VC.VC.opendb
88 |
89 | # Visual Studio profiler
90 | *.psess
91 | *.vsp
92 | *.vspx
93 | *.sap
94 |
95 | # TFS 2012 Local Workspace
96 | $tf/
97 |
98 | # Guidance Automation Toolkit
99 | *.gpState
100 |
101 | # ReSharper is a .NET coding add-in
102 | _ReSharper*/
103 | *.[Rr]e[Ss]harper
104 | *.DotSettings.user
105 |
106 | # JustCode is a .NET coding add-in
107 | .JustCode
108 |
109 | # TeamCity is a build add-in
110 | _TeamCity*
111 |
112 | # DotCover is a Code Coverage Tool
113 | *.dotCover
114 |
115 | # NCrunch
116 | _NCrunch_*
117 | .*crunch*.local.xml
118 | nCrunchTemp_*
119 |
120 | # MightyMoose
121 | *.mm.*
122 | AutoTest.Net/
123 |
124 | # Web workbench (sass)
125 | .sass-cache/
126 |
127 | # Installshield output folder
128 | [Ee]xpress/
129 |
130 | # DocProject is a documentation generator add-in
131 | DocProject/buildhelp/
132 | DocProject/Help/*.HxT
133 | DocProject/Help/*.HxC
134 | DocProject/Help/*.hhc
135 | DocProject/Help/*.hhk
136 | DocProject/Help/*.hhp
137 | DocProject/Help/Html2
138 | DocProject/Help/html
139 |
140 | # Click-Once directory
141 | publish/
142 |
143 | # Publish Web Output
144 | *.[Pp]ublish.xml
145 | *.azurePubxml
146 | # TODO: Comment the next line if you want to checkin your web deploy settings
147 | # but database connection strings (with potential passwords) will be unencrypted
148 | #*.pubxml
149 | *.publishproj
150 |
151 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
152 | # checkin your Azure Web App publish settings, but sensitive information contained
153 | # in these scripts will be unencrypted
154 | PublishScripts/
155 |
156 | # NuGet Packages
157 | *.nupkg
158 | # The packages folder can be ignored because of Package Restore
159 | **/packages/*
160 | # except build/, which is used as an MSBuild target.
161 | !**/packages/build/
162 | # Uncomment if necessary however generally it will be regenerated when needed
163 | #!**/packages/repositories.config
164 | # NuGet v3's project.json files produces more ignoreable files
165 | *.nuget.props
166 | *.nuget.targets
167 |
168 | # Microsoft Azure Build Output
169 | csx/
170 | *.build.csdef
171 |
172 | # Microsoft Azure Emulator
173 | ecf/
174 | rcf/
175 |
176 | # Windows Store app package directories and files
177 | AppPackages/
178 | BundleArtifacts/
179 | Package.StoreAssociation.xml
180 | _pkginfo.txt
181 |
182 | # Visual Studio cache files
183 | # files ending in .cache can be ignored
184 | *.[Cc]ache
185 | # but keep track of directories ending in .cache
186 | !*.[Cc]ache/
187 |
188 | # Others
189 | ClientBin/
190 | ~$*
191 | *~
192 | *.dbmdl
193 | *.dbproj.schemaview
194 | *.jfm
195 | *.pfx
196 | *.publishsettings
197 | node_modules/
198 | orleans.codegen.cs
199 |
200 | # Since there are multiple workflows, uncomment next line to ignore bower_components
201 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
202 | #bower_components/
203 |
204 | # RIA/Silverlight projects
205 | Generated_Code/
206 |
207 | # Backup & report files from converting an old project file
208 | # to a newer Visual Studio version. Backup files are not needed,
209 | # because we have git ;-)
210 | _UpgradeReport_Files/
211 | Backup*/
212 | UpgradeLog*.XML
213 | UpgradeLog*.htm
214 |
215 | # SQL Server files
216 | *.mdf
217 | *.ldf
218 |
219 | # Business Intelligence projects
220 | *.rdl.data
221 | *.bim.layout
222 | *.bim_*.settings
223 |
224 | # Microsoft Fakes
225 | FakesAssemblies/
226 |
227 | # GhostDoc plugin setting file
228 | *.GhostDoc.xml
229 |
230 | # Node.js Tools for Visual Studio
231 | .ntvs_analysis.dat
232 |
233 | # Visual Studio 6 build log
234 | *.plg
235 |
236 | # Visual Studio 6 workspace options file
237 | *.opt
238 |
239 | # Visual Studio LightSwitch build output
240 | **/*.HTMLClient/GeneratedArtifacts
241 | **/*.DesktopClient/GeneratedArtifacts
242 | **/*.DesktopClient/ModelManifest.xml
243 | **/*.Server/GeneratedArtifacts
244 | **/*.Server/ModelManifest.xml
245 | _Pvt_Extensions
246 |
247 | # Paket dependency manager
248 | .paket/paket.exe
249 | paket-files/
250 |
251 | # FAKE - F# Make
252 | .fake/
253 |
254 | # JetBrains Rider
255 | .idea/
256 | *.sln.iml
257 |
258 | # CodeRush
259 | .cr/
260 |
261 | # Python Tools for Visual Studio (PTVS)
262 | __pycache__/
263 | *.pyc
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbmV2/LightGbmBlockClassifier.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML;
2 | using PdfPigMLNetBlockClassifier.Data.v2;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 | using UglyToad.PdfPig.Content;
6 | using UglyToad.PdfPig.Core;
7 | using UglyToad.PdfPig.DocumentLayoutAnalysis;
8 | using UglyToad.PdfPig.DocumentLayoutAnalysis.PageSegmenter;
9 | using UglyToad.PdfPig.Outline;
10 | using UglyToad.PdfPig.Util;
11 |
12 | namespace PdfPigMLNetBlockClassifier.LightGbmV2
13 | {
14 | public class LightGbmBlockClassifier
15 | {
16 | private MLContext mlContext;
17 | private ITransformer mlModel;
18 | private PredictionEngine predEngine;
19 |
20 | public DataViewSchema OutputSchema => predEngine?.OutputSchema;
21 |
22 | private static readonly int[] MLNetCategories = new int[] { 2, 0, 3, 4, 1 };
23 |
24 | public LightGbmBlockClassifier(string modelPath)
25 | {
26 | mlContext = new MLContext();
27 | mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
28 | predEngine = mlContext.Model.CreatePredictionEngine(mlModel);
29 | }
30 |
31 | public (string Prediction, float Score) Classify(TextBlock textBlock, IEnumerable paths, IEnumerable images,
32 | double averagePageFontHeight, List pageBookmarksNodes)
33 | {
34 | double bboxArea = textBlock.BoundingBox.Area;
35 |
36 | var letters = textBlock.TextLines.SelectMany(li => li.Words).SelectMany(w => w.Letters);
37 |
38 |
39 | var features = FeatureHelper.GetFeatures(
40 | textBlock, paths,
41 | images, averagePageFontHeight,
42 | textBlock.BoundingBox.Area,
43 | pageBookmarksNodes);
44 |
45 | BlockFeatures blockFeatures = new BlockFeatures()
46 | {
47 | BlockAspectRatio = features[0],
48 | CharsCount = features[1],
49 | WordsCount = features[2],
50 | LinesCount = features[3],
51 | PctNumericChars = features[4],
52 | PctAlphabeticalChars = features[5],
53 | PctSymbolicChars = features[6],
54 | PctBulletChars = features[7],
55 | DeltaToHeight = features[8],
56 | PathsCount = features[9],
57 | PctBezierPaths = features[10],
58 | PctHorPaths = features[11],
59 | PctVertPaths = features[12],
60 | PctOblPaths = features[13],
61 | ImagesCount = features[14],
62 | ImageAvgProportion = features[15],
63 | BestNormEditDistance = features[16],
64 | };
65 | var result = predEngine.Predict(blockFeatures);
66 |
67 | return (FeatureHelper.Categories[(int)result.Prediction], result.Score.Max());
68 | }
69 |
70 | public IEnumerable<(string Prediction, float Score, TextBlock Block)> Classify(Page page, IWordExtractor wordExtractor,
71 | IPageSegmenter pageSegmenter, Bookmarks bookmarks = null)
72 | {
73 |
74 | List bookmarksNodes = bookmarks?.GetNodes()
75 | .Where(b => b is DocumentBookmarkNode)
76 | .Select(b => b as DocumentBookmarkNode)
77 | .Cast()
78 | .Where(b => b.PageNumber == page.Number).ToList();
79 |
80 | var avgPageFontHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average();
81 |
82 | var words = wordExtractor.GetWords(page.Letters);
83 | var blocks = pageSegmenter.GetBlocks(words);
84 |
85 | foreach (var block in blocks)
86 | {
87 | var letters = block.TextLines.SelectMany(li => li.Words).SelectMany(w => w.Letters);
88 | var paths = FeatureHelper.GetPathsInside(block.BoundingBox, page.ExperimentalAccess.Paths);
89 | var images = FeatureHelper.GetImagesInside(block.BoundingBox, page.GetImages());
90 |
91 | var features = FeatureHelper.GetFeatures(
92 | block, paths,
93 | images, avgPageFontHeight,
94 | block.BoundingBox.Area,
95 | bookmarksNodes);
96 |
97 | BlockFeatures blockFeatures = new BlockFeatures()
98 | {
99 | BlockAspectRatio = features[0],
100 | CharsCount = features[1],
101 | WordsCount = features[2],
102 | LinesCount = features[3],
103 | PctNumericChars = features[4],
104 | PctAlphabeticalChars = features[5],
105 | PctSymbolicChars = features[6],
106 | PctBulletChars = features[7],
107 | DeltaToHeight = features[8],
108 | PathsCount = features[9],
109 | PctBezierPaths = features[10],
110 | PctHorPaths = features[11],
111 | PctVertPaths = features[12],
112 | PctOblPaths = features[13],
113 | ImagesCount = features[14],
114 | ImageAvgProportion = features[15],
115 | BestNormEditDistance = features[16],
116 | };
117 |
118 | var result = predEngine.Predict(blockFeatures);
119 |
120 | yield return (FeatureHelper.Categories[(int)result.Prediction], result.Score.Max(), block);
121 | }
122 | }
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.Data/v1/FeatureHelper.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using UglyToad.PdfPig.Content;
5 | using UglyToad.PdfPig.Core;
6 | using static UglyToad.PdfPig.Core.PdfPath;
7 |
8 | namespace PdfPigMLNetBlockClassifier.Data.v1
9 | {
10 | public static class FeatureHelper
11 | {
12 | public static readonly Dictionary Categories = new Dictionary()
13 | {
14 | { 0, "text" },
15 | { 1, "title" },
16 | { 2, "list" },
17 | { 3, "table" },
18 | { 4, "image" },
19 | };
20 |
21 | private static readonly char[] Bullets = new char[]
22 | {
23 | '•', 'o', '▪', '❖', '➢', '►', '✓', '➔', '⇨', '➪',
24 | '➨', '➫', '➬', '➭', '➮', '➯', '➱', '➲', '\u2023',
25 | '\u2043', '\u204C', '\u204D'
26 | };
27 |
28 | public static float[] GetFeatures(Page page, PdfRectangle bbox, IEnumerable letters, IEnumerable paths, IEnumerable images)
29 | {
30 | // Letters features
31 | float charsCount = 0;
32 | float pctNumericChars = 0;
33 | float pctAlphabeticalChars = 0;
34 | float pctSymbolicChars = 0;
35 | float pctBulletChars = 0;
36 | float deltaToHeight = -1; // might be problematic
37 |
38 | if (letters != null && letters.Count() > 0)
39 | {
40 | var avgHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average();
41 |
42 | char[] chars = letters.SelectMany(l => l.Value).ToArray();
43 |
44 | charsCount = chars.Length;
45 | pctNumericChars = (float)Math.Round(chars.Count(c => char.IsNumber(c)) / charsCount, 5);
46 | pctAlphabeticalChars = (float)Math.Round(chars.Count(c => char.IsLetter(c)) / charsCount, 5);
47 | pctSymbolicChars = (float)Math.Round(chars.Count(c => !char.IsLetterOrDigit(c)) / charsCount, 5);
48 | pctBulletChars = (float)Math.Round(chars.Count(c => Bullets.Any(bullet => bullet == c)) / charsCount, 5);
49 | deltaToHeight = avgHeight != 0 ? (float)Math.Round(letters.Select(l => l.GlyphRectangle.Height).Average() / avgHeight, 5) : -1;
50 | }
51 |
52 | // Paths features
53 | float pathsCount = 0;
54 | float pctBezierPaths = 0;
55 | float pctHorPaths = 0;
56 | float pctVertPaths = 0;
57 | float pctOblPaths = 0;
58 |
59 | if (paths != null && paths.Count() > 0)
60 | {
61 | foreach (var path in paths)
62 | {
63 | foreach (var command in path.Commands)
64 | {
65 | if (command is BezierCurve bezierCurve)
66 | {
67 | pathsCount++;
68 | pctBezierPaths++;
69 | }
70 | else if (command is Line line)
71 | {
72 | pathsCount++;
73 | if (line.From.X == line.To.X)
74 | {
75 | pctVertPaths++;
76 | }
77 | else if (line.From.Y == line.To.Y)
78 | {
79 | pctHorPaths++;
80 | }
81 | else
82 | {
83 | pctOblPaths++;
84 | }
85 | }
86 | }
87 | }
88 |
89 | pctBezierPaths = (float)Math.Round(pctBezierPaths / pathsCount, 5);
90 | pctHorPaths = (float)Math.Round(pctHorPaths / pathsCount, 5);
91 | pctVertPaths = (float)Math.Round(pctVertPaths / pathsCount, 5);
92 | pctOblPaths = (float)Math.Round(pctOblPaths / pathsCount, 5);
93 | }
94 |
95 | // Images features
96 | float imagesCount = 0;
97 | float imageAvgProportion = 0;
98 |
99 | if (images != null && images.Count() > 0)
100 | {
101 | imagesCount = images.Count();
102 | imageAvgProportion = (float)(images.Average(i => i.Bounds.Area) / bbox.Area);
103 | }
104 |
105 | return new float[]
106 | {
107 | charsCount, pctNumericChars, pctAlphabeticalChars, pctSymbolicChars, pctBulletChars, deltaToHeight,
108 | pathsCount, pctBezierPaths, pctHorPaths, pctVertPaths, pctOblPaths,
109 | imagesCount, imageAvgProportion
110 | };
111 | }
112 |
113 | public static IEnumerable GetLettersInside(PdfRectangle bound, IEnumerable letters)
114 | {
115 | return letters.Where(l => l.GlyphRectangle.Left >= bound.Left &&
116 | l.GlyphRectangle.Right <= bound.Right &&
117 | l.GlyphRectangle.Bottom >= bound.Bottom &&
118 | l.GlyphRectangle.Top <= bound.Top);
119 | }
120 |
121 | public static IEnumerable GetImagesInside(PdfRectangle bound, IEnumerable images)
122 | {
123 | return images.Where(b => b.Bounds.Left >= bound.Left &&
124 | b.Bounds.Right <= bound.Right &&
125 | b.Bounds.Bottom >= bound.Bottom &&
126 | b.Bounds.Top <= bound.Top);
127 | }
128 |
129 | public static IEnumerable GetPathsInside(PdfRectangle bound, IEnumerable paths)
130 | {
131 | return paths.Where(b => b.GetBoundingRectangle().HasValue)
132 | .Where(b => b.GetBoundingRectangle().Value.Left >= bound.Left &&
133 | b.GetBoundingRectangle().Value.Right <= bound.Right &&
134 | b.GetBoundingRectangle().Value.Bottom >= bound.Bottom &&
135 | b.GetBoundingRectangle().Value.Top <= bound.Top);
136 | }
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbm/README.md:
--------------------------------------------------------------------------------
1 | # Results
2 | Results are based on PubLayNet's validation dataset, where the page segmentation is known. For real life use, a page segmenter will be needed (see PdfPig's [PageSegmenters](https://github.com/UglyToad/PdfPig/tree/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/PageSegmenter)). The quality of the page segmentation will impact the results.
3 | ## Metrics for multi-class classification model
4 | ```
5 | MicroAccuracy = 0.9369, a value between 0 and 1, the closer to 1, the better
6 | MacroAccuracy = 0.7482, a value between 0 and 1, the closer to 1, the better
7 | LogLoss = 0.2092, the closer to 0, the better
8 |
9 | LogLoss for class 0 (title) = 0.2156, the closer to 0, the better
10 | LogLoss for class 1 (list) = 3.1245, the closer to 0, the better
11 | LogLoss for class 2 (table) = 0.3060, the closer to 0, the better
12 | LogLoss for class 3 (text) = 0.1094, the closer to 0, the better
13 | LogLoss for class 4 (image) = 0.2472, the closer to 0, the better
14 |
15 | F1 Score for class 0 (title) = 0.9003, a value between 0 and 1, the closer to 1, the better
16 | F1 Score for class 1 (list) = 0.0213, a value between 0 and 1, the closer to 1, the better
17 | F1 Score for class 2 (table) = 0.9361, a value between 0 and 1, the closer to 1, the better
18 | F1 Score for class 3 (text) = 0.9593, a value between 0 and 1, the closer to 1, the better
19 | F1 Score for class 4 (image) = 0.9161, a value between 0 and 1, the closer to 1, the better
20 | ```
21 |
22 | ## Confusion table
23 | ```
24 | ||=======================================================
25 | PREDICTED || 0 | 1 | 2 | 3 | 4 | Total | Recall
26 | TRUTH ||=======================================================
27 | (title) 0 || 1,765 | 0 | 0 | 145 | 2 | 1,912 | 0.9231
28 | (list) 1 || 1 | 3 | 2 | 273 | 0 | 279 | 0.0108
29 | (table) 2 || 0 | 0 | 623 | 63 | 5 | 691 | 0.9016
30 | (text) 3 || 242 | 0 | 13 | 8,709 | 1 | 8,965 | 0.9714
31 | (image) 4 || 1 | 0 | 2 | 2 | 71 | 76 | 0.9342
32 | ||=======================================================
33 | Precision ||0.8785 |1.0000 |0.9734 |0.9475 |0.8987 |
34 | ```
35 |
36 | ## Permutation Feature Importance
37 | PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that
38 | feature across all the examples, so that each example now has a random value for the feature
39 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is
40 | then calculated for this modified dataset, and the change in the evaluation metric from the
41 | original dataset is computed. The larger the change in the evaluation metric, the more
42 | important the feature is to the model. PFI works by performing this permutation analysis
43 | across all the features of a model, one after another. - [Source]( https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_)
44 |
45 | ### Micro Accuracy
46 |
47 | |Feature | Description | Change in MicroAccuracy | 95% Confidence in the Mean Change in MicroAccuracy |
48 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:|
49 | |**charsCount** |Characters count|-0.2192 |0.0008443 |
50 | |**pctNumericChars** |% of numeric characters|-0.04996 |0.0004363 |
51 | |**deltaToHeight** |Average delta to average page glyph height|-0.04155 |0.000428|
52 | |**pctBulletChars** |% of bullet characters|-0.01571 |0.0004034|
53 | |**pctAlphabeticalChars** |% of alphabetical characters|-0.012 |0.0003245 |
54 | |**pctSymbolicChars** |% of symbolic characters|-0.01187 |0.0004204 |
55 | |**pathsCount** |Paths count|-0.01089 |0.0002144 |
56 | |**pctHorPaths** |% of horizontal paths|-0.002695 |0.0001318 |
57 | |**imageAvgProportion** |Average area covered by images|-0.001895 |0.00003909 |
58 | |**pctVertPaths** |% of vertical paths|-0.001403 |0.0001069 |
59 | |**pctOblPaths** |% of oblique paths|-0.0005032 |0.00003612 |
60 | |**pctBezierPaths** |% of Bezier curve paths|-0.00007828 |0.00001563 |
61 | |**imagesCount** |Images count|-0.00002796 |0.00002768 |
62 |
63 | ### Macro Accuracy
64 |
65 | |Feature | Description | Change in MacroAccuracy | 95% Confidence in the Mean Change in MacroAccuracy |
66 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:|
67 | |**charsCount** |Characters count|-0.1906 |0.001906|
68 | |**pathsCount** |Paths count|-0.07355 |0.001211|
69 | |**pctNumericChars** |% of numeric characters|-0.06516 |0.0008356|
70 | |**deltaToHeight** |Average delta to average page glyph height|-0.05476 |0.001333|
71 | |**pctOblPaths** |% of oblique paths|-0.01097 |0.0002006|
72 | |**pctAlphabeticalChars** |% of alphabetical characters|-0.009334 |0.001237|
73 | |**imageAvgProportion** |Average area covered by images|-0.00874 |0.0001771|
74 | |**pctVertPaths** |% of vertical paths|-0.005001 |0.0003568|
75 | |**pctHorPaths** |% of horizontal paths|-0.004718 |0.0004244|
76 | |**pctSymbolicChars** |% of symbolic characters|-0.001522 |0.0008552|
77 | |**pctBulletChars** |% of bullet characters|0.0009088 |0.0007747|
78 | |**pctBezierPaths** |% of Bezier curve paths|-0.0003737 |0.00005705|
79 | |**imagesCount** |Images count|-0.00003805 |0.00003434|
80 |
81 | # TO DO
82 | ## Features
83 | - Add a [decoration](https://github.com/UglyToad/PdfPig/blob/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/DecorationTextBlockClassifier.cs) score/flag
84 | - Add block's number of line, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf)
85 | - Add block's aspect ratio: the ratio between width and height of the bounding box, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf)
86 | - Add block's area ratio: the ratio between block area and the page area, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf)
87 | - Add block's font style: an enumerated type, with possible values: regular, bold, italic, underline, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf)
88 | - Use [bookmarks](https://github.com/UglyToad/PdfPig/blob/master/src/UglyToad.PdfPig/Outline/BookmarksProvider.cs) when available with minimum edit distance score
89 | - Add % sparse lines in a block, for better table recognition [cf.](https://clgiles.ist.psu.edu/pubs/CIKM2008-table-boundaries.pdf)
90 | - Font color distance from most common color
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PdfPig ML.Net Block Classifier v2
2 | Proof of concept of training a simple Region Classifier using [PdfPig](https://github.com/UglyToad/PdfPig) and [ML.NET](https://github.com/dotnet/machinelearning).
3 | The objective is to classify each text block in a __pdf document__ page as either __title__, __text__, __list__, __table__ and __image__.
4 |
5 | [AutoML](https://docs.microsoft.com/en-us/dotnet/machine-learning/automate-training-with-model-builder) model builder was used. The model was
6 | trained on a subset of the [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet#getting-data) dataset. See their license [here](https://cdla.io/permissive-1-0/).
7 |
8 | | Generation | MicroAccuracy | MacroAccuracy |
9 | |-----------:|:-------------:|:-------------:|
10 | | **v1** | 0.937 | 0.748 |
11 | | **v2** | 0.952 | 0.801 |
12 |
13 | For v1 model results, see [here](https://github.com/BobLd/PdfPigMLNetBlockClassifier/blob/master/PdfPigMLNetBlockClassifier.LightGbm/README.md).
14 |
15 | # Results
16 | Results are based on PubLayNet's validation dataset, where the page segmentation is known. For real life use, a page segmenter will be needed (see PdfPig's [PageSegmenters](https://github.com/UglyToad/PdfPig/tree/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/PageSegmenter)). The quality of the page segmentation will impact the results.
17 | ## Metrics for multi-class classification model
18 | ```
19 | MicroAccuracy = 0.9523, a value between 0 and 1, the closer to 1, the better
20 | MacroAccuracy = 0.8009, a value between 0 and 1, the closer to 1, the better
21 | LogLoss = 3.5333, the closer to 0, the better
22 |
23 | LogLoss for class 0 (title) = 4.6289, the closer to 0, the better
24 | LogLoss for class 1 (image) = 1.849, the closer to 0, the better
25 | LogLoss for class 2 (text) = 2.5834, the closer to 0, the better
26 | LogLoss for class 3 (list) = 28.8412, the closer to 0, the better
27 | LogLoss for class 4 (table) = 2.8326, the closer to 0, the better
28 |
29 | F1 Score for class 0 (title) = 0.9305, a value between 0 and 1, the closer to 1, the better
30 | F1 Score for class 1 (image) = 0.9412, a value between 0 and 1, the closer to 1, the better
31 | F1 Score for class 2 (text) = 0.9691, a value between 0 and 1, the closer to 1, the better
32 | F1 Score for class 3 (list) = 0.2963, a value between 0 and 1, the closer to 1, the better
33 | F1 Score for class 4 (table) = 0.9611, a value between 0 and 1, the closer to 1, the better
34 | ```
35 |
36 | ## Confusion table
37 | ```
38 | ||=======================================================
39 | PREDICTED || 0 | 1 | 2 | 3 | 4 | Total | Recall
40 | TRUTH ||=======================================================
41 | (title) 0 || 1,848 | 0 | 89 | 1 | 0 | 1,938 | 0.9536
42 | (image) 1 || 0 | 72 | 3 | 0 | 1 | 76 | 0.9474
43 | (text) 2 || 185 | 1 | 8,837 | 16 | 9 | 9,048 | 0.9767
44 | (list) 3 || 1 | 0 | 225 | 52 | 2 | 280 | 0.1857
45 | (table) 4 || 0 | 4 | 35 | 2 | 654 | 695 | 0.9410
46 | ||=======================================================
47 | Precision ||0.9086 |0.9351 |0.9617 |0.7324 |0.9820 |
48 | ```
49 |
50 | ## Permutation Feature Importance
51 | PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that
52 | feature across all the examples, so that each example now has a random value for the feature
53 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is
54 | then calculated for this modified dataset, and the change in the evaluation metric from the
55 | original dataset is computed. The larger the change in the evaluation metric, the more
56 | important the feature is to the model. PFI works by performing this permutation analysis
57 | across all the features of a model, one after another. - [Source]( https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_)
58 |
59 | ### Micro Accuracy
60 |
61 | |Feature | Description | Change in MicroAccuracy | 95% Confidence in the Mean Change in MicroAccuracy |
62 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:|
63 | |**wordsCount**| Words count |-0.1144 |0.0008237|
64 | |**charsCount**| Characters count |-0.09849|0.0009115|
65 | |**linesCount**| Lines count |-0.04255 |0.0005726|
66 | |**blockAspectRatio**|Ratio between the block's width and height|-0.04159|0.0005134|
67 | |**bestNormEditDistance**|Minimum edit distance between bookmark title and block text|-0.04016 | 0.0006307|
68 | |**deltaToHeight**|Average delta to average page glyph height|-0.03689 | 0.0003823|
69 | |**pctNumericChars**|% of numeric characters|-0.03375 | 0.0005174|
70 | |**pctSymbolicChars**|% of symbolic characters|-0.01035 | 0.000358|
71 | |**pctAlphabeticalChars**|% of alphabetical characters| -0.00859 | 0.0003075|
72 | |**pctBulletChars**|% of bullet characters|-0.004071 | 0.0002063|
73 | |**pathsCount**|Paths count|-0.003661 | 0.0001184|
74 | |**pctHorPaths**|% of horizontal paths|-0.003431 | 0.0001309|
75 | |**imageAvgProportion**|Average area covered by images|-0.001684 | 4.68E-05|
76 | |**pctVertPaths**|% of vertical paths|-0.001448 | 8.139E-05|
77 | |**pctOblPaths**|% of oblique paths|-0.0001883 | 1.734E-05|
78 | |**imagesCount**|Images count|-9.692E-05 | 1.127E-05|
79 | |**pctBezierPaths**|% of Bezier curve paths|6.212E-22 | 1.352E-05|
80 |
81 | ### Macro Accuracy
82 |
83 | |Feature | Description | Change in MacroAccuracy | 95% Confidence in the Mean Change in MacroAccuracy |
84 | |------------------------:|:-----------:|:-----------------------:|:--------------------------------------------------:|
85 | |**charsCount**|Characters count| -0.1339 | 0.002263|
86 | |**linesCount**| Lines count | -0.08378 | 0.0009925|
87 | |**deltaToHeight**|Average delta to average page glyph height| -0.06461 | 0.001223|
88 | |**blockAspectRatio**|Ratio between the block's width and height| -0.05896 | 0.001163|
89 | |**bestNormEditDistance**|Minimum edit distance between bookmark title and block text| -0.05039 | 0.001419|
90 | |**pctNumericChars**|% of numeric characters| -0.0475 | 0.001258|
91 | |**pathsCount**|Paths count| -0.0252 | 0.000614|
92 | |**wordsCount**|Words count| -0.01651 | 0.001655|
93 | |**pctSymbolicChars**|% of symbolic characters| -0.01166 | 0.001505|
94 | |**pctHorPaths**|% of horizontal paths| -0.00984 | 0.0003186|
95 | |**pctVertPaths**|% of vertical paths| -0.008558 | 0.0002925|
96 | |**pctAlphabeticalChars**|% of alphabetical characters| -0.006882 | 0.0009913|
97 | |**imageAvgProportion**|Average area covered by images| -0.006148 | 0.0001146|
98 | |**pctBulletChars**|% of bullet characters| -0.005319 | 0.0008375|
99 | |**pctOblPaths**|% of oblique paths| -0.002747 | 6.657E-05|
100 | |**imagesCount**|Images count| -0.00268 | 3.903E-05|
101 | |**pctBezierPaths**|% of Bezier curve paths| -4.078E-05 | 5.263E-05|
102 |
103 | # TO DO
104 | ## Features
105 | - Add a [decoration](https://github.com/UglyToad/PdfPig/blob/master/src/UglyToad.PdfPig.DocumentLayoutAnalysis/DecorationTextBlockClassifier.cs) score/flag
106 | - Add block's area ratio: the ratio between block area and the page area, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf)
107 | - Add block's font style: an enumerated type, with possible values: regular, bold, italic, underline, [cf.](http://www.cs.rug.nl/~aiellom/publications/ijdar.pdf)
108 | - Add % sparse lines in a block, for better table recognition [cf.](https://clgiles.ist.psu.edu/pubs/CIKM2008-table-boundaries.pdf)
109 | - Font color distance from most common color
110 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.Data/v2/FeatureHelper.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using UglyToad.PdfPig.Content;
5 | using UglyToad.PdfPig.Core;
6 | using UglyToad.PdfPig.DocumentLayoutAnalysis;
7 | using UglyToad.PdfPig.DocumentLayoutAnalysis.WordExtractor;
8 | using static UglyToad.PdfPig.Core.PdfPath;
9 | using UglyToad.PdfPig.Geometry;
10 | using UglyToad.PdfPig.Outline;
11 | using System.Text;
12 |
13 | namespace PdfPigMLNetBlockClassifier.Data.v2
14 | {
15 | public static class FeatureHelper
16 | {
17 | public static readonly string Header = "blockAspectRatio,charsCount,wordsCount,linesCount,pctNumericChars,pctAlphabeticalChars,pctSymbolicChars,pctBulletChars,deltaToHeight,pathsCount,pctBezierPaths,pctHorPaths,pctVertPaths,pctOblPaths,imagesCount,imageAvgProportion,bestNormEditDistance,label";
18 |
19 | public static readonly Dictionary Categories = new Dictionary()
20 | {
21 | { 0, "text" },
22 | { 1, "title" },
23 | { 2, "list" },
24 | { 3, "table" },
25 | { 4, "image" },
26 | };
27 |
28 | private static readonly char[] Bullets = new char[]
29 | {
30 | '•', 'o', '▪', '❖', '➢', '►', '✓', '➔', '⇨', '➪',
31 | '➨', '➫', '➬', '➭', '➮', '➯', '➱', '➲', '\u2023',
32 | '\u2043', '\u204C', '\u204D'
33 | };
34 |
35 | public static float[] GetFeatures(TextBlock textBlock, IEnumerable paths, IEnumerable images,
36 | double averagePageFontHeight, double bboxArea, List pageBookmarksNodes)
37 | {
38 | // text block features
39 | float blockAspectRatio = float.NaN;
40 |
41 | // Letters features
42 | float charsCount = 0;
43 | float pctNumericChars = 0;
44 | float pctAlphabeticalChars = 0;
45 | float pctSymbolicChars = 0;
46 | float pctBulletChars = 0;
47 | float deltaToHeight = float.NaN; // might be problematic
48 |
49 | float wordsCount = 0;
50 | float linesCount = 0;
51 | float bestNormEditDistance = float.NaN;
52 |
53 | if (textBlock?.TextLines != null && textBlock.TextLines.Any())
54 | {
55 | blockAspectRatio = (float)Math.Round(textBlock.BoundingBox.Width / textBlock.BoundingBox.Height, 5);
56 |
57 | var avgHeight = averagePageFontHeight;
58 |
59 | var textLines = textBlock.TextLines;
60 | var words = textLines.SelectMany(tl => tl.Words).ToList();
61 | var letters = words.SelectMany(w => w.Letters).ToList();
62 | char[] chars = letters.SelectMany(l => l.Value).ToArray();
63 |
64 | charsCount = chars.Length;
65 | pctNumericChars = (float)Math.Round(chars.Count(c => char.IsNumber(c)) / charsCount, 5);
66 | pctAlphabeticalChars = (float)Math.Round(chars.Count(c => char.IsLetter(c)) / charsCount, 5);
67 | pctSymbolicChars = (float)Math.Round(chars.Count(c => !char.IsLetterOrDigit(c)) / charsCount, 5);
68 | pctBulletChars = (float)Math.Round(chars.Count(c => Bullets.Any(bullet => bullet == c)) / charsCount, 5);
69 | if (avgHeight != 0)
70 | {
71 | deltaToHeight = (float)Math.Round(letters.Select(l => l.GlyphRectangle.Height).Average() / avgHeight, 5);
72 | }
73 |
74 | wordsCount = words.Count();
75 | linesCount = textLines.Count();
76 |
77 | if (pageBookmarksNodes != null)
78 | {
79 | // http://www.unicode.org/reports/tr15/
80 | var textBlockNormalised = textBlock.Text.Normalize(NormalizationForm.FormKC).ToLower();
81 | foreach (var bookmark in pageBookmarksNodes)
82 | {
83 | // need to normalise both text
84 | var bookmarkTextNormalised = bookmark.Title.Normalize(NormalizationForm.FormKC).ToLower();
85 | var currentDist = Distances.MinimumEditDistanceNormalised(textBlockNormalised, bookmarkTextNormalised);
86 | if (float.IsNaN(bestNormEditDistance) || currentDist < bestNormEditDistance)
87 | {
88 | bestNormEditDistance = (float)Math.Round(currentDist, 5);
89 | }
90 | }
91 | }
92 | }
93 |
94 | // Paths features
95 | float pathsCount = 0;
96 | float pctBezierPaths = 0;
97 | float pctHorPaths = 0;
98 | float pctVertPaths = 0;
99 | float pctOblPaths = 0;
100 |
101 | if (paths != null && paths.Count() > 0)
102 | {
103 | foreach (var path in paths)
104 | {
105 | foreach (var command in path.Commands)
106 | {
107 | if (command is BezierCurve bezierCurve)
108 | {
109 | pathsCount++;
110 | pctBezierPaths++;
111 | }
112 | else if (command is Line line)
113 | {
114 | pathsCount++;
115 | if (line.From.X == line.To.X)
116 | {
117 | pctVertPaths++;
118 | }
119 | else if (line.From.Y == line.To.Y)
120 | {
121 | pctHorPaths++;
122 | }
123 | else
124 | {
125 | pctOblPaths++;
126 | }
127 | }
128 | }
129 | }
130 |
131 | pctBezierPaths = (float)Math.Round(pctBezierPaths / pathsCount, 5);
132 | pctHorPaths = (float)Math.Round(pctHorPaths / pathsCount, 5);
133 | pctVertPaths = (float)Math.Round(pctVertPaths / pathsCount, 5);
134 | pctOblPaths = (float)Math.Round(pctOblPaths / pathsCount, 5);
135 | }
136 |
137 | // Images features
138 | float imagesCount = 0;
139 | float imageAvgProportion = 0;
140 |
141 | if (images != null && images.Count() > 0)
142 | {
143 | imagesCount = images.Count();
144 | imageAvgProportion = (float)(images.Average(i => i.Bounds.Area) / bboxArea);
145 | }
146 |
147 | return new float[]
148 | {
149 | blockAspectRatio, charsCount, wordsCount, linesCount, pctNumericChars,
150 | pctAlphabeticalChars, pctSymbolicChars, pctBulletChars, deltaToHeight,
151 | pathsCount, pctBezierPaths, pctHorPaths, pctVertPaths, pctOblPaths,
152 | imagesCount, imageAvgProportion, bestNormEditDistance
153 | };
154 | }
155 |
156 | public static IEnumerable GetLettersInside(PdfRectangle bound, IEnumerable letters)
157 | {
158 | return letters.Where(l => bound.IntersectsWith(l.GlyphRectangle));
159 | }
160 |
161 | static NearestNeighbourWordExtractor nearestNeighbourWordExtractor = NearestNeighbourWordExtractor.Instance;
162 |
163 | public static IReadOnlyList GetWords(IReadOnlyList letters)
164 | {
165 | return nearestNeighbourWordExtractor.GetWords(letters).ToList();
166 | }
167 |
168 | public static IReadOnlyList GetLines(IReadOnlyList words)
169 | {
170 | return words.GroupBy(x => x.BoundingBox.Bottom).OrderByDescending(x => x.Key)
171 | .Select(x => new TextLine(x.ToList())).ToArray();
172 | }
173 |
174 | public static IEnumerable GetImagesInside(PdfRectangle bound, IEnumerable images)
175 | {
176 | return images.Where(b => b.Bounds.Left >= bound.Left &&
177 | b.Bounds.Right <= bound.Right &&
178 | b.Bounds.Bottom >= bound.Bottom &&
179 | b.Bounds.Top <= bound.Top);
180 | }
181 |
182 | public static IEnumerable GetPathsInside(PdfRectangle bound, IEnumerable paths)
183 | {
184 | return paths.Where(b => b.GetBoundingRectangle().HasValue)
185 | .Where(b => b.GetBoundingRectangle().Value.Left >= bound.Left &&
186 | b.GetBoundingRectangle().Value.Right <= bound.Right &&
187 | b.GetBoundingRectangle().Value.Bottom >= bound.Bottom &&
188 | b.GetBoundingRectangle().Value.Top <= bound.Top);
189 | }
190 | }
191 | }
192 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.Data/v1/DataGenerator.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Concurrent;
3 | using System.Collections.Generic;
4 | using System.IO;
5 | using System.Linq;
6 | using System.Threading.Tasks;
7 | using System.Xml;
8 | using System.Xml.Serialization;
9 | using UglyToad.PdfPig;
10 | using UglyToad.PdfPig.Core;
11 | using UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE;
12 | using static UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE.PageXmlDocument;
13 |
14 | namespace PdfPigMLNetBlockClassifier.Data.v1
15 | {
16 | public static class DataGenerator
17 | {
18 | public static readonly string OutputFolderPath = @"../../../data";
19 | private static readonly string header = "charsCount,pctNumericChars,pctAlphabeticalChars,pctSymbolicChars,pctBulletChars,deltaToHeight,pathsCount,pctBezierPaths,pctHorPaths,pctVertPaths,pctOblPaths,imagesCount,imageAvgProportion,label";
20 |
21 | ///
22 | /// Generate a csv file of features. You will need the pdf documents and the ground truths in PAGE xml format.
23 | ///
24 | /// The path to the data folder. It should contain both the pdf files
25 | /// and their corresponding ground truth xml files.
26 | /// Number of documents to concider.
27 | public static void GetCsv(string dataFolder, int numberOfPdfDocs, string outputFileName)
28 | {
29 | string outputFullPath = GetDataPath(outputFileName);
30 | string outputErrorFullPath = Path.Combine(OutputFolderPath, "invalide_pdfs_" + Path.ChangeExtension(outputFileName, "txt"));
31 |
32 | ConcurrentBag invalidPdfs = new ConcurrentBag();
33 | ConcurrentBag data = new ConcurrentBag();
34 |
35 | int done = 0;
36 |
37 | DirectoryInfo d = new DirectoryInfo(dataFolder);
38 | var pdfFileLinks = d.GetFiles("*.pdf", SearchOption.TopDirectoryOnly);
39 | var maxPageNumber = d.GetFiles("*.xml", SearchOption.TopDirectoryOnly).Select(f => ParseXmlFileName(f.Name)).Max() + 1;
40 |
41 | numberOfPdfDocs = Math.Min(pdfFileLinks.Count(), numberOfPdfDocs);
42 | numberOfPdfDocs = numberOfPdfDocs == 0 ? pdfFileLinks.Count() : numberOfPdfDocs;
43 |
44 | var indexesSelected = GenerateRandom(numberOfPdfDocs, 0, pdfFileLinks.Length);
45 |
46 | Parallel.ForEach(indexesSelected, index =>
47 | {
48 | var pdfFile = pdfFileLinks[index];
49 | string fileName = pdfFile.Name;
50 | string xmlFileNameTemplate = fileName.Replace(".pdf", "_");
51 |
52 | var pageXmlLinksCandidates = Enumerable.Range(0, maxPageNumber).Select(i =>
53 | Path.Combine(dataFolder, fileName.Replace(".pdf", "_" + string.Format("{0:00000}", i) + ".xml"))).ToArray();
54 | var pageXmlLinks = pageXmlLinksCandidates.Where(l => File.Exists(l)).Select(l => new FileInfo(l)).ToArray();
55 |
56 | if (pageXmlLinks.Length == 0)
57 | {
58 | Console.BackgroundColor = ConsoleColor.DarkRed;
59 | Console.WriteLine("Error for document '" + fileName + "': No PageXml files found");
60 | Console.ResetColor();
61 | return;
62 | }
63 |
64 | try
65 | {
66 | var pagesNumbers = pageXmlLinks.Select(l => ParseXmlFileName(l.Name)).ToList();
67 | List localFeatures = new List();
68 | List localCategories = new List();
69 | bool isValidDocument = true;
70 |
71 | using (var doc = PdfDocument.Open(pdfFile.FullName))
72 | {
73 | // Checks if this pdf document looks to be valid
74 | if ((pagesNumbers.Max() + 1) > doc.NumberOfPages)
75 | {
76 | // ignore this document as page number is not correct
77 | Console.BackgroundColor = ConsoleColor.Red;
78 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as page number is not correct");
79 | Console.ResetColor();
80 | isValidDocument = false;
81 | }
82 |
83 | foreach (var pageXmlLink in pageXmlLinks)
84 | {
85 | if (!isValidDocument) break;
86 |
87 | int pageNo = ParseXmlFileName(pageXmlLink.Name);
88 | var page = doc.GetPage(pageNo + 1);
89 |
90 | if (page.Rotation.Value != 0)
91 | {
92 | Console.BackgroundColor = ConsoleColor.Yellow;
93 | Console.ForegroundColor = ConsoleColor.Black;
94 | Console.WriteLine("Error for document '" + fileName + "': Ignoring page " + (pageNo + 1) + " because it is rotated");
95 | Console.ResetColor();
96 | continue;
97 | }
98 |
99 | var pageXml = Deserialize(pageXmlLink.FullName);
100 |
101 | var blocks = pageXml.Page.Items;
102 |
103 | foreach (var block in blocks)
104 | {
105 | int category = -1;
106 | PdfRectangle bbox = new PdfRectangle();
107 |
108 | if (block is PageXmlTextRegion textBlock)
109 | {
110 | bbox = ParsePageXmlCoord(textBlock.Coords.Points, page.Height);
111 | switch (textBlock.Type)
112 | {
113 | case PageXmlTextSimpleType.Paragraph:
114 | category = 0;
115 | break;
116 | case PageXmlTextSimpleType.Heading:
117 | category = 1;
118 | break;
119 | case PageXmlTextSimpleType.LisLabel:
120 | category = 2;
121 | break;
122 | default:
123 | throw new ArgumentException("Unknown category");
124 | }
125 |
126 | if (FeatureHelper.GetLettersInside(bbox, page.Letters).Count() == 0)
127 | {
128 | Console.BackgroundColor = ConsoleColor.Red;
129 | Console.ForegroundColor = ConsoleColor.Black;
130 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as an empty paragraph was found");
131 | Console.ResetColor();
132 | isValidDocument = false;
133 | break;
134 | }
135 | }
136 | else if (block is PageXmlTableRegion tableBlock)
137 | {
138 | bbox = ParsePageXmlCoord(tableBlock.Coords.Points, page.Height);
139 | category = 3;
140 | }
141 | else if (block is PageXmlImageRegion imageBlock)
142 | {
143 | bbox = ParsePageXmlCoord(imageBlock.Coords.Points, page.Height);
144 | category = 4;
145 | }
146 | else
147 | {
148 | throw new ArgumentException("Unknown region type");
149 | }
150 |
151 | var letters = FeatureHelper.GetLettersInside(bbox, page.Letters).ToList();
152 | var paths = FeatureHelper.GetPathsInside(bbox, page.ExperimentalAccess.Paths).ToList();
153 | var images = FeatureHelper.GetImagesInside(bbox, page.GetImages());
154 | var f = FeatureHelper.GetFeatures(page, bbox, letters, paths, images);
155 |
156 | if (category == -1)
157 | {
158 | throw new ArgumentException("Unknown category number.");
159 | }
160 |
161 | if (f != null)
162 | {
163 | localFeatures.Add(f);
164 | localCategories.Add(category);
165 | }
166 | }
167 | }
168 | }
169 |
170 | if (isValidDocument)
171 | {
172 | if (localFeatures.Count != localCategories.Count)
173 | {
174 | throw new ArgumentException("features and categories don't have the same size");
175 | }
176 |
177 | foreach (var line in localFeatures.Zip(localCategories, (f, c) => string.Join(",", f) + "," + c))
178 | {
179 | data.Add(line);
180 | }
181 | }
182 | else
183 | {
184 | invalidPdfs.Add(pdfFile.Name);
185 | }
186 | }
187 | catch (Exception ex)
188 | {
189 | Console.ForegroundColor = ConsoleColor.Red;
190 | Console.WriteLine("Error for document '" + fileName + "': " + ex.Message);
191 | Console.ResetColor();
192 | }
193 | Console.WriteLine(done++);
194 | });
195 |
196 | List csv = new List() { header };
197 | csv.AddRange(data);
198 | File.WriteAllLines(outputFullPath, csv);
199 | File.WriteAllLines(outputErrorFullPath, invalidPdfs);
200 |
201 | Console.WriteLine("Done. Csv file saved in " + outputFullPath);
202 | }
203 |
204 | public static string GetDataPath(string fileName)
205 | {
206 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(fileName, "csv"));
207 | }
208 |
209 | private static PageXmlDocument Deserialize(string xmlPath)
210 | {
211 | XmlSerializer serializer = new XmlSerializer(typeof(PageXmlDocument));
212 |
213 | using (var reader = XmlReader.Create(xmlPath))
214 | {
215 | return (PageXmlDocument)serializer.Deserialize(reader);
216 | }
217 | }
218 |
219 | private static PdfRectangle ParsePageXmlCoord(string points, double height)
220 | {
221 | string[] pointsStr = points.Split(' ');
222 |
223 | List pdfPoints = new List();
224 |
225 | foreach (var p in pointsStr)
226 | {
227 | string[] coord = p.Split(',');
228 | pdfPoints.Add(new PdfPoint(double.Parse(coord[0]), height - double.Parse(coord[1])));
229 | }
230 |
231 | return new PdfRectangle(pdfPoints.Min(p => p.X), pdfPoints.Min(p => p.Y), pdfPoints.Max(p => p.X), pdfPoints.Max(p => p.Y));
232 | }
233 |
234 | private static int ParseXmlFileName(string xmlFileName)
235 | {
236 | string split = xmlFileName.Split('_')[1].Replace(".xml", "");
237 | if (int.TryParse(split, out int pageNo))
238 | {
239 | return pageNo;
240 | }
241 |
242 | throw new ArgumentException("Cannot parse page number");
243 | }
244 |
245 | ///
246 | /// https://codereview.stackexchange.com/questions/61338/generate-random-numbers-without-repetitions
247 | ///
248 | private static List GenerateRandom(int count, int min, int max)
249 | {
250 | Random random = new Random(42);
251 |
252 | // initialize set S to empty
253 | // for J := N-M + 1 to N do
254 | // T := RandInt(1, J)
255 | // if T is not in S then
256 | // insert T in S
257 | // else
258 | // insert J in S
259 | //
260 | // adapted for C# which does not have an inclusive Next(..)
261 | // and to make it from configurable range not just 1.
262 |
263 | if (max <= min || count < 0 ||
264 | // max - min > 0 required to avoid overflow
265 | (count > max - min && max - min > 0))
266 | {
267 | // need to use 64-bit to support big ranges (negative min, positive max)
268 | throw new ArgumentOutOfRangeException("Range " + min + " to " + max +
269 | " (" + ((Int64)max - (Int64)min) + " values), or count " + count + " is illegal");
270 | }
271 |
272 | // generate count random values.
273 | HashSet candidates = new HashSet();
274 |
275 | // start count values before max, and end at max
276 | for (int top = max - count; top < max; top++)
277 | {
278 | // May strike a duplicate.
279 | // Need to add +1 to make inclusive generator
280 | // +1 is safe even for MaxVal max value because top < max
281 | if (!candidates.Add(random.Next(min, top + 1)))
282 | {
283 | // collision, add inclusive max.
284 | // which could not possibly have been added before.
285 | candidates.Add(top);
286 | }
287 | }
288 |
289 | // load them in to a list, to sort
290 | List result = candidates.ToList();
291 |
292 | // shuffle the results because HashSet has messed
293 | // with the order, and the algorithm does not produce
294 | // random-ordered results (e.g. max-1 will never be the first value)
295 | for (int i = result.Count - 1; i > 0; i--)
296 | {
297 | int k = random.Next(i + 1);
298 | int tmp = result[k];
299 | result[k] = result[i];
300 | result[i] = tmp;
301 | }
302 | return result;
303 | }
304 | }
305 | }
306 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.Data/v2/DataGenerator.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Concurrent;
3 | using System.Collections.Generic;
4 | using System.IO;
5 | using System.Linq;
6 | using System.Threading.Tasks;
7 | using System.Xml;
8 | using System.Xml.Serialization;
9 | using UglyToad.PdfPig;
10 | using UglyToad.PdfPig.Core;
11 | using UglyToad.PdfPig.DocumentLayoutAnalysis;
12 | using UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE;
13 | using UglyToad.PdfPig.Outline;
14 | using static UglyToad.PdfPig.DocumentLayoutAnalysis.Export.PAGE.PageXmlDocument;
15 |
16 | namespace PdfPigMLNetBlockClassifier.Data.v2
17 | {
18 | public static class DataGenerator
19 | {
20 | public static readonly string OutputFolderPath = @"../../../data";
21 |
22 | ///
23 | /// Generate a csv file of features. You will need the pdf documents and the ground truths in PAGE xml format.
24 | ///
25 | /// The path to the data folder. It should contain both the pdf files
26 | /// and their corresponding ground truth xml files.
27 | /// Number of documents to concider.
28 | public static void GetCsv(string dataFolder, int numberOfPdfDocs, string outputFileName)
29 | {
30 | string outputFullPath = GetDataPath(outputFileName);
31 | string outputErrorFullPath = Path.Combine(OutputFolderPath, "invalide_pdfs_" + Path.ChangeExtension(outputFileName, "txt"));
32 |
33 | ConcurrentBag invalidPdfs = new ConcurrentBag();
34 | ConcurrentBag data = new ConcurrentBag();
35 |
36 | int done = 0;
37 |
38 | DirectoryInfo d = new DirectoryInfo(dataFolder);
39 | var pdfFileLinks = d.GetFiles("*.pdf", SearchOption.TopDirectoryOnly);
40 | var maxPageNumber = d.GetFiles("*.xml", SearchOption.TopDirectoryOnly).Select(f => ParseXmlFileName(f.Name)).Max() + 1;
41 |
42 | numberOfPdfDocs = Math.Min(pdfFileLinks.Count(), numberOfPdfDocs);
43 | numberOfPdfDocs = numberOfPdfDocs == 0 ? pdfFileLinks.Count() : numberOfPdfDocs;
44 |
45 | var indexesSelected = GenerateRandom(numberOfPdfDocs, 0, pdfFileLinks.Length);
46 |
47 | Parallel.ForEach(indexesSelected, index =>
48 | //foreach (var index in indexesSelected)
49 | {
50 | var pdfFile = pdfFileLinks[index];
51 | string fileName = pdfFile.Name;
52 | string xmlFileNameTemplate = fileName.Replace(".pdf", "_");
53 |
54 | var pageXmlLinksCandidates = Enumerable.Range(0, maxPageNumber).Select(i =>
55 | Path.Combine(dataFolder, fileName.Replace(".pdf", "_" + string.Format("{0:00000}", i) + ".xml"))).ToArray();
56 | var pageXmlLinks = pageXmlLinksCandidates.Where(l => File.Exists(l)).Select(l => new FileInfo(l)).ToArray();
57 |
58 | if (pageXmlLinks.Length == 0)
59 | {
60 | Console.BackgroundColor = ConsoleColor.DarkRed;
61 | Console.WriteLine("Error for document '" + fileName + "': No PageXml files found");
62 | Console.ResetColor();
63 | return;
64 | }
65 |
66 | try
67 | {
68 | var pagesNumbers = pageXmlLinks.Select(l => ParseXmlFileName(l.Name)).ToList();
69 | List localFeatures = new List();
70 | List localCategories = new List();
71 | bool isValidDocument = true;
72 |
73 | using (var doc = PdfDocument.Open(pdfFile.FullName))
74 | {
75 | var hasBookmarks = doc.TryGetBookmarks(out Bookmarks bookmarks);
76 | List bookmarksNodes = null;
77 | if (hasBookmarks) bookmarksNodes = bookmarks.GetNodes()
78 | .Where(b => b is DocumentBookmarkNode)
79 | .Select(b => b as DocumentBookmarkNode)
80 | .Cast().ToList();
81 |
82 | // Checks if this pdf document looks to be valid
83 | if ((pagesNumbers.Max() + 1) > doc.NumberOfPages)
84 | {
85 | // ignore this document as page number is not correct
86 | Console.BackgroundColor = ConsoleColor.Red;
87 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as page number is not correct");
88 | Console.ResetColor();
89 | isValidDocument = false;
90 | }
91 |
92 | foreach (var pageXmlLink in pageXmlLinks)
93 | {
94 | if (!isValidDocument) break;
95 |
96 | int pageNo = ParseXmlFileName(pageXmlLink.Name); // base 0
97 |
98 | List pageBookmarksNodes = null;
99 | if (hasBookmarks)
100 | {
101 | pageBookmarksNodes = bookmarksNodes.Where(b => b.PageNumber == pageNo + 1).ToList();
102 | }
103 |
104 | var page = doc.GetPage(pageNo + 1);
105 |
106 | var avgPageFontHeight = page.Letters.Select(l => l.GlyphRectangle.Height).Average();
107 |
108 | if (page.Rotation.Value != 0)
109 | {
110 | Console.BackgroundColor = ConsoleColor.Yellow;
111 | Console.ForegroundColor = ConsoleColor.Black;
112 | Console.WriteLine("Error for document '" + fileName + "': Ignoring page " + (pageNo + 1) + " because it is rotated");
113 | Console.ResetColor();
114 | continue;
115 | }
116 |
117 | var pageXml = Deserialize(pageXmlLink.FullName);
118 |
119 | var blocks = pageXml.Page.Items;
120 |
121 | foreach (var block in blocks)
122 | {
123 | int category = -1;
124 | PdfRectangle bbox = new PdfRectangle();
125 |
126 | if (block is PageXmlTextRegion pageTextRegion)
127 | {
128 | bbox = ParsePageXmlCoord(pageTextRegion.Coords.Points, page.Height);
129 | switch (pageTextRegion.Type)
130 | {
131 | case PageXmlTextSimpleType.Paragraph:
132 | category = 0;
133 | break;
134 | case PageXmlTextSimpleType.Heading:
135 | category = 1;
136 | break;
137 | case PageXmlTextSimpleType.LisLabel:
138 | category = 2;
139 | break;
140 | default:
141 | throw new ArgumentException("Unknown category");
142 | }
143 |
144 | if (FeatureHelper.GetLettersInside(bbox, page.Letters).Count() == 0)
145 | {
146 | Console.BackgroundColor = ConsoleColor.Red;
147 | Console.ForegroundColor = ConsoleColor.Black;
148 | Console.WriteLine("Error for document '" + fileName + "': Ignoring this document as an empty paragraph was found");
149 | Console.ResetColor();
150 | isValidDocument = false;
151 | break;
152 | }
153 | }
154 | else if (block is PageXmlTableRegion tableBlock)
155 | {
156 | bbox = ParsePageXmlCoord(tableBlock.Coords.Points, page.Height);
157 | category = 3;
158 | }
159 | else if (block is PageXmlImageRegion imageBlock)
160 | {
161 | bbox = ParsePageXmlCoord(imageBlock.Coords.Points, page.Height);
162 | category = 4;
163 | }
164 | else
165 | {
166 | throw new ArgumentException("Unknown region type");
167 | }
168 |
169 | TextBlock textBlock = null;
170 | var letters = FeatureHelper.GetLettersInside(bbox, page.Letters).ToList();
171 | if (letters.Any())
172 | {
173 | var words = FeatureHelper.GetWords(letters);
174 | var lines = FeatureHelper.GetLines(words);
175 | if (lines != null && lines.Count > 0)
176 | {
177 | textBlock = new TextBlock(lines);
178 | }
179 | }
180 |
181 | var paths = FeatureHelper.GetPathsInside(bbox, page.ExperimentalAccess.Paths).ToList();
182 | var images = FeatureHelper.GetImagesInside(bbox, page.GetImages());
183 | var f = FeatureHelper.GetFeatures(textBlock, paths, images,
184 | avgPageFontHeight, bbox.Area, pageBookmarksNodes);
185 |
186 | if (category == -1)
187 | {
188 | throw new ArgumentException("Unknown category number.");
189 | }
190 |
191 | if (f != null)
192 | {
193 | localFeatures.Add(f);
194 | localCategories.Add(category);
195 | }
196 | }
197 | }
198 | }
199 |
200 | if (isValidDocument)
201 | {
202 | if (localFeatures.Count != localCategories.Count)
203 | {
204 | throw new ArgumentException("features and categories don't have the same size");
205 | }
206 |
207 | foreach (var line in localFeatures.Zip(localCategories,
208 | (f, c) => string.Join(",", f).Replace(float.NaN.ToString(), "") + "," + c))
209 | {
210 | data.Add(line);
211 | }
212 | }
213 | else
214 | {
215 | invalidPdfs.Add(pdfFile.Name);
216 | }
217 | }
218 | catch (Exception ex)
219 | {
220 | Console.ForegroundColor = ConsoleColor.Red;
221 | Console.WriteLine("Error for document '" + fileName + "': " + ex.Message);
222 | Console.ResetColor();
223 | }
224 | Console.WriteLine(done++);
225 | });
226 |
227 | List csv = new List() { FeatureHelper.Header };
228 | csv.AddRange(data);
229 | File.WriteAllLines(outputFullPath, csv);
230 | File.WriteAllLines(outputErrorFullPath, invalidPdfs);
231 |
232 | Console.WriteLine("Done. Csv file saved in " + outputFullPath);
233 | }
234 |
235 | public static string GetDataPath(string fileName)
236 | {
237 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(fileName, "csv"));
238 | }
239 |
240 | private static PageXmlDocument Deserialize(string xmlPath)
241 | {
242 | XmlSerializer serializer = new XmlSerializer(typeof(PageXmlDocument));
243 |
244 | using (var reader = XmlReader.Create(xmlPath))
245 | {
246 | return (PageXmlDocument)serializer.Deserialize(reader);
247 | }
248 | }
249 |
250 | private static PdfRectangle ParsePageXmlCoord(string points, double height)
251 | {
252 | string[] pointsStr = points.Split(' ');
253 |
254 | List pdfPoints = new List();
255 |
256 | foreach (var p in pointsStr)
257 | {
258 | string[] coord = p.Split(',');
259 | pdfPoints.Add(new PdfPoint(double.Parse(coord[0]), height - double.Parse(coord[1])));
260 | }
261 |
262 | return new PdfRectangle(pdfPoints.Min(p => p.X), pdfPoints.Min(p => p.Y), pdfPoints.Max(p => p.X), pdfPoints.Max(p => p.Y));
263 | }
264 |
265 | private static int ParseXmlFileName(string xmlFileName)
266 | {
267 | string split = xmlFileName.Split('_')[1].Replace(".xml", "");
268 | if (int.TryParse(split, out int pageNo))
269 | {
270 | return pageNo;
271 | }
272 |
273 | throw new ArgumentException("Cannot parse page number");
274 | }
275 |
276 | ///
277 | /// https://codereview.stackexchange.com/questions/61338/generate-random-numbers-without-repetitions
278 | ///
279 | private static List GenerateRandom(int count, int min, int max)
280 | {
281 | Random random = new Random(42);
282 |
283 | // initialize set S to empty
284 | // for J := N-M + 1 to N do
285 | // T := RandInt(1, J)
286 | // if T is not in S then
287 | // insert T in S
288 | // else
289 | // insert J in S
290 | //
291 | // adapted for C# which does not have an inclusive Next(..)
292 | // and to make it from configurable range not just 1.
293 |
294 | if (max <= min || count < 0 ||
295 | // max - min > 0 required to avoid overflow
296 | (count > max - min && max - min > 0))
297 | {
298 | // need to use 64-bit to support big ranges (negative min, positive max)
299 | throw new ArgumentOutOfRangeException("Range " + min + " to " + max +
300 | " (" + ((Int64)max - (Int64)min) + " values), or count " + count + " is illegal");
301 | }
302 |
303 | // generate count random values.
304 | HashSet candidates = new HashSet();
305 |
306 | // start count values before max, and end at max
307 | for (int top = max - count; top < max; top++)
308 | {
309 | // May strike a duplicate.
310 | // Need to add +1 to make inclusive generator
311 | // +1 is safe even for MaxVal max value because top < max
312 | if (!candidates.Add(random.Next(min, top + 1)))
313 | {
314 | // collision, add inclusive max.
315 | // which could not possibly have been added before.
316 | candidates.Add(top);
317 | }
318 | }
319 |
320 | // load them in to a list, to sort
321 | List result = candidates.ToList();
322 |
323 | // shuffle the results because HashSet has messed
324 | // with the order, and the algorithm does not produce
325 | // random-ordered results (e.g. max-1 will never be the first value)
326 | for (int i = result.Count - 1; i > 0; i--)
327 | {
328 | int k = random.Next(i + 1);
329 | int tmp = result[k];
330 | result[k] = result[i];
331 | result[i] = tmp;
332 | }
333 | return result;
334 | }
335 | }
336 | }
337 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbmV2/LightGbmModelBuilder.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML;
2 | using Microsoft.ML.Data;
3 | using Microsoft.ML.Trainers;
4 | using System;
5 | using System.Collections.Generic;
6 | using System.IO;
7 | using System.Linq;
8 |
9 | namespace PdfPigMLNetBlockClassifier.LightGbmV2
10 | {
11 | public static class LightGbmModelBuilder
12 | {
13 | public static readonly string OutputFolderPath = @"../../../model";
14 |
15 | private static MLContext mlContext = new MLContext(seed: 1);
16 |
17 | public static void TrainModel(string trainDataFilePath, string outputModelName)
18 | {
19 | string outputFullPath = Path.Combine(OutputFolderPath, Path.ChangeExtension(outputModelName, "zip"));
20 |
21 | // Load Data
22 | IDataView trainingDataView = mlContext.Data.LoadFromTextFile(
23 | path: trainDataFilePath,
24 | hasHeader: true,
25 | separatorChar: ',',
26 | allowQuoting: true,
27 | allowSparse: false);
28 |
29 | // Build training pipeline
30 | IEstimator trainingPipeline = BuildTrainingPipeline(mlContext);
31 |
32 | // Evaluate quality of Model
33 | CrossValidate(mlContext, trainingDataView, trainingPipeline);
34 |
35 | // Train Model
36 | ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
37 |
38 | // Save model
39 | SaveModel(mlContext, mlModel, outputFullPath, trainingDataView.Schema);
40 | }
41 |
42 | public static void Evaluate(string modelName, string testDataFilePath)
43 | {
44 | string modelFullPath = GetModelPath(modelName);
45 |
46 | // Create new MLContext
47 | MLContext mlContext = new MLContext();
48 |
49 | // Load model & create prediction engine
50 | ITransformer mlModel = mlContext.Model.Load(modelFullPath, out var modelInputSchema);
51 | var predEngine = mlContext.Model.CreatePredictionEngine(mlModel);
52 |
53 | // Load Data
54 | IDataView testingDataView = mlContext.Data.LoadFromTextFile(
55 | path: testDataFilePath,
56 | hasHeader: true,
57 | separatorChar: ',',
58 | allowQuoting: true,
59 | allowSparse: false);
60 |
61 | IDataView transformedTestingDataView = mlModel.Transform(testingDataView);
62 |
63 | Evaluate(mlContext, transformedTestingDataView);
64 |
65 | // Permutation Feature Importance
66 | // https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_
67 |
68 | Console.WriteLine("=============== Permutation Feature Importance ===============");
69 | Console.WriteLine(@"PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that
70 | feature across all the examples, so that each example now has a random value for the feature
71 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is
72 | then calculated for this modified dataset, and the change in the evaluation metric from the
73 | original dataset is computed. The larger the change in the evaluation metric, the more
74 | important the feature is to the model. PFI works by performing this permutation analysis
75 | across all the features of a model, one after another.\n");
76 |
77 | // Get the column name of input features.
78 | string[] featureColumns = testingDataView.Schema.Select(column => column.Name)
79 | .Where(columnName => columnName != "label").ToArray();
80 |
81 | var predictor = ((mlModel as TransformerChain).LastTransformer as TransformerChain)
82 | .First() as MulticlassPredictionTransformer;
83 |
84 | var pfi = mlContext.MulticlassClassification.PermutationFeatureImportance(
85 | predictor,
86 | transformedTestingDataView,
87 | labelColumnName: "label",
88 | permutationCount: 30);
89 |
90 | // Now let's look at which features are most important to the model
91 | // overall. Get the feature indices sorted by their impact on
92 | // microaccuracy.
93 | var sortedIndicesMicro = pfi.Select((metrics, index) => new { index, metrics.MicroAccuracy })
94 | .OrderByDescending(feature => Math.Abs(feature.MicroAccuracy.Mean))
95 | .Select(feature => feature.index);
96 |
97 | Console.WriteLine("Feature\tChange in MicroAccuracy\t95% Confidence in the Mean Change in MicroAccuracy");
98 | var microAccuracy = pfi.Select(x => x.MicroAccuracy).ToArray();
99 |
100 | foreach (int i in sortedIndicesMicro)
101 | {
102 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}",
103 | featureColumns[i],
104 | microAccuracy[i].Mean,
105 | 1.96 * microAccuracy[i].StandardError);
106 | }
107 |
108 | Console.WriteLine();
109 |
110 | // Now let's look at which features are most important to the model
111 | // overall. Get the feature indices sorted by their impact on
112 | // macroaccuracy.
113 | var sortedIndicesMacro = pfi.Select((metrics, index) => new { index, metrics.MacroAccuracy })
114 | .OrderByDescending(feature => Math.Abs(feature.MacroAccuracy.Mean))
115 | .Select(feature => feature.index);
116 |
117 | Console.WriteLine("Feature\tChange in MacroAccuracy\t95% Confidence in the Mean Change in MacroAccuracy");
118 | var macroAccuracy = pfi.Select(x => x.MacroAccuracy).ToArray();
119 |
120 | foreach (int i in sortedIndicesMacro)
121 | {
122 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}",
123 | featureColumns[i],
124 | macroAccuracy[i].Mean,
125 | 1.96 * macroAccuracy[i].StandardError);
126 | }
127 | }
128 |
129 | public static string GetModelPath(string modelName)
130 | {
131 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(modelName, "zip"));
132 | }
133 |
134 | private static IEstimator BuildTrainingPipeline(MLContext mlContext)
135 | {
136 | // Data process configuration with pipeline data transformations
137 | var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("label", "label")
138 | .Append(mlContext.Transforms.Concatenate("Features",
139 | new[]
140 | {
141 | "blockAspectRatio", "charsCount",
142 | "wordsCount", "linesCount",
143 | "pctNumericChars", "pctAlphabeticalChars",
144 | "pctSymbolicChars", "pctBulletChars",
145 | "deltaToHeight", "pathsCount",
146 | "pctBezierPaths", "pctHorPaths",
147 | "pctVertPaths", "pctOblPaths",
148 | "imagesCount", "imageAvgProportion",
149 | "bestNormEditDistance"
150 | }));
151 |
152 | // Set the training algorithm
153 | var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(labelColumnName: "label", featureColumnName: "Features")
154 | .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
155 | var trainingPipeline = dataProcessPipeline.Append(trainer);
156 |
157 | return trainingPipeline;
158 | }
159 |
160 | private static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline)
161 | {
162 | Console.WriteLine("=============== Training model ===============");
163 |
164 | ITransformer model = trainingPipeline.Fit(trainingDataView);
165 |
166 | Console.WriteLine("=============== End of training process ===============");
167 | return model;
168 | }
169 |
170 | private static void CrossValidate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline)
171 | {
172 | // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
173 | // in order to evaluate and get the model's accuracy metrics
174 | Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
175 | var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "label");
176 | PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
177 | }
178 |
179 | private static void Evaluate(MLContext mlContext, IDataView testingDataView)
180 | {
181 | Console.WriteLine("=============== Evaluating to get model's accuracy metrics ===============");
182 | var evaluationResults = mlContext.MulticlassClassification.Evaluate(testingDataView, labelColumnName: "label");
183 | PrintMulticlassClassificationMetrics(evaluationResults);
184 | }
185 |
186 | private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
187 | {
188 | // Save/persist the trained model to a .ZIP file
189 | Console.WriteLine($"=============== Saving the model ===============");
190 | mlContext.Model.Save(mlModel, modelInputSchema, modelRelativePath);
191 | Console.WriteLine("The model is saved to {0}", modelRelativePath);
192 | }
193 |
194 | private static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
195 | {
196 | Console.WriteLine($"************************************************************");
197 | Console.WriteLine($"* Metrics for multi-class classification model ");
198 | Console.WriteLine($"*-----------------------------------------------------------");
199 | Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
200 | Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
201 | Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
202 | for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
203 | {
204 | Console.WriteLine($" LogLoss for class {i} \t= {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
205 | }
206 |
207 | Console.WriteLine(" " + metrics.ConfusionMatrix.GetFormattedConfusionTable());
208 |
209 | for (int i = 0; i < metrics.ConfusionMatrix.PerClassPrecision.Count; i++)
210 | {
211 | var precision = metrics.ConfusionMatrix.PerClassPrecision[i];
212 | var recall = metrics.ConfusionMatrix.PerClassRecall[i];
213 | var f1Score = 2 * (precision * recall) / (precision + recall);
214 | Console.WriteLine($" F1 Score for class {i} \t= {f1Score:0.####}, a value between 0 and 1, the closer to 1, the better");
215 | }
216 |
217 | Console.WriteLine($"************************************************************");
218 | }
219 |
220 | private static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults)
221 | {
222 | var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
223 |
224 | var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
225 | var microAccuracyAverage = microAccuracyValues.Average();
226 | var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
227 | var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
228 |
229 | var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
230 | var macroAccuracyAverage = macroAccuracyValues.Average();
231 | var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
232 | var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
233 |
234 | var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
235 | var logLossAverage = logLossValues.Average();
236 | var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
237 | var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
238 |
239 | var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
240 | var logLossReductionAverage = logLossReductionValues.Average();
241 | var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
242 | var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
243 |
244 | Console.WriteLine($"*************************************************************************************************************");
245 | Console.WriteLine($"* Metrics for Multi-class Classification model ");
246 | Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
247 | Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
248 | Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
249 | Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
250 | Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
251 | Console.WriteLine($"*************************************************************************************************************");
252 | }
253 |
254 | private static double CalculateStandardDeviation(IEnumerable values)
255 | {
256 | double average = values.Average();
257 | double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
258 | double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
259 | return standardDeviation;
260 | }
261 |
262 | private static double CalculateConfidenceInterval95(IEnumerable values)
263 | {
264 | double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
265 | return confidenceInterval95;
266 | }
267 | }
268 | }
269 |
--------------------------------------------------------------------------------
/PdfPigMLNetBlockClassifier.LightGbm/LightGbmModelBuilder.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML;
2 | using Microsoft.ML.Data;
3 | using Microsoft.ML.Trainers;
4 | using Microsoft.ML.Trainers.LightGbm;
5 | using System;
6 | using System.Collections.Generic;
7 | using System.IO;
8 | using System.Linq;
9 |
10 | namespace PdfPigMLNetBlockClassifier.LightGbm
11 | {
12 | public static class LightGbmModelBuilder
13 | {
14 | public static readonly string OutputFolderPath = @"../../../model";
15 |
16 | private static MLContext mlContext = new MLContext(seed: 1);
17 |
18 | public static void TrainModel(string trainDataFilePath, string outputModelName)
19 | {
20 | string outputFullPath = Path.Combine(OutputFolderPath, Path.ChangeExtension(outputModelName, "zip"));
21 |
22 | // Load Data
23 | IDataView trainingDataView = mlContext.Data.LoadFromTextFile(
24 | path: trainDataFilePath,
25 | hasHeader: true,
26 | separatorChar: ',',
27 | allowQuoting: true,
28 | allowSparse: false);
29 |
30 | // Build training pipeline
31 | IEstimator trainingPipeline = BuildTrainingPipeline(mlContext);
32 |
33 | // Evaluate quality of Model
34 | CrossValidate(mlContext, trainingDataView, trainingPipeline);
35 |
36 | // Train Model
37 | ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
38 |
39 | // Save model
40 | SaveModel(mlContext, mlModel, outputFullPath, trainingDataView.Schema);
41 | }
42 |
43 | public static void Evaluate(string modelName, string testDataFilePath)
44 | {
45 | string modelFullPath = GetModelPath(modelName);
46 |
47 | // Create new MLContext
48 | MLContext mlContext = new MLContext();
49 |
50 | // Load model & create prediction engine
51 | ITransformer mlModel = mlContext.Model.Load(modelFullPath, out var modelInputSchema);
52 | var predEngine = mlContext.Model.CreatePredictionEngine(mlModel);
53 |
54 | // Load Data
55 | IDataView testingDataView = mlContext.Data.LoadFromTextFile(
56 | path: testDataFilePath,
57 | hasHeader: true,
58 | separatorChar: ',',
59 | allowQuoting: true,
60 | allowSparse: false);
61 |
62 | IDataView transformedTestingDataView = mlModel.Transform(testingDataView);
63 |
64 | Evaluate(mlContext, transformedTestingDataView);
65 |
66 | // Permutation Feature Importance
67 | // https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.permutationfeatureimportanceextensions.permutationfeatureimportance?view=ml-dotnet#Microsoft_ML_PermutationFeatureImportanceExtensions_PermutationFeatureImportance__1_Microsoft_ML_MulticlassClassificationCatalog_Microsoft_ML_ISingleFeaturePredictionTransformer___0__Microsoft_ML_IDataView_System_String_System_Boolean_System_Nullable_System_Int32__System_Int32_
68 |
69 | Console.WriteLine("=============== Permutation Feature Importance ===============");
70 | Console.WriteLine(@"PFI works by taking a labeled dataset, choosing a feature, and permuting the values for that
71 | feature across all the examples, so that each example now has a random value for the feature
72 | and the original values for all other features. The evaluation metric (e.g. micro-accuracy) is
73 | then calculated for this modified dataset, and the change in the evaluation metric from the
74 | original dataset is computed. The larger the change in the evaluation metric, the more
75 | important the feature is to the model. PFI works by performing this permutation analysis
76 | across all the features of a model, one after another.\n");
77 |
78 | // Get the column name of input features.
79 | string[] featureColumns = testingDataView.Schema.Select(column => column.Name)
80 | .Where(columnName => columnName != "label").ToArray();
81 |
82 | var predictor = ((mlModel as TransformerChain).LastTransformer as TransformerChain)
83 | .First() as MulticlassPredictionTransformer;
84 |
85 | var pfi = mlContext.MulticlassClassification.PermutationFeatureImportance(
86 | predictor,
87 | transformedTestingDataView,
88 | labelColumnName: "label",
89 | permutationCount: 30);
90 |
91 | // Now let's look at which features are most important to the model
92 | // overall. Get the feature indices sorted by their impact on
93 | // microaccuracy.
94 | var sortedIndicesMicro = pfi.Select((metrics, index) => new { index, metrics.MicroAccuracy })
95 | .OrderByDescending(feature => Math.Abs(feature.MicroAccuracy.Mean))
96 | .Select(feature => feature.index);
97 |
98 | Console.WriteLine("Feature\tChange in MicroAccuracy\t95% Confidence in the Mean Change in MicroAccuracy");
99 | var microAccuracy = pfi.Select(x => x.MicroAccuracy).ToArray();
100 |
101 | foreach (int i in sortedIndicesMicro)
102 | {
103 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}",
104 | featureColumns[i],
105 | microAccuracy[i].Mean,
106 | 1.96 * microAccuracy[i].StandardError);
107 | }
108 |
109 | Console.WriteLine();
110 |
111 | // Now let's look at which features are most important to the model
112 | // overall. Get the feature indices sorted by their impact on
113 | // macroaccuracy.
114 | var sortedIndicesMacro = pfi.Select((metrics, index) => new { index, metrics.MacroAccuracy })
115 | .OrderByDescending(feature => Math.Abs(feature.MacroAccuracy.Mean))
116 | .Select(feature => feature.index);
117 |
118 | Console.WriteLine("Feature\tChange in MacroAccuracy\t95% Confidence in the Mean Change in MacroAccuracy");
119 | var macroAccuracy = pfi.Select(x => x.MacroAccuracy).ToArray();
120 |
121 | foreach (int i in sortedIndicesMacro)
122 | {
123 | Console.WriteLine("{0}\t{1:G4}\t{2:G4}",
124 | featureColumns[i],
125 | macroAccuracy[i].Mean,
126 | 1.96 * macroAccuracy[i].StandardError);
127 | }
128 | }
129 |
130 | public static string GetModelPath(string modelName)
131 | {
132 | return Path.Combine(OutputFolderPath, Path.ChangeExtension(modelName, "zip"));
133 | }
134 |
135 | private static IEstimator BuildTrainingPipeline(MLContext mlContext)
136 | {
137 | // Data process configuration with pipeline data transformations
138 | var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("label", "label")
139 | .Append(mlContext.Transforms.Concatenate("Features", new[] { "charsCount", "pctNumericChars", "pctAlphabeticalChars",
140 | "pctSymbolicChars", "pctBulletChars", "deltaToHeight",
141 | "pathsCount", "pctBezierPaths", "pctHorPaths",
142 | "pctVertPaths", "pctOblPaths", "imagesCount",
143 | "imageAvgProportion" }));
144 |
145 | // Set the training algorithm
146 | var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options()
147 | {
148 | NumberOfIterations = 150,
149 | LearningRate = 0.1158737f,
150 | NumberOfLeaves = 39,
151 | MinimumExampleCountPerLeaf = 50,
152 | UseCategoricalSplit = true,
153 | HandleMissingValue = false,
154 | MinimumExampleCountPerGroup = 50,
155 | MaximumCategoricalSplitPointCount = 32,
156 | CategoricalSmoothing = 10,
157 | L2CategoricalRegularization = 1,
158 | UseSoftmax = false,
159 | Booster = new GradientBooster.Options()
160 | {
161 | L2Regularization = 0,
162 | L1Regularization = 0
163 | },
164 | LabelColumnName = "label",
165 | FeatureColumnName = "Features",
166 | //UnbalancedSets = true, // added by BobLd
167 | }).Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
168 |
169 | var trainingPipeline = dataProcessPipeline.Append(trainer);
170 |
171 | return trainingPipeline;
172 | }
173 |
174 | private static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline)
175 | {
176 | Console.WriteLine("=============== Training model ===============");
177 |
178 | ITransformer model = trainingPipeline.Fit(trainingDataView);
179 |
180 | Console.WriteLine("=============== End of training process ===============");
181 | return model;
182 | }
183 |
184 | private static void CrossValidate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline)
185 | {
186 | // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
187 | // in order to evaluate and get the model's accuracy metrics
188 | Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
189 | var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "label");
190 | PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
191 | }
192 |
193 | private static void Evaluate(MLContext mlContext, IDataView testingDataView)
194 | {
195 | Console.WriteLine("=============== Evaluating to get model's accuracy metrics ===============");
196 | var evaluationResults = mlContext.MulticlassClassification.Evaluate(testingDataView, labelColumnName: "label");
197 | PrintMulticlassClassificationMetrics(evaluationResults);
198 | }
199 |
200 | private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
201 | {
202 | // Save/persist the trained model to a .ZIP file
203 | Console.WriteLine($"=============== Saving the model ===============");
204 | mlContext.Model.Save(mlModel, modelInputSchema, modelRelativePath);
205 | Console.WriteLine("The model is saved to {0}", modelRelativePath);
206 | }
207 |
208 | private static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
209 | {
210 | Console.WriteLine($"************************************************************");
211 | Console.WriteLine($"* Metrics for multi-class classification model ");
212 | Console.WriteLine($"*-----------------------------------------------------------");
213 | Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
214 | Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
215 | Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
216 | for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
217 | {
218 | Console.WriteLine($" LogLoss for class {i} \t= {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
219 | }
220 |
221 | Console.WriteLine(" " + metrics.ConfusionMatrix.GetFormattedConfusionTable());
222 |
223 | for (int i = 0; i < metrics.ConfusionMatrix.PerClassPrecision.Count; i++)
224 | {
225 | var precision = metrics.ConfusionMatrix.PerClassPrecision[i];
226 | var recall = metrics.ConfusionMatrix.PerClassRecall[i];
227 | var f1Score = 2 * (precision * recall) / (precision + recall);
228 | Console.WriteLine($" F1 Score for class {i} \t= {f1Score:0.####}, a value between 0 and 1, the closer to 1, the better");
229 | }
230 |
231 | Console.WriteLine($"************************************************************");
232 | }
233 |
234 | private static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults)
235 | {
236 | var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
237 |
238 | var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
239 | var microAccuracyAverage = microAccuracyValues.Average();
240 | var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
241 | var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
242 |
243 | var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
244 | var macroAccuracyAverage = macroAccuracyValues.Average();
245 | var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
246 | var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
247 |
248 | var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
249 | var logLossAverage = logLossValues.Average();
250 | var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
251 | var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
252 |
253 | var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
254 | var logLossReductionAverage = logLossReductionValues.Average();
255 | var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
256 | var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
257 |
258 | Console.WriteLine($"*************************************************************************************************************");
259 | Console.WriteLine($"* Metrics for Multi-class Classification model ");
260 | Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
261 | Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
262 | Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
263 | Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
264 | Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
265 | Console.WriteLine($"*************************************************************************************************************");
266 | }
267 |
268 | private static double CalculateStandardDeviation(IEnumerable values)
269 | {
270 | double average = values.Average();
271 | double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
272 | double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
273 | return standardDeviation;
274 | }
275 |
276 | private static double CalculateConfidenceInterval95(IEnumerable values)
277 | {
278 | double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
279 | return confidenceInterval95;
280 | }
281 | }
282 | }
283 |
--------------------------------------------------------------------------------