├── ochre
├── __init__.py
├── cwl
│ ├── vudnc-select-files.cwl
│ ├── count-chars.cwl
│ ├── rmgarbage.cwl
│ ├── sac2gs-and-ocr.cwl
│ ├── clin2018st-extract-text.cwl
│ ├── create-data-division.cwl
│ ├── remove-empty-files.cwl
│ ├── kb-tss-concat-files.cwl
│ ├── ocrevaluation-extract.cwl
│ ├── char-align.cwl
│ ├── remove-title-page.cwl
│ ├── select-test-files.cwl
│ ├── icdar2017st-extract-text.cwl
│ ├── match-ocr-and-gs.cwl
│ ├── lstm-synced-correct-ocr.cwl
│ ├── vudnc2ocr-and-gs.cwl
│ ├── merge-json.cwl
│ ├── create-word-mappings.cwl
│ ├── ocrevaluation-performance-wf.cwl
│ ├── ocr-transform.cwl
│ ├── onmt-tokenize-text.cwl
│ ├── onmt-main.cwl
│ ├── onmt-build-vocab.cwl
│ ├── icdar2017st-extract-data.cwl
│ ├── kb-tss-preprocess-single-dir.cwl
│ ├── align-texts-wf.cwl
│ ├── post-correct-dir.cwl
│ ├── word-mapping-wf.cwl
│ ├── post-correct-test-files.cwl
│ ├── lowercase-directory.cwl
│ └── sac-extract.cwl
├── select_vudnc_files.py
├── select_test_files.py
├── dncvu_select_ocr_and_gs_texts.py
├── count_chars.py
├── kb_tss_concat_files.py
├── remove_empty_files.py
├── merge_json.py
├── char_align.py
├── create_data_division.py
├── match_ocr_and_gs.py
├── sac2gs_and_ocr.py
├── icdar2017st_extract_text.py
├── remove_title_page.py
├── clin2018st_extract_text.py
├── ocrevaluation_extract.py
├── edlibutils.py
├── scramble.py
├── create_word_mappings.py
├── lstm_synced_correct_ocr.py
├── vudnc2ocr_and_gs.py
├── rmgarbage.py
├── keras_utils.py
├── train_seq2seq.py
└── train_mt.py
├── tests
├── data
│ ├── ocrevaluation
│ │ ├── in
│ │ │ ├── gs.txt
│ │ │ └── ocr.txt
│ │ └── out
│ │ │ └── gs_out.html
│ └── ocrevaluation-extract
│ │ ├── out
│ │ ├── empty-global.csv
│ │ ├── in-global.csv
│ │ ├── empty-character.csv
│ │ └── in-character.csv
│ │ └── in
│ │ ├── empty.html
│ │ └── in.html
├── test_ocrerrors.py
├── test_datagen.py
├── test_rmgarbage.py
├── test_utils.py
├── test_ocrevaluation_extract.py
└── test_remove_title_page.py
├── requirements.txt
├── .gitignore
├── job.sh
├── nematus_translate.sh
├── .editorconfig
├── nematusA8P2.sh
├── nematusA8P3.sh
├── nematus.sh
├── transformer.sh
├── notebooks
├── lowercase-directory-workflow.ipynb
├── vudnc-preprocess-workflow.ipynb
├── post_correction_workflows.ipynb
├── icdar2019POCR-to-texts.ipynb
├── sac-preprocess-workflow.ipynb
├── preprocess-dbnl_ocr.ipynb
├── align-workflow.ipynb
├── icdar2017-ngrams.ipynb
├── 2017-baseline-nn.ipynb
├── kb_tss_preprocess.ipynb
├── ocr-evaluation-workflow.ipynb
├── Compare-performance-of-pred.ipynb
├── character_counts.ipynb
├── word-mapping-workflow.ipynb
├── improve-keras-datagen.ipynb
├── kb_xml_to_text_(unaligned).ipynb
├── try-lstm.ipynb
├── ICDAR2017_shared_task_workflows.ipynb
└── fuzzy-string-matching.ipynb
├── setup.py
├── datagen.py
├── 2017_baseline.py
└── lstm_synced.py
/ochre/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import cwl_path
2 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation/in/gs.txt:
--------------------------------------------------------------------------------
1 | This is an example text.
2 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation/in/ocr.txt:
--------------------------------------------------------------------------------
1 | This is an cxample text.
2 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation-extract/out/empty-global.csv:
--------------------------------------------------------------------------------
1 | ,CER,WER,WER (order independent)
2 | empty,n/a,n/a,n/a
3 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation-extract/out/in-global.csv:
--------------------------------------------------------------------------------
1 | ,CER,WER,WER (order independent)
2 | in,4.17,20.00,20.00
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | click
2 | lxml
3 | beautifulsoup4
4 | keras
5 | tensorflow
6 | edlib
7 | jupyter
8 | pytest
9 | nlppln>=0.3.1
10 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation-extract/out/empty-character.csv:
--------------------------------------------------------------------------------
1 | "Character",Hex code,Total,Spurious,Confused,Lost,Error rate
2 | "n/a",n/a,n/a,n/a,n/a,n/a,n/a
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[cod]
2 | *.egg-info
3 | *.eggs
4 | .ipynb_checkpoints
5 |
6 | build
7 | dist
8 | .cache
9 | __pycache__
10 |
11 | htmlcov
12 | .coverage
13 | coverage.xml
14 | .pytest_cache
15 |
16 | docs/_build
17 | docs/apidocs
18 |
19 | # ide
20 | .idea
21 | .eclipse
22 | .vscode
23 |
24 | # Mac
25 | .DS_Store
26 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation-extract/out/in-character.csv:
--------------------------------------------------------------------------------
1 | "Character",Hex code,Total,Spurious,Confused,Lost,Error rate
2 | " ",20,4,0,0,0,0.00
3 | ".",2e,1,0,0,0,0.00
4 | "T",54,1,0,0,0,0.00
5 | "a",61,2,0,0,0,0.00
6 | "e",65,3,0,1,0,33.33
7 | "h",68,1,0,0,0,0.00
8 | "i",69,2,0,0,0,0.00
9 | "l",6c,1,0,0,0,0.00
10 | "m",6d,1,0,0,0,0.00
11 | "n",6e,1,0,0,0,0.00
12 | "p",70,1,0,0,0,0.00
13 | "s",73,2,0,0,0,0.00
14 | "t",74,2,0,0,0,0.00
15 | "x",78,2,0,0,0,0.00
16 |
--------------------------------------------------------------------------------
/ochre/cwl/vudnc-select-files.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.select_vudnc_files"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | in_dir:
14 | type: Directory
15 | inputBinding:
16 | position: 1
17 |
18 | stdout: cwl.output.json
19 |
20 | outputs:
21 | out_files:
22 | type: File[]
23 |
--------------------------------------------------------------------------------
/ochre/cwl/count-chars.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.count_chars"]
12 |
13 | inputs:
14 | in_file:
15 | type: File
16 | inputBinding:
17 | position: 1
18 |
19 | outputs:
20 | char_counts:
21 | type: File
22 | outputBinding:
23 | glob: "*.json"
24 |
--------------------------------------------------------------------------------
/job.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --job-name=2017_baseline
4 | #SBATCH --output=2017_baseline.txt
5 | #
6 | #SBATCH --ntasks=1
7 | #SBATCH -C TitanX
8 | #SBATCH --time=15:00:00
9 | #SBATCH --gres=gpu:1
10 |
11 | module load python/3.5.2
12 | module load python-extra/python3.5/r0.5.0
13 | module load cuda80/toolkit/8.0.61
14 | module load tensorflow/python3.x/gpu/r1.4.0-py3
15 | module load keras/python3.5/r2.0.2
16 | module load cuDNN/cuda80/5.1.5
17 |
18 | srun python /home/jvdzwaan/code/ochre/2017_baseline.py
19 |
--------------------------------------------------------------------------------
/nematus_translate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --job-name=nematus_translate
4 | #SBATCH --output=nematus_translate.txt
5 | #
6 | #SBATCH --ntasks=1
7 | #SBATCH -C TitanX
8 | #SBATCH --gres=gpu:1
9 |
10 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P1/pred/
11 | for filename in /var/scratch/jvdzwaan/kb-ocr/A8P1/test/*.ocr; do
12 | srun python ~/code/nematus/nematus/translate.py -m /var/scratch/jvdzwaan/kb-ocr/A8P1/model/model-40000 -i "$filename" -o /var/scratch/jvdzwaan/kb-ocr/A8P1/pred/"$(basename "$filename" .ocr).pred"
13 | done
14 |
--------------------------------------------------------------------------------
/ochre/cwl/rmgarbage.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "nlppln.commands.rmgarbage"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | in_file:
14 | type: File
15 |
16 | outputs:
17 | out_file:
18 | type: File
19 | outputBinding:
20 | glob: $(inputs.in_file.nameroot).txt
21 | metadata_out:
22 | type: File
23 | outputBinding:
24 | glob: $(inputs.in_file.nameroot).txt
25 |
--------------------------------------------------------------------------------
/ochre/cwl/sac2gs-and-ocr.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.sac2gs_and_ocr"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | in_dir:
14 | type: Directory
15 | inputBinding:
16 | position: 1
17 |
18 | stdout: cwl.output.json
19 |
20 | outputs:
21 | gs_de:
22 | type: File[]
23 | ocr_de:
24 | type: File[]
25 | gs_fr:
26 | type: File[]
27 | ocr_fr:
28 | type: File[]
29 |
--------------------------------------------------------------------------------
/ochre/cwl/clin2018st-extract-text.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.clin2018st_extract_text"]
12 |
13 | inputs:
14 | json_file:
15 | type: File
16 | inputBinding:
17 | position: 1
18 |
19 | outputs:
20 | gs_text:
21 | type: File
22 | outputBinding:
23 | glob: "*-gs.txt"
24 | err_text:
25 | type: File
26 | outputBinding:
27 | glob: "*-errors.txt"
28 |
--------------------------------------------------------------------------------
/ochre/cwl/create-data-division.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.create_data_division"]
12 |
13 | inputs:
14 | in_dir:
15 | type: Directory
16 | inputBinding:
17 | position: 1
18 | out_name:
19 | type: string?
20 | inputBinding:
21 | prefix: --out_name
22 |
23 | outputs:
24 | metadata_out:
25 | type: File
26 | outputBinding:
27 | glob: "*.json"
28 |
--------------------------------------------------------------------------------
/ochre/cwl/remove-empty-files.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | baseCommand: ["python", "-m", "ochre.remove_empty_files"]
6 |
7 | requirements:
8 | EnvVarRequirement:
9 | envDef:
10 | LC_ALL: C.UTF-8
11 | LANG: C.UTF-8
12 |
13 | inputs:
14 | ocr_dir:
15 | type: Directory
16 | inputBinding:
17 | position: 2
18 | gs_dir:
19 | type: Directory
20 | inputBinding:
21 | position: 1
22 |
23 | stdout: cwl.output.json
24 |
25 | outputs:
26 | ocr:
27 | type: File[]
28 | gs:
29 | type: File[]
30 |
--------------------------------------------------------------------------------
/ochre/cwl/kb-tss-concat-files.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 | InitialWorkDirRequirement:
11 | listing: $(inputs.in_files)
12 |
13 | baseCommand: ["python", "-m", "ochre.kb_tss_concat_files"]
14 |
15 | arguments:
16 | - valueFrom: $(runtime.outdir)
17 | position: 1
18 |
19 | inputs:
20 | in_files:
21 | type: File[]
22 |
23 | outputs:
24 | out_files:
25 | type: File[]
26 | outputBinding:
27 | glob: "*.txt"
28 |
--------------------------------------------------------------------------------
/ochre/cwl/ocrevaluation-extract.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.ocrevaluation_extract"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | in_file:
14 | type: File
15 | inputBinding:
16 | position: 1
17 |
18 | outputs:
19 | character_data:
20 | type: File
21 | outputBinding:
22 | glob: "*-character.csv"
23 | global_data:
24 | type: File
25 | outputBinding:
26 | glob: "*-global.csv"
27 |
--------------------------------------------------------------------------------
/ochre/cwl/char-align.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.char_align"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | ocr_text:
14 | type: File
15 | inputBinding:
16 | position: 1
17 | gs_text:
18 | type: File
19 | inputBinding:
20 | position: 2
21 | metadata:
22 | type: File
23 | inputBinding:
24 | position: 3
25 |
26 | outputs:
27 | out_file:
28 | type: File
29 | outputBinding:
30 | glob: "*.json"
31 |
--------------------------------------------------------------------------------
/ochre/cwl/remove-title-page.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.remove_title_page"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | without_tp:
14 | type: File
15 | inputBinding:
16 | position: 1
17 | with_tp:
18 | type: File
19 | inputBinding:
20 | position: 2
21 | num_lines:
22 | type: int?
23 | inputBinding:
24 | prefix: -n
25 |
26 | outputs:
27 | out_file:
28 | type: File
29 | outputBinding:
30 | glob: "*"
31 |
--------------------------------------------------------------------------------
/ochre/cwl/select-test-files.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.select_test_files"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | stdout: cwl.output.json
13 |
14 | inputs:
15 | in_dir:
16 | type: Directory
17 | inputBinding:
18 | position: 1
19 | datadivision:
20 | type: File
21 | inputBinding:
22 | position: 2
23 | name:
24 | type: string?
25 | inputBinding:
26 | prefix: --name
27 |
28 | outputs:
29 | out_files:
30 | type: File[]
31 |
--------------------------------------------------------------------------------
/ochre/cwl/icdar2017st-extract-text.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.icdar2017st_extract_text"]
12 |
13 | inputs:
14 | in_file:
15 | type: File
16 | inputBinding:
17 | position: 1
18 |
19 | outputs:
20 | gs:
21 | type: File
22 | outputBinding:
23 | glob: "gs/*.txt"
24 | ocr:
25 | type: File
26 | outputBinding:
27 | glob: "ocr/*.txt"
28 | aligned:
29 | type: File
30 | outputBinding:
31 | glob: "*.json"
32 |
--------------------------------------------------------------------------------
/ochre/cwl/match-ocr-and-gs.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.match_ocr_and_gs"]
12 |
13 | inputs:
14 | ocr_dir:
15 | type: Directory
16 | inputBinding:
17 | position: 0
18 | gs_dir:
19 | type: Directory
20 | inputBinding:
21 | position: 1
22 |
23 | outputs:
24 | ocr:
25 | type: Directory
26 | outputBinding:
27 | glob: $(runtime.outdir)/ocr
28 | gs:
29 | type: Directory
30 | outputBinding:
31 | glob: "gs"
32 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig is awesome: http://EditorConfig.org
2 |
3 | # top-most EditorConfig file
4 | root = true
5 |
6 | # Unix-style newlines with a newline ending every file
7 | [*]
8 | end_of_line = lf
9 | insert_final_newline = true
10 | trim_trailing_whitespace = true
11 | charset = utf-8
12 |
13 | # Matches multiple files with brace expansion notation
14 | # Set default charset
15 | [*.{js,py,java,r,R,html,cwl}]
16 | indent_style = space
17 |
18 | # 4 space indentation
19 | [*.{py,java,r,R}]
20 | indent_size = 4
21 |
22 | # 2 space indentation
23 | [*.{js,json,yml,html,cwl}]
24 | indent_size = 2
25 |
26 | [*.{md,Rmd}]
27 | trim_trailing_whitespace = false
28 |
--------------------------------------------------------------------------------
/ochre/cwl/lstm-synced-correct-ocr.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.lstm_synced_correct_ocr"]
12 |
13 | inputs:
14 | model:
15 | type: File
16 | inputBinding:
17 | position: 1
18 | charset:
19 | type: File
20 | inputBinding:
21 | position: 2
22 | txt:
23 | type: File
24 | inputBinding:
25 | position: 3
26 |
27 | outputs:
28 | corrected:
29 | type: File
30 | outputBinding:
31 | glob: "$(inputs.txt.basename)"
32 |
--------------------------------------------------------------------------------
/ochre/cwl/vudnc2ocr-and-gs.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ["python", "-m", "ochre.vudnc2ocr_and_gs"]
5 |
6 | requirements:
7 | EnvVarRequirement:
8 | envDef:
9 | LC_ALL: C.UTF-8
10 | LANG: C.UTF-8
11 |
12 | inputs:
13 | in_file:
14 | type: File
15 | inputBinding:
16 | position: 1
17 | out_dir:
18 | type: Directory?
19 | inputBinding:
20 | prefix: --out_dir=
21 | separate: false
22 |
23 | outputs:
24 | gs:
25 | type: File
26 | outputBinding:
27 | glob: "*.gs.txt"
28 | ocr:
29 | type: File
30 | outputBinding:
31 | glob: "*.ocr.txt"
32 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation-extract/in/empty.html:
--------------------------------------------------------------------------------
1 |
General results
2 |
3 |
4 | | CER | n/a |
5 |
6 |
7 | | WER | n/a |
8 |
9 |
10 | | WER (order independent) | n/a |
11 |
12 |
13 | Difference spotting
14 |
16 | Error rate per character and type
17 |
18 |
19 | | Character | Hex code | Total | Spurious | Confused | Lost | Error rate |
20 |
21 |
22 | | n/a | n/a | n/a | n/a | n/a | n/a | n/a |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/ochre/cwl/merge-json.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | baseCommand: ["python", "-m", "ochre.merge_json"]
6 |
7 | requirements:
8 | InitialWorkDirRequirement:
9 | listing: $(inputs.in_files)
10 | EnvVarRequirement:
11 | envDef:
12 | LC_ALL: C.UTF-8
13 | LANG: C.UTF-8
14 |
15 | arguments:
16 | - valueFrom: $(runtime.outdir)
17 | position: 1
18 |
19 | inputs:
20 | in_files:
21 | type: File[]
22 | name:
23 | type: string?
24 | inputBinding:
25 | prefix: --name=
26 | separate: false
27 |
28 | outputs:
29 | merged:
30 | type: File
31 | outputBinding:
32 | glob: "*.csv"
33 |
--------------------------------------------------------------------------------
/ochre/cwl/create-word-mappings.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwlrunner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 |
5 | requirements:
6 | EnvVarRequirement:
7 | envDef:
8 | LC_ALL: C.UTF-8
9 | LANG: C.UTF-8
10 |
11 | baseCommand: ["python", "-m", "ochre.create_word_mappings"]
12 |
13 | inputs:
14 | saf:
15 | type: File
16 | inputBinding:
17 | position: 1
18 | alignments:
19 | type: File
20 | inputBinding:
21 | position: 2
22 | lowercase:
23 | type: boolean?
24 | inputBinding:
25 | prefix: --lowercase
26 | name:
27 | type: string?
28 | inputBinding:
29 | prefix: --name
30 |
31 | outputs:
32 | word_mapping:
33 | type: File
34 | outputBinding:
35 | glob: "*.csv"
36 |
--------------------------------------------------------------------------------
/ochre/cwl/ocrevaluation-performance-wf.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | inputs:
5 | gt: File
6 | ocr: File
7 | xmx: string?
8 | outputs:
9 | character_data:
10 | outputSource: ocrevaluation-extract/character_data
11 | type: File
12 | global_data:
13 | outputSource: ocrevaluation-extract/global_data
14 | type: File
15 | steps:
16 | ocrevaluation:
17 | run: https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl
18 | in:
19 | gt: gt
20 | ocr: ocr
21 | xmx: xmx
22 | out:
23 | - out_file
24 | ocrevaluation-extract:
25 | run: ocrevaluation-extract.cwl
26 | in:
27 | in_file: ocrevaluation/out_file
28 | out:
29 | - character_data
30 | - global_data
31 |
--------------------------------------------------------------------------------
/ochre/select_vudnc_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import os
4 | import json
5 |
6 | from nlppln.utils import cwl_file
7 |
8 |
9 | @click.command()
10 | @click.argument('dir_in', type=click.Path(exists=True))
11 | def command(dir_in):
12 | files_out = []
13 |
14 | newspapers = ['ad1951', 'nrc1950', 't1950', 'tr1950', 'vk1951']
15 |
16 | for np in newspapers:
17 | path = os.path.join(dir_in, np)
18 | for f in os.listdir(path):
19 | fi = os.path.join(path, f)
20 | if fi.endswith('.folia.xml'):
21 | files_out.append(cwl_file(fi))
22 |
23 | stdout_text = click.get_text_stream('stdout')
24 | stdout_text.write(json.dumps({'out_files': files_out}))
25 |
26 |
27 | if __name__ == '__main__':
28 | command()
29 |
--------------------------------------------------------------------------------
/ochre/select_test_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import os
4 | import json
5 |
6 | from nlppln.utils import create_dirs, cwl_file
7 | from ochre.utils import get_files
8 |
9 |
10 | @click.command()
11 | @click.argument('in_dir', type=click.Path(exists=True))
12 | @click.argument('datadivision', type=click.File(encoding='utf-8'))
13 | @click.option('--name', '-n', default='test')
14 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
15 | def command(in_dir, datadivision, name, out_dir):
16 | create_dirs(out_dir)
17 |
18 | div = json.load(datadivision)
19 | files_out = [cwl_file(f) for f in get_files(in_dir, div, name)]
20 |
21 | stdout_text = click.get_text_stream('stdout')
22 | stdout_text.write(json.dumps({'out_files': files_out}))
23 |
24 | if __name__ == '__main__':
25 | command()
26 |
--------------------------------------------------------------------------------
/nematusA8P2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --job-name=nematus_kb-ocr
4 | #SBATCH --output=nematus_A8P2.txt
5 | #
6 | #SBATCH --ntasks=1
7 | #SBATCH -C TitanX
8 | #SBATCH --time=00:05:00
9 | #SBATCH --gres=gpu:1
10 |
11 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P2/model/
12 | srun python ~/code/nematus/nematus/train.py --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P2/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P2/model/model --reload latest_checkpoint --save_freq 10000
13 |
--------------------------------------------------------------------------------
/nematusA8P3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --job-name=nematus_kb-ocr
4 | #SBATCH --output=nematus_A8P3.txt
5 | #
6 | #SBATCH --ntasks=1
7 | #SBATCH -C TitanX
8 | #SBATCH --time=00:05:00
9 | #SBATCH --gres=gpu:1
10 |
11 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P3/model/
12 | srun python ~/code/nematus/nematus/train.py --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P3/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P3/model/model --reload latest_checkpoint --save_freq 10000
13 |
--------------------------------------------------------------------------------
/ochre/cwl/ocr-transform.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: ocr-transform
5 |
6 | requirements:
7 | - class: DockerRequirement
8 | dockerPull: ubma/ocr-fileformat
9 | - class: InlineJavascriptRequirement
10 |
11 | inputs:
12 | in_fmt:
13 | type: string
14 | inputBinding:
15 | position: 1
16 | out_fmt:
17 | type: string
18 | inputBinding:
19 | position: 2
20 | in_file:
21 | type: File
22 | inputBinding:
23 | position: 3
24 |
25 | stdout: |
26 | ${
27 | var nameroot = inputs.in_file.nameroot;
28 | var ext = 'xml';
29 | if(inputs.out_fmt == 'text'){
30 | ext = 'txt';
31 | }
32 | return nameroot + '.' + ext;
33 | }
34 |
35 | outputs:
36 | out_file:
37 | type: File
38 | outputBinding:
39 | glob: $(inputs.in_file.nameroot).*
40 |
--------------------------------------------------------------------------------
/ochre/dncvu_select_ocr_and_gs_texts.py:
--------------------------------------------------------------------------------
1 | import click
2 | import json
3 | import os
4 | import glob
5 |
6 | from nlppln.utils import cwl_file
7 |
8 |
9 | @click.command()
10 | @click.argument('in_dir', type=click.Path(exists=True))
11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
12 | def select(in_dir, out_dir):
13 | gs_files = sorted(glob.glob('{}{}*.gs.txt'.format(in_dir, os.sep)))
14 | gs_files = [cwl_file(os.path.abspath(f)) for f in gs_files]
15 |
16 | ocr_files = sorted(glob.glob('{}{}*.ocr.txt'.format(in_dir, os.sep)))
17 | ocr_files = [cwl_file(os.path.abspath(f)) for f in ocr_files]
18 |
19 | stdout_text = click.get_text_stream('stdout')
20 | stdout_text.write(json.dumps({'ocr_files': ocr_files,
21 | 'gs_files': gs_files}))
22 |
23 |
24 | if __name__ == '__main__':
25 | select()
26 |
--------------------------------------------------------------------------------
/ochre/cwl/onmt-tokenize-text.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: onmt-tokenize-text
5 |
6 | doc: |
7 | Use OpenNMT tokenizer offline.
8 | See http://opennmt.net/OpenNMT-tf/tokenization.html for more information.
9 |
10 | inputs:
11 | text:
12 | type: File
13 | delimiter:
14 | type: string?
15 | inputBinding:
16 | prefix: --delimiter
17 | tokenizer:
18 | type:
19 | type: enum
20 | symbols:
21 | - CharacterTokenizer
22 | - SpaceTokenizer
23 | default: SpaceTokenizer
24 | inputBinding:
25 | prefix: --tokenizer
26 | tokenizer_config:
27 | type: File?
28 | inputBinding:
29 | prefix: --tokenizer_config
30 | out_name:
31 | type: string
32 |
33 | stdin: $(inputs.text.path)
34 |
35 | stdout: $(inputs.out_name)
36 |
37 | outputs:
38 | tokenized: stdout
39 |
--------------------------------------------------------------------------------
/nematus.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --job-name=nematus_kb-ocr
4 | #SBATCH --output=nematus_kb-ocr.txt
5 | #
6 | #SBATCH --ntasks=1
7 | #SBATCH -C TitanX
8 | #SBATCH --begin=20:00
9 | #SBATCH --time=12:00:00
10 | #SBATCH --gres=gpu:1
11 |
12 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P1/model/
13 | srun python ~/code/nematus/nematus/train.py --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P1/model/model --reload latest_checkpoint --save_freq 10000
14 |
--------------------------------------------------------------------------------
/transformer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --job-name=nematus_kb-ocr
4 | #SBATCH --output=nematus_kb-ocr.txt
5 | #
6 | #SBATCH --ntasks=1
7 | #SBATCH -C TitanX
8 | #SBATCH --begin=20:00
9 | #SBATCH --time=12:00:00
10 | #SBATCH --gres=gpu:1
11 |
12 | srun mkdir -p /var/scratch/jvdzwaan/kb-ocr/A8P1-transformer/model/
13 | srun python ~/code/nematus/nematus/train.py --model_type transformer --learning_schedule transformer --source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.ocr --target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/train.gs --embedding_size 256 --tie_encoder_decoder_embeddings --rnn_use_dropout --batch_size 100 --valid_source_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.ocr --valid_target_dataset /var/scratch/jvdzwaan/kb-ocr/A8P1/val.gs --dictionaries /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json /var/scratch/jvdzwaan/kb-ocr/A8P1/train.json --valid_batch_size 100 --model /var/scratch/jvdzwaan/kb-ocr/A8P1-transformer/model/model --reload latest_checkpoint --save_freq 10000
14 |
--------------------------------------------------------------------------------
/ochre/count_chars.py:
--------------------------------------------------------------------------------
1 | import click
2 | import json
3 | import os
4 | import codecs
5 |
6 | from collections import Counter
7 |
8 |
9 | @click.command()
10 | @click.argument('in_file', type=click.Path(exists=True))
11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
12 | def count_chars(in_file, out_dir):
13 | chars = Counter()
14 | with codecs.open(in_file, 'r', encoding='utf-8') as f:
15 | text = f.read().strip()
16 |
17 | for char in text:
18 | chars[char] += 1
19 |
20 | fname = os.path.basename(in_file.replace('.txt', '.json'))
21 | fname = os.path.join(out_dir, fname)
22 | with codecs.open(fname, 'w', encoding='utf-8') as f:
23 | json.dump(chars, f, indent=2)
24 |
25 | #for c, freq in chars.most_common():
26 | # print('{}\t{}'.format(repr(c), freq))
27 | # print(repr(c.encode('utf-8')))
28 | # print(len(c.encode('utf-8')))
29 |
30 |
31 | if __name__ == '__main__':
32 | count_chars()
33 |
--------------------------------------------------------------------------------
/ochre/cwl/onmt-main.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: onmt-main
5 |
6 | inputs:
7 | run_type:
8 | type:
9 | type: enum
10 | symbols:
11 | - train_and_eval
12 | - train
13 | - eval
14 | - infer
15 | - export
16 | - score
17 | inputBinding:
18 | position: 0
19 | model_type:
20 | type:
21 | - "null"
22 | - type: enum
23 | symbols:
24 | - ListenAttendSpell
25 | - NMTBig
26 | - NMTMedium
27 | - NMTSmall
28 | - SeqTagger
29 | - Transformer
30 | - TransformerAAN
31 | - TransformerBig
32 | inputBinding:
33 | prefix: --model_type
34 | model:
35 | type: File?
36 | inputBinding:
37 | prefix: --model
38 | config:
39 | type: File[]
40 | inputBinding:
41 | prefix: --config
42 |
43 | outputs:
44 | out_dir:
45 | type: Directory
46 |
--------------------------------------------------------------------------------
/ochre/cwl/onmt-build-vocab.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: CommandLineTool
4 | baseCommand: onmt-build-vocab
5 |
6 | inputs:
7 | in_files:
8 | type: File[]
9 | inputBinding:
10 | position: 1
11 | save_vocab:
12 | type: string
13 | inputBinding:
14 | prefix: --save_vocab
15 | min_frequency:
16 | type: int?
17 | default: 1
18 | inputBinding:
19 | prefix: --min_frequency
20 | size:
21 | type: int?
22 | default: 0
23 | inputBinding:
24 | prefix: --size
25 | without_sequence_tokens:
26 | type: bool?
27 | inputBinding:
28 | prefix: --without_sequence_tokens
29 | tokenizer:
30 | type:
31 | type: enum
32 | symbols:
33 | - CharacterTokenizer
34 | - SpaceTokenizer
35 | default: SpaceTokenizer
36 | tokenizer_config:
37 | type: File?
38 | inputBinding:
39 | prefix: --tokenizer_config
40 |
41 | outputs:
42 | out_files:
43 | type: File
44 | outputBinding:
45 | glob: $(inputs.save_vocab.basename)
46 |
--------------------------------------------------------------------------------
/ochre/kb_tss_concat_files.py:
--------------------------------------------------------------------------------
1 | import click
2 | import codecs
3 | import os
4 | from collections import Counter
5 |
6 | from nlppln.utils import get_files, out_file_name
7 |
8 |
9 | @click.command()
10 | @click.argument('in_dir', type=click.Path(exists=True))
11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
12 | def concat_files(in_dir, out_dir):
13 | in_files = get_files(in_dir)
14 |
15 | counts = Counter()
16 |
17 | for in_file in in_files:
18 | parts = os.path.basename(in_file).split(u'_')
19 | prefix = u'_'.join(parts[:2])
20 | counts[prefix] += 1
21 |
22 | out_file = out_file_name(out_dir, prefix, ext='txt')
23 |
24 | with codecs.open(in_file, 'r', encoding='utf-8') as fi:
25 | text = fi.read()
26 | text = text.replace(u'\n', u'')
27 | text = text.strip()
28 |
29 | with codecs.open(out_file, 'a', encoding='utf-8') as fo:
30 | fo.write(text)
31 | fo.write(u'\n')
32 |
33 |
34 | if __name__ == '__main__':
35 | concat_files()
36 |
--------------------------------------------------------------------------------
/tests/test_ocrerrors.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pytest
3 |
4 | from ochre.ocrerrors import hyphenation_error, accent_error, real_word_error
5 |
6 |
7 | def test_hyphenation_error_true():
8 | row = {'ocr': u'gesta- tionneerd',
9 | 'gs': u'gestationneerd'
10 | }
11 | assert hyphenation_error(row) is True
12 |
13 |
14 | def test_hyphenation_error_false():
15 | row = {'ocr': u'gestationneerd',
16 | 'gs': u'gestationneerd'
17 | }
18 | assert hyphenation_error(row) is False
19 |
20 |
21 | def test_accent_error_true():
22 | row = {'ocr': u'patienten',
23 | 'gs': u'patiënten'
24 | }
25 | assert accent_error(row) is True
26 |
27 |
28 | def test_accent_error_true_ocr():
29 | row = {'ocr': u'patiënten',
30 | 'gs': u'patienten'
31 | }
32 | assert accent_error(row) is True
33 |
34 |
35 | def test_real_word_error_vs_accent_error():
36 | row = {'ocr': u'zeker',
37 | 'gs': u'zéker'
38 | }
39 | terms = ['zeker']
40 | assert real_word_error(row, terms) is False
41 | assert accent_error(row) is True
42 |
--------------------------------------------------------------------------------
/ochre/remove_empty_files.py:
--------------------------------------------------------------------------------
1 | import click
2 | import json
3 | import os
4 | import codecs
5 |
6 | from nlppln.utils import cwl_file, get_files
7 |
8 |
9 | @click.command()
10 | @click.argument('gs_dir', type=click.Path(exists=True))
11 | @click.argument('ocr_dir', type=click.Path(exists=True))
12 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
13 | def remove_empty_files(gs_dir, ocr_dir, out_dir):
14 | ocr_out = []
15 | gs_out = []
16 |
17 | ocr_files = get_files(ocr_dir)
18 | gs_files = get_files(gs_dir)
19 |
20 | for ocr, gs in zip(ocr_files, gs_files):
21 | with codecs.open(ocr, 'r', encoding='utf-8') as f:
22 | ocr_text = f.read()
23 |
24 | with codecs.open(gs, 'r', encoding='utf-8') as f:
25 | gs_text = f.read()
26 |
27 | if len(ocr_text.strip()) > 0 and len(gs_text.strip()) > 0:
28 | ocr_out.append(cwl_file(ocr))
29 | gs_out.append(cwl_file(gs))
30 |
31 | stdout_text = click.get_text_stream('stdout')
32 | stdout_text.write(json.dumps({'ocr': ocr_out, 'gs': gs_out}))
33 |
34 |
35 | if __name__ == '__main__':
36 | remove_empty_files()
37 |
--------------------------------------------------------------------------------
/ochre/merge_json.py:
--------------------------------------------------------------------------------
1 | import click
2 | import json
3 | import os
4 | import glob
5 | import codecs
6 | import uuid
7 | import pandas as pd
8 |
9 |
10 | @click.command()
11 | @click.argument('in_dir', type=click.Path(exists=True))
12 | @click.option('--name', '-n', default='{}.csv'.format(uuid.uuid4()))
13 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
14 | def merge_json(in_dir, name, out_dir):
15 | in_files = glob.glob('{}{}*.json'.format(in_dir, os.sep))
16 | idx = [os.path.basename(f) for f in in_files]
17 |
18 | dfs = []
19 | for in_file in in_files:
20 | with codecs.open(in_file, 'r', encoding='utf-8') as f:
21 | data = json.load(f)
22 |
23 | # Make sure it works if the json file contains a list of dictionaries
24 | # instead of just a single dictionary.
25 | if not isinstance(data, list):
26 | data = [data]
27 | for item in data:
28 | dfs.append(item)
29 |
30 | if len(idx) != len(dfs):
31 | result = pd.DataFrame(dfs)
32 | else:
33 | result = pd.DataFrame(dfs, index=idx)
34 | result = result.fillna(0)
35 |
36 | out_file = os.path.join(out_dir, name)
37 | result.to_csv(out_file, encoding='utf-8')
38 |
39 |
40 | if __name__ == '__main__':
41 | merge_json()
42 |
--------------------------------------------------------------------------------
/tests/test_datagen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ochre.datagen import DataGenerator
4 |
5 |
6 | def dgen():
7 | ocr_seqs = ['abc', 'ab', 'ca8']
8 | gs_seqs = ['abc', 'bb', 'ca']
9 | p_char = 'P'
10 | oov_char = '@'
11 | n = 3
12 | ci = {'a': 0, 'b': 1, 'c': 2, p_char: 3, oov_char: 4}
13 | dg = DataGenerator(xData=ocr_seqs, yData=gs_seqs, char_to_int=ci,
14 | seq_length=n, padding_char=p_char, oov_char=oov_char,
15 | batch_size=1, shuffle=False)
16 |
17 | return dg
18 |
19 |
20 | def test_dg():
21 | dg = dgen()
22 |
23 | assert dg.n_vocab == len(dg.char_to_int)
24 | assert len(dg) == 3
25 |
26 | x, y = dg[0]
27 |
28 | print(x)
29 | print(y)
30 |
31 | assert np.array_equal(x[0], np.array([0, 1, 2]))
32 | assert np.array_equal(y[0], np.array([[1, 0, 0, 0, 0],
33 | [0, 1, 0, 0, 0],
34 | [0, 0, 1, 0, 0]]))
35 |
36 |
37 | def test_convert_sample():
38 | dg = dgen()
39 |
40 | cases = {
41 | 'aaa': np.array([0, 0, 0]),
42 | 'a': np.array([0, 3, 3]),
43 | 'b8': np.array([1, 4, 3])
44 | }
45 |
46 | for inp, outp in cases.items():
47 | assert np.array_equal(dg._convert_sample(inp), outp)
48 |
--------------------------------------------------------------------------------
/notebooks/lowercase-directory-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import nlppln\n",
20 | "\n",
21 | "with nlppln.WorkflowGenerator() as wf:\n",
22 | " wf.load(steps_dir='../cwl/')\n",
23 | " print(wf.list_steps())\n",
24 | "\n",
25 | " in_dir = wf.add_input(in_dir='Directory')\n",
26 | " dir_name = wf.add_input(dir_name='string', default='gs_lowercase')\n",
27 | "\n",
28 | " txt_files = wf.ls(in_dir=in_dir)\n",
29 | " lowercase = wf.lowercase(in_file=txt_files, scatter='in_file', scatter_method='dotproduct')\n",
30 | " out_dir = wf.save_files_to_dir(dir_name=dir_name, in_files=lowercase)\n",
31 | "\n",
32 | " wf.add_outputs(out_dir=out_dir)\n",
33 | "\n",
34 | " wf.save('../cwl/lowercase-directory.cwl')"
35 | ]
36 | }
37 | ],
38 | "metadata": {
39 | "language_info": {
40 | "name": "python",
41 | "pygments_lexer": "ipython3"
42 | }
43 | },
44 | "nbformat": 4,
45 | "nbformat_minor": 2
46 | }
47 |
--------------------------------------------------------------------------------
/ochre/char_align.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import os
5 | import json
6 |
7 | from nlppln.utils import create_dirs, out_file_name
8 |
9 | from .edlibutils import align_characters
10 |
11 |
12 | @click.command()
13 | @click.argument('ocr_text', type=click.File(mode='r', encoding='utf-8'))
14 | @click.argument('gs_text', type=click.File(mode='r', encoding='utf-8'))
15 | @click.argument('metadata', type=click.File(mode='r', encoding='utf-8'))
16 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
17 | def command(ocr_text, gs_text, metadata, out_dir):
18 | create_dirs(out_dir)
19 |
20 | ocr = ocr_text.read()
21 | gs = gs_text.read()
22 | md = json.load(metadata)
23 |
24 | check = True
25 | # Too many strange characters, so disable sanity check
26 | if len(set(ocr+gs)) > 127:
27 | check = False
28 |
29 | ocr_a, gs_a = align_characters(ocr, gs, md['cigar'], sanity_check=check)
30 |
31 | out_file = out_file_name(out_dir, md['doc_id'], 'json')
32 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
33 | try:
34 | json.dump({'ocr': ocr_a, 'gs': gs_a}, f, encoding='utf-8')
35 | except TypeError:
36 | json.dump({'ocr': ocr_a, 'gs': gs_a}, f)
37 |
38 |
39 | if __name__ == '__main__':
40 | command()
41 |
--------------------------------------------------------------------------------
/ochre/cwl/icdar2017st-extract-data.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | requirements:
5 | - class: ScatterFeatureRequirement
6 | inputs:
7 | in_dir: Directory
8 | ocr_dir_name: string
9 | gs_dir_name: string
10 | aligned_dir_name: string
11 | outputs:
12 | gs_dir:
13 | type: Directory
14 | outputSource: save-files-to-dir-6/out
15 | ocr_dir:
16 | type: Directory
17 | outputSource: save-files-to-dir-7/out
18 | aligned_dir:
19 | type: Directory
20 | outputSource: save-files-to-dir-8/out
21 | steps:
22 | ls-3:
23 | run: ls.cwl
24 | in:
25 | in_dir: in_dir
26 | out:
27 | - out_files
28 | icdar2017st-extract-text:
29 | run: icdar2017st-extract-text.cwl
30 | in:
31 | in_file: ls-3/out_files
32 | out:
33 | - aligned
34 | - gs
35 | - ocr
36 | scatter:
37 | - in_file
38 | scatterMethod: dotproduct
39 | save-files-to-dir-6:
40 | run: save-files-to-dir.cwl
41 | in:
42 | dir_name: gs_dir_name
43 | in_files: icdar2017st-extract-text/gs
44 | out:
45 | - out
46 | save-files-to-dir-7:
47 | run: save-files-to-dir.cwl
48 | in:
49 | dir_name: ocr_dir_name
50 | in_files: icdar2017st-extract-text/ocr
51 | out:
52 | - out
53 | save-files-to-dir-8:
54 | run: save-files-to-dir.cwl
55 | in:
56 | dir_name: aligned_dir_name
57 | in_files: icdar2017st-extract-text/aligned
58 | out:
59 | - out
60 |
--------------------------------------------------------------------------------
/ochre/cwl/kb-tss-preprocess-single-dir.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | requirements:
5 | - class: ScatterFeatureRequirement
6 | inputs:
7 | in_dir: Directory
8 | recursive:
9 | default: true
10 | type: boolean
11 | endswith:
12 | default: alto.xml
13 | type: string
14 | element:
15 | default:
16 | - SP
17 | type: string[]
18 | in_fmt:
19 | default: alto
20 | type: string
21 | out_fmt:
22 | default: text
23 | type: string
24 | outputs:
25 | text_files:
26 | type:
27 | type: array
28 | items: File
29 | outputSource: kb-tss-concat-files/out_files
30 | steps:
31 | ls:
32 | run: ls.cwl
33 | in:
34 | endswith: endswith
35 | in_dir: in_dir
36 | recursive: recursive
37 | out:
38 | - out_files
39 | remove-xml-elements-1:
40 | run: remove-xml-elements.cwl
41 | in:
42 | xml_file: ls/out_files
43 | element: element
44 | out:
45 | - out_file
46 | scatter:
47 | - xml_file
48 | scatterMethod: dotproduct
49 | ocr-transform-1:
50 | run: ocr-transform.cwl
51 | in:
52 | out_fmt: out_fmt
53 | in_file: remove-xml-elements-1/out_file
54 | in_fmt: in_fmt
55 | out:
56 | - out_file
57 | scatter:
58 | - in_file
59 | scatterMethod: dotproduct
60 | kb-tss-concat-files:
61 | run: kb-tss-concat-files.cwl
62 | in:
63 | in_files: ocr-transform-1/out_file
64 | out:
65 | - out_files
66 |
--------------------------------------------------------------------------------
/ochre/cwl/align-texts-wf.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | requirements:
5 | - class: ScatterFeatureRequirement
6 | inputs:
7 | gs: File[]
8 | ocr: File[]
9 | align_m:
10 | default: merged_metadata.csv
11 | type: string
12 | align_c:
13 | default: merged_changes.csv
14 | type: string
15 | outputs:
16 | alignments:
17 | outputSource: char-align-1/out_file
18 | type:
19 | type: array
20 | items: File
21 | metadata:
22 | outputSource: merge-json-2/merged
23 | type: File
24 | changes:
25 | outputSource: merge-json-3/merged
26 | type: File
27 | steps:
28 | align-1:
29 | run: https://raw.githubusercontent.com/nlppln/edlib-align/master/align.cwl
30 | in:
31 | file1: ocr
32 | file2: gs
33 | out:
34 | - changes
35 | - metadata
36 | scatter:
37 | - file1
38 | - file2
39 | scatterMethod: dotproduct
40 | merge-json-2:
41 | run: merge-json.cwl
42 | in:
43 | in_files: align-1/metadata
44 | name: align_m
45 | out:
46 | - merged
47 | merge-json-3:
48 | run: merge-json.cwl
49 | in:
50 | in_files: align-1/changes
51 | name: align_c
52 | out:
53 | - merged
54 | char-align-1:
55 | run: char-align.cwl
56 | in:
57 | gs_text: gs
58 | metadata: align-1/metadata
59 | ocr_text: ocr
60 | out:
61 | - out_file
62 | scatter:
63 | - gs_text
64 | - ocr_text
65 | - metadata
66 | scatterMethod: dotproduct
67 |
--------------------------------------------------------------------------------
/ochre/create_data_division.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import os
5 | import json
6 | import numpy as np
7 |
8 | from nlppln.utils import create_dirs, get_files
9 |
10 |
11 | @click.command()
12 | @click.argument('in_dir', type=click.Path(exists=True))
13 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
14 | @click.option('--out_name', default='datadivision.json')
15 | def command(in_dir, out_dir, out_name):
16 | """Create a division of the data in train, test and validation sets.
17 |
18 | The result is stored to a JSON file, so it can be reused.
19 | """
20 | # TODO: make seed and percentages options
21 | SEED = 4
22 | TEST_PERCENTAGE = 10
23 | VAL_PERCENTAGE = 10
24 |
25 | create_dirs(out_dir)
26 |
27 | in_files = get_files(in_dir)
28 |
29 | np.random.seed(SEED)
30 | np.random.shuffle(in_files)
31 |
32 | n_test = int(len(in_files)/100.0 * TEST_PERCENTAGE)
33 | n_val = int(len(in_files)/100.0 * VAL_PERCENTAGE)
34 |
35 | validation_texts = in_files[0:n_val]
36 | test_texts = in_files[n_val:n_val+n_test]
37 | train_texts = in_files[n_val+n_test:]
38 |
39 | division = {
40 | 'train': [os.path.basename(t) for t in train_texts],
41 | 'val': [os.path.basename(t) for t in validation_texts],
42 | 'test': [os.path.basename(t) for t in test_texts]
43 | }
44 |
45 | out_file = os.path.join(out_dir, out_name)
46 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
47 | json.dump(division, f, indent=4)
48 |
49 |
50 | if __name__ == '__main__':
51 | command()
52 |
--------------------------------------------------------------------------------
/ochre/match_ocr_and_gs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Select files that occur in both input directories and save them.
3 |
4 | Given two input directories, create ocr and gs output directories containing
5 | the files that occur in both.
6 |
7 | Files are actually copied, to prevent problems with CWL outputs.
8 | """
9 | import os
10 | import click
11 | import shutil
12 |
13 | from nlppln.utils import get_files, create_dirs, out_file_name
14 |
15 |
16 | def copy_file(fi, name, out_dir, dest):
17 | fo = out_file_name(os.path.join(out_dir, dest), name)
18 | create_dirs(fo, is_file=True)
19 | shutil.copy2(fi, fo)
20 |
21 |
22 | @click.command()
23 | @click.argument('ocr_dir', type=click.Path(exists=True))
24 | @click.argument('gs_dir', type=click.Path(exists=True))
25 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
26 | def match_ocr_and_gs(ocr_dir, gs_dir, out_dir):
27 | create_dirs(out_dir)
28 |
29 | ocr_files = {os.path.basename(f): f for f in get_files(ocr_dir)}
30 | gs_files = {os.path.basename(f): f for f in get_files(gs_dir)}
31 |
32 | ocr = set(ocr_files.keys())
33 | gs = set(gs_files.keys())
34 |
35 | if len(ocr) == 0:
36 | raise ValueError('No ocr files in directory "{}".'.format(ocr_dir))
37 | if len(gs) == 0:
38 | raise ValueError('No gs files in directory "{}".'.format(gs_dir))
39 |
40 | keep = ocr.intersection(gs)
41 |
42 | if len(keep) == 0:
43 | raise ValueError('No matching ocr and gs files.')
44 |
45 | for name in keep:
46 | copy_file(ocr_files[name], name, out_dir, 'ocr')
47 | copy_file(gs_files[name], name, out_dir, 'gs')
48 |
49 |
50 | if __name__ == '__main__':
51 | match_ocr_and_gs()
52 |
--------------------------------------------------------------------------------
/tests/test_rmgarbage.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from ochre.rmgarbage import rmgarbage_long, rmgarbage_alphanumeric, \
3 | rmgarbage_row, rmgarbage_vowels, \
4 | rmgarbage_punctuation, rmgarbage_case
5 |
6 |
7 | def test_rmgarbage_long():
8 | assert rmgarbage_long('short') is False
9 | long_string = 'Ii'*21
10 | assert rmgarbage_long(long_string) is True
11 |
12 |
13 | def test_rmgarbage_alphanumeric():
14 | assert rmgarbage_alphanumeric('.M~y~l~i~c~.I~') is True
15 | assert rmgarbage_alphanumeric('______________J.~:ys~,.j}ss.') is True
16 | assert rmgarbage_alphanumeric('14.9tv="~;ia.~:..') is True
17 | assert rmgarbage_alphanumeric('.') is False
18 |
19 |
20 | def test_rmgarbage_row():
21 | assert rmgarbage_row('111111111111111111111111') is True
22 | assert rmgarbage_row('Pnlhrrrr') is True
23 | assert rmgarbage_row('11111k1U1M.il.uu4ailuidt]i') is True
24 |
25 |
26 | def test_rmgarbage_row_non_ascii():
27 | assert rmgarbage_row(u'ÐÐÐÐææææ') is True
28 |
29 |
30 | def test_rmgarbage_vowels():
31 | assert rmgarbage_vowels('CslwWkrm') is True
32 | assert rmgarbage_vowels('Tptpmn') is True
33 | assert rmgarbage_vowels('Thlrlnd') is True
34 |
35 |
36 | def test_rmgarbage_punctuation():
37 | assert rmgarbage_punctuation('ab,cde,fg') is False
38 | assert rmgarbage_punctuation('btkvdy@us1s8') is True
39 | assert rmgarbage_punctuation('w.a.e.~tcet~oe~') is True
40 | assert rmgarbage_punctuation('iA,1llfllwl~flII~N') is True
41 |
42 |
43 | def test_rmgarbage_case():
44 | assert rmgarbage_case('bAa') is True
45 | assert rmgarbage_case('aepauWetectronic') is True
46 | assert rmgarbage_case('sUatigraphic') is True
47 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from ochre.utils import read_text_to_predict, get_char_to_int, \
2 | to_space_tokenized
3 |
4 |
5 | def test_read_text_to_predict_no_embedding():
6 | text = u'aaaaaa'
7 | seq_length = 3
8 | lowercase = True
9 | char_to_int = get_char_to_int(u'a\n')
10 | n_vocab = len(char_to_int)
11 | padding_char = u'\n'
12 | predict_chars = 0
13 | step = 1
14 | char_embedding = False
15 |
16 | result = read_text_to_predict(text, seq_length, lowercase, n_vocab,
17 | char_to_int, padding_char,
18 | predict_chars=predict_chars, step=step,
19 | char_embedding=char_embedding)
20 |
21 | # The result contains one hot encoded sequences
22 | assert result.dtype == bool
23 | assert result.shape == (4, seq_length, n_vocab)
24 |
25 |
26 | def test_read_text_to_predict_embedding():
27 | text = u'aaaaaa'
28 | seq_length = 3
29 | lowercase = True
30 | char_to_int = get_char_to_int(u'a\n')
31 | n_vocab = len(char_to_int)
32 | padding_char = u'\n'
33 | predict_chars = 0
34 | step = 1
35 | char_embedding = True
36 |
37 | result = read_text_to_predict(text, seq_length, lowercase, n_vocab,
38 | char_to_int, padding_char,
39 | predict_chars=predict_chars, step=step,
40 | char_embedding=char_embedding)
41 |
42 | # The result contains lists of ints
43 | assert result.dtype == int
44 | assert result.shape == (4, seq_length)
45 |
46 |
47 | def test_to_space_tokenized():
48 | result = to_space_tokenized('Dit is een test.')
49 | assert 'D i t i s e e n t e s t .' == result
50 |
--------------------------------------------------------------------------------
/ochre/sac2gs_and_ocr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import json
5 | import os
6 | from nlppln.utils import create_dirs, get_files, cwl_file
7 |
8 |
9 | @click.command()
10 | @click.argument('in_dir', type=click.Path(exists=True))
11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
12 | def sac2gs_and_ocr(in_dir, out_dir):
13 | result = {}
14 | result['gs_de'] = []
15 | result['ocr_de'] = []
16 | result['gs_fr'] = []
17 | result['ocr_fr'] = []
18 |
19 | files = {}
20 |
21 | for i in range(1864, 1900):
22 | try:
23 | in_files = get_files(os.path.join(in_dir, str(i)))
24 | for fi in in_files:
25 | language = 'de'
26 | typ = 'gs'
27 | bn = os.path.basename(fi)
28 |
29 | if bn.endswith('ocr'):
30 | typ = 'ocr'
31 | if 'fr' in bn:
32 | language = 'fr'
33 | with codecs.open(fi, encoding='utf-8') as f:
34 | text = f.read()
35 | fname = '{}-{}-{}.txt'.format(i, language, typ)
36 | out_file = os.path.join(out_dir, fname)
37 | create_dirs(out_file)
38 | with codecs.open(out_file, 'a', encoding='utf-8') as fo:
39 | fo.write(text)
40 | if out_file not in files:
41 | label = '{}_{}'.format(typ, language)
42 | result[label].append(cwl_file(out_file))
43 | files[out_file] = None
44 | except OSError:
45 | pass
46 |
47 | stdout_text = click.get_text_stream('stdout')
48 | stdout_text.write(json.dumps(result))
49 |
50 |
51 | if __name__ == '__main__':
52 | sac2gs_and_ocr()
53 |
--------------------------------------------------------------------------------
/ochre/icdar2017st_extract_text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import os
5 | import json
6 |
7 | from nlppln.utils import create_dirs, out_file_name
8 |
9 |
10 | def to_character_list(string):
11 | l = []
12 | for c in string:
13 | if c in ('@', '#'):
14 | l.append('')
15 | else:
16 | l.append(c)
17 | return l
18 |
19 |
20 | @click.command()
21 | @click.argument('in_file', type=click.File(encoding='utf-8'))
22 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
23 | def command(in_file, out_dir):
24 | create_dirs(out_dir)
25 |
26 | lines = in_file.readlines()
27 | # OCR_toInput: lines[0][:14]
28 | # OCR_aligned: lines[1][:14]
29 | # GS_aligned: lines[2][:14]
30 | ocr = to_character_list(lines[1][14:].strip())
31 | gs = to_character_list(lines[2][14:].strip())
32 |
33 | # Write texts
34 | out_file = out_file_name(os.path.join(out_dir, 'ocr'), os.path.basename(in_file.name))
35 | print(out_file)
36 | create_dirs(out_file)
37 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
38 | f.write(u''.join(ocr))
39 |
40 | out_file = out_file_name(os.path.join(out_dir, 'gs'), os.path.basename(in_file.name))
41 | print(out_file)
42 | create_dirs(out_file)
43 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
44 | f.write(u''.join(gs))
45 |
46 | out_file = out_file_name(out_dir, os.path.basename(in_file.name), 'json')
47 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
48 | json.dump({'ocr': ocr, 'gs': gs}, f, encoding='utf-8', indent=4)
49 |
50 | # out_file = out_file_name(out_dir, fi, 'json')
51 | # with codecs.open(out_file, 'wb', encoding='utf-8') as f:
52 | # pass
53 |
54 |
55 | if __name__ == '__main__':
56 | command()
57 |
--------------------------------------------------------------------------------
/ochre/remove_title_page.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import os
4 | import edlib
5 | import warnings
6 | import codecs
7 | import shutil
8 |
9 | from nlppln.utils import out_file_name
10 |
11 |
12 | @click.command()
13 | @click.argument('txt_file_without_tp', type=click.File(encoding='utf-8'))
14 | @click.argument('txt_file_with_tp', type=click.File(encoding='utf-8'))
15 | @click.option('--num_lines', '-n', default=100)
16 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
17 | def remove_title_page(txt_file_without_tp, txt_file_with_tp, num_lines,
18 | out_dir):
19 | result = None
20 | lines_without = txt_file_without_tp.readlines()
21 | lines_with = txt_file_with_tp.readlines()
22 |
23 | without_txt = ''.join(lines_without[:num_lines]).lower()
24 | with_txt = ''.join(lines_with[:num_lines]).lower()
25 | res = edlib.align(without_txt, with_txt)
26 | prev_ld = res['editDistance']
27 |
28 | for i in range(num_lines):
29 | without_txt = ''.join(lines_without[:num_lines]).lower()
30 | with_txt = ''.join(lines_with[i:num_lines]).lower()
31 |
32 | res = edlib.align(without_txt, with_txt)
33 | ld = res['editDistance']
34 |
35 | if ld > prev_ld:
36 | result = ''.join(lines_with[i-1:])
37 | break
38 | elif ld < prev_ld:
39 | prev_ld = ld
40 |
41 | if result is None:
42 | warnings.warn('No title page found')
43 | out_file = out_file_name(out_dir, txt_file_with_tp.name)
44 | shutil.copy2(txt_file_with_tp.name, out_file)
45 | else:
46 | out_file = out_file_name(out_dir, txt_file_with_tp.name)
47 | with codecs.open(out_file, 'w', encoding='utf8') as f:
48 | f.write(result)
49 |
50 |
51 | if __name__ == '__main__':
52 | remove_title_page()
53 |
--------------------------------------------------------------------------------
/ochre/cwl/post-correct-dir.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | requirements:
5 | - class: ScatterFeatureRequirement
6 | inputs:
7 | in_dir: Directory
8 | charset: File
9 | model: File
10 | outputs:
11 | corrected:
12 | type:
13 | items: File
14 | type: array
15 | outputSource: lstm-synced-correct-ocr/corrected
16 | steps:
17 | ls-4:
18 | run:
19 | cwlVersion: v1.0
20 | class: CommandLineTool
21 | baseCommand: [python, -m, nlppln.commands.ls]
22 |
23 | inputs:
24 | - type: Directory
25 | inputBinding:
26 | position: 2
27 | id: _:ls-4#in_dir
28 | - type:
29 | - 'null'
30 | - boolean
31 | inputBinding:
32 | prefix: --recursive
33 |
34 | id: _:ls-4#recursive
35 | stdout: cwl.output.json
36 |
37 | outputs:
38 | - type:
39 | type: array
40 | items: File
41 | id: _:ls-4#out_files
42 | id: _:ls-4
43 | in:
44 | in_dir: in_dir
45 | out:
46 | - out_files
47 | lstm-synced-correct-ocr:
48 | run:
49 | cwlVersion: v1.0
50 | class: CommandLineTool
51 | baseCommand: [python, -m, ochre.lstm_synced_correct_ocr]
52 |
53 | inputs:
54 | - type: File
55 | inputBinding:
56 | position: 2
57 | id: _:lstm-synced-correct-ocr#charset
58 | - type: File
59 | inputBinding:
60 | position: 1
61 | id: _:lstm-synced-correct-ocr#model
62 | - type: File
63 | inputBinding:
64 | position: 3
65 |
66 | id: _:lstm-synced-correct-ocr#txt
67 | outputs:
68 | - type: File
69 | outputBinding:
70 | glob: $(inputs.txt.basename)
71 | id: _:lstm-synced-correct-ocr#corrected
72 | id: _:lstm-synced-correct-ocr
73 | in:
74 | txt: ls-4/out_files
75 | model: model
76 | charset: charset
77 | out:
78 | - corrected
79 | scatter:
80 | - txt
81 | scatterMethod: dotproduct
82 |
--------------------------------------------------------------------------------
/ochre/clin2018st_extract_text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import json
5 | import os
6 |
7 | from nlppln.utils import create_dirs, remove_ext
8 |
9 |
10 | @click.command()
11 | @click.argument('json_file', type=click.File(encoding='utf-8'))
12 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
13 | def clin2018st_extract_text(json_file, out_dir):
14 | create_dirs(out_dir)
15 |
16 | corrections = {}
17 | gs_text = []
18 | text_with_errors = []
19 |
20 | text = json.load(json_file)
21 | for w in text['corrections']:
22 | span = w['span']
23 | # TODO: fix 'after'
24 | if 'after' in w.keys():
25 | print('Found "after" in {}.'.format(os.path.basename(json_file.name)))
26 | for i, w_id in enumerate(span):
27 | corrections[w_id] = {}
28 | if i == 0:
29 | corrections[w_id]['text'] = w['text']
30 | else:
31 | corrections[w_id]['text'] = u''
32 | corrections[w_id]['last'] = False
33 | if i == (len(span) - 1):
34 | corrections[w_id]['last'] = True
35 |
36 | for w in text['words']:
37 | w_id = w['id']
38 | gs_text.append(w['text'])
39 | if w_id in corrections.keys():
40 | text_with_errors.append(corrections[w_id]['text'])
41 | else:
42 | text_with_errors.append(w['text'])
43 | if w['space']:
44 | gs_text.append(u' ')
45 | text_with_errors.append(u' ')
46 |
47 | gs_file = remove_ext(json_file.name)
48 | gs_file = os.path.join(out_dir, '{}-gs.txt'.format(gs_file))
49 | with codecs.open(gs_file, 'wb', encoding='utf-8') as f:
50 | f.write(u''.join(gs_text))
51 |
52 | err_file = remove_ext(json_file.name)
53 | err_file = os.path.join(out_dir, '{}-errors.txt'.format(err_file))
54 | with codecs.open(err_file, 'wb', encoding='utf-8') as f:
55 | f.write(u''.join(text_with_errors))
56 |
57 |
58 | if __name__ == '__main__':
59 | clin2018st_extract_text()
60 |
--------------------------------------------------------------------------------
/ochre/cwl/word-mapping-wf.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | doc: This workflow is meant to be used as a subworkflow.
5 | requirements:
6 | - class: SubworkflowFeatureRequirement
7 | - class: ScatterFeatureRequirement
8 | inputs:
9 | gs_files: File[]
10 | ocr_files: File[]
11 | language: string
12 | lowercase: boolean?
13 | align_m: string?
14 | align_c: string?
15 | wm_name: string?
16 | outputs:
17 | wm_mapping:
18 | type: File
19 | outputSource: merge-csv-2/merged
20 | steps:
21 | normalize-whitespace-punctuation-2:
22 | run: normalize-whitespace-punctuation.cwl
23 | in:
24 | meta_in: gs_files
25 | out:
26 | - metadata_out
27 | scatter:
28 | - meta_in
29 | scatterMethod: dotproduct
30 | normalize-whitespace-punctuation-3:
31 | run: normalize-whitespace-punctuation.cwl
32 | in:
33 | meta_in: ocr_files
34 | out:
35 | - metadata_out
36 | scatter:
37 | - meta_in
38 | scatterMethod: dotproduct
39 | align-texts-wf-2:
40 | run: align-texts-wf.cwl
41 | in:
42 | align_m: align_m
43 | align_c: align_c
44 | ocr: normalize-whitespace-punctuation-3/metadata_out
45 | gs: normalize-whitespace-punctuation-2/metadata_out
46 | out:
47 | - alignments
48 | - changes
49 | - metadata
50 | pattern-1:
51 | run: https://raw.githubusercontent.com/nlppln/pattern-docker/master/pattern.cwl
52 | in:
53 | in_file: normalize-whitespace-punctuation-2/metadata_out
54 | language: language
55 | out:
56 | - out_files
57 | scatter:
58 | - in_file
59 | scatterMethod: dotproduct
60 | create-word-mappings-1:
61 | run: create-word-mappings.cwl
62 | in:
63 | lowercase: lowercase
64 | alignments: align-texts-wf-2/alignments
65 | saf: pattern-1/out_files
66 | out:
67 | - word_mapping
68 | scatter:
69 | - alignments
70 | - saf
71 | scatterMethod: dotproduct
72 | merge-csv-2:
73 | run: merge-csv.cwl
74 | in:
75 | in_files: create-word-mappings-1/word_mapping
76 | name: wm_name
77 | out:
78 | - merged
79 |
--------------------------------------------------------------------------------
/notebooks/vudnc-preprocess-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import nlppln\n",
20 | "import ochre\n",
21 | "import os\n",
22 | "\n",
23 | "working_dir = '/home/jvdzwaan/cwl-working-dir/'\n",
24 | "\n",
25 | "with nlppln.WorkflowGenerator(working_dir=working_dir) as wf:\n",
26 | " wf.load(steps_dir=ochre.cwl_path())\n",
27 | " print(wf.list_steps())\n",
28 | "\n",
29 | " archive = wf.add_input(archive='File')\n",
30 | " ocr_dir_name = wf.add_input(ocr_dir_name='string?', default='ocr')\n",
31 | " gs_dir_name = wf.add_input(gs_dir_name='string?', default='gs')\n",
32 | " \n",
33 | " in_dir = wf.archive2dir(archive=archive)\n",
34 | "\n",
35 | " vudnc_files = wf.vudnc_select_files(in_dir=in_dir)\n",
36 | " gs_with_empty, ocr_with_empty = wf.vudnc2ocr_and_gs(in_file=vudnc_files, \n",
37 | " scatter='in_file', scatter_method='dotproduct')\n",
38 | " gs_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs_with_empty)\n",
39 | " ocr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr_with_empty)\n",
40 | " gs, ocr = wf.remove_empty_files(gs_dir=gs_dir, ocr_dir=ocr_dir)\n",
41 | "\n",
42 | " gs_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs)\n",
43 | " ocr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr)\n",
44 | "\n",
45 | " wf.add_outputs(gs_dir=gs_dir, ocr_dir=ocr_dir)\n",
46 | " wf.save(os.path.join(ochre.cwl_path(), 'vudnc-preprocess-pack.cwl'), pack=True, relative=False)"
47 | ]
48 | }
49 | ],
50 | "metadata": {
51 | "language_info": {
52 | "name": "python",
53 | "pygments_lexer": "ipython3"
54 | }
55 | },
56 | "nbformat": 4,
57 | "nbformat_minor": 2
58 | }
59 |
--------------------------------------------------------------------------------
/notebooks/post_correction_workflows.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from nlppln import WorkflowGenerator\n",
20 | "\n",
21 | "with WorkflowGenerator() as wf:\n",
22 | " wf.load(steps_dir='../cwl/')\n",
23 | " print(wf.list_steps())\n",
24 | "\n",
25 | " in_dir = wf.add_input(in_dir='Directory')\n",
26 | " charset = wf.add_input(charset='File')\n",
27 | " model = wf.add_input(model='File')\n",
28 | "\n",
29 | " ocr_files = wf.ls(in_dir=in_dir)\n",
30 | " corrected = wf.lstm_synced_correct_ocr(charset=charset, model=model, txt=ocr_files, scatter='txt', scatter_method='dotproduct')\n",
31 | "\n",
32 | " wf.add_outputs(corrected=corrected)\n",
33 | " wf.save('../cwl/post-correct-dir.cwl')"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "from nlppln import WorkflowGenerator\n",
43 | "\n",
44 | "with WorkflowGenerator() as wf: \n",
45 | " wf.load(steps_dir='../cwl/')\n",
46 | " print(wf.list_steps())\n",
47 | "\n",
48 | " in_dir = wf.add_input(in_dir='Directory')\n",
49 | " datadiv = wf.add_input(datadivision='File')\n",
50 | " div_name = wf.add_input(div_name='string', default='test')\n",
51 | " charset = wf.add_input(charset='File')\n",
52 | " model = wf.add_input(model='File')\n",
53 | "\n",
54 | " ocr_files = wf.select_test_files(in_dir=in_dir, datadivision=datadiv)\n",
55 | " corrected = wf.lstm_synced_correct_ocr(charset=charset, model=model, txt=ocr_files, scatter='txt', scatter_method='dotproduct')\n",
56 | "\n",
57 | " wf.add_outputs(corrected=corrected)\n",
58 | " wf.save('../cwl/post-correct-test-files.cwl')"
59 | ]
60 | }
61 | ],
62 | "metadata": {
63 | "language_info": {
64 | "name": "python",
65 | "pygments_lexer": "ipython3"
66 | }
67 | },
68 | "nbformat": 4,
69 | "nbformat_minor": 2
70 | }
71 |
--------------------------------------------------------------------------------
/ochre/cwl/post-correct-test-files.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | requirements:
5 | - class: ScatterFeatureRequirement
6 | inputs:
7 | in_dir: Directory
8 | datadivision: File
9 | div_name:
10 | default: test
11 | type: string
12 | charset: File
13 | model: File
14 | outputs:
15 | corrected:
16 | type:
17 | items: File
18 | type: array
19 | outputSource: lstm-synced-correct-ocr-1/corrected
20 | steps:
21 | select-test-files-2:
22 | run:
23 | cwlVersion: v1.0
24 | class: CommandLineTool
25 | baseCommand: [python, -m, ochre.select_test_files]
26 |
27 | stdout: cwl.output.json
28 |
29 | inputs:
30 | - type: File
31 | inputBinding:
32 | position: 2
33 | id: _:select-test-files-2#datadivision
34 | - type: Directory
35 | inputBinding:
36 | position: 1
37 | id: _:select-test-files-2#in_dir
38 | - type:
39 | - 'null'
40 | - string
41 | inputBinding:
42 | prefix: --name
43 |
44 | id: _:select-test-files-2#name
45 | outputs:
46 | - type:
47 | type: array
48 | items: File
49 | id: _:select-test-files-2#out_files
50 | id: _:select-test-files-2
51 | in:
52 | in_dir: in_dir
53 | datadivision: datadivision
54 | out:
55 | - out_files
56 | lstm-synced-correct-ocr-1:
57 | run:
58 | cwlVersion: v1.0
59 | class: CommandLineTool
60 | baseCommand: [python, -m, ochre.lstm_synced_correct_ocr]
61 |
62 | inputs:
63 | - type: File
64 | inputBinding:
65 | position: 2
66 | id: _:lstm-synced-correct-ocr-1#charset
67 | - type: File
68 | inputBinding:
69 | position: 1
70 | id: _:lstm-synced-correct-ocr-1#model
71 | - type: File
72 | inputBinding:
73 | position: 3
74 |
75 | id: _:lstm-synced-correct-ocr-1#txt
76 | outputs:
77 | - type: File
78 | outputBinding:
79 | glob: $(inputs.txt.basename)
80 | id: _:lstm-synced-correct-ocr-1#corrected
81 | id: _:lstm-synced-correct-ocr-1
82 | in:
83 | txt: select-test-files-2/out_files
84 | model: model
85 | charset: charset
86 | out:
87 | - corrected
88 | scatter:
89 | - txt
90 | scatterMethod: dotproduct
91 |
--------------------------------------------------------------------------------
/ochre/ocrevaluation_extract.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import os
5 | import tempfile
6 |
7 | from bs4 import BeautifulSoup
8 |
9 | from nlppln.utils import create_dirs, remove_ext
10 |
11 |
12 | @click.command()
13 | @click.argument('in_file', type=click.File(encoding='utf-8'))
14 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
15 | def ocrevaluation_extract(in_file, out_dir):
16 | create_dirs(out_dir)
17 |
18 | tables = []
19 |
20 | write = False
21 |
22 | (fd, tmpfile) = tempfile.mkstemp()
23 | with codecs.open(tmpfile, 'w', encoding='utf-8') as tmp:
24 | for line in in_file:
25 | if line.startswith('General'):
26 | write = True
27 | if line.startswith('Difference'):
28 | write = False
29 | if line.startswith('Error'):
30 | write = True
31 |
32 | if write:
33 | tmp.write(line)
34 |
35 | with codecs.open(tmpfile, encoding='utf-8') as f:
36 | soup = BeautifulSoup(f.read(), 'lxml')
37 | os.remove(tmpfile)
38 |
39 | tables = soup.find_all('table')
40 | assert len(tables) == 2
41 |
42 | doc = remove_ext(in_file.name)
43 |
44 | t = tables[0]
45 | table_data = [[cell.text for cell in row('td')] for row in t('tr')]
46 |
47 | # 'transpose' table_data
48 | lines = {}
49 | for data in table_data:
50 | for i, entry in enumerate(data):
51 | if i not in lines.keys():
52 | # add doc id to data line (but not to header)
53 | if i != 0:
54 | lines[i] = [doc]
55 | else:
56 | lines[i] = ['']
57 | lines[i].append(entry)
58 |
59 | out_file = os.path.join(out_dir, '{}-global.csv'.format(doc))
60 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
61 | for i in range(len(lines.keys())):
62 | f.write(u','.join(lines[i]))
63 | f.write(u'\n')
64 |
65 | t = tables[1]
66 | table_data = [[cell.text for cell in row('td')] for row in t('tr')]
67 | out_file = os.path.join(out_dir, '{}-character.csv'.format(doc))
68 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
69 | for data in table_data:
70 | f.write(u'"{}",'.format(data[0]))
71 | f.write(u','.join(data[1:]))
72 | f.write(u'\n')
73 |
74 |
75 | if __name__ == '__main__':
76 | ocrevaluation_extract()
77 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Always prefer setuptools over distutils
2 | from setuptools import setup, find_packages
3 | # To use a consistent encoding
4 | from codecs import open
5 | from os import path
6 |
7 | here = path.abspath(path.dirname(__file__))
8 |
9 | setup(
10 | name='ochre',
11 |
12 | # Versions should comply with PEP440. For a discussion on single-sourcing
13 | # the version across setup.py and the project code, see
14 | # https://packaging.python.org/en/latest/single_source_version.html
15 | version='0.1.0',
16 |
17 | description='Command line tools for the ocr project',
18 | long_description="""Command line tools for the ocr project
19 | """,
20 |
21 | # The project's main homepage.
22 | url='https://github.com/WhatWorksWhenForWhom/nlppln',
23 |
24 | # Author details
25 | author='Janneke van der Zwaan',
26 | author_email='j.vanderzwaan@esciencecenter.nl',
27 |
28 | # Choose your license
29 | license='Apache',
30 |
31 | include_package_data=True,
32 |
33 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
34 | classifiers=[
35 | # How mature is this project? Common values are
36 | # 3 - Alpha
37 | # 4 - Beta
38 | # 5 - Production/Stable
39 | 'Development Status :: 3 - Alpha',
40 |
41 | # Indicate who your project is intended for
42 | 'Intended Audience :: Developers',
43 |
44 | # Pick your license as you wish (should match "license" above)
45 | 'License :: OSI Approved :: Apache Software License',
46 |
47 | # Specify the Python versions you support here. In particular, ensure
48 | # that you indicate whether you support Python 2, Python 3 or both.
49 | 'Programming Language :: Python :: 2',
50 | 'Programming Language :: Python :: 2.7',
51 | ],
52 |
53 | # What does your project relate to?
54 | keywords='text-mining, nlp, pipeline',
55 |
56 | # You can just specify the packages manually here if your project is
57 | # simple. Or you can use find_packages().
58 | packages=find_packages(),
59 |
60 | # List run-time dependencies here. These will be installed by pip when
61 | # your project is installed. For an analysis of "install_requires" vs pip's
62 | # requirements files see:
63 | # https://packaging.python.org/en/latest/requirements.html
64 | install_requires=[
65 | 'cwltool==1.0.20181102182747',
66 | 'scriptcwl>=0.8.1',
67 | 'nlppln>=0.3.2',
68 | 'six',
69 | 'glob2',
70 | 'sh',
71 | ]
72 |
73 | #scripts=['recipy-cmd']
74 | )
75 |
--------------------------------------------------------------------------------
/ochre/cwl/lowercase-directory.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | requirements:
5 | - class: ScatterFeatureRequirement
6 | inputs:
7 | in_dir: Directory
8 | dir_name:
9 | default: gs_lowercase
10 | type: string
11 | outputs:
12 | out_dir:
13 | type: Directory
14 | outputSource: save-files-to-dir-4/out
15 | steps:
16 | ls-1:
17 | run:
18 | cwlVersion: v1.0
19 | class: CommandLineTool
20 | baseCommand: [python, -m, nlppln.commands.ls]
21 |
22 | inputs:
23 | - type: Directory
24 | inputBinding:
25 | position: 2
26 | id: _:ls-1#in_dir
27 | - type:
28 | - 'null'
29 | - boolean
30 | inputBinding:
31 | prefix: --recursive
32 |
33 | id: _:ls-1#recursive
34 | stdout: cwl.output.json
35 |
36 | outputs:
37 | - type:
38 | type: array
39 | items: File
40 | id: _:ls-1#out_files
41 | id: _:ls-1
42 | in:
43 | in_dir: in_dir
44 | out:
45 | - out_files
46 | lowercase-1:
47 | run:
48 | cwlVersion: v1.0
49 | class: CommandLineTool
50 | baseCommand: [python, -m, nlppln.commands.lowercase]
51 |
52 | inputs:
53 | - type: File
54 | inputBinding:
55 | position: 1
56 |
57 | id: _:lowercase-1#in_file
58 | stdout: $(inputs.in_file.nameroot).txt
59 |
60 | outputs:
61 | - type: File
62 | outputBinding:
63 | glob: $(inputs.in_file.nameroot).txt
64 | id: _:lowercase-1#out_files
65 | id: _:lowercase-1
66 | in:
67 | in_file: ls-1/out_files
68 | out:
69 | - out_files
70 | scatter:
71 | - in_file
72 | scatterMethod: dotproduct
73 | save-files-to-dir-4:
74 | run:
75 | cwlVersion: v1.0
76 | class: ExpressionTool
77 |
78 | requirements:
79 | - class: InlineJavascriptRequirement
80 |
81 | inputs:
82 | - type: string
83 | id: _:save-files-to-dir-4#dir_name
84 | - type:
85 | type: array
86 | items: File
87 | id: _:save-files-to-dir-4#in_files
88 | outputs:
89 | - type: Directory
90 | id: _:save-files-to-dir-4#out
91 | expression: |
92 | ${
93 | return {"out": {
94 | "class": "Directory",
95 | "basename": inputs.dir_name,
96 | "listing": inputs.in_files
97 | } };
98 | }
99 | id: _:save-files-to-dir-4
100 | in:
101 | dir_name: dir_name
102 | in_files: lowercase-1/out_files
103 | out:
104 | - out
105 |
--------------------------------------------------------------------------------
/ochre/edlibutils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import edlib
3 |
4 |
5 | def align_characters(query, ref, empty_char=''):
6 | a = edlib.align(query, ref, task="path")
7 | ref_pos = a["locations"][0][0]
8 | query_pos = 0
9 | ref_aln = []
10 | match_aln = ""
11 | query_aln = []
12 |
13 | for step, code in re.findall(r"(\d+)(\D)", a["cigar"]):
14 | step = int(step)
15 | if code == "=":
16 | for c in ref[ref_pos: ref_pos + step]:
17 | ref_aln.append(c)
18 | #ref_aln += ref[ref_pos : ref_pos + step]
19 | ref_pos += step
20 | for c in query[query_pos: query_pos + step]:
21 | query_aln.append(c)
22 | #query_aln += query[query_pos : query_pos + step]
23 | query_pos += step
24 | match_aln += "|" * step
25 | elif code == "X":
26 | for c in ref[ref_pos: ref_pos + step]:
27 | ref_aln.append(c)
28 | #ref_aln += ref[ref_pos : ref_pos + step]
29 | ref_pos += step
30 | for c in query[query_pos: query_pos + step]:
31 | query_aln.append(c)
32 | #query_aln += query[query_pos : query_pos + step]
33 | query_pos += step
34 | match_aln += "." * step
35 | elif code == "D":
36 | for c in ref[ref_pos: ref_pos + step]:
37 | ref_aln.append(c)
38 | #ref_aln += ref[ref_pos : ref_pos + step]
39 | ref_pos += step
40 | #query_aln += " " * step
41 | query_pos += 0
42 | for i in range(step):
43 | query_aln.append('')
44 | match_aln += " " * step
45 | elif code == "I":
46 | for i in range(step):
47 | ref_aln.append('')
48 | #ref_aln += " " * step
49 | ref_pos += 0
50 | for c in query[query_pos: query_pos + step]:
51 | query_aln.append(c)
52 | #query_aln += query[query_pos : query_pos + step]
53 | query_pos += step
54 | match_aln += " " * step
55 | else:
56 | pass
57 |
58 | return ref_aln, query_aln, match_aln
59 |
60 |
61 | def align_output_to_input(input_str, output_str, empty_char=u'@'):
62 | t_output_str = output_str.encode('ASCII', 'replace')
63 | t_input_str = input_str.encode('ASCII', 'replace')
64 | #try:
65 | # r = edlib.align(t_input_str, t_output_str, task='path')
66 | #except:
67 | # print(input_str)
68 | # print(output_str)
69 | r1, r2 = align_characters(input_str, output_str,
70 | empty_char=empty_char)
71 | while len(r2) < len(input_str):
72 | r2.append(empty_char)
73 | return u''.join(r2)
74 |
--------------------------------------------------------------------------------
/ochre/cwl/sac-extract.cwl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env cwl-runner
2 | cwlVersion: v1.0
3 | class: Workflow
4 | inputs:
5 | data_file: File
6 | ocr_dir_name:
7 | default: ocr
8 | type: string
9 | gs_dir_name:
10 | default: gs
11 | type: string
12 | de_dir_name:
13 | default: de
14 | type: string
15 | fr_dir_name:
16 | default: fr
17 | type: string
18 | outputs:
19 | ocr_de:
20 | type: Directory
21 | outputSource: save-dir-to-subdir/out
22 | gs_de:
23 | type: Directory
24 | outputSource: save-dir-to-subdir-1/out
25 | ocr_fr:
26 | type: Directory
27 | outputSource: save-dir-to-subdir-2/out
28 | gs_fr:
29 | type: Directory
30 | outputSource: save-dir-to-subdir-3/out
31 | steps:
32 | tar:
33 | run: tar.cwl
34 | in:
35 | in_file: data_file
36 | out:
37 | - out
38 | sac2gs-and-ocr:
39 | run: sac2gs-and-ocr.cwl
40 | in:
41 | in_dir: tar/out
42 | out:
43 | - gs_de
44 | - gs_fr
45 | - ocr_de
46 | - ocr_fr
47 | mkdir:
48 | run: mkdir.cwl
49 | in:
50 | dir_name: de_dir_name
51 | out:
52 | - out
53 | mkdir-1:
54 | run: mkdir.cwl
55 | in:
56 | dir_name: fr_dir_name
57 | out:
58 | - out
59 | save-files-to-dir-9:
60 | run: save-files-to-dir.cwl
61 | in:
62 | dir_name: ocr_dir_name
63 | in_files: sac2gs-and-ocr/ocr_de
64 | out:
65 | - out
66 | save-dir-to-subdir:
67 | run: save-dir-to-subdir.cwl
68 | in:
69 | outer_dir: mkdir/out
70 | inner_dir: save-files-to-dir-9/out
71 | out:
72 | - out
73 | save-files-to-dir-10:
74 | run: save-files-to-dir.cwl
75 | in:
76 | dir_name: gs_dir_name
77 | in_files: sac2gs-and-ocr/gs_de
78 | out:
79 | - out
80 | save-dir-to-subdir-1:
81 | run: save-dir-to-subdir.cwl
82 | in:
83 | outer_dir: mkdir/out
84 | inner_dir: save-files-to-dir-10/out
85 | out:
86 | - out
87 | save-files-to-dir-11:
88 | run: save-files-to-dir.cwl
89 | in:
90 | dir_name: ocr_dir_name
91 | in_files: sac2gs-and-ocr/ocr_fr
92 | out:
93 | - out
94 | save-dir-to-subdir-2:
95 | run: save-dir-to-subdir.cwl
96 | in:
97 | outer_dir: mkdir-1/out
98 | inner_dir: save-files-to-dir-11/out
99 | out:
100 | - out
101 | save-files-to-dir-12:
102 | run: save-files-to-dir.cwl
103 | in:
104 | dir_name: gs_dir_name
105 | in_files: sac2gs-and-ocr/gs_fr
106 | out:
107 | - out
108 | save-dir-to-subdir-3:
109 | run: save-dir-to-subdir.cwl
110 | in:
111 | outer_dir: mkdir-1/out
112 | inner_dir: save-files-to-dir-12/out
113 | out:
114 | - out
115 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation/out/gs_out.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | General results
8 |
9 |
10 |
11 | | CER | 4.17 |
12 |
13 |
14 | | WER | 20.00 |
15 |
16 |
17 | | WER (order independent) | 20.00 |
18 |
19 |
20 |
21 | Difference spotting
22 |
23 |
24 |
25 |
26 |
27 | gs.txt
28 | |
29 | ocr.txt
30 | |
31 |
32 |
33 | | This is an example text. | This is an cxample text. |
34 |
35 |
36 | Error rate per character and type
37 |
38 |
39 | | Character | Hex code | Total | Spurious | Confused | Lost | Error rate |
40 |
41 |
42 | | | 20 | 4 | 0 | 0 | 0 | 0.00 |
43 |
44 |
45 | | . | 2e | 1 | 0 | 0 | 0 | 0.00 |
46 |
47 |
48 | | T | 54 | 1 | 0 | 0 | 0 | 0.00 |
49 |
50 |
51 | | a | 61 | 2 | 0 | 0 | 0 | 0.00 |
52 |
53 |
54 | | e | 65 | 3 | 0 | 1 | 0 | 33.33 |
55 |
56 |
57 | | h | 68 | 1 | 0 | 0 | 0 | 0.00 |
58 |
59 |
60 | | i | 69 | 2 | 0 | 0 | 0 | 0.00 |
61 |
62 |
63 | | l | 6c | 1 | 0 | 0 | 0 | 0.00 |
64 |
65 |
66 | | m | 6d | 1 | 0 | 0 | 0 | 0.00 |
67 |
68 |
69 | | n | 6e | 1 | 0 | 0 | 0 | 0.00 |
70 |
71 |
72 | | p | 70 | 1 | 0 | 0 | 0 | 0.00 |
73 |
74 |
75 | | s | 73 | 2 | 0 | 0 | 0 | 0.00 |
76 |
77 |
78 | | t | 74 | 2 | 0 | 0 | 0 | 0.00 |
79 |
80 |
81 | | x | 78 | 2 | 0 | 0 | 0 | 0.00 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/tests/data/ocrevaluation-extract/in/in.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | General results
8 |
9 |
10 |
11 | | CER | 4.17 |
12 |
13 |
14 | | WER | 20.00 |
15 |
16 |
17 | | WER (order independent) | 20.00 |
18 |
19 |
20 |
21 | Difference spotting
22 |
23 |
24 |
25 |
26 |
27 | gs.txt
28 | |
29 | ocr.txt
30 | |
31 |
32 |
33 | | This is an example text. | This is an cxample text. |
34 |
35 |
36 | Error rate per character and type
37 |
38 |
39 | | Character | Hex code | Total | Spurious | Confused | Lost | Error rate |
40 |
41 |
42 | | | 20 | 4 | 0 | 0 | 0 | 0.00 |
43 |
44 |
45 | | . | 2e | 1 | 0 | 0 | 0 | 0.00 |
46 |
47 |
48 | | T | 54 | 1 | 0 | 0 | 0 | 0.00 |
49 |
50 |
51 | | a | 61 | 2 | 0 | 0 | 0 | 0.00 |
52 |
53 |
54 | | e | 65 | 3 | 0 | 1 | 0 | 33.33 |
55 |
56 |
57 | | h | 68 | 1 | 0 | 0 | 0 | 0.00 |
58 |
59 |
60 | | i | 69 | 2 | 0 | 0 | 0 | 0.00 |
61 |
62 |
63 | | l | 6c | 1 | 0 | 0 | 0 | 0.00 |
64 |
65 |
66 | | m | 6d | 1 | 0 | 0 | 0 | 0.00 |
67 |
68 |
69 | | n | 6e | 1 | 0 | 0 | 0 | 0.00 |
70 |
71 |
72 | | p | 70 | 1 | 0 | 0 | 0 | 0.00 |
73 |
74 |
75 | | s | 73 | 2 | 0 | 0 | 0 | 0.00 |
76 |
77 |
78 | | t | 74 | 2 | 0 | 0 | 0 | 0.00 |
79 |
80 |
81 | | x | 78 | 2 | 0 | 0 | 0 | 0.00 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/datagen.py:
--------------------------------------------------------------------------------
1 | """Data generation functionality for ochre
2 |
3 | Source:
4 | https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
5 | """
6 | import keras
7 | import tensorflow
8 |
9 | import numpy as np
10 |
11 |
12 | class DataGenerator(tensorflow.keras.utils.Sequence):
13 | 'Generates data for Keras'
14 | def __init__(self, xData, yData, char_to_int, seq_length,
15 | padding_char='\n', oov_char='@', batch_size=32, shuffle=True):
16 | """
17 | xData is list of input strings
18 | yData is list of output strings
19 | """
20 | self.xData = xData
21 | self.yData = yData
22 |
23 | self.char_to_int = char_to_int
24 | self.padding_char = padding_char
25 | self.oov_char = oov_char
26 |
27 | self.n_vocab = len(char_to_int)
28 | self.seq_length = seq_length
29 |
30 | self.batch_size = batch_size
31 | self.shuffle = shuffle
32 | self.on_epoch_end()
33 |
34 | def __len__(self):
35 | 'Denotes the number of batches per epoch'
36 | return int(np.floor(len(self.xData) / self.batch_size))
37 |
38 | def __getitem__(self, index):
39 | 'Generate one batch of data'
40 | # Generate indexes of the batch
41 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
42 |
43 | # Generate data
44 | X, y = self.__data_generation(indexes)
45 |
46 | return X, y
47 |
48 | def on_epoch_end(self):
49 | 'Updates indexes after each epoch'
50 | self.indexes = np.arange(len(self.xData))
51 | if self.shuffle is True:
52 | np.random.shuffle(self.indexes)
53 |
54 | def __data_generation(self, list_IDs_temp):
55 | 'Generates data containing batch_size samples'
56 | # Initialization
57 | X = np.empty((self.batch_size, self.seq_length), dtype=np.int)
58 | ylist = list()
59 |
60 | # Generate data
61 | for i, ID in enumerate(list_IDs_temp):
62 | # input
63 | X[i, ] = self._convert_sample(self.xData[ID])
64 |
65 | # output
66 | y_seq = self._convert_sample(self.yData[ID])
67 | enc = keras.utils.to_categorical(y_seq, num_classes=self.n_vocab)
68 | ylist.append(enc)
69 |
70 | y = np.array(ylist)
71 | y = y.reshape(self.batch_size, self.seq_length, self.n_vocab)
72 |
73 | return X, y
74 |
75 | def _convert_sample(self, string):
76 | res = np.empty(self.seq_length, dtype=np.int)
77 | oov = self.char_to_int[self.oov_char]
78 | for i in range(self.seq_length):
79 | try:
80 | res[i] = self.char_to_int.get(string[i], oov)
81 | except IndexError:
82 | res[i] = self.char_to_int[self.padding_char]
83 | return res
84 |
--------------------------------------------------------------------------------
/ochre/scramble.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | #import codecs
4 | import numpy as np
5 |
6 |
7 | @click.command()
8 | #@click.argument('in_files', nargs=-1, type=click.Path(exists=True))
9 | #@click.argument('out_dir', nargs=1, type=click.Path())
10 | #def command(in_files, out_dir):
11 | def command():
12 | # Hoe gaat het scramblen in z'n werk?
13 | # Random. We hebben insertions, deletions en substitutions.
14 | # Die kunnen allemaal meer dan 1 teken lang zijn (zeg: tot maximaal 5 lang)
15 | # Eerst kies je of je een insertion, deletion of substitution gaat doen.
16 | # Dan kies je hoe lang ie gaat zijn (bij substitution moet je twee lengtes
17 | # kiezen: wat wordt vervangen en waarvoor het wordt vervangen)
18 | # Hmm, je hebt alleen insertion en deletion nodig (met bijbehorende
19 | # lengtes), soms doe je allebei, dat is substitution).
20 | # Nee, als het iets is, is het een substitution, dan sample je voor de
21 | # lengte van de insert en delete (die soms ook 0 kan zijn)
22 | # Hmm, which characters are subsituted depends on the char (but now we do
23 | # totally random, this is for later)
24 | text = 'Dit is een tekst.'
25 | text_list = [c for c in text]
26 | text_length = len(text)
27 | probs2 = [0.5, 0.35, 0.14, 0.005, 0.005]
28 |
29 | change = np.random.choice(2, text_length, p=[0.8, 0.2])
30 | insert = []
31 | delete = []
32 | for i in change:
33 | if i == 1:
34 | insert.append(np.random.choice(len(probs2), 1, p=probs2)[0])
35 | delete.append(np.random.choice(len(probs2), 1, p=probs2)[0])
36 | else:
37 | insert.append(0)
38 | delete.append(0)
39 |
40 | #insert = np.random.choice(max_length, text_length, p=probs)
41 | #delete = np.random.choice(max_length, text_length, p=probs)
42 | #insert = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
43 | #delete = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
44 | print(insert)
45 | print(delete)
46 | idx = 0
47 | for tidx, (i, d) in enumerate(zip(insert, delete)):
48 | #print(idx)
49 | #print(i)
50 | #print(d)
51 | #print('---')
52 | #print(idx)
53 | if i != 0 or d != 0:
54 | if d != 0:
55 | for _ in range(d):
56 | # make sure text_list[idx] exists, before deleting it
57 | if idx < len(text_list):
58 | del(text_list[idx])
59 | # only adjust idx if no characters should be inserted
60 | if i == 0:
61 | idx += 1
62 | if i != 0:
63 | for _ in range(i):
64 | # TODO: replace by random character(s) instead of a fixed
65 | # one (possibly dependent on what was deleted)
66 | text_list.insert(idx, '*')
67 | idx += i
68 | else:
69 | idx += 1
70 |
71 | print(''.join(text_list))
72 |
73 |
74 | if __name__ == '__main__':
75 | command()
76 |
--------------------------------------------------------------------------------
/ochre/create_word_mappings.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import os
4 | import json
5 |
6 | import pandas as pd
7 |
8 | from string import punctuation
9 |
10 | from nlppln.utils import create_dirs, remove_ext, out_file_name
11 |
12 |
13 | def find_word_boundaries(words, aligned):
14 | unaligned = []
15 | for w in words:
16 | for c in w:
17 | unaligned.append(c)
18 | unaligned.append('@')
19 |
20 | i = 0
21 | j = 0
22 | prev = 0
23 | wb = []
24 | while i < len(aligned) and j < len(unaligned):
25 | if unaligned[j] == '@':
26 | w = u''.join(aligned[prev:i])
27 | # Punctuation that occurs in the ocr but not in the gold standard
28 | # is ignored.
29 | if len(w) > 1 and (w[0] == u' ' or w[0] in punctuation):
30 | s = prev + 1
31 | else:
32 | s = prev
33 | #print('Word', w)
34 | wb.append((s, i))
35 | prev = i
36 | j += 1
37 | elif aligned[i] == '' or aligned[i] == ' ':
38 | #print('Space or empty in aligned', unaligned[j])
39 | i += 1
40 | elif aligned[i] == unaligned[j]:
41 | #print('Matching chars', unaligned[j])
42 | i += 1
43 | j += 1
44 | else:
45 | #print('Other', unaligned[j], aligned[i])
46 | #print('j', j, 'i', i)
47 | i += 1
48 | # add last word
49 | wb.append((prev, len(aligned)))
50 | return wb
51 |
52 |
53 | @click.command()
54 | @click.argument('saf', type=click.File(encoding='utf-8'))
55 | @click.argument('alignments', type=click.File(encoding='utf-8'))
56 | @click.option('--lowercase/--no-lowercase', default=False)
57 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
58 | def create_word_mappings(saf, alignments, lowercase, out_dir):
59 | create_dirs(out_dir)
60 |
61 | alignment_data = json.load(alignments)
62 | aligned1 = alignment_data['gs']
63 | aligned2 = alignment_data['ocr']
64 |
65 | saf = json.load(saf)
66 | if lowercase:
67 | words = [w['word'].lower() for w in saf['tokens']]
68 |
69 | aligned1 = [c.lower() for c in aligned1]
70 | aligned2 = [c.lower() for c in aligned2]
71 | else:
72 | words = [w['word'] for w in saf['tokens']]
73 |
74 | wb = find_word_boundaries(words, aligned1)
75 |
76 | doc_id = remove_ext(alignments.name)
77 |
78 | res = {'gs': [], 'ocr': [], 'doc_id': []}
79 | for s, e in wb:
80 | w1 = u''.join(aligned1[s:e])
81 | w2 = u''.join(aligned2[s:e])
82 |
83 | res['gs'].append(w1.strip())
84 | res['ocr'].append(w2.strip())
85 | res['doc_id'].append(doc_id)
86 |
87 | # Use pandas DataFrame to create the csv, so commas and quotes are properly
88 | # escaped.
89 | df = pd.DataFrame(res)
90 |
91 | out_file = out_file_name(out_dir, doc_id, ext='csv')
92 | df.to_csv(out_file, encoding='utf-8')
93 |
94 | if __name__ == '__main__':
95 | create_word_mappings()
96 |
--------------------------------------------------------------------------------
/notebooks/icdar2019POCR-to-texts.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "\n",
25 | "from tqdm import tqdm_notebook as tqdm"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import os\n",
35 | "import codecs\n",
36 | "import json\n",
37 | "\n",
38 | "from nlppln.utils import get_files, out_file_name, create_dirs\n",
39 | "\n",
40 | "from ochre.icdar2017st_extract_text import to_character_list\n",
41 | "\n",
42 | "in_dir = '/home/jvdzwaan/Downloads/POCR_training_dataset/NL/NL1/'\n",
43 | "\n",
44 | "out_dir = '/home/jvdzwaan/data/icdar2019pocr-nl'\n",
45 | "\n",
46 | "def command(in_file, out_dir):\n",
47 | " create_dirs(out_dir)\n",
48 | "\n",
49 | " lines = in_file.readlines()\n",
50 | " # OCR_toInput: lines[0][:14]\n",
51 | " # OCR_aligned: lines[1][:14]\n",
52 | " # GS_aligned: lines[2][:14]\n",
53 | " ocr = to_character_list(lines[1][14:].strip())\n",
54 | " gs = to_character_list(lines[2][14:].strip())\n",
55 | "\n",
56 | " # Write texts\n",
57 | " out_file = out_file_name(os.path.join(out_dir, 'ocr'), os.path.basename(in_file.name))\n",
58 | " print(out_file)\n",
59 | " create_dirs(out_file, is_file=True)\n",
60 | " with codecs.open(out_file, 'wb', encoding='utf-8') as f:\n",
61 | " f.write(u''.join(ocr))\n",
62 | "\n",
63 | " out_file = out_file_name(os.path.join(out_dir, 'gs'), os.path.basename(in_file.name))\n",
64 | " print(out_file)\n",
65 | " create_dirs(out_file, is_file=True)\n",
66 | " with codecs.open(out_file, 'wb', encoding='utf-8') as f:\n",
67 | " f.write(u''.join(gs))\n",
68 | "\n",
69 | " out_file = out_file_name(os.path.join(out_dir, 'aligned'), os.path.basename(in_file.name), 'json')\n",
70 | " print(out_file)\n",
71 | " create_dirs(out_file, is_file=True)\n",
72 | " with codecs.open(out_file, 'wb', encoding='utf-8') as f:\n",
73 | " json.dump({'ocr': ocr, 'gs': gs}, f)\n",
74 | " \n",
75 | " print()\n",
76 | "\n",
77 | "in_files = get_files(in_dir)\n",
78 | "#print(in_files)\n",
79 | "print(len(in_files))\n",
80 | "for in_file in in_files:\n",
81 | " print(in_file)\n",
82 | " \n",
83 | " command(open(in_file), out_dir)"
84 | ]
85 | }
86 | ],
87 | "metadata": {
88 | "language_info": {
89 | "name": "python",
90 | "pygments_lexer": "ipython3"
91 | }
92 | },
93 | "nbformat": 4,
94 | "nbformat_minor": 2
95 | }
96 |
--------------------------------------------------------------------------------
/tests/test_ocrevaluation_extract.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import codecs
4 | import sh
5 | import pytest
6 |
7 | from click.testing import CliRunner
8 | from bs4 import BeautifulSoup
9 |
10 | from ochre import ocrevaluation_extract
11 |
12 |
13 | #def test_prettify_xml():
14 | # runner = CliRunner()
15 | # with runner.isolated_filesystem():
16 | # os.makedirs('in')
17 | # os.makedirs('out')
18 | # with open('in/test.xml', 'w') as f:
19 | # xml = '\n\n' \
20 | # '\n' \
21 | # '\n' \
22 | # '\n'
23 | # f.write(xml)
24 |
25 | # result = runner.invoke(prettify_xml, ['in/test.xml',
26 | # '--out_dir', 'out'])
27 |
28 | # assert result.exit_code == 0
29 |
30 | # assert os.path.exists('out/test.xml')
31 |
32 | # with codecs.open('out/test.xml', 'r', encoding='utf-8') as f:
33 | # pretty = f.read()
34 |
35 | # assert pretty == BeautifulSoup(xml, 'xml').prettify()
36 |
37 |
38 | def test_ocr_evaluation_extract_cwl(tmpdir):
39 | tool = os.path.join('ochre', 'cwl', 'ocrevaluation-extract.cwl')
40 | in_file = os.path.join('tests', 'data', 'ocrevaluation-extract', 'in',
41 | 'in.html')
42 |
43 | try:
44 | sh.cwltool(['--outdir', tmpdir, tool, '--in_file', in_file])
45 | except sh.ErrorReturnCode as e:
46 | print(e)
47 | pytest.fail(e)
48 |
49 | for out in ['in-character.csv', 'in-global.csv']:
50 |
51 | out_file = tmpdir.join(out).strpath
52 | with open(out_file) as f:
53 | actual = f.read()
54 |
55 | fname = os.path.join('tests', 'data', 'ocrevaluation-extract', 'out',
56 | out)
57 | with open(fname) as f:
58 | expected = f.read()
59 |
60 | print(out)
61 | print(' actual:', actual)
62 | print('expected:', expected)
63 | assert actual == expected
64 |
65 |
66 | def test_ocr_evaluation_extract_cwl_empty(tmpdir):
67 | tool = os.path.join('ochre', 'cwl', 'ocrevaluation-extract.cwl')
68 | in_file = os.path.join('tests', 'data', 'ocrevaluation-extract', 'in',
69 | 'empty.html')
70 |
71 | try:
72 | sh.cwltool(['--outdir', tmpdir, tool, '--in_file', in_file])
73 | except sh.ErrorReturnCode as e:
74 | print(e)
75 | pytest.fail(e)
76 |
77 | for out in ['empty-character.csv', 'empty-global.csv']:
78 |
79 | out_file = tmpdir.join(out).strpath
80 | with open(out_file) as f:
81 | actual = f.read()
82 |
83 | fname = os.path.join('tests', 'data', 'ocrevaluation-extract', 'out',
84 | out)
85 | with open(fname) as f:
86 | expected = f.read()
87 |
88 | print(out)
89 | print(' actual:', actual)
90 | print('expected:', expected)
91 | assert actual == expected
92 |
--------------------------------------------------------------------------------
/2017_baseline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from keras.models import Sequential
5 | from keras.layers import Dense
6 | from keras.layers import LSTM
7 | from keras.layers import TimeDistributed
8 | from keras.layers import RepeatVector
9 | from keras.layers import Embedding
10 | from keras.callbacks import ModelCheckpoint
11 |
12 | from datagen import DataGenerator
13 |
14 |
15 | def infinite_loop(generator):
16 | while True:
17 | for i in range(len(generator)):
18 | yield(generator[i])
19 | generator.on_epoch_end()
20 |
21 |
22 | # load the data
23 | data_dir = '/home/jvdzwaan/data/sprint-icdar/in' # FIXME
24 | weights_dir = '/home/jvdzwaan/data/sprint-icdar/weights'
25 |
26 | if not os.path.exists(weights_dir):
27 | os.makedirs(weights_dir)
28 |
29 | seq_length = 53
30 | batch_size = 100
31 | shuffle = True
32 | pc = '\n'
33 | oc = '@'
34 |
35 | n_nodes = 1000
36 | dropout = 0.2
37 | n_embed = 256
38 |
39 | epochs = 30
40 | loss = 'categorical_crossentropy'
41 | optimizer = 'adam'
42 | metrics = ['accuracy']
43 |
44 | with open(os.path.join(data_dir, 'train.pkl'), 'rb') as f:
45 | gs_selected_train, ocr_selected_train = pickle.load(f)
46 |
47 | with open(os.path.join(data_dir, 'val.pkl'), 'rb') as f:
48 | gs_selected_val, ocr_selected_val = pickle.load(f)
49 |
50 | with open(os.path.join(data_dir, 'ci.pkl'), 'rb') as f:
51 | ci = pickle.load(f)
52 |
53 | n_vocab = len(ci)
54 |
55 | dg_val = DataGenerator(xData=ocr_selected_val, yData=gs_selected_val,
56 | char_to_int=ci,
57 | seq_length=seq_length, padding_char=pc, oov_char=oc,
58 | batch_size=batch_size, shuffle=shuffle)
59 | dg_train = DataGenerator(xData=ocr_selected_train,
60 | yData=gs_selected_train, char_to_int=ci,
61 | seq_length=seq_length, padding_char=pc,
62 | oov_char=oc,
63 | batch_size=batch_size, shuffle=shuffle)
64 |
65 | # create the network
66 | model = Sequential()
67 |
68 | # encoder
69 | model.add(Embedding(n_vocab, n_embed, input_length=seq_length))
70 | model.add(LSTM(n_nodes, input_shape=(seq_length, n_vocab)))
71 | # For the decoder's input, we repeat the encoded input for each time step
72 | model.add(RepeatVector(seq_length))
73 | model.add(LSTM(n_nodes, return_sequences=True))
74 |
75 | # For each of step of the output sequence, decide which character should be
76 | # chosen
77 | model.add(TimeDistributed(Dense(n_vocab, activation='softmax')))
78 | model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
79 |
80 | # initialize saving of weights
81 | filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5')
82 | checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1,
83 | save_best_only=True, mode='min')
84 | callbacks_list = [checkpoint]
85 |
86 | # do training (and save weights)
87 | model.fit_generator(infinite_loop(dg_train), steps_per_epoch=len(dg_train), epochs=epochs,
88 | validation_data=infinite_loop(dg_val),
89 | validation_steps=len(dg_val), callbacks=callbacks_list)
90 |
--------------------------------------------------------------------------------
/ochre/lstm_synced_correct_ocr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import click
3 | import codecs
4 | import os
5 | import numpy as np
6 |
7 | from collections import Counter
8 |
9 | from keras.models import load_model
10 |
11 | from nlppln.utils import create_dirs, out_file_name
12 |
13 | from .utils import get_char_to_int, get_int_to_char, read_text_to_predict
14 | from .edlibutils import align_output_to_input
15 |
16 |
17 | @click.command()
18 | @click.argument('model', type=click.Path(exists=True))
19 | @click.argument('charset', type=click.File(encoding='utf-8'))
20 | @click.argument('text', type=click.File(encoding='utf-8'))
21 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
22 | def lstm_synced_correct_ocr(model, charset, text, out_dir):
23 | create_dirs(out_dir)
24 |
25 | # load model
26 | model = load_model(model)
27 | conf = model.get_config()
28 | conf_result = conf[0].get('config').get('batch_input_shape')
29 | seq_length = conf_result[1]
30 | char_embedding = False
31 | if conf[0].get('class_name') == u'Embedding':
32 | char_embedding = True
33 |
34 | charset = charset.read()
35 | n_vocab = len(charset)
36 | char_to_int = get_char_to_int(charset)
37 | int_to_char = get_int_to_char(charset)
38 | lowercase = True
39 | for c in u'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
40 | if c in charset:
41 | lowercase = False
42 | break
43 |
44 | pad = u'\n'
45 |
46 | to_predict = read_text_to_predict(text.read(), seq_length, lowercase,
47 | n_vocab, char_to_int, padding_char=pad,
48 | char_embedding=char_embedding)
49 |
50 | outputs = []
51 | inputs = []
52 |
53 | predicted = model.predict(to_predict, verbose=0)
54 | for i, sequence in enumerate(predicted):
55 | predicted_indices = [np.random.choice(n_vocab, p=p) for p in sequence]
56 | pred_str = u''.join([int_to_char[j] for j in predicted_indices])
57 | outputs.append(pred_str)
58 |
59 | if char_embedding:
60 | indices = to_predict[i]
61 | else:
62 | indices = np.where(to_predict[i:i+1, :, :] == True)[2]
63 | inp = u''.join([int_to_char[j] for j in indices])
64 | inputs.append(inp)
65 |
66 | idx = 0
67 | counters = {}
68 |
69 | for input_str, output_str in zip(inputs, outputs):
70 | if pad in output_str:
71 | output_str2 = align_output_to_input(input_str, output_str,
72 | empty_char=pad)
73 | else:
74 | output_str2 = output_str
75 | for i, (inp, outp) in enumerate(zip(input_str, output_str2)):
76 | if not idx + i in counters.keys():
77 | counters[idx+i] = Counter()
78 | counters[idx+i][outp] += 1
79 |
80 | idx += 1
81 |
82 | agg_out = []
83 | for idx, c in counters.items():
84 | agg_out.append(c.most_common(1)[0][0])
85 |
86 | corrected_text = u''.join(agg_out)
87 | corrected_text = corrected_text.replace(pad, u'')
88 |
89 | out_file = out_file_name(out_dir, text.name)
90 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
91 | f.write(corrected_text)
92 |
93 |
94 | if __name__ == '__main__':
95 | lstm_synced_correct_ocr()
96 |
--------------------------------------------------------------------------------
/ochre/vudnc2ocr_and_gs.py:
--------------------------------------------------------------------------------
1 | import click
2 | import codecs
3 | import os
4 |
5 | from lxml import etree
6 | from nlppln.utils import create_dirs
7 |
8 |
9 | @click.command()
10 | @click.argument('in_file', type=click.Path(exists=True))
11 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
12 | def vudnc2ocr_and_gs(in_file, out_dir):
13 | create_dirs(out_dir)
14 |
15 | ocr_text_complete = []
16 | ocr_text = []
17 | gold_standard = []
18 | punctuation = []
19 |
20 | # TODO: how to handle headings? They are sentences without a closing
21 | # punctuation mark.
22 | # I think the existing LSTM code does not take into account documents (all
23 | # text is simply concatenated together).
24 | # TODO: fix treatment of " (double quotes). After a starting quote a space
25 | # is inserted.
26 | context = etree.iterparse(in_file, tag='{http://ilk.uvt.nl/folia}w',
27 | encoding='utf-8')
28 | for action, elem in context:
29 | ocr_word = ''
30 | punc = False
31 | pos_elem = elem.find('{http://ilk.uvt.nl/folia}pos')
32 | if pos_elem is not None:
33 | if pos_elem.get('class').startswith('LET'):
34 | punc = True
35 | for t in elem.iterchildren(tag='{http://ilk.uvt.nl/folia}t'):
36 | if t.get('class') == 'ocroutput':
37 | ocr_word = t.text
38 | else:
39 | gs_word = t.text
40 |
41 | if ocr_word != '':
42 | ocr_text.append(ocr_word)
43 | ocr_text_complete.append(ocr_word)
44 | gold_standard.append(gs_word)
45 | punctuation.append(punc)
46 |
47 | result = []
48 | for i in range(len(ocr_text_complete)):
49 | #print ocr_text_complete[i]
50 | #print gold_standard[i]
51 | result.append(gold_standard[i])
52 | space = True
53 | if punctuation[i]:
54 | #print ocr_text_complete[i-1]
55 | if i+1 < len(ocr_text_complete):
56 | if ocr_text_complete[i+1].strip().startswith(gold_standard[i]):
57 | #print 'No space after this punctuation: {}'.format(gold_standard[i])
58 | space = False
59 | if punctuation[i+1]:
60 | #print 'No space after this punctuation: '.format(gold_standard[i])
61 | space = False
62 | elif i+1 < len(ocr_text_complete) and punctuation[i+1]:
63 | #print '!'
64 | if i+2 < len(ocr_text_complete):
65 | if ocr_text_complete[i+2].strip().startswith(gold_standard[i+1]):
66 | #print 'Space after this word: {}'.format(gold_standard[i])
67 | space = True
68 | else:
69 | #print 'No space after this word: {}'.format(gold_standard[i])
70 | space = False
71 | else:
72 | space = False
73 | if space:
74 | result.append(' ')
75 |
76 | fname = os.path.basename(in_file)
77 | fname = fname.replace('.folia.xml', '.{}.txt')
78 | fname = os.path.join(out_dir, fname)
79 | with codecs.open(fname.format('gs'), 'wb', encoding='utf-8') as f:
80 | gs = u''.join(result).strip()
81 | f.write(gs)
82 |
83 | with codecs.open(fname.format('ocr'), 'wb', encoding='utf-8') as f:
84 | ocr = u' '.join(ocr_text)
85 | f.write(ocr)
86 |
87 |
88 | if __name__ == '__main__':
89 | vudnc2ocr_and_gs()
90 |
--------------------------------------------------------------------------------
/tests/test_remove_title_page.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | from click.testing import CliRunner
5 |
6 | # FIXME: import correct methods for testing
7 | from ochre.remove_title_page import remove_title_page
8 |
9 |
10 | # Documentation about testing click commands: http://click.pocoo.org/5/testing/
11 | def test_remove_title_page_single_line():
12 | runner = CliRunner()
13 | with runner.isolated_filesystem():
14 | os.makedirs('in')
15 | os.makedirs('out')
16 |
17 | with open('in/test-without.txt', 'w') as f:
18 | content_without = 'Text starts here.\n' \
19 | 'Second line.\n'
20 | f.write(content_without)
21 |
22 | with open('in/test-with.txt', 'w') as f:
23 | content = 'This is the title page\n' \
24 | 'Text starts here.\n' \
25 | 'Second line.\n'
26 | f.write(content)
27 |
28 | result = runner.invoke(remove_title_page,
29 | ['in/test-without.txt', 'in/test-with.txt',
30 | '--out_dir', 'out'])
31 |
32 | assert result.exit_code == 0
33 |
34 | assert os.path.exists('out/test-with.txt')
35 |
36 | with open('out/test-with.txt') as f:
37 | c = f.read()
38 |
39 | assert c == content_without
40 |
41 |
42 | def test_remove_title_page_multiple_lines():
43 | runner = CliRunner()
44 | with runner.isolated_filesystem():
45 | os.makedirs('in')
46 | os.makedirs('out')
47 |
48 | with open('in/test-without.txt', 'w') as f:
49 | content_without = 'Text starts here.\n' \
50 | 'Second line.\n'
51 | f.write(content_without)
52 |
53 | with open('in/test-with.txt', 'w') as f:
54 | content = 'This is the title page 1.\n' \
55 | 'This is the title page 2.\n' \
56 | 'Text starts here.\n' \
57 | 'Second line.\n'
58 | f.write(content)
59 |
60 | result = runner.invoke(remove_title_page,
61 | ['in/test-without.txt', 'in/test-with.txt',
62 | '--out_dir', 'out'])
63 |
64 | assert result.exit_code == 0
65 |
66 | assert os.path.exists('out/test-with.txt')
67 |
68 | with open('out/test-with.txt') as f:
69 | c = f.read()
70 |
71 | assert c == content_without
72 |
73 |
74 | def test_remove_title_page_no_lines():
75 | runner = CliRunner()
76 | with runner.isolated_filesystem():
77 | os.makedirs('in')
78 | os.makedirs('out')
79 |
80 | with open('in/test-without.txt', 'w') as f:
81 | content_without = 'Text starts here.\n' \
82 | 'Second line.\n'
83 | f.write(content_without)
84 |
85 | with open('in/test-with.txt', 'w') as f:
86 | content = 'Text starts here.\n' \
87 | 'Second line.\n'
88 | f.write(content)
89 |
90 | result = runner.invoke(remove_title_page,
91 | ['in/test-without.txt', 'in/test-with.txt',
92 | '--out_dir', 'out'])
93 |
94 | assert result.exit_code == 0
95 |
96 | assert os.path.exists('out/test-with.txt')
97 |
98 | with open('out/test-with.txt') as f:
99 | c = f.read()
100 |
101 | assert c == content_without
102 |
--------------------------------------------------------------------------------
/notebooks/sac-preprocess-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import os\n",
20 | "import nlppln\n",
21 | "import ochre\n",
22 | "\n",
23 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
24 | " wf.load(steps_dir=ochre.cwl_path())\n",
25 | " wf.load(step_file='https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl')\n",
26 | " \n",
27 | " print wf.list_steps()\n",
28 | " \n",
29 | " archive = wf.add_input(data_file='File')\n",
30 | " ocr_dir_name = wf.add_input(ocr_dir_name='string', default='ocr')\n",
31 | " gs_dir_name = wf.add_input(gs_dir_name='string', default='gs')\n",
32 | " align_dir_name = wf.add_input(align_dir_name='string', default='align')\n",
33 | " de_dir_name = wf.add_input(de_dir_name='string', default='de')\n",
34 | " fr_dir_name = wf.add_input(fr_dir_name='string', default='fr')\n",
35 | " \n",
36 | " data_dir = wf.tar(in_file=archive)\n",
37 | " gs_de, gs_fr, ocr_de, ocr_fr = wf.sac2gs_and_ocr(in_dir=data_dir)\n",
38 | " \n",
39 | " # create alignments\n",
40 | " alignments_de, changes_de, metadata_de = wf.align_texts_wf(gs=gs_de, ocr=ocr_de)\n",
41 | " alignments_fr, changes_fr, metadata_fr = wf.align_texts_wf(gs=gs_fr, ocr=ocr_fr)\n",
42 | " \n",
43 | " # save files to correct dirs\n",
44 | " de_dir = wf.mkdir(dir_name=de_dir_name)\n",
45 | " fr_dir = wf.mkdir(dir_name=fr_dir_name)\n",
46 | " \n",
47 | " ocr_de_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr_de)\n",
48 | " ocr_de = wf.save_dir_to_subdir(inner_dir=ocr_de_dir, outer_dir=de_dir)\n",
49 | " gs_de_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs_de)\n",
50 | " gs_de = wf.save_dir_to_subdir(inner_dir=gs_de_dir, outer_dir=de_dir)\n",
51 | " \n",
52 | " ocr_fr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr_fr)\n",
53 | " ocr_fr = wf.save_dir_to_subdir(inner_dir=ocr_fr_dir, outer_dir=fr_dir)\n",
54 | " gs_fr_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs_fr)\n",
55 | " gs_fr = wf.save_dir_to_subdir(inner_dir=gs_fr_dir, outer_dir=fr_dir)\n",
56 | " \n",
57 | " align_de_dir = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments_de)\n",
58 | " align_de = wf.save_dir_to_subdir(inner_dir=align_de_dir, outer_dir=de_dir)\n",
59 | " \n",
60 | " align_fr_dir = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments_fr)\n",
61 | " align_fr = wf.save_dir_to_subdir(inner_dir=align_fr_dir, outer_dir=fr_dir)\n",
62 | " \n",
63 | " wf.add_outputs(ocr_de=ocr_de)\n",
64 | " wf.add_outputs(gs_de=gs_de)\n",
65 | " wf.add_outputs(align_de=align_de)\n",
66 | " \n",
67 | " wf.add_outputs(ocr_fr=ocr_fr)\n",
68 | " wf.add_outputs(gs_fr=gs_fr)\n",
69 | " wf.add_outputs(align_fr=align_fr)\n",
70 | " \n",
71 | " wf.save(os.path.join(ochre.cwl_path(), 'sac-preprocess.cwl'), pack=True)"
72 | ]
73 | }
74 | ],
75 | "metadata": {
76 | "language_info": {
77 | "name": "python",
78 | "pygments_lexer": "ipython3"
79 | }
80 | },
81 | "nbformat": 4,
82 | "nbformat_minor": 2
83 | }
84 |
--------------------------------------------------------------------------------
/ochre/rmgarbage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """Implementation of rmgarbage.
4 |
5 | As described in the paper:
6 |
7 | Taghva, K., Nartker, T., Condit, A. and Borsack, J., 2001. Automatic
8 | removal of "garbage strings" in OCR text: An implementation. In The 5th
9 | World Multi-Conference on Systemics, Cybernetics and Informatics.
10 | """
11 | import click
12 | import codecs
13 | import os
14 | import pandas as pd
15 |
16 | from string import punctuation
17 |
18 | from nlppln.utils import create_dirs, out_file_name
19 |
20 |
21 | def get_rmgarbage_errors(word):
22 | errors = []
23 | if rmgarbage_long(word):
24 | errors.append('L')
25 | if rmgarbage_alphanumeric(word):
26 | errors.append('A')
27 | if rmgarbage_row(word):
28 | errors.append('R')
29 | if rmgarbage_vowels(word):
30 | errors.append('V')
31 | if rmgarbage_punctuation(word):
32 | errors.append('P')
33 | if rmgarbage_case(word):
34 | errors.append('C')
35 | return errors
36 |
37 |
38 | def rmgarbage_long(string, threshold=40):
39 | if len(string) > threshold:
40 | return True
41 | return False
42 |
43 |
44 | def rmgarbage_alphanumeric(string):
45 | alphanumeric_chars = sum(c.isalnum() for c in string)
46 | if len(string) > 2 and (alphanumeric_chars+0.0)/len(string) < 0.5:
47 | return True
48 | return False
49 |
50 |
51 | def rmgarbage_row(string, rep=4):
52 | for c in string:
53 | if c.isalnum():
54 | if c * rep in string:
55 | return True
56 | return False
57 |
58 |
59 | def rmgarbage_vowels(string):
60 | string = string.lower()
61 | if len(string) > 2 and string.isalpha():
62 | vowels = sum(c in u'aáâàåãäeéèëêuúûùüiíîìïoóôòøõö' for c in string)
63 | consonants = len(string) - vowels
64 |
65 | low = min(vowels, consonants)
66 | high = max(vowels, consonants)
67 |
68 | if low/(high+0.0) <= 0.1:
69 | return True
70 | return False
71 |
72 |
73 | def rmgarbage_punctuation(string):
74 | string = string[1:len(string)-1]
75 |
76 | punctuation_marks = set()
77 | for c in string:
78 | if c in punctuation:
79 | punctuation_marks.add(c)
80 |
81 | if len(punctuation_marks) > 1:
82 | return True
83 | return False
84 |
85 |
86 | def rmgarbage_case(string):
87 | if string[0].islower() and string[len(string)-1].islower():
88 | for c in string:
89 | if c.isupper():
90 | return True
91 | return False
92 |
93 |
94 | @click.command()
95 | @click.argument('in_file', type=click.File(encoding='utf-8'))
96 | @click.option('--out_dir', '-o', default=os.getcwd(), type=click.Path())
97 | def rmgarbage(in_file, out_dir):
98 | create_dirs(out_dir)
99 |
100 | text = in_file.read()
101 | words = text.split()
102 |
103 | doc_id = os.path.basename(in_file.name).split('.')[0]
104 |
105 | result = []
106 | removed = []
107 |
108 | for word in words:
109 | errors = get_rmgarbage_errors(word)
110 |
111 | if len(errors) == 0:
112 | result.append(word)
113 | else:
114 | removed.append({'word': word,
115 | 'errors': u''.join(errors),
116 | 'doc_id': doc_id})
117 |
118 | out_file = out_file_name(out_dir, in_file.name)
119 | with codecs.open(out_file, 'wb', encoding='utf-8') as f:
120 | f.write(u' '.join(result))
121 |
122 | metadata_out = pd.DataFrame(removed)
123 | fname = '{}-rmgarbage-metadata.csv'.format(doc_id)
124 | out_file = out_file_name(out_dir, fname)
125 | metadata_out.to_csv(out_file, encoding='utf-8')
126 |
127 |
128 | if __name__ == '__main__':
129 | rmgarbage()
130 |
--------------------------------------------------------------------------------
/ochre/keras_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import glob2
4 |
5 | from keras.models import Sequential
6 | from keras.layers import Dense
7 | from keras.layers import Dropout
8 | from keras.layers import LSTM
9 | from keras.layers import TimeDistributed
10 | from keras.layers import Bidirectional
11 | from keras.layers import RepeatVector
12 | from keras.layers import Embedding
13 | from keras.callbacks import ModelCheckpoint
14 |
15 |
16 | def initialize_model(n, dropout, seq_length, chars, output_size, layers,
17 | loss='categorical_crossentropy', optimizer='adam'):
18 | model = Sequential()
19 | model.add(LSTM(n, input_shape=(seq_length, len(chars)),
20 | return_sequences=True))
21 | model.add(Dropout(dropout))
22 |
23 | for _ in range(layers-1):
24 | model.add(LSTM(n, return_sequences=True))
25 | model.add(Dropout(dropout))
26 |
27 | model.add(TimeDistributed(Dense(len(chars), activation='softmax')))
28 |
29 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])
30 |
31 | return model
32 |
33 |
34 | def initialize_model_bidirectional(n, dropout, seq_length, chars, output_size,
35 | layers, loss='categorical_crossentropy',
36 | optimizer='adam'):
37 | model = Sequential()
38 | model.add(Bidirectional(LSTM(n, return_sequences=True),
39 | input_shape=(seq_length, len(chars))))
40 | model.add(Dropout(dropout))
41 |
42 | for _ in range(layers-1):
43 | model.add(Bidirectional(LSTM(n, return_sequences=True)))
44 | model.add(Dropout(dropout))
45 |
46 | model.add(TimeDistributed(Dense(len(chars), activation='softmax')))
47 |
48 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])
49 |
50 | return model
51 |
52 |
53 | def initialize_model_seq2seq(n, dropout, seq_length,
54 | output_size, layers, char_embedding_size=0,
55 | loss='categorical_crossentropy', optimizer='adam',
56 | metrics=['accuracy']):
57 | model = Sequential()
58 | # encoder
59 | if char_embedding_size:
60 | n_embed = char_embedding_size
61 | model.add(Embedding(output_size, n_embed, input_length=seq_length))
62 | model.add(LSTM(n))
63 | else:
64 | model.add(LSTM(n, input_shape=(seq_length, output_size)))
65 | # For the decoder's input, we repeat the encoded input for each time step
66 | model.add(RepeatVector(seq_length))
67 | # The decoder RNN could be multiple layers stacked or a single layer
68 | for _ in range(layers-1):
69 | model.add(LSTM(n, return_sequences=True))
70 |
71 | # For each of step of the output sequence, decide which character should be
72 | # chosen
73 | model.add(TimeDistributed(Dense(output_size, activation='softmax')))
74 | model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
75 |
76 | return model
77 |
78 |
79 | def load_weights(model, weights_dir, loss='categorical_crossentropy',
80 | optimizer='adam'):
81 | epoch = 0
82 | weight_files = glob2.glob('{}{}*.hdf5'.format(weights_dir, os.sep))
83 | if weight_files != []:
84 | fname = sorted(weight_files)[0]
85 | print('Loading weights from {}'.format(fname))
86 |
87 | model.load_weights(fname)
88 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])
89 |
90 | m = re.match(r'.+-(\d\d).hdf5', fname)
91 | if m:
92 | epoch = int(m.group(1))
93 | epoch += 1
94 |
95 | return epoch, model
96 |
97 |
98 | def add_checkpoint(weights_dir):
99 | filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5')
100 | checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1,
101 | save_best_only=True, mode='min')
102 | return checkpoint
103 |
--------------------------------------------------------------------------------
/notebooks/preprocess-dbnl_ocr.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import os\n",
20 | "import nlppln\n",
21 | "import ochre\n",
22 | "\n",
23 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
24 | " wf.load(steps_dir=ochre.cwl_path())\n",
25 | " #print(wf.list_steps())\n",
26 | " \n",
27 | " txt_dir = wf.add_input(txt_dir='Directory')\n",
28 | " tei_dir = wf.add_input(tei_dir='Directory')\n",
29 | " repl = wf.add_input(replacement='string', default='space')\n",
30 | " tmp_ocr_dir_name = wf.add_input(tmp_ocr_dir_name='string', default='tmp_ocr')\n",
31 | " tmp_gs_dir_name = wf.add_input(tmp_gs_dir_name='string', default='tmp_gs')\n",
32 | " convert = wf.add_input(convert='boolean', default=True)\n",
33 | " \n",
34 | " tei_files = wf.ls(in_dir=tei_dir)\n",
35 | " unnormalized_gs_files = wf.tei2txt(tei_file=tei_files, scatter='tei_file', scatter_method='dotproduct')\n",
36 | " normalized_gs_files = wf.remove_newlines(in_file=unnormalized_gs_files, replacement=repl, \n",
37 | " scatter='in_file', scatter_method='dotproduct')\n",
38 | " tmp_gs_dir = wf.save_files_to_dir(dir_name=tmp_gs_dir_name, in_files=normalized_gs_files)\n",
39 | " gs_without_empty = wf.delete_empty_files(in_dir=tmp_gs_dir)\n",
40 | " \n",
41 | " non_utf8_ocr_files = wf.ls(in_dir=txt_dir)\n",
42 | " unnormalized_ocr_files = wf.check_utf8(in_files=non_utf8_ocr_files, convert=convert)\n",
43 | " normalized_ocr_files = wf.remove_newlines(in_file=unnormalized_ocr_files, replacement=repl, \n",
44 | " scatter='in_file', scatter_method='dotproduct')\n",
45 | " tmp_ocr_dir = wf.save_files_to_dir(dir_name=tmp_ocr_dir_name, in_files=normalized_ocr_files)\n",
46 | " ocr_without_empty = wf.delete_empty_files(in_dir=tmp_ocr_dir)\n",
47 | " \n",
48 | " tmp_gs_dir = wf.save_files_to_dir(dir_name=tmp_gs_dir_name, in_files=gs_without_empty)\n",
49 | " tmp_ocr_dir = wf.save_files_to_dir(dir_name=tmp_ocr_dir_name, in_files=ocr_without_empty)\n",
50 | "\n",
51 | " gs, ocr = wf.match_ocr_and_gs(gs_dir=tmp_gs_dir, ocr_dir=tmp_ocr_dir)\n",
52 | "\n",
53 | " wf.add_outputs(gs=gs)\n",
54 | " wf.add_outputs(ocr=ocr)\n",
55 | " \n",
56 | " wf.save(os.path.join(ochre.cwl_path(),'dbnl_ocr2ocr_and_gs.cwl'), mode='pack')"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "import logging\n",
66 | "logging.basicConfig(format=\"%(asctime)s [%(process)d] %(levelname)-8s \"\n",
67 | " \"%(name)s,%(lineno)s\\t%(message)s\")\n",
68 | "logging.getLogger().setLevel('DEBUG')"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "import os\n",
78 | "import nlppln\n",
79 | "import ochre\n",
80 | "\n",
81 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
82 | " wf.load(steps_dir=ochre.cwl_path())\n",
83 | " \n",
84 | " in_dir = wf.add_input(in_dir='Directory')\n",
85 | " \n",
86 | " in_files = wf.ls(in_dir=in_dir)\n",
87 | " txt_files = wf.tei2txt(tei_file=in_files, scatter='tei_file')\n",
88 | " out_files = wf.delete_empty_files(in_files=txt_files)\n",
89 | " \n",
90 | " wf.add_outputs(out_files=out_files)\n",
91 | " \n",
92 | " wf.save(os.path.join(ochre.cwl_path(),'tei2txt_dir.cwl'), mode='wd')"
93 | ]
94 | }
95 | ],
96 | "metadata": {
97 | "language_info": {
98 | "name": "python",
99 | "pygments_lexer": "ipython3"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 2
104 | }
105 |
--------------------------------------------------------------------------------
/lstm_synced.py:
--------------------------------------------------------------------------------
1 | from keras.callbacks import ModelCheckpoint
2 |
3 | from ochre.utils import create_training_data, read_texts, \
4 | get_char_to_int
5 | from ochre.keras_utils import initialize_model, load_weights, \
6 | initialize_model_bidirectional, initialize_model_seq2seq
7 |
8 | import click
9 | import os
10 | import json
11 | import codecs
12 |
13 |
14 | @click.command()
15 | @click.argument('datasets', type=click.File())
16 | @click.argument('data_dir', type=click.Path(exists=True))
17 | @click.option('--weights_dir', '-w', default=os.getcwd(), type=click.Path())
18 | def train_lstm(datasets, data_dir, weights_dir):
19 | # lees data in en maak character mappings
20 | # genereer trainings data
21 | seq_length = 25
22 | num_nodes = 256
23 | layers = 2
24 | batch_size = 100
25 | step = 3 # step size used to create data (3 = use every third sequence)
26 | lowercase = True
27 | bidirectional = False
28 | seq2seq = True
29 |
30 | print('Sequence lenght: {}'.format(seq_length))
31 | print('Number of nodes in hidden layers: {}'.format(num_nodes))
32 | print('Number of hidden layers: {}'.format(layers))
33 | print('Batch size: {}'.format(batch_size))
34 | print('Lowercase data: {}'.format(lowercase))
35 | print('Bidirectional layers: {}'.format(bidirectional))
36 | print('Seq2seq: {}'.format(seq2seq))
37 |
38 | division = json.load(datasets)
39 |
40 | raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir)
41 | raw_test, gs_test, ocr_test = read_texts(division.get('test'), data_dir)
42 | raw_train, gs_train, ocr_train = read_texts(division.get('train'), data_dir)
43 |
44 | raw_text = ''.join([raw_val, raw_test, raw_train])
45 | if lowercase:
46 | raw_text = raw_text.lower()
47 |
48 | #print('Number of texts: {}'.format(len(data_files)))
49 |
50 | chars = sorted(list(set(raw_text)))
51 | chars.append(u'\n') # padding character
52 | char_to_int = get_char_to_int(chars)
53 |
54 | # save charset to file
55 | if lowercase:
56 | fname = 'chars-lower.txt'
57 | else:
58 | fname = 'chars.txt'
59 | chars_file = os.path.join(weights_dir, fname)
60 | with codecs.open(chars_file, 'wb', encoding='utf-8') as f:
61 | f.write(u''.join(chars))
62 |
63 | n_chars = len(raw_text)
64 | n_vocab = len(chars)
65 |
66 | print('Total Characters: {}'.format(n_chars))
67 | print('Total Vocab: {}'.format(n_vocab))
68 |
69 | numTrainSamples, trainDataGen = create_training_data(ocr_train, gs_train, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase, step=step)
70 | numTestSamples, testDataGen = create_training_data(ocr_test, gs_test, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase)
71 | numValSamples, valDataGen = create_training_data(ocr_val, gs_val, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase)
72 |
73 | n_patterns = numTrainSamples
74 | print("Train Patterns: {}".format(n_patterns))
75 | print("Validation Patterns: {}".format(numValSamples))
76 | print("Test Patterns: {}".format(numTestSamples))
77 | print('Total: {}'.format(numTrainSamples+numTestSamples+numValSamples))
78 |
79 | if bidirectional:
80 | model = initialize_model_bidirectional(num_nodes, 0.5, seq_length,
81 | chars, n_vocab, layers)
82 | elif seq2seq:
83 | model = initialize_model_seq2seq(num_nodes, 0.5, seq_length,
84 | n_vocab, layers)
85 | else:
86 | model = initialize_model(num_nodes, 0.5, seq_length, chars, n_vocab,
87 | layers)
88 | epoch, model = load_weights(model, weights_dir)
89 |
90 | # initialize saving of weights
91 | filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5')
92 | checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1,
93 | save_best_only=True, mode='min')
94 | callbacks_list = [checkpoint]
95 |
96 | # do training (and save weights)
97 | model.fit_generator(trainDataGen, steps_per_epoch=int(numTrainSamples/batch_size), epochs=40, validation_data=valDataGen, validation_steps=int(numValSamples/batch_size), callbacks=callbacks_list, initial_epoch=epoch)
98 |
99 |
100 | if __name__ == '__main__':
101 | train_lstm()
102 |
--------------------------------------------------------------------------------
/notebooks/align-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import nlppln\n",
20 | "import ochre\n",
21 | "import os\n",
22 | "\n",
23 | "working_dir = '~/cwl-working-dir/'\n",
24 | "\n",
25 | "with nlppln.WorkflowGenerator(working_dir=working_dir) as wf:\n",
26 | " wf.load(steps_dir=ochre.cwl_path())\n",
27 | " print(wf.list_steps())\n",
28 | "\n",
29 | " gs_files = wf.add_input(gs='File[]')\n",
30 | " ocr_files = wf.add_input(ocr='File[]')\n",
31 | " merged_metadata_name = wf.add_input(align_m='string', default='merged_metadata.csv')\n",
32 | " merged_changes_name = wf.add_input(align_c='string', default='merged_changes.csv')\n",
33 | " \n",
34 | " changes, metadata = wf.align(file1=ocr_files, file2=gs_files, \n",
35 | " scatter=['file1', 'file2'], scatter_method='dotproduct')\n",
36 | " merged1 = wf.merge_json(in_files=metadata, name=merged_metadata_name)\n",
37 | " merged2 = wf.merge_json(in_files=changes, name=merged_changes_name)\n",
38 | " \n",
39 | " alignments = wf.char_align(gs_text=gs_files, metadata=metadata, ocr_text=ocr_files, \n",
40 | " scatter=['gs_text', 'ocr_text', 'metadata'], scatter_method='dotproduct')\n",
41 | " \n",
42 | " wf.add_outputs(alignments=alignments)\n",
43 | " wf.add_outputs(metadata=merged1)\n",
44 | " wf.add_outputs(changes=merged2)\n",
45 | " wf.save(os.path.join(ochre.cwl_path(), 'align-texts-wf.cwl'), wd=True, relative=False)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "# align directory\n",
55 | "import nlppln\n",
56 | "import ochre\n",
57 | "import os\n",
58 | "\n",
59 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
60 | " wf.load(steps_dir=ochre.cwl_path())\n",
61 | " print(wf.list_steps())\n",
62 | "\n",
63 | " gs = wf.add_input(gs='Directory')\n",
64 | " ocr = wf.add_input(ocr='Directory')\n",
65 | " align_dir_name = wf.add_input(align_dir_name='string', default='align')\n",
66 | " \n",
67 | " gs_files = wf.ls(in_dir=gs)\n",
68 | " ocr_files = wf.ls(in_dir=ocr)\n",
69 | " \n",
70 | " alignments, changes, metadata = wf.align_texts_wf(gs=gs_files, ocr=ocr_files)\n",
71 | " \n",
72 | " align = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments)\n",
73 | " \n",
74 | " wf.add_outputs(align=align)\n",
75 | " wf.save(os.path.join(ochre.cwl_path(), 'align-dir-pack.cwl'), pack=True, relative=False)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "# align test files only\n",
85 | "import nlppln\n",
86 | "import ochre\n",
87 | "import os\n",
88 | "\n",
89 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
90 | " wf.load(steps_dir=ochre.cwl_path())\n",
91 | " print(wf.list_steps())\n",
92 | "\n",
93 | " gs_dir = wf.add_input(gs_dir='Directory')\n",
94 | " ocr_dir = wf.add_input(ocr_dir='Directory')\n",
95 | " data_div = wf.add_input(data_div='File')\n",
96 | " div_name = wf.add_input(div_name='string')\n",
97 | " align_dir_name = wf.add_input(align_dir_name='string', default='align')\n",
98 | "\n",
99 | " test_gs = wf.select_test_files(datadivision=data_div, name=div_name, in_dir=gs_dir)\n",
100 | " test_ocr = wf.select_test_files(datadivision=data_div, name=div_name, in_dir=ocr_dir)\n",
101 | "\n",
102 | " alignments, changes, metadata = wf.align_texts_wf(gs=test_gs, ocr=test_ocr)\n",
103 | "\n",
104 | " align = wf.save_files_to_dir(dir_name=align_dir_name, in_files=alignments)\n",
105 | "\n",
106 | " wf.add_outputs(align=align)\n",
107 | " wf.save(os.path.join(ochre.cwl_path(), 'align-test-files-pack.cwl'), pack=True, relative=False)"
108 | ]
109 | }
110 | ],
111 | "metadata": {
112 | "language_info": {
113 | "name": "python",
114 | "pygments_lexer": "ipython3"
115 | }
116 | },
117 | "nbformat": 4,
118 | "nbformat_minor": 2
119 | }
120 |
--------------------------------------------------------------------------------
/notebooks/icdar2017-ngrams.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "import matplotlib.pyplot as plt"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import json\n",
33 | "\n",
34 | "from ochre.utils import read_texts\n",
35 | "\n",
36 | "datasets = '/home/jvdzwaan/data/icdar2017st/eng_monograph/datadivision.json'\n",
37 | "data_dir = '/home/jvdzwaan/data/icdar2017st/eng_monograph/aligned/'\n",
38 | "\n",
39 | "with open(datasets) as d:\n",
40 | " division = json.load(d)\n",
41 | "print(len(division['train']))\n",
42 | "print(len(division['test']))\n",
43 | "print(len(division['val']))"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "from ochre.utils import get_chars, get_sequences\n",
53 | "\n",
54 | "seq_length = 53\n",
55 | "\n",
56 | "raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir)\n",
57 | "raw_test, gs_test, ocr_test = read_texts(division.get('test'), data_dir)\n",
58 | "raw_train, gs_train, ocr_train = read_texts(division.get('train'), data_dir)\n",
59 | "\n",
60 | "chars, num_chars, ci = get_chars(raw_val, raw_test, raw_train, False)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "%%time\n",
70 | "from collections import defaultdict, Counter\n",
71 | "\n",
72 | "from nlppln.utils import get_files\n",
73 | "\n",
74 | "#c = defaultdict(int)\n",
75 | "c = Counter()\n",
76 | "\n",
77 | "data_dirs = ['/home/jvdzwaan/data/icdar2017st/eng_monograph/aligned/',\n",
78 | " '/home/jvdzwaan/data/icdar2017st/eng_periodical/aligned/',\n",
79 | " '/home/jvdzwaan/data/icdar2017st/fr_monograph/aligned/',\n",
80 | " '/home/jvdzwaan/data/icdar2017st/fr_periodical/aligned/']\n",
81 | "\n",
82 | "for dd in data_dirs:\n",
83 | " in_files = get_files(dd)\n",
84 | " \n",
85 | " raw, gs, ocr = read_texts(in_files, data_dir=None)\n",
86 | " \n",
87 | " for c1, c2 in zip(ocr, gs):\n",
88 | " if c1 != c2 and c1 != ' ' and c2 != ' ':\n",
89 | " c[(c1, c2)] += 1"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "for ch, f in c.most_common(10):\n",
99 | " print(ch, f)"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "c[('1', 'I')]"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "%%time\n",
118 | "from ochre.utils import get_sequences\n",
119 | "\n",
120 | "data_dirs = ['/home/jvdzwaan/data/icdar2017st/eng_monograph/aligned/',\n",
121 | " '/home/jvdzwaan/data/icdar2017st/eng_periodical/aligned/',\n",
122 | " '/home/jvdzwaan/data/icdar2017st/fr_monograph/aligned/',\n",
123 | " '/home/jvdzwaan/data/icdar2017st/fr_periodical/aligned/']\n",
124 | "\n",
125 | "c = Counter()\n",
126 | "\n",
127 | "for dd in data_dirs:\n",
128 | " in_files = get_files(dd)\n",
129 | " \n",
130 | " raw, gs, ocr = read_texts(in_files, data_dir=None)\n",
131 | " \n",
132 | " gs_seqs, ocr_seqs = get_sequences(gs, ocr, 3)\n",
133 | " \n",
134 | " for ocr, gs in zip(ocr_seqs, gs_seqs):\n",
135 | " c[(ocr, gs)] += 1"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "for ch, f in c.most_common(20):\n",
145 | " print(ch, f)"
146 | ]
147 | }
148 | ],
149 | "metadata": {
150 | "language_info": {
151 | "name": "python",
152 | "pygments_lexer": "ipython3"
153 | }
154 | },
155 | "nbformat": 4,
156 | "nbformat_minor": 2
157 | }
158 |
--------------------------------------------------------------------------------
/notebooks/2017-baseline-nn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "import matplotlib.pyplot as plt"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import pickle\n",
33 | "\n",
34 | "seq_length = 53\n",
35 | "\n",
36 | "with open('train.pkl', 'rb') as f:\n",
37 | " gs_selected_train, ocr_selected_train = pickle.load(f)\n",
38 | " \n",
39 | "with open('val.pkl', 'rb') as f:\n",
40 | " gs_selected_val, ocr_selected_val = pickle.load(f)\n",
41 | " \n",
42 | "with open('ci.pkl', 'rb') as f:\n",
43 | " ci = pickle.load(f)"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "gs_selected_train[:50]"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "from ochre.datagen import DataGenerator\n",
62 | "\n",
63 | "dg_val = DataGenerator(xData=ocr_selected_val[:10], yData=gs_selected_val[:10], char_to_int=ci,\n",
64 | " seq_length=seq_length, padding_char='\\n', oov_char='@',\n",
65 | " batch_size=10, shuffle=False)\n",
66 | "dg_train = DataGenerator(xData=ocr_selected_train[:50], yData=gs_selected_train[:50], char_to_int=ci,\n",
67 | " seq_length=seq_length, padding_char='\\n', oov_char='@',\n",
68 | " batch_size=10, shuffle=False)"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "from keras.models import Sequential\n",
78 | "from keras.layers import Dense\n",
79 | "from keras.layers import Dropout\n",
80 | "from keras.layers import LSTM\n",
81 | "from keras.layers import TimeDistributed\n",
82 | "from keras.layers import Bidirectional\n",
83 | "from keras.layers import RepeatVector\n",
84 | "from keras.layers import Embedding\n",
85 | "from keras.callbacks import ModelCheckpoint\n",
86 | "\n",
87 | "n_nodes = 1000\n",
88 | "dropout = 0.2\n",
89 | "n_embed = 256\n",
90 | "n_vocab = len(ci)\n",
91 | "\n",
92 | "loss='categorical_crossentropy'\n",
93 | "optimizer='adam'\n",
94 | "metrics=['accuracy']\n",
95 | "\n",
96 | "model = Sequential()\n",
97 | "\n",
98 | "# encoder\n",
99 | "\n",
100 | "model.add(Embedding(n_vocab, n_embed, input_length=seq_length))\n",
101 | "model.add(LSTM(n_nodes, input_shape=(seq_length, n_vocab)))\n",
102 | "# For the decoder's input, we repeat the encoded input for each time step\n",
103 | "model.add(RepeatVector(seq_length))\n",
104 | "model.add(LSTM(n_nodes, return_sequences=True))\n",
105 | "\n",
106 | "# For each of step of the output sequence, decide which character should be\n",
107 | "# chosen\n",
108 | "model.add(TimeDistributed(Dense(n_vocab, activation='softmax')))\n",
109 | "model.compile(loss=loss, optimizer=optimizer, metrics=metrics)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "# initialize saving of weights\n",
119 | "#filepath = os.path.join(weights_dir, '{loss:.4f}-{epoch:02d}.hdf5')\n",
120 | "filepath = '{loss:.4f}-{epoch:02d}.hdf5'\n",
121 | "checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1,\n",
122 | " save_best_only=True, mode='min')\n",
123 | "callbacks_list = [checkpoint]\n",
124 | "\n",
125 | "# do training (and save weights)\n",
126 | "model.fit_generator(dg_train, steps_per_epoch=len(dg_train), epochs=10, \n",
127 | " validation_data=dg_val, \n",
128 | " validation_steps=len(dg_val), callbacks=callbacks_list,\n",
129 | " use_multiprocessing=True,\n",
130 | " workers=3)"
131 | ]
132 | }
133 | ],
134 | "metadata": {
135 | "language_info": {
136 | "name": "python",
137 | "pygments_lexer": "ipython3"
138 | }
139 | },
140 | "nbformat": 4,
141 | "nbformat_minor": 2
142 | }
143 |
--------------------------------------------------------------------------------
/notebooks/kb_tss_preprocess.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "# subworkflow to convert all alto files in the subdirectories of one directory into text files\n",
20 | "import nlppln\n",
21 | "\n",
22 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
23 | " wf.load(steps_dir='../ochre/cwl/')\n",
24 | " print(wf.list_steps())\n",
25 | " \n",
26 | " in_dir = wf.add_input(in_dir='Directory')\n",
27 | " \n",
28 | " # inputs with default values\n",
29 | " recursive = wf.add_input(recursive='boolean', default=True)\n",
30 | " endswith = wf.add_input(endswith='string', default='alto.xml')\n",
31 | " element = wf.add_input(element='string[]', default=['SP'])\n",
32 | " in_fmt = wf.add_input(in_fmt='string', default='alto')\n",
33 | " out_fmt = wf.add_input(out_fmt='string', default='text')\n",
34 | " \n",
35 | " in_files = wf.ls(in_dir=in_dir, recursive=recursive, endswith=endswith)\n",
36 | " cleaned_files = wf.remove_xml_elements(element=element, xml_file=in_files, \n",
37 | " scatter='xml_file', scatter_method='dotproduct')\n",
38 | " text_pages = wf.ocr_transform(in_file=cleaned_files, in_fmt=in_fmt, out_fmt=out_fmt,\n",
39 | " scatter='in_file', scatter_method='dotproduct')\n",
40 | " text_files = wf.kb_tss_concat_files(in_files=text_pages)\n",
41 | " \n",
42 | " \n",
43 | " wf.add_outputs(text_files=text_files)\n",
44 | " \n",
45 | " wf.save('../ochre/cwl/kb-tss-preprocess-single-dir.cwl', wd=True)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "# Workflow to preprocess all kb_tss data\n",
55 | "import nlppln\n",
56 | "\n",
57 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
58 | " wf.load(steps_dir='../ochre/cwl/')\n",
59 | " print(wf.list_steps())\n",
60 | " \n",
61 | " karmac_dir = wf.add_input(karmac_dir='Directory')\n",
62 | " original_dir = wf.add_input(orignal_dir='Directory')\n",
63 | " xcago_dir = wf.add_input(xcago_dir='Directory')\n",
64 | " \n",
65 | " karmac_name = wf.add_input(karmac_name='string', default='Karmac')\n",
66 | " original_name = wf.add_input(original_name='string', default='Origineel')\n",
67 | " xcago_name = wf.add_input(xcago_name='string', default='X-Cago')\n",
68 | " \n",
69 | " karmac_aligned_name = wf.add_input(karmac_aligned_name='string', default='align-Karmac-Origineel')\n",
70 | " xcago_aligned_name = wf.add_input(xcago_aligned_name='string', default='align-X-Cago-Origineel')\n",
71 | " \n",
72 | " karmac_files = wf.kb_tss_preprocess_single_dir(in_dir=karmac_dir)\n",
73 | " original_files = wf.kb_tss_preprocess_single_dir(in_dir=original_dir)\n",
74 | " xcago_files = wf.kb_tss_preprocess_single_dir(in_dir=xcago_dir)\n",
75 | " \n",
76 | " karmac_dir_new = wf.save_files_to_dir(dir_name=karmac_name, in_files=karmac_files)\n",
77 | " original_dir_new = wf.save_files_to_dir(dir_name=original_name, in_files=original_files)\n",
78 | " xcago_dir_new = wf.save_files_to_dir(dir_name=xcago_name, in_files=xcago_files)\n",
79 | " \n",
80 | " karmac_alignments, karmac_changes, karmac_metadata = wf.align_texts_wf(gs=karmac_files, ocr=original_files)\n",
81 | " xcago_alignments, xcago_changes, xcago_metadata = wf.align_texts_wf(gs=xcago_files, ocr=original_files)\n",
82 | " \n",
83 | " karmac_align_dir = wf.save_files_to_dir(dir_name=karmac_aligned_name, in_files=karmac_alignments)\n",
84 | " xcago_align_dir = wf.save_files_to_dir(dir_name=xcago_aligned_name, in_files=xcago_alignments)\n",
85 | " \n",
86 | " wf.add_outputs(karmac=karmac_dir_new)\n",
87 | " wf.add_outputs(original=original_dir_new)\n",
88 | " wf.add_outputs(xcago=xcago_dir_new)\n",
89 | " wf.add_outputs(karmac_align=karmac_align_dir)\n",
90 | " wf.add_outputs(xcago_align=xcago_align_dir)\n",
91 | " \n",
92 | " wf.save('../ochre/cwl/kb-tss-preprocess-all.cwl', pack=True)"
93 | ]
94 | }
95 | ],
96 | "metadata": {
97 | "language_info": {
98 | "name": "python",
99 | "pygments_lexer": "ipython3"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 2
104 | }
105 |
--------------------------------------------------------------------------------
/notebooks/ocr-evaluation-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import nlppln\n",
20 | "import ochre\n",
21 | "import os\n",
22 | "\n",
23 | "with nlppln.WorkflowGenerator(working_dir='/Users/jvdzwaan/cwl-working-dir/') as wf:\n",
24 | " wf.load(steps_dir=ochre.cwl_path())\n",
25 | " wf.load(step_file='https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl')\n",
26 | " \n",
27 | " #print wf.list_steps()\n",
28 | "\n",
29 | " gt_file = wf.add_input(gt='File')\n",
30 | " ocr_file = wf.add_input(ocr='File')\n",
31 | " xmx = wf.add_input(xmx='string?')\n",
32 | "\n",
33 | " out_file = wf.ocrevaluation(gt=gt_file, ocr=ocr_file, xmx=xmx)\n",
34 | " character_data, global_data = wf.ocrevaluation_extract(in_file=out_file)\n",
35 | "\n",
36 | " wf.add_outputs(character_data=character_data)\n",
37 | " wf.add_outputs(global_data=global_data)\n",
38 | "\n",
39 | " # can be used as a subworkflow\n",
40 | " wf.save(os.path.join(ochre.cwl_path(), 'ocrevaluation-performance-wf.cwl'), mode='wd')"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import nlppln\n",
50 | "import ochre\n",
51 | "import os\n",
52 | "\n",
53 | "with nlppln.WorkflowGenerator(working_dir='/Users/jvdzwaan/cwl-working-dir/') as wf:\n",
54 | " wf.load(steps_dir=ochre.cwl_path())\n",
55 | " \n",
56 | " print(wf.list_steps())\n",
57 | "\n",
58 | " gt_dir = wf.add_input(gt='Directory')\n",
59 | " ocr_dir = wf.add_input(ocr='Directory')\n",
60 | " xmx = wf.add_input(xmx='string?')\n",
61 | " performance_file = wf.add_input(out_name='string?', default='performance.csv')\n",
62 | "\n",
63 | " ocr_files = wf.ls(in_dir=ocr_dir)\n",
64 | " gt_files = wf.ls(in_dir=gt_dir)\n",
65 | " \n",
66 | " character_data, global_data = wf.ocrevaluation_performance_wf(gt=gt_files, ocr=ocr_files, xmx=xmx,\n",
67 | " scatter=['gt', 'ocr'], scatter_method='dotproduct')\n",
68 | " \n",
69 | " merged = wf.merge_csv(in_files=global_data, name=performance_file)\n",
70 | "\n",
71 | " wf.add_outputs(performance=merged)\n",
72 | "\n",
73 | " wf.save(os.path.join(ochre.cwl_path(), 'ocrevaluation-performance-wf-pack.cwl'), mode='pack')"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "import nlppln\n",
83 | "import ochre\n",
84 | "import os\n",
85 | "\n",
86 | "with nlppln.WorkflowGenerator(working_dir='/Users/jvdzwaan/cwl-working-dir/') as wf:\n",
87 | " wf.load(steps_dir=ochre.cwl_path())\n",
88 | " wf.load(step_file='https://raw.githubusercontent.com/nlppln/ocrevaluation-docker/master/ocrevaluation.cwl')\n",
89 | " print(wf.list_steps())\n",
90 | "\n",
91 | " ocr_dir = wf.add_input(ocr='Directory')\n",
92 | " gt_dir = wf.add_input(gt='Directory')\n",
93 | " datadiv = wf.add_input(datadivision='File')\n",
94 | " div_name = wf.add_input(div_name='string', default='test')\n",
95 | " gt_dir_name = wf.add_input(gt_dir_name='string', default='gs')\n",
96 | " ocr_dir_name = wf.add_input(ocr_dir_name='string', default='ocr')\n",
97 | " fname = wf.add_input(out_name='string?', default='performance.csv')\n",
98 | "\n",
99 | " ocr_files = wf.select_test_files(datadivision=datadiv, in_dir=ocr_dir, name=div_name)\n",
100 | " gt_files = wf.select_test_files(datadivision=datadiv, in_dir=gt_dir, name=div_name)\n",
101 | "\n",
102 | " character_data, global_data = wf.ocrevaluation_performance_wf(gt=gt_files, ocr=ocr_files, \n",
103 | " scatter=['gt', 'ocr'], \n",
104 | " scatter_method='dotproduct')\n",
105 | " \n",
106 | " merged = wf.merge_csv(in_files=global_data, name=performance_file)\n",
107 | "\n",
108 | " wf.add_outputs(performance=merged)\n",
109 | "\n",
110 | " wf.save(os.path.join(ochre.cwl_path(), 'ocrevaluation-performance-test-files-wf-pack.cwl'), mode='pack')"
111 | ]
112 | }
113 | ],
114 | "metadata": {
115 | "language_info": {
116 | "name": "python",
117 | "pygments_lexer": "ipython3"
118 | }
119 | },
120 | "nbformat": 4,
121 | "nbformat_minor": 2
122 | }
123 |
--------------------------------------------------------------------------------
/ochre/train_seq2seq.py:
--------------------------------------------------------------------------------
1 | from ochre.utils import initialize_model_seq2seq, load_weights, save_charset, \
2 | create_training_data, read_texts, get_chars, \
3 | add_checkpoint
4 |
5 | import click
6 | import os
7 | import json
8 |
9 |
10 | @click.command()
11 | @click.argument('datasets', type=click.File())
12 | @click.argument('data_dir', type=click.Path(exists=True))
13 | @click.option('--weights_dir', '-w', default=os.getcwd(), type=click.Path())
14 | def train_lstm(datasets, data_dir, weights_dir):
15 | seq_length = 25
16 | pred_chars = 1
17 | num_nodes = 256
18 | layers = 2
19 | batch_size = 100
20 | step = 3 # step size used to create data (3 = use every third sequence)
21 | lowercase = True
22 | char_embedding_size = 16
23 | pad = u'\n'
24 |
25 | print('Sequence length: {}'.format(seq_length))
26 | print('Predict characters: {}'.format(pred_chars))
27 | print('Use character embedding: {}'.format(bool(char_embedding_size)))
28 | if char_embedding_size:
29 | print('Char embedding size: {}'.format(char_embedding_size))
30 | print('Number of nodes in hidden layers: {}'.format(num_nodes))
31 | print('Number of hidden layers: {}'.format(layers))
32 | print('Batch size: {}'.format(batch_size))
33 | print('Lowercase data: {}'.format(lowercase))
34 |
35 | div = json.load(datasets)
36 |
37 | raw_val, gs_val, ocr_val = read_texts(div.get('val'), data_dir)
38 | raw_test, gs_test, ocr_test = read_texts(div.get('test'), data_dir)
39 | raw_train, gs_train, ocr_train = read_texts(div.get('train'), data_dir)
40 |
41 | chars, n_vocab, char_to_int = get_chars(raw_val, raw_test, raw_train,
42 | lowercase, padding_char=pad)
43 | # save charset to file
44 | save_charset(weights_dir, chars, lowercase)
45 |
46 | print('Total Vocab: {}'.format(n_vocab))
47 |
48 | embed = bool(char_embedding_size)
49 | nTrainSamples, trainData = create_training_data(ocr_train,
50 | gs_train,
51 | char_to_int,
52 | n_vocab,
53 | seq_length=seq_length,
54 | batch_size=batch_size,
55 | lowercase=lowercase,
56 | step=step,
57 | predict_chars=pred_chars,
58 | char_embedding=embed)
59 | nTestSamples, testData = create_training_data(ocr_test,
60 | gs_test,
61 | char_to_int,
62 | n_vocab,
63 | seq_length=seq_length,
64 | batch_size=batch_size,
65 | lowercase=lowercase,
66 | predict_chars=pred_chars,
67 | char_embedding=embed)
68 | nValSamples, valData = create_training_data(ocr_val,
69 | gs_val,
70 | char_to_int,
71 | n_vocab,
72 | seq_length=seq_length,
73 | batch_size=batch_size,
74 | lowercase=lowercase,
75 | predict_chars=pred_chars,
76 | char_embedding=embed)
77 |
78 | n_patterns = nTrainSamples
79 | print("Train Patterns: {}".format(n_patterns))
80 | print("Validation Patterns: {}".format(nValSamples))
81 | print("Test Patterns: {}".format(nTestSamples))
82 | print('Total: {}'.format(nTrainSamples+nTestSamples+nValSamples))
83 |
84 | model = initialize_model_seq2seq(num_nodes, 0.5, seq_length, pred_chars,
85 | n_vocab, layers, char_embedding_size)
86 | epoch, model = load_weights(model, weights_dir)
87 | callbacks_list = [add_checkpoint(weights_dir)]
88 |
89 | # do training (and save weights)
90 | model.fit_generator(trainData,
91 | steps_per_epoch=int(nTrainSamples/batch_size),
92 | epochs=40,
93 | validation_data=valData,
94 | validation_steps=int(nValSamples/batch_size),
95 | callbacks=callbacks_list,
96 | initial_epoch=epoch)
97 |
98 |
99 | if __name__ == '__main__':
100 | train_lstm()
101 |
--------------------------------------------------------------------------------
/notebooks/Compare-performance-of-pred.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "import matplotlib.pyplot as plt"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "#performance_4 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/performance_A8P1_model-50000_2019-05-16.csv', index_col=0)\n",
33 | "performance_4 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/performance_A8P1_model-50000_2019-05-19.csv', index_col=0)\n",
34 | "performance_4.columns = ['CER_40000', 'WER_40000', 'WER (order independent)_40000']\n",
35 | "#performance_4"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "performance_12 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/performance_A8P1_model-120000_2019-05-22.csv', index_col=0)\n",
45 | "performance_12.columns = ['CER_120000', 'WER_120000', 'WER (order independent)_120000']\n",
46 | "#performance_12"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "performance_fr11 = pd.read_csv('/home/jvdzwaan/data/kb-ocr/FR11_match-gs_vs_GT_test-files_2019-05-22.csv', index_col=0)\n",
56 | "performance_fr11.columns = ['CER_FR11', 'WER_FR11', 'WER (order independent)_FR11']\n",
57 | "#performance_fr11"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "performance_old = pd.read_csv('/home/jvdzwaan/data/kb-ocr/text_aligned_blocks-match_gs-20190514.csv', index_col=0)\n",
67 | "#performance_old"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "performance_old.columns = ['CER_old', 'WER_old', 'WER (order independent)_old']\n",
77 | "#performance_old"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "result = performance_12.join(performance_4).join(performance_old).join(performance_fr11)\n",
87 | "#result"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "result[['CER_old', 'CER_FR11', 'CER_40000', 'CER_120000']].plot(kind='bar', figsize=(15,7))"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "result[['WER_old', 'WER_FR11', 'WER_40000', 'WER_120000']].plot(kind='bar', figsize=(15,7))"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "result[['WER (order independent)_old', 'WER (order independent)_FR11', 'WER (order independent)_40000', 'WER (order independent)_120000']].plot(kind='bar', figsize=(15,7))"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "print('CER old:')\n",
124 | "print(result['CER_old'].describe())"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "print('CER model')\n",
134 | "print(result['CER_120000'].describe())"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "print('Afname:')\n",
144 | "print((result['CER_old'] - result['CER_120000']).describe())"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "print('GT')\n",
154 | "print(result['CER_FR11'].describe())"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "print('Afname:')\n",
164 | "print((result['CER_old'] - result['CER_FR11']).describe())"
165 | ]
166 | }
167 | ],
168 | "metadata": {
169 | "language_info": {
170 | "name": "python",
171 | "pygments_lexer": "ipython3"
172 | }
173 | },
174 | "nbformat": 4,
175 | "nbformat_minor": 2
176 | }
177 |
--------------------------------------------------------------------------------
/notebooks/character_counts.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import pandas as pd\n",
20 | "\n",
21 | "ocr = pd.read_csv('/Users/janneke/Documents/data/ocr/char_counts_ocr.csv', index_col=0, encoding='utf-8')\n",
22 | "print ocr.shape\n",
23 | "ocr"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "gs = pd.read_csv('/Users/janneke/Documents/data/ocr/char_counts_gs.csv', index_col=0, encoding='utf-8')\n",
33 | "print gs.shape\n",
34 | "gs"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "ocr_unique = set(ocr.columns) - set(gs.columns)\n",
44 | "print ' '.join(ocr_unique)\n",
45 | "# characters that are in ocr, but not in gs"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "ocr[list(ocr_unique)].sum()"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "gs_unique = set(gs.columns) - set(ocr.columns)\n",
64 | "print ' '.join(gs_unique)\n",
65 | "# characters that are in gs, but not in ocr"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "gs[list(gs_unique)].sum()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "for c in gs_unique:\n",
84 | " print c, repr(c), gs[c].sum()"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "print ocr.columns\n",
94 | "char_mapping = {}\n",
95 | "multiple_chars = []\n",
96 | "for c in list(ocr.columns) + list(gs.columns):\n",
97 | " l = len(c.encode('utf-8'))\n",
98 | " if l > 1:\n",
99 | " print repr(c)\n",
100 | " multiple_chars.append(c)\n",
101 | " else:\n",
102 | " char_mapping[ord(c)] = c\n",
103 | " #print ord(c)\n",
104 | "print len(multiple_chars)\n",
105 | "#print char_mapping\n",
106 | "\n",
107 | "for c in set(multiple_chars):\n",
108 | " for i in range(1, 256):\n",
109 | " if i not in char_mapping.keys():\n",
110 | " char_mapping[i] = c\n",
111 | " break\n",
112 | "print char_mapping\n",
113 | "print len(char_mapping)\n",
114 | "print sorted(char_mapping.keys())\n",
115 | "print u' '.join(multiple_chars)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "mapping = {}\n",
125 | "for k, v in char_mapping.items():\n",
126 | " mapping[v] = k\n",
127 | "print mapping"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "print mapping[u'ó']\n",
137 | "print char_mapping[8]"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "import json\n",
147 | "import codecs\n",
148 | "\n",
149 | "with codecs.open('../char_mapping.json', 'w', encoding='utf-8') as f:\n",
150 | " json.dump(mapping, f, indent=2)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "extended_ascii = \"\".join([chr(i) for i in xrange(256)])\n",
160 | "print extended_ascii.decode('latin1')\n",
161 | "for i in xrange(256):\n",
162 | " if len(chr(i)) > 1:\n",
163 | " print i"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "char_mapping = {}\n",
173 | "multiple_chars = []\n",
174 | "\n",
175 | "for c in list(ocr.columns) + list(gs.columns):\n",
176 | " try:\n",
177 | " char = c.encode('latin1')\n",
178 | " char_mapping[ord(char)] = char\n",
179 | " except:\n",
180 | " multiple_chars.append(c)\n",
181 | " \n",
182 | "for c in set(multiple_chars):\n",
183 | " for i in range(1, 256):\n",
184 | " if i not in char_mapping.keys():\n",
185 | " char_mapping[i] = c\n",
186 | " break\n",
187 | "\n",
188 | "print char_mapping\n",
189 | "mapping = {}\n",
190 | "for k, v in char_mapping.items():\n",
191 | " mapping[v] = k\n",
192 | "print mapping\n",
193 | "with open('../char_mapping2.json', 'w') as f:\n",
194 | " json.dump(mapping, f, indent=2)"
195 | ]
196 | }
197 | ],
198 | "metadata": {
199 | "language_info": {
200 | "name": "python",
201 | "pygments_lexer": "ipython3"
202 | }
203 | },
204 | "nbformat": 4,
205 | "nbformat_minor": 1
206 | }
207 |
--------------------------------------------------------------------------------
/notebooks/word-mapping-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "# to be used as subworkflow\n",
20 | "import nlppln\n",
21 | "import ochre\n",
22 | "import os\n",
23 | "\n",
24 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
25 | " wf.load(steps_dir=ochre.cwl_path())\n",
26 | " print(wf.list_steps())\n",
27 | " \n",
28 | " wf.set_documentation('This workflow is meant to be used as a subworkflow.')\n",
29 | "\n",
30 | " gs_unnormalized = wf.add_input(gs_files='File[]')\n",
31 | " ocr_unnormalized = wf.add_input(ocr_files='File[]')\n",
32 | " language = wf.add_input(language='string')\n",
33 | " lowercase = wf.add_input(lowercase='boolean?')\n",
34 | " align_metadata = wf.add_input(align_m='string?')\n",
35 | " align_changes = wf.add_input(align_c='string?')\n",
36 | " word_mapping_name = wf.add_input(wm_name='string?')\n",
37 | "\n",
38 | " gs = wf.normalize_whitespace_punctuation(meta_in=gs_unnormalized, scatter=['meta_in'], scatter_method='dotproduct')\n",
39 | " ocr = wf.normalize_whitespace_punctuation(meta_in=ocr_unnormalized, scatter=['meta_in'], scatter_method='dotproduct')\n",
40 | "\n",
41 | " alignments, changes, metadata = wf.align_texts_wf(gs=gs, ocr=ocr, align_c=align_changes, align_m=align_metadata)\n",
42 | " \n",
43 | " gs_saf = wf.pattern(in_file=gs, language=language, scatter='in_file', scatter_method='dotproduct')\n",
44 | "\n",
45 | " mappings = wf.create_word_mappings(alignments=alignments, saf=gs_saf, lowercase=lowercase, \n",
46 | " scatter=['alignments', 'saf'], scatter_method='dotproduct')\n",
47 | " merged = wf.merge_csv(in_files=mappings, name=word_mapping_name)\n",
48 | "\n",
49 | " wf.add_outputs(wm_mapping=merged)\n",
50 | " wf.save(os.path.join(ochre.cwl_path(), 'word-mapping-wf.cwl'), wd=True)"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "# create word mappings for directory\n",
60 | "import nlppln\n",
61 | "import ochre\n",
62 | "import os\n",
63 | "\n",
64 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
65 | " wf.load(steps_dir=ochre.cwl_path())\n",
66 | " print(wf.list_steps())\n",
67 | " \n",
68 | " gs_dir = wf.add_input(gs_dir='Directory')\n",
69 | " ocr_dir = wf.add_input(ocr_dir='Directory')\n",
70 | " language = wf.add_input(language='string')\n",
71 | " lowercase = wf.add_input(lowercase='boolean?')\n",
72 | " align_metadata = wf.add_input(align_m='string?')\n",
73 | " align_changes = wf.add_input(align_c='string?')\n",
74 | " word_mapping_name = wf.add_input(wm_name='string?')\n",
75 | "\n",
76 | " gs_files = wf.ls(in_dir=gs_dir)\n",
77 | " ocr_files = wf.ls(in_dir=ocr_dir)\n",
78 | "\n",
79 | " wm_mapping = wf.word_mapping_wf(gs_files=gs_files, ocr_files=ocr_files, language=language, \n",
80 | " lowercase=lowercase, align_c=align_changes, \n",
81 | " align_m=align_metadata, wm_name=word_mapping_name)\n",
82 | "\n",
83 | " wf.add_outputs(wm_mapping=wm_mapping)\n",
84 | " wf.save(os.path.join(ochre.cwl_path(), 'word-mapping-dir.cwl'), pack=True)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "# create word mappings for test data\n",
94 | "import nlppln\n",
95 | "import ochre\n",
96 | "import os\n",
97 | "\n",
98 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
99 | " wf.load(steps_dir=ochre.cwl_path())\n",
100 | " print(wf.list_steps())\n",
101 | "\n",
102 | " gs_dir = wf.add_input(gs_dir='Directory')\n",
103 | " ocr_dir = wf.add_input(ocr_dir='Directory')\n",
104 | " language = wf.add_input(language='string')\n",
105 | " data_div = wf.add_input(data_div='File')\n",
106 | " lowercase = wf.add_input(lowercase='boolean?')\n",
107 | " align_metadata = wf.add_input(align_m='string?')\n",
108 | " align_changes = wf.add_input(align_c='string?')\n",
109 | " word_mapping_name = wf.add_input(wm_name='string?')\n",
110 | "\n",
111 | " test_gs = wf.select_test_files(datadivision=data_div, in_dir=gs_dir)\n",
112 | " test_ocr = wf.select_test_files(datadivision=data_div, in_dir=ocr_dir)\n",
113 | "\n",
114 | " wm_mapping = wf.word_mapping_wf(gs_files=test_gs, ocr_files=test_ocr, language=language, \n",
115 | " lowercase=lowercase, align_c=align_changes, \n",
116 | " align_m=align_metadata, wm_name=word_mapping_name)\n",
117 | "\n",
118 | " wf.add_outputs(wm_mapping=wm_mapping)\n",
119 | " wf.save(os.path.join(ochre.cwl_path(), 'word-mapping-test-files-wf.cwl'), pack=True)"
120 | ]
121 | }
122 | ],
123 | "metadata": {
124 | "language_info": {
125 | "name": "python",
126 | "pygments_lexer": "ipython3"
127 | }
128 | },
129 | "nbformat": 4,
130 | "nbformat_minor": 2
131 | }
132 |
--------------------------------------------------------------------------------
/notebooks/improve-keras-datagen.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "import matplotlib.pyplot as plt"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# read the texts (using data division)\n",
33 | "# convert texts to input and output sequences (strings)\n",
34 | "# function for converting input and output sequences to numbers (in data generator)\n",
35 | "import json\n",
36 | "\n",
37 | "from ochre.utils import read_texts\n",
38 | "\n",
39 | "datasets = '/home/jvdzwaan/data/dncvu/datadivision-small.json'\n",
40 | "data_dir = '/home/jvdzwaan/data/dncvu/aligned/'\n",
41 | "\n",
42 | "with open(datasets) as d:\n",
43 | " division = json.load(d)\n",
44 | "print(division)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir)\n",
54 | "print(gs_val)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "# input: sequence of n characters from ocr\n",
64 | "# output: sequence of n characters from gs\n",
65 | "\n",
66 | "# we can filter the sequences later if we want\n",
67 | "\n",
68 | "n = 5\n",
69 | "num = 0\n",
70 | "\n",
71 | "def get_sequences(gs, ocr, length):\n",
72 | " gs_ngrams = zip(*[gs[i:] for i in range(length)])\n",
73 | " ocr_ngrams = zip(*[ocr[i:] for i in range(length)])\n",
74 | "\n",
75 | " return [''.join(n) for n in gs_ngrams], [''.join(n) for n in ocr_ngrams]\n",
76 | "\n",
77 | "gs_seqs, ocr_seqs = get_sequences(gs_val, ocr_val, length=n)\n",
78 | "print(len(gs_seqs), len(ocr_seqs))\n",
79 | "# each item in the *_seqs is a sample"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "gs_seqs = gs_seqs[:7]\n",
89 | "ocr_seqs = ocr_seqs[:7]"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "from sklearn.preprocessing import LabelEncoder\n",
99 | "\n",
100 | "le = LabelEncoder()\n",
101 | "le.fit([c for c in raw_val])\n",
102 | "for s in ocr_seqs:\n",
103 | " print(le.fit_transform(''.join(s)))"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "raw = ''.join(gs_seqs)\n",
113 | "print(raw)"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "from ochre.utils import get_chars, get_int_to_char\n",
123 | "\n",
124 | "chars, num_chars, ci = get_chars(raw, raw, raw, False)\n",
125 | "ic = get_int_to_char(ci)\n",
126 | "\n",
127 | "#print(ci)\n",
128 | "#print(ic)\n",
129 | "\n",
130 | "oov_char = '@'\n",
131 | "\n",
132 | "for s in ocr_seqs:\n",
133 | "#for s in [['wwww']]:\n",
134 | "#for s in [['w8w']]:\n",
135 | " print(s)\n",
136 | " res = np.empty(n, dtype=np.int)\n",
137 | " for i in range(n):\n",
138 | " try:\n",
139 | " if s[i] != '':\n",
140 | " res[i] = ci.get(s[i], ci[oov_char])\n",
141 | " except IndexError:\n",
142 | " res[i] = ci['\\n']\n",
143 | " print(res)\n"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "from ochre.datagen import DataGenerator\n",
153 | "\n",
154 | "dg = DataGenerator(xData=ocr_seqs, yData=gs_seqs, char_to_int=ci, seq_length=n, batch_size=3)"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "print(len(dg))"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "inp, outp = dg[0]\n",
173 | "print(inp)"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "print(outp.shape)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "print(outp)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "print(ocr_seqs)"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {},
207 | "outputs": [],
208 | "source": [
209 | "dg._convert_sample('test')"
210 | ]
211 | }
212 | ],
213 | "metadata": {
214 | "language_info": {
215 | "name": "python",
216 | "pygments_lexer": "ipython3"
217 | }
218 | },
219 | "nbformat": 4,
220 | "nbformat_minor": 2
221 | }
222 |
--------------------------------------------------------------------------------
/notebooks/kb_xml_to_text_(unaligned).ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "\n",
25 | "from tqdm import tqdm_notebook as tqdm"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import os\n",
35 | "\n",
36 | "from lxml import etree\n",
37 | "\n",
38 | "from nlppln.utils import get_files\n",
39 | "from ochre.matchlines import gt_fname2ocr_fname\n",
40 | "\n",
41 | "gs_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Ground-truth/'\n",
42 | "ocr_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/'\n",
43 | "#ocr_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Opnieuw geOCRd/'\n",
44 | "\n",
45 | "gs_files = get_files(gs_dir)\n",
46 | "# remove file with \"extra\" in the name, this one is the same as the file without \"extra\" in the name\n",
47 | "gs_files = [f for f in gs_files if not 'extra' in f]\n",
48 | "\n",
49 | "ocr_files = []\n",
50 | "for gs_file in gs_files:\n",
51 | " ocr_bn = gt_fname2ocr_fname(gs_file)\n",
52 | " # the 'opnieuw' alto files have a different file name\n",
53 | " #ocr_bn = ocr_bn.replace('alto.xml', 'altoFR11.xml')\n",
54 | " ocr_file = os.path.join(ocr_dir, ocr_bn)\n",
55 | " if os.path.isfile(ocr_file):\n",
56 | " ocr_files.append(ocr_file)\n",
57 | " else:\n",
58 | " print('File not found:', ocr_file)\n",
59 | " print('GS file:', gs_file)\n",
60 | "print(len(gs_files), len(ocr_files))"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "%%time\n",
70 | "\n",
71 | "from nlppln.utils import create_dirs\n",
72 | "\n",
73 | "from ochre.utils import get_temp_file\n",
74 | "from ochre.matchlines import get_ns, replace_entities\n",
75 | "\n",
76 | "def get_lines(fname, alto_ns):\n",
77 | " lines = []\n",
78 | " context = etree.iterparse(fname, events=('end', ), tag=(alto_ns+'TextLine'))\n",
79 | " for event, elem in context:\n",
80 | " words = []\n",
81 | " for a in elem.getchildren():\n",
82 | " if a.tag == alto_ns+'String':\n",
83 | " if a.attrib.get('SUBS_TYPE') == 'HypPart1':\n",
84 | " words.append(a.attrib['SUBS_CONTENT'])\n",
85 | " elif a.attrib.get('SUBS_TYPE') != 'HypPart2':\n",
86 | " words.append(a.attrib['CONTENT'])\n",
87 | " \n",
88 | " lines.append(' '.join(words))\n",
89 | " \n",
90 | " # make iteration over context fast and consume less memory\n",
91 | " #https://www.ibm.com/developerworks/xml/library/x-hiperfparse\n",
92 | " elem.clear()\n",
93 | " while elem.getprevious() is not None:\n",
94 | " del elem.getparent()[0]\n",
95 | " \n",
96 | " return lines\n",
97 | "\n",
98 | "def doc_id(fname):\n",
99 | " bn = os.path.basename(fname)\n",
100 | " n = bn.rsplit('_', 1)[0]\n",
101 | " return n\n",
102 | "\n",
103 | "\n",
104 | "out_dir = '/home/jvdzwaan/data/kb-ocr/text-not-aligned/'\n",
105 | "\n",
106 | "create_dirs(out_dir)\n",
107 | "\n",
108 | "gs_dir = os.path.join(out_dir, 'gs')\n",
109 | "create_dirs(gs_dir)\n",
110 | "\n",
111 | "ocr_dir = os.path.join(out_dir, 'ocr')\n",
112 | "create_dirs(ocr_dir)\n",
113 | "\n",
114 | "for gs_file, ocr_file in tqdm(zip(gs_files, ocr_files), total=len(gs_files)):\n",
115 | " try:\n",
116 | " gs_tmp = get_temp_file()\n",
117 | " #print(gs_tmp)\n",
118 | " with open(gs_tmp, 'w') as f:\n",
119 | " f.write(replace_entities(gs_file))\n",
120 | " \n",
121 | " #ocr_tmp = get_temp_file()\n",
122 | " #print(gs_tmp)\n",
123 | " #with open(ocr_tmp, 'w') as f:\n",
124 | " # f.write(replace_entities(ocr_file))\n",
125 | " \n",
126 | " gs_lines = get_lines(gs_tmp, get_ns(gs_file))\n",
127 | " ocr_lines = get_lines(ocr_file, get_ns(ocr_file))\n",
128 | " #print(len(gs_lines), len(ocr_lines))\n",
129 | " \n",
130 | " os.remove(gs_tmp)\n",
131 | " #os.remove(ocr_tmp)\n",
132 | " \n",
133 | " #print(doc_id(gs_file))\n",
134 | " #print(doc_id(ocr_file))\n",
135 | " assert doc_id(gs_file) == doc_id(ocr_file)\n",
136 | " gs_out = os.path.join(gs_dir, '{}.txt'.format(doc_id(gs_file)))\n",
137 | " ocr_out = os.path.join(ocr_dir, '{}.txt'.format(doc_id(ocr_file)))\n",
138 | " #print(gs_out)\n",
139 | " #print(ocr_out)\n",
140 | " \n",
141 | " with open(gs_out, 'w') as f:\n",
142 | " f.write(' '.join(gs_lines))\n",
143 | " with open(ocr_out, 'w') as f:\n",
144 | " f.write(' '.join(ocr_lines))\n",
145 | " except etree.XMLSyntaxError as e:\n",
146 | " print(gs_file)\n",
147 | " print(e)\n",
148 | " print()"
149 | ]
150 | }
151 | ],
152 | "metadata": {
153 | "language_info": {
154 | "name": "python",
155 | "pygments_lexer": "ipython3"
156 | }
157 | },
158 | "nbformat": 4,
159 | "nbformat_minor": 2
160 | }
161 |
--------------------------------------------------------------------------------
/notebooks/try-lstm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "## Links/tutorials\n",
18 | "\n",
19 | "* http://karpathy.github.io/2015/05/21/rnn-effectiveness/\n",
20 | "* http://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/\n",
21 | "* http://machinelearningmastery.com/sequence-classification-lstm-recurrent-neural-networks-python-keras/"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import numpy\n",
31 | "from keras.models import Sequential\n",
32 | "from keras.layers import Dense\n",
33 | "from keras.layers import Dropout\n",
34 | "from keras.layers import LSTM\n",
35 | "from keras.callbacks import ModelCheckpoint\n",
36 | "from keras.utils import np_utils"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "# load ascii text and covert to lowercase\n",
46 | "filename = \"/home/jvdzwaan/data/alice.txt\"\n",
47 | "raw_text = open(filename).read()\n",
48 | "raw_text = raw_text.lower()\n",
49 | "print raw_text[:100]"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "# create mapping of unique chars to integers\n",
59 | "chars = sorted(list(set(raw_text)))\n",
60 | "char_to_int = dict((c, i) for i, c in enumerate(chars))"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "print chars"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "print char_to_int"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "n_chars = len(raw_text)\n",
88 | "n_vocab = len(chars)\n",
89 | "print \"Total Characters: \", n_chars\n",
90 | "print \"Total Vocab: \", n_vocab"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "# prepare the dataset of input to output pairs encoded as integers\n",
100 | "seq_length = 100\n",
101 | "dataX = []\n",
102 | "dataY = []\n",
103 | "for i in range(0, n_chars - seq_length, 1):\n",
104 | " seq_in = raw_text[i:i + seq_length]\n",
105 | " seq_out = raw_text[i + seq_length]\n",
106 | " dataX.append([char_to_int[char] for char in seq_in])\n",
107 | " dataY.append(char_to_int[seq_out])\n",
108 | "n_patterns = len(dataX)\n",
109 | "print \"Total Patterns: \", n_patterns"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "print seq_in\n",
119 | "print '---'\n",
120 | "print repr(seq_out)\n",
121 | "print len(seq_in)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "for i in range(0, n_chars - seq_length, 1):\n",
131 | " print i"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "print dataX[0]\n",
141 | "print dataY[0]"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "X = numpy.reshape(dataX, (n_patterns, seq_length, 1))\n",
151 | "# normalize\n",
152 | "X = X / float(n_vocab)\n",
153 | "# one hot encode the output variable\n",
154 | "y = np_utils.to_categorical(dataY)"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "print X"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "print y"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "# define the LSTM model\n",
182 | "model = Sequential()\n",
183 | "model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))\n",
184 | "model.add(Dropout(0.2))\n",
185 | "model.add(Dense(y.shape[1], activation='softmax'))\n",
186 | "model.compile(loss='categorical_crossentropy', optimizer='adam')"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "# define the checkpoint\n",
196 | "filepath=\"/home/jvdzwaan/data/tmp/lstm/weights-improvement-{epoch:02d}-{loss:.4f}.hdf5\"\n",
197 | "checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')\n",
198 | "callbacks_list = [checkpoint]"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {},
205 | "outputs": [],
206 | "source": [
207 | "model.fit(X, y, epochs=20, batch_size=128, callbacks=callbacks_list)"
208 | ]
209 | }
210 | ],
211 | "metadata": {
212 | "language_info": {
213 | "name": "python",
214 | "pygments_lexer": "ipython3"
215 | }
216 | },
217 | "nbformat": 4,
218 | "nbformat_minor": 2
219 | }
220 |
--------------------------------------------------------------------------------
/notebooks/ICDAR2017_shared_task_workflows.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import nlppln\n",
20 | "\n",
21 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
22 | " wf.load(steps_dir='../ochre/cwl/')\n",
23 | " print wf.list_steps()\n",
24 | "\n",
25 | " in_dir = wf.add_input(in_dir='Directory')\n",
26 | " ocr_dir_name = wf.add_input(ocr_dir_name='string')\n",
27 | " gs_dir_name = wf.add_input(gs_dir_name='string')\n",
28 | " aligned_dir_name = wf.add_input(aligned_dir_name='string')\n",
29 | "\n",
30 | " files = wf.ls(in_dir=in_dir)\n",
31 | " aligned, gs, ocr = wf.icdar2017st_extract_text(in_file=files, scatter=['in_file'], scatter_method='dotproduct')\n",
32 | " gs_dir = wf.save_files_to_dir(dir_name=gs_dir_name, in_files=gs)\n",
33 | " ocr_dir = wf.save_files_to_dir(dir_name=ocr_dir_name, in_files=ocr)\n",
34 | " aligned_dir = wf.save_files_to_dir(dir_name=aligned_dir_name, in_files=aligned)\n",
35 | "\n",
36 | " wf.add_outputs(gs_dir=gs_dir)\n",
37 | " wf.add_outputs(ocr_dir=ocr_dir)\n",
38 | " wf.add_outputs(aligned_dir=aligned_dir)\n",
39 | "\n",
40 | " wf.save('../ochre/cwl/icdar2017st-extract-data.cwl', wd=True)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import nlppln\n",
50 | "\n",
51 | "with nlppln.WorkflowGenerator(working_dir='/home/jvdzwaan/cwl-working-dir/') as wf:\n",
52 | " wf.load(steps_dir='../ochre/cwl/')\n",
53 | " print wf.list_steps()\n",
54 | "\n",
55 | " in_dir1 = wf.add_input(in_dir1='Directory')\n",
56 | " in_dir2 = wf.add_input(in_dir2='Directory')\n",
57 | " in_dir3 = wf.add_input(in_dir3='Directory')\n",
58 | " in_dir4 = wf.add_input(in_dir4='Directory')\n",
59 | " ocr_dir_name = wf.add_input(ocr_dir_name='string', default='ocr')\n",
60 | " gs_dir_name = wf.add_input(gs_dir_name='string', default='gs')\n",
61 | " aligned_dir_name = wf.add_input(aligned_dir_name='string', default='aligned')\n",
62 | "\n",
63 | " aligned_dir1, gs_dir1, ocr_dir1 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n",
64 | " gs_dir_name=gs_dir_name, \n",
65 | " ocr_dir_name=ocr_dir_name,\n",
66 | " in_dir=in_dir1)\n",
67 | " gs1 = wf.save_dir_to_subdir(inner_dir=gs_dir1, outer_dir=in_dir1)\n",
68 | " ocr1 = wf.save_dir_to_subdir(inner_dir=ocr_dir1, outer_dir=in_dir1)\n",
69 | " aligned1 = wf.save_dir_to_subdir(inner_dir=aligned_dir1, outer_dir=in_dir1)\n",
70 | "\n",
71 | " aligned_dir2, gs_dir2, ocr_dir2 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n",
72 | " gs_dir_name=gs_dir_name, \n",
73 | " ocr_dir_name=ocr_dir_name,\n",
74 | " in_dir=in_dir2)\n",
75 | " gs2 = wf.save_dir_to_subdir(inner_dir=gs_dir2, outer_dir=in_dir2)\n",
76 | " ocr2 = wf.save_dir_to_subdir(inner_dir=ocr_dir2, outer_dir=in_dir2)\n",
77 | " aligned2 = wf.save_dir_to_subdir(inner_dir=aligned_dir2, outer_dir=in_dir2)\n",
78 | "\n",
79 | " aligned_dir3, gs_dir3, ocr_dir3 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n",
80 | " gs_dir_name=gs_dir_name, \n",
81 | " ocr_dir_name=ocr_dir_name,\n",
82 | " in_dir=in_dir3)\n",
83 | " gs3 = wf.save_dir_to_subdir(inner_dir=gs_dir3, outer_dir=in_dir3)\n",
84 | " ocr3 = wf.save_dir_to_subdir(inner_dir=ocr_dir3, outer_dir=in_dir3)\n",
85 | " aligned3 = wf.save_dir_to_subdir(inner_dir=aligned_dir3, outer_dir=in_dir3)\n",
86 | "\n",
87 | " aligned_dir4, gs_dir4, ocr_dir4 = wf.icdar2017st_extract_data(aligned_dir_name=aligned_dir_name, \n",
88 | " gs_dir_name=gs_dir_name, \n",
89 | " ocr_dir_name=ocr_dir_name,\n",
90 | " in_dir=in_dir4)\n",
91 | " gs4 = wf.save_dir_to_subdir(inner_dir=gs_dir4, outer_dir=in_dir4)\n",
92 | " ocr4 = wf.save_dir_to_subdir(inner_dir=ocr_dir4, outer_dir=in_dir4)\n",
93 | " aligned4 = wf.save_dir_to_subdir(inner_dir=aligned_dir4, outer_dir=in_dir4)\n",
94 | "\n",
95 | "\n",
96 | " wf.add_outputs(gs1=gs1)\n",
97 | " wf.add_outputs(gs2=gs2)\n",
98 | " wf.add_outputs(gs3=gs3)\n",
99 | " wf.add_outputs(gs4=gs4)\n",
100 | "\n",
101 | " wf.add_outputs(ocr1=ocr1)\n",
102 | " wf.add_outputs(ocr2=ocr2)\n",
103 | " wf.add_outputs(ocr3=ocr3)\n",
104 | " wf.add_outputs(ocr4=ocr4)\n",
105 | "\n",
106 | " wf.add_outputs(aligned1=aligned4)\n",
107 | " wf.add_outputs(aligned2=aligned3)\n",
108 | " wf.add_outputs(aligned3=aligned2)\n",
109 | " wf.add_outputs(aligned4=aligned1)\n",
110 | "\n",
111 | " wf.save('../ochre/cwl/icdar2017st-extract-data-all.cwl', pack=True)"
112 | ]
113 | }
114 | ],
115 | "metadata": {
116 | "language_info": {
117 | "name": "python",
118 | "pygments_lexer": "ipython3"
119 | }
120 | },
121 | "nbformat": 4,
122 | "nbformat_minor": 2
123 | }
124 |
--------------------------------------------------------------------------------
/notebooks/fuzzy-string-matching.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "text = \"Many force beat offer third attention. Relationship whether question ask shoulder. Professional management method soldier itself enough single value opportunity represent short defense. Last human want tree sign people form old moment myself understand support station ahead water red relationship million indeed away music. Benefit range have seem day no thing single tough explain attention far local that above one design authority nothing. Trade stage seek former white their we success president item wind yeah full soldier suggest role never. Down condition north tell than item program at. Dark my box culture treat court tell some according quickly often issue quickly quickly office at wide want. Hope start any. Government value relate cover country network according. Travel. On but property production oil produce material gun page why face serve nation expect. Concern this apply travel skill allow same friend night country young once group style new turn situation gun foreign voice mother enough. Turn commercial manager news put determine pass address start particular. Light begin style hand almost popular thousand or least. Dog walk kitchen star skin information task land continue always spend value speak quickly energy improve leg as mind sing stage purpose few. Reason task phone conference. Production move rule marriage attorney money team open pretty card fall. Small player how since list adult idea apply hold both feel house team sign wall put authority free anyone happy authority throw body art pick. Skill entire major close series ground thus letter recent example I hold base policy include rate. Man think best position his lead career stage contain around eight mean they character specific memory religious heart shake see. Just during yet section thank rise force director water last ground go keep area different. Quickly several character west wrong model attention so truth resource her test gas. Most civil soldier consider front scene never military theory not free economic capital admit yeah fall thought there specific trial low against little tree. Center late.\"\n",
20 | "print(text)"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "from fuzzyset import FuzzySet\n",
30 | "\n",
31 | "fs = FuzzySet()\n",
32 | "fs.add(text)"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "fs.get(\"Production move rule marriage attorney money team open pretty card fall.\")"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "fs = FuzzySet()\n",
51 | "fs.add(\"Production move rule marriage attorney money team open pretty card fall.\")\n",
52 | "fs.get(text)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "from faker import Faker\n",
62 | "\n",
63 | "fake = Faker()\n",
64 | "num_words = 10\n",
65 | "sentences = []\n",
66 | "\n",
67 | "for i in range(10000):\n",
68 | " sentence = fake.sentence(nb_words=num_words, variable_nb_words=True, ext_word_list=None)\n",
69 | " sentences.append(sentence)"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "fs = FuzzySet()\n",
79 | "for s in sentences:\n",
80 | " fs.add(s)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "text = ' '.join(sentences[5:10])\n",
90 | "print(text)"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "fs.get(text)"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "from fuzzywuzzy import fuzz\n",
109 | "\n",
110 | "part = \"Production move rule marriage attorney money team open pretty card fall.\"\n",
111 | "\n",
112 | "fuzz.partial_ratio(part, text)"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "fuzz.partial_ratio(\"character they\", text)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "part = \"Production move rule marriage attorncy moncy team opcn pretty card fall.\"\n",
131 | "\n",
132 | "fuzz.partial_ratio(part, text)"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "from faker import Faker\n",
142 | "\n",
143 | "fake = Faker()\n",
144 | "num_words = 10\n",
145 | "sentences = []\n",
146 | "\n",
147 | "for i in range(10000):\n",
148 | " sentence = fake.sentence(nb_words=num_words, variable_nb_words=True, ext_word_list=None)\n",
149 | " sentences.append(sentence)"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "sentences[1]"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "from tqdm import tqdm_notebook as tqdm"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "%%time\n",
177 | "\n",
178 | "for s in tqdm(sentences):\n",
179 | " fuzz.partial_ratio(s, text)"
180 | ]
181 | }
182 | ],
183 | "metadata": {
184 | "language_info": {
185 | "name": "python",
186 | "pygments_lexer": "ipython3"
187 | }
188 | },
189 | "nbformat": 4,
190 | "nbformat_minor": 2
191 | }
192 |
--------------------------------------------------------------------------------
/ochre/train_mt.py:
--------------------------------------------------------------------------------
1 | import click
2 | import os
3 | import json
4 | import codecs
5 |
6 | from keras.models import Model
7 | from keras.layers import Input, LSTM, Dense
8 | import numpy as np
9 |
10 | from ochre.utils import add_checkpoint, load_weights, get_files
11 |
12 |
13 | def read_texts(data_dir, div, name):
14 | in_files = get_files(data_dir, div, name)
15 |
16 | # Vectorize the data.
17 | input_texts = []
18 | target_texts = []
19 |
20 | for in_file in in_files:
21 | lines = codecs.open(in_file, 'r', encoding='utf-8').readlines()
22 | for line in lines:
23 | #print line.split('||@@||')
24 | input_text, target_text = line.split('||@@||')
25 | # We use "tab" as the "start sequence" character
26 | # for the targets, and "\n" as "end sequence" character.
27 | target_text = '\t' + target_text
28 | input_texts.append(input_text)
29 | target_texts.append(target_text)
30 |
31 | return input_texts, target_texts
32 |
33 |
34 | def convert(input_texts, target_texts, input_characters, target_characters, max_encoder_seq_length, num_encoder_tokens, max_decoder_seq_length, num_decoder_tokens):
35 | input_token_index = dict(
36 | [(char, i) for i, char in enumerate(input_characters)])
37 | target_token_index = dict(
38 | [(char, i) for i, char in enumerate(target_characters)])
39 |
40 | encoder_input_data = np.zeros(
41 | (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
42 | dtype='float32')
43 | decoder_input_data = np.zeros(
44 | (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
45 | dtype='float32')
46 | decoder_target_data = np.zeros(
47 | (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
48 | dtype='float32')
49 |
50 | for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
51 | for t, char in enumerate(input_text):
52 | try:
53 | encoder_input_data[i, t, input_token_index[char]] = 1.
54 | except IndexError:
55 | # sequence longer than max length of training inputs
56 | pass
57 | for t, char in enumerate(target_text):
58 | # decoder_target_data is ahead of decoder_input_data by one timestep
59 | try:
60 | decoder_input_data[i, t, target_token_index[char]] = 1.
61 | except IndexError:
62 | # sequence longer than max length of training inputs
63 | pass
64 | if t > 0:
65 | # decoder_target_data will be ahead by one timestep
66 | # and will not include the start character.
67 | try:
68 | decoder_target_data[i, t - 1, target_token_index[char]] = 1.
69 | except IndexError:
70 | # sequence longer than max length of training inputs
71 | pass
72 |
73 | return encoder_input_data, decoder_input_data, decoder_target_data
74 |
75 |
76 | @click.command()
77 | @click.argument('datasets', type=click.File())
78 | @click.argument('data_dir', type=click.Path(exists=True))
79 | @click.option('--weights_dir', '-w', default=os.getcwd(), type=click.Path())
80 | def train_lstm(datasets, data_dir, weights_dir):
81 | batch_size = 64 # Batch size for training.
82 | epochs = 100 # Number of epochs to train for.
83 | latent_dim = 256 # Latent dimensionality of the encoding space.
84 |
85 | div = json.load(datasets)
86 |
87 | train_input, train_target = read_texts(data_dir, div, 'train')
88 | val_input, val_target = read_texts(data_dir, div, 'val')
89 | #test_input, test_target = read_texts(data_dir, div, 'test')
90 |
91 | input_characters = sorted(list(set(u''.join(train_input) + u''.join(val_input))))
92 | target_characters = sorted(list(set(u''.join(train_target) + u''.join(val_target))))
93 | num_encoder_tokens = len(input_characters)
94 | num_decoder_tokens = len(target_characters)
95 | max_encoder_seq_length = max([len(txt) for txt in train_input])
96 | max_decoder_seq_length = max([len(txt) for txt in train_target])
97 |
98 | print('Number of samples:', len(train_input))
99 | print('Number of unique input tokens:', num_encoder_tokens)
100 | print('Number of unique output tokens:', num_decoder_tokens)
101 | print('Max sequence length for inputs:', max_encoder_seq_length)
102 | print('Max sequence length for outputs:', max_decoder_seq_length)
103 | print('Input characters:', u''.join(input_characters))
104 | print('Output characters:', u''.join(target_characters))
105 |
106 | train_enc_input, train_dec_input, train_dec_target = convert(train_input,
107 | train_target, input_characters, target_characters,
108 | max_encoder_seq_length, num_encoder_tokens, max_decoder_seq_length,
109 | num_decoder_tokens)
110 |
111 | val_enc_input, val_dec_input, val_dec_target = convert(val_input,
112 | val_target, input_characters, target_characters,
113 | max_encoder_seq_length, num_encoder_tokens, max_decoder_seq_length,
114 | num_decoder_tokens)
115 |
116 | # Define an input sequence and process it.
117 | encoder_inputs = Input(shape=(None, num_encoder_tokens))
118 | encoder = LSTM(latent_dim, return_state=True)
119 | encoder_outputs, state_h, state_c = encoder(encoder_inputs)
120 | # We discard `encoder_outputs` and only keep the states.
121 | encoder_states = [state_h, state_c]
122 |
123 | # Set up the decoder, using `encoder_states` as initial state.
124 | decoder_inputs = Input(shape=(None, num_decoder_tokens))
125 | # We set up our decoder to return full output sequences,
126 | # and to return internal states as well. We don't use the
127 | # return states in the training model, but we will use them in inference.
128 | decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
129 | decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
130 | initial_state=encoder_states)
131 | decoder_dense = Dense(num_decoder_tokens, activation='softmax')
132 | decoder_outputs = decoder_dense(decoder_outputs)
133 |
134 | # Define the model that will turn
135 | # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
136 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
137 |
138 | # Run training
139 | model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
140 | epoch, model = load_weights(model, weights_dir, optimizer='rmsprop', loss='categorical_crossentropy')
141 | callbacks_list = [add_checkpoint(weights_dir)]
142 | model.fit([train_enc_input, train_dec_input], train_dec_target,
143 | batch_size=batch_size,
144 | epochs=epochs,
145 | validation_data=([val_enc_input, val_dec_input], val_dec_target),
146 | callbacks=callbacks_list,
147 | initial_epoch=epoch)
148 |
149 |
150 | if __name__ == '__main__':
151 | train_lstm()
152 |
--------------------------------------------------------------------------------