├── ochre ├── __init__.py ├── cwl │ ├── vudnc-select-files.cwl │ ├── count-chars.cwl │ ├── rmgarbage.cwl │ ├── sac2gs-and-ocr.cwl │ ├── clin2018st-extract-text.cwl │ ├── create-data-division.cwl │ ├── remove-empty-files.cwl │ ├── kb-tss-concat-files.cwl │ ├── ocrevaluation-extract.cwl │ ├── char-align.cwl │ ├── remove-title-page.cwl │ ├── select-test-files.cwl │ ├── icdar2017st-extract-text.cwl │ ├── match-ocr-and-gs.cwl │ ├── lstm-synced-correct-ocr.cwl │ ├── vudnc2ocr-and-gs.cwl │ ├── merge-json.cwl │ ├── create-word-mappings.cwl │ ├── ocrevaluation-performance-wf.cwl │ ├── ocr-transform.cwl │ ├── onmt-tokenize-text.cwl │ ├── onmt-main.cwl │ ├── onmt-build-vocab.cwl │ ├── icdar2017st-extract-data.cwl │ ├── kb-tss-preprocess-single-dir.cwl │ ├── align-texts-wf.cwl │ ├── post-correct-dir.cwl │ ├── word-mapping-wf.cwl │ ├── post-correct-test-files.cwl │ ├── lowercase-directory.cwl │ └── sac-extract.cwl ├── select_vudnc_files.py ├── select_test_files.py ├── dncvu_select_ocr_and_gs_texts.py ├── count_chars.py ├── kb_tss_concat_files.py ├── remove_empty_files.py ├── merge_json.py ├── char_align.py ├── create_data_division.py ├── match_ocr_and_gs.py ├── sac2gs_and_ocr.py ├── icdar2017st_extract_text.py ├── remove_title_page.py ├── clin2018st_extract_text.py ├── ocrevaluation_extract.py ├── edlibutils.py ├── scramble.py ├── create_word_mappings.py ├── lstm_synced_correct_ocr.py ├── vudnc2ocr_and_gs.py ├── rmgarbage.py ├── keras_utils.py ├── train_seq2seq.py └── train_mt.py ├── tests ├── data │ ├── ocrevaluation │ │ ├── in │ │ │ ├── gs.txt │ │ │ └── ocr.txt │ │ └── out │ │ │ └── gs_out.html │ └── ocrevaluation-extract │ │ ├── out │ │ ├── empty-global.csv │ │ ├── in-global.csv │ │ ├── empty-character.csv │ │ └── in-character.csv │ │ └── in │ │ ├── empty.html │ │ └── in.html ├── test_ocrerrors.py ├── test_datagen.py ├── test_rmgarbage.py ├── test_utils.py ├── test_ocrevaluation_extract.py └── test_remove_title_page.py ├── requirements.txt ├── .gitignore ├── job.sh ├── nematus_translate.sh ├── .editorconfig ├── nematusA8P2.sh ├── nematusA8P3.sh ├── nematus.sh ├── transformer.sh ├── notebooks ├── lowercase-directory-workflow.ipynb ├── vudnc-preprocess-workflow.ipynb ├── post_correction_workflows.ipynb ├── icdar2019POCR-to-texts.ipynb ├── sac-preprocess-workflow.ipynb ├── preprocess-dbnl_ocr.ipynb ├── align-workflow.ipynb ├── icdar2017-ngrams.ipynb ├── 2017-baseline-nn.ipynb ├── kb_tss_preprocess.ipynb ├── ocr-evaluation-workflow.ipynb ├── Compare-performance-of-pred.ipynb ├── character_counts.ipynb ├── word-mapping-workflow.ipynb ├── improve-keras-datagen.ipynb ├── kb_xml_to_text_(unaligned).ipynb ├── try-lstm.ipynb ├── ICDAR2017_shared_task_workflows.ipynb └── fuzzy-string-matching.ipynb ├── setup.py ├── datagen.py ├── 2017_baseline.py └── lstm_synced.py /ochre/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import cwl_path 2 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation/in/gs.txt: -------------------------------------------------------------------------------- 1 | This is an example text. 2 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation/in/ocr.txt: -------------------------------------------------------------------------------- 1 | This is an cxample text. 2 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation-extract/out/empty-global.csv: -------------------------------------------------------------------------------- 1 | ,CER,WER,WER (order independent) 2 | empty,n/a,n/a,n/a 3 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation-extract/out/in-global.csv: -------------------------------------------------------------------------------- 1 | ,CER,WER,WER (order independent) 2 | in,4.17,20.00,20.00 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | lxml 3 | beautifulsoup4 4 | keras 5 | tensorflow 6 | edlib 7 | jupyter 8 | pytest 9 | nlppln>=0.3.1 10 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation-extract/out/empty-character.csv: -------------------------------------------------------------------------------- 1 | "Character",Hex code,Total,Spurious,Confused,Lost,Error rate 2 | "n/a",n/a,n/a,n/a,n/a,n/a,n/a 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | *.egg-info 3 | *.eggs 4 | .ipynb_checkpoints 5 | 6 | build 7 | dist 8 | .cache 9 | __pycache__ 10 | 11 | htmlcov 12 | .coverage 13 | coverage.xml 14 | .pytest_cache 15 | 16 | docs/_build 17 | docs/apidocs 18 | 19 | # ide 20 | .idea 21 | .eclipse 22 | .vscode 23 | 24 | # Mac 25 | .DS_Store 26 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation-extract/out/in-character.csv: -------------------------------------------------------------------------------- 1 | "Character",Hex code,Total,Spurious,Confused,Lost,Error rate 2 | " ",20,4,0,0,0,0.00 3 | ".",2e,1,0,0,0,0.00 4 | "T",54,1,0,0,0,0.00 5 | "a",61,2,0,0,0,0.00 6 | "e",65,3,0,1,0,33.33 7 | "h",68,1,0,0,0,0.00 8 | "i",69,2,0,0,0,0.00 9 | "l",6c,1,0,0,0,0.00 10 | "m",6d,1,0,0,0,0.00 11 | "n",6e,1,0,0,0,0.00 12 | "p",70,1,0,0,0,0.00 13 | "s",73,2,0,0,0,0.00 14 | "t",74,2,0,0,0,0.00 15 | "x",78,2,0,0,0,0.00 16 | -------------------------------------------------------------------------------- /ochre/cwl/vudnc-select-files.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.select_vudnc_files"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | in_dir: 14 | type: Directory 15 | inputBinding: 16 | position: 1 17 | 18 | stdout: cwl.output.json 19 | 20 | outputs: 21 | out_files: 22 | type: File[] 23 | -------------------------------------------------------------------------------- /ochre/cwl/count-chars.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.count_chars"] 12 | 13 | inputs: 14 | in_file: 15 | type: File 16 | inputBinding: 17 | position: 1 18 | 19 | outputs: 20 | char_counts: 21 | type: File 22 | outputBinding: 23 | glob: "*.json" 24 | -------------------------------------------------------------------------------- /job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=2017_baseline 4 | #SBATCH --output=2017_baseline.txt 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH -C TitanX 8 | #SBATCH --time=15:00:00 9 | #SBATCH --gres=gpu:1 10 | 11 | module load python/3.5.2 12 | module load python-extra/python3.5/r0.5.0 13 | module load cuda80/toolkit/8.0.61 14 | module load tensorflow/python3.x/gpu/r1.4.0-py3 15 | module load keras/python3.5/r2.0.2 16 | module load cuDNN/cuda80/5.1.5 17 | 18 | srun python /home/jvdzwaan/code/ochre/2017_baseline.py 19 | -------------------------------------------------------------------------------- /nematus_translate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=nematus_translate 4 | #SBATCH --output=nematus_translate.txt 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH -C TitanX 8 | #SBATCH --gres=gpu:1 9 | 10 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P1/pred/ 11 | for filename in /var/scratch/jvdzwaan/kb-ocr/A8P1/test/*.ocr; do 12 | srun python ~/code/nematus/nematus/translate.py -m /var/scratch/jvdzwaan/kb-ocr/A8P1/model/model-40000 -i "$filename" -o /var/scratch/jvdzwaan/kb-ocr/A8P1/pred/"$(basename "$filename" .ocr).pred" 13 | done 14 | -------------------------------------------------------------------------------- /ochre/cwl/rmgarbage.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "nlppln.commands.rmgarbage"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | in_file: 14 | type: File 15 | 16 | outputs: 17 | out_file: 18 | type: File 19 | outputBinding: 20 | glob: $(inputs.in_file.nameroot).txt 21 | metadata_out: 22 | type: File 23 | outputBinding: 24 | glob: $(inputs.in_file.nameroot).txt 25 | -------------------------------------------------------------------------------- /ochre/cwl/sac2gs-and-ocr.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.sac2gs_and_ocr"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | in_dir: 14 | type: Directory 15 | inputBinding: 16 | position: 1 17 | 18 | stdout: cwl.output.json 19 | 20 | outputs: 21 | gs_de: 22 | type: File[] 23 | ocr_de: 24 | type: File[] 25 | gs_fr: 26 | type: File[] 27 | ocr_fr: 28 | type: File[] 29 | -------------------------------------------------------------------------------- /ochre/cwl/clin2018st-extract-text.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.clin2018st_extract_text"] 12 | 13 | inputs: 14 | json_file: 15 | type: File 16 | inputBinding: 17 | position: 1 18 | 19 | outputs: 20 | gs_text: 21 | type: File 22 | outputBinding: 23 | glob: "*-gs.txt" 24 | err_text: 25 | type: File 26 | outputBinding: 27 | glob: "*-errors.txt" 28 | -------------------------------------------------------------------------------- /ochre/cwl/create-data-division.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.create_data_division"] 12 | 13 | inputs: 14 | in_dir: 15 | type: Directory 16 | inputBinding: 17 | position: 1 18 | out_name: 19 | type: string? 20 | inputBinding: 21 | prefix: --out_name 22 | 23 | outputs: 24 | metadata_out: 25 | type: File 26 | outputBinding: 27 | glob: "*.json" 28 | -------------------------------------------------------------------------------- /ochre/cwl/remove-empty-files.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | baseCommand: ["python", "-m", "ochre.remove_empty_files"] 6 | 7 | requirements: 8 | EnvVarRequirement: 9 | envDef: 10 | LC_ALL: C.UTF-8 11 | LANG: C.UTF-8 12 | 13 | inputs: 14 | ocr_dir: 15 | type: Directory 16 | inputBinding: 17 | position: 2 18 | gs_dir: 19 | type: Directory 20 | inputBinding: 21 | position: 1 22 | 23 | stdout: cwl.output.json 24 | 25 | outputs: 26 | ocr: 27 | type: File[] 28 | gs: 29 | type: File[] 30 | -------------------------------------------------------------------------------- /ochre/cwl/kb-tss-concat-files.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | InitialWorkDirRequirement: 11 | listing: $(inputs.in_files) 12 | 13 | baseCommand: ["python", "-m", "ochre.kb_tss_concat_files"] 14 | 15 | arguments: 16 | - valueFrom: $(runtime.outdir) 17 | position: 1 18 | 19 | inputs: 20 | in_files: 21 | type: File[] 22 | 23 | outputs: 24 | out_files: 25 | type: File[] 26 | outputBinding: 27 | glob: "*.txt" 28 | -------------------------------------------------------------------------------- /ochre/cwl/ocrevaluation-extract.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.ocrevaluation_extract"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | in_file: 14 | type: File 15 | inputBinding: 16 | position: 1 17 | 18 | outputs: 19 | character_data: 20 | type: File 21 | outputBinding: 22 | glob: "*-character.csv" 23 | global_data: 24 | type: File 25 | outputBinding: 26 | glob: "*-global.csv" 27 | -------------------------------------------------------------------------------- /ochre/cwl/char-align.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.char_align"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | ocr_text: 14 | type: File 15 | inputBinding: 16 | position: 1 17 | gs_text: 18 | type: File 19 | inputBinding: 20 | position: 2 21 | metadata: 22 | type: File 23 | inputBinding: 24 | position: 3 25 | 26 | outputs: 27 | out_file: 28 | type: File 29 | outputBinding: 30 | glob: "*.json" 31 | -------------------------------------------------------------------------------- /ochre/cwl/remove-title-page.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.remove_title_page"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | without_tp: 14 | type: File 15 | inputBinding: 16 | position: 1 17 | with_tp: 18 | type: File 19 | inputBinding: 20 | position: 2 21 | num_lines: 22 | type: int? 23 | inputBinding: 24 | prefix: -n 25 | 26 | outputs: 27 | out_file: 28 | type: File 29 | outputBinding: 30 | glob: "*" 31 | -------------------------------------------------------------------------------- /ochre/cwl/select-test-files.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.select_test_files"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | stdout: cwl.output.json 13 | 14 | inputs: 15 | in_dir: 16 | type: Directory 17 | inputBinding: 18 | position: 1 19 | datadivision: 20 | type: File 21 | inputBinding: 22 | position: 2 23 | name: 24 | type: string? 25 | inputBinding: 26 | prefix: --name 27 | 28 | outputs: 29 | out_files: 30 | type: File[] 31 | -------------------------------------------------------------------------------- /ochre/cwl/icdar2017st-extract-text.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.icdar2017st_extract_text"] 12 | 13 | inputs: 14 | in_file: 15 | type: File 16 | inputBinding: 17 | position: 1 18 | 19 | outputs: 20 | gs: 21 | type: File 22 | outputBinding: 23 | glob: "gs/*.txt" 24 | ocr: 25 | type: File 26 | outputBinding: 27 | glob: "ocr/*.txt" 28 | aligned: 29 | type: File 30 | outputBinding: 31 | glob: "*.json" 32 | -------------------------------------------------------------------------------- /ochre/cwl/match-ocr-and-gs.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.match_ocr_and_gs"] 12 | 13 | inputs: 14 | ocr_dir: 15 | type: Directory 16 | inputBinding: 17 | position: 0 18 | gs_dir: 19 | type: Directory 20 | inputBinding: 21 | position: 1 22 | 23 | outputs: 24 | ocr: 25 | type: Directory 26 | outputBinding: 27 | glob: $(runtime.outdir)/ocr 28 | gs: 29 | type: Directory 30 | outputBinding: 31 | glob: "gs" 32 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: http://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | trim_trailing_whitespace = true 11 | charset = utf-8 12 | 13 | # Matches multiple files with brace expansion notation 14 | # Set default charset 15 | [*.{js,py,java,r,R,html,cwl}] 16 | indent_style = space 17 | 18 | # 4 space indentation 19 | [*.{py,java,r,R}] 20 | indent_size = 4 21 | 22 | # 2 space indentation 23 | [*.{js,json,yml,html,cwl}] 24 | indent_size = 2 25 | 26 | [*.{md,Rmd}] 27 | trim_trailing_whitespace = false 28 | -------------------------------------------------------------------------------- /ochre/cwl/lstm-synced-correct-ocr.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.lstm_synced_correct_ocr"] 12 | 13 | inputs: 14 | model: 15 | type: File 16 | inputBinding: 17 | position: 1 18 | charset: 19 | type: File 20 | inputBinding: 21 | position: 2 22 | txt: 23 | type: File 24 | inputBinding: 25 | position: 3 26 | 27 | outputs: 28 | corrected: 29 | type: File 30 | outputBinding: 31 | glob: "$(inputs.txt.basename)" 32 | -------------------------------------------------------------------------------- /ochre/cwl/vudnc2ocr-and-gs.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ["python", "-m", "ochre.vudnc2ocr_and_gs"] 5 | 6 | requirements: 7 | EnvVarRequirement: 8 | envDef: 9 | LC_ALL: C.UTF-8 10 | LANG: C.UTF-8 11 | 12 | inputs: 13 | in_file: 14 | type: File 15 | inputBinding: 16 | position: 1 17 | out_dir: 18 | type: Directory? 19 | inputBinding: 20 | prefix: --out_dir= 21 | separate: false 22 | 23 | outputs: 24 | gs: 25 | type: File 26 | outputBinding: 27 | glob: "*.gs.txt" 28 | ocr: 29 | type: File 30 | outputBinding: 31 | glob: "*.ocr.txt" 32 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation-extract/in/empty.html: -------------------------------------------------------------------------------- 1 |

General results

2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
CERn/a
WERn/a
WER (order independent)n/a
13 |

Difference spotting

14 | 15 |
16 |

Error rate per character and type

17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
CharacterHex codeTotalSpuriousConfusedLostError rate
n/an/an/an/an/an/an/a
25 | -------------------------------------------------------------------------------- /ochre/cwl/merge-json.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | baseCommand: ["python", "-m", "ochre.merge_json"] 6 | 7 | requirements: 8 | InitialWorkDirRequirement: 9 | listing: $(inputs.in_files) 10 | EnvVarRequirement: 11 | envDef: 12 | LC_ALL: C.UTF-8 13 | LANG: C.UTF-8 14 | 15 | arguments: 16 | - valueFrom: $(runtime.outdir) 17 | position: 1 18 | 19 | inputs: 20 | in_files: 21 | type: File[] 22 | name: 23 | type: string? 24 | inputBinding: 25 | prefix: --name= 26 | separate: false 27 | 28 | outputs: 29 | merged: 30 | type: File 31 | outputBinding: 32 | glob: "*.csv" 33 | -------------------------------------------------------------------------------- /ochre/cwl/create-word-mappings.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwlrunner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | 5 | requirements: 6 | EnvVarRequirement: 7 | envDef: 8 | LC_ALL: C.UTF-8 9 | LANG: C.UTF-8 10 | 11 | baseCommand: ["python", "-m", "ochre.create_word_mappings"] 12 | 13 | inputs: 14 | saf: 15 | type: File 16 | inputBinding: 17 | position: 1 18 | alignments: 19 | type: File 20 | inputBinding: 21 | position: 2 22 | lowercase: 23 | type: boolean? 24 | inputBinding: 25 | prefix: --lowercase 26 | name: 27 | type: string? 28 | inputBinding: 29 | prefix: --name 30 | 31 | outputs: 32 | word_mapping: 33 | type: File 34 | outputBinding: 35 | glob: "*.csv" 36 | -------------------------------------------------------------------------------- /ochre/cwl/ocrevaluation-performance-wf.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | inputs: 5 | gt: File 6 | ocr: File 7 | xmx: string? 8 | outputs: 9 | character_data: 10 | outputSource: ocrevaluation-extract/character_data 11 | type: File 12 | global_data: 13 | outputSource: ocrevaluation-extract/global_data 14 | type: File 15 | steps: 16 | ocrevaluation: 17 | run: https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl 18 | in: 19 | gt: gt 20 | ocr: ocr 21 | xmx: xmx 22 | out: 23 | - out_file 24 | ocrevaluation-extract: 25 | run: ocrevaluation-extract.cwl 26 | in: 27 | in_file: ocrevaluation/out_file 28 | out: 29 | - character_data 30 | - global_data 31 | -------------------------------------------------------------------------------- /ochre/select_vudnc_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import os 4 | import json 5 | 6 | from nlppln.utils import cwl_file 7 | 8 | 9 | @click.command() 10 | @click.argument('dir_in', type=click.Path(exists=True)) 11 | def command(dir_in): 12 | files_out = [] 13 | 14 | newspapers = ['ad1951', 'nrc1950', 't1950', 'tr1950', 'vk1951'] 15 | 16 | for np in newspapers: 17 | path = os.path.join(dir_in, np) 18 | for f in os.listdir(path): 19 | fi = os.path.join(path, f) 20 | if fi.endswith('.folia.xml'): 21 | files_out.append(cwl_file(fi)) 22 | 23 | stdout_text = click.get_text_stream('stdout') 24 | stdout_text.write(json.dumps({'out_files': files_out})) 25 | 26 | 27 | if __name__ == '__main__': 28 | command() 29 | -------------------------------------------------------------------------------- /ochre/select_test_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import os 4 | import json 5 | 6 | from nlppln.utils import create_dirs, cwl_file 7 | from ochre.utils import get_files 8 | 9 | 10 | @click.command() 11 | @click.argument('in_dir', type=click.Path(exists=True)) 12 | @click.argument('datadivision', type=click.File(encoding='utf-8')) 13 | @click.option('--name', '-n', default='test') 14 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 15 | def command(in_dir, datadivision, name, out_dir): 16 | create_dirs(out_dir) 17 | 18 | div = json.load(datadivision) 19 | files_out = [cwl_file(f) for f in get_files(in_dir, div, name)] 20 | 21 | stdout_text = click.get_text_stream('stdout') 22 | stdout_text.write(json.dumps({'out_files': files_out})) 23 | 24 | if __name__ == '__main__': 25 | command() 26 | -------------------------------------------------------------------------------- /nematusA8P2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=nematus_kb-ocr 4 | #SBATCH --output=nematus_A8P2.txt 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH -C TitanX 8 | #SBATCH --time=00:05:00 9 | #SBATCH --gres=gpu:1 10 | 11 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P2/model/ 12 | srun python ~/code/nematus/nematus/train.py --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P2/model/model --reload latest_checkpoint --save_freq 10000 13 | -------------------------------------------------------------------------------- /nematusA8P3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=nematus_kb-ocr 4 | #SBATCH --output=nematus_A8P3.txt 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH -C TitanX 8 | #SBATCH --time=00:05:00 9 | #SBATCH --gres=gpu:1 10 | 11 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P3/model/ 12 | srun python ~/code/nematus/nematus/train.py --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P3/model/model --reload latest_checkpoint --save_freq 10000 13 | -------------------------------------------------------------------------------- /ochre/cwl/ocr-transform.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: ocr-transform 5 | 6 | requirements: 7 | - class: DockerRequirement 8 | dockerPull: ubma/ocr-fileformat 9 | - class: InlineJavascriptRequirement 10 | 11 | inputs: 12 | in_fmt: 13 | type: string 14 | inputBinding: 15 | position: 1 16 | out_fmt: 17 | type: string 18 | inputBinding: 19 | position: 2 20 | in_file: 21 | type: File 22 | inputBinding: 23 | position: 3 24 | 25 | stdout: | 26 | ${ 27 | var nameroot = inputs.in_file.nameroot; 28 | var ext = 'xml'; 29 | if(inputs.out_fmt == 'text'){ 30 | ext = 'txt'; 31 | } 32 | return nameroot + '.' + ext; 33 | } 34 | 35 | outputs: 36 | out_file: 37 | type: File 38 | outputBinding: 39 | glob: $(inputs.in_file.nameroot).* 40 | -------------------------------------------------------------------------------- /ochre/dncvu_select_ocr_and_gs_texts.py: -------------------------------------------------------------------------------- 1 | import click 2 | import json 3 | import os 4 | import glob 5 | 6 | from nlppln.utils import cwl_file 7 | 8 | 9 | @click.command() 10 | @click.argument('in_dir', type=click.Path(exists=True)) 11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 12 | def select(in_dir, out_dir): 13 | gs_files = sorted(glob.glob('{}{}*.gs.txt'.format(in_dir, os.sep))) 14 | gs_files = [cwl_file(os.path.abspath(f)) for f in gs_files] 15 | 16 | ocr_files = sorted(glob.glob('{}{}*.ocr.txt'.format(in_dir, os.sep))) 17 | ocr_files = [cwl_file(os.path.abspath(f)) for f in ocr_files] 18 | 19 | stdout_text = click.get_text_stream('stdout') 20 | stdout_text.write(json.dumps({'ocr_files': ocr_files, 21 | 'gs_files': gs_files})) 22 | 23 | 24 | if __name__ == '__main__': 25 | select() 26 | -------------------------------------------------------------------------------- /ochre/cwl/onmt-tokenize-text.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: onmt-tokenize-text 5 | 6 | doc: | 7 | Use OpenNMT tokenizer offline. 8 | See http://opennmt.net/OpenNMT-tf/tokenization.html for more information. 9 | 10 | inputs: 11 | text: 12 | type: File 13 | delimiter: 14 | type: string? 15 | inputBinding: 16 | prefix: --delimiter 17 | tokenizer: 18 | type: 19 | type: enum 20 | symbols: 21 | - CharacterTokenizer 22 | - SpaceTokenizer 23 | default: SpaceTokenizer 24 | inputBinding: 25 | prefix: --tokenizer 26 | tokenizer_config: 27 | type: File? 28 | inputBinding: 29 | prefix: --tokenizer_config 30 | out_name: 31 | type: string 32 | 33 | stdin: $(inputs.text.path) 34 | 35 | stdout: $(inputs.out_name) 36 | 37 | outputs: 38 | tokenized: stdout 39 | -------------------------------------------------------------------------------- /nematus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=nematus_kb-ocr 4 | #SBATCH --output=nematus_kb-ocr.txt 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH -C TitanX 8 | #SBATCH --begin=20:00 9 | #SBATCH --time=12:00:00 10 | #SBATCH --gres=gpu:1 11 | 12 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P1/model/ 13 | srun python ~/code/nematus/nematus/train.py --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P1/model/model --reload latest_checkpoint --save_freq 10000 14 | -------------------------------------------------------------------------------- /transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=nematus_kb-ocr 4 | #SBATCH --output=nematus_kb-ocr.txt 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH -C TitanX 8 | #SBATCH --begin=20:00 9 | #SBATCH --time=12:00:00 10 | #SBATCH --gres=gpu:1 11 | 12 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P1-transformer/model/ 13 | srun python ~/code/nematus/nematus/train.py --model_type transformer --learning_schedule transformer --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P1-transformer/model/model --reload latest_checkpoint --save_freq 10000 14 | -------------------------------------------------------------------------------- /ochre/count_chars.py: -------------------------------------------------------------------------------- 1 | import click 2 | import json 3 | import os 4 | import codecs 5 | 6 | from collections import Counter 7 | 8 | 9 | @click.command() 10 | @click.argument('in_file', type=click.Path(exists=True)) 11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 12 | def count_chars(in_file, out_dir): 13 | chars = Counter() 14 | with codecs.open(in_file, 'r', encoding='utf-8') as f: 15 | text = f.read().strip() 16 | 17 | for char in text: 18 | chars[char] += 1 19 | 20 | fname = os.path.basename(in_file.replace('.txt', '.json')) 21 | fname = os.path.join(out_dir, fname) 22 | with codecs.open(fname, 'w', encoding='utf-8') as f: 23 | json.dump(chars, f, indent=2) 24 | 25 | #for c, freq in chars.most_common(): 26 | # print('{}\t{}'.format(repr(c), freq)) 27 | # print(repr(c.encode('utf-8'))) 28 | # print(len(c.encode('utf-8'))) 29 | 30 | 31 | if __name__ == '__main__': 32 | count_chars() 33 | -------------------------------------------------------------------------------- /ochre/cwl/onmt-main.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: onmt-main 5 | 6 | inputs: 7 | run_type: 8 | type: 9 | type: enum 10 | symbols: 11 | - train_and_eval 12 | - train 13 | - eval 14 | - infer 15 | - export 16 | - score 17 | inputBinding: 18 | position: 0 19 | model_type: 20 | type: 21 | - "null" 22 | - type: enum 23 | symbols: 24 | - ListenAttendSpell 25 | - NMTBig 26 | - NMTMedium 27 | - NMTSmall 28 | - SeqTagger 29 | - Transformer 30 | - TransformerAAN 31 | - TransformerBig 32 | inputBinding: 33 | prefix: --model_type 34 | model: 35 | type: File? 36 | inputBinding: 37 | prefix: --model 38 | config: 39 | type: File[] 40 | inputBinding: 41 | prefix: --config 42 | 43 | outputs: 44 | out_dir: 45 | type: Directory 46 | -------------------------------------------------------------------------------- /ochre/cwl/onmt-build-vocab.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: CommandLineTool 4 | baseCommand: onmt-build-vocab 5 | 6 | inputs: 7 | in_files: 8 | type: File[] 9 | inputBinding: 10 | position: 1 11 | save_vocab: 12 | type: string 13 | inputBinding: 14 | prefix: --save_vocab 15 | min_frequency: 16 | type: int? 17 | default: 1 18 | inputBinding: 19 | prefix: --min_frequency 20 | size: 21 | type: int? 22 | default: 0 23 | inputBinding: 24 | prefix: --size 25 | without_sequence_tokens: 26 | type: bool? 27 | inputBinding: 28 | prefix: --without_sequence_tokens 29 | tokenizer: 30 | type: 31 | type: enum 32 | symbols: 33 | - CharacterTokenizer 34 | - SpaceTokenizer 35 | default: SpaceTokenizer 36 | tokenizer_config: 37 | type: File? 38 | inputBinding: 39 | prefix: --tokenizer_config 40 | 41 | outputs: 42 | out_files: 43 | type: File 44 | outputBinding: 45 | glob: $(inputs.save_vocab.basename) 46 | -------------------------------------------------------------------------------- /ochre/kb_tss_concat_files.py: -------------------------------------------------------------------------------- 1 | import click 2 | import codecs 3 | import os 4 | from collections import Counter 5 | 6 | from nlppln.utils import get_files, out_file_name 7 | 8 | 9 | @click.command() 10 | @click.argument('in_dir', type=click.Path(exists=True)) 11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 12 | def concat_files(in_dir, out_dir): 13 | in_files = get_files(in_dir) 14 | 15 | counts = Counter() 16 | 17 | for in_file in in_files: 18 | parts = os.path.basename(in_file).split(u'_') 19 | prefix = u'_'.join(parts[:2]) 20 | counts[prefix] += 1 21 | 22 | out_file = out_file_name(out_dir, prefix, ext='txt') 23 | 24 | with codecs.open(in_file, 'r', encoding='utf-8') as fi: 25 | text = fi.read() 26 | text = text.replace(u'\n', u'') 27 | text = text.strip() 28 | 29 | with codecs.open(out_file, 'a', encoding='utf-8') as fo: 30 | fo.write(text) 31 | fo.write(u'\n') 32 | 33 | 34 | if __name__ == '__main__': 35 | concat_files() 36 | -------------------------------------------------------------------------------- /tests/test_ocrerrors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | 4 | from ochre.ocrerrors import hyphenation_error, accent_error, real_word_error 5 | 6 | 7 | def test_hyphenation_error_true(): 8 | row = {'ocr': u'gesta- tionneerd', 9 | 'gs': u'gestationneerd' 10 | } 11 | assert hyphenation_error(row) is True 12 | 13 | 14 | def test_hyphenation_error_false(): 15 | row = {'ocr': u'gestationneerd', 16 | 'gs': u'gestationneerd' 17 | } 18 | assert hyphenation_error(row) is False 19 | 20 | 21 | def test_accent_error_true(): 22 | row = {'ocr': u'patienten', 23 | 'gs': u'patiënten' 24 | } 25 | assert accent_error(row) is True 26 | 27 | 28 | def test_accent_error_true_ocr(): 29 | row = {'ocr': u'patiënten', 30 | 'gs': u'patienten' 31 | } 32 | assert accent_error(row) is True 33 | 34 | 35 | def test_real_word_error_vs_accent_error(): 36 | row = {'ocr': u'zeker', 37 | 'gs': u'zéker' 38 | } 39 | terms = ['zeker'] 40 | assert real_word_error(row, terms) is False 41 | assert accent_error(row) is True 42 | -------------------------------------------------------------------------------- /ochre/remove_empty_files.py: -------------------------------------------------------------------------------- 1 | import click 2 | import json 3 | import os 4 | import codecs 5 | 6 | from nlppln.utils import cwl_file, get_files 7 | 8 | 9 | @click.command() 10 | @click.argument('gs_dir', type=click.Path(exists=True)) 11 | @click.argument('ocr_dir', type=click.Path(exists=True)) 12 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 13 | def remove_empty_files(gs_dir, ocr_dir, out_dir): 14 | ocr_out = [] 15 | gs_out = [] 16 | 17 | ocr_files = get_files(ocr_dir) 18 | gs_files = get_files(gs_dir) 19 | 20 | for ocr, gs in zip(ocr_files, gs_files): 21 | with codecs.open(ocr, 'r', encoding='utf-8') as f: 22 | ocr_text = f.read() 23 | 24 | with codecs.open(gs, 'r', encoding='utf-8') as f: 25 | gs_text = f.read() 26 | 27 | if len(ocr_text.strip()) > 0 and len(gs_text.strip()) > 0: 28 | ocr_out.append(cwl_file(ocr)) 29 | gs_out.append(cwl_file(gs)) 30 | 31 | stdout_text = click.get_text_stream('stdout') 32 | stdout_text.write(json.dumps({'ocr': ocr_out, 'gs': gs_out})) 33 | 34 | 35 | if __name__ == '__main__': 36 | remove_empty_files() 37 | -------------------------------------------------------------------------------- /ochre/merge_json.py: -------------------------------------------------------------------------------- 1 | import click 2 | import json 3 | import os 4 | import glob 5 | import codecs 6 | import uuid 7 | import pandas as pd 8 | 9 | 10 | @click.command() 11 | @click.argument('in_dir', type=click.Path(exists=True)) 12 | @click.option('--name', '-n', default='{}.csv'.format(uuid.uuid4())) 13 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 14 | def merge_json(in_dir, name, out_dir): 15 | in_files = glob.glob('{}{}*.json'.format(in_dir, os.sep)) 16 | idx = [os.path.basename(f) for f in in_files] 17 | 18 | dfs = [] 19 | for in_file in in_files: 20 | with codecs.open(in_file, 'r', encoding='utf-8') as f: 21 | data = json.load(f) 22 | 23 | # Make sure it works if the json file contains a list of dictionaries 24 | # instead of just a single dictionary. 25 | if not isinstance(data, list): 26 | data = [data] 27 | for item in data: 28 | dfs.append(item) 29 | 30 | if len(idx) != len(dfs): 31 | result = pd.DataFrame(dfs) 32 | else: 33 | result = pd.DataFrame(dfs, index=idx) 34 | result = result.fillna(0) 35 | 36 | out_file = os.path.join(out_dir, name) 37 | result.to_csv(out_file, encoding='utf-8') 38 | 39 | 40 | if __name__ == '__main__': 41 | merge_json() 42 | -------------------------------------------------------------------------------- /tests/test_datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ochre.datagen import DataGenerator 4 | 5 | 6 | def dgen(): 7 | ocr_seqs = ['abc', 'ab', 'ca8'] 8 | gs_seqs = ['abc', 'bb', 'ca'] 9 | p_char = 'P' 10 | oov_char = '@' 11 | n = 3 12 | ci = {'a': 0, 'b': 1, 'c': 2, p_char: 3, oov_char: 4} 13 | dg = DataGenerator(xData=ocr_seqs, yData=gs_seqs, char_to_int=ci, 14 | seq_length=n, padding_char=p_char, oov_char=oov_char, 15 | batch_size=1, shuffle=False) 16 | 17 | return dg 18 | 19 | 20 | def test_dg(): 21 | dg = dgen() 22 | 23 | assert dg.n_vocab == len(dg.char_to_int) 24 | assert len(dg) == 3 25 | 26 | x, y = dg[0] 27 | 28 | print(x) 29 | print(y) 30 | 31 | assert np.array_equal(x[0], np.array([0, 1, 2])) 32 | assert np.array_equal(y[0], np.array([[1, 0, 0, 0, 0], 33 | [0, 1, 0, 0, 0], 34 | [0, 0, 1, 0, 0]])) 35 | 36 | 37 | def test_convert_sample(): 38 | dg = dgen() 39 | 40 | cases = { 41 | 'aaa': np.array([0, 0, 0]), 42 | 'a': np.array([0, 3, 3]), 43 | 'b8': np.array([1, 4, 3]) 44 | } 45 | 46 | for inp, outp in cases.items(): 47 | assert np.array_equal(dg._convert_sample(inp), outp) 48 | -------------------------------------------------------------------------------- /notebooks/lowercase-directory-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import nlppln\n", 20 | "\n", 21 | "with nlppln.WorkflowGenerator() as wf:\n", 22 | " wf.load(steps_dir='../cwl/')\n", 23 | " print(wf.list_steps())\n", 24 | "\n", 25 | " in_dir = wf.add_input(in_dir='Directory')\n", 26 | " dir_name = wf.add_input(dir_name='string', default='gs_lowercase')\n", 27 | "\n", 28 | " txt_files = wf.ls(in_dir=in_dir)\n", 29 | " lowercase = wf.lowercase(in_file=txt_files, scatter='in_file', scatter_method='dotproduct')\n", 30 | " out_dir = wf.save_files_to_dir(dir_name=dir_name, in_files=lowercase)\n", 31 | "\n", 32 | " wf.add_outputs(out_dir=out_dir)\n", 33 | "\n", 34 | " wf.save('../cwl/lowercase-directory.cwl')" 35 | ] 36 | } 37 | ], 38 | "metadata": { 39 | "language_info": { 40 | "name": "python", 41 | "pygments_lexer": "ipython3" 42 | } 43 | }, 44 | "nbformat": 4, 45 | "nbformat_minor": 2 46 | } 47 | -------------------------------------------------------------------------------- /ochre/char_align.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import os 5 | import json 6 | 7 | from nlppln.utils import create_dirs, out_file_name 8 | 9 | from .edlibutils import align_characters 10 | 11 | 12 | @click.command() 13 | @click.argument('ocr_text', type=click.File(mode='r', encoding='utf-8')) 14 | @click.argument('gs_text', type=click.File(mode='r', encoding='utf-8')) 15 | @click.argument('metadata', type=click.File(mode='r', encoding='utf-8')) 16 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 17 | def command(ocr_text, gs_text, metadata, out_dir): 18 | create_dirs(out_dir) 19 | 20 | ocr = ocr_text.read() 21 | gs = gs_text.read() 22 | md = json.load(metadata) 23 | 24 | check = True 25 | # Too many strange characters, so disable sanity check 26 | if len(set(ocr+gs)) > 127: 27 | check = False 28 | 29 | ocr_a, gs_a = align_characters(ocr, gs, md['cigar'], sanity_check=check) 30 | 31 | out_file = out_file_name(out_dir, md['doc_id'], 'json') 32 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 33 | try: 34 | json.dump({'ocr': ocr_a, 'gs': gs_a}, f, encoding='utf-8') 35 | except TypeError: 36 | json.dump({'ocr': ocr_a, 'gs': gs_a}, f) 37 | 38 | 39 | if __name__ == '__main__': 40 | command() 41 | -------------------------------------------------------------------------------- /ochre/cwl/icdar2017st-extract-data.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | requirements: 5 | - class: ScatterFeatureRequirement 6 | inputs: 7 | in_dir: Directory 8 | ocr_dir_name: string 9 | gs_dir_name: string 10 | aligned_dir_name: string 11 | outputs: 12 | gs_dir: 13 | type: Directory 14 | outputSource: save-files-to-dir-6/out 15 | ocr_dir: 16 | type: Directory 17 | outputSource: save-files-to-dir-7/out 18 | aligned_dir: 19 | type: Directory 20 | outputSource: save-files-to-dir-8/out 21 | steps: 22 | ls-3: 23 | run: ls.cwl 24 | in: 25 | in_dir: in_dir 26 | out: 27 | - out_files 28 | icdar2017st-extract-text: 29 | run: icdar2017st-extract-text.cwl 30 | in: 31 | in_file: ls-3/out_files 32 | out: 33 | - aligned 34 | - gs 35 | - ocr 36 | scatter: 37 | - in_file 38 | scatterMethod: dotproduct 39 | save-files-to-dir-6: 40 | run: save-files-to-dir.cwl 41 | in: 42 | dir_name: gs_dir_name 43 | in_files: icdar2017st-extract-text/gs 44 | out: 45 | - out 46 | save-files-to-dir-7: 47 | run: save-files-to-dir.cwl 48 | in: 49 | dir_name: ocr_dir_name 50 | in_files: icdar2017st-extract-text/ocr 51 | out: 52 | - out 53 | save-files-to-dir-8: 54 | run: save-files-to-dir.cwl 55 | in: 56 | dir_name: aligned_dir_name 57 | in_files: icdar2017st-extract-text/aligned 58 | out: 59 | - out 60 | -------------------------------------------------------------------------------- /ochre/cwl/kb-tss-preprocess-single-dir.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | requirements: 5 | - class: ScatterFeatureRequirement 6 | inputs: 7 | in_dir: Directory 8 | recursive: 9 | default: true 10 | type: boolean 11 | endswith: 12 | default: alto.xml 13 | type: string 14 | element: 15 | default: 16 | - SP 17 | type: string[] 18 | in_fmt: 19 | default: alto 20 | type: string 21 | out_fmt: 22 | default: text 23 | type: string 24 | outputs: 25 | text_files: 26 | type: 27 | type: array 28 | items: File 29 | outputSource: kb-tss-concat-files/out_files 30 | steps: 31 | ls: 32 | run: ls.cwl 33 | in: 34 | endswith: endswith 35 | in_dir: in_dir 36 | recursive: recursive 37 | out: 38 | - out_files 39 | remove-xml-elements-1: 40 | run: remove-xml-elements.cwl 41 | in: 42 | xml_file: ls/out_files 43 | element: element 44 | out: 45 | - out_file 46 | scatter: 47 | - xml_file 48 | scatterMethod: dotproduct 49 | ocr-transform-1: 50 | run: ocr-transform.cwl 51 | in: 52 | out_fmt: out_fmt 53 | in_file: remove-xml-elements-1/out_file 54 | in_fmt: in_fmt 55 | out: 56 | - out_file 57 | scatter: 58 | - in_file 59 | scatterMethod: dotproduct 60 | kb-tss-concat-files: 61 | run: kb-tss-concat-files.cwl 62 | in: 63 | in_files: ocr-transform-1/out_file 64 | out: 65 | - out_files 66 | -------------------------------------------------------------------------------- /ochre/cwl/align-texts-wf.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | requirements: 5 | - class: ScatterFeatureRequirement 6 | inputs: 7 | gs: File[] 8 | ocr: File[] 9 | align_m: 10 | default: merged_metadata.csv 11 | type: string 12 | align_c: 13 | default: merged_changes.csv 14 | type: string 15 | outputs: 16 | alignments: 17 | outputSource: char-align-1/out_file 18 | type: 19 | type: array 20 | items: File 21 | metadata: 22 | outputSource: merge-json-2/merged 23 | type: File 24 | changes: 25 | outputSource: merge-json-3/merged 26 | type: File 27 | steps: 28 | align-1: 29 | run: https://raw.githubusercontent.com/nlppln/edlib-align/master/align.cwl 30 | in: 31 | file1: ocr 32 | file2: gs 33 | out: 34 | - changes 35 | - metadata 36 | scatter: 37 | - file1 38 | - file2 39 | scatterMethod: dotproduct 40 | merge-json-2: 41 | run: merge-json.cwl 42 | in: 43 | in_files: align-1/metadata 44 | name: align_m 45 | out: 46 | - merged 47 | merge-json-3: 48 | run: merge-json.cwl 49 | in: 50 | in_files: align-1/changes 51 | name: align_c 52 | out: 53 | - merged 54 | char-align-1: 55 | run: char-align.cwl 56 | in: 57 | gs_text: gs 58 | metadata: align-1/metadata 59 | ocr_text: ocr 60 | out: 61 | - out_file 62 | scatter: 63 | - gs_text 64 | - ocr_text 65 | - metadata 66 | scatterMethod: dotproduct 67 | -------------------------------------------------------------------------------- /ochre/create_data_division.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import os 5 | import json 6 | import numpy as np 7 | 8 | from nlppln.utils import create_dirs, get_files 9 | 10 | 11 | @click.command() 12 | @click.argument('in_dir', type=click.Path(exists=True)) 13 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 14 | @click.option('--out_name', default='datadivision.json') 15 | def command(in_dir, out_dir, out_name): 16 | """Create a division of the data in train, test and validation sets. 17 | 18 | The result is stored to a JSON file, so it can be reused. 19 | """ 20 | # TODO: make seed and percentages options 21 | SEED = 4 22 | TEST_PERCENTAGE = 10 23 | VAL_PERCENTAGE = 10 24 | 25 | create_dirs(out_dir) 26 | 27 | in_files = get_files(in_dir) 28 | 29 | np.random.seed(SEED) 30 | np.random.shuffle(in_files) 31 | 32 | n_test = int(len(in_files)/100.0 * TEST_PERCENTAGE) 33 | n_val = int(len(in_files)/100.0 * VAL_PERCENTAGE) 34 | 35 | validation_texts = in_files[0:n_val] 36 | test_texts = in_files[n_val:n_val+n_test] 37 | train_texts = in_files[n_val+n_test:] 38 | 39 | division = { 40 | 'train': [os.path.basename(t) for t in train_texts], 41 | 'val': [os.path.basename(t) for t in validation_texts], 42 | 'test': [os.path.basename(t) for t in test_texts] 43 | } 44 | 45 | out_file = os.path.join(out_dir, out_name) 46 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 47 | json.dump(division, f, indent=4) 48 | 49 | 50 | if __name__ == '__main__': 51 | command() 52 | -------------------------------------------------------------------------------- /ochre/match_ocr_and_gs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Select files that occur in both input directories and save them. 3 | 4 | Given two input directories, create ocr and gs output directories containing 5 | the files that occur in both. 6 | 7 | Files are actually copied, to prevent problems with CWL outputs. 8 | """ 9 | import os 10 | import click 11 | import shutil 12 | 13 | from nlppln.utils import get_files, create_dirs, out_file_name 14 | 15 | 16 | def copy_file(fi, name, out_dir, dest): 17 | fo = out_file_name(os.path.join(out_dir, dest), name) 18 | create_dirs(fo, is_file=True) 19 | shutil.copy2(fi, fo) 20 | 21 | 22 | @click.command() 23 | @click.argument('ocr_dir', type=click.Path(exists=True)) 24 | @click.argument('gs_dir', type=click.Path(exists=True)) 25 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 26 | def match_ocr_and_gs(ocr_dir, gs_dir, out_dir): 27 | create_dirs(out_dir) 28 | 29 | ocr_files = {os.path.basename(f): f for f in get_files(ocr_dir)} 30 | gs_files = {os.path.basename(f): f for f in get_files(gs_dir)} 31 | 32 | ocr = set(ocr_files.keys()) 33 | gs = set(gs_files.keys()) 34 | 35 | if len(ocr) == 0: 36 | raise ValueError('No ocr files in directory "{}".'.format(ocr_dir)) 37 | if len(gs) == 0: 38 | raise ValueError('No gs files in directory "{}".'.format(gs_dir)) 39 | 40 | keep = ocr.intersection(gs) 41 | 42 | if len(keep) == 0: 43 | raise ValueError('No matching ocr and gs files.') 44 | 45 | for name in keep: 46 | copy_file(ocr_files[name], name, out_dir, 'ocr') 47 | copy_file(gs_files[name], name, out_dir, 'gs') 48 | 49 | 50 | if __name__ == '__main__': 51 | match_ocr_and_gs() 52 | -------------------------------------------------------------------------------- /tests/test_rmgarbage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from ochre.rmgarbage import rmgarbage_long, rmgarbage_alphanumeric, \ 3 | rmgarbage_row, rmgarbage_vowels, \ 4 | rmgarbage_punctuation, rmgarbage_case 5 | 6 | 7 | def test_rmgarbage_long(): 8 | assert rmgarbage_long('short') is False 9 | long_string = 'Ii'*21 10 | assert rmgarbage_long(long_string) is True 11 | 12 | 13 | def test_rmgarbage_alphanumeric(): 14 | assert rmgarbage_alphanumeric('.M~y~l~i~c~.I~') is True 15 | assert rmgarbage_alphanumeric('______________J.~:ys~,.j}ss.') is True 16 | assert rmgarbage_alphanumeric('14.9tv="~;ia.~:..') is True 17 | assert rmgarbage_alphanumeric('.') is False 18 | 19 | 20 | def test_rmgarbage_row(): 21 | assert rmgarbage_row('111111111111111111111111') is True 22 | assert rmgarbage_row('Pnlhrrrr') is True 23 | assert rmgarbage_row('11111k1U1M.il.uu4ailuidt]i') is True 24 | 25 | 26 | def test_rmgarbage_row_non_ascii(): 27 | assert rmgarbage_row(u'ÐÐÐÐææææ') is True 28 | 29 | 30 | def test_rmgarbage_vowels(): 31 | assert rmgarbage_vowels('CslwWkrm') is True 32 | assert rmgarbage_vowels('Tptpmn') is True 33 | assert rmgarbage_vowels('Thlrlnd') is True 34 | 35 | 36 | def test_rmgarbage_punctuation(): 37 | assert rmgarbage_punctuation('ab,cde,fg') is False 38 | assert rmgarbage_punctuation('btkvdy@us1s8') is True 39 | assert rmgarbage_punctuation('w.a.e.~tcet~oe~') is True 40 | assert rmgarbage_punctuation('iA,1llfllwl~flII~N') is True 41 | 42 | 43 | def test_rmgarbage_case(): 44 | assert rmgarbage_case('bAa') is True 45 | assert rmgarbage_case('aepauWetectronic') is True 46 | assert rmgarbage_case('sUatigraphic') is True 47 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from ochre.utils import read_text_to_predict, get_char_to_int, \ 2 | to_space_tokenized 3 | 4 | 5 | def test_read_text_to_predict_no_embedding(): 6 | text = u'aaaaaa' 7 | seq_length = 3 8 | lowercase = True 9 | char_to_int = get_char_to_int(u'a\n') 10 | n_vocab = len(char_to_int) 11 | padding_char = u'\n' 12 | predict_chars = 0 13 | step = 1 14 | char_embedding = False 15 | 16 | result = read_text_to_predict(text, seq_length, lowercase, n_vocab, 17 | char_to_int, padding_char, 18 | predict_chars=predict_chars, step=step, 19 | char_embedding=char_embedding) 20 | 21 | # The result contains one hot encoded sequences 22 | assert result.dtype == bool 23 | assert result.shape == (4, seq_length, n_vocab) 24 | 25 | 26 | def test_read_text_to_predict_embedding(): 27 | text = u'aaaaaa' 28 | seq_length = 3 29 | lowercase = True 30 | char_to_int = get_char_to_int(u'a\n') 31 | n_vocab = len(char_to_int) 32 | padding_char = u'\n' 33 | predict_chars = 0 34 | step = 1 35 | char_embedding = True 36 | 37 | result = read_text_to_predict(text, seq_length, lowercase, n_vocab, 38 | char_to_int, padding_char, 39 | predict_chars=predict_chars, step=step, 40 | char_embedding=char_embedding) 41 | 42 | # The result contains lists of ints 43 | assert result.dtype == int 44 | assert result.shape == (4, seq_length) 45 | 46 | 47 | def test_to_space_tokenized(): 48 | result = to_space_tokenized('Dit is een test.') 49 | assert 'D i t i s e e n t e s t .' == result 50 | -------------------------------------------------------------------------------- /ochre/sac2gs_and_ocr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import json 5 | import os 6 | from nlppln.utils import create_dirs, get_files, cwl_file 7 | 8 | 9 | @click.command() 10 | @click.argument('in_dir', type=click.Path(exists=True)) 11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 12 | def sac2gs_and_ocr(in_dir, out_dir): 13 | result = {} 14 | result['gs_de'] = [] 15 | result['ocr_de'] = [] 16 | result['gs_fr'] = [] 17 | result['ocr_fr'] = [] 18 | 19 | files = {} 20 | 21 | for i in range(1864, 1900): 22 | try: 23 | in_files = get_files(os.path.join(in_dir, str(i))) 24 | for fi in in_files: 25 | language = 'de' 26 | typ = 'gs' 27 | bn = os.path.basename(fi) 28 | 29 | if bn.endswith('ocr'): 30 | typ = 'ocr' 31 | if 'fr' in bn: 32 | language = 'fr' 33 | with codecs.open(fi, encoding='utf-8') as f: 34 | text = f.read() 35 | fname = '{}-{}-{}.txt'.format(i, language, typ) 36 | out_file = os.path.join(out_dir, fname) 37 | create_dirs(out_file) 38 | with codecs.open(out_file, 'a', encoding='utf-8') as fo: 39 | fo.write(text) 40 | if out_file not in files: 41 | label = '{}_{}'.format(typ, language) 42 | result[label].append(cwl_file(out_file)) 43 | files[out_file] = None 44 | except OSError: 45 | pass 46 | 47 | stdout_text = click.get_text_stream('stdout') 48 | stdout_text.write(json.dumps(result)) 49 | 50 | 51 | if __name__ == '__main__': 52 | sac2gs_and_ocr() 53 | -------------------------------------------------------------------------------- /ochre/icdar2017st_extract_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import os 5 | import json 6 | 7 | from nlppln.utils import create_dirs, out_file_name 8 | 9 | 10 | def to_character_list(string): 11 | l = [] 12 | for c in string: 13 | if c in ('@', '#'): 14 | l.append('') 15 | else: 16 | l.append(c) 17 | return l 18 | 19 | 20 | @click.command() 21 | @click.argument('in_file', type=click.File(encoding='utf-8')) 22 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 23 | def command(in_file, out_dir): 24 | create_dirs(out_dir) 25 | 26 | lines = in_file.readlines() 27 | # OCR_toInput: lines[0][:14] 28 | # OCR_aligned: lines[1][:14] 29 | # GS_aligned: lines[2][:14] 30 | ocr = to_character_list(lines[1][14:].strip()) 31 | gs = to_character_list(lines[2][14:].strip()) 32 | 33 | # Write texts 34 | out_file = out_file_name(os.path.join(out_dir, 'ocr'), os.path.basename(in_file.name)) 35 | print(out_file) 36 | create_dirs(out_file) 37 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 38 | f.write(u''.join(ocr)) 39 | 40 | out_file = out_file_name(os.path.join(out_dir, 'gs'), os.path.basename(in_file.name)) 41 | print(out_file) 42 | create_dirs(out_file) 43 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 44 | f.write(u''.join(gs)) 45 | 46 | out_file = out_file_name(out_dir, os.path.basename(in_file.name), 'json') 47 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 48 | json.dump({'ocr': ocr, 'gs': gs}, f, encoding='utf-8', indent=4) 49 | 50 | # out_file = out_file_name(out_dir, fi, 'json') 51 | # with codecs.open(out_file, 'wb', encoding='utf-8') as f: 52 | # pass 53 | 54 | 55 | if __name__ == '__main__': 56 | command() 57 | -------------------------------------------------------------------------------- /ochre/remove_title_page.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import os 4 | import edlib 5 | import warnings 6 | import codecs 7 | import shutil 8 | 9 | from nlppln.utils import out_file_name 10 | 11 | 12 | @click.command() 13 | @click.argument('txt_file_without_tp', type=click.File(encoding='utf-8')) 14 | @click.argument('txt_file_with_tp', type=click.File(encoding='utf-8')) 15 | @click.option('--num_lines', '-n', default=100) 16 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 17 | def remove_title_page(txt_file_without_tp, txt_file_with_tp, num_lines, 18 | out_dir): 19 | result = None 20 | lines_without = txt_file_without_tp.readlines() 21 | lines_with = txt_file_with_tp.readlines() 22 | 23 | without_txt = ''.join(lines_without[:num_lines]).lower() 24 | with_txt = ''.join(lines_with[:num_lines]).lower() 25 | res = edlib.align(without_txt, with_txt) 26 | prev_ld = res['editDistance'] 27 | 28 | for i in range(num_lines): 29 | without_txt = ''.join(lines_without[:num_lines]).lower() 30 | with_txt = ''.join(lines_with[i:num_lines]).lower() 31 | 32 | res = edlib.align(without_txt, with_txt) 33 | ld = res['editDistance'] 34 | 35 | if ld > prev_ld: 36 | result = ''.join(lines_with[i-1:]) 37 | break 38 | elif ld < prev_ld: 39 | prev_ld = ld 40 | 41 | if result is None: 42 | warnings.warn('No title page found') 43 | out_file = out_file_name(out_dir, txt_file_with_tp.name) 44 | shutil.copy2(txt_file_with_tp.name, out_file) 45 | else: 46 | out_file = out_file_name(out_dir, txt_file_with_tp.name) 47 | with codecs.open(out_file, 'w', encoding='utf8') as f: 48 | f.write(result) 49 | 50 | 51 | if __name__ == '__main__': 52 | remove_title_page() 53 | -------------------------------------------------------------------------------- /ochre/cwl/post-correct-dir.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | requirements: 5 | - class: ScatterFeatureRequirement 6 | inputs: 7 | in_dir: Directory 8 | charset: File 9 | model: File 10 | outputs: 11 | corrected: 12 | type: 13 | items: File 14 | type: array 15 | outputSource: lstm-synced-correct-ocr/corrected 16 | steps: 17 | ls-4: 18 | run: 19 | cwlVersion: v1.0 20 | class: CommandLineTool 21 | baseCommand: [python, -m, nlppln.commands.ls] 22 | 23 | inputs: 24 | - type: Directory 25 | inputBinding: 26 | position: 2 27 | id: _:ls-4#in_dir 28 | - type: 29 | - 'null' 30 | - boolean 31 | inputBinding: 32 | prefix: --recursive 33 | 34 | id: _:ls-4#recursive 35 | stdout: cwl.output.json 36 | 37 | outputs: 38 | - type: 39 | type: array 40 | items: File 41 | id: _:ls-4#out_files 42 | id: _:ls-4 43 | in: 44 | in_dir: in_dir 45 | out: 46 | - out_files 47 | lstm-synced-correct-ocr: 48 | run: 49 | cwlVersion: v1.0 50 | class: CommandLineTool 51 | baseCommand: [python, -m, ochre.lstm_synced_correct_ocr] 52 | 53 | inputs: 54 | - type: File 55 | inputBinding: 56 | position: 2 57 | id: _:lstm-synced-correct-ocr#charset 58 | - type: File 59 | inputBinding: 60 | position: 1 61 | id: _:lstm-synced-correct-ocr#model 62 | - type: File 63 | inputBinding: 64 | position: 3 65 | 66 | id: _:lstm-synced-correct-ocr#txt 67 | outputs: 68 | - type: File 69 | outputBinding: 70 | glob: $(inputs.txt.basename) 71 | id: _:lstm-synced-correct-ocr#corrected 72 | id: _:lstm-synced-correct-ocr 73 | in: 74 | txt: ls-4/out_files 75 | model: model 76 | charset: charset 77 | out: 78 | - corrected 79 | scatter: 80 | - txt 81 | scatterMethod: dotproduct 82 | -------------------------------------------------------------------------------- /ochre/clin2018st_extract_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import json 5 | import os 6 | 7 | from nlppln.utils import create_dirs, remove_ext 8 | 9 | 10 | @click.command() 11 | @click.argument('json_file', type=click.File(encoding='utf-8')) 12 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 13 | def clin2018st_extract_text(json_file, out_dir): 14 | create_dirs(out_dir) 15 | 16 | corrections = {} 17 | gs_text = [] 18 | text_with_errors = [] 19 | 20 | text = json.load(json_file) 21 | for w in text['corrections']: 22 | span = w['span'] 23 | # TODO: fix 'after' 24 | if 'after' in w.keys(): 25 | print('Found "after" in {}.'.format(os.path.basename(json_file.name))) 26 | for i, w_id in enumerate(span): 27 | corrections[w_id] = {} 28 | if i == 0: 29 | corrections[w_id]['text'] = w['text'] 30 | else: 31 | corrections[w_id]['text'] = u'' 32 | corrections[w_id]['last'] = False 33 | if i == (len(span) - 1): 34 | corrections[w_id]['last'] = True 35 | 36 | for w in text['words']: 37 | w_id = w['id'] 38 | gs_text.append(w['text']) 39 | if w_id in corrections.keys(): 40 | text_with_errors.append(corrections[w_id]['text']) 41 | else: 42 | text_with_errors.append(w['text']) 43 | if w['space']: 44 | gs_text.append(u' ') 45 | text_with_errors.append(u' ') 46 | 47 | gs_file = remove_ext(json_file.name) 48 | gs_file = os.path.join(out_dir, '{}-gs.txt'.format(gs_file)) 49 | with codecs.open(gs_file, 'wb', encoding='utf-8') as f: 50 | f.write(u''.join(gs_text)) 51 | 52 | err_file = remove_ext(json_file.name) 53 | err_file = os.path.join(out_dir, '{}-errors.txt'.format(err_file)) 54 | with codecs.open(err_file, 'wb', encoding='utf-8') as f: 55 | f.write(u''.join(text_with_errors)) 56 | 57 | 58 | if __name__ == '__main__': 59 | clin2018st_extract_text() 60 | -------------------------------------------------------------------------------- /ochre/cwl/word-mapping-wf.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | doc: This workflow is meant to be used as a subworkflow. 5 | requirements: 6 | - class: SubworkflowFeatureRequirement 7 | - class: ScatterFeatureRequirement 8 | inputs: 9 | gs_files: File[] 10 | ocr_files: File[] 11 | language: string 12 | lowercase: boolean? 13 | align_m: string? 14 | align_c: string? 15 | wm_name: string? 16 | outputs: 17 | wm_mapping: 18 | type: File 19 | outputSource: merge-csv-2/merged 20 | steps: 21 | normalize-whitespace-punctuation-2: 22 | run: normalize-whitespace-punctuation.cwl 23 | in: 24 | meta_in: gs_files 25 | out: 26 | - metadata_out 27 | scatter: 28 | - meta_in 29 | scatterMethod: dotproduct 30 | normalize-whitespace-punctuation-3: 31 | run: normalize-whitespace-punctuation.cwl 32 | in: 33 | meta_in: ocr_files 34 | out: 35 | - metadata_out 36 | scatter: 37 | - meta_in 38 | scatterMethod: dotproduct 39 | align-texts-wf-2: 40 | run: align-texts-wf.cwl 41 | in: 42 | align_m: align_m 43 | align_c: align_c 44 | ocr: normalize-whitespace-punctuation-3/metadata_out 45 | gs: normalize-whitespace-punctuation-2/metadata_out 46 | out: 47 | - alignments 48 | - changes 49 | - metadata 50 | pattern-1: 51 | run: https://raw.githubusercontent.com/nlppln/pattern-docker/master/pattern.cwl 52 | in: 53 | in_file: normalize-whitespace-punctuation-2/metadata_out 54 | language: language 55 | out: 56 | - out_files 57 | scatter: 58 | - in_file 59 | scatterMethod: dotproduct 60 | create-word-mappings-1: 61 | run: create-word-mappings.cwl 62 | in: 63 | lowercase: lowercase 64 | alignments: align-texts-wf-2/alignments 65 | saf: pattern-1/out_files 66 | out: 67 | - word_mapping 68 | scatter: 69 | - alignments 70 | - saf 71 | scatterMethod: dotproduct 72 | merge-csv-2: 73 | run: merge-csv.cwl 74 | in: 75 | in_files: create-word-mappings-1/word_mapping 76 | name: wm_name 77 | out: 78 | - merged 79 | -------------------------------------------------------------------------------- /notebooks/vudnc-preprocess-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import nlppln\n", 20 | "import ochre\n", 21 | "import os\n", 22 | "\n", 23 | "working_dir = '/home/jvdzwaan/cwl-working-dir/'\n", 24 | "\n", 25 | "with nlppln.WorkflowGenerator(working_dir=working_dir) as wf:\n", 26 | " wf.load(steps_dir=ochre.cwl_path())\n", 27 | " print(wf.list_steps())\n", 28 | "\n", 29 | " archive = wf.add_input(archive='File')\n", 30 | " ocr_dir_name = wf.add_input(ocr_dir_name='string?', default='ocr')\n", 31 | " gs_dir_name = wf.add_input(gs_dir_name='string?', default='gs')\n", 32 | " \n", 33 | " in_dir = wf.archive2dir(archive=archive)\n", 34 | "\n", 35 | " vudnc_files = wf.vudnc_select_files(in_dir=in_dir)\n", 36 | " gs_with_empty, ocr_with_empty = wf.vudnc2ocr_and_gs(in_file=vudnc_files, \n", 37 | " scatter='in_file', scatter_method='dotproduct')\n", 38 | " gs_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs_with_empty)\n", 39 | " ocr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr_with_empty)\n", 40 | " gs, ocr = wf.remove_empty_files(gs_dir=gs_dir, ocr_dir=ocr_dir)\n", 41 | "\n", 42 | " gs_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs)\n", 43 | " ocr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr)\n", 44 | "\n", 45 | " wf.add_outputs(gs_dir=gs_dir, ocr_dir=ocr_dir)\n", 46 | " wf.save(os.path.join(ochre.cwl_path(), 'vudnc-preprocess-pack.cwl'), pack=True, relative=False)" 47 | ] 48 | } 49 | ], 50 | "metadata": { 51 | "language_info": { 52 | "name": "python", 53 | "pygments_lexer": "ipython3" 54 | } 55 | }, 56 | "nbformat": 4, 57 | "nbformat_minor": 2 58 | } 59 | -------------------------------------------------------------------------------- /notebooks/post_correction_workflows.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from nlppln import WorkflowGenerator\n", 20 | "\n", 21 | "with WorkflowGenerator() as wf:\n", 22 | " wf.load(steps_dir='../cwl/')\n", 23 | " print(wf.list_steps())\n", 24 | "\n", 25 | " in_dir = wf.add_input(in_dir='Directory')\n", 26 | " charset = wf.add_input(charset='File')\n", 27 | " model = wf.add_input(model='File')\n", 28 | "\n", 29 | " ocr_files = wf.ls(in_dir=in_dir)\n", 30 | " corrected = wf.lstm_synced_correct_ocr(charset=charset, model=model, txt=ocr_files, scatter='txt', scatter_method='dotproduct')\n", 31 | "\n", 32 | " wf.add_outputs(corrected=corrected)\n", 33 | " wf.save('../cwl/post-correct-dir.cwl')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from nlppln import WorkflowGenerator\n", 43 | "\n", 44 | "with WorkflowGenerator() as wf: \n", 45 | " wf.load(steps_dir='../cwl/')\n", 46 | " print(wf.list_steps())\n", 47 | "\n", 48 | " in_dir = wf.add_input(in_dir='Directory')\n", 49 | " datadiv = wf.add_input(datadivision='File')\n", 50 | " div_name = wf.add_input(div_name='string', default='test')\n", 51 | " charset = wf.add_input(charset='File')\n", 52 | " model = wf.add_input(model='File')\n", 53 | "\n", 54 | " ocr_files = wf.select_test_files(in_dir=in_dir, datadivision=datadiv)\n", 55 | " corrected = wf.lstm_synced_correct_ocr(charset=charset, model=model, txt=ocr_files, scatter='txt', scatter_method='dotproduct')\n", 56 | "\n", 57 | " wf.add_outputs(corrected=corrected)\n", 58 | " wf.save('../cwl/post-correct-test-files.cwl')" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "language_info": { 64 | "name": "python", 65 | "pygments_lexer": "ipython3" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 2 70 | } 71 | -------------------------------------------------------------------------------- /ochre/cwl/post-correct-test-files.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | requirements: 5 | - class: ScatterFeatureRequirement 6 | inputs: 7 | in_dir: Directory 8 | datadivision: File 9 | div_name: 10 | default: test 11 | type: string 12 | charset: File 13 | model: File 14 | outputs: 15 | corrected: 16 | type: 17 | items: File 18 | type: array 19 | outputSource: lstm-synced-correct-ocr-1/corrected 20 | steps: 21 | select-test-files-2: 22 | run: 23 | cwlVersion: v1.0 24 | class: CommandLineTool 25 | baseCommand: [python, -m, ochre.select_test_files] 26 | 27 | stdout: cwl.output.json 28 | 29 | inputs: 30 | - type: File 31 | inputBinding: 32 | position: 2 33 | id: _:select-test-files-2#datadivision 34 | - type: Directory 35 | inputBinding: 36 | position: 1 37 | id: _:select-test-files-2#in_dir 38 | - type: 39 | - 'null' 40 | - string 41 | inputBinding: 42 | prefix: --name 43 | 44 | id: _:select-test-files-2#name 45 | outputs: 46 | - type: 47 | type: array 48 | items: File 49 | id: _:select-test-files-2#out_files 50 | id: _:select-test-files-2 51 | in: 52 | in_dir: in_dir 53 | datadivision: datadivision 54 | out: 55 | - out_files 56 | lstm-synced-correct-ocr-1: 57 | run: 58 | cwlVersion: v1.0 59 | class: CommandLineTool 60 | baseCommand: [python, -m, ochre.lstm_synced_correct_ocr] 61 | 62 | inputs: 63 | - type: File 64 | inputBinding: 65 | position: 2 66 | id: _:lstm-synced-correct-ocr-1#charset 67 | - type: File 68 | inputBinding: 69 | position: 1 70 | id: _:lstm-synced-correct-ocr-1#model 71 | - type: File 72 | inputBinding: 73 | position: 3 74 | 75 | id: _:lstm-synced-correct-ocr-1#txt 76 | outputs: 77 | - type: File 78 | outputBinding: 79 | glob: $(inputs.txt.basename) 80 | id: _:lstm-synced-correct-ocr-1#corrected 81 | id: _:lstm-synced-correct-ocr-1 82 | in: 83 | txt: select-test-files-2/out_files 84 | model: model 85 | charset: charset 86 | out: 87 | - corrected 88 | scatter: 89 | - txt 90 | scatterMethod: dotproduct 91 | -------------------------------------------------------------------------------- /ochre/ocrevaluation_extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import os 5 | import tempfile 6 | 7 | from bs4 import BeautifulSoup 8 | 9 | from nlppln.utils import create_dirs, remove_ext 10 | 11 | 12 | @click.command() 13 | @click.argument('in_file', type=click.File(encoding='utf-8')) 14 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 15 | def ocrevaluation_extract(in_file, out_dir): 16 | create_dirs(out_dir) 17 | 18 | tables = [] 19 | 20 | write = False 21 | 22 | (fd, tmpfile) = tempfile.mkstemp() 23 | with codecs.open(tmpfile, 'w', encoding='utf-8') as tmp: 24 | for line in in_file: 25 | if line.startswith('

General'): 26 | write = True 27 | if line.startswith('

Difference'): 28 | write = False 29 | if line.startswith('

Error'): 30 | write = True 31 | 32 | if write: 33 | tmp.write(line) 34 | 35 | with codecs.open(tmpfile, encoding='utf-8') as f: 36 | soup = BeautifulSoup(f.read(), 'lxml') 37 | os.remove(tmpfile) 38 | 39 | tables = soup.find_all('table') 40 | assert len(tables) == 2 41 | 42 | doc = remove_ext(in_file.name) 43 | 44 | t = tables[0] 45 | table_data = [[cell.text for cell in row('td')] for row in t('tr')] 46 | 47 | # 'transpose' table_data 48 | lines = {} 49 | for data in table_data: 50 | for i, entry in enumerate(data): 51 | if i not in lines.keys(): 52 | # add doc id to data line (but not to header) 53 | if i != 0: 54 | lines[i] = [doc] 55 | else: 56 | lines[i] = [''] 57 | lines[i].append(entry) 58 | 59 | out_file = os.path.join(out_dir, '{}-global.csv'.format(doc)) 60 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 61 | for i in range(len(lines.keys())): 62 | f.write(u','.join(lines[i])) 63 | f.write(u'\n') 64 | 65 | t = tables[1] 66 | table_data = [[cell.text for cell in row('td')] for row in t('tr')] 67 | out_file = os.path.join(out_dir, '{}-character.csv'.format(doc)) 68 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 69 | for data in table_data: 70 | f.write(u'"{}",'.format(data[0])) 71 | f.write(u','.join(data[1:])) 72 | f.write(u'\n') 73 | 74 | 75 | if __name__ == '__main__': 76 | ocrevaluation_extract() 77 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Always prefer setuptools over distutils 2 | from setuptools import setup, find_packages 3 | # To use a consistent encoding 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | setup( 10 | name='ochre', 11 | 12 | # Versions should comply with PEP440. For a discussion on single-sourcing 13 | # the version across setup.py and the project code, see 14 | # https://packaging.python.org/en/latest/single_source_version.html 15 | version='0.1.0', 16 | 17 | description='Command line tools for the ocr project', 18 | long_description="""Command line tools for the ocr project 19 | """, 20 | 21 | # The project's main homepage. 22 | url='https://github.com/WhatWorksWhenForWhom/nlppln', 23 | 24 | # Author details 25 | author='Janneke van der Zwaan', 26 | author_email='j.vanderzwaan@esciencecenter.nl', 27 | 28 | # Choose your license 29 | license='Apache', 30 | 31 | include_package_data=True, 32 | 33 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 34 | classifiers=[ 35 | # How mature is this project? Common values are 36 | # 3 - Alpha 37 | # 4 - Beta 38 | # 5 - Production/Stable 39 | 'Development Status :: 3 - Alpha', 40 | 41 | # Indicate who your project is intended for 42 | 'Intended Audience :: Developers', 43 | 44 | # Pick your license as you wish (should match "license" above) 45 | 'License :: OSI Approved :: Apache Software License', 46 | 47 | # Specify the Python versions you support here. In particular, ensure 48 | # that you indicate whether you support Python 2, Python 3 or both. 49 | 'Programming Language :: Python :: 2', 50 | 'Programming Language :: Python :: 2.7', 51 | ], 52 | 53 | # What does your project relate to? 54 | keywords='text-mining, nlp, pipeline', 55 | 56 | # You can just specify the packages manually here if your project is 57 | # simple. Or you can use find_packages(). 58 | packages=find_packages(), 59 | 60 | # List run-time dependencies here. These will be installed by pip when 61 | # your project is installed. For an analysis of "install_requires" vs pip's 62 | # requirements files see: 63 | # https://packaging.python.org/en/latest/requirements.html 64 | install_requires=[ 65 | 'cwltool==1.0.20181102182747', 66 | 'scriptcwl>=0.8.1', 67 | 'nlppln>=0.3.2', 68 | 'six', 69 | 'glob2', 70 | 'sh', 71 | ] 72 | 73 | #scripts=['recipy-cmd'] 74 | ) 75 | -------------------------------------------------------------------------------- /ochre/cwl/lowercase-directory.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | requirements: 5 | - class: ScatterFeatureRequirement 6 | inputs: 7 | in_dir: Directory 8 | dir_name: 9 | default: gs_lowercase 10 | type: string 11 | outputs: 12 | out_dir: 13 | type: Directory 14 | outputSource: save-files-to-dir-4/out 15 | steps: 16 | ls-1: 17 | run: 18 | cwlVersion: v1.0 19 | class: CommandLineTool 20 | baseCommand: [python, -m, nlppln.commands.ls] 21 | 22 | inputs: 23 | - type: Directory 24 | inputBinding: 25 | position: 2 26 | id: _:ls-1#in_dir 27 | - type: 28 | - 'null' 29 | - boolean 30 | inputBinding: 31 | prefix: --recursive 32 | 33 | id: _:ls-1#recursive 34 | stdout: cwl.output.json 35 | 36 | outputs: 37 | - type: 38 | type: array 39 | items: File 40 | id: _:ls-1#out_files 41 | id: _:ls-1 42 | in: 43 | in_dir: in_dir 44 | out: 45 | - out_files 46 | lowercase-1: 47 | run: 48 | cwlVersion: v1.0 49 | class: CommandLineTool 50 | baseCommand: [python, -m, nlppln.commands.lowercase] 51 | 52 | inputs: 53 | - type: File 54 | inputBinding: 55 | position: 1 56 | 57 | id: _:lowercase-1#in_file 58 | stdout: $(inputs.in_file.nameroot).txt 59 | 60 | outputs: 61 | - type: File 62 | outputBinding: 63 | glob: $(inputs.in_file.nameroot).txt 64 | id: _:lowercase-1#out_files 65 | id: _:lowercase-1 66 | in: 67 | in_file: ls-1/out_files 68 | out: 69 | - out_files 70 | scatter: 71 | - in_file 72 | scatterMethod: dotproduct 73 | save-files-to-dir-4: 74 | run: 75 | cwlVersion: v1.0 76 | class: ExpressionTool 77 | 78 | requirements: 79 | - class: InlineJavascriptRequirement 80 | 81 | inputs: 82 | - type: string 83 | id: _:save-files-to-dir-4#dir_name 84 | - type: 85 | type: array 86 | items: File 87 | id: _:save-files-to-dir-4#in_files 88 | outputs: 89 | - type: Directory 90 | id: _:save-files-to-dir-4#out 91 | expression: | 92 | ${ 93 | return {"out": { 94 | "class": "Directory", 95 | "basename": inputs.dir_name, 96 | "listing": inputs.in_files 97 | } }; 98 | } 99 | id: _:save-files-to-dir-4 100 | in: 101 | dir_name: dir_name 102 | in_files: lowercase-1/out_files 103 | out: 104 | - out 105 | -------------------------------------------------------------------------------- /ochre/edlibutils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import edlib 3 | 4 | 5 | def align_characters(query, ref, empty_char=''): 6 | a = edlib.align(query, ref, task="path") 7 | ref_pos = a["locations"][0][0] 8 | query_pos = 0 9 | ref_aln = [] 10 | match_aln = "" 11 | query_aln = [] 12 | 13 | for step, code in re.findall(r"(\d+)(\D)", a["cigar"]): 14 | step = int(step) 15 | if code == "=": 16 | for c in ref[ref_pos: ref_pos + step]: 17 | ref_aln.append(c) 18 | #ref_aln += ref[ref_pos : ref_pos + step] 19 | ref_pos += step 20 | for c in query[query_pos: query_pos + step]: 21 | query_aln.append(c) 22 | #query_aln += query[query_pos : query_pos + step] 23 | query_pos += step 24 | match_aln += "|" * step 25 | elif code == "X": 26 | for c in ref[ref_pos: ref_pos + step]: 27 | ref_aln.append(c) 28 | #ref_aln += ref[ref_pos : ref_pos + step] 29 | ref_pos += step 30 | for c in query[query_pos: query_pos + step]: 31 | query_aln.append(c) 32 | #query_aln += query[query_pos : query_pos + step] 33 | query_pos += step 34 | match_aln += "." * step 35 | elif code == "D": 36 | for c in ref[ref_pos: ref_pos + step]: 37 | ref_aln.append(c) 38 | #ref_aln += ref[ref_pos : ref_pos + step] 39 | ref_pos += step 40 | #query_aln += " " * step 41 | query_pos += 0 42 | for i in range(step): 43 | query_aln.append('') 44 | match_aln += " " * step 45 | elif code == "I": 46 | for i in range(step): 47 | ref_aln.append('') 48 | #ref_aln += " " * step 49 | ref_pos += 0 50 | for c in query[query_pos: query_pos + step]: 51 | query_aln.append(c) 52 | #query_aln += query[query_pos : query_pos + step] 53 | query_pos += step 54 | match_aln += " " * step 55 | else: 56 | pass 57 | 58 | return ref_aln, query_aln, match_aln 59 | 60 | 61 | def align_output_to_input(input_str, output_str, empty_char=u'@'): 62 | t_output_str = output_str.encode('ASCII', 'replace') 63 | t_input_str = input_str.encode('ASCII', 'replace') 64 | #try: 65 | # r = edlib.align(t_input_str, t_output_str, task='path') 66 | #except: 67 | # print(input_str) 68 | # print(output_str) 69 | r1, r2 = align_characters(input_str, output_str, 70 | empty_char=empty_char) 71 | while len(r2) < len(input_str): 72 | r2.append(empty_char) 73 | return u''.join(r2) 74 | -------------------------------------------------------------------------------- /ochre/cwl/sac-extract.cwl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cwl-runner 2 | cwlVersion: v1.0 3 | class: Workflow 4 | inputs: 5 | data_file: File 6 | ocr_dir_name: 7 | default: ocr 8 | type: string 9 | gs_dir_name: 10 | default: gs 11 | type: string 12 | de_dir_name: 13 | default: de 14 | type: string 15 | fr_dir_name: 16 | default: fr 17 | type: string 18 | outputs: 19 | ocr_de: 20 | type: Directory 21 | outputSource: save-dir-to-subdir/out 22 | gs_de: 23 | type: Directory 24 | outputSource: save-dir-to-subdir-1/out 25 | ocr_fr: 26 | type: Directory 27 | outputSource: save-dir-to-subdir-2/out 28 | gs_fr: 29 | type: Directory 30 | outputSource: save-dir-to-subdir-3/out 31 | steps: 32 | tar: 33 | run: tar.cwl 34 | in: 35 | in_file: data_file 36 | out: 37 | - out 38 | sac2gs-and-ocr: 39 | run: sac2gs-and-ocr.cwl 40 | in: 41 | in_dir: tar/out 42 | out: 43 | - gs_de 44 | - gs_fr 45 | - ocr_de 46 | - ocr_fr 47 | mkdir: 48 | run: mkdir.cwl 49 | in: 50 | dir_name: de_dir_name 51 | out: 52 | - out 53 | mkdir-1: 54 | run: mkdir.cwl 55 | in: 56 | dir_name: fr_dir_name 57 | out: 58 | - out 59 | save-files-to-dir-9: 60 | run: save-files-to-dir.cwl 61 | in: 62 | dir_name: ocr_dir_name 63 | in_files: sac2gs-and-ocr/ocr_de 64 | out: 65 | - out 66 | save-dir-to-subdir: 67 | run: save-dir-to-subdir.cwl 68 | in: 69 | outer_dir: mkdir/out 70 | inner_dir: save-files-to-dir-9/out 71 | out: 72 | - out 73 | save-files-to-dir-10: 74 | run: save-files-to-dir.cwl 75 | in: 76 | dir_name: gs_dir_name 77 | in_files: sac2gs-and-ocr/gs_de 78 | out: 79 | - out 80 | save-dir-to-subdir-1: 81 | run: save-dir-to-subdir.cwl 82 | in: 83 | outer_dir: mkdir/out 84 | inner_dir: save-files-to-dir-10/out 85 | out: 86 | - out 87 | save-files-to-dir-11: 88 | run: save-files-to-dir.cwl 89 | in: 90 | dir_name: ocr_dir_name 91 | in_files: sac2gs-and-ocr/ocr_fr 92 | out: 93 | - out 94 | save-dir-to-subdir-2: 95 | run: save-dir-to-subdir.cwl 96 | in: 97 | outer_dir: mkdir-1/out 98 | inner_dir: save-files-to-dir-11/out 99 | out: 100 | - out 101 | save-files-to-dir-12: 102 | run: save-files-to-dir.cwl 103 | in: 104 | dir_name: gs_dir_name 105 | in_files: sac2gs-and-ocr/gs_fr 106 | out: 107 | - out 108 | save-dir-to-subdir-3: 109 | run: save-dir-to-subdir.cwl 110 | in: 111 | outer_dir: mkdir-1/out 112 | inner_dir: save-files-to-dir-12/out 113 | out: 114 | - out 115 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation/out/gs_out.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 |

General results

8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
CER4.17
WER20.00
WER (order independent)20.00
20 |
21 |

Difference spotting

22 |
23 |
24 | 25 | 26 | 31 | 32 | 33 | 34 | 35 |
27 |

gs.txt

28 |
29 |

ocr.txt

30 |
This is an example text.This is an cxample text.
36 |

Error rate per character and type

37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 |
CharacterHex codeTotalSpuriousConfusedLostError rate
2040000.00
.2e10000.00
T5410000.00
a6120000.00
e65301033.33
h6810000.00
i6920000.00
l6c10000.00
m6d10000.00
n6e10000.00
p7010000.00
s7320000.00
t7420000.00
x7820000.00
84 | 85 | 86 | -------------------------------------------------------------------------------- /tests/data/ocrevaluation-extract/in/in.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 |

General results

8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
CER4.17
WER20.00
WER (order independent)20.00
20 |
21 |

Difference spotting

22 |
23 |
24 | 25 | 26 | 31 | 32 | 33 | 34 | 35 |
27 |

gs.txt

28 |
29 |

ocr.txt

30 |
This is an example text.This is an cxample text.
36 |

Error rate per character and type

37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 |
CharacterHex codeTotalSpuriousConfusedLostError rate
2040000.00
.2e10000.00
T5410000.00
a6120000.00
e65301033.33
h6810000.00
i6920000.00
l6c10000.00
m6d10000.00
n6e10000.00
p7010000.00
s7320000.00
t7420000.00
x7820000.00
84 | 85 | 86 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | """Data generation functionality for ochre 2 | 3 | Source: 4 | https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly 5 | """ 6 | import keras 7 | import tensorflow 8 | 9 | import numpy as np 10 | 11 | 12 | class DataGenerator(tensorflow.keras.utils.Sequence): 13 | 'Generates data for Keras' 14 | def __init__(self, xData, yData, char_to_int, seq_length, 15 | padding_char='\n', oov_char='@', batch_size=32, shuffle=True): 16 | """ 17 | xData is list of input strings 18 | yData is list of output strings 19 | """ 20 | self.xData = xData 21 | self.yData = yData 22 | 23 | self.char_to_int = char_to_int 24 | self.padding_char = padding_char 25 | self.oov_char = oov_char 26 | 27 | self.n_vocab = len(char_to_int) 28 | self.seq_length = seq_length 29 | 30 | self.batch_size = batch_size 31 | self.shuffle = shuffle 32 | self.on_epoch_end() 33 | 34 | def __len__(self): 35 | 'Denotes the number of batches per epoch' 36 | return int(np.floor(len(self.xData) / self.batch_size)) 37 | 38 | def __getitem__(self, index): 39 | 'Generate one batch of data' 40 | # Generate indexes of the batch 41 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] 42 | 43 | # Generate data 44 | X, y = self.__data_generation(indexes) 45 | 46 | return X, y 47 | 48 | def on_epoch_end(self): 49 | 'Updates indexes after each epoch' 50 | self.indexes = np.arange(len(self.xData)) 51 | if self.shuffle is True: 52 | np.random.shuffle(self.indexes) 53 | 54 | def __data_generation(self, list_IDs_temp): 55 | 'Generates data containing batch_size samples' 56 | # Initialization 57 | X = np.empty((self.batch_size, self.seq_length), dtype=np.int) 58 | ylist = list() 59 | 60 | # Generate data 61 | for i, ID in enumerate(list_IDs_temp): 62 | # input 63 | X[i, ] = self._convert_sample(self.xData[ID]) 64 | 65 | # output 66 | y_seq = self._convert_sample(self.yData[ID]) 67 | enc = keras.utils.to_categorical(y_seq, num_classes=self.n_vocab) 68 | ylist.append(enc) 69 | 70 | y = np.array(ylist) 71 | y = y.reshape(self.batch_size, self.seq_length, self.n_vocab) 72 | 73 | return X, y 74 | 75 | def _convert_sample(self, string): 76 | res = np.empty(self.seq_length, dtype=np.int) 77 | oov = self.char_to_int[self.oov_char] 78 | for i in range(self.seq_length): 79 | try: 80 | res[i] = self.char_to_int.get(string[i], oov) 81 | except IndexError: 82 | res[i] = self.char_to_int[self.padding_char] 83 | return res 84 | -------------------------------------------------------------------------------- /ochre/scramble.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | #import codecs 4 | import numpy as np 5 | 6 | 7 | @click.command() 8 | #@click.argument('in_files', nargs=-1, type=click.Path(exists=True)) 9 | #@click.argument('out_dir', nargs=1, type=click.Path()) 10 | #def command(in_files, out_dir): 11 | def command(): 12 | # Hoe gaat het scramblen in z'n werk? 13 | # Random. We hebben insertions, deletions en substitutions. 14 | # Die kunnen allemaal meer dan 1 teken lang zijn (zeg: tot maximaal 5 lang) 15 | # Eerst kies je of je een insertion, deletion of substitution gaat doen. 16 | # Dan kies je hoe lang ie gaat zijn (bij substitution moet je twee lengtes 17 | # kiezen: wat wordt vervangen en waarvoor het wordt vervangen) 18 | # Hmm, je hebt alleen insertion en deletion nodig (met bijbehorende 19 | # lengtes), soms doe je allebei, dat is substitution). 20 | # Nee, als het iets is, is het een substitution, dan sample je voor de 21 | # lengte van de insert en delete (die soms ook 0 kan zijn) 22 | # Hmm, which characters are subsituted depends on the char (but now we do 23 | # totally random, this is for later) 24 | text = 'Dit is een tekst.' 25 | text_list = [c for c in text] 26 | text_length = len(text) 27 | probs2 = [0.5, 0.35, 0.14, 0.005, 0.005] 28 | 29 | change = np.random.choice(2, text_length, p=[0.8, 0.2]) 30 | insert = [] 31 | delete = [] 32 | for i in change: 33 | if i == 1: 34 | insert.append(np.random.choice(len(probs2), 1, p=probs2)[0]) 35 | delete.append(np.random.choice(len(probs2), 1, p=probs2)[0]) 36 | else: 37 | insert.append(0) 38 | delete.append(0) 39 | 40 | #insert = np.random.choice(max_length, text_length, p=probs) 41 | #delete = np.random.choice(max_length, text_length, p=probs) 42 | #insert = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 43 | #delete = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] 44 | print(insert) 45 | print(delete) 46 | idx = 0 47 | for tidx, (i, d) in enumerate(zip(insert, delete)): 48 | #print(idx) 49 | #print(i) 50 | #print(d) 51 | #print('---') 52 | #print(idx) 53 | if i != 0 or d != 0: 54 | if d != 0: 55 | for _ in range(d): 56 | # make sure text_list[idx] exists, before deleting it 57 | if idx < len(text_list): 58 | del(text_list[idx]) 59 | # only adjust idx if no characters should be inserted 60 | if i == 0: 61 | idx += 1 62 | if i != 0: 63 | for _ in range(i): 64 | # TODO: replace by random character(s) instead of a fixed 65 | # one (possibly dependent on what was deleted) 66 | text_list.insert(idx, '*') 67 | idx += i 68 | else: 69 | idx += 1 70 | 71 | print(''.join(text_list)) 72 | 73 | 74 | if __name__ == '__main__': 75 | command() 76 | -------------------------------------------------------------------------------- /ochre/create_word_mappings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import os 4 | import json 5 | 6 | import pandas as pd 7 | 8 | from string import punctuation 9 | 10 | from nlppln.utils import create_dirs, remove_ext, out_file_name 11 | 12 | 13 | def find_word_boundaries(words, aligned): 14 | unaligned = [] 15 | for w in words: 16 | for c in w: 17 | unaligned.append(c) 18 | unaligned.append('@') 19 | 20 | i = 0 21 | j = 0 22 | prev = 0 23 | wb = [] 24 | while i < len(aligned) and j < len(unaligned): 25 | if unaligned[j] == '@': 26 | w = u''.join(aligned[prev:i]) 27 | # Punctuation that occurs in the ocr but not in the gold standard 28 | # is ignored. 29 | if len(w) > 1 and (w[0] == u' ' or w[0] in punctuation): 30 | s = prev + 1 31 | else: 32 | s = prev 33 | #print('Word', w) 34 | wb.append((s, i)) 35 | prev = i 36 | j += 1 37 | elif aligned[i] == '' or aligned[i] == ' ': 38 | #print('Space or empty in aligned', unaligned[j]) 39 | i += 1 40 | elif aligned[i] == unaligned[j]: 41 | #print('Matching chars', unaligned[j]) 42 | i += 1 43 | j += 1 44 | else: 45 | #print('Other', unaligned[j], aligned[i]) 46 | #print('j', j, 'i', i) 47 | i += 1 48 | # add last word 49 | wb.append((prev, len(aligned))) 50 | return wb 51 | 52 | 53 | @click.command() 54 | @click.argument('saf', type=click.File(encoding='utf-8')) 55 | @click.argument('alignments', type=click.File(encoding='utf-8')) 56 | @click.option('--lowercase/--no-lowercase', default=False) 57 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 58 | def create_word_mappings(saf, alignments, lowercase, out_dir): 59 | create_dirs(out_dir) 60 | 61 | alignment_data = json.load(alignments) 62 | aligned1 = alignment_data['gs'] 63 | aligned2 = alignment_data['ocr'] 64 | 65 | saf = json.load(saf) 66 | if lowercase: 67 | words = [w['word'].lower() for w in saf['tokens']] 68 | 69 | aligned1 = [c.lower() for c in aligned1] 70 | aligned2 = [c.lower() for c in aligned2] 71 | else: 72 | words = [w['word'] for w in saf['tokens']] 73 | 74 | wb = find_word_boundaries(words, aligned1) 75 | 76 | doc_id = remove_ext(alignments.name) 77 | 78 | res = {'gs': [], 'ocr': [], 'doc_id': []} 79 | for s, e in wb: 80 | w1 = u''.join(aligned1[s:e]) 81 | w2 = u''.join(aligned2[s:e]) 82 | 83 | res['gs'].append(w1.strip()) 84 | res['ocr'].append(w2.strip()) 85 | res['doc_id'].append(doc_id) 86 | 87 | # Use pandas DataFrame to create the csv, so commas and quotes are properly 88 | # escaped. 89 | df = pd.DataFrame(res) 90 | 91 | out_file = out_file_name(out_dir, doc_id, ext='csv') 92 | df.to_csv(out_file, encoding='utf-8') 93 | 94 | if __name__ == '__main__': 95 | create_word_mappings() 96 | -------------------------------------------------------------------------------- /notebooks/icdar2019POCR-to-texts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "\n", 25 | "from tqdm import tqdm_notebook as tqdm" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os\n", 35 | "import codecs\n", 36 | "import json\n", 37 | "\n", 38 | "from nlppln.utils import get_files, out_file_name, create_dirs\n", 39 | "\n", 40 | "from ochre.icdar2017st_extract_text import to_character_list\n", 41 | "\n", 42 | "in_dir = '/home/jvdzwaan/Downloads/POCR_training_dataset/NL/NL1/'\n", 43 | "\n", 44 | "out_dir = '/home/jvdzwaan/data/icdar2019pocr-nl'\n", 45 | "\n", 46 | "def command(in_file, out_dir):\n", 47 | " create_dirs(out_dir)\n", 48 | "\n", 49 | " lines = in_file.readlines()\n", 50 | " # OCR_toInput: lines[0][:14]\n", 51 | " # OCR_aligned: lines[1][:14]\n", 52 | " # GS_aligned: lines[2][:14]\n", 53 | " ocr = to_character_list(lines[1][14:].strip())\n", 54 | " gs = to_character_list(lines[2][14:].strip())\n", 55 | "\n", 56 | " # Write texts\n", 57 | " out_file = out_file_name(os.path.join(out_dir, 'ocr'), os.path.basename(in_file.name))\n", 58 | " print(out_file)\n", 59 | " create_dirs(out_file, is_file=True)\n", 60 | " with codecs.open(out_file, 'wb', encoding='utf-8') as f:\n", 61 | " f.write(u''.join(ocr))\n", 62 | "\n", 63 | " out_file = out_file_name(os.path.join(out_dir, 'gs'), os.path.basename(in_file.name))\n", 64 | " print(out_file)\n", 65 | " create_dirs(out_file, is_file=True)\n", 66 | " with codecs.open(out_file, 'wb', encoding='utf-8') as f:\n", 67 | " f.write(u''.join(gs))\n", 68 | "\n", 69 | " out_file = out_file_name(os.path.join(out_dir, 'aligned'), os.path.basename(in_file.name), 'json')\n", 70 | " print(out_file)\n", 71 | " create_dirs(out_file, is_file=True)\n", 72 | " with codecs.open(out_file, 'wb', encoding='utf-8') as f:\n", 73 | " json.dump({'ocr': ocr, 'gs': gs}, f)\n", 74 | " \n", 75 | " print()\n", 76 | "\n", 77 | "in_files = get_files(in_dir)\n", 78 | "#print(in_files)\n", 79 | "print(len(in_files))\n", 80 | "for in_file in in_files:\n", 81 | " print(in_file)\n", 82 | " \n", 83 | " command(open(in_file), out_dir)" 84 | ] 85 | } 86 | ], 87 | "metadata": { 88 | "language_info": { 89 | "name": "python", 90 | "pygments_lexer": "ipython3" 91 | } 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 2 95 | } 96 | -------------------------------------------------------------------------------- /tests/test_ocrevaluation_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import codecs 4 | import sh 5 | import pytest 6 | 7 | from click.testing import CliRunner 8 | from bs4 import BeautifulSoup 9 | 10 | from ochre import ocrevaluation_extract 11 | 12 | 13 | #def test_prettify_xml(): 14 | # runner = CliRunner() 15 | # with runner.isolated_filesystem(): 16 | # os.makedirs('in') 17 | # os.makedirs('out') 18 | # with open('in/test.xml', 'w') as f: 19 | # xml = '\n\n' \ 20 | # '\n' \ 21 | # '\n' \ 22 | # '\n' 23 | # f.write(xml) 24 | 25 | # result = runner.invoke(prettify_xml, ['in/test.xml', 26 | # '--out_dir', 'out']) 27 | 28 | # assert result.exit_code == 0 29 | 30 | # assert os.path.exists('out/test.xml') 31 | 32 | # with codecs.open('out/test.xml', 'r', encoding='utf-8') as f: 33 | # pretty = f.read() 34 | 35 | # assert pretty == BeautifulSoup(xml, 'xml').prettify() 36 | 37 | 38 | def test_ocr_evaluation_extract_cwl(tmpdir): 39 | tool = os.path.join('ochre', 'cwl', 'ocrevaluation-extract.cwl') 40 | in_file = os.path.join('tests', 'data', 'ocrevaluation-extract', 'in', 41 | 'in.html') 42 | 43 | try: 44 | sh.cwltool(['--outdir', tmpdir, tool, '--in_file', in_file]) 45 | except sh.ErrorReturnCode as e: 46 | print(e) 47 | pytest.fail(e) 48 | 49 | for out in ['in-character.csv', 'in-global.csv']: 50 | 51 | out_file = tmpdir.join(out).strpath 52 | with open(out_file) as f: 53 | actual = f.read() 54 | 55 | fname = os.path.join('tests', 'data', 'ocrevaluation-extract', 'out', 56 | out) 57 | with open(fname) as f: 58 | expected = f.read() 59 | 60 | print(out) 61 | print(' actual:', actual) 62 | print('expected:', expected) 63 | assert actual == expected 64 | 65 | 66 | def test_ocr_evaluation_extract_cwl_empty(tmpdir): 67 | tool = os.path.join('ochre', 'cwl', 'ocrevaluation-extract.cwl') 68 | in_file = os.path.join('tests', 'data', 'ocrevaluation-extract', 'in', 69 | 'empty.html') 70 | 71 | try: 72 | sh.cwltool(['--outdir', tmpdir, tool, '--in_file', in_file]) 73 | except sh.ErrorReturnCode as e: 74 | print(e) 75 | pytest.fail(e) 76 | 77 | for out in ['empty-character.csv', 'empty-global.csv']: 78 | 79 | out_file = tmpdir.join(out).strpath 80 | with open(out_file) as f: 81 | actual = f.read() 82 | 83 | fname = os.path.join('tests', 'data', 'ocrevaluation-extract', 'out', 84 | out) 85 | with open(fname) as f: 86 | expected = f.read() 87 | 88 | print(out) 89 | print(' actual:', actual) 90 | print('expected:', expected) 91 | assert actual == expected 92 | -------------------------------------------------------------------------------- /2017_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from keras.models import Sequential 5 | from keras.layers import Dense 6 | from keras.layers import LSTM 7 | from keras.layers import TimeDistributed 8 | from keras.layers import RepeatVector 9 | from keras.layers import Embedding 10 | from keras.callbacks import ModelCheckpoint 11 | 12 | from datagen import DataGenerator 13 | 14 | 15 | def infinite_loop(generator): 16 | while True: 17 | for i in range(len(generator)): 18 | yield(generator[i]) 19 | generator.on_epoch_end() 20 | 21 | 22 | # load the data 23 | data_dir = '/home/jvdzwaan/data/sprint-icdar/in' # FIXME 24 | weights_dir = '/home/jvdzwaan/data/sprint-icdar/weights' 25 | 26 | if not os.path.exists(weights_dir): 27 | os.makedirs(weights_dir) 28 | 29 | seq_length = 53 30 | batch_size = 100 31 | shuffle = True 32 | pc = '\n' 33 | oc = '@' 34 | 35 | n_nodes = 1000 36 | dropout = 0.2 37 | n_embed = 256 38 | 39 | epochs = 30 40 | loss = 'categorical_crossentropy' 41 | optimizer = 'adam' 42 | metrics = ['accuracy'] 43 | 44 | with open(os.path.join(data_dir, 'train.pkl'), 'rb') as f: 45 | gs_selected_train, ocr_selected_train = pickle.load(f) 46 | 47 | with open(os.path.join(data_dir, 'val.pkl'), 'rb') as f: 48 | gs_selected_val, ocr_selected_val = pickle.load(f) 49 | 50 | with open(os.path.join(data_dir, 'ci.pkl'), 'rb') as f: 51 | ci = pickle.load(f) 52 | 53 | n_vocab = len(ci) 54 | 55 | dg_val = DataGenerator(xData=ocr_selected_val, yData=gs_selected_val, 56 | char_to_int=ci, 57 | seq_length=seq_length, padding_char=pc, oov_char=oc, 58 | batch_size=batch_size, shuffle=shuffle) 59 | dg_train = DataGenerator(xData=ocr_selected_train, 60 | yData=gs_selected_train, char_to_int=ci, 61 | seq_length=seq_length, padding_char=pc, 62 | oov_char=oc, 63 | batch_size=batch_size, shuffle=shuffle) 64 | 65 | # create the network 66 | model = Sequential() 67 | 68 | # encoder 69 | model.add(Embedding(n_vocab, n_embed, input_length=seq_length)) 70 | model.add(LSTM(n_nodes, input_shape=(seq_length, n_vocab))) 71 | # For the decoder's input, we repeat the encoded input for each time step 72 | model.add(RepeatVector(seq_length)) 73 | model.add(LSTM(n_nodes, return_sequences=True)) 74 | 75 | # For each of step of the output sequence, decide which character should be 76 | # chosen 77 | model.add(TimeDistributed(Dense(n_vocab, activation='softmax'))) 78 | model.compile(loss=loss, optimizer=optimizer, metrics=metrics) 79 | 80 | # initialize saving of weights 81 | filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5') 82 | checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, 83 | save_best_only=True, mode='min') 84 | callbacks_list = [checkpoint] 85 | 86 | # do training (and save weights) 87 | model.fit_generator(infinite_loop(dg_train), steps_per_epoch=len(dg_train), epochs=epochs, 88 | validation_data=infinite_loop(dg_val), 89 | validation_steps=len(dg_val), callbacks=callbacks_list) 90 | -------------------------------------------------------------------------------- /ochre/lstm_synced_correct_ocr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import click 3 | import codecs 4 | import os 5 | import numpy as np 6 | 7 | from collections import Counter 8 | 9 | from keras.models import load_model 10 | 11 | from nlppln.utils import create_dirs, out_file_name 12 | 13 | from .utils import get_char_to_int, get_int_to_char, read_text_to_predict 14 | from .edlibutils import align_output_to_input 15 | 16 | 17 | @click.command() 18 | @click.argument('model', type=click.Path(exists=True)) 19 | @click.argument('charset', type=click.File(encoding='utf-8')) 20 | @click.argument('text', type=click.File(encoding='utf-8')) 21 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 22 | def lstm_synced_correct_ocr(model, charset, text, out_dir): 23 | create_dirs(out_dir) 24 | 25 | # load model 26 | model = load_model(model) 27 | conf = model.get_config() 28 | conf_result = conf[0].get('config').get('batch_input_shape') 29 | seq_length = conf_result[1] 30 | char_embedding = False 31 | if conf[0].get('class_name') == u'Embedding': 32 | char_embedding = True 33 | 34 | charset = charset.read() 35 | n_vocab = len(charset) 36 | char_to_int = get_char_to_int(charset) 37 | int_to_char = get_int_to_char(charset) 38 | lowercase = True 39 | for c in u'ABCDEFGHIJKLMNOPQRSTUVWXYZ': 40 | if c in charset: 41 | lowercase = False 42 | break 43 | 44 | pad = u'\n' 45 | 46 | to_predict = read_text_to_predict(text.read(), seq_length, lowercase, 47 | n_vocab, char_to_int, padding_char=pad, 48 | char_embedding=char_embedding) 49 | 50 | outputs = [] 51 | inputs = [] 52 | 53 | predicted = model.predict(to_predict, verbose=0) 54 | for i, sequence in enumerate(predicted): 55 | predicted_indices = [np.random.choice(n_vocab, p=p) for p in sequence] 56 | pred_str = u''.join([int_to_char[j] for j in predicted_indices]) 57 | outputs.append(pred_str) 58 | 59 | if char_embedding: 60 | indices = to_predict[i] 61 | else: 62 | indices = np.where(to_predict[i:i+1, :, :] == True)[2] 63 | inp = u''.join([int_to_char[j] for j in indices]) 64 | inputs.append(inp) 65 | 66 | idx = 0 67 | counters = {} 68 | 69 | for input_str, output_str in zip(inputs, outputs): 70 | if pad in output_str: 71 | output_str2 = align_output_to_input(input_str, output_str, 72 | empty_char=pad) 73 | else: 74 | output_str2 = output_str 75 | for i, (inp, outp) in enumerate(zip(input_str, output_str2)): 76 | if not idx + i in counters.keys(): 77 | counters[idx+i] = Counter() 78 | counters[idx+i][outp] += 1 79 | 80 | idx += 1 81 | 82 | agg_out = [] 83 | for idx, c in counters.items(): 84 | agg_out.append(c.most_common(1)[0][0]) 85 | 86 | corrected_text = u''.join(agg_out) 87 | corrected_text = corrected_text.replace(pad, u'') 88 | 89 | out_file = out_file_name(out_dir, text.name) 90 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 91 | f.write(corrected_text) 92 | 93 | 94 | if __name__ == '__main__': 95 | lstm_synced_correct_ocr() 96 | -------------------------------------------------------------------------------- /ochre/vudnc2ocr_and_gs.py: -------------------------------------------------------------------------------- 1 | import click 2 | import codecs 3 | import os 4 | 5 | from lxml import etree 6 | from nlppln.utils import create_dirs 7 | 8 | 9 | @click.command() 10 | @click.argument('in_file', type=click.Path(exists=True)) 11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 12 | def vudnc2ocr_and_gs(in_file, out_dir): 13 | create_dirs(out_dir) 14 | 15 | ocr_text_complete = [] 16 | ocr_text = [] 17 | gold_standard = [] 18 | punctuation = [] 19 | 20 | # TODO: how to handle headings? They are sentences without a closing 21 | # punctuation mark. 22 | # I think the existing LSTM code does not take into account documents (all 23 | # text is simply concatenated together). 24 | # TODO: fix treatment of " (double quotes). After a starting quote a space 25 | # is inserted. 26 | context = etree.iterparse(in_file, tag='{http://ilk.uvt.nl/folia}w', 27 | encoding='utf-8') 28 | for action, elem in context: 29 | ocr_word = '' 30 | punc = False 31 | pos_elem = elem.find('{http://ilk.uvt.nl/folia}pos') 32 | if pos_elem is not None: 33 | if pos_elem.get('class').startswith('LET'): 34 | punc = True 35 | for t in elem.iterchildren(tag='{http://ilk.uvt.nl/folia}t'): 36 | if t.get('class') == 'ocroutput': 37 | ocr_word = t.text 38 | else: 39 | gs_word = t.text 40 | 41 | if ocr_word != '': 42 | ocr_text.append(ocr_word) 43 | ocr_text_complete.append(ocr_word) 44 | gold_standard.append(gs_word) 45 | punctuation.append(punc) 46 | 47 | result = [] 48 | for i in range(len(ocr_text_complete)): 49 | #print ocr_text_complete[i] 50 | #print gold_standard[i] 51 | result.append(gold_standard[i]) 52 | space = True 53 | if punctuation[i]: 54 | #print ocr_text_complete[i-1] 55 | if i+1 < len(ocr_text_complete): 56 | if ocr_text_complete[i+1].strip().startswith(gold_standard[i]): 57 | #print 'No space after this punctuation: {}'.format(gold_standard[i]) 58 | space = False 59 | if punctuation[i+1]: 60 | #print 'No space after this punctuation: '.format(gold_standard[i]) 61 | space = False 62 | elif i+1 < len(ocr_text_complete) and punctuation[i+1]: 63 | #print '!' 64 | if i+2 < len(ocr_text_complete): 65 | if ocr_text_complete[i+2].strip().startswith(gold_standard[i+1]): 66 | #print 'Space after this word: {}'.format(gold_standard[i]) 67 | space = True 68 | else: 69 | #print 'No space after this word: {}'.format(gold_standard[i]) 70 | space = False 71 | else: 72 | space = False 73 | if space: 74 | result.append(' ') 75 | 76 | fname = os.path.basename(in_file) 77 | fname = fname.replace('.folia.xml', '.{}.txt') 78 | fname = os.path.join(out_dir, fname) 79 | with codecs.open(fname.format('gs'), 'wb', encoding='utf-8') as f: 80 | gs = u''.join(result).strip() 81 | f.write(gs) 82 | 83 | with codecs.open(fname.format('ocr'), 'wb', encoding='utf-8') as f: 84 | ocr = u' '.join(ocr_text) 85 | f.write(ocr) 86 | 87 | 88 | if __name__ == '__main__': 89 | vudnc2ocr_and_gs() 90 | -------------------------------------------------------------------------------- /tests/test_remove_title_page.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from click.testing import CliRunner 5 | 6 | # FIXME: import correct methods for testing 7 | from ochre.remove_title_page import remove_title_page 8 | 9 | 10 | # Documentation about testing click commands: http://click.pocoo.org/5/testing/ 11 | def test_remove_title_page_single_line(): 12 | runner = CliRunner() 13 | with runner.isolated_filesystem(): 14 | os.makedirs('in') 15 | os.makedirs('out') 16 | 17 | with open('in/test-without.txt', 'w') as f: 18 | content_without = 'Text starts here.\n' \ 19 | 'Second line.\n' 20 | f.write(content_without) 21 | 22 | with open('in/test-with.txt', 'w') as f: 23 | content = 'This is the title page\n' \ 24 | 'Text starts here.\n' \ 25 | 'Second line.\n' 26 | f.write(content) 27 | 28 | result = runner.invoke(remove_title_page, 29 | ['in/test-without.txt', 'in/test-with.txt', 30 | '--out_dir', 'out']) 31 | 32 | assert result.exit_code == 0 33 | 34 | assert os.path.exists('out/test-with.txt') 35 | 36 | with open('out/test-with.txt') as f: 37 | c = f.read() 38 | 39 | assert c == content_without 40 | 41 | 42 | def test_remove_title_page_multiple_lines(): 43 | runner = CliRunner() 44 | with runner.isolated_filesystem(): 45 | os.makedirs('in') 46 | os.makedirs('out') 47 | 48 | with open('in/test-without.txt', 'w') as f: 49 | content_without = 'Text starts here.\n' \ 50 | 'Second line.\n' 51 | f.write(content_without) 52 | 53 | with open('in/test-with.txt', 'w') as f: 54 | content = 'This is the title page 1.\n' \ 55 | 'This is the title page 2.\n' \ 56 | 'Text starts here.\n' \ 57 | 'Second line.\n' 58 | f.write(content) 59 | 60 | result = runner.invoke(remove_title_page, 61 | ['in/test-without.txt', 'in/test-with.txt', 62 | '--out_dir', 'out']) 63 | 64 | assert result.exit_code == 0 65 | 66 | assert os.path.exists('out/test-with.txt') 67 | 68 | with open('out/test-with.txt') as f: 69 | c = f.read() 70 | 71 | assert c == content_without 72 | 73 | 74 | def test_remove_title_page_no_lines(): 75 | runner = CliRunner() 76 | with runner.isolated_filesystem(): 77 | os.makedirs('in') 78 | os.makedirs('out') 79 | 80 | with open('in/test-without.txt', 'w') as f: 81 | content_without = 'Text starts here.\n' \ 82 | 'Second line.\n' 83 | f.write(content_without) 84 | 85 | with open('in/test-with.txt', 'w') as f: 86 | content = 'Text starts here.\n' \ 87 | 'Second line.\n' 88 | f.write(content) 89 | 90 | result = runner.invoke(remove_title_page, 91 | ['in/test-without.txt', 'in/test-with.txt', 92 | '--out_dir', 'out']) 93 | 94 | assert result.exit_code == 0 95 | 96 | assert os.path.exists('out/test-with.txt') 97 | 98 | with open('out/test-with.txt') as f: 99 | c = f.read() 100 | 101 | assert c == content_without 102 | -------------------------------------------------------------------------------- /notebooks/sac-preprocess-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "import nlppln\n", 21 | "import ochre\n", 22 | "\n", 23 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 24 | " wf.load(steps_dir=ochre.cwl_path())\n", 25 | " wf.load(step_file='https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl')\n", 26 | " \n", 27 | " print wf.list_steps()\n", 28 | " \n", 29 | " archive = wf.add_input(data_file='File')\n", 30 | " ocr_dir_name = wf.add_input(ocr_dir_name='string', default='ocr')\n", 31 | " gs_dir_name = wf.add_input(gs_dir_name='string', default='gs')\n", 32 | " align_dir_name = wf.add_input(align_dir_name='string', default='align')\n", 33 | " de_dir_name = wf.add_input(de_dir_name='string', default='de')\n", 34 | " fr_dir_name = wf.add_input(fr_dir_name='string', default='fr')\n", 35 | " \n", 36 | " data_dir = wf.tar(in_file=archive)\n", 37 | " gs_de, gs_fr, ocr_de, ocr_fr = wf.sac2gs_and_ocr(in_dir=data_dir)\n", 38 | " \n", 39 | " # create alignments\n", 40 | " alignments_de, changes_de, metadata_de = wf.align_texts_wf(gs=gs_de, ocr=ocr_de)\n", 41 | " alignments_fr, changes_fr, metadata_fr = wf.align_texts_wf(gs=gs_fr, ocr=ocr_fr)\n", 42 | " \n", 43 | " # save files to correct dirs\n", 44 | " de_dir = wf.mkdir(dir_name=de_dir_name)\n", 45 | " fr_dir = wf.mkdir(dir_name=fr_dir_name)\n", 46 | " \n", 47 | " ocr_de_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr_de)\n", 48 | " ocr_de = wf.save_dir_to_subdir(inner_dir=ocr_de_dir, outer_dir=de_dir)\n", 49 | " gs_de_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs_de)\n", 50 | " gs_de = wf.save_dir_to_subdir(inner_dir=gs_de_dir, outer_dir=de_dir)\n", 51 | " \n", 52 | " ocr_fr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr_fr)\n", 53 | " ocr_fr = wf.save_dir_to_subdir(inner_dir=ocr_fr_dir, outer_dir=fr_dir)\n", 54 | " gs_fr_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs_fr)\n", 55 | " gs_fr = wf.save_dir_to_subdir(inner_dir=gs_fr_dir, outer_dir=fr_dir)\n", 56 | " \n", 57 | " align_de_dir = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments_de)\n", 58 | " align_de = wf.save_dir_to_subdir(inner_dir=align_de_dir, outer_dir=de_dir)\n", 59 | " \n", 60 | " align_fr_dir = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments_fr)\n", 61 | " align_fr = wf.save_dir_to_subdir(inner_dir=align_fr_dir, outer_dir=fr_dir)\n", 62 | " \n", 63 | " wf.add_outputs(ocr_de=ocr_de)\n", 64 | " wf.add_outputs(gs_de=gs_de)\n", 65 | " wf.add_outputs(align_de=align_de)\n", 66 | " \n", 67 | " wf.add_outputs(ocr_fr=ocr_fr)\n", 68 | " wf.add_outputs(gs_fr=gs_fr)\n", 69 | " wf.add_outputs(align_fr=align_fr)\n", 70 | " \n", 71 | " wf.save(os.path.join(ochre.cwl_path(), 'sac-preprocess.cwl'), pack=True)" 72 | ] 73 | } 74 | ], 75 | "metadata": { 76 | "language_info": { 77 | "name": "python", 78 | "pygments_lexer": "ipython3" 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 2 83 | } 84 | -------------------------------------------------------------------------------- /ochre/rmgarbage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Implementation of rmgarbage. 4 | 5 | As described in the paper: 6 | 7 | Taghva, K., Nartker, T., Condit, A. and Borsack, J., 2001. Automatic 8 | removal of "garbage strings" in OCR text: An implementation. In The 5th 9 | World Multi-Conference on Systemics, Cybernetics and Informatics. 10 | """ 11 | import click 12 | import codecs 13 | import os 14 | import pandas as pd 15 | 16 | from string import punctuation 17 | 18 | from nlppln.utils import create_dirs, out_file_name 19 | 20 | 21 | def get_rmgarbage_errors(word): 22 | errors = [] 23 | if rmgarbage_long(word): 24 | errors.append('L') 25 | if rmgarbage_alphanumeric(word): 26 | errors.append('A') 27 | if rmgarbage_row(word): 28 | errors.append('R') 29 | if rmgarbage_vowels(word): 30 | errors.append('V') 31 | if rmgarbage_punctuation(word): 32 | errors.append('P') 33 | if rmgarbage_case(word): 34 | errors.append('C') 35 | return errors 36 | 37 | 38 | def rmgarbage_long(string, threshold=40): 39 | if len(string) > threshold: 40 | return True 41 | return False 42 | 43 | 44 | def rmgarbage_alphanumeric(string): 45 | alphanumeric_chars = sum(c.isalnum() for c in string) 46 | if len(string) > 2 and (alphanumeric_chars+0.0)/len(string) < 0.5: 47 | return True 48 | return False 49 | 50 | 51 | def rmgarbage_row(string, rep=4): 52 | for c in string: 53 | if c.isalnum(): 54 | if c * rep in string: 55 | return True 56 | return False 57 | 58 | 59 | def rmgarbage_vowels(string): 60 | string = string.lower() 61 | if len(string) > 2 and string.isalpha(): 62 | vowels = sum(c in u'aáâàåãäeéèëêuúûùüiíîìïoóôòøõö' for c in string) 63 | consonants = len(string) - vowels 64 | 65 | low = min(vowels, consonants) 66 | high = max(vowels, consonants) 67 | 68 | if low/(high+0.0) <= 0.1: 69 | return True 70 | return False 71 | 72 | 73 | def rmgarbage_punctuation(string): 74 | string = string[1:len(string)-1] 75 | 76 | punctuation_marks = set() 77 | for c in string: 78 | if c in punctuation: 79 | punctuation_marks.add(c) 80 | 81 | if len(punctuation_marks) > 1: 82 | return True 83 | return False 84 | 85 | 86 | def rmgarbage_case(string): 87 | if string[0].islower() and string[len(string)-1].islower(): 88 | for c in string: 89 | if c.isupper(): 90 | return True 91 | return False 92 | 93 | 94 | @click.command() 95 | @click.argument('in_file', type=click.File(encoding='utf-8')) 96 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path()) 97 | def rmgarbage(in_file, out_dir): 98 | create_dirs(out_dir) 99 | 100 | text = in_file.read() 101 | words = text.split() 102 | 103 | doc_id = os.path.basename(in_file.name).split('.')[0] 104 | 105 | result = [] 106 | removed = [] 107 | 108 | for word in words: 109 | errors = get_rmgarbage_errors(word) 110 | 111 | if len(errors) == 0: 112 | result.append(word) 113 | else: 114 | removed.append({'word': word, 115 | 'errors': u''.join(errors), 116 | 'doc_id': doc_id}) 117 | 118 | out_file = out_file_name(out_dir, in_file.name) 119 | with codecs.open(out_file, 'wb', encoding='utf-8') as f: 120 | f.write(u' '.join(result)) 121 | 122 | metadata_out = pd.DataFrame(removed) 123 | fname = '{}-rmgarbage-metadata.csv'.format(doc_id) 124 | out_file = out_file_name(out_dir, fname) 125 | metadata_out.to_csv(out_file, encoding='utf-8') 126 | 127 | 128 | if __name__ == '__main__': 129 | rmgarbage() 130 | -------------------------------------------------------------------------------- /ochre/keras_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import glob2 4 | 5 | from keras.models import Sequential 6 | from keras.layers import Dense 7 | from keras.layers import Dropout 8 | from keras.layers import LSTM 9 | from keras.layers import TimeDistributed 10 | from keras.layers import Bidirectional 11 | from keras.layers import RepeatVector 12 | from keras.layers import Embedding 13 | from keras.callbacks import ModelCheckpoint 14 | 15 | 16 | def initialize_model(n, dropout, seq_length, chars, output_size, layers, 17 | loss='categorical_crossentropy', optimizer='adam'): 18 | model = Sequential() 19 | model.add(LSTM(n, input_shape=(seq_length, len(chars)), 20 | return_sequences=True)) 21 | model.add(Dropout(dropout)) 22 | 23 | for _ in range(layers-1): 24 | model.add(LSTM(n, return_sequences=True)) 25 | model.add(Dropout(dropout)) 26 | 27 | model.add(TimeDistributed(Dense(len(chars), activation='softmax'))) 28 | 29 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) 30 | 31 | return model 32 | 33 | 34 | def initialize_model_bidirectional(n, dropout, seq_length, chars, output_size, 35 | layers, loss='categorical_crossentropy', 36 | optimizer='adam'): 37 | model = Sequential() 38 | model.add(Bidirectional(LSTM(n, return_sequences=True), 39 | input_shape=(seq_length, len(chars)))) 40 | model.add(Dropout(dropout)) 41 | 42 | for _ in range(layers-1): 43 | model.add(Bidirectional(LSTM(n, return_sequences=True))) 44 | model.add(Dropout(dropout)) 45 | 46 | model.add(TimeDistributed(Dense(len(chars), activation='softmax'))) 47 | 48 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) 49 | 50 | return model 51 | 52 | 53 | def initialize_model_seq2seq(n, dropout, seq_length, 54 | output_size, layers, char_embedding_size=0, 55 | loss='categorical_crossentropy', optimizer='adam', 56 | metrics=['accuracy']): 57 | model = Sequential() 58 | # encoder 59 | if char_embedding_size: 60 | n_embed = char_embedding_size 61 | model.add(Embedding(output_size, n_embed, input_length=seq_length)) 62 | model.add(LSTM(n)) 63 | else: 64 | model.add(LSTM(n, input_shape=(seq_length, output_size))) 65 | # For the decoder's input, we repeat the encoded input for each time step 66 | model.add(RepeatVector(seq_length)) 67 | # The decoder RNN could be multiple layers stacked or a single layer 68 | for _ in range(layers-1): 69 | model.add(LSTM(n, return_sequences=True)) 70 | 71 | # For each of step of the output sequence, decide which character should be 72 | # chosen 73 | model.add(TimeDistributed(Dense(output_size, activation='softmax'))) 74 | model.compile(loss=loss, optimizer=optimizer, metrics=metrics) 75 | 76 | return model 77 | 78 | 79 | def load_weights(model, weights_dir, loss='categorical_crossentropy', 80 | optimizer='adam'): 81 | epoch = 0 82 | weight_files = glob2.glob('{}{}*.hdf5'.format(weights_dir, os.sep)) 83 | if weight_files != []: 84 | fname = sorted(weight_files)[0] 85 | print('Loading weights from {}'.format(fname)) 86 | 87 | model.load_weights(fname) 88 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) 89 | 90 | m = re.match(r'.+-(\d\d).hdf5', fname) 91 | if m: 92 | epoch = int(m.group(1)) 93 | epoch += 1 94 | 95 | return epoch, model 96 | 97 | 98 | def add_checkpoint(weights_dir): 99 | filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5') 100 | checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, 101 | save_best_only=True, mode='min') 102 | return checkpoint 103 | -------------------------------------------------------------------------------- /notebooks/preprocess-dbnl_ocr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "import nlppln\n", 21 | "import ochre\n", 22 | "\n", 23 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 24 | " wf.load(steps_dir=ochre.cwl_path())\n", 25 | " #print(wf.list_steps())\n", 26 | " \n", 27 | " txt_dir = wf.add_input(txt_dir='Directory')\n", 28 | " tei_dir = wf.add_input(tei_dir='Directory')\n", 29 | " repl = wf.add_input(replacement='string', default='space')\n", 30 | " tmp_ocr_dir_name = wf.add_input(tmp_ocr_dir_name='string', default='tmp_ocr')\n", 31 | " tmp_gs_dir_name = wf.add_input(tmp_gs_dir_name='string', default='tmp_gs')\n", 32 | " convert = wf.add_input(convert='boolean', default=True)\n", 33 | " \n", 34 | " tei_files = wf.ls(in_dir=tei_dir)\n", 35 | " unnormalized_gs_files = wf.tei2txt(tei_file=tei_files, scatter='tei_file', scatter_method='dotproduct')\n", 36 | " normalized_gs_files = wf.remove_newlines(in_file=unnormalized_gs_files, replacement=repl, \n", 37 | " scatter='in_file', scatter_method='dotproduct')\n", 38 | " tmp_gs_dir = wf.save_files_to_dir(dir_name=tmp_gs_dir_name, in_files=normalized_gs_files)\n", 39 | " gs_without_empty = wf.delete_empty_files(in_dir=tmp_gs_dir)\n", 40 | " \n", 41 | " non_utf8_ocr_files = wf.ls(in_dir=txt_dir)\n", 42 | " unnormalized_ocr_files = wf.check_utf8(in_files=non_utf8_ocr_files, convert=convert)\n", 43 | " normalized_ocr_files = wf.remove_newlines(in_file=unnormalized_ocr_files, replacement=repl, \n", 44 | " scatter='in_file', scatter_method='dotproduct')\n", 45 | " tmp_ocr_dir = wf.save_files_to_dir(dir_name=tmp_ocr_dir_name, in_files=normalized_ocr_files)\n", 46 | " ocr_without_empty = wf.delete_empty_files(in_dir=tmp_ocr_dir)\n", 47 | " \n", 48 | " tmp_gs_dir = wf.save_files_to_dir(dir_name=tmp_gs_dir_name, in_files=gs_without_empty)\n", 49 | " tmp_ocr_dir = wf.save_files_to_dir(dir_name=tmp_ocr_dir_name, in_files=ocr_without_empty)\n", 50 | "\n", 51 | " gs, ocr = wf.match_ocr_and_gs(gs_dir=tmp_gs_dir, ocr_dir=tmp_ocr_dir)\n", 52 | "\n", 53 | " wf.add_outputs(gs=gs)\n", 54 | " wf.add_outputs(ocr=ocr)\n", 55 | " \n", 56 | " wf.save(os.path.join(ochre.cwl_path(),'dbnl_ocr2ocr_and_gs.cwl'), mode='pack')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import logging\n", 66 | "logging.basicConfig(format=\"%(asctime)s [%(process)d] %(levelname)-8s \"\n", 67 | " \"%(name)s,%(lineno)s\\t%(message)s\")\n", 68 | "logging.getLogger().setLevel('DEBUG')" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "import os\n", 78 | "import nlppln\n", 79 | "import ochre\n", 80 | "\n", 81 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 82 | " wf.load(steps_dir=ochre.cwl_path())\n", 83 | " \n", 84 | " in_dir = wf.add_input(in_dir='Directory')\n", 85 | " \n", 86 | " in_files = wf.ls(in_dir=in_dir)\n", 87 | " txt_files = wf.tei2txt(tei_file=in_files, scatter='tei_file')\n", 88 | " out_files = wf.delete_empty_files(in_files=txt_files)\n", 89 | " \n", 90 | " wf.add_outputs(out_files=out_files)\n", 91 | " \n", 92 | " wf.save(os.path.join(ochre.cwl_path(),'tei2txt_dir.cwl'), mode='wd')" 93 | ] 94 | } 95 | ], 96 | "metadata": { 97 | "language_info": { 98 | "name": "python", 99 | "pygments_lexer": "ipython3" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 2 104 | } 105 | -------------------------------------------------------------------------------- /lstm_synced.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import ModelCheckpoint 2 | 3 | from ochre.utils import create_training_data, read_texts, \ 4 | get_char_to_int 5 | from ochre.keras_utils import initialize_model, load_weights, \ 6 | initialize_model_bidirectional, initialize_model_seq2seq 7 | 8 | import click 9 | import os 10 | import json 11 | import codecs 12 | 13 | 14 | @click.command() 15 | @click.argument('datasets', type=click.File()) 16 | @click.argument('data_dir', type=click.Path(exists=True)) 17 | @click.option('--weights_dir', '-w', default=os.getcwd(), type=click.Path()) 18 | def train_lstm(datasets, data_dir, weights_dir): 19 | # lees data in en maak character mappings 20 | # genereer trainings data 21 | seq_length = 25 22 | num_nodes = 256 23 | layers = 2 24 | batch_size = 100 25 | step = 3 # step size used to create data (3 = use every third sequence) 26 | lowercase = True 27 | bidirectional = False 28 | seq2seq = True 29 | 30 | print('Sequence lenght: {}'.format(seq_length)) 31 | print('Number of nodes in hidden layers: {}'.format(num_nodes)) 32 | print('Number of hidden layers: {}'.format(layers)) 33 | print('Batch size: {}'.format(batch_size)) 34 | print('Lowercase data: {}'.format(lowercase)) 35 | print('Bidirectional layers: {}'.format(bidirectional)) 36 | print('Seq2seq: {}'.format(seq2seq)) 37 | 38 | division = json.load(datasets) 39 | 40 | raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir) 41 | raw_test, gs_test, ocr_test = read_texts(division.get('test'), data_dir) 42 | raw_train, gs_train, ocr_train = read_texts(division.get('train'), data_dir) 43 | 44 | raw_text = ''.join([raw_val, raw_test, raw_train]) 45 | if lowercase: 46 | raw_text = raw_text.lower() 47 | 48 | #print('Number of texts: {}'.format(len(data_files))) 49 | 50 | chars = sorted(list(set(raw_text))) 51 | chars.append(u'\n') # padding character 52 | char_to_int = get_char_to_int(chars) 53 | 54 | # save charset to file 55 | if lowercase: 56 | fname = 'chars-lower.txt' 57 | else: 58 | fname = 'chars.txt' 59 | chars_file = os.path.join(weights_dir, fname) 60 | with codecs.open(chars_file, 'wb', encoding='utf-8') as f: 61 | f.write(u''.join(chars)) 62 | 63 | n_chars = len(raw_text) 64 | n_vocab = len(chars) 65 | 66 | print('Total Characters: {}'.format(n_chars)) 67 | print('Total Vocab: {}'.format(n_vocab)) 68 | 69 | numTrainSamples, trainDataGen = create_training_data(ocr_train, gs_train, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase, step=step) 70 | numTestSamples, testDataGen = create_training_data(ocr_test, gs_test, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase) 71 | numValSamples, valDataGen = create_training_data(ocr_val, gs_val, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase) 72 | 73 | n_patterns = numTrainSamples 74 | print("Train Patterns: {}".format(n_patterns)) 75 | print("Validation Patterns: {}".format(numValSamples)) 76 | print("Test Patterns: {}".format(numTestSamples)) 77 | print('Total: {}'.format(numTrainSamples+numTestSamples+numValSamples)) 78 | 79 | if bidirectional: 80 | model = initialize_model_bidirectional(num_nodes, 0.5, seq_length, 81 | chars, n_vocab, layers) 82 | elif seq2seq: 83 | model = initialize_model_seq2seq(num_nodes, 0.5, seq_length, 84 | n_vocab, layers) 85 | else: 86 | model = initialize_model(num_nodes, 0.5, seq_length, chars, n_vocab, 87 | layers) 88 | epoch, model = load_weights(model, weights_dir) 89 | 90 | # initialize saving of weights 91 | filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5') 92 | checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, 93 | save_best_only=True, mode='min') 94 | callbacks_list = [checkpoint] 95 | 96 | # do training (and save weights) 97 | model.fit_generator(trainDataGen, steps_per_epoch=int(numTrainSamples/batch_size), epochs=40, validation_data=valDataGen, validation_steps=int(numValSamples/batch_size), callbacks=callbacks_list, initial_epoch=epoch) 98 | 99 | 100 | if __name__ == '__main__': 101 | train_lstm() 102 | -------------------------------------------------------------------------------- /notebooks/align-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import nlppln\n", 20 | "import ochre\n", 21 | "import os\n", 22 | "\n", 23 | "working_dir = '~/cwl-working-dir/'\n", 24 | "\n", 25 | "with nlppln.WorkflowGenerator(working_dir=working_dir) as wf:\n", 26 | " wf.load(steps_dir=ochre.cwl_path())\n", 27 | " print(wf.list_steps())\n", 28 | "\n", 29 | " gs_files = wf.add_input(gs='File[]')\n", 30 | " ocr_files = wf.add_input(ocr='File[]')\n", 31 | " merged_metadata_name = wf.add_input(align_m='string', default='merged_metadata.csv')\n", 32 | " merged_changes_name = wf.add_input(align_c='string', default='merged_changes.csv')\n", 33 | " \n", 34 | " changes, metadata = wf.align(file1=ocr_files, file2=gs_files, \n", 35 | " scatter=['file1', 'file2'], scatter_method='dotproduct')\n", 36 | " merged1 = wf.merge_json(in_files=metadata, name=merged_metadata_name)\n", 37 | " merged2 = wf.merge_json(in_files=changes, name=merged_changes_name)\n", 38 | " \n", 39 | " alignments = wf.char_align(gs_text=gs_files, metadata=metadata, ocr_text=ocr_files, \n", 40 | " scatter=['gs_text', 'ocr_text', 'metadata'], scatter_method='dotproduct')\n", 41 | " \n", 42 | " wf.add_outputs(alignments=alignments)\n", 43 | " wf.add_outputs(metadata=merged1)\n", 44 | " wf.add_outputs(changes=merged2)\n", 45 | " wf.save(os.path.join(ochre.cwl_path(), 'align-texts-wf.cwl'), wd=True, relative=False)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# align directory\n", 55 | "import nlppln\n", 56 | "import ochre\n", 57 | "import os\n", 58 | "\n", 59 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 60 | " wf.load(steps_dir=ochre.cwl_path())\n", 61 | " print(wf.list_steps())\n", 62 | "\n", 63 | " gs = wf.add_input(gs='Directory')\n", 64 | " ocr = wf.add_input(ocr='Directory')\n", 65 | " align_dir_name = wf.add_input(align_dir_name='string', default='align')\n", 66 | " \n", 67 | " gs_files = wf.ls(in_dir=gs)\n", 68 | " ocr_files = wf.ls(in_dir=ocr)\n", 69 | " \n", 70 | " alignments, changes, metadata = wf.align_texts_wf(gs=gs_files, ocr=ocr_files)\n", 71 | " \n", 72 | " align = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments)\n", 73 | " \n", 74 | " wf.add_outputs(align=align)\n", 75 | " wf.save(os.path.join(ochre.cwl_path(), 'align-dir-pack.cwl'), pack=True, relative=False)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# align test files only\n", 85 | "import nlppln\n", 86 | "import ochre\n", 87 | "import os\n", 88 | "\n", 89 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 90 | " wf.load(steps_dir=ochre.cwl_path())\n", 91 | " print(wf.list_steps())\n", 92 | "\n", 93 | " gs_dir = wf.add_input(gs_dir='Directory')\n", 94 | " ocr_dir = wf.add_input(ocr_dir='Directory')\n", 95 | " data_div = wf.add_input(data_div='File')\n", 96 | " div_name = wf.add_input(div_name='string')\n", 97 | " align_dir_name = wf.add_input(align_dir_name='string', default='align')\n", 98 | "\n", 99 | " test_gs = wf.select_test_files(datadivision=data_div, name=div_name, in_dir=gs_dir)\n", 100 | " test_ocr = wf.select_test_files(datadivision=data_div, name=div_name, in_dir=ocr_dir)\n", 101 | "\n", 102 | " alignments, changes, metadata = wf.align_texts_wf(gs=test_gs, ocr=test_ocr)\n", 103 | "\n", 104 | " align = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments)\n", 105 | "\n", 106 | " wf.add_outputs(align=align)\n", 107 | " wf.save(os.path.join(ochre.cwl_path(), 'align-test-files-pack.cwl'), pack=True, relative=False)" 108 | ] 109 | } 110 | ], 111 | "metadata": { 112 | "language_info": { 113 | "name": "python", 114 | "pygments_lexer": "ipython3" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 2 119 | } 120 | -------------------------------------------------------------------------------- /notebooks/icdar2017-ngrams.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import json\n", 33 | "\n", 34 | "from ochre.utils import read_texts\n", 35 | "\n", 36 | "datasets = '/home/jvdzwaan/data/icdar2017st/eng_monograph/datadivision.json'\n", 37 | "data_dir = '/home/jvdzwaan/data/icdar2017st/eng_monograph/aligned/'\n", 38 | "\n", 39 | "with open(datasets) as d:\n", 40 | " division = json.load(d)\n", 41 | "print(len(division['train']))\n", 42 | "print(len(division['test']))\n", 43 | "print(len(division['val']))" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from ochre.utils import get_chars, get_sequences\n", 53 | "\n", 54 | "seq_length = 53\n", 55 | "\n", 56 | "raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir)\n", 57 | "raw_test, gs_test, ocr_test = read_texts(division.get('test'), data_dir)\n", 58 | "raw_train, gs_train, ocr_train = read_texts(division.get('train'), data_dir)\n", 59 | "\n", 60 | "chars, num_chars, ci = get_chars(raw_val, raw_test, raw_train, False)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "%%time\n", 70 | "from collections import defaultdict, Counter\n", 71 | "\n", 72 | "from nlppln.utils import get_files\n", 73 | "\n", 74 | "#c = defaultdict(int)\n", 75 | "c = Counter()\n", 76 | "\n", 77 | "data_dirs = ['/home/jvdzwaan/data/icdar2017st/eng_monograph/aligned/',\n", 78 | " '/home/jvdzwaan/data/icdar2017st/eng_periodical/aligned/',\n", 79 | " '/home/jvdzwaan/data/icdar2017st/fr_monograph/aligned/',\n", 80 | " '/home/jvdzwaan/data/icdar2017st/fr_periodical/aligned/']\n", 81 | "\n", 82 | "for dd in data_dirs:\n", 83 | " in_files = get_files(dd)\n", 84 | " \n", 85 | " raw, gs, ocr = read_texts(in_files, data_dir=None)\n", 86 | " \n", 87 | " for c1, c2 in zip(ocr, gs):\n", 88 | " if c1 != c2 and c1 != ' ' and c2 != ' ':\n", 89 | " c[(c1, c2)] += 1" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "for ch, f in c.most_common(10):\n", 99 | " print(ch, f)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "c[('1', 'I')]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "%%time\n", 118 | "from ochre.utils import get_sequences\n", 119 | "\n", 120 | "data_dirs = ['/home/jvdzwaan/data/icdar2017st/eng_monograph/aligned/',\n", 121 | " '/home/jvdzwaan/data/icdar2017st/eng_periodical/aligned/',\n", 122 | " '/home/jvdzwaan/data/icdar2017st/fr_monograph/aligned/',\n", 123 | " '/home/jvdzwaan/data/icdar2017st/fr_periodical/aligned/']\n", 124 | "\n", 125 | "c = Counter()\n", 126 | "\n", 127 | "for dd in data_dirs:\n", 128 | " in_files = get_files(dd)\n", 129 | " \n", 130 | " raw, gs, ocr = read_texts(in_files, data_dir=None)\n", 131 | " \n", 132 | " gs_seqs, ocr_seqs = get_sequences(gs, ocr, 3)\n", 133 | " \n", 134 | " for ocr, gs in zip(ocr_seqs, gs_seqs):\n", 135 | " c[(ocr, gs)] += 1" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "for ch, f in c.most_common(20):\n", 145 | " print(ch, f)" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "language_info": { 151 | "name": "python", 152 | "pygments_lexer": "ipython3" 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 2 157 | } 158 | -------------------------------------------------------------------------------- /notebooks/2017-baseline-nn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import pickle\n", 33 | "\n", 34 | "seq_length = 53\n", 35 | "\n", 36 | "with open('train.pkl', 'rb') as f:\n", 37 | " gs_selected_train, ocr_selected_train = pickle.load(f)\n", 38 | " \n", 39 | "with open('val.pkl', 'rb') as f:\n", 40 | " gs_selected_val, ocr_selected_val = pickle.load(f)\n", 41 | " \n", 42 | "with open('ci.pkl', 'rb') as f:\n", 43 | " ci = pickle.load(f)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "gs_selected_train[:50]" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from ochre.datagen import DataGenerator\n", 62 | "\n", 63 | "dg_val = DataGenerator(xData=ocr_selected_val[:10], yData=gs_selected_val[:10], char_to_int=ci,\n", 64 | " seq_length=seq_length, padding_char='\\n', oov_char='@',\n", 65 | " batch_size=10, shuffle=False)\n", 66 | "dg_train = DataGenerator(xData=ocr_selected_train[:50], yData=gs_selected_train[:50], char_to_int=ci,\n", 67 | " seq_length=seq_length, padding_char='\\n', oov_char='@',\n", 68 | " batch_size=10, shuffle=False)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "from keras.models import Sequential\n", 78 | "from keras.layers import Dense\n", 79 | "from keras.layers import Dropout\n", 80 | "from keras.layers import LSTM\n", 81 | "from keras.layers import TimeDistributed\n", 82 | "from keras.layers import Bidirectional\n", 83 | "from keras.layers import RepeatVector\n", 84 | "from keras.layers import Embedding\n", 85 | "from keras.callbacks import ModelCheckpoint\n", 86 | "\n", 87 | "n_nodes = 1000\n", 88 | "dropout = 0.2\n", 89 | "n_embed = 256\n", 90 | "n_vocab = len(ci)\n", 91 | "\n", 92 | "loss='categorical_crossentropy'\n", 93 | "optimizer='adam'\n", 94 | "metrics=['accuracy']\n", 95 | "\n", 96 | "model = Sequential()\n", 97 | "\n", 98 | "# encoder\n", 99 | "\n", 100 | "model.add(Embedding(n_vocab, n_embed, input_length=seq_length))\n", 101 | "model.add(LSTM(n_nodes, input_shape=(seq_length, n_vocab)))\n", 102 | "# For the decoder's input, we repeat the encoded input for each time step\n", 103 | "model.add(RepeatVector(seq_length))\n", 104 | "model.add(LSTM(n_nodes, return_sequences=True))\n", 105 | "\n", 106 | "# For each of step of the output sequence, decide which character should be\n", 107 | "# chosen\n", 108 | "model.add(TimeDistributed(Dense(n_vocab, activation='softmax')))\n", 109 | "model.compile(loss=loss, optimizer=optimizer, metrics=metrics)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# initialize saving of weights\n", 119 | "#filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5')\n", 120 | "filepath = '{loss:.4f}-{epoch:02d}.hdf5'\n", 121 | "checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1,\n", 122 | " save_best_only=True, mode='min')\n", 123 | "callbacks_list = [checkpoint]\n", 124 | "\n", 125 | "# do training (and save weights)\n", 126 | "model.fit_generator(dg_train, steps_per_epoch=len(dg_train), epochs=10, \n", 127 | " validation_data=dg_val, \n", 128 | " validation_steps=len(dg_val), callbacks=callbacks_list,\n", 129 | " use_multiprocessing=True,\n", 130 | " workers=3)" 131 | ] 132 | } 133 | ], 134 | "metadata": { 135 | "language_info": { 136 | "name": "python", 137 | "pygments_lexer": "ipython3" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /notebooks/kb_tss_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# subworkflow to convert all alto files in the subdirectories of one directory into text files\n", 20 | "import nlppln\n", 21 | "\n", 22 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 23 | " wf.load(steps_dir='../ochre/cwl/')\n", 24 | " print(wf.list_steps())\n", 25 | " \n", 26 | " in_dir = wf.add_input(in_dir='Directory')\n", 27 | " \n", 28 | " # inputs with default values\n", 29 | " recursive = wf.add_input(recursive='boolean', default=True)\n", 30 | " endswith = wf.add_input(endswith='string', default='alto.xml')\n", 31 | " element = wf.add_input(element='string[]', default=['SP'])\n", 32 | " in_fmt = wf.add_input(in_fmt='string', default='alto')\n", 33 | " out_fmt = wf.add_input(out_fmt='string', default='text')\n", 34 | " \n", 35 | " in_files = wf.ls(in_dir=in_dir, recursive=recursive, endswith=endswith)\n", 36 | " cleaned_files = wf.remove_xml_elements(element=element, xml_file=in_files, \n", 37 | " scatter='xml_file', scatter_method='dotproduct')\n", 38 | " text_pages = wf.ocr_transform(in_file=cleaned_files, in_fmt=in_fmt, out_fmt=out_fmt,\n", 39 | " scatter='in_file', scatter_method='dotproduct')\n", 40 | " text_files = wf.kb_tss_concat_files(in_files=text_pages)\n", 41 | " \n", 42 | " \n", 43 | " wf.add_outputs(text_files=text_files)\n", 44 | " \n", 45 | " wf.save('../ochre/cwl/kb-tss-preprocess-single-dir.cwl', wd=True)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Workflow to preprocess all kb_tss data\n", 55 | "import nlppln\n", 56 | "\n", 57 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 58 | " wf.load(steps_dir='../ochre/cwl/')\n", 59 | " print(wf.list_steps())\n", 60 | " \n", 61 | " karmac_dir = wf.add_input(karmac_dir='Directory')\n", 62 | " original_dir = wf.add_input(orignal_dir='Directory')\n", 63 | " xcago_dir = wf.add_input(xcago_dir='Directory')\n", 64 | " \n", 65 | " karmac_name = wf.add_input(karmac_name='string', default='Karmac')\n", 66 | " original_name = wf.add_input(original_name='string', default='Origineel')\n", 67 | " xcago_name = wf.add_input(xcago_name='string', default='X-Cago')\n", 68 | " \n", 69 | " karmac_aligned_name = wf.add_input(karmac_aligned_name='string', default='align-Karmac-Origineel')\n", 70 | " xcago_aligned_name = wf.add_input(xcago_aligned_name='string', default='align-X-Cago-Origineel')\n", 71 | " \n", 72 | " karmac_files = wf.kb_tss_preprocess_single_dir(in_dir=karmac_dir)\n", 73 | " original_files = wf.kb_tss_preprocess_single_dir(in_dir=original_dir)\n", 74 | " xcago_files = wf.kb_tss_preprocess_single_dir(in_dir=xcago_dir)\n", 75 | " \n", 76 | " karmac_dir_new = wf.save_files_to_dir(dir_name=karmac_name, in_files=karmac_files)\n", 77 | " original_dir_new = wf.save_files_to_dir(dir_name=original_name, in_files=original_files)\n", 78 | " xcago_dir_new = wf.save_files_to_dir(dir_name=xcago_name, in_files=xcago_files)\n", 79 | " \n", 80 | " karmac_alignments, karmac_changes, karmac_metadata = wf.align_texts_wf(gs=karmac_files, ocr=original_files)\n", 81 | " xcago_alignments, xcago_changes, xcago_metadata = wf.align_texts_wf(gs=xcago_files, ocr=original_files)\n", 82 | " \n", 83 | " karmac_align_dir = wf.save_files_to_dir(dir_name=karmac_aligned_name, in_files=karmac_alignments)\n", 84 | " xcago_align_dir = wf.save_files_to_dir(dir_name=xcago_aligned_name, in_files=xcago_alignments)\n", 85 | " \n", 86 | " wf.add_outputs(karmac=karmac_dir_new)\n", 87 | " wf.add_outputs(original=original_dir_new)\n", 88 | " wf.add_outputs(xcago=xcago_dir_new)\n", 89 | " wf.add_outputs(karmac_align=karmac_align_dir)\n", 90 | " wf.add_outputs(xcago_align=xcago_align_dir)\n", 91 | " \n", 92 | " wf.save('../ochre/cwl/kb-tss-preprocess-all.cwl', pack=True)" 93 | ] 94 | } 95 | ], 96 | "metadata": { 97 | "language_info": { 98 | "name": "python", 99 | "pygments_lexer": "ipython3" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 2 104 | } 105 | -------------------------------------------------------------------------------- /notebooks/ocr-evaluation-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import nlppln\n", 20 | "import ochre\n", 21 | "import os\n", 22 | "\n", 23 | "with nlppln.WorkflowGenerator(working_dir='/Users/jvdzwaan/cwl-working-dir/') as wf:\n", 24 | " wf.load(steps_dir=ochre.cwl_path())\n", 25 | " wf.load(step_file='https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl')\n", 26 | " \n", 27 | " #print wf.list_steps()\n", 28 | "\n", 29 | " gt_file = wf.add_input(gt='File')\n", 30 | " ocr_file = wf.add_input(ocr='File')\n", 31 | " xmx = wf.add_input(xmx='string?')\n", 32 | "\n", 33 | " out_file = wf.ocrevaluation(gt=gt_file, ocr=ocr_file, xmx=xmx)\n", 34 | " character_data, global_data = wf.ocrevaluation_extract(in_file=out_file)\n", 35 | "\n", 36 | " wf.add_outputs(character_data=character_data)\n", 37 | " wf.add_outputs(global_data=global_data)\n", 38 | "\n", 39 | " # can be used as a subworkflow\n", 40 | " wf.save(os.path.join(ochre.cwl_path(), 'ocrevaluation-performance-wf.cwl'), mode='wd')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import nlppln\n", 50 | "import ochre\n", 51 | "import os\n", 52 | "\n", 53 | "with nlppln.WorkflowGenerator(working_dir='/Users/jvdzwaan/cwl-working-dir/') as wf:\n", 54 | " wf.load(steps_dir=ochre.cwl_path())\n", 55 | " \n", 56 | " print(wf.list_steps())\n", 57 | "\n", 58 | " gt_dir = wf.add_input(gt='Directory')\n", 59 | " ocr_dir = wf.add_input(ocr='Directory')\n", 60 | " xmx = wf.add_input(xmx='string?')\n", 61 | " performance_file = wf.add_input(out_name='string?', default='performance.csv')\n", 62 | "\n", 63 | " ocr_files = wf.ls(in_dir=ocr_dir)\n", 64 | " gt_files = wf.ls(in_dir=gt_dir)\n", 65 | " \n", 66 | " character_data, global_data = wf.ocrevaluation_performance_wf(gt=gt_files, ocr=ocr_files, xmx=xmx,\n", 67 | " scatter=['gt', 'ocr'], scatter_method='dotproduct')\n", 68 | " \n", 69 | " merged = wf.merge_csv(in_files=global_data, name=performance_file)\n", 70 | "\n", 71 | " wf.add_outputs(performance=merged)\n", 72 | "\n", 73 | " wf.save(os.path.join(ochre.cwl_path(), 'ocrevaluation-performance-wf-pack.cwl'), mode='pack')" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "import nlppln\n", 83 | "import ochre\n", 84 | "import os\n", 85 | "\n", 86 | "with nlppln.WorkflowGenerator(working_dir='/Users/jvdzwaan/cwl-working-dir/') as wf:\n", 87 | " wf.load(steps_dir=ochre.cwl_path())\n", 88 | " wf.load(step_file='https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl')\n", 89 | " print(wf.list_steps())\n", 90 | "\n", 91 | " ocr_dir = wf.add_input(ocr='Directory')\n", 92 | " gt_dir = wf.add_input(gt='Directory')\n", 93 | " datadiv = wf.add_input(datadivision='File')\n", 94 | " div_name = wf.add_input(div_name='string', default='test')\n", 95 | " gt_dir_name = wf.add_input(gt_dir_name='string', default='gs')\n", 96 | " ocr_dir_name = wf.add_input(ocr_dir_name='string', default='ocr')\n", 97 | " fname = wf.add_input(out_name='string?', default='performance.csv')\n", 98 | "\n", 99 | " ocr_files = wf.select_test_files(datadivision=datadiv, in_dir=ocr_dir, name=div_name)\n", 100 | " gt_files = wf.select_test_files(datadivision=datadiv, in_dir=gt_dir, name=div_name)\n", 101 | "\n", 102 | " character_data, global_data = wf.ocrevaluation_performance_wf(gt=gt_files, ocr=ocr_files, \n", 103 | " scatter=['gt', 'ocr'], \n", 104 | " scatter_method='dotproduct')\n", 105 | " \n", 106 | " merged = wf.merge_csv(in_files=global_data, name=performance_file)\n", 107 | "\n", 108 | " wf.add_outputs(performance=merged)\n", 109 | "\n", 110 | " wf.save(os.path.join(ochre.cwl_path(), 'ocrevaluation-performance-test-files-wf-pack.cwl'), mode='pack')" 111 | ] 112 | } 113 | ], 114 | "metadata": { 115 | "language_info": { 116 | "name": "python", 117 | "pygments_lexer": "ipython3" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /ochre/train_seq2seq.py: -------------------------------------------------------------------------------- 1 | from ochre.utils import initialize_model_seq2seq, load_weights, save_charset, \ 2 | create_training_data, read_texts, get_chars, \ 3 | add_checkpoint 4 | 5 | import click 6 | import os 7 | import json 8 | 9 | 10 | @click.command() 11 | @click.argument('datasets', type=click.File()) 12 | @click.argument('data_dir', type=click.Path(exists=True)) 13 | @click.option('--weights_dir', '-w', default=os.getcwd(), type=click.Path()) 14 | def train_lstm(datasets, data_dir, weights_dir): 15 | seq_length = 25 16 | pred_chars = 1 17 | num_nodes = 256 18 | layers = 2 19 | batch_size = 100 20 | step = 3 # step size used to create data (3 = use every third sequence) 21 | lowercase = True 22 | char_embedding_size = 16 23 | pad = u'\n' 24 | 25 | print('Sequence length: {}'.format(seq_length)) 26 | print('Predict characters: {}'.format(pred_chars)) 27 | print('Use character embedding: {}'.format(bool(char_embedding_size))) 28 | if char_embedding_size: 29 | print('Char embedding size: {}'.format(char_embedding_size)) 30 | print('Number of nodes in hidden layers: {}'.format(num_nodes)) 31 | print('Number of hidden layers: {}'.format(layers)) 32 | print('Batch size: {}'.format(batch_size)) 33 | print('Lowercase data: {}'.format(lowercase)) 34 | 35 | div = json.load(datasets) 36 | 37 | raw_val, gs_val, ocr_val = read_texts(div.get('val'), data_dir) 38 | raw_test, gs_test, ocr_test = read_texts(div.get('test'), data_dir) 39 | raw_train, gs_train, ocr_train = read_texts(div.get('train'), data_dir) 40 | 41 | chars, n_vocab, char_to_int = get_chars(raw_val, raw_test, raw_train, 42 | lowercase, padding_char=pad) 43 | # save charset to file 44 | save_charset(weights_dir, chars, lowercase) 45 | 46 | print('Total Vocab: {}'.format(n_vocab)) 47 | 48 | embed = bool(char_embedding_size) 49 | nTrainSamples, trainData = create_training_data(ocr_train, 50 | gs_train, 51 | char_to_int, 52 | n_vocab, 53 | seq_length=seq_length, 54 | batch_size=batch_size, 55 | lowercase=lowercase, 56 | step=step, 57 | predict_chars=pred_chars, 58 | char_embedding=embed) 59 | nTestSamples, testData = create_training_data(ocr_test, 60 | gs_test, 61 | char_to_int, 62 | n_vocab, 63 | seq_length=seq_length, 64 | batch_size=batch_size, 65 | lowercase=lowercase, 66 | predict_chars=pred_chars, 67 | char_embedding=embed) 68 | nValSamples, valData = create_training_data(ocr_val, 69 | gs_val, 70 | char_to_int, 71 | n_vocab, 72 | seq_length=seq_length, 73 | batch_size=batch_size, 74 | lowercase=lowercase, 75 | predict_chars=pred_chars, 76 | char_embedding=embed) 77 | 78 | n_patterns = nTrainSamples 79 | print("Train Patterns: {}".format(n_patterns)) 80 | print("Validation Patterns: {}".format(nValSamples)) 81 | print("Test Patterns: {}".format(nTestSamples)) 82 | print('Total: {}'.format(nTrainSamples+nTestSamples+nValSamples)) 83 | 84 | model = initialize_model_seq2seq(num_nodes, 0.5, seq_length, pred_chars, 85 | n_vocab, layers, char_embedding_size) 86 | epoch, model = load_weights(model, weights_dir) 87 | callbacks_list = [add_checkpoint(weights_dir)] 88 | 89 | # do training (and save weights) 90 | model.fit_generator(trainData, 91 | steps_per_epoch=int(nTrainSamples/batch_size), 92 | epochs=40, 93 | validation_data=valData, 94 | validation_steps=int(nValSamples/batch_size), 95 | callbacks=callbacks_list, 96 | initial_epoch=epoch) 97 | 98 | 99 | if __name__ == '__main__': 100 | train_lstm() 101 | -------------------------------------------------------------------------------- /notebooks/Compare-performance-of-pred.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "#performance_4 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/performance_A8P1_model-50000_2019-05-16.csv', index_col=0)\n", 33 | "performance_4 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/performance_A8P1_model-50000_2019-05-19.csv', index_col=0)\n", 34 | "performance_4.columns = ['CER_40000', 'WER_40000', 'WER (order independent)_40000']\n", 35 | "#performance_4" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "performance_12 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/performance_A8P1_model-120000_2019-05-22.csv', index_col=0)\n", 45 | "performance_12.columns = ['CER_120000', 'WER_120000', 'WER (order independent)_120000']\n", 46 | "#performance_12" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "performance_fr11 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/FR11_match-gs_vs_GT_test-files_2019-05-22.csv', index_col=0)\n", 56 | "performance_fr11.columns = ['CER_FR11', 'WER_FR11', 'WER (order independent)_FR11']\n", 57 | "#performance_fr11" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "performance_old = pd.read_csv('/home/jvdzwaan/data/kb-ocr/text_aligned_blocks-match_gs-20190514.csv', index_col=0)\n", 67 | "#performance_old" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "performance_old.columns = ['CER_old', 'WER_old', 'WER (order independent)_old']\n", 77 | "#performance_old" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "result = performance_12.join(performance_4).join(performance_old).join(performance_fr11)\n", 87 | "#result" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "result[['CER_old', 'CER_FR11', 'CER_40000', 'CER_120000']].plot(kind='bar', figsize=(15,7))" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "result[['WER_old', 'WER_FR11', 'WER_40000', 'WER_120000']].plot(kind='bar', figsize=(15,7))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "result[['WER (order independent)_old', 'WER (order independent)_FR11', 'WER (order independent)_40000', 'WER (order independent)_120000']].plot(kind='bar', figsize=(15,7))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "print('CER old:')\n", 124 | "print(result['CER_old'].describe())" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "print('CER model')\n", 134 | "print(result['CER_120000'].describe())" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "print('Afname:')\n", 144 | "print((result['CER_old'] - result['CER_120000']).describe())" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "print('GT')\n", 154 | "print(result['CER_FR11'].describe())" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "print('Afname:')\n", 164 | "print((result['CER_old'] - result['CER_FR11']).describe())" 165 | ] 166 | } 167 | ], 168 | "metadata": { 169 | "language_info": { 170 | "name": "python", 171 | "pygments_lexer": "ipython3" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /notebooks/character_counts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import pandas as pd\n", 20 | "\n", 21 | "ocr = pd.read_csv('/Users/janneke/Documents/data/ocr/char_counts_ocr.csv', index_col=0, encoding='utf-8')\n", 22 | "print ocr.shape\n", 23 | "ocr" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "gs = pd.read_csv('/Users/janneke/Documents/data/ocr/char_counts_gs.csv', index_col=0, encoding='utf-8')\n", 33 | "print gs.shape\n", 34 | "gs" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "ocr_unique = set(ocr.columns) - set(gs.columns)\n", 44 | "print ' '.join(ocr_unique)\n", 45 | "# characters that are in ocr, but not in gs" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "ocr[list(ocr_unique)].sum()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "gs_unique = set(gs.columns) - set(ocr.columns)\n", 64 | "print ' '.join(gs_unique)\n", 65 | "# characters that are in gs, but not in ocr" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "gs[list(gs_unique)].sum()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "for c in gs_unique:\n", 84 | " print c, repr(c), gs[c].sum()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "print ocr.columns\n", 94 | "char_mapping = {}\n", 95 | "multiple_chars = []\n", 96 | "for c in list(ocr.columns) + list(gs.columns):\n", 97 | " l = len(c.encode('utf-8'))\n", 98 | " if l > 1:\n", 99 | " print repr(c)\n", 100 | " multiple_chars.append(c)\n", 101 | " else:\n", 102 | " char_mapping[ord(c)] = c\n", 103 | " #print ord(c)\n", 104 | "print len(multiple_chars)\n", 105 | "#print char_mapping\n", 106 | "\n", 107 | "for c in set(multiple_chars):\n", 108 | " for i in range(1, 256):\n", 109 | " if i not in char_mapping.keys():\n", 110 | " char_mapping[i] = c\n", 111 | " break\n", 112 | "print char_mapping\n", 113 | "print len(char_mapping)\n", 114 | "print sorted(char_mapping.keys())\n", 115 | "print u' '.join(multiple_chars)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "mapping = {}\n", 125 | "for k, v in char_mapping.items():\n", 126 | " mapping[v] = k\n", 127 | "print mapping" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "print mapping[u'ó']\n", 137 | "print char_mapping[8]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "import json\n", 147 | "import codecs\n", 148 | "\n", 149 | "with codecs.open('../char_mapping.json', 'w', encoding='utf-8') as f:\n", 150 | " json.dump(mapping, f, indent=2)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "extended_ascii = \"\".join([chr(i) for i in xrange(256)])\n", 160 | "print extended_ascii.decode('latin1')\n", 161 | "for i in xrange(256):\n", 162 | " if len(chr(i)) > 1:\n", 163 | " print i" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "char_mapping = {}\n", 173 | "multiple_chars = []\n", 174 | "\n", 175 | "for c in list(ocr.columns) + list(gs.columns):\n", 176 | " try:\n", 177 | " char = c.encode('latin1')\n", 178 | " char_mapping[ord(char)] = char\n", 179 | " except:\n", 180 | " multiple_chars.append(c)\n", 181 | " \n", 182 | "for c in set(multiple_chars):\n", 183 | " for i in range(1, 256):\n", 184 | " if i not in char_mapping.keys():\n", 185 | " char_mapping[i] = c\n", 186 | " break\n", 187 | "\n", 188 | "print char_mapping\n", 189 | "mapping = {}\n", 190 | "for k, v in char_mapping.items():\n", 191 | " mapping[v] = k\n", 192 | "print mapping\n", 193 | "with open('../char_mapping2.json', 'w') as f:\n", 194 | " json.dump(mapping, f, indent=2)" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "language_info": { 200 | "name": "python", 201 | "pygments_lexer": "ipython3" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 1 206 | } 207 | -------------------------------------------------------------------------------- /notebooks/word-mapping-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# to be used as subworkflow\n", 20 | "import nlppln\n", 21 | "import ochre\n", 22 | "import os\n", 23 | "\n", 24 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 25 | " wf.load(steps_dir=ochre.cwl_path())\n", 26 | " print(wf.list_steps())\n", 27 | " \n", 28 | " wf.set_documentation('This workflow is meant to be used as a subworkflow.')\n", 29 | "\n", 30 | " gs_unnormalized = wf.add_input(gs_files='File[]')\n", 31 | " ocr_unnormalized = wf.add_input(ocr_files='File[]')\n", 32 | " language = wf.add_input(language='string')\n", 33 | " lowercase = wf.add_input(lowercase='boolean?')\n", 34 | " align_metadata = wf.add_input(align_m='string?')\n", 35 | " align_changes = wf.add_input(align_c='string?')\n", 36 | " word_mapping_name = wf.add_input(wm_name='string?')\n", 37 | "\n", 38 | " gs = wf.normalize_whitespace_punctuation(meta_in=gs_unnormalized, scatter=['meta_in'], scatter_method='dotproduct')\n", 39 | " ocr = wf.normalize_whitespace_punctuation(meta_in=ocr_unnormalized, scatter=['meta_in'], scatter_method='dotproduct')\n", 40 | "\n", 41 | " alignments, changes, metadata = wf.align_texts_wf(gs=gs, ocr=ocr, align_c=align_changes, align_m=align_metadata)\n", 42 | " \n", 43 | " gs_saf = wf.pattern(in_file=gs, language=language, scatter='in_file', scatter_method='dotproduct')\n", 44 | "\n", 45 | " mappings = wf.create_word_mappings(alignments=alignments, saf=gs_saf, lowercase=lowercase, \n", 46 | " scatter=['alignments', 'saf'], scatter_method='dotproduct')\n", 47 | " merged = wf.merge_csv(in_files=mappings, name=word_mapping_name)\n", 48 | "\n", 49 | " wf.add_outputs(wm_mapping=merged)\n", 50 | " wf.save(os.path.join(ochre.cwl_path(), 'word-mapping-wf.cwl'), wd=True)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# create word mappings for directory\n", 60 | "import nlppln\n", 61 | "import ochre\n", 62 | "import os\n", 63 | "\n", 64 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 65 | " wf.load(steps_dir=ochre.cwl_path())\n", 66 | " print(wf.list_steps())\n", 67 | " \n", 68 | " gs_dir = wf.add_input(gs_dir='Directory')\n", 69 | " ocr_dir = wf.add_input(ocr_dir='Directory')\n", 70 | " language = wf.add_input(language='string')\n", 71 | " lowercase = wf.add_input(lowercase='boolean?')\n", 72 | " align_metadata = wf.add_input(align_m='string?')\n", 73 | " align_changes = wf.add_input(align_c='string?')\n", 74 | " word_mapping_name = wf.add_input(wm_name='string?')\n", 75 | "\n", 76 | " gs_files = wf.ls(in_dir=gs_dir)\n", 77 | " ocr_files = wf.ls(in_dir=ocr_dir)\n", 78 | "\n", 79 | " wm_mapping = wf.word_mapping_wf(gs_files=gs_files, ocr_files=ocr_files, language=language, \n", 80 | " lowercase=lowercase, align_c=align_changes, \n", 81 | " align_m=align_metadata, wm_name=word_mapping_name)\n", 82 | "\n", 83 | " wf.add_outputs(wm_mapping=wm_mapping)\n", 84 | " wf.save(os.path.join(ochre.cwl_path(), 'word-mapping-dir.cwl'), pack=True)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# create word mappings for test data\n", 94 | "import nlppln\n", 95 | "import ochre\n", 96 | "import os\n", 97 | "\n", 98 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 99 | " wf.load(steps_dir=ochre.cwl_path())\n", 100 | " print(wf.list_steps())\n", 101 | "\n", 102 | " gs_dir = wf.add_input(gs_dir='Directory')\n", 103 | " ocr_dir = wf.add_input(ocr_dir='Directory')\n", 104 | " language = wf.add_input(language='string')\n", 105 | " data_div = wf.add_input(data_div='File')\n", 106 | " lowercase = wf.add_input(lowercase='boolean?')\n", 107 | " align_metadata = wf.add_input(align_m='string?')\n", 108 | " align_changes = wf.add_input(align_c='string?')\n", 109 | " word_mapping_name = wf.add_input(wm_name='string?')\n", 110 | "\n", 111 | " test_gs = wf.select_test_files(datadivision=data_div, in_dir=gs_dir)\n", 112 | " test_ocr = wf.select_test_files(datadivision=data_div, in_dir=ocr_dir)\n", 113 | "\n", 114 | " wm_mapping = wf.word_mapping_wf(gs_files=test_gs, ocr_files=test_ocr, language=language, \n", 115 | " lowercase=lowercase, align_c=align_changes, \n", 116 | " align_m=align_metadata, wm_name=word_mapping_name)\n", 117 | "\n", 118 | " wf.add_outputs(wm_mapping=wm_mapping)\n", 119 | " wf.save(os.path.join(ochre.cwl_path(), 'word-mapping-test-files-wf.cwl'), pack=True)" 120 | ] 121 | } 122 | ], 123 | "metadata": { 124 | "language_info": { 125 | "name": "python", 126 | "pygments_lexer": "ipython3" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 2 131 | } 132 | -------------------------------------------------------------------------------- /notebooks/improve-keras-datagen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# read the texts (using data division)\n", 33 | "# convert texts to input and output sequences (strings)\n", 34 | "# function for converting input and output sequences to numbers (in data generator)\n", 35 | "import json\n", 36 | "\n", 37 | "from ochre.utils import read_texts\n", 38 | "\n", 39 | "datasets = '/home/jvdzwaan/data/dncvu/datadivision-small.json'\n", 40 | "data_dir = '/home/jvdzwaan/data/dncvu/aligned/'\n", 41 | "\n", 42 | "with open(datasets) as d:\n", 43 | " division = json.load(d)\n", 44 | "print(division)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir)\n", 54 | "print(gs_val)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# input: sequence of n characters from ocr\n", 64 | "# output: sequence of n characters from gs\n", 65 | "\n", 66 | "# we can filter the sequences later if we want\n", 67 | "\n", 68 | "n = 5\n", 69 | "num = 0\n", 70 | "\n", 71 | "def get_sequences(gs, ocr, length):\n", 72 | " gs_ngrams = zip(*[gs[i:] for i in range(length)])\n", 73 | " ocr_ngrams = zip(*[ocr[i:] for i in range(length)])\n", 74 | "\n", 75 | " return [''.join(n) for n in gs_ngrams], [''.join(n) for n in ocr_ngrams]\n", 76 | "\n", 77 | "gs_seqs, ocr_seqs = get_sequences(gs_val, ocr_val, length=n)\n", 78 | "print(len(gs_seqs), len(ocr_seqs))\n", 79 | "# each item in the *_seqs is a sample" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "gs_seqs = gs_seqs[:7]\n", 89 | "ocr_seqs = ocr_seqs[:7]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "from sklearn.preprocessing import LabelEncoder\n", 99 | "\n", 100 | "le = LabelEncoder()\n", 101 | "le.fit([c for c in raw_val])\n", 102 | "for s in ocr_seqs:\n", 103 | " print(le.fit_transform(''.join(s)))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "raw = ''.join(gs_seqs)\n", 113 | "print(raw)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "from ochre.utils import get_chars, get_int_to_char\n", 123 | "\n", 124 | "chars, num_chars, ci = get_chars(raw, raw, raw, False)\n", 125 | "ic = get_int_to_char(ci)\n", 126 | "\n", 127 | "#print(ci)\n", 128 | "#print(ic)\n", 129 | "\n", 130 | "oov_char = '@'\n", 131 | "\n", 132 | "for s in ocr_seqs:\n", 133 | "#for s in [['wwww']]:\n", 134 | "#for s in [['w8w']]:\n", 135 | " print(s)\n", 136 | " res = np.empty(n, dtype=np.int)\n", 137 | " for i in range(n):\n", 138 | " try:\n", 139 | " if s[i] != '':\n", 140 | " res[i] = ci.get(s[i], ci[oov_char])\n", 141 | " except IndexError:\n", 142 | " res[i] = ci['\\n']\n", 143 | " print(res)\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "from ochre.datagen import DataGenerator\n", 153 | "\n", 154 | "dg = DataGenerator(xData=ocr_seqs, yData=gs_seqs, char_to_int=ci, seq_length=n, batch_size=3)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "print(len(dg))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "inp, outp = dg[0]\n", 173 | "print(inp)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "print(outp.shape)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "print(outp)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "print(ocr_seqs)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "dg._convert_sample('test')" 210 | ] 211 | } 212 | ], 213 | "metadata": { 214 | "language_info": { 215 | "name": "python", 216 | "pygments_lexer": "ipython3" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 2 221 | } 222 | -------------------------------------------------------------------------------- /notebooks/kb_xml_to_text_(unaligned).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "\n", 25 | "from tqdm import tqdm_notebook as tqdm" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os\n", 35 | "\n", 36 | "from lxml import etree\n", 37 | "\n", 38 | "from nlppln.utils import get_files\n", 39 | "from ochre.matchlines import gt_fname2ocr_fname\n", 40 | "\n", 41 | "gs_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Ground-truth/'\n", 42 | "ocr_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/'\n", 43 | "#ocr_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Opnieuw geOCRd/'\n", 44 | "\n", 45 | "gs_files = get_files(gs_dir)\n", 46 | "# remove file with \"extra\" in the name, this one is the same as the file without \"extra\" in the name\n", 47 | "gs_files = [f for f in gs_files if not 'extra' in f]\n", 48 | "\n", 49 | "ocr_files = []\n", 50 | "for gs_file in gs_files:\n", 51 | " ocr_bn = gt_fname2ocr_fname(gs_file)\n", 52 | " # the 'opnieuw' alto files have a different file name\n", 53 | " #ocr_bn = ocr_bn.replace('alto.xml', 'altoFR11.xml')\n", 54 | " ocr_file = os.path.join(ocr_dir, ocr_bn)\n", 55 | " if os.path.isfile(ocr_file):\n", 56 | " ocr_files.append(ocr_file)\n", 57 | " else:\n", 58 | " print('File not found:', ocr_file)\n", 59 | " print('GS file:', gs_file)\n", 60 | "print(len(gs_files), len(ocr_files))" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "%%time\n", 70 | "\n", 71 | "from nlppln.utils import create_dirs\n", 72 | "\n", 73 | "from ochre.utils import get_temp_file\n", 74 | "from ochre.matchlines import get_ns, replace_entities\n", 75 | "\n", 76 | "def get_lines(fname, alto_ns):\n", 77 | " lines = []\n", 78 | " context = etree.iterparse(fname, events=('end', ), tag=(alto_ns+'TextLine'))\n", 79 | " for event, elem in context:\n", 80 | " words = []\n", 81 | " for a in elem.getchildren():\n", 82 | " if a.tag == alto_ns+'String':\n", 83 | " if a.attrib.get('SUBS_TYPE') == 'HypPart1':\n", 84 | " words.append(a.attrib['SUBS_CONTENT'])\n", 85 | " elif a.attrib.get('SUBS_TYPE') != 'HypPart2':\n", 86 | " words.append(a.attrib['CONTENT'])\n", 87 | " \n", 88 | " lines.append(' '.join(words))\n", 89 | " \n", 90 | " # make iteration over context fast and consume less memory\n", 91 | " #https://www.ibm.com/developerworks/xml/library/x-hiperfparse\n", 92 | " elem.clear()\n", 93 | " while elem.getprevious() is not None:\n", 94 | " del elem.getparent()[0]\n", 95 | " \n", 96 | " return lines\n", 97 | "\n", 98 | "def doc_id(fname):\n", 99 | " bn = os.path.basename(fname)\n", 100 | " n = bn.rsplit('_', 1)[0]\n", 101 | " return n\n", 102 | "\n", 103 | "\n", 104 | "out_dir = '/home/jvdzwaan/data/kb-ocr/text-not-aligned/'\n", 105 | "\n", 106 | "create_dirs(out_dir)\n", 107 | "\n", 108 | "gs_dir = os.path.join(out_dir, 'gs')\n", 109 | "create_dirs(gs_dir)\n", 110 | "\n", 111 | "ocr_dir = os.path.join(out_dir, 'ocr')\n", 112 | "create_dirs(ocr_dir)\n", 113 | "\n", 114 | "for gs_file, ocr_file in tqdm(zip(gs_files, ocr_files), total=len(gs_files)):\n", 115 | " try:\n", 116 | " gs_tmp = get_temp_file()\n", 117 | " #print(gs_tmp)\n", 118 | " with open(gs_tmp, 'w') as f:\n", 119 | " f.write(replace_entities(gs_file))\n", 120 | " \n", 121 | " #ocr_tmp = get_temp_file()\n", 122 | " #print(gs_tmp)\n", 123 | " #with open(ocr_tmp, 'w') as f:\n", 124 | " # f.write(replace_entities(ocr_file))\n", 125 | " \n", 126 | " gs_lines = get_lines(gs_tmp, get_ns(gs_file))\n", 127 | " ocr_lines = get_lines(ocr_file, get_ns(ocr_file))\n", 128 | " #print(len(gs_lines), len(ocr_lines))\n", 129 | " \n", 130 | " os.remove(gs_tmp)\n", 131 | " #os.remove(ocr_tmp)\n", 132 | " \n", 133 | " #print(doc_id(gs_file))\n", 134 | " #print(doc_id(ocr_file))\n", 135 | " assert doc_id(gs_file) == doc_id(ocr_file)\n", 136 | " gs_out = os.path.join(gs_dir, '{}.txt'.format(doc_id(gs_file)))\n", 137 | " ocr_out = os.path.join(ocr_dir, '{}.txt'.format(doc_id(ocr_file)))\n", 138 | " #print(gs_out)\n", 139 | " #print(ocr_out)\n", 140 | " \n", 141 | " with open(gs_out, 'w') as f:\n", 142 | " f.write(' '.join(gs_lines))\n", 143 | " with open(ocr_out, 'w') as f:\n", 144 | " f.write(' '.join(ocr_lines))\n", 145 | " except etree.XMLSyntaxError as e:\n", 146 | " print(gs_file)\n", 147 | " print(e)\n", 148 | " print()" 149 | ] 150 | } 151 | ], 152 | "metadata": { 153 | "language_info": { 154 | "name": "python", 155 | "pygments_lexer": "ipython3" 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 2 160 | } 161 | -------------------------------------------------------------------------------- /notebooks/try-lstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "## Links/tutorials\n", 18 | "\n", 19 | "* http://karpathy.github.io/2015/05/21/rnn-effectiveness/\n", 20 | "* http://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/\n", 21 | "* http://machinelearningmastery.com/sequence-classification-lstm-recurrent-neural-networks-python-keras/" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import numpy\n", 31 | "from keras.models import Sequential\n", 32 | "from keras.layers import Dense\n", 33 | "from keras.layers import Dropout\n", 34 | "from keras.layers import LSTM\n", 35 | "from keras.callbacks import ModelCheckpoint\n", 36 | "from keras.utils import np_utils" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# load ascii text and covert to lowercase\n", 46 | "filename = \"/home/jvdzwaan/data/alice.txt\"\n", 47 | "raw_text = open(filename).read()\n", 48 | "raw_text = raw_text.lower()\n", 49 | "print raw_text[:100]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# create mapping of unique chars to integers\n", 59 | "chars = sorted(list(set(raw_text)))\n", 60 | "char_to_int = dict((c, i) for i, c in enumerate(chars))" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "print chars" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "print char_to_int" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "n_chars = len(raw_text)\n", 88 | "n_vocab = len(chars)\n", 89 | "print \"Total Characters: \", n_chars\n", 90 | "print \"Total Vocab: \", n_vocab" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# prepare the dataset of input to output pairs encoded as integers\n", 100 | "seq_length = 100\n", 101 | "dataX = []\n", 102 | "dataY = []\n", 103 | "for i in range(0, n_chars - seq_length, 1):\n", 104 | " seq_in = raw_text[i:i + seq_length]\n", 105 | " seq_out = raw_text[i + seq_length]\n", 106 | " dataX.append([char_to_int[char] for char in seq_in])\n", 107 | " dataY.append(char_to_int[seq_out])\n", 108 | "n_patterns = len(dataX)\n", 109 | "print \"Total Patterns: \", n_patterns" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "print seq_in\n", 119 | "print '---'\n", 120 | "print repr(seq_out)\n", 121 | "print len(seq_in)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "for i in range(0, n_chars - seq_length, 1):\n", 131 | " print i" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "print dataX[0]\n", 141 | "print dataY[0]" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "X = numpy.reshape(dataX, (n_patterns, seq_length, 1))\n", 151 | "# normalize\n", 152 | "X = X / float(n_vocab)\n", 153 | "# one hot encode the output variable\n", 154 | "y = np_utils.to_categorical(dataY)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "print X" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "print y" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# define the LSTM model\n", 182 | "model = Sequential()\n", 183 | "model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))\n", 184 | "model.add(Dropout(0.2))\n", 185 | "model.add(Dense(y.shape[1], activation='softmax'))\n", 186 | "model.compile(loss='categorical_crossentropy', optimizer='adam')" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "# define the checkpoint\n", 196 | "filepath=\"/home/jvdzwaan/data/tmp/lstm/weights-improvement-{epoch:02d}-{loss:.4f}.hdf5\"\n", 197 | "checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')\n", 198 | "callbacks_list = [checkpoint]" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "model.fit(X, y, epochs=20, batch_size=128, callbacks=callbacks_list)" 208 | ] 209 | } 210 | ], 211 | "metadata": { 212 | "language_info": { 213 | "name": "python", 214 | "pygments_lexer": "ipython3" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /notebooks/ICDAR2017_shared_task_workflows.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import nlppln\n", 20 | "\n", 21 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 22 | " wf.load(steps_dir='../ochre/cwl/')\n", 23 | " print wf.list_steps()\n", 24 | "\n", 25 | " in_dir = wf.add_input(in_dir='Directory')\n", 26 | " ocr_dir_name = wf.add_input(ocr_dir_name='string')\n", 27 | " gs_dir_name = wf.add_input(gs_dir_name='string')\n", 28 | " aligned_dir_name = wf.add_input(aligned_dir_name='string')\n", 29 | "\n", 30 | " files = wf.ls(in_dir=in_dir)\n", 31 | " aligned, gs, ocr = wf.icdar2017st_extract_text(in_file=files, scatter=['in_file'], scatter_method='dotproduct')\n", 32 | " gs_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs)\n", 33 | " ocr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr)\n", 34 | " aligned_dir = wf.save_files_to_dir(dir_name=aligned_dir_name, in_files=aligned)\n", 35 | "\n", 36 | " wf.add_outputs(gs_dir=gs_dir)\n", 37 | " wf.add_outputs(ocr_dir=ocr_dir)\n", 38 | " wf.add_outputs(aligned_dir=aligned_dir)\n", 39 | "\n", 40 | " wf.save('../ochre/cwl/icdar2017st-extract-data.cwl', wd=True)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import nlppln\n", 50 | "\n", 51 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n", 52 | " wf.load(steps_dir='../ochre/cwl/')\n", 53 | " print wf.list_steps()\n", 54 | "\n", 55 | " in_dir1 = wf.add_input(in_dir1='Directory')\n", 56 | " in_dir2 = wf.add_input(in_dir2='Directory')\n", 57 | " in_dir3 = wf.add_input(in_dir3='Directory')\n", 58 | " in_dir4 = wf.add_input(in_dir4='Directory')\n", 59 | " ocr_dir_name = wf.add_input(ocr_dir_name='string', default='ocr')\n", 60 | " gs_dir_name = wf.add_input(gs_dir_name='string', default='gs')\n", 61 | " aligned_dir_name = wf.add_input(aligned_dir_name='string', default='aligned')\n", 62 | "\n", 63 | " aligned_dir1, gs_dir1, ocr_dir1 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n", 64 | " gs_dir_name=gs_dir_name, \n", 65 | " ocr_dir_name=ocr_dir_name,\n", 66 | " in_dir=in_dir1)\n", 67 | " gs1 = wf.save_dir_to_subdir(inner_dir=gs_dir1, outer_dir=in_dir1)\n", 68 | " ocr1 = wf.save_dir_to_subdir(inner_dir=ocr_dir1, outer_dir=in_dir1)\n", 69 | " aligned1 = wf.save_dir_to_subdir(inner_dir=aligned_dir1, outer_dir=in_dir1)\n", 70 | "\n", 71 | " aligned_dir2, gs_dir2, ocr_dir2 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n", 72 | " gs_dir_name=gs_dir_name, \n", 73 | " ocr_dir_name=ocr_dir_name,\n", 74 | " in_dir=in_dir2)\n", 75 | " gs2 = wf.save_dir_to_subdir(inner_dir=gs_dir2, outer_dir=in_dir2)\n", 76 | " ocr2 = wf.save_dir_to_subdir(inner_dir=ocr_dir2, outer_dir=in_dir2)\n", 77 | " aligned2 = wf.save_dir_to_subdir(inner_dir=aligned_dir2, outer_dir=in_dir2)\n", 78 | "\n", 79 | " aligned_dir3, gs_dir3, ocr_dir3 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n", 80 | " gs_dir_name=gs_dir_name, \n", 81 | " ocr_dir_name=ocr_dir_name,\n", 82 | " in_dir=in_dir3)\n", 83 | " gs3 = wf.save_dir_to_subdir(inner_dir=gs_dir3, outer_dir=in_dir3)\n", 84 | " ocr3 = wf.save_dir_to_subdir(inner_dir=ocr_dir3, outer_dir=in_dir3)\n", 85 | " aligned3 = wf.save_dir_to_subdir(inner_dir=aligned_dir3, outer_dir=in_dir3)\n", 86 | "\n", 87 | " aligned_dir4, gs_dir4, ocr_dir4 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n", 88 | " gs_dir_name=gs_dir_name, \n", 89 | " ocr_dir_name=ocr_dir_name,\n", 90 | " in_dir=in_dir4)\n", 91 | " gs4 = wf.save_dir_to_subdir(inner_dir=gs_dir4, outer_dir=in_dir4)\n", 92 | " ocr4 = wf.save_dir_to_subdir(inner_dir=ocr_dir4, outer_dir=in_dir4)\n", 93 | " aligned4 = wf.save_dir_to_subdir(inner_dir=aligned_dir4, outer_dir=in_dir4)\n", 94 | "\n", 95 | "\n", 96 | " wf.add_outputs(gs1=gs1)\n", 97 | " wf.add_outputs(gs2=gs2)\n", 98 | " wf.add_outputs(gs3=gs3)\n", 99 | " wf.add_outputs(gs4=gs4)\n", 100 | "\n", 101 | " wf.add_outputs(ocr1=ocr1)\n", 102 | " wf.add_outputs(ocr2=ocr2)\n", 103 | " wf.add_outputs(ocr3=ocr3)\n", 104 | " wf.add_outputs(ocr4=ocr4)\n", 105 | "\n", 106 | " wf.add_outputs(aligned1=aligned4)\n", 107 | " wf.add_outputs(aligned2=aligned3)\n", 108 | " wf.add_outputs(aligned3=aligned2)\n", 109 | " wf.add_outputs(aligned4=aligned1)\n", 110 | "\n", 111 | " wf.save('../ochre/cwl/icdar2017st-extract-data-all.cwl', pack=True)" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "language_info": { 117 | "name": "python", 118 | "pygments_lexer": "ipython3" 119 | } 120 | }, 121 | "nbformat": 4, 122 | "nbformat_minor": 2 123 | } 124 | -------------------------------------------------------------------------------- /notebooks/fuzzy-string-matching.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "text = \"Many force beat offer third attention. Relationship whether question ask shoulder. Professional management method soldier itself enough single value opportunity represent short defense. Last human want tree sign people form old moment myself understand support station ahead water red relationship million indeed away music. Benefit range have seem day no thing single tough explain attention far local that above one design authority nothing. Trade stage seek former white their we success president item wind yeah full soldier suggest role never. Down condition north tell than item program at. Dark my box culture treat court tell some according quickly often issue quickly quickly office at wide want. Hope start any. Government value relate cover country network according. Travel. On but property production oil produce material gun page why face serve nation expect. Concern this apply travel skill allow same friend night country young once group style new turn situation gun foreign voice mother enough. Turn commercial manager news put determine pass address start particular. Light begin style hand almost popular thousand or least. Dog walk kitchen star skin information task land continue always spend value speak quickly energy improve leg as mind sing stage purpose few. Reason task phone conference. Production move rule marriage attorney money team open pretty card fall. Small player how since list adult idea apply hold both feel house team sign wall put authority free anyone happy authority throw body art pick. Skill entire major close series ground thus letter recent example I hold base policy include rate. Man think best position his lead career stage contain around eight mean they character specific memory religious heart shake see. Just during yet section thank rise force director water last ground go keep area different. Quickly several character west wrong model attention so truth resource her test gas. Most civil soldier consider front scene never military theory not free economic capital admit yeah fall thought there specific trial low against little tree. Center late.\"\n", 20 | "print(text)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from fuzzyset import FuzzySet\n", 30 | "\n", 31 | "fs = FuzzySet()\n", 32 | "fs.add(text)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "fs.get(\"Production move rule marriage attorney money team open pretty card fall.\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "fs = FuzzySet()\n", 51 | "fs.add(\"Production move rule marriage attorney money team open pretty card fall.\")\n", 52 | "fs.get(text)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from faker import Faker\n", 62 | "\n", 63 | "fake = Faker()\n", 64 | "num_words = 10\n", 65 | "sentences = []\n", 66 | "\n", 67 | "for i in range(10000):\n", 68 | " sentence = fake.sentence(nb_words=num_words, variable_nb_words=True, ext_word_list=None)\n", 69 | " sentences.append(sentence)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "fs = FuzzySet()\n", 79 | "for s in sentences:\n", 80 | " fs.add(s)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "text = ' '.join(sentences[5:10])\n", 90 | "print(text)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "fs.get(text)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "from fuzzywuzzy import fuzz\n", 109 | "\n", 110 | "part = \"Production move rule marriage attorney money team open pretty card fall.\"\n", 111 | "\n", 112 | "fuzz.partial_ratio(part, text)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "fuzz.partial_ratio(\"character they\", text)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "part = \"Production move rule marriage attorncy moncy team opcn pretty card fall.\"\n", 131 | "\n", 132 | "fuzz.partial_ratio(part, text)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "from faker import Faker\n", 142 | "\n", 143 | "fake = Faker()\n", 144 | "num_words = 10\n", 145 | "sentences = []\n", 146 | "\n", 147 | "for i in range(10000):\n", 148 | " sentence = fake.sentence(nb_words=num_words, variable_nb_words=True, ext_word_list=None)\n", 149 | " sentences.append(sentence)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "sentences[1]" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from tqdm import tqdm_notebook as tqdm" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "%%time\n", 177 | "\n", 178 | "for s in tqdm(sentences):\n", 179 | " fuzz.partial_ratio(s, text)" 180 | ] 181 | } 182 | ], 183 | "metadata": { 184 | "language_info": { 185 | "name": "python", 186 | "pygments_lexer": "ipython3" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 2 191 | } 192 | -------------------------------------------------------------------------------- /ochre/train_mt.py: -------------------------------------------------------------------------------- 1 | import click 2 | import os 3 | import json 4 | import codecs 5 | 6 | from keras.models import Model 7 | from keras.layers import Input, LSTM, Dense 8 | import numpy as np 9 | 10 | from ochre.utils import add_checkpoint, load_weights, get_files 11 | 12 | 13 | def read_texts(data_dir, div, name): 14 | in_files = get_files(data_dir, div, name) 15 | 16 | # Vectorize the data. 17 | input_texts = [] 18 | target_texts = [] 19 | 20 | for in_file in in_files: 21 | lines = codecs.open(in_file, 'r', encoding='utf-8').readlines() 22 | for line in lines: 23 | #print line.split('||@@||') 24 | input_text, target_text = line.split('||@@||') 25 | # We use "tab" as the "start sequence" character 26 | # for the targets, and "\n" as "end sequence" character. 27 | target_text = '\t' + target_text 28 | input_texts.append(input_text) 29 | target_texts.append(target_text) 30 | 31 | return input_texts, target_texts 32 | 33 | 34 | def convert(input_texts, target_texts, input_characters, target_characters, max_encoder_seq_length, num_encoder_tokens, max_decoder_seq_length, num_decoder_tokens): 35 | input_token_index = dict( 36 | [(char, i) for i, char in enumerate(input_characters)]) 37 | target_token_index = dict( 38 | [(char, i) for i, char in enumerate(target_characters)]) 39 | 40 | encoder_input_data = np.zeros( 41 | (len(input_texts), max_encoder_seq_length, num_encoder_tokens), 42 | dtype='float32') 43 | decoder_input_data = np.zeros( 44 | (len(input_texts), max_decoder_seq_length, num_decoder_tokens), 45 | dtype='float32') 46 | decoder_target_data = np.zeros( 47 | (len(input_texts), max_decoder_seq_length, num_decoder_tokens), 48 | dtype='float32') 49 | 50 | for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)): 51 | for t, char in enumerate(input_text): 52 | try: 53 | encoder_input_data[i, t, input_token_index[char]] = 1. 54 | except IndexError: 55 | # sequence longer than max length of training inputs 56 | pass 57 | for t, char in enumerate(target_text): 58 | # decoder_target_data is ahead of decoder_input_data by one timestep 59 | try: 60 | decoder_input_data[i, t, target_token_index[char]] = 1. 61 | except IndexError: 62 | # sequence longer than max length of training inputs 63 | pass 64 | if t > 0: 65 | # decoder_target_data will be ahead by one timestep 66 | # and will not include the start character. 67 | try: 68 | decoder_target_data[i, t - 1, target_token_index[char]] = 1. 69 | except IndexError: 70 | # sequence longer than max length of training inputs 71 | pass 72 | 73 | return encoder_input_data, decoder_input_data, decoder_target_data 74 | 75 | 76 | @click.command() 77 | @click.argument('datasets', type=click.File()) 78 | @click.argument('data_dir', type=click.Path(exists=True)) 79 | @click.option('--weights_dir', '-w', default=os.getcwd(), type=click.Path()) 80 | def train_lstm(datasets, data_dir, weights_dir): 81 | batch_size = 64 # Batch size for training. 82 | epochs = 100 # Number of epochs to train for. 83 | latent_dim = 256 # Latent dimensionality of the encoding space. 84 | 85 | div = json.load(datasets) 86 | 87 | train_input, train_target = read_texts(data_dir, div, 'train') 88 | val_input, val_target = read_texts(data_dir, div, 'val') 89 | #test_input, test_target = read_texts(data_dir, div, 'test') 90 | 91 | input_characters = sorted(list(set(u''.join(train_input) + u''.join(val_input)))) 92 | target_characters = sorted(list(set(u''.join(train_target) + u''.join(val_target)))) 93 | num_encoder_tokens = len(input_characters) 94 | num_decoder_tokens = len(target_characters) 95 | max_encoder_seq_length = max([len(txt) for txt in train_input]) 96 | max_decoder_seq_length = max([len(txt) for txt in train_target]) 97 | 98 | print('Number of samples:', len(train_input)) 99 | print('Number of unique input tokens:', num_encoder_tokens) 100 | print('Number of unique output tokens:', num_decoder_tokens) 101 | print('Max sequence length for inputs:', max_encoder_seq_length) 102 | print('Max sequence length for outputs:', max_decoder_seq_length) 103 | print('Input characters:', u''.join(input_characters)) 104 | print('Output characters:', u''.join(target_characters)) 105 | 106 | train_enc_input, train_dec_input, train_dec_target = convert(train_input, 107 | train_target, input_characters, target_characters, 108 | max_encoder_seq_length, num_encoder_tokens, max_decoder_seq_length, 109 | num_decoder_tokens) 110 | 111 | val_enc_input, val_dec_input, val_dec_target = convert(val_input, 112 | val_target, input_characters, target_characters, 113 | max_encoder_seq_length, num_encoder_tokens, max_decoder_seq_length, 114 | num_decoder_tokens) 115 | 116 | # Define an input sequence and process it. 117 | encoder_inputs = Input(shape=(None, num_encoder_tokens)) 118 | encoder = LSTM(latent_dim, return_state=True) 119 | encoder_outputs, state_h, state_c = encoder(encoder_inputs) 120 | # We discard `encoder_outputs` and only keep the states. 121 | encoder_states = [state_h, state_c] 122 | 123 | # Set up the decoder, using `encoder_states` as initial state. 124 | decoder_inputs = Input(shape=(None, num_decoder_tokens)) 125 | # We set up our decoder to return full output sequences, 126 | # and to return internal states as well. We don't use the 127 | # return states in the training model, but we will use them in inference. 128 | decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) 129 | decoder_outputs, _, _ = decoder_lstm(decoder_inputs, 130 | initial_state=encoder_states) 131 | decoder_dense = Dense(num_decoder_tokens, activation='softmax') 132 | decoder_outputs = decoder_dense(decoder_outputs) 133 | 134 | # Define the model that will turn 135 | # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` 136 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs) 137 | 138 | # Run training 139 | model.compile(optimizer='rmsprop', loss='categorical_crossentropy') 140 | epoch, model = load_weights(model, weights_dir, optimizer='rmsprop', loss='categorical_crossentropy') 141 | callbacks_list = [add_checkpoint(weights_dir)] 142 | model.fit([train_enc_input, train_dec_input], train_dec_target, 143 | batch_size=batch_size, 144 | epochs=epochs, 145 | validation_data=([val_enc_input, val_dec_input], val_dec_target), 146 | callbacks=callbacks_list, 147 | initial_epoch=epoch) 148 | 149 | 150 | if __name__ == '__main__': 151 | train_lstm() 152 | --------------------------------------------------------------------------------