├── .DS_Store
├── AUTHORS.txt
├── LICENSE.txt
├── README.md
├── doc
├── README_apply_PCA_parameters.txt
├── README_boosted_test.txt
├── README_correct_labels.txt
├── README_count_common.txt
├── README_dataset_info.txt
├── README_extract_symbol.txt
├── README_get_PCA_parameters.txt
├── README_get_enhanced_clustered_set.txt
├── README_get_training_set.txt
├── README_parallel_evaluate.txt
├── README_parallel_prob_evaluate.txt
├── README_random_forest_classify.txt
├── README_svm_lin_classifier.txt
├── README_svm_rbf_classifier.txt
├── README_train_adaboost.txt
└── README_train_c45.txt
├── server
├── PenStrokeServer.py
├── best_full2013_SVMRBF_new.dat
├── generic_symbol_table.csv
├── mathSymbol.py
├── start
├── symbol_classifier.py
├── test_classifier.py
└── traceInfo.py
└── src
├── adaboost_c45.c
├── ambiguous.txt
├── apply_PCA_parameters.py
├── boosted_test.py
├── c45.h
├── correct_labels.py
├── count_common.py
├── dataset_info.py
├── dataset_ops.py
├── distorter.py
├── distorter_lib.c
├── evaluation_ops.py
├── extract_symbol.py
├── get_PCA_parameters.py
├── get_enhanced_clustered_set.py
├── get_training_set.py
├── load_inkml.py
├── mathSymbol.py
├── parallel_evaluate.py
├── parallel_prob_evaluate.py
├── random_forest_classify.py
├── svm_lin_classifier.py
├── svm_rbf_classifier.py
├── symbol_classifier.py
├── test_classifier.py
├── test_classify_inkml.py
├── traceInfo.py
├── train_adaboost.py
└── train_c45.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DPRL/MathSymbolRecognizer/21a68677770a953d176de57356685027e27c0414/.DS_Store
--------------------------------------------------------------------------------
/AUTHORS.txt:
--------------------------------------------------------------------------------
1 | DPRL Math Symbol Recognizers
2 | Copyright (c) 2014 Kenny Davila, Richard Zanibbi
3 |
4 | This collection of classifiers for isolated handwritten math symbols was
5 | created by Kenny Davila on-and-off over a year-and-a-half. The code was
6 | first created for a project for Dr. Zanibbi's Pattern Recognition course in
7 | Fall 2012, was then used for the RIT entry in the CROHME 2013 handwritten math
8 | recognition competition, and was then later extended and revised under Dr.
9 | Zanibbi's supervision.
10 |
11 | The system was developed at the Rochester Institute of Technology in
12 | Rochester, New York, USA, primarily in the Document and Pattern Recognition
13 | Lab (DPRL) located within the Computer Science Department.
14 |
15 | Please see the LICENSE.txt file for terms of use (the system is being
16 | distributed under the GNU General Public License version 3).
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DPRL Math Symbol Recognizers
2 |
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 | ***
5 |
6 | RIT DPRL Math Symbol Recognizers is free software: you can redistribute it
7 | and/or modify it under the terms of the GNU General Public License as published
8 | by the Free Software Foundation, either version 3 of the License, or (at your
9 | option) any later version.
10 |
11 | RIT DPRL Math Symbol Recognizers is distributed in the hope that it will be
12 | useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
14 | Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License along with
17 | RIT DPRL Math Symbol Recognizers. If not, see .
18 |
19 | Contact:
20 | * Kenny Davila: kxd7282@rit.edu
21 | * Richard Zanibbi: rlaz@cs.rit.edu
22 |
23 | ***
24 | The system divides is composed of different tools for data extraction, training
25 | and evaluation and other miscellaneous tools for isolated math symbol
26 | recognition.
27 |
28 | A README file is included in the doc/ directory for each tool that describes
29 | its purpose, how to use it and what its parameters are.
30 |
31 | A README file is included in the doc/ directory for each tool that describes
32 | its purpose, how to use it and what its parameters are.
33 |
34 | The executable scripts on this release are the following:
35 |
36 | * Preprocessing of data:
37 | apply_PCA_parameters.py
38 | correct_labels.py
39 | get_enhanced_clustered_set.py
40 | get_PCA_parameters.py
41 | get_training_set.py
42 | * Analysis of datasets:
43 | count_common.py
44 | dataset_info.py
45 | extract_symbol.py
46 |
47 | * Training a symbol classifier:
48 | random_forest_classify.py
49 | svm_lin_classifier.py
50 | svm_rbf_classifier.py
51 | train_adaboost.py
52 | train_c45.py
53 | * Tools for evaluation
54 | boosted_test.py
55 | parallel_evaluate.py
56 | parallel_prob_evaluate.py
57 |
58 |
59 |
60 | ***
61 | # SOURCE FILES
62 | Source code (in Python and C) is provided in the src/ directory.
--------------------------------------------------------------------------------
/doc/README_apply_PCA_parameters.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for Application of PCA parameters
3 |
4 | The tool will take a file containing a dataset, a file containing PCA parameters
5 | (Normalization, Eigenvectors, Eigenvalues, Max PC to use) and will produce the projected version
6 | of the dataset and store it in the specified file.
7 |
8 | Usage: python apply_PCA_parameters.py training_set PCA_params output
9 | Where
10 | training_set = Path to the file of the training_set
11 | PCA_params = Path to the file of the PCA parameters
12 | output = File to output the final dataset with reduced dimensionality
13 |
14 |
15 |
--------------------------------------------------------------------------------
/doc/README_boosted_test.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for evaluating AdaBoost C4.5 classifier
3 |
4 | This tool evaluates the training and testing performance of an AdaBoost with C4.5 decision trees
5 | classifier. It stores the final confusion matrix into a file. It can also save the failure cases
6 | for the testing set in .SVG format if the file with the sources for each sample is present on
7 | current directory. Saving failure cases requires a directory called "output" to be created on
8 | the same directory where the tool is located.
9 |
10 | Usage: python boosted_test.py classifier training_set testing_set [save_fail]
11 | Where
12 | classifier = Path to the .bc45 file that contains the classifier
13 | training_set = Path to the file of the training set
14 | testing_set = Path to the file of the testing set
15 | save_fail = Optional, will output failure case if specified
16 |
17 |
18 | ================================================
19 | Notes
20 | ================================================
21 |
22 | - All tools that work with AdaBoost require the AdaBoost library to be located
23 | on the same directory. This is a shared library that can be compiled from
24 | the provided source code using the following commands:
25 |
26 | Linux:
27 | gcc -shared -fPIC adaboost_c45.c -o adaboost_c45.so
28 |
29 | Windows (using MinGW):
30 | gcc -shared adaboost_c45.c -o adaboost_c45.so
31 |
32 |
--------------------------------------------------------------------------------
/doc/README_correct_labels.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for correcting labels in CROHME datasets
3 |
4 | CROHME 2013 datasets contain more labels than classes as defined by the competition. To fix this
5 | issue use this tool to automatically correct the additional labels to the set of 101 valid classes.
6 |
7 | Usage: python correct_labels.py training_set output_set
8 | Where
9 | training_set = Path to the file of the training_set
10 | output_set = File to output file with corrected labels
11 |
12 |
--------------------------------------------------------------------------------
/doc/README_count_common.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for count and extraction of data from common classes between datasets
3 |
4 | Use count_common.py to analyze sets of common labels between two given datasets.
5 | If specified, the program can extract the data of the common classes from the second
6 | dataset and store it in a new file called: [Name of dataset 2]".common.txt".
7 |
8 | This tool can be used for example to extract data from CROHME 2013 dataset that
9 | have the same labels as a CROHME 2012 dataset.
10 |
11 |
12 | Usage: python count_common.py dataset_1 dataset_2 extract
13 | Where
14 | dataset_1 = Path to the file of the first dataset
15 | dataset_2 = Path to the file of the second dataset
16 | extract = Extract common samples from second dataset
17 |
18 |
--------------------------------------------------------------------------------
/doc/README_dataset_info.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool to extract general information from dataset file
3 |
4 | Use dataset_info.py to compute the number of classes present in dataset file and
5 | the number of samples per class. It also shows the number of current attributes.
6 | An histogram of the class representation is built using the specified number of
7 | bins (10 by default).
8 |
9 | Usage: python dataset_info.py dataset [n_bins]
10 | Where
11 | dataset = Path to file that contains the data set
12 | n_bins = Optional, number of bins for histogram of class representation
13 |
14 |
--------------------------------------------------------------------------------
/doc/README_extract_symbol.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool to extract an specific symbol from an inkml file
3 |
4 | Use extract_symbol.py to extract a symbol and store in .SVG format on the specified
5 | output file. Requires the full path to the inkml_file where the desired symbol is
6 | located and the id of the symbol to extract from that file. This tool is useful to
7 | visualize special cases.
8 |
9 | Usage: python extract_symbol.py inkml_file sym_id output
10 | Where
11 | inkml_file = Path to the inkml file that contains the symbol
12 | sym_id = Id of the symbol to extract
13 | output = File where the extracted symbol will be stored
14 |
15 |
--------------------------------------------------------------------------------
/doc/README_get_PCA_parameters.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for extraction of PCA parameters
3 |
4 | The tool will take a file containing a dataset and will compute the corresponding
5 | PCA parameters (Normalization, Eigenvectors, Eigenvalues, max number of PCA).
6 | and will store them to a file. The user can specify a maximum number of PCA to use
7 | and a maximum value of variance, the program will use the smallest of these two
8 | as the final number of principal components to store.
9 |
10 | Usage: python get_PCA_parameters.py training_set output_params var_max k_max
11 | Where
12 | training_set = Path to the file of the training_set
13 | output_params = File to output PCA preprocessing parameters
14 | var_max = Maximum percentage of variance to add
15 | k_max = Maximum number of Principal Components to Add
16 |
--------------------------------------------------------------------------------
/doc/README_get_enhanced_clustered_set.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool used for expansion of a dataset
3 |
4 | Use get_enhanced_clustered_set.py to expand a dataset using synthetic data. Two ways of expansion are
5 | provided, one for under-represented classes and the second for over-represented classes. The min_prc
6 | parameter defines a threshold that separates the under_represented classes from the over-represented.
7 | Note that min_prc also defines the minimum representation per class based on the size of the largest class.
8 | For example, if the largest class has 5,000 samples and min_prc is set to 0.20, then all classes will have
9 | at least 20% x 5,0000 = 1,000 samples. For over-represented classes clustering is applied, and then the
10 | clust_prc parameter is used to defined the minimum final size for each cluster per class based on the size
11 | of largest cluster per each class.
12 |
13 | Usage: python get_enhanced_clustered_set.py inkml_path output min_prc diag_dist
14 | max_clusters clust_prc [verbose] [count_only]
15 | Where
16 | inkml_path = Path to directory that contains the inkml files
17 | output = File name of the output file
18 | min_prc = Minimum representation based on (%) of largest class
19 | diag_dist = Distortion factor relative to length of main diagonal
20 | max_clusters = Maximum number of clusters per large class
21 | clust_prc = Minimum cluster size based on (%) of largest
22 | verbose = Optional, print detailed messages
23 | count_only = Will only count what will be the final size of dataset
24 |
25 |
26 | ================================================
27 | Notes
28 | ================================================
29 |
30 | - All tools that work with synthetic data generation require the distorter library
31 | to be located on the same directory. This is a shared library that can be compiled from
32 | the provided source code using the following commands:
33 |
34 | Linux:
35 | gcc -shared -fPIC distorter_lib.c -o distorter_lib.so
36 |
37 | Windows (using MinGW):
38 | gcc -shared distorter_lib.c -o distorter_lib.so
39 |
40 | - To only expand data without using the clustering option for large classes
41 | use max_clusters = 1 and clust_prc = 0.0.
--------------------------------------------------------------------------------
/doc/README_get_training_set.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool used to produce a training set from INKML files
3 |
4 | Use get_training_set.py tool to process and extract the features of the isolated
5 | symbols present in a set of INKML files. The system first extract all symbols found
6 | in the inkml files present in the given directory and then for each symbol it will
7 | extract the current set of features as defined in MathSymbol.py. Then, final dataset
8 | ready to use for training will be stored in the specified output file.
9 |
10 | Usage: python get_training_set.py inkml_path output
11 | Where
12 | inkml_path = Path to directory that contains the inkml files
13 | output = File name of the output file
14 |
--------------------------------------------------------------------------------
/doc/README_parallel_evaluate.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for evaluation of classifier performance in parallel
3 |
4 | Use parallel_evaluate.py to evaluate the performance of a given classifier over
5 | training and testing sets in parallel threads for shorter evaluation time.
6 | Currently, this tool only supports classifiers from the Scikit-learn library like
7 | SVC (support vector classifier) and Random Forests. It does not support AdaBoost with
8 | C4.5 classifier. For AdaBoost C4.5 classifier use "boosted_test.py".
9 |
10 | The tool computes global and per class accuracy with the corresponding confusion matrix. If
11 | specified, the tool computes the the global and per class accuracy with confusion matrix for
12 | the case where errors between ambiguous classes are ignored. The list of ambiguous classes
13 | simply contains pairs of symbols considered ambiguous, one pair per line. Separate the
14 | symbols on each line with comma.
15 |
16 |
17 | Usage: python parallel_evaluate.py training_set testing_set classifier normalize
18 | workers [test_only] [ambiguous]
19 | Where
20 | training_set = Path to the file of the training set
21 | testing_set = Path to the file of the testing set
22 | classifier = File that contains the pickled classifier
23 | normalized = Whether training data was normalized prior training
24 | workers = Number of parallel threads to use
25 | test_only = Optional, Only execute for testing set
26 | ambiguous = Optional, file that contains the list of ambiguous
27 |
28 |
29 | ================================================
30 | Notes
31 | ================================================
32 |
33 | - For SVM classifiers remember to use the "normalized" option as 1.
--------------------------------------------------------------------------------
/doc/README_parallel_prob_evaluate.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for evaluation of probabilistic classifier performance in parallel
3 |
4 | Use parallel_prob_evaluate.py to evaluate the performance of a given probabilistic
5 | classifier over training and testing sets in parallel threads for shorter evaluation time.
6 | Currently, this tool only supports classifiers from the Scikit-learn library like
7 | SVC (support vector classifier) trained with the probabilistic option and Random Forests.
8 | It does not support AdaBoost with C4.5 classifier. For AdaBoost C4.5 classifier use "boosted_test.py".
9 |
10 | The tool computes global and per class accuracy from Top-1 to Top-5. If
11 | specified, the tool computes the the global and per class accuracy Top-1 to Top-5 for
12 | the case where errors between ambiguous classes are ignored. The list of ambiguous classes
13 | simply contains pairs of symbols considered ambiguous, one pair per line. Separate the
14 | symbols on each line with comma.
15 |
16 |
17 | Usage: python parallel_prob_evaluate.py training_set testing_set classifier normalize
18 | workers [test_only] [ambiguous]
19 | Where
20 | training_set = Path to the file of the training set
21 | testing_set = Path to the file of the testing set
22 | classifier = File that contains the pickled classifier
23 | normalized = Whether training data was normalized prior training
24 | workers = Number of parallel threads to use
25 | test_only = Optional, Only execute for testing set
26 | ambiguous = Optional, file that contains the list of ambiguous
27 |
28 |
29 | ================================================
30 | Notes
31 | ================================================
32 |
33 | - For SVM classifiers remember to use the "normalized" option as 1.
34 | - For SVM classifiers, only the ones trained as probabilistic are supported.
--------------------------------------------------------------------------------
/doc/README_random_forest_classify.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for training and testing of Random Forest classifiers
3 |
4 | Use random_forest_classify.py to train and test the performance of Random Forest classifiers
5 | over the given training and testing sets. Must specify the number of decision trees to include
6 | on each forest and the maximum depth of each tree. Also, the current implementation of random
7 | forest used is included in the Scikit-Learn library. This implementation randomize the available
8 | features at each split, and the max_feats parameter is used to set the maximum number of features
9 | to consider at each split. Two split criterion are available: Gini impurity and Entropy.
10 |
11 | Since random forests introduce randomness on the training process, a single classifier might not be
12 | enough to define the final performance of this type of classifier over the given dataset. The
13 | parameter times allows the user to define how many classifiers will be trained and tested and the
14 | final mean of them is going to be used for the final metric. At the end, the system keeps the classifier
15 | that achieved the highest global accuracy and stores it to a final named: [Name of training set]".best.RF".
16 |
17 | Use n_jobs parameter to define the number of threads to use during the training process. if omitted,
18 | everything will be done on a single thread.
19 |
20 | Usage: python random_forest_classify.py training_set testing_set N_trees max_D max_feats
21 | type times [n_jobs]
22 | Where
23 | training_set = Path to the file of the training set
24 | testing_set = Path to the file of the testing set
25 | N_trees = Number of trees to use
26 | max_D = Maximum Depth
27 | max_feats = Maximum Features
28 | type = Type of Decision trees (criterion for splits)
29 | 0 - Gini
30 | 1 - Entropy
31 | times = Number of times to repeat experiments
32 | n_jobs = Optional, number of parallel threads to use
33 |
34 |
--------------------------------------------------------------------------------
/doc/README_svm_lin_classifier.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for training and evaluation of SVM with linear kernel.
3 |
4 | Use svm_lin_classifier.py to train and evaluate the performance of SVM classifier with
5 | linear kernel over the specified dataset. Since testing can take very long for large
6 | datasets, the "evaluate" option is provided. By default the program will evaluate
7 | the classifier training and testing error after the training process is complete.
8 | However, if evaluate parameter is 0, the program will terminate after training is
9 | done and parallel_evalaute.py or parallel_prob_evalaute.py can be used for evaluation
10 | later with even more complete output. The probab parameter can be specified to make
11 | the classifier probabilistic. Note that top-1 accuracy of probabilistic classifiers
12 | might be a little bit lower than non-probabilistic classifier. Probabilistic
13 | classifiers are required to compute top-5 classification accuracy. After training
14 | is finished, the program will store the trained classifier on a file called:
15 | [Name of training]".LSVM"
16 |
17 | Usage: python svm_lin_classifier.py training testing [evaluate] [probab]
18 | Where
19 | training = Path to training set
20 | testing = Path to testing set
21 | evaluate = Optional, run evaluation or just training
22 | probab = Optional, make it a probabilistic classifier
23 |
24 |
--------------------------------------------------------------------------------
/doc/README_svm_rbf_classifier.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for training and evaluation of SVM with Radial-Basis Function (RBF) kernel.
3 |
4 | Use svm_rbf_classifier.py to train and evaluate the performance of SVM classifier
5 | with RBF kernel over the specified dataset. Since testing can take very long for large
6 | datasets, the "evaluate" option is provided. By default the program will evaluate
7 | the classifier training and testing error after the training process is complete.
8 | However, if evaluate parameter is 0, the program will terminate after training is
9 | done and parallel_evalaute.py or parallel_prob_evalaute.py can be used for evaluation
10 | later with even more complete output. The probab parameter can be specified to make
11 | the classifier probabilistic. Note that top-1 accuracy of probabilistic classifiers
12 | might be a little bit lower than non-probabilistic classifier. Probabilistic
13 | classifiers are required to compute top-5 classification accuracy. After training
14 | is finished, the program will store the trained classifier on a file called:
15 | [Name of training]".LSVM"
16 |
17 | Usage: python svm_rbf_classifier.py training testing C Gamma eval [probab]
18 | Where
19 | training = Path to training set
20 | testing = Path to testing set
21 | C = C parameter of the RBF SVM
22 | Gamma = Gamma parameter of the RBF SVM
23 | eval = Optional, will not evaluate if equal to 1
24 | probab = Optional, make it a probabilistic classifier
--------------------------------------------------------------------------------
/doc/README_train_adaboost.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for training and evaluation of AdaBoost with C4.5 decision trees classifier.
3 |
4 | Use train_adaboost.py to train and evaluate the performance of AdaBoost with C4.5
5 | decision trees over the specified dataset. After training is finished,
6 | the program will store the trained classifier on a file called:
7 | [Name of training]".BC45"
8 |
9 | Usage: python train_adaboost.py training_set testing_set rounds
10 | Where
11 | training_set = Path to the file of the training set
12 | testing_set = Path to the file of the testing set
13 | rounds = Rounds to use for AdaBoost
14 |
15 |
16 | ================================================
17 | Notes
18 | ================================================
19 |
20 | - All tools that work with AdaBoost require the AdaBoost library to be located
21 | on the same directory. This is a shared library that can be compiled from
22 | the provided source code using the following commands:
23 |
24 | Linux:
25 | gcc -shared -fPIC adaboost_c45.c -o adaboost_c45.so
26 |
27 | Windows (using MinGW):
28 | gcc -shared adaboost_c45.c -o adaboost_c45.so
29 |
30 |
--------------------------------------------------------------------------------
/doc/README_train_c45.txt:
--------------------------------------------------------------------------------
1 |
2 | Tool for training and evaluation of C4.5 decision tree classifier.
3 |
4 | Use train_c45.py to train and evaluate the performance of a single C4.5
5 | decision trees over the specified dataset.
6 |
7 | Usage: python train_c45.py training_set testing_set
8 | Where
9 | training_set = Path to the file of the training set
10 | testing_set = Path to the file of the testing set
11 |
12 |
13 | ================================================
14 | Notes
15 | ================================================
16 |
17 | - All tools that work with C4.5 decision trees require the AdaBoost library
18 | to be located on the same directory. This is a shared library that can be
19 | compiled from the provided source code using the following commands:
20 |
21 | Linux:
22 | gcc -shared -fPIC adaboost_c45.c -o adaboost_c45.so
23 |
24 | Windows (using MinGW):
25 | gcc -shared adaboost_c45.c -o adaboost_c45.so
26 |
27 |
--------------------------------------------------------------------------------
/server/PenStrokeServer.py:
--------------------------------------------------------------------------------
1 | # Author: Awelemdy Orakwue March 10, 2015
2 | #!/usr/bin/python
3 |
4 | import sys
5 | import os
6 | import csv
7 | import subprocess
8 | import SocketServer
9 | import SimpleHTTPServer
10 | import base64
11 | import urllib
12 | import httplib
13 | import socket
14 | from xml.dom.minidom import Document, parseString
15 | import cPickle
16 | from symbol_classifier import SymbolClassifier
17 |
18 | classifier_filename = "best_full2013_SVMRBF_new.dat"
19 |
20 | classifier = ''
21 |
22 | class RecognitionServer(SimpleHTTPServer.SimpleHTTPRequestHandler):
23 | instance_id = 0
24 |
25 | def do_GET(self):
26 | """
27 | This will process a request which comes into the server.
28 | """
29 | #print classifier
30 | unquoted_url = urllib.unquote(self.path)
31 | unquoted_url = unquoted_url.replace("/?segmentList=", "")
32 | unquoted_url = unquoted_url.replace("&segment=false", "")
33 |
34 | dom = parseString(unquoted_url)
35 |
36 | Segments = dom.getElementsByTagName("Segment")
37 | numSegments = Segments.length
38 | segmentIDS = []
39 | classifierPoints = []
40 | for j in range(numSegments):
41 | segmentIDS.append(Segments[j].getAttribute("instanceID"))
42 | points = Segments[j].getAttribute("points")
43 | url_parts = points.split('|')
44 |
45 | translation = Segments[j].getAttribute("translation").split(",")
46 | tempPoints = []
47 | for i in range(len(url_parts)):
48 | pt = url_parts[i].split(",");
49 | pt2 = (int(pt[0]) + int(translation[0]), int(pt[1]) + int(translation[1]))
50 | tempPoints.append(pt2)
51 | classifierPoints.append(tempPoints)
52 |
53 | print classifierPoints
54 | results = classifier.classify_points_prob(classifierPoints, 30)
55 |
56 | doc = Document()
57 | root = doc.createElement("RecognitionResults")
58 | doc.appendChild(root)
59 | root.setAttribute("instanceIDs", ",".join(segmentIDS))
60 |
61 | for k in range(len(results)):
62 | r = doc.createElement("Result")
63 | s = str(results[k][0]).replace("\\", "")
64 | sym = dict.get(s)
65 | if(sym == None):
66 | sym = s
67 | # special case due to CSV file
68 | if(sym.lower() == "comma"):
69 | sym = ","
70 | v = str(results[k][1])
71 | c = format(float(v), '.35f')
72 | r.setAttribute("symbol", sym)
73 | r.setAttribute("certainty", c)
74 | root.appendChild(r)
75 |
76 | xml_str = doc.toxml()
77 | xml_str = xml_str.replace("amp;","")
78 |
79 | self.send_response(httplib.OK, 'OK')
80 | self.send_header("Content-length", len(xml_str))
81 | self.send_header("Access-Control-Allow-Origin", "*")
82 | self.end_headers()
83 | self.wfile.write(xml_str)
84 |
85 | ####### UTILITY FUNCTIONS #######
86 |
87 |
88 | if __name__ == "__main__":
89 | usage = "python PenStrokeServer "
90 | if(len(sys.argv) < 2):
91 | print usage
92 | sys.exit()
93 |
94 | HOST, PORT = "localhost", int(sys.argv[1])
95 | print("Loading classifier")
96 |
97 | global classifer
98 | global dict
99 | in_file = open(classifier_filename, 'rb')
100 | classifier = cPickle.load(in_file)
101 | in_file.close()
102 |
103 | if not isinstance(classifier, SymbolClassifier):
104 | print("Invalid classifier file!")
105 | #return
106 | print "Reading symbol code from generic_symbol_table.csv"
107 | dict = {}
108 | with open('generic_symbol_table.csv', 'rt') as csvfile2:
109 | reader2 = csv.reader(csvfile2, delimiter=',')
110 | for row in reader2:
111 | dict[row[3]] = row[0]
112 | #print dict
113 | print "Starting server."
114 |
115 | try:
116 | server = SocketServer.TCPServer(("", PORT), RecognitionServer)
117 | except socket.error, e:
118 | print e
119 | exit(1)
120 | #proc = subprocess.Popen(['python','kdtreeServer.py'])
121 | print "Serving"
122 | server.serve_forever()
123 |
124 |
--------------------------------------------------------------------------------
/server/best_full2013_SVMRBF_new.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DPRL/MathSymbolRecognizer/21a68677770a953d176de57356685027e27c0414/server/best_full2013_SVMRBF_new.dat
--------------------------------------------------------------------------------
/server/generic_symbol_table.csv:
--------------------------------------------------------------------------------
1 | Codepoint,Symbol,Classifier,LaTeX,
2 | 0,0,0,0,
3 | 1,1,1,1,
4 | 2,2,2,2,
5 | 3,3,3,3,
6 | 4,4,4,4,
7 | 5,5,5,5,
8 | 6,6,6,6,
9 | 7,7,7,7,
10 | 8,8,8,8,
11 | 9,9,9,9,
12 | +,+,pluslower,+,
13 | -,-,horzlinelower,-,
14 | =,=,eqlower,=,
15 | ≥,≥,geqlower,geq,
16 | <,<,ltlower,lt,
17 | >,>,gtlower,geq,
18 | ≠,≠,neqlower,neq,
19 | ≤,≤,less_than_or_equal,leq,
20 | ∫,∫,Integralupper,int,
21 | Π,Π,product,Pi,
22 | Σ,Σ,summation,sum,
23 | √,√,sqrtlower,sqrt,
24 | [,[,lbracketlower,[,
25 | ],],rbracketlower,],
26 | (,(,lparenlower,(,
27 | ),),rparenlower,),
28 | ∞,∞,infinlower,infty,
29 | A,A,Aupper,A,
30 | B,B,Bupper,B,
31 | C,C,Cupper,C,
32 | D,D,Dupper,D,
33 | E,E,Eupper,E,
34 | F,F,Fupper,F,
35 | G,G,Gupper,G,
36 | H,H,Hupper,H,
37 | I,I,Iupper,I,
38 | J,J,Jupper,J,
39 | K,K,Kupper,K,
40 | L,L,Lupper,L,
41 | M,M,Mupper,M,
42 | N,N,Nupper,N,
43 | O,O,Oupper,O,
44 | P,P,Pupper,P,
45 | Q,Q,Qupper,Q,
46 | R,R,Rupper,R,
47 | S,S,Supper,S,
48 | T,T,Tupper,T,
49 | U,U,Uupper,U,
50 | V,V,Vupper,V,
51 | W,W,Wupper,W,
52 | X,X,Xupper,X,
53 | Y,Y,Yupper,Y,
54 | Z,Z,Zupper,Z,
55 | a,a,alower,a,
56 | b,b,blower,b,
57 | c,c,clower,c,
58 | d,d,dlower,d,
59 | e,e,elower,e,
60 | f,f,flower,f,
61 | g,g,glower,g,
62 | h,h,hlower,h,
63 | i,i,ilower,i,
64 | j,j,jlower,j,
65 | k,k,klower,k,
66 | l,l,llower,l,
67 | m,m,mlower,m,
68 | n,n,nlower,n,
69 | o,o,olower,o,
70 | p,p,plower,p,
71 | q,q,qlower,q,
72 | r,r,rlower,r,
73 | s,s,slower,s,
74 | t,t,tlower,t,
75 | u,u,ulower,u,
76 | v,v,vlower,v,
77 | w,w,wlower,w,
78 | x,x,xlower,x,
79 | y,y,ylower,y,
80 | z,z,zlower,z,
81 | α,α,alphalower,alpha,
82 | β,β,betalower,beta,
83 | γ,γ,gammalower,gamma,
84 | δ,δ,deltalower,Delta,
85 | ε,ε,epsilonlower,epsilon,
86 | ζ,ζ,zetalower,zeta,
87 | η,η,etalower,eta,
88 | θ,θ,thetalower,theta,
89 | ι,ι,iotalower,iota,
90 | κ,κ,kappalower,kappa,
91 | λ,λ,lambdalower,lambda,
92 | μ,μ,mulower,mu,
93 | ν,ν,nulower,nu,
94 | ξ,ξ,xilower,xi,
95 | ο,ο,omicronlower,omicron,
96 | π,π,pilower,pi,
97 | ρ,ρ,rholower,rho
98 | σ,σ,sigmalower,sigma
99 | τ,τ,taulower,tau
100 | υ,υ,upsilonlower,upsilon
101 | φ,φ,philower,phi
102 | χ,χ,chilower,chi
103 | ψ,ψ,psilower,psi
104 | ω,ω,omegalower,omega
105 | Α,Αlpha,Alphaupper,Alpha
106 | Β,Beta,Betaupper,Beta
107 | Γ,Γ,Gammaupper,Gamma
108 | Δ,Δ,Deltaupper,Delta
109 | Ε,Εpsilon,Epsilonupper,Epsilon
110 | Ζ,Ζeta,Zetaupper,Zeta
111 | Η,Eta,Etaupper,Eta
112 | Θ,Θ,Thetaupper,Theta
113 | Ι,Ιota,Iotaupper,Iota
114 | Κ,Κappa,Kappaupper,Kappa
115 | Λ,Λ,Lambdaupper,Lambda
116 | Μ,Μu,Muupper,Mu
117 | Ν,Νu,Nuupper,Nu
118 | Ξ,Ξ,Xiupper,Xi
119 | Ο,Οmicron,Omicronupper,Omicron
120 | Π,Pi,Piupper,Pi
121 | Ρ,Rho,Rhoupper,Rho
122 | Σ,Sigma,Sigmaupper,Sigma
123 | Τ,Τau,Tauupper,tau
124 | Υ,Upsilon,Upsilonupper,Upsilon
125 | Φ,Φ,Phiupper,Phi
126 | Χ,Chi,Chiupper,Chi
127 | Ψ,Ψ,Psiupper,Psi
128 | Ω,Ω,Omegaupper,Omega
129 | ±,±,plusminus,pm
130 | ∃,∃,exists,exists
131 | ∀,∀,forall,forall
132 | ÷,÷,div,div
133 | ∈,∈,in,in
134 | ⨉,⨉,times,times
135 | …,…,ldots,ldots
136 | >,>,gt,gt
137 |
--------------------------------------------------------------------------------
/server/start:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | nohup python PenStrokeServer.py 6501 &
4 |
--------------------------------------------------------------------------------
/server/symbol_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import numpy as np
25 | from traceInfo import TraceInfo
26 | from mathSymbol import MathSymbol
27 |
28 | class SymbolClassifier:
29 | TypeRandomForest = 1
30 | TypeSVMLIN = 2
31 | TypeSVMRBF = 3
32 |
33 | def __init__(self, type, trained_classifier, classes_list, classes_dict, scaler=None, probabilistic=False):
34 | self.type = type
35 | self.trained_classifier = trained_classifier
36 | self.classes_list = classes_list
37 | self.classes_dict = classes_dict
38 | self.scaler = scaler
39 | self.probabilistic = probabilistic
40 |
41 | def predict(self, dataset):
42 | return self.trained_classifier.predict(dataset)
43 |
44 | def predict_proba(self, dataset):
45 | return self.trained_classifier.predict_proba(dataset)
46 |
47 | def get_raw_classes(self):
48 | return self.trained_classifier.classes_
49 |
50 | def get_symbol_from_points(self, points_lists):
51 |
52 | traces = []
53 | for trace_id, point_list in enumerate(points_lists):
54 | object_trace = TraceInfo(trace_id, point_list)
55 |
56 | traces.append(object_trace)
57 |
58 | # apply general trace pre processing...
59 | # 1) first step of pre processing: Remove duplicated points
60 | object_trace.removeDuplicatedPoints()
61 |
62 | # Add points to the trace...
63 | object_trace.addMissingPoints()
64 |
65 | # Apply smoothing to the trace...
66 | object_trace.applySmoothing()
67 |
68 | # it should not ... but .....
69 | if object_trace.hasDuplicatedPoints():
70 | # ...remove them! ....
71 | object_trace.removeDuplicatedPoints()
72 |
73 | new_symbol = MathSymbol(0, traces, '{Unknown}')
74 |
75 | # normalize size and locations
76 | new_symbol.normalize()
77 |
78 | return new_symbol
79 |
80 | def get_symbol_features(self, symbol):
81 | # get raw features
82 | features = symbol.getFeatures()
83 |
84 | # put them in python format
85 | mat_features = np.mat(features, dtype=np.float64)
86 |
87 | # automatically transform features
88 | if self.scaler is not None:
89 | mat_features = self.scaler.transform(mat_features)
90 |
91 | return mat_features
92 |
93 | def classify_points(self, points_lists):
94 | symbol = self.get_symbol_from_points(points_lists)
95 |
96 | return self.classify_symbol(symbol)
97 |
98 | def classify_points_prob(self, points_lists, top_n=None):
99 | symbol = self.get_symbol_from_points(points_lists)
100 |
101 | return self.classify_symbol_prob(symbol, top_n)
102 |
103 | def classify_symbol(self, symbol):
104 | features = self.get_symbol_features(symbol)
105 |
106 | predicted = self.trained_classifier.predict(features)
107 |
108 | return self.classes_list[predicted[0]]
109 |
110 |
111 | def classify_symbol_prob(self, symbol, top_n=None):
112 | features = self.get_symbol_features(symbol)
113 |
114 | #try:
115 | predicted = self.trained_classifier.predict_proba(features)
116 | #except:
117 | # raise Exception("Classifier was not trained as probabilistic classifier")
118 |
119 | scores = sorted([(predicted[0, k], k) for k in range(predicted.shape[1])], reverse=True)
120 |
121 | tempo_classes = self.trained_classifier.classes_
122 | n_classes = len(tempo_classes)
123 | if top_n is None or top_n > n_classes:
124 | top_n = n_classes
125 |
126 | confidences = [(self.classes_list[tempo_classes[scores[k][1]]], scores[k][0]) for k in range(top_n)]
127 |
128 | return confidences
129 |
130 |
131 |
--------------------------------------------------------------------------------
/server/test_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 | import sys
26 | import cPickle
27 | from symbol_classifier import SymbolClassifier
28 |
29 | def main():
30 | # usage check...
31 | if len(sys.argv) < 2:
32 | print("Usage: python test_classifer.py classifier")
33 | print("Where")
34 | print("\tclassifier\t= Path to trained symbol classifier")
35 | return
36 |
37 | classifier_file = sys.argv[1]
38 |
39 | print("Loading classifier")
40 |
41 | in_file = open(classifier_file, 'rb')
42 | classifier = cPickle.load(in_file)
43 | in_file.close()
44 |
45 | if not isinstance(classifier, SymbolClassifier):
46 | print("Invalid classifier file!")
47 | return
48 |
49 | # get mapping
50 | classes_dict = classifier.classes_dict
51 | classes_l = classifier.classes_list
52 | n_classes = len(classes_l)
53 |
54 | # create test characters from points
55 | sample_x = [[(-0.8, -0.8), (0.1, 0.12), (0.8, 0.79)], [(-0.85, 0.79), (0.01, 0.005), (0.79, -0.83)]]
56 | sample_1 = [[(0.15, 0.7), (0.2, 1.0), (0.21, -1.2)], [(0.05, -1.19), (0.3, -1.25)]]
57 | sample_eq = [[(-1.5, -0.4), (1.5, -0.4)], [(-1.5, 0.4), (1.5, 0.4)]]
58 |
59 | # classify them
60 | class_x = classifier.classify_points(sample_x)
61 | class_1 = classifier.classify_points(sample_1)
62 | class_eq = classifier.classify_points(sample_eq)
63 |
64 | print("X classified as " + class_x)
65 | print("1 classified as " + class_1)
66 | print("= classified as " + class_eq)
67 |
68 | # now with confidence ...
69 | classes_x = classifier.classify_points_prob(sample_x, 3)
70 | classes_1 = classifier.classify_points_prob(sample_1, 3)
71 | classes_eq = classifier.classify_points_prob(sample_eq, 3)
72 |
73 | print("X top classes are: " + str(classes_x))
74 | print("1 top classes are: " + str(classes_1))
75 | print("= top classes are: " + str(classes_eq))
76 |
77 | if __name__ == '__main__':
78 | main()
--------------------------------------------------------------------------------
/src/ambiguous.txt:
--------------------------------------------------------------------------------
1 | x,X
2 | x,\times
3 | X,\times
4 | 1,|
5 | (,|
6 | ),|
7 | 1,(
8 | 1,)
9 | 1,/
10 | 1,COMMA
11 | c,C
12 | ),COMMA
13 | p,P
14 | \prime,COMMA
15 | \prime,|
16 | v,V
17 | s,S
18 | q,9
19 | o,0
20 | \prime,/
21 | /,COMMA
--------------------------------------------------------------------------------
/src/apply_PCA_parameters.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import pickle
26 | from sklearn.preprocessing import StandardScaler
27 | from dataset_ops import *
28 |
29 | #=====================================================================
30 | # Uses the provided PCA paramaters and applies them to a given
31 | # dataset
32 | #
33 | # Created by:
34 | # - Kenny Davila (Feb 1, 2012-2014)
35 | # Modified By:
36 | # - Kenny Davila (Feb 1, 2012-2014)
37 | #
38 | #=====================================================================
39 |
40 |
41 | def load_PCA_parameters(file_name):
42 | file_params = open(file_name, 'r')
43 |
44 | normalization = pickle.load(file_params)
45 | pca_vector = pickle.load(file_params)
46 | pca_k = pickle.load(file_params)
47 |
48 | file_params.close()
49 |
50 | return (normalization, pca_vector, pca_k)
51 |
52 |
53 | def project_PCA(training_set, eig_vectors, k):
54 | #in case that K > # of atts, then just clamp....
55 | n_samples = training_set.shape[0]
56 | n_atts = training_set.shape[1]
57 |
58 | k = min(k, n_atts)
59 |
60 | projected = np.zeros((n_samples, k))
61 |
62 | #...for each sample...
63 | for n in range(n_samples):
64 | x = np.mat(training_set[n, :])
65 |
66 | #...for each eigenvector
67 | for i in range(k):
68 | #use dot product to project...
69 | #p = np.dot(x, eig_vectors[i])[0, 0]
70 | p = np.dot(x, eig_vectors[i])[0, 0]
71 |
72 | projected[n, i] = p.real
73 |
74 | return projected
75 |
76 |
77 | def main():
78 | #usage check
79 | if len(sys.argv) != 4:
80 | print("Usage: python apply_PCA_parameters.py training_set PCA_params output")
81 | print("Where")
82 | print("\ttraining_set\t= Path to the file of the training_set")
83 | print("\tPCA_params\t= Path to the file of the PCA parameters")
84 | print("\toutput\t= File to output the final dataset with reduced dimensionality")
85 | return
86 |
87 | input_filename = sys.argv[1]
88 | params_filename = sys.argv[2]
89 | output_filename = sys.argv[3]
90 |
91 | #...load training set ...
92 | print("Loading data....")
93 | training, labels_l, att_types = load_dataset(input_filename)
94 |
95 | #...the parameters....
96 | print("Loading parameters...")
97 | #normalization, pca_vector, pca_k = load_PCA_parameters(params_filename)
98 | scaler, pca_vector, pca_k = load_PCA_parameters(params_filename)
99 |
100 | #...apply normalization...
101 | print("Normalizing data...")
102 | #new_data = normalize_data_from_params(training, normalization)
103 | new_data = scaler.transform(training)
104 |
105 | #...transform....
106 | print("Applying transformation...")
107 | projected = project_PCA(new_data, pca_vector, pca_k)
108 |
109 | #...save final version...
110 | print("Saving to file...")
111 | final_atts = np.zeros((pca_k, 1), dtype=np.int32)
112 | final_atts[:, :] = 1 # Continuous attributes
113 |
114 | save_dataset_string_labels(projected, labels_l, final_atts, output_filename)
115 |
116 | print("Finished!")
117 |
118 | main()
119 |
120 |
--------------------------------------------------------------------------------
/src/boosted_test.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import ctypes
26 | import numpy as np
27 | from dataset_ops import *
28 | from evaluation_ops import *
29 | from load_inkml import *
30 |
31 | #=====================================================================
32 | # this program takes as input boosted classifier and the dataset
33 | # used for training, and a testing set and computes the classifier accuracy
34 | #
35 | # Created by:
36 | # - Kenny Davila (Jan 19, 2013)
37 | # Modified By:
38 | # - Kenny Davila (Jan 19, 2013)
39 | # - Kenny Davila (Feb 26, 2012-2014)
40 | # - eliminated mapping compatibility
41 | # - Added evaluation metrics: top-n, per-class avg
42 | #
43 | #=====================================================================
44 |
45 | c45_lib = ctypes.CDLL('./adaboost_c45.so')
46 |
47 |
48 | def count_per_class(labels, n_samples, n_classes):
49 | result_counts = [0 for x in range(n_classes)]
50 |
51 | for i in range(n_samples):
52 | result_counts[labels[i]] += 1
53 |
54 | return result_counts
55 |
56 |
57 | def output_sample(file_path, sym_id, out_path):
58 | #load file...
59 | symbols = load_inkml( file_path, True )
60 |
61 | #find symbol ...
62 | for symbol in symbols:
63 | if symbol.id == sym_id:
64 | symbol.saveAsSVG( out_path )
65 |
66 | return True
67 |
68 | return False
69 |
70 |
71 | def predict_top_n_data(classifier, data, top_n, n_classes):
72 | n_samples = np.size(data, 0)
73 | p_classes = np.zeros((1, n_classes), dtype=np.float64)
74 | p_classes_p = p_classes.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
75 |
76 | results = np.zeros((n_samples, top_n))
77 |
78 | for i in range(n_samples):
79 | c_sample = data[i, :]
80 | c_sample_p = c_sample.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
81 |
82 | c45_lib.boosted_c45_probabilistic_classify(classifier, c_sample_p, p_classes_p)
83 |
84 | tempo_values = []
85 | for k in range(n_classes):
86 | tempo_values.append((p_classes[0, k], k))
87 |
88 | #now, sort!
89 | tempo_values = sorted(tempo_values, reverse=True)
90 |
91 | #use the top-N
92 | for k in range(top_n):
93 | results[i, k] = tempo_values[k][1]
94 |
95 | return results
96 |
97 |
98 | def main():
99 | #usage check
100 | if len(sys.argv) < 4:
101 | print("Usage: python boosted_test.py classifier training_set testing_set [save_fail]")
102 | print("Where")
103 | print("\tclassifier\t= Path to the .bc45 file that contains the classifier")
104 | print("\ttraining_set\t= Path to the file of the training set")
105 | print("\ttesting_set\t= Path to the file of the testing set")
106 | print("\tsave_fail\t= Optional, will output failure case if specified")
107 | return
108 |
109 | #...load training data from file...
110 | print "...Loading Training Data..."
111 |
112 | file_name = sys.argv[2]
113 | training, labels_l, att_types = load_dataset(file_name)
114 | #...generate mapping...
115 | classes_dict, classes_l = get_label_mapping(labels_l)
116 | #...generate mapped labels...
117 | labels_train = get_mapped_labels(labels_l, classes_dict)
118 |
119 | if len(sys.argv) >= 5:
120 | save_fail = int(sys.argv[4]) > 0
121 | else:
122 | save_fail = False
123 |
124 | #...load testing data from file...
125 | print "...Loading Testing Data..."
126 |
127 | testing, test_labels_l, att_types = load_dataset(sys.argv[3])
128 | #...generate mapped labels...
129 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
130 |
131 | if save_fail:
132 | #need to load sources...
133 | test_sources = load_ds_sources(sys.argv[3] + ".sources.txt")
134 |
135 | if test_sources is None:
136 | print("Sources are unavailable")
137 | save_fail = False
138 |
139 | print "...Loading classifier..."
140 | ensemble = c45_lib.boosted_c45_load( sys.argv[1])
141 |
142 | #...info....
143 | n_classes = len(classes_l)
144 | n_train_samples = np.size(training, 0)
145 | n_test_samples = np.size(testing, 0)
146 |
147 | print("Classifier: " + sys.argv[1])
148 | print("Training Set: " + sys.argv[2])
149 | print("Testing Set: " + sys.argv[3])
150 |
151 |
152 | print "...evaluating..."
153 |
154 | top_n = 5
155 | predicted = predict_top_n_data(ensemble, training, top_n, n_classes)
156 | print "Training Samples: " + str(n_train_samples)
157 | print "Training Results"
158 | print "Top\tAccuracy\tClass Average\tClass STD "
159 |
160 | #....on main thread, compute final statistics
161 | for i in range(top_n):
162 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_train,
163 | n_classes, i + 1)
164 | accuracy = (float(total_correct) / float(n_train_samples)) * 100
165 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
166 |
167 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
168 |
169 | predicted = predict_top_n_data(ensemble, testing, top_n, n_classes)
170 | print "Testing Samples: " + str(n_test_samples)
171 | print "Testing Results"
172 | print "Top\tAccuracy\tClass Average\tClass STD "
173 | for i in range(top_n):
174 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_test,
175 | n_classes, i + 1)
176 | accuracy = (float(total_correct) / float(n_test_samples)) * 100
177 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
178 |
179 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
180 |
181 |
182 | #repeat for Top=1 accuracy...
183 | all_failure_info = []
184 | confusion_matrix = np.zeros((n_classes, n_classes), dtype = np.int32)
185 | total_correct = 0
186 | train_counts = count_per_class(labels_train, n_train_samples, n_classes)
187 | for k in range(n_test_samples):
188 | top_label = predicted[i, 0]
189 |
190 | if top_label != labels_test[k, 0]:
191 | #inccorrect...
192 | if save_fail:
193 | file_path, sym_id = test_sources[k]
194 | output_sample(file_path, sym_id, "output//error_" + str(k) + ".svg")
195 |
196 | all_failure_info.append((k, file_path, sym_id, classes_l[top_label], classes_l[labels_test[k, 0]]))
197 | else:
198 | total_correct += 1
199 |
200 | confusion_matrix[labels_test[k, 0], top_label] += 1
201 |
202 | if save_fail:
203 | #...save additional info of errors...
204 | content = 'id, path, sym_id, predicted, expected\n'
205 | for error_info in all_failure_info:
206 |
207 | for i in range(len(error_info)):
208 | if i > 0:
209 | content += ","
210 |
211 | content += str(error_info[i])
212 |
213 | content += "\n"
214 |
215 | file_name = sys.argv[1] + ".failures.csv"
216 | try:
217 | f = open(file_name, 'w')
218 | f.write(content)
219 | f.close()
220 | except:
221 | print("ERROR WRITING RESULTS TO FILE! <" + file_name + ">")
222 |
223 |
224 | #... print results....
225 | accuracy = (float(total_correct) / float(n_test_samples)) * 100.0
226 | print "Testing Accuracy = " + str(accuracy)
227 |
228 | #.... save confusion matrix to file...
229 | out_str = "Samples:," + str(n_test_samples) + "\n"
230 | out_str += "Correct:," + str(total_correct) + "\n"
231 | out_str += "Wrong:," + str(n_test_samples - total_correct) + "\n"
232 | out_str += "Accuracy:," + str(accuracy) + "\n\n\n"
233 | out_str += "Rows:,Expected \n"
234 | out_str += "Column:,Predicted \n\n"
235 | out_str += "Full Matrix\n"
236 | out_str += "X"
237 | only_err = "X"
238 | #...build header...
239 | for i in range(n_classes ):
240 | c_class = classes_l[i]
241 | if c_class == ",":
242 | c_class = "COMMA"
243 |
244 | out_str += "," + c_class
245 | only_err += "," + c_class
246 |
247 | out_str += ",Total,Train Count\n"
248 | only_err += ",Total\n"
249 |
250 | #...build the content...
251 | for i in range(n_classes):
252 | c_class = classes_l[i]
253 | if c_class == ",":
254 | c_class = "COMMA"
255 |
256 | out_str += c_class
257 | only_err += c_class
258 |
259 | t_errs = 0
260 | t_samples = 0
261 | for k in range(n_classes):
262 | out_str += "," + str(confusion_matrix[i, k])
263 |
264 | t_samples += confusion_matrix[i, k]
265 | if i == k:
266 | only_err += ",0"
267 | else:
268 | t_errs += confusion_matrix[i, k]
269 | only_err += "," + str(confusion_matrix[i, k])
270 |
271 | out_str += "," + str(t_samples) + "," + str(train_counts[i]) + "\n"
272 | only_err += "," + str(t_errs) + "\n"
273 |
274 | #...build the totals row...
275 | out_str += "Total"
276 | only_err += "Total"
277 | for i in range(n_classes):
278 | t_errs = 0
279 | t_samples = 0
280 |
281 | for k in range(n_classes):
282 | #...inverted, k = row, i = column
283 | t_samples += confusion_matrix[k, i]
284 | if i != k:
285 | t_errs += confusion_matrix[k, i]
286 |
287 | out_str += "," + str(t_samples)
288 | only_err += "," + str(t_errs)
289 |
290 |
291 | out_str += "\n\n\nOnly Errors\n\n" + only_err
292 |
293 | file_name = sys.argv[1] + ".results.csv"
294 |
295 | try:
296 | f = open(file_name, 'w')
297 | f.write( out_str )
298 | f.close()
299 | except:
300 | print("ERROR WRITING RESULTS TO FILE! <" + file_name + ">")
301 |
302 | print "Finished!"
303 |
304 | main()
305 |
--------------------------------------------------------------------------------
/src/correct_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | from dataset_ops import *
26 |
27 |
28 | #=====================================================================
29 | # Corrects certain common labeling errors usually found in CROHME
30 | # datasets
31 | #
32 | # Created by:
33 | # - Kenny Davila (Oct, 2012)
34 | # Modified By:
35 | # - Kenny Davila (Oct 25, 2013)
36 | # - Kenny Davila (Feb 5, 2012-2014)
37 | #
38 | #=====================================================================
39 |
40 | def get_classes_list(labels):
41 | all_classes = {}
42 |
43 | for label in labels:
44 | if not label in all_classes:
45 | all_classes[label] = True
46 |
47 | return all_classes.keys()
48 |
49 |
50 | def replace_labels(labels):
51 | new_labels = []
52 |
53 | for label in labels:
54 | if label == "\\tg":
55 | new_label = "\\tan"
56 | elif label == '>':
57 | new_label = '\gt'
58 | elif label == '<':
59 | new_label = '\lt'
60 | elif label == '\'':
61 | new_label = '\prime'
62 | elif label == '\cdots':
63 | new_label = '\ldots'
64 | elif label == '\\vec':
65 | new_label = '\\rightarrow'
66 | elif label == '\cdot':
67 | new_label = '.'
68 | elif label == ',':
69 | new_label = 'COMMA'
70 | elif label == '\\frac':
71 | new_label = '-'
72 | else:
73 | #unchanged...
74 | new_label = label
75 |
76 | new_labels.append(new_label)
77 |
78 | return new_labels
79 |
80 |
81 | def main():
82 | #usage check
83 | if len(sys.argv) != 3:
84 | print("Usage: python correct_labels.py training_set output_set ")
85 | print("Where")
86 | print("\ttraining_set\t= Path to the file of the training_set")
87 | print("\toutput_set\t= File to output file with corrected labels")
88 | return
89 |
90 | #...load training data from file...
91 | input_filename = sys.argv[1]
92 | output_filename = sys.argv[2]
93 |
94 | #file_name = 'ds_test_2012.txt'
95 | training, labels_l, att_types = load_dataset(input_filename)
96 |
97 | if training is None:
98 | print("Data not could not be loaded")
99 | else:
100 | #extract list of classes...
101 | all_classes = get_classes_list(labels_l)
102 |
103 | print("Original number of classes = " + str(len(all_classes)))
104 |
105 | new_labels = replace_labels(labels_l)
106 |
107 | new_all_classes = get_classes_list(new_labels)
108 | print("New number of classes = " + str(len(new_all_classes)))
109 |
110 | save_dataset_string_labels(training, new_labels, att_types, output_filename)
111 |
112 | print("Success!")
113 |
114 | main()
115 |
116 |
--------------------------------------------------------------------------------
/src/count_common.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import numpy as np
26 | from dataset_ops import *
27 |
28 |
29 | def get_counts_per_class(labels_l):
30 | count_per_class = {}
31 |
32 | for label in labels_l:
33 | if label in count_per_class:
34 | count_per_class[label] += 1
35 | else:
36 | count_per_class[label] = 1
37 |
38 | return count_per_class
39 |
40 |
41 | def count_in_common(counts_l1, counts_l2):
42 | common_1 = 0
43 | common_2 = 0
44 | common_labels = []
45 | only1 = []
46 | only2 = []
47 |
48 | for label1 in counts_l1:
49 | if label1 in counts_l2:
50 | common_1 += counts_l1[label1]
51 | common_2 += counts_l2[label1]
52 |
53 | common_labels.append(label1)
54 | else:
55 | only1.append(label1)
56 |
57 | for label2 in counts_l2:
58 | if not label2 in counts_l1:
59 | only2.append(label2)
60 |
61 |
62 | return common_1, common_2, common_labels, only1, only2
63 |
64 |
65 | def extract_common_classes(dataset, labels, common_labels):
66 | #...first, find the references to samples from common labels...
67 | common_refs = []
68 | for idx, label in enumerate(labels):
69 | if label in common_labels:
70 | common_refs.append(idx)
71 |
72 | #now... create a new dataset only with common refs...
73 | n_common_size = len(common_refs)
74 | n_atts = np.size(dataset, 1)
75 | common_set_data = np.zeros((n_common_size, n_atts), dtype=np.float64)
76 | common_set_labels = []
77 |
78 | for idx, ref_idx in enumerate(common_refs):
79 | common_set_data[idx, :] = dataset[ref_idx, :]
80 | common_set_labels.append(labels[ref_idx])
81 |
82 | return common_set_data, common_set_labels
83 |
84 | def main():
85 | if len(sys.argv) != 4:
86 | print("Usage: python count_common.py dataset_1 dataset_2")
87 | print("Where")
88 | print("\tdataset_1\t= Path to the file of the first dataset")
89 | print("\tdataset_2\t= Path to the file of the second dataset")
90 | print("\textract\t= Extract common samples from second dataset")
91 | return
92 |
93 | data1_filename = sys.argv[1]
94 | data2_filename = sys.argv[2]
95 |
96 | try:
97 | do_extraction = int(sys.argv[3]) > 0
98 | except:
99 | print("Invalid value for extract parameter")
100 | return
101 |
102 | #...loading dataset...
103 | print("...Loading data set!")
104 | data1, labels_l1, att_types_1 = load_dataset(data1_filename)
105 | data2, labels_l2, att_types_2 = load_dataset(data2_filename)
106 |
107 | print("...Getting counts!")
108 | counts_l1 = get_counts_per_class(labels_l1)
109 | counts_l2 = get_counts_per_class(labels_l2)
110 |
111 | print("...Finding class overlap!")
112 | common1, common2, common_labels, only1, only2 = count_in_common(counts_l1, counts_l2)
113 |
114 | print("Classes on dataset 1: " + str(len(counts_l1.keys())))
115 | print("Classes on dataset 2: " + str(len(counts_l2.keys())))
116 | print("Total common classes: " + str(len(common_labels)))
117 | for k in common_labels:
118 | print(k)
119 | print("Samples of common classes on dataset 1: " + str(common1))
120 | print("Samples of common classes on dataset 2: " + str(common2))
121 |
122 | print("Total classes only on 1: " + str(len(only1)))
123 | for k in only1:
124 | print(k)
125 | print("Total classes only on 2: " + str(len(only2)))
126 | for k in only2:
127 | print(k)
128 |
129 | if do_extraction:
130 | print("Extracting samples with common labels from dataset 2")
131 | filtered_data, filtered_labels = extract_common_classes(data2, labels_l2, common_labels)
132 |
133 | print("Saving filtered samples ")
134 | save_dataset_string_labels(filtered_data,filtered_labels,att_types_2, data2_filename + ".common.txt")
135 |
136 | print("Finished!")
137 |
138 | main()
139 |
--------------------------------------------------------------------------------
/src/dataset_info.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 | import sys
26 | import math
27 | import numpy as np
28 | from dataset_ops import *
29 | #=====================================================================
30 | # Simple script that loads an existing dataset from a file and
31 | # outputs the most general information.
32 | #
33 | # Created by:
34 | # - Kenny Davila (Feb 7, 2012-2014)
35 | # Modified By:
36 | # - Kenny Davila (Feb 7, 2012-2014)
37 | # - Kenny Davila (Feb 7, 2012-2014)
38 | #
39 | #=====================================================================
40 |
41 |
42 | def main():
43 | #usage check
44 | if len(sys.argv) < 2:
45 | print("Usage: python dataset_info.py datset [n_bins]")
46 | print("Where")
47 | print("\tdataset\t= Path to file that contains the data set")
48 | print("\tn_bins\t= Optional, number of bins for histogram of class representation")
49 | return
50 |
51 | input_filename = sys.argv[1]
52 |
53 | if len(sys.argv) >= 3:
54 | try:
55 | n_bins = int(sys.argv[2])
56 | if n_bins < 1:
57 | print("Invalid n_bins value")
58 | return
59 | except:
60 | print("Invalid n_bins value")
61 | return
62 | else:
63 | n_bins = 10
64 |
65 | print("Loading data....")
66 | training, labels_l, att_types = load_dataset(input_filename);
67 | print("Data loaded!")
68 |
69 | print("Getting information...")
70 | n_samples = np.size(training, 0)
71 | n_atts = np.size(training, 1)
72 |
73 | #...counts per class...
74 | count_per_class = {}
75 | for idx in range(n_samples):
76 | #s_label = labels_train[idx, 0]
77 | s_label = labels_l[idx]
78 |
79 | if s_label in count_per_class:
80 | count_per_class[s_label] += 1
81 | else:
82 | count_per_class[s_label] = 1
83 |
84 | #...distribution...
85 | #...first pass, compute minimum and maximum...
86 | smallest_class_size = n_samples
87 | smallest_class_label = ""
88 | largest_class_size = 0
89 | largest_class_label = ""
90 | for label in count_per_class:
91 | #...check minimum
92 | if count_per_class[label] < smallest_class_size:
93 | smallest_class_size = count_per_class[label]
94 | smallest_class_label = label
95 |
96 | #...check maximum
97 | if count_per_class[label] > largest_class_size:
98 | largest_class_size = count_per_class[label]
99 | largest_class_label = label
100 |
101 | print("Class\t" + label + "\tCount\t" + str(count_per_class[label]))
102 |
103 | #...second pass... create histogram...
104 | count_bins = [0 for x in range(n_bins)]
105 | samples_bins = [0 for x in range(n_bins)]
106 | size_per_bin = int(math.ceil(float(largest_class_size + 1) / float(n_bins)))
107 | for label in count_per_class:
108 | current_bin = int(count_per_class[label] / size_per_bin)
109 | count_bins[current_bin] += 1
110 | samples_bins[current_bin] += count_per_class[label]
111 |
112 | #...print... bins...
113 | print("Class sizes distribution")
114 | for i in range(n_bins):
115 | start = i * size_per_bin + 1
116 | end = (i + 1) * size_per_bin
117 | percentage = (float(samples_bins[i]) / float(n_samples)) * 100.0
118 | print("... From " + str(start) + "\t to " + str(end) + "\t : " +
119 | str(count_bins[i]) + "\t (" + str(percentage) + " of data)")
120 |
121 | n_classes = len(count_per_class.keys())
122 |
123 | print("Total Samples: " + str(n_samples))
124 | print("Total Attributes: " + str(n_atts))
125 | print("Total Classes: " + str(n_classes))
126 | print("-> Largest Class: " + largest_class_label + "\t: " + str(largest_class_size) + " samples")
127 | print("-> Smallest Class: " + smallest_class_label + "\t: " + str(smallest_class_size) + " samples")
128 | print("Finished...")
129 |
130 | main()
131 |
--------------------------------------------------------------------------------
/src/dataset_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 | import numpy as np
26 |
27 | #=====================================================================
28 | # Most general functions used to load and save datasets from
29 | # text files separated by semi-colon
30 | #
31 | # Created by:
32 | # - Kenny Davila (Dic 1, 2013)
33 | # Modified By:
34 | # - Kenny Davila (Jan 26, 2012-2014)
35 | # - Added function to save dataset
36 | # - Kenny Davila (Feb 1, 2012-2014)
37 | # - Added functions for data normalization
38 | # - Kenny Davila (Feb 10, 2012-2014)
39 | # - empty lines are now skipped!
40 | # - Kenny Davila (Feb 18, 2012-2014)
41 | # - load_dataset is now more memory efficient
42 | #
43 | #=====================================================================
44 |
45 | #=========================================
46 | # Loads a dataset from a file
47 | # returns
48 | # Samples, Labels, Att Types
49 | # (NP Matrix, List, NP Matrix)
50 | #=========================================
51 | def load_dataset(file_name):
52 | try:
53 |
54 | data_file = open(file_name, 'r')
55 | lines = data_file.readlines()
56 | data_file.close()
57 |
58 | #now process every line...
59 | #first line must be the attribute type for each feature...
60 | att_types_s1 = lines[0].split(';')
61 | att_types_s2 = [ s.strip().upper() for s in att_types_s1 ]
62 | n_atts = len( att_types_s2 )
63 |
64 | #..proces the attributes...
65 | att_types = np.zeros( (n_atts, 1), dtype=np.int32 )
66 | for i in xrange(n_atts):
67 | if att_types_s2[i] == 'D':
68 | att_types[i] = 2
69 | else:
70 | att_types[i] = 1
71 |
72 | estimated_samples = len(lines) - 1
73 | tempo_samples = np.zeros( (estimated_samples, n_atts), dtype = np.float64 )
74 |
75 | count_samples = 0
76 | labels_l = []
77 | for i in xrange(1, len(lines)):
78 | values_s = lines[i].split(';')
79 | if len(values_s) == 1:
80 | if values_s[0].strip() == "":
81 | #just skip the empty line...
82 | continue
83 |
84 | #assume last value is class label
85 | label = values_s[-1].strip()
86 | del values_s[-1]
87 |
88 | #read values...
89 | values = []
90 | for idx, att_type in enumerate(att_types):
91 | values.append( float( values_s[idx].strip() ) )
92 |
93 |
94 | #validate number of attributes on the sample
95 | if len(values) != n_atts:
96 | print("Number of values is different to number of attributes")
97 | print("Atts: " + str(n_atts))
98 | print("Values: " + str(len(values)))
99 | return None, None, None
100 |
101 | #add sample
102 | for att in xrange(n_atts):
103 | tempo_samples[count_samples, att] = values[att]
104 |
105 | count_samples += 1
106 |
107 | labels_l.append( label )
108 |
109 | n_samples = count_samples
110 | if n_samples != estimated_samples:
111 | samples = tempo_samples[:n_samples, :].copy()
112 | tempo_samples = None
113 | else:
114 | samples = tempo_samples
115 |
116 |
117 | return samples, labels_l, att_types
118 |
119 | except Exception as e:
120 | print("Error loading dataset from file")
121 | print( e )
122 | return None, None, None
123 |
124 |
125 |
126 | #==============================================
127 | # Generates a mapping for the unique set of
128 | # classes present on the given list of labels
129 | #==============================================
130 | def get_label_mapping(labels_l):
131 | classes_dict = {}
132 | classes_l = []
133 |
134 | #... for each sample...
135 | n_samples = len(labels_l)
136 | for i in xrange( n_samples ):
137 | label = labels_l[ i ]
138 |
139 | #check mapping of labels...
140 | if not label in classes_dict:
141 | #...add label to mapping...
142 | label_val = len( classes_l )
143 | classes_l.append( label )
144 |
145 | classes_dict[ label ] = label_val
146 |
147 | return classes_dict, classes_l
148 |
149 | def get_mapped_labels(labels_l, classes_dict):
150 | n_samples = len( labels_l )
151 |
152 | #...for the mapped labels...
153 | labels = np.zeros( (n_samples, 1), dtype = np.int32 )
154 |
155 | #... for each sample ...
156 | for i in xrange( n_samples ):
157 | #get current label...
158 | label = labels_l[ i ]
159 | #get mapped label...
160 | label_val = classes_dict[ label ]
161 | #add mapped label
162 | labels[ i, 0 ] = label_val
163 |
164 | return labels
165 |
166 |
167 | def load_ds_sources(file_name):
168 | try:
169 | #read all lines...
170 | file_source = open(file_name, 'r')
171 | lines = file_source.readlines()
172 | file_source.close()
173 |
174 | all_sources = []
175 |
176 | #get filename, symbol id per each symbol in DS
177 | for i in range(len(lines)):
178 | values_s = lines[i].split(',')
179 |
180 | #should contain 2 values...
181 | if len(values_s) != 2:
182 | print( "Invalid line <" + str(i) + "> in Auxiliary file: " + lines[i] )
183 | return None
184 | else:
185 | file_path = values_s[0].strip()
186 | sym_id = int( values_s[1] )
187 |
188 | all_sources.append( (file_path, sym_id) )
189 |
190 | return all_sources
191 |
192 | except Exception as e:
193 | print(e)
194 | return None
195 |
196 | def save_dataset(data, labels, att_types, out_file):
197 | n_samples = np.size(data, 0)
198 |
199 | try:
200 | out_file = open(out_file, 'w')
201 | except:
202 | print( "File <" + out_file + "> could not be created")
203 | return
204 |
205 | #...writing first header....
206 | content = ''
207 | #print as headers the types for each feature...
208 | n_atts = np.size(att_types, 0)
209 | for i in range(n_atts):
210 | if i > 0:
211 | content += '; '
212 |
213 | if att_types[i, 0] == 2:
214 | content += 'D'
215 | else:
216 | content += 'C'
217 |
218 | content += '\r\n'
219 | out_file.write(content)
220 |
221 | #...now, write the samples....
222 | content = ''
223 | for idx in range(n_samples):
224 | #...add the values...
225 | line = ''
226 | for k in range(n_atts):
227 | if k > 0:
228 | line += '; '
229 |
230 | line += str(data[idx, k])
231 |
232 | #... add the label...
233 | line += "; " + str(labels[idx, 0]) + "\r\n"
234 |
235 | content += line
236 | #....check if buffer is full!
237 | if len(content) >= 50000:
238 | out_file.write(content)
239 | content = ''
240 |
241 | #....write any remaining content
242 | out_file.write(content)
243 |
244 | out_file.close()
245 |
246 | def save_label_mapping(base_classes, extra_mapping, file_name):
247 | #...create....
248 | try:
249 | out_file = open(file_name, 'w')
250 | except:
251 | print( "File <" + file_name + "> could not be created")
252 | return
253 |
254 | #... First, write the sizes of both mappings
255 | content = str(len(base_classes)) + ";" + str(len(extra_mapping.keys())) + "\r\n"
256 | out_file.write(content)
257 |
258 | #...write the original class list...
259 | content = ''
260 | for c_class in base_classes:
261 | content += c_class + "\r\n"
262 | out_file.write(content)
263 |
264 | #...write now the extra mapping...
265 | content = ''
266 | for key in extra_mapping:
267 | content += str(key) + ";" + str(extra_mapping[key]) + "\r\n"
268 |
269 | out_file.write(content)
270 |
271 | #...close...
272 | out_file.close()
273 |
274 |
275 | def load_label_mapping(file_name):
276 | #...open....
277 | try:
278 | data_file = open(file_name, 'r')
279 | lines = data_file.readlines()
280 | data_file.close()
281 |
282 | #first line should contain the size of the mappings..
283 | sizes_s = lines[0].split(';')
284 | n_classes = int(sizes_s[0])
285 | n_mapped = int(sizes_s[1])
286 |
287 | print "...Loading class mapping..."
288 | print "N-real-classes: " + str(n_classes)
289 | print "N-virtual-classes: " + str(n_mapped)
290 |
291 | class_l = []
292 | class_dict = {}
293 | for i in range(n_classes):
294 | label = lines[1 + i].strip()
295 |
296 | class_l.append(label)
297 | class_dict[label] = i
298 |
299 | class_mapping = {}
300 | for i in range(n_mapped):
301 | #get the key -> value pair...
302 | mapped = lines[1 + n_classes + i].split(';')
303 |
304 | #...split...
305 | new_label = int(mapped[0])
306 | original_label = int(mapped[1])
307 |
308 | class_mapping[new_label] = original_label
309 |
310 | return class_l, class_dict, class_mapping
311 | except Exception as e:
312 | print(e)
313 | return None, None
314 |
315 |
316 | def append_symbols(symbols, out_file):
317 | n_samples = len(symbols)
318 |
319 | print("...adding samples " + str(n_samples) + " to output file...")
320 |
321 | n_atts = 0
322 | content = ''
323 | for idx, symbol in enumerate(symbols):
324 | sample = symbol.getFeatures() + [ symbol.truth ]
325 | if idx == 0:
326 | n_atts = len(sample) - 1
327 |
328 | line = ''
329 | for i, v in enumerate(sample):
330 | if i > 0:
331 | line += '; '
332 |
333 | line += str(v)
334 |
335 | line += '\r\n'
336 | content += line
337 |
338 | #check if buffer is full!
339 | if len(content) >= 50000:
340 | out_file.write(content)
341 | content = ''
342 |
343 | #write any remaining content
344 | out_file.write(content)
345 |
346 | print("... samples added to file successfully!")
347 |
348 | return n_atts
349 |
350 |
351 | def append_dataset(data, labels, out_file):
352 | #...will append samples in dataset to output file...
353 | n_samples = np.size(data, 0)
354 | n_atts = np.size(data, 1)
355 |
356 | print("...adding samples " + str(n_samples) + " to output file...")
357 |
358 | #...now, write the samples....
359 | content = ''
360 | for idx in range(n_samples):
361 | #...add the values...
362 | line = ''
363 | for k in range(n_atts):
364 | if k > 0:
365 | line += '; '
366 |
367 | line += str(data[idx, k])
368 |
369 | #... add the label...
370 | line += "; " + str(labels[idx, 0]) + "\r\n"
371 |
372 | content += line
373 | #....check if buffer is full!
374 | if len(content) >= 50000:
375 | out_file.write(content)
376 | content = ''
377 |
378 | #....write any remaining content
379 | out_file.write(content)
380 |
381 | print("... samples added to file successfully!")
382 |
383 |
384 | def append_dataset_string_labels(data, labels, out_file):
385 | #...will append samples in dataset to output file...
386 | n_samples = np.size(data, 0)
387 | n_atts = np.size(data, 1)
388 |
389 | print("...adding samples " + str(n_samples) + " to output file...")
390 |
391 | #...now, write the samples....
392 | content = ''
393 | for idx in range(n_samples):
394 | #...add the values...
395 | line = ''
396 | for k in range(n_atts):
397 | if k > 0:
398 | line += '; '
399 |
400 | line += str(data[idx, k])
401 |
402 | #... add the label...
403 | line += "; " + labels[idx] + "\r\n"
404 |
405 | content += line
406 | #....check if buffer is full!
407 | if len(content) >= 50000:
408 | out_file.write(content)
409 | content = ''
410 |
411 | #....write any remaining content
412 | out_file.write(content)
413 |
414 | print("... samples added to file successfully!")
415 |
416 |
417 | def save_dataset_string_labels(data, labels, att_types, out_filename):
418 | n_samples = np.size(data, 0)
419 |
420 | try:
421 | out_file = open(out_filename, 'w')
422 | except:
423 | print( "File <" + out_filename + "> could not be created")
424 | return
425 |
426 | #...writing first header....
427 | content = ''
428 | #print as headers the types for each feature...
429 | n_atts = np.size(att_types, 0)
430 | for i in range(n_atts):
431 | if i > 0:
432 | content += '; '
433 |
434 | if att_types[i, 0] == 2:
435 | content += 'D'
436 | else:
437 | content += 'C'
438 |
439 | content += '\r\n'
440 | out_file.write(content)
441 |
442 | #...now, write the samples....
443 | content = ''
444 | for idx in range(n_samples):
445 | #...add the values...
446 | line = ''
447 | for k in range(n_atts):
448 | if k > 0:
449 | line += '; '
450 |
451 | line += str(data[idx, k])
452 |
453 | #... add the label...
454 | line += "; " + str(labels[idx]) + "\r\n"
455 |
456 | content += line
457 | #....check if buffer is full!
458 | if len(content) >= 50000:
459 | out_file.write(content)
460 | content = ''
461 |
462 | #....write any remaining content
463 | out_file.write(content)
464 |
465 | out_file.close()
466 |
467 |
468 | #====================================================
469 | # Data Normalization: Centering and scaling
470 | #====================================================
471 |
472 | #----------------------------------------------------
473 | # Function that computes normalization parameters
474 | # and applies them to given data returning a
475 | # normalized version of it
476 | #----------------------------------------------------
477 | def normalize_data(data):
478 | params = []
479 | n_samples = np.size(data,0)
480 | n_atts = np.size(data, 1)
481 | new_data = np.zeros((n_samples, n_atts))
482 |
483 | #for each attribute...
484 | for att in range(n_atts):
485 | #the mean and std_dev for each att
486 | cut = data[:, att]
487 | cut_mean = cut.mean()
488 | cut_std = cut.std()
489 |
490 | #add to parameters list...
491 | params.append((cut_mean, cut_std))
492 |
493 | #now normalize...
494 | if cut_std > 0.0:
495 | new_data[:, att] = (data[:, att] - cut_mean) / cut_std
496 | else:
497 | #only center...
498 | new_data[:, att] = (data[:, att] - cut_mean)
499 |
500 | return new_data, params
501 |
502 | #------------------------------------------------------
503 | # Function that receives a dataset and normalization
504 | # parameters, and applies them returning a normalized
505 | # version of the data
506 | #------------------------------------------------------
507 | def normalize_data_from_params(data, params):
508 | n_samples = np.size(data,0)
509 | n_atts = np.size(data, 1)
510 | new_data = np.zeros((n_samples, n_atts))
511 |
512 | #for each attribute...
513 | for att in range(n_atts):
514 | cut_mean, cut_std = params[att]
515 |
516 | #now normalize...
517 | if cut_std > 0.0:
518 | new_data[:, att] = (data[:, att] - cut_mean) / cut_std
519 | else:
520 | #only center...
521 | new_data[:, att] = (data[:, att] - cut_mean)
522 |
523 | return new_data
524 |
525 |
--------------------------------------------------------------------------------
/src/distorter.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import numpy as np
25 | import ctypes
26 | import scipy.ndimage as ndimage
27 | from mathSymbol import *
28 |
29 | #=====================================================================
30 | # Generates distorted versions of samples
31 | #
32 | # Created by:
33 | # - Kenny Davila (Jan 14, 2012-2014)
34 | # Modified By:
35 | # - Kenny Davila (Jan 14, 2012-2014)
36 | # - Kenny Davila (Jan 20, 2012-2014)
37 | #
38 | #=====================================================================
39 |
40 | distorter_lib = ctypes.CDLL('./distorter_lib.so')
41 | distorter_lib.distorter_init()
42 |
43 |
44 | class Distorter:
45 | def __init__(self):
46 | self.noise_size = 64 # 128
47 | self.noise_map_size = 8
48 | self.noise_map_depth = 6
49 |
50 | def getBoundingBox(self, points):
51 | min_x = points[0][0]
52 | max_x = points[0][0]
53 | min_y = points[0][1]
54 | max_y = points[0][1]
55 |
56 | for i in range(1, len(points)):
57 | x, y = points[i]
58 |
59 | if x < min_x:
60 | min_x = x
61 | if x > max_x:
62 | max_x = x
63 | if y < min_y:
64 | min_y = y
65 | if y > max_y:
66 | max_y = y
67 |
68 | return (min_x, max_x, min_y, max_y)
69 |
70 | def distortPoints(self, points, max_diagonal):
71 | new_points = []
72 |
73 | #check original box...
74 | min_x, max_x, min_y, max_y = self.getBoundingBox(points)
75 | w = max_x - min_x
76 | h = max_y - min_y
77 |
78 | if w == 0.0:
79 | w = 0.000001
80 |
81 | if h == 0.0:
82 | h = 0.000001
83 |
84 | line_dist = math.sqrt(w ** 2 + h ** 2) * max_diagonal
85 |
86 | #Get noise maps...
87 | #...x....
88 | #noise_x = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth)
89 | noise_x = np.zeros((self.noise_size, self.noise_size))
90 | noise_x_p = noise_x.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
91 | distorter_lib.distorter_create_noise_map(noise_x_p, self.noise_size, self.noise_map_size, self.noise_map_depth)
92 | #...smooth x....
93 | noise_x = ndimage.gaussian_filter(noise_x, sigma=(2, 2), order=0)
94 |
95 | #...y....
96 | #noise_y = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth)
97 | noise_y = np.zeros((self.noise_size, self.noise_size))
98 | noise_y_p = noise_y.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
99 | distorter_lib.distorter_create_noise_map(noise_y_p, self.noise_size, self.noise_map_size, self.noise_map_depth)
100 | #...smooth y....
101 | noise_y = ndimage.gaussian_filter(noise_y, sigma=(2, 2), order=0)
102 |
103 | #now, apply distortion....
104 | #...for each point....
105 | for i in range(len(points)):
106 | x, y = points[i]
107 |
108 | #compute relative position...
109 | p_x = (x - min_x) / w
110 | p_y = (y - min_y) / h
111 |
112 | #compute noise position...
113 | p_nx = int(p_x * self.noise_size)
114 | p_ny = int(p_y * self.noise_size)
115 | if p_nx >= self.noise_size:
116 | p_nx = self.noise_size - 1
117 | if p_ny >= self.noise_size:
118 | p_ny = self.noise_size - 1
119 |
120 | #get distortion...
121 | dist_x = noise_x[p_nx, p_ny]
122 | dist_y = noise_y[p_nx, p_ny]
123 |
124 |
125 | #print "Pre: (" + str(dist_x) + ", " + str(dist_y) + ")"
126 | """
127 | #normalize...
128 | norm = math.sqrt(dist_x * dist_x + dist_y * dist_y)
129 | if norm > 0.0:
130 | dist_x /= norm
131 | dist_y /= norm
132 | else:
133 | dist_x = 0.0
134 | dist_y = 0.0
135 | """
136 | #print "Post: (" + str(dist_x) + ", " + str(dist_y) + ")"
137 |
138 |
139 | off_x = dist_x * line_dist
140 | off_y = dist_y * line_dist
141 |
142 | new_x = x + off_x
143 | new_y = y + off_y
144 |
145 | new_points.append((new_x, new_y))
146 |
147 | return new_points
148 |
149 | #create a distorted version of the given symbol...
150 | def distortSymbol(self, symbol, max_diagonal):
151 |
152 | #create distorted traces...
153 | all_traces = []
154 | for t in symbol.traces:
155 | new_points = self.distortPoints(t.original_points, max_diagonal)
156 |
157 | new_trace = TraceInfo(t.id, new_points)
158 |
159 | #try smoothing the distorted version...
160 | new_trace.removeDuplicatedPoints()
161 |
162 | #Add points to the trace...
163 | new_trace.addMissingPoints()
164 |
165 | #Apply smoothing to the trace...
166 | new_trace.applySmoothing()
167 |
168 | #it should not ... but .....
169 | if new_trace.hasDuplicatedPoints():
170 | #...remove them! ....
171 | new_trace.removeDuplicatedPoints()
172 |
173 | all_traces.append(new_trace)
174 |
175 | #now, create the new symbol...
176 | new_symbol = MathSymbol(symbol.id, all_traces, symbol.truth)
177 | new_symbol.normalize()
178 |
179 | return new_symbol
180 |
--------------------------------------------------------------------------------
/src/distorter_lib.c:
--------------------------------------------------------------------------------
1 | /*
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | */
24 |
25 | //Compile using:
26 | // gcc -shared distorter.c -o distorter.so
27 |
28 | #include
29 | #include
30 | #include
31 | #include
32 | #include
33 |
34 | double distorter_init(){
35 | //Use current time as a seed
36 | srand(time(0));
37 | }
38 |
39 | double distorter_rand(){
40 | return ((double)rand() / (double)RAND_MAX);
41 | }
42 |
43 | double distorter_mirror(double value){
44 |
45 | double i_value = floor(value);
46 |
47 | if ( ((int)i_value) % 2 == 0 ){
48 | //Not mirror
49 | value -= i_value;
50 | } else {
51 | //Mirror
52 | value = 1.0 - (value - i_value);
53 | }
54 |
55 | return value;
56 | }
57 |
58 | double distorter_bilinear_filtering(double* map, double x, double y, int n_rows, int n_cols){
59 | if ( x < 0.0){
60 | x = 0.0;
61 | }
62 | if ( y < 0.0){
63 | y = 0.0;
64 | }
65 | //...row (y) and its weights...
66 | double v_row = y * (n_rows - 1);
67 | int r0 = (int)floor(v_row);
68 | double r_w1;
69 | if ( r0 >= n_rows - 1 ){
70 | r_w1 = 1.0;
71 | r0 = n_rows - 2;
72 | } else {
73 | r_w1 = v_row - floor(v_row);
74 | }
75 | int r1 = r0 + 1;
76 |
77 | //...col (x) and its weights
78 | double v_col = x * (n_cols - 1);
79 | int c0 = (int)floor(v_col);
80 | double c_w1;
81 | if (c0 >= n_cols - 1){
82 | c_w1 = 1.0;
83 | c0 = n_cols - 2;
84 | }else{
85 | c_w1 = v_col - c0;
86 | }
87 | int c1 = c0 + 1;
88 |
89 | //Bilinear filtering...
90 | double final_val = map[r0 * n_cols + c0] * (1.0 - r_w1) * (1.0 - c_w1) +
91 | map[r1 * n_cols + c0] * r_w1 * (1.0 - c_w1) +
92 | map[r0 * n_cols + c1] * (1.0 - r_w1) * c_w1 +
93 | map[r1 * n_cols + c1] * r_w1 * c_w1;
94 |
95 | return final_val;
96 | }
97 |
98 | void distorter_create_noise_map(double* map_buffer, int map_n, int noise_n, int max_d){
99 | int i, row, col;
100 |
101 | //Create random noise...
102 | int t_noise = noise_n * noise_n;
103 | double* c_map = (double *)calloc(t_noise, sizeof(double));
104 |
105 | //...Set a random value per element...
106 | double* p_map = c_map;
107 | for (i = 0; i < t_noise; i++ ){
108 | *p_map = distorter_rand() * 2.0 - 1.0;
109 | p_map++;
110 | }
111 |
112 | //..also random displacements...
113 | int t_disp = 2 * max_d;
114 | double* disp = (double *)calloc(2 * max_d, sizeof(double));
115 | double* p_disp = disp;
116 | for ( i = 0; i < t_disp; i++){
117 | *p_disp = distorter_rand();
118 | p_disp++;
119 | }
120 |
121 | //Combine the maps acording to perlin noise algorithm...
122 | for (row = 0; row < map_n; row++ ){
123 | for (col = 0; col < map_n; col++ ){
124 | map_buffer[row * map_n + col] = 0.0;
125 | }
126 | }
127 |
128 | //... for each size of the map...
129 | double w, p_r, p_c, c_pow;
130 | for ( i = 0; i < max_d; i++){
131 | double dr = disp[ i * 2 ];
132 | double dc = disp[ i * 2 + 1 ];
133 |
134 | if ( i < max_d - 1 ){
135 | w = 1.0 / pow(2, i + 1);
136 | } else {
137 | w = 1.0 / pow(2, i);
138 | }
139 |
140 | c_pow = pow(2, i);
141 |
142 | //...for each pixel....
143 | for ( row = 0; row < map_n; row++){
144 | p_r = ((double)row / (double)map_n) * c_pow;
145 | //...mirror mapping...
146 | p_r = distorter_mirror(p_r + dr);
147 |
148 | for ( col = 0; col < map_n; col++){
149 | p_c = ((double)col /(double)map_n) * c_pow;
150 | //...mirror mapping...
151 | p_c = distorter_mirror(p_c + dc);
152 |
153 | map_buffer[row * map_n + col] += w * distorter_bilinear_filtering(c_map, p_c, p_r, noise_n, noise_n);
154 | }
155 | }
156 | }
157 |
158 | //...Release allocated memory....
159 | free( disp );
160 | free( c_map );
161 |
162 | }
163 |
164 | /*
165 | def getBoundingBox(self, points):
166 | min_x = points[0][0]
167 | max_x = points[0][0]
168 | min_y = points[0][1]
169 | max_y = points[0][1]
170 |
171 | for i in range(1, len(points)):
172 | x, y = points[i]
173 |
174 | if x < min_x:
175 | min_x = x
176 | if x > max_x:
177 | max_x = x
178 | if y < min_y:
179 | min_y = y
180 | if y > max_y:
181 | max_y = y
182 |
183 | return (min_x, max_x, min_y, max_y)
184 |
185 | def distortPoints(self, points, max_diagonal):
186 | new_points = []
187 |
188 | #check original box...
189 | min_x, max_x, min_y, max_y = self.getBoundingBox(points)
190 | w = max_x - min_x
191 | h = max_y - min_y
192 |
193 | if w == 0.0:
194 | w = 0.000001
195 |
196 | if h == 0.0:
197 | h = 0.000001
198 |
199 | line_dist = math.sqrt(w ** 2 + h ** 2) * max_diagonal
200 |
201 | #Get noise maps...
202 | #...x....
203 | noise_x = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth)
204 | #...smooth x....
205 | noise_x = ndimage.gaussian_filter(noise_x, sigma=(2, 2), order=0)
206 | #...y....
207 | noise_y = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth)
208 | #...smooth y....
209 | noise_y = ndimage.gaussian_filter(noise_y, sigma=(2, 2), order=0)
210 |
211 | #now, apply distortion....
212 | #...for each point....
213 | for i in range(len(points)):
214 | x, y = points[i]
215 |
216 | #compute relative position...
217 | p_x = (x - min_x) / w
218 | p_y = (y - min_y) / h
219 |
220 | #compute noise position...
221 | p_nx = int(p_x * self.noise_size)
222 | p_ny = int(p_y * self.noise_size)
223 | if p_nx >= self.noise_size:
224 | p_nx = self.noise_size - 1
225 | if p_ny >= self.noise_size:
226 | p_ny = self.noise_size - 1
227 |
228 | #get distortion...
229 | dist_x = noise_x[p_nx, p_ny]
230 | dist_y = noise_y[p_nx, p_ny]
231 |
232 | #normalize...
233 | norm = math.sqrt(dist_x * dist_x + dist_y * dist_y)
234 | if norm > 0.0:
235 | dist_x /= norm
236 | dist_y /= norm
237 | else:
238 | dist_x = 0.0
239 | dist_y = 0.0
240 |
241 | off_x = dist_x * line_dist
242 | off_y = dist_y * line_dist
243 |
244 | new_x = x + off_x
245 | new_y = y + off_y
246 |
247 | new_points.append((new_x, new_y))
248 |
249 | return new_points
250 |
251 | #create a distorted version of the given symbol...
252 | def distortSymbol(self, symbol, max_diagonal):
253 |
254 | #create distorted traces...
255 | all_traces = []
256 | for t in symbol.traces:
257 | new_points = self.distortPoints(t.original_points, max_diagonal)
258 |
259 | new_trace = TraceInfo(t.id, new_points)
260 |
261 | #try smoothing the distorted version...
262 | new_trace.removeDuplicatedPoints()
263 |
264 | #Add points to the trace...
265 | new_trace.addMissingPoints()
266 |
267 | #Apply smoothing to the trace...
268 | new_trace.applySmoothing()
269 |
270 | #it should not ... but .....
271 | if new_trace.hasDuplicatedPoints():
272 | #...remove them! ....
273 | new_trace.removeDuplicatedPoints()
274 |
275 | all_traces.append(new_trace)
276 |
277 | #now, create the new symbol...
278 | new_symbol = MathSymbol(symbol.id, all_traces, symbol.truth)
279 | new_symbol.normalize()
280 |
281 | return new_symbol
282 |
283 | */
284 |
--------------------------------------------------------------------------------
/src/evaluation_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import numpy as np
25 |
26 | #=====================================================================
27 | # General functions used for evaluation
28 | #
29 | # Created by:
30 | # - Kenny Davila (Feb 6, 2012-2014)
31 | # Modified By:
32 | # - Kenny Davila (Feb 6, 2012-2014)
33 | # - Kenny Davila (Feb 23, 2012-2014)
34 | # - changes to handle ambiguous classes
35 | # - output confusion matrices
36 | # - Kenny Davila (Feb 25, 2012-2014)
37 | # - Add Top-N evaluation functions
38 | #
39 | #=====================================================================
40 |
41 |
42 | def get_average_class_accuracy(counts_per_class, errors_per_class, n_classes):
43 | #now, only consider classes that existed in testing set...
44 | valid_classes = counts_per_class > 0
45 |
46 | all_accuracies = 1.0 - errors_per_class[valid_classes] / counts_per_class[valid_classes]
47 |
48 | avg_accuracy = all_accuracies.mean()
49 | std_accuracy = all_accuracies.std()
50 |
51 | return avg_accuracy, std_accuracy
52 |
53 | def compute_error_counts(predicted, labels, n_classes):
54 | counts_per_class = np.zeros(n_classes)
55 | errors_per_class = np.zeros(n_classes)
56 | total_correct = 0
57 | n_samples = np.size(labels, 0)
58 | for k in range(n_samples):
59 | expected = labels[k, 0]
60 |
61 | counts_per_class[expected] += 1
62 |
63 | if predicted[k] == expected:
64 | total_correct += 1
65 | else:
66 | errors_per_class[expected] += 1
67 |
68 | return total_correct, counts_per_class, errors_per_class
69 |
70 |
71 | def compute_topn_error_counts(predicted, labels, n_classes, top_n):
72 | counts_per_class = np.zeros(n_classes)
73 | errors_per_class = np.zeros(n_classes)
74 | total_correct = 0
75 | n_samples = np.size(labels, 0)
76 | for k in range(n_samples):
77 | expected = labels[k, 0]
78 |
79 | counts_per_class[expected] += 1
80 |
81 | found = False
82 | for n in range(top_n):
83 | if predicted[k, n] == expected:
84 | found = True
85 | break
86 |
87 | if found:
88 | total_correct += 1
89 | else:
90 | errors_per_class[expected] += 1
91 |
92 | return total_correct, counts_per_class, errors_per_class
93 |
94 |
95 | def compute_confusion_matrix(predicted, labels, n_classes):
96 | confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.int32)
97 |
98 | n_test_samples = np.size(labels, 0)
99 |
100 | #... for each sample
101 | for k in range(n_test_samples):
102 | expected = int(labels[k, 0])
103 |
104 | confusion_matrix[expected, int(predicted[k])] += 1
105 |
106 | return confusion_matrix
107 |
108 |
109 | def compute_ambiguous_confusion_matrix(all_predicted, labels, n_classes, ambiguous):
110 | confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.int32)
111 |
112 | n_test_samples = np.size(labels, 0)
113 |
114 | #... for each sample
115 | for k in range(n_test_samples):
116 | expected = int(labels[k, 0])
117 | predicted = int(all_predicted[k])
118 |
119 | if ambiguous[expected, predicted] == 1:
120 | #no error will be reported between ambiguous classes
121 | confusion_matrix[expected, expected] += 1
122 | else:
123 | #usual error condition
124 | confusion_matrix[expected, predicted] += 1
125 |
126 | return confusion_matrix
127 |
128 | def save_evaluation_results(classes_l, confusion_matrix, file_name, ambiguous):
129 |
130 | n_classes = np.size(confusion_matrix, 0)
131 |
132 | per_class_correct = np.zeros(n_classes)
133 | per_class_counts = confusion_matrix.sum(1)
134 |
135 | class_avg_list = []
136 | for i in range(n_classes):
137 | per_class_correct[i] = confusion_matrix[i, i]
138 |
139 | if per_class_counts[i] > 0:
140 | class_acc = per_class_correct[i] / per_class_counts[i]
141 |
142 | class_avg_list.append((class_acc, i, classes_l[i]))
143 |
144 | #...sort the list...
145 | class_avg_list = sorted(class_avg_list,reverse=True)
146 |
147 | #now, only consider classes that existed in testing set...
148 | valid_classes = per_class_counts > 0
149 | per_class_acc = per_class_correct[valid_classes] / per_class_counts[valid_classes]
150 |
151 | n_test_samples = confusion_matrix.sum()
152 | total_correct = per_class_correct.sum()
153 |
154 | accuracy = (float(total_correct) / float(n_test_samples)) * 100.0
155 | per_class_avg = per_class_acc.mean() * 100.0
156 | per_class_std = per_class_acc.std() * 100.0
157 |
158 | #.... save confusion matrix to file...
159 | out_str = "Samples:," + str(n_test_samples) + "\n"
160 | out_str += "N Classes:," + str(n_classes) + "\n"
161 | out_str += "Correct:," + str(total_correct) + "\n"
162 | out_str += "Wrong:," + str(n_test_samples - total_correct) + "\n"
163 | out_str += "Accuracy:," + str(accuracy) + "\n"
164 | out_str += "Per class AVG:, " + str(per_class_avg) + "\n"
165 | out_str += "Per class STD:, " + str(per_class_std) + "\n"
166 | out_str += "\n\n"
167 | out_str += "Rows:,Expected \n"
168 | out_str += "Column:,Predicted \n\n"
169 | out_str += "Full Matrix\n"
170 | out_str += "X"
171 | only_err = "X"
172 | #...build header...
173 | for i in range(n_classes ):
174 | c_class = classes_l[i]
175 | if c_class == ",":
176 | c_class = "COMMA"
177 |
178 | out_str += "," + c_class
179 | only_err += "," + c_class
180 |
181 | out_str += ",Total,Train Count\n"
182 | only_err += ",Total\n"
183 |
184 | #...build the content...
185 | for i in range(n_classes):
186 | c_class = classes_l[i]
187 | if c_class == ",":
188 | c_class = "COMMA"
189 |
190 | out_str += c_class
191 | only_err += c_class
192 |
193 | t_errs = 0
194 | t_samples = 0
195 | for k in range(n_classes):
196 | out_str += "," + str(confusion_matrix[i, k])
197 |
198 | t_samples += confusion_matrix[i, k]
199 | if i == k:
200 | only_err += ",0"
201 | else:
202 | t_errs += confusion_matrix[i, k]
203 | only_err += "," + str(confusion_matrix[i, k])
204 |
205 | out_str += "," + str(t_samples) + "\n"
206 | only_err += "," + str(t_errs) + "\n"
207 |
208 | #...build the totals row...
209 | out_str += "Total"
210 | only_err += "Total"
211 | for i in range(n_classes):
212 | t_errs = 0
213 | t_samples = 0
214 |
215 | for k in range(n_classes):
216 | #...inverted, k = row, i = column
217 | t_samples += confusion_matrix[k, i]
218 | if i != k:
219 | t_errs += confusion_matrix[k, i]
220 |
221 | out_str += "," + str(t_samples)
222 | only_err += "," + str(t_errs)
223 |
224 | out_str += "\n\n\nOnly Errors\n\n" + only_err + "\n\n"
225 |
226 | #now, add the class average per class....
227 | #...also, if list of ambiguous is available... mark them!
228 | if not ambiguous is None:
229 | ambiguous_list = ambiguous.sum(0)
230 |
231 | out_str += "idx,Class,Accuracy,Ambiguous\n"
232 | else:
233 | out_str += "idx,Class,Accuracy\n"
234 |
235 | for class_acc, idx, class_name in class_avg_list:
236 | if class_name == ",":
237 | class_name = "COMMA"
238 |
239 | out_str += str(idx) + "," + class_name + "," + str(class_acc)
240 | if not ambiguous is None:
241 | if ambiguous_list[idx] > 0:
242 | out_str += ",1"
243 | else:
244 | out_str += ",0"
245 | out_str += "\n"
246 |
247 | try:
248 | f = open(file_name, 'w')
249 | f.write(out_str)
250 | f.close()
251 | except:
252 | print("ERROR WRITING RESULTS TO FILE! <" + file_name + ">")
253 |
254 |
255 | def load_ambiguous(filename, class_dict, skip_failures):
256 | #first, load all the content from the file...
257 | file_ambiguous = open(filename, 'r')
258 | all_ambiguous_lines = file_ambiguous.readlines()
259 | file_ambiguous.close()
260 |
261 | n_classes = len(class_dict.keys())
262 |
263 | total_ambiguous = 0
264 | ambiguous = np.zeros((n_classes, n_classes), dtype=np.int32)
265 | for idx, line in enumerate(all_ambiguous_lines):
266 | #split and check...
267 | parts = line.split(',')
268 | if len(parts) != 2:
269 | print("Invalid content found in ambiguous file!")
270 | print("Line " + str(idx + 1) + ": " + line)
271 | return None
272 |
273 | grapheme_1 = parts[0].strip()
274 | grapheme_2 = parts[1].strip()
275 |
276 | #check if valid ...
277 | if not grapheme_1 in class_dict.keys():
278 | print("Invalid class found: " + str(grapheme_1) + ".")
279 | print("Line " + str(idx + 1) + ": " + line)
280 |
281 | if skip_failures:
282 | continue
283 | else:
284 | return
285 |
286 | if not grapheme_2 in class_dict.keys():
287 | print("Invalid class found: " + str(grapheme_2) + ".")
288 | print("Line " + str(idx + 1) + ": " + line)
289 |
290 | if skip_failures:
291 | continue
292 | else:
293 | return
294 |
295 | #...they are valid, now build the dictionary but using their mapped versions
296 | class1 = class_dict[grapheme_1]
297 | class2 = class_dict[grapheme_2]
298 |
299 | #...mark them as ambiguous (symmetric relation)
300 | if ambiguous[class1, class2] == 0:
301 | ambiguous[class1, class2] = 1
302 | ambiguous[class2, class1] = 1
303 |
304 | total_ambiguous += 1
305 |
306 |
307 | return ambiguous, total_ambiguous
308 |
--------------------------------------------------------------------------------
/src/extract_symbol.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | from load_inkml import *
26 |
27 | #=====================================================================
28 | # Extract a symbol from a given inkml file using a sym id, and then
29 | # outputs the symbol to a SVG file
30 | #
31 | # Created by:
32 | # - Kenny Davila (Feb 11, 2012-2014)
33 | # Modified By:
34 | # - Kenny Davila (Feb 11, 2012-2014)
35 | #
36 | #=====================================================================
37 |
38 | def output_sample( file_path, sym_id, out_path ):
39 | #load file...
40 | symbols = load_inkml(file_path, True)
41 |
42 | #find symbol ...
43 | for symbol in symbols:
44 | if symbol.id == sym_id:
45 | print("Symbol found, class: " + symbol.truth)
46 | symbol.saveAsSVG( out_path )
47 |
48 | return True
49 |
50 | return False
51 |
52 | def main():
53 | #usage check
54 | if len(sys.argv) != 4:
55 | print("Usage: python extract_symbol.py inkml_file sym_id output")
56 | print("Where")
57 | print("\tinkml_file\t= Path to the inkml file that contains the symbol")
58 | print("\tsym_id\t\t= Id of the symbol to extract")
59 | print("\toutput\t\t= File where the extracted symbol will be stored")
60 | return
61 |
62 | file_path = sys.argv[1]
63 | sym_id = int(sys.argv[2])
64 | output_path = sys.argv[3]
65 |
66 | if output_sample(file_path, sym_id, output_path):
67 | print("Sample extracted successfully!")
68 | else:
69 | print("Sample was not found in the given file!")
70 |
71 | main()
72 |
--------------------------------------------------------------------------------
/src/get_PCA_parameters.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 | import sys
26 | from sklearn.preprocessing import StandardScaler
27 | import numpy as np
28 | import pickle
29 | from dataset_ops import *
30 |
31 | #=====================================================================
32 | # Find the PCA parameters for a given training set. The learn
33 | # parameters can be later applied for dimensionality reduction
34 | # of the dataset
35 | #
36 | # Created by:
37 | # - Kenny Davila (Feb 1, 2012-2014)
38 | # Modified By:
39 | # - Kenny Davila (Feb 1, 2012-2014)
40 | #
41 | #=====================================================================
42 |
43 |
44 | def get_PCA( training_set ):
45 | #get the covariance matrix...
46 | cov_matrix = np.cov(training_set.transpose())
47 |
48 | #...obtain eigenvectors and eigenvalues...
49 | eig_values, eig_matrix = np.linalg.eig(cov_matrix)
50 |
51 | #...sort them....
52 | pair_list = [(eig_values[i], i) for i in range(cov_matrix.shape[0])]
53 | pair_list = sorted( pair_list, reverse=True )
54 |
55 | sorted_values = []
56 | sorted_vectors = []
57 |
58 | for eigenvalue, idx in pair_list:
59 | sorted_values.append(eigenvalue)
60 | sorted_vectors.append(np.mat(eig_matrix[:, idx]).T)
61 |
62 | return sorted_values, sorted_vectors
63 |
64 |
65 | def get_variance_K(values, variance):
66 | #...get the variance percentages...
67 | n_atts = len(values)
68 | total_variance = 0.0
69 | cumulative_variance = []
70 | for i in xrange(n_atts):
71 | eigenvalue = abs(values[i])
72 |
73 | total_variance += eigenvalue
74 | cumulative_variance.append(total_variance)
75 |
76 | #...normalize...
77 | for i in xrange(n_atts):
78 | cumulative_variance[i] /= total_variance
79 |
80 | if cumulative_variance[i] >= variance:
81 | return i + 1, cumulative_variance
82 |
83 | return n_atts, cumulative_variance
84 |
85 |
86 | def save_PCA_parameters(file_name, normalization, pca_vector, pca_k):
87 | file_params = open(file_name, 'w')
88 | pickle.dump(normalization, file_params)
89 | pickle.dump(pca_vector, file_params)
90 | pickle.dump(pca_k, file_params)
91 | file_params.close()
92 |
93 |
94 | def main():
95 | #usage check
96 | if len(sys.argv) != 5:
97 | print("Usage: python get_PCA_parameters.py training_set output_params var_max k_max")
98 | print("Where")
99 | print("\ttraining_set\t= Path to the file of the training_set")
100 | print("\toutput_params\t= File to output PCA preprocessing parameters")
101 | print("\tvar_max\t\t= Maximum percentage of variance to add")
102 | print("\tk_max\t\t= Maximum number of Principal Components to Add")
103 | return
104 |
105 | #get parameters
106 | input_filename = sys.argv[1]
107 | output_filename = sys.argv[2]
108 |
109 | try:
110 | var_max = float(sys.argv[3])
111 | if var_max <= 0.0 or var_max > 1.0:
112 | print("Invalid var_max value! ")
113 | return
114 | except:
115 | print("Invalid var_max value! ")
116 | return
117 |
118 | try:
119 | k_max = int(sys.argv[4])
120 | if k_max <= 0:
121 | print("Invalid k_max value!")
122 | return
123 | except:
124 | print("Invalid k_max value!")
125 | return
126 |
127 |
128 | #...load training set ...
129 | print("Loading data....")
130 | training, labels_l, att_types = load_dataset(input_filename)
131 |
132 | print("Data loaded! ... normalizing ...")
133 | #new_data, norm_params = normalize_data(training)
134 | scaler = StandardScaler()
135 | new_data = scaler.fit_transform(training)
136 |
137 | print( "Normalized! ... Applying PCA ..." )
138 | values, vectors = get_PCA(new_data)
139 |
140 | variance_k, all_variances = get_variance_K(values, var_max)
141 |
142 | final_k = min(variance_k, k_max)
143 |
144 | print("Final K = " + str(final_k) + ", variance = " + str(all_variances[final_k - 1]) )
145 |
146 | print("Saving Parameters ... ")
147 | save_PCA_parameters(output_filename, scaler, vectors, final_k)
148 |
149 | print("Finished!")
150 |
151 |
152 | main()
153 |
--------------------------------------------------------------------------------
/src/get_enhanced_clustered_set.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import os
25 | import sys
26 | import fnmatch
27 | from sklearn.preprocessing import StandardScaler
28 | from sklearn.cluster import Ward
29 | from traceInfo import *
30 | from mathSymbol import *
31 | from load_inkml import *
32 | from distorter import *
33 | from dataset_ops import *
34 |
35 | #=====================================================================
36 | # generates an enhanced training set from a directory containing
37 | # inkml files by adding extra samples generated with a distortion
38 | # model. it also applies clustering for to automatically identify
39 | # under represented classes.
40 | #
41 | # Created by:
42 | # - Kenny Davila (Jan 27, 2012-2014)
43 | # Modified By:
44 | # - Kenny Davila (Jan 27, 2012-2014)
45 | #
46 | #=====================================================================
47 |
48 | def print_overwrite(text):
49 | sys.stdout.write("\r")
50 | sys.stdout.write(" " * 79)
51 | sys.stdout.write("\r")
52 | sys.stdout.write(text)
53 |
54 |
55 | def replace_labels(labels):
56 | new_labels = []
57 |
58 | for label in labels:
59 | if label == "\\tg":
60 | new_label = "\\tan"
61 | elif label == '>':
62 | new_label = '\gt'
63 | elif label == '<':
64 | new_label = '\lt'
65 | elif label == '\'':
66 | new_label = '\prime'
67 | elif label == '\cdots':
68 | new_label = '\ldots'
69 | elif label == '\\vec':
70 | new_label = '\\rightarrow'
71 | elif label == '\cdot':
72 | new_label = '.'
73 | elif label == ',':
74 | new_label = 'COMMA'
75 | elif label == '\\frac':
76 | new_label = '-'
77 | else:
78 | #unchanged...
79 | new_label = label
80 |
81 | new_labels.append(new_label)
82 |
83 | return new_labels
84 |
85 |
86 | def get_symbol_features(symbols, n_atts, verbose):
87 | if verbose:
88 | print("...Getting features for samples....")
89 |
90 | #...get features....
91 | total_extra = len(symbols)
92 | extra_features = np.zeros((total_extra, n_atts))
93 |
94 | for idx, symbol in enumerate(symbols):
95 | features = symbol.getFeatures()
96 |
97 | #copy features from list to numpy array
98 | for k in range(n_atts):
99 | extra_features[idx, k] = features[k]
100 |
101 | #...some output....
102 | if verbose and (idx == total_extra - 1 or idx % 10 == 0):
103 | print_overwrite("...Processed " + str(idx + 1) + " of " + str(total_extra))
104 |
105 | return extra_features
106 |
107 |
108 | def main():
109 | #usage check
110 | if len(sys.argv) < 7:
111 | print("Usage: python get_enhanced_clustered_set.py inkml_path output min_prc diag_dist max_clusters " +
112 | "clust_prc [verbose] [count_only]")
113 | print("Where")
114 | print("\tinkml_path\t= Path to directory that contains the inkml files")
115 | print("\toutput\t\t= File name of the output file")
116 | print("\tmin_prc\t\t= Minimum representation based on (%) of largest class")
117 | print("\tdiag_dist\t= Distortion factor relative to length of main diagonal")
118 | print("\tmax_clusters\t= Maximum number of clusters per large class")
119 | print("\tclust_prc\t= Minimum cluster size based on (%) of largest ")
120 | print("\tverbose\t= Optional, print detailed messages ")
121 | print("\tcount_only\t= Will only count what will be the final size of dataset")
122 | return
123 |
124 | #load and filter the list of files, the result is a list of inkml files only
125 | try:
126 | complete_list = os.listdir(sys.argv[1])
127 | filtered_list = []
128 | for file in complete_list:
129 | if fnmatch.fnmatch(file, '*.inkml'):
130 | filtered_list.append( file )
131 | except:
132 | print( "The inkml path <" + sys.argv[1] + "> is invalid!" )
133 | return
134 |
135 | output_filename = sys.argv[2]
136 |
137 | try:
138 | min_prc = float(sys.argv[3])
139 | if min_prc < 0.0:
140 | print("Invalid minimum percentage")
141 | return
142 | except:
143 | print("Invalid minimum percentage")
144 | return
145 |
146 | try:
147 | diag = float(sys.argv[4])
148 | if diag < 0.0:
149 | print("Invalid distortion factor")
150 | return
151 | except:
152 | print("Invalid distortion factor")
153 | return
154 |
155 | try:
156 | components = int(sys.argv[5])
157 | if components < 0:
158 | print("Invalid max_clusters")
159 | except:
160 | print("Invalid max_clusters")
161 | return
162 |
163 | try:
164 | clust_prc = float(sys.argv[6])
165 | if clust_prc < 0.0:
166 | print("Invalid minimum cluster percentage")
167 | return
168 | except:
169 | print("Invalid minimum cluster percentage")
170 | return
171 |
172 | if len(sys.argv) > 7:
173 | try:
174 | verbose = int(sys.argv[7]) > 0
175 | except:
176 | print("Invalid value for verbose")
177 | return
178 | else:
179 | #by default, print all messages
180 | verbose = True
181 |
182 | if len(sys.argv) > 8:
183 | try:
184 | count_only = int(sys.argv[8]) > 0
185 | except:
186 | print("Invalid value for count_only")
187 | return
188 | else:
189 | #by default...
190 | count_only = False
191 |
192 | #....read every inkml file in the path specified...
193 | #....create the initial symbol objects...
194 | all_symbols = []
195 | sources = []
196 | print("Loading samples from files.... ")
197 | for i in range(len(filtered_list)):
198 | file_name = filtered_list[i]
199 | file_path = sys.argv[1] + '//' + file_name
200 | advance = float(i) / len(filtered_list)
201 | if verbose:
202 | print_overwrite(("Processing => {:.2%} => " + file_path).format(advance))
203 |
204 | #print()
205 |
206 | symbols = load_inkml(file_path, True)
207 |
208 | for new_symbol in symbols:
209 | all_symbols.append(new_symbol)
210 | sources.append((file_path, new_symbol.id))
211 |
212 | print("")
213 | print("....samples loaded!")
214 |
215 | #...get features of original training set...
216 | print("Getting features of base training set...")
217 | training = None
218 | labels_l = None
219 | att_types = None
220 | n_original_samples = len(all_symbols)
221 | for idx, symbol in enumerate(all_symbols):
222 | features = symbol.getFeatures()
223 |
224 | if idx == 0:
225 | #...create training set ....
226 |
227 | #...first, the feature types...
228 | n_atts = len(features)
229 | feature_types = symbol.getFeaturesTypes()
230 | att_types = np.zeros((n_atts, 1), dtype=np.int32)
231 | for i in xrange(n_atts):
232 | if feature_types[i] == 'D':
233 | att_types[i] = 2
234 | else:
235 | att_types[i] = 1
236 |
237 | #...for labels...
238 | labels_l = []
239 |
240 | #...finally, the matrix for training set itself...
241 | training = np.zeros((n_original_samples, n_atts))
242 |
243 | #copy features from list to numpy array
244 | for k in range(n_atts):
245 | training[idx, k] = features[k]
246 |
247 | #copy label
248 | labels_l.append(symbol.truth)
249 |
250 | #...some output....
251 | if verbose and (idx == n_original_samples - 1 or idx % 10 == 0):
252 | print_overwrite("...Processed " + str(idx + 1) + " of " + str(n_original_samples))
253 |
254 | #...correct labels....
255 | labels_l = replace_labels(labels_l)
256 |
257 |
258 | #...scale original data...
259 | print("Scaling original data...")
260 | scaler = StandardScaler()
261 | scaled_training = scaler.fit_transform(training)
262 |
263 | #...save original samples to file....
264 | if not count_only:
265 | print("Saving original data...")
266 | save_dataset_string_labels(training, labels_l, att_types, output_filename)
267 | #append_dataset_string_labels
268 |
269 | #... identify under-represented classes ...
270 | #... first, count samples per class ...
271 | print("Getting counts...")
272 | count_per_class = {}
273 | refs_per_class = {}
274 | for idx in range(n_original_samples):
275 | #s_label = labels_train[idx, 0]
276 | s_label = labels_l[idx]
277 |
278 | if s_label in count_per_class:
279 | count_per_class[s_label] += 1
280 | refs_per_class[s_label].append(idx)
281 | else:
282 | count_per_class[s_label] = 1
283 | refs_per_class[s_label] = [idx]
284 |
285 | n_classes = len(count_per_class.keys())
286 |
287 | #...distribute data in classes...
288 | largest_size = 0
289 | for label in count_per_class:
290 | #...check largest...
291 | if count_per_class[label] > largest_size:
292 | largest_size = count_per_class[label]
293 |
294 | print("Samples on largest class: " + str(largest_size))
295 |
296 | #Now, do the data enhancement....
297 | distorter = Distorter()
298 |
299 | #....for each class....
300 | n_extra_samples = 0
301 | min_elements = int(math.ceil(min_prc * largest_size))
302 | print("Analyzing data per class...")
303 | for label in count_per_class:
304 | current_refs = refs_per_class[label]
305 | n_class_samples = count_per_class[label]
306 |
307 | if n_class_samples < min_elements:
308 | print("Class: " + label + ", count = " + str(n_class_samples) + ", adding samples...")
309 |
310 | #class is under represented, generate distorted artificial samples....
311 | to_create = min_elements - n_class_samples
312 | n_extra_samples += to_create
313 |
314 | if not count_only:
315 | #...create...
316 | extra_samples = []
317 | for i in range(to_create):
318 | #...take one element from original samples
319 | base_symbol = all_symbols[current_refs[(i % n_class_samples)]]
320 |
321 | #...create distorted version...
322 | new_symbol = distorter.distortSymbol(base_symbol, diag)
323 |
324 | #...add...
325 | extra_samples.append(new_symbol)
326 |
327 | #...get features....
328 | extra_features = get_symbol_features(extra_samples, n_atts, verbose)
329 | #...create labels....
330 | extra_labels = [label] * len(extra_samples)
331 |
332 | #...now, save the extra samples to file...
333 | out_file = open(output_filename, 'a')
334 | append_dataset_string_labels(extra_features, extra_labels, out_file)
335 | out_file.close()
336 | else:
337 | #class has more than enough data, try clustering and then enhancing small clusters...
338 |
339 | print("Class: " + label + ", count = " + str(n_class_samples) + ", clustering...")
340 |
341 | if components == 0:
342 | print("...Skipping clustering...")
343 | continue
344 |
345 | #create empty dataset ...
346 | class_data = np.zeros((n_class_samples, n_atts))
347 | #fill with samples...
348 | for i in xrange(n_class_samples):
349 | class_data[i, :] = scaled_training[current_refs[i], :]
350 |
351 | #apply clustering...
352 | ward = Ward(n_clusters=components).fit(class_data)
353 | cluster_labels = ward.labels_
354 |
355 | #separate references per cluster, and get counts...
356 | counts_per_cluster = [0] * components
357 | refs_per_cluster = {}
358 | for i in range(n_class_samples):
359 | c_label = int(cluster_labels[i])
360 |
361 | if c_label in refs_per_cluster:
362 | refs_per_cluster[c_label].append(current_refs[i])
363 | else:
364 | refs_per_cluster[c_label] = [current_refs[i]]
365 |
366 | #increase the count of samples per class...
367 | counts_per_cluster[c_label] += 1
368 |
369 | #...check minimum size for enhancement
370 | largest_cluster = max(counts_per_cluster)
371 | #print "Largest cluster: " + str(largest_cluster)
372 | min_cluster_size = int(math.ceil(clust_prc * largest_cluster))
373 | final_counts_per_cluster = []
374 |
375 | #...for each cluster...
376 | extra_samples = []
377 | for i in range(components):
378 | #...if has les than the minimum number of elements...
379 | if counts_per_cluster[i] < min_cluster_size:
380 | #enhance....
381 | c_cluster = counts_per_cluster[i]
382 | to_create = min_cluster_size - c_cluster
383 | cluster_refs = refs_per_cluster[i]
384 |
385 | n_extra_samples += to_create
386 |
387 | if not count_only:
388 | #...create...
389 | for k in range(to_create):
390 | #...take one element from original samples
391 | base_symbol = all_symbols[cluster_refs[(k % c_cluster)]]
392 | #...create distorted version...
393 | new_symbol = distorter.distortSymbol(base_symbol, diag)
394 | #...modify label (use new mapped label)...
395 | #new_symbol.truth = classes_dict[new_symbol.truth]
396 | #...add...
397 | extra_samples.append(new_symbol)
398 |
399 | final_counts_per_cluster.append(min_cluster_size)
400 | else:
401 | final_counts_per_cluster.append(counts_per_cluster[i])
402 |
403 | #print counts_per_cluster
404 | #print final_counts_per_cluster
405 |
406 | print ("Cluster #" + str(i + 1) + ", i. size = " + str(counts_per_cluster[i]) +
407 | ", f. size = " + str(final_counts_per_cluster[i]))
408 |
409 | if not count_only:
410 | #...get features....
411 | extra_features = get_symbol_features(extra_samples, n_atts, verbose)
412 |
413 | #...create labels matrix....
414 | extra_labels = [label] * len(extra_samples)
415 |
416 | #...now, save the extra samples to file...
417 | out_file = open(output_filename, 'a')
418 | #append_dataset(extra_features, extra_labels, out_file)
419 | append_dataset_string_labels(extra_features, extra_labels, out_file)
420 | out_file.close()
421 |
422 | extra_samples = []
423 |
424 | print("Initial size: " + str(n_original_samples))
425 | print("Extra samples: " + str(n_extra_samples))
426 | print("Final size: " + str(n_extra_samples + n_original_samples))
427 | print("Added Ratio: " + str((float(n_extra_samples) / float(n_original_samples)) * 100.0))
428 |
429 | print("Finished!")
430 |
431 | main()
432 |
--------------------------------------------------------------------------------
/src/get_training_set.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import os
25 | import sys
26 | import fnmatch
27 | import string
28 | from traceInfo import *
29 | from mathSymbol import *
30 | from load_inkml import *
31 |
32 | #=====================================================================
33 | # generates a training set from a directory containing the inkml
34 | #
35 | # Created by:
36 | # - Kenny Davila (Oct, 2012)
37 | # Modified By:
38 | # - Kenny Davila (Oct 19, 2012)
39 | # - Kenny Davila (Nov 25, 2013)
40 | # - Added ID to symbol
41 | # - Added AUX file to output origin for each sample in DS
42 | # - Kenny Davila (Jan 16, 2012-2014)
43 | # - Print number of attributes found
44 | # - Kenny Davila (March 2016)
45 | # - Added additional error handling for files with errors
46 | #
47 | #=====================================================================
48 |
49 | def main():
50 | #usage check
51 | if len(sys.argv) != 3:
52 | print("Usage: python get_training_set.py inkml_path output")
53 | print("Where")
54 | print("\tinkml_path\t= Path to directory that contains the inkml files")
55 | print("\toutput\t\t= File name of the output file")
56 | return
57 |
58 | #load and filter the list of files, the result is a list of inkml files only
59 | try:
60 | complete_list = os.listdir(sys.argv[1])
61 | filtered_list = []
62 | for file in complete_list:
63 | if fnmatch.fnmatch(file, '*.inkml'):
64 | filtered_list.append( file )
65 | except:
66 | print( "The inkml path <" + sys.argv[1] + "> is invalid!" )
67 | return
68 |
69 | samples = []
70 | labels_found = {}
71 | sources = []
72 | error_files = []
73 |
74 | #read every file in the path specified...
75 | for i in range(len(filtered_list)):
76 | file_name = filtered_list[i]
77 | file_path = sys.argv[1] + '//' + file_name;
78 | advance = float(i) / len(filtered_list)
79 | print(("Processing => {:.2%} => " + file_path).format( advance ))
80 |
81 | try:
82 | symbols = load_inkml( file_path, True )
83 | except:
84 | print("Failed processing: " + file_path)
85 | error_files.append(file_path)
86 | symbols = []
87 |
88 | for new_symbol in symbols:
89 | #now generate the features and add them to the list, including the tag
90 | #for the expected class....
91 | sample = new_symbol.getFeatures() + [ new_symbol.truth ]
92 | samples.append( sample )
93 |
94 | #count samples per class
95 | if not new_symbol.truth in labels_found:
96 | labels_found[ new_symbol.truth ] = 1
97 | else:
98 | labels_found[ new_symbol.truth ] += 1
99 |
100 | #the source of current symbol will be
101 | #exported as auxiliary file
102 | sources.append( ( file_path, new_symbol.id) )
103 |
104 |
105 | print("Total input files: " + str(len(filtered_list)))
106 | print("Total valid files: " + str(len(filtered_list) - len(error_files)))
107 | print("Files with errors: ")
108 | for filename in error_files:
109 | print("\t- " + filename)
110 |
111 | print("Total files with error: " + str(len(error_files)))
112 |
113 | print( "Found: " + str(len(labels_found.keys())) + " different classes" )
114 | if len(samples) > 0:
115 | print( "Found: " + str(len(samples[0]) - 1 ) + " different attributes" )
116 |
117 | print "Saving main .... "
118 | #now that all the samples have been collected, write them all
119 | #in the output file
120 | try:
121 | file = open(sys.argv[2], 'w')
122 | except:
123 | print( "File <" + sys.argv[2] + "> could not be created")
124 | return
125 |
126 | content = ''
127 | #print as headers the types for each feature...
128 | feature_types = new_symbol.getFeaturesTypes()
129 | for i, feat_type in enumerate(feature_types):
130 | if i > 0:
131 | content += '; '
132 | content += feat_type
133 | content += '\r\n'
134 |
135 | for sample in samples:
136 | line = ''
137 | for i, v in enumerate(sample):
138 | if i > 0:
139 | line += '; '
140 |
141 | if v.__class__.__name__ == "list":
142 | #multiple values...
143 | for j, sv in enumerate(v):
144 | if j > 0:
145 | line += '; '
146 | line += str(sv)
147 | else:
148 | #single value...
149 | line += str(v)
150 |
151 | line += '\r\n'
152 | content += line
153 |
154 | if len(content) >= 50000:
155 | file.write(content)
156 | content = ''
157 |
158 | file.write(content)
159 |
160 | file.close()
161 |
162 | print "Saving auxiliary.... "
163 |
164 | #Now, add the auxiliary file
165 | try:
166 | aux_file = open(sys.argv[2] + ".sources.txt" , 'w')
167 | except:
168 | print( "File <" + sys.argv[2] + ".sources.txt> could not be created")
169 | return
170 |
171 | content = ''
172 | for source_path, sym_id in sources:
173 | content += source_path + ', ' + str(sym_id) + '\r\n'
174 |
175 | aux_file.write(content)
176 |
177 | aux_file.close()
178 |
179 | print "Done!"
180 | main()
181 |
--------------------------------------------------------------------------------
/src/load_inkml.py:
--------------------------------------------------------------------------------
1 | from traceInfo import *
2 | from mathSymbol import *
3 | import xml.etree.ElementTree as ET
4 |
5 | #=====================================================================
6 | # Load an INKML file using the mathSymbol and traceInfo classes
7 | # to represent the data
8 | #
9 | # Created by:
10 | # - Kenny Davila (Oct, 2012)
11 | # Modified By:
12 | # - Kenny Davila (Oct 24, 2013)
13 | # - Kenny Davila (Nov 25, 2013)
14 | # - Added ID to symbol
15 | # - Kenny Davila (Jan 17, 2014)
16 | # - Fixed ID cases for CHROME 2013 data containing colon
17 | # - Kenny Davila (Feb 3, 2014)
18 | # - Fixed cases where symbol loaded has no traces
19 | # - Kenny Davila (Apr 11, 2014)
20 | # - Fixed cases where symbol loaded has no traces
21 | # - Added junk symbol compatibility
22 | # - Kenny Davila (March, 2016)
23 | # - Error handling
24 | #
25 | #=====================================================================
26 |
27 |
28 | #debug flags
29 | debug_raw = False
30 | debug_added = False
31 | debug_smoothing = False
32 | debug_normalization = False
33 |
34 | #the current XML namespace prefix...
35 | INKML_NAMESPACE = '{http://www.w3.org/2003/InkML}'
36 |
37 | def load_inkml_traces(file_name):
38 | #first load the tree...
39 | tree = ET.parse(file_name)
40 | root = tree.getroot()
41 |
42 | #extract all the traces first...
43 | traces_objects = {}
44 | for trace in root.findall(INKML_NAMESPACE + 'trace'):
45 | #text contains all points as string, parse them and put them
46 | #into a list of tuples...
47 | points_s = trace.text.split(",");
48 | points_f = []
49 | for p_s in points_s:
50 | #split again...
51 | coords_s = p_s.split()
52 | #add...
53 | points_f.append( (float(coords_s[0]), float(coords_s[1])) )
54 |
55 | trace_id = int(trace.attrib['id'])
56 |
57 | #now create the element
58 | object_trace = TraceInfo(trace_id, points_f )
59 |
60 | #add to the diccionary...
61 | traces_objects[trace_id] = object_trace
62 |
63 | #apply general trace pre processing...
64 |
65 | #1) first step of pre processing: Remove duplicated points
66 | object_trace.removeDuplicatedPoints()
67 |
68 | if debug_raw:
69 | #output raw data
70 | file = open('out_raw_' + trace.attrib['id'] + '.txt', 'w')
71 | file.write( str(object_trace) )
72 | file.close()
73 |
74 | #Add points to the trace...
75 | object_trace.addMissingPoints()
76 |
77 | if debug_added:
78 | #output raw data
79 | file = open('out_added_' + trace.attrib["id"] + '.txt', 'w')
80 | file.write( str(object_trace) )
81 | file.close()
82 |
83 | #Apply smoothing to the trace...
84 | object_trace.applySmoothing()
85 |
86 | #it should not ... but .....
87 | if object_trace.hasDuplicatedPoints():
88 | #...remove them! ....
89 | object_trace.removeDuplicatedPoints()
90 |
91 | if debug_smoothing:
92 | #output data after smoothing
93 | file = open('out_smoothed_' + trace.attrib["id"] + '.txt', 'w')
94 | file.write( str(object_trace) )
95 | file.close()
96 |
97 |
98 |
99 | return root, traces_objects
100 |
101 | def extract_symbols( root, traces_objects, truth_available ):
102 | #put all the traces together with their corresponding symbols...
103 | #first, find the root of the trace groups...
104 | groups_root = root.find(INKML_NAMESPACE + 'traceGroup')
105 | trace_groups = groups_root.findall(INKML_NAMESPACE + 'traceGroup')
106 |
107 | symbols = []
108 | avg_width = 0.0
109 | avg_height = 0.0
110 |
111 | for group in trace_groups:
112 | if truth_available:
113 | #search for class label...
114 | symbol_class = group.find(INKML_NAMESPACE + 'annotation').text
115 |
116 | #search for id attribute...
117 | symbol_id = 0
118 | for id_att_name in group.attrib:
119 | if id_att_name[-2:] == "id":
120 | try:
121 | symbol_id = int(group.attrib[id_att_name])
122 | except:
123 | #could not convert to int, try spliting...
124 | symbol_id = int( group.attrib[id_att_name].split(":")[0] )
125 |
126 | else:
127 | #unknown
128 | symbol_class = '{Unknown}'
129 | symbol_id = 0
130 |
131 | #link with corresponding traces...
132 | group_traces = group.findall(INKML_NAMESPACE + 'traceView')
133 | symbol_list = []
134 | for trace in group_traces:
135 | object_trace = traces_objects[int(trace.attrib["traceDataRef"])]
136 | symbol_list.append(object_trace)
137 |
138 | #create the math symbol...
139 | try:
140 | new_symbol = MathSymbol(symbol_id, symbol_list, symbol_class)
141 | except Exception as e:
142 | print("Failed to load symbol!!")
143 | print(e)
144 | #skip this symbol...
145 | continue
146 |
147 | #capture statistics of relative size....
148 | symMinX, symMaxX, symMinY, symMaxY = new_symbol.original_box
149 | #...add...
150 | avg_width += (symMaxX - symMinX)
151 | avg_height += (symMaxY - symMinY)
152 |
153 | #now normalize size and locations for traces in current symbol
154 | new_symbol.normalize()
155 |
156 | if debug_normalization:
157 | #output data after relocated
158 | for trace in group_traces:
159 | object_trace = traces_objects[int(trace.attrib["traceDataRef"])]
160 |
161 | file = open('out_reloc_' + trace.attrib["traceDataRef"] + '.txt', 'w')
162 | file.write( str(object_trace) )
163 | file.close()
164 |
165 | symbols.append(new_symbol)
166 |
167 | if len(trace_groups) > 0:
168 | avg_width /= len(trace_groups)
169 | avg_height /= len(trace_groups)
170 |
171 | for s in symbols:
172 | s.setSizeRatio(avg_width, avg_height)
173 |
174 | return symbols
175 |
176 | def load_inkml(file_name, truth_available):
177 |
178 | root, traces_objects = load_inkml_traces(file_name)
179 |
180 | symbols = extract_symbols(root, traces_objects, truth_available)
181 |
182 | return symbols
183 |
184 | def extract_junk_symbol(traces_objects, junk_class_name):
185 | symbol_id = 0
186 | symbol_class = junk_class_name
187 |
188 | symbol_list = []
189 | for trace_id in traces_objects:
190 | symbol_list.append(traces_objects[trace_id])
191 |
192 | #create the math symbol...
193 | try:
194 | new_symbol = MathSymbol(symbol_id, symbol_list, symbol_class)
195 | except Exception as e:
196 | print("Failed to load symbol!!")
197 | print(e)
198 | return None
199 |
200 | #now normalize size and locations for traces in current symbol
201 | new_symbol.normalize()
202 |
203 | return new_symbol
204 |
205 | def load_junk_inkml(file_name, junk_class_name):
206 |
207 | root, traces_objects = load_inkml_traces(file_name)
208 |
209 | symbols = [extract_junk_symbol(traces_objects, junk_class_name)]
210 |
211 | return symbols
212 |
--------------------------------------------------------------------------------
/src/parallel_evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 |
26 | import sys
27 | import time
28 | import math
29 | import numpy as np
30 | import cPickle
31 | import multiprocessing
32 | from sklearn.preprocessing import StandardScaler
33 | from dataset_ops import *
34 | from evaluation_ops import *
35 | from symbol_classifier import SymbolClassifier
36 |
37 | #=====================================================================
38 | # this program takes as input a training set, a testing set and a
39 | # classifier and runs the evaluation on parallel, then metrics
40 | # are computed.
41 | #
42 | # Created by:
43 | # - Kenny Davila (Feb 12, 2012-2014)
44 | # Modified By:
45 | # - Kenny Davila (Feb 12, 2012-2014)
46 | # - Kenny Davila (Mar 11, 2015)
47 | # - Symbol classifier class now adopted
48 | #
49 | #=====================================================================
50 |
51 |
52 | def evaluate_data(classifier, data, first, last, results_queue):
53 |
54 | predicted = classifier.predict(data)
55 |
56 | results_queue.put((first, last, predicted))
57 |
58 | def parallel_evaluate(classifier, dataset, workers):
59 | #...determine the ranges per worker...
60 | n_samples = np.size(dataset, 0)
61 | max_per_worker = int(math.ceil(float(n_samples) / float(workers)))
62 |
63 | #...prepare for parallel processing...
64 | results_queue = multiprocessing.Queue()
65 | processes = []
66 | final_results = np.zeros(n_samples)
67 |
68 | #...start parallel threads....
69 | for i in range(workers):
70 | first_element = max_per_worker * i
71 | last_element = min(max_per_worker * (i + 1), n_samples)
72 |
73 | params = (classifier, dataset[first_element:last_element, :], first_element, last_element, results_queue)
74 | p = multiprocessing.Process(target=evaluate_data, args=params)
75 | p.start()
76 |
77 | processes.append(p)
78 |
79 | #...await for al results to be ready...
80 | for i in range(workers):
81 | first, last, predicted = results_queue.get()
82 |
83 | #merge with current final results...
84 | final_results[first:last] = predicted
85 |
86 | return final_results
87 |
88 |
89 | def main():
90 | if len(sys.argv) < 5:
91 | print("Usage: python parallel_evaluate.py training_set testing_set classifier normalize " +
92 | "workers [test_only] [ambiguous]")
93 | print("Where")
94 | print("\ttraining_set\t= Path to the file of the training set")
95 | print("\ttesting_set\t= Path to the file of the testing set")
96 | print("\tclassifier\t= File that contains the pickled classifier")
97 | print("\tworkers\t\t= Number of parallel threads to use")
98 | print("\ttest_only\t= Optional, Only execute for testing set")
99 | print("\tambiguous\t= Optional, file that contains the list of ambiguous")
100 | return
101 |
102 | training_file = sys.argv[1]
103 | testing_file = sys.argv[2]
104 | classifier_file = sys.argv[3]
105 |
106 | try:
107 | workers = int(sys.argv[4])
108 | if workers < 1:
109 | print("Invalid number of workers")
110 | return
111 | except:
112 | print("Invalid number of workers")
113 | return
114 |
115 | if len(sys.argv) >= 6:
116 | try:
117 | test_only = int(sys.argv[5]) > 0
118 | except:
119 | print("Invalid value for test_only")
120 | return
121 | else:
122 | test_only = False
123 |
124 | if len(sys.argv) >= 7:
125 | allograph_file = sys.argv[6]
126 | else:
127 | allograph_file = None
128 | ambiguous = None
129 |
130 | print("Loading classifier...")
131 |
132 | in_file = open(classifier_file, 'rb')
133 | classifier = cPickle.load(in_file)
134 | in_file.close()
135 |
136 | if not isinstance(classifier, SymbolClassifier):
137 | print("Invalid classifier file!")
138 | return
139 |
140 | # get mapping
141 | classes_dict = classifier.classes_dict
142 | classes_l = classifier.classes_list
143 | n_classes = len(classes_l)
144 |
145 | print("Loading data...")
146 |
147 | if not test_only:
148 | #...loading traininig data...
149 | training, labels_l, att_types = load_dataset(training_file)
150 | if att_types is None:
151 | print("Error loading File <" + training_file + ">")
152 | return
153 |
154 | #...generate mapped labels...
155 | labels_train = get_mapped_labels(labels_l, classes_dict)
156 | else:
157 | training, labels_l, labels_train = (None, None, None)
158 |
159 | #...loading testing data...
160 | testing, test_labels_l, att_types = load_dataset(testing_file)
161 |
162 | if att_types is None:
163 | print("Error loading File <" + testing_file + ">")
164 | return
165 |
166 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
167 |
168 | if classifier.scaler is not None:
169 | print("Normalizing...")
170 |
171 | scaler = classifier.scaler
172 | if not test_only:
173 | training = scaler.transform(training)
174 |
175 | testing = scaler.transform(testing)
176 |
177 | if not allograph_file is None:
178 | print("Loading ambiguous file...")
179 | ambiguous, total_ambiguous = load_ambiguous(allograph_file, classes_dict, True)
180 |
181 | print("...A total of " + str(total_ambiguous) + " were found")
182 |
183 | start_time = time.time()
184 |
185 | print("Evaluating in multiple threads...")
186 |
187 | if not test_only:
188 | # Training data
189 | n_training_samples = np.size(training, 0)
190 |
191 | # ....first, evaluate samples on multiple threads
192 | predicted = parallel_evaluate(classifier, training, workers)
193 | #....on main thread, compute final statistics
194 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
195 | accuracy = (float(total_correct) / float(n_training_samples)) * 100
196 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
197 | print "Training Samples: " + str(n_training_samples)
198 | print "Training Results"
199 | print "Accuracy\tClass Average\tClass STD "
200 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
201 | else:
202 | print("...Skipping Training set...")
203 |
204 | n_testing_samples = np.size(testing, 0)
205 |
206 | #Testing data
207 | #....first, evaluate samples on multiple threads
208 | predicted = parallel_evaluate(classifier, testing, workers)
209 | #....on main thread, compute final statistics
210 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
211 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100
212 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
213 | print "Testing Samples: " + str(n_testing_samples)
214 | print "Testing Results"
215 | print "Accuracy\tClass Average\tClass STD "
216 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
217 |
218 | end_time = time.time()
219 | total_elapsed = end_time - start_time
220 |
221 | print "Total Elapsed: " + str(total_elapsed)
222 |
223 | #get confusion matrix...
224 | print("Generating confusion matrix for test...")
225 | confussion = compute_confusion_matrix(predicted, labels_test, n_classes)
226 | print("Saving results...")
227 | results_file = classifier_file + ".results.csv"
228 | save_evaluation_results(classes_l, confussion, results_file, ambiguous)
229 |
230 | #...check for ambiguous...
231 | if not ambiguous is None:
232 | #...recompute metrics considering ambiguous...
233 | confussion = compute_ambiguous_confusion_matrix(predicted, labels_test, n_classes, ambiguous)
234 |
235 | results_file = classifier_file + ".results_ambiguous.csv"
236 | save_evaluation_results(classes_l, confussion, results_file, ambiguous)
237 |
238 | print("...Finished!")
239 |
240 | if __name__ == '__main__':
241 | main()
242 |
243 |
244 |
245 |
246 |
--------------------------------------------------------------------------------
/src/parallel_prob_evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import time
26 | import math
27 | import numpy as np
28 | import cPickle
29 | import multiprocessing
30 | from sklearn.preprocessing import StandardScaler
31 | from dataset_ops import *
32 | from evaluation_ops import *
33 | from symbol_classifier import SymbolClassifier
34 |
35 | #=====================================================================
36 | # this program takes as input a training set, a testing set and a
37 | # classifier and runs the evaluation on parallel, then metrics
38 | # are computed.
39 | #
40 | # Created by:
41 | # - Kenny Davila (Feb 12, 2012-2014)
42 | # Modified By:
43 | # - Kenny Davila (Feb 12, 2012-2014)
44 | # - Kenny Davila (Feb 25, 2012-2014
45 | # - Now includes top-5 accuracy too
46 | # - Kenny Davila (Mar 11, 2015)
47 | # - Symbol classifier class now adopted
48 | #
49 | #=====================================================================
50 |
51 |
52 | def evaluate_data(classifier, data, first, last, top_n, results_queue):
53 | #use probabilistic evaluation to get probability per class...
54 | predicted = classifier.predict_proba(data)
55 |
56 | n_samples = np.size(data, 0)
57 | n_classes = np.size(predicted, 1)
58 | results = np.zeros((n_samples, top_n))
59 |
60 | #find the top N values per sample...
61 | raw_classes = classifier.get_raw_classes()
62 | for i in range(n_samples):
63 | tempo_values = []
64 | for k in range(n_classes):
65 | tempo_values.append((predicted[i, k], raw_classes[k]))
66 |
67 | #now, sort!
68 | tempo_values = sorted(tempo_values, reverse=True)
69 |
70 | #use the top-N
71 | for k in range(top_n):
72 | results[i, k] = tempo_values[k][1]
73 |
74 | results_queue.put((first, last, results))
75 |
76 | def parallel_evaluate(classifier, dataset, workers, top_n):
77 | #...determine the ranges per worker...
78 | n_samples = np.size(dataset, 0)
79 | max_per_worker = int(math.ceil(float(n_samples) / float(workers)))
80 |
81 | #...prepare for parallel processing...
82 | results_queue = multiprocessing.Queue()
83 | processes = []
84 | final_results = np.zeros((n_samples, top_n))
85 |
86 | #...start parallel threads....
87 | for i in range(workers):
88 | first_element = max_per_worker * i
89 | last_element = min(max_per_worker * (i + 1), n_samples)
90 |
91 | params = (classifier, dataset[first_element:last_element, :], first_element, last_element, top_n, results_queue)
92 | p = multiprocessing.Process(target=evaluate_data, args=params)
93 | p.start()
94 |
95 | processes.append(p)
96 |
97 | #...await for al results to be ready...
98 | for i in range(workers):
99 | first, last, predicted = results_queue.get()
100 |
101 | #merge with current final results...
102 | final_results[first:last, :] = predicted
103 |
104 | return final_results
105 |
106 |
107 | def main():
108 | if len(sys.argv) < 5:
109 | print("Usage: python parallel_prob_evaluate.py training_set testing_set classifier normalize " +
110 | "workers [test_only] [ambiguous]")
111 | print("Where")
112 | print("\ttraining_set\t= Path to the file of the training set")
113 | print("\ttesting_set\t= Path to the file of the testing set")
114 | print("\tclassifier\t= File that contains the pickled classifier")
115 | print("\tworkers\t\t= Number of parallel threads to use")
116 | print("\ttest_only\t= Optional, Only execute for testing set")
117 | print("\tambiguous\t= Optional, file that contains the list of ambiguous")
118 | return
119 |
120 | training_file = sys.argv[1]
121 | testing_file = sys.argv[2]
122 | classifier_file = sys.argv[3]
123 |
124 | try:
125 | workers = int(sys.argv[4])
126 | if workers < 1:
127 | print("Invalid number of workers")
128 | return
129 | except:
130 | print("Invalid number of workers")
131 | return
132 |
133 | if len(sys.argv) >= 6:
134 | try:
135 | test_only = int(sys.argv[5]) > 0
136 | except:
137 | print("Invalid value for test_only")
138 | return
139 | else:
140 | test_only = False
141 |
142 | if len(sys.argv) >= 7:
143 | allograph_file = sys.argv[6]
144 | else:
145 | allograph_file = None
146 | ambiguous = None
147 |
148 | print("Loading classifier...")
149 |
150 | in_file = open(classifier_file, 'rb')
151 | classifier = cPickle.load(in_file)
152 | in_file.close()
153 |
154 | if not isinstance(classifier, SymbolClassifier):
155 | print("Invalid classifier file!")
156 | return
157 |
158 | print("Loading data...")
159 |
160 | # get mapping
161 | classes_dict = classifier.classes_dict
162 | classes_l = classifier.classes_list
163 | n_classes = len(classes_l)
164 |
165 | if not test_only:
166 | #...loading traininig data...
167 | training, labels_l, att_types = load_dataset(training_file)
168 | if att_types is None:
169 | print("Error loading File <" + training_file + ">")
170 | return
171 |
172 | #...generate mapped labels...
173 | labels_train = get_mapped_labels(labels_l, classes_dict)
174 | else:
175 | training, labels_l, labels_train = (None, None, None)
176 |
177 | #...loading testing data...
178 | testing, test_labels_l, att_types = load_dataset(testing_file)
179 |
180 | if att_types is None:
181 | print("Error loading File <" + testing_file + ">")
182 | return
183 |
184 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
185 |
186 | if classifier.scaler is not None:
187 | print("Normalizing...")
188 |
189 | scaler = classifier.scaler
190 | if not test_only:
191 | training = scaler.transform(training)
192 |
193 | testing = scaler.transform(testing)
194 |
195 | if not allograph_file is None:
196 | print("Loading ambiguous file...")
197 | ambiguous, total_ambiguous = load_ambiguous(allograph_file, classes_dict, True)
198 |
199 | print("...A total of " + str(total_ambiguous) + " were found")
200 |
201 | print("Training data set: " + training_file)
202 | print("Testing data set: " + testing_file)
203 | print("Evaluating in multiple threads...")
204 |
205 | start_time = time.time()
206 |
207 | top_n = 5
208 |
209 | if not test_only:
210 | # Training data
211 | n_training_samples = np.size(training, 0)
212 |
213 | # ....first, evaluate samples on multiple threads
214 | predicted = parallel_evaluate(classifier, training, workers, top_n)
215 |
216 | print "Training Samples: " + str(n_training_samples)
217 | print "Training Results"
218 | print "Top\tAccuracy\tClass Average\tClass STD "
219 |
220 | # ....on main thread, compute final statistics
221 | for i in range(top_n):
222 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_train,
223 | n_classes, i + 1)
224 | accuracy = (float(total_correct) / float(n_training_samples)) * 100
225 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
226 |
227 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
228 | else:
229 | print("...Skipping Training set...")
230 |
231 | n_testing_samples = np.size(testing, 0)
232 |
233 | # Testing data
234 | # ....first, evaluate samples on multiple threads
235 | predicted = parallel_evaluate(classifier, testing, workers, top_n)
236 |
237 | #....on main thread, compute final statistics
238 | print "Testing Samples: " + str(n_testing_samples)
239 | print "Testing Results"
240 | print "Top\tAccuracy\tClass Average\tClass STD "
241 |
242 | for i in range(top_n):
243 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_test,
244 | n_classes, i + 1)
245 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100
246 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
247 |
248 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
249 |
250 | end_time = time.time()
251 | total_elapsed = end_time - start_time
252 |
253 | print "Total Elapsed: " + str(total_elapsed)
254 |
255 | print("...Finished!")
256 |
257 | if __name__ == '__main__':
258 | main()
259 |
260 |
--------------------------------------------------------------------------------
/src/random_forest_classify.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import time
26 | import cPickle
27 | import math
28 | from sklearn.ensemble import RandomForestClassifier
29 | from dataset_ops import *
30 | from evaluation_ops import *
31 | from symbol_classifier import SymbolClassifier
32 |
33 | #=====================================================================
34 | # this program takes as input a data set and using Random forest
35 | # it generates a set of decision trees to classify data.
36 | #
37 | # Created by:
38 | # - Kenny Davila (Feb 1, 2012-2014)
39 | # Modified By:
40 | # - Kenny Davila (Feb 1, 2014)
41 | # - Kenny Davila (March 11, 2015)
42 | # - Incorporated SymbolClassifier class
43 | #
44 | #=====================================================================
45 |
46 | def predict_in_chunks(data, classifier, max_chunk_size):
47 | n_samples = np.size(data, 0)
48 | all_predicted = np.zeros(n_samples)
49 | pos = 0
50 | while pos < n_samples:
51 | last = min(pos + max_chunk_size, n_samples)
52 |
53 | #...get predictions for current chuck...
54 | predicted = classifier.predict(data[pos:last])
55 | #...add them to result...
56 | all_predicted[pos:last] = predicted
57 |
58 | #print("...Evaluated " + str(last) + " out of " + str(n_samples) )
59 |
60 | pos = last
61 |
62 | return all_predicted
63 |
64 |
65 |
66 | def main():
67 | #usage check
68 | if len(sys.argv) < 8:
69 | print("Usage: python random_forest_classify.py training_set testing_set N_trees max_D ")
70 | print(" max_feats type times [n_jobs] [out_file]")
71 | print("Where")
72 | print("\ttraining_set\t= Path to the file of the training set")
73 | print("\ttesting_set\t= Path to the file of the testing set")
74 | print("\tN_trees\t\t= Number of trees to use")
75 | print("\tmax_D\t\t= Maximum Depth")
76 | print("\tmax_feats\t= Maximum Features")
77 | print("\ttype\t\t= Type of Decision trees (criterion for splits)")
78 | print("\t\t\t\t0 - Gini")
79 | print("\t\t\t\t1 - Entropy")
80 | print("\ttimes\t\t= Number of times to repeat experiments")
81 | print ("\tn_jobs\t\t= Optional, number of parallel threads to use")
82 | print ("\tout_file\t= Optional, file where classifier will be stored")
83 | return
84 |
85 | print("Loading data....")
86 | #...load training data from file...
87 | train_filename = sys.argv[1]
88 | training, labels_l, att_types = load_dataset(train_filename)
89 | #...generate mapping...
90 | classes_dict, classes_l = get_label_mapping(labels_l)
91 | n_classes = len(classes_l)
92 | #...generate mapped labels...
93 | labels_train = get_mapped_labels(labels_l, classes_dict)
94 |
95 | #...load testing data from file...
96 | test_filename = sys.argv[2]
97 | testing, test_labels_l, att_types = load_dataset(test_filename)
98 | #...generate mapped labels...
99 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
100 |
101 | try:
102 | n_trees = int(sys.argv[3])
103 | if n_trees < 1:
104 | print("Invalid N_trees value")
105 | return
106 | except:
107 | print("Invalid N_trees value")
108 | return
109 |
110 | try:
111 | max_D = int(sys.argv[4])
112 | if max_D < 1:
113 | print("Invalid max_D value")
114 | return
115 | except:
116 | print("Invalid max_D value")
117 | return
118 |
119 |
120 | try:
121 | max_features = int(sys.argv[5])
122 | if max_features < 1:
123 | print("Invalid max_feats value")
124 | return
125 | except:
126 | print("Invalid max_feats value")
127 | return
128 |
129 | try:
130 | criterion_t = 'gini' if int(sys.argv[6]) == 0 else 'entropy'
131 | except:
132 | print("Invalid type value")
133 | return
134 |
135 | try:
136 | times = int(sys.argv[7])
137 | if times < 1:
138 | print("Invalid times value")
139 | return
140 | except:
141 | print("Invalid times value")
142 | return
143 |
144 | if len(sys.argv) >= 9:
145 | try:
146 | num_jobs = int(sys.argv[8])
147 | except:
148 | print("Invalid number of jobs")
149 | return
150 | else:
151 | num_jobs = 1
152 |
153 | if len(sys.argv) >= 10:
154 | out_filename = sys.argv[9]
155 | else:
156 | out_filename = train_filename + ".best.RF"
157 |
158 |
159 | print("....Data loaded!")
160 |
161 | start_time = time.time()
162 | total_training_time = 0
163 | total_testing_time = 0
164 |
165 | all_training_accuracies = np.zeros(times)
166 | all_training_averages = np.zeros(times)
167 | all_training_stds = np.zeros(times)
168 | all_testing_accuracies = np.zeros(times)
169 | all_testing_averages = np.zeros(times)
170 | all_testing_stds = np.zeros(times)
171 |
172 | n_training_samples = np.size(training, 0)
173 | n_testing_samples = np.size(testing, 0)
174 |
175 | #tree = DecisionTreeClassifier(criterion="gini")
176 | #tree = DecisionTreeClassifier(criterion="entropy")
177 | #boosting = AdaBoostClassifier(DecisionTreeClassifier(criterion="gini", max_depth=25), n_estimators=50,learning_rate=1.0)
178 |
179 | print " \tTrain\tTrain\tTrain\tTest\tTest\tTest"
180 | print " # \tACC\tC. AVG\tC. STD\tACC\tC. AVG\tC. STD"
181 | best_forest_ref = None
182 | best_forest_accuracy = None
183 |
184 | for i in range(times):
185 | start_training_time = time.time()
186 | #print "Training #" + str(i + 1)
187 | forest = RandomForestClassifier(n_estimators=n_trees,criterion=criterion_t,
188 | max_features=max_features, max_depth=max_D,n_jobs=num_jobs) #n_jobs=-1?
189 |
190 | forest.fit(training, np.ravel(labels_train))
191 |
192 | end_training_time = time.time()
193 | total_training_time += end_training_time - start_training_time
194 |
195 | start_testing_time = time.time()
196 |
197 | #...first, training error
198 | predicted = predict_in_chunks(training, forest, 1000)
199 | #...compute metrics...
200 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
201 | accuracy = (float(total_correct) / float(n_training_samples)) * 100
202 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
203 | all_training_accuracies[i] = accuracy
204 | all_training_averages[i] = avg_accuracy * 100.0
205 | all_training_stds[i] = std_accuracy * 100.0
206 |
207 | #...then, testing error
208 | #predicted = forest.predict(testing)
209 | predicted = predict_in_chunks(testing, forest, 1000)
210 | #...compute metrics ...
211 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
212 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100
213 | test_accuracy = accuracy
214 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
215 | all_testing_accuracies[i] = accuracy
216 | all_testing_averages[i] = avg_accuracy * 100.0
217 | all_testing_stds[i] = std_accuracy * 100.0
218 |
219 | end_testing_time = time.time()
220 |
221 | total_testing_time += end_testing_time - start_testing_time
222 |
223 | print(" " + str(i+1) + "\t" + str(round(all_training_accuracies[i], 3)) + "\t" +
224 | str(round(all_training_averages[i], 3)) + "\t" +
225 | str(round(all_training_stds[i], 3)) + "\t" +
226 | str(round(all_testing_accuracies[i], 3)) + "\t" +
227 | str(round(all_testing_averages[i], 3)) + "\t" +
228 | str(round(all_testing_stds[i], 3)))
229 |
230 | if best_forest_ref is None or best_forest_accuracy < test_accuracy:
231 | best_forest_ref = forest
232 | best_forest_accuracy = test_accuracy
233 |
234 |
235 | end_time = time.time()
236 | total_elapsed = end_time - start_time
237 |
238 | print(" Mean\t" + str(round(all_training_accuracies.mean(), 3)) + "\t" +
239 | str(round(all_training_averages.mean(), 3)) + "\t" +
240 | str(round(all_training_stds.mean(), 3)) + "\t" +
241 | str(round(all_testing_accuracies.mean(), 3)) + "\t" +
242 | str(round(all_testing_averages.mean(), 3)) + "\t" +
243 | str(round(all_testing_stds.mean(), 3)))
244 |
245 | print(" STD\t" + str(round(all_training_accuracies.std(), 3)) + "\t" +
246 | str(round(all_training_averages.std(), 3)) + "\t" +
247 | str(round(all_training_stds.std(), 3)) + "\t" +
248 | str(round(all_testing_accuracies.std(), 3)) + "\t" +
249 | str(round(all_testing_averages.std(), 3)) + "\t" +
250 | str(round(all_testing_stds.std(), 3)))
251 |
252 | #...Save to file...
253 | classifier = SymbolClassifier(SymbolClassifier.TypeRandomForest, best_forest_ref, classes_l, classes_dict)
254 |
255 | out_file = open(out_filename, 'wb')
256 | cPickle.dump(classifier, out_file, cPickle.HIGHEST_PROTOCOL)
257 | out_file.close()
258 |
259 |
260 | print("Total Elapsed: " + str(total_elapsed))
261 | print("Mean Elapsed Time: " + str(total_elapsed / times))
262 | print("Total Training time: " + str(total_training_time))
263 | print("Mean Training time: " + str(total_training_time / times))
264 | print("Total Evaluating time: " + str(total_testing_time))
265 | print("Mean Evaluating time: " + str(total_testing_time / times))
266 |
267 | print("Training File: " + train_filename)
268 | print("Training samples: " + str(n_training_samples))
269 | print("Testing File: " + test_filename)
270 | print("Testing samples: " + str(n_testing_samples))
271 | print("N_trees: " + str(n_trees))
272 | print("Criterion: " + criterion_t)
273 | print("Max features: " + str(max_features))
274 | print("Max Depth: " + str(max_D))
275 | print("Finished!")
276 |
277 | if __name__ == '__main__':
278 | main()
279 |
--------------------------------------------------------------------------------
/src/svm_lin_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import time
26 | import numpy as np
27 | import cPickle
28 | from sklearn import svm
29 | from sklearn.preprocessing import StandardScaler
30 | from dataset_ops import *
31 | from evaluation_ops import *
32 | from symbol_classifier import SymbolClassifier
33 |
34 |
35 | def main():
36 | # usage check...
37 | if len(sys.argv) < 3:
38 | print("Usage: python svm_lin_classifier.py training testing [evaluate] [probab]")
39 | print("Where")
40 | print("\ttraining\t= Path to training set")
41 | print("\ttesting\t= Path to testing set")
42 | print("\tevaluate\t= Optional, run evaluation or just training ")
43 | print("\tprobab\t= Optional, make it a probabilistic classifier")
44 | print("\toutput\t= Optional, path to store trained classifier")
45 | return
46 |
47 | print("loading data")
48 |
49 | #...load training data from file...
50 | file_name = sys.argv[1]
51 | #file_name = 'ds_test_2012.txt'
52 | training, labels_l, att_types = load_dataset(file_name)
53 | #...generate mapping...
54 | classes_dict, classes_l = get_label_mapping(labels_l)
55 | n_classes = len(classes_l)
56 | #...generate mapped labels...
57 | labels_train = get_mapped_labels(labels_l, classes_dict)
58 |
59 | #now, load the testing set...
60 | file_name = sys.argv[2]
61 | testing, test_labels_l, att_types = load_dataset(file_name)
62 | #...generate mapped labels...
63 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
64 |
65 | if len(sys.argv) >= 4:
66 | try:
67 | evaluate = int(sys.argv[3]) > 0
68 | except:
69 | print("Invalid evaluate value")
70 | return
71 | else:
72 | evaluate = True
73 |
74 | if len(sys.argv) >= 5:
75 | try:
76 | make_probabilistic = int(sys.argv[4]) > 0
77 | except:
78 | print("Invalid probab value")
79 | return
80 | else:
81 | make_probabilistic = False
82 |
83 | if len(sys.argv) >= 6:
84 | out_filename = sys.argv[5]
85 | else:
86 | out_filename = sys.argv[1] + ".LSVM"
87 |
88 | print("Training with: " + sys.argv[1])
89 | print("Training Samples: " + str(np.size(training, 0)))
90 | print("Testing with: " + sys.argv[2])
91 | print("Testing Samples: " + str(np.size(testing, 0)))
92 |
93 | start_time = time.time()
94 |
95 | #try scaling...
96 | scaler = StandardScaler()
97 | training = scaler.fit_transform(training)
98 | testing = scaler.transform(testing)
99 |
100 | print("...training...")
101 | classifier = svm.SVC(kernel='linear', probability=make_probabilistic)
102 | classifier.fit(training, np.ravel(labels_train) )
103 |
104 | print("...Saving to file...")
105 | symbol_classifier = SymbolClassifier(SymbolClassifier.TypeSVMLIN, classifier, classes_l, classes_dict, scaler,
106 | make_probabilistic)
107 | out_file = open(out_filename, 'wb')
108 | cPickle.dump(symbol_classifier, out_file, cPickle.HIGHEST_PROTOCOL)
109 | out_file.close()
110 |
111 | if not evaluate:
112 | #...finish early...
113 | end_time = time.time()
114 | total_elapsed = end_time - start_time
115 |
116 | print "Total Elapsed: " + str(total_elapsed)
117 | print("Finished!")
118 | return
119 |
120 | print("...Evaluating...")
121 |
122 | #...first, training error
123 | n_training_samples = np.size(training, 0)
124 | predicted = classifier.predict(training)
125 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
126 | accuracy = (float(total_correct) / float(n_training_samples)) * 100
127 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
128 | print "Training Samples: " + str(n_training_samples)
129 | print "Training Results"
130 | print "Accuracy\tClass Average\tClass STD "
131 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
132 |
133 | #...then, testing error
134 | n_testing_samples = np.size(testing, 0)
135 | predicted = classifier.predict(testing)
136 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
137 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100
138 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
139 | print "Testing Samples: " + str(n_testing_samples)
140 | print "Testing Results"
141 | print "Accuracy\tClass Average\tClass STD "
142 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
143 |
144 | end_time = time.time()
145 | total_elapsed = end_time - start_time
146 |
147 | print "Total Elapsed: " + str(total_elapsed)
148 |
149 | print("Finished")
150 |
151 | main()
152 |
--------------------------------------------------------------------------------
/src/svm_rbf_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import time
26 | import numpy as np
27 | import cPickle
28 | from sklearn import svm
29 | from sklearn.preprocessing import StandardScaler
30 | from dataset_ops import *
31 | from evaluation_ops import *
32 | from symbol_classifier import SymbolClassifier
33 |
34 |
35 | def main():
36 | #usage check...
37 | if len(sys.argv) < 5:
38 | print("Usage: python svm_rbf_classifier.py training testing C Gamma eval [probab]")
39 | print("Where")
40 | print("\ttraining\t= Path to training set")
41 | print("\ttesting\t\t= Path to testing set")
42 | print("\tC\t\t= C parameter of the RBF SVM")
43 | print("\tGamma\t\t= Gamma parameter of the RBF SVM")
44 | print("\teval\t\t= Optional, will not evaluate if equal to 1")
45 | print("\tprobab\t\t= Optional, make it a probabilistic classifier")
46 | print("\toutput\t= Optional, path to store trained classifier")
47 | return
48 |
49 | print("loading data")
50 |
51 | #...load training data from file...
52 | file_name = sys.argv[1]
53 | #file_name = 'ds_test_2012.txt'
54 | training, labels_l, att_types = load_dataset(file_name)
55 | #...generate mapping...
56 | classes_dict, classes_l = get_label_mapping(labels_l)
57 | n_classes = len(classes_l)
58 | #...generate mapped labels...
59 | labels_train = get_mapped_labels(labels_l, classes_dict)
60 |
61 | #now, load the testing set...
62 | file_name = sys.argv[2]
63 | testing, test_labels_l, att_types = load_dataset(file_name)
64 | #...generate mapped labels...
65 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
66 |
67 | try:
68 | rbf_C = float(sys.argv[3])
69 | except:
70 | print("Invalid value for C")
71 | return
72 |
73 | try:
74 | rbf_Gamma = float(sys.argv[4])
75 | except:
76 | print("Invalid value for Gamma")
77 | return
78 |
79 | if len(sys.argv) >= 6:
80 | try:
81 | evaluate = int(sys.argv[5]) > 0
82 | except:
83 | print("Invalid evaluate value")
84 | return
85 | else:
86 | evaluate = True
87 |
88 | if len(sys.argv) >= 7:
89 | try:
90 | make_probabilistic = int(sys.argv[6]) > 0
91 | except:
92 | print("Invalid probab value")
93 | return
94 | else:
95 | make_probabilistic = False
96 |
97 | if len(sys.argv) >= 8:
98 | out_filename = sys.argv[7]
99 | else:
100 | out_filename = sys.argv[1] + ".RSVM"
101 |
102 | print("Training with: " + sys.argv[1])
103 | print("Testing with: " + sys.argv[2])
104 | print("Current C = " + str(rbf_C))
105 | print("Current Gamma = " + str(rbf_Gamma))
106 |
107 | start_time = time.time()
108 |
109 | #try scaling...
110 | scaler = StandardScaler()
111 | training = scaler.fit_transform(training)
112 | testing = scaler.transform(testing)
113 |
114 | print("...training...")
115 | classifier = svm.SVC(C=rbf_C, gamma=rbf_Gamma, probability=make_probabilistic)
116 | classifier.fit(training, np.ravel(labels_train))
117 |
118 | #store the classifier...
119 | print("...Saving to file...")
120 | symbol_classifier = SymbolClassifier(SymbolClassifier.TypeSVMRBF, classifier, classes_l, classes_dict, scaler,
121 | make_probabilistic)
122 | out_file = open(out_filename, 'wb')
123 | cPickle.dump(symbol_classifier, out_file, cPickle.HIGHEST_PROTOCOL)
124 | out_file.close()
125 |
126 | if not evaluate:
127 | #...finish early...
128 | end_time = time.time()
129 | total_elapsed = end_time - start_time
130 |
131 | print "Total Elapsed: " + str(total_elapsed)
132 | print("Finished!")
133 | return
134 |
135 | print("...Evaluating...")
136 | #...first, training error
137 | n_training_samples = np.size(training, 0)
138 | predicted = classifier.predict(training)
139 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
140 | accuracy = (float(total_correct) / float(n_training_samples)) * 100
141 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
142 | print "Training Samples: " + str(n_training_samples)
143 | print "Training Results"
144 | print "Accuracy\tClass Average\tClass STD "
145 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
146 |
147 |
148 | #...then, testing error
149 | n_testing_samples = np.size(testing, 0)
150 | predicted = classifier.predict(testing)
151 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
152 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100
153 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
154 | print "Testing Samples: " + str(n_testing_samples)
155 | print "Testing Results"
156 | print "Accuracy\tClass Average\tClass STD "
157 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
158 |
159 | end_time = time.time()
160 | total_elapsed = end_time - start_time
161 |
162 | print "Total Elapsed: " + str(total_elapsed)
163 |
164 | print("Finished")
165 |
166 | main()
167 |
--------------------------------------------------------------------------------
/src/symbol_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import numpy as np
25 | from traceInfo import TraceInfo
26 | from mathSymbol import MathSymbol
27 |
28 | class SymbolClassifier:
29 | TypeRandomForest = 1
30 | TypeSVMLIN = 2
31 | TypeSVMRBF = 3
32 |
33 | def __init__(self, type, trained_classifier, classes_list, classes_dict, scaler=None, probabilistic=False):
34 | self.type = type
35 | self.trained_classifier = trained_classifier
36 | self.classes_list = classes_list
37 | self.classes_dict = classes_dict
38 | self.scaler = scaler
39 | self.probabilistic = probabilistic
40 |
41 | def predict(self, dataset):
42 | return self.trained_classifier.predict(dataset)
43 |
44 | def predict_proba(self, dataset):
45 | return self.trained_classifier.predict_proba(dataset)
46 |
47 | def get_raw_classes(self):
48 | return self.trained_classifier.classes_
49 |
50 | def get_symbol_from_points(self, points_lists):
51 |
52 | traces = []
53 | for trace_id, point_list in enumerate(points_lists):
54 | object_trace = TraceInfo(trace_id, point_list)
55 |
56 | traces.append(object_trace)
57 |
58 | # apply general trace pre processing...
59 | # 1) first step of pre processing: Remove duplicated points
60 | object_trace.removeDuplicatedPoints()
61 |
62 | # Add points to the trace...
63 | object_trace.addMissingPoints()
64 |
65 | # Apply smoothing to the trace...
66 | object_trace.applySmoothing()
67 |
68 | # it should not ... but .....
69 | if object_trace.hasDuplicatedPoints():
70 | # ...remove them! ....
71 | object_trace.removeDuplicatedPoints()
72 |
73 | new_symbol = MathSymbol(0, traces, '{Unknown}')
74 |
75 | # normalize size and locations
76 | new_symbol.normalize()
77 |
78 | return new_symbol
79 |
80 | def get_symbol_features(self, symbol):
81 | # get raw features
82 | features = symbol.getFeatures()
83 |
84 | # put them in python format
85 | mat_features = np.mat(features, dtype=np.float64)
86 |
87 | # automatically transform features
88 | if self.scaler is not None:
89 | mat_features = self.scaler.transform(mat_features)
90 |
91 | return mat_features
92 |
93 | def classify_points(self, points_lists):
94 | symbol = self.get_symbol_from_points(points_lists)
95 |
96 | return self.classify_symbol(symbol)
97 |
98 | def classify_points_prob(self, points_lists, top_n=None):
99 | symbol = self.get_symbol_from_points(points_lists)
100 |
101 | return self.classify_symbol_prob(symbol, top_n)
102 |
103 | def classify_symbol(self, symbol):
104 | features = self.get_symbol_features(symbol)
105 |
106 | predicted = self.trained_classifier.predict(features)
107 |
108 | return self.classes_list[predicted[0]]
109 |
110 |
111 | def classify_symbol_prob(self, symbol, top_n=None):
112 | features = self.get_symbol_features(symbol)
113 |
114 | try:
115 | predicted = self.trained_classifier.predict_proba(features)
116 | except:
117 | raise Exception("Classifier was not trained as probabilistic classifier")
118 |
119 | scores = sorted([(predicted[0, k], k) for k in range(predicted.shape[1])], reverse=True)
120 |
121 | tempo_classes = self.trained_classifier.classes_
122 | n_classes = len(tempo_classes)
123 | if top_n is None or top_n > n_classes:
124 | top_n = n_classes
125 |
126 | confidences = [(self.classes_list[tempo_classes[scores[k][1]]], scores[k][0]) for k in range(top_n)]
127 |
128 | return confidences
129 |
130 |
131 |
--------------------------------------------------------------------------------
/src/test_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 | import sys
26 | import cPickle
27 | from symbol_classifier import SymbolClassifier
28 |
29 | def main():
30 | # usage check...
31 | if len(sys.argv) < 2:
32 | print("Usage: python test_classifer.py classifier")
33 | print("Where")
34 | print("\tclassifier\t= Path to trained symbol classifier")
35 | return
36 |
37 | classifier_file = sys.argv[1]
38 |
39 | print("Loading classifier")
40 |
41 | in_file = open(classifier_file, 'rb')
42 | classifier = cPickle.load(in_file)
43 | in_file.close()
44 |
45 | if not isinstance(classifier, SymbolClassifier):
46 | print("Invalid classifier file!")
47 | return
48 |
49 | # get mapping
50 | classes_dict = classifier.classes_dict
51 | classes_l = classifier.classes_list
52 | n_classes = len(classes_l)
53 |
54 | # create test characters from points
55 | sample_x = [[(-0.8, -0.8), (0.1, 0.12), (0.8, 0.79)], [(-0.85, 0.79), (0.01, 0.005), (0.79, -0.83)]]
56 | sample_1 = [[(0.15, 0.7), (0.2, 1.0), (0.21, -1.2)], [(0.05, -1.19), (0.3, -1.25)]]
57 | sample_eq = [[(-1.5, -0.4), (1.5, -0.4)], [(-1.5, 0.4), (1.5, 0.4)]]
58 |
59 | # classify them
60 | class_x = classifier.classify_points(sample_x)
61 | class_1 = classifier.classify_points(sample_1)
62 | class_eq = classifier.classify_points(sample_eq)
63 |
64 | print("X classified as " + class_x)
65 | print("1 classified as " + class_1)
66 | print("= classified as " + class_eq)
67 |
68 | # now with confidence ...
69 | classes_x = classifier.classify_points_prob(sample_x, 3)
70 | classes_1 = classifier.classify_points_prob(sample_1, 3)
71 | classes_eq = classifier.classify_points_prob(sample_eq, 3)
72 |
73 | print("X top classes are: " + str(classes_x))
74 | print("1 top classes are: " + str(classes_1))
75 | print("= top classes are: " + str(classes_eq))
76 |
77 | if __name__ == '__main__':
78 | main()
--------------------------------------------------------------------------------
/src/test_classify_inkml.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2016 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 |
25 | import sys
26 | import cPickle
27 | from load_inkml import *
28 | from symbol_classifier import SymbolClassifier
29 |
30 | def main():
31 | # usage check...
32 | if len(sys.argv) < 3:
33 | print("Usage: python test_classifer.py classifier inkml_file [save_svg] [svg_path]")
34 | print("Where")
35 | print("\tclassifier\t= Path to trained symbol classifier")
36 | print("\tinkml_file\t= Path to inkml file to load")
37 | print("\tsave_svg\t= Optional, save SVG file for each symbol")
38 | print("\t\t0 - No (Default)")
39 | print("\t\t1 - Save SVG after pre-processing")
40 | print("\t\t2 - Save SVG before pre-processing")
41 | print("\tsvg_path\t= Optional, path prefix used for saved SVG files")
42 | return
43 |
44 | classifier_file = sys.argv[1]
45 | inkml_file = sys.argv[2]
46 |
47 | if len(sys.argv) >= 4:
48 | try:
49 | save_sgv = int(sys.argv[3])
50 | if save_sgv < 0 or save_sgv > 2:
51 | print("Invalid value for save_svg")
52 | return
53 | except:
54 | print("Invalid value for save_svg")
55 | return
56 | else:
57 | save_sgv = 0
58 |
59 | if len(sys.argv) >= 5:
60 | svg_path = sys.argv[4]
61 | else:
62 | svg_path = ""
63 |
64 | print("Loading classifier")
65 |
66 | in_file = open(classifier_file, 'rb')
67 | classifier = cPickle.load(in_file)
68 | in_file.close()
69 |
70 | if not isinstance(classifier, SymbolClassifier):
71 | print("Invalid classifier file!")
72 | return
73 |
74 | ground_truth_available = True
75 | try:
76 | symbols = load_inkml(inkml_file, True )
77 | except:
78 | ground_truth_available = False
79 | try:
80 | symbols = load_inkml(inkml_file, False)
81 | except:
82 | print("Failed processing: " + inkml_file)
83 | return
84 |
85 | # run the classifier for each symbol in the file
86 | for symbol in symbols:
87 | trace_ids = [str(trace.id) for trace in symbol.traces]
88 | print("Symbol id: " + str(symbol.id))
89 | print("=> Traces: " + ", ".join(trace_ids))
90 | print("=> Ground Truth Class: " + symbol.truth)
91 |
92 | # classify
93 | main_class = classifier.classify_symbol(symbol)
94 | print("=> Predicted Class: " + main_class)
95 |
96 | top_5 = classifier.classify_symbol_prob(symbol, 5)
97 | top_5_desc = [class_name + " ({0:.2f}%)".format(prob * 100) for class_name, prob in top_5]
98 | print("=> Top 5 classes: " + ",".join(top_5_desc))
99 | print("")
100 |
101 | if save_sgv == 2:
102 | symbol.swapPoints()
103 |
104 | if save_sgv >= 1:
105 | symbol.saveAsSVG(svg_path + "symbol_" + str(symbol.id) + ".svg")
106 |
107 |
108 | print("Finished")
109 |
110 | if __name__ == '__main__':
111 | main()
--------------------------------------------------------------------------------
/src/train_adaboost.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import sys
25 | import ctypes
26 | import time
27 | import numpy as np
28 | from dataset_ops import *
29 | from evaluation_ops import *
30 |
31 | #=====================================================================
32 | # this program takes as input a data set and through AdaBoost M1 it generates
33 | # a set of decision trees to classify data. the classifier is then
34 | # written to a file for further use.
35 | #
36 | # Created by:
37 | # - Kenny Davila (Dec 27, 2013)
38 | # Modified By:
39 | # - Kenny Davila (Dec 27, 2013)
40 | # - Kenny Davila (Jan 19, 2013)
41 | # - Dataset functions are now imported from external file
42 | # - Kenny Davila (Jan 27, 2013)
43 | # - Added mapping
44 | # - Save result to file
45 | #
46 | #=====================================================================
47 |
48 | c45_lib = ctypes.CDLL('./adaboost_c45.so')
49 |
50 |
51 | def ensemble_predict(classifier, data):
52 | n_samples = np.size(data, 0)
53 | predicted = np.zeros(n_samples)
54 |
55 | for k in range(n_samples):
56 | c_sample = data[k, :]
57 | c_sample_p = c_sample.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
58 |
59 | predicted[k] = c45_lib.boosted_c45_classify(classifier, c_sample_p)
60 |
61 | return predicted
62 |
63 | def main():
64 | #usage check
65 | if len(sys.argv) < 4:
66 | print("Usage: python train_adaboost.py training_set testing_set rounds [mapped] [save_weights]")
67 | print("Where")
68 | print("\ttraining_set\t= Path to the file of the training set")
69 | print("\ttesting_set\t= Path to the file of the testing set")
70 | print("\trounds\t\t= Rounds to use for AdaBoost")
71 | print("\tmapped\t= Optional, reads class mapping if specified")
72 | print("\tsave_weights\t= Optional, save the weights used for boosting")
73 | return
74 |
75 | try:
76 | rounds = int(sys.argv[3])
77 | if ( rounds < 1):
78 | print("Must be at least 1 round")
79 | return
80 | except:
81 | print("Invalid number of rounds")
82 | return
83 |
84 | #...load training data from file...
85 | #file_name = 'ds_2012.txt'
86 | file_name = sys.argv[1]
87 | #file_name = 'ds_test_2012.txt'
88 | training, labels_l, att_types = load_dataset(file_name)
89 | #...generate mapping...
90 | classes_dict, classes_l = get_label_mapping(labels_l)
91 | #...generate mapped labels...
92 | labels_train = get_mapped_labels(labels_l, classes_dict)
93 |
94 | if len(sys.argv) >= 5:
95 | try:
96 | mapped = int(sys.argv[4]) > 0
97 | except:
98 | mapped = False
99 |
100 | #...used for the clustered training set
101 | if mapped:
102 | o_classes_l, o_classes_dict, class_mapping = load_label_mapping(file_name + ".mapping.txt")
103 | else:
104 | mapped =False
105 |
106 | if len(sys.argv) >= 6:
107 | try:
108 | save_weights = int(sys.argv[5]) > 0
109 | except:
110 | save_weights = False
111 | else:
112 | save_weights = False
113 |
114 | #...load testing data from file...
115 | #file_name = 'ds_test_2012.txt'
116 | file_name = sys.argv[2]
117 | #file_name = 'ds_2012.txt'
118 | testing, test_labels_l, att_types = load_dataset(file_name)
119 | #...generate mapped labels...
120 | if mapped:
121 | labels_test = get_mapped_labels(test_labels_l, o_classes_dict)
122 | else:
123 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
124 |
125 | #... for learning....
126 | n_samples = np.size(training, 0)
127 | n_atts = np.size(att_types, 0)
128 | n_classes = len( classes_l )
129 | print "Samples : " + str(n_samples)
130 | print "Atts: " + str(n_atts)
131 | print "Classes: " + str(n_classes)
132 |
133 | #...create distribution
134 | init_value = 1.0 / float(n_samples)
135 | distrib_train = np.zeros( (n_samples, 1), dtype=np.float64 )
136 | for i in range(n_samples):
137 | distrib_train[i, 0] = init_value
138 |
139 | #prepare to pass data to C-side
140 | samples_p = training.ctypes.data_as( ctypes.POINTER( ctypes.c_double ) )
141 |
142 | labels_p = labels_train.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) )
143 | atts_p = att_types.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) )
144 |
145 | #for testing
146 | n_test_samples = np.size(testing, 0)
147 |
148 | start_time = time.time()
149 |
150 | #test construction...
151 | m_rounds = rounds
152 | ensemble = c45_lib.created_boosted_c45( samples_p, labels_p, atts_p, n_samples, n_atts, n_classes, m_rounds, 5, 1, "Fold #1")
153 |
154 | print "Saving..."
155 | c45_lib.boosted_c45_save(ensemble, sys.argv[1] + ".bc45")
156 |
157 | if save_weights:
158 | c45_lib.boosted_c45_save_training_weights(ensemble, sys.argv[1] + ".weights.bin", 0)
159 | c45_lib.boosted_c45_save_training_weights(ensemble, sys.argv[1] + ".weights.txt", 1)
160 |
161 | #test accuracy of tree...
162 |
163 | #...training...
164 | predicted = ensemble_predict(ensemble, training)
165 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
166 | accuracy = (float(total_correct) / float(n_samples)) * 100
167 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
168 | print "Training Samples: " + str(n_samples)
169 | print "Training Results"
170 | print "Accuracy\tClass Average\tClass STD "
171 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
172 |
173 | #...testing...
174 | n_testing_samples = np.size(testing, 0)
175 | predicted = ensemble_predict(ensemble, testing)
176 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
177 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100
178 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
179 | print "Testing Samples: " + str(n_testing_samples)
180 | print "Testing Results"
181 | print "Accuracy\tClass Average\tClass STD "
182 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
183 |
184 |
185 | c45_lib.release_boosted_c45( ensemble )
186 |
187 | end_time = time.time()
188 | total_elapsed = end_time - start_time
189 | print "Total Elapsed: " + str(total_elapsed)
190 | print("Training: " + sys.argv[1])
191 | print("Testing: " + sys.argv[2])
192 |
193 | main()
194 |
195 |
196 |
--------------------------------------------------------------------------------
/src/train_c45.py:
--------------------------------------------------------------------------------
1 | """
2 | DPRL Math Symbol Recognizers
3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi
4 |
5 | This file is part of DPRL Math Symbol Recognizers.
6 |
7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with DPRL Math Symbol Recognizers. If not, see .
19 |
20 | Contact:
21 | - Kenny Davila: kxd7282@rit.edu
22 | - Richard Zanibbi: rlaz@cs.rit.edu
23 | """
24 | import ctypes
25 | import time
26 | import sys
27 | import numpy as np
28 | from dataset_ops import *
29 | from evaluation_ops import *
30 |
31 | #=====================================================================
32 | # this program takes as input a data set and through AdaBoost M1 it generates
33 | # a set of decision trees to classify data. The data set is divided in folds
34 | # and cross validation is then applied, the best classifier of all the
35 | # folds is then written to a file for further use.
36 | #
37 | # Created by:
38 | # - Kenny Davila (Dec 27, 2013)
39 | # Modified By:
40 | # - Kenny Davila (Jan 16, 2013)
41 | # - Kenny Davila (Jan 19, 2013)
42 | # - Dataset functions are now imported from external file
43 | #
44 | #=====================================================================
45 |
46 | c45_lib = ctypes.CDLL('./adaboost_c45.so')
47 |
48 |
49 | def get_majority_class(labels, weights, classes):
50 | l_counts = {}
51 | n_samples = np.size( labels, 0)
52 | majority_class = None
53 | majority_w = 0
54 | for i in range(n_samples):
55 | label = labels[i, 0]
56 |
57 | if label in l_counts:
58 | l_counts[ label ] += weights[i, 0]
59 | else:
60 | l_counts[ label ] = weights[i, 0]
61 |
62 | if majority_class == None:
63 | majority_class = label
64 | majority_w = l_counts[ label ]
65 | else:
66 | if majority_w < l_counts[ label ]:
67 | majority_w = l_counts[ label ]
68 | majority_class = label
69 |
70 | """
71 | for l in l_counts:
72 | print str(classes[l]) + " - " + str(l) + " - " + str(l_counts[l] )
73 | """
74 |
75 | return majority_class, majority_w
76 |
77 |
78 | def tree_predict(root, data):
79 | n_samples = np.size(data, 0)
80 | predicted = np.zeros(n_samples)
81 |
82 | for k in range(n_samples):
83 | c_sample = data[k, :]
84 | c_sample_p = c_sample.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
85 |
86 | predicted[k] = c45_lib.c45_node_evaluate(root, c_sample_p)
87 |
88 | return predicted
89 |
90 |
91 | def main():
92 | if len(sys.argv) < 3:
93 | print("Usage: python train_c45.py training_set testing_set [mapped]")
94 | print("Where")
95 | print("\ttraining_set\t= Path to the file of the training set")
96 | print("\ttesting_set\t= Path to the file of the testing set")
97 | print("\tmapped\t= Optional, reads class mapping if specified")
98 | return
99 |
100 | #...load training data from file...
101 | #file_name = 'ds_2012.txt'
102 | file_name = sys.argv[1]
103 | training, labels_l, att_types = load_dataset(file_name)
104 | #...generate mapping...
105 | classes_dict, classes_l = get_label_mapping(labels_l)
106 | #...generate mapped labels...
107 | labels_train = get_mapped_labels(labels_l, classes_dict)
108 |
109 | #...check for class mapping...
110 | if len(sys.argv) >= 4:
111 | try:
112 | mapped = int(sys.argv[3]) > 0
113 | except:
114 | mapped = False
115 |
116 | #...will over-write the class list and class dict, and will extract the class mapping
117 | #...used for the clustered training set
118 | if mapped:
119 | o_classes_l, o_classes_dict, class_mapping = load_label_mapping(file_name + ".mapping.txt")
120 |
121 | else:
122 | mapped = False
123 |
124 |
125 | #...load testing data from file...
126 | #file_name = 'ds_test_2012.txt'
127 | file_name = sys.argv[2]
128 | testing, test_labels_l, att_types = load_dataset(file_name)
129 | #...generate mapped labels...
130 | if mapped:
131 | labels_test = get_mapped_labels(test_labels_l, o_classes_dict)
132 | else:
133 | labels_test = get_mapped_labels(test_labels_l, classes_dict)
134 | #print("UNCOMMENT MAPPING OF TESTING LABELS")
135 |
136 |
137 |
138 | print("Training: " + sys.argv[1])
139 | print("Testing: " + sys.argv[2])
140 |
141 | #for learning and testing....
142 | n_samples = np.size(training, 0)
143 | n_test_samples = np.size(testing, 0)
144 | n_atts = np.size(att_types, 0)
145 | n_classes = len( classes_l )
146 | print "Training Samples : " + str(n_samples)
147 | print "Testing Samples : " + str(n_test_samples)
148 | print "Atts: " + str(n_atts)
149 | print "Classes: " + str(n_classes)
150 |
151 | #...create distribution
152 | init_value = 1.0 / float(n_samples)
153 | distrib_train = np.zeros((n_samples, 1), dtype=np.float64)
154 | for i in range(n_samples):
155 | distrib_train[i, 0] = init_value
156 |
157 | #prepare to pass data to C-side
158 | samples_p = training.ctypes.data_as( ctypes.POINTER( ctypes.c_double ) )
159 |
160 | labels_p = labels_train.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) )
161 | atts_p = att_types.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) )
162 | distrib_p = distrib_train.ctypes.data_as( ctypes.POINTER( ctypes.c_double ) )
163 |
164 |
165 |
166 | #get majority class...
167 | majority_l, majority_w = get_majority_class(labels_train, distrib_train, classes_l)
168 |
169 | parent = None
170 | majority = int(majority_l)
171 | #print majority
172 | #print type(majority)
173 |
174 | start_time = time.time()
175 |
176 | #test construction...
177 | times = 1
178 | for i in range( times ):
179 | root = c45_lib.c45_tree_construct(samples_p, labels_p, atts_p, distrib_p, n_samples, n_atts, n_classes, majority, 5)
180 |
181 | """
182 | #un-comment to test file
183 | print "to write ... "
184 | c45_lib.c45_save_to_file( root, "tree_2012.tree" )
185 | print "to read ... "
186 | root = c45_lib.c45_load_from_file( "tree_2012.tree" )
187 | #return
188 | """
189 |
190 | #test accuracy of tree...
191 | predicted = tree_predict(root, training)
192 |
193 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
194 | accuracy = (float(total_correct) / float(n_samples)) * 100
195 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
196 | print "Results\tAccuracy\tClass Average\tClass STD "
197 | print "BP. Training\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
198 |
199 | #...testing...
200 | predicted = tree_predict(root, testing)
201 |
202 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
203 | accuracy = (float(total_correct) / float(n_test_samples)) * 100
204 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
205 | print "BP. Testing\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
206 |
207 | #do pruning
208 | c45_lib.c45_prune_tree.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_double )
209 | #c45_lib.c45_prune_tree( root, 0, 0.25 )
210 | c45_lib.c45_prune_tree(root, 1, 0.25)
211 |
212 | #test accuracy of tree...
213 | #...training...
214 | predicted = tree_predict(root, training)
215 |
216 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes)
217 | accuracy = (float(total_correct) / float(n_samples)) * 100
218 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
219 | print "AP. Training\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
220 |
221 | #...testing...
222 | predicted = tree_predict(root, testing)
223 |
224 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes)
225 | accuracy = (float(total_correct) / float(n_test_samples)) * 100
226 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes)
227 | print "AP. Testing\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0)
228 |
229 |
230 | c45_lib.c45_node_release( root, 1 )
231 |
232 | end_time = time.time()
233 | total_elapsed = end_time - start_time
234 | mean_elapsed = total_elapsed / times
235 | print "Total Elapsed: " + str(total_elapsed)
236 | print "Mean Elapsed: " + str(mean_elapsed)
237 | print("")
238 |
239 | main()
240 |
241 |
242 |
--------------------------------------------------------------------------------