├── .gitignore ├── LICENSE.md ├── README.md ├── ctc_decoder ├── __init__.py ├── beam_search.py ├── best_path.py ├── bk_tree.py ├── common.py ├── language_model.py ├── lexicon_search.py ├── loss.py ├── prefix_search.py └── token_passing.py ├── data ├── README.md ├── line │ ├── corpus.txt │ ├── img.png │ └── rnnOutput.csv └── word │ ├── corpus.txt │ ├── img.png │ └── rnnOutput.csv ├── doc ├── comparison.pdf ├── line.png ├── mini.png └── word.png ├── extras ├── best_path_cl.cl └── best_path_cl.py ├── requirements.txt ├── setup.py └── tests ├── README.md ├── test_bk_tree.py ├── test_language_model.py ├── test_mini_example.py └── test_real_example.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | doc/notes.txt 4 | doc/matrix.svg 5 | __pycache__/ 6 | .idea/ 7 | ctc_decoder.egg-info 8 | .pytest_cache -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Harald Scheidl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CTC Decoding Algorithms 2 | 3 | **Update 2021: installable Python package** 4 | 5 | Python implementation of some common **Connectionist Temporal Classification (CTC) decoding algorithms**. 6 | A minimalistic **language model** is provided. 7 | 8 | ## Installation 9 | 10 | * Go to the root level of the repository 11 | * Execute `pip install .` 12 | * Go to `tests/` and execute `pytest` to check if installation worked 13 | 14 | 15 | ## Usage 16 | 17 | ### Basic usage 18 | 19 | Here is a minimalistic executable example: 20 | 21 | ````python 22 | import numpy as np 23 | from ctc_decoder import best_path, beam_search 24 | 25 | mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) 26 | chars = 'ab' 27 | 28 | print(f'Best path: "{best_path(mat, chars)}"') 29 | print(f'Beam search: "{beam_search(mat, chars)}"') 30 | ```` 31 | 32 | The output `mat` (numpy array, softmax already applied) of the CTC-trained neural network is expected to have shape TxC 33 | and is passed as the first argument to the decoders. 34 | T is the number of time-steps, and C the number of characters (the CTC-blank is the last element). 35 | The characters that can be predicted by the neural network are passed as the `chars` string to the decoder. 36 | Decoders return the decoded string. 37 | Running the code outputs: 38 | 39 | ```` 40 | Best path: "" 41 | Beam search: "a" 42 | ```` 43 | 44 | To see more examples on how to use the decoders, 45 | please have a look at the scripts in the `tests/` folder. 46 | 47 | 48 | 49 | ### Language model and BK-tree 50 | 51 | Beam search can optionally integrate a character-level language model. 52 | Text statistics (bigrams) are used by beam search to improve reading accuracy. 53 | 54 | ````python 55 | from ctc_decoder import beam_search, LanguageModel 56 | 57 | # create language model instance from a (large) text 58 | lm = LanguageModel('this is some text', chars) 59 | 60 | # and use it in the beam search decoder 61 | res = beam_search(mat, chars, lm=lm) 62 | ```` 63 | 64 | The lexicon search decoder computes a first approximation with best path decoding. 65 | Then, it uses a BK-tree to retrieve similar words, scores them and finally returns the best scoring word. 66 | The BK-tree is created by providing a list of dictionary words. 67 | A tolerance parameter defines the maximum edit distance from the query word to the returned dictionary words. 68 | 69 | ````python 70 | from ctc_decoder import lexicon_search, BKTree 71 | 72 | # create BK-tree from a list of words 73 | bk_tree = BKTree(['words', 'from', 'a', 'dictionary']) 74 | 75 | # and use the tree in the lexicon search 76 | res = lexicon_search(mat, chars, bk_tree, tolerance=2) 77 | ```` 78 | 79 | ### Usage with deep learning frameworks 80 | Some notes: 81 | * No adapter for TensorFlow or PyTorch is provided 82 | * Apply softmax already in the model 83 | * Convert to numpy array 84 | * Usually, the output of an RNN layer `rnn_output` has shape TxBxC, with B the batch dimension 85 | * Decoders work on single batch elements of shape TxC 86 | * Therefore, iterate over all batch elements and apply the decoder to each of them separately 87 | * Example: extract matrix of batch element 0 `mat = rnn_output[:, 0, :]` 88 | * The CTC-blank is expected to be the last element along the character dimension 89 | * TensorFlow has the CTC-blank as last element, so nothing to do here 90 | * PyTorch, however, has the CTC-blank as first element by default, so you have to move it to the end, or change the default setting 91 | 92 | ## List of provided decoders 93 | 94 | Recommended decoders: 95 | * `best_path`: best path (or greedy) decoder, the fastest of all algorithms, however, other decoders often perform better 96 | * `beam_search`: beam search decoder, optionally integrates a character-level language model, can be tuned via the beam width parameter 97 | * `lexicon_search`: lexicon search decoder, returns the best scoring word from a dictionary 98 | 99 | Other decoders, from my experience not really suited for practical purposes, 100 | but might be used for experiments or research: 101 | * `prefix_search`: prefix search decoder 102 | * `token_passing`: token passing algorithm 103 | * Best path decoder implementation in OpenCL (see `extras/` folder) 104 | 105 | [This paper](./doc/comparison.pdf) gives suggestions when to use best path decoding, beam search decoding and token passing. 106 | 107 | 108 | ## Documentation of test cases and data 109 | 110 | * Documentation of [test cases](./tests/README.md) 111 | * Documentation of the [data](./data/README.md) 112 | 113 | 114 | ## References 115 | 116 | * [Graves - Supervised sequence labelling with recurrent neural networks](https://www.cs.toronto.edu/~graves/preprint.pdf) 117 | * [Hwang - Character-level incremental speech recognition with recurrent neural networks](https://arxiv.org/pdf/1601.06581.pdf) 118 | * [Shi - An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf) 119 | * [Marti - The IAM-database: an English sentence database for offline handwriting recognition](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database) 120 | * [Beam Search Decoding in CTC-trained Neural Networks](https://towardsdatascience.com/5a889a3d85a7) 121 | * [An Intuitive Explanation of Connectionist Temporal Classification](https://towardsdatascience.com/3797e43a86c) 122 | * [Scheidl - Comparison of Connectionist Temporal Classification Decoding Algorithms](./doc/comparison.pdf) 123 | * [Scheidl - Word Beam Search: A Connectionist Temporal Classification Decoding Algorithm](https://repositum.tuwien.ac.at/obvutwoa/download/pdf/2774578) 124 | -------------------------------------------------------------------------------- /ctc_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from ctc_decoder.best_path import best_path as best_path 2 | from ctc_decoder.beam_search import beam_search as beam_search 3 | from ctc_decoder.token_passing import token_passing as token_passing 4 | from ctc_decoder.prefix_search import prefix_search_heuristic_split as prefix_search_heuristic_split 5 | from ctc_decoder.prefix_search import prefix_search as prefix_search 6 | from ctc_decoder.lexicon_search import lexicon_search as lexicon_search 7 | from ctc_decoder.loss import loss as loss 8 | from ctc_decoder.loss import probability as probability 9 | from ctc_decoder.language_model import LanguageModel 10 | from ctc_decoder.bk_tree import BKTree 11 | -------------------------------------------------------------------------------- /ctc_decoder/beam_search.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Optional, List, Tuple 4 | 5 | import numpy as np 6 | 7 | from ctc_decoder.language_model import LanguageModel 8 | 9 | 10 | def log(x: float) -> float: 11 | with np.errstate(divide='ignore'): 12 | return np.log(x) 13 | 14 | 15 | @dataclass 16 | class BeamEntry: 17 | """Information about one single beam at specific time-step.""" 18 | pr_total: float = log(0) # blank and non-blank 19 | pr_non_blank: float = log(0) # non-blank 20 | pr_blank: float = log(0) # blank 21 | pr_text: float = log(1) # LM score 22 | lm_applied: bool = False # flag if LM was already applied to this beam 23 | labeling: tuple = () # beam-labeling 24 | 25 | 26 | class BeamList: 27 | """Information about all beams at specific time-step.""" 28 | 29 | def __init__(self) -> None: 30 | self.entries = defaultdict(BeamEntry) 31 | 32 | def normalize(self) -> None: 33 | """Length-normalise LM score.""" 34 | for k in self.entries.keys(): 35 | labeling_len = len(self.entries[k].labeling) 36 | self.entries[k].pr_text = (1.0 / (labeling_len if labeling_len else 1.0)) * self.entries[k].pr_text 37 | 38 | def sort_labelings(self) -> List[Tuple[int]]: 39 | """Return beam-labelings, sorted by probability.""" 40 | beams = self.entries.values() 41 | sorted_beams = sorted(beams, reverse=True, key=lambda x: x.pr_total + x.pr_text) 42 | return [x.labeling for x in sorted_beams] 43 | 44 | 45 | def apply_lm(parent_beam: BeamEntry, child_beam: BeamEntry, chars: str, lm: LanguageModel) -> None: 46 | """Calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars.""" 47 | if not lm or child_beam.lm_applied: 48 | return 49 | 50 | # take bigram if beam length at least 2 51 | if len(child_beam.labeling) > 1: 52 | c = chars[child_beam.labeling[-2]] 53 | d = chars[child_beam.labeling[-1]] 54 | ngram_prob = lm.get_char_bigram(c, d) 55 | # otherwise take unigram 56 | else: 57 | c = chars[child_beam.labeling[-1]] 58 | ngram_prob = lm.get_char_unigram(c) 59 | 60 | lm_factor = 0.01 # influence of language model 61 | child_beam.pr_text = parent_beam.pr_text + lm_factor * log(ngram_prob) # probability of char sequence 62 | child_beam.lm_applied = True # only apply LM once per beam entry 63 | 64 | 65 | def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[LanguageModel] = None) -> str: 66 | """Beam search decoder. 67 | 68 | See the paper of Hwang et al. and the paper of Graves et al. 69 | 70 | Args: 71 | mat: Output of neural network of shape TxC. 72 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 73 | beam_width: Number of beams kept per iteration. 74 | lm: Character level language model if specified. 75 | 76 | Returns: 77 | The decoded text. 78 | """ 79 | 80 | blank_idx = len(chars) 81 | max_T, max_C = mat.shape 82 | 83 | # initialise beam state 84 | last = BeamList() 85 | labeling = () 86 | last.entries[labeling] = BeamEntry() 87 | last.entries[labeling].pr_blank = log(1) 88 | last.entries[labeling].pr_total = log(1) 89 | 90 | # go over all time-steps 91 | for t in range(max_T): 92 | curr = BeamList() 93 | 94 | # get beam-labelings of best beams 95 | best_labelings = last.sort_labelings()[:beam_width] 96 | 97 | # go over best beams 98 | for labeling in best_labelings: 99 | 100 | # probability of paths ending with a non-blank 101 | pr_non_blank = log(0) 102 | # in case of non-empty beam 103 | if labeling: 104 | # probability of paths with repeated last char at the end 105 | pr_non_blank = last.entries[labeling].pr_non_blank + log(mat[t, labeling[-1]]) 106 | 107 | # probability of paths ending with a blank 108 | pr_blank = last.entries[labeling].pr_total + log(mat[t, blank_idx]) 109 | 110 | # fill in data for current beam 111 | curr.entries[labeling].labeling = labeling 112 | curr.entries[labeling].pr_non_blank = np.logaddexp(curr.entries[labeling].pr_non_blank, pr_non_blank) 113 | curr.entries[labeling].pr_blank = np.logaddexp(curr.entries[labeling].pr_blank, pr_blank) 114 | curr.entries[labeling].pr_total = np.logaddexp(curr.entries[labeling].pr_total, 115 | np.logaddexp(pr_blank, pr_non_blank)) 116 | curr.entries[labeling].pr_text = last.entries[labeling].pr_text 117 | curr.entries[labeling].lm_applied = True # LM already applied at previous time-step for this beam-labeling 118 | 119 | # extend current beam-labeling 120 | for c in range(max_C - 1): 121 | # add new char to current beam-labeling 122 | new_labeling = labeling + (c,) 123 | 124 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank 125 | if labeling and labeling[-1] == c: 126 | pr_non_blank = last.entries[labeling].pr_blank + log(mat[t, c]) 127 | else: 128 | pr_non_blank = last.entries[labeling].pr_total + log(mat[t, c]) 129 | 130 | # fill in data 131 | curr.entries[new_labeling].labeling = new_labeling 132 | curr.entries[new_labeling].pr_non_blank = np.logaddexp(curr.entries[new_labeling].pr_non_blank, 133 | pr_non_blank) 134 | curr.entries[new_labeling].pr_total = np.logaddexp(curr.entries[new_labeling].pr_total, pr_non_blank) 135 | 136 | # apply LM 137 | apply_lm(curr.entries[labeling], curr.entries[new_labeling], chars, lm) 138 | 139 | # set new beam state 140 | last = curr 141 | 142 | # normalise LM scores according to beam-labeling-length 143 | last.normalize() 144 | 145 | # sort by probability 146 | best_labeling = last.sort_labelings()[0] # get most probable labeling 147 | 148 | # map label string to char string 149 | res = ''.join([chars[label] for label in best_labeling]) 150 | return res 151 | -------------------------------------------------------------------------------- /ctc_decoder/best_path.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | 3 | import numpy as np 4 | 5 | 6 | def best_path(mat: np.ndarray, chars: str) -> str: 7 | """Best path (greedy) decoder. 8 | 9 | Take best-scoring character per time-step, then remove repeated characters and CTC blank characters. 10 | See dissertation of Graves, p63. 11 | 12 | Args: 13 | mat: Output of neural network of shape TxC. 14 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 15 | 16 | Returns: 17 | The decoded text. 18 | """ 19 | 20 | # get char indices along best path 21 | best_path_indices = np.argmax(mat, axis=1) 22 | 23 | # collapse best path (using itertools.groupby), map to chars, join char list to string 24 | blank_idx = len(chars) 25 | best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != blank_idx] 26 | res = ''.join(best_chars_collapsed) 27 | return res 28 | -------------------------------------------------------------------------------- /ctc_decoder/bk_tree.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import editdistance as ed 4 | 5 | 6 | class BKTree: 7 | """Burkhard Keller tree: used to find strings within tolerance (w.r.t. edit distance metric) 8 | to given query string.""" 9 | 10 | def __init__(self, txt_list: List[str]) -> None: 11 | """Pass list of texts (words) which are inserted into the tree.""" 12 | self.root = None 13 | for txt in txt_list: 14 | self._insert(self.root, txt) 15 | 16 | def query(self, txt: str, tolerance: int) -> List[str]: 17 | """Query strings within given tolerance (w.r.t. edit distance metric).""" 18 | return self._query(self.root, txt, tolerance) 19 | 20 | def _insert(self, node, txt): 21 | # insert root node 22 | if node is None: 23 | self.root = (txt, {}) 24 | return 25 | 26 | # insert all other nodes 27 | d = ed.eval(node[0], txt) 28 | if d in node[1]: 29 | self._insert(node[1][d], txt) 30 | else: 31 | node[1][d] = (txt, {}) 32 | 33 | def _query(self, node, txt, tolerance): 34 | # handle empty root node 35 | if node is None: 36 | return [] 37 | 38 | # distance between query and current node 39 | d = ed.eval(node[0], txt) 40 | 41 | # add current node to result if within tolerance 42 | res = [] 43 | if d <= tolerance: 44 | res.append(node[0]) 45 | 46 | # iterate over children 47 | for (edge, child) in node[1].items(): 48 | if d - tolerance <= edge <= d + tolerance: 49 | res += self._query(child, txt, tolerance) 50 | 51 | return res 52 | -------------------------------------------------------------------------------- /ctc_decoder/common.py: -------------------------------------------------------------------------------- 1 | def extend_by_blanks(seq, b): 2 | """Extend a label seq. by adding blanks at the beginning, end and in between each label.""" 3 | res = [b] 4 | for s in seq: 5 | res.append(s) 6 | res.append(b) 7 | return res 8 | 9 | 10 | def word_to_label_seq(w, chars): 11 | """Map a word (string of characters) to a sequence of labels (indices).""" 12 | res = [chars.index(c) for c in w] 13 | return res 14 | -------------------------------------------------------------------------------- /ctc_decoder/language_model.py: -------------------------------------------------------------------------------- 1 | class LanguageModel: 2 | "Simple character-level language model." 3 | 4 | def __init__(self, txt: str, chars: str) -> None: 5 | """Create language model from text corpus.""" 6 | 7 | # compute unigrams 8 | self._unigram = {c: 0 for c in chars} 9 | for c in chars: 10 | # ignore unknown chars 11 | if c not in self._unigram: 12 | continue 13 | self._unigram[c] += 1 14 | 15 | # compute bigrams 16 | self._bigram = {c: {d: 0 for d in chars} for c in chars} 17 | for i in range(len(txt) - 1): 18 | c = txt[i] 19 | d = txt[i + 1] 20 | 21 | # ignore unknown chars 22 | if c not in self._bigram or d not in self._bigram[c]: 23 | continue 24 | 25 | self._bigram[c][d] += 1 26 | 27 | # normalize 28 | sum_unigram = sum(self._unigram.values()) 29 | for c in chars: 30 | self._unigram[c] /= sum_unigram 31 | 32 | for c in chars: 33 | sum_bigram = sum(self._bigram[c].values()) 34 | if sum_bigram == 0: 35 | continue 36 | for d in chars: 37 | self._bigram[c][d] /= sum_bigram 38 | 39 | def get_char_unigram(self, c: str) -> float: 40 | """Probability of character c.""" 41 | return self._unigram[c] 42 | 43 | def get_char_bigram(self, c: str, d: str) -> float: 44 | """Probability that character c is followed by character d.""" 45 | return self._bigram[c][d] 46 | -------------------------------------------------------------------------------- /ctc_decoder/lexicon_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ctc_decoder.best_path import best_path 4 | from ctc_decoder.bk_tree import BKTree 5 | from ctc_decoder.loss import probability 6 | 7 | 8 | def lexicon_search(mat: np.ndarray, chars: str, bk_tree: BKTree, tolerance: int) -> str: 9 | """Lexicon search decoder. 10 | 11 | The algorithm computes a first approximation using best path decoding. Similar words are queried using the BK tree. 12 | These word candidates are then scored given the neural network output, and the best one is returned. 13 | See CRNN paper from Shi, Bai and Yao. 14 | 15 | Args: 16 | mat: Output of neural network of shape TxC. 17 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 18 | bk_tree: Instance of BKTree which is used to query similar words. 19 | tolerance: Words to be considered, which are within specified edit distance. 20 | 21 | Returns: 22 | The decoded text. 23 | """ 24 | 25 | # use best path decoding to get an approximation 26 | approx = best_path(mat, chars) 27 | 28 | # get similar words from dictionary within given tolerance 29 | words = bk_tree.query(approx, tolerance) 30 | 31 | # if there are no similar words, return empty string 32 | if not words: 33 | return '' 34 | 35 | # else compute probabilities of all similar words and return best scoring one 36 | word_probs = [(w, probability(mat, w, chars)) for w in words] 37 | word_probs.sort(key=lambda x: x[1], reverse=True) 38 | return word_probs[0][0] 39 | -------------------------------------------------------------------------------- /ctc_decoder/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | from ctc_decoder import common 6 | 7 | 8 | def recursive_probability(t, s, mat, labeling_with_blanks, blank, cache): 9 | """Recursively compute probability of labeling, 10 | save results of sub-problems in cache to avoid recalculating them.""" 11 | 12 | # check index of labeling 13 | if s < 0: 14 | return 0.0 15 | 16 | # sub-problem already computed 17 | if cache[t][s] is not None: 18 | return cache[t][s] 19 | 20 | # initial values 21 | if t == 0: 22 | if s == 0: 23 | res = mat[0, blank] 24 | elif s == 1: 25 | res = mat[0, labeling_with_blanks[1]] 26 | else: 27 | res = 0.0 28 | 29 | cache[t][s] = res 30 | return res 31 | 32 | # recursion on s and t 33 | p1 = recursive_probability(t - 1, s, mat, labeling_with_blanks, blank, cache) 34 | p2 = recursive_probability(t - 1, s - 1, mat, labeling_with_blanks, blank, cache) 35 | res = (p1 + p2) * mat[t, labeling_with_blanks[s]] 36 | 37 | # in case of a blank or a repeated label, we only consider s and s-1 at t-1, so we're done 38 | if labeling_with_blanks[s] == blank or (s >= 2 and labeling_with_blanks[s - 2] == labeling_with_blanks[s]): 39 | cache[t][s] = res 40 | return res 41 | 42 | # otherwise, in case of a non-blank and non-repeated label, we additionally add s-2 at t-1 43 | p = recursive_probability(t - 1, s - 2, mat, labeling_with_blanks, blank, cache) 44 | res += p * mat[t, labeling_with_blanks[s]] 45 | cache[t][s] = res 46 | return res 47 | 48 | 49 | def empty_cache(max_T, labeling_with_blanks): 50 | """Create empty cache.""" 51 | return [[None for _ in range(len(labeling_with_blanks))] for _ in range(max_T)] 52 | 53 | 54 | def probability(mat: np.ndarray, gt: str, chars: str) -> float: 55 | """Compute probability of ground truth text gt given neural network output mat. 56 | 57 | See the CTC Forward-Backward Algorithm in Graves paper. 58 | 59 | Args: 60 | mat: Output of neural network of shape TxC. 61 | gt: Ground truth text. 62 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 63 | 64 | Returns: 65 | The probability of the text given the neural network output. 66 | """ 67 | 68 | max_T, _ = mat.shape # size of input matrix 69 | blank = len(chars) # index of blank label 70 | labeling_with_blanks = common.extend_by_blanks(common.word_to_label_seq(gt, chars), blank) 71 | cache = empty_cache(max_T, labeling_with_blanks) 72 | 73 | p1 = recursive_probability(max_T - 1, len(labeling_with_blanks) - 1, mat, labeling_with_blanks, blank, cache) 74 | p2 = recursive_probability(max_T - 1, len(labeling_with_blanks) - 2, mat, labeling_with_blanks, blank, cache) 75 | p = p1 + p2 76 | return p 77 | 78 | 79 | def loss(mat: np.ndarray, gt: str, chars: str) -> float: 80 | """Compute loss of ground truth text gt given neural network output mat. 81 | 82 | See the CTC Forward-Backward Algorithm in Graves paper. 83 | 84 | Args: 85 | mat: Output of neural network of shape TxC. 86 | gt: Ground truth text. 87 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 88 | 89 | Returns: 90 | The probability of the text given the neural network output. 91 | """ 92 | 93 | try: 94 | return -math.log(probability(mat, gt, chars)) 95 | except ValueError: 96 | return float('inf') 97 | -------------------------------------------------------------------------------- /ctc_decoder/prefix_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def prefix_search(mat: np.ndarray, chars: str) -> str: 5 | """Prefix search decoding. 6 | 7 | See dissertation of Graves, p63-66. 8 | 9 | Args: 10 | mat: Output of neural network of shape TxC. 11 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 12 | 13 | Returns: 14 | The decoded text. 15 | """ 16 | 17 | blank_idx = len(chars) 18 | max_T, max_C = mat.shape 19 | 20 | # g_n and g_b: gamma in paper 21 | g_n = [] 22 | g_b = [] 23 | 24 | # p(y|x) and p(y...|x), where y is a prefix (not p as in paper to avoid confusion with probability) 25 | prob = {} 26 | prob_ext = {} 27 | 28 | # Init: 1-6 29 | for t in range(max_T): 30 | g_n.append({'': 0}) 31 | last = g_b[t - 1][''] if t > 0 else 1 32 | g_b.append({'': last * mat[t, blank_idx]}) 33 | 34 | # init for empty prefix 35 | prob[''] = g_b[max_T - 1][''] 36 | prob_ext[''] = 1 - prob[''] 37 | l_star = y_star = '' 38 | Y = {''} 39 | 40 | # Algorithm: 8-31 41 | while prob_ext[y_star] > prob[l_star]: 42 | prob_remaining = prob_ext[y_star] 43 | 44 | # for all chars 45 | for k in range(max_C - 1): 46 | y = y_star + chars[k] 47 | g_n[0][y] = mat[0, k] if len(y_star) == 0 else 0 48 | g_b[0][y] = 0 49 | prefix_prob = g_n[0][y] 50 | 51 | # for all time steps 52 | for t in range(1, max_T): 53 | new_label_prob = g_b[t - 1][y_star] + ( 54 | 0 if y_star != '' and y_star[-1] == chars[k] else g_n[t - 1][y_star]) 55 | g_n[t][y] = mat[t, k] * (new_label_prob + g_n[t - 1][y]) 56 | g_b[t][y] = mat[t, blank_idx] * (g_b[t - 1][y] + g_n[t - 1][y]) 57 | prefix_prob += mat[t, k] * new_label_prob 58 | 59 | prob[y] = g_n[max_T - 1][y] + g_b[max_T - 1][y] 60 | prob_ext[y] = prefix_prob - prob[y] 61 | prob_remaining -= prob_ext[y] 62 | 63 | if prob[y] > prob[l_star]: 64 | l_star = y 65 | if prob_ext[y] > prob[l_star]: 66 | Y.add(y) 67 | if prob_remaining <= prob[l_star]: 68 | break 69 | 70 | # 30 71 | Y.remove(y_star) 72 | 73 | # 31 74 | best_y = None 75 | best_prob_ext = 0 76 | for y in Y: 77 | if prob_ext[y] > best_prob_ext: 78 | best_prob_ext = prob_ext[y] 79 | best_y = y 80 | y_star = best_y 81 | 82 | # terminate if no more prefix exists 83 | if best_y is None: 84 | break 85 | 86 | # Termination: 33-34 87 | return l_star 88 | 89 | 90 | def prefix_search_heuristic_split(mat: np.ndarray, chars: str) -> str: 91 | """Prefix search decoding with heuristic to speed up the algorithm. 92 | 93 | Speed up prefix computation by splitting sequence into subsequences as described by Graves (p66). 94 | 95 | Args: 96 | mat: Output of neural network of shape TxC. 97 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 98 | 99 | Returns: 100 | The decoded text. 101 | """ 102 | 103 | blank_idx = len(chars) 104 | max_T, _ = mat.shape 105 | 106 | # split sequence into 3 subsequences, splitting points should be roughly placed at 1/3 and 2/3 107 | split_targets = [int(max_T * 1 / 3), int(max_T * 2 / 3)] 108 | best = [{'target': s, 'bestDist': max_T, 'bestIdx': s} for s in split_targets] 109 | 110 | # find good splitting points (blanks above threshold) 111 | thres = 0.9 112 | for t in range(max_T): 113 | for b in best: 114 | if mat[t, blank_idx] > thres and abs(t - b['target']) < b['bestDist']: 115 | b['bestDist'] = abs(t - b['target']) 116 | b['bestIdx'] = t 117 | break 118 | 119 | # splitting points plus begin and end of sequence 120 | ranges = [0] + [b['bestIdx'] for b in best] + [max_T] 121 | 122 | # do prefix search for each subsequence and concatenate results 123 | res = '' 124 | for i in range(len(ranges) - 1): 125 | beg = ranges[i] 126 | end = ranges[i + 1] 127 | res += prefix_search(mat[beg: end, :], chars) 128 | 129 | return res 130 | -------------------------------------------------------------------------------- /ctc_decoder/token_passing.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | import numpy as np 5 | 6 | from ctc_decoder import common 7 | 8 | 9 | class Token: 10 | """Token for token passing algorithm. Each token contains a score and a history of visited words.""" 11 | 12 | def __init__(self, score=float('-inf'), history=None): 13 | self.score = score 14 | self.history = history if history else [] 15 | 16 | def __str__(self): 17 | res = 'class Token: ' + str(self.score) + '; ' 18 | for w in self.history: 19 | res += w + '; ' 20 | return res 21 | 22 | 23 | class TokenList: 24 | """This class simplifies getting/setting tokens.""" 25 | 26 | def __init__(self): 27 | self.tokens = {} 28 | 29 | def set(self, w, s, t, tok): 30 | self.tokens[(w, s, t)] = tok 31 | 32 | def get(self, w, s, t): 33 | return self.tokens[(w, s, t)] 34 | 35 | def dump(self, s, t): 36 | for (k, v) in self.tokens.items(): 37 | if k[1] == s and k[2] == t: 38 | print(k, v) 39 | 40 | 41 | def output_indices(toks, words, s, t): 42 | """argmax_w tok(w,s,t).""" 43 | res = [] 44 | for (wIdx, _) in enumerate(words): 45 | res.append(toks.get(wIdx, s, t)) 46 | 47 | idx = [i[0] for i in sorted(enumerate(res), key=lambda x: x[1].score)] 48 | return idx 49 | 50 | 51 | def log(val): 52 | """Return -inf for log(0) instead of throwing error like python implementation does it.""" 53 | if val > 0: 54 | return math.log(val) 55 | return float('-inf') 56 | 57 | 58 | def token_passing(mat: np.ndarray, chars: str, words: List[str]) -> str: 59 | """Token passing algorithm. 60 | 61 | See dissertation of Graves, p67-69. 62 | 63 | Args: 64 | mat: Output of neural network of shape TxC. 65 | chars: The set of characters the neural network can recognize, excluding the CTC-blank. 66 | words: List of words that can be recognized. 67 | 68 | Returns: 69 | The decoded text. 70 | """ 71 | 72 | blank_idx = len(chars) 73 | max_T, _ = mat.shape 74 | 75 | # special s index for beginning and end of word 76 | beg = 0 77 | end = -1 78 | 79 | # map characters to labels for each word 80 | label_words = [common.word_to_label_seq(w, chars) for w in words] 81 | 82 | # w' in paper: word with blanks in front, back and between labels: for -> _f_o_r_ 83 | prime_words = [common.extend_by_blanks(w, blank_idx) for w in label_words] 84 | 85 | # data structure holding all tokens 86 | toks = TokenList() 87 | 88 | # Initialisation: 1-9 89 | for w_idx, w in enumerate(label_words): 90 | w = label_words[w_idx] 91 | w_prime = prime_words[w_idx] 92 | 93 | # set all toks(w,s,t) to init state 94 | for s in range(len(w_prime)): 95 | for t in range(max_T): 96 | toks.set(w_idx, s + 1, t + 1, Token()) 97 | toks.set(w_idx, beg, t, Token()) 98 | toks.set(w_idx, end, t, Token()) 99 | 100 | toks.set(w_idx, 1, 1, Token(log(mat[1 - 1, blank_idx]), [w_idx])) 101 | c_idx = w[1 - 1] 102 | toks.set(w_idx, 2, 1, Token(log(mat[1 - 1, c_idx]), [w_idx])) 103 | 104 | if len(w) == 1: 105 | toks.set(w_idx, end, 1, toks.get(w_idx, 2, 1)) 106 | 107 | # Algorithm: 11-24 108 | t = 2 109 | while t <= max_T: 110 | 111 | sorted_word_idx = output_indices(toks, label_words, end, t - 1) 112 | 113 | for w_idx in sorted_word_idx: 114 | w_prime = prime_words[w_idx] 115 | 116 | # 15-17 117 | # if bigrams should be used, these lines have to be adapted 118 | best_output_tok = toks.get(sorted_word_idx[-1], end, t - 1) 119 | toks.set(w_idx, beg, t, Token(best_output_tok.score, best_output_tok.history + [w_idx])) 120 | 121 | # 18-24 122 | s = 1 123 | while s <= len(w_prime): 124 | if s == 1: 125 | P = [toks.get(w_idx, s, t - 1), toks.get(w_idx, s - 1, t)] 126 | else: 127 | P = [toks.get(w_idx, s, t - 1), toks.get(w_idx, s - 1, t - 1)] 128 | 129 | if w_prime[s - 1] != blank_idx and s > 2 and w_prime[s - 2 - 1] != w_prime[s - 1]: 130 | tok = toks.get(w_idx, s - 2, t - 1) 131 | P.append(Token(tok.score, tok.history)) 132 | 133 | max_tok = sorted(P, key=lambda x: x.score)[-1] 134 | c_idx = w_prime[s - 1] 135 | 136 | score = max_tok.score + log(mat[t - 1, c_idx]) 137 | history = max_tok.history 138 | 139 | toks.set(w_idx, s, t, Token(score, history)) 140 | s += 1 141 | 142 | max_tok = sorted([toks.get(w_idx, len(w_prime), t), toks.get(w_idx, len(w_prime) - 1, t)], 143 | key=lambda x: x.score, 144 | reverse=True)[0] 145 | toks.set(w_idx, end, t, max_tok) 146 | 147 | t += 1 148 | 149 | # Termination: 26-28 150 | best_w_idx = output_indices(toks, label_words, end, max_T)[-1] 151 | return str(' ').join([words[i] for i in toks.get(best_w_idx, end, max_T).history]) 152 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data files 2 | The data files for the **Word example** are located in `data/word/` and the files for the **Line example** in `data/line`. 3 | Each of these directories contains: 4 | * `rnnOutput.csv`: output of RNN layer (softmax not yet applied), which contains 32 or 100 time-steps and 80 label scores per time-step. 5 | * `corpus.txt`: the text from which the language model is generated. 6 | * `img.png`: the input image of the neural network. It is contained as an illustration, however, the decoding algorithms do not use it. 7 | 8 | -------------------------------------------------------------------------------- /data/line/corpus.txt: -------------------------------------------------------------------------------- 1 | family, fake friend like of the the the 2 | 3 | 4 | -------------------------------------------------------------------------------- /data/line/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/CTCDecoder/6b5c3dd34944e5399a7308e241319b7f9c47e7c3/data/line/img.png -------------------------------------------------------------------------------- /data/line/rnnOutput.csv: -------------------------------------------------------------------------------- 1 | 0.946499;-2.62179;-4.43488;-4.47514;-4.83325;-0.654089;0.645645;-2.1106;-4.7001;-3.45716;-2.88631;-3.78971;-1.78226;-0.0757015;-6.72823;-2.65625;-6.86296;-6.42798;-4.34241;-3.87848;-3.79938;-3.50903;-5.64675;-3.77184;-5.17855;-5.22721;-4.62519;-2.2997;-6.23319;-3.23933;-3.00241;-3.40662;-2.63433;-5.80278;3.44429;1.71859;1.00323;-2.35547;-3.81091;-1.1276;-3.78665;-3.46108;-3.52036;-7.31013;-5.08305;-0.378828;2.04177;-3.26892;-2.29267;-2.37863;-3.34981;0.819732;-4.95336;-3.69477;-1.71299;-3.20539;-2.05364;-4.22296;1.45303;-3.60028;-3.23615;2.15802;-5.90791;-1.62814;-0.0147691;-4.57057;-5.00541;-2.51477;-2.71035;-5.46196;-3.97351;-3.17576;6.20167;-3.90716;-3.51777;-6.86106;-7.5742;-2.19179;-3.57384;3.29054; 2 | -0.174183;-7.75881;-5.61229;-9.04227;-7.92017;0.538565;-3.55974;-5.25639;-9.9701;-8.72368;-5.49197;-5.53916;-2.70453;-2.52938;-8.13962;-6.21383;-8.77007;-9.37458;-8.25591;-7.90203;-8.6263;-11.1362;-9.3865;-8.60642;-9.58042;-10.2776;-7.7682;-5.93898;-8.19541;-2.94996;-4.65495;-5.2892;-8.25348;-8.1374;-1.88367;1.27365;-3.69204;-6.20222;-5.76746;-3.7058;-8.28361;-3.31377;-5.43636;-9.61525;-5.92345;-3.04281;-1.38901;-4.51384;-4.61064;-5.55738;-7.93551;-3.7742;-8.23461;-3.87489;-3.35787;-2.58135;-6.31926;-2.24185;-3.72396;-6.54687;-0.822902;1.31318;-9.37465;-2.81736;-1.67882;-5.5786;-4.83753;-1.49305;-5.15088;-11.4087;-3.85501;-4.41235;1.31195;-3.02054;-5.15306;-7.46363;-12.0217;-7.28538;-7.92799;9.20873; 3 | -5.43106;-7.21524;-5.24688;-5.34667;-0.4919;-2.04607;-2.10752;-7.83262;-6.90699;-7.12708;-4.91211;-6.79647;-1.87965;-1.90614;-4.7872;-5.74314;-3.34322;-8.80112;-2.3913;-4.23099;-3.2957;-7.85354;-3.77785;-4.75481;-6.94819;-8.0857;-5.46562;-3.31103;-5.5902;0.395132;-4.27772;-0.451284;-4.81195;-3.1897;-2.48861;0.0927051;-7.02273;-1.83262;-1.50896;-3.06789;-6.58979;-1.04649;-4.46649;-5.37542;-0.656804;-1.98525;-2.08252;-3.22278;-3.13713;-4.83102;-4.24504;-4.37654;-5.7317;-0.379636;-1.89482;-1.02898;-3.64379;1.75101;-1.55682;-3.55522;4.58611;-0.553492;-11.0749;3.35529;1.52467;-1.10926;-1.71044;1.34437;-3.92288;-9.51402;0.847202;-3.02432;0.230103;-2.62194;-6.32555;-4.07167;-8.11674;-7.02032;-5.7656;-1.26966; 4 | -2.95039;-12.0488;-4.67461;-6.78099;-7.30997;-4.70425;-12.7195;-13.5184;-11.6151;-10.3488;-3.63271;-6.24153;-0.385705;-9.23882;-6.89611;-7.60288;-3.72016;-11.4937;-5.74643;-11.4496;-10.2566;-11.2092;-8.66557;-8.25463;-9.22235;-9.07677;-8.09764;-6.15397;-7.75287;-7.5582;-8.61403;-5.40712;-9.6946;-6.90089;-3.96626;-7.29179;-13.6775;-8.49311;-6.45663;-6.77553;-7.31827;-8.441;-10.5293;-11.0264;-2.1675;-8.49198;-7.95242;-10.1795;-12.5157;-7.04084;-9.83906;-8.34065;-10.3614;2.93558;-5.81743;-4.97269;-6.02182;7.00922;-6.37485;-4.40111;2.09331;-2.15846;-14.08;1.28135;-2.75915;-1.27837;1.97733;2.15122;-4.19958;-9.56989;3.36751;-1.41285;-1.88277;-0.452284;-8.77114;-2.94678;-5.14432;-3.9012;-7.46487;6.78853; 5 | -1.4645;-11.6449;-3.40534;-4.98474;-6.76356;-4.94564;-12.4363;-10.5024;-9.46557;-7.92521;-1.80041;-5.97612;1.35155;-9.53898;-7.03706;-4.28115;-2.73148;-10.6452;-6.69715;-11.1795;-10.1426;-8.27422;-9.21307;-8.08614;-8.91922;-7.78997;-6.46992;-4.93705;-7.64365;-7.01713;-7.06295;-4.16686;-8.99746;-7.67152;-3.50272;-7.44036;-11.065;-5.41942;-4.01568;-7.15701;-4.15922;-8.38443;-8.19981;-10.0731;-1.1008;-6.58393;-6.65714;-9.00841;-10.8595;-5.58711;-8.36389;-7.36706;-9.16931;2.20642;-6.19675;-3.6244;-4.80374;5.50555;-4.46572;-3.75827;-0.694222;-1.306;-11.2435;0.697968;-2.74837;-2.36696;2.54982;-0.98005;-5.21827;-7.30863;4.6855;0.53885;0.900379;-1.48607;-6.26543;-2.38535;-2.97961;-3.14003;-5.05537;9.122; 6 | 3.73504;-9.64177;-2.24387;-5.16799;-8.94511;-2.80324;-10.9815;-7.24339;-8.70592;-7.43748;-0.142992;-5.38049;1.8198;-8.84219;-4.69742;-4.41884;-7.91524;-10.2392;-9.87543;-9.31653;-9.74584;-9.50039;-9.0893;-8.26333;-8.80123;-7.51767;-6.81662;-5.97463;-6.94283;-8.31248;-5.14046;-7.45118;-7.93749;-8.95051;-5.39009;-5.88786;-8.33229;-7.22362;-5.86372;-7.70749;-3.65585;-6.05738;-6.39031;-10.17;-3.78782;-5.37421;-6.21136;-8.71527;-7.4036;-6.2855;-8.39717;-7.52369;-9.05478;-0.742762;-4.64152;-5.19955;-2.14994;0.966221;-3.11853;-3.69805;-0.903308;-0.613298;-9.21649;-3.83024;-2.33663;-1.91337;2.14735;0.119307;-5.09303;-6.97529;2.94369;-0.21954;1.05736;-2.3397;-3.53675;-3.07207;-6.2505;-3.3467;-6.5013;10.4139; 7 | 8.77037;-8.89306;-2.51892;-4.69188;-8.59043;-3.15058;-9.26969;-7.75083;-8.71346;-8.96302;-0.537924;-4.38232;0.576832;-10.3493;-5.91059;-4.27766;-9.14683;-10.2449;-9.57614;-9.31443;-9.28328;-10.7112;-9.89854;-8.68585;-8.41428;-7.50034;-6.27192;-5.93108;-7.08003;-8.58337;-4.58738;-8.41256;-7.97809;-8.80979;-5.26605;-5.38534;-8.20178;-8.69414;-5.67459;-7.5723;-3.61853;-5.57477;-7.65353;-10.7836;-5.06569;-4.18815;-6.55822;-9.61374;-7.88867;-6.25673;-8.94785;-8.09962;-9.33272;-0.0722831;-4.35892;-5.86239;-1.32971;0.411561;-2.31502;-3.50077;-0.583069;-0.195578;-9.05505;-5.01366;-2.7876;-2.93195;1.89258;0.51684;-3.93723;-7.64826;1.33997;0.00452161;0.143988;-3.41613;-3.86057;-3.02418;-7.8471;-4.25601;-7.49024;7.61783; 8 | 11.104;-5.91076;-2.2981;-3.40977;-7.8406;-2.08117;-7.54676;-8.03612;-7.47085;-7.86883;-1.90133;-0.78246;-0.44392;-9.84226;-7.08315;-5.26224;-8.94024;-9.69342;-9.37332;-8.3773;-6.47942;-10.0748;-8.92659;-8.81044;-6.51821;-6.60277;-4.23188;-4.96101;-7.16828;-6.4726;-4.75457;-6.97054;-5.49441;-6.08367;-3.21348;-6.33087;-5.31611;-8.31685;-4.20196;-6.01376;-3.7747;-5.27454;-5.80102;-9.24001;-6.08616;-0.435949;-6.02387;-8.2231;-6.48612;-6.39962;-8.77966;-6.60345;-8.0102;-0.213097;-3.21482;-4.95604;-0.900201;0.502656;0.862773;-3.30396;-2.34788;-1.00468;-5.87543;-5.11042;-2.12169;-2.52077;-0.694168;-0.606023;-3.58946;-5.88722;-0.72538;0.618049;-0.371501;-3.83559;-3.45131;-4.652;-8.21155;-5.08305;-8.40513;4.62927; 9 | 5.79349;-6.5222;-5.14419;-3.52444;-8.63415;-5.33517;-8.24045;-5.23785;-7.70759;-7.62126;-4.68251;-4.46211;-3.5313;-8.13348;-8.91575;-5.48489;-9.77029;-9.70536;-9.21914;-8.23799;-7.93878;-8.15675;-8.44091;-6.82234;-8.8701;-7.41378;-5.44876;-3.84219;-5.66541;-5.51608;-2.25172;-7.65413;-2.10473;-0.89852;-1.51582;-6.1916;0.384801;-7.0511;-1.88939;-6.01862;-3.33845;-7.10683;-3.68791;-7.68111;-6.10546;1.88457;-5.78216;-9.45017;-5.9569;-6.5809;-8.63667;-4.87126;-7.20477;-3.47835;-3.93;-3.71052;-1.45369;-3.25413;5.38125;-1.43989;-2.05025;-1.71057;-1.87423;-4.72956;-3.14646;-4.30579;-3.81183;-3.25257;-0.105351;-3.71721;-3.06538;-0.146202;-0.504611;-4.71869;-3.80483;-6.29996;-4.89122;-5.00511;-6.95227;7.89783; 10 | -2.74901;-4.69496;-7.07818;-2.11652;-6.26475;-6.14874;-5.63024;-6.28253;-9.4075;-7.41451;-7.01647;-8.32084;-5.48011;-6.43709;-9.37513;-5.08825;-10.5839;-9.93909;-8.33632;-9.23746;-6.13104;-7.51494;-4.10114;-5.29432;-6.52742;-4.92602;-5.07064;-3.23994;-2.20067;-3.4198;-0.842573;-5.99188;-1.31545;2.08121;-3.26139;-4.33308;2.07758;-5.67425;-1.30396;-7.36859;-3.65226;-4.03694;-0.141316;-5.1514;-3.89598;3.53481;-4.18244;-7.76483;-5.82593;-6.04277;-7.01025;-3.26865;-6.14922;-1.56558;-0.710393;-1.97148;-0.136052;-1.02525;8.57152;0.891234;-0.43523;-1.01188;0.33671;-3.58924;1.33689;-5.36834;-6.08914;-1.32345;1.47101;-2.62829;-1.68318;-1.98912;1.0193;-3.84456;-5.21764;-4.86894;-7.93091;-4.80645;-5.64151;1.40914; 11 | -4.05101;-9.60097;-6.94149;-3.90052;-11.0856;-5.05345;-13.8444;-13.3888;-12.3301;-10.6881;-5.71032;-6.06271;-4.38525;-9.91626;-8.36755;-11.7043;-11.2374;-11.6368;-12.3853;-13.4596;-8.58107;-13.7884;-6.65175;-7.29278;-6.4393;-7.05349;-7.84355;-6.63464;-5.71308;-8.70069;-7.46896;-9.83438;-9.33227;-3.80597;-8.91341;-9.88957;-8.98529;-14.2747;-9.96742;-11.4616;-9.79978;-3.23297;-6.92246;-7.56182;-6.70565;-5.84959;-8.99166;-12.0196;-13.1366;-8.43146;-11.5844;-10.5069;-11.0687;7.37453;-2.96914;-0.449531;-3.012;2.06084;2.12624;1.71564;-2.83752;1.69632;-4.25242;-5.02601;-0.982918;-4.33839;-4.376;6.414;1.69319;-2.95734;3.11079;-1.25612;-1.12002;0.0242273;-4.19803;-3.46791;-8.80281;-3.20862;-5.82567;5.24449; 12 | -1.82648;-10.9685;-4.60716;-6.1465;-11.2729;-3.27915;-15.4221;-11.4331;-10.0995;-11.2183;-5.35535;-6.26087;-4.02628;-10.1917;-8.07403;-9.87115;-11.4176;-11.6469;-14.0176;-11.322;-9.42;-13.5334;-8.50787;-9.59083;-8.09451;-8.89661;-10.3565;-6.40702;-6.64595;-10.9439;-6.32243;-8.01943;-12.1935;-9.4619;-8.95674;-9.95124;-9.92984;-15.2004;-12.1346;-9.99933;-7.72561;-4.19558;-7.75342;-10.8036;-7.18139;-6.90316;-8.68763;-11.184;-11.4792;-8.12882;-11.827;-10.1121;-12.0213;5.64021;-0.86789;-1.4254;-2.13311;1.873;-2.5282;-2.18906;-4.26227;2.462;-6.9952;-6.00165;-0.751997;-4.55824;-4.28635;4.9095;-0.317264;-6.94666;2.97817;0.5322;0.280396;0.995188;-1.10696;-3.63171;-6.67901;-5.31874;-6.72853;10.254; 13 | -0.339724;-9.00025;-3.24646;-4.03049;-9.74278;-2.11466;-11.7963;-9.56672;-7.08844;-8.39073;-3.67568;-5.6556;-2.73242;-7.7811;-7.94899;-6.03851;-7.76306;-11.3009;-12.159;-10.0661;-9.54729;-9.5666;-8.0348;-7.52071;-7.42862;-5.22987;-7.54611;-5.41313;-8.06713;-8.30314;-5.46902;-6.25378;-9.61342;-11.4389;-7.28363;-9.07246;-10.4846;-11.5436;-9.56965;-6.29173;-5.20313;-5.10918;-5.76979;-10.0773;-4.76397;-5.68403;-4.41745;-9.56475;-9.02766;-8.17135;-7.80484;-7.90023;-9.31773;3.7427;-3.17889;1.21698;-0.859533;4.66223;-2.8861;-2.67823;-4.97592;3.70711;-5.60132;-4.50825;-1.94637;-2.10677;-1.63363;3.54687;0.599399;-6.17295;2.59473;1.33118;1.64457;0.146806;-0.809978;-2.96042;-3.8567;-5.17732;-3.5907;6.77408; 14 | 3.79636;-7.7167;-3.11199;-5.08308;-12.6165;-1.05462;-11.265;-9.73495;-7.62432;-11.3563;-3.78927;-3.90603;-1.21416;-7.6976;-10.5987;-5.80998;-10.0445;-11.8081;-12.3088;-11.6881;-11.8662;-12.5229;-12.105;-12.7059;-10.5594;-9.61692;-5.82978;-8.08727;-8.01799;-8.07273;-6.17231;-9.36649;-9.54504;-11.7415;-6.41013;-7.21958;-10.0159;-9.62904;-7.77113;-7.7901;-5.65472;-8.15487;-6.56516;-11.5108;-7.08519;-3.83889;-0.675303;-12.1589;-8.55212;-8.78561;-9.07797;-6.70052;-9.98271;-0.936671;-2.03956;-1.56998;-0.877334;-1.91158;-3.56299;-6.15572;-1.5578;3.12531;-8.67277;0.828247;-0.365892;-0.470086;-0.834604;-1.38856;-4.83202;-10.5119;3.48451;-0.0752825;5.69491;-3.04264;-1.26003;-5.93864;-7.65095;-7.66474;-4.12678;7.39455; 15 | -1.91415;-10.8629;-5.38742;-1.49218;-2.56216;-3.11388;-11.7016;-8.69537;-7.21209;-13.0206;-4.27031;-8.73194;-2.24534;-7.00612;-11.0634;-9.62481;-5.8464;-8.59466;-7.69313;-9.33784;-7.35519;-11.4711;-7.79024;-8.26803;-8.97365;-6.12209;-6.16864;-6.4024;-0.440737;-7.9433;-3.54608;-6.72101;-7.12049;-5.18172;-3.51406;-5.29207;-8.66198;-4.36116;-3.19881;-7.2604;-7.292;-7.54728;-4.90647;-9.119;-2.93923;-0.881316;-4.72306;-12.4661;-8.39194;-7.98957;-7.2797;-6.64637;-6.9604;-0.683689;4.16795;-1.45455;-1.07515;-0.170495;-0.0564147;-2.54999;5.77865;0.957258;-6.87962;8.76503;1.42369;1.53424;0.410303;-2.78718;-4.02396;-7.43153;-0.575593;2.15443;3.30801;-4.32443;-4.51466;-4.04791;-6.40461;-7.4839;0.248181;4.50313; 16 | -3.14314;-13.5289;-5.73667;-2.9643;-5.3401;-5.18285;-12.8567;-13.2812;-10.265;-13.396;-5.87743;-6.81899;-2.81654;-9.95403;-11.5593;-10.4523;-5.4462;-11.7426;-7.30941;-14.3794;-11.0508;-12.5504;-9.40463;-10.5031;-10.1428;-9.53961;-8.15985;-8.0597;-5.7747;-11.2246;-8.64935;-7.14293;-10.6275;-7.45586;-3.78526;-6.66636;-12.8152;-8.30802;-6.54655;-7.91695;-8.46068;-11.5003;-9.18255;-13.0509;-3.98907;-8.09415;-10.2229;-13.9323;-13.5698;-11.1468;-10.3986;-8.21144;-10.8624;0.997377;-0.715932;-3.71993;-5.09184;5.13454;-2.38353;-2.75587;3.07542;-1.65521;-9.06762;3.38827;0.1955;0.489612;1.83339;0.135434;-2.32138;-8.3672;1.16116;2.10593;0.804523;-0.196095;-5.70908;-3.84258;-4.12716;-2.59772;-3.69991;8.04616; 17 | -2.25038;-11.9568;-4.36028;-4.06735;-6.39706;-4.72664;-11.6558;-12.0435;-10.0542;-10.8035;-3.11878;-4.72236;-0.294151;-10.2855;-10.1443;-7.09513;-3.68754;-12.4933;-8.13424;-13.8028;-11.559;-9.75245;-8.60945;-9.70824;-9.09667;-8.62617;-6.73502;-6.54285;-7.52343;-9.78239;-8.50406;-5.07226;-10.7274;-10.1657;-4.13081;-7.4283;-11.8729;-7.47621;-3.70954;-6.62487;-6.20546;-11.8071;-8.69331;-12.5687;-1.50717;-8.62414;-9.26392;-11.2506;-12.9987;-9.16599;-8.8369;-8.02012;-10.2793;0.907807;-4.16956;-2.87711;-3.26093;6.78206;-2.78553;-3.17483;-0.179693;-1.26237;-8.0782;1.26038;-1.79033;0.790762;4.44505;-0.626725;-2.30937;-7.57798;3.6216;2.83501;0.0730027;0.613905;-5.9771;-2.71067;-1.19661;-1.29461;-4.43796;7.46922; 18 | 0.219528;-10.1515;-2.78146;-6.04304;-11.2238;-1.82886;-12.3965;-9.5129;-10.6656;-9.34034;-1.19488;-4.3084;2.44956;-10.0849;-7.17153;-6.60649;-8.41006;-12.2398;-11.1515;-12.7436;-11.2708;-10.7199;-9.30087;-10.5133;-8.42804;-9.09536;-7.57009;-7.27954;-6.44114;-9.38247;-7.55528;-5.85978;-9.23366;-11.4899;-6.61195;-7.86712;-10.2777;-8.25523;-6.55579;-6.97075;-4.5444;-9.90678;-7.15661;-11.6441;-3.31205;-7.58245;-6.82079;-8.8043;-10.9398;-7.55806;-9.59244;-8.74401;-10.4507;-1.94326;-4.24704;-3.81583;-1.00853;3.80276;-3.39231;-4.13042;-1.7045;-1.53473;-8.72777;-2.46852;-2.74207;-0.149352;3.21061;-1.24593;-4.54082;-6.94356;4.26957;1.30051;0.133539;0.323225;-5.31951;-3.49604;-4.91556;-2.61229;-7.08817;10.5018; 19 | 4.72583;-9.94922;-1.9039;-6.66183;-12.1475;-1.79881;-10.4569;-8.8314;-10.3884;-10.4682;-0.441611;-4.42967;1.81738;-10.4748;-7.06295;-6.27093;-10.5138;-12.553;-10.5606;-12.2543;-11.8591;-12.2342;-10.1951;-10.4394;-8.36573;-8.58856;-8.5107;-6.68966;-7.36919;-9.86766;-7.32685;-7.72075;-9.30192;-11.1912;-6.52871;-7.71271;-10.0587;-9.63657;-6.98332;-7.14642;-4.7238;-8.78172;-7.97339;-11.5744;-5.06798;-6.95121;-6.62696;-9.27031;-10.4234;-7.7402;-9.57593;-8.98348;-10.5301;-1.80326;-4.63536;-5.17731;-0.893603;0.980775;-3.1565;-4.14834;-1.15513;-1.15288;-9.32712;-5.11252;-2.86136;-1.03347;2.49103;-0.563358;-4.38433;-6.95545;3.425;0.499773;-0.244409;-0.54156;-5.1072;-2.94032;-6.69866;-4.01538;-8.83002;10.0472; 20 | 10.9201;-8.87213;-0.841528;-7.14993;-10.5281;-2.98176;-6.91782;-9.03458;-9.97309;-11.4374;-1.27378;-2.11493;0.152139;-11.7071;-8.46363;-6.22552;-10.5835;-12.962;-10.681;-12.2003;-10.8064;-13.2592;-11.4602;-10.9214;-8.15921;-9.29914;-6.35989;-5.26002;-8.35061;-7.9605;-6.90578;-7.7856;-9.65233;-10.1856;-6.43827;-6.16778;-9.75587;-9.79783;-5.71989;-5.98397;-5.10261;-6.58682;-7.75311;-12.2035;-6.33183;-4.84884;-6.77122;-8.85006;-10.3151;-8.1166;-10.2967;-9.48536;-10.5609;-0.367286;-4.7127;-5.6766;-2.00326;0.691469;-1.32465;-4.45115;-1.49526;-0.258967;-9.77248;-6.35061;-2.39967;-2.57729;1.64032;-0.0240059;-4.1488;-6.48793;1.81224;-0.229281;-0.807948;-2.53781;-5.30097;-4.46417;-9.79438;-5.63284;-10.296;6.26672; 21 | 10.3011;-7.31035;-1.47104;-7.36501;-8.47932;-2.78529;-7.13489;-9.02895;-9.40867;-10.7143;-4.14846;-1.2275;-3.5869;-11.0324;-9.98495;-7.94953;-11.6242;-12.3453;-10.1356;-10.7366;-8.45129;-13.8932;-11.4136;-10.7138;-8.82946;-10.3729;-5.18081;-4.49154;-8.47144;-5.44856;-5.62589;-6.51356;-7.23555;-6.11062;-5.0649;-6.01347;-5.64849;-8.09033;-4.58541;-4.84505;-5.73745;-6.03484;-6.74355;-10.8979;-8.39355;-1.33923;-5.91055;-7.94266;-8.57656;-8.48926;-10.2652;-7.49812;-10.3483;-0.869024;-2.32844;-3.34389;-2.52465;-1.17019;3.15907;-3.73074;-1.82584;-1.04174;-6.62005;-5.59318;-2.38308;-2.91302;-0.566659;-1.11456;-3.67084;-6.80137;-1.66461;-0.757485;-0.348378;-3.3935;-4.56314;-6.82836;-10.4655;-6.79775;-10.7753;6.0191; 22 | 2.63703;-6.22214;-4.11437;-6.81045;-4.46452;-5.25139;-5.42593;-4.27247;-10.8021;-10.0078;-6.74516;-4.85562;-6.34029;-7.60887;-8.37108;-8.15432;-10.8799;-9.89515;-6.18506;-8.87715;-6.4883;-11.68;-7.1647;-6.22648;-7.88132;-9.01999;-5.52403;-2.85982;-4.93584;-1.88593;-2.38464;-4.52838;-3.84707;2.71375;-3.81867;-1.88965;1.52945;-3.91012;-1.06792;-5.92014;-5.77941;-4.38892;-3.5972;-6.48474;-8.36702;2.4764;-5.43715;-6.09635;-5.57727;-6.93499;-7.71363;-4.69459;-7.20823;-3.39392;-0.202776;-1.55566;-1.36904;-5.60699;7.84252;0.149181;0.881176;-1.67819;-2.64089;-2.92895;-2.31962;-3.99872;-2.48571;-3.40243;-1.2459;-4.28009;-5.15762;-2.54734;-0.582384;-4.40043;-4.36917;-7.31999;-8.95585;-5.94722;-7.74974;6.20985; 23 | -2.58583;-7.91171;-7.13424;-7.27709;-5.74095;-5.62783;-5.43986;-7.53874;-13.2965;-9.58853;-6.79907;-5.70106;-6.72346;-7.9833;-8.83189;-9.43903;-13.3207;-11.4541;-7.37611;-10.9576;-7.32609;-12.7904;-6.4292;-8.10153;-7.22822;-9.56314;-8.14174;-2.47695;-4.67653;-2.30201;-1.78639;-6.09681;-3.39958;3.64433;-5.31971;-2.40321;-0.354155;-5.34322;-1.91931;-6.00915;-5.29406;-3.08277;-3.60153;-5.98498;-7.86157;2.32235;-6.01944;-5.8826;-5.95641;-8.03371;-7.83052;-4.47993;-7.84446;-2.37673;0.240747;-1.33405;-0.440655;-2.84231;8.74028;1.3298;1.68328;-0.638313;-3.13961;-4.31082;-1.46126;-4.90579;-2.19748;-1.9358;-1.02446;-5.24171;-0.794992;-3.69855;-1.00141;-3.8835;-5.09584;-6.73415;-11.9339;-6.15064;-9.38314;3.1254; 24 | -6.66168;-11.777;-6.07662;-9.91141;-12.374;-5.37247;-11.4724;-11.4712;-13.3287;-10.9014;-4.77939;-3.08668;-6.02144;-10.152;-7.93688;-9.92502;-10.9381;-11.6743;-11.2827;-10.8708;-12.1413;-13.0975;-12.9741;-9.10894;-9.8693;-9.38042;-8.16785;-5.52341;-7.23948;-5.07055;-5.49947;-9.58444;-7.59544;-3.27004;-7.59501;-5.95665;-9.26788;-8.71867;-6.14839;-6.81738;-5.64876;-5.81762;-6.35069;-7.88545;-5.00672;-4.62355;-6.40042;-10.7369;-9.94587;-9.77058;-9.94634;-10.3302;-9.52645;2.40446;-5.19929;-0.230385;-2.41194;-2.46541;0.797367;-1.90867;0.393792;1.55851;-7.7359;-4.63999;-0.711079;-3.24812;2.72146;2.92298;-1.89792;-6.16944;9.44352;-0.216291;-1.33445;0.113525;-2.57085;-4.91125;-8.83035;-6.53203;-6.38003;4.07487; 25 | -5.60653;-13.8835;-6.38848;-11.2267;-16.1393;-4.89806;-15.7608;-10.0168;-13.5767;-11.949;-4.395;-5.02772;-5.346;-10.5379;-8.37236;-9.70311;-9.31819;-11.7025;-12.8201;-13.8956;-16.7999;-12.8794;-17.9536;-10.9172;-11.4402;-10.8965;-9.63053;-7.47009;-11.1154;-9.47961;-7.04854;-11.8825;-11.1747;-9.07842;-8.25262;-6.35443;-10.6179;-8.58157;-5.95559;-10.3184;-6.21486;-10.3842;-8.85094;-10.6703;-5.44844;-8.15258;-6.49949;-13.3855;-13.5231;-10.0773;-12.0426;-11.9127;-9.99785;1.67292;-7.15254;-1.11996;-2.67956;-2.39264;-2.67236;-4.59545;-0.349036;6.05079;-6.10085;-2.39027;-0.617805;-3.51474;1.04284;1.20014;-5.198;-8.10562;9.05943;-0.511233;-1.0449;-0.269429;-2.288;-3.33126;-5.92212;-7.3798;-4.74005;10.6626; 26 | -2.61183;-8.18409;-3.25027;-8.16381;-11.3393;-1.64679;-11.2654;-8.77576;-9.13037;-7.78491;-1.22498;-4.19671;-3.09514;-7.58071;-6.0313;-8.00516;-6.18346;-10.4861;-7.75485;-10.138;-10.2248;-10.9113;-12.9316;-10.6454;-6.69453;-5.62187;-7.96083;-8.09186;-11.9319;-4.43791;-8.87037;-4.68903;-9.05436;-8.5549;-8.54533;-3.28836;-7.52047;-5.82997;-2.2946;-10.0288;-5.63959;-8.46129;-9.35187;-7.82815;-7.92918;-5.2878;-5.10605;-6.84039;-9.91393;-8.78904;-7.87127;-9.16825;-6.30411;1.01135;-5.81945;5.2038;2.79669;2.39661;-3.13763;-4.3099;-3.08682;9.62888;-1.28573;2.27911;-0.367997;1.27859;0.56173;1.79072;-8.82676;-6.86467;3.97277;-0.00240856;0.0726139;3.82208;0.364106;0.0697557;-3.13218;-4.35839;-2.62286;5.27713; 27 | -1.57111;-7.82806;-3.72772;-9.50339;-15.2334;-0.091486;-13.2166;-9.91196;-12.5221;-10.9642;-0.267564;-6.97281;-3.05205;-8.85884;-8.85782;-8.61418;-9.29361;-11.9678;-12.4583;-11.9912;-15.5157;-12.9719;-13.9326;-13.8689;-10.037;-7.17214;-10.9786;-10.1687;-10.7863;-8.03037;-10.118;-6.5887;-11.1338;-12.4922;-11.0947;-2.6489;-7.83595;-8.63284;-3.45356;-11.889;-7.55052;-10.9896;-8.55091;-12.4702;-9.6637;-5.4721;-6.57506;-9.62086;-11.8576;-11.4348;-11.0256;-10.3536;-9.44878;-1.5653;-5.28105;1.93796;1.19473;1.88921;-5.52746;-4.00379;-3.14012;7.81066;-2.38663;0.698909;1.22039;2.66513;0.0262729;-1.60898;-7.42243;-10.4158;3.7513;0.87826;0.671533;3.35898;1.35407;-0.500593;-5.36102;-5.13008;-3.86437;9.64162; 28 | -1.41918;-8.09687;-3.97519;-6.20808;-11.3005;-2.67025;-11.0575;-12.0737;-8.78403;-8.23993;-2.19693;-5.28844;-1.91821;-8.61718;-7.61963;-8.2469;-4.31857;-10.4722;-9.48903;-10.3959;-12.7206;-9.11653;-9.47345;-9.85038;-8.45167;-7.38745;-6.89337;-5.48501;-7.54654;-7.44505;-9.86388;-0.9251;-9.2724;-9.92651;-10.3455;-4.41909;-10.3922;-8.58805;-5.57605;-9.88882;-6.96458;-8.61786;-8.72806;-11.2796;-6.56547;-8.61265;-7.81943;-9.65761;-12.9856;-9.92737;-7.89303;-11.8034;-9.15059;4.24111;-5.68858;2.3958;-1.11067;9.66665;-7.11262;-1.67529;-5.12305;4.62796;-5.5565;0.521545;-1.48784;3.4635;1.89594;2.17044;-4.21255;-7.48849;1.59989;1.41295;-1.78691;5.46007;-0.366628;0.657967;-3.05473;-3.19145;-0.920462;3.84387; 29 | 0.598002;-12.7952;-4.52884;-9.39001;-15.0774;-3.26512;-16.0687;-15.7246;-10.6627;-11.5577;-4.47863;-4.83981;-1.66146;-10.8306;-10.8071;-9.97911;-9.64963;-13.5559;-10.9419;-13.1462;-17.2723;-13.1889;-13.7943;-12.8441;-12.3314;-11.5276;-11.0896;-5.41711;-9.39115;-11.3964;-10.044;-2.85224;-10.4709;-11.9824;-11.2853;-7.74256;-12.1992;-12.554;-10.0264;-9.54792;-8.63838;-10.4612;-10.8832;-14.3864;-7.80922;-11.2259;-11.7145;-11.8281;-15.6764;-13.1463;-11.1925;-13.8592;-13.4335;2.81459;-7.36167;-1.01527;-1.88594;8.54943;-8.67862;-2.63347;-4.11473;2.32378;-10.439;-2.21032;-3.77266;0.763715;3.91391;-0.197764;-5.26254;-9.0529;0.733045;-0.20463;-3.36337;2.59105;-3.35371;-3.19358;-5.0553;-4.72048;-5.53473;11.7539; 30 | -2.04695;-14.1214;-1.43871;-6.03154;-13.4258;-2.95332;-16.5806;-11.5323;-10.1118;-11.0954;-4.95143;-6.68618;-1.55131;-10.562;-10.2279;-6.24756;-8.94924;-12.7044;-7.41591;-13.3229;-15.2343;-11.9146;-15.5641;-9.97782;-12.1835;-11.9904;-11.0484;-1.41067;-8.54326;-11.6696;-7.22381;-9.6506;-11.7011;-11.5324;-5.29649;-6.48729;-8.83213;-9.09167;-5.95074;-5.1494;-1.24605;-11.0603;-9.71149;-11.4251;-5.40164;-11.6918;-11.1035;-6.81373;-12.0257;-5.9442;-7.31571;-10.9192;-11.9354;4.32673;-5.2671;-2.97355;0.164217;3.7359;-8.95886;-2.78202;-1.73347;1.47579;-11.2458;-2.48979;-3.4778;6.89738;12.719;-1.37635;-3.19544;-8.27737;2.78773;0.811428;-2.14975;7.57546;-1.64001;2.81707;-4.04831;-4.45267;-6.29281;4.03548; 31 | -1.05644;-14.8446;-2.14157;-7.91169;-14.7699;-0.144551;-14.7385;-11.4427;-12.2054;-11.1474;-3.6955;-5.46155;0.224156;-11.0166;-11.8963;-7.37406;-7.5046;-10.5583;-10.0329;-15.2356;-16.3661;-11.6638;-16.0416;-9.49522;-11.9837;-12.5929;-10.1199;-2.65365;-9.77972;-11.5427;-7.47589;-12.3499;-12.168;-9.87899;-6.69725;-8.73628;-9.29615;-12.5346;-7.24789;-9.03491;-3.8208;-11.9648;-10.6705;-13.6822;-7.25;-12.026;-10.9373;-11.6162;-17.5958;-7.83926;-10.0893;-11.909;-12.665;2.8446;-6.4766;-3.67927;-0.885669;0.72274;-7.90203;-2.66323;-1.44227;1.27195;-11.2639;-4.67819;-2.001;1.82153;7.26258;-3.99336;-3.33785;-8.90583;0.453222;-0.559609;-0.374085;1.67825;-6.25939;-3.283;-7.01711;-5.90497;-6.13406;13.3576; 32 | 0.214704;-8.38395;-3.97242;-8.32896;-10.8706;-2.87442;-10.0319;-11.7011;-9.51971;-10.6935;-5.71065;-5.51437;0.662933;-10.5114;-7.16723;-9.45302;-9.29984;-11.7428;-11.8748;-13.0755;-12.0731;-11.9551;-10.2097;-8.09889;-8.66657;-9.84895;-9.8289;-2.28961;-10.9076;-4.72462;-5.075;-6.96794;-7.93435;-6.99573;-7.67147;-8.71049;-8.94757;-12.16;-7.26202;-9.20646;-6.71508;-4.8316;-9.48642;-11.5467;-8.64494;-8.0236;-7.72768;-9.26058;-13.7441;-10.7322;-9.96135;-10.6483;-12.4675;3.52433;-7.43115;1.32101;5.75674;2.82627;-7.09359;-1.03502;-6.19744;0.764119;-10.3443;-8.52689;-3.00915;-2.17356;1.73326;1.31401;-2.84872;-8.36033;-1.27849;-0.163605;-0.872021;-3.07297;-5.69319;-7.3929;-8.57596;-5.97624;-7.79903;12.5935; 33 | -3.30299;-4.5005;-4.7835;-6.52498;-6.91661;-5.03191;-7.2701;-12.358;-7.61732;-9.30536;-4.13344;-3.90755;-0.210864;-11.0366;-5.94157;-6.95425;-6.57059;-11.1917;-7.23199;-11.6668;-9.0288;-9.84272;-7.92013;-5.69713;-5.48409;-7.78796;-5.35126;-2.74505;-10.4131;-1.23354;-2.85842;-5.7833;-8.11074;-5.67485;-5.81067;-8.20138;-7.39384;-8.98304;-3.34825;-7.31115;-6.63911;-3.35379;-10.6698;-8.27169;-7.45465;-6.23117;-6.69251;-5.98355;-10.7239;-8.5742;-8.15044;-9.58789;-10.286;4.74968;-6.57246;5.89378;9.69583;3.40773;-4.62601;0.337228;-3.28419;-1.12116;-8.69816;-3.82632;-2.38947;-2.2606;1.49044;2.68196;-4.3444;-4.944;-0.135577;-1.13456;-0.0957655;-1.67715;-5.97491;-5.74645;-6.28316;-4.19689;-7.30402;6.32099; 34 | -4.01744;-3.61099;-5.00467;-6.39848;-4.91129;-4.21442;-5.7317;-6.54228;-6.67337;-8.53652;-0.621217;-4.46856;0.26625;-8.59486;-5.81142;-3.52913;-3.25943;-10.0308;-5.48741;-12.1771;-9.34113;-8.18358;-7.08595;-5.62677;-7.82077;-7.80269;-3.33213;-4.06854;-8.44474;-4.41801;-1.21589;-6.13865;-7.82938;-4.28835;-4.53811;-5.88942;-7.1742;-6.53983;-1.18009;-6.53867;-5.51543;-5.1317;-7.83356;-8.68901;-4.32772;-4.65703;-4.33518;-5.85735;-8.8167;-6.99037;-7.88495;-8.56962;-8.96408;0.847567;-4.26314;-1.09546;8.91277;1.33979;0.811136;-0.462977;4.53814;-1.96742;-8.96371;0.180383;3.67574;-5.26836;-0.639808;-1.4238;-3.75895;-6.49564;-0.50039;-2.08098;2.71535;-2.24537;-9.51788;-6.2362;-7.1123;-3.8055;-7.13148;2.72276; 35 | 1.30652;-7.07422;-4.72248;-8.14722;-10.5733;-2.86174;-9.41452;-6.3417;-10.1369;-11.6667;2.38007;-4.94546;1.90639;-11.4965;-7.63327;-7.46678;-6.7393;-10.9255;-11.5776;-14.2619;-13.3146;-10.4501;-11.5629;-8.36563;-10.1731;-9.42328;-5.10059;-8.50585;-8.90604;-8.72451;-5.22695;-8.86622;-9.36213;-7.64778;-6.76578;-7.8665;-8.99866;-9.2445;-5.54043;-8.94209;-8.51505;-8.36668;-7.02736;-12.6107;-6.1624;-6.28979;-5.25586;-11.1611;-11.3942;-9.57777;-12.6473;-9.06217;-11.5723;-2.33475;-5.80609;-6.19833;3.16616;1.28727;-2.07532;-3.29691;1.9087;-1.80378;-9.21391;-3.18643;0.394085;-5.83865;-1.84323;-1.63405;-6.21277;-8.76049;-1.99443;-0.520725;1.72161;-4.61693;-8.66212;-7.69465;-9.89054;-2.60727;-8.56414;12.4437; 36 | 4.1503;-6.13196;-2.22096;-5.85145;-8.68416;-0.911029;-8.13027;-4.11171;-6.706;-9.75702;2.99559;-3.96231;2.68962;-9.47816;-6.16834;-6.08152;-6.92989;-8.63072;-10.4515;-10.4355;-11.9438;-7.77602;-9.83333;-8.62671;-8.05141;-6.45754;-2.59953;-7.02851;-7.06374;-8.53676;-5.01188;-7.38702;-5.53189;-6.91954;-4.52375;-7.33511;-6.67743;-7.38663;-5.78904;-7.39508;-6.06957;-6.11659;-4.47192;-10.772;-4.87384;-4.19283;-4.41322;-9.60025;-7.85348;-8.0859;-10.2828;-6.97945;-9.63618;-2.56372;-3.67394;-6.18731;0.183922;-1.78963;-1.3154;-3.20253;1.01581;-1.84521;-6.40049;-3.51957;-0.42427;-4.11957;-1.8377;-1.75681;-5.28057;-7.17936;-1.97446;0.889029;2.52236;-4.58209;-4.98416;-5.65894;-8.91958;-2.61549;-6.49911;11.908; 37 | 7.58399;-6.76218;-1.24149;-5.15749;-8.38509;-1.56337;-8.50316;-5.22518;-5.82678;-8.42027;2.56944;-3.55403;1.74511;-8.87349;-5.11696;-5.28106;-6.59003;-7.50368;-9.23652;-8.08066;-10.7547;-7.45831;-8.32538;-8.24369;-7.08061;-4.88003;-3.71512;-6.59306;-6.43967;-9.30161;-4.95991;-7.29543;-4.47725;-7.24859;-4.09169;-7.21119;-5.76248;-7.76342;-6.02273;-6.95184;-4.83712;-6.03416;-4.81536;-10.5685;-4.83773;-3.73916;-4.90391;-9.79774;-6.84547;-7.04371;-9.20544;-6.84158;-9.04344;-2.29076;-3.21065;-6.80774;-0.621883;-2.45279;-1.89894;-2.67829;0.669033;-1.13435;-5.53428;-4.34989;-1.71072;-2.92714;-0.9586;-1.15665;-4.26667;-7.5852;-0.977576;1.53583;1.92465;-3.79709;-3.23326;-3.45911;-7.68658;-1.62149;-5.8123;11.0014; 38 | 10.9315;-7.33023;-0.445602;-4.55104;-8.11792;-2.51893;-7.35032;-5.96008;-5.77239;-8.12898;2.44173;-3.70049;1.66324;-9.01117;-4.07142;-4.29342;-6.39797;-7.52561;-7.90602;-6.8401;-9.43267;-7.2889;-7.80742;-8.27176;-6.31107;-4.21691;-4.32673;-6.17475;-7.23557;-10.0353;-5.42717;-8.04405;-4.94523;-7.90706;-4.49634;-6.83043;-6.46715;-9.47504;-6.13306;-6.89303;-4.26674;-5.6524;-5.77022;-11.2125;-4.49639;-4.17917;-6.63544;-9.90462;-7.68341;-7.32045;-9.26769;-7.86936;-9.13858;-2.21602;-4.30291;-6.50045;-2.41773;-1.31217;-1.87191;-3.47403;0.849737;-0.15895;-6.93525;-5.05244;-2.19463;-2.93298;0.306508;0.413446;-3.71803;-7.57022;-0.285358;1.52209;0.8266;-2.81024;-3.98692;-2.0917;-7.61909;-1.91203;-5.9529;9.0404; 39 | 14.3825;-6.08338;-1.15894;-3.74865;-8.14332;-2.77266;-7.02627;-7.42051;-5.82253;-8.3884;1.04291;-2.49443;1.91475;-10.23;-3.76413;-4.79487;-7.74039;-8.01717;-8.42991;-6.85124;-6.66174;-8.30947;-8.41358;-8.39013;-5.43402;-5.20763;-4.4476;-5.49592;-8.9997;-10.1826;-7.73472;-8.97325;-6.56798;-9.27027;-7.02865;-5.70055;-7.51763;-12.4842;-8.33171;-8.01686;-5.75724;-3.22711;-6.01234;-12.4739;-5.74364;-5.42091;-7.8782;-10.254;-8.67432;-9.51242;-10.3052;-9.63225;-10.1212;-0.637224;-5.81618;-5.05929;-4.77665;-0.47556;-1.96222;-4.14186;-1.78534;1.23749;-8.28434;-8.11874;-2.9413;-4.23523;1.2578;5.99318;-3.28264;-7.45321;0.340803;1.24137;0.85491;-1.46071;-5.12973;-2.42331;-9.12795;-2.85814;-7.71195;6.4417; 40 | 5.60581;-2.7102;0.398344;-0.367158;-4.32738;-0.961584;-8.75949;-10.82;-6.12726;-8.57663;-3.29553;-5.32383;0.755303;-9.59779;-1.30472;-7.17361;-6.94685;-6.83994;-8.64288;-7.24691;-3.21391;-8.76533;-6.40671;-5.64136;-3.87625;-4.12401;-5.63689;-4.297;-7.16138;-9.8399;-5.45003;-9.18551;-8.61124;-6.90068;-7.7779;-3.9921;-9.93893;-13.8309;-9.6474;-8.80901;-6.85045;0.456805;-6.64217;-9.92545;-9.2845;-8.45466;-7.88876;-9.6882;-8.84275;-8.38491;-9.34917;-10.242;-10.0514;5.09535;-4.97028;-3.25312;-1.01797;0.481231;-2.36877;-1.84533;-6.31753;2.00499;-9.35096;-10.2135;-3.54466;-2.31748;1.43741;13.2247;-1.37997;-5.97935;-0.230938;-0.140899;-1.91135;0.932147;-5.15276;-1.96519;-8.12753;-3.86837;-8.66114;4.91278; 41 | 1.79063;-4.73496;-3.8465;-4.37903;-7.56306;-3.23559;-7.77657;-10.2821;-9.62974;-10.4081;-3.94961;-7.18166;-2.04456;-11.4172;-5.15057;-9.02858;-11.4403;-11.3398;-11.7842;-10.8994;-8.01034;-12.728;-9.29184;-7.09091;-7.31195;-7.42555;-7.18686;-6.38122;-10.688;-10.4134;-4.68716;-11.7489;-8.09077;-7.4614;-9.61502;-6.17581;-10.6032;-16.0241;-10.8693;-7.37448;-7.20221;-2.37186;-8.30309;-10.652;-11.8753;-8.87515;-7.26859;-11.1504;-9.77983;-11.578;-11.5891;-10.6698;-12.7555;-0.450168;-5.71502;-2.40724;0.314078;-3.75708;5.07478;0.319205;-5.13818;-1.42739;-11.9218;-11.2217;-3.32196;-6.65939;1.16445;10.0387;0.81555;-6.87635;0.305304;-2.71941;-1.97468;-1.7493;-7.65634;-7.79799;-8.79688;-4.88473;-12.1559;9.14591; 42 | 0.163707;-1.19202;-5.19302;-3.35456;-3.88782;-3.52383;-2.09194;-3.15726;-6.08952;-5.98657;-3.15926;-8.17366;-2.94329;-6.39207;-6.93584;-3.79378;-9.34596;-9.4351;-5.38045;-7.01798;-3.73478;-7.02231;-3.481;-3.07057;-6.50452;-3.92722;-2.30731;-2.5659;-8.93073;-6.7696;-2.99042;-6.58943;-2.79285;-2.92316;-4.9152;-2.88372;-4.06648;-10.077;-5.21384;-5.4725;-4.55604;-3.542;-5.26162;-7.24538;-8.09856;-3.70191;-2.62793;-6.30722;-7.06705;-8.28663;-6.62547;-5.98142;-8.4252;-2.84627;-2.90781;-3.7384;1.30607;-4.50866;14.8742;2.90146;-0.816093;-2.32623;-4.57724;-7.10666;0.766864;-7.10808;0.103411;2.51612;2.16459;-2.2637;-3.02651;-4.54008;2.01781;-2.61035;-8.86075;-8.33346;-7.73624;-2.14354;-9.73262;2.76951; 43 | 4.26204;-2.96896;-4.95045;-7.41241;-8.91052;-1.69205;-5.23628;-4.12192;-10.9335;-10.0884;-2.2054;-7.34737;-1.16168;-9.367;-8.62642;-6.75333;-11.4333;-10.5995;-10.4814;-9.42065;-7.55537;-10.4834;-8.04332;-7.78352;-8.6697;-6.99672;-2.59822;-6.38378;-10.5065;-8.55567;-5.59878;-9.88029;-5.38859;-5.3112;-6.62375;-5.46045;-6.67707;-11.7714;-5.61839;-9.25247;-8.88601;-7.19834;-7.92353;-11.8407;-10.5646;-4.23817;-3.95384;-10.0334;-10.3698;-9.43068;-11.2661;-7.68667;-11.6616;-2.93278;-3.67451;-5.56035;-1.41548;-4.67739;9.18809;0.575713;-1.59056;-1.72953;-6.86152;-8.81562;-2.22153;-7.18642;-2.2163;-0.944265;-2.58202;-5.95066;-3.57473;-3.6859;0.377838;-4.2487;-8.15276;-9.28115;-11.9675;-4.97004;-11.0773;9.56479; 44 | 4.38944;-4.57293;-0.559667;-4.83286;-4.35083;-0.693625;-4.77391;-5.10266;-8.14407;-9.6857;-3.13351;-4.62062;-0.101405;-9.09035;-7.03548;-5.07332;-6.38319;-6.45245;-9.95489;-7.3613;-7.19011;-6.53571;-7.56193;-4.99245;-6.24246;-4.70936;1.86464;-5.48028;-6.13092;-7.27265;-4.80708;-7.88891;-3.2098;-5.27437;-3.71916;-5.41617;-6.82823;-8.83524;-6.8967;-6.29235;-6.51221;-5.60557;-4.34743;-9.46633;-5.79805;-3.62473;-3.50426;-8.41051;-6.75958;-6.98648;-9.32674;-6.67456;-8.84599;-0.294847;-4.06557;-5.1595;-4.71518;-3.26785;4.99773;-0.376357;-1.27629;-0.714138;-5.93882;-7.34986;-2.43584;-6.64168;-2.28924;-1.46908;-2.67373;-4.82748;-2.43764;-1.77199;1.62371;-4.5638;-6.56214;-6.65341;-10.463;-5.27738;-6.55317;8.54796; 45 | 7.60754;-7.28684;1.1046;-6.34259;-5.95386;-0.441784;-5.85103;-5.72821;-8.47172;-9.63356;-2.79703;-5.28298;-0.730401;-9.50762;-5.82528;-2.07551;-6.78833;-9.00875;-9.37784;-8.55157;-8.09969;-7.63429;-9.95678;-7.10658;-9.80579;-8.959;-2.18701;-6.13404;-8.02402;-7.6306;-5.49553;-7.78561;-5.72877;-8.99985;-3.16283;-4.63763;-7.38056;-8.95349;-5.54609;-5.66496;-4.69294;-6.53526;-5.31503;-12.3442;-6.31826;-4.65995;-5.46921;-8.38435;-6.41356;-6.73889;-9.48302;-7.02557;-9.58115;-1.34527;-3.89786;-6.80805;-3.81569;-1.52652;0.436487;-4.85946;-1.7046;-0.095901;-9.5899;-8.04714;-0.504783;-5.75264;-1.49127;-1.18039;-5.0234;-7.92637;-0.972971;-1.73052;2.12064;-4.09142;-4.38541;-6.62526;-11.7666;-6.25292;-8.2089;7.34673; 46 | 11.2752;-6.62606;1.64656;-7.04787;-6.83202;0.991793;-5.49969;-7.40786;-7.4326;-9.95713;-3.62704;-3.87936;-1.45341;-8.07526;-6.89361;-2.13946;-7.24039;-9.04531;-8.77913;-8.09246;-8.24419;-8.70038;-9.60445;-8.05066;-10.2314;-9.17504;-4.15814;-7.13903;-8.73274;-6.10334;-6.01203;-6.58617;-5.5481;-8.83958;-2.35319;-4.35007;-7.06983;-8.7494;-5.98415;-5.0883;-5.07941;-6.22618;-5.60565;-12.5861;-7.0213;-2.30341;-2.94476;-8.64446;-5.91562;-6.37309;-8.99203;-5.57268;-8.77369;-2.22205;-4.2167;-6.71177;-3.84775;-1.98732;-1.80746;-6.01491;-3.06437;-0.6294;-9.22195;-8.12877;-2.28599;-5.48168;-1.97641;-1.80815;-5.60659;-8.72723;-2.6753;-1.8416;2.90342;-6.00857;-4.04612;-7.30316;-12.8831;-7.0193;-8.96778;8.04536; 47 | 4.71771;-5.46021;-3.14985;-6.21754;-5.83947;-1.84938;-2.62718;-4.48607;-6.25244;-8.27452;-5.5621;-7.64742;-3.38059;-2.64229;-7.52806;-2.4889;-6.84689;-8.0273;-7.7907;-5.13595;-6.70339;-7.03311;-6.87549;-6.28207;-10.1368;-7.57214;-6.62646;-6.67394;-5.76745;-2.29737;-2.91432;-4.63139;-2.88904;-6.58251;1.22493;-1.17219;-3.22555;-3.33601;-4.03967;-3.72071;-2.48635;-5.84473;-2.18863;-10.6673;-4.30118;1.85313;2.92449;-7.57684;-2.30888;-4.15749;-6.88213;-1.73496;-7.38892;-5.81298;-3.29994;-4.45328;-4.36557;-5.90696;-0.208059;-6.31909;0.151741;-2.12504;-6.7288;-3.44617;-1.08633;-5.28065;-4.83729;-4.76552;-4.77672;-9.30504;-3.75858;-1.60504;6.47302;-7.95158;-2.05836;-5.90064;-11.5918;-7.28354;-6.69685;6.18722; 48 | -3.26429;-7.051;-3.89595;-4.8059;-2.32071;-1.6067;-3.58452;-7.10989;-6.22797;-8.52742;-6.97882;-9.57641;-3.0488;-1.47494;-5.99267;-3.4206;-4.71615;-6.62238;-5.10333;-4.76566;-4.75795;-7.88139;-4.51747;-4.51089;-8.28589;-7.55345;-6.25931;-4.54527;-1.89438;-1.45584;-1.6941;-2.92864;-2.36722;-5.0328;2.39333;-0.641835;-4.82957;0.567151;-2.94025;-3.09932;-3.36625;-3.2024;-0.831491;-7.35717;-0.0976366;0.755473;1.72705;-5.51954;-1.48282;-4.81838;-4.69674;-2.24556;-5.90732;-3.40687;-1.82481;-2.70296;-4.1946;-3.93922;-0.376428;-6.18362;4.17437;-1.25003;-8.83221;1.88723;-0.257787;-4.7568;-4.17471;-2.683;-4.67022;-10.9454;-2.13975;-2.77558;3.49357;-4.8366;-3.56438;-4.84211;-10.9672;-8.2787;-4.9349;2.19827; 49 | -5.51693;-10.087;-7.15805;-8.23172;-4.93494;-3.20841;-9.14303;-11.6031;-11.7162;-11.5981;-7.63927;-8.46528;-3.84648;-4.27386;-3.81277;-8.72157;-4.52379;-8.91366;-5.43524;-7.85943;-5.00515;-11.8565;-3.46199;-6.61574;-10.128;-9.82275;-9.81911;-7.83052;-4.57429;-3.39341;-5.85469;-5.38954;-7.88681;-5.73217;-3.44333;-5.37649;-11.4279;-6.11017;-5.16085;-7.28571;-8.34679;-4.66452;-6.31002;-9.92484;-1.36741;-5.38741;-4.94805;-8.52685;-7.37537;-8.32143;-9.26059;-7.76605;-9.97148;-0.635853;-3.20425;-2.16625;-5.32822;2.07234;-2.49107;-5.05086;6.05377;-2.91001;-13.5815;1.11508;0.804056;-4.67689;-1.38708;1.12103;-4.32884;-12.177;2.2517;-3.09463;-1.23761;-3.57004;-5.53825;-5.08272;-10.0499;-7.17002;-7.95865;5.09165; 50 | -4.91205;-8.00757;-5.23158;-5.40514;-3.32578;-4.00721;-9.32786;-11.9167;-9.39086;-9.13833;-4.69842;-4.82084;-1.14515;-6.88937;-1.91428;-6.87586;-0.248343;-8.18634;-3.35298;-8.1715;-4.82112;-6.7616;-2.26752;-4.72095;-5.96855;-5.40213;-4.97571;-4.53024;-6.77869;-2.01962;-6.48869;0.024073;-6.4197;-5.47983;-5.5884;-6.69135;-12.6089;-7.74318;-5.30696;-7.31262;-7.09332;-4.30443;-6.34517;-6.71697;-0.205482;-6.85279;-5.51769;-7.33662;-10.0881;-8.0188;-8.2752;-7.43954;-8.10841;3.76055;-6.64242;0.0220879;-6.48163;8.36133;-3.75057;-3.71371;0.893807;-3.6114;-11.5339;-1.14995;-2.02586;-2.55225;-2.02246;3.19111;-3.41755;-5.67362;1.18608;-2.01652;-1.32911;-3.71053;-7.72944;-4.29925;-4.52548;-3.68181;-5.35751;3.73746; 51 | -2.15332;-8.23542;-3.10911;-6.07315;-7.32025;-2.37425;-10.9346;-7.92887;-9.14712;-7.33565;-1.74034;-4.57674;1.43286;-9.00453;-2.61947;-5.72484;-4.26626;-8.65507;-7.11331;-9.42592;-7.82117;-7.66287;-6.47754;-8.00399;-8.04494;-7.55981;-4.8983;-5.5511;-6.48653;-4.5779;-6.3166;-3.26204;-6.84507;-8.54726;-5.92297;-5.95518;-9.43484;-7.03212;-5.71708;-6.8168;-4.81992;-5.58886;-5.89062;-8.369;-1.75252;-5.40691;-5.46757;-7.25764;-8.05109;-6.56937;-9.26407;-6.56102;-8.81577;-0.177786;-5.9265;-2.70997;-4.35316;4.50273;-4.3336;-4.05458;-1.46566;-2.04827;-9.02862;-3.17173;-2.85405;-1.78897;-0.278964;-0.026608;-5.41153;-6.1586;2.02637;-1.48088;-0.424927;-4.11637;-4.56971;-4.38629;-5.23195;-3.63026;-5.32743;10.7869; 52 | 0.439169;-8.79924;-1.88697;-5.72634;-8.91812;-2.42407;-10.1853;-6.90109;-8.44906;-7.30761;-1.28539;-4.80291;1.16245;-8.51105;-2.85648;-4.8301;-7.90151;-9.63988;-8.71823;-8.9428;-9.51434;-8.38711;-8.12198;-8.01608;-8.8984;-8.50463;-5.60816;-5.61496;-6.1597;-5.69384;-5.20133;-6.16635;-6.57409;-9.21711;-6.3947;-5.98211;-8.13085;-7.69361;-5.9074;-6.24666;-3.95599;-5.30255;-4.74889;-8.63247;-3.26659;-4.8545;-5.36067;-7.94277;-6.18803;-6.03405;-8.76049;-6.96284;-8.9392;-1.22689;-4.41176;-3.87841;-2.49934;0.582959;-2.81044;-3.8816;-1.53821;-1.32348;-7.99427;-5.17358;-1.3172;-0.940385;0.779427;0.463047;-4.23779;-6.07547;2.72124;-1.29677;0.875145;-2.77273;-2.31298;-3.48592;-6.79682;-4.20575;-6.04801;11.1848; 53 | 3.50265;-9.96125;-1.61336;-6.34027;-9.27911;-3.15634;-9.90522;-8.30799;-9.28821;-8.66817;-1.33536;-5.14115;0.259781;-9.63097;-4.14662;-5.24536;-9.2432;-10.9634;-9.1586;-9.44371;-10.5522;-10.3265;-9.82949;-8.98741;-9.36982;-9.69188;-7.02534;-6.28707;-7.48759;-7.02393;-5.14874;-6.93309;-7.34697;-9.96495;-6.29597;-6.51733;-9.10867;-8.10733;-6.03428;-6.77421;-4.61769;-5.57438;-6.23641;-10.1991;-4.40516;-5.48575;-6.0049;-9.42318;-7.05306;-6.5735;-9.47485;-7.78158;-9.94135;-1.28768;-4.76738;-4.61608;-2.34824;-0.168995;-3.44051;-4.61635;-1.07995;-0.763745;-9.37491;-5.32544;-1.42017;-1.80884;1.01397;0.708578;-4.7026;-7.162;2.56366;-1.16029;0.335844;-2.97826;-2.55366;-3.40761;-8.04519;-4.79041;-6.74094;10.8735; 54 | 8.76021;-9.25283;-1.50268;-6.59155;-9.40011;-2.91189;-8.5404;-8.32029;-9.41104;-9.74768;-0.804845;-3.21146;0.498729;-10.5201;-5.70794;-5.61627;-9.5789;-10.7346;-9.06714;-9.41065;-9.87235;-11.205;-11.3486;-9.47019;-8.24569;-8.32477;-5.60033;-5.84601;-7.8905;-7.30551;-5.00775;-7.12011;-7.80896;-9.25103;-6.15506;-6.37501;-8.6277;-8.47757;-5.29317;-7.42227;-5.11919;-5.36795;-7.17554;-10.5384;-5.32185;-4.80388;-6.2359;-9.34232;-6.99744;-7.3075;-9.52903;-7.76258;-9.301;-1.0628;-5.14186;-5.2985;-2.31857;-0.538291;-2.84536;-4.82514;-0.714809;-0.0951438;-9.38226;-5.69545;-1.88066;-3.30025;0.986883;0.273752;-4.83207;-7.67339;1.04929;-1.08871;-0.368871;-3.87147;-3.44867;-4.51027;-9.19121;-5.21076;-7.08112;8.25428; 55 | 12.0723;-6.38048;-1.93699;-5.39648;-8.22012;-2.38943;-7.75447;-8.27535;-8.43385;-8.79904;-2.48053;-0.263091;-1.01214;-10.1931;-6.19539;-6.02468;-9.29897;-9.58426;-9.41295;-8.13448;-7.84086;-10.4672;-11.0606;-9.46733;-5.70979;-6.42391;-4.49145;-4.31351;-7.92305;-5.85951;-4.79817;-5.18967;-6.01629;-7.58688;-4.28899;-5.83683;-5.99286;-7.73645;-4.30936;-6.43589;-4.53846;-4.77101;-6.26514;-9.4448;-5.77055;-1.69802;-5.96137;-7.5736;-5.65989;-7.38382;-8.90563;-6.85279;-7.79544;-1.05181;-4.70549;-4.5936;-1.4423;0.517887;-2.03065;-5.02318;-2.09946;-0.906546;-6.89234;-6.05042;-1.77492;-3.76598;-0.112576;-0.445292;-4.24494;-6.82925;0.18884;-0.267762;-1.27006;-5.08769;-3.27271;-5.78736;-8.84772;-5.43841;-7.55764;5.58733; 56 | 9.94845;-6.18689;-3.37763;-7.40653;-11.801;-2.62929;-10.6034;-9.24154;-9.60553;-10.1943;-4.2219;-1.70677;-3.58157;-9.97716;-8.88302;-8.69699;-12.4678;-11.8695;-10.9103;-10.7226;-11.6568;-11.8476;-13.8165;-11.9602;-7.34964;-7.94899;-6.1961;-5.67901;-8.69237;-5.55467;-5.03923;-5.81065;-3.94783;-5.22694;-4.21069;-7.18087;-1.71432;-7.38656;-4.9655;-6.39151;-5.75951;-7.50873;-6.81991;-8.74065;-7.46633;-0.924356;-5.38536;-9.27936;-6.22838;-9.15857;-10.4221;-5.03578;-8.90371;-4.08884;-4.8832;-2.87448;-0.804467;-1.15305;-0.0795751;-4.48721;-3.01288;-2.25461;-4.97956;-5.89072;-3.43081;-4.20077;-2.26155;-3.25468;-2.46979;-7.56215;-1.71552;-1.45234;-1.23611;-6.17145;-4.08928;-8.65507;-7.72952;-5.72645;-9.72422;9.44342; 57 | -1.71786;-2.68104;-6.53987;-4.44099;-6.1585;-5.90823;-5.81044;-3.87426;-7.37832;-6.71942;-6.20191;-5.04573;-4.22779;-4.42225;-8.91268;-4.72301;-9.60981;-7.95759;-6.85133;-7.67707;-8.20877;-5.25206;-7.00563;-5.73664;-4.02507;-4.46881;-4.36464;-1.94937;-3.66415;-1.74304;-1.20583;-3.57909;2.61108;0.941365;-0.725611;-2.82884;7.58445;-2.09407;0.66665;-4.62827;-3.60713;-6.58017;-0.851764;-4.51779;-3.56776;3.64318;-2.29567;-6.67745;-3.5392;-6.39992;-6.28508;-1.71243;-4.42587;-4.00904;-2.01223;-0.781333;2.03672;-2.60997;7.78976;-0.848377;-0.886986;-2.86807;1.31298;-3.75718;0.523734;-3.01592;-5.03434;-4.17353;1.3663;-3.76105;-0.799278;-2.21962;2.84886;-6.01659;-3.2779;-7.55284;-5.02941;-2.9487;-4.24583;-0.133237; 58 | -4.84821;-6.80229;-6.88376;-6.46795;-9.32538;-3.92608;-11.7593;-12.5057;-11.9679;-9.97584;-5.79771;-5.24578;-4.55287;-8.22169;-6.51734;-11.2222;-10.9704;-11.2254;-11.1823;-11.4875;-8.92473;-13.2884;-5.84803;-7.97685;-5.09056;-7.98389;-8.77259;-5.70794;-5.08791;-4.37019;-5.87046;-3.76091;-6.34226;-4.15545;-7.84025;-6.97871;-4.38163;-9.95987;-5.81529;-9.56563;-8.3489;-2.81103;-6.56516;-6.26873;-7.18966;-2.45119;-7.43715;-8.66581;-9.61766;-9.84874;-11.5403;-9.3017;-9.56583;3.82451;-2.58052;0.483593;1.04592;3.29296;0.546952;-1.14348;-2.89508;-2.1574;-4.92958;-6.83386;0.208372;-2.05935;-4.65606;5.26791;-1.32916;-4.81623;2.95921;-2.58317;-0.885145;-1.63318;-4.78686;-6.3269;-9.05433;-3.11434;-8.69488;3.73889; 59 | -3.74606;-9.18359;-5.99202;-6.68675;-9.84917;-5.7722;-14.6133;-10.9003;-11.2555;-11.8849;-5.54954;-6.33805;-5.64745;-8.78059;-8.58608;-9.12326;-8.62399;-12.2273;-9.96507;-12.3773;-11.5951;-13.4281;-9.38884;-7.25397;-9.26993;-9.31587;-8.62357;-4.5359;-5.59137;-7.38586;-5.76682;-5.63644;-10.3339;-5.53446;-9.13573;-8.38488;-8.64058;-11.3263;-8.29048;-8.90471;-7.72083;-4.70893;-7.29312;-7.27789;-6.44322;-6.27266;-8.85328;-9.06322;-13.2989;-10.0039;-11.9606;-10.8004;-10.2379;6.6436;-1.73412;0.121718;-0.403001;0.97513;-1.30054;-0.727065;-2.30231;-3.02184;-7.8003;-5.86724;0.0716836;0.214505;-1.14491;3.49866;-0.854948;-2.18879;3.71993;0.391807;-1.15612;0.664169;-3.6861;-4.60341;-4.93553;-3.53957;-7.01636;9.1739; 60 | -2.72474;-8.88241;-4.43391;-5.71908;-9.95407;-3.75614;-13.0526;-8.80238;-9.93238;-8.57938;-3.08653;-4.51517;-3.75237;-8.20355;-8.65439;-5.83627;-5.33644;-11.7947;-9.85274;-11.6891;-12.9475;-9.50611;-10.7726;-7.59126;-8.04109;-6.19958;-6.4179;-4.4051;-6.92885;-7.92382;-6.95096;-5.23991;-9.72631;-7.90039;-8.09865;-6.62023;-8.40363;-9.04439;-7.18982;-6.91662;-5.21854;-6.52199;-7.70131;-8.35074;-3.99618;-8.53057;-7.69004;-7.28726;-13.0144;-10.0267;-8.50756;-9.58588;-8.70802;6.09924;-3.39407;0.61461;-1.80505;5.99172;-4.44902;-1.60595;-2.44853;1.02197;-4.80081;-3.86206;-2.43511;-0.104913;1.67159;2.73646;-1.81064;-3.03553;3.81464;0.321577;-2.76612;3.77704;-3.3273;-3.06151;-1.85381;-2.8037;-4.00394;6.76757; 61 | -0.442047;-10.454;-4.347;-8.42487;-15.4764;-4.36603;-13.727;-11.9334;-10.2907;-10.5498;-3.58871;-2.9164;-2.35208;-10.0591;-11.0827;-6.27429;-10.4124;-12.5509;-12.7723;-13.9277;-17.5486;-12.5773;-15.46;-10.7796;-11.089;-9.51689;-9.37516;-6.84972;-8.55485;-8.56196;-8.18076;-6.84052;-9.14681;-10.9302;-10.5271;-7.11077;-8.29558;-10.3812;-8.35239;-7.94304;-4.75423;-10.1979;-8.7111;-10.2253;-8.26046;-8.4159;-6.74922;-9.32988;-12.0695;-10.148;-10.1164;-8.91662;-10.6862;2.37396;-5.5211;1.15806;-1.18019;1.5953;-4.82548;-2.1985;-6.40686;1.30554;-4.31129;-5.35974;-2.57423;1.79342;2.48492;-0.562477;-3.74716;-3.48249;7.26912;0.952392;-0.795247;1.13878;-0.422036;-1.67238;-1.55319;-2.57468;-4.14032;11.6908; 62 | -4.40926;-9.59416;-1.61614;-5.61493;-14.6056;-3.4205;-14.4952;-11.6982;-10.995;-11.0287;-3.7495;-4.06814;-1.03573;-10.0438;-9.10201;-7.22904;-10.7287;-12.4954;-8.72694;-12.8497;-13.7897;-13.4891;-14.8907;-9.33821;-8.13048;-7.53076;-10.992;-7.53934;-8.07937;-9.44754;-7.90833;-7.81581;-11.5574;-11.2885;-9.91689;-6.10864;-7.82156;-9.86835;-8.62083;-6.11121;-3.39412;-9.92527;-10.4088;-9.39708;-10.0855;-10.8223;-8.61012;-5.16149;-9.95555;-5.91341;-7.92269;-9.27874;-9.87487;3.59397;-2.48583;1.59045;0.0793994;2.63587;-7.99713;-1.62414;-6.5201;2.91766;-5.70171;-4.10977;-2.30339;9.96571;7.92442;0.746004;-3.40604;-2.62274;4.80792;2.36983;-1.05436;6.39364;0.250724;4.26391;-2.33511;0.0352806;-5.06661;4.04944; 63 | -1.58118;-13.1528;-2.21491;-7.07725;-16.3573;-2.32583;-16.8238;-12.818;-13.353;-12.8659;-2.96268;-4.3699;-1.3063;-11.1731;-12.3575;-10.2983;-11.4197;-14.547;-10.7074;-16.463;-18.4438;-14.9375;-17.3827;-9.30919;-9.91849;-10.6298;-11.3424;-7.81235;-9.27385;-13.3237;-9.3905;-11.9545;-13.2234;-11.4236;-8.72427;-10.1095;-10.8187;-12.0044;-11.394;-5.87836;-5.39645;-13.3114;-11.1108;-12.3182;-11.9299;-13.106;-10.2528;-9.04151;-15.5798;-6.81846;-9.32944;-11.8285;-12.161;1.86584;-3.35238;-1.98765;-1.6622;1.6598;-10.0565;-3.03414;-2.98597;3.22007;-9.21711;-3.00045;-3.59601;9.18024;8.04331;-1.22237;-3.32885;-4.26133;2.03499;0.507935;-1.5308;5.84944;-4.5702;0.796076;-5.77124;-4.15474;-7.17413;12.1691; 64 | -1.81242;-15.0607;-2.00468;-7.41436;-13.2505;-3.0505;-13.0732;-11.3286;-10.3321;-10.8263;-4.68075;-3.47334;-1.37323;-10.7115;-11.4361;-6.92377;-9.04271;-11.7716;-11.8949;-14.9887;-15.3369;-12.7997;-15.8086;-8.67624;-10.8524;-10.0209;-9.79663;-6.25682;-7.49317;-10.8302;-8.84801;-9.06244;-10.9744;-9.20288;-7.93044;-9.87892;-10.0118;-10.7094;-8.99632;-7.77721;-5.46524;-12.6212;-10.5877;-11.9878;-9.36292;-11.4128;-9.76996;-8.97375;-14.4968;-7.99691;-7.95643;-12.3435;-10.8812;2.6518;-3.23291;-1.3413;-1.04923;1.88966;-7.7214;-4.07959;-2.68624;6.14227;-8.04386;-3.35795;-2.96963;2.56325;6.43852;-3.44594;-3.10309;-5.91175;2.05894;1.34258;-0.258543;5.78544;-4.21986;-1.86233;-6.14877;-5.71566;-4.8782;13.0647; 65 | -1.48873;-8.74647;-2.89081;-7.60319;-10.8277;-3.80391;-7.33196;-10.6377;-5.47351;-7.72513;-3.91283;0.947234;0.310964;-7.88235;-6.84502;-6.68746;-6.46225;-8.89115;-9.32185;-10.992;-9.5152;-9.26485;-9.4446;-7.49673;-3.97152;-6.36453;-6.96163;-4.64578;-9.18917;-5.60779;-7.92649;-5.51132;-8.72195;-7.47939;-7.38916;-8.56679;-7.79021;-7.62136;-5.36094;-6.35044;-6.29073;-7.3722;-10.8638;-10.2539;-8.61613;-8.30165;-7.32725;-8.39399;-9.75719;-8.90919;-6.31186;-8.67446;-9.37572;5.36387;-5.13298;4.70363;4.66105;2.79329;-7.01011;-2.51176;-5.74943;4.87994;-7.15892;-2.18288;-2.78652;0.461314;2.34188;0.629925;-3.19662;-6.14668;2.56055;0.897646;0.994372;2.57631;-0.668117;-3.719;-3.3952;-2.91794;-2.09979;5.54654; 66 | -3.05357;-8.0613;-4.34386;-8.31737;-10.4172;-3.09113;-8.84247;-11.7488;-6.995;-8.81334;-4.37095;-0.624149;0.707711;-9.41246;-7.408;-6.78376;-5.28633;-10.2915;-9.80559;-13.5323;-11.9209;-9.87789;-10.1824;-9.47386;-5.36421;-7.9986;-6.60148;-5.09664;-10.5799;-6.09848;-7.17806;-7.02237;-10.6435;-10.3006;-7.64846;-9.44891;-8.49202;-9.6943;-4.56899;-4.52822;-6.5693;-7.64576;-10.6282;-11.5107;-8.28497;-8.4014;-8.31999;-8.60686;-10.9517;-9.96594;-7.76102;-11.1317;-10.2038;4.40534;-4.48708;5.55202;4.37723;4.57295;-6.76041;-3.36448;-3.40796;3.14761;-9.35613;-3.53136;-0.828098;-0.166254;1.51437;0.144296;-3.12228;-8.08896;3.31434;-0.548648;1.25742;1.00095;-3.62184;-5.12811;-5.53374;-4.28251;-3.95996;6.19248; 67 | -0.382353;-7.41626;-4.58206;-11.0945;-12.9862;-2.37062;-7.91162;-9.7853;-8.9677;-13.6254;-5.45748;-2.34447;1.07068;-8.97232;-8.69016;-6.08986;-7.49183;-13.3865;-10.0474;-15.6262;-15.1538;-13.3134;-14.0395;-9.52931;-11.3887;-11.7876;-7.01265;-5.13585;-11.3183;-7.07952;-7.22732;-10.3115;-11.4355;-9.32228;-8.70173;-8.54751;-8.84941;-9.06527;-5.55045;-4.96896;-6.71878;-8.76668;-9.00289;-13.5239;-7.47947;-8.12975;-7.4154;-10.0531;-11.5597;-11.2884;-11.543;-10.8333;-12.658;0.300929;-1.86502;0.359306;2.97079;0.481722;-3.71824;-5.80505;-0.324495;0.962227;-12.3853;-2.74975;3.29805;-2.52127;0.421442;-1.11188;-4.80662;-8.88099;0.884087;-3.37412;0.607943;-2.06484;-7.82315;-9.94149;-11.3339;-5.70613;-8.87777;10.7865; 68 | -1.77065;-4.52682;-2.95733;-7.45064;-7.24477;-4.34886;-2.45925;-5.53327;-6.83809;-11.6832;-6.21743;-6.39503;-1.38155;-4.46887;-5.23629;-3.08169;-4.30559;-9.1139;-4.62718;-12.0212;-8.82484;-9.51991;-8.92163;-5.78287;-9.27582;-8.40382;-3.48031;-4.00467;-3.96474;-3.44257;-1.94046;-5.64882;-8.01861;-4.22043;-5.35581;-6.66648;-5.32161;-4.05255;-2.12668;-1.74934;-3.2566;-5.74434;-3.92866;-9.44621;-3.47826;-5.38866;-3.24964;-5.98814;-4.21527;-6.81448;-7.26148;-3.26887;-10.0336;-1.5964;3.48847;-2.54201;4.44159;-1.64106;-0.506722;-3.50732;2.56648;-0.289785;-7.23135;1.04936;8.71007;-2.00848;-1.60932;-3.07244;-2.37348;-6.7104;-0.605108;-3.96368;3.07639;-4.32271;-4.34221;-7.22499;-9.74027;-1.62583;-4.35866;4.14686; 69 | -2.26371;-8.52026;-4.06571;-8.91766;-10.8481;-5.59726;-7.24681;-7.65409;-11.1615;-14.3048;-6.82449;-6.93963;-2.27718;-7.95674;-8.24054;-6.25638;-8.21961;-11.8337;-5.36274;-14.0392;-12.7251;-13.0934;-12.9838;-8.5382;-12.7259;-10.8824;-7.27301;-7.50892;-5.35495;-7.93295;-6.00372;-8.38499;-11.2183;-5.84681;-8.00854;-9.24167;-8.74575;-8.55199;-8.48993;-5.50222;-6.58432;-10.6191;-7.45117;-13.5411;-7.69342;-8.53687;-6.45692;-9.77185;-7.27372;-10.6285;-11.4826;-2.57784;-13.9906;-2.64246;1.0687;-7.20715;0.0380836;-0.495356;-3.75568;-4.92157;-0.437219;-0.401937;-10.3033;-1.87231;7.38611;-3.74032;-2.62739;-3.3776;-6.11941;-9.6038;-0.807573;-2.49644;2.13298;-4.60939;-4.10506;-7.38937;-8.88385;2.19262;-8.11895;11.9355; 70 | -5.11903;-3.89078;-1.74213;-3.6984;-8.68732;-4.33156;-4.72532;-4.97799;-5.65295;-7.90011;-1.67583;-4.92681;-1.03412;-6.52583;-5.93965;-2.62252;-3.7994;-8.60609;0.239039;-9.40139;-8.13873;-6.32511;-8.18217;-2.71118;-5.90465;-2.28529;-3.81817;-7.32626;-6.06063;-5.47794;-5.7548;-6.33552;-8.89681;-4.40692;-3.78587;-5.14247;-5.95283;-5.04271;-5.68144;-4.91966;-3.59875;-6.60624;-7.17155;-8.0469;-6.52257;-7.41166;-3.36982;-2.85732;-2.24707;-7.18631;-4.6867;2.03774;-8.24462;-1.03622;-2.62462;-2.9289;-1.96469;-0.279306;-2.52993;-0.437633;-1.76054;-0.735979;-4.53994;-0.346648;5.48934;-0.541808;1.14001;-0.616795;-1.90344;-4.13068;1.65354;0.862994;1.69461;1.31125;-3.24117;-2.59483;-2.96673;11.8711;-3.88756;0.550133; 71 | -1.78578;-4.45747;-1.88758;-6.51053;-13.3141;-2.15958;-7.02146;-4.16594;-8.11226;-8.7687;2.63492;-3.48664;0.699619;-7.72891;-7.02512;-5.00085;-7.08899;-8.82282;-2.98101;-9.95036;-9.85291;-7.33317;-10.6041;-5.11731;-7.89041;-5.38171;-5.6129;-9.9495;-8.55618;-8.52362;-8.27653;-8.88022;-8.46622;-8.61965;-5.31024;-7.97058;-6.64842;-7.49623;-7.42334;-8.44511;-6.95176;-10.4411;-7.42221;-11.301;-8.74568;-8.42354;-3.17447;-6.18557;-4.91719;-7.63005;-8.391;-0.293865;-9.85907;-3.78323;-4.25986;-5.41175;-1.75025;-2.71477;-3.50727;-0.943469;-2.37264;-1.55002;-4.42646;-3.11601;0.0196887;-2.01201;-1.23081;-3.46836;-2.40756;-5.0256;1.71261;1.37151;0.485926;-1.88353;-4.12118;-4.07582;-3.58265;10.6173;-4.67136;9.22291; 72 | 0.811192;-4.99824;-1.00005;-6.32959;-9.3314;-1.00206;-6.86652;-3.79687;-7.43771;-8.46113;3.8092;-3.60292;2.70421;-7.65981;-5.45213;-5.05417;-5.72501;-8.85681;-6.05785;-8.14047;-9.59892;-8.01267;-8.03565;-5.06906;-7.17648;-5.30873;-3.662;-7.1901;-5.02384;-7.58822;-6.20666;-6.30363;-6.33644;-7.43779;-4.94326;-8.08934;-5.16147;-7.07841;-4.96049;-7.71846;-6.58282;-7.44335;-5.58477;-9.72297;-5.31251;-5.40201;-4.12293;-7.10098;-6.3755;-5.72685;-8.00266;-4.47423;-8.79445;-2.79498;-2.10038;-5.21386;-1.1643;-2.31094;-1.95503;-1.46459;-0.324965;-2.19785;-5.07011;-3.1581;-0.848963;-1.39076;-1.21108;-3.18451;-2.52788;-4.15262;-1.43719;1.14357;-0.289242;-1.69537;-3.7064;-3.20255;-7.15366;3.79671;-4.90886;11.2314; 73 | 1.16454;-6.40622;0.147738;-5.41025;-7.9782;-1.17707;-7.09435;-5.39584;-7.96582;-8.06883;4.95415;-4.64794;2.47547;-7.53107;-4.71383;-3.2762;-6.10814;-9.07728;-6.07198;-6.76577;-8.07248;-8.08025;-7.10642;-4.59351;-7.57181;-5.12354;-4.87761;-5.64715;-4.80374;-7.68578;-5.85847;-6.35626;-6.75865;-7.4131;-4.69525;-6.37198;-6.59547;-6.93397;-3.80558;-6.70857;-5.12769;-6.32986;-6.02953;-9.79763;-4.12765;-4.65136;-5.22835;-7.07321;-7.11982;-5.63465;-7.76554;-6.48105;-8.73887;-1.8465;-1.83038;-4.66056;-1.76651;-1.79453;-2.26284;-2.17231;-0.301765;-0.139107;-6.75352;-3.45146;-1.4549;-0.348052;-0.18541;-1.91481;-2.92295;-4.48237;-0.152061;1.80947;-0.508051;-1.10239;-3.43112;-1.47262;-6.63873;0.398282;-5.86164;10.1652; 74 | 1.57339;-6.57055;0.611412;-4.87241;-7.81042;-1.42799;-6.27518;-5.18844;-8.21637;-7.94237;6.61062;-4.13164;3.07297;-7.06397;-5.22755;-2.50046;-5.61264;-8.48237;-6.03162;-6.67254;-7.06973;-7.25118;-7.62794;-4.37717;-8.15986;-5.02214;-5.48306;-5.88616;-6.13163;-6.73427;-6.31061;-7.52767;-7.19737;-8.40517;-4.50012;-5.12182;-7.88293;-6.83605;-2.20515;-6.49537;-5.17743;-6.66774;-5.5007;-9.93966;-3.69806;-4.22647;-5.43203;-7.28543;-6.37888;-5.29853;-7.61563;-7.03326;-8.36192;-1.78886;-2.56412;-4.18059;-2.59852;-2.08963;-3.48061;-3.21563;-0.350471;0.43935;-7.11815;-4.02832;-1.22004;-0.132688;0.552981;-1.82803;-2.7369;-6.20797;2.18345;2.62902;-0.547699;-1.00478;-2.778;-0.330646;-6.37849;-0.546834;-5.54716;7.98122; 75 | 3.83802;-7.26761;0.684772;-5.18076;-8.20416;-1.23053;-6.58348;-5.65724;-8.8548;-8.78558;6.37962;-4.80568;2.23737;-7.77691;-5.32746;-3.53521;-6.37187;-9.19616;-6.50445;-7.54603;-7.4574;-7.75244;-8.83824;-5.69124;-8.5807;-4.59638;-6.22602;-6.37216;-7.25272;-7.20618;-7.7695;-8.33765;-7.98917;-9.81226;-4.91933;-4.46565;-8.33425;-8.0567;-3.63202;-7.31438;-6.23258;-6.99782;-5.70681;-11.0672;-4.44509;-5.07238;-6.13365;-8.11105;-6.84966;-6.03666;-8.34934;-7.34714;-8.75792;-1.96961;-3.33134;-4.46553;-3.72705;-2.18004;-3.5119;-3.75295;-0.943051;0.998968;-7.38515;-4.85244;-1.2347;-0.662787;-0.621588;-1.50943;-3.16037;-6.62974;1.38477;1.75585;-0.648183;-1.61922;-3.15229;-0.798937;-7.55031;-0.623514;-6.11051;9.7882; 76 | 4.12418;-7.87959;0.969691;-4.83621;-7.19646;-2.06687;-6.92357;-7.77114;-7.54014;-8.00393;1.79373;-4.59521;0.674539;-7.98158;-5.65249;-4.43633;-7.82558;-10.3216;-6.65032;-7.91483;-8.00733;-8.64189;-8.31615;-7.31262;-7.77707;-6.0291;-6.72401;-6.55077;-7.04094;-7.05551;-7.21029;-7.77824;-7.88707;-9.49773;-4.5639;-5.67585;-8.32014;-7.1362;-5.2161;-5.98602;-4.9878;-7.33323;-5.66239;-11.3782;-4.3516;-5.71109;-7.08097;-8.50169;-7.14393;-5.08992;-8.11226;-7.2605;-8.71917;-2.47435;-3.15776;-3.1417;-3.3673;-1.51216;-3.23814;-4.1463;-0.546054;1.08235;-7.14375;-3.87358;-1.87045;0.400461;0.156297;-1.18221;-3.28075;-7.11819;0.181476;0.0549722;0.41452;-1.35672;-2.06164;-0.26614;-7.45229;-1.0103;-6.76132;11.0436; 77 | 7.70377;-9.01651;-0.396016;-6.29216;-8.75728;-2.06118;-7.50327;-9.27345;-8.15758;-9.40032;0.759287;-4.3892;-0.238555;-9.49186;-6.76806;-5.78476;-9.40611;-11.8778;-7.89333;-9.52522;-8.66049;-11.1222;-9.93111;-8.95728;-8.66103;-8.0419;-7.87692;-7.04618;-8.04143;-7.63675;-7.37836;-8.52018;-9.40079;-10.2393;-4.77549;-6.83595;-9.42233;-8.03514;-5.63451;-6.08667;-5.55365;-7.46243;-7.11109;-12.6619;-5.89461;-6.3162;-7.46005;-9.32736;-7.93116;-6.14643;-8.89503;-7.87568;-9.37002;-2.59391;-4.12933;-3.28886;-3.19421;-1.38018;-3.45198;-4.61891;-0.479416;1.19159;-8.89759;-4.25757;-2.06367;-0.68034;0.118467;-1.01538;-4.53298;-8.41757;-0.198925;-0.571233;0.701156;-1.82474;-2.30759;-1.56476;-8.72687;-2.52903;-7.9442;9.86872; 78 | 11.0473;-8.18993;-1.6355;-6.08834;-7.50554;-2.03256;-6.40846;-8.94288;-7.02151;-8.89186;-0.136068;-2.49997;-0.932814;-9.28984;-6.86983;-5.47481;-8.61266;-10.6131;-7.39077;-8.73832;-6.9434;-10.5782;-9.38255;-8.39632;-7.28558;-7.29576;-6.35864;-6.05233;-7.98472;-6.28894;-6.91561;-7.38841;-8.93039;-9.31256;-3.89381;-6.12146;-8.92849;-7.92159;-5.33798;-5.03085;-5.76824;-6.36748;-6.41773;-11.7628;-6.2765;-4.64794;-7.21198;-7.96437;-6.96888;-6.39393;-8.21158;-7.11174;-8.06291;-1.69475;-4.17252;-2.58172;-2.9358;-0.901373;-3.12803;-3.9619;-1.20427;1.09665;-8.8037;-4.47284;-2.34422;-1.53409;-0.523577;-1.15946;-4.46278;-7.52472;-0.70496;-0.208648;0.960509;-2.19555;-2.45993;-2.57851;-8.47278;-2.86367;-7.42505;6.91104; 79 | 11.2948;-7.33605;-2.6341;-6.61608;-6.49252;-1.61619;-6.61384;-9.42467;-6.68236;-8.38056;-2.10693;-1.47318;-2.24719;-8.87644;-7.39797;-5.54874;-8.37548;-10.137;-8.37901;-8.18157;-5.9984;-10.1938;-8.85415;-8.82754;-6.62756;-7.61687;-5.44035;-4.10846;-8.48955;-4.41413;-6.18951;-6.21546;-8.14832;-9.21384;-3.11055;-5.26969;-8.54697;-7.9704;-4.90943;-3.44057;-6.23676;-5.57741;-4.71686;-10.3196;-6.30661;-2.50254;-6.71189;-6.81277;-6.42362;-7.3436;-8.45337;-7.10685;-7.47269;-0.937048;-4.28755;-0.796476;-2.40944;-0.716313;-3.31458;-4.51828;-2.29728;0.671448;-9.12007;-5.1745;-2.28352;-1.92357;-1.34868;-1.22509;-4.27638;-7.33371;-1.33513;0.383336;1.25331;-2.61358;-2.88828;-3.96474;-9.30498;-4.69747;-7.4564;5.79402; 80 | 7.9451;-8.82174;-4.03208;-9.05587;-9.47465;-3.11576;-10.0556;-10.9019;-9.04537;-10.4628;-5.65093;-3.43208;-4.03613;-8.57343;-9.49927;-3.97928;-8.8711;-12.1325;-9.69005;-10.3789;-9.38156;-11.7461;-10.5431;-9.97458;-8.74939;-10.6732;-7.36516;-2.80838;-8.75836;-4.21935;-4.31731;-6.35203;-8.76304;-9.38578;-4.0173;-4.26076;-7.20524;-6.53132;-3.08931;-3.87294;-5.67041;-6.42906;-4.80016;-9.98153;-4.20646;-2.3283;-5.99211;-7.45165;-8.54698;-8.20486;-9.98636;-8.34888;-8.94833;-2.26423;-3.71936;-0.102965;-2.77991;-1.78909;-4.59408;-6.12145;-1.09455;0.35635;-10.3507;-3.68797;-2.62803;-2.20273;-1.88428;-3.24471;-5.23359;-8.42032;-2.99703;-1.4209;1.71211;-2.95637;-3.90245;-6.25972;-8.16308;-7.23714;-8.25342;10.1503; 81 | -3.41529;-7.41735;-4.34678;-6.70961;-2.10076;-5.60094;-6.15139;-7.81069;-9.39473;-9.2634;-5.30068;-7.50519;-3.9764;-5.66458;-6.63616;-0.692024;-2.87266;-9.91805;-5.01813;-9.79071;-4.07702;-9.11388;-4.25206;-4.30121;-6.35863;-8.67998;-4.36864;0.82015;-4.67045;0.993148;1.58668;-3.44224;-7.75719;-3.31723;-2.66413;-1.07923;-4.73929;-0.935784;3.33045;-1.0836;-1.89238;-2.18132;-1.28021;-5.88911;1.54646;2.2276;-4.82793;-2.71575;-5.43205;-3.56245;-6.00096;-5.57574;-4.96715;-1.45623;0.782766;0.673225;1.82673;1.12831;-0.158989;-3.88678;5.90602;-0.521535;-8.18001;0.754796;3.88878;-0.686005;-1.2774;-4.43649;-3.73253;-6.53968;-3.17059;-1.77953;2.34558;-2.67573;-4.22232;-4.44611;-6.67046;-7.31023;-4.92832;-2.19853; 82 | -2.00212;-13.7983;-6.53884;-11.6054;-10.8422;-6.77416;-10.384;-11.4333;-14.1296;-11.7306;-3.97951;-9.35993;-2.61049;-7.82973;-8.53122;-5.53637;-4.88524;-11.6365;-8.69699;-16.0308;-11.366;-13.4932;-11.2081;-7.0485;-11.8596;-11.2283;-8.47136;-3.91088;-6.73354;-5.31113;-4.40108;-7.99517;-12.3472;-5.82494;-8.13008;-5.99709;-11.5556;-7.99393;-4.79317;-6.52159;-6.51397;-6.88348;-6.01257;-11.0472;-4.13821;-5.15592;-7.44139;-8.46457;-13.0215;-9.69343;-12.1819;-9.08601;-9.42354;-0.294999;-4.59623;-2.86603;-3.72102;-0.685477;-3.63454;-6.54864;2.27766;1.19297;-11.1973;-2.73981;-0.645458;-4.0036;-0.530853;-3.66642;-5.48185;-7.57206;-3.52335;-4.76542;-1.30365;-3.43795;-5.69945;-7.51539;-8.81593;-6.9056;-7.0006;11.1836; 83 | -4.82946;-12.1188;-3.16582;-7.31653;-6.98873;-5.15292;-10.4041;-9.0181;-9.85585;-8.89758;-3.21889;-5.23787;0.128962;-6.48811;-3.40504;-3.48195;-1.01713;-7.87348;-5.28405;-10.9952;-7.3818;-8.99628;-6.76359;-1.87176;-8.297;-6.08151;-6.1128;-0.632231;-6.22826;-4.62491;-4.1723;-7.03582;-9.97183;-3.22318;-6.87691;-5.98366;-10.735;-8.83616;-6.04594;-6.58224;-5.83016;-2.248;-7.35175;-6.96573;-4.51801;-6.1198;-7.38172;-4.91211;-12.6084;-8.52711;-6.9345;-10.3248;-6.89895;7.87916;-5.0239;0.581193;-1.56713;0.892153;-4.37377;-4.04256;-1.8694;2.63021;-10.7787;-3.07928;-3.10963;-1.29137;1.17671;3.72183;-5.38987;-3.29063;0.233309;-1.17985;-2.2307;2.07267;-3.01561;-3.86638;-4.73783;-6.39875;-4.31555;2.72425; 84 | -4.87847;-11.3908;-1.21989;-6.63518;-9.34425;-1.74085;-10.8472;-5.9083;-9.73478;-8.26056;-0.268449;-3.72485;2.36223;-5.41801;-4.74807;-1.55438;-2.23526;-4.94923;-6.38798;-8.76323;-8.20213;-7.41;-9.28265;-3.77497;-9.18464;-6.3748;-6.44235;-3.64342;-7.60629;-8.20912;-5.99368;-10.3986;-9.01686;-7.56312;-4.76757;-3.88427;-9.89899;-8.98685;-5.6231;-8.01492;-4.34396;-6.5473;-7.29731;-10.6324;-5.70254;-5.65912;-6.93089;-6.67487;-10.9967;-8.42561;-6.85366;-8.05646;-6.61051;3.67152;-3.72811;-0.42833;-1.71609;-3.49117;-4.20767;-3.50954;-4.17589;3.8885;-9.26727;-2.54476;-2.63263;-0.34379;2.65624;-0.76562;-3.94295;-6.67076;4.1548;3.74171;-0.773357;1.47138;-0.454469;-3.48609;-0.697072;-5.88524;-1.6518;5.76733; 85 | 1.9102;-9.19201;-0.391071;-6.38259;-9.65105;-1.23421;-8.8346;-7.23563;-8.31056;-9.03798;-2.41306;-4.17691;1.42716;-4.25287;-5.28106;-1.33437;-4.8406;-8.35719;-8.50998;-10.9177;-10.1765;-8.90122;-10.6447;-7.38351;-11.8366;-11.5543;-4.98268;-6.65564;-7.94669;-6.89457;-6.16507;-11.4535;-9.59546;-8.85939;-6.06434;-7.21163;-8.98123;-7.42188;-4.49513;-9.41182;-4.07589;-7.53824;-6.56478;-11.85;-6.62848;-5.8827;-5.75978;-8.53259;-9.29966;-8.02405;-7.49397;-7.46475;-6.36002;-1.84912;-3.50773;-1.79592;-1.14196;-5.18045;-4.31928;-4.30312;-2.69075;1.49076;-9.966;-1.609;-0.120294;-0.794611;3.26805;-2.70728;-4.74484;-8.85615;2.76874;1.05607;0.732055;-0.528727;0.131901;-3.97052;-7.61084;-6.70572;-4.33242;7.6646; 86 | -0.346484;-8.33663;-4.97053;-4.40144;-3.36015;-4.36811;-9.66796;-9.05767;-7.34111;-9.99411;-5.86845;-5.71877;-1.52703;-4.15601;-6.38124;-2.80963;-0.741245;-3.99527;-8.17016;-6.23261;-7.02614;-7.52411;-4.07288;-3.51545;-8.68103;-8.48026;-3.34436;-6.37438;-2.72007;-2.7;-3.51083;-3.507;-4.74101;-2.12618;-5.46636;-7.39737;-5.5324;-5.11101;-2.73794;-9.72359;-7.7453;-6.37047;-3.43312;-7.48626;-3.40627;1.05187;-3.31914;-10.6579;-10.7112;-9.40304;-7.47939;-7.35305;-4.22973;0.476775;-0.382849;1.53338;-1.16175;-2.33369;-1.48655;0.215712;-0.48572;0.0955184;-5.66573;2.95604;0.362559;-2.02661;-0.804854;-4.23034;-4.03926;-7.50521;0.260214;2.99673;-0.0105068;-4.81488;-1.71935;-5.85679;-6.71392;-8.00037;1.09791;5.60453; 87 | -5.69681;-8.8112;-4.92234;-2.46878;0.765999;-6.24865;-10.0641;-11.8418;-6.86305;-9.57725;-6.57688;-8.4271;-1.19978;-6.83175;-5.17159;-6.09567;0.307215;-3.39917;-4.96796;-8.51763;-4.77718;-7.05259;-0.721252;-2.44275;-5.23741;-7.00145;-2.28272;-5.59487;-0.591353;-2.05279;-3.27468;-2.56887;-3.71225;0.771803;-4.03826;-6.53849;-7.97214;-4.62177;-3.89236;-7.3656;-7.02886;-4.66666;-3.5009;-5.56558;-1.37022;-2.09582;-5.21446;-9.12775;-10.5514;-9.00288;-6.98138;-7.34495;-5.27489;2.35502;-0.488038;2.61032;-1.68223;3.50741;-1.56206;2.41086;-0.207511;-3.52373;-7.46283;1.85282;-1.0478;-1.46783;-1.72966;-1.55055;-1.71882;-5.03797;0.203598;2.35649;-0.986911;-4.39326;-3.42419;-5.6192;-3.33786;-6.4593;1.36348;2.30317; 88 | -3.93601;-9.33995;-4.35111;-5.34992;-5.78993;-5.94659;-11.8952;-12.397;-9.41307;-9.23553;-2.6128;-6.07843;2.53022;-8.55286;-3.16201;-5.34416;-2.92813;-8.78872;-6.48278;-10.4241;-7.59452;-8.21305;-3.58594;-4.84299;-6.45764;-8.4143;-6.4529;-6.50864;-5.90934;-6.88809;-5.90446;-3.98171;-7.05149;-8.09094;-4.85285;-8.51208;-10.6367;-8.0743;-6.53445;-7.01102;-4.64752;-6.23361;-6.7011;-9.86227;-0.934855;-7.24887;-7.05724;-8.3205;-12.2388;-7.94065;-9.26472;-9.19995;-9.6612;1.87835;-4.14876;-1.07486;-1.52452;7.1388;-3.59063;-3.24657;-2.4226;-2.17519;-9.93473;-1.5247;-2.52012;-1.03615;1.4512;1.43656;-1.63359;-6.69484;2.69645;-0.622019;-1.33643;-0.648901;-4.92646;-3.31666;-2.35775;-4.1291;-5.37161;4.19019; 89 | -2.38907;-8.83227;-3.90484;-5.94138;-7.49316;-2.94971;-11.2133;-8.71047;-8.95188;-7.51915;-0.225919;-4.53614;5.49726;-8.98495;-3.92188;-3.11925;-4.67579;-8.37546;-8.79366;-9.50405;-8.3165;-7.88417;-5.25776;-6.3168;-7.46414;-9.95682;-5.34285;-7.23909;-5.20086;-6.96367;-5.49455;-4.9767;-5.81909;-8.44053;-4.89504;-8.44008;-7.72012;-6.22939;-5.66445;-7.10345;-2.90673;-7.01909;-5.96408;-9.12068;-2.05574;-4.61538;-6.40849;-7.58008;-9.5347;-5.20353;-9.0802;-7.93376;-8.98985;-1.66026;-3.7475;-2.04815;-0.487398;3.94902;-3.25442;-2.60575;-2.8746;-1.43858;-8.88968;-2.49949;-1.77287;-1.88095;2.46221;-2.40813;-4.65568;-6.63697;3.18961;-0.121591;-0.369404;-1.81984;-4.1877;-4.05758;-5.06013;-4.79321;-5.20371;7.45254; 90 | 4.11228;-7.81868;-3.54357;-6.61945;-10.1544;-1.71309;-10.902;-7.14413;-8.7187;-8.51288;0.282029;-4.02897;4.79644;-8.84002;-4.71102;-3.86184;-8.4076;-9.46288;-9.93432;-9.41785;-9.97649;-10.1279;-8.18842;-7.71933;-8.30873;-9.42744;-6.84571;-7.07525;-6.48584;-9.15322;-5.33231;-8.37937;-6.91245;-10.3667;-4.50607;-7.84884;-7.36151;-8.33195;-7.23975;-6.92671;-3.18082;-6.32011;-7.03068;-10.7112;-4.85945;-5.28823;-5.81159;-8.94941;-7.31079;-5.90152;-9.11137;-7.91656;-9.67012;-2.36217;-3.63463;-4.79598;-0.750709;-0.139776;-2.85286;-2.9045;-1.49916;-1.02357;-9.64637;-5.03467;-1.8062;-3.16529;2.20698;-0.527021;-5.09573;-7.42612;2.24045;-1.19849;0.105039;-1.71339;-2.74254;-4.42551;-6.80842;-5.50474;-7.57074;8.15608; 91 | 8.65283;-5.3494;-2.20297;-5.54195;-6.73795;-0.739942;-5.56922;-6.03835;-7.48071;-9.2222;-0.771405;-2.80907;2.33592;-8.94458;-6.39435;-3.09929;-7.10864;-8.71127;-8.5088;-8.56634;-7.9707;-9.09009;-8.03866;-6.82426;-8.55555;-8.52329;-5.54746;-5.37024;-6.00206;-7.35368;-5.36765;-6.95908;-7.50739;-9.28935;-2.57948;-5.21712;-7.4714;-8.17944;-5.92017;-4.7603;-4.17892;-3.47392;-7.30723;-11.6031;-5.68112;-4.77774;-6.17761;-7.2577;-5.64117;-4.19143;-8.82151;-7.7574;-9.0228;-0.444651;-2.47696;-6.80681;-2.95573;-0.388666;-1.51205;-5.20942;-2.01019;-0.392161;-10.8885;-6.64533;-0.25204;-4.42805;-0.383226;-0.0135871;-5.53181;-8.5823;-1.03918;-2.9767;0.813228;-3.26043;-3.50349;-5.51682;-11.1433;-5.12689;-8.72975;5.53214; 92 | 7.97499;-4.02126;-1.0612;-5.89667;-6.94158;-0.852334;-3.07138;-3.73501;-6.99153;-9.36908;-3.51352;-4.97948;-0.752497;-6.29863;-6.65071;-0.718765;-6.50147;-7.4206;-8.21201;-7.92885;-7.10514;-6.50229;-8.17749;-5.01846;-10.7069;-8.25146;-5.493;-3.68903;-5.54527;-5.55225;-4.0715;-6.58549;-7.06335;-7.98343;0.313501;-1.67771;-5.6982;-6.77382;-5.96675;-2.44264;-2.53782;-3.86717;-5.07655;-12.4236;-4.15545;-3.38262;-3.81944;-6.80651;-3.37312;-3.1326;-8.57383;-6.1394;-8.65422;-2.61428;-1.96595;-8.9084;-5.1383;-3.04827;-1.2092;-6.46787;-3.35168;-0.719978;-10.0123;-6.95682;0.783293;-4.99774;-3.04257;-2.49809;-5.09631;-9.48183;-3.54764;-4.66205;2.77188;-5.30519;-3.78292;-7.45568;-12.4845;-5.48426;-8.83473;6.8566; 93 | 0.749601;-3.35808;-0.660999;-2.32547;-5.11285;-1.59856;-0.0106558;-4.24334;-5.77153;-7.21723;-6.87343;-7.51368;-1.71962;-2.40837;-6.55314;0.474488;-4.2056;-6.09787;-3.81306;-6.25301;-4.19532;-3.54173;-5.64003;-3.9868;-8.96008;-6.9199;-5.44231;0.821834;0.452067;-2.39095;-0.901728;-3.53343;-3.24334;-5.5348;5.05275;0.671257;-2.90618;-0.662502;-2.67892;0.874937;0.256288;-2.7742;-2.57371;-7.99922;-1.33713;-3.0687;0.473434;-2.81865;-0.0669386;0.742359;-5.36008;-2.58142;-5.89042;-2.31766;0.590628;-7.84332;-4.25301;-4.24312;-0.24072;-6.1686;-1.45079;-1.83355;-8.89213;-2.77503;1.52206;-3.70283;-5.14562;-3.56067;-4.41698;-8.73783;-5.38247;-6.76648;5.52253;-4.86821;-4.83847;-5.28699;-11.0678;-4.89145;-5.64135;1.66955; 94 | -3.96142;-7.17454;-3.79338;-5.21026;-6.27009;-2.39209;-6.46332;-7.52676;-7.57982;-7.93268;-8.74953;-9.69287;-2.8992;-3.00973;-7.96487;-2.36987;-3.7897;-5.7412;-1.03365;-5.95971;-5.66524;-8.28349;-5.05982;-2.80496;-10.2609;-9.08579;-8.75905;-2.11869;-0.984683;-5.29006;-4.14485;-4.75121;-7.78563;-7.6953;3.64474;-1.32721;-6.47183;-1.64563;-4.54424;-3.60733;-4.58516;-5.04011;-6.39309;-9.20984;-2.43196;-5.37687;-4.07513;-3.64647;-3.14916;-2.9418;-5.41262;-3.91992;-6.78797;-0.452288;0.230705;-7.23432;-5.20518;-2.22954;-3.96621;-6.07019;0.981917;1.23857;-11.3304;-3.47338;-0.162246;-7.47318;-5.53912;-1.51546;-6.07641;-10.6153;-2.61041;-5.61446;1.3116;-3.90417;-6.74763;-6.10574;-10.066;-3.98322;-5.95997;6.82811; 95 | -4.65476;-6.56884;-5.68861;-7.90044;-4.95902;-2.74419;-4.77394;-7.22521;-9.23247;-7.93587;-8.04542;-7.06474;-3.11115;-3.29938;-6.09134;-5.23508;-5.61896;-7.88848;-2.10147;-4.00746;-4.95479;-9.38495;-2.29161;-3.75216;-10.8729;-9.86049;-8.71701;-4.60824;-6.17034;-2.7555;-5.61683;-2.69483;-7.19523;-5.04788;-1.761;-3.21653;-8.34676;-4.04199;-3.60308;-7.21496;-7.12199;-3.67411;-6.30045;-9.0238;-3.25038;-4.35592;-5.00682;-3.8614;-4.20084;-7.09478;-7.3116;-5.61003;-7.9984;-2.53;-2.23469;-4.45681;-4.99047;0.472897;-0.10877;-4.01359;3.65657;-0.857882;-11.7055;-2.28386;1.09747;-8.13444;-3.87836;-0.604588;-4.80474;-11.2642;-0.714961;-3.94686;-1.86908;-4.11132;-7.17995;-9.83456;-9.20017;-4.57997;-8.35717;6.28216; 96 | -4.86316;-5.39262;-3.84599;-6.66034;-2.1211;-4.58297;-1.89893;-10.5298;-9.17591;-7.3675;-4.95095;-0.744404;0.146056;-6.03626;-2.67464;-3.9072;-0.910592;-7.90063;-1.23632;-3.31066;-0.652269;-6.36293;-1.82399;-3.07461;-7.85281;-7.36513;-4.92262;-2.57571;-5.94212;5.34976;-6.46037;1.53265;-4.35665;-2.08381;-4.30096;-3.75393;-9.13036;-3.26457;-0.0381842;-5.64123;-6.1312;-2.78317;-5.60018;-5.33703;0.390905;-1.68444;-2.98388;-2.40618;-4.88822;-4.76975;-7.52485;-5.63049;-5.87022;-0.953678;-4.03869;0.442307;-4.20902;3.28469;-0.236069;-2.91718;2.01127;-3.65561;-11.3571;-1.20572;1.97329;-4.96725;-2.63238;0.771958;-4.90076;-6.93759;0.235954;-2.27925;-0.472656;-4.61571;-6.89148;-6.01702;-5.92552;-4.38339;-5.66825;1.36249; 97 | -3.03073;-6.14744;-3.72446;-5.17724;-8.04997;-2.73899;-7.90234;-6.18448;-9.08895;-6.06592;-1.37255;-1.61884;3.28829;-7.62612;0.0384169;-4.93946;-4.81612;-6.72019;-5.53518;-5.43637;-5.33604;-6.22144;-4.38847;-4.28914;-7.52791;-7.58896;-4.00533;-4.38263;-3.69593;-2.2825;-5.70285;-2.44976;-3.78982;-5.39625;-4.69668;-3.29965;-8.11973;-6.5662;-6.12649;-6.23807;-4.86004;-2.91931;-3.08006;-5.7292;-2.08854;-3.7083;-3.98912;-5.55876;-6.26792;-4.86779;-9.08581;-5.81482;-7.01748;-2.12946;-4.58239;-3.58241;-4.1211;2.75804;-2.97375;-4.61542;-0.51006;-2.67448;-8.40379;-3.98952;-1.21721;-2.98429;-2.80251;0.222102;-5.19324;-4.79029;-0.535724;-1.95942;0.0841188;-4.173;-5.99399;-5.30593;-7.00399;-3.10154;-6.10057;7.69512; 98 | -0.868529;-7.11025;-2.33787;-5.17098;-8.522;-2.02765;-7.83577;-5.41059;-8.23888;-6.10601;-0.090772;-2.10984;2.92008;-7.47332;-0.230582;-3.51931;-5.84608;-6.49341;-6.38129;-4.96217;-6.7307;-6.19761;-4.89147;-4.29021;-7.74284;-7.61403;-4.4402;-3.71447;-3.39943;-3.60699;-4.60798;-4.51302;-4.06493;-5.94352;-4.81624;-3.68479;-6.82623;-6.6466;-5.51526;-5.81837;-3.78223;-2.86043;-3.29457;-5.94854;-2.12159;-3.40792;-4.11543;-6.37866;-5.75758;-4.9798;-8.37232;-6.15527;-7.01785;-1.95357;-3.76837;-4.00863;-2.86717;0.128838;-3.15668;-4.07575;-0.755147;-2.16287;-6.89057;-4.71372;-1.44613;-2.61223;-1.40601;-0.0991185;-3.8369;-4.55543;0.00717092;-1.13756;0.379532;-3.18826;-3.93195;-4.34374;-7.08338;-2.83229;-4.93772;9.11817; 99 | -0.286895;-7.77639;-1.69431;-5.06338;-8.03187;-1.93403;-7.74735;-6.36113;-8.26548;-6.57577;0.0356798;-2.5471;2.73011;-7.84341;-1.10159;-3.36938;-6.45413;-7.26133;-6.49557;-5.27702;-7.46924;-6.99965;-5.49202;-4.96605;-7.5534;-7.83423;-5.51308;-3.70531;-3.86571;-4.26314;-4.48205;-5.49117;-4.49512;-6.27911;-5.08512;-4.00586;-7.15393;-6.69563;-5.25108;-5.59171;-2.99827;-3.14982;-4.16949;-6.50946;-2.02675;-3.51492;-4.4147;-6.82462;-5.91058;-4.574;-8.33343;-6.32486;-7.55972;-1.67062;-3.28219;-3.76779;-2.25659;-0.394122;-3.56057;-3.88688;-0.503908;-1.49915;-6.70596;-4.84211;-1.59971;-2.16087;-0.896315;0.00794873;-3.26633;-4.46737;0.363503;-0.528154;0.560911;-2.80558;-3.20862;-2.79528;-6.54461;-2.83239;-4.92212;9.49514; 100 | 1.13804;-8.02376;-1.70369;-4.93869;-7.67865;-2.15941;-7.72865;-7.20111;-8.19521;-6.54804;-0.397167;-2.5599;2.34906;-7.43585;-2.35628;-3.61438;-6.98896;-7.81229;-5.9057;-5.73325;-7.65199;-8.23293;-6.11519;-5.83785;-7.1578;-7.89235;-6.89865;-3.93112;-4.3965;-4.83615;-4.28698;-5.81881;-5.12618;-6.26743;-5.27015;-4.47107;-7.38506;-7.05148;-5.30346;-5.25727;-2.59911;-3.45182;-5.38901;-6.84489;-2.58819;-3.63724;-4.79244;-6.80831;-5.86912;-4.42105;-7.85275;-5.74353;-7.55606;-1.57175;-2.84201;-3.28634;-2.45604;-0.699608;-3.68635;-3.49498;-0.192963;-0.77096;-6.8046;-4.89796;-1.66847;-1.66584;-0.157239;0.133914;-2.81215;-4.74468;0.576509;-0.233093;0.172319;-2.08497;-2.43698;-1.73889;-5.71084;-2.91299;-5.25219;8.68117; 101 | -------------------------------------------------------------------------------- /data/word/corpus.txt: -------------------------------------------------------------------------------- 1 | appoint 2 | abandon 3 | add 4 | artificial 5 | agency 6 | arch 7 | adviser 8 | accompany 9 | assembly 10 | age 11 | authority 12 | access 13 | aircraft 14 | arrange 15 | ash 16 | attractive 17 | avant-garde 18 | agree 19 | agenda 20 | area 21 | absorb 22 | accountant 23 | abbey 24 | absent 25 | accent 26 | appreciate 27 | airplane 28 | aluminium 29 | assignment 30 | assertive 31 | alarm 32 | athlete 33 | acquaintance 34 | argument 35 | advice 36 | art 37 | ask 38 | adult 39 | advantage 40 | abuse 41 | acceptable 42 | anger 43 | account 44 | aisle 45 | aquarium 46 | addition 47 | aunt 48 | abstract 49 | agile 50 | accurate 51 | attack 52 | arm 53 | agent 54 | advertise 55 | affinity 56 | ambiguous 57 | aware 58 | afford 59 | affect 60 | ambiguity 61 | archive 62 | absolute 63 | auditor 64 | architecture 65 | aspect 66 | agreement 67 | address 68 | aid 69 | air 70 | appointment 71 | admit 72 | AIDS 73 | architect 74 | agony 75 | advertising 76 | apparatus 77 | article 78 | acute 79 | angel 80 | acquisition 81 | association 82 | abortion 83 | acquit 84 | applied 85 | abundant 86 | anniversary 87 | abridge 88 | available 89 | album 90 | ample 91 | analysis 92 | answer 93 | able 94 | annual 95 | atmosphere 96 | ambition 97 | allowance 98 | awful 99 | allow 100 | attic 101 | amputate 102 | arise -------------------------------------------------------------------------------- /data/word/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/CTCDecoder/6b5c3dd34944e5399a7308e241319b7f9c47e7c3/data/word/img.png -------------------------------------------------------------------------------- /data/word/rnnOutput.csv: -------------------------------------------------------------------------------- 1 | -7.76214;-5.62231;-4.2493;-6.69015;-6.66073;-5.7536;-7.44713;-4.62077;-7.88057;-8.09553;-1.27974;-4.60857;-2.49587;-8.34131;-4.56571;-5.21836;-4.23089;-4.20566;-4.84786;-6.75935;-6.39574;-6.56756;-6.28402;-5.159;-4.78485;-6.78909;-6.0688;-2.65652;-4.09712;-6.16935;-3.90803;-5.47897;-5.94278;-1.95536;-6.65438;-5.08228;-7.87152;-6.27794;-3.24306;-4.76765;-5.75703;-6.54045;-8.64844;-8.18258;-3.23779;-5.66343;-5.8843;-6.44461;-8.74892;-5.87982;-8.21329;-7.38374;-7.86028;8.2236;-2.49935;1.50065;1.84085;-1.36386;-2.79243;0.0345535;-0.238027;-2.09928;-4.64528;-1.31079;-4.6172;-1.1936;-0.124125;-1.00592;-2.1056;-1.36663;-2.48771;0.929708;-1.35541;-1.91438;-6.10711;-0.790544;-0.971813;-1.90448;-3.85296;-2.07753; 2 | -13.5062;-14.6283;-9.44392;-13.5271;-14.3667;-9.27126;-13.2367;-13.0681;-15.1832;-15.439;-7.83853;-10.5661;-6.25097;-14.7536;-8.29031;-9.80701;-10.698;-10.5012;-12.672;-12.3262;-12.6259;-14.0192;-12.5816;-10.4145;-15.8736;-15.0227;-13.6422;-9.96284;-11.053;-11.3518;-10.6086;-10.7303;-12.1658;-9.58506;-10.1369;-9.31693;-14.9576;-12.3604;-8.51279;-10.1651;-8.78293;-13.2287;-12.3412;-15.2566;-7.38007;-11.0521;-10.7821;-11.2608;-11.8528;-11.3922;-15.6539;-12.0208;-15.6;3.17365;-5.84011;-2.93208;-2.00714;-2.57325;-3.90536;-5.09372;-0.830435;-3.9074;-9.81998;-5.44563;-3.97482;-7.06537;-2.7156;-4.88457;-5.93185;-8.66555;-4.6994;-3.38561;-2.69333;-3.86297;-8.64017;-6.05026;-5.34588;-6.56884;-9.22535;7.01954; 3 | -12.1859;-12.4652;-6.82487;-12.0757;-13.365;-7.44458;-12.8814;-11.1192;-14.0108;-14.0373;-6.33989;-8.85548;-5.20164;-13.3474;-6.97391;-8.54238;-8.76258;-10.3456;-11.0667;-11.3961;-9.90095;-12.0389;-11.2302;-9.90426;-14.45;-13.1266;-11.1952;-9.66451;-9.634;-9.57714;-10.1358;-9.05709;-11.0269;-9.86285;-9.39496;-7.67299;-13.235;-10.361;-5.29678;-9.68402;-7.95341;-13.0197;-12.3064;-14.1416;-6.57897;-10.2591;-9.50518;-8.30062;-10.2159;-9.58271;-14.5298;-11.7071;-14.6139;-0.564102;-5.68213;-1.90956;-3.03345;-2.16206;-3.33276;-5.45244;-1.30091;-2.90385;-8.88933;-4.88407;-2.81352;-6.70957;-1.98604;-5.39195;-6.81406;-9.15271;-3.79556;-4.61056;-3.49001;-1.32726;-5.92907;-5.71604;-2.50714;-6.42515;-8.42264;10.187; 4 | -11.6118;-14.66;-7.3019;-13.9755;-14.0016;-6.31368;-14.3755;-13.1861;-14.4522;-14.57;-7.57239;-6.25906;-4.65897;-13.9637;-7.8726;-9.44602;-10.1761;-13.0987;-10.6584;-11.8265;-11.1568;-11.8556;-12.6031;-11.1147;-16.5337;-14.9896;-14.0111;-12.2884;-11.9385;-10.1056;-12.3873;-9.88594;-10.8487;-12.1607;-11.3579;-9.53473;-14.873;-12.0708;-4.6368;-12.3894;-11.2626;-13.0148;-13.6093;-14.9503;-11.3756;-10.9083;-9.49112;-7.4897;-9.83718;-11.9693;-14.9191;-13.6288;-14.4484;-3.3627;-8.16219;0.0688379;-3.66225;-3.59774;-4.78024;-6.95815;-5.30174;-1.80568;-11.34;-6.98188;-0.927909;-5.43914;-2.18372;-6.39892;-8.39792;-11.8554;-2.23487;-4.8469;-4.40082;0.789544;-3.25046;-4.99633;-4.71429;-6.74818;-10.0263;11.448; 5 | -9.83744;-16.9316;-6.16608;-12.9694;-13.6088;-2.89411;-14.5942;-13.5878;-12.0375;-11.7354;-6.56959;-2.78849;-4.38017;-11.8793;-7.27476;-9.84474;-10.6382;-12.4705;-10.0986;-9.60156;-10.9088;-10.4286;-11.8998;-11.3063;-16.0812;-14.6858;-15.0739;-11.1098;-12.0953;-10.7137;-13.0305;-9.91881;-11.5149;-12.429;-13.6819;-6.47668;-14.5773;-12.6471;-7.5176;-11.9405;-12.3581;-11.0065;-13.1136;-12.815;-13.6626;-9.81414;-8.46885;-8.62408;-10.0683;-12.4524;-12.0182;-13.4937;-11.9843;-4.66908;-7.84478;-0.216665;-3.99753;-6.85787;-4.09931;-6.95805;-6.53302;5.37189;-5.86519;-7.9599;1.01826;-4.39396;-2.30493;-6.39131;-8.38805;-11.476;-0.923265;-4.42098;-3.92325;-0.300759;-3.08192;-4.01377;-7.497;-6.53692;-8.59161;8.0255; 6 | -10.231;-14.5504;-5.4415;-11.5239;-12.4226;-4.28594;-11.7684;-12.4438;-11.5456;-11.5324;-7.00172;-4.23643;-4.41415;-11.0449;-7.36089;-9.25219;-9.3255;-9.36099;-9.62999;-9.25328;-10.3046;-8.80444;-10.8945;-12.0211;-13.6422;-13.568;-13.5395;-8.24941;-10.0979;-10.2208;-10.4213;-7.9357;-9.55588;-9.72003;-12.4478;-3.6265;-11.3846;-10.7656;-8.04101;-10.4752;-10.7937;-9.16926;-11.1239;-11.5105;-11.961;-8.94455;-5.60038;-9.36379;-8.64543;-10.1001;-11.4091;-10.2612;-11.1512;-3.13522;-7.12402;-1.72636;-4.87284;-5.88774;-3.73617;-4.93624;-7.32597;9.27634;-3.3262;-6.33054;-0.565507;-2.15452;-0.607178;-4.42063;-6.94008;-10.7607;0.385242;-1.81374;-1.62587;-1.15451;-2.58354;-4.5485;-7.49884;-4.65541;-5.80447;4.14115; 7 | -8.95248;-13.1251;-4.59945;-8.95106;-11.1402;-3.62933;-9.17815;-11.5866;-10.2364;-10.7004;-7.05355;-3.2278;-3.9997;-10.1893;-5.93818;-7.73341;-7.48781;-6.0908;-8.42565;-7.24872;-7.21047;-7.54964;-8.47817;-8.84071;-11.4684;-12.0762;-11.8431;-7.16692;-7.4378;-8.12262;-7.9834;-4.88298;-6.79974;-8.48693;-9.96187;-1.76154;-9.03825;-8.55904;-5.47775;-8.87497;-8.36108;-7.38289;-8.38586;-10.4347;-9.56886;-8.16104;-4.82741;-8.23542;-6.91715;-8.77116;-10.9116;-8.21223;-10.404;-3.65013;-6.00837;-2.85575;-5.44772;-3.36689;-3.05714;-2.96453;-6.49788;8.98452;-1.63532;-4.61597;-1.40552;-2.06509;0.645539;-3.15808;-4.8911;-8.23132;1.00895;-0.932281;-2.48469;-2.0425;-2.91111;-6.26671;-6.19812;-4.22456;-3.21104;5.64194; 8 | -16.645;-27.4562;-18.4112;-19.3749;-22.3077;-11.7112;-23.3279;-25.6446;-18.4617;-19.0751;-17.9979;-8.15002;-11.328;-18.1556;-13.1519;-18.5411;-18.1676;-18.0784;-18.986;-17.6209;-16.9427;-18.6937;-19.2173;-19.931;-27.1921;-25.5541;-25.0135;-18.1449;-19.0627;-17.4223;-18.4205;-15.9784;-17.8234;-21.6035;-17.0345;-14.1729;-23.5418;-20.269;-17.107;-18.0799;-15.6752;-16.3997;-17.937;-18.8971;-18.7451;-17.4456;-14.0361;-19.9955;-19.4631;-19.4953;-19.8354;-19.7223;-18.6786;-9.00997;-12.1802;-7.19917;-9.10307;-5.40329;-9.69636;-9.11696;-10.1077;1.63956;-15.8227;-9.86524;-5.98511;-6.67588;-2.35;-6.85522;-7.7844;-20.3411;-0.761795;-5.16155;-6.56647;-6.47994;-5.35303;-9.14389;-12.913;-11.1779;-12.9401;11.232; 9 | -7.96866;-6.69133;-4.00561;-5.94063;-7.0689;-5.0134;-4.06243;-6.12373;-7.20209;-8.95335;-6.38919;-3.64889;-4.32166;-8.76579;-5.82079;-3.99017;-6.24942;-5.74224;-8.08705;-4.3993;-4.35968;-7.22993;-5.03418;-6.24252;-6.81156;-6.46931;-5.96117;-4.1187;-4.44035;-4.82785;-5.38027;-4.41075;-0.34224;-4.97356;-5.76498;-5.77329;-4.15142;-5.48975;-3.85095;-2.09686;-4.21205;-7.25816;-0.859927;-8.41624;-2.94936;-5.18913;-0.860456;-6.42991;-3.41007;-5.70094;-9.42344;-3.84324;-9.21608;-4.13559;-4.67461;-0.0830826;-4.60774;-3.49226;-0.677678;-3.76156;-8.48146;-1.65077;-2.66369;-1.92256;-3.25881;-0.199376;3.40243;-4.35621;0.774296;-3.3292;9.98369;2.55381;0.158886;-3.93171;0.873681;-2.21035;2.48233;-3.82178;-0.465842;-0.0275634; 10 | -16.5669;-12.6084;-9.27236;-13.4789;-17.4194;-7.77672;-12.7349;-13.5931;-15.9209;-18.6999;-9.04815;-5.54777;-7.13468;-18.2676;-12.9063;-11.4708;-14.3419;-13.8182;-15.0852;-13.7523;-12.6743;-15.1471;-15.2266;-15.5933;-13.137;-12.981;-13.4295;-11.6947;-12.0573;-10.495;-11.5381;-11.0825;-8.77909;-14.4971;-12.4755;-11.7877;-13.9532;-14.6485;-9.29668;-10.3814;-12.4241;-13.7055;-8.37453;-19.0869;-9.2287;-12.7463;-5.45073;-14.0495;-11.5821;-13.0475;-18.9731;-11.3891;-19.3996;-8.83888;-9.81837;-3.05709;-6.2765;-5.05825;-3.53628;-7.22463;-9.84781;-2.4804;-8.7069;-8.65075;-3.56874;-4.46999;-1.9476;-8.0044;-3.44242;-8.78835;5.62706;-2.90922;-2.48861;-9.16477;-5.36171;-9.63991;-7.23076;-7.89747;-7.71824;10.211; 11 | -18.9972;-24.3541;-17.9354;-18.2064;-23.0105;-12.1136;-24.8342;-23.4939;-18.0919;-20.0257;-15.6554;-6.97123;-11.1942;-18.8363;-15.1451;-18.3332;-17.9634;-19.2666;-19.3333;-18.3021;-19.3602;-19.2184;-20.605;-21.0068;-25.3007;-23.5203;-22.7839;-18.7565;-17.4209;-13.3309;-17.9659;-14.6888;-18.8647;-21.1176;-19.1857;-19.0219;-21.7772;-19.7116;-15.2986;-16.8063;-18.5646;-18.624;-18.3696;-19.572;-16.3672;-16.4963;-14.0077;-18.6513;-18.4508;-17.5722;-19.9575;-18.0381;-19.729;-8.55354;-12.4809;-1.80461;-7.55378;-3.8362;-8.3371;-8.36445;-11.0822;-3.66045;-18.1602;-10.369;-4.36354;-8.68369;-6.78907;-9.93256;-7.98308;-18.7971;-2.78281;-5.59489;-5.39652;-8.73623;-9.90903;-10.1106;-13.8105;-10.2198;-12.6777;11.8068; 12 | -6.80734;-7.33574;-5.43516;-6.39252;-7.72016;-3.34382;-6.71712;-6.91565;-5.4361;-5.9671;-3.92954;0.207484;-3.75977;-5.65815;-3.99059;-4.92063;-5.07312;-7.44898;-6.90007;-7.47427;-7.72153;-6.0867;-8.12021;-8.2271;-9.15766;-8.0358;-8.31851;-10.0169;-6.12393;0.932926;-7.58997;-4.37879;-6.62786;-6.06745;-9.67357;-7.1031;-7.45122;-6.56258;-3.74921;-7.19543;-7.7201;-6.69692;-6.11011;-6.21087;-5.60727;-7.49029;-3.411;-5.15052;-5.52876;-4.59045;-6.08651;-6.04986;-6.06082;-2.83395;-6.69997;10.8091;-4.45466;2.79431;-3.66496;-1.63787;-6.57299;-3.53291;-7.19715;-2.06215;-2.51935;-3.80982;-2.06067;-2.51364;-1.06067;-5.97917;0.741466;-0.146213;-3.2912;-0.92702;-0.566928;0.267785;-4.04308;-2.72765;-1.20394;1.43569; 13 | -9.0586;-8.77316;-4.39732;-8.38935;-9.32893;-5.33925;-6.73501;-9.14405;-8.83157;-10.4478;-6.1344;-3.16646;-5.67925;-9.61027;-7.28164;-5.54288;-6.82489;-8.86872;-9.10437;-9.98297;-9.37605;-8.55216;-10.3513;-9.15148;-9.27379;-9.50199;-9.05939;-10.4571;-8.45355;-1.57498;-8.65855;-7.52612;-5.53043;-8.35769;-8.16551;-7.00103;-10.05;-9.08183;-4.30787;-9.29041;-9.78816;-8.2577;-4.63451;-10.6264;-7.06075;-7.99931;-3.51182;-7.86667;-8.05985;-10.0724;-10.7255;-8.91166;-10.6506;-3.21692;-7.88592;8.3113;-3.40459;2.18664;-4.25326;-2.25807;-4.47204;-4.10091;-7.70465;-5.54247;-2.89272;-4.47158;-3.80425;-4.64722;-4.87601;-7.7248;-0.58043;-2.0211;-3.96313;-2.28383;-3.10948;-5.2328;-8.7093;-6.05275;-5.77184;6.57815; 14 | -16.1435;-16.6366;-9.83893;-14.9698;-17.4292;-5.49679;-12.9306;-16.5271;-15.6729;-17.6654;-10.495;-4.05914;-7.86493;-17.001;-12.4149;-12.5941;-14.0616;-14.5162;-15.972;-15.4304;-14.517;-15.6064;-15.9664;-16.5086;-17.9133;-17.2635;-15.2113;-15.4359;-14.0987;-9.80823;-14.5821;-13.7927;-12.4955;-16.5102;-13.7214;-10.7877;-17.7131;-18.1335;-10.4293;-13.7509;-16.7277;-11.3582;-12.3529;-17.8477;-15.3539;-12.6742;-10.1046;-15.6193;-15.4472;-17.314;-17.8416;-15.5194;-18.0263;-6.56391;-11.0513;1.38063;-3.06276;-1.84931;-7.98687;-6.28446;-5.41261;-1.84158;-12.8608;-10.6269;-3.6243;-6.92941;-6.62876;-6.01436;-10.5834;-14.3696;-2.14467;-5.67292;-5.13612;-4.93935;-7.45779;-10.6826;-14.7532;-10.0716;-13.7228;11.6464; 15 | -13.7181;-12.5575;-8.10632;-11.448;-13.5834;-3.10654;-9.24656;-12.7909;-12.8525;-14.3706;-8.55323;-3.38202;-6.30053;-14.0513;-12.279;-11.492;-12.1787;-12.5838;-12.7661;-12.6879;-12.2649;-13.0572;-13.151;-13.9483;-13.9855;-13.116;-11.4559;-13.3698;-9.42803;-10.0509;-12.1745;-11.6455;-8.93734;-13.0462;-10.8717;-8.33188;-13.2707;-14.6157;-9.89406;-11.2292;-14.9587;-9.84743;-9.02879;-14.0022;-13.3184;-9.88802;-7.78186;-14.4254;-12.3308;-14.3411;-14.3888;-12.3916;-14.5232;-6.62716;-8.72979;-1.54824;-3.54638;-2.84423;-6.21805;-4.78457;-4.84722;-1.41091;-8.93522;-8.32461;-4.18821;-4.61949;-6.64411;-5.71058;-9.24478;-9.71192;-0.637286;-3.37652;-3.26469;-5.41741;-7.57979;-8.46282;-11.6311;-9.33556;-9.87357;10.6531; 16 | -21.2109;-28.6874;-20.2671;-20.7262;-23.996;-11.4268;-24.1548;-26.7999;-20.9445;-22.2757;-17.31;-11.3097;-13.3648;-21.3316;-18.6452;-21.1541;-22.7208;-24.9391;-22.6197;-21.373;-22.8799;-22.8126;-23.5006;-24.338;-31.5487;-27.0181;-26.6626;-21.9991;-23.3153;-21.6611;-22.6834;-21.6437;-21.4761;-24.6145;-18.4413;-17.8664;-27.9531;-23.9614;-22.4755;-18.5773;-20.5472;-20.132;-19.9208;-21.8196;-21.5974;-19.4449;-15.2644;-24.3403;-21.5142;-23.1676;-22.2395;-21.2184;-22.6862;-9.8994;-12.8055;-7.05046;-9.83039;-5.40625;-10.6534;-10.5624;-8.0626;-4.88681;-22.019;-11.6992;-7.21109;-8.06215;-6.65972;-8.19594;-10.1938;-22.8913;-0.261905;-4.78007;-6.3482;-7.46272;-6.24609;-9.78128;-15.2243;-11.3962;-16.3209;11.0968; 17 | -6.273;-5.6887;-3.24169;-3.80094;-5.98833;-1.71118;-2.24441;-6.40342;-4.6916;-6.66567;-4.65059;-2.55514;-5.7078;-6.18789;-5.94327;-1.97174;-7.04251;-6.64416;-6.31223;-4.31814;-4.29726;-5.20993;-6.08147;-5.59457;-5.77814;-6.59944;-5.49074;-5.16364;-6.1018;-5.42237;-5.84739;-8.6856;-3.74955;-6.61877;-6.57567;-4.54432;-4.95148;-7.03242;-4.7992;-2.07841;-6.22507;-5.97494;-1.56072;-5.70552;-3.88393;-4.7325;-0.0971746;-6.63266;-2.40751;-6.20121;-6.08217;-2.39869;-6.93161;-4.90425;-5.14693;-0.295856;-5.29437;-3.15111;-2.82243;-4.91697;-3.42036;-0.78257;-2.74643;-4.1808;0.516135;-0.308401;0.655728;-4.35115;-1.63534;-6.55937;9.16152;0.34076;-0.698383;-2.96687;3.1022;-2.7028;-0.806812;-1.46935;-2.35283;1.61181; 18 | -14.5291;-14.355;-9.24071;-12.8636;-15.2932;-8.83859;-10.493;-16.0417;-14.4811;-17.2204;-8.70792;-7.30292;-8.88104;-15.6897;-13.6281;-9.88494;-14.5697;-15.7845;-14.4016;-15.0154;-12.0591;-14.1771;-16.2121;-14.8499;-14.3238;-15.2059;-13.5246;-9.71109;-13.6089;-11.3846;-10.9136;-14.6584;-10.5198;-15.2524;-11.5095;-9.38816;-15.179;-15.8933;-10.6717;-10.4381;-14.3311;-11.717;-8.85494;-16.7414;-9.08149;-13.0419;-5.41887;-14.8873;-13.0139;-13.6328;-16.2542;-10.5383;-17.3164;-6.47726;-10.5287;-5.43557;-6.83563;-4.75574;-5.44706;-9.72129;-5.7304;-0.75137;-9.50957;-11.39;-1.43281;-4.01256;-4.06523;-6.0796;-5.97181;-11.7332;3.15472;-5.22511;-3.20488;-8.16883;-6.09454;-11.6492;-10.9396;-7.75306;-10.5182;11.011; 19 | -21.4293;-27.6509;-19.0884;-20.8533;-24.6427;-14.641;-24.9671;-27.5726;-22.1023;-24.5876;-17.4908;-11.386;-12.4026;-21.9382;-19.8994;-19.4536;-21.6806;-23.7832;-23.0706;-22.833;-21.8682;-22.3944;-24.4377;-24.1178;-29.5602;-28.1047;-24.3294;-17.3491;-21.8988;-17.9898;-20.5964;-18.4888;-23.6132;-23.8328;-20.7684;-19.5121;-26.3858;-24.9188;-18.8659;-20.4613;-22.2904;-18.2741;-22.4623;-22.8575;-18.7477;-19.849;-15.6392;-21.9295;-24.0208;-20.1749;-23.5439;-20.956;-23.428;-3.60016;-14.7634;-8.49462;-8.87323;-5.63525;-12.5131;-12.3768;-9.13225;-2.36333;-23.9332;-14.3723;-4.62868;-11.4159;-9.24505;-7.48049;-10.8037;-25.2521;-3.37606;-8.9067;-5.71144;-8.55317;-14.4824;-14.6762;-19.088;-13.274;-19.4988;12.662; 20 | -10.9217;-7.25202;-4.83512;-8.69465;-8.85772;-6.3149;-6.68957;-10.0824;-9.09097;-10.6675;-4.88647;-6.7858;-4.12304;-9.81996;-9.58001;-7.58992;-8.07252;-9.22714;-9.77987;-11.1665;-10.9099;-8.84845;-10.149;-8.66161;-7.16901;-8.99043;-8.70395;-2.22834;-10.6042;-8.17383;-9.77649;-7.99855;-10.4668;-8.7189;-9.38903;-7.32637;-12.45;-12.0118;-10.1428;-10.7305;-12.0814;-5.88694;-12.9443;-9.49442;-7.2135;-9.58261;-8.78715;-8.76097;-12.2681;-8.22168;-9.5007;-9.39799;-9.70548;8.97014;-6.27596;-2.46211;-3.88417;3.204;-5.97809;-3.55937;-3.23383;-2.46016;-11.751;-5.41238;-4.73078;-2.39012;-2.48174;2.78955;-4.2048;-6.46183;-3.7001;-3.23048;-4.93732;1.2341;-12.6316;-5.60422;-9.13871;-5.41093;-8.66714;0.451675; 21 | -11.3928;-7.54332;-4.67187;-9.99549;-10.9002;-6.87982;-7.04603;-10.9152;-10.3545;-12.9565;-7.28677;-11.1925;-5.64887;-11.1335;-11.5099;-8.40364;-8.01119;-8.99502;-10.4769;-10.894;-10.6122;-9.4938;-9.42884;-7.95045;-9.63769;-9.62726;-9.42827;-5.19993;-10.0418;-9.29343;-9.30385;-7.80355;-8.78331;-9.60104;-7.12246;-7.0597;-11.838;-10.4173;-7.96765;-10.0347;-10.0274;-8.55928;-10.1838;-12.2348;-5.74153;-11.0263;-7.13112;-8.20146;-9.2194;-7.03371;-12.6927;-6.78688;-11.9447;4.26701;-6.65348;-3.82883;-3.7419;1.04296;-5.88977;-4.14936;-5.97768;-2.75569;-10.4318;-6.34886;-4.67852;-5.17139;-3.9067;-0.962589;-5.01905;-7.19801;-4.74674;-2.84518;-5.64713;-2.70383;-11.1046;-7.55876;-8.54959;-6.92315;-8.15911;8.58913; 22 | -12.8817;-12.4988;-7.99978;-12.419;-14.5194;-7.64388;-11.4522;-13.3692;-13.5337;-15.8544;-9.22042;-11.8;-6.31914;-13.5433;-12.8814;-9.23208;-10.1596;-11.5463;-13.3436;-12.7988;-12.8433;-11.2733;-11.2367;-11.2606;-15.327;-13.2319;-13.3346;-8.02277;-10.6301;-9.81606;-11.48;-9.63978;-9.99568;-12.6148;-9.29002;-8.67775;-13.1978;-13.1927;-7.30143;-11.5352;-11.5935;-9.40676;-9.65788;-15.1663;-9.00492;-12.1992;-8.8131;-10.0274;-11.2544;-10.8877;-15.6483;-10.5456;-14.7851;-0.268311;-9.29035;-2.92453;-4.55638;0.556035;-6.46229;-4.68629;-7.93809;-0.877332;-12.1029;-8.39476;-4.61604;-7.27792;-4.29535;-2.97391;-5.0996;-10.582;-4.70517;-3.4483;-6.95097;-2.42768;-9.31983;-9.83868;-9.1535;-8.02332;-10.2319;11.2746; 23 | -12.0009;-15.2282;-10.8465;-11.5795;-14.5322;-4.05495;-13.3105;-14.2095;-12.2971;-14.2436;-9.45503;-8.40874;-6.28837;-12.7461;-12.0556;-10.9082;-12.1406;-13.746;-13.1214;-12.9285;-12.2496;-11.8811;-12.2772;-12.6765;-17.0389;-15.0631;-14.2084;-10.5577;-10.4549;-11.3373;-12.9135;-10.5708;-9.83614;-12.8608;-11.5089;-10.192;-13.6683;-13.5708;-9.58842;-11.9542;-12.7791;-9.75811;-10.2066;-13.6697;-11.1113;-11.2398;-10.5772;-11.4773;-11.6152;-12.9419;-13.7376;-9.83028;-13.4083;-5.96495;-9.22633;-3.25587;-6.5791;-1.70662;-6.04294;-3.91719;-10.2049;-1.37793;-9.55177;-8.64386;-3.74104;-7.15281;-4.35463;-5.25486;-4.9779;-9.48934;-3.40461;-3.14784;-7.18311;-4.54627;-8.35667;-9.77385;-8.89604;-6.30414;-9.29717;11.3098; 24 | -4.9991;-10.5712;-10.0995;-5.56646;-6.89062;-2.65767;-8.47829;-8.29081;-4.70257;-6.34379;-7.52629;-7.46118;-7.34643;-6.49724;-6.69796;-7.65433;-5.55197;-7.61467;-6.15882;-6.36032;-6.9974;-5.41975;-5.37397;-6.5983;-8.59427;-6.87867;-9.23085;-4.72853;-7.15368;-9.24607;-8.49292;-6.30416;-3.63588;-7.35057;-8.93296;-7.93357;-6.94023;-8.85019;-9.69593;-7.14139;-8.45006;-6.42365;-3.75407;-7.29685;-4.65626;-7.9512;-5.01393;-8.94441;-8.668;-9.11438;-6.36174;-6.17438;-5.91937;-6.17745;-6.30723;-4.02626;-5.39871;-2.74674;3.68191;-0.0841334;-8.16753;-2.37589;-2.25918;-3.49428;-1.01308;-5.87307;-2.70209;-6.75424;6.028;-0.950421;-3.36887;-2.46741;-3.38427;-6.6556;-6.62064;-6.06423;-2.48371;-2.90236;-2.83057;4.45127; 25 | -5.60931;-8.72476;-7.92563;-5.86569;-6.18704;-1.16881;-4.98263;-6.47676;-5.16655;-6.90312;-5.29201;-7.4943;-5.81143;-6.89678;-6.46216;-6.12201;-2.872;-6.18893;-6.55101;-6.6127;-6.40438;-5.89229;-4.04733;-6.55891;-6.86163;-4.10772;-7.20434;-2.50527;-5.62841;-4.51968;-6.52436;-4.3398;-2.27234;-5.26388;-7.5158;-5.12595;-5.07723;-7.32904;-5.12683;-3.03626;-4.42124;-4.53199;0.473405;-7.78474;-1.92557;-7.1234;-3.82444;-6.47092;-6.62417;-6.82717;-7.33196;-4.39817;-7.10504;-5.15974;-6.37219;-2.47913;-6.47875;-1.283;4.73254;0.928209;-4.51568;-2.38306;-2.93453;-4.48432;-0.61104;-7.69219;-1.18316;-6.89708;8.90815;-0.860142;-2.76933;-2.73624;-2.97848;-4.63815;-5.75445;-6.25519;-3.53438;-0.87253;-2.9639;4.88026; 26 | -10.3876;-12.264;-9.88955;-11.4019;-11.7338;-1.07113;-8.96669;-10.6254;-10.9743;-13.1636;-7.56295;-7.21757;-6.88477;-12.3836;-9.41012;-10.7539;-7.43038;-10.826;-10.7632;-10.8693;-10.9601;-11.3654;-9.61885;-10.9413;-12.863;-9.15181;-9.63325;-7.0828;-9.31105;-5.35807;-9.19147;-7.4615;-6.96753;-10.8639;-9.15506;-6.93531;-10.2294;-11.1727;-5.75895;-6.65718;-9.01967;-7.61485;-3.80197;-13.6211;-6.21752;-8.80429;-5.50176;-7.2984;-9.48775;-10.1607;-13.3421;-7.54119;-12.9379;-5.8231;-9.57155;-1.89708;-8.87232;0.230901;-1.50246;-3.37894;-3.93338;-3.85665;-11.2514;-9.52763;-1.55925;-12.0768;-3.64594;-5.99508;1.22063;-8.66952;-4.38462;-6.00791;-5.53701;-3.99541;-8.66662;-9.48338;-8.49153;-4.64901;-7.59933;10.9674; 27 | -12.6414;-15.6479;-10.5688;-14.5001;-15.9936;1.39343;-12.1634;-13.3525;-13.527;-15.7302;-8.0675;-3.86562;-6.34362;-14.4951;-9.91035;-12.9071;-11.5489;-13.9152;-14.2125;-12.7321;-14.1033;-14.1064;-13.0403;-13.7426;-19.3563;-14.9715;-13.2192;-11.8192;-12.336;-7.01853;-12.6988;-10.8682;-10.2478;-14.1554;-12.2164;-8.89626;-13.9113;-15.2632;-9.48871;-11.3308;-13.2232;-8.53363;-9.43692;-15.3159;-12.5814;-9.09795;-8.66831;-10.4732;-11.6172;-14.5537;-15.4538;-9.73303;-15.3213;-7.5198;-11.2711;-0.389471;-8.04587;-1.82482;-4.16988;-6.38587;-4.31795;-4.84769;-14.3528;-12.1339;-1.07345;-10.8404;-5.57144;-5.86499;-4.29962;-11.7889;-4.96434;-6.60751;-5.84706;-3.48416;-11.1467;-10.8615;-11.2676;-7.47887;-10.6293;10.3836; 28 | -13.9625;-16.1006;-12.397;-15.1642;-17.9323;1.06935;-14.31;-14.0284;-14.8409;-17.4538;-8.75588;-4.50911;-7.29898;-15.9729;-11.8626;-13.8299;-14.4798;-17.2899;-14.9056;-14.9063;-15.2016;-16.2342;-15.8811;-16.4977;-22.9584;-16.7627;-15.562;-11.2225;-13.285;-11.3453;-14.052;-12.2526;-11.4159;-14.6909;-11.8626;-9.64001;-16.1024;-16.7757;-11.8619;-12.4382;-14.0983;-11.2499;-12.1581;-17.2572;-12.8439;-10.6754;-10.3944;-13.4508;-13.5347;-14.8794;-17.3247;-12.3999;-17.6249;-7.56381;-8.83183;-2.5927;-8.1036;-3.14862;-3.65701;-6.98684;-2.93294;-4.04735;-15.0206;-9.92781;0.118442;-9.35144;-4.6515;-5.36948;-6.46101;-12.3444;-4.59262;-5.45201;-3.46494;-5.23537;-11.7714;-11.0713;-11.4171;-10.2829;-10.9135;10.0771; 29 | -10.4978;-12.4396;-9.07436;-10.7931;-13.3886;0.807511;-12.7995;-11.6317;-10.261;-13.6957;-5.01735;-6.37962;-6.49825;-12.0441;-10.7718;-5.95541;-10.3395;-11.4658;-9.57307;-10.6093;-9.04322;-10.961;-9.59194;-11.1184;-17.0971;-12.6076;-10.8442;-3.68629;-7.24575;-8.43352;-9.81173;-6.6359;-5.34494;-8.9958;-5.6702;-4.59376;-11.0168;-9.54804;-4.17835;-8.20373;-6.00902;-7.55099;-7.00227;-12.6291;-7.04198;-6.49973;-8.94795;-8.31459;-8.7489;-6.68038;-13.5401;-8.38697;-13.7395;-4.58442;-4.90689;-5.67812;-5.82777;-1.96531;-0.19797;-5.50077;-1.39287;-1.17468;-10.4201;-5.23984;-0.328433;-5.78809;-4.20796;-2.53646;-4.82854;-11.2462;-2.61545;-3.27244;1.31567;-4.6023;-7.04678;-6.96595;-6.42697;-6.86922;-5.95585;8.97875; 30 | -14.5882;-17.9515;-12.4786;-14.27;-19.5324;-4.25172;-17.8458;-17.4066;-15.314;-18.2521;-7.79907;-10.1787;-8.15769;-15.8189;-13.575;-10.4072;-14.151;-15.8624;-15.212;-15.4305;-12.0821;-16.5163;-14.4634;-14.8897;-21.3163;-16.7181;-14.7763;-7.32454;-11.3358;-10.8644;-10.9987;-10.1639;-9.34151;-12.7723;-8.79709;-8.60186;-15.8344;-13.3583;-6.84352;-7.86196;-7.1588;-9.35976;-10.394;-17.7766;-9.58959;-9.77608;-9.26338;-10.8112;-12.4635;-7.95737;-18.1809;-11.6901;-17.9574;-6.10836;-6.44119;-8.54675;-6.65771;-2.17478;-3.76748;-9.13337;-0.913811;-2.38499;-14.9783;-6.08852;-2.45912;-7.77453;-4.22172;-1.41564;-6.27772;-14.9532;-3.3609;-5.80879;1.95068;-4.81334;-8.15642;-6.22375;-8.77917;-8.27524;-10.1648;14.2672; 31 | -12.4425;-21.561;-13.0964;-13.8778;-18.5622;-7.22929;-17.5512;-17.5014;-15.5607;-16.8187;-11.6874;-5.66178;-6.64332;-15.4047;-11.8127;-11.9184;-13.0036;-15.9206;-14.5632;-14.6815;-13.5714;-14.8475;-13.47;-13.2457;-24.1174;-18.1276;-15.4098;-9.09456;-11.8193;-9.72197;-10.4182;-9.38709;-11.4523;-12.5712;-10.3963;-11.1286;-16.4189;-14.3378;-7.9943;-9.40662;-8.85535;-10.4697;-10.9746;-17.0542;-10.0881;-11.1871;-7.5449;-9.7232;-12.1161;-8.19456;-16.8198;-11.6223;-16.5738;-3.47356;-8.91376;-4.52963;-6.00827;1.19185;-3.60238;-8.84962;-0.304443;-2.23194;-15.7692;-5.46351;-2.44811;-8.32986;-1.98053;-2.13116;-7.28914;-12.8237;-0.334987;-5.28257;1.90552;-1.33145;-8.93931;-6.38136;-8.05429;-7.36091;-8.88246;11.2779; 32 | -7.13515;-5.52204;-3.28631;-5.07854;-7.54506;-1.01173;-6.50107;-4.70085;-7.8216;-9.52412;-3.54314;0.362296;0.691244;-8.96489;-5.02481;-3.06606;-4.88325;-4.32368;-5.09434;-6.04878;-5.7849;-5.50493;-5.14916;-4.85093;-9.55186;-5.10806;-3.69755;-1.27203;-2.66242;-2.3925;-2.37045;-1.2499;-2.58325;-2.82935;-2.19384;-3.35373;-3.92999;-5.87393;-0.59945;-1.39301;-2.65359;-2.58713;-3.28158;-9.54959;-1.95553;-2.119;0.146295;-3.0373;-4.48239;-2.70575;-9.8039;-2.69899;-9.63146;-0.950681;-2.33663;-0.816389;0.624837;1.59769;2.75077;-2.62994;2.21472;-0.187826;-6.66059;-0.189182;1.83678;-4.2679;0.199314;-0.298237;-3.35348;-3.16512;2.95968;0.295224;8.84706;-0.266561;-4.97011;-3.6491;-1.91869;-0.7069;-3.68616;2.61447; 33 | -------------------------------------------------------------------------------- /doc/comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/CTCDecoder/6b5c3dd34944e5399a7308e241319b7f9c47e7c3/doc/comparison.pdf -------------------------------------------------------------------------------- /doc/line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/CTCDecoder/6b5c3dd34944e5399a7308e241319b7f9c47e7c3/doc/line.png -------------------------------------------------------------------------------- /doc/mini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/CTCDecoder/6b5c3dd34944e5399a7308e241319b7f9c47e7c3/doc/mini.png -------------------------------------------------------------------------------- /doc/word.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/CTCDecoder/6b5c3dd34944e5399a7308e241319b7f9c47e7c3/doc/word.png -------------------------------------------------------------------------------- /extras/best_path_cl.cl: -------------------------------------------------------------------------------- 1 | /* 2 | best path decoding implemented in OpenCL 3 | two variants are provided: 4 | variant 1: single-pass kernel 5 | variant 2: two-pass kernel 6 | time measurement for 10000 batch elements on an AMD Radeon 8570: 7 | variant 1: 25ms 8 | variant 2: 355ms + 5ms 9 | MAX_T, MAX_C and STEP_BEGIN are defined via program build options to avoid passing constant values to each kernel 10 | */ 11 | 12 | 13 | // index of BxTxC matrix to 1d index 14 | int btcOffset1d(int b, int t) 15 | { 16 | return b * MAX_C * MAX_T + t * MAX_C; 17 | } 18 | 19 | 20 | // index of BxT matrix to 1d index 21 | int btOffset1d(int b) 22 | { 23 | return b * MAX_T; 24 | } 25 | 26 | 27 | // variant 1: single pass kernel 28 | __kernel void bestPathAndCollapse(__global float* in, __global int* out) 29 | { 30 | // constants 31 | const int b = get_global_id(0); 32 | const int t = get_local_id(1); 33 | const int blankLabel = MAX_C - 1; 34 | const int btcOffset = btcOffset1d(b, t); 35 | 36 | // find character with highest probability 37 | float bestVal = 0.0f; 38 | int bestIdx = 0; 39 | for(int c = 0; c < MAX_C; ++c) 40 | { 41 | const float currVal = in[btcOffset + c]; 42 | if(currVal > bestVal) 43 | { 44 | bestVal = currVal; 45 | bestIdx = c; 46 | } 47 | } 48 | 49 | // save result in local memory 50 | __local int locIdx[MAX_T]; 51 | locIdx[t] = bestIdx; 52 | barrier(CLK_LOCAL_MEM_FENCE); 53 | 54 | // collapse 55 | if(t == 0) 56 | { 57 | const int btOffset = btOffset1d(b); 58 | int lastLabel = blankLabel; 59 | int v = 0; 60 | for(int u = 0; u < MAX_T; ++u) 61 | { 62 | const int currLabel = locIdx[u]; 63 | if(currLabel != lastLabel && currLabel != blankLabel) 64 | { 65 | out[btOffset + v] = currLabel; 66 | v++; 67 | } 68 | lastLabel = currLabel; 69 | } 70 | 71 | // put end marker at end of label string if needed 72 | if(v != MAX_T) 73 | { 74 | out[btOffset + v] = blankLabel; 75 | } 76 | } 77 | } 78 | 79 | 80 | // struct holds index and value of a character 81 | typedef struct __attribute__ ((packed)) 82 | { 83 | float val; 84 | int idx; 85 | } ValueIndexPair; 86 | 87 | 88 | // variant 2: pass 1/2, compute best path 89 | __kernel void bestPath(__global float* in, __global int* out) 90 | { 91 | // constants 92 | const int b = get_global_id(0); 93 | const int t = get_global_id(1); 94 | const int c = get_local_id(2); 95 | 96 | // put into local memory 97 | __local ValueIndexPair valueIndexPairs[MAX_C]; 98 | __local ValueIndexPair* currPtr = valueIndexPairs + c; 99 | currPtr->val = in[btcOffset1d(b, t)+c]; 100 | currPtr->idx = c; 101 | barrier(CLK_LOCAL_MEM_FENCE); 102 | 103 | // reduce to largest value and corresponding index 104 | for(int i = STEP_BEGIN; i > 0; i >>= 1) 105 | { 106 | if(c < i && c + i < MAX_C) 107 | { 108 | __local ValueIndexPair* otherPtr = valueIndexPairs + c + i; 109 | *currPtr = currPtr->val < otherPtr->val ? *otherPtr : *currPtr; 110 | } 111 | 112 | barrier(CLK_LOCAL_MEM_FENCE); 113 | } 114 | 115 | // write best label index to global memory 116 | if(c == 0) 117 | { 118 | out[btOffset1d(b) + t] = currPtr->idx; 119 | } 120 | } 121 | 122 | 123 | // variant 2: pass 2/2, collapse best path 124 | __kernel void collapsePath(__global int* in, __global int* out) 125 | { 126 | // constants 127 | const int b = get_global_id(0); 128 | const int blankLabel = MAX_C - 1; 129 | const int btOffset = btOffset1d(b); 130 | 131 | // collapse 132 | int lastLabel = blankLabel; 133 | int v = 0; 134 | for(int u = 0; u < MAX_T; ++u) 135 | { 136 | const int currLabel = in[btOffset + u]; 137 | if(currLabel != lastLabel && currLabel != blankLabel) 138 | { 139 | out[btOffset + v] = currLabel; 140 | v++; 141 | } 142 | lastLabel = currLabel; 143 | } 144 | 145 | // put end marker at end of label string if needed 146 | if(v != MAX_T) 147 | { 148 | out[btOffset + v] = blankLabel; 149 | } 150 | } 151 | 152 | -------------------------------------------------------------------------------- /extras/best_path_cl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import pyopencl as cl 7 | 8 | 9 | class CLWrapper: 10 | "class holds information about OpenCL state" 11 | 12 | def __init__(self, batchSize, maxT, maxC, kernelVariant=1, enableGPUDebug=False): 13 | """Specify size: number of batch elements, number of time-steps, number of characters. 14 | Set kernelVariant to either 1 or 2. Set enableGPUDebug to True to debug kernel via CodeXL.""" 15 | 16 | # force rebuild of program such that GPU debugger can attach to kernel 17 | self.enableGPUDebug = enableGPUDebug 18 | if enableGPUDebug: 19 | os.environ['PYOPENCL_COMPILER_OUTPUT'] = '1' 20 | os.environ['PYOPENCL_NO_CACHE'] = '1' 21 | 22 | # consts 23 | self.batchSize = batchSize 24 | self.maxT = maxT 25 | self.maxC = maxC 26 | assert kernelVariant in [1, 2] 27 | self.kernelVariant = kernelVariant 28 | 29 | # platform, context, queue 30 | platforms = cl.get_platforms() 31 | assert platforms 32 | self.platform = platforms[0] # take first platform 33 | devices = self.platform.get_devices(cl.device_type.GPU) # get GPU devices 34 | assert devices 35 | self.device = devices[0] # take first GPU 36 | self.context = cl.Context([self.device]) # context contains the first GPU 37 | self.queue = cl.CommandQueue(self.context, self.device) # command queue to first GPU 38 | 39 | # buffer 40 | sizeOfFloat32 = 4 41 | batchBufSize = batchSize * maxC * maxT * sizeOfFloat32 42 | self.batchBuf = cl.Buffer(self.context, cl.mem_flags.READ_ONLY, size=batchBufSize, hostbuf=None) 43 | self.res = np.zeros([batchSize, maxT]).astype(np.int32) 44 | self.resBuf = cl.Buffer(self.context, cl.mem_flags.WRITE_ONLY, self.res.nbytes) 45 | self.tmpBuf = cl.Buffer(self.context, cl.mem_flags.WRITE_ONLY, self.res.nbytes) 46 | 47 | # compile program and use defines for program-constants to avoid passing private variables 48 | buildOptions = '-D STEP_BEGIN={} -D MAX_T={} -D MAX_C={}'.format(2 ** math.ceil(math.log2(maxT)), maxT, maxC) 49 | self.program = cl.Program(self.context, open('best_path_cl.cl').read()).build(buildOptions) 50 | 51 | # variant 1: single pass 52 | if kernelVariant == 1: 53 | self.kernel1 = cl.Kernel(self.program, 'bestPathAndCollapse') 54 | self.kernel1.set_arg(0, self.batchBuf) 55 | self.kernel1.set_arg(1, self.resBuf) 56 | 57 | # all time-steps must fit into a work-group 58 | assert maxT <= self.kernel1.get_work_group_info(cl.kernel_work_group_info.WORK_GROUP_SIZE, self.device) 59 | 60 | # variant 2: two passes 61 | else: 62 | # kernel1: calculate best path 63 | self.kernel1 = cl.Kernel(self.program, 'bestPath') 64 | self.kernel1.set_arg(0, self.batchBuf) 65 | self.kernel1.set_arg(1, self.tmpBuf) 66 | 67 | # kernel2: collapse best path 68 | self.kernel2 = cl.Kernel(self.program, 'collapsePath') 69 | self.kernel2.set_arg(0, self.tmpBuf) 70 | self.kernel2.set_arg(1, self.resBuf) 71 | 72 | # all chars must fit into a work-group 73 | assert maxC <= self.kernel1.get_work_group_info(cl.kernel_work_group_info.WORK_GROUP_SIZE, self.device) 74 | 75 | def compute(self, batch): 76 | "compute best path for each batch element. Returns blank-terminated label strings for batch elements." 77 | 78 | # measure time in GPU debug mode 79 | if self.enableGPUDebug: 80 | t0 = time.time() 81 | 82 | # copy batch to device 83 | cl.enqueue_write_buffer(self.queue, self.batchBuf, batch.astype(np.float32), is_blocking=False) 84 | 85 | # one pass 86 | if self.kernelVariant == 1: 87 | cl.enqueue_nd_range_kernel(self.queue, self.kernel1, (self.batchSize, self.maxT), (1, self.maxT)) 88 | # two passes 89 | else: 90 | cl.enqueue_nd_range_kernel(self.queue, self.kernel1, (self.batchSize, self.maxT, self.maxC), 91 | (1, 1, self.maxC)) 92 | cl.enqueue_nd_range_kernel(self.queue, self.kernel2, (self.batchSize,), None) 93 | 94 | # copy result back from GPU and return it 95 | cl.enqueue_read_buffer(self.queue, self.resBuf, self.res, is_blocking=True) 96 | 97 | # measure time in GPU debug mode 98 | if self.enableGPUDebug: 99 | t1 = time.time() 100 | print('BestPathCL.compute(...) time: ', t1 - t0) 101 | 102 | return self.res 103 | 104 | 105 | def ctcBestPathCL(batch, classes, clWrapper): 106 | "implements best path decoding on the GPU with OpenCL" 107 | 108 | # compute best labeling 109 | labelStrBatch = clWrapper.compute(batch) 110 | 111 | # go over batch 112 | blank = len(classes) 113 | charStrBatch = [] 114 | for b in range(clWrapper.batchSize): 115 | # map to chars 116 | charStr = '' 117 | for label in labelStrBatch[b]: 118 | if label == blank: 119 | break 120 | charStr += classes[label] 121 | charStrBatch.append(charStr) 122 | 123 | return charStrBatch 124 | 125 | 126 | def testBestPathCL(): 127 | "test decoder" 128 | classes = 'ab' 129 | mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) 130 | maxT, maxC = mat.shape 131 | clWrapper = CLWrapper(1, maxT, maxC, enableGPUDebug=True) 132 | print('Test best path decoding (CL)') 133 | expected = '' 134 | actual = ctcBestPathCL(np.stack([mat]), classes, clWrapper)[0] 135 | print('Expected: "' + expected + '"') 136 | print('Actual: "' + actual + '"') 137 | print('OK' if expected == actual else 'ERROR') 138 | 139 | 140 | if __name__ == '__main__': 141 | testBestPathCL() 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance 2 | numpy 3 | pytest -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='ctc-decoder', 5 | version='1.0.1', 6 | description='Connectionist Temporal Classification decoders.', 7 | author='Harald Scheidl', 8 | packages=['ctc_decoder'], 9 | url="https://github.com/githubharald/CTCDecoder", 10 | install_requires=['editdistance', 'numpy'], 11 | python_requires='>=3.7' 12 | ) 13 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Testcases 2 | 3 | ## Mini example 4 | The RNN output matrix of the **Mini example** testcase contains 2 time-steps (t0 and t1) and 3 labels (a, b and - representing the CTC-blank). 5 | Best path decoding (see left figure) takes the most probable label per time-step which gives the path "--" and therefore the recognized text "" with probability 0.6 * 0.6 = 0.36. 6 | Beam search, prefix search and token passing calculate the probability of labelings. 7 | For the labeling "a" these algorithms sum over the paths "-a", "a-" and "aa" (see right figure) with probability 0.6 * 0.4 + 0.4 * 0.6 + 0.4 *0.4 = 0.64. 8 | The only path which gives "" still has probability 0.36, therefore "a" is the result returned by beam search, prefix search and token passing. 9 | 10 | ![mini](../doc/mini.png) 11 | 12 | ## Word example 13 | The **Word example** testcase contains a single word from the IAM Handwriting Database. 14 | It is used to test lexicon search. 15 | RNN output was generated with the [SimpleHTR](https://github.com/githubharald/SimpleHTR) model (by using the `--dump` option). 16 | Lexicon search first computes an approximation with best path decoding, then searches for similar words in a dictionary using a BK tree, and finally scores them by computing the loss and returning the most probable dictionary word. 17 | Best path decoding outputs "aircrapt", lexicon search is able to find similar words like "aircraft" and "airplane" in the dictionary, calculates a score for each of them and finally returns "aircraft", which is the correct result. 18 | The figure below shows the input image and the RNN output matrix with 32 time-steps and 80 classes (the last one being the CTC-blank). 19 | Each column sums to 1 and each entry represents the probability of seeing a label at a given time-step. 20 | 21 | 22 | ![word](../doc/word.png) 23 | 24 | ## Line example 25 | The ground-truth text of the **Line example** testcase is "the fake friend of the family, like the" and is a sample from the IAM Handwriting Database. 26 | This test case is used to test all algorithms except lexicon search. 27 | RNN output was generated by a partially trained TensorFlow model inspired by CRNN which essentially is a larger version of the [SimpleHTR](https://github.com/githubharald/SimpleHTR) model. 28 | The figure below shows the input image and the RNN output matrix with 100 time-steps and 80 classes. 29 | 30 | ![line](../doc/line.png) 31 | 32 | -------------------------------------------------------------------------------- /tests/test_bk_tree.py: -------------------------------------------------------------------------------- 1 | import editdistance as ed 2 | 3 | from ctc_decoder import BKTree 4 | 5 | 6 | def test_bk_tree(): 7 | "test BK tree on words from corpus" 8 | with open('../data/word/corpus.txt') as f: 9 | words = f.read().split() 10 | 11 | tolerance = 2 12 | t = BKTree(words) 13 | q = 'air' 14 | actual = sorted(t.query(q, tolerance)) 15 | expected = sorted([w for w in words if ed.eval(q, w) <= tolerance]) 16 | assert actual == expected 17 | -------------------------------------------------------------------------------- /tests/test_language_model.py: -------------------------------------------------------------------------------- 1 | from ctc_decoder import LanguageModel 2 | 3 | 4 | def test_char_bigram(): 5 | lm = LanguageModel('aab abc', 'ab') 6 | assert lm.get_char_bigram('a', 'a') == 1 / 3 7 | assert lm.get_char_bigram('a', 'b') == 2 / 3 8 | assert lm.get_char_bigram('b', 'a') == 0 9 | -------------------------------------------------------------------------------- /tests/test_mini_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from ctc_decoder import beam_search 5 | from ctc_decoder import best_path 6 | from ctc_decoder import prefix_search 7 | from ctc_decoder import probability, loss 8 | from ctc_decoder import token_passing 9 | 10 | 11 | @pytest.fixture 12 | def mat(): 13 | return np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) 14 | 15 | 16 | @pytest.fixture 17 | def chars(): 18 | return 'ab' 19 | 20 | 21 | def test_beam_search(mat, chars): 22 | expected = 'a' 23 | actual = beam_search(mat, chars) 24 | assert actual == expected 25 | 26 | 27 | def test_best_path(mat, chars): 28 | expected = '' 29 | actual = best_path(mat, chars) 30 | assert actual == expected 31 | 32 | 33 | def test_token_passing(mat, chars): 34 | expected = 'a' 35 | actual = token_passing(mat, chars, ['a', 'b', 'ab', 'ba']) 36 | assert actual == expected 37 | 38 | 39 | def test_prefix_search(mat, chars): 40 | expected = 'a' 41 | actual = prefix_search(mat, chars) 42 | assert actual == expected 43 | 44 | 45 | def test_probability(mat, chars): 46 | expected = 0.64 47 | actual = probability(mat, 'a', chars) 48 | assert actual == expected 49 | 50 | 51 | def test_loss(mat, chars): 52 | expected = -np.log(0.64) 53 | actual = loss(mat, 'a', chars) 54 | assert actual == expected 55 | -------------------------------------------------------------------------------- /tests/test_real_example.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from ctc_decoder import BKTree 7 | from ctc_decoder import LanguageModel 8 | from ctc_decoder import beam_search 9 | from ctc_decoder import best_path 10 | from ctc_decoder import lexicon_search 11 | from ctc_decoder import loss 12 | from ctc_decoder import prefix_search_heuristic_split 13 | from ctc_decoder import probability 14 | from ctc_decoder import token_passing 15 | 16 | 17 | def softmax(mat): 18 | maxT, _ = mat.shape # dim0=t, dim1=c 19 | res = np.zeros(mat.shape) 20 | for t in range(maxT): 21 | y = mat[t, :] 22 | e = np.exp(y) 23 | s = np.sum(e) 24 | res[t, :] = e / s 25 | return res 26 | 27 | 28 | def load_rnn_output(fn): 29 | return np.genfromtxt(fn, delimiter=';')[:, : -1] 30 | 31 | 32 | @pytest.fixture 33 | def line_mat(): 34 | return softmax(load_rnn_output('../data/line/rnnOutput.csv')) 35 | 36 | 37 | @pytest.fixture 38 | def word_mat(): 39 | return softmax(load_rnn_output('../data/word/rnnOutput.csv')) 40 | 41 | 42 | @pytest.fixture 43 | def chars(): 44 | return ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 45 | 46 | 47 | @pytest.fixture 48 | def corpus(): 49 | with open('../data/line/corpus.txt') as f: 50 | txt = f.read() 51 | return txt 52 | 53 | 54 | @pytest.fixture 55 | def words(): 56 | with open('../data/word/corpus.txt') as f: 57 | words = f.read().split() 58 | return words 59 | 60 | 61 | def test_line_example_best_path(line_mat, chars): 62 | mat = line_mat 63 | assert best_path(mat, chars) == 'the fak friend of the fomly hae tC' 64 | 65 | 66 | def test_line_example_prefix_search_heuristic_split(line_mat, chars): 67 | mat = line_mat 68 | assert prefix_search_heuristic_split(mat, chars) == 'the fak friend of the fomcly hae tC' 69 | 70 | 71 | def test_line_example_beam_search(line_mat, chars): 72 | mat = line_mat 73 | assert beam_search(mat, chars) == 'the fak friend of the fomcly hae tC' 74 | 75 | 76 | def test_line_example_beam_search_with_language_model(line_mat, chars, corpus): 77 | mat = line_mat 78 | 79 | # create language model from text corpus 80 | lm = LanguageModel(corpus, chars) 81 | 82 | assert beam_search(mat, chars, lm=lm) == 'the fake friend of the family, lie th' 83 | 84 | 85 | def test_line_example_token_passing(line_mat, chars, corpus): 86 | mat = line_mat 87 | 88 | # create language model from text corpus 89 | words = re.findall(r'\w+', corpus) 90 | 91 | assert token_passing(mat, chars, words) == 'the fake friend of the family fake the' 92 | 93 | 94 | def test_line_example_loss_and_probability(line_mat, chars): 95 | mat = line_mat 96 | gt = 'the fake friend of the family, like the' 97 | 98 | assert np.isclose(probability(mat, gt, chars), 6.31472642886565e-13) 99 | assert np.isclose(loss(mat, gt, chars), 28.090721774903226) 100 | 101 | 102 | def test_word_example_best_path(word_mat, chars, words): 103 | mat = word_mat 104 | assert best_path(mat, chars) == 'aircrapt' 105 | 106 | 107 | def test_word_example_lexicon_search(word_mat, chars, words): 108 | mat = word_mat 109 | 110 | # create BK tree from list of words 111 | bk_tree = BKTree(words) 112 | 113 | assert lexicon_search(mat, chars, bk_tree, tolerance=4) == 'aircraft' 114 | --------------------------------------------------------------------------------