├── .DS_Store ├── AUTHORS.txt ├── LICENSE.txt ├── README.md ├── doc ├── README_apply_PCA_parameters.txt ├── README_boosted_test.txt ├── README_correct_labels.txt ├── README_count_common.txt ├── README_dataset_info.txt ├── README_extract_symbol.txt ├── README_get_PCA_parameters.txt ├── README_get_enhanced_clustered_set.txt ├── README_get_training_set.txt ├── README_parallel_evaluate.txt ├── README_parallel_prob_evaluate.txt ├── README_random_forest_classify.txt ├── README_svm_lin_classifier.txt ├── README_svm_rbf_classifier.txt ├── README_train_adaboost.txt └── README_train_c45.txt ├── server ├── PenStrokeServer.py ├── best_full2013_SVMRBF_new.dat ├── generic_symbol_table.csv ├── mathSymbol.py ├── start ├── symbol_classifier.py ├── test_classifier.py └── traceInfo.py └── src ├── adaboost_c45.c ├── ambiguous.txt ├── apply_PCA_parameters.py ├── boosted_test.py ├── c45.h ├── correct_labels.py ├── count_common.py ├── dataset_info.py ├── dataset_ops.py ├── distorter.py ├── distorter_lib.c ├── evaluation_ops.py ├── extract_symbol.py ├── get_PCA_parameters.py ├── get_enhanced_clustered_set.py ├── get_training_set.py ├── load_inkml.py ├── mathSymbol.py ├── parallel_evaluate.py ├── parallel_prob_evaluate.py ├── random_forest_classify.py ├── svm_lin_classifier.py ├── svm_rbf_classifier.py ├── symbol_classifier.py ├── test_classifier.py ├── test_classify_inkml.py ├── traceInfo.py ├── train_adaboost.py └── train_c45.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DPRL/MathSymbolRecognizer/21a68677770a953d176de57356685027e27c0414/.DS_Store -------------------------------------------------------------------------------- /AUTHORS.txt: -------------------------------------------------------------------------------- 1 | DPRL Math Symbol Recognizers 2 | Copyright (c) 2014 Kenny Davila, Richard Zanibbi 3 | 4 | This collection of classifiers for isolated handwritten math symbols was 5 | created by Kenny Davila on-and-off over a year-and-a-half. The code was 6 | first created for a project for Dr. Zanibbi's Pattern Recognition course in 7 | Fall 2012, was then used for the RIT entry in the CROHME 2013 handwritten math 8 | recognition competition, and was then later extended and revised under Dr. 9 | Zanibbi's supervision. 10 | 11 | The system was developed at the Rochester Institute of Technology in 12 | Rochester, New York, USA, primarily in the Document and Pattern Recognition 13 | Lab (DPRL) located within the Computer Science Department. 14 | 15 | Please see the LICENSE.txt file for terms of use (the system is being 16 | distributed under the GNU General Public License version 3). 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPRL Math Symbol Recognizers 2 | 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | *** 5 | 6 | RIT DPRL Math Symbol Recognizers is free software: you can redistribute it 7 | and/or modify it under the terms of the GNU General Public License as published 8 | by the Free Software Foundation, either version 3 of the License, or (at your 9 | option) any later version. 10 | 11 | RIT DPRL Math Symbol Recognizers is distributed in the hope that it will be 12 | useful, but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General 14 | Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License along with 17 | RIT DPRL Math Symbol Recognizers. If not, see . 18 | 19 | Contact: 20 | * Kenny Davila: kxd7282@rit.edu 21 | * Richard Zanibbi: rlaz@cs.rit.edu 22 | 23 | *** 24 | The system divides is composed of different tools for data extraction, training 25 | and evaluation and other miscellaneous tools for isolated math symbol 26 | recognition. 27 | 28 | A README file is included in the doc/ directory for each tool that describes 29 | its purpose, how to use it and what its parameters are. 30 | 31 | A README file is included in the doc/ directory for each tool that describes 32 | its purpose, how to use it and what its parameters are. 33 | 34 | The executable scripts on this release are the following: 35 | 36 | * Preprocessing of data: 37 | apply_PCA_parameters.py 38 | correct_labels.py 39 | get_enhanced_clustered_set.py 40 | get_PCA_parameters.py 41 | get_training_set.py 42 | * Analysis of datasets: 43 | count_common.py 44 | dataset_info.py 45 | extract_symbol.py 46 | 47 | * Training a symbol classifier: 48 | random_forest_classify.py 49 | svm_lin_classifier.py 50 | svm_rbf_classifier.py 51 | train_adaboost.py 52 | train_c45.py 53 | * Tools for evaluation 54 | boosted_test.py 55 | parallel_evaluate.py 56 | parallel_prob_evaluate.py 57 | 58 | 59 | 60 | *** 61 | # SOURCE FILES 62 | Source code (in Python and C) is provided in the src/ directory. -------------------------------------------------------------------------------- /doc/README_apply_PCA_parameters.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for Application of PCA parameters 3 | 4 | The tool will take a file containing a dataset, a file containing PCA parameters 5 | (Normalization, Eigenvectors, Eigenvalues, Max PC to use) and will produce the projected version 6 | of the dataset and store it in the specified file. 7 | 8 | Usage: python apply_PCA_parameters.py training_set PCA_params output 9 | Where 10 | training_set = Path to the file of the training_set 11 | PCA_params = Path to the file of the PCA parameters 12 | output = File to output the final dataset with reduced dimensionality 13 | 14 | 15 | -------------------------------------------------------------------------------- /doc/README_boosted_test.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for evaluating AdaBoost C4.5 classifier 3 | 4 | This tool evaluates the training and testing performance of an AdaBoost with C4.5 decision trees 5 | classifier. It stores the final confusion matrix into a file. It can also save the failure cases 6 | for the testing set in .SVG format if the file with the sources for each sample is present on 7 | current directory. Saving failure cases requires a directory called "output" to be created on 8 | the same directory where the tool is located. 9 | 10 | Usage: python boosted_test.py classifier training_set testing_set [save_fail] 11 | Where 12 | classifier = Path to the .bc45 file that contains the classifier 13 | training_set = Path to the file of the training set 14 | testing_set = Path to the file of the testing set 15 | save_fail = Optional, will output failure case if specified 16 | 17 | 18 | ================================================ 19 | Notes 20 | ================================================ 21 | 22 | - All tools that work with AdaBoost require the AdaBoost library to be located 23 | on the same directory. This is a shared library that can be compiled from 24 | the provided source code using the following commands: 25 | 26 | Linux: 27 | gcc -shared -fPIC adaboost_c45.c -o adaboost_c45.so 28 | 29 | Windows (using MinGW): 30 | gcc -shared adaboost_c45.c -o adaboost_c45.so 31 | 32 | -------------------------------------------------------------------------------- /doc/README_correct_labels.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for correcting labels in CROHME datasets 3 | 4 | CROHME 2013 datasets contain more labels than classes as defined by the competition. To fix this 5 | issue use this tool to automatically correct the additional labels to the set of 101 valid classes. 6 | 7 | Usage: python correct_labels.py training_set output_set 8 | Where 9 | training_set = Path to the file of the training_set 10 | output_set = File to output file with corrected labels 11 | 12 | -------------------------------------------------------------------------------- /doc/README_count_common.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for count and extraction of data from common classes between datasets 3 | 4 | Use count_common.py to analyze sets of common labels between two given datasets. 5 | If specified, the program can extract the data of the common classes from the second 6 | dataset and store it in a new file called: [Name of dataset 2]".common.txt". 7 | 8 | This tool can be used for example to extract data from CROHME 2013 dataset that 9 | have the same labels as a CROHME 2012 dataset. 10 | 11 | 12 | Usage: python count_common.py dataset_1 dataset_2 extract 13 | Where 14 | dataset_1 = Path to the file of the first dataset 15 | dataset_2 = Path to the file of the second dataset 16 | extract = Extract common samples from second dataset 17 | 18 | -------------------------------------------------------------------------------- /doc/README_dataset_info.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool to extract general information from dataset file 3 | 4 | Use dataset_info.py to compute the number of classes present in dataset file and 5 | the number of samples per class. It also shows the number of current attributes. 6 | An histogram of the class representation is built using the specified number of 7 | bins (10 by default). 8 | 9 | Usage: python dataset_info.py dataset [n_bins] 10 | Where 11 | dataset = Path to file that contains the data set 12 | n_bins = Optional, number of bins for histogram of class representation 13 | 14 | -------------------------------------------------------------------------------- /doc/README_extract_symbol.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool to extract an specific symbol from an inkml file 3 | 4 | Use extract_symbol.py to extract a symbol and store in .SVG format on the specified 5 | output file. Requires the full path to the inkml_file where the desired symbol is 6 | located and the id of the symbol to extract from that file. This tool is useful to 7 | visualize special cases. 8 | 9 | Usage: python extract_symbol.py inkml_file sym_id output 10 | Where 11 | inkml_file = Path to the inkml file that contains the symbol 12 | sym_id = Id of the symbol to extract 13 | output = File where the extracted symbol will be stored 14 | 15 | -------------------------------------------------------------------------------- /doc/README_get_PCA_parameters.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for extraction of PCA parameters 3 | 4 | The tool will take a file containing a dataset and will compute the corresponding 5 | PCA parameters (Normalization, Eigenvectors, Eigenvalues, max number of PCA). 6 | and will store them to a file. The user can specify a maximum number of PCA to use 7 | and a maximum value of variance, the program will use the smallest of these two 8 | as the final number of principal components to store. 9 | 10 | Usage: python get_PCA_parameters.py training_set output_params var_max k_max 11 | Where 12 | training_set = Path to the file of the training_set 13 | output_params = File to output PCA preprocessing parameters 14 | var_max = Maximum percentage of variance to add 15 | k_max = Maximum number of Principal Components to Add 16 | -------------------------------------------------------------------------------- /doc/README_get_enhanced_clustered_set.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool used for expansion of a dataset 3 | 4 | Use get_enhanced_clustered_set.py to expand a dataset using synthetic data. Two ways of expansion are 5 | provided, one for under-represented classes and the second for over-represented classes. The min_prc 6 | parameter defines a threshold that separates the under_represented classes from the over-represented. 7 | Note that min_prc also defines the minimum representation per class based on the size of the largest class. 8 | For example, if the largest class has 5,000 samples and min_prc is set to 0.20, then all classes will have 9 | at least 20% x 5,0000 = 1,000 samples. For over-represented classes clustering is applied, and then the 10 | clust_prc parameter is used to defined the minimum final size for each cluster per class based on the size 11 | of largest cluster per each class. 12 | 13 | Usage: python get_enhanced_clustered_set.py inkml_path output min_prc diag_dist 14 | max_clusters clust_prc [verbose] [count_only] 15 | Where 16 | inkml_path = Path to directory that contains the inkml files 17 | output = File name of the output file 18 | min_prc = Minimum representation based on (%) of largest class 19 | diag_dist = Distortion factor relative to length of main diagonal 20 | max_clusters = Maximum number of clusters per large class 21 | clust_prc = Minimum cluster size based on (%) of largest 22 | verbose = Optional, print detailed messages 23 | count_only = Will only count what will be the final size of dataset 24 | 25 | 26 | ================================================ 27 | Notes 28 | ================================================ 29 | 30 | - All tools that work with synthetic data generation require the distorter library 31 | to be located on the same directory. This is a shared library that can be compiled from 32 | the provided source code using the following commands: 33 | 34 | Linux: 35 | gcc -shared -fPIC distorter_lib.c -o distorter_lib.so 36 | 37 | Windows (using MinGW): 38 | gcc -shared distorter_lib.c -o distorter_lib.so 39 | 40 | - To only expand data without using the clustering option for large classes 41 | use max_clusters = 1 and clust_prc = 0.0. -------------------------------------------------------------------------------- /doc/README_get_training_set.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool used to produce a training set from INKML files 3 | 4 | Use get_training_set.py tool to process and extract the features of the isolated 5 | symbols present in a set of INKML files. The system first extract all symbols found 6 | in the inkml files present in the given directory and then for each symbol it will 7 | extract the current set of features as defined in MathSymbol.py. Then, final dataset 8 | ready to use for training will be stored in the specified output file. 9 | 10 | Usage: python get_training_set.py inkml_path output 11 | Where 12 | inkml_path = Path to directory that contains the inkml files 13 | output = File name of the output file 14 | -------------------------------------------------------------------------------- /doc/README_parallel_evaluate.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for evaluation of classifier performance in parallel 3 | 4 | Use parallel_evaluate.py to evaluate the performance of a given classifier over 5 | training and testing sets in parallel threads for shorter evaluation time. 6 | Currently, this tool only supports classifiers from the Scikit-learn library like 7 | SVC (support vector classifier) and Random Forests. It does not support AdaBoost with 8 | C4.5 classifier. For AdaBoost C4.5 classifier use "boosted_test.py". 9 | 10 | The tool computes global and per class accuracy with the corresponding confusion matrix. If 11 | specified, the tool computes the the global and per class accuracy with confusion matrix for 12 | the case where errors between ambiguous classes are ignored. The list of ambiguous classes 13 | simply contains pairs of symbols considered ambiguous, one pair per line. Separate the 14 | symbols on each line with comma. 15 | 16 | 17 | Usage: python parallel_evaluate.py training_set testing_set classifier normalize 18 | workers [test_only] [ambiguous] 19 | Where 20 | training_set = Path to the file of the training set 21 | testing_set = Path to the file of the testing set 22 | classifier = File that contains the pickled classifier 23 | normalized = Whether training data was normalized prior training 24 | workers = Number of parallel threads to use 25 | test_only = Optional, Only execute for testing set 26 | ambiguous = Optional, file that contains the list of ambiguous 27 | 28 | 29 | ================================================ 30 | Notes 31 | ================================================ 32 | 33 | - For SVM classifiers remember to use the "normalized" option as 1. -------------------------------------------------------------------------------- /doc/README_parallel_prob_evaluate.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for evaluation of probabilistic classifier performance in parallel 3 | 4 | Use parallel_prob_evaluate.py to evaluate the performance of a given probabilistic 5 | classifier over training and testing sets in parallel threads for shorter evaluation time. 6 | Currently, this tool only supports classifiers from the Scikit-learn library like 7 | SVC (support vector classifier) trained with the probabilistic option and Random Forests. 8 | It does not support AdaBoost with C4.5 classifier. For AdaBoost C4.5 classifier use "boosted_test.py". 9 | 10 | The tool computes global and per class accuracy from Top-1 to Top-5. If 11 | specified, the tool computes the the global and per class accuracy Top-1 to Top-5 for 12 | the case where errors between ambiguous classes are ignored. The list of ambiguous classes 13 | simply contains pairs of symbols considered ambiguous, one pair per line. Separate the 14 | symbols on each line with comma. 15 | 16 | 17 | Usage: python parallel_prob_evaluate.py training_set testing_set classifier normalize 18 | workers [test_only] [ambiguous] 19 | Where 20 | training_set = Path to the file of the training set 21 | testing_set = Path to the file of the testing set 22 | classifier = File that contains the pickled classifier 23 | normalized = Whether training data was normalized prior training 24 | workers = Number of parallel threads to use 25 | test_only = Optional, Only execute for testing set 26 | ambiguous = Optional, file that contains the list of ambiguous 27 | 28 | 29 | ================================================ 30 | Notes 31 | ================================================ 32 | 33 | - For SVM classifiers remember to use the "normalized" option as 1. 34 | - For SVM classifiers, only the ones trained as probabilistic are supported. -------------------------------------------------------------------------------- /doc/README_random_forest_classify.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for training and testing of Random Forest classifiers 3 | 4 | Use random_forest_classify.py to train and test the performance of Random Forest classifiers 5 | over the given training and testing sets. Must specify the number of decision trees to include 6 | on each forest and the maximum depth of each tree. Also, the current implementation of random 7 | forest used is included in the Scikit-Learn library. This implementation randomize the available 8 | features at each split, and the max_feats parameter is used to set the maximum number of features 9 | to consider at each split. Two split criterion are available: Gini impurity and Entropy. 10 | 11 | Since random forests introduce randomness on the training process, a single classifier might not be 12 | enough to define the final performance of this type of classifier over the given dataset. The 13 | parameter times allows the user to define how many classifiers will be trained and tested and the 14 | final mean of them is going to be used for the final metric. At the end, the system keeps the classifier 15 | that achieved the highest global accuracy and stores it to a final named: [Name of training set]".best.RF". 16 | 17 | Use n_jobs parameter to define the number of threads to use during the training process. if omitted, 18 | everything will be done on a single thread. 19 | 20 | Usage: python random_forest_classify.py training_set testing_set N_trees max_D max_feats 21 | type times [n_jobs] 22 | Where 23 | training_set = Path to the file of the training set 24 | testing_set = Path to the file of the testing set 25 | N_trees = Number of trees to use 26 | max_D = Maximum Depth 27 | max_feats = Maximum Features 28 | type = Type of Decision trees (criterion for splits) 29 | 0 - Gini 30 | 1 - Entropy 31 | times = Number of times to repeat experiments 32 | n_jobs = Optional, number of parallel threads to use 33 | 34 | -------------------------------------------------------------------------------- /doc/README_svm_lin_classifier.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for training and evaluation of SVM with linear kernel. 3 | 4 | Use svm_lin_classifier.py to train and evaluate the performance of SVM classifier with 5 | linear kernel over the specified dataset. Since testing can take very long for large 6 | datasets, the "evaluate" option is provided. By default the program will evaluate 7 | the classifier training and testing error after the training process is complete. 8 | However, if evaluate parameter is 0, the program will terminate after training is 9 | done and parallel_evalaute.py or parallel_prob_evalaute.py can be used for evaluation 10 | later with even more complete output. The probab parameter can be specified to make 11 | the classifier probabilistic. Note that top-1 accuracy of probabilistic classifiers 12 | might be a little bit lower than non-probabilistic classifier. Probabilistic 13 | classifiers are required to compute top-5 classification accuracy. After training 14 | is finished, the program will store the trained classifier on a file called: 15 | [Name of training]".LSVM" 16 | 17 | Usage: python svm_lin_classifier.py training testing [evaluate] [probab] 18 | Where 19 | training = Path to training set 20 | testing = Path to testing set 21 | evaluate = Optional, run evaluation or just training 22 | probab = Optional, make it a probabilistic classifier 23 | 24 | -------------------------------------------------------------------------------- /doc/README_svm_rbf_classifier.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for training and evaluation of SVM with Radial-Basis Function (RBF) kernel. 3 | 4 | Use svm_rbf_classifier.py to train and evaluate the performance of SVM classifier 5 | with RBF kernel over the specified dataset. Since testing can take very long for large 6 | datasets, the "evaluate" option is provided. By default the program will evaluate 7 | the classifier training and testing error after the training process is complete. 8 | However, if evaluate parameter is 0, the program will terminate after training is 9 | done and parallel_evalaute.py or parallel_prob_evalaute.py can be used for evaluation 10 | later with even more complete output. The probab parameter can be specified to make 11 | the classifier probabilistic. Note that top-1 accuracy of probabilistic classifiers 12 | might be a little bit lower than non-probabilistic classifier. Probabilistic 13 | classifiers are required to compute top-5 classification accuracy. After training 14 | is finished, the program will store the trained classifier on a file called: 15 | [Name of training]".LSVM" 16 | 17 | Usage: python svm_rbf_classifier.py training testing C Gamma eval [probab] 18 | Where 19 | training = Path to training set 20 | testing = Path to testing set 21 | C = C parameter of the RBF SVM 22 | Gamma = Gamma parameter of the RBF SVM 23 | eval = Optional, will not evaluate if equal to 1 24 | probab = Optional, make it a probabilistic classifier -------------------------------------------------------------------------------- /doc/README_train_adaboost.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for training and evaluation of AdaBoost with C4.5 decision trees classifier. 3 | 4 | Use train_adaboost.py to train and evaluate the performance of AdaBoost with C4.5 5 | decision trees over the specified dataset. After training is finished, 6 | the program will store the trained classifier on a file called: 7 | [Name of training]".BC45" 8 | 9 | Usage: python train_adaboost.py training_set testing_set rounds 10 | Where 11 | training_set = Path to the file of the training set 12 | testing_set = Path to the file of the testing set 13 | rounds = Rounds to use for AdaBoost 14 | 15 | 16 | ================================================ 17 | Notes 18 | ================================================ 19 | 20 | - All tools that work with AdaBoost require the AdaBoost library to be located 21 | on the same directory. This is a shared library that can be compiled from 22 | the provided source code using the following commands: 23 | 24 | Linux: 25 | gcc -shared -fPIC adaboost_c45.c -o adaboost_c45.so 26 | 27 | Windows (using MinGW): 28 | gcc -shared adaboost_c45.c -o adaboost_c45.so 29 | 30 | -------------------------------------------------------------------------------- /doc/README_train_c45.txt: -------------------------------------------------------------------------------- 1 | 2 | Tool for training and evaluation of C4.5 decision tree classifier. 3 | 4 | Use train_c45.py to train and evaluate the performance of a single C4.5 5 | decision trees over the specified dataset. 6 | 7 | Usage: python train_c45.py training_set testing_set 8 | Where 9 | training_set = Path to the file of the training set 10 | testing_set = Path to the file of the testing set 11 | 12 | 13 | ================================================ 14 | Notes 15 | ================================================ 16 | 17 | - All tools that work with C4.5 decision trees require the AdaBoost library 18 | to be located on the same directory. This is a shared library that can be 19 | compiled from the provided source code using the following commands: 20 | 21 | Linux: 22 | gcc -shared -fPIC adaboost_c45.c -o adaboost_c45.so 23 | 24 | Windows (using MinGW): 25 | gcc -shared adaboost_c45.c -o adaboost_c45.so 26 | 27 | -------------------------------------------------------------------------------- /server/PenStrokeServer.py: -------------------------------------------------------------------------------- 1 | # Author: Awelemdy Orakwue March 10, 2015 2 | #!/usr/bin/python 3 | 4 | import sys 5 | import os 6 | import csv 7 | import subprocess 8 | import SocketServer 9 | import SimpleHTTPServer 10 | import base64 11 | import urllib 12 | import httplib 13 | import socket 14 | from xml.dom.minidom import Document, parseString 15 | import cPickle 16 | from symbol_classifier import SymbolClassifier 17 | 18 | classifier_filename = "best_full2013_SVMRBF_new.dat" 19 | 20 | classifier = '' 21 | 22 | class RecognitionServer(SimpleHTTPServer.SimpleHTTPRequestHandler): 23 | instance_id = 0 24 | 25 | def do_GET(self): 26 | """ 27 | This will process a request which comes into the server. 28 | """ 29 | #print classifier 30 | unquoted_url = urllib.unquote(self.path) 31 | unquoted_url = unquoted_url.replace("/?segmentList=", "") 32 | unquoted_url = unquoted_url.replace("&segment=false", "") 33 | 34 | dom = parseString(unquoted_url) 35 | 36 | Segments = dom.getElementsByTagName("Segment") 37 | numSegments = Segments.length 38 | segmentIDS = [] 39 | classifierPoints = [] 40 | for j in range(numSegments): 41 | segmentIDS.append(Segments[j].getAttribute("instanceID")) 42 | points = Segments[j].getAttribute("points") 43 | url_parts = points.split('|') 44 | 45 | translation = Segments[j].getAttribute("translation").split(",") 46 | tempPoints = [] 47 | for i in range(len(url_parts)): 48 | pt = url_parts[i].split(","); 49 | pt2 = (int(pt[0]) + int(translation[0]), int(pt[1]) + int(translation[1])) 50 | tempPoints.append(pt2) 51 | classifierPoints.append(tempPoints) 52 | 53 | print classifierPoints 54 | results = classifier.classify_points_prob(classifierPoints, 30) 55 | 56 | doc = Document() 57 | root = doc.createElement("RecognitionResults") 58 | doc.appendChild(root) 59 | root.setAttribute("instanceIDs", ",".join(segmentIDS)) 60 | 61 | for k in range(len(results)): 62 | r = doc.createElement("Result") 63 | s = str(results[k][0]).replace("\\", "") 64 | sym = dict.get(s) 65 | if(sym == None): 66 | sym = s 67 | # special case due to CSV file 68 | if(sym.lower() == "comma"): 69 | sym = "," 70 | v = str(results[k][1]) 71 | c = format(float(v), '.35f') 72 | r.setAttribute("symbol", sym) 73 | r.setAttribute("certainty", c) 74 | root.appendChild(r) 75 | 76 | xml_str = doc.toxml() 77 | xml_str = xml_str.replace("amp;","") 78 | 79 | self.send_response(httplib.OK, 'OK') 80 | self.send_header("Content-length", len(xml_str)) 81 | self.send_header("Access-Control-Allow-Origin", "*") 82 | self.end_headers() 83 | self.wfile.write(xml_str) 84 | 85 | ####### UTILITY FUNCTIONS ####### 86 | 87 | 88 | if __name__ == "__main__": 89 | usage = "python PenStrokeServer " 90 | if(len(sys.argv) < 2): 91 | print usage 92 | sys.exit() 93 | 94 | HOST, PORT = "localhost", int(sys.argv[1]) 95 | print("Loading classifier") 96 | 97 | global classifer 98 | global dict 99 | in_file = open(classifier_filename, 'rb') 100 | classifier = cPickle.load(in_file) 101 | in_file.close() 102 | 103 | if not isinstance(classifier, SymbolClassifier): 104 | print("Invalid classifier file!") 105 | #return 106 | print "Reading symbol code from generic_symbol_table.csv" 107 | dict = {} 108 | with open('generic_symbol_table.csv', 'rt') as csvfile2: 109 | reader2 = csv.reader(csvfile2, delimiter=',') 110 | for row in reader2: 111 | dict[row[3]] = row[0] 112 | #print dict 113 | print "Starting server." 114 | 115 | try: 116 | server = SocketServer.TCPServer(("", PORT), RecognitionServer) 117 | except socket.error, e: 118 | print e 119 | exit(1) 120 | #proc = subprocess.Popen(['python','kdtreeServer.py']) 121 | print "Serving" 122 | server.serve_forever() 123 | 124 | -------------------------------------------------------------------------------- /server/best_full2013_SVMRBF_new.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DPRL/MathSymbolRecognizer/21a68677770a953d176de57356685027e27c0414/server/best_full2013_SVMRBF_new.dat -------------------------------------------------------------------------------- /server/generic_symbol_table.csv: -------------------------------------------------------------------------------- 1 | Codepoint,Symbol,Classifier,LaTeX, 2 | 0,0,0,0, 3 | 1,1,1,1, 4 | 2,2,2,2, 5 | 3,3,3,3, 6 | 4,4,4,4, 7 | 5,5,5,5, 8 | 6,6,6,6, 9 | 7,7,7,7, 10 | 8,8,8,8, 11 | 9,9,9,9, 12 | +,+,pluslower,+, 13 | -,-,horzlinelower,-, 14 | =,=,eqlower,=, 15 | ≥,≥,geqlower,geq, 16 | <,<,ltlower,lt, 17 | >,>,gtlower,geq, 18 | ≠,≠,neqlower,neq, 19 | ≤,≤,less_than_or_equal,leq, 20 | ∫,∫,Integralupper,int, 21 | Π,Π,product,Pi, 22 | Σ,Σ,summation,sum, 23 | √,√,sqrtlower,sqrt, 24 | [,[,lbracketlower,[, 25 | ],],rbracketlower,], 26 | (,(,lparenlower,(, 27 | ),),rparenlower,), 28 | ∞,∞,infinlower,infty, 29 | A,A,Aupper,A, 30 | B,B,Bupper,B, 31 | C,C,Cupper,C, 32 | D,D,Dupper,D, 33 | E,E,Eupper,E, 34 | F,F,Fupper,F, 35 | G,G,Gupper,G, 36 | H,H,Hupper,H, 37 | I,I,Iupper,I, 38 | J,J,Jupper,J, 39 | K,K,Kupper,K, 40 | L,L,Lupper,L, 41 | M,M,Mupper,M, 42 | N,N,Nupper,N, 43 | O,O,Oupper,O, 44 | P,P,Pupper,P, 45 | Q,Q,Qupper,Q, 46 | R,R,Rupper,R, 47 | S,S,Supper,S, 48 | T,T,Tupper,T, 49 | U,U,Uupper,U, 50 | V,V,Vupper,V, 51 | W,W,Wupper,W, 52 | X,X,Xupper,X, 53 | Y,Y,Yupper,Y, 54 | Z,Z,Zupper,Z, 55 | a,a,alower,a, 56 | b,b,blower,b, 57 | c,c,clower,c, 58 | d,d,dlower,d, 59 | e,e,elower,e, 60 | f,f,flower,f, 61 | g,g,glower,g, 62 | h,h,hlower,h, 63 | i,i,ilower,i, 64 | j,j,jlower,j, 65 | k,k,klower,k, 66 | l,l,llower,l, 67 | m,m,mlower,m, 68 | n,n,nlower,n, 69 | o,o,olower,o, 70 | p,p,plower,p, 71 | q,q,qlower,q, 72 | r,r,rlower,r, 73 | s,s,slower,s, 74 | t,t,tlower,t, 75 | u,u,ulower,u, 76 | v,v,vlower,v, 77 | w,w,wlower,w, 78 | x,x,xlower,x, 79 | y,y,ylower,y, 80 | z,z,zlower,z, 81 | α,α,alphalower,alpha, 82 | β,β,betalower,beta, 83 | γ,γ,gammalower,gamma, 84 | δ,δ,deltalower,Delta, 85 | ε,ε,epsilonlower,epsilon, 86 | ζ,ζ,zetalower,zeta, 87 | η,η,etalower,eta, 88 | θ,θ,thetalower,theta, 89 | ι,ι,iotalower,iota, 90 | κ,κ,kappalower,kappa, 91 | λ,λ,lambdalower,lambda, 92 | μ,μ,mulower,mu, 93 | ν,ν,nulower,nu, 94 | ξ,ξ,xilower,xi, 95 | ο,ο,omicronlower,omicron, 96 | π,π,pilower,pi, 97 | ρ,ρ,rholower,rho 98 | σ,σ,sigmalower,sigma 99 | τ,τ,taulower,tau 100 | υ,υ,upsilonlower,upsilon 101 | φ,φ,philower,phi 102 | χ,χ,chilower,chi 103 | ψ,ψ,psilower,psi 104 | ω,ω,omegalower,omega 105 | Α,Αlpha,Alphaupper,Alpha 106 | Β,Beta,Betaupper,Beta 107 | Γ,Γ,Gammaupper,Gamma 108 | Δ,Δ,Deltaupper,Delta 109 | Ε,Εpsilon,Epsilonupper,Epsilon 110 | Ζ,Ζeta,Zetaupper,Zeta 111 | Η,Eta,Etaupper,Eta 112 | Θ,Θ,Thetaupper,Theta 113 | Ι,Ιota,Iotaupper,Iota 114 | Κ,Κappa,Kappaupper,Kappa 115 | Λ,Λ,Lambdaupper,Lambda 116 | Μ,Μu,Muupper,Mu 117 | Ν,Νu,Nuupper,Nu 118 | Ξ,Ξ,Xiupper,Xi 119 | Ο,Οmicron,Omicronupper,Omicron 120 | Π,Pi,Piupper,Pi 121 | Ρ,Rho,Rhoupper,Rho 122 | Σ,Sigma,Sigmaupper,Sigma 123 | Τ,Τau,Tauupper,tau 124 | Υ,Upsilon,Upsilonupper,Upsilon 125 | Φ,Φ,Phiupper,Phi 126 | Χ,Chi,Chiupper,Chi 127 | Ψ,Ψ,Psiupper,Psi 128 | Ω,Ω,Omegaupper,Omega 129 | ±,±,plusminus,pm 130 | ∃,∃,exists,exists 131 | ∀,∀,forall,forall 132 | ÷,÷,div,div 133 | ∈,∈,in,in 134 | ⨉,⨉,times,times 135 | …,…,ldots,ldots 136 | >,>,gt,gt 137 | -------------------------------------------------------------------------------- /server/start: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | nohup python PenStrokeServer.py 6501 & 4 | -------------------------------------------------------------------------------- /server/symbol_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import numpy as np 25 | from traceInfo import TraceInfo 26 | from mathSymbol import MathSymbol 27 | 28 | class SymbolClassifier: 29 | TypeRandomForest = 1 30 | TypeSVMLIN = 2 31 | TypeSVMRBF = 3 32 | 33 | def __init__(self, type, trained_classifier, classes_list, classes_dict, scaler=None, probabilistic=False): 34 | self.type = type 35 | self.trained_classifier = trained_classifier 36 | self.classes_list = classes_list 37 | self.classes_dict = classes_dict 38 | self.scaler = scaler 39 | self.probabilistic = probabilistic 40 | 41 | def predict(self, dataset): 42 | return self.trained_classifier.predict(dataset) 43 | 44 | def predict_proba(self, dataset): 45 | return self.trained_classifier.predict_proba(dataset) 46 | 47 | def get_raw_classes(self): 48 | return self.trained_classifier.classes_ 49 | 50 | def get_symbol_from_points(self, points_lists): 51 | 52 | traces = [] 53 | for trace_id, point_list in enumerate(points_lists): 54 | object_trace = TraceInfo(trace_id, point_list) 55 | 56 | traces.append(object_trace) 57 | 58 | # apply general trace pre processing... 59 | # 1) first step of pre processing: Remove duplicated points 60 | object_trace.removeDuplicatedPoints() 61 | 62 | # Add points to the trace... 63 | object_trace.addMissingPoints() 64 | 65 | # Apply smoothing to the trace... 66 | object_trace.applySmoothing() 67 | 68 | # it should not ... but ..... 69 | if object_trace.hasDuplicatedPoints(): 70 | # ...remove them! .... 71 | object_trace.removeDuplicatedPoints() 72 | 73 | new_symbol = MathSymbol(0, traces, '{Unknown}') 74 | 75 | # normalize size and locations 76 | new_symbol.normalize() 77 | 78 | return new_symbol 79 | 80 | def get_symbol_features(self, symbol): 81 | # get raw features 82 | features = symbol.getFeatures() 83 | 84 | # put them in python format 85 | mat_features = np.mat(features, dtype=np.float64) 86 | 87 | # automatically transform features 88 | if self.scaler is not None: 89 | mat_features = self.scaler.transform(mat_features) 90 | 91 | return mat_features 92 | 93 | def classify_points(self, points_lists): 94 | symbol = self.get_symbol_from_points(points_lists) 95 | 96 | return self.classify_symbol(symbol) 97 | 98 | def classify_points_prob(self, points_lists, top_n=None): 99 | symbol = self.get_symbol_from_points(points_lists) 100 | 101 | return self.classify_symbol_prob(symbol, top_n) 102 | 103 | def classify_symbol(self, symbol): 104 | features = self.get_symbol_features(symbol) 105 | 106 | predicted = self.trained_classifier.predict(features) 107 | 108 | return self.classes_list[predicted[0]] 109 | 110 | 111 | def classify_symbol_prob(self, symbol, top_n=None): 112 | features = self.get_symbol_features(symbol) 113 | 114 | #try: 115 | predicted = self.trained_classifier.predict_proba(features) 116 | #except: 117 | # raise Exception("Classifier was not trained as probabilistic classifier") 118 | 119 | scores = sorted([(predicted[0, k], k) for k in range(predicted.shape[1])], reverse=True) 120 | 121 | tempo_classes = self.trained_classifier.classes_ 122 | n_classes = len(tempo_classes) 123 | if top_n is None or top_n > n_classes: 124 | top_n = n_classes 125 | 126 | confidences = [(self.classes_list[tempo_classes[scores[k][1]]], scores[k][0]) for k in range(top_n)] 127 | 128 | return confidences 129 | 130 | 131 | -------------------------------------------------------------------------------- /server/test_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | import sys 26 | import cPickle 27 | from symbol_classifier import SymbolClassifier 28 | 29 | def main(): 30 | # usage check... 31 | if len(sys.argv) < 2: 32 | print("Usage: python test_classifer.py classifier") 33 | print("Where") 34 | print("\tclassifier\t= Path to trained symbol classifier") 35 | return 36 | 37 | classifier_file = sys.argv[1] 38 | 39 | print("Loading classifier") 40 | 41 | in_file = open(classifier_file, 'rb') 42 | classifier = cPickle.load(in_file) 43 | in_file.close() 44 | 45 | if not isinstance(classifier, SymbolClassifier): 46 | print("Invalid classifier file!") 47 | return 48 | 49 | # get mapping 50 | classes_dict = classifier.classes_dict 51 | classes_l = classifier.classes_list 52 | n_classes = len(classes_l) 53 | 54 | # create test characters from points 55 | sample_x = [[(-0.8, -0.8), (0.1, 0.12), (0.8, 0.79)], [(-0.85, 0.79), (0.01, 0.005), (0.79, -0.83)]] 56 | sample_1 = [[(0.15, 0.7), (0.2, 1.0), (0.21, -1.2)], [(0.05, -1.19), (0.3, -1.25)]] 57 | sample_eq = [[(-1.5, -0.4), (1.5, -0.4)], [(-1.5, 0.4), (1.5, 0.4)]] 58 | 59 | # classify them 60 | class_x = classifier.classify_points(sample_x) 61 | class_1 = classifier.classify_points(sample_1) 62 | class_eq = classifier.classify_points(sample_eq) 63 | 64 | print("X classified as " + class_x) 65 | print("1 classified as " + class_1) 66 | print("= classified as " + class_eq) 67 | 68 | # now with confidence ... 69 | classes_x = classifier.classify_points_prob(sample_x, 3) 70 | classes_1 = classifier.classify_points_prob(sample_1, 3) 71 | classes_eq = classifier.classify_points_prob(sample_eq, 3) 72 | 73 | print("X top classes are: " + str(classes_x)) 74 | print("1 top classes are: " + str(classes_1)) 75 | print("= top classes are: " + str(classes_eq)) 76 | 77 | if __name__ == '__main__': 78 | main() -------------------------------------------------------------------------------- /src/ambiguous.txt: -------------------------------------------------------------------------------- 1 | x,X 2 | x,\times 3 | X,\times 4 | 1,| 5 | (,| 6 | ),| 7 | 1,( 8 | 1,) 9 | 1,/ 10 | 1,COMMA 11 | c,C 12 | ),COMMA 13 | p,P 14 | \prime,COMMA 15 | \prime,| 16 | v,V 17 | s,S 18 | q,9 19 | o,0 20 | \prime,/ 21 | /,COMMA -------------------------------------------------------------------------------- /src/apply_PCA_parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import pickle 26 | from sklearn.preprocessing import StandardScaler 27 | from dataset_ops import * 28 | 29 | #===================================================================== 30 | # Uses the provided PCA paramaters and applies them to a given 31 | # dataset 32 | # 33 | # Created by: 34 | # - Kenny Davila (Feb 1, 2012-2014) 35 | # Modified By: 36 | # - Kenny Davila (Feb 1, 2012-2014) 37 | # 38 | #===================================================================== 39 | 40 | 41 | def load_PCA_parameters(file_name): 42 | file_params = open(file_name, 'r') 43 | 44 | normalization = pickle.load(file_params) 45 | pca_vector = pickle.load(file_params) 46 | pca_k = pickle.load(file_params) 47 | 48 | file_params.close() 49 | 50 | return (normalization, pca_vector, pca_k) 51 | 52 | 53 | def project_PCA(training_set, eig_vectors, k): 54 | #in case that K > # of atts, then just clamp.... 55 | n_samples = training_set.shape[0] 56 | n_atts = training_set.shape[1] 57 | 58 | k = min(k, n_atts) 59 | 60 | projected = np.zeros((n_samples, k)) 61 | 62 | #...for each sample... 63 | for n in range(n_samples): 64 | x = np.mat(training_set[n, :]) 65 | 66 | #...for each eigenvector 67 | for i in range(k): 68 | #use dot product to project... 69 | #p = np.dot(x, eig_vectors[i])[0, 0] 70 | p = np.dot(x, eig_vectors[i])[0, 0] 71 | 72 | projected[n, i] = p.real 73 | 74 | return projected 75 | 76 | 77 | def main(): 78 | #usage check 79 | if len(sys.argv) != 4: 80 | print("Usage: python apply_PCA_parameters.py training_set PCA_params output") 81 | print("Where") 82 | print("\ttraining_set\t= Path to the file of the training_set") 83 | print("\tPCA_params\t= Path to the file of the PCA parameters") 84 | print("\toutput\t= File to output the final dataset with reduced dimensionality") 85 | return 86 | 87 | input_filename = sys.argv[1] 88 | params_filename = sys.argv[2] 89 | output_filename = sys.argv[3] 90 | 91 | #...load training set ... 92 | print("Loading data....") 93 | training, labels_l, att_types = load_dataset(input_filename) 94 | 95 | #...the parameters.... 96 | print("Loading parameters...") 97 | #normalization, pca_vector, pca_k = load_PCA_parameters(params_filename) 98 | scaler, pca_vector, pca_k = load_PCA_parameters(params_filename) 99 | 100 | #...apply normalization... 101 | print("Normalizing data...") 102 | #new_data = normalize_data_from_params(training, normalization) 103 | new_data = scaler.transform(training) 104 | 105 | #...transform.... 106 | print("Applying transformation...") 107 | projected = project_PCA(new_data, pca_vector, pca_k) 108 | 109 | #...save final version... 110 | print("Saving to file...") 111 | final_atts = np.zeros((pca_k, 1), dtype=np.int32) 112 | final_atts[:, :] = 1 # Continuous attributes 113 | 114 | save_dataset_string_labels(projected, labels_l, final_atts, output_filename) 115 | 116 | print("Finished!") 117 | 118 | main() 119 | 120 | -------------------------------------------------------------------------------- /src/boosted_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import ctypes 26 | import numpy as np 27 | from dataset_ops import * 28 | from evaluation_ops import * 29 | from load_inkml import * 30 | 31 | #===================================================================== 32 | # this program takes as input boosted classifier and the dataset 33 | # used for training, and a testing set and computes the classifier accuracy 34 | # 35 | # Created by: 36 | # - Kenny Davila (Jan 19, 2013) 37 | # Modified By: 38 | # - Kenny Davila (Jan 19, 2013) 39 | # - Kenny Davila (Feb 26, 2012-2014) 40 | # - eliminated mapping compatibility 41 | # - Added evaluation metrics: top-n, per-class avg 42 | # 43 | #===================================================================== 44 | 45 | c45_lib = ctypes.CDLL('./adaboost_c45.so') 46 | 47 | 48 | def count_per_class(labels, n_samples, n_classes): 49 | result_counts = [0 for x in range(n_classes)] 50 | 51 | for i in range(n_samples): 52 | result_counts[labels[i]] += 1 53 | 54 | return result_counts 55 | 56 | 57 | def output_sample(file_path, sym_id, out_path): 58 | #load file... 59 | symbols = load_inkml( file_path, True ) 60 | 61 | #find symbol ... 62 | for symbol in symbols: 63 | if symbol.id == sym_id: 64 | symbol.saveAsSVG( out_path ) 65 | 66 | return True 67 | 68 | return False 69 | 70 | 71 | def predict_top_n_data(classifier, data, top_n, n_classes): 72 | n_samples = np.size(data, 0) 73 | p_classes = np.zeros((1, n_classes), dtype=np.float64) 74 | p_classes_p = p_classes.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 75 | 76 | results = np.zeros((n_samples, top_n)) 77 | 78 | for i in range(n_samples): 79 | c_sample = data[i, :] 80 | c_sample_p = c_sample.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 81 | 82 | c45_lib.boosted_c45_probabilistic_classify(classifier, c_sample_p, p_classes_p) 83 | 84 | tempo_values = [] 85 | for k in range(n_classes): 86 | tempo_values.append((p_classes[0, k], k)) 87 | 88 | #now, sort! 89 | tempo_values = sorted(tempo_values, reverse=True) 90 | 91 | #use the top-N 92 | for k in range(top_n): 93 | results[i, k] = tempo_values[k][1] 94 | 95 | return results 96 | 97 | 98 | def main(): 99 | #usage check 100 | if len(sys.argv) < 4: 101 | print("Usage: python boosted_test.py classifier training_set testing_set [save_fail]") 102 | print("Where") 103 | print("\tclassifier\t= Path to the .bc45 file that contains the classifier") 104 | print("\ttraining_set\t= Path to the file of the training set") 105 | print("\ttesting_set\t= Path to the file of the testing set") 106 | print("\tsave_fail\t= Optional, will output failure case if specified") 107 | return 108 | 109 | #...load training data from file... 110 | print "...Loading Training Data..." 111 | 112 | file_name = sys.argv[2] 113 | training, labels_l, att_types = load_dataset(file_name) 114 | #...generate mapping... 115 | classes_dict, classes_l = get_label_mapping(labels_l) 116 | #...generate mapped labels... 117 | labels_train = get_mapped_labels(labels_l, classes_dict) 118 | 119 | if len(sys.argv) >= 5: 120 | save_fail = int(sys.argv[4]) > 0 121 | else: 122 | save_fail = False 123 | 124 | #...load testing data from file... 125 | print "...Loading Testing Data..." 126 | 127 | testing, test_labels_l, att_types = load_dataset(sys.argv[3]) 128 | #...generate mapped labels... 129 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 130 | 131 | if save_fail: 132 | #need to load sources... 133 | test_sources = load_ds_sources(sys.argv[3] + ".sources.txt") 134 | 135 | if test_sources is None: 136 | print("Sources are unavailable") 137 | save_fail = False 138 | 139 | print "...Loading classifier..." 140 | ensemble = c45_lib.boosted_c45_load( sys.argv[1]) 141 | 142 | #...info.... 143 | n_classes = len(classes_l) 144 | n_train_samples = np.size(training, 0) 145 | n_test_samples = np.size(testing, 0) 146 | 147 | print("Classifier: " + sys.argv[1]) 148 | print("Training Set: " + sys.argv[2]) 149 | print("Testing Set: " + sys.argv[3]) 150 | 151 | 152 | print "...evaluating..." 153 | 154 | top_n = 5 155 | predicted = predict_top_n_data(ensemble, training, top_n, n_classes) 156 | print "Training Samples: " + str(n_train_samples) 157 | print "Training Results" 158 | print "Top\tAccuracy\tClass Average\tClass STD " 159 | 160 | #....on main thread, compute final statistics 161 | for i in range(top_n): 162 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_train, 163 | n_classes, i + 1) 164 | accuracy = (float(total_correct) / float(n_train_samples)) * 100 165 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 166 | 167 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 168 | 169 | predicted = predict_top_n_data(ensemble, testing, top_n, n_classes) 170 | print "Testing Samples: " + str(n_test_samples) 171 | print "Testing Results" 172 | print "Top\tAccuracy\tClass Average\tClass STD " 173 | for i in range(top_n): 174 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_test, 175 | n_classes, i + 1) 176 | accuracy = (float(total_correct) / float(n_test_samples)) * 100 177 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 178 | 179 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 180 | 181 | 182 | #repeat for Top=1 accuracy... 183 | all_failure_info = [] 184 | confusion_matrix = np.zeros((n_classes, n_classes), dtype = np.int32) 185 | total_correct = 0 186 | train_counts = count_per_class(labels_train, n_train_samples, n_classes) 187 | for k in range(n_test_samples): 188 | top_label = predicted[i, 0] 189 | 190 | if top_label != labels_test[k, 0]: 191 | #inccorrect... 192 | if save_fail: 193 | file_path, sym_id = test_sources[k] 194 | output_sample(file_path, sym_id, "output//error_" + str(k) + ".svg") 195 | 196 | all_failure_info.append((k, file_path, sym_id, classes_l[top_label], classes_l[labels_test[k, 0]])) 197 | else: 198 | total_correct += 1 199 | 200 | confusion_matrix[labels_test[k, 0], top_label] += 1 201 | 202 | if save_fail: 203 | #...save additional info of errors... 204 | content = 'id, path, sym_id, predicted, expected\n' 205 | for error_info in all_failure_info: 206 | 207 | for i in range(len(error_info)): 208 | if i > 0: 209 | content += "," 210 | 211 | content += str(error_info[i]) 212 | 213 | content += "\n" 214 | 215 | file_name = sys.argv[1] + ".failures.csv" 216 | try: 217 | f = open(file_name, 'w') 218 | f.write(content) 219 | f.close() 220 | except: 221 | print("ERROR WRITING RESULTS TO FILE! <" + file_name + ">") 222 | 223 | 224 | #... print results.... 225 | accuracy = (float(total_correct) / float(n_test_samples)) * 100.0 226 | print "Testing Accuracy = " + str(accuracy) 227 | 228 | #.... save confusion matrix to file... 229 | out_str = "Samples:," + str(n_test_samples) + "\n" 230 | out_str += "Correct:," + str(total_correct) + "\n" 231 | out_str += "Wrong:," + str(n_test_samples - total_correct) + "\n" 232 | out_str += "Accuracy:," + str(accuracy) + "\n\n\n" 233 | out_str += "Rows:,Expected \n" 234 | out_str += "Column:,Predicted \n\n" 235 | out_str += "Full Matrix\n" 236 | out_str += "X" 237 | only_err = "X" 238 | #...build header... 239 | for i in range(n_classes ): 240 | c_class = classes_l[i] 241 | if c_class == ",": 242 | c_class = "COMMA" 243 | 244 | out_str += "," + c_class 245 | only_err += "," + c_class 246 | 247 | out_str += ",Total,Train Count\n" 248 | only_err += ",Total\n" 249 | 250 | #...build the content... 251 | for i in range(n_classes): 252 | c_class = classes_l[i] 253 | if c_class == ",": 254 | c_class = "COMMA" 255 | 256 | out_str += c_class 257 | only_err += c_class 258 | 259 | t_errs = 0 260 | t_samples = 0 261 | for k in range(n_classes): 262 | out_str += "," + str(confusion_matrix[i, k]) 263 | 264 | t_samples += confusion_matrix[i, k] 265 | if i == k: 266 | only_err += ",0" 267 | else: 268 | t_errs += confusion_matrix[i, k] 269 | only_err += "," + str(confusion_matrix[i, k]) 270 | 271 | out_str += "," + str(t_samples) + "," + str(train_counts[i]) + "\n" 272 | only_err += "," + str(t_errs) + "\n" 273 | 274 | #...build the totals row... 275 | out_str += "Total" 276 | only_err += "Total" 277 | for i in range(n_classes): 278 | t_errs = 0 279 | t_samples = 0 280 | 281 | for k in range(n_classes): 282 | #...inverted, k = row, i = column 283 | t_samples += confusion_matrix[k, i] 284 | if i != k: 285 | t_errs += confusion_matrix[k, i] 286 | 287 | out_str += "," + str(t_samples) 288 | only_err += "," + str(t_errs) 289 | 290 | 291 | out_str += "\n\n\nOnly Errors\n\n" + only_err 292 | 293 | file_name = sys.argv[1] + ".results.csv" 294 | 295 | try: 296 | f = open(file_name, 'w') 297 | f.write( out_str ) 298 | f.close() 299 | except: 300 | print("ERROR WRITING RESULTS TO FILE! <" + file_name + ">") 301 | 302 | print "Finished!" 303 | 304 | main() 305 | -------------------------------------------------------------------------------- /src/correct_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | from dataset_ops import * 26 | 27 | 28 | #===================================================================== 29 | # Corrects certain common labeling errors usually found in CROHME 30 | # datasets 31 | # 32 | # Created by: 33 | # - Kenny Davila (Oct, 2012) 34 | # Modified By: 35 | # - Kenny Davila (Oct 25, 2013) 36 | # - Kenny Davila (Feb 5, 2012-2014) 37 | # 38 | #===================================================================== 39 | 40 | def get_classes_list(labels): 41 | all_classes = {} 42 | 43 | for label in labels: 44 | if not label in all_classes: 45 | all_classes[label] = True 46 | 47 | return all_classes.keys() 48 | 49 | 50 | def replace_labels(labels): 51 | new_labels = [] 52 | 53 | for label in labels: 54 | if label == "\\tg": 55 | new_label = "\\tan" 56 | elif label == '>': 57 | new_label = '\gt' 58 | elif label == '<': 59 | new_label = '\lt' 60 | elif label == '\'': 61 | new_label = '\prime' 62 | elif label == '\cdots': 63 | new_label = '\ldots' 64 | elif label == '\\vec': 65 | new_label = '\\rightarrow' 66 | elif label == '\cdot': 67 | new_label = '.' 68 | elif label == ',': 69 | new_label = 'COMMA' 70 | elif label == '\\frac': 71 | new_label = '-' 72 | else: 73 | #unchanged... 74 | new_label = label 75 | 76 | new_labels.append(new_label) 77 | 78 | return new_labels 79 | 80 | 81 | def main(): 82 | #usage check 83 | if len(sys.argv) != 3: 84 | print("Usage: python correct_labels.py training_set output_set ") 85 | print("Where") 86 | print("\ttraining_set\t= Path to the file of the training_set") 87 | print("\toutput_set\t= File to output file with corrected labels") 88 | return 89 | 90 | #...load training data from file... 91 | input_filename = sys.argv[1] 92 | output_filename = sys.argv[2] 93 | 94 | #file_name = 'ds_test_2012.txt' 95 | training, labels_l, att_types = load_dataset(input_filename) 96 | 97 | if training is None: 98 | print("Data not could not be loaded") 99 | else: 100 | #extract list of classes... 101 | all_classes = get_classes_list(labels_l) 102 | 103 | print("Original number of classes = " + str(len(all_classes))) 104 | 105 | new_labels = replace_labels(labels_l) 106 | 107 | new_all_classes = get_classes_list(new_labels) 108 | print("New number of classes = " + str(len(new_all_classes))) 109 | 110 | save_dataset_string_labels(training, new_labels, att_types, output_filename) 111 | 112 | print("Success!") 113 | 114 | main() 115 | 116 | -------------------------------------------------------------------------------- /src/count_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import numpy as np 26 | from dataset_ops import * 27 | 28 | 29 | def get_counts_per_class(labels_l): 30 | count_per_class = {} 31 | 32 | for label in labels_l: 33 | if label in count_per_class: 34 | count_per_class[label] += 1 35 | else: 36 | count_per_class[label] = 1 37 | 38 | return count_per_class 39 | 40 | 41 | def count_in_common(counts_l1, counts_l2): 42 | common_1 = 0 43 | common_2 = 0 44 | common_labels = [] 45 | only1 = [] 46 | only2 = [] 47 | 48 | for label1 in counts_l1: 49 | if label1 in counts_l2: 50 | common_1 += counts_l1[label1] 51 | common_2 += counts_l2[label1] 52 | 53 | common_labels.append(label1) 54 | else: 55 | only1.append(label1) 56 | 57 | for label2 in counts_l2: 58 | if not label2 in counts_l1: 59 | only2.append(label2) 60 | 61 | 62 | return common_1, common_2, common_labels, only1, only2 63 | 64 | 65 | def extract_common_classes(dataset, labels, common_labels): 66 | #...first, find the references to samples from common labels... 67 | common_refs = [] 68 | for idx, label in enumerate(labels): 69 | if label in common_labels: 70 | common_refs.append(idx) 71 | 72 | #now... create a new dataset only with common refs... 73 | n_common_size = len(common_refs) 74 | n_atts = np.size(dataset, 1) 75 | common_set_data = np.zeros((n_common_size, n_atts), dtype=np.float64) 76 | common_set_labels = [] 77 | 78 | for idx, ref_idx in enumerate(common_refs): 79 | common_set_data[idx, :] = dataset[ref_idx, :] 80 | common_set_labels.append(labels[ref_idx]) 81 | 82 | return common_set_data, common_set_labels 83 | 84 | def main(): 85 | if len(sys.argv) != 4: 86 | print("Usage: python count_common.py dataset_1 dataset_2") 87 | print("Where") 88 | print("\tdataset_1\t= Path to the file of the first dataset") 89 | print("\tdataset_2\t= Path to the file of the second dataset") 90 | print("\textract\t= Extract common samples from second dataset") 91 | return 92 | 93 | data1_filename = sys.argv[1] 94 | data2_filename = sys.argv[2] 95 | 96 | try: 97 | do_extraction = int(sys.argv[3]) > 0 98 | except: 99 | print("Invalid value for extract parameter") 100 | return 101 | 102 | #...loading dataset... 103 | print("...Loading data set!") 104 | data1, labels_l1, att_types_1 = load_dataset(data1_filename) 105 | data2, labels_l2, att_types_2 = load_dataset(data2_filename) 106 | 107 | print("...Getting counts!") 108 | counts_l1 = get_counts_per_class(labels_l1) 109 | counts_l2 = get_counts_per_class(labels_l2) 110 | 111 | print("...Finding class overlap!") 112 | common1, common2, common_labels, only1, only2 = count_in_common(counts_l1, counts_l2) 113 | 114 | print("Classes on dataset 1: " + str(len(counts_l1.keys()))) 115 | print("Classes on dataset 2: " + str(len(counts_l2.keys()))) 116 | print("Total common classes: " + str(len(common_labels))) 117 | for k in common_labels: 118 | print(k) 119 | print("Samples of common classes on dataset 1: " + str(common1)) 120 | print("Samples of common classes on dataset 2: " + str(common2)) 121 | 122 | print("Total classes only on 1: " + str(len(only1))) 123 | for k in only1: 124 | print(k) 125 | print("Total classes only on 2: " + str(len(only2))) 126 | for k in only2: 127 | print(k) 128 | 129 | if do_extraction: 130 | print("Extracting samples with common labels from dataset 2") 131 | filtered_data, filtered_labels = extract_common_classes(data2, labels_l2, common_labels) 132 | 133 | print("Saving filtered samples ") 134 | save_dataset_string_labels(filtered_data,filtered_labels,att_types_2, data2_filename + ".common.txt") 135 | 136 | print("Finished!") 137 | 138 | main() 139 | -------------------------------------------------------------------------------- /src/dataset_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | import sys 26 | import math 27 | import numpy as np 28 | from dataset_ops import * 29 | #===================================================================== 30 | # Simple script that loads an existing dataset from a file and 31 | # outputs the most general information. 32 | # 33 | # Created by: 34 | # - Kenny Davila (Feb 7, 2012-2014) 35 | # Modified By: 36 | # - Kenny Davila (Feb 7, 2012-2014) 37 | # - Kenny Davila (Feb 7, 2012-2014) 38 | # 39 | #===================================================================== 40 | 41 | 42 | def main(): 43 | #usage check 44 | if len(sys.argv) < 2: 45 | print("Usage: python dataset_info.py datset [n_bins]") 46 | print("Where") 47 | print("\tdataset\t= Path to file that contains the data set") 48 | print("\tn_bins\t= Optional, number of bins for histogram of class representation") 49 | return 50 | 51 | input_filename = sys.argv[1] 52 | 53 | if len(sys.argv) >= 3: 54 | try: 55 | n_bins = int(sys.argv[2]) 56 | if n_bins < 1: 57 | print("Invalid n_bins value") 58 | return 59 | except: 60 | print("Invalid n_bins value") 61 | return 62 | else: 63 | n_bins = 10 64 | 65 | print("Loading data....") 66 | training, labels_l, att_types = load_dataset(input_filename); 67 | print("Data loaded!") 68 | 69 | print("Getting information...") 70 | n_samples = np.size(training, 0) 71 | n_atts = np.size(training, 1) 72 | 73 | #...counts per class... 74 | count_per_class = {} 75 | for idx in range(n_samples): 76 | #s_label = labels_train[idx, 0] 77 | s_label = labels_l[idx] 78 | 79 | if s_label in count_per_class: 80 | count_per_class[s_label] += 1 81 | else: 82 | count_per_class[s_label] = 1 83 | 84 | #...distribution... 85 | #...first pass, compute minimum and maximum... 86 | smallest_class_size = n_samples 87 | smallest_class_label = "" 88 | largest_class_size = 0 89 | largest_class_label = "" 90 | for label in count_per_class: 91 | #...check minimum 92 | if count_per_class[label] < smallest_class_size: 93 | smallest_class_size = count_per_class[label] 94 | smallest_class_label = label 95 | 96 | #...check maximum 97 | if count_per_class[label] > largest_class_size: 98 | largest_class_size = count_per_class[label] 99 | largest_class_label = label 100 | 101 | print("Class\t" + label + "\tCount\t" + str(count_per_class[label])) 102 | 103 | #...second pass... create histogram... 104 | count_bins = [0 for x in range(n_bins)] 105 | samples_bins = [0 for x in range(n_bins)] 106 | size_per_bin = int(math.ceil(float(largest_class_size + 1) / float(n_bins))) 107 | for label in count_per_class: 108 | current_bin = int(count_per_class[label] / size_per_bin) 109 | count_bins[current_bin] += 1 110 | samples_bins[current_bin] += count_per_class[label] 111 | 112 | #...print... bins... 113 | print("Class sizes distribution") 114 | for i in range(n_bins): 115 | start = i * size_per_bin + 1 116 | end = (i + 1) * size_per_bin 117 | percentage = (float(samples_bins[i]) / float(n_samples)) * 100.0 118 | print("... From " + str(start) + "\t to " + str(end) + "\t : " + 119 | str(count_bins[i]) + "\t (" + str(percentage) + " of data)") 120 | 121 | n_classes = len(count_per_class.keys()) 122 | 123 | print("Total Samples: " + str(n_samples)) 124 | print("Total Attributes: " + str(n_atts)) 125 | print("Total Classes: " + str(n_classes)) 126 | print("-> Largest Class: " + largest_class_label + "\t: " + str(largest_class_size) + " samples") 127 | print("-> Smallest Class: " + smallest_class_label + "\t: " + str(smallest_class_size) + " samples") 128 | print("Finished...") 129 | 130 | main() 131 | -------------------------------------------------------------------------------- /src/dataset_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | import numpy as np 26 | 27 | #===================================================================== 28 | # Most general functions used to load and save datasets from 29 | # text files separated by semi-colon 30 | # 31 | # Created by: 32 | # - Kenny Davila (Dic 1, 2013) 33 | # Modified By: 34 | # - Kenny Davila (Jan 26, 2012-2014) 35 | # - Added function to save dataset 36 | # - Kenny Davila (Feb 1, 2012-2014) 37 | # - Added functions for data normalization 38 | # - Kenny Davila (Feb 10, 2012-2014) 39 | # - empty lines are now skipped! 40 | # - Kenny Davila (Feb 18, 2012-2014) 41 | # - load_dataset is now more memory efficient 42 | # 43 | #===================================================================== 44 | 45 | #========================================= 46 | # Loads a dataset from a file 47 | # returns 48 | # Samples, Labels, Att Types 49 | # (NP Matrix, List, NP Matrix) 50 | #========================================= 51 | def load_dataset(file_name): 52 | try: 53 | 54 | data_file = open(file_name, 'r') 55 | lines = data_file.readlines() 56 | data_file.close() 57 | 58 | #now process every line... 59 | #first line must be the attribute type for each feature... 60 | att_types_s1 = lines[0].split(';') 61 | att_types_s2 = [ s.strip().upper() for s in att_types_s1 ] 62 | n_atts = len( att_types_s2 ) 63 | 64 | #..proces the attributes... 65 | att_types = np.zeros( (n_atts, 1), dtype=np.int32 ) 66 | for i in xrange(n_atts): 67 | if att_types_s2[i] == 'D': 68 | att_types[i] = 2 69 | else: 70 | att_types[i] = 1 71 | 72 | estimated_samples = len(lines) - 1 73 | tempo_samples = np.zeros( (estimated_samples, n_atts), dtype = np.float64 ) 74 | 75 | count_samples = 0 76 | labels_l = [] 77 | for i in xrange(1, len(lines)): 78 | values_s = lines[i].split(';') 79 | if len(values_s) == 1: 80 | if values_s[0].strip() == "": 81 | #just skip the empty line... 82 | continue 83 | 84 | #assume last value is class label 85 | label = values_s[-1].strip() 86 | del values_s[-1] 87 | 88 | #read values... 89 | values = [] 90 | for idx, att_type in enumerate(att_types): 91 | values.append( float( values_s[idx].strip() ) ) 92 | 93 | 94 | #validate number of attributes on the sample 95 | if len(values) != n_atts: 96 | print("Number of values is different to number of attributes") 97 | print("Atts: " + str(n_atts)) 98 | print("Values: " + str(len(values))) 99 | return None, None, None 100 | 101 | #add sample 102 | for att in xrange(n_atts): 103 | tempo_samples[count_samples, att] = values[att] 104 | 105 | count_samples += 1 106 | 107 | labels_l.append( label ) 108 | 109 | n_samples = count_samples 110 | if n_samples != estimated_samples: 111 | samples = tempo_samples[:n_samples, :].copy() 112 | tempo_samples = None 113 | else: 114 | samples = tempo_samples 115 | 116 | 117 | return samples, labels_l, att_types 118 | 119 | except Exception as e: 120 | print("Error loading dataset from file") 121 | print( e ) 122 | return None, None, None 123 | 124 | 125 | 126 | #============================================== 127 | # Generates a mapping for the unique set of 128 | # classes present on the given list of labels 129 | #============================================== 130 | def get_label_mapping(labels_l): 131 | classes_dict = {} 132 | classes_l = [] 133 | 134 | #... for each sample... 135 | n_samples = len(labels_l) 136 | for i in xrange( n_samples ): 137 | label = labels_l[ i ] 138 | 139 | #check mapping of labels... 140 | if not label in classes_dict: 141 | #...add label to mapping... 142 | label_val = len( classes_l ) 143 | classes_l.append( label ) 144 | 145 | classes_dict[ label ] = label_val 146 | 147 | return classes_dict, classes_l 148 | 149 | def get_mapped_labels(labels_l, classes_dict): 150 | n_samples = len( labels_l ) 151 | 152 | #...for the mapped labels... 153 | labels = np.zeros( (n_samples, 1), dtype = np.int32 ) 154 | 155 | #... for each sample ... 156 | for i in xrange( n_samples ): 157 | #get current label... 158 | label = labels_l[ i ] 159 | #get mapped label... 160 | label_val = classes_dict[ label ] 161 | #add mapped label 162 | labels[ i, 0 ] = label_val 163 | 164 | return labels 165 | 166 | 167 | def load_ds_sources(file_name): 168 | try: 169 | #read all lines... 170 | file_source = open(file_name, 'r') 171 | lines = file_source.readlines() 172 | file_source.close() 173 | 174 | all_sources = [] 175 | 176 | #get filename, symbol id per each symbol in DS 177 | for i in range(len(lines)): 178 | values_s = lines[i].split(',') 179 | 180 | #should contain 2 values... 181 | if len(values_s) != 2: 182 | print( "Invalid line <" + str(i) + "> in Auxiliary file: " + lines[i] ) 183 | return None 184 | else: 185 | file_path = values_s[0].strip() 186 | sym_id = int( values_s[1] ) 187 | 188 | all_sources.append( (file_path, sym_id) ) 189 | 190 | return all_sources 191 | 192 | except Exception as e: 193 | print(e) 194 | return None 195 | 196 | def save_dataset(data, labels, att_types, out_file): 197 | n_samples = np.size(data, 0) 198 | 199 | try: 200 | out_file = open(out_file, 'w') 201 | except: 202 | print( "File <" + out_file + "> could not be created") 203 | return 204 | 205 | #...writing first header.... 206 | content = '' 207 | #print as headers the types for each feature... 208 | n_atts = np.size(att_types, 0) 209 | for i in range(n_atts): 210 | if i > 0: 211 | content += '; ' 212 | 213 | if att_types[i, 0] == 2: 214 | content += 'D' 215 | else: 216 | content += 'C' 217 | 218 | content += '\r\n' 219 | out_file.write(content) 220 | 221 | #...now, write the samples.... 222 | content = '' 223 | for idx in range(n_samples): 224 | #...add the values... 225 | line = '' 226 | for k in range(n_atts): 227 | if k > 0: 228 | line += '; ' 229 | 230 | line += str(data[idx, k]) 231 | 232 | #... add the label... 233 | line += "; " + str(labels[idx, 0]) + "\r\n" 234 | 235 | content += line 236 | #....check if buffer is full! 237 | if len(content) >= 50000: 238 | out_file.write(content) 239 | content = '' 240 | 241 | #....write any remaining content 242 | out_file.write(content) 243 | 244 | out_file.close() 245 | 246 | def save_label_mapping(base_classes, extra_mapping, file_name): 247 | #...create.... 248 | try: 249 | out_file = open(file_name, 'w') 250 | except: 251 | print( "File <" + file_name + "> could not be created") 252 | return 253 | 254 | #... First, write the sizes of both mappings 255 | content = str(len(base_classes)) + ";" + str(len(extra_mapping.keys())) + "\r\n" 256 | out_file.write(content) 257 | 258 | #...write the original class list... 259 | content = '' 260 | for c_class in base_classes: 261 | content += c_class + "\r\n" 262 | out_file.write(content) 263 | 264 | #...write now the extra mapping... 265 | content = '' 266 | for key in extra_mapping: 267 | content += str(key) + ";" + str(extra_mapping[key]) + "\r\n" 268 | 269 | out_file.write(content) 270 | 271 | #...close... 272 | out_file.close() 273 | 274 | 275 | def load_label_mapping(file_name): 276 | #...open.... 277 | try: 278 | data_file = open(file_name, 'r') 279 | lines = data_file.readlines() 280 | data_file.close() 281 | 282 | #first line should contain the size of the mappings.. 283 | sizes_s = lines[0].split(';') 284 | n_classes = int(sizes_s[0]) 285 | n_mapped = int(sizes_s[1]) 286 | 287 | print "...Loading class mapping..." 288 | print "N-real-classes: " + str(n_classes) 289 | print "N-virtual-classes: " + str(n_mapped) 290 | 291 | class_l = [] 292 | class_dict = {} 293 | for i in range(n_classes): 294 | label = lines[1 + i].strip() 295 | 296 | class_l.append(label) 297 | class_dict[label] = i 298 | 299 | class_mapping = {} 300 | for i in range(n_mapped): 301 | #get the key -> value pair... 302 | mapped = lines[1 + n_classes + i].split(';') 303 | 304 | #...split... 305 | new_label = int(mapped[0]) 306 | original_label = int(mapped[1]) 307 | 308 | class_mapping[new_label] = original_label 309 | 310 | return class_l, class_dict, class_mapping 311 | except Exception as e: 312 | print(e) 313 | return None, None 314 | 315 | 316 | def append_symbols(symbols, out_file): 317 | n_samples = len(symbols) 318 | 319 | print("...adding samples " + str(n_samples) + " to output file...") 320 | 321 | n_atts = 0 322 | content = '' 323 | for idx, symbol in enumerate(symbols): 324 | sample = symbol.getFeatures() + [ symbol.truth ] 325 | if idx == 0: 326 | n_atts = len(sample) - 1 327 | 328 | line = '' 329 | for i, v in enumerate(sample): 330 | if i > 0: 331 | line += '; ' 332 | 333 | line += str(v) 334 | 335 | line += '\r\n' 336 | content += line 337 | 338 | #check if buffer is full! 339 | if len(content) >= 50000: 340 | out_file.write(content) 341 | content = '' 342 | 343 | #write any remaining content 344 | out_file.write(content) 345 | 346 | print("... samples added to file successfully!") 347 | 348 | return n_atts 349 | 350 | 351 | def append_dataset(data, labels, out_file): 352 | #...will append samples in dataset to output file... 353 | n_samples = np.size(data, 0) 354 | n_atts = np.size(data, 1) 355 | 356 | print("...adding samples " + str(n_samples) + " to output file...") 357 | 358 | #...now, write the samples.... 359 | content = '' 360 | for idx in range(n_samples): 361 | #...add the values... 362 | line = '' 363 | for k in range(n_atts): 364 | if k > 0: 365 | line += '; ' 366 | 367 | line += str(data[idx, k]) 368 | 369 | #... add the label... 370 | line += "; " + str(labels[idx, 0]) + "\r\n" 371 | 372 | content += line 373 | #....check if buffer is full! 374 | if len(content) >= 50000: 375 | out_file.write(content) 376 | content = '' 377 | 378 | #....write any remaining content 379 | out_file.write(content) 380 | 381 | print("... samples added to file successfully!") 382 | 383 | 384 | def append_dataset_string_labels(data, labels, out_file): 385 | #...will append samples in dataset to output file... 386 | n_samples = np.size(data, 0) 387 | n_atts = np.size(data, 1) 388 | 389 | print("...adding samples " + str(n_samples) + " to output file...") 390 | 391 | #...now, write the samples.... 392 | content = '' 393 | for idx in range(n_samples): 394 | #...add the values... 395 | line = '' 396 | for k in range(n_atts): 397 | if k > 0: 398 | line += '; ' 399 | 400 | line += str(data[idx, k]) 401 | 402 | #... add the label... 403 | line += "; " + labels[idx] + "\r\n" 404 | 405 | content += line 406 | #....check if buffer is full! 407 | if len(content) >= 50000: 408 | out_file.write(content) 409 | content = '' 410 | 411 | #....write any remaining content 412 | out_file.write(content) 413 | 414 | print("... samples added to file successfully!") 415 | 416 | 417 | def save_dataset_string_labels(data, labels, att_types, out_filename): 418 | n_samples = np.size(data, 0) 419 | 420 | try: 421 | out_file = open(out_filename, 'w') 422 | except: 423 | print( "File <" + out_filename + "> could not be created") 424 | return 425 | 426 | #...writing first header.... 427 | content = '' 428 | #print as headers the types for each feature... 429 | n_atts = np.size(att_types, 0) 430 | for i in range(n_atts): 431 | if i > 0: 432 | content += '; ' 433 | 434 | if att_types[i, 0] == 2: 435 | content += 'D' 436 | else: 437 | content += 'C' 438 | 439 | content += '\r\n' 440 | out_file.write(content) 441 | 442 | #...now, write the samples.... 443 | content = '' 444 | for idx in range(n_samples): 445 | #...add the values... 446 | line = '' 447 | for k in range(n_atts): 448 | if k > 0: 449 | line += '; ' 450 | 451 | line += str(data[idx, k]) 452 | 453 | #... add the label... 454 | line += "; " + str(labels[idx]) + "\r\n" 455 | 456 | content += line 457 | #....check if buffer is full! 458 | if len(content) >= 50000: 459 | out_file.write(content) 460 | content = '' 461 | 462 | #....write any remaining content 463 | out_file.write(content) 464 | 465 | out_file.close() 466 | 467 | 468 | #==================================================== 469 | # Data Normalization: Centering and scaling 470 | #==================================================== 471 | 472 | #---------------------------------------------------- 473 | # Function that computes normalization parameters 474 | # and applies them to given data returning a 475 | # normalized version of it 476 | #---------------------------------------------------- 477 | def normalize_data(data): 478 | params = [] 479 | n_samples = np.size(data,0) 480 | n_atts = np.size(data, 1) 481 | new_data = np.zeros((n_samples, n_atts)) 482 | 483 | #for each attribute... 484 | for att in range(n_atts): 485 | #the mean and std_dev for each att 486 | cut = data[:, att] 487 | cut_mean = cut.mean() 488 | cut_std = cut.std() 489 | 490 | #add to parameters list... 491 | params.append((cut_mean, cut_std)) 492 | 493 | #now normalize... 494 | if cut_std > 0.0: 495 | new_data[:, att] = (data[:, att] - cut_mean) / cut_std 496 | else: 497 | #only center... 498 | new_data[:, att] = (data[:, att] - cut_mean) 499 | 500 | return new_data, params 501 | 502 | #------------------------------------------------------ 503 | # Function that receives a dataset and normalization 504 | # parameters, and applies them returning a normalized 505 | # version of the data 506 | #------------------------------------------------------ 507 | def normalize_data_from_params(data, params): 508 | n_samples = np.size(data,0) 509 | n_atts = np.size(data, 1) 510 | new_data = np.zeros((n_samples, n_atts)) 511 | 512 | #for each attribute... 513 | for att in range(n_atts): 514 | cut_mean, cut_std = params[att] 515 | 516 | #now normalize... 517 | if cut_std > 0.0: 518 | new_data[:, att] = (data[:, att] - cut_mean) / cut_std 519 | else: 520 | #only center... 521 | new_data[:, att] = (data[:, att] - cut_mean) 522 | 523 | return new_data 524 | 525 | -------------------------------------------------------------------------------- /src/distorter.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import numpy as np 25 | import ctypes 26 | import scipy.ndimage as ndimage 27 | from mathSymbol import * 28 | 29 | #===================================================================== 30 | # Generates distorted versions of samples 31 | # 32 | # Created by: 33 | # - Kenny Davila (Jan 14, 2012-2014) 34 | # Modified By: 35 | # - Kenny Davila (Jan 14, 2012-2014) 36 | # - Kenny Davila (Jan 20, 2012-2014) 37 | # 38 | #===================================================================== 39 | 40 | distorter_lib = ctypes.CDLL('./distorter_lib.so') 41 | distorter_lib.distorter_init() 42 | 43 | 44 | class Distorter: 45 | def __init__(self): 46 | self.noise_size = 64 # 128 47 | self.noise_map_size = 8 48 | self.noise_map_depth = 6 49 | 50 | def getBoundingBox(self, points): 51 | min_x = points[0][0] 52 | max_x = points[0][0] 53 | min_y = points[0][1] 54 | max_y = points[0][1] 55 | 56 | for i in range(1, len(points)): 57 | x, y = points[i] 58 | 59 | if x < min_x: 60 | min_x = x 61 | if x > max_x: 62 | max_x = x 63 | if y < min_y: 64 | min_y = y 65 | if y > max_y: 66 | max_y = y 67 | 68 | return (min_x, max_x, min_y, max_y) 69 | 70 | def distortPoints(self, points, max_diagonal): 71 | new_points = [] 72 | 73 | #check original box... 74 | min_x, max_x, min_y, max_y = self.getBoundingBox(points) 75 | w = max_x - min_x 76 | h = max_y - min_y 77 | 78 | if w == 0.0: 79 | w = 0.000001 80 | 81 | if h == 0.0: 82 | h = 0.000001 83 | 84 | line_dist = math.sqrt(w ** 2 + h ** 2) * max_diagonal 85 | 86 | #Get noise maps... 87 | #...x.... 88 | #noise_x = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth) 89 | noise_x = np.zeros((self.noise_size, self.noise_size)) 90 | noise_x_p = noise_x.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 91 | distorter_lib.distorter_create_noise_map(noise_x_p, self.noise_size, self.noise_map_size, self.noise_map_depth) 92 | #...smooth x.... 93 | noise_x = ndimage.gaussian_filter(noise_x, sigma=(2, 2), order=0) 94 | 95 | #...y.... 96 | #noise_y = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth) 97 | noise_y = np.zeros((self.noise_size, self.noise_size)) 98 | noise_y_p = noise_y.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 99 | distorter_lib.distorter_create_noise_map(noise_y_p, self.noise_size, self.noise_map_size, self.noise_map_depth) 100 | #...smooth y.... 101 | noise_y = ndimage.gaussian_filter(noise_y, sigma=(2, 2), order=0) 102 | 103 | #now, apply distortion.... 104 | #...for each point.... 105 | for i in range(len(points)): 106 | x, y = points[i] 107 | 108 | #compute relative position... 109 | p_x = (x - min_x) / w 110 | p_y = (y - min_y) / h 111 | 112 | #compute noise position... 113 | p_nx = int(p_x * self.noise_size) 114 | p_ny = int(p_y * self.noise_size) 115 | if p_nx >= self.noise_size: 116 | p_nx = self.noise_size - 1 117 | if p_ny >= self.noise_size: 118 | p_ny = self.noise_size - 1 119 | 120 | #get distortion... 121 | dist_x = noise_x[p_nx, p_ny] 122 | dist_y = noise_y[p_nx, p_ny] 123 | 124 | 125 | #print "Pre: (" + str(dist_x) + ", " + str(dist_y) + ")" 126 | """ 127 | #normalize... 128 | norm = math.sqrt(dist_x * dist_x + dist_y * dist_y) 129 | if norm > 0.0: 130 | dist_x /= norm 131 | dist_y /= norm 132 | else: 133 | dist_x = 0.0 134 | dist_y = 0.0 135 | """ 136 | #print "Post: (" + str(dist_x) + ", " + str(dist_y) + ")" 137 | 138 | 139 | off_x = dist_x * line_dist 140 | off_y = dist_y * line_dist 141 | 142 | new_x = x + off_x 143 | new_y = y + off_y 144 | 145 | new_points.append((new_x, new_y)) 146 | 147 | return new_points 148 | 149 | #create a distorted version of the given symbol... 150 | def distortSymbol(self, symbol, max_diagonal): 151 | 152 | #create distorted traces... 153 | all_traces = [] 154 | for t in symbol.traces: 155 | new_points = self.distortPoints(t.original_points, max_diagonal) 156 | 157 | new_trace = TraceInfo(t.id, new_points) 158 | 159 | #try smoothing the distorted version... 160 | new_trace.removeDuplicatedPoints() 161 | 162 | #Add points to the trace... 163 | new_trace.addMissingPoints() 164 | 165 | #Apply smoothing to the trace... 166 | new_trace.applySmoothing() 167 | 168 | #it should not ... but ..... 169 | if new_trace.hasDuplicatedPoints(): 170 | #...remove them! .... 171 | new_trace.removeDuplicatedPoints() 172 | 173 | all_traces.append(new_trace) 174 | 175 | #now, create the new symbol... 176 | new_symbol = MathSymbol(symbol.id, all_traces, symbol.truth) 177 | new_symbol.normalize() 178 | 179 | return new_symbol 180 | -------------------------------------------------------------------------------- /src/distorter_lib.c: -------------------------------------------------------------------------------- 1 | /* 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | */ 24 | 25 | //Compile using: 26 | // gcc -shared distorter.c -o distorter.so 27 | 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | double distorter_init(){ 35 | //Use current time as a seed 36 | srand(time(0)); 37 | } 38 | 39 | double distorter_rand(){ 40 | return ((double)rand() / (double)RAND_MAX); 41 | } 42 | 43 | double distorter_mirror(double value){ 44 | 45 | double i_value = floor(value); 46 | 47 | if ( ((int)i_value) % 2 == 0 ){ 48 | //Not mirror 49 | value -= i_value; 50 | } else { 51 | //Mirror 52 | value = 1.0 - (value - i_value); 53 | } 54 | 55 | return value; 56 | } 57 | 58 | double distorter_bilinear_filtering(double* map, double x, double y, int n_rows, int n_cols){ 59 | if ( x < 0.0){ 60 | x = 0.0; 61 | } 62 | if ( y < 0.0){ 63 | y = 0.0; 64 | } 65 | //...row (y) and its weights... 66 | double v_row = y * (n_rows - 1); 67 | int r0 = (int)floor(v_row); 68 | double r_w1; 69 | if ( r0 >= n_rows - 1 ){ 70 | r_w1 = 1.0; 71 | r0 = n_rows - 2; 72 | } else { 73 | r_w1 = v_row - floor(v_row); 74 | } 75 | int r1 = r0 + 1; 76 | 77 | //...col (x) and its weights 78 | double v_col = x * (n_cols - 1); 79 | int c0 = (int)floor(v_col); 80 | double c_w1; 81 | if (c0 >= n_cols - 1){ 82 | c_w1 = 1.0; 83 | c0 = n_cols - 2; 84 | }else{ 85 | c_w1 = v_col - c0; 86 | } 87 | int c1 = c0 + 1; 88 | 89 | //Bilinear filtering... 90 | double final_val = map[r0 * n_cols + c0] * (1.0 - r_w1) * (1.0 - c_w1) + 91 | map[r1 * n_cols + c0] * r_w1 * (1.0 - c_w1) + 92 | map[r0 * n_cols + c1] * (1.0 - r_w1) * c_w1 + 93 | map[r1 * n_cols + c1] * r_w1 * c_w1; 94 | 95 | return final_val; 96 | } 97 | 98 | void distorter_create_noise_map(double* map_buffer, int map_n, int noise_n, int max_d){ 99 | int i, row, col; 100 | 101 | //Create random noise... 102 | int t_noise = noise_n * noise_n; 103 | double* c_map = (double *)calloc(t_noise, sizeof(double)); 104 | 105 | //...Set a random value per element... 106 | double* p_map = c_map; 107 | for (i = 0; i < t_noise; i++ ){ 108 | *p_map = distorter_rand() * 2.0 - 1.0; 109 | p_map++; 110 | } 111 | 112 | //..also random displacements... 113 | int t_disp = 2 * max_d; 114 | double* disp = (double *)calloc(2 * max_d, sizeof(double)); 115 | double* p_disp = disp; 116 | for ( i = 0; i < t_disp; i++){ 117 | *p_disp = distorter_rand(); 118 | p_disp++; 119 | } 120 | 121 | //Combine the maps acording to perlin noise algorithm... 122 | for (row = 0; row < map_n; row++ ){ 123 | for (col = 0; col < map_n; col++ ){ 124 | map_buffer[row * map_n + col] = 0.0; 125 | } 126 | } 127 | 128 | //... for each size of the map... 129 | double w, p_r, p_c, c_pow; 130 | for ( i = 0; i < max_d; i++){ 131 | double dr = disp[ i * 2 ]; 132 | double dc = disp[ i * 2 + 1 ]; 133 | 134 | if ( i < max_d - 1 ){ 135 | w = 1.0 / pow(2, i + 1); 136 | } else { 137 | w = 1.0 / pow(2, i); 138 | } 139 | 140 | c_pow = pow(2, i); 141 | 142 | //...for each pixel.... 143 | for ( row = 0; row < map_n; row++){ 144 | p_r = ((double)row / (double)map_n) * c_pow; 145 | //...mirror mapping... 146 | p_r = distorter_mirror(p_r + dr); 147 | 148 | for ( col = 0; col < map_n; col++){ 149 | p_c = ((double)col /(double)map_n) * c_pow; 150 | //...mirror mapping... 151 | p_c = distorter_mirror(p_c + dc); 152 | 153 | map_buffer[row * map_n + col] += w * distorter_bilinear_filtering(c_map, p_c, p_r, noise_n, noise_n); 154 | } 155 | } 156 | } 157 | 158 | //...Release allocated memory.... 159 | free( disp ); 160 | free( c_map ); 161 | 162 | } 163 | 164 | /* 165 | def getBoundingBox(self, points): 166 | min_x = points[0][0] 167 | max_x = points[0][0] 168 | min_y = points[0][1] 169 | max_y = points[0][1] 170 | 171 | for i in range(1, len(points)): 172 | x, y = points[i] 173 | 174 | if x < min_x: 175 | min_x = x 176 | if x > max_x: 177 | max_x = x 178 | if y < min_y: 179 | min_y = y 180 | if y > max_y: 181 | max_y = y 182 | 183 | return (min_x, max_x, min_y, max_y) 184 | 185 | def distortPoints(self, points, max_diagonal): 186 | new_points = [] 187 | 188 | #check original box... 189 | min_x, max_x, min_y, max_y = self.getBoundingBox(points) 190 | w = max_x - min_x 191 | h = max_y - min_y 192 | 193 | if w == 0.0: 194 | w = 0.000001 195 | 196 | if h == 0.0: 197 | h = 0.000001 198 | 199 | line_dist = math.sqrt(w ** 2 + h ** 2) * max_diagonal 200 | 201 | #Get noise maps... 202 | #...x.... 203 | noise_x = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth) 204 | #...smooth x.... 205 | noise_x = ndimage.gaussian_filter(noise_x, sigma=(2, 2), order=0) 206 | #...y.... 207 | noise_y = self.getNoiseMap(self.noise_size, self.noise_map_size, self.noise_map_depth) 208 | #...smooth y.... 209 | noise_y = ndimage.gaussian_filter(noise_y, sigma=(2, 2), order=0) 210 | 211 | #now, apply distortion.... 212 | #...for each point.... 213 | for i in range(len(points)): 214 | x, y = points[i] 215 | 216 | #compute relative position... 217 | p_x = (x - min_x) / w 218 | p_y = (y - min_y) / h 219 | 220 | #compute noise position... 221 | p_nx = int(p_x * self.noise_size) 222 | p_ny = int(p_y * self.noise_size) 223 | if p_nx >= self.noise_size: 224 | p_nx = self.noise_size - 1 225 | if p_ny >= self.noise_size: 226 | p_ny = self.noise_size - 1 227 | 228 | #get distortion... 229 | dist_x = noise_x[p_nx, p_ny] 230 | dist_y = noise_y[p_nx, p_ny] 231 | 232 | #normalize... 233 | norm = math.sqrt(dist_x * dist_x + dist_y * dist_y) 234 | if norm > 0.0: 235 | dist_x /= norm 236 | dist_y /= norm 237 | else: 238 | dist_x = 0.0 239 | dist_y = 0.0 240 | 241 | off_x = dist_x * line_dist 242 | off_y = dist_y * line_dist 243 | 244 | new_x = x + off_x 245 | new_y = y + off_y 246 | 247 | new_points.append((new_x, new_y)) 248 | 249 | return new_points 250 | 251 | #create a distorted version of the given symbol... 252 | def distortSymbol(self, symbol, max_diagonal): 253 | 254 | #create distorted traces... 255 | all_traces = [] 256 | for t in symbol.traces: 257 | new_points = self.distortPoints(t.original_points, max_diagonal) 258 | 259 | new_trace = TraceInfo(t.id, new_points) 260 | 261 | #try smoothing the distorted version... 262 | new_trace.removeDuplicatedPoints() 263 | 264 | #Add points to the trace... 265 | new_trace.addMissingPoints() 266 | 267 | #Apply smoothing to the trace... 268 | new_trace.applySmoothing() 269 | 270 | #it should not ... but ..... 271 | if new_trace.hasDuplicatedPoints(): 272 | #...remove them! .... 273 | new_trace.removeDuplicatedPoints() 274 | 275 | all_traces.append(new_trace) 276 | 277 | #now, create the new symbol... 278 | new_symbol = MathSymbol(symbol.id, all_traces, symbol.truth) 279 | new_symbol.normalize() 280 | 281 | return new_symbol 282 | 283 | */ 284 | -------------------------------------------------------------------------------- /src/evaluation_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import numpy as np 25 | 26 | #===================================================================== 27 | # General functions used for evaluation 28 | # 29 | # Created by: 30 | # - Kenny Davila (Feb 6, 2012-2014) 31 | # Modified By: 32 | # - Kenny Davila (Feb 6, 2012-2014) 33 | # - Kenny Davila (Feb 23, 2012-2014) 34 | # - changes to handle ambiguous classes 35 | # - output confusion matrices 36 | # - Kenny Davila (Feb 25, 2012-2014) 37 | # - Add Top-N evaluation functions 38 | # 39 | #===================================================================== 40 | 41 | 42 | def get_average_class_accuracy(counts_per_class, errors_per_class, n_classes): 43 | #now, only consider classes that existed in testing set... 44 | valid_classes = counts_per_class > 0 45 | 46 | all_accuracies = 1.0 - errors_per_class[valid_classes] / counts_per_class[valid_classes] 47 | 48 | avg_accuracy = all_accuracies.mean() 49 | std_accuracy = all_accuracies.std() 50 | 51 | return avg_accuracy, std_accuracy 52 | 53 | def compute_error_counts(predicted, labels, n_classes): 54 | counts_per_class = np.zeros(n_classes) 55 | errors_per_class = np.zeros(n_classes) 56 | total_correct = 0 57 | n_samples = np.size(labels, 0) 58 | for k in range(n_samples): 59 | expected = labels[k, 0] 60 | 61 | counts_per_class[expected] += 1 62 | 63 | if predicted[k] == expected: 64 | total_correct += 1 65 | else: 66 | errors_per_class[expected] += 1 67 | 68 | return total_correct, counts_per_class, errors_per_class 69 | 70 | 71 | def compute_topn_error_counts(predicted, labels, n_classes, top_n): 72 | counts_per_class = np.zeros(n_classes) 73 | errors_per_class = np.zeros(n_classes) 74 | total_correct = 0 75 | n_samples = np.size(labels, 0) 76 | for k in range(n_samples): 77 | expected = labels[k, 0] 78 | 79 | counts_per_class[expected] += 1 80 | 81 | found = False 82 | for n in range(top_n): 83 | if predicted[k, n] == expected: 84 | found = True 85 | break 86 | 87 | if found: 88 | total_correct += 1 89 | else: 90 | errors_per_class[expected] += 1 91 | 92 | return total_correct, counts_per_class, errors_per_class 93 | 94 | 95 | def compute_confusion_matrix(predicted, labels, n_classes): 96 | confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.int32) 97 | 98 | n_test_samples = np.size(labels, 0) 99 | 100 | #... for each sample 101 | for k in range(n_test_samples): 102 | expected = int(labels[k, 0]) 103 | 104 | confusion_matrix[expected, int(predicted[k])] += 1 105 | 106 | return confusion_matrix 107 | 108 | 109 | def compute_ambiguous_confusion_matrix(all_predicted, labels, n_classes, ambiguous): 110 | confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.int32) 111 | 112 | n_test_samples = np.size(labels, 0) 113 | 114 | #... for each sample 115 | for k in range(n_test_samples): 116 | expected = int(labels[k, 0]) 117 | predicted = int(all_predicted[k]) 118 | 119 | if ambiguous[expected, predicted] == 1: 120 | #no error will be reported between ambiguous classes 121 | confusion_matrix[expected, expected] += 1 122 | else: 123 | #usual error condition 124 | confusion_matrix[expected, predicted] += 1 125 | 126 | return confusion_matrix 127 | 128 | def save_evaluation_results(classes_l, confusion_matrix, file_name, ambiguous): 129 | 130 | n_classes = np.size(confusion_matrix, 0) 131 | 132 | per_class_correct = np.zeros(n_classes) 133 | per_class_counts = confusion_matrix.sum(1) 134 | 135 | class_avg_list = [] 136 | for i in range(n_classes): 137 | per_class_correct[i] = confusion_matrix[i, i] 138 | 139 | if per_class_counts[i] > 0: 140 | class_acc = per_class_correct[i] / per_class_counts[i] 141 | 142 | class_avg_list.append((class_acc, i, classes_l[i])) 143 | 144 | #...sort the list... 145 | class_avg_list = sorted(class_avg_list,reverse=True) 146 | 147 | #now, only consider classes that existed in testing set... 148 | valid_classes = per_class_counts > 0 149 | per_class_acc = per_class_correct[valid_classes] / per_class_counts[valid_classes] 150 | 151 | n_test_samples = confusion_matrix.sum() 152 | total_correct = per_class_correct.sum() 153 | 154 | accuracy = (float(total_correct) / float(n_test_samples)) * 100.0 155 | per_class_avg = per_class_acc.mean() * 100.0 156 | per_class_std = per_class_acc.std() * 100.0 157 | 158 | #.... save confusion matrix to file... 159 | out_str = "Samples:," + str(n_test_samples) + "\n" 160 | out_str += "N Classes:," + str(n_classes) + "\n" 161 | out_str += "Correct:," + str(total_correct) + "\n" 162 | out_str += "Wrong:," + str(n_test_samples - total_correct) + "\n" 163 | out_str += "Accuracy:," + str(accuracy) + "\n" 164 | out_str += "Per class AVG:, " + str(per_class_avg) + "\n" 165 | out_str += "Per class STD:, " + str(per_class_std) + "\n" 166 | out_str += "\n\n" 167 | out_str += "Rows:,Expected \n" 168 | out_str += "Column:,Predicted \n\n" 169 | out_str += "Full Matrix\n" 170 | out_str += "X" 171 | only_err = "X" 172 | #...build header... 173 | for i in range(n_classes ): 174 | c_class = classes_l[i] 175 | if c_class == ",": 176 | c_class = "COMMA" 177 | 178 | out_str += "," + c_class 179 | only_err += "," + c_class 180 | 181 | out_str += ",Total,Train Count\n" 182 | only_err += ",Total\n" 183 | 184 | #...build the content... 185 | for i in range(n_classes): 186 | c_class = classes_l[i] 187 | if c_class == ",": 188 | c_class = "COMMA" 189 | 190 | out_str += c_class 191 | only_err += c_class 192 | 193 | t_errs = 0 194 | t_samples = 0 195 | for k in range(n_classes): 196 | out_str += "," + str(confusion_matrix[i, k]) 197 | 198 | t_samples += confusion_matrix[i, k] 199 | if i == k: 200 | only_err += ",0" 201 | else: 202 | t_errs += confusion_matrix[i, k] 203 | only_err += "," + str(confusion_matrix[i, k]) 204 | 205 | out_str += "," + str(t_samples) + "\n" 206 | only_err += "," + str(t_errs) + "\n" 207 | 208 | #...build the totals row... 209 | out_str += "Total" 210 | only_err += "Total" 211 | for i in range(n_classes): 212 | t_errs = 0 213 | t_samples = 0 214 | 215 | for k in range(n_classes): 216 | #...inverted, k = row, i = column 217 | t_samples += confusion_matrix[k, i] 218 | if i != k: 219 | t_errs += confusion_matrix[k, i] 220 | 221 | out_str += "," + str(t_samples) 222 | only_err += "," + str(t_errs) 223 | 224 | out_str += "\n\n\nOnly Errors\n\n" + only_err + "\n\n" 225 | 226 | #now, add the class average per class.... 227 | #...also, if list of ambiguous is available... mark them! 228 | if not ambiguous is None: 229 | ambiguous_list = ambiguous.sum(0) 230 | 231 | out_str += "idx,Class,Accuracy,Ambiguous\n" 232 | else: 233 | out_str += "idx,Class,Accuracy\n" 234 | 235 | for class_acc, idx, class_name in class_avg_list: 236 | if class_name == ",": 237 | class_name = "COMMA" 238 | 239 | out_str += str(idx) + "," + class_name + "," + str(class_acc) 240 | if not ambiguous is None: 241 | if ambiguous_list[idx] > 0: 242 | out_str += ",1" 243 | else: 244 | out_str += ",0" 245 | out_str += "\n" 246 | 247 | try: 248 | f = open(file_name, 'w') 249 | f.write(out_str) 250 | f.close() 251 | except: 252 | print("ERROR WRITING RESULTS TO FILE! <" + file_name + ">") 253 | 254 | 255 | def load_ambiguous(filename, class_dict, skip_failures): 256 | #first, load all the content from the file... 257 | file_ambiguous = open(filename, 'r') 258 | all_ambiguous_lines = file_ambiguous.readlines() 259 | file_ambiguous.close() 260 | 261 | n_classes = len(class_dict.keys()) 262 | 263 | total_ambiguous = 0 264 | ambiguous = np.zeros((n_classes, n_classes), dtype=np.int32) 265 | for idx, line in enumerate(all_ambiguous_lines): 266 | #split and check... 267 | parts = line.split(',') 268 | if len(parts) != 2: 269 | print("Invalid content found in ambiguous file!") 270 | print("Line " + str(idx + 1) + ": " + line) 271 | return None 272 | 273 | grapheme_1 = parts[0].strip() 274 | grapheme_2 = parts[1].strip() 275 | 276 | #check if valid ... 277 | if not grapheme_1 in class_dict.keys(): 278 | print("Invalid class found: " + str(grapheme_1) + ".") 279 | print("Line " + str(idx + 1) + ": " + line) 280 | 281 | if skip_failures: 282 | continue 283 | else: 284 | return 285 | 286 | if not grapheme_2 in class_dict.keys(): 287 | print("Invalid class found: " + str(grapheme_2) + ".") 288 | print("Line " + str(idx + 1) + ": " + line) 289 | 290 | if skip_failures: 291 | continue 292 | else: 293 | return 294 | 295 | #...they are valid, now build the dictionary but using their mapped versions 296 | class1 = class_dict[grapheme_1] 297 | class2 = class_dict[grapheme_2] 298 | 299 | #...mark them as ambiguous (symmetric relation) 300 | if ambiguous[class1, class2] == 0: 301 | ambiguous[class1, class2] = 1 302 | ambiguous[class2, class1] = 1 303 | 304 | total_ambiguous += 1 305 | 306 | 307 | return ambiguous, total_ambiguous 308 | -------------------------------------------------------------------------------- /src/extract_symbol.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | from load_inkml import * 26 | 27 | #===================================================================== 28 | # Extract a symbol from a given inkml file using a sym id, and then 29 | # outputs the symbol to a SVG file 30 | # 31 | # Created by: 32 | # - Kenny Davila (Feb 11, 2012-2014) 33 | # Modified By: 34 | # - Kenny Davila (Feb 11, 2012-2014) 35 | # 36 | #===================================================================== 37 | 38 | def output_sample( file_path, sym_id, out_path ): 39 | #load file... 40 | symbols = load_inkml(file_path, True) 41 | 42 | #find symbol ... 43 | for symbol in symbols: 44 | if symbol.id == sym_id: 45 | print("Symbol found, class: " + symbol.truth) 46 | symbol.saveAsSVG( out_path ) 47 | 48 | return True 49 | 50 | return False 51 | 52 | def main(): 53 | #usage check 54 | if len(sys.argv) != 4: 55 | print("Usage: python extract_symbol.py inkml_file sym_id output") 56 | print("Where") 57 | print("\tinkml_file\t= Path to the inkml file that contains the symbol") 58 | print("\tsym_id\t\t= Id of the symbol to extract") 59 | print("\toutput\t\t= File where the extracted symbol will be stored") 60 | return 61 | 62 | file_path = sys.argv[1] 63 | sym_id = int(sys.argv[2]) 64 | output_path = sys.argv[3] 65 | 66 | if output_sample(file_path, sym_id, output_path): 67 | print("Sample extracted successfully!") 68 | else: 69 | print("Sample was not found in the given file!") 70 | 71 | main() 72 | -------------------------------------------------------------------------------- /src/get_PCA_parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | import sys 26 | from sklearn.preprocessing import StandardScaler 27 | import numpy as np 28 | import pickle 29 | from dataset_ops import * 30 | 31 | #===================================================================== 32 | # Find the PCA parameters for a given training set. The learn 33 | # parameters can be later applied for dimensionality reduction 34 | # of the dataset 35 | # 36 | # Created by: 37 | # - Kenny Davila (Feb 1, 2012-2014) 38 | # Modified By: 39 | # - Kenny Davila (Feb 1, 2012-2014) 40 | # 41 | #===================================================================== 42 | 43 | 44 | def get_PCA( training_set ): 45 | #get the covariance matrix... 46 | cov_matrix = np.cov(training_set.transpose()) 47 | 48 | #...obtain eigenvectors and eigenvalues... 49 | eig_values, eig_matrix = np.linalg.eig(cov_matrix) 50 | 51 | #...sort them.... 52 | pair_list = [(eig_values[i], i) for i in range(cov_matrix.shape[0])] 53 | pair_list = sorted( pair_list, reverse=True ) 54 | 55 | sorted_values = [] 56 | sorted_vectors = [] 57 | 58 | for eigenvalue, idx in pair_list: 59 | sorted_values.append(eigenvalue) 60 | sorted_vectors.append(np.mat(eig_matrix[:, idx]).T) 61 | 62 | return sorted_values, sorted_vectors 63 | 64 | 65 | def get_variance_K(values, variance): 66 | #...get the variance percentages... 67 | n_atts = len(values) 68 | total_variance = 0.0 69 | cumulative_variance = [] 70 | for i in xrange(n_atts): 71 | eigenvalue = abs(values[i]) 72 | 73 | total_variance += eigenvalue 74 | cumulative_variance.append(total_variance) 75 | 76 | #...normalize... 77 | for i in xrange(n_atts): 78 | cumulative_variance[i] /= total_variance 79 | 80 | if cumulative_variance[i] >= variance: 81 | return i + 1, cumulative_variance 82 | 83 | return n_atts, cumulative_variance 84 | 85 | 86 | def save_PCA_parameters(file_name, normalization, pca_vector, pca_k): 87 | file_params = open(file_name, 'w') 88 | pickle.dump(normalization, file_params) 89 | pickle.dump(pca_vector, file_params) 90 | pickle.dump(pca_k, file_params) 91 | file_params.close() 92 | 93 | 94 | def main(): 95 | #usage check 96 | if len(sys.argv) != 5: 97 | print("Usage: python get_PCA_parameters.py training_set output_params var_max k_max") 98 | print("Where") 99 | print("\ttraining_set\t= Path to the file of the training_set") 100 | print("\toutput_params\t= File to output PCA preprocessing parameters") 101 | print("\tvar_max\t\t= Maximum percentage of variance to add") 102 | print("\tk_max\t\t= Maximum number of Principal Components to Add") 103 | return 104 | 105 | #get parameters 106 | input_filename = sys.argv[1] 107 | output_filename = sys.argv[2] 108 | 109 | try: 110 | var_max = float(sys.argv[3]) 111 | if var_max <= 0.0 or var_max > 1.0: 112 | print("Invalid var_max value! ") 113 | return 114 | except: 115 | print("Invalid var_max value! ") 116 | return 117 | 118 | try: 119 | k_max = int(sys.argv[4]) 120 | if k_max <= 0: 121 | print("Invalid k_max value!") 122 | return 123 | except: 124 | print("Invalid k_max value!") 125 | return 126 | 127 | 128 | #...load training set ... 129 | print("Loading data....") 130 | training, labels_l, att_types = load_dataset(input_filename) 131 | 132 | print("Data loaded! ... normalizing ...") 133 | #new_data, norm_params = normalize_data(training) 134 | scaler = StandardScaler() 135 | new_data = scaler.fit_transform(training) 136 | 137 | print( "Normalized! ... Applying PCA ..." ) 138 | values, vectors = get_PCA(new_data) 139 | 140 | variance_k, all_variances = get_variance_K(values, var_max) 141 | 142 | final_k = min(variance_k, k_max) 143 | 144 | print("Final K = " + str(final_k) + ", variance = " + str(all_variances[final_k - 1]) ) 145 | 146 | print("Saving Parameters ... ") 147 | save_PCA_parameters(output_filename, scaler, vectors, final_k) 148 | 149 | print("Finished!") 150 | 151 | 152 | main() 153 | -------------------------------------------------------------------------------- /src/get_enhanced_clustered_set.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import os 25 | import sys 26 | import fnmatch 27 | from sklearn.preprocessing import StandardScaler 28 | from sklearn.cluster import Ward 29 | from traceInfo import * 30 | from mathSymbol import * 31 | from load_inkml import * 32 | from distorter import * 33 | from dataset_ops import * 34 | 35 | #===================================================================== 36 | # generates an enhanced training set from a directory containing 37 | # inkml files by adding extra samples generated with a distortion 38 | # model. it also applies clustering for to automatically identify 39 | # under represented classes. 40 | # 41 | # Created by: 42 | # - Kenny Davila (Jan 27, 2012-2014) 43 | # Modified By: 44 | # - Kenny Davila (Jan 27, 2012-2014) 45 | # 46 | #===================================================================== 47 | 48 | def print_overwrite(text): 49 | sys.stdout.write("\r") 50 | sys.stdout.write(" " * 79) 51 | sys.stdout.write("\r") 52 | sys.stdout.write(text) 53 | 54 | 55 | def replace_labels(labels): 56 | new_labels = [] 57 | 58 | for label in labels: 59 | if label == "\\tg": 60 | new_label = "\\tan" 61 | elif label == '>': 62 | new_label = '\gt' 63 | elif label == '<': 64 | new_label = '\lt' 65 | elif label == '\'': 66 | new_label = '\prime' 67 | elif label == '\cdots': 68 | new_label = '\ldots' 69 | elif label == '\\vec': 70 | new_label = '\\rightarrow' 71 | elif label == '\cdot': 72 | new_label = '.' 73 | elif label == ',': 74 | new_label = 'COMMA' 75 | elif label == '\\frac': 76 | new_label = '-' 77 | else: 78 | #unchanged... 79 | new_label = label 80 | 81 | new_labels.append(new_label) 82 | 83 | return new_labels 84 | 85 | 86 | def get_symbol_features(symbols, n_atts, verbose): 87 | if verbose: 88 | print("...Getting features for samples....") 89 | 90 | #...get features.... 91 | total_extra = len(symbols) 92 | extra_features = np.zeros((total_extra, n_atts)) 93 | 94 | for idx, symbol in enumerate(symbols): 95 | features = symbol.getFeatures() 96 | 97 | #copy features from list to numpy array 98 | for k in range(n_atts): 99 | extra_features[idx, k] = features[k] 100 | 101 | #...some output.... 102 | if verbose and (idx == total_extra - 1 or idx % 10 == 0): 103 | print_overwrite("...Processed " + str(idx + 1) + " of " + str(total_extra)) 104 | 105 | return extra_features 106 | 107 | 108 | def main(): 109 | #usage check 110 | if len(sys.argv) < 7: 111 | print("Usage: python get_enhanced_clustered_set.py inkml_path output min_prc diag_dist max_clusters " + 112 | "clust_prc [verbose] [count_only]") 113 | print("Where") 114 | print("\tinkml_path\t= Path to directory that contains the inkml files") 115 | print("\toutput\t\t= File name of the output file") 116 | print("\tmin_prc\t\t= Minimum representation based on (%) of largest class") 117 | print("\tdiag_dist\t= Distortion factor relative to length of main diagonal") 118 | print("\tmax_clusters\t= Maximum number of clusters per large class") 119 | print("\tclust_prc\t= Minimum cluster size based on (%) of largest ") 120 | print("\tverbose\t= Optional, print detailed messages ") 121 | print("\tcount_only\t= Will only count what will be the final size of dataset") 122 | return 123 | 124 | #load and filter the list of files, the result is a list of inkml files only 125 | try: 126 | complete_list = os.listdir(sys.argv[1]) 127 | filtered_list = [] 128 | for file in complete_list: 129 | if fnmatch.fnmatch(file, '*.inkml'): 130 | filtered_list.append( file ) 131 | except: 132 | print( "The inkml path <" + sys.argv[1] + "> is invalid!" ) 133 | return 134 | 135 | output_filename = sys.argv[2] 136 | 137 | try: 138 | min_prc = float(sys.argv[3]) 139 | if min_prc < 0.0: 140 | print("Invalid minimum percentage") 141 | return 142 | except: 143 | print("Invalid minimum percentage") 144 | return 145 | 146 | try: 147 | diag = float(sys.argv[4]) 148 | if diag < 0.0: 149 | print("Invalid distortion factor") 150 | return 151 | except: 152 | print("Invalid distortion factor") 153 | return 154 | 155 | try: 156 | components = int(sys.argv[5]) 157 | if components < 0: 158 | print("Invalid max_clusters") 159 | except: 160 | print("Invalid max_clusters") 161 | return 162 | 163 | try: 164 | clust_prc = float(sys.argv[6]) 165 | if clust_prc < 0.0: 166 | print("Invalid minimum cluster percentage") 167 | return 168 | except: 169 | print("Invalid minimum cluster percentage") 170 | return 171 | 172 | if len(sys.argv) > 7: 173 | try: 174 | verbose = int(sys.argv[7]) > 0 175 | except: 176 | print("Invalid value for verbose") 177 | return 178 | else: 179 | #by default, print all messages 180 | verbose = True 181 | 182 | if len(sys.argv) > 8: 183 | try: 184 | count_only = int(sys.argv[8]) > 0 185 | except: 186 | print("Invalid value for count_only") 187 | return 188 | else: 189 | #by default... 190 | count_only = False 191 | 192 | #....read every inkml file in the path specified... 193 | #....create the initial symbol objects... 194 | all_symbols = [] 195 | sources = [] 196 | print("Loading samples from files.... ") 197 | for i in range(len(filtered_list)): 198 | file_name = filtered_list[i] 199 | file_path = sys.argv[1] + '//' + file_name 200 | advance = float(i) / len(filtered_list) 201 | if verbose: 202 | print_overwrite(("Processing => {:.2%} => " + file_path).format(advance)) 203 | 204 | #print() 205 | 206 | symbols = load_inkml(file_path, True) 207 | 208 | for new_symbol in symbols: 209 | all_symbols.append(new_symbol) 210 | sources.append((file_path, new_symbol.id)) 211 | 212 | print("") 213 | print("....samples loaded!") 214 | 215 | #...get features of original training set... 216 | print("Getting features of base training set...") 217 | training = None 218 | labels_l = None 219 | att_types = None 220 | n_original_samples = len(all_symbols) 221 | for idx, symbol in enumerate(all_symbols): 222 | features = symbol.getFeatures() 223 | 224 | if idx == 0: 225 | #...create training set .... 226 | 227 | #...first, the feature types... 228 | n_atts = len(features) 229 | feature_types = symbol.getFeaturesTypes() 230 | att_types = np.zeros((n_atts, 1), dtype=np.int32) 231 | for i in xrange(n_atts): 232 | if feature_types[i] == 'D': 233 | att_types[i] = 2 234 | else: 235 | att_types[i] = 1 236 | 237 | #...for labels... 238 | labels_l = [] 239 | 240 | #...finally, the matrix for training set itself... 241 | training = np.zeros((n_original_samples, n_atts)) 242 | 243 | #copy features from list to numpy array 244 | for k in range(n_atts): 245 | training[idx, k] = features[k] 246 | 247 | #copy label 248 | labels_l.append(symbol.truth) 249 | 250 | #...some output.... 251 | if verbose and (idx == n_original_samples - 1 or idx % 10 == 0): 252 | print_overwrite("...Processed " + str(idx + 1) + " of " + str(n_original_samples)) 253 | 254 | #...correct labels.... 255 | labels_l = replace_labels(labels_l) 256 | 257 | 258 | #...scale original data... 259 | print("Scaling original data...") 260 | scaler = StandardScaler() 261 | scaled_training = scaler.fit_transform(training) 262 | 263 | #...save original samples to file.... 264 | if not count_only: 265 | print("Saving original data...") 266 | save_dataset_string_labels(training, labels_l, att_types, output_filename) 267 | #append_dataset_string_labels 268 | 269 | #... identify under-represented classes ... 270 | #... first, count samples per class ... 271 | print("Getting counts...") 272 | count_per_class = {} 273 | refs_per_class = {} 274 | for idx in range(n_original_samples): 275 | #s_label = labels_train[idx, 0] 276 | s_label = labels_l[idx] 277 | 278 | if s_label in count_per_class: 279 | count_per_class[s_label] += 1 280 | refs_per_class[s_label].append(idx) 281 | else: 282 | count_per_class[s_label] = 1 283 | refs_per_class[s_label] = [idx] 284 | 285 | n_classes = len(count_per_class.keys()) 286 | 287 | #...distribute data in classes... 288 | largest_size = 0 289 | for label in count_per_class: 290 | #...check largest... 291 | if count_per_class[label] > largest_size: 292 | largest_size = count_per_class[label] 293 | 294 | print("Samples on largest class: " + str(largest_size)) 295 | 296 | #Now, do the data enhancement.... 297 | distorter = Distorter() 298 | 299 | #....for each class.... 300 | n_extra_samples = 0 301 | min_elements = int(math.ceil(min_prc * largest_size)) 302 | print("Analyzing data per class...") 303 | for label in count_per_class: 304 | current_refs = refs_per_class[label] 305 | n_class_samples = count_per_class[label] 306 | 307 | if n_class_samples < min_elements: 308 | print("Class: " + label + ", count = " + str(n_class_samples) + ", adding samples...") 309 | 310 | #class is under represented, generate distorted artificial samples.... 311 | to_create = min_elements - n_class_samples 312 | n_extra_samples += to_create 313 | 314 | if not count_only: 315 | #...create... 316 | extra_samples = [] 317 | for i in range(to_create): 318 | #...take one element from original samples 319 | base_symbol = all_symbols[current_refs[(i % n_class_samples)]] 320 | 321 | #...create distorted version... 322 | new_symbol = distorter.distortSymbol(base_symbol, diag) 323 | 324 | #...add... 325 | extra_samples.append(new_symbol) 326 | 327 | #...get features.... 328 | extra_features = get_symbol_features(extra_samples, n_atts, verbose) 329 | #...create labels.... 330 | extra_labels = [label] * len(extra_samples) 331 | 332 | #...now, save the extra samples to file... 333 | out_file = open(output_filename, 'a') 334 | append_dataset_string_labels(extra_features, extra_labels, out_file) 335 | out_file.close() 336 | else: 337 | #class has more than enough data, try clustering and then enhancing small clusters... 338 | 339 | print("Class: " + label + ", count = " + str(n_class_samples) + ", clustering...") 340 | 341 | if components == 0: 342 | print("...Skipping clustering...") 343 | continue 344 | 345 | #create empty dataset ... 346 | class_data = np.zeros((n_class_samples, n_atts)) 347 | #fill with samples... 348 | for i in xrange(n_class_samples): 349 | class_data[i, :] = scaled_training[current_refs[i], :] 350 | 351 | #apply clustering... 352 | ward = Ward(n_clusters=components).fit(class_data) 353 | cluster_labels = ward.labels_ 354 | 355 | #separate references per cluster, and get counts... 356 | counts_per_cluster = [0] * components 357 | refs_per_cluster = {} 358 | for i in range(n_class_samples): 359 | c_label = int(cluster_labels[i]) 360 | 361 | if c_label in refs_per_cluster: 362 | refs_per_cluster[c_label].append(current_refs[i]) 363 | else: 364 | refs_per_cluster[c_label] = [current_refs[i]] 365 | 366 | #increase the count of samples per class... 367 | counts_per_cluster[c_label] += 1 368 | 369 | #...check minimum size for enhancement 370 | largest_cluster = max(counts_per_cluster) 371 | #print "Largest cluster: " + str(largest_cluster) 372 | min_cluster_size = int(math.ceil(clust_prc * largest_cluster)) 373 | final_counts_per_cluster = [] 374 | 375 | #...for each cluster... 376 | extra_samples = [] 377 | for i in range(components): 378 | #...if has les than the minimum number of elements... 379 | if counts_per_cluster[i] < min_cluster_size: 380 | #enhance.... 381 | c_cluster = counts_per_cluster[i] 382 | to_create = min_cluster_size - c_cluster 383 | cluster_refs = refs_per_cluster[i] 384 | 385 | n_extra_samples += to_create 386 | 387 | if not count_only: 388 | #...create... 389 | for k in range(to_create): 390 | #...take one element from original samples 391 | base_symbol = all_symbols[cluster_refs[(k % c_cluster)]] 392 | #...create distorted version... 393 | new_symbol = distorter.distortSymbol(base_symbol, diag) 394 | #...modify label (use new mapped label)... 395 | #new_symbol.truth = classes_dict[new_symbol.truth] 396 | #...add... 397 | extra_samples.append(new_symbol) 398 | 399 | final_counts_per_cluster.append(min_cluster_size) 400 | else: 401 | final_counts_per_cluster.append(counts_per_cluster[i]) 402 | 403 | #print counts_per_cluster 404 | #print final_counts_per_cluster 405 | 406 | print ("Cluster #" + str(i + 1) + ", i. size = " + str(counts_per_cluster[i]) + 407 | ", f. size = " + str(final_counts_per_cluster[i])) 408 | 409 | if not count_only: 410 | #...get features.... 411 | extra_features = get_symbol_features(extra_samples, n_atts, verbose) 412 | 413 | #...create labels matrix.... 414 | extra_labels = [label] * len(extra_samples) 415 | 416 | #...now, save the extra samples to file... 417 | out_file = open(output_filename, 'a') 418 | #append_dataset(extra_features, extra_labels, out_file) 419 | append_dataset_string_labels(extra_features, extra_labels, out_file) 420 | out_file.close() 421 | 422 | extra_samples = [] 423 | 424 | print("Initial size: " + str(n_original_samples)) 425 | print("Extra samples: " + str(n_extra_samples)) 426 | print("Final size: " + str(n_extra_samples + n_original_samples)) 427 | print("Added Ratio: " + str((float(n_extra_samples) / float(n_original_samples)) * 100.0)) 428 | 429 | print("Finished!") 430 | 431 | main() 432 | -------------------------------------------------------------------------------- /src/get_training_set.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import os 25 | import sys 26 | import fnmatch 27 | import string 28 | from traceInfo import * 29 | from mathSymbol import * 30 | from load_inkml import * 31 | 32 | #===================================================================== 33 | # generates a training set from a directory containing the inkml 34 | # 35 | # Created by: 36 | # - Kenny Davila (Oct, 2012) 37 | # Modified By: 38 | # - Kenny Davila (Oct 19, 2012) 39 | # - Kenny Davila (Nov 25, 2013) 40 | # - Added ID to symbol 41 | # - Added AUX file to output origin for each sample in DS 42 | # - Kenny Davila (Jan 16, 2012-2014) 43 | # - Print number of attributes found 44 | # - Kenny Davila (March 2016) 45 | # - Added additional error handling for files with errors 46 | # 47 | #===================================================================== 48 | 49 | def main(): 50 | #usage check 51 | if len(sys.argv) != 3: 52 | print("Usage: python get_training_set.py inkml_path output") 53 | print("Where") 54 | print("\tinkml_path\t= Path to directory that contains the inkml files") 55 | print("\toutput\t\t= File name of the output file") 56 | return 57 | 58 | #load and filter the list of files, the result is a list of inkml files only 59 | try: 60 | complete_list = os.listdir(sys.argv[1]) 61 | filtered_list = [] 62 | for file in complete_list: 63 | if fnmatch.fnmatch(file, '*.inkml'): 64 | filtered_list.append( file ) 65 | except: 66 | print( "The inkml path <" + sys.argv[1] + "> is invalid!" ) 67 | return 68 | 69 | samples = [] 70 | labels_found = {} 71 | sources = [] 72 | error_files = [] 73 | 74 | #read every file in the path specified... 75 | for i in range(len(filtered_list)): 76 | file_name = filtered_list[i] 77 | file_path = sys.argv[1] + '//' + file_name; 78 | advance = float(i) / len(filtered_list) 79 | print(("Processing => {:.2%} => " + file_path).format( advance )) 80 | 81 | try: 82 | symbols = load_inkml( file_path, True ) 83 | except: 84 | print("Failed processing: " + file_path) 85 | error_files.append(file_path) 86 | symbols = [] 87 | 88 | for new_symbol in symbols: 89 | #now generate the features and add them to the list, including the tag 90 | #for the expected class.... 91 | sample = new_symbol.getFeatures() + [ new_symbol.truth ] 92 | samples.append( sample ) 93 | 94 | #count samples per class 95 | if not new_symbol.truth in labels_found: 96 | labels_found[ new_symbol.truth ] = 1 97 | else: 98 | labels_found[ new_symbol.truth ] += 1 99 | 100 | #the source of current symbol will be 101 | #exported as auxiliary file 102 | sources.append( ( file_path, new_symbol.id) ) 103 | 104 | 105 | print("Total input files: " + str(len(filtered_list))) 106 | print("Total valid files: " + str(len(filtered_list) - len(error_files))) 107 | print("Files with errors: ") 108 | for filename in error_files: 109 | print("\t- " + filename) 110 | 111 | print("Total files with error: " + str(len(error_files))) 112 | 113 | print( "Found: " + str(len(labels_found.keys())) + " different classes" ) 114 | if len(samples) > 0: 115 | print( "Found: " + str(len(samples[0]) - 1 ) + " different attributes" ) 116 | 117 | print "Saving main .... " 118 | #now that all the samples have been collected, write them all 119 | #in the output file 120 | try: 121 | file = open(sys.argv[2], 'w') 122 | except: 123 | print( "File <" + sys.argv[2] + "> could not be created") 124 | return 125 | 126 | content = '' 127 | #print as headers the types for each feature... 128 | feature_types = new_symbol.getFeaturesTypes() 129 | for i, feat_type in enumerate(feature_types): 130 | if i > 0: 131 | content += '; ' 132 | content += feat_type 133 | content += '\r\n' 134 | 135 | for sample in samples: 136 | line = '' 137 | for i, v in enumerate(sample): 138 | if i > 0: 139 | line += '; ' 140 | 141 | if v.__class__.__name__ == "list": 142 | #multiple values... 143 | for j, sv in enumerate(v): 144 | if j > 0: 145 | line += '; ' 146 | line += str(sv) 147 | else: 148 | #single value... 149 | line += str(v) 150 | 151 | line += '\r\n' 152 | content += line 153 | 154 | if len(content) >= 50000: 155 | file.write(content) 156 | content = '' 157 | 158 | file.write(content) 159 | 160 | file.close() 161 | 162 | print "Saving auxiliary.... " 163 | 164 | #Now, add the auxiliary file 165 | try: 166 | aux_file = open(sys.argv[2] + ".sources.txt" , 'w') 167 | except: 168 | print( "File <" + sys.argv[2] + ".sources.txt> could not be created") 169 | return 170 | 171 | content = '' 172 | for source_path, sym_id in sources: 173 | content += source_path + ', ' + str(sym_id) + '\r\n' 174 | 175 | aux_file.write(content) 176 | 177 | aux_file.close() 178 | 179 | print "Done!" 180 | main() 181 | -------------------------------------------------------------------------------- /src/load_inkml.py: -------------------------------------------------------------------------------- 1 | from traceInfo import * 2 | from mathSymbol import * 3 | import xml.etree.ElementTree as ET 4 | 5 | #===================================================================== 6 | # Load an INKML file using the mathSymbol and traceInfo classes 7 | # to represent the data 8 | # 9 | # Created by: 10 | # - Kenny Davila (Oct, 2012) 11 | # Modified By: 12 | # - Kenny Davila (Oct 24, 2013) 13 | # - Kenny Davila (Nov 25, 2013) 14 | # - Added ID to symbol 15 | # - Kenny Davila (Jan 17, 2014) 16 | # - Fixed ID cases for CHROME 2013 data containing colon 17 | # - Kenny Davila (Feb 3, 2014) 18 | # - Fixed cases where symbol loaded has no traces 19 | # - Kenny Davila (Apr 11, 2014) 20 | # - Fixed cases where symbol loaded has no traces 21 | # - Added junk symbol compatibility 22 | # - Kenny Davila (March, 2016) 23 | # - Error handling 24 | # 25 | #===================================================================== 26 | 27 | 28 | #debug flags 29 | debug_raw = False 30 | debug_added = False 31 | debug_smoothing = False 32 | debug_normalization = False 33 | 34 | #the current XML namespace prefix... 35 | INKML_NAMESPACE = '{http://www.w3.org/2003/InkML}' 36 | 37 | def load_inkml_traces(file_name): 38 | #first load the tree... 39 | tree = ET.parse(file_name) 40 | root = tree.getroot() 41 | 42 | #extract all the traces first... 43 | traces_objects = {} 44 | for trace in root.findall(INKML_NAMESPACE + 'trace'): 45 | #text contains all points as string, parse them and put them 46 | #into a list of tuples... 47 | points_s = trace.text.split(","); 48 | points_f = [] 49 | for p_s in points_s: 50 | #split again... 51 | coords_s = p_s.split() 52 | #add... 53 | points_f.append( (float(coords_s[0]), float(coords_s[1])) ) 54 | 55 | trace_id = int(trace.attrib['id']) 56 | 57 | #now create the element 58 | object_trace = TraceInfo(trace_id, points_f ) 59 | 60 | #add to the diccionary... 61 | traces_objects[trace_id] = object_trace 62 | 63 | #apply general trace pre processing... 64 | 65 | #1) first step of pre processing: Remove duplicated points 66 | object_trace.removeDuplicatedPoints() 67 | 68 | if debug_raw: 69 | #output raw data 70 | file = open('out_raw_' + trace.attrib['id'] + '.txt', 'w') 71 | file.write( str(object_trace) ) 72 | file.close() 73 | 74 | #Add points to the trace... 75 | object_trace.addMissingPoints() 76 | 77 | if debug_added: 78 | #output raw data 79 | file = open('out_added_' + trace.attrib["id"] + '.txt', 'w') 80 | file.write( str(object_trace) ) 81 | file.close() 82 | 83 | #Apply smoothing to the trace... 84 | object_trace.applySmoothing() 85 | 86 | #it should not ... but ..... 87 | if object_trace.hasDuplicatedPoints(): 88 | #...remove them! .... 89 | object_trace.removeDuplicatedPoints() 90 | 91 | if debug_smoothing: 92 | #output data after smoothing 93 | file = open('out_smoothed_' + trace.attrib["id"] + '.txt', 'w') 94 | file.write( str(object_trace) ) 95 | file.close() 96 | 97 | 98 | 99 | return root, traces_objects 100 | 101 | def extract_symbols( root, traces_objects, truth_available ): 102 | #put all the traces together with their corresponding symbols... 103 | #first, find the root of the trace groups... 104 | groups_root = root.find(INKML_NAMESPACE + 'traceGroup') 105 | trace_groups = groups_root.findall(INKML_NAMESPACE + 'traceGroup') 106 | 107 | symbols = [] 108 | avg_width = 0.0 109 | avg_height = 0.0 110 | 111 | for group in trace_groups: 112 | if truth_available: 113 | #search for class label... 114 | symbol_class = group.find(INKML_NAMESPACE + 'annotation').text 115 | 116 | #search for id attribute... 117 | symbol_id = 0 118 | for id_att_name in group.attrib: 119 | if id_att_name[-2:] == "id": 120 | try: 121 | symbol_id = int(group.attrib[id_att_name]) 122 | except: 123 | #could not convert to int, try spliting... 124 | symbol_id = int( group.attrib[id_att_name].split(":")[0] ) 125 | 126 | else: 127 | #unknown 128 | symbol_class = '{Unknown}' 129 | symbol_id = 0 130 | 131 | #link with corresponding traces... 132 | group_traces = group.findall(INKML_NAMESPACE + 'traceView') 133 | symbol_list = [] 134 | for trace in group_traces: 135 | object_trace = traces_objects[int(trace.attrib["traceDataRef"])] 136 | symbol_list.append(object_trace) 137 | 138 | #create the math symbol... 139 | try: 140 | new_symbol = MathSymbol(symbol_id, symbol_list, symbol_class) 141 | except Exception as e: 142 | print("Failed to load symbol!!") 143 | print(e) 144 | #skip this symbol... 145 | continue 146 | 147 | #capture statistics of relative size.... 148 | symMinX, symMaxX, symMinY, symMaxY = new_symbol.original_box 149 | #...add... 150 | avg_width += (symMaxX - symMinX) 151 | avg_height += (symMaxY - symMinY) 152 | 153 | #now normalize size and locations for traces in current symbol 154 | new_symbol.normalize() 155 | 156 | if debug_normalization: 157 | #output data after relocated 158 | for trace in group_traces: 159 | object_trace = traces_objects[int(trace.attrib["traceDataRef"])] 160 | 161 | file = open('out_reloc_' + trace.attrib["traceDataRef"] + '.txt', 'w') 162 | file.write( str(object_trace) ) 163 | file.close() 164 | 165 | symbols.append(new_symbol) 166 | 167 | if len(trace_groups) > 0: 168 | avg_width /= len(trace_groups) 169 | avg_height /= len(trace_groups) 170 | 171 | for s in symbols: 172 | s.setSizeRatio(avg_width, avg_height) 173 | 174 | return symbols 175 | 176 | def load_inkml(file_name, truth_available): 177 | 178 | root, traces_objects = load_inkml_traces(file_name) 179 | 180 | symbols = extract_symbols(root, traces_objects, truth_available) 181 | 182 | return symbols 183 | 184 | def extract_junk_symbol(traces_objects, junk_class_name): 185 | symbol_id = 0 186 | symbol_class = junk_class_name 187 | 188 | symbol_list = [] 189 | for trace_id in traces_objects: 190 | symbol_list.append(traces_objects[trace_id]) 191 | 192 | #create the math symbol... 193 | try: 194 | new_symbol = MathSymbol(symbol_id, symbol_list, symbol_class) 195 | except Exception as e: 196 | print("Failed to load symbol!!") 197 | print(e) 198 | return None 199 | 200 | #now normalize size and locations for traces in current symbol 201 | new_symbol.normalize() 202 | 203 | return new_symbol 204 | 205 | def load_junk_inkml(file_name, junk_class_name): 206 | 207 | root, traces_objects = load_inkml_traces(file_name) 208 | 209 | symbols = [extract_junk_symbol(traces_objects, junk_class_name)] 210 | 211 | return symbols 212 | -------------------------------------------------------------------------------- /src/parallel_evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | 26 | import sys 27 | import time 28 | import math 29 | import numpy as np 30 | import cPickle 31 | import multiprocessing 32 | from sklearn.preprocessing import StandardScaler 33 | from dataset_ops import * 34 | from evaluation_ops import * 35 | from symbol_classifier import SymbolClassifier 36 | 37 | #===================================================================== 38 | # this program takes as input a training set, a testing set and a 39 | # classifier and runs the evaluation on parallel, then metrics 40 | # are computed. 41 | # 42 | # Created by: 43 | # - Kenny Davila (Feb 12, 2012-2014) 44 | # Modified By: 45 | # - Kenny Davila (Feb 12, 2012-2014) 46 | # - Kenny Davila (Mar 11, 2015) 47 | # - Symbol classifier class now adopted 48 | # 49 | #===================================================================== 50 | 51 | 52 | def evaluate_data(classifier, data, first, last, results_queue): 53 | 54 | predicted = classifier.predict(data) 55 | 56 | results_queue.put((first, last, predicted)) 57 | 58 | def parallel_evaluate(classifier, dataset, workers): 59 | #...determine the ranges per worker... 60 | n_samples = np.size(dataset, 0) 61 | max_per_worker = int(math.ceil(float(n_samples) / float(workers))) 62 | 63 | #...prepare for parallel processing... 64 | results_queue = multiprocessing.Queue() 65 | processes = [] 66 | final_results = np.zeros(n_samples) 67 | 68 | #...start parallel threads.... 69 | for i in range(workers): 70 | first_element = max_per_worker * i 71 | last_element = min(max_per_worker * (i + 1), n_samples) 72 | 73 | params = (classifier, dataset[first_element:last_element, :], first_element, last_element, results_queue) 74 | p = multiprocessing.Process(target=evaluate_data, args=params) 75 | p.start() 76 | 77 | processes.append(p) 78 | 79 | #...await for al results to be ready... 80 | for i in range(workers): 81 | first, last, predicted = results_queue.get() 82 | 83 | #merge with current final results... 84 | final_results[first:last] = predicted 85 | 86 | return final_results 87 | 88 | 89 | def main(): 90 | if len(sys.argv) < 5: 91 | print("Usage: python parallel_evaluate.py training_set testing_set classifier normalize " + 92 | "workers [test_only] [ambiguous]") 93 | print("Where") 94 | print("\ttraining_set\t= Path to the file of the training set") 95 | print("\ttesting_set\t= Path to the file of the testing set") 96 | print("\tclassifier\t= File that contains the pickled classifier") 97 | print("\tworkers\t\t= Number of parallel threads to use") 98 | print("\ttest_only\t= Optional, Only execute for testing set") 99 | print("\tambiguous\t= Optional, file that contains the list of ambiguous") 100 | return 101 | 102 | training_file = sys.argv[1] 103 | testing_file = sys.argv[2] 104 | classifier_file = sys.argv[3] 105 | 106 | try: 107 | workers = int(sys.argv[4]) 108 | if workers < 1: 109 | print("Invalid number of workers") 110 | return 111 | except: 112 | print("Invalid number of workers") 113 | return 114 | 115 | if len(sys.argv) >= 6: 116 | try: 117 | test_only = int(sys.argv[5]) > 0 118 | except: 119 | print("Invalid value for test_only") 120 | return 121 | else: 122 | test_only = False 123 | 124 | if len(sys.argv) >= 7: 125 | allograph_file = sys.argv[6] 126 | else: 127 | allograph_file = None 128 | ambiguous = None 129 | 130 | print("Loading classifier...") 131 | 132 | in_file = open(classifier_file, 'rb') 133 | classifier = cPickle.load(in_file) 134 | in_file.close() 135 | 136 | if not isinstance(classifier, SymbolClassifier): 137 | print("Invalid classifier file!") 138 | return 139 | 140 | # get mapping 141 | classes_dict = classifier.classes_dict 142 | classes_l = classifier.classes_list 143 | n_classes = len(classes_l) 144 | 145 | print("Loading data...") 146 | 147 | if not test_only: 148 | #...loading traininig data... 149 | training, labels_l, att_types = load_dataset(training_file) 150 | if att_types is None: 151 | print("Error loading File <" + training_file + ">") 152 | return 153 | 154 | #...generate mapped labels... 155 | labels_train = get_mapped_labels(labels_l, classes_dict) 156 | else: 157 | training, labels_l, labels_train = (None, None, None) 158 | 159 | #...loading testing data... 160 | testing, test_labels_l, att_types = load_dataset(testing_file) 161 | 162 | if att_types is None: 163 | print("Error loading File <" + testing_file + ">") 164 | return 165 | 166 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 167 | 168 | if classifier.scaler is not None: 169 | print("Normalizing...") 170 | 171 | scaler = classifier.scaler 172 | if not test_only: 173 | training = scaler.transform(training) 174 | 175 | testing = scaler.transform(testing) 176 | 177 | if not allograph_file is None: 178 | print("Loading ambiguous file...") 179 | ambiguous, total_ambiguous = load_ambiguous(allograph_file, classes_dict, True) 180 | 181 | print("...A total of " + str(total_ambiguous) + " were found") 182 | 183 | start_time = time.time() 184 | 185 | print("Evaluating in multiple threads...") 186 | 187 | if not test_only: 188 | # Training data 189 | n_training_samples = np.size(training, 0) 190 | 191 | # ....first, evaluate samples on multiple threads 192 | predicted = parallel_evaluate(classifier, training, workers) 193 | #....on main thread, compute final statistics 194 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 195 | accuracy = (float(total_correct) / float(n_training_samples)) * 100 196 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 197 | print "Training Samples: " + str(n_training_samples) 198 | print "Training Results" 199 | print "Accuracy\tClass Average\tClass STD " 200 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 201 | else: 202 | print("...Skipping Training set...") 203 | 204 | n_testing_samples = np.size(testing, 0) 205 | 206 | #Testing data 207 | #....first, evaluate samples on multiple threads 208 | predicted = parallel_evaluate(classifier, testing, workers) 209 | #....on main thread, compute final statistics 210 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 211 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100 212 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 213 | print "Testing Samples: " + str(n_testing_samples) 214 | print "Testing Results" 215 | print "Accuracy\tClass Average\tClass STD " 216 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 217 | 218 | end_time = time.time() 219 | total_elapsed = end_time - start_time 220 | 221 | print "Total Elapsed: " + str(total_elapsed) 222 | 223 | #get confusion matrix... 224 | print("Generating confusion matrix for test...") 225 | confussion = compute_confusion_matrix(predicted, labels_test, n_classes) 226 | print("Saving results...") 227 | results_file = classifier_file + ".results.csv" 228 | save_evaluation_results(classes_l, confussion, results_file, ambiguous) 229 | 230 | #...check for ambiguous... 231 | if not ambiguous is None: 232 | #...recompute metrics considering ambiguous... 233 | confussion = compute_ambiguous_confusion_matrix(predicted, labels_test, n_classes, ambiguous) 234 | 235 | results_file = classifier_file + ".results_ambiguous.csv" 236 | save_evaluation_results(classes_l, confussion, results_file, ambiguous) 237 | 238 | print("...Finished!") 239 | 240 | if __name__ == '__main__': 241 | main() 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /src/parallel_prob_evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import time 26 | import math 27 | import numpy as np 28 | import cPickle 29 | import multiprocessing 30 | from sklearn.preprocessing import StandardScaler 31 | from dataset_ops import * 32 | from evaluation_ops import * 33 | from symbol_classifier import SymbolClassifier 34 | 35 | #===================================================================== 36 | # this program takes as input a training set, a testing set and a 37 | # classifier and runs the evaluation on parallel, then metrics 38 | # are computed. 39 | # 40 | # Created by: 41 | # - Kenny Davila (Feb 12, 2012-2014) 42 | # Modified By: 43 | # - Kenny Davila (Feb 12, 2012-2014) 44 | # - Kenny Davila (Feb 25, 2012-2014 45 | # - Now includes top-5 accuracy too 46 | # - Kenny Davila (Mar 11, 2015) 47 | # - Symbol classifier class now adopted 48 | # 49 | #===================================================================== 50 | 51 | 52 | def evaluate_data(classifier, data, first, last, top_n, results_queue): 53 | #use probabilistic evaluation to get probability per class... 54 | predicted = classifier.predict_proba(data) 55 | 56 | n_samples = np.size(data, 0) 57 | n_classes = np.size(predicted, 1) 58 | results = np.zeros((n_samples, top_n)) 59 | 60 | #find the top N values per sample... 61 | raw_classes = classifier.get_raw_classes() 62 | for i in range(n_samples): 63 | tempo_values = [] 64 | for k in range(n_classes): 65 | tempo_values.append((predicted[i, k], raw_classes[k])) 66 | 67 | #now, sort! 68 | tempo_values = sorted(tempo_values, reverse=True) 69 | 70 | #use the top-N 71 | for k in range(top_n): 72 | results[i, k] = tempo_values[k][1] 73 | 74 | results_queue.put((first, last, results)) 75 | 76 | def parallel_evaluate(classifier, dataset, workers, top_n): 77 | #...determine the ranges per worker... 78 | n_samples = np.size(dataset, 0) 79 | max_per_worker = int(math.ceil(float(n_samples) / float(workers))) 80 | 81 | #...prepare for parallel processing... 82 | results_queue = multiprocessing.Queue() 83 | processes = [] 84 | final_results = np.zeros((n_samples, top_n)) 85 | 86 | #...start parallel threads.... 87 | for i in range(workers): 88 | first_element = max_per_worker * i 89 | last_element = min(max_per_worker * (i + 1), n_samples) 90 | 91 | params = (classifier, dataset[first_element:last_element, :], first_element, last_element, top_n, results_queue) 92 | p = multiprocessing.Process(target=evaluate_data, args=params) 93 | p.start() 94 | 95 | processes.append(p) 96 | 97 | #...await for al results to be ready... 98 | for i in range(workers): 99 | first, last, predicted = results_queue.get() 100 | 101 | #merge with current final results... 102 | final_results[first:last, :] = predicted 103 | 104 | return final_results 105 | 106 | 107 | def main(): 108 | if len(sys.argv) < 5: 109 | print("Usage: python parallel_prob_evaluate.py training_set testing_set classifier normalize " + 110 | "workers [test_only] [ambiguous]") 111 | print("Where") 112 | print("\ttraining_set\t= Path to the file of the training set") 113 | print("\ttesting_set\t= Path to the file of the testing set") 114 | print("\tclassifier\t= File that contains the pickled classifier") 115 | print("\tworkers\t\t= Number of parallel threads to use") 116 | print("\ttest_only\t= Optional, Only execute for testing set") 117 | print("\tambiguous\t= Optional, file that contains the list of ambiguous") 118 | return 119 | 120 | training_file = sys.argv[1] 121 | testing_file = sys.argv[2] 122 | classifier_file = sys.argv[3] 123 | 124 | try: 125 | workers = int(sys.argv[4]) 126 | if workers < 1: 127 | print("Invalid number of workers") 128 | return 129 | except: 130 | print("Invalid number of workers") 131 | return 132 | 133 | if len(sys.argv) >= 6: 134 | try: 135 | test_only = int(sys.argv[5]) > 0 136 | except: 137 | print("Invalid value for test_only") 138 | return 139 | else: 140 | test_only = False 141 | 142 | if len(sys.argv) >= 7: 143 | allograph_file = sys.argv[6] 144 | else: 145 | allograph_file = None 146 | ambiguous = None 147 | 148 | print("Loading classifier...") 149 | 150 | in_file = open(classifier_file, 'rb') 151 | classifier = cPickle.load(in_file) 152 | in_file.close() 153 | 154 | if not isinstance(classifier, SymbolClassifier): 155 | print("Invalid classifier file!") 156 | return 157 | 158 | print("Loading data...") 159 | 160 | # get mapping 161 | classes_dict = classifier.classes_dict 162 | classes_l = classifier.classes_list 163 | n_classes = len(classes_l) 164 | 165 | if not test_only: 166 | #...loading traininig data... 167 | training, labels_l, att_types = load_dataset(training_file) 168 | if att_types is None: 169 | print("Error loading File <" + training_file + ">") 170 | return 171 | 172 | #...generate mapped labels... 173 | labels_train = get_mapped_labels(labels_l, classes_dict) 174 | else: 175 | training, labels_l, labels_train = (None, None, None) 176 | 177 | #...loading testing data... 178 | testing, test_labels_l, att_types = load_dataset(testing_file) 179 | 180 | if att_types is None: 181 | print("Error loading File <" + testing_file + ">") 182 | return 183 | 184 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 185 | 186 | if classifier.scaler is not None: 187 | print("Normalizing...") 188 | 189 | scaler = classifier.scaler 190 | if not test_only: 191 | training = scaler.transform(training) 192 | 193 | testing = scaler.transform(testing) 194 | 195 | if not allograph_file is None: 196 | print("Loading ambiguous file...") 197 | ambiguous, total_ambiguous = load_ambiguous(allograph_file, classes_dict, True) 198 | 199 | print("...A total of " + str(total_ambiguous) + " were found") 200 | 201 | print("Training data set: " + training_file) 202 | print("Testing data set: " + testing_file) 203 | print("Evaluating in multiple threads...") 204 | 205 | start_time = time.time() 206 | 207 | top_n = 5 208 | 209 | if not test_only: 210 | # Training data 211 | n_training_samples = np.size(training, 0) 212 | 213 | # ....first, evaluate samples on multiple threads 214 | predicted = parallel_evaluate(classifier, training, workers, top_n) 215 | 216 | print "Training Samples: " + str(n_training_samples) 217 | print "Training Results" 218 | print "Top\tAccuracy\tClass Average\tClass STD " 219 | 220 | # ....on main thread, compute final statistics 221 | for i in range(top_n): 222 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_train, 223 | n_classes, i + 1) 224 | accuracy = (float(total_correct) / float(n_training_samples)) * 100 225 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 226 | 227 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 228 | else: 229 | print("...Skipping Training set...") 230 | 231 | n_testing_samples = np.size(testing, 0) 232 | 233 | # Testing data 234 | # ....first, evaluate samples on multiple threads 235 | predicted = parallel_evaluate(classifier, testing, workers, top_n) 236 | 237 | #....on main thread, compute final statistics 238 | print "Testing Samples: " + str(n_testing_samples) 239 | print "Testing Results" 240 | print "Top\tAccuracy\tClass Average\tClass STD " 241 | 242 | for i in range(top_n): 243 | total_correct, counts_per_class, errors_per_class = compute_topn_error_counts(predicted, labels_test, 244 | n_classes, i + 1) 245 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100 246 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 247 | 248 | print str(i+1) + "\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 249 | 250 | end_time = time.time() 251 | total_elapsed = end_time - start_time 252 | 253 | print "Total Elapsed: " + str(total_elapsed) 254 | 255 | print("...Finished!") 256 | 257 | if __name__ == '__main__': 258 | main() 259 | 260 | -------------------------------------------------------------------------------- /src/random_forest_classify.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import time 26 | import cPickle 27 | import math 28 | from sklearn.ensemble import RandomForestClassifier 29 | from dataset_ops import * 30 | from evaluation_ops import * 31 | from symbol_classifier import SymbolClassifier 32 | 33 | #===================================================================== 34 | # this program takes as input a data set and using Random forest 35 | # it generates a set of decision trees to classify data. 36 | # 37 | # Created by: 38 | # - Kenny Davila (Feb 1, 2012-2014) 39 | # Modified By: 40 | # - Kenny Davila (Feb 1, 2014) 41 | # - Kenny Davila (March 11, 2015) 42 | # - Incorporated SymbolClassifier class 43 | # 44 | #===================================================================== 45 | 46 | def predict_in_chunks(data, classifier, max_chunk_size): 47 | n_samples = np.size(data, 0) 48 | all_predicted = np.zeros(n_samples) 49 | pos = 0 50 | while pos < n_samples: 51 | last = min(pos + max_chunk_size, n_samples) 52 | 53 | #...get predictions for current chuck... 54 | predicted = classifier.predict(data[pos:last]) 55 | #...add them to result... 56 | all_predicted[pos:last] = predicted 57 | 58 | #print("...Evaluated " + str(last) + " out of " + str(n_samples) ) 59 | 60 | pos = last 61 | 62 | return all_predicted 63 | 64 | 65 | 66 | def main(): 67 | #usage check 68 | if len(sys.argv) < 8: 69 | print("Usage: python random_forest_classify.py training_set testing_set N_trees max_D ") 70 | print(" max_feats type times [n_jobs] [out_file]") 71 | print("Where") 72 | print("\ttraining_set\t= Path to the file of the training set") 73 | print("\ttesting_set\t= Path to the file of the testing set") 74 | print("\tN_trees\t\t= Number of trees to use") 75 | print("\tmax_D\t\t= Maximum Depth") 76 | print("\tmax_feats\t= Maximum Features") 77 | print("\ttype\t\t= Type of Decision trees (criterion for splits)") 78 | print("\t\t\t\t0 - Gini") 79 | print("\t\t\t\t1 - Entropy") 80 | print("\ttimes\t\t= Number of times to repeat experiments") 81 | print ("\tn_jobs\t\t= Optional, number of parallel threads to use") 82 | print ("\tout_file\t= Optional, file where classifier will be stored") 83 | return 84 | 85 | print("Loading data....") 86 | #...load training data from file... 87 | train_filename = sys.argv[1] 88 | training, labels_l, att_types = load_dataset(train_filename) 89 | #...generate mapping... 90 | classes_dict, classes_l = get_label_mapping(labels_l) 91 | n_classes = len(classes_l) 92 | #...generate mapped labels... 93 | labels_train = get_mapped_labels(labels_l, classes_dict) 94 | 95 | #...load testing data from file... 96 | test_filename = sys.argv[2] 97 | testing, test_labels_l, att_types = load_dataset(test_filename) 98 | #...generate mapped labels... 99 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 100 | 101 | try: 102 | n_trees = int(sys.argv[3]) 103 | if n_trees < 1: 104 | print("Invalid N_trees value") 105 | return 106 | except: 107 | print("Invalid N_trees value") 108 | return 109 | 110 | try: 111 | max_D = int(sys.argv[4]) 112 | if max_D < 1: 113 | print("Invalid max_D value") 114 | return 115 | except: 116 | print("Invalid max_D value") 117 | return 118 | 119 | 120 | try: 121 | max_features = int(sys.argv[5]) 122 | if max_features < 1: 123 | print("Invalid max_feats value") 124 | return 125 | except: 126 | print("Invalid max_feats value") 127 | return 128 | 129 | try: 130 | criterion_t = 'gini' if int(sys.argv[6]) == 0 else 'entropy' 131 | except: 132 | print("Invalid type value") 133 | return 134 | 135 | try: 136 | times = int(sys.argv[7]) 137 | if times < 1: 138 | print("Invalid times value") 139 | return 140 | except: 141 | print("Invalid times value") 142 | return 143 | 144 | if len(sys.argv) >= 9: 145 | try: 146 | num_jobs = int(sys.argv[8]) 147 | except: 148 | print("Invalid number of jobs") 149 | return 150 | else: 151 | num_jobs = 1 152 | 153 | if len(sys.argv) >= 10: 154 | out_filename = sys.argv[9] 155 | else: 156 | out_filename = train_filename + ".best.RF" 157 | 158 | 159 | print("....Data loaded!") 160 | 161 | start_time = time.time() 162 | total_training_time = 0 163 | total_testing_time = 0 164 | 165 | all_training_accuracies = np.zeros(times) 166 | all_training_averages = np.zeros(times) 167 | all_training_stds = np.zeros(times) 168 | all_testing_accuracies = np.zeros(times) 169 | all_testing_averages = np.zeros(times) 170 | all_testing_stds = np.zeros(times) 171 | 172 | n_training_samples = np.size(training, 0) 173 | n_testing_samples = np.size(testing, 0) 174 | 175 | #tree = DecisionTreeClassifier(criterion="gini") 176 | #tree = DecisionTreeClassifier(criterion="entropy") 177 | #boosting = AdaBoostClassifier(DecisionTreeClassifier(criterion="gini", max_depth=25), n_estimators=50,learning_rate=1.0) 178 | 179 | print " \tTrain\tTrain\tTrain\tTest\tTest\tTest" 180 | print " # \tACC\tC. AVG\tC. STD\tACC\tC. AVG\tC. STD" 181 | best_forest_ref = None 182 | best_forest_accuracy = None 183 | 184 | for i in range(times): 185 | start_training_time = time.time() 186 | #print "Training #" + str(i + 1) 187 | forest = RandomForestClassifier(n_estimators=n_trees,criterion=criterion_t, 188 | max_features=max_features, max_depth=max_D,n_jobs=num_jobs) #n_jobs=-1? 189 | 190 | forest.fit(training, np.ravel(labels_train)) 191 | 192 | end_training_time = time.time() 193 | total_training_time += end_training_time - start_training_time 194 | 195 | start_testing_time = time.time() 196 | 197 | #...first, training error 198 | predicted = predict_in_chunks(training, forest, 1000) 199 | #...compute metrics... 200 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 201 | accuracy = (float(total_correct) / float(n_training_samples)) * 100 202 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 203 | all_training_accuracies[i] = accuracy 204 | all_training_averages[i] = avg_accuracy * 100.0 205 | all_training_stds[i] = std_accuracy * 100.0 206 | 207 | #...then, testing error 208 | #predicted = forest.predict(testing) 209 | predicted = predict_in_chunks(testing, forest, 1000) 210 | #...compute metrics ... 211 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 212 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100 213 | test_accuracy = accuracy 214 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 215 | all_testing_accuracies[i] = accuracy 216 | all_testing_averages[i] = avg_accuracy * 100.0 217 | all_testing_stds[i] = std_accuracy * 100.0 218 | 219 | end_testing_time = time.time() 220 | 221 | total_testing_time += end_testing_time - start_testing_time 222 | 223 | print(" " + str(i+1) + "\t" + str(round(all_training_accuracies[i], 3)) + "\t" + 224 | str(round(all_training_averages[i], 3)) + "\t" + 225 | str(round(all_training_stds[i], 3)) + "\t" + 226 | str(round(all_testing_accuracies[i], 3)) + "\t" + 227 | str(round(all_testing_averages[i], 3)) + "\t" + 228 | str(round(all_testing_stds[i], 3))) 229 | 230 | if best_forest_ref is None or best_forest_accuracy < test_accuracy: 231 | best_forest_ref = forest 232 | best_forest_accuracy = test_accuracy 233 | 234 | 235 | end_time = time.time() 236 | total_elapsed = end_time - start_time 237 | 238 | print(" Mean\t" + str(round(all_training_accuracies.mean(), 3)) + "\t" + 239 | str(round(all_training_averages.mean(), 3)) + "\t" + 240 | str(round(all_training_stds.mean(), 3)) + "\t" + 241 | str(round(all_testing_accuracies.mean(), 3)) + "\t" + 242 | str(round(all_testing_averages.mean(), 3)) + "\t" + 243 | str(round(all_testing_stds.mean(), 3))) 244 | 245 | print(" STD\t" + str(round(all_training_accuracies.std(), 3)) + "\t" + 246 | str(round(all_training_averages.std(), 3)) + "\t" + 247 | str(round(all_training_stds.std(), 3)) + "\t" + 248 | str(round(all_testing_accuracies.std(), 3)) + "\t" + 249 | str(round(all_testing_averages.std(), 3)) + "\t" + 250 | str(round(all_testing_stds.std(), 3))) 251 | 252 | #...Save to file... 253 | classifier = SymbolClassifier(SymbolClassifier.TypeRandomForest, best_forest_ref, classes_l, classes_dict) 254 | 255 | out_file = open(out_filename, 'wb') 256 | cPickle.dump(classifier, out_file, cPickle.HIGHEST_PROTOCOL) 257 | out_file.close() 258 | 259 | 260 | print("Total Elapsed: " + str(total_elapsed)) 261 | print("Mean Elapsed Time: " + str(total_elapsed / times)) 262 | print("Total Training time: " + str(total_training_time)) 263 | print("Mean Training time: " + str(total_training_time / times)) 264 | print("Total Evaluating time: " + str(total_testing_time)) 265 | print("Mean Evaluating time: " + str(total_testing_time / times)) 266 | 267 | print("Training File: " + train_filename) 268 | print("Training samples: " + str(n_training_samples)) 269 | print("Testing File: " + test_filename) 270 | print("Testing samples: " + str(n_testing_samples)) 271 | print("N_trees: " + str(n_trees)) 272 | print("Criterion: " + criterion_t) 273 | print("Max features: " + str(max_features)) 274 | print("Max Depth: " + str(max_D)) 275 | print("Finished!") 276 | 277 | if __name__ == '__main__': 278 | main() 279 | -------------------------------------------------------------------------------- /src/svm_lin_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import time 26 | import numpy as np 27 | import cPickle 28 | from sklearn import svm 29 | from sklearn.preprocessing import StandardScaler 30 | from dataset_ops import * 31 | from evaluation_ops import * 32 | from symbol_classifier import SymbolClassifier 33 | 34 | 35 | def main(): 36 | # usage check... 37 | if len(sys.argv) < 3: 38 | print("Usage: python svm_lin_classifier.py training testing [evaluate] [probab]") 39 | print("Where") 40 | print("\ttraining\t= Path to training set") 41 | print("\ttesting\t= Path to testing set") 42 | print("\tevaluate\t= Optional, run evaluation or just training ") 43 | print("\tprobab\t= Optional, make it a probabilistic classifier") 44 | print("\toutput\t= Optional, path to store trained classifier") 45 | return 46 | 47 | print("loading data") 48 | 49 | #...load training data from file... 50 | file_name = sys.argv[1] 51 | #file_name = 'ds_test_2012.txt' 52 | training, labels_l, att_types = load_dataset(file_name) 53 | #...generate mapping... 54 | classes_dict, classes_l = get_label_mapping(labels_l) 55 | n_classes = len(classes_l) 56 | #...generate mapped labels... 57 | labels_train = get_mapped_labels(labels_l, classes_dict) 58 | 59 | #now, load the testing set... 60 | file_name = sys.argv[2] 61 | testing, test_labels_l, att_types = load_dataset(file_name) 62 | #...generate mapped labels... 63 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 64 | 65 | if len(sys.argv) >= 4: 66 | try: 67 | evaluate = int(sys.argv[3]) > 0 68 | except: 69 | print("Invalid evaluate value") 70 | return 71 | else: 72 | evaluate = True 73 | 74 | if len(sys.argv) >= 5: 75 | try: 76 | make_probabilistic = int(sys.argv[4]) > 0 77 | except: 78 | print("Invalid probab value") 79 | return 80 | else: 81 | make_probabilistic = False 82 | 83 | if len(sys.argv) >= 6: 84 | out_filename = sys.argv[5] 85 | else: 86 | out_filename = sys.argv[1] + ".LSVM" 87 | 88 | print("Training with: " + sys.argv[1]) 89 | print("Training Samples: " + str(np.size(training, 0))) 90 | print("Testing with: " + sys.argv[2]) 91 | print("Testing Samples: " + str(np.size(testing, 0))) 92 | 93 | start_time = time.time() 94 | 95 | #try scaling... 96 | scaler = StandardScaler() 97 | training = scaler.fit_transform(training) 98 | testing = scaler.transform(testing) 99 | 100 | print("...training...") 101 | classifier = svm.SVC(kernel='linear', probability=make_probabilistic) 102 | classifier.fit(training, np.ravel(labels_train) ) 103 | 104 | print("...Saving to file...") 105 | symbol_classifier = SymbolClassifier(SymbolClassifier.TypeSVMLIN, classifier, classes_l, classes_dict, scaler, 106 | make_probabilistic) 107 | out_file = open(out_filename, 'wb') 108 | cPickle.dump(symbol_classifier, out_file, cPickle.HIGHEST_PROTOCOL) 109 | out_file.close() 110 | 111 | if not evaluate: 112 | #...finish early... 113 | end_time = time.time() 114 | total_elapsed = end_time - start_time 115 | 116 | print "Total Elapsed: " + str(total_elapsed) 117 | print("Finished!") 118 | return 119 | 120 | print("...Evaluating...") 121 | 122 | #...first, training error 123 | n_training_samples = np.size(training, 0) 124 | predicted = classifier.predict(training) 125 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 126 | accuracy = (float(total_correct) / float(n_training_samples)) * 100 127 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 128 | print "Training Samples: " + str(n_training_samples) 129 | print "Training Results" 130 | print "Accuracy\tClass Average\tClass STD " 131 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 132 | 133 | #...then, testing error 134 | n_testing_samples = np.size(testing, 0) 135 | predicted = classifier.predict(testing) 136 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 137 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100 138 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 139 | print "Testing Samples: " + str(n_testing_samples) 140 | print "Testing Results" 141 | print "Accuracy\tClass Average\tClass STD " 142 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 143 | 144 | end_time = time.time() 145 | total_elapsed = end_time - start_time 146 | 147 | print "Total Elapsed: " + str(total_elapsed) 148 | 149 | print("Finished") 150 | 151 | main() 152 | -------------------------------------------------------------------------------- /src/svm_rbf_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import time 26 | import numpy as np 27 | import cPickle 28 | from sklearn import svm 29 | from sklearn.preprocessing import StandardScaler 30 | from dataset_ops import * 31 | from evaluation_ops import * 32 | from symbol_classifier import SymbolClassifier 33 | 34 | 35 | def main(): 36 | #usage check... 37 | if len(sys.argv) < 5: 38 | print("Usage: python svm_rbf_classifier.py training testing C Gamma eval [probab]") 39 | print("Where") 40 | print("\ttraining\t= Path to training set") 41 | print("\ttesting\t\t= Path to testing set") 42 | print("\tC\t\t= C parameter of the RBF SVM") 43 | print("\tGamma\t\t= Gamma parameter of the RBF SVM") 44 | print("\teval\t\t= Optional, will not evaluate if equal to 1") 45 | print("\tprobab\t\t= Optional, make it a probabilistic classifier") 46 | print("\toutput\t= Optional, path to store trained classifier") 47 | return 48 | 49 | print("loading data") 50 | 51 | #...load training data from file... 52 | file_name = sys.argv[1] 53 | #file_name = 'ds_test_2012.txt' 54 | training, labels_l, att_types = load_dataset(file_name) 55 | #...generate mapping... 56 | classes_dict, classes_l = get_label_mapping(labels_l) 57 | n_classes = len(classes_l) 58 | #...generate mapped labels... 59 | labels_train = get_mapped_labels(labels_l, classes_dict) 60 | 61 | #now, load the testing set... 62 | file_name = sys.argv[2] 63 | testing, test_labels_l, att_types = load_dataset(file_name) 64 | #...generate mapped labels... 65 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 66 | 67 | try: 68 | rbf_C = float(sys.argv[3]) 69 | except: 70 | print("Invalid value for C") 71 | return 72 | 73 | try: 74 | rbf_Gamma = float(sys.argv[4]) 75 | except: 76 | print("Invalid value for Gamma") 77 | return 78 | 79 | if len(sys.argv) >= 6: 80 | try: 81 | evaluate = int(sys.argv[5]) > 0 82 | except: 83 | print("Invalid evaluate value") 84 | return 85 | else: 86 | evaluate = True 87 | 88 | if len(sys.argv) >= 7: 89 | try: 90 | make_probabilistic = int(sys.argv[6]) > 0 91 | except: 92 | print("Invalid probab value") 93 | return 94 | else: 95 | make_probabilistic = False 96 | 97 | if len(sys.argv) >= 8: 98 | out_filename = sys.argv[7] 99 | else: 100 | out_filename = sys.argv[1] + ".RSVM" 101 | 102 | print("Training with: " + sys.argv[1]) 103 | print("Testing with: " + sys.argv[2]) 104 | print("Current C = " + str(rbf_C)) 105 | print("Current Gamma = " + str(rbf_Gamma)) 106 | 107 | start_time = time.time() 108 | 109 | #try scaling... 110 | scaler = StandardScaler() 111 | training = scaler.fit_transform(training) 112 | testing = scaler.transform(testing) 113 | 114 | print("...training...") 115 | classifier = svm.SVC(C=rbf_C, gamma=rbf_Gamma, probability=make_probabilistic) 116 | classifier.fit(training, np.ravel(labels_train)) 117 | 118 | #store the classifier... 119 | print("...Saving to file...") 120 | symbol_classifier = SymbolClassifier(SymbolClassifier.TypeSVMRBF, classifier, classes_l, classes_dict, scaler, 121 | make_probabilistic) 122 | out_file = open(out_filename, 'wb') 123 | cPickle.dump(symbol_classifier, out_file, cPickle.HIGHEST_PROTOCOL) 124 | out_file.close() 125 | 126 | if not evaluate: 127 | #...finish early... 128 | end_time = time.time() 129 | total_elapsed = end_time - start_time 130 | 131 | print "Total Elapsed: " + str(total_elapsed) 132 | print("Finished!") 133 | return 134 | 135 | print("...Evaluating...") 136 | #...first, training error 137 | n_training_samples = np.size(training, 0) 138 | predicted = classifier.predict(training) 139 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 140 | accuracy = (float(total_correct) / float(n_training_samples)) * 100 141 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 142 | print "Training Samples: " + str(n_training_samples) 143 | print "Training Results" 144 | print "Accuracy\tClass Average\tClass STD " 145 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 146 | 147 | 148 | #...then, testing error 149 | n_testing_samples = np.size(testing, 0) 150 | predicted = classifier.predict(testing) 151 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 152 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100 153 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 154 | print "Testing Samples: " + str(n_testing_samples) 155 | print "Testing Results" 156 | print "Accuracy\tClass Average\tClass STD " 157 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 158 | 159 | end_time = time.time() 160 | total_elapsed = end_time - start_time 161 | 162 | print "Total Elapsed: " + str(total_elapsed) 163 | 164 | print("Finished") 165 | 166 | main() 167 | -------------------------------------------------------------------------------- /src/symbol_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import numpy as np 25 | from traceInfo import TraceInfo 26 | from mathSymbol import MathSymbol 27 | 28 | class SymbolClassifier: 29 | TypeRandomForest = 1 30 | TypeSVMLIN = 2 31 | TypeSVMRBF = 3 32 | 33 | def __init__(self, type, trained_classifier, classes_list, classes_dict, scaler=None, probabilistic=False): 34 | self.type = type 35 | self.trained_classifier = trained_classifier 36 | self.classes_list = classes_list 37 | self.classes_dict = classes_dict 38 | self.scaler = scaler 39 | self.probabilistic = probabilistic 40 | 41 | def predict(self, dataset): 42 | return self.trained_classifier.predict(dataset) 43 | 44 | def predict_proba(self, dataset): 45 | return self.trained_classifier.predict_proba(dataset) 46 | 47 | def get_raw_classes(self): 48 | return self.trained_classifier.classes_ 49 | 50 | def get_symbol_from_points(self, points_lists): 51 | 52 | traces = [] 53 | for trace_id, point_list in enumerate(points_lists): 54 | object_trace = TraceInfo(trace_id, point_list) 55 | 56 | traces.append(object_trace) 57 | 58 | # apply general trace pre processing... 59 | # 1) first step of pre processing: Remove duplicated points 60 | object_trace.removeDuplicatedPoints() 61 | 62 | # Add points to the trace... 63 | object_trace.addMissingPoints() 64 | 65 | # Apply smoothing to the trace... 66 | object_trace.applySmoothing() 67 | 68 | # it should not ... but ..... 69 | if object_trace.hasDuplicatedPoints(): 70 | # ...remove them! .... 71 | object_trace.removeDuplicatedPoints() 72 | 73 | new_symbol = MathSymbol(0, traces, '{Unknown}') 74 | 75 | # normalize size and locations 76 | new_symbol.normalize() 77 | 78 | return new_symbol 79 | 80 | def get_symbol_features(self, symbol): 81 | # get raw features 82 | features = symbol.getFeatures() 83 | 84 | # put them in python format 85 | mat_features = np.mat(features, dtype=np.float64) 86 | 87 | # automatically transform features 88 | if self.scaler is not None: 89 | mat_features = self.scaler.transform(mat_features) 90 | 91 | return mat_features 92 | 93 | def classify_points(self, points_lists): 94 | symbol = self.get_symbol_from_points(points_lists) 95 | 96 | return self.classify_symbol(symbol) 97 | 98 | def classify_points_prob(self, points_lists, top_n=None): 99 | symbol = self.get_symbol_from_points(points_lists) 100 | 101 | return self.classify_symbol_prob(symbol, top_n) 102 | 103 | def classify_symbol(self, symbol): 104 | features = self.get_symbol_features(symbol) 105 | 106 | predicted = self.trained_classifier.predict(features) 107 | 108 | return self.classes_list[predicted[0]] 109 | 110 | 111 | def classify_symbol_prob(self, symbol, top_n=None): 112 | features = self.get_symbol_features(symbol) 113 | 114 | try: 115 | predicted = self.trained_classifier.predict_proba(features) 116 | except: 117 | raise Exception("Classifier was not trained as probabilistic classifier") 118 | 119 | scores = sorted([(predicted[0, k], k) for k in range(predicted.shape[1])], reverse=True) 120 | 121 | tempo_classes = self.trained_classifier.classes_ 122 | n_classes = len(tempo_classes) 123 | if top_n is None or top_n > n_classes: 124 | top_n = n_classes 125 | 126 | confidences = [(self.classes_list[tempo_classes[scores[k][1]]], scores[k][0]) for k in range(top_n)] 127 | 128 | return confidences 129 | 130 | 131 | -------------------------------------------------------------------------------- /src/test_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | import sys 26 | import cPickle 27 | from symbol_classifier import SymbolClassifier 28 | 29 | def main(): 30 | # usage check... 31 | if len(sys.argv) < 2: 32 | print("Usage: python test_classifer.py classifier") 33 | print("Where") 34 | print("\tclassifier\t= Path to trained symbol classifier") 35 | return 36 | 37 | classifier_file = sys.argv[1] 38 | 39 | print("Loading classifier") 40 | 41 | in_file = open(classifier_file, 'rb') 42 | classifier = cPickle.load(in_file) 43 | in_file.close() 44 | 45 | if not isinstance(classifier, SymbolClassifier): 46 | print("Invalid classifier file!") 47 | return 48 | 49 | # get mapping 50 | classes_dict = classifier.classes_dict 51 | classes_l = classifier.classes_list 52 | n_classes = len(classes_l) 53 | 54 | # create test characters from points 55 | sample_x = [[(-0.8, -0.8), (0.1, 0.12), (0.8, 0.79)], [(-0.85, 0.79), (0.01, 0.005), (0.79, -0.83)]] 56 | sample_1 = [[(0.15, 0.7), (0.2, 1.0), (0.21, -1.2)], [(0.05, -1.19), (0.3, -1.25)]] 57 | sample_eq = [[(-1.5, -0.4), (1.5, -0.4)], [(-1.5, 0.4), (1.5, 0.4)]] 58 | 59 | # classify them 60 | class_x = classifier.classify_points(sample_x) 61 | class_1 = classifier.classify_points(sample_1) 62 | class_eq = classifier.classify_points(sample_eq) 63 | 64 | print("X classified as " + class_x) 65 | print("1 classified as " + class_1) 66 | print("= classified as " + class_eq) 67 | 68 | # now with confidence ... 69 | classes_x = classifier.classify_points_prob(sample_x, 3) 70 | classes_1 = classifier.classify_points_prob(sample_1, 3) 71 | classes_eq = classifier.classify_points_prob(sample_eq, 3) 72 | 73 | print("X top classes are: " + str(classes_x)) 74 | print("1 top classes are: " + str(classes_1)) 75 | print("= top classes are: " + str(classes_eq)) 76 | 77 | if __name__ == '__main__': 78 | main() -------------------------------------------------------------------------------- /src/test_classify_inkml.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2016 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | 25 | import sys 26 | import cPickle 27 | from load_inkml import * 28 | from symbol_classifier import SymbolClassifier 29 | 30 | def main(): 31 | # usage check... 32 | if len(sys.argv) < 3: 33 | print("Usage: python test_classifer.py classifier inkml_file [save_svg] [svg_path]") 34 | print("Where") 35 | print("\tclassifier\t= Path to trained symbol classifier") 36 | print("\tinkml_file\t= Path to inkml file to load") 37 | print("\tsave_svg\t= Optional, save SVG file for each symbol") 38 | print("\t\t0 - No (Default)") 39 | print("\t\t1 - Save SVG after pre-processing") 40 | print("\t\t2 - Save SVG before pre-processing") 41 | print("\tsvg_path\t= Optional, path prefix used for saved SVG files") 42 | return 43 | 44 | classifier_file = sys.argv[1] 45 | inkml_file = sys.argv[2] 46 | 47 | if len(sys.argv) >= 4: 48 | try: 49 | save_sgv = int(sys.argv[3]) 50 | if save_sgv < 0 or save_sgv > 2: 51 | print("Invalid value for save_svg") 52 | return 53 | except: 54 | print("Invalid value for save_svg") 55 | return 56 | else: 57 | save_sgv = 0 58 | 59 | if len(sys.argv) >= 5: 60 | svg_path = sys.argv[4] 61 | else: 62 | svg_path = "" 63 | 64 | print("Loading classifier") 65 | 66 | in_file = open(classifier_file, 'rb') 67 | classifier = cPickle.load(in_file) 68 | in_file.close() 69 | 70 | if not isinstance(classifier, SymbolClassifier): 71 | print("Invalid classifier file!") 72 | return 73 | 74 | ground_truth_available = True 75 | try: 76 | symbols = load_inkml(inkml_file, True ) 77 | except: 78 | ground_truth_available = False 79 | try: 80 | symbols = load_inkml(inkml_file, False) 81 | except: 82 | print("Failed processing: " + inkml_file) 83 | return 84 | 85 | # run the classifier for each symbol in the file 86 | for symbol in symbols: 87 | trace_ids = [str(trace.id) for trace in symbol.traces] 88 | print("Symbol id: " + str(symbol.id)) 89 | print("=> Traces: " + ", ".join(trace_ids)) 90 | print("=> Ground Truth Class: " + symbol.truth) 91 | 92 | # classify 93 | main_class = classifier.classify_symbol(symbol) 94 | print("=> Predicted Class: " + main_class) 95 | 96 | top_5 = classifier.classify_symbol_prob(symbol, 5) 97 | top_5_desc = [class_name + " ({0:.2f}%)".format(prob * 100) for class_name, prob in top_5] 98 | print("=> Top 5 classes: " + ",".join(top_5_desc)) 99 | print("") 100 | 101 | if save_sgv == 2: 102 | symbol.swapPoints() 103 | 104 | if save_sgv >= 1: 105 | symbol.saveAsSVG(svg_path + "symbol_" + str(symbol.id) + ".svg") 106 | 107 | 108 | print("Finished") 109 | 110 | if __name__ == '__main__': 111 | main() -------------------------------------------------------------------------------- /src/train_adaboost.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import sys 25 | import ctypes 26 | import time 27 | import numpy as np 28 | from dataset_ops import * 29 | from evaluation_ops import * 30 | 31 | #===================================================================== 32 | # this program takes as input a data set and through AdaBoost M1 it generates 33 | # a set of decision trees to classify data. the classifier is then 34 | # written to a file for further use. 35 | # 36 | # Created by: 37 | # - Kenny Davila (Dec 27, 2013) 38 | # Modified By: 39 | # - Kenny Davila (Dec 27, 2013) 40 | # - Kenny Davila (Jan 19, 2013) 41 | # - Dataset functions are now imported from external file 42 | # - Kenny Davila (Jan 27, 2013) 43 | # - Added mapping 44 | # - Save result to file 45 | # 46 | #===================================================================== 47 | 48 | c45_lib = ctypes.CDLL('./adaboost_c45.so') 49 | 50 | 51 | def ensemble_predict(classifier, data): 52 | n_samples = np.size(data, 0) 53 | predicted = np.zeros(n_samples) 54 | 55 | for k in range(n_samples): 56 | c_sample = data[k, :] 57 | c_sample_p = c_sample.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 58 | 59 | predicted[k] = c45_lib.boosted_c45_classify(classifier, c_sample_p) 60 | 61 | return predicted 62 | 63 | def main(): 64 | #usage check 65 | if len(sys.argv) < 4: 66 | print("Usage: python train_adaboost.py training_set testing_set rounds [mapped] [save_weights]") 67 | print("Where") 68 | print("\ttraining_set\t= Path to the file of the training set") 69 | print("\ttesting_set\t= Path to the file of the testing set") 70 | print("\trounds\t\t= Rounds to use for AdaBoost") 71 | print("\tmapped\t= Optional, reads class mapping if specified") 72 | print("\tsave_weights\t= Optional, save the weights used for boosting") 73 | return 74 | 75 | try: 76 | rounds = int(sys.argv[3]) 77 | if ( rounds < 1): 78 | print("Must be at least 1 round") 79 | return 80 | except: 81 | print("Invalid number of rounds") 82 | return 83 | 84 | #...load training data from file... 85 | #file_name = 'ds_2012.txt' 86 | file_name = sys.argv[1] 87 | #file_name = 'ds_test_2012.txt' 88 | training, labels_l, att_types = load_dataset(file_name) 89 | #...generate mapping... 90 | classes_dict, classes_l = get_label_mapping(labels_l) 91 | #...generate mapped labels... 92 | labels_train = get_mapped_labels(labels_l, classes_dict) 93 | 94 | if len(sys.argv) >= 5: 95 | try: 96 | mapped = int(sys.argv[4]) > 0 97 | except: 98 | mapped = False 99 | 100 | #...used for the clustered training set 101 | if mapped: 102 | o_classes_l, o_classes_dict, class_mapping = load_label_mapping(file_name + ".mapping.txt") 103 | else: 104 | mapped =False 105 | 106 | if len(sys.argv) >= 6: 107 | try: 108 | save_weights = int(sys.argv[5]) > 0 109 | except: 110 | save_weights = False 111 | else: 112 | save_weights = False 113 | 114 | #...load testing data from file... 115 | #file_name = 'ds_test_2012.txt' 116 | file_name = sys.argv[2] 117 | #file_name = 'ds_2012.txt' 118 | testing, test_labels_l, att_types = load_dataset(file_name) 119 | #...generate mapped labels... 120 | if mapped: 121 | labels_test = get_mapped_labels(test_labels_l, o_classes_dict) 122 | else: 123 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 124 | 125 | #... for learning.... 126 | n_samples = np.size(training, 0) 127 | n_atts = np.size(att_types, 0) 128 | n_classes = len( classes_l ) 129 | print "Samples : " + str(n_samples) 130 | print "Atts: " + str(n_atts) 131 | print "Classes: " + str(n_classes) 132 | 133 | #...create distribution 134 | init_value = 1.0 / float(n_samples) 135 | distrib_train = np.zeros( (n_samples, 1), dtype=np.float64 ) 136 | for i in range(n_samples): 137 | distrib_train[i, 0] = init_value 138 | 139 | #prepare to pass data to C-side 140 | samples_p = training.ctypes.data_as( ctypes.POINTER( ctypes.c_double ) ) 141 | 142 | labels_p = labels_train.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) ) 143 | atts_p = att_types.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) ) 144 | 145 | #for testing 146 | n_test_samples = np.size(testing, 0) 147 | 148 | start_time = time.time() 149 | 150 | #test construction... 151 | m_rounds = rounds 152 | ensemble = c45_lib.created_boosted_c45( samples_p, labels_p, atts_p, n_samples, n_atts, n_classes, m_rounds, 5, 1, "Fold #1") 153 | 154 | print "Saving..." 155 | c45_lib.boosted_c45_save(ensemble, sys.argv[1] + ".bc45") 156 | 157 | if save_weights: 158 | c45_lib.boosted_c45_save_training_weights(ensemble, sys.argv[1] + ".weights.bin", 0) 159 | c45_lib.boosted_c45_save_training_weights(ensemble, sys.argv[1] + ".weights.txt", 1) 160 | 161 | #test accuracy of tree... 162 | 163 | #...training... 164 | predicted = ensemble_predict(ensemble, training) 165 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 166 | accuracy = (float(total_correct) / float(n_samples)) * 100 167 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 168 | print "Training Samples: " + str(n_samples) 169 | print "Training Results" 170 | print "Accuracy\tClass Average\tClass STD " 171 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 172 | 173 | #...testing... 174 | n_testing_samples = np.size(testing, 0) 175 | predicted = ensemble_predict(ensemble, testing) 176 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 177 | accuracy = (float(total_correct) / float(n_testing_samples)) * 100 178 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 179 | print "Testing Samples: " + str(n_testing_samples) 180 | print "Testing Results" 181 | print "Accuracy\tClass Average\tClass STD " 182 | print str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 183 | 184 | 185 | c45_lib.release_boosted_c45( ensemble ) 186 | 187 | end_time = time.time() 188 | total_elapsed = end_time - start_time 189 | print "Total Elapsed: " + str(total_elapsed) 190 | print("Training: " + sys.argv[1]) 191 | print("Testing: " + sys.argv[2]) 192 | 193 | main() 194 | 195 | 196 | -------------------------------------------------------------------------------- /src/train_c45.py: -------------------------------------------------------------------------------- 1 | """ 2 | DPRL Math Symbol Recognizers 3 | Copyright (c) 2012-2014 Kenny Davila, Richard Zanibbi 4 | 5 | This file is part of DPRL Math Symbol Recognizers. 6 | 7 | DPRL Math Symbol Recognizers is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | DPRL Math Symbol Recognizers is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with DPRL Math Symbol Recognizers. If not, see . 19 | 20 | Contact: 21 | - Kenny Davila: kxd7282@rit.edu 22 | - Richard Zanibbi: rlaz@cs.rit.edu 23 | """ 24 | import ctypes 25 | import time 26 | import sys 27 | import numpy as np 28 | from dataset_ops import * 29 | from evaluation_ops import * 30 | 31 | #===================================================================== 32 | # this program takes as input a data set and through AdaBoost M1 it generates 33 | # a set of decision trees to classify data. The data set is divided in folds 34 | # and cross validation is then applied, the best classifier of all the 35 | # folds is then written to a file for further use. 36 | # 37 | # Created by: 38 | # - Kenny Davila (Dec 27, 2013) 39 | # Modified By: 40 | # - Kenny Davila (Jan 16, 2013) 41 | # - Kenny Davila (Jan 19, 2013) 42 | # - Dataset functions are now imported from external file 43 | # 44 | #===================================================================== 45 | 46 | c45_lib = ctypes.CDLL('./adaboost_c45.so') 47 | 48 | 49 | def get_majority_class(labels, weights, classes): 50 | l_counts = {} 51 | n_samples = np.size( labels, 0) 52 | majority_class = None 53 | majority_w = 0 54 | for i in range(n_samples): 55 | label = labels[i, 0] 56 | 57 | if label in l_counts: 58 | l_counts[ label ] += weights[i, 0] 59 | else: 60 | l_counts[ label ] = weights[i, 0] 61 | 62 | if majority_class == None: 63 | majority_class = label 64 | majority_w = l_counts[ label ] 65 | else: 66 | if majority_w < l_counts[ label ]: 67 | majority_w = l_counts[ label ] 68 | majority_class = label 69 | 70 | """ 71 | for l in l_counts: 72 | print str(classes[l]) + " - " + str(l) + " - " + str(l_counts[l] ) 73 | """ 74 | 75 | return majority_class, majority_w 76 | 77 | 78 | def tree_predict(root, data): 79 | n_samples = np.size(data, 0) 80 | predicted = np.zeros(n_samples) 81 | 82 | for k in range(n_samples): 83 | c_sample = data[k, :] 84 | c_sample_p = c_sample.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 85 | 86 | predicted[k] = c45_lib.c45_node_evaluate(root, c_sample_p) 87 | 88 | return predicted 89 | 90 | 91 | def main(): 92 | if len(sys.argv) < 3: 93 | print("Usage: python train_c45.py training_set testing_set [mapped]") 94 | print("Where") 95 | print("\ttraining_set\t= Path to the file of the training set") 96 | print("\ttesting_set\t= Path to the file of the testing set") 97 | print("\tmapped\t= Optional, reads class mapping if specified") 98 | return 99 | 100 | #...load training data from file... 101 | #file_name = 'ds_2012.txt' 102 | file_name = sys.argv[1] 103 | training, labels_l, att_types = load_dataset(file_name) 104 | #...generate mapping... 105 | classes_dict, classes_l = get_label_mapping(labels_l) 106 | #...generate mapped labels... 107 | labels_train = get_mapped_labels(labels_l, classes_dict) 108 | 109 | #...check for class mapping... 110 | if len(sys.argv) >= 4: 111 | try: 112 | mapped = int(sys.argv[3]) > 0 113 | except: 114 | mapped = False 115 | 116 | #...will over-write the class list and class dict, and will extract the class mapping 117 | #...used for the clustered training set 118 | if mapped: 119 | o_classes_l, o_classes_dict, class_mapping = load_label_mapping(file_name + ".mapping.txt") 120 | 121 | else: 122 | mapped = False 123 | 124 | 125 | #...load testing data from file... 126 | #file_name = 'ds_test_2012.txt' 127 | file_name = sys.argv[2] 128 | testing, test_labels_l, att_types = load_dataset(file_name) 129 | #...generate mapped labels... 130 | if mapped: 131 | labels_test = get_mapped_labels(test_labels_l, o_classes_dict) 132 | else: 133 | labels_test = get_mapped_labels(test_labels_l, classes_dict) 134 | #print("UNCOMMENT MAPPING OF TESTING LABELS") 135 | 136 | 137 | 138 | print("Training: " + sys.argv[1]) 139 | print("Testing: " + sys.argv[2]) 140 | 141 | #for learning and testing.... 142 | n_samples = np.size(training, 0) 143 | n_test_samples = np.size(testing, 0) 144 | n_atts = np.size(att_types, 0) 145 | n_classes = len( classes_l ) 146 | print "Training Samples : " + str(n_samples) 147 | print "Testing Samples : " + str(n_test_samples) 148 | print "Atts: " + str(n_atts) 149 | print "Classes: " + str(n_classes) 150 | 151 | #...create distribution 152 | init_value = 1.0 / float(n_samples) 153 | distrib_train = np.zeros((n_samples, 1), dtype=np.float64) 154 | for i in range(n_samples): 155 | distrib_train[i, 0] = init_value 156 | 157 | #prepare to pass data to C-side 158 | samples_p = training.ctypes.data_as( ctypes.POINTER( ctypes.c_double ) ) 159 | 160 | labels_p = labels_train.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) ) 161 | atts_p = att_types.ctypes.data_as( ctypes.POINTER( ctypes.c_int ) ) 162 | distrib_p = distrib_train.ctypes.data_as( ctypes.POINTER( ctypes.c_double ) ) 163 | 164 | 165 | 166 | #get majority class... 167 | majority_l, majority_w = get_majority_class(labels_train, distrib_train, classes_l) 168 | 169 | parent = None 170 | majority = int(majority_l) 171 | #print majority 172 | #print type(majority) 173 | 174 | start_time = time.time() 175 | 176 | #test construction... 177 | times = 1 178 | for i in range( times ): 179 | root = c45_lib.c45_tree_construct(samples_p, labels_p, atts_p, distrib_p, n_samples, n_atts, n_classes, majority, 5) 180 | 181 | """ 182 | #un-comment to test file 183 | print "to write ... " 184 | c45_lib.c45_save_to_file( root, "tree_2012.tree" ) 185 | print "to read ... " 186 | root = c45_lib.c45_load_from_file( "tree_2012.tree" ) 187 | #return 188 | """ 189 | 190 | #test accuracy of tree... 191 | predicted = tree_predict(root, training) 192 | 193 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 194 | accuracy = (float(total_correct) / float(n_samples)) * 100 195 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 196 | print "Results\tAccuracy\tClass Average\tClass STD " 197 | print "BP. Training\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 198 | 199 | #...testing... 200 | predicted = tree_predict(root, testing) 201 | 202 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 203 | accuracy = (float(total_correct) / float(n_test_samples)) * 100 204 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 205 | print "BP. Testing\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 206 | 207 | #do pruning 208 | c45_lib.c45_prune_tree.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_double ) 209 | #c45_lib.c45_prune_tree( root, 0, 0.25 ) 210 | c45_lib.c45_prune_tree(root, 1, 0.25) 211 | 212 | #test accuracy of tree... 213 | #...training... 214 | predicted = tree_predict(root, training) 215 | 216 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_train, n_classes) 217 | accuracy = (float(total_correct) / float(n_samples)) * 100 218 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 219 | print "AP. Training\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 220 | 221 | #...testing... 222 | predicted = tree_predict(root, testing) 223 | 224 | total_correct, counts_per_class, errors_per_class = compute_error_counts(predicted, labels_test, n_classes) 225 | accuracy = (float(total_correct) / float(n_test_samples)) * 100 226 | avg_accuracy, std_accuracy = get_average_class_accuracy(counts_per_class, errors_per_class, n_classes) 227 | print "AP. Testing\t" + str(accuracy) + "\t" + str(avg_accuracy * 100.0) + "\t" + str(std_accuracy * 100.0) 228 | 229 | 230 | c45_lib.c45_node_release( root, 1 ) 231 | 232 | end_time = time.time() 233 | total_elapsed = end_time - start_time 234 | mean_elapsed = total_elapsed / times 235 | print "Total Elapsed: " + str(total_elapsed) 236 | print "Mean Elapsed: " + str(mean_elapsed) 237 | print("") 238 | 239 | main() 240 | 241 | 242 | --------------------------------------------------------------------------------