├── rev ├── spec │ └── __init__.py ├── text │ ├── localizer.py │ ├── __init__.py │ ├── classifier.py │ ├── feature_extractor.py │ ├── rectutils.py │ └── ocr.py ├── mark │ ├── __init__.py │ └── classifier.py ├── __init__.py ├── utils.py ├── chart.py └── textbox.py ├── docs ├── README.md └── _config.yml ├── examples ├── vega1.png ├── vega1-debug.png ├── vega1-mask.png ├── vega1-pred1-debug.png ├── vega1-texts.csv └── vega1-pred1-texts.csv ├── .gitignore ├── testing.py ├── scripts ├── rate_mark_type_classifier.py ├── rate_text_role_classifier.py ├── run_feature_extraction.py └── run_text_role_classifier.py └── README.md /rev/spec/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rev/text/localizer.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | Initial Text 2 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-tactile -------------------------------------------------------------------------------- /rev/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import TextClassifier 2 | -------------------------------------------------------------------------------- /rev/mark/__init__.py: -------------------------------------------------------------------------------- 1 | from rev.mark.classifier import MarkClassifier 2 | -------------------------------------------------------------------------------- /rev/__init__.py: -------------------------------------------------------------------------------- 1 | from .chart import Chart 2 | from .chart import chart_dataset -------------------------------------------------------------------------------- /examples/vega1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uwdata/rev/master/examples/vega1.png -------------------------------------------------------------------------------- /examples/vega1-debug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uwdata/rev/master/examples/vega1-debug.png -------------------------------------------------------------------------------- /examples/vega1-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uwdata/rev/master/examples/vega1-mask.png -------------------------------------------------------------------------------- /examples/vega1-pred1-debug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uwdata/rev/master/examples/vega1-pred1-debug.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | 7 | # Python egg metadata, regenerated from source files by setuptools. 8 | /*.egg-info 9 | 10 | data/* 11 | models/* 12 | 13 | data 14 | models 15 | 16 | out.* -------------------------------------------------------------------------------- /examples/vega1-texts.csv: -------------------------------------------------------------------------------- 1 | id,x,y,width,height,text,type 2 | 0,5.0,119.0,22.0,188.0,Number of Records,y-axis-title 3 | 1,31.0,5.5,19.0,17.0,45,y-axis-label 4 | 2,31.0,49.5,19.0,17.0,40,y-axis-label 5 | 3,31.0,94.5,19.0,17.0,35,y-axis-label 6 | 4,31.0,138.5,19.0,17.0,30,y-axis-label 7 | 5,31.0,183.5,19.0,17.0,25,y-axis-label 8 | 6,31.0,227.5,19.0,17.0,20,y-axis-label 9 | 7,31.0,272.5,19.0,17.0,15,y-axis-label 10 | 8,31.0,316.5,19.0,17.0,10,y-axis-label 11 | 9,39.0,361.5,11.0,17.0,5,y-axis-label 12 | 10,39.0,405.5,11.0,17.0,0,y-axis-label 13 | 11,204.5,441.0,105.0,24.0,BIN(yield),x-axis-title 14 | 12,447.5,417.0,19.0,21.0,70,x-axis-label 15 | 13,414.5,417.0,19.0,21.0,65,x-axis-label 16 | 14,380.5,417.0,19.0,21.0,60,x-axis-label 17 | 15,347.5,417.0,19.0,21.0,55,x-axis-label 18 | 16,314.5,417.0,19.0,21.0,50,x-axis-label 19 | 17,280.5,417.0,19.0,21.0,45,x-axis-label 20 | 18,247.5,417.0,19.0,21.0,40,x-axis-label 21 | 19,214.5,417.0,19.0,21.0,35,x-axis-label 22 | 20,180.5,417.0,19.0,21.0,30,x-axis-label 23 | 21,147.5,417.0,19.0,21.0,25,x-axis-label 24 | 22,114.5,417.0,19.0,21.0,20,x-axis-label 25 | 23,80.5,417.0,19.0,21.0,15,x-axis-label 26 | 24,47.5,417.0,19.0,21.0,10,x-axis-label 27 | -------------------------------------------------------------------------------- /examples/vega1-pred1-texts.csv: -------------------------------------------------------------------------------- 1 | id,x,y,width,height,text,type 2 | 0,5.0,119.0,22.0,188.0,Number of Records,y-axis-title 3 | 1,31.0,5.5,19.0,17.0,45,y-axis-label 4 | 2,31.0,49.5,19.0,17.0,40,y-axis-label 5 | 3,31.0,94.5,19.0,17.0,35,y-axis-label 6 | 4,31.0,138.5,19.0,17.0,30,y-axis-label 7 | 5,31.0,183.5,19.0,17.0,25,y-axis-label 8 | 6,31.0,227.5,19.0,17.0,20,y-axis-label 9 | 7,31.0,272.5,19.0,17.0,15,y-axis-label 10 | 8,31.0,316.5,19.0,17.0,10,y-axis-label 11 | 9,39.0,361.5,11.0,17.0,5,y-axis-label 12 | 10,39.0,405.5,11.0,17.0,0,y-axis-label 13 | 11,204.5,441.0,105.0,24.0,BIN(yield),x-axis-title 14 | 12,447.5,417.0,19.0,21.0,70,x-axis-label 15 | 13,414.5,417.0,19.0,21.0,65,x-axis-label 16 | 14,380.5,417.0,19.0,21.0,60,x-axis-label 17 | 15,347.5,417.0,19.0,21.0,55,x-axis-label 18 | 16,314.5,417.0,19.0,21.0,50,x-axis-label 19 | 17,280.5,417.0,19.0,21.0,45,x-axis-label 20 | 18,247.5,417.0,19.0,21.0,40,x-axis-label 21 | 19,214.5,417.0,19.0,21.0,35,x-axis-label 22 | 20,180.5,417.0,19.0,21.0,30,x-axis-label 23 | 21,147.5,417.0,19.0,21.0,25,x-axis-label 24 | 22,114.5,417.0,19.0,21.0,20,x-axis-label 25 | 23,80.5,417.0,19.0,21.0,15,x-axis-label 26 | 24,47.5,417.0,19.0,21.0,10,x-axis-label 27 | -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | import rev 2 | 3 | # load a chart 4 | chart = rev.Chart('examples/vega1.png', text_from=0) 5 | 6 | ########################################################################################## 7 | # mark type classifier 8 | # import rev.mark 9 | # mark_clf = rev.mark.MarkClassifier() 10 | # print mark_clf.classify([chart]) 11 | 12 | ########################################################################################## 13 | # feature extraction (single) 14 | # import rev.text 15 | # text_features = rev.text.feature_extractor.from_chart(chart) 16 | # print text_features 17 | 18 | ########################################################################################## 19 | # feature extraction (corpus) 20 | # import rev.text 21 | # text_features = rev.text.feature_extractor.from_chart(chart) 22 | # print text_features 23 | 24 | ########################################################################################## 25 | # text role classification 26 | # import rev.text 27 | # text_clf = rev.text.TextClassifier('default') 28 | # text_type_preds = text_clf.classify(chart) 29 | # print text_type_preds 30 | 31 | ########################################################################################## 32 | # training text role classifier 33 | import pandas as pd 34 | import rev.text 35 | data = pd.read_csv('data/features_all.csv') 36 | features = data[rev.text.classifier.VALID_COLUMNS] 37 | types = data['type'] 38 | 39 | text_clf = rev.text.TextClassifier() 40 | text_clf.train(features, types) 41 | # text_clf.save_model('models/text_role_classifier/text_type_classifier_new.pkl') 42 | 43 | -------------------------------------------------------------------------------- /scripts/rate_mark_type_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to compute accuracy. 3 | 4 | Usage: 5 | rate_mark_type_classifier.py MODEL_NAME TEST_FILE [--show] 6 | rate_mark_type_classifier.py (-h | --help) 7 | rate_mark_type_classifier.py --version 8 | 9 | Options: 10 | --show Show incorrect predictions. 11 | -h --help Show this screen. 12 | --version Show version. 13 | 14 | Example: 15 | python scripts/rate_mark_type_classifier.py revision models/mark_classifier/revision/test.txt 16 | """ 17 | from docopt import docopt 18 | import numpy as np 19 | 20 | from sklearn.metrics import classification_report, accuracy_score, confusion_matrix 21 | 22 | import rev.mark 23 | import rev.utils as u 24 | 25 | 26 | def report_accuracy(y_true, y_pred, labels): 27 | cm = confusion_matrix(y_true, y_pred, labels) 28 | 29 | print 'accuracy: ', accuracy_score(y_true, y_pred) 30 | u.print_cm(cm, labels=labels) 31 | print classification_report(y_true, y_pred, target_names=labels) 32 | 33 | 34 | def main(): 35 | model_name = args['MODEL_NAME'] 36 | test_file = args['TEST_FILE'] 37 | 38 | # loading model 39 | mark_clf = rev.mark.MarkClassifier(model_name) 40 | 41 | # loading testing data 42 | test_data = np.genfromtxt(test_file, dtype=None) 43 | # test_data = test_data[0:100] 44 | test_charts = [rev.Chart(item[0]) for item in test_data] 45 | true_types = [mark_clf.categories[item[1]] for item in test_data] 46 | 47 | # classifying and evaluating 48 | pred_types = mark_clf.classify(test_charts) 49 | report_accuracy(true_types, pred_types, mark_clf.categories) 50 | 51 | 52 | if __name__ == '__main__': 53 | args = docopt(__doc__, version='1.0') 54 | if args['--show']: 55 | SHOW = True 56 | 57 | main() 58 | -------------------------------------------------------------------------------- /scripts/rate_text_role_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to compute accuracy. 3 | 4 | Usage: 5 | rate_text_role_classifier.py features FEATURES_CSV [--group_size GROUP_SIZE] 6 | rate_text_role_classifier.py (-h | --help) 7 | rate_text_role_classifier.py --version 8 | 9 | Options: 10 | --group_size GROUP_SIZE Number of elements per group. -1 means not sampling [Default: -1]. 11 | -h --help Show this screen. 12 | --version Show version. 13 | 14 | Examples: 15 | python scripts/rate_text_role_classifier.py features data/features_academic.csv 16 | python scripts/rate_text_role_classifier.py features data/features_quarts.csv 17 | python scripts/rate_text_role_classifier.py features data/features_vega.csv 18 | """ 19 | from docopt import docopt 20 | import pandas as pd 21 | import numpy as np 22 | 23 | import rev.text 24 | 25 | 26 | def sample_group(data, samples_per_group): 27 | def sampling(group, num_samples): 28 | if num_samples < 0 or num_samples > len(group): 29 | num_samples = len(group) 30 | return group.sample(num_samples) 31 | 32 | return data.groupby('type').apply(sampling, num_samples=samples_per_group) 33 | 34 | 35 | def main(): 36 | if args['features']: 37 | features_file = args['FEATURES_CSV'] 38 | samples_per_group = int(args['--group_size']) 39 | 40 | # loading model 41 | text_clf = rev.text.TextClassifier() 42 | 43 | # loading test data 44 | np.random.seed(seed=0) 45 | data = pd.read_csv(features_file) 46 | data = sample_group(data, samples_per_group) 47 | test_features = data[rev.text.classifier.VALID_COLUMNS] 48 | true_types = data['type'] 49 | 50 | # cross-validation 51 | text_clf.cross_validation(test_features, true_types, cv=5) 52 | 53 | return 54 | 55 | 56 | if __name__ == '__main__': 57 | args = docopt(__doc__, version='1.0') 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/run_feature_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run text box feature extraction 3 | 4 | Usage: 5 | run_feature_extraction.py single INPUT_PNG OUTPUT_CSV [--from_bbs=FROM] 6 | run_feature_extraction.py multiple INPUT_LIST_TXT OUTPUT_CSV [--from_bbs FROM] 7 | run_feature_extraction.py (-h | --help) 8 | run_feature_extraction.py --version 9 | 10 | Options: 11 | --from_bbs FROM 0: from bbs.csv [default: 0] 12 | 1: from pred1-bbs.csv 13 | 2: from pred2-bbs.csv 14 | -h --help Show this screen. 15 | --version Show version. 16 | 17 | Examples: 18 | python scripts/run_feature_extraction.py single examples/vega1.png out.csv 19 | python scripts/run_feature_extraction.py multiple data/academic.txt out.csv 20 | """ 21 | from docopt import docopt 22 | import pandas as pd 23 | 24 | from joblib import Parallel, delayed 25 | import multiprocessing 26 | 27 | from rev import Chart, chart_dataset 28 | from rev.text.feature_extractor import from_chart 29 | 30 | 31 | def sample_group(data, samples_per_group): 32 | def sampling(group, num_samples): 33 | if num_samples < 0 or num_samples > len(group): 34 | num_samples = len(group) 35 | return group.sample(num_samples) 36 | 37 | return data.groupby('type').apply(sampling, num_samples=samples_per_group) 38 | 39 | 40 | def main(): 41 | from_bbs = int(args['--from_bbs']) 42 | if args['single']: 43 | image_name = args['INPUT_PNG'] 44 | 45 | chart = Chart(image_name, text_from=from_bbs) 46 | text_features = from_chart(chart) 47 | text_features.to_csv(args['OUTPUT_CSV'], index=False) 48 | 49 | if args['multiple']: 50 | chart_list = args['INPUT_LIST_TXT'] 51 | 52 | # run in parallel 53 | num_cores = multiprocessing.cpu_count() 54 | results = Parallel(n_jobs=num_cores, verbose=1, backend='multiprocessing')( 55 | delayed(from_chart)(chart) for chart in chart_dataset(chart_list, from_bbs)) 56 | 57 | df = pd.concat(results) 58 | df.to_csv(args['OUTPUT_CSV'], index=False) 59 | 60 | 61 | if __name__ == '__main__': 62 | args = docopt(__doc__, version='1.0') 63 | main() 64 | -------------------------------------------------------------------------------- /rev/mark/classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import caffe 3 | 4 | models = { 5 | 'revision': { 6 | 'path': 'models/mark_classifier/revision/', 7 | 'model_file': 'deploy.prototxt', 8 | 'weights_file': 'model_iter_50000.caffemodel', 9 | 'mean_file': 'ilsvrc_2012_mean.npy', 10 | 'categories_file': 'categories.txt' 11 | }, 12 | 'charts5cats': { 13 | 'path': 'models/mark_classifier/charts5cats/', 14 | 'model_file': 'deploy.prototxt', 15 | 'weights_file': 'model_iter_50000.caffemodel', 16 | 'mean_file': 'ilsvrc_2012_mean.npy', 17 | 'categories_file': 'categories.txt' 18 | } 19 | } 20 | 21 | 22 | class MarkClassifier: 23 | def __init__(self, model_name=None): 24 | model = models[model_name if model_name is not None else 'charts5cats'] 25 | print model['path']+model['model_file'] 26 | print model['path']+'snapshots/'+model['weights_file'] 27 | print model['path']+model['mean_file'] 28 | self._net = caffe.Classifier( 29 | model_file=model['path']+model['model_file'], 30 | pretrained_file=model['path']+'snapshots/'+model['weights_file'], 31 | mean=np.load(model['path']+model['mean_file']).mean(1).mean(1), 32 | channel_swap=(2, 1, 0), 33 | raw_scale=255) 34 | 35 | self._categories = np.genfromtxt(model['path']+model['categories_file'], dtype=None, encoding=None) 36 | 37 | @property 38 | def categories(self): 39 | return self._categories 40 | 41 | def train(self): 42 | pass 43 | 44 | def classify(self, charts): 45 | def chunks(l, n): 46 | # Yield successive n-sized chunks from l. 47 | for i in xrange(0, len(l), n): 48 | yield l[i:i + n] 49 | 50 | all_predictions = [] 51 | for block_charts in chunks(charts, 100): 52 | inputs = [caffe.io.load_image(chart.filename) for chart in block_charts] 53 | 54 | predictions = self._net.predict(inputs, True) 55 | predictions = predictions.argmax(1) 56 | all_predictions.append(predictions) 57 | 58 | predictions = np.hstack(all_predictions) 59 | 60 | return self._categories[predictions] 61 | -------------------------------------------------------------------------------- /rev/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rgba2rgb(img): 5 | """ 6 | Convert the rgba image into a rgb with white background. 7 | :param img: 8 | :return: 9 | """ 10 | arr = img.astype('float') / 255. 11 | alpha = arr[..., -1] 12 | channels = arr[..., :-1] 13 | out = np.empty_like(channels) 14 | 15 | background = (1, 1, 1) 16 | for ichan in range(channels.shape[-1]): 17 | out[..., ichan] = np.clip( 18 | (1 - alpha) * background[ichan] + alpha * channels[..., ichan], 19 | a_min=0, a_max=1) 20 | 21 | return (out * 255.0).astype('uint8') 22 | 23 | 24 | def ttoi(t): 25 | """ 26 | Converts tuples values to tuple of rounded integers. 27 | """ 28 | return tuple(map(int, map(round, t))) 29 | 30 | 31 | def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None): 32 | """pretty print for confusion matrixes""" 33 | columnwidth = max([len(x) for x in labels]+[5]) # 5 is value length 34 | empty_cell = " " * columnwidth 35 | # Print header 36 | print " " + empty_cell, 37 | for label in labels: 38 | print "%{0}s".format(columnwidth) % label, 39 | print 40 | # Print rows 41 | for i, label1 in enumerate(labels): 42 | print " %{0}s".format(columnwidth) % label1, 43 | for j in range(len(labels)): 44 | cell = "%{0}.1f".format(columnwidth) % cm[i, j] 45 | if hide_zeroes: 46 | cell = cell if float(cm[i, j]) != 0 else empty_cell 47 | if hide_diagonal: 48 | cell = cell if i != j else empty_cell 49 | if hide_threshold: 50 | cell = cell if cm[i, j] > hide_threshold else empty_cell 51 | print cell, 52 | print 53 | 54 | 55 | def create_predicted1_bbs(chart): 56 | """ 57 | Create an empty bbs file with empty texts and types. 58 | :param chart: 59 | :return: 60 | """ 61 | import rev 62 | ifn = chart.filename.replace('.png', '-texts.csv') 63 | text_boxes = rev.chart.load_texts(ifn) 64 | 65 | # cleaning type field 66 | for text_box in text_boxes: 67 | text_box._type = '' 68 | 69 | ofn = chart.filename.replace('.png', '-pred1-texts.csv') 70 | rev.chart.save_texts(text_boxes, ofn) 71 | -------------------------------------------------------------------------------- /scripts/run_text_role_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to predict type of text tole and update *-texts.csv file. 3 | 4 | Usage: 5 | run_text_box_classifier.py train FEATURES_CSV OUTPUT_MODEL_PLK 6 | run_text_box_classifier.py single INPUT_PNG [--from_bbs=FROM] [--with_post] [--pad=PAD] 7 | run_text_box_classifier.py multiple INPUT_LIST_TXT [--from_bbs=FROM] [--with_post] [--pad=PAD] 8 | run_text_box_classifier.py (-h | --help) 9 | run_text_box_classifier.py --version 10 | 11 | Options: 12 | --from_bbs FROM 1: from predicted1-bbs.csv [default: 1] 13 | 2: from predicted2-bbs.csv 14 | --with_post Boolean, run post processing? 15 | --pad PAD Add padding to boxes [default: 0] 16 | -h --help Show this screen. 17 | --version Show version. 18 | 19 | Examples: 20 | # train text role classifier 21 | python scripts/run_text_role_classifier.py train data/features_all.csv out.plk 22 | 23 | # run text role classifier in a chart to test 24 | python scripts/run_text_role_classifier.py single examples/vega1.png 25 | 26 | # run text role classifier in multiple charts 27 | python scripts/run_text_role_classifier.py multiple data/academic.txt 28 | """ 29 | 30 | from docopt import docopt 31 | from joblib import Parallel, delayed 32 | import multiprocessing 33 | import pandas as pd 34 | 35 | import pandas as pd 36 | 37 | import rev.text 38 | from rev import Chart, chart_dataset 39 | 40 | 41 | def __classify(clf, chart, with_post=False, draw_debug=False, pad=0, save=False): 42 | print clf.classify(chart, with_post, draw_debug, pad, save) 43 | 44 | 45 | if __name__ == '__main__': 46 | args = docopt(__doc__, version='1.0') 47 | draw_debug = True 48 | 49 | if args['train']: 50 | features_file = args['FEATURES_CSV'] 51 | output_file = args['OUTPUT_MODEL_PLK'] 52 | 53 | data = pd.read_csv(features_file) 54 | features = data[rev.text.classifier.VALID_COLUMNS] 55 | types = data['type'] 56 | 57 | text_clf = rev.text.TextClassifier() 58 | text_clf.train(features, types) 59 | text_clf.save_model(output_file) 60 | 61 | if args['single']: 62 | # clf = bc.load_classifier() 63 | image_name = args['INPUT_PNG'] 64 | from_bbs = int(args['--from_bbs']) 65 | with_post = args['--with_post'] 66 | pad = int(args['--pad']) 67 | print with_post 68 | 69 | chart = Chart(image_name, text_from=from_bbs) 70 | text_clf = rev.text.TextClassifier('default') 71 | pred_types = text_clf.classify(chart, with_post, draw_debug, pad, save=True) 72 | print pred_types 73 | 74 | if args['multiple']: 75 | chart_list = args['INPUT_LIST_TXT'] 76 | from_bbs = int(args['--from_bbs']) 77 | with_post = args['--with_post'] 78 | pad = int(args['--pad']) 79 | 80 | text_clf = rev.text.TextClassifier('default') 81 | # run in parallel 82 | num_cores = multiprocessing.cpu_count() 83 | results = Parallel(n_jobs=num_cores, verbose=1, backend='multiprocessing')( 84 | delayed(__classify)(text_clf, chart, with_post, draw_debug, pad, True) 85 | for chart in chart_dataset(chart_list, from_bbs)) 86 | 87 | # print 'Total boxes : %d' % sum(results) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reverse-Engineering Visualizations (REV) 2 | 3 | REV([paper](http://idl.cs.washington.edu/papers/reverse-engineering-vis/)) is a text analysis pipeline which detects text elements in a chart, classifies their role (e.g., chart title, x-axis label, y-axis title, etc.), and recovers the text content using optical character recognition. It also uses a Convolutional Neural Network for mark type classification. Using the identified text elements and graphical mark type, it infers the encoding specification of an input chart image. 4 | 5 | Our pipeline consist of the following steps: 6 | 7 | * Text localization and recognition 8 | * Text role classification 9 | * Mark type classification 10 | * Specification induction 11 | 12 | ## Installation 13 | You first need to download our code: 14 | ```sh 15 | git clone git@github.com:uwdata/rev.git 16 | ``` 17 | 18 | Then, download the data and modes are in the following 19 | [link](https://drive.google.com/open?id=1Bg9hyxlt2szXj6CBWIIt3yInIjKEqPFx). 20 | You have to unzip the files in the project folder. 21 | 22 | 23 | ### Dependencies 24 | * conda create -n rev python=2.7 opencv=3.4 pandas scikit-image scikit-learn 25 | docopt joblib 26 | * caffe 1 (https://caffe.berkeleyvision.org/installation.html) 27 | 28 | ## Using our API 29 | In this example we assume that we have the text elements from a chart. For a given image (`image.png`), text elements should be provided in a CSV file named `image-texts.csv` with the following format. 30 | 31 | ```CSV 32 | id,x,y,width,height,text,type 33 | 1,30,5,19,17,"45",y-axis-label 34 | ... 35 | ``` 36 | Check file `examples/vega1-texts.csv` for an example. 37 | 38 | Text `type` can be: `title`, `x-axis-title`, `x-axis-label`, `y-axis-title`, 39 | `y-axis-label`, `legend-title`, `legend-label`, and `text-label`. 40 | 41 | However, in most cases we do not have access to the text elements, then, we can infer them using our pipeline. Each step of our pipeline can be run independently. 42 | 43 | 44 | 45 | #### Text localization and recognition 46 | 47 | #### Text role classification 48 | ```python 49 | import rev.text 50 | 51 | # feature extraction (single) 52 | text_features = rev.text.feature_extractor.from_chart(chart) 53 | print text_features 54 | 55 | # feature extraction (corpus) 56 | text_features = rev.text.feature_extractor.from_chart(chart) 57 | print text_features 58 | 59 | # text role classification 60 | text_clf = rev.text.TextClassifier('default') 61 | text_type_preds = text_clf.classify(chart) 62 | print text_type_preds 63 | 64 | # training text role classifier 65 | import pandas as pd 66 | data = pd.read_csv('data/features_all.csv') 67 | features = data[rev.text.classifier.VALID_COLUMNS] 68 | types = data['type'] 69 | 70 | text_clf = rev.text.TextClassifier() 71 | text_clf.train(features, types) 72 | text_clf.save_model('out.pkl') 73 | ``` 74 | 75 | #### Mark type classification 76 | ```python 77 | import rev.mark 78 | mark_clf = rev.mark.MarkClassifier() 79 | print mark_clf.classify([chart]) 80 | ``` 81 | 82 | #### Specification induction 83 | 84 | 85 | ## Scripts 86 | Some usefull script to reproduce results from paper: 87 | ```shell 88 | # code to rate the text-role classifier (Table 4 from paper) 89 | python scripts/rate_text_role_classifier.py features data/features_academic.csv 90 | python scripts/rate_text_role_classifier.py features data/features_quartz.csv 91 | python scripts/rate_text_role_classifier.py features data/features_vega.csv 92 | 93 | # script to extract features 94 | python scripts/run_feature_extraction.py multiple data/academic.txt out.csv 95 | 96 | # train text-role classifier 97 | python scripts/run_text_role_classifier.py train data/features_all.csv out.plk 98 | 99 | # run text-role classifier in a chart to test 100 | python scripts/run_text_role_classifier.py single examples/vega1.png 101 | 102 | # run text-role classifier in multiple charts 103 | python scripts/run_text_role_classifier.py multiple data/academic.txt 104 | ``` 105 | 106 | -------------------------------------------------------------------------------- /rev/chart.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pandas as pd 3 | import os.path 4 | import numpy as np 5 | 6 | from .textbox import TextBox 7 | from . import utils as u 8 | 9 | 10 | ''' 11 | The attribute 'text_from' means: 12 | 0: read ground truth data: 13 | 'chart-texts.csv' 14 | 'chart-mask.png' 15 | 'chart-debug.png' 16 | 1: read text from 'prediction 1': 17 | i.e. ground truth boxes and output of text role classification and output of OCR. 18 | 'chart-pred1-texts.csv' 19 | 'chart-pred1-mask.png' 20 | 'chart-pred1-debug.png' 21 | 2: read text from 'prediction 2', 22 | i.e., output of text localization and output of text role classification, and output of OCR. 23 | 'chart-pred2-texts.csv' 24 | 'chart-pred2-mask.png' 25 | 'chart-pred2-debug.png' 26 | ''' 27 | prefixes = {0: '', 1: '-pred1', 2: '-pred2'} 28 | 29 | 30 | class Chart(object): 31 | def __init__(self, fn, _id=None, text_from=0): 32 | self._fn = fn 33 | self._id = _id 34 | self._text_from = text_from 35 | 36 | self._image = None 37 | self._texts = None 38 | self._mark_type = None 39 | 40 | self._prefix = prefixes[text_from] 41 | 42 | @property 43 | def filename(self): 44 | return self._fn 45 | 46 | @property 47 | def text_boxes_filename(self): 48 | return self._fn.replace('.png', self._prefix + '-texts.csv') 49 | 50 | @property 51 | def id(self): 52 | return self._id 53 | 54 | @property 55 | def text_from(self): 56 | return self._text_from 57 | 58 | @property 59 | def image(self): 60 | if self._image is None: 61 | print 62 | self._image = cv2.imread(self._fn, cv2.IMREAD_UNCHANGED) 63 | if self._image.dtype == 'uint16': 64 | self._image = (self._image / 256.0).astype('uint8') 65 | else: 66 | self._image = self._image.astype('uint8') 67 | 68 | if len(self._image.shape) == 2: 69 | self._image = cv2.merge((self._image, self._image, self._image)) 70 | elif self._image.shape[2] == 4: 71 | self._image = u.rgba2rgb(self._image) 72 | cv2.imwrite(self._fn, self._image) 73 | 74 | return self._image 75 | 76 | @property 77 | def text_boxes(self): 78 | if self._texts is None: 79 | fn = self._fn.replace('.png', self._prefix + '-texts.csv') 80 | self._texts = load_texts(fn) 81 | 82 | return self._texts 83 | 84 | @property 85 | def mask(self, force_to_create=False): 86 | fn = self._fn.replace('.png', self._prefix + '-mask.png') 87 | if not os.path.exists(fn) or force_to_create: 88 | h, w, _ = self.image.shape 89 | mask = create_mask((h, w), self.texts) 90 | cv2.imwrite(fn, mask) 91 | 92 | return cv2.imread(fn, cv2.IMREAD_GRAYSCALE) 93 | 94 | # @property 95 | # def pixel_mask(self, force_to_create=False): 96 | # fn = self._fn.replace('.png', 'predicted-mask.png') 97 | # if not os.path.exists(fn) or force_to_create: 98 | # # from mask_predictor import predict_mask 99 | # # predict_mask(self) 100 | # pass 101 | # 102 | # return cv2.imread(fn, cv2.IMREAD_GRAYSCALE) 103 | 104 | @property 105 | def debug(self): 106 | fn = self._fn.replace('.png', self._prefix + '-debug.png') 107 | return cv2.imread(fn, cv2.IMREAD_COLOR) 108 | 109 | 110 | def save_text_boxes(self): 111 | save_texts(self.text_boxes, self.text_boxes_filename) 112 | 113 | 114 | def create_mask((h, w), texts): 115 | mask = np.zeros((h, w), np.uint8) 116 | for t in texts: 117 | cv2.rectangle(mask, u.ttoi(t.p1), u.ttoi(t.p2), 255, thickness=-1) 118 | 119 | return mask 120 | 121 | 122 | def load_texts(fn): 123 | df = pd.read_csv(fn) 124 | df.replace(np.nan, '', inplace=True) 125 | 126 | # force text column to be string 127 | df.text = df.text.astype(str) 128 | 129 | texts = [] 130 | for idx, row in df.iterrows(): 131 | text = TextBox(row.id, row.x, row.y, row.width, row.height, row.type, row.text) 132 | texts.append(text) 133 | 134 | return texts 135 | 136 | 137 | def save_texts(text_boxes, fn): 138 | rows = [] 139 | for t in text_boxes: 140 | rows.append(t.to_dict()) 141 | df = pd.DataFrame(rows) 142 | df = df[rows[0].keys()] 143 | df.to_csv(fn, index=False) 144 | 145 | 146 | def chart_dataset(chart_list, from_bbs=0): 147 | corpus = os.path.splitext(os.path.basename(chart_list))[0] 148 | with open(chart_list) as f: 149 | for idx, line in enumerate(f): 150 | yield Chart(line.strip(), _id='%s-%04d' % (corpus, idx), text_from=from_bbs) 151 | -------------------------------------------------------------------------------- /rev/text/classifier.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pandas as pd 3 | import sys 4 | import os 5 | 6 | from sklearn.externals import joblib 7 | from sklearn.model_selection import cross_val_predict 8 | from sklearn import svm 9 | import sklearn.metrics as metrics 10 | from sklearn.pipeline import make_pipeline 11 | from sklearn.preprocessing import StandardScaler 12 | from sklearn.utils.multiclass import unique_labels 13 | 14 | from .. import utils as u 15 | from . import feature_extractor 16 | 17 | # import warnings 18 | # warnings.filterwarnings("ignore") 19 | 20 | model_files = { 21 | 'default': 'models/text_role_classifier/text_type_classifier.pkl', 22 | 'testing': 'models/text_role_classifier/text_type_classifier_new.pkl' 23 | } 24 | 25 | # valid features 26 | VALID_COLUMNS = [ 27 | # 'fig_id', 28 | 'vscore', 'hscore', 29 | 'vrange', 'hrange', 30 | 'vfreq', 'hfreq', 31 | # 'x', 'y', 32 | # 'x2', 'y2', 33 | # 'xc', 'yc', 34 | # 'w', 'h', 35 | # 'xp', 'yp', 36 | # 'x2p', 'y2p', 37 | 'xcp', 'ycp', 38 | # 'wp', 'hp', 39 | 'aspect', 40 | 'ang', 41 | 'quad', 42 | # 'bw', 'bh', 43 | 'bwp', 'bhp', 44 | # 'u', 'v', 45 | # 'u2', 'v2', 46 | 'uc', 'vc', 47 | # 'du', 'dv', 48 | # 'adu', 'adv', 49 | 'rad', 50 | # 'text', 51 | # 'type' 52 | ] 53 | 54 | 55 | class TextClassifier: 56 | def __init__(self, model_name=None): 57 | if model_name is None: 58 | # Pipeline: standardization -> svm 59 | my_svm = svm.SVC(C=100, gamma=0.1, class_weight='balanced', kernel='rbf') 60 | self._clf = make_pipeline(StandardScaler(), my_svm) 61 | else: 62 | model_file = model_files[model_name] 63 | self._clf = joblib.load(model_file) 64 | 65 | def train(self, features, types): 66 | """ 67 | Train an svm model with the complete dataset. 68 | This classifier will be used for the following steps in the pipeline. 69 | :param features: 70 | :param types: 71 | :return: 72 | """ 73 | print >> sys.stderr, 'fitting...', 74 | self._clf.fit(features, types) 75 | print 'DONE' 76 | 77 | print >> sys.stderr, 'evaluating...', 78 | pred_types = self._clf.predict(features) 79 | print 'DONE' 80 | 81 | cm = metrics.confusion_matrix(types, pred_types, labels=self._clf.classes_) 82 | u.print_cm(cm, labels=self._clf.classes_) 83 | print 'accuracy: ', metrics.accuracy_score(types, pred_types) 84 | print 'wrong boxes: ', sum(types != pred_types) 85 | 86 | def cross_validation(self, features, true_types, cv): 87 | labels = unique_labels(true_types) 88 | print 'total after sampling:', len(true_types) 89 | print pd.value_counts(true_types)[labels] 90 | 91 | # cross-validation 92 | pred_type = cross_val_predict(self._clf, features, true_types, cv=cv, n_jobs=-1) 93 | print metrics.classification_report(true_types, pred_type, target_names=labels) 94 | print 'Accuracy: ', metrics.accuracy_score(true_types, pred_type) 95 | 96 | cm = metrics.confusion_matrix(true_types, pred_type, labels=labels) 97 | u.print_cm(cm, labels=labels) 98 | 99 | def classify(self, chart, with_post=False, draw_debug=False, pad=0, save=False): 100 | """ 101 | Classify text boxes in a chart and save them in a cvs file 102 | :param chart: 103 | :param with_post 104 | :param draw_debug 105 | :param save: save pred_type in the *-texts.csv file 106 | :return: 107 | """ 108 | if chart.text_from == 1 and not os.path.isfile(chart.text_boxes_filename): 109 | u.create_predicted1_bbs(chart) 110 | 111 | # extract boxes from chart 112 | fh, fw, _ = chart.image.shape 113 | text_boxes = copy.deepcopy(chart.text_boxes) 114 | for b in text_boxes: 115 | b.wrap_rect((fh, fw), padx=pad, pady=pad) 116 | 117 | pred_types = self.classify_from_boxes(text_boxes, (fh, fw), with_post) 118 | 119 | if save: 120 | for text_box, pred_type in zip(chart.text_boxes, pred_types): 121 | text_box._type = pred_type 122 | chart.save_text_boxes() 123 | 124 | return pred_types 125 | 126 | def classify_from_boxes(self, text_boxes, shape, with_post=False): 127 | """ 128 | Classify text boxes 129 | :param text_boxes: bounding boxes 130 | :param shape: (fh, fw) figure height and width. 131 | :return: 132 | """ 133 | data = feature_extractor.from_text_boxes(text_boxes, shape, 0, '') 134 | features = data[VALID_COLUMNS] 135 | 136 | # predict class 137 | pred_types = self._clf.predict(features) 138 | 139 | # if with_post: 140 | # self.post_process(boxes) 141 | 142 | return pred_types 143 | 144 | def save_model(self, filename): 145 | joblib.dump(self._clf, filename) 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /rev/textbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | import cv2 4 | 5 | 6 | class TextBox(object): 7 | def __init__(self, id, x, y, w, h, type='', text=''): 8 | self._id = id 9 | self._rect = (float(x), float(y), float(w), float(h)) 10 | self._type = type 11 | self._text = text 12 | 13 | self._regions = [] # the connected components 14 | 15 | def __str__(self): 16 | return 'bbox[{4}]: [{0} {1} {2} {3}] : [{5}] : [{6}]'\ 17 | .format(self.x, self.y, self.w, self.h, self._id, self._type, self._text) 18 | 19 | __repr__ = __str__ 20 | 21 | def center(self): 22 | return self.x + self.w / 2.0, self.y + self.h / 2.0 23 | 24 | def area(self): 25 | return self.w * self.h 26 | 27 | # def expand(self, factor): 28 | # _w, _h = self.w * factor, self.h * factor 29 | # _x = self.x - (_w - self.w) / 2.0 30 | # _y = self.y - (_h - self.h) / 2.0 31 | # return TextBox(self._id, self._type, _x, _y, _w, _h) 32 | 33 | def wrap_rect(self, (fh, fw), padx=2, pady=None): 34 | pady = padx if pady is None else pady 35 | nx, ny = max(self.x - padx, 0), max(self.y - pady, 0) 36 | nw = min(self.x + self.w + padx, fw) - nx 37 | nh = min(self.y + self.h + pady, fh) - ny 38 | self._rect = (nx, ny, nw, nh) 39 | 40 | @property 41 | def type(self): 42 | return self._type 43 | 44 | @property 45 | def text(self): 46 | return self._text 47 | 48 | @property 49 | def x(self): 50 | return self._rect[0] 51 | 52 | @property 53 | def y(self): 54 | return self._rect[1] 55 | 56 | @property 57 | def w(self): 58 | return self._rect[2] 59 | 60 | @property 61 | def h(self): 62 | return self._rect[3] 63 | 64 | @property 65 | def rect0(self): 66 | x, y, w, h = self._rect 67 | return [x, y, x + w - 1, y + h - 1] 68 | 69 | @property 70 | def x1(self): 71 | return self._rect[0] 72 | 73 | @property 74 | def y1(self): 75 | return self._rect[1] 76 | 77 | @property 78 | def x2(self): 79 | return self.x + self.w - 1 80 | 81 | @property 82 | def y2(self): 83 | return self.y + self.h - 1 84 | 85 | @property 86 | def xc(self): 87 | return self.x + self.w / 2.0 88 | 89 | @property 90 | def yc(self): 91 | return self.y + self.h / 2.0 92 | 93 | @property 94 | def p1(self): 95 | return self.x, self.y 96 | 97 | @property 98 | def p2(self): 99 | return self.x + self.w - 1, self.y + self.h - 1 100 | 101 | def to_dict(self): 102 | row = OrderedDict() 103 | row['id'] = self._id 104 | row['x'] = self.x 105 | row['y'] = self.y 106 | row['width'] = self.w 107 | row['height'] = self.h 108 | row['text'] = self._text 109 | row['type'] = self._type 110 | return row 111 | 112 | def jaccard_similarity(self, tbox): 113 | """ 114 | Calculates the Jaccard similarity (the similarity used in the 115 | PASCAL VOC) 116 | Note: the are could be computed as: 117 | area_intersection = bbox.copy().intersect(self).area() 118 | but we replicate the code for efficency reason. 119 | 120 | copied from https://github.com/lorisbaz/self-taught_localization/blob/master/textbox.py 121 | """ 122 | xmin = max(self.x, tbox.x) 123 | ymin = max(self.y, tbox.y) 124 | xmax = min(self.x + self.w, tbox.x + tbox.w) 125 | ymax = min(self.y + self.h, tbox.y + tbox.h) 126 | if (xmin > xmax) or (ymin > ymax): 127 | xmin = 0.0 128 | ymin = 0.0 129 | xmax = 0.0 130 | ymax = 0.0 131 | 132 | area_intersection = np.abs(xmax - xmin) * np.abs(ymax - ymin) 133 | area_union = self.area() + tbox.area() - area_intersection 134 | return area_intersection / float(area_union) 135 | 136 | def matching_score(self, tbox): 137 | """ 138 | Score use to determine the matching between two rectangles. 139 | http://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=D7BFE2DA34118919E31A9A2FC5F85170?doi=10.1.1.104.1667&rep=rep1&type=pdf 140 | 141 | Note matching_score >= jaccard_similarity 142 | :param tbox: 143 | :return: 144 | """ 145 | xmin = max(self.x, tbox.x) 146 | ymin = max(self.y, tbox.y) 147 | xmax = min(self.x + self.w, tbox.x + tbox.w) 148 | ymax = min(self.y + self.h, tbox.y + tbox.h) 149 | if (xmin > xmax) or (ymin > ymax): 150 | xmin = 0.0 151 | ymin = 0.0 152 | xmax = 0.0 153 | ymax = 0.0 154 | 155 | area_intersection = np.abs(xmax - xmin) * np.abs(ymax - ymin) 156 | if self.area() + tbox.area() == 0: 157 | return 0 158 | 159 | return 2.0 * area_intersection / (self.area() + tbox.area()) 160 | 161 | def find_best_match(self, texts, return_all=False): 162 | coeffs = [] 163 | for tbox in texts: 164 | coeff = self.matching_score(tbox) 165 | coeffs.append((tbox, coeff)) 166 | 167 | if return_all: 168 | return coeffs 169 | 170 | return max(coeffs, key=lambda t: t[1]) 171 | 172 | @staticmethod 173 | def merge_boxes(texts, id=0): 174 | points = [] 175 | new_tbox = TextBox(id, 0, 0, 0, 0) 176 | for tbox in texts: 177 | points.append(ru.points(tbox.rect)) 178 | new_tbox._regions.extend(tbox.regions) 179 | 180 | new_tbox._rect = cv2.boundingRect(np.concatenate(points).astype('float32')) 181 | 182 | return new_tbox 183 | 184 | @property 185 | def num_comp(self): 186 | return len(self._regions) 187 | 188 | -------------------------------------------------------------------------------- /rev/text/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | 6 | filtered_columns = [ 7 | 'fig_id', 'fig_fn', 8 | 'fw', 'fh', 9 | 'vscore', 'hscore', 10 | 'vrange', 'hrange', 11 | 'vfreq', 'hfreq', 12 | 'x', 'y', 13 | 'x2', 'y2', 14 | 'xc', 'yc', 15 | 'w', 'h', 16 | 'xp', 'yp', 17 | 'x2p', 'y2p', 18 | 'xcp', 'ycp', 19 | 'wp', 'hp', 20 | 'aspect', 21 | 'ang', 22 | 'quad', 23 | 'bw', 'bh', 24 | 'bwp', 'bhp', 25 | 'u', 'v', 26 | 'u2', 'v2', 27 | 'uc', 'vc', 28 | 'du', 'dv', 29 | 'adu', 'adv', 30 | 'rad', 31 | 'text', 32 | 'type' 33 | ] 34 | 35 | 36 | def from_chart(chart): 37 | """ 38 | Extract geometric features from bounding boxes in a chart. 39 | Assuming all the boxes belong to a figure. 40 | 41 | :param chart: chart object 42 | :return: panda DataFrame with features for all the boxes. 43 | """ 44 | text_boxes = chart.text_boxes 45 | fh, fw, _ = chart.image.shape 46 | features = from_text_boxes(text_boxes, (fh, fw), chart.id, chart.filename) 47 | 48 | return features 49 | 50 | 51 | def from_text_boxes(boxes, (fh, fw), chart_id, chart_fn=''): 52 | """ 53 | Extract geometric features from bounding boxes. 54 | Assuming all the boxes belong to a figure. 55 | 56 | :param boxes: 57 | :param (fh, fw): 58 | :param chart_id: figure id. 59 | :param chart_fn: file name for debug. 60 | :return: panda DataFrame with features for all the boxes 61 | """ 62 | rows = [] 63 | for box in boxes: 64 | vscore, hscore, vrange, hrange, vfreq, hfreq = alignment_scores(box, boxes, (fh, fw)) 65 | 66 | row = box.to_dict() 67 | row['vscore'] = vscore 68 | row['hscore'] = hscore 69 | row['vrange'] = vrange 70 | row['hrange'] = hrange 71 | row['vfreq'] = vfreq 72 | row['hfreq'] = hfreq 73 | # TODO(jpocom) 74 | # temporal fix, because in this class we use 'w' and 'h' instead of 'width' and 'height' 75 | row['w'] = row['width'] 76 | row['h'] = row['height'] 77 | row['fig_id'] = chart_id 78 | row['fig_fn'] = chart_fn 79 | row['fw'] = fw 80 | row['fh'] = fh 81 | 82 | rows.append(row) 83 | 84 | df = pd.DataFrame(rows) 85 | 86 | if not rows: 87 | return df 88 | 89 | # right-bottom coordinate 90 | df['x2'] = df.x + df.w 91 | df['y2'] = df.y + df.h 92 | # center coordinate 93 | df['xc'] = df.x + df.w / 2.0 94 | df['yc'] = df.y + df.h / 2.0 95 | 96 | # normalized top-left coordinate 97 | df['xp'] = df.x / fw 98 | df['yp'] = df.y / fh 99 | # normalized right-bottom coordinate 100 | df['x2p'] = df.x2 / fw 101 | df['y2p'] = df.y2 / fh 102 | # normalized center coordinate 103 | df['xcp'] = df.xc / fw 104 | df['ycp'] = df.yc / fh 105 | # normalized box size 106 | df['wp'] = df.w / fw 107 | df['hp'] = df.h / fh 108 | 109 | # aspect ratio in log-10 units 110 | df['aspect'] = np.log10(df.w / df.h) 111 | 112 | # angle from actual center [-1,1] 113 | # 0+ -> counter-clockwise from positive x-axis 114 | df['ang'] = np.arctan2(df.yc - fh / 2.0, df.xc - fw / 2.0) / np.pi 115 | # discretize angles into quadrants (0, 1, 2, 3) 116 | df['quad'] = np.floor(2 * (df.ang + 1)) % 4 117 | 118 | xmin = df['x'].min() 119 | ymin = df['y'].min() 120 | x2max = df['x2'].max() 121 | y2max = df['y2'].max() 122 | 123 | # bounding-width (bw) and bounding-height (bh) 124 | # bounding box of region containing text boxes 125 | df['bw'] = (x2max - xmin) 126 | df['bh'] = (y2max - ymin) 127 | df['bwp'] = df.bw / fw 128 | df['bhp'] = df.bh / fh 129 | 130 | # normalized top-left coordinate in container box 131 | df['u'] = (df.x - xmin) / df.bw 132 | df['v'] = (df.y - ymin) / df.bh 133 | 134 | # normalized bottom-right coordinate in container box 135 | df['u2'] = (df.x2 - xmin) / df.bw 136 | df['v2'] = (df.y2 - ymin) / df.bh 137 | 138 | # normalized bottom-right coordinate in container box 139 | df['uc'] = (df.xc - xmin) / df.bw 140 | df['vc'] = (df.yc - ymin) / df.bh 141 | 142 | def extremum(a, b): 143 | return np.where(abs(b) < abs(a), a, b) 144 | 145 | df['du'] = extremum(2 * df.u - 1, 2 * df.u2 - 1) 146 | df['dv'] = extremum(2 * df.v - 1, 2 * df.v2 - 1) 147 | 148 | # absolute extremal point [0,1] 149 | df['adu'] = abs(df.du) 150 | df['adv'] = abs(df.dv) 151 | 152 | # radius from normalized center 153 | df['rad'] = np.sqrt(df.du * df.du + df.dv * df.dv) 154 | 155 | return df[filtered_columns] 156 | 157 | 158 | def alignment_scores(ref_box, boxes, (fh, fw)): 159 | """ 160 | Return the number of boxes which intersect vertically and horizontally the 161 | 'ref_box'. These values are normalized by the total number of boxes. 162 | 163 | :param ref_box: reference box 164 | :param boxes: set of boxes in figure 165 | :return: (vscore, hscore, vrange, hrange, vfreq, hfreq) 166 | """ 167 | getters = { 168 | 'vert': {'left': lambda b: b.x1, 'right': lambda b: b.x2, 'center': lambda b: b.xc}, 169 | 'hori': {'left': lambda b: b.y1, 'right': lambda b: b.y2, 'center': lambda b: b.yc} 170 | } 171 | 172 | r = ref_box 173 | count = defaultdict(int) 174 | aboxes = defaultdict(list) 175 | span = defaultdict(float) 176 | freq = defaultdict(float) 177 | th = 3 178 | for orient in ['vert', 'hori']: 179 | getter = getters[orient] 180 | 181 | for b in boxes: 182 | aligned = {'left': True, 'right': True, 'center': True} 183 | for pos in aligned.keys(): 184 | if abs(getter[pos](r) - getter[pos](b)) > th: 185 | aligned[pos] = False 186 | 187 | if any(aligned.values()): 188 | count[orient] += 1 189 | aboxes[orient].append(b) 190 | 191 | getter = getters['vert' if orient == 'hori' else 'hori'] 192 | aboxes[orient].sort(key=lambda b: getter['center'](b)) 193 | 194 | values = [getter['center'](b) for b in aboxes[orient]] 195 | span[orient] = abs(values[-1] - values[0]) 196 | 197 | pos = np.searchsorted(values, getter['center'](r)) 198 | values[pos] += 1e10 199 | closest = np.argmin(np.abs(np.array(values) - getter['center'](r))) 200 | values[pos] -= 1e10 201 | 202 | freq[orient] = 0.0 if span[orient] == 0 else 1.0 - abs(values[closest] - values[pos]) / span[orient] 203 | 204 | num_boxes = float(len(boxes)) 205 | return count['vert'] / num_boxes, count['hori'] / num_boxes, \ 206 | span['vert'] / fh, span['hori'] / fw, \ 207 | freq['vert'], freq['hori'] 208 | -------------------------------------------------------------------------------- /rev/text/rectutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/szakrewsky/text-search/blob/master/rectutils.py 3 | """ 4 | import cv2 5 | import numpy as np 6 | from colormath.color_objects import sRGBColor, LabColor 7 | from colormath.color_conversions import convert_color 8 | from colormath.color_diff import delta_e_cie2000 9 | 10 | from shapely.geometry import LineString 11 | 12 | from .. import utils as u 13 | 14 | def on_same_line(r1, r2, horiz=True): 15 | x1, y1, w1, h1 = r1 if horiz else (r1[1], r1[0], r1[3], r1[2]) 16 | x2, y2, w2, h2 = r2 if horiz else (r2[1], r2[0], r2[3], r2[2]) 17 | 18 | over, d = range_overlap((y1, y1+h1), (y2, y2+h2)) 19 | if over and d > min(h1, h2) / 2.0: 20 | return True 21 | return False 22 | 23 | 24 | def next_on_same_line(r1, r2, dist=None, horiz=True): 25 | x1, y1, w1, h1 = r1 if horiz else (r1[1], r1[0], r1[3], r1[2]) 26 | x2, y2, w2, h2 = r2 if horiz else (r2[1], r2[0], r2[3], r2[2]) 27 | dist = min(h1, h2) / float(2) if dist is None else dist 28 | 29 | if not on_same_line(r1, r2, horiz=horiz) or abs(x1 + w1 - x2) > dist: 30 | return False 31 | return True 32 | 33 | 34 | def on_consecutive_line(r1, r2): 35 | x1, y1, w1, h1 = r1 36 | x2, y2, w2, h2 = r2 37 | if abs(y1 + h1 - y2) > min(h1,h2)/float(2): 38 | return False 39 | return True 40 | 41 | 42 | def same_height(r1, r2, max_diff=None, horiz=True): 43 | x1, y1, w1, h1 = r1 if horiz else (r1[1], r1[0], r1[3], r1[2]) 44 | x2, y2, w2, h2 = r2 if horiz else (r2[1], r2[0], r2[3], r2[2]) 45 | max_diff = min(h1, h2) if max_diff is None else max_diff 46 | 47 | if abs(h1 - h2) > max_diff: 48 | return False 49 | return True 50 | 51 | 52 | def overlap(r1, r2): 53 | """ 54 | Based on http://codereview.stackexchange.com/questions/31352/overlapping-rectangles 55 | Overlapping rectangles overlap both horizontally & vertically 56 | """ 57 | x1, y1, w1, h1 = r1 58 | x2, y2, w2, h2 = r2 59 | over1, _ = range_overlap((x1, x1 + w1), (x2, x2 + w2)) 60 | over2, _ = range_overlap((y1, y1 + h1), (y2, y2 + h2)) 61 | return over1 and over2 62 | 63 | 64 | def range_overlap((a_min, a_max), (b_min, b_max)): 65 | """ 66 | Based on http://codereview.stackexchange.com/questions/31352/overlapping-rectangles 67 | Neither range is completely greater than the other 68 | """ 69 | if (a_min <= b_max) and (b_min <= a_max): 70 | return True, min(a_max, b_max) - max(a_min, b_min) + 1 71 | return False, -1 72 | 73 | 74 | def inside(r1, r2): 75 | """ 76 | Check if r1 is inside r2 77 | """ 78 | x1, y1, w1, h1 = r1 79 | x2, y2, w2, h2 = r2 80 | return (x1 >= x2) and (y1 >= y2) and (x1+w1 <= x2+w2) and (y1+h1 <= y2 + h2) 81 | 82 | 83 | def rect_segment_intersection(rect, seg): 84 | """ 85 | Returns the intersection of a rectangle rect and seg 86 | :param rect: (x, y, w, h) 87 | :param seg: (point1, point2) 88 | :return tuple(x, y, v), where 89 | (x, y) is the intersection 90 | v == False if there are 0 or inf. intersections (invalid) 91 | v == True if it has a unique intersection ON the segment 92 | """ 93 | x, y, w, h = rect 94 | segments = LineString([(x, y), (x + w, y), (x + w, y + h), (x, y + h), (x, y)]) 95 | segment = LineString(seg) 96 | inter = segments.intersection(segment) 97 | if inter.geom_type == 'Point': 98 | return inter.x, inter.y, True 99 | 100 | return 0, 0, False 101 | 102 | 103 | def center(r): 104 | x, y, w, h = r 105 | w = w if w > 2 else w - 0.1 106 | h = h if h > 2 else h - 0.1 107 | return x + w / 2.0, y + h / 2.0 108 | 109 | def filter_duplicates(rects): 110 | print "Filtering %d regions..." % (len(rects)) 111 | 112 | th = 10 113 | C = np.zeros((len(rects), len(rects)), dtype=bool) 114 | for i, r1 in enumerate(rects): 115 | for j, r2 in enumerate(rects): 116 | # if abs(r1[0] - r2[0]) < th and abs(r1[1] - r2[1]) < th and \ 117 | # abs(r1[2] - r2[2]) < th and abs(r1[3] - r2[3]) < th: 118 | # C[i, j] = True 119 | # if overlap(r1, r2): 120 | # C[i, j] = True 121 | if inside(r1, r2): 122 | C[i, j] = True 123 | 124 | rects, group_indices = __bfs_bbx(rects, C) 125 | 126 | print "\tto %d regions" % (len(rects)) 127 | return rects, group_indices 128 | 129 | 130 | def mean_color(img, bw, rect): 131 | x, y, w, h = wrap_rect(rect, bw.shape, padx=1) 132 | roi = img[y:y + h, x:x + w, :] 133 | roi_bw = bw[y:y + h, x:x + w] 134 | 135 | pos = np.transpose(np.nonzero(roi_bw)) 136 | rows = pos[:, 0] 137 | cols = pos[:, 1] 138 | meancolor = np.mean(roi[rows, cols], axis=0) 139 | 140 | # vis = cv2.cvtColor(roi_bw, cv2.COLOR_GRAY2BGR) 141 | # vis[roi_bw==255] = meancolor 142 | # 143 | # vis = np.hstack((roi, cv2.cvtColor(roi_bw, cv2.COLOR_GRAY2BGR), vis)) 144 | # show_image('img (%d, %d, %d, %d)' % (x, y, w, h), vis) 145 | 146 | return meancolor 147 | 148 | 149 | def color_dist(img, bw, r1, r2): 150 | c1 = mean_color(img, bw, r1)/255. 151 | c2 = mean_color(img, bw, r2)/255. 152 | 153 | c1_lab = convert_color(sRGBColor(c1[2], c1[1], c1[0]), LabColor) 154 | c2_lab = convert_color(sRGBColor(c2[2], c2[1], c2[0]), LabColor) 155 | delta_e = delta_e_cie2000(c1_lab, c2_lab) 156 | return delta_e 157 | 158 | 159 | def find_words(rects, img): 160 | C = np.zeros((len(rects), len(rects)), dtype=bool) 161 | for i, r1 in enumerate(rects): 162 | x1, y1, w1, h1 = r1 163 | for j, r2 in enumerate(rects): 164 | x2, y2, w2, h2, = r2 165 | if i == j or \ 166 | (abs(y1 - y2) < min(h1, h2)/float(2) and # almost same level 167 | abs(h1 - h2) < min(h1, h2)/2. and # almost same height 168 | # (inside(r1, r2) or inside(r2, r1)) and 169 | # (abs(x1 + w1 - x2) < 10 or abs(x2 + w2 - x1) < 10) and # boxes distance 170 | (abs(x1 + w1 - x2) < 10 or abs(x2 + w2 - x1) < 10 or inside(r1, r2) or inside(r2, r1)) and # boxes distance 171 | color_dist(img, r1, r2) < 10): # almost same color 172 | C[i, j] = True 173 | 174 | rects, group_indices = __bfs_bbx(rects, C) 175 | return rects 176 | 177 | 178 | def find_words2(rects, img): 179 | C = np.zeros((len(rects), len(rects)), dtype=bool) 180 | for i, r1 in enumerate(rects): 181 | y1, x1, h1, w1 = r1 182 | for j, r2 in enumerate(rects): 183 | y2, x2, h2, w2, = r2 184 | if i == j or \ 185 | (abs(y1 - y2) < min(h1, h2)/float(2) and # almost same level 186 | abs(h1 - h2) < min(h1, h2)/float(2) and # almost same height 187 | (abs(x1 + w1 - x2) < 10 or abs(x2 + w2 - x1) < 10) and # boxes distance 188 | color_dist(img, r1, r2) < 10): # almost same color 189 | C[i, j] = True 190 | 191 | rects, group_indices = __bfs_bbx(rects, C) 192 | return rects 193 | 194 | 195 | def __bfs_bbx(rects, C): 196 | visited = set() 197 | isclose = {} 198 | for i in range(0, len(rects)): 199 | if i in visited: 200 | continue 201 | 202 | visited.add(i) 203 | neighbors = isclose.get(i, []) 204 | neighbors.extend(np.where(C[i] == True)[0]) 205 | isclose[i] = neighbors 206 | visited = visited | set(neighbors) 207 | 208 | j = 0 209 | while j < len(neighbors): 210 | s = neighbors[j] 211 | s_neighbors = set(np.where(C[s] == True)[0]) 212 | s_neighbors = s_neighbors - visited 213 | neighbors.extend(s_neighbors) 214 | visited = visited | s_neighbors 215 | j += 1 216 | 217 | newrects = [] 218 | group_indices = [] 219 | for value in isclose.values(): 220 | newrects.append(cv2.boundingRect(np.concatenate([u.points(rects[r]) for r in 221 | value]))) 222 | group_indices.append(value) 223 | 224 | return newrects, group_indices 225 | 226 | 227 | def points(rect): 228 | x, y, w, h = rect 229 | return [(x, y), (x+w-1, y), (x+w-1, y+h-1), (x, y+h-1)] 230 | 231 | 232 | # def wrap_rect(rect, (fh, fw), padx=2, pady=None): 233 | # if pady is None: 234 | # pady = padx 235 | # x, y, w, h = rect 236 | # nx, ny = max(x - padx, 0), max(y - pady, 0) 237 | # nw = min(x + w + padx, fw) - nx 238 | # nh = min(y + h + pady, fh) - ny 239 | # return nx, ny, nw, nh 240 | -------------------------------------------------------------------------------- /rev/text/ocr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import cv2 3 | import os 4 | import re 5 | import numpy as np 6 | 7 | import scipy.misc as smp 8 | 9 | from tesserocr import PyTessBaseAPI, PSM 10 | 11 | from skimage import morphology 12 | from skimage.segmentation import clear_border 13 | 14 | import rev.utils as u 15 | # import chartprocessor.utils as u 16 | # import chartprocessor.rectutils as ru 17 | # from chartprocessor.chart import save_bbs 18 | # from chartprocessor.third_party.textconvert import lossy_unicode_to_ascii 19 | # import chartprocessor.utils as u 20 | 21 | 22 | SHOW = False 23 | 24 | 25 | def post_process_text(text): 26 | # '1O99' -> '1099' 27 | # '4o' -> '40' 28 | # 'l' -> '1' 29 | tmp = text.replace('O', '0') 30 | tmp = tmp.replace('o', '0') 31 | tmp = tmp.replace('l', '1') 32 | if u.is_number(tmp): 33 | text = tmp 34 | 35 | # '1 5%' -> '15%' 36 | # '51 ,050' -> '51,050' 37 | # '1001 -5000' -> '1001-5000' 38 | pos = text.find('1 ') 39 | while pos != -1: 40 | skip = 2 41 | 42 | # 'Q1 14' -> 'Q1 14' 43 | if pos - 1 >= 0 and text[pos - 1] in '0OQ': 44 | # pos = text.find('1 ', pos + 2) 45 | skip = 2 46 | # continue 47 | 48 | # '1 999' -> '1999' 49 | if text[pos + 2].isdigit() or text[pos + 2] in ',.-': 50 | text = text[:pos + 1] + text[pos + 2:] 51 | skip = 1 52 | 53 | pos = text.find('1 ', pos + skip) 54 | 55 | # In Quartz is common to use Q1, Q2, Q3 and Q4 56 | # ex. '02' -> 'Q2' 57 | # ex. '02 14' -> 'Q2 14' 58 | if re.match(r'0\d(?:\s|$)', text): 59 | text = 'Q' + text[1:] 60 | 61 | # '°/o of keynote' -> '% of keynote' 62 | text = text.replace('°/o', '%') 63 | 64 | return text 65 | 66 | 67 | # def run_ocr_in_boxes2(img, boxes, pad=0, psm=PSM.SINGLE_LINE): 68 | # """ 69 | # Run OCR for all the boxes. 70 | # :param img: 71 | # :param boxes: 72 | # :param pad: padding before applying ocr 73 | # :param psm: PSM.SINGLE_WORD or PSM.SINGLE_LINE 74 | # :return: 75 | # """ 76 | # # add a padding to the initial figure 77 | # fpad = 1 78 | # img = cv2.copyMakeBorder(img.copy(), fpad, fpad, fpad, fpad, cv2.BORDER_CONSTANT, value=(255, 255, 255)) 79 | # fh, fw, _ = img.shape 80 | # 81 | # api = PyTessBaseAPI(psm=psm, lang='eng') 82 | # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4, 4)) 83 | # 84 | # for box in boxes: 85 | # # adding a pad to original image. Some case in quartz corpus, the text touch the border. 86 | # x, y, w, h = ru.wrap_rect(u.ttoi(box.rect), (fh, fw), padx=pad, pady=pad) 87 | # x, y = x + fpad, y + fpad 88 | # 89 | # if w * h == 0: 90 | # box.text = '' 91 | # continue 92 | # 93 | # # crop region of interest 94 | # roi = img[y:y + h, x:x + w] 95 | # # to gray scale 96 | # roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) 97 | # # 98 | # roi_gray = cv2.resize(roi_gray, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC) 99 | # # binarization 100 | # _, roi_bw = cv2.threshold(roi_gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 101 | # # removing noise from borders 102 | # roi_bw = 255 - clear_border(255-roi_bw) 103 | # 104 | # # roi_gray = cv2.copyMakeBorder(roi_gray, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=255) 105 | # 106 | # # when testing boxes from csv files 107 | # if box.num_comp == 0: 108 | # # Apply Contrast Limited Adaptive Histogram Equalization 109 | # roi_gray2 = clahe.apply(roi_gray) 110 | # _, roi_bw2 = cv2.threshold(roi_gray2, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 111 | # _, num_comp = morphology.label(roi_bw2, return_num=True, background=255) 112 | # box.regions.extend(range(num_comp)) 113 | # 114 | # pil_img = smp.toimage(roi_bw) 115 | # if SHOW: 116 | # pil_img.show() 117 | # max_conf = -np.inf 118 | # min_dist = np.inf 119 | # correct_text = '' 120 | # correct_angle = 0 121 | # u.log('---------------') 122 | # for angle in [0, -90, 90]: 123 | # rot_img = pil_img.rotate(angle, expand=1) 124 | # 125 | # api.SetImage(rot_img) 126 | # conf = api.MeanTextConf() 127 | # text = api.GetUTF8Text().strip() 128 | # dist = abs(len(text.replace(' ', '')) - box.num_comp) 129 | # 130 | # u.log('text: %s conf: %f dist: %d' % (text, conf, dist)) 131 | # if conf > max_conf and dist <= min_dist: 132 | # max_conf = conf 133 | # correct_text = text 134 | # correct_angle = angle 135 | # min_dist = dist 136 | # 137 | # box.text = post_process_text(lossy_unicode_to_ascii(correct_text)) 138 | # box.text_conf = max_conf 139 | # box.text_dist = min_dist 140 | # box.text_angle = correct_angle 141 | # 142 | # u.log('num comp %d' % box.num_comp) 143 | # u.log(u'** text: {} conf: {} angle: {}'.format(correct_text, max_conf, correct_angle)) 144 | # 145 | # api.End() 146 | # 147 | # return boxes 148 | 149 | 150 | def run_ocr_in_chart(chart, pad=0, psm=PSM.SINGLE_LINE): 151 | """ 152 | Run OCR for all the boxes. 153 | :param img: 154 | :param boxes: 155 | :param pad: padding before applying ocr 156 | :param psm: PSM.SINGLE_WORD or PSM.SINGLE_LINE 157 | :return: 158 | """ 159 | img = chart.image 160 | 161 | # add a padding to the initial figure 162 | fpad = 1 163 | img = cv2.copyMakeBorder(img.copy(), fpad, fpad, fpad, fpad, cv2.BORDER_CONSTANT, value=(255, 255, 255)) 164 | fh, fw, _ = img.shape 165 | 166 | api = PyTessBaseAPI(psm=psm, lang='eng') 167 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4, 4)) 168 | 169 | for tbox in chart.texts: 170 | # adding a pad to original image. Some case in quartz corpus, the text touch the border. 171 | x, y, w, h = ru.wrap_rect(u.ttoi(tbox.rect), (fh, fw), padx=pad, pady=pad) 172 | x, y = x + fpad, y + fpad 173 | 174 | if w * h == 0: 175 | tbox.text = '' 176 | continue 177 | 178 | # crop region of interest 179 | roi = img[y:y + h, x:x + w] 180 | # to gray scale 181 | roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) 182 | # 183 | roi_gray = cv2.resize(roi_gray, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC) 184 | # binarization 185 | _, roi_bw = cv2.threshold(roi_gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 186 | # removing noise from borders 187 | roi_bw = 255 - clear_border(255-roi_bw) 188 | 189 | # roi_gray = cv2.copyMakeBorder(roi_gray, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=255) 190 | 191 | # when testing boxes from csv files 192 | if tbox.num_comp == 0: 193 | # Apply Contrast Limited Adaptive Histogram Equalization 194 | roi_gray2 = clahe.apply(roi_gray) 195 | _, roi_bw2 = cv2.threshold(roi_gray2, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 196 | _, num_comp = morphology.label(roi_bw2, return_num=True, background=255) 197 | tbox.regions.extend(range(num_comp)) 198 | 199 | pil_img = smp.toimage(roi_bw) 200 | if SHOW: 201 | pil_img.show() 202 | max_conf = -np.inf 203 | min_dist = np.inf 204 | correct_text = '' 205 | correct_angle = 0 206 | u.log('---------------') 207 | for angle in [0, -90, 90]: 208 | rot_img = pil_img.rotate(angle, expand=1) 209 | 210 | api.SetImage(rot_img) 211 | conf = api.MeanTextConf() 212 | text = api.GetUTF8Text().strip() 213 | dist = abs(len(text.replace(' ', '')) - tbox.num_comp) 214 | 215 | u.log('text: %s conf: %f dist: %d' % (text, conf, dist)) 216 | if conf > max_conf and dist <= min_dist: 217 | max_conf = conf 218 | correct_text = text 219 | correct_angle = angle 220 | min_dist = dist 221 | 222 | tbox.text = post_process_text(lossy_unicode_to_ascii(correct_text)) 223 | tbox.text_conf = max_conf 224 | tbox.text_dist = min_dist 225 | tbox.text_angle = correct_angle 226 | 227 | u.log('num comp %d' % tbox.num_comp) 228 | u.log(u'** text: {} conf: {} angle: {}'.format(correct_text, max_conf, correct_angle)) 229 | 230 | api.End() 231 | 232 | # return boxes 233 | 234 | 235 | # def run_ocr_in_chart(chart, from_bbs=1, pad=0): 236 | # """ 237 | # Run OCR for all the text boxes in a chart and save them in a csv file. 238 | # :param chart: 239 | # :param from_bbs: 1: from predicted1-bbs.csv 240 | # 2: from predicted2-bbs.csv [default: 1] 241 | # :return: 242 | # """ 243 | # # assert (from_bbs != 0) 244 | # # if from_bbs == 1 and not os.path.isfile(chart.predicted_bbs_name(1)): 245 | # # u.create_predicted1_bbs(chart) 246 | # 247 | # 248 | # boxes = chart.predicted_bbs(from_bbs) 249 | # run_ocr_in_boxes(chart.image(), boxes, pad=pad) 250 | # 251 | # bb_name = chart.predicted_bbs_name(from_bbs) 252 | # chart.save_bbs(bb_name, boxes) 253 | --------------------------------------------------------------------------------