├── .gitattributes ├── .gitignore ├── README.md ├── balance.py ├── categories.txt ├── classes.txt ├── classes ├── classes_2012.txt └── classes_2013.txt ├── data ├── CROHME_full_v2.zip └── CROHME_papers │ ├── CROHME_ICDAR_2011.pdf │ ├── CROHME_ICDAR_2013.pdf │ ├── CROHME_ICFHR_2012.pdf │ ├── CROHME_ICFHR_2014.pdf │ └── Thumbs.db ├── extract.py ├── extract_hog.py ├── extract_phog.py ├── histograms ├── all_labels_distribution.png ├── capital_letters_distribution.png ├── digits_distribution.png ├── greek_letters_distribution.png ├── labels_histogram.txt ├── lowercase_letters_distribution.png ├── math_symbols_distribution.png └── special_characters_distribution.png ├── one_hot.py ├── requirements.txt ├── visualization.png └── visualize.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/CROHME_full_v2/ 3 | outputs/ 4 | HOG/ 5 | 6 | 7 | # Windows image file caches 8 | Thumbs.db 9 | ehthumbs.db 10 | 11 | # Folder config file 12 | Desktop.ini 13 | 14 | # Recycle Bin used on file shares 15 | $RECYCLE.BIN/ 16 | 17 | # Windows Installer files 18 | *.cab 19 | *.msi 20 | *.msm 21 | *.msp 22 | 23 | # Windows shortcuts 24 | *.lnk 25 | 26 | # ========================= 27 | # Operating System Files 28 | # ========================= 29 | 30 | # OSX 31 | # ========================= 32 | 33 | .DS_Store 34 | .AppleDouble 35 | .LSOverride 36 | 37 | # Thumbnails 38 | ._* 39 | 40 | # Files that might appear in the root of a volume 41 | .DocumentRevisions-V100 42 | .fseventsd 43 | .Spotlight-V100 44 | .TemporaryItems 45 | .Trashes 46 | .VolumeIcon.icns 47 | 48 | # Directories potentially created on remote AFP share 49 | .AppleDB 50 | .AppleDesktop 51 | Network Trash Folder 52 | Temporary Items 53 | .apdisk 54 | *.pyc 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Abstract 2 | CROHME datasets originally exhibit features designed for _Online-handwritting_ recognition task. 3 | Apart from drawn traces being encoded, inkml files also contain trace drawing time captured. 4 | So we need to extract new feature map, namely matrices of pixel intensities. 5 | 6 | The following scripts will get you started with _Offline math symbols recognition_ task. 7 | 8 | 9 | ## Setup 10 | All code is compatible with Python **3.5.*** version. 11 | 12 | 1. Extract **_CROHME_full_v2.zip_** (found inside **_data_** directory) contents before running any of the above scripts. 13 | 14 | 2. Install specified dependencies with pip (Python Package Manager) using the following shell command: 15 | ``` 16 | pip install -U -r requirements.txt 17 | ``` 18 | 19 | 20 | ## Scripts info 21 | 1. **_extract.py_** 22 | - Extracts trace groups from inkml files. 23 | - Converts extracted trace groups into images. Images are **square shaped** bitmaps with only black (value 0) and white (value 1) pixels. Black color denotes patterns (ROI). 24 | - Labels those images (according to inkml files). 25 | - Flattens images to one-dimensional vectors. 26 | - Converts labels to one-hot format. 27 | - Dumps training and testing sets separately into **_outputs_** folder. 28 | 29 | **Command line arguments**: -b [BOX_SIZE] -d [DATASET_VERSION] -c [CATEGORY] -t [THICKNESS] 30 | 31 | **Example usage**: `python extract.py -b 50 -d 2011 2012 2013 -c digits lowercase_letters operators -t 5` 32 | 33 | **Caution**: Script doesn't work properly for images bigger than 200x200 (For yet unknown reason). 34 | 35 | 2. **_balance.py_** script balances the overall distribution of classes. 36 | 37 | **Command line arguments**: -b [BOX_SIZE] -ub [UPPER_BOUND][Optional] 38 | 39 | **Example usage**: `python balance.py -b 50 -ub 6000` 40 | 41 | 3. **_visualize.py_** script will plot single figure depicting a random batch of **extracted** data. 42 | 43 | **Command line arguments**: -b [BOX_SIZE] -n [N_SAMPLES] -c [COLUMNS] 44 | 45 | **Example usage**: `python visualize.py -b 50 -n 40 -c 8` 46 | 47 | **Sample Plot**: 48 | ![crohme_extractor_plot](https://user-images.githubusercontent.com/22115481/30137213-9c619b0a-9362-11e7-839a-624f08e606f7.png) 49 | 50 | 3. **_extract_hog.py_** script will extract **HoG features**. 51 | This script accepts 1 command line argument, namely **hog_cell_size**. 52 | **hog_cell_size** corresponds to **pixels_per_cell** parameter of **skimage.feature.hog** function. 53 | We use **skimage.feature.hog** to extract HoG features. 54 | Example of script execution: `python extract_hog.py 5` <-- pixels_per_cell=(5, 5) 55 | This script loads data previously dumped by **_extract.py_** and again dumps its outputs(train, test) separately. 56 | 57 | 58 | 4. **_extract_phog.py_** script will extract **PHoG features**. 59 | For PHoG features, HoG feature maps using different cell sizes are concatenated into a single feature vector. 60 | So this script takes arbitrary number of **hog_cell_size** values(HoG features have to be previously extracted with **_extract_hog.py_**) 61 | Example of script execution: `python extract_phog.py 5 10 20` <-- loads HoGs with respectively 5x5, 10x10, 20x20 cell sizes. 62 | 63 | 64 | 5. **_histograms_** folder contains histograms representing **distribution of labels** based on different label categories. These diagrams help you better understand extracted data. 65 | 66 | 67 | ## Distribution of classes 68 | ![all_labels_distribution](https://cloud.githubusercontent.com/assets/22115481/26694312/413fb646-4707-11e7-943c-b8ecebd0c986.png) 69 | Labels were combined from **_train_** and **_test_** sets. 70 | -------------------------------------------------------------------------------- /balance.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script makes class_infos more balanced. 3 | ''' 4 | import os 5 | import argparse 6 | import pickle 7 | import one_hot 8 | from random import shuffle 9 | import numpy as np 10 | from keras.preprocessing.image import ImageDataGenerator 11 | import matplotlib.pyplot as plt 12 | 13 | outputs_dir = 'outputs' 14 | train_out_dir = os.path.join(outputs_dir, 'train') 15 | test_out_dir = os.path.join(outputs_dir, 'test') 16 | 17 | ap = argparse.ArgumentParser() 18 | ap.add_argument('-b', '--box_size', required=True, help="Specify a length of square box side.") 19 | ap.add_argument('-ub', '--upper_bound', required=False, help="Specify the upper bound which essentially is the eventual distribution of each class after balance.") 20 | args = vars(ap.parse_args()) 21 | 22 | box_size = int(args.get('box_size')) 23 | # Balance ratio 24 | b_ratio = 1.0 25 | batch_size = 32 26 | 27 | # Load data 28 | with open(os.path.join(train_out_dir, 'train.pickle'), 'rb') as data: 29 | train = pickle.load(data) 30 | with open(os.path.join(test_out_dir, 'test.pickle'), 'rb') as data: 31 | test = pickle.load(data) 32 | 33 | print('Training set size:', len(train)) 34 | print('Testing set size:', len(test)) 35 | 36 | # Initialize keras image generator 37 | datagen = ImageDataGenerator(rotation_range=10, shear_range=0.1) 38 | 39 | # Load all class_infos that were extracted 40 | classes = [label.strip() for label in list(open('classes.txt', 'r'))] 41 | class_infos = [{'class': class_name, 'occurrences': 0} for class_name in classes] 42 | 43 | for train_sample in train: 44 | label = one_hot.decode(train_sample['label'], classes) 45 | # Find index of this label in class_infos list 46 | class_idx = classes.index(label) 47 | # Update the number of occurrences 48 | class_infos[class_idx]['occurrences'] += 1 49 | 50 | # Sort class_infos by occurrences 51 | class_infos = sorted(class_infos, key=lambda class_info: class_info['occurrences'], reverse=True) 52 | if not args.get('upper_bound'): 53 | max_occurances = class_infos[0]['occurrences'] 54 | else: 55 | max_occurances = int(args.get('upper_bound')) 56 | 57 | min_occurances = class_infos[len(class_infos)-1]['occurrences'] 58 | for class_info in class_infos: 59 | class_info['deviation'] = max_occurances - class_info['occurrences'] 60 | 61 | print('====================== Distribution of classes ======================') 62 | for label in class_infos: 63 | print('CLASS: {}; occurrences: {}; deviation: {}'.format(label['class'], label['occurrences'], label['deviation'])) 64 | print('Max occurrences:', max_occurances) 65 | print('Min occurrences:', min_occurances) 66 | print('=====================================================================') 67 | 68 | for class_info in class_infos: 69 | # Get one_hot representation of current class 70 | hot_class = one_hot.encode(class_info['class'], classes) 71 | # Calculate how many new samples have to be generated 72 | how_many_gen = int(round(class_info['deviation'] * b_ratio)) 73 | print('\nClass: {}; How many new samples to generate: {}'.format(class_info['class'], how_many_gen)) 74 | # Create images and labels for data representing current class 75 | images = np.asarray([train_rec['features'].reshape((box_size, box_size, 1)) for train_rec in train if np.array_equal(train_rec['label'], hot_class)]) 76 | labels = np.tile(hot_class, reps=(class_info['occurrences'], 1)) 77 | 78 | # Generate new images 79 | # datagen.fit(images) 80 | new_data = [] 81 | for X_batch, y_batch in datagen.flow(images, labels, batch_size=batch_size): 82 | # # Plot newly generated images 83 | # n_cols = 4 84 | # n_rows = int(np.ceil(len(X_batch) / 4)) 85 | # figure, axis_arr = plt.subplots(n_rows, n_cols, figsize=(12, 4)) 86 | # for row in range(n_rows): 87 | # for col in range(n_cols): 88 | # axis_arr[row, col].imshow(X_batch[row*n_cols + col].reshape((box_size, box_size)), cmap='gray') 89 | # # Remove explicit axises 90 | # # axis_arr[row, col].axis('off') 91 | # plt.show() 92 | 93 | # If enough samples were generated 94 | if len(new_data) >= how_many_gen: 95 | break; 96 | for idx in range(len(X_batch)): 97 | new_record = {'features': X_batch[idx].flatten(), 'label': y_batch[idx]} 98 | new_data.append(new_record) 99 | 100 | print('CLASS: {}; NEW records: {};'.format(class_info['class'], len(new_data))) 101 | # Append newly generated data & shuffle given dataset 102 | train += new_data 103 | 104 | # Shuffle sets 105 | print('\nShuffling training set ...') 106 | shuffle(train) 107 | 108 | print('\nNEW Training set size:', len(train)) 109 | 110 | with open(os.path.join(train_out_dir, 'train.pickle'), 'wb') as f: 111 | pickle.dump(train, f, protocol=pickle.HIGHEST_PROTOCOL) 112 | print('Training data has been successfully dumped into', f.name) 113 | with open(os.path.join(test_out_dir, 'test.pickle'), 'wb') as f: 114 | pickle.dump(test, f, protocol=pickle.HIGHEST_PROTOCOL) 115 | print('Testing data has been successfully dumped into', f.name) 116 | 117 | print('\n\n# Like our facebook page @ https://www.facebook.com/mathocr/') 118 | -------------------------------------------------------------------------------- /categories.txt: -------------------------------------------------------------------------------- 1 | all: ! ( ) + , - . / 0 1 2 3 4 5 6 7 8 9 = A B C E F G H I L M N P R S T V X Y [ \Delta \alpha \beta \cos \div \exists \forall \gamma \geq \gt \in \infty \int \lambda \ldots \leq \lim \log \lt \mu \neq \phi \pi \pm \prime \rightarrow \sigma \sin \sqrt \sum \tan \theta \times \{ \} ] a b c d e f g h i j k l m n o p q r s t u v w x y z | 2 | digits: 0 1 2 3 4 5 6 7 8 9 3 | operators: ( ) [ ] + - = 4 | lowercase_letters: a b c d e f g h i j k l m n o p q r s t u v w x y z 5 | uppercase_letters: A B C E F G H I L M N P R S T V X Y 6 | greek: \Delta \alpha \beta \gamma \lambda \mu \phi \pi \sigma \theta 7 | miscellaneous: \times = \in \pm ! \rightarrow \sin \cos \tan \lim \log \exists \forall \sqrt / \geq \gt \leq \lt \neq \div 8 | symbols: \int \infty \sqrt -------------------------------------------------------------------------------- /classes.txt: -------------------------------------------------------------------------------- 1 | ( 2 | ) 3 | + 4 | - 5 | 0 6 | 1 7 | 2 8 | 3 9 | 4 10 | 5 11 | 6 12 | 7 13 | 8 14 | 9 15 | = 16 | [ 17 | \infty 18 | \int 19 | \sqrt 20 | ] 21 | a 22 | b 23 | c 24 | d 25 | e 26 | f 27 | g 28 | h 29 | i 30 | j 31 | k 32 | l 33 | m 34 | n 35 | o 36 | p 37 | q 38 | r 39 | s 40 | t 41 | u 42 | v 43 | w 44 | x 45 | y 46 | z 47 | -------------------------------------------------------------------------------- /classes/classes_2012.txt: -------------------------------------------------------------------------------- 1 | 8 2 | \ldots 3 | ( 4 | \theta 5 | \in 6 | = 7 | \forall 8 | p 9 | \times 10 | e 11 | B 12 | k 13 | g 14 | \leq 15 | i 16 | m 17 | X 18 | \int 19 | [ 20 | \tan 21 | x 22 | \div 23 | r 24 | \beta 25 | d 26 | Y 27 | \infty 28 | \gamma 29 | t 30 | \geq 31 | 5 32 | ] 33 | b 34 | \exists 35 | \phi 36 | 6 37 | . 38 | F 39 | \pi 40 | 9 41 | \pm 42 | j 43 | n 44 | 0 45 | , 46 | \lim 47 | - 48 | \log 49 | A 50 | ! 51 | \gt 52 | 7 53 | \{ 54 | ) 55 | z 56 | 2 57 | \sin 58 | \sqrt 59 | / 60 | \rightarrow 61 | C 62 | + 63 | \alpha 64 | y 65 | \lt 66 | a 67 | 3 68 | c 69 | \cos 70 | 1 71 | \neq 72 | 4 73 | f 74 | \} 75 | \sum 76 | -------------------------------------------------------------------------------- /classes/classes_2013.txt: -------------------------------------------------------------------------------- 1 | V 2 | ] 3 | \pi 4 | F 5 | X 6 | \lambda 7 | \forall 8 | 8 9 | \div 10 | p 11 | s 12 | = 13 | R 14 | \log 15 | b 16 | \sigma 17 | B 18 | j 19 | 9 20 | . 21 | h 22 | z 23 | \gt 24 | \infty 25 | 2 26 | y 27 | \sqrt 28 | L 29 | N 30 | f 31 | l 32 | - 33 | H 34 | \lim 35 | t 36 | 6 37 | [ 38 | g 39 | M 40 | \geq 41 | a 42 | \prime 43 | G 44 | \mu 45 | \Delta 46 | \} 47 | | 48 | \leq 49 | \tan 50 | I 51 | i 52 | u 53 | 7 54 | C 55 | S 56 | \theta 57 | o 58 | \int 59 | \exists 60 | r 61 | T 62 | \alpha 63 | e 64 | q 65 | x 66 | \phi 67 | \sum 68 | m 69 | 5 70 | ( 71 | \{ 72 | \sin 73 | + 74 | \in 75 | Y 76 | 3 77 | \lt 78 | ) 79 | 4 80 | \times 81 | / 82 | P 83 | , 84 | w 85 | \gamma 86 | k 87 | 1 88 | ! 89 | \rightarrow 90 | \pm 91 | n 92 | 0 93 | A 94 | d 95 | \cos 96 | \beta 97 | \ldots 98 | \neq 99 | E 100 | v 101 | c 102 | -------------------------------------------------------------------------------- /data/CROHME_full_v2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/data/CROHME_full_v2.zip -------------------------------------------------------------------------------- /data/CROHME_papers/CROHME_ICDAR_2011.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/data/CROHME_papers/CROHME_ICDAR_2011.pdf -------------------------------------------------------------------------------- /data/CROHME_papers/CROHME_ICDAR_2013.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/data/CROHME_papers/CROHME_ICDAR_2013.pdf -------------------------------------------------------------------------------- /data/CROHME_papers/CROHME_ICFHR_2012.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/data/CROHME_papers/CROHME_ICFHR_2012.pdf -------------------------------------------------------------------------------- /data/CROHME_papers/CROHME_ICFHR_2014.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/data/CROHME_papers/CROHME_ICFHR_2014.pdf -------------------------------------------------------------------------------- /data/CROHME_papers/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/data/CROHME_papers/Thumbs.db -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | # Use like: 2 | # python extract.py -b 28 -d 2011 2012 2013 -c digits symbols -t 20 3 | 4 | import os 5 | import argparse 6 | # PArse xml 7 | import xml.etree.ElementTree as ET 8 | import numpy as np 9 | import cv2 10 | # One-hot encoder/decoder 11 | import one_hot 12 | # Load / dump data 13 | import pickle 14 | 15 | data_dir = os.path.join('data', 'CROHME_full_v2') 16 | # Construct the argument parse and parse the arguments 17 | version_choices = ['2011', '2012', '2013'] 18 | # Load categories from `categories.txt` file 19 | categories = [{'name': cat.split(':')[0], 'classes': cat.split(':')[1].split()} for cat in list(open('categories.txt', 'r'))] 20 | category_names = [cat['name'] for cat in categories] 21 | 22 | ap = argparse.ArgumentParser() 23 | ap.add_argument('-b', '--box_size', required=True, help="Specify a length of square box side.") 24 | ap.add_argument('-d', '--dataset_version', required=True, help="Specify what dataset versions have to be extracted.", choices=version_choices, nargs='+') 25 | ap.add_argument('-c', '--category', required=True, help="Specify what dataset versions have to be extracted.", choices=category_names, nargs='+') 26 | ap.add_argument('-t', '--thickness', required=False, help="Specify the thickness of extractd patterns.", default=1, type=int) 27 | args = vars(ap.parse_args()) 28 | # Get classes that have to be extracted (based on categories selected by user) 29 | classes_to_extract = [] 30 | for cat_name in args.get('category'): 31 | cat_idx = category_names.index(cat_name) 32 | classes_to_extract += categories[cat_idx]['classes'] 33 | 34 | # Extract INKML files 35 | all_inkml_files = [] 36 | for d_version in args.get('dataset_version'): 37 | # Chose directory containing data based on dataset version selected 38 | working_dir = os.path.join(data_dir, 'CROHME{}_data'.format(d_version)) 39 | # List folders found within working_dir 40 | for folder in os.listdir(working_dir): 41 | curr_folder = os.path.join(working_dir, folder) 42 | if os.path.isdir(curr_folder): 43 | # List files & folders found within folder 44 | content = os.listdir(curr_folder) 45 | # Filter inkml fiels and folders 46 | inkml_files = [os.path.join(curr_folder, inmkl_file) for inmkl_file in content if inmkl_file.endswith('.inkml')] 47 | sub_folders = [sub_folder for sub_folder in content if os.path.isdir(os.path.join(curr_folder, sub_folder))] 48 | 49 | print('FOLDER:', curr_folder) 50 | print('Numb. of inkml files:', len(inkml_files)) 51 | 52 | all_inkml_files += inkml_files 53 | for sub_folder in sub_folders: 54 | # Extract inkml files from within sub_folder 55 | sub_folder_path = os.path.join(curr_folder, sub_folder) 56 | inkml_files = [os.path.join(sub_folder_path, inmkl_file) for inmkl_file in os.listdir(sub_folder_path) if inmkl_file.endswith('.inkml')] 57 | all_inkml_files += inkml_files 58 | 59 | print('FOLDER:', sub_folder_path) 60 | print('Numb. of inkml files:', len(inkml_files)) 61 | print('\n') 62 | 63 | # Filter inkml files that are used for training and those used for testing 64 | training_inkmls = [inkml_file for inkml_file in all_inkml_files if 'CROHME_training' in inkml_file or 'trainData' in inkml_file or 'TrainINKML' in inkml_file] 65 | testing_inkmls = [inkml_file for inkml_file in all_inkml_files if 'CROHME_testGT' in inkml_file or 'testDataGT' in inkml_file or ('TestINKMLGT' in inkml_file and not 'Prime_in_row' in inkml_file)] 66 | print('Numder of training INKML files:', len(training_inkmls)) 67 | print('Numder of testing INKML files:', len(testing_inkmls)) 68 | 69 | classes = [] 70 | def extract_trace_grps(inkml_file_abs_path): 71 | trace_grps = [] 72 | 73 | tree = ET.parse(inkml_file_abs_path) 74 | root = tree.getroot() 75 | doc_namespace = "{http://www.w3.org/2003/InkML}" 76 | 77 | # Find traceGroup wrapper - traceGroup wrapping important traceGroups 78 | traceGrpWrapper = root.findall(doc_namespace + 'traceGroup')[0] 79 | traceGroups = traceGrpWrapper.findall(doc_namespace + 'traceGroup') 80 | for traceGrp in traceGroups: 81 | latex_class = traceGrp.findall(doc_namespace + 'annotation')[0].text 82 | traceViews = traceGrp.findall(doc_namespace + 'traceView') 83 | # Get traceid of traces that refer to latex_class extracted above 84 | id_traces = [traceView.get('traceDataRef') for traceView in traceViews] 85 | # Construct pattern object 86 | trace_grp = {'label': latex_class, 'traces': []} 87 | 88 | # Find traces with referenced by latex_class 89 | traces = [trace for trace in root.findall(doc_namespace + 'trace') if trace.get('id') in id_traces] 90 | # Extract trace coords 91 | for idx, trace in enumerate(traces): 92 | coords = [] 93 | for coord in trace.text.replace('\n', '').split(','): 94 | # Remove empty strings from coord list (e.g. ['', '-238', '-91'] -> [-238', '-91']) 95 | coord = list(filter(None, coord.split(' '))) 96 | # Unpack coordinates 97 | x, y = coord[:2] 98 | # print('{}, {}'.format(x, y)) 99 | if not float(x).is_integer(): 100 | # Count decimal places of x coordinate 101 | d_places = len(x.split('.')[-1]) 102 | # ! Get rid of decimal places (e.g. '13.5662' -> '135662') 103 | # x = float(x) * (10 ** len(x.split('.')[-1]) + 1) 104 | x = float(x) * 10000 105 | else: 106 | x = float(x) 107 | if not float(y).is_integer(): 108 | # Count decimal places of y coordinate 109 | d_places = len(y.split('.')[-1]) 110 | # ! Get rid of decimal places (e.g. '13.5662' -> '135662') 111 | # y = float(y) * (10 ** len(y.split('.')[-1]) + 1) 112 | y = float(y) * 10000 113 | else: 114 | y = float(y) 115 | 116 | # Cast x & y coords to integer 117 | x, y = round(x), round(y) 118 | coords.append([x, y]) 119 | trace_grp['traces'].append(coords) 120 | trace_grps.append(trace_grp) 121 | 122 | # print('Pattern: {};'.format(pattern)) 123 | return trace_grps 124 | 125 | 126 | def get_tracegrp_properties(trace_group): 127 | x_mins, y_mins, x_maxs, y_maxs = [], [], [], [] 128 | for trace in trace_group['traces']: 129 | 130 | x_min, y_min = np.amin(trace, axis=0) 131 | x_max, y_max = np.amax(trace, axis=0) 132 | x_mins.append(x_min) 133 | x_maxs.append(x_max) 134 | y_mins.append(y_min) 135 | y_maxs.append(y_max) 136 | # print('X_min: {}; Y_min: {}; X_max: {}; Y_max: {}'.format(min(x_mins), min(y_mins), max(x_maxs), max(y_maxs))) 137 | return min(x_mins), min(y_mins), max(x_maxs) - min(x_mins), max(y_maxs) - min(y_mins) 138 | 139 | def shift_trace_group(trace_grp, x_min, y_min): 140 | shifted_traces = [] 141 | for trace in trace_grp['traces']: 142 | shifted_traces.append(np.subtract(trace, [x_min, y_min])) 143 | return {'label': trace_grp['label'], 'traces': shifted_traces} 144 | 145 | def get_scale(width, height, box_size): 146 | ratio = width / height 147 | if ratio < 1.0: 148 | return box_size / height 149 | else: 150 | return box_size / width 151 | 152 | def rescale_trace_group(trace_grp, width, height, box_size): 153 | # Get scale - we will use this scale to interpolate trace_group so that it fits into (box_size X box_size) square box. 154 | scale = get_scale(width, height, box_size) 155 | rescaled_traces = [] 156 | for trace in trace_grp['traces']: 157 | # Interpolate contour and round coordinate values to int type 158 | rescaled_trace = np.around(np.asarray(trace) * scale).astype(dtype=np.uint8) 159 | rescaled_traces.append(rescaled_trace) 160 | 161 | return {'label': trace_grp['label'], 'traces': rescaled_traces} 162 | 163 | def draw_trace(trace_grp, box_size, thickness): 164 | placeholder = np.ones(shape=(box_size, box_size), dtype=np.uint8) * 255 165 | for trace in trace_grp['traces']: 166 | for coord_idx in range(1, len(trace)): 167 | cv2.line(placeholder, tuple(trace[coord_idx - 1]), tuple(trace[coord_idx]), color=(0), thickness=thickness) 168 | return placeholder 169 | 170 | def convert_to_img(trace_group): 171 | # Extract command line arguments 172 | box_size = int(args.get('box_size')) 173 | thickness = int(args.get('thickness')) 174 | # Calculate Thickness Padding 175 | thickness_pad = (thickness - 1) // 2 176 | # Convert traces to np.array 177 | trace_group['traces'] = np.asarray(trace_group['traces']) 178 | # Get properies of a trace group 179 | x, y, width, height = get_tracegrp_properties(trace_group) 180 | 181 | # 1. Shift trace_group 182 | trace_group = shift_trace_group(trace_group, x_min=x, y_min=y) 183 | x, y, width, height = get_tracegrp_properties(trace_group) 184 | # 2. Rescale trace_group 185 | trace_group = rescale_trace_group(trace_group, width, height, box_size=box_size-thickness_pad*2) 186 | x, y, width_r, height_r = get_tracegrp_properties(trace_group) 187 | # Shift trace_group by thickness padding 188 | trace_group = shift_trace_group(trace_group, x_min=-thickness_pad, y_min=-thickness_pad) 189 | # Center inside square box (box_size X box_size) 190 | margin_x = (box_size - (width_r + thickness_pad*2)) // 2 191 | margin_y = (box_size - (height_r + thickness_pad*2)) // 2 192 | trace_group = shift_trace_group(trace_group, x_min=-margin_x, y_min=-margin_y) 193 | image = draw_trace(trace_group, box_size, thickness=thickness) 194 | # Get pattern's width & height 195 | pat_width, pat_height = width_r + thickness_pad*2, height_r + thickness_pad*2 196 | 197 | # ! TESTS 198 | # cv2.imshow('image', image) 199 | # cv2.waitKey(0) 200 | if width < box_size and height < box_size: 201 | raise Exception('Trace group is too small.') 202 | if x != 0 or y != 0: 203 | raise Exception('Trace group was inproperly shifted.') 204 | if pat_width == 0 or pat_height == 0: 205 | raise Exception('Some sides are 0 length.') 206 | if pat_width < box_size and pat_height < box_size: 207 | raise Exception('Both sides are < box_size.') 208 | if pat_width > box_size or pat_height > box_size: 209 | raise Exception('Some sides are > box_size.') 210 | return image 211 | 212 | damaged = 0 213 | # Extract TRAINING data 214 | train = [] 215 | for training_inkml in training_inkmls: 216 | print(training_inkml) 217 | trace_groups = extract_trace_grps(training_inkml) 218 | for trace_grp in trace_groups: 219 | label = trace_grp['label'] 220 | # Extract only classes selected by user (selecting categories) 221 | if label not in classes_to_extract: 222 | continue 223 | try: 224 | if label not in classes: 225 | classes.append(label) 226 | # Convert patterns to images 227 | image = convert_to_img(trace_grp) 228 | # Flatten image & construct pattern object 229 | pattern = {'features': image.flatten(), 'label': label} 230 | train.append(pattern) 231 | except Exception as e: 232 | print(e) 233 | # Ignore damaged trace groups 234 | damaged += 1 235 | 236 | # Extract TESTING data 237 | test = [] 238 | for testing_inkml in testing_inkmls: 239 | print(testing_inkml) 240 | trace_groups = extract_trace_grps(testing_inkml) 241 | for trace_grp in trace_groups: 242 | label = trace_grp['label'] 243 | # Extract only classes selected by user (selecting categories) 244 | if label not in classes_to_extract: 245 | continue 246 | try: 247 | if label not in classes: 248 | classes.append(label) 249 | # Convert patterns to images 250 | image = convert_to_img(trace_grp) 251 | # Flatten image & construct pattern object 252 | pattern = {'features': image.flatten(), 'label': label} 253 | test.append(pattern) 254 | except Exception as e: 255 | print(e) 256 | # Ignore damaged trace groups 257 | damaged += 1 258 | 259 | # Sort classes alphabetically 260 | classes = sorted(classes) 261 | print('\nTraining set size:', len(train)) 262 | print('Testing set size:', len(test)) 263 | print('How many rejected trace groups:', damaged, '\n') 264 | 265 | # Data POST-processing 266 | # 1. Normalize features 267 | # 2. Convert labels to one-hot format 268 | for pat in train: 269 | pat['features'] = (pat['features'] / 255).astype(dtype=np.uint8) 270 | pat['label'] = one_hot.encode(pat['label'], classes) 271 | for pat in test: 272 | pat['features'] = (pat['features'] / 255).astype(dtype=np.uint8) 273 | pat['label'] = one_hot.encode(pat['label'], classes) 274 | 275 | # Dump extracted data 276 | outputs_dir = 'outputs' 277 | train_out_dir = os.path.join(outputs_dir, 'train') 278 | test_out_dir = os.path.join(outputs_dir, 'test') 279 | # Make directories if needed 280 | if not os.path.exists(outputs_dir): 281 | os.mkdir(outputs_dir) 282 | if not os.path.exists(train_out_dir): 283 | os.mkdir(train_out_dir) 284 | if not os.path.exists(test_out_dir): 285 | os.mkdir(test_out_dir) 286 | 287 | with open(os.path.join(train_out_dir, 'train.pickle'), 'wb') as f: 288 | pickle.dump(train, f, protocol=pickle.HIGHEST_PROTOCOL) 289 | print('Training data has been successfully dumped into', f.name) 290 | with open(os.path.join(test_out_dir, 'test.pickle'), 'wb') as f: 291 | pickle.dump(test, f, protocol=pickle.HIGHEST_PROTOCOL) 292 | print('Testing data has been successfully dumped into', f.name) 293 | # Save all labels in 'classes.txt' file 294 | with open('classes.txt', 'w') as f: 295 | for r_class in classes: 296 | f.write(r_class + '\n') 297 | print('All classes that were extracted are listed in {} file.'.format(f.name)) 298 | 299 | print('\n# Like our facebook page @ https://www.facebook.com/mathocr/') 300 | -------------------------------------------------------------------------------- /extract_hog.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import pickle 5 | 6 | from skimage.feature import hog 7 | import matplotlib.pyplot as plt 8 | 9 | 'constants' 10 | outputs_rel_path = 'outputs' 11 | train_dir = os.path.join(outputs_rel_path, 'train') 12 | test_dir = os.path.join(outputs_rel_path, 'test') 13 | validation_dir = os.path.join(outputs_rel_path, 'validation') 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | 'parse cmd input' 20 | print(' # Script flags:', '', '\n') 21 | 22 | 'parse 1st arg' 23 | if len(sys.argv) < 2: 24 | print('\n + Usage:', sys.argv[0], '', '\n') 25 | exit() 26 | 27 | try: 28 | hog_cell_size = int(sys.argv[1]) 29 | except Exception as e: 30 | print(e) 31 | exit() 32 | 33 | 34 | 35 | 'Load pickled data' 36 | with open(os.path.join(train_dir, 'train.pickle'), 'rb') as train: 37 | print('Restoring training set ...') 38 | train_set = pickle.load(train) 39 | 40 | with open(os.path.join(test_dir, 'test.pickle'), 'rb') as test: 41 | print('Restoring test set ...') 42 | test_set = pickle.load(test) 43 | 44 | with open(os.path.join(validation_dir, 'validation.pickle'), 'rb') as validation: 45 | print('Restoring validation set ...') 46 | validation_set = pickle.load(validation) 47 | 48 | 49 | 50 | ' **** HOG PARAMS **** ' 51 | orientations = 8 52 | pixels_per_cell = (hog_cell_size, hog_cell_size) 53 | cells_per_block = (1, 1) 54 | 55 | 56 | 57 | ' **** Extract hog features **** ' 58 | 59 | ' TRAIN SET ' 60 | print('Extracting hog - TRAIN set ...') 61 | train_hog = [] 62 | for pattern_enc in train_set[40:]: 63 | 64 | hog_enc = dict({'label': pattern_enc.get('label'), 'features': hog(pattern_enc.get('features'), \ 65 | orientations=orientations, pixels_per_cell=pixels_per_cell, cells_per_block=cells_per_block, visualise=False, \ 66 | block_norm='L2-Hys')}) 67 | train_hog.append(hog_enc) 68 | 69 | 70 | ' TEST SET ' 71 | print('Extracting hog - TEST set ...') 72 | test_hog = [] 73 | for pattern_enc in test_set: 74 | 75 | hog_enc = dict({'label': pattern_enc.get('label'), 'features': hog(pattern_enc.get('features'), \ 76 | orientations=orientations, pixels_per_cell=pixels_per_cell, cells_per_block=cells_per_block, visualise=False, \ 77 | block_norm='L2-Hys')}) 78 | test_hog.append(hog_enc) 79 | 80 | 81 | ' VALIDATION SET ' 82 | print('Extracting hog - VALIDATION set ...') 83 | validation_hog = [] 84 | for pattern_enc in validation_set: 85 | 86 | hog_enc = dict({'label': pattern_enc.get('label'), 'features': hog(pattern_enc.get('features'), \ 87 | orientations=orientations, pixels_per_cell=pixels_per_cell, cells_per_block=cells_per_block, visualise=False, \ 88 | block_norm='L2-Hys')}) 89 | validation_hog.append(hog_enc) 90 | 91 | 92 | 93 | 94 | ' DUMP DATA ' 95 | print('\nDumping extracted data ...') 96 | 'Make dirs if needed' 97 | if not os.path.exists(train_dir): 98 | os.mkdir(train_dir) 99 | if not os.path.exists(test_dir): 100 | os.mkdir(test_dir) 101 | if not os.path.exists(validation_dir): 102 | os.mkdir(validation_dir) 103 | 104 | 105 | with open(os.path.join(train_dir, 'train_hog_' + str(hog_cell_size) + 'x' + str(hog_cell_size) + '.pickle'), 'wb') as train: 106 | pickle.dump(train_hog, train, protocol=pickle.HIGHEST_PROTOCOL) 107 | print('Data has been successfully dumped into', train.name) 108 | 109 | with open(os.path.join(test_dir, 'test_hog_' + str(hog_cell_size) + 'x' + str(hog_cell_size) + '.pickle'), 'wb') as test: 110 | pickle.dump(test_hog, test, protocol=pickle.HIGHEST_PROTOCOL) 111 | print('Data has been successfully dumped into', test.name) 112 | 113 | with open(os.path.join(validation_dir, 'validation_hog_' + str(hog_cell_size) + 'x' + str(hog_cell_size) + '.pickle'), 'wb') as validation: 114 | pickle.dump(validation_hog, validation, protocol=pickle.HIGHEST_PROTOCOL) 115 | print('Data has been successfully dumped into', validation.name) 116 | -------------------------------------------------------------------------------- /extract_phog.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import pickle 5 | import numpy as np 6 | 7 | 8 | 'constants' 9 | outputs_rel_path = 'outputs' 10 | train_dir = os.path.join(outputs_rel_path, 'train') 11 | test_dir = os.path.join(outputs_rel_path, 'test') 12 | validation_dir = os.path.join(outputs_rel_path, 'validation') 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | 'parse cmd input' 19 | print(' # Script flags:', '', '', '', '...') 20 | 21 | 'parse 1st arg' 22 | if len(sys.argv) < 2: 23 | print('\n + Usage:', sys.argv[0], '', '', '', '...') 24 | exit() 25 | 26 | try: 27 | hog_cell_sizes = [hog_cell_size for hog_cell_size in sys.argv[1:]] 28 | except Exception as e: 29 | print(e) 30 | exit() 31 | 32 | 33 | 'Load TRAIN HOGs' 34 | train_dif_hogs = [] 35 | for hog_cell_size in hog_cell_sizes: 36 | 37 | with open(os.path.join(train_dir, 'train_hog_' + hog_cell_size + 'x' + hog_cell_size + '.pickle'), 'rb') as train: 38 | train_dif_hogs.append(pickle.load(train)) 39 | 40 | train_phog_size = len(train_dif_hogs[0]) 41 | train_phog = [] 42 | for hog_enc_idx in range(train_phog_size): 43 | 44 | ' **** MERGE all different hog representations of EACH PATTERN **** ' 45 | PHOG_features = [] 46 | for train_dif_hog in train_dif_hogs: 47 | 48 | PHOG_features += train_dif_hog[hog_enc_idx]['features'].tolist() 49 | 50 | 51 | PHOG_enc = dict({'label': train_dif_hogs[0][hog_enc_idx]['label'], 'features': np.asarray(PHOG_features)}) 52 | train_phog.append(PHOG_enc) 53 | 54 | # ' **** MERGE all different hog representations of EACH PATTERN **** ' 55 | # train_phog = [{'label': train_dif_hogs[0][hog_enc_idx]['label'], \ 56 | # 'phog': np.asarray([train_dif_hog[hog_enc_idx]['hog'] for train_dif_hog in train_dif_hogs], dtype=np.float32)} \ 57 | # for hog_enc_idx in range(len(train_dif_hogs[0]))] 58 | 59 | 60 | 61 | 62 | 63 | 64 | 'Load TEST HOGs' 65 | test_dif_hogs = [] 66 | for hog_cell_size in hog_cell_sizes: 67 | 68 | with open(os.path.join(test_dir, 'test_hog_' + hog_cell_size + 'x' + hog_cell_size + '.pickle'), 'rb') as test: 69 | test_dif_hogs.append(pickle.load(test)) 70 | 71 | test_phog_size = len(test_dif_hogs[0]) 72 | test_phog = [] 73 | for hog_enc_idx in range(test_phog_size): 74 | 75 | ' **** MERGE all different hog representations of EACH PATTERN **** ' 76 | PHOG_features = [] 77 | for test_dif_hog in test_dif_hogs: 78 | 79 | PHOG_features += test_dif_hog[hog_enc_idx]['features'].tolist() 80 | 81 | 82 | PHOG_enc = dict({'label': test_dif_hogs[0][hog_enc_idx]['label'], 'features': np.asarray(PHOG_features)}) 83 | test_phog.append(PHOG_enc) 84 | 85 | 86 | 87 | 88 | 89 | 'Load VALIDATION HOGs' 90 | validation_dif_hogs = [] 91 | for hog_cell_size in hog_cell_sizes: 92 | 93 | with open(os.path.join(validation_dir, 'validation_hog_' + hog_cell_size + 'x' + hog_cell_size + '.pickle'), 'rb') as validation: 94 | validation_dif_hogs.append(pickle.load(validation)) 95 | 96 | validation_phog_size = len(validation_dif_hogs[0]) 97 | validation_phog = [] 98 | for hog_enc_idx in range(validation_phog_size): 99 | 100 | ' **** MERGE all different hog representations of EACH PATTERN **** ' 101 | PHOG_features = [] 102 | for validation_dif_hog in validation_dif_hogs: 103 | 104 | PHOG_features += validation_dif_hog[hog_enc_idx]['features'].tolist() 105 | 106 | 107 | PHOG_enc = dict({'label': validation_dif_hogs[0][hog_enc_idx]['label'], 'features': np.asarray(PHOG_features)}) 108 | validation_phog.append(PHOG_enc) 109 | 110 | 111 | 112 | 113 | 114 | 115 | ' DUMP DATA ' 116 | print('\nDumping extracted data ...') 117 | 118 | 119 | phog_cell_sizes_str = '' 120 | for hog_cell_size in hog_cell_sizes: 121 | phog_cell_sizes_str += hog_cell_size + '_' 122 | 123 | phog_cell_sizes_str = phog_cell_sizes_str[:-1] # Removes last redundant '_' separator 124 | 125 | 126 | with open(os.path.join(train_dir, 'train_phog_' + phog_cell_sizes_str + '.pickle'), 'wb') as train: 127 | pickle.dump(train_phog, train, protocol=pickle.HIGHEST_PROTOCOL) 128 | print('Data has been successfully dumped into', train.name) 129 | 130 | with open(os.path.join(test_dir, 'test_phog_' + phog_cell_sizes_str + '.pickle'), 'wb') as test: 131 | pickle.dump(test_phog, test, protocol=pickle.HIGHEST_PROTOCOL) 132 | print('Data has been successfully dumped into', test.name) 133 | 134 | with open(os.path.join(validation_dir, 'validation_phog_' + phog_cell_sizes_str + '.pickle'), 'wb') as validation: 135 | pickle.dump(validation_phog, validation, protocol=pickle.HIGHEST_PROTOCOL) 136 | print('Data has been successfully dumped into', validation.name) 137 | -------------------------------------------------------------------------------- /histograms/all_labels_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/all_labels_distribution.png -------------------------------------------------------------------------------- /histograms/capital_letters_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/capital_letters_distribution.png -------------------------------------------------------------------------------- /histograms/digits_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/digits_distribution.png -------------------------------------------------------------------------------- /histograms/greek_letters_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/greek_letters_distribution.png -------------------------------------------------------------------------------- /histograms/labels_histogram.txt: -------------------------------------------------------------------------------- 1 | \exists : 11 2 | \forall : 26 3 | \lambda : 31 4 | \in : 33 5 | \Delta : 56 6 | \mu : 63 7 | \sigma : 66 8 | \gt : 80 9 | I : 96 10 | \phi : 106 11 | G : 114 12 | \gamma : 116 13 | \prime : 127 14 | H : 127 15 | o : 130 16 | Y : 134 17 | l : 137 18 | M : 140 19 | \} : 147 20 | \{ : 147 21 | \lt : 148 22 | T : 150 23 | V : 158 24 | \neq : 160 25 | P : 165 26 | S : 166 27 | w : 166 28 | N : 173 29 | \ldots : 181 30 | L : 185 31 | E : 196 32 | \geq : 198 33 | \pm : 225 34 | \div : 231 35 | R : 236 36 | / : 243 37 | \leq : 275 38 | h : 291 39 | s : 291 40 | [ : 293 41 | ] : 293 42 | F : 307 43 | q : 347 44 | v : 348 45 | ! : 368 46 | A : 382 47 | X : 383 48 | u : 383 49 | g : 389 50 | B : 398 51 | j : 445 52 | C : 469 53 | \infty : 508 54 | \lim : 531 55 | | : 536 56 | r : 541 57 | \rightarrow : 545 58 | \log : 619 59 | \beta : 623 60 | m : 627 61 | p : 649 62 | \pi : 680 63 | \tan : 744 64 | \sum : 753 65 | . : 760 66 | f : 766 67 | , : 790 68 | \theta : 792 69 | \int : 804 70 | e : 804 71 | \alpha : 819 72 | t : 907 73 | \cos : 923 74 | k : 927 75 | \times : 984 76 | 7 : 1025 77 | 8 : 1045 78 | 6 : 1085 79 | 9 : 1091 80 | c : 1235 81 | 5 : 1257 82 | \sin : 1314 83 | i : 1478 84 | d : 1481 85 | z : 1833 86 | b : 2137 87 | 0 : 2509 88 | 4 : 2536 89 | \sqrt : 2607 90 | y : 2687 91 | n : 3075 92 | a : 3267 93 | 3 : 3785 94 | = : 4800 95 | ) : 5420 96 | ( : 5425 97 | x : 8109 98 | + : 8877 99 | 2 : 9455 100 | 1 : 9460 101 | - : 12328 102 | -------------------------------------------------------------------------------- /histograms/lowercase_letters_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/lowercase_letters_distribution.png -------------------------------------------------------------------------------- /histograms/math_symbols_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/math_symbols_distribution.png -------------------------------------------------------------------------------- /histograms/special_characters_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/histograms/special_characters_distribution.png -------------------------------------------------------------------------------- /one_hot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # Encode to one_hot format 3 | def encode(class_name, classes): 4 | 5 | one_hot = np.zeros(shape=(len(classes)), dtype=np.int8) 6 | class_index = classes.index(class_name) 7 | one_hot[class_index] = 1 8 | 9 | return one_hot 10 | 11 | # Decode from one_hot format to string 12 | def decode(one_hot, classes): 13 | index = one_hot.argmax() 14 | return classes[index] 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.12.1+mkl 2 | scikit-image==0.13.0 3 | pickleshare==0.7.4 4 | matplotlib==2.0.0 -------------------------------------------------------------------------------- /visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasLech/CROHME_extractor/ca2d9328b8ad4dabec94587d8f98f94851ff3a49/visualization.png -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | import pickle 6 | 7 | import math 8 | import random 9 | # Image processing 10 | from skimage.feature import hog 11 | # Data visualization 12 | import matplotlib.pyplot as plt 13 | # One-hot encoder/decoder 14 | import one_hot 15 | 16 | 'constants' 17 | outputs_rel_path = 'outputs' 18 | train_dir = os.path.join(outputs_rel_path, 'train') 19 | test_dir = os.path.join(outputs_rel_path, 'test') 20 | 21 | ap = argparse.ArgumentParser() 22 | ap.add_argument('-b', '--box_size', required=True, help="Specify a length of square box side.") 23 | ap.add_argument('-n', '--n_samples', required=True, help="Specify the nubmer of samples to show.") 24 | ap.add_argument('-c', '--columns', required=True, help="Specify the nubmer of columns.") 25 | args = vars(ap.parse_args()) 26 | 27 | # Load pickled data 28 | with open(os.path.join(train_dir, 'train.pickle'), 'rb') as train: 29 | print('Restoring training set ...') 30 | train_set = pickle.load(train) 31 | 32 | with open(os.path.join(test_dir, 'test.pickle'), 'rb') as test: 33 | print('Restoring test set ...') 34 | test_set = pickle.load(test) 35 | 36 | # Extract command-line arguments 37 | box_size = int(args.get('box_size')) 38 | n_samples = int(args.get('n_samples')) 39 | n_cols = int(args.get('columns')) 40 | 41 | # Load classes 42 | classes = open('classes.txt', 'r').read().split() 43 | 44 | 'Compute number of rows with respect to number of both columns and samples provided by user' 45 | rows_numb = math.ceil(n_samples / n_cols) 46 | 47 | 'Instanciate a figure to plot samples on' 48 | figure, axis_arr = plt.subplots(rows_numb, n_cols, figsize=(12, 4)) 49 | figure.patch.set_facecolor((0.91, 0.91, 0.91)) 50 | 51 | sample_id = 0 52 | for row in range(rows_numb): 53 | for col in range(n_cols): 54 | 55 | if sample_id < n_samples: 56 | 'Generate random sample id' 57 | random_id = random.randint(0, len(test_set)) 58 | training_sample = test_set[random_id] 59 | # Decode from one-hot format to string 60 | label = one_hot.decode(training_sample['label'], classes) 61 | 62 | axis_arr[row, col].imshow(training_sample['features'].reshape((box_size, box_size)), cmap='gray') 63 | axis_arr[row, col].set_title('Class: \"' + label + '\"', size=13, y=1.2) 64 | 65 | 'Remove explicit axises' 66 | axis_arr[row, col].axis('off') 67 | 68 | sample_id += 1 69 | 70 | 'Adjust spacing between subplots and window border' 71 | figure.subplots_adjust(hspace=1.4, wspace=0.2) 72 | plt.savefig('visualization.png') 73 | 74 | # Brings foreground 75 | plt.show() 76 | --------------------------------------------------------------------------------