├── .circleci └── config.yml ├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── Makefile ├── data └── raw │ ├── emnist │ ├── metadata.toml │ └── readme.md │ ├── fsdl_handwriting │ ├── fsdl_handwriting.json │ ├── metadata.toml │ └── readme.md │ └── iam │ ├── metadata.toml │ └── readme.md ├── environment.yml ├── lab1 ├── readme.md ├── text_recognizer │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── mnist.py │ │ └── util.py │ ├── lit_models │ │ ├── __init__.py │ │ └── base.py │ ├── models │ │ ├── __init__.py │ │ └── mlp.py │ └── util.py └── training │ ├── __init__.py │ └── run_experiment.py ├── lab2 ├── notebooks │ ├── 01-look-at-emnist.ipynb │ └── 02-look-at-emnist-lines.ipynb ├── readme.md ├── resblock.png ├── text_recognizer │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── lit_models │ │ ├── __init__.py │ │ └── base.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ └── mlp.py │ └── util.py └── training │ ├── __init__.py │ └── run_experiment.py ├── lab3 ├── notebooks │ ├── 01-look-at-emnist.ipynb │ └── 02-look-at-emnist-lines.ipynb ├── readme.md ├── text_recognizer │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── lit_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ctc.py │ │ ├── metrics.py │ │ └── util.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── line_cnn.py │ │ ├── line_cnn_lstm.py │ │ ├── line_cnn_simple.py │ │ └── mlp.py │ └── util.py └── training │ ├── __init__.py │ └── run_experiment.py ├── lab4 ├── notebooks │ ├── 01-look-at-emnist.ipynb │ └── 02-look-at-emnist-lines.ipynb ├── readme.md ├── text_recognizer │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── lit_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ctc.py │ │ ├── metrics.py │ │ ├── transformer.py │ │ └── util.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── line_cnn.py │ │ ├── line_cnn_lstm.py │ │ ├── line_cnn_simple.py │ │ ├── line_cnn_transformer.py │ │ ├── mlp.py │ │ └── transformer_util.py │ └── util.py └── training │ ├── __init__.py │ └── run_experiment.py ├── lab5 ├── notebooks │ ├── 01-look-at-emnist.ipynb │ ├── 02-look-at-emnist-lines.ipynb │ ├── 02b-look-at-emnist-lines2.ipynb │ └── 03-look-at-iam-lines.ipynb ├── readme.md ├── text_recognizer │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── emnist_lines2.py │ │ ├── iam.py │ │ ├── iam_lines.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── lit_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ctc.py │ │ ├── metrics.py │ │ ├── transformer.py │ │ └── util.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── line_cnn.py │ │ ├── line_cnn_lstm.py │ │ ├── line_cnn_simple.py │ │ ├── line_cnn_transformer.py │ │ ├── mlp.py │ │ └── transformer_util.py │ └── util.py └── training │ ├── __init__.py │ ├── run_experiment.py │ └── sweeps │ └── emnist_lines2_line_cnn_transformer.yml ├── lab6 ├── annotation_interface.png └── readme.md ├── lab7 ├── notebooks │ ├── 01-look-at-emnist.ipynb │ ├── 02-look-at-emnist-lines.ipynb │ ├── 02b-look-at-emnist-lines2.ipynb │ ├── 03-look-at-iam-lines.ipynb │ ├── 04-look-at-iam-paragraphs.ipynb │ └── 04b-look-at-paragraph-predictions.ipynb ├── readme.md ├── text_recognizer │ ├── __init__.py │ ├── artifacts │ │ └── paragraph_text_recognizer │ │ │ ├── config.json │ │ │ ├── model.pt │ │ │ └── run_command.txt │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── emnist_lines2.py │ │ ├── iam.py │ │ ├── iam_lines.py │ │ ├── iam_original_and_synthetic_paragraphs.py │ │ ├── iam_paragraphs.py │ │ ├── iam_synthetic_paragraphs.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── lit_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ctc.py │ │ ├── metrics.py │ │ ├── transformer.py │ │ └── util.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── line_cnn.py │ │ ├── line_cnn_lstm.py │ │ ├── line_cnn_simple.py │ │ ├── line_cnn_transformer.py │ │ ├── mlp.py │ │ ├── resnet_transformer.py │ │ └── transformer_util.py │ ├── paragraph_text_recognizer.py │ └── util.py └── training │ ├── __init__.py │ ├── run_experiment.py │ ├── save_best_model.py │ └── sweeps │ └── emnist_lines2_line_cnn_transformer.yml ├── lab8 ├── .pylintrc ├── notebooks │ ├── 01-look-at-emnist.ipynb │ ├── 02-look-at-emnist-lines.ipynb │ ├── 02b-look-at-emnist-lines2.ipynb │ ├── 03-look-at-iam-lines.ipynb │ ├── 04-look-at-iam-paragraphs.ipynb │ └── 04b-look-at-paragraph-predictions.ipynb ├── pyproject.toml ├── readme.md ├── setup.cfg ├── tasks │ ├── lint.sh │ └── test.sh ├── text_recognizer │ ├── __init__.py │ ├── artifacts │ │ └── paragraph_text_recognizer │ │ │ ├── config.json │ │ │ ├── model.pt │ │ │ └── run_command.txt │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── emnist_lines2.py │ │ ├── fake_images.py │ │ ├── iam.py │ │ ├── iam_lines.py │ │ ├── iam_original_and_synthetic_paragraphs.py │ │ ├── iam_paragraphs.py │ │ ├── iam_synthetic_paragraphs.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── evaluation │ │ └── evaluate_paragraph_text_recognizer.py │ ├── lit_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ctc.py │ │ ├── metrics.py │ │ ├── transformer.py │ │ └── util.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── line_cnn.py │ │ ├── line_cnn_lstm.py │ │ ├── line_cnn_simple.py │ │ ├── line_cnn_transformer.py │ │ ├── mlp.py │ │ ├── resnet_transformer.py │ │ └── transformer_util.py │ ├── paragraph_text_recognizer.py │ ├── tests │ │ ├── support │ │ │ └── paragraphs │ │ │ │ ├── a01-077.png │ │ │ │ ├── a01-087.png │ │ │ │ ├── a01-107.png │ │ │ │ ├── a02-046.png │ │ │ │ └── data_by_file_id.json │ │ └── test_paragraph_text_recognizer.py │ └── util.py └── training │ ├── __init__.py │ ├── run_experiment.py │ ├── save_best_model.py │ ├── sweeps │ └── emnist_lines2_line_cnn_transformer.yml │ └── tests │ └── test_run_experiment.sh ├── lab9 ├── .dockerignore ├── .pylintrc ├── api_server │ ├── Dockerfile │ ├── __init__.py │ ├── app.py │ └── tests │ │ └── test_app.py ├── api_serverless │ ├── Dockerfile │ ├── __init__.py │ └── app.py ├── notebooks │ ├── 01-look-at-emnist.ipynb │ ├── 02-look-at-emnist-lines.ipynb │ ├── 02b-look-at-emnist-lines2.ipynb │ ├── 03-look-at-iam-lines.ipynb │ ├── 04-look-at-iam-paragraphs.ipynb │ └── 04b-look-at-paragraph-predictions.ipynb ├── pyproject.toml ├── readme.md ├── setup.cfg ├── tasks │ ├── lint.sh │ └── test.sh ├── text_recognizer │ ├── __init__.py │ ├── artifacts │ │ └── paragraph_text_recognizer │ │ │ ├── config.json │ │ │ ├── model.pt │ │ │ └── run_command.txt │ ├── data │ │ ├── __init__.py │ │ ├── base_data_module.py │ │ ├── emnist.py │ │ ├── emnist_essentials.json │ │ ├── emnist_lines.py │ │ ├── emnist_lines2.py │ │ ├── fake_images.py │ │ ├── iam.py │ │ ├── iam_lines.py │ │ ├── iam_original_and_synthetic_paragraphs.py │ │ ├── iam_paragraphs.py │ │ ├── iam_synthetic_paragraphs.py │ │ ├── mnist.py │ │ ├── sentence_generator.py │ │ └── util.py │ ├── evaluation │ │ └── evaluate_paragraph_text_recognizer.py │ ├── lit_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ctc.py │ │ ├── metrics.py │ │ ├── transformer.py │ │ └── util.py │ ├── models │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── line_cnn.py │ │ ├── line_cnn_lstm.py │ │ ├── line_cnn_simple.py │ │ ├── line_cnn_transformer.py │ │ ├── mlp.py │ │ ├── resnet_transformer.py │ │ └── transformer_util.py │ ├── paragraph_text_recognizer.py │ ├── tests │ │ ├── support │ │ │ └── paragraphs │ │ │ │ ├── a01-077.png │ │ │ │ ├── a01-087.png │ │ │ │ ├── a01-107.png │ │ │ │ ├── a02-046.png │ │ │ │ └── data_by_file_id.json │ │ └── test_paragraph_text_recognizer.py │ └── util.py └── training │ ├── __init__.py │ ├── run_experiment.py │ ├── save_best_model.py │ ├── sweeps │ └── emnist_lines2_line_cnn_transformer.yml │ └── tests │ └── test_run_experiment.sh ├── readme.md ├── requirements ├── dev.in ├── dev.txt ├── prod.in └── prod.txt └── setup ├── colab_lab1.png ├── colab_runtime.png ├── colab_vscode.png ├── colab_vscode_2.png └── readme.md /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | version: 2 5 | jobs: 6 | build: 7 | docker: 8 | - image: circleci/python:3.6 9 | 10 | steps: 11 | - checkout 12 | 13 | - restore_cache: 14 | keys: 15 | - cache-{{ checksum "requirements/prod.txt" }}-{{ checksum "requirements/dev.txt" }} 16 | 17 | - run: 18 | name: Install Git LFS 19 | command: | 20 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 21 | sudo apt-get install git-lfs 22 | git lfs install 23 | 24 | - run: 25 | name: Pull LFS Files 26 | command: git lfs pull 27 | 28 | - run: 29 | name: Install Shellcheck 30 | command: | 31 | curl -OL https://github.com/koalaman/shellcheck/releases/download/stable/shellcheck-stable.linux.x86_64.tar.xz 32 | tar xf shellcheck-stable.linux.x86_64.tar.xz 33 | sudo mv shellcheck-stable/shellcheck /usr/local/bin 34 | working_directory: /tmp/shellcheck 35 | 36 | - run: 37 | name: install dependencies 38 | command: | 39 | pip install --quiet -r requirements/prod.txt -r requirements/dev.txt 40 | 41 | - save_cache: 42 | key: cache-{{ checksum "requirements/prod.txt" }}-{{ checksum "requirements/dev.txt" }} 43 | paths: 44 | - ~/.local 45 | 46 | - run: 47 | name: run linting 48 | when: always 49 | command: | 50 | cd lab8 || true; PYTHONPATH=. ./tasks/lint.sh 51 | 52 | - run: 53 | name: run prediction tests 54 | when: always 55 | command: | 56 | cd lab8 || true; PYTHONPATH=. ./tasks/test.sh 57 | 58 | - store_artifacts: 59 | path: test-reports 60 | destination: test-reports 61 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | *.ckpt filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/downloaded 3 | data/processed 4 | data/interim 5 | 6 | # Logs 7 | training/logs 8 | 9 | # Editors 10 | .vscode 11 | 12 | # Node 13 | node_modules 14 | 15 | # Python 16 | __pycache__ 17 | .pytest_cache 18 | .ipynb_checkpoints 19 | 20 | # Distribution / packaging 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | 38 | # W&B 39 | wandb 40 | 41 | # Misc 42 | .DS_Store 43 | _labs 44 | logs 45 | .mypy_cache 46 | notebooks/lightning_logs 47 | lightning_logs/ 48 | lab9/requirements.txt 49 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Full Stack Deep Learning, LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Arcane incantation to print all the other targets, from https://stackoverflow.com/a/26339924 2 | help: 3 | @$(MAKE) -pRrq -f $(lastword $(MAKEFILE_LIST)) : 2>/dev/null | awk -v RS= -F: '/^# File/,/^# Finished Make data base/ {if ($$1 !~ "^[#.]") {print $$1}}' | sort | egrep -v -e '^[^[:alnum:]]' -e '^$@$$' 4 | 5 | # Install exact Python and CUDA versions 6 | conda-update: 7 | conda env update --prune -f environment.yml 8 | echo "!!!RUN RIGHT NOW:\nconda activate fsdl-text-recognizer-2021" 9 | 10 | # Compile and install exact pip packages 11 | pip-tools: 12 | pip install pip-tools 13 | pip-compile requirements/prod.in && pip-compile requirements/dev.in 14 | pip-sync requirements/prod.txt requirements/dev.txt 15 | 16 | # Example training command 17 | train-mnist-cnn-ddp: 18 | python training/run_experiment.py --max_epochs=10 --gpus=-1 --accelerator=ddp --num_workers=20 --data_class=MNIST --model_class=CNN 19 | 20 | # Lint 21 | lint: 22 | tasks/lint.sh 23 | -------------------------------------------------------------------------------- /data/raw/emnist/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = 'matlab.zip' 2 | sha256 = 'e1fa805cdeae699a52da0b77c2db17f6feb77eed125f9b45c022e7990444df95' 3 | url = 'https://s3-us-west-2.amazonaws.com/fsdl-public-assets/matlab.zip' 4 | -------------------------------------------------------------------------------- /data/raw/emnist/readme.md: -------------------------------------------------------------------------------- 1 | # EMNIST dataset 2 | 3 | The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 4 | and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." 5 | From https://www.nist.gov/itl/iad/image-group/emnist-dataset 6 | 7 | Original url is http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/matlab.zip 8 | 9 | We uploaded the same file to our S3 bucket for faster download. 10 | -------------------------------------------------------------------------------- /data/raw/fsdl_handwriting/metadata.toml: -------------------------------------------------------------------------------- 1 | url = "https://dataturks.com/projects/sergeykarayev/fsdl_handwriting/export" 2 | filename = "fsdl_handwriting.json" 3 | sha256 = "720d6c72b4317a9a5492630a1c9f6d83a20d36101a29311a5cf7825c1d60c180" 4 | -------------------------------------------------------------------------------- /data/raw/fsdl_handwriting/readme.md: -------------------------------------------------------------------------------- 1 | # FSDL Handwriting Dataset 2 | 3 | ## Collection 4 | 5 | Handwritten paragraphs were collected in the FSDL March 2019 class. 6 | The resulting PDF was stored at https://fsdl-public-assets.s3-us-west-2.amazonaws.com/fsdl_handwriting_20190302.pdf 7 | 8 | Pages were extracted from the PDF by running `gs -q -dBATCH -dNOPAUSE -sDEVICE=jpeg -r300 -sOutputFile=page-%03d.jpg -f fsdl_handwriting_20190302.pdf` and uploaded to S3, with urls like https://fsdl-public-assets.s3-us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-001.jpg 9 | -------------------------------------------------------------------------------- /data/raw/iam/metadata.toml: -------------------------------------------------------------------------------- 1 | url = 'https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam/iamdb.zip' 2 | filename = 'iamdb.zip' 3 | sha256 = 'f3c9e87a88a313e557c6d3548ed8a2a1af2dc3c4a678c5f3fc6f972ba4a50c55' 4 | test_ids = ["m01-049","m01-060","m01-079","m01-084","m01-090","m01-095","m01-104","m01-121","m01-110","m01-131","m01-115","m01-125","m01-136","m01-149","m01-160","m02-048","m02-052","m02-055","m02-059","m02-066","m02-069","m02-072","m02-075","m02-106","m02-080","m02-083","m02-087","m02-090","m02-095","m02-102","m02-109","m02-112","m03-006","m03-013","m03-020","m03-033","m03-062","m03-095","m03-110","m03-114","m03-118","m04-000","m04-007","m04-012","m04-019","m04-024","m04-030","m04-038","m04-043","m04-061","m04-072","m04-078","m04-081","m04-093","m04-100","m04-107","m04-113","m04-123","m04-131","m04-138","m04-145","m04-152","m04-164","m04-180","m04-251","m04-190","m04-200","m04-209","m04-216","m04-222","m04-231","m04-238","m04-246","n04-000","n04-009","m06-019","n06-148","n06-156","n06-163","n06-169","n06-175","n06-182","n06-186","n06-194","n06-201","m06-031","m06-042","m06-048","m06-056","m06-067","m06-076","m06-083","m06-091","m06-098","m06-106","n01-000","n01-009","n01-004","n01-020","n01-031","n01-045","n01-036","n01-052","n01-057","n02-000","n02-016","n02-004","n02-009","n02-028","n02-049","n02-033","n02-037","n02-040","n02-045","n02-054","n02-062","n02-082","n02-098","p03-057","p03-087","p03-096","p03-103","p03-112","n02-151","n02-154","n02-157","n03-038","n03-064","n03-066","n03-079","n03-082","n03-091","n03-097","n03-103","n03-106","n03-113","n03-120","n03-126","n04-015","n04-022","n04-031","n04-039","n04-044","n04-048","n04-052","n04-060","n04-068","n04-075","n04-084","n04-092","n04-100","n04-107","n04-114","n04-130","n04-139","n04-149","n04-156","n04-163","n04-171","n04-183","n04-190","n04-195","n04-202","n04-209","n04-213","n04-218","n06-074","n06-082","n06-092","n06-100","n06-111","n06-119","n06-123","n06-128","n06-133","n06-140","p01-147","p01-155","p01-168","p01-174","p02-000","p02-008","p02-017","p02-022","p02-027","p02-069","p02-090","p02-076","p02-081","p02-101","p02-105","p02-109","p02-115","p02-121","p02-127","p02-131","p02-135","p02-139","p02-144","p02-150","p02-155","p03-004","p03-009","p03-012","p03-023","p03-027","p03-029","p03-033","p03-040","p03-047","p03-069","p03-072","p03-080","p03-121","p03-135","p03-142","p03-151","p03-158","p03-163","p03-173","p03-181","p03-185","p03-189","p06-030","p06-042","p06-047","p06-052","p06-058","p06-069","p06-088","p06-096","p06-104"] 5 | -------------------------------------------------------------------------------- /data/raw/iam/readme.md: -------------------------------------------------------------------------------- 1 | # IAM Dataset 2 | 3 | The IAM Handwriting Database contains forms of handwritten English text which can be used to train and test handwritten text recognizers and to perform writer identification and verification experiments. 4 | 5 | - 657 writers contributed samples of their handwriting 6 | - 1,539 pages of scanned text 7 | - 13,353 isolated and labeled text lines 8 | 9 | - http://www.fki.inf.unibe.ch/databases/iam-handwriting-database 10 | 11 | ## Pre-processing 12 | 13 | First, all forms were placed into one directory called `forms`, from original directories like `formsA-D`. 14 | 15 | To save space, I converted the original PNG files to JPG, and resized them to half-size 16 | ``` 17 | mkdir forms-resized 18 | cd forms 19 | ls -1 *.png | parallel --eta -j 6 convert '{}' -adaptive-resize 50% '../forms-resized/{.}.jpg' 20 | ``` 21 | 22 | ## Split 23 | 24 | The data split we will use is 25 | IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. 26 | 27 | - The validation set has been merged into the train set. 28 | - The train set has 7,101 lines from 326 writers. 29 | - The test set has 1,861 lines from 128 writers. 30 | - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. 31 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fsdl-text-recognizer-2021 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6 # Google Colab is still on Python 3.6 6 | - cudatoolkit=10.1 7 | - cudnn=7.6 8 | - pip 9 | - pip: 10 | - pip-tools 11 | -------------------------------------------------------------------------------- /lab1/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab1/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab1/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /lab1/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab1/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | -------------------------------------------------------------------------------- /lab1/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /lab1/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab1/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab1/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab1/training/__init__.py -------------------------------------------------------------------------------- /lab2/resblock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab2/resblock.png -------------------------------------------------------------------------------- /lab2/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab2/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab2/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /lab2/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab2/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab2/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab2/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | -------------------------------------------------------------------------------- /lab2/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /lab2/text_recognizer/models/cnn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | CONV_DIM = 64 10 | FC_DIM = 128 11 | IMAGE_SIZE = 28 12 | 13 | 14 | class ConvBlock(nn.Module): 15 | """ 16 | Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. 17 | """ 18 | 19 | def __init__(self, input_channels: int, output_channels: int) -> None: 20 | super().__init__() 21 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) 22 | self.relu = nn.ReLU() 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | """ 26 | Parameters 27 | ---------- 28 | x 29 | of dimensions (B, C, H, W) 30 | 31 | Returns 32 | ------- 33 | torch.Tensor 34 | of dimensions (B, C, H, W) 35 | """ 36 | c = self.conv(x) 37 | r = self.relu(c) 38 | return r 39 | 40 | 41 | class CNN(nn.Module): 42 | """Simple CNN for recognizing characters in a square image.""" 43 | 44 | def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: 45 | super().__init__() 46 | self.args = vars(args) if args is not None else {} 47 | 48 | input_dims = data_config["input_dims"] 49 | num_classes = len(data_config["mapping"]) 50 | 51 | conv_dim = self.args.get("conv_dim", CONV_DIM) 52 | fc_dim = self.args.get("fc_dim", FC_DIM) 53 | 54 | self.conv1 = ConvBlock(input_dims[0], conv_dim) 55 | self.conv2 = ConvBlock(conv_dim, conv_dim) 56 | self.dropout = nn.Dropout(0.25) 57 | self.max_pool = nn.MaxPool2d(2) 58 | 59 | # Because our 3x3 convs have padding size 1, they leave the input size unchanged. 60 | # The 2x2 max-pool divides the input size by 2. Flattening squares it. 61 | conv_output_size = IMAGE_SIZE // 2 62 | fc_input_dim = int(conv_output_size * conv_output_size * conv_dim) 63 | self.fc1 = nn.Linear(fc_input_dim, fc_dim) 64 | self.fc2 = nn.Linear(fc_dim, num_classes) 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | """ 68 | Args: 69 | x 70 | (B, C, H, W) tensor, where H and W must equal IMAGE_SIZE 71 | 72 | Returns 73 | ------- 74 | torch.Tensor 75 | (B, C) tensor 76 | """ 77 | _B, _C, H, W = x.shape 78 | assert H == W == IMAGE_SIZE 79 | x = self.conv1(x) 80 | x = self.conv2(x) 81 | x = self.max_pool(x) 82 | x = self.dropout(x) 83 | x = torch.flatten(x, 1) 84 | x = self.fc1(x) 85 | x = F.relu(x) 86 | x = self.fc2(x) 87 | return x 88 | 89 | @staticmethod 90 | def add_to_argparse(parser): 91 | parser.add_argument("--conv_dim", type=int, default=CONV_DIM) 92 | parser.add_argument("--fc_dim", type=int, default=FC_DIM) 93 | return parser 94 | -------------------------------------------------------------------------------- /lab2/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab2/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab2/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab2/training/__init__.py -------------------------------------------------------------------------------- /lab3/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab3/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab3/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /lab3/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab3/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab3/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab3/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | # Hide lines below until Lab 3 4 | from .ctc import CTCLitModel 5 | 6 | # Hide lines above until Lab 3 7 | -------------------------------------------------------------------------------- /lab3/text_recognizer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import editdistance 6 | 7 | 8 | class CharacterErrorRate(pl.metrics.Metric): 9 | """Character error rate metric, computed using Levenshtein distance.""" 10 | 11 | def __init__(self, ignore_tokens: Sequence[int], *args): 12 | super().__init__(*args) 13 | self.ignore_tokens = set(ignore_tokens) 14 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") # pylint: disable=not-callable 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # pylint: disable=not-callable 16 | self.error: torch.Tensor 17 | self.total: torch.Tensor 18 | 19 | def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: 20 | N = preds.shape[0] 21 | for ind in range(N): 22 | pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens] 23 | target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens] 24 | distance = editdistance.distance(pred, target) 25 | error = distance / max(len(pred), len(target)) 26 | self.error = self.error + error 27 | self.total = self.total + N 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.error / self.total 31 | 32 | 33 | def test_character_error_rate(): 34 | metric = CharacterErrorRate([0, 1]) 35 | X = torch.tensor( # pylint: disable=not-callable 36 | [ 37 | [0, 2, 2, 3, 3, 1], # error will be 0 38 | [0, 2, 1, 1, 1, 1], # error will be .75 39 | [0, 2, 2, 4, 4, 1], # error will be .5 40 | ] 41 | ) 42 | Y = torch.tensor( # pylint: disable=not-callable 43 | [ 44 | [0, 2, 2, 3, 3, 1], 45 | [0, 2, 2, 3, 3, 1], 46 | [0, 2, 2, 3, 3, 1], 47 | ] 48 | ) 49 | metric(X, Y) 50 | print(metric.compute()) 51 | assert metric.compute() == sum([0, 0.75, 0.5]) / 3 52 | 53 | 54 | if __name__ == "__main__": 55 | test_character_error_rate() 56 | -------------------------------------------------------------------------------- /lab3/text_recognizer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def first_element(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: 7 | """ 8 | Return indices of first occurence of element in x. If not found, return length of x along dim. 9 | 10 | Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 11 | 12 | Examples 13 | -------- 14 | >>> first_element(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3) 15 | tensor([2, 1, 3]) 16 | """ 17 | nonz = x == element 18 | ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices 19 | ind[ind == 0] = x.shape[dim] 20 | return ind 21 | -------------------------------------------------------------------------------- /lab3/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | # Hide lines below until Lab 3 9 | from .line_cnn_simple import LineCNNSimple 10 | from .line_cnn import LineCNN 11 | from .line_cnn_lstm import LineCNNLSTM 12 | 13 | # Hide lines above until Lab 3 14 | 15 | 16 | -------------------------------------------------------------------------------- /lab3/text_recognizer/models/line_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .line_cnn import LineCNN 7 | 8 | LSTM_DIM = 512 9 | LSTM_LAYERS = 1 10 | LSTM_DROPOUT = 0.2 11 | 12 | 13 | class LineCNNLSTM(nn.Module): 14 | """Process the line through a CNN and process the resulting sequence through LSTM layers.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.data_config = data_config 23 | self.args = vars(args) if args is not None else {} 24 | 25 | num_classes = len(data_config["mapping"]) 26 | lstm_dim = self.args.get("lstm_dim", LSTM_DIM) 27 | lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS) 28 | lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT) 29 | 30 | self.line_cnn = LineCNN(data_config=data_config, args=args) 31 | # LineCNN outputs (B, C, S) log probs, with C == num_classes 32 | 33 | self.lstm = nn.LSTM( 34 | input_size=num_classes, 35 | hidden_size=lstm_dim, 36 | num_layers=lstm_layers, 37 | dropout=lstm_dropout, 38 | bidirectional=True, 39 | ) 40 | self.fc = nn.Linear(lstm_dim, num_classes) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Parameters 45 | ---------- 46 | x 47 | (B, H, W) input image 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 53 | S can be computed from W and CHAR_WIDTH 54 | C is num_classes 55 | """ 56 | x = self.line_cnn(x) # -> (B, C, S) 57 | B, _C, S = x.shape 58 | x = x.permute(2, 0, 1) # -> (S, B, C) 59 | 60 | x, _ = self.lstm(x) # -> (S, B, 2 * H) where H is lstm_dim 61 | 62 | # Sum up both directions of the LSTM: 63 | x = x.view(S, B, 2, -1).sum(dim=2) # -> (S, B, H) 64 | 65 | x = self.fc(x) # -> (S, B, C) 66 | 67 | return x.permute(1, 2, 0) # -> (B, C, S) 68 | 69 | @staticmethod 70 | def add_to_argparse(parser): 71 | LineCNN.add_to_argparse(parser) 72 | parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM) 73 | parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS) 74 | parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT) 75 | return parser 76 | -------------------------------------------------------------------------------- /lab3/text_recognizer/models/line_cnn_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .cnn import CNN, IMAGE_SIZE 9 | 10 | WINDOW_WIDTH = 28 11 | WINDOW_STRIDE = 28 12 | 13 | 14 | class LineCNNSimple(nn.Module): 15 | """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" 16 | 17 | def __init__( 18 | self, 19 | data_config: Dict[str, Any], 20 | args: argparse.Namespace = None, 21 | ) -> None: 22 | super().__init__() 23 | self.args = vars(args) if args is not None else {} 24 | 25 | self.WW = self.args.get("window_width", WINDOW_WIDTH) 26 | self.WS = self.args.get("window_stride", WINDOW_STRIDE) 27 | self.limit_output_length = self.args.get("limit_output_length", False) 28 | 29 | self.num_classes = len(data_config["mapping"]) 30 | self.output_length = data_config["output_dims"][0] 31 | self.cnn = CNN(data_config=data_config, args=args) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Parameters 36 | ---------- 37 | x 38 | (B, C, H, W) input image 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 44 | S can be computed from W and CHAR_WIDTH 45 | C is self.num_classes 46 | """ 47 | B, _C, H, W = x.shape 48 | assert H == IMAGE_SIZE # Make sure we can use our CNN class 49 | 50 | # Compute number of windows 51 | S = math.floor((W - self.WW) / self.WS + 1) 52 | 53 | # NOTE: type_as properly sets device 54 | activations = torch.zeros((B, self.num_classes, S)).type_as(x) 55 | for s in range(S): 56 | start_w = self.WS * s 57 | end_w = start_w + self.WW 58 | window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) 59 | activations[:, :, s] = self.cnn(window) 60 | 61 | if self.limit_output_length: 62 | # S might not match ground truth, so let's only take enough activations as are expected 63 | activations = activations[:, :, : self.output_length] 64 | return activations 65 | 66 | @staticmethod 67 | def add_to_argparse(parser): 68 | CNN.add_to_argparse(parser) 69 | parser.add_argument( 70 | "--window_width", 71 | type=int, 72 | default=WINDOW_WIDTH, 73 | help="Width of the window that will slide over the input image.", 74 | ) 75 | parser.add_argument( 76 | "--window_stride", 77 | type=int, 78 | default=WINDOW_STRIDE, 79 | help="Stride of the window that will slide over the input image.", 80 | ) 81 | parser.add_argument("--limit_output_length", action="store_true", default=False) 82 | return parser 83 | -------------------------------------------------------------------------------- /lab3/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab3/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab3/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab3/training/__init__.py -------------------------------------------------------------------------------- /lab4/readme.md: -------------------------------------------------------------------------------- 1 | # Lab 4: Recognize synthetic sequences with Transformers 2 | 3 | Our goals are to introduce `LineCNNTransformer` and `TransformerLitModel`. 4 | 5 | ## LineCNNTransformer 6 | 7 | In Lab 3, we trained a `LineCNN` + LSTM model with CTC loss. 8 | 9 | In this lab, we will use the same `LineCNN` architecture as an "encoder" of the image, and then send it through Transformer decoder layers. 10 | 11 | The [PyTorch docs](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html 12 | ) for Transformer are not very good, and you might find our [simple Colab notebook](https://colab.research.google.com/drive/1swXWW5sOLW8zSZBaQBYcGQkQ_Bje_bmI) helpful. 13 | 14 | In the `LineCNNTransformer` class, pay attention to the `predict` method. 15 | 16 | ## TransformerLitModel 17 | 18 | Nothing too fancy here. 19 | We're now back to using cross entropy loss, but we're keeping the character error rate metrics from `CTCLitModel`. 20 | 21 | ## Training 22 | 23 | I find that more epochs are necessary with the Transformer than with our LSTM+CTC model. 24 | ~30 epochs gives the same performance as we were able to obtain before. 25 | 26 | I also changed the window width to 20 from 28, and window stride to 12, just because. 27 | 28 | ``` 29 | python training/run_experiment.py --max_epochs=40 --gpus=1 --num_workers=16 --data_class=EMNISTLines --min_overlap=0 --max_overlap=0.33 --model_class=LineCNNTransformer --window_width=20 --window_stride=12 --loss=transformer 30 | 31 | DATALOADER:0 TEST RESULTS 32 | {'test_acc': tensor(0.9022, device='cuda:0'), 33 | 'test_cer': tensor(0.1749, device='cuda:0')} 34 | ``` 35 | 36 | ## Homework 37 | 38 | Standard stuff: try training with some different hyperparameters, explain what you tried. 39 | 40 | There is also an opportunity to speed up the `predict` method that you could try. 41 | -------------------------------------------------------------------------------- /lab4/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab4/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab4/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /lab4/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab4/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab4/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab4/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | # Hide lines below until Lab 3 4 | from .ctc import CTCLitModel 5 | 6 | # Hide lines above until Lab 3 7 | # Hide lines below until Lab 4 8 | from .transformer import TransformerLitModel 9 | 10 | # Hide lines above until Lab 4 11 | -------------------------------------------------------------------------------- /lab4/text_recognizer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import editdistance 6 | 7 | 8 | class CharacterErrorRate(pl.metrics.Metric): 9 | """Character error rate metric, computed using Levenshtein distance.""" 10 | 11 | def __init__(self, ignore_tokens: Sequence[int], *args): 12 | super().__init__(*args) 13 | self.ignore_tokens = set(ignore_tokens) 14 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") # pylint: disable=not-callable 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # pylint: disable=not-callable 16 | self.error: torch.Tensor 17 | self.total: torch.Tensor 18 | 19 | def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: 20 | N = preds.shape[0] 21 | for ind in range(N): 22 | pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens] 23 | target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens] 24 | distance = editdistance.distance(pred, target) 25 | error = distance / max(len(pred), len(target)) 26 | self.error = self.error + error 27 | self.total = self.total + N 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.error / self.total 31 | 32 | 33 | def test_character_error_rate(): 34 | metric = CharacterErrorRate([0, 1]) 35 | X = torch.tensor( # pylint: disable=not-callable 36 | [ 37 | [0, 2, 2, 3, 3, 1], # error will be 0 38 | [0, 2, 1, 1, 1, 1], # error will be .75 39 | [0, 2, 2, 4, 4, 1], # error will be .5 40 | ] 41 | ) 42 | Y = torch.tensor( # pylint: disable=not-callable 43 | [ 44 | [0, 2, 2, 3, 3, 1], 45 | [0, 2, 2, 3, 3, 1], 46 | [0, 2, 2, 3, 3, 1], 47 | ] 48 | ) 49 | metric(X, Y) 50 | print(metric.compute()) 51 | assert metric.compute() == sum([0, 0.75, 0.5]) / 3 52 | 53 | 54 | if __name__ == "__main__": 55 | test_character_error_rate() 56 | -------------------------------------------------------------------------------- /lab4/text_recognizer/lit_models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | import wandb 4 | except ModuleNotFoundError: 5 | pass 6 | 7 | 8 | from .metrics import CharacterErrorRate 9 | from .base import BaseLitModel 10 | 11 | 12 | class TransformerLitModel(BaseLitModel): # pylint: disable=too-many-ancestors 13 | """ 14 | Generic PyTorch-Lightning class that must be initialized with a PyTorch module. 15 | 16 | The module must take x, y as inputs, and have a special predict() method. 17 | """ 18 | 19 | def __init__(self, model, args=None): 20 | super().__init__(model, args) 21 | 22 | self.mapping = self.model.data_config["mapping"] 23 | inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} 24 | start_index = inverse_mapping[""] 25 | end_index = inverse_mapping[""] 26 | padding_index = inverse_mapping["

"] 27 | 28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=padding_index) 29 | 30 | ignore_tokens = [start_index, end_index, padding_index] 31 | self.val_cer = CharacterErrorRate(ignore_tokens) 32 | self.test_cer = CharacterErrorRate(ignore_tokens) 33 | 34 | def forward(self, x): 35 | return self.model.predict(x) 36 | 37 | def training_step(self, batch, batch_idx): # pylint: disable=unused-argument 38 | x, y = batch 39 | logits = self.model(x, y[:, :-1]) 40 | loss = self.loss_fn(logits, y[:, 1:]) 41 | self.log("train_loss", loss) 42 | return loss 43 | 44 | def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument 45 | x, y = batch 46 | logits = self.model(x, y[:, :-1]) 47 | loss = self.loss_fn(logits, y[:, 1:]) 48 | self.log("val_loss", loss, prog_bar=True) 49 | 50 | pred = self.model.predict(x) 51 | self.val_cer(pred, y) 52 | self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) 53 | 54 | def test_step(self, batch, batch_idx): # pylint: disable=unused-argument 55 | x, y = batch 56 | pred = self.model.predict(x) 57 | self.test_cer(pred, y) 58 | self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) 59 | -------------------------------------------------------------------------------- /lab4/text_recognizer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def first_element(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: 7 | """ 8 | Return indices of first occurence of element in x. If not found, return length of x along dim. 9 | 10 | Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 11 | 12 | Examples 13 | -------- 14 | >>> first_element(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3) 15 | tensor([2, 1, 3]) 16 | """ 17 | nonz = x == element 18 | ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices 19 | ind[ind == 0] = x.shape[dim] 20 | return ind 21 | -------------------------------------------------------------------------------- /lab4/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | # Hide lines below until Lab 3 9 | from .line_cnn_simple import LineCNNSimple 10 | from .line_cnn import LineCNN 11 | from .line_cnn_lstm import LineCNNLSTM 12 | 13 | # Hide lines above until Lab 3 14 | 15 | # Hide lines below until Lab 4 16 | from .line_cnn_transformer import LineCNNTransformer 17 | 18 | # Hide lines above until Lab 4 19 | 20 | -------------------------------------------------------------------------------- /lab4/text_recognizer/models/line_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .line_cnn import LineCNN 7 | 8 | LSTM_DIM = 512 9 | LSTM_LAYERS = 1 10 | LSTM_DROPOUT = 0.2 11 | 12 | 13 | class LineCNNLSTM(nn.Module): 14 | """Process the line through a CNN and process the resulting sequence through LSTM layers.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.data_config = data_config 23 | self.args = vars(args) if args is not None else {} 24 | 25 | num_classes = len(data_config["mapping"]) 26 | lstm_dim = self.args.get("lstm_dim", LSTM_DIM) 27 | lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS) 28 | lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT) 29 | 30 | self.line_cnn = LineCNN(data_config=data_config, args=args) 31 | # LineCNN outputs (B, C, S) log probs, with C == num_classes 32 | 33 | self.lstm = nn.LSTM( 34 | input_size=num_classes, 35 | hidden_size=lstm_dim, 36 | num_layers=lstm_layers, 37 | dropout=lstm_dropout, 38 | bidirectional=True, 39 | ) 40 | self.fc = nn.Linear(lstm_dim, num_classes) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Parameters 45 | ---------- 46 | x 47 | (B, H, W) input image 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 53 | S can be computed from W and CHAR_WIDTH 54 | C is num_classes 55 | """ 56 | x = self.line_cnn(x) # -> (B, C, S) 57 | B, _C, S = x.shape 58 | x = x.permute(2, 0, 1) # -> (S, B, C) 59 | 60 | x, _ = self.lstm(x) # -> (S, B, 2 * H) where H is lstm_dim 61 | 62 | # Sum up both directions of the LSTM: 63 | x = x.view(S, B, 2, -1).sum(dim=2) # -> (S, B, H) 64 | 65 | x = self.fc(x) # -> (S, B, C) 66 | 67 | return x.permute(1, 2, 0) # -> (B, C, S) 68 | 69 | @staticmethod 70 | def add_to_argparse(parser): 71 | LineCNN.add_to_argparse(parser) 72 | parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM) 73 | parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS) 74 | parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT) 75 | return parser 76 | -------------------------------------------------------------------------------- /lab4/text_recognizer/models/line_cnn_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .cnn import CNN, IMAGE_SIZE 9 | 10 | WINDOW_WIDTH = 28 11 | WINDOW_STRIDE = 28 12 | 13 | 14 | class LineCNNSimple(nn.Module): 15 | """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" 16 | 17 | def __init__( 18 | self, 19 | data_config: Dict[str, Any], 20 | args: argparse.Namespace = None, 21 | ) -> None: 22 | super().__init__() 23 | self.args = vars(args) if args is not None else {} 24 | 25 | self.WW = self.args.get("window_width", WINDOW_WIDTH) 26 | self.WS = self.args.get("window_stride", WINDOW_STRIDE) 27 | self.limit_output_length = self.args.get("limit_output_length", False) 28 | 29 | self.num_classes = len(data_config["mapping"]) 30 | self.output_length = data_config["output_dims"][0] 31 | self.cnn = CNN(data_config=data_config, args=args) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Parameters 36 | ---------- 37 | x 38 | (B, C, H, W) input image 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 44 | S can be computed from W and CHAR_WIDTH 45 | C is self.num_classes 46 | """ 47 | B, _C, H, W = x.shape 48 | assert H == IMAGE_SIZE # Make sure we can use our CNN class 49 | 50 | # Compute number of windows 51 | S = math.floor((W - self.WW) / self.WS + 1) 52 | 53 | # NOTE: type_as properly sets device 54 | activations = torch.zeros((B, self.num_classes, S)).type_as(x) 55 | for s in range(S): 56 | start_w = self.WS * s 57 | end_w = start_w + self.WW 58 | window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) 59 | activations[:, :, s] = self.cnn(window) 60 | 61 | if self.limit_output_length: 62 | # S might not match ground truth, so let's only take enough activations as are expected 63 | activations = activations[:, :, : self.output_length] 64 | return activations 65 | 66 | @staticmethod 67 | def add_to_argparse(parser): 68 | CNN.add_to_argparse(parser) 69 | parser.add_argument( 70 | "--window_width", 71 | type=int, 72 | default=WINDOW_WIDTH, 73 | help="Width of the window that will slide over the input image.", 74 | ) 75 | parser.add_argument( 76 | "--window_stride", 77 | type=int, 78 | default=WINDOW_STRIDE, 79 | help="Stride of the window that will slide over the input image.", 80 | ) 81 | parser.add_argument("--limit_output_length", action="store_true", default=False) 82 | return parser 83 | -------------------------------------------------------------------------------- /lab4/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab4/text_recognizer/models/transformer_util.py: -------------------------------------------------------------------------------- 1 | """Position Encoding and other utilities for Tranformers""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | 9 | 10 | class PositionalEncoding(torch.nn.Module): 11 | """Classic Attention-is-all-you-need positional encoding.""" 12 | 13 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: 14 | super().__init__() 15 | self.dropout = torch.nn.Dropout(p=dropout) 16 | pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) 17 | self.register_buffer("pe", pe) 18 | 19 | @staticmethod 20 | def make_pe(d_model: int, max_len: int) -> torch.Tensor: 21 | pe = torch.zeros(max_len, d_model) 22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 23 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 24 | pe[:, 0::2] = torch.sin(position * div_term) 25 | pe[:, 1::2] = torch.cos(position * div_term) 26 | pe = pe.unsqueeze(1) 27 | return pe 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | # x.shape = (S, B, d_model) 31 | assert x.shape[2] == self.pe.shape[2] # type: ignore 32 | x = x + self.pe[: x.size(0)] # type: ignore 33 | return self.dropout(x) 34 | 35 | 36 | def generate_square_subsequent_mask(size: int) -> torch.Tensor: 37 | """Generate a triangular (size, size) mask.""" 38 | mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) 39 | mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) 40 | return mask 41 | -------------------------------------------------------------------------------- /lab4/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab4/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab4/training/__init__.py -------------------------------------------------------------------------------- /lab5/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab5/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab5/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | # Hide lines below until Lab 5 12 | from .emnist_lines2 import EMNISTLines2 13 | from .iam_lines import IAMLines 14 | 15 | # Hide lines above until Lab 5 16 | 17 | 18 | -------------------------------------------------------------------------------- /lab5/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab5/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab5/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab5/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | # Hide lines below until Lab 3 4 | from .ctc import CTCLitModel 5 | 6 | # Hide lines above until Lab 3 7 | # Hide lines below until Lab 4 8 | from .transformer import TransformerLitModel 9 | 10 | # Hide lines above until Lab 4 11 | -------------------------------------------------------------------------------- /lab5/text_recognizer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import editdistance 6 | 7 | 8 | class CharacterErrorRate(pl.metrics.Metric): 9 | """Character error rate metric, computed using Levenshtein distance.""" 10 | 11 | def __init__(self, ignore_tokens: Sequence[int], *args): 12 | super().__init__(*args) 13 | self.ignore_tokens = set(ignore_tokens) 14 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") # pylint: disable=not-callable 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # pylint: disable=not-callable 16 | self.error: torch.Tensor 17 | self.total: torch.Tensor 18 | 19 | def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: 20 | N = preds.shape[0] 21 | for ind in range(N): 22 | pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens] 23 | target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens] 24 | distance = editdistance.distance(pred, target) 25 | error = distance / max(len(pred), len(target)) 26 | self.error = self.error + error 27 | self.total = self.total + N 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.error / self.total 31 | 32 | 33 | def test_character_error_rate(): 34 | metric = CharacterErrorRate([0, 1]) 35 | X = torch.tensor( # pylint: disable=not-callable 36 | [ 37 | [0, 2, 2, 3, 3, 1], # error will be 0 38 | [0, 2, 1, 1, 1, 1], # error will be .75 39 | [0, 2, 2, 4, 4, 1], # error will be .5 40 | ] 41 | ) 42 | Y = torch.tensor( # pylint: disable=not-callable 43 | [ 44 | [0, 2, 2, 3, 3, 1], 45 | [0, 2, 2, 3, 3, 1], 46 | [0, 2, 2, 3, 3, 1], 47 | ] 48 | ) 49 | metric(X, Y) 50 | print(metric.compute()) 51 | assert metric.compute() == sum([0, 0.75, 0.5]) / 3 52 | 53 | 54 | if __name__ == "__main__": 55 | test_character_error_rate() 56 | -------------------------------------------------------------------------------- /lab5/text_recognizer/lit_models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | import wandb 4 | except ModuleNotFoundError: 5 | pass 6 | 7 | 8 | from .metrics import CharacterErrorRate 9 | from .base import BaseLitModel 10 | 11 | 12 | class TransformerLitModel(BaseLitModel): # pylint: disable=too-many-ancestors 13 | """ 14 | Generic PyTorch-Lightning class that must be initialized with a PyTorch module. 15 | 16 | The module must take x, y as inputs, and have a special predict() method. 17 | """ 18 | 19 | def __init__(self, model, args=None): 20 | super().__init__(model, args) 21 | 22 | self.mapping = self.model.data_config["mapping"] 23 | inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} 24 | start_index = inverse_mapping[""] 25 | end_index = inverse_mapping[""] 26 | padding_index = inverse_mapping["

"] 27 | 28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=padding_index) 29 | 30 | ignore_tokens = [start_index, end_index, padding_index] 31 | self.val_cer = CharacterErrorRate(ignore_tokens) 32 | self.test_cer = CharacterErrorRate(ignore_tokens) 33 | 34 | def forward(self, x): 35 | return self.model.predict(x) 36 | 37 | def training_step(self, batch, batch_idx): # pylint: disable=unused-argument 38 | x, y = batch 39 | logits = self.model(x, y[:, :-1]) 40 | loss = self.loss_fn(logits, y[:, 1:]) 41 | self.log("train_loss", loss) 42 | return loss 43 | 44 | def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument 45 | x, y = batch 46 | logits = self.model(x, y[:, :-1]) 47 | loss = self.loss_fn(logits, y[:, 1:]) 48 | self.log("val_loss", loss, prog_bar=True) 49 | 50 | pred = self.model.predict(x) 51 | # Hide lines below until Lab 5 52 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 53 | try: 54 | self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 55 | except AttributeError: 56 | pass 57 | # Hide lines above until Lab 5 58 | self.val_cer(pred, y) 59 | self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) 60 | 61 | def test_step(self, batch, batch_idx): # pylint: disable=unused-argument 62 | x, y = batch 63 | pred = self.model.predict(x) 64 | # Hide lines below until Lab 5 65 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 66 | try: 67 | self.logger.experiment.log({"test_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 68 | except AttributeError: 69 | pass 70 | # Hide lines above until Lab 5 71 | self.test_cer(pred, y) 72 | self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) 73 | -------------------------------------------------------------------------------- /lab5/text_recognizer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def first_element(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: 7 | """ 8 | Return indices of first occurence of element in x. If not found, return length of x along dim. 9 | 10 | Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 11 | 12 | Examples 13 | -------- 14 | >>> first_element(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3) 15 | tensor([2, 1, 3]) 16 | """ 17 | nonz = x == element 18 | ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices 19 | ind[ind == 0] = x.shape[dim] 20 | return ind 21 | -------------------------------------------------------------------------------- /lab5/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | # Hide lines below until Lab 3 9 | from .line_cnn_simple import LineCNNSimple 10 | from .line_cnn import LineCNN 11 | from .line_cnn_lstm import LineCNNLSTM 12 | 13 | # Hide lines above until Lab 3 14 | 15 | # Hide lines below until Lab 4 16 | from .line_cnn_transformer import LineCNNTransformer 17 | 18 | # Hide lines above until Lab 4 19 | 20 | -------------------------------------------------------------------------------- /lab5/text_recognizer/models/line_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .line_cnn import LineCNN 7 | 8 | LSTM_DIM = 512 9 | LSTM_LAYERS = 1 10 | LSTM_DROPOUT = 0.2 11 | 12 | 13 | class LineCNNLSTM(nn.Module): 14 | """Process the line through a CNN and process the resulting sequence through LSTM layers.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.data_config = data_config 23 | self.args = vars(args) if args is not None else {} 24 | 25 | num_classes = len(data_config["mapping"]) 26 | lstm_dim = self.args.get("lstm_dim", LSTM_DIM) 27 | lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS) 28 | lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT) 29 | 30 | self.line_cnn = LineCNN(data_config=data_config, args=args) 31 | # LineCNN outputs (B, C, S) log probs, with C == num_classes 32 | 33 | self.lstm = nn.LSTM( 34 | input_size=num_classes, 35 | hidden_size=lstm_dim, 36 | num_layers=lstm_layers, 37 | dropout=lstm_dropout, 38 | bidirectional=True, 39 | ) 40 | self.fc = nn.Linear(lstm_dim, num_classes) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Parameters 45 | ---------- 46 | x 47 | (B, H, W) input image 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 53 | S can be computed from W and CHAR_WIDTH 54 | C is num_classes 55 | """ 56 | x = self.line_cnn(x) # -> (B, C, S) 57 | B, _C, S = x.shape 58 | x = x.permute(2, 0, 1) # -> (S, B, C) 59 | 60 | x, _ = self.lstm(x) # -> (S, B, 2 * H) where H is lstm_dim 61 | 62 | # Sum up both directions of the LSTM: 63 | x = x.view(S, B, 2, -1).sum(dim=2) # -> (S, B, H) 64 | 65 | x = self.fc(x) # -> (S, B, C) 66 | 67 | return x.permute(1, 2, 0) # -> (B, C, S) 68 | 69 | @staticmethod 70 | def add_to_argparse(parser): 71 | LineCNN.add_to_argparse(parser) 72 | parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM) 73 | parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS) 74 | parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT) 75 | return parser 76 | -------------------------------------------------------------------------------- /lab5/text_recognizer/models/line_cnn_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .cnn import CNN, IMAGE_SIZE 9 | 10 | WINDOW_WIDTH = 28 11 | WINDOW_STRIDE = 28 12 | 13 | 14 | class LineCNNSimple(nn.Module): 15 | """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" 16 | 17 | def __init__( 18 | self, 19 | data_config: Dict[str, Any], 20 | args: argparse.Namespace = None, 21 | ) -> None: 22 | super().__init__() 23 | self.args = vars(args) if args is not None else {} 24 | 25 | self.WW = self.args.get("window_width", WINDOW_WIDTH) 26 | self.WS = self.args.get("window_stride", WINDOW_STRIDE) 27 | self.limit_output_length = self.args.get("limit_output_length", False) 28 | 29 | self.num_classes = len(data_config["mapping"]) 30 | self.output_length = data_config["output_dims"][0] 31 | self.cnn = CNN(data_config=data_config, args=args) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Parameters 36 | ---------- 37 | x 38 | (B, C, H, W) input image 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 44 | S can be computed from W and CHAR_WIDTH 45 | C is self.num_classes 46 | """ 47 | B, _C, H, W = x.shape 48 | assert H == IMAGE_SIZE # Make sure we can use our CNN class 49 | 50 | # Compute number of windows 51 | S = math.floor((W - self.WW) / self.WS + 1) 52 | 53 | # NOTE: type_as properly sets device 54 | activations = torch.zeros((B, self.num_classes, S)).type_as(x) 55 | for s in range(S): 56 | start_w = self.WS * s 57 | end_w = start_w + self.WW 58 | window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) 59 | activations[:, :, s] = self.cnn(window) 60 | 61 | if self.limit_output_length: 62 | # S might not match ground truth, so let's only take enough activations as are expected 63 | activations = activations[:, :, : self.output_length] 64 | return activations 65 | 66 | @staticmethod 67 | def add_to_argparse(parser): 68 | CNN.add_to_argparse(parser) 69 | parser.add_argument( 70 | "--window_width", 71 | type=int, 72 | default=WINDOW_WIDTH, 73 | help="Width of the window that will slide over the input image.", 74 | ) 75 | parser.add_argument( 76 | "--window_stride", 77 | type=int, 78 | default=WINDOW_STRIDE, 79 | help="Stride of the window that will slide over the input image.", 80 | ) 81 | parser.add_argument("--limit_output_length", action="store_true", default=False) 82 | return parser 83 | -------------------------------------------------------------------------------- /lab5/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab5/text_recognizer/models/transformer_util.py: -------------------------------------------------------------------------------- 1 | """Position Encoding and other utilities for Tranformers""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | 9 | 10 | class PositionalEncoding(torch.nn.Module): 11 | """Classic Attention-is-all-you-need positional encoding.""" 12 | 13 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: 14 | super().__init__() 15 | self.dropout = torch.nn.Dropout(p=dropout) 16 | pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) 17 | self.register_buffer("pe", pe) 18 | 19 | @staticmethod 20 | def make_pe(d_model: int, max_len: int) -> torch.Tensor: 21 | pe = torch.zeros(max_len, d_model) 22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 23 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 24 | pe[:, 0::2] = torch.sin(position * div_term) 25 | pe[:, 1::2] = torch.cos(position * div_term) 26 | pe = pe.unsqueeze(1) 27 | return pe 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | # x.shape = (S, B, d_model) 31 | assert x.shape[2] == self.pe.shape[2] # type: ignore 32 | x = x + self.pe[: x.size(0)] # type: ignore 33 | return self.dropout(x) 34 | 35 | 36 | def generate_square_subsequent_mask(size: int) -> torch.Tensor: 37 | """Generate a triangular (size, size) mask.""" 38 | mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) 39 | mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) 40 | return mask 41 | -------------------------------------------------------------------------------- /lab5/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab5/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab5/training/__init__.py -------------------------------------------------------------------------------- /lab5/training/sweeps/emnist_lines2_line_cnn_transformer.yml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - ${interpreter} 4 | - ${program} 5 | - "--wandb" 6 | - ${args} 7 | program: training/run_experiment.py 8 | method: random 9 | metric: 10 | goal: minimize 11 | name: val_loss 12 | early_terminate: 13 | type: hyperband 14 | min_iter: 20 15 | parameters: 16 | conv_dim: 17 | values: [32, 64] 18 | window_width: 19 | values: [8, 16] 20 | window_stride: 21 | value: 8 22 | fc_dim: 23 | values: [512, 1024] 24 | tf_dim: 25 | values: [128, 256] 26 | tf_fc_dim: 27 | values: [256, 1024] 28 | tf_nhead: 29 | values: [4, 8] 30 | tf_layers: 31 | values: [2, 4, 6] 32 | lr: 33 | values: [0.01, 0.001, 0.0003] 34 | num_workers: 35 | value: 20 36 | gpus: 37 | value: -1 38 | data_class: 39 | value: EMNISTLines2 40 | model_class: 41 | value: LineCNNTransformer 42 | loss: 43 | value: transformer 44 | precision: 45 | value: 16 46 | -------------------------------------------------------------------------------- /lab6/annotation_interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab6/annotation_interface.png -------------------------------------------------------------------------------- /lab6/readme.md: -------------------------------------------------------------------------------- 1 | # Lab 6: Data Labeling 2 | 3 | In this lab we will annotate some handwriting samples that we collected using the open-source tool Label Studio. 4 | 5 | ## Collection 6 | 7 | Handwritten paragraphs were collected in the FSDL March 2019 class. 8 | 9 | The resulting PDF was stored at https://fsdl-public-assets.s3-us-west-2.amazonaws.com/fsdl_handwriting_20190302.pdf 10 | 11 | Pages were extracted from the PDF using Ghostscript by running 12 | 13 | ```sh 14 | gs -q -dBATCH -dNOPAUSE -sDEVICE=jpeg -r300 -sOutputFile=page-%03d.jpg -f fsdl_handwriting_20190302.pdf 15 | ``` 16 | 17 | and uploaded to S3, with urls like https://fsdl-public-assets.s3-us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-001.jpg 18 | 19 | ## Annotation 20 | 21 | Video recording: https://www.loom.com/share/6c60a353a7fa44d5af75a455f8c3b0c4 22 | 23 | We will use the open source tool Label Studio (https://labelstud.io) to do our annotation. 24 | 25 | Set up Label Studio on your machine using pip or Docker (recommended): 26 | 27 | ```sh 28 | docker run --rm -p 8080:8080 -v `pwd`/my_project:/label-studio/my_project --name label-studio heartexlabs/label-studio:latest label-studio start my_project --init 29 | ``` 30 | 31 | After launching it, import a few page images that you downloaded from S3 into Label Studio. 32 | 33 | Define the labeling interface to be able to both annotate lines and their text content. This is one possibility, a combination of the “Bbox object detection” and “Transcription per region” templates: 34 | 35 | ![annotation interface](./annotation_interface.png) 36 | 37 | You can also experiment with a polygon (vs rectangle) annotation, or figure something else out! 38 | 39 | Then, annotate the few pages you imported. Make sure to annotate each page fully. 40 | 41 | When you’re done, you can export the data from the main Tasks page as a JSON file. 42 | 43 | ## Questions 44 | 45 | Think about the decision you made in annotating. 46 | 47 | Did you use axis-aligned rectangles, or rotated-as-needed rectangles, or polygon annotations? Why? 48 | 49 | Did you transcribe text exactly as you read it, including spelling errors if any, or did you correct the spelling errors? What did you do if you couldn’t quite make some characters out? Why? 50 | -------------------------------------------------------------------------------- /lab7/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab7/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab7/text_recognizer/artifacts/paragraph_text_recognizer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0001, 3 | "gpus": -1, 4 | "loss": "transformer", 5 | "wandb": true, 6 | "fc_dim": 512, 7 | "logger": true, 8 | "tf_dim": 256, 9 | "plugins": "None", 10 | "conv_dim": 32, 11 | "profiler": "None", 12 | "tf_nhead": 4, 13 | "amp_level": "O2", 14 | "benchmark": false, 15 | "max_steps": "None", 16 | "min_steps": "None", 17 | "num_nodes": 1, 18 | "optimizer": "Adam", 19 | "precision": 32, 20 | "test_only": false, 21 | "tf_fc_dim": 256, 22 | "tf_layers": 4, 23 | "tpu_cores": "_gpus_arg_default", 24 | "batch_size": 16, 25 | "data_class": "IAMOriginalAndSyntheticParagraphs", 26 | "max_epochs": "None", 27 | "min_epochs": "None", 28 | "tf_dropout": 0.4, 29 | "accelerator": "ddp", 30 | "amp_backend": "native", 31 | "model_class": "ResnetTransformer", 32 | "num_workers": 24, 33 | "augment_data": "true", 34 | "auto_lr_find": false, 35 | "fast_dev_run": false, 36 | "window_width": 16, 37 | "deterministic": false, 38 | "num_processes": 1, 39 | "window_stride": 8, 40 | "log_gpu_memory": "None", 41 | "sync_batchnorm": false, 42 | "load_checkpoint": "None", 43 | "overfit_batches": 0, 44 | "track_grad_norm": -1, 45 | "weights_summary": "top", 46 | "auto_select_gpus": false, 47 | "default_root_dir": "None", 48 | "one_cycle_max_lr": "None", 49 | "process_position": 0, 50 | "terminate_on_nan": true, 51 | "gradient_clip_val": 0, 52 | "limit_val_batches": 1, 53 | "log_every_n_steps": 50, 54 | "weights_save_path": "None", 55 | "limit_test_batches": 1, 56 | "val_check_interval": 1, 57 | "checkpoint_callback": true, 58 | "distributed_backend": "None", 59 | "enable_pl_optimizer": "None", 60 | "limit_output_length": false, 61 | "limit_train_batches": 1, 62 | "move_metrics_to_cpu": false, 63 | "replace_sampler_ddp": true, 64 | "num_sanity_val_steps": 2, 65 | "truncated_bptt_steps": "None", 66 | "auto_scale_batch_size": false, 67 | "limit_predict_batches": 1, 68 | "one_cycle_total_steps": 100, 69 | "prepare_data_per_node": true, 70 | "stochastic_weight_avg": false, 71 | "automatic_optimization": "None", 72 | "resume_from_checkpoint": "None", 73 | "accumulate_grad_batches": 4, 74 | "check_val_every_n_epoch": 10, 75 | "flush_logs_every_n_steps": 100, 76 | "multiple_trainloader_mode": "max_size_cycle", 77 | "progress_bar_refresh_rate": "None", 78 | "reload_dataloaders_every_epoch": false 79 | } -------------------------------------------------------------------------------- /lab7/text_recognizer/artifacts/paragraph_text_recognizer/model.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ef4b7ca47edaf2e7be1eeb6cccbd1c0df864f43de564ea3a668f5019751dffe6 3 | size 546255010 4 | -------------------------------------------------------------------------------- /lab7/text_recognizer/artifacts/paragraph_text_recognizer/run_command.txt: -------------------------------------------------------------------------------- 1 | python training/run_experiment.py --wandb --gpus=-1 --data_class=IAMOriginalAndSyntheticParagraphs --model_class=ResnetTransformer --loss=transformer --batch_size=16 --check_val_every_n_epoch=10 --terminate_on_nan=1 --num_workers=24 --accelerator=ddp --lr=0.0001 --accumulate_grad_batches=4 -------------------------------------------------------------------------------- /lab7/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | # Hide lines below until Lab 5 12 | from .emnist_lines2 import EMNISTLines2 13 | from .iam_lines import IAMLines 14 | 15 | # Hide lines above until Lab 5 16 | 17 | # Hide lines below until Lab 7 18 | from .iam_paragraphs import IAMParagraphs 19 | from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs 20 | 21 | # Hide lines above until Lab 7 22 | 23 | -------------------------------------------------------------------------------- /lab7/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab7/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab7/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab7/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | # Hide lines below until Lab 3 4 | from .ctc import CTCLitModel 5 | 6 | # Hide lines above until Lab 3 7 | # Hide lines below until Lab 4 8 | from .transformer import TransformerLitModel 9 | 10 | # Hide lines above until Lab 4 11 | -------------------------------------------------------------------------------- /lab7/text_recognizer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import editdistance 6 | 7 | 8 | class CharacterErrorRate(pl.metrics.Metric): 9 | """Character error rate metric, computed using Levenshtein distance.""" 10 | 11 | def __init__(self, ignore_tokens: Sequence[int], *args): 12 | super().__init__(*args) 13 | self.ignore_tokens = set(ignore_tokens) 14 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") # pylint: disable=not-callable 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # pylint: disable=not-callable 16 | self.error: torch.Tensor 17 | self.total: torch.Tensor 18 | 19 | def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: 20 | N = preds.shape[0] 21 | for ind in range(N): 22 | pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens] 23 | target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens] 24 | distance = editdistance.distance(pred, target) 25 | error = distance / max(len(pred), len(target)) 26 | self.error = self.error + error 27 | self.total = self.total + N 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.error / self.total 31 | 32 | 33 | def test_character_error_rate(): 34 | metric = CharacterErrorRate([0, 1]) 35 | X = torch.tensor( # pylint: disable=not-callable 36 | [ 37 | [0, 2, 2, 3, 3, 1], # error will be 0 38 | [0, 2, 1, 1, 1, 1], # error will be .75 39 | [0, 2, 2, 4, 4, 1], # error will be .5 40 | ] 41 | ) 42 | Y = torch.tensor( # pylint: disable=not-callable 43 | [ 44 | [0, 2, 2, 3, 3, 1], 45 | [0, 2, 2, 3, 3, 1], 46 | [0, 2, 2, 3, 3, 1], 47 | ] 48 | ) 49 | metric(X, Y) 50 | print(metric.compute()) 51 | assert metric.compute() == sum([0, 0.75, 0.5]) / 3 52 | 53 | 54 | if __name__ == "__main__": 55 | test_character_error_rate() 56 | -------------------------------------------------------------------------------- /lab7/text_recognizer/lit_models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | import wandb 4 | except ModuleNotFoundError: 5 | pass 6 | 7 | 8 | from .metrics import CharacterErrorRate 9 | from .base import BaseLitModel 10 | 11 | 12 | class TransformerLitModel(BaseLitModel): # pylint: disable=too-many-ancestors 13 | """ 14 | Generic PyTorch-Lightning class that must be initialized with a PyTorch module. 15 | 16 | The module must take x, y as inputs, and have a special predict() method. 17 | """ 18 | 19 | def __init__(self, model, args=None): 20 | super().__init__(model, args) 21 | 22 | self.mapping = self.model.data_config["mapping"] 23 | inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} 24 | start_index = inverse_mapping[""] 25 | end_index = inverse_mapping[""] 26 | padding_index = inverse_mapping["

"] 27 | 28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=padding_index) 29 | 30 | ignore_tokens = [start_index, end_index, padding_index] 31 | self.val_cer = CharacterErrorRate(ignore_tokens) 32 | self.test_cer = CharacterErrorRate(ignore_tokens) 33 | 34 | def forward(self, x): 35 | return self.model.predict(x) 36 | 37 | def training_step(self, batch, batch_idx): # pylint: disable=unused-argument 38 | x, y = batch 39 | logits = self.model(x, y[:, :-1]) 40 | loss = self.loss_fn(logits, y[:, 1:]) 41 | self.log("train_loss", loss) 42 | return loss 43 | 44 | def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument 45 | x, y = batch 46 | logits = self.model(x, y[:, :-1]) 47 | loss = self.loss_fn(logits, y[:, 1:]) 48 | self.log("val_loss", loss, prog_bar=True) 49 | 50 | pred = self.model.predict(x) 51 | # Hide lines below until Lab 5 52 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 53 | try: 54 | self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 55 | except AttributeError: 56 | pass 57 | # Hide lines above until Lab 5 58 | self.val_cer(pred, y) 59 | self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) 60 | 61 | def test_step(self, batch, batch_idx): # pylint: disable=unused-argument 62 | x, y = batch 63 | pred = self.model.predict(x) 64 | # Hide lines below until Lab 5 65 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 66 | try: 67 | self.logger.experiment.log({"test_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 68 | except AttributeError: 69 | pass 70 | # Hide lines above until Lab 5 71 | self.test_cer(pred, y) 72 | self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) 73 | -------------------------------------------------------------------------------- /lab7/text_recognizer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def first_element(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: 7 | """ 8 | Return indices of first occurence of element in x. If not found, return length of x along dim. 9 | 10 | Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 11 | 12 | Examples 13 | -------- 14 | >>> first_element(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3) 15 | tensor([2, 1, 3]) 16 | """ 17 | nonz = x == element 18 | ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices 19 | ind[ind == 0] = x.shape[dim] 20 | return ind 21 | -------------------------------------------------------------------------------- /lab7/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | # Hide lines below until Lab 3 9 | from .line_cnn_simple import LineCNNSimple 10 | from .line_cnn import LineCNN 11 | from .line_cnn_lstm import LineCNNLSTM 12 | 13 | # Hide lines above until Lab 3 14 | 15 | # Hide lines below until Lab 4 16 | from .line_cnn_transformer import LineCNNTransformer 17 | 18 | # Hide lines above until Lab 4 19 | 20 | # Hide lines below until Lab 7 21 | from .resnet_transformer import ResnetTransformer 22 | 23 | # Hide lines above until Lab 7 24 | -------------------------------------------------------------------------------- /lab7/text_recognizer/models/line_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .line_cnn import LineCNN 7 | 8 | LSTM_DIM = 512 9 | LSTM_LAYERS = 1 10 | LSTM_DROPOUT = 0.2 11 | 12 | 13 | class LineCNNLSTM(nn.Module): 14 | """Process the line through a CNN and process the resulting sequence through LSTM layers.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.data_config = data_config 23 | self.args = vars(args) if args is not None else {} 24 | 25 | num_classes = len(data_config["mapping"]) 26 | lstm_dim = self.args.get("lstm_dim", LSTM_DIM) 27 | lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS) 28 | lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT) 29 | 30 | self.line_cnn = LineCNN(data_config=data_config, args=args) 31 | # LineCNN outputs (B, C, S) log probs, with C == num_classes 32 | 33 | self.lstm = nn.LSTM( 34 | input_size=num_classes, 35 | hidden_size=lstm_dim, 36 | num_layers=lstm_layers, 37 | dropout=lstm_dropout, 38 | bidirectional=True, 39 | ) 40 | self.fc = nn.Linear(lstm_dim, num_classes) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Parameters 45 | ---------- 46 | x 47 | (B, H, W) input image 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 53 | S can be computed from W and CHAR_WIDTH 54 | C is num_classes 55 | """ 56 | x = self.line_cnn(x) # -> (B, C, S) 57 | B, _C, S = x.shape 58 | x = x.permute(2, 0, 1) # -> (S, B, C) 59 | 60 | x, _ = self.lstm(x) # -> (S, B, 2 * H) where H is lstm_dim 61 | 62 | # Sum up both directions of the LSTM: 63 | x = x.view(S, B, 2, -1).sum(dim=2) # -> (S, B, H) 64 | 65 | x = self.fc(x) # -> (S, B, C) 66 | 67 | return x.permute(1, 2, 0) # -> (B, C, S) 68 | 69 | @staticmethod 70 | def add_to_argparse(parser): 71 | LineCNN.add_to_argparse(parser) 72 | parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM) 73 | parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS) 74 | parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT) 75 | return parser 76 | -------------------------------------------------------------------------------- /lab7/text_recognizer/models/line_cnn_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .cnn import CNN, IMAGE_SIZE 9 | 10 | WINDOW_WIDTH = 28 11 | WINDOW_STRIDE = 28 12 | 13 | 14 | class LineCNNSimple(nn.Module): 15 | """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" 16 | 17 | def __init__( 18 | self, 19 | data_config: Dict[str, Any], 20 | args: argparse.Namespace = None, 21 | ) -> None: 22 | super().__init__() 23 | self.args = vars(args) if args is not None else {} 24 | 25 | self.WW = self.args.get("window_width", WINDOW_WIDTH) 26 | self.WS = self.args.get("window_stride", WINDOW_STRIDE) 27 | self.limit_output_length = self.args.get("limit_output_length", False) 28 | 29 | self.num_classes = len(data_config["mapping"]) 30 | self.output_length = data_config["output_dims"][0] 31 | self.cnn = CNN(data_config=data_config, args=args) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Parameters 36 | ---------- 37 | x 38 | (B, C, H, W) input image 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 44 | S can be computed from W and CHAR_WIDTH 45 | C is self.num_classes 46 | """ 47 | B, _C, H, W = x.shape 48 | assert H == IMAGE_SIZE # Make sure we can use our CNN class 49 | 50 | # Compute number of windows 51 | S = math.floor((W - self.WW) / self.WS + 1) 52 | 53 | # NOTE: type_as properly sets device 54 | activations = torch.zeros((B, self.num_classes, S)).type_as(x) 55 | for s in range(S): 56 | start_w = self.WS * s 57 | end_w = start_w + self.WW 58 | window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) 59 | activations[:, :, s] = self.cnn(window) 60 | 61 | if self.limit_output_length: 62 | # S might not match ground truth, so let's only take enough activations as are expected 63 | activations = activations[:, :, : self.output_length] 64 | return activations 65 | 66 | @staticmethod 67 | def add_to_argparse(parser): 68 | CNN.add_to_argparse(parser) 69 | parser.add_argument( 70 | "--window_width", 71 | type=int, 72 | default=WINDOW_WIDTH, 73 | help="Width of the window that will slide over the input image.", 74 | ) 75 | parser.add_argument( 76 | "--window_stride", 77 | type=int, 78 | default=WINDOW_STRIDE, 79 | help="Stride of the window that will slide over the input image.", 80 | ) 81 | parser.add_argument("--limit_output_length", action="store_true", default=False) 82 | return parser 83 | -------------------------------------------------------------------------------- /lab7/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab7/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab7/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab7/training/__init__.py -------------------------------------------------------------------------------- /lab7/training/sweeps/emnist_lines2_line_cnn_transformer.yml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - ${interpreter} 4 | - ${program} 5 | - "--wandb" 6 | - ${args} 7 | program: training/run_experiment.py 8 | method: random 9 | metric: 10 | goal: minimize 11 | name: val_loss 12 | early_terminate: 13 | type: hyperband 14 | min_iter: 20 15 | parameters: 16 | conv_dim: 17 | values: [32, 64] 18 | window_width: 19 | values: [8, 16] 20 | window_stride: 21 | value: 8 22 | fc_dim: 23 | values: [512, 1024] 24 | tf_dim: 25 | values: [128, 256] 26 | tf_fc_dim: 27 | values: [256, 1024] 28 | tf_nhead: 29 | values: [4, 8] 30 | tf_layers: 31 | values: [2, 4, 6] 32 | lr: 33 | values: [0.01, 0.001, 0.0003] 34 | num_workers: 35 | value: 20 36 | gpus: 37 | value: -1 38 | data_class: 39 | value: EMNISTLines2 40 | model_class: 41 | value: LineCNNTransformer 42 | loss: 43 | value: transformer 44 | precision: 45 | value: 16 46 | -------------------------------------------------------------------------------- /lab8/.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | # Without this, pylint will complain that torch does not contain some methods 3 | # https://github.com/pytorch/pytorch/issues/701 4 | generated-members=numpy.*,torch.* 5 | 6 | [MESSAGES CONTROL] 7 | disable= 8 | abstract-method, 9 | arguments-differ, 10 | attribute-defined-outside-init, 11 | duplicate-code, 12 | invalid-name, 13 | fixme, 14 | missing-function-docstring, 15 | missing-module-docstring, 16 | too-few-public-methods 17 | 18 | [FORMAT] 19 | max-line-length = 120 20 | 21 | [DESIGN] 22 | max-attributes = 18 23 | max-args = 8 24 | max-locals = 16 25 | -------------------------------------------------------------------------------- /lab8/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | -------------------------------------------------------------------------------- /lab8/setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | max-line-length = 120 3 | ignore = E203,E402,E501,W503 4 | 5 | [pydocstyle] 6 | convention = numpy 7 | add-ignore = D100,D102,D103,D104,D105,D200,D205,D400 8 | 9 | [mypy] 10 | ignore_missing_imports = True 11 | 12 | [tool:pytest] 13 | addopts = --doctest-modules 14 | -------------------------------------------------------------------------------- /lab8/tasks/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | echo "safety (failure is tolerated)" 8 | FILE=requirements/prod.txt 9 | if [ -f "$FILE" ]; then 10 | # We're in the main repo 11 | safety check -r requirements/prod.txt -r requirements/dev.txt 12 | else 13 | # We're in the labs repo 14 | safety check -r ../requirements/prod.txt -r ../requirements/dev.txt 15 | fi 16 | 17 | echo "pylint" 18 | pylint text_recognizer training || FAILURE=true 19 | 20 | echo "pycodestyle" 21 | pycodestyle text_recognizer training || FAILURE=true 22 | 23 | echo "pydocstyle" 24 | pydocstyle text_recognizer training || FAILURE=true 25 | 26 | echo "mypy" 27 | mypy text_recognizer training || FAILURE=true 28 | 29 | echo "bandit" 30 | bandit -ll -r {text_recognizer,training} || FAILURE=true 31 | 32 | echo "shellcheck" 33 | find . -name "*.sh" -print0 | xargs -0 shellcheck || FAILURE=true 34 | 35 | if [ "$FAILURE" = true ]; then 36 | echo "Linting failed" 37 | exit 1 38 | fi 39 | echo "Linting passed" 40 | exit 0 41 | -------------------------------------------------------------------------------- /lab8/tasks/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | ./training/tests/test_run_experiment.sh || FAILURE=true 8 | pytest -s . || FAILURE=true 9 | 10 | if [ "$FAILURE" = true ]; then 11 | echo "Tests failed" 12 | exit 1 13 | fi 14 | echo "Tests passed" 15 | exit 0 16 | -------------------------------------------------------------------------------- /lab8/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab8/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab8/text_recognizer/artifacts/paragraph_text_recognizer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0001, 3 | "gpus": -1, 4 | "loss": "transformer", 5 | "wandb": true, 6 | "fc_dim": 512, 7 | "logger": true, 8 | "tf_dim": 256, 9 | "plugins": "None", 10 | "conv_dim": 32, 11 | "profiler": "None", 12 | "tf_nhead": 4, 13 | "amp_level": "O2", 14 | "benchmark": false, 15 | "max_steps": "None", 16 | "min_steps": "None", 17 | "num_nodes": 1, 18 | "optimizer": "Adam", 19 | "precision": 32, 20 | "test_only": false, 21 | "tf_fc_dim": 256, 22 | "tf_layers": 4, 23 | "tpu_cores": "_gpus_arg_default", 24 | "batch_size": 16, 25 | "data_class": "IAMOriginalAndSyntheticParagraphs", 26 | "max_epochs": "None", 27 | "min_epochs": "None", 28 | "tf_dropout": 0.4, 29 | "accelerator": "ddp", 30 | "amp_backend": "native", 31 | "model_class": "ResnetTransformer", 32 | "num_workers": 24, 33 | "augment_data": "true", 34 | "auto_lr_find": false, 35 | "fast_dev_run": false, 36 | "window_width": 16, 37 | "deterministic": false, 38 | "num_processes": 1, 39 | "window_stride": 8, 40 | "log_gpu_memory": "None", 41 | "sync_batchnorm": false, 42 | "load_checkpoint": "None", 43 | "overfit_batches": 0, 44 | "track_grad_norm": -1, 45 | "weights_summary": "top", 46 | "auto_select_gpus": false, 47 | "default_root_dir": "None", 48 | "one_cycle_max_lr": "None", 49 | "process_position": 0, 50 | "terminate_on_nan": true, 51 | "gradient_clip_val": 0, 52 | "limit_val_batches": 1, 53 | "log_every_n_steps": 50, 54 | "weights_save_path": "None", 55 | "limit_test_batches": 1, 56 | "val_check_interval": 1, 57 | "checkpoint_callback": true, 58 | "distributed_backend": "None", 59 | "enable_pl_optimizer": "None", 60 | "limit_output_length": false, 61 | "limit_train_batches": 1, 62 | "move_metrics_to_cpu": false, 63 | "replace_sampler_ddp": true, 64 | "num_sanity_val_steps": 2, 65 | "truncated_bptt_steps": "None", 66 | "auto_scale_batch_size": false, 67 | "limit_predict_batches": 1, 68 | "one_cycle_total_steps": 100, 69 | "prepare_data_per_node": true, 70 | "stochastic_weight_avg": false, 71 | "automatic_optimization": "None", 72 | "resume_from_checkpoint": "None", 73 | "accumulate_grad_batches": 4, 74 | "check_val_every_n_epoch": 10, 75 | "flush_logs_every_n_steps": 100, 76 | "multiple_trainloader_mode": "max_size_cycle", 77 | "progress_bar_refresh_rate": "None", 78 | "reload_dataloaders_every_epoch": false 79 | } -------------------------------------------------------------------------------- /lab8/text_recognizer/artifacts/paragraph_text_recognizer/model.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ef4b7ca47edaf2e7be1eeb6cccbd1c0df864f43de564ea3a668f5019751dffe6 3 | size 546255010 4 | -------------------------------------------------------------------------------- /lab8/text_recognizer/artifacts/paragraph_text_recognizer/run_command.txt: -------------------------------------------------------------------------------- 1 | python training/run_experiment.py --wandb --gpus=-1 --data_class=IAMOriginalAndSyntheticParagraphs --model_class=ResnetTransformer --loss=transformer --batch_size=16 --check_val_every_n_epoch=10 --terminate_on_nan=1 --num_workers=24 --accelerator=ddp --lr=0.0001 --accumulate_grad_batches=4 -------------------------------------------------------------------------------- /lab8/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | # Hide lines below until Lab 5 12 | from .emnist_lines2 import EMNISTLines2 13 | from .iam_lines import IAMLines 14 | 15 | # Hide lines above until Lab 5 16 | 17 | # Hide lines below until Lab 7 18 | from .iam_paragraphs import IAMParagraphs 19 | from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs 20 | 21 | # Hide lines above until Lab 7 22 | 23 | # Hide lines below until Lab 8 24 | from .fake_images import FakeImageData 25 | 26 | # Hide lines above until Lab 8 27 | -------------------------------------------------------------------------------- /lab8/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab8/text_recognizer/data/fake_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fake images dataset. 3 | """ 4 | import argparse 5 | import torch 6 | import torchvision 7 | from text_recognizer.data.base_data_module import BaseDataModule 8 | 9 | 10 | _NUM_SAMPLES = 512 11 | _IMAGE_LEN = 28 12 | _NUM_CLASSES = 10 13 | 14 | 15 | class FakeImageData(BaseDataModule): 16 | """ 17 | Fake images dataset. 18 | """ 19 | 20 | def __init__(self, args: argparse.Namespace = None): 21 | super().__init__(args) 22 | self.num_samples = self.args.get("num_samples", _NUM_SAMPLES) 23 | self.dims = (1, self.args.get("image_height", _IMAGE_LEN), self.args.get("image_width", _IMAGE_LEN)) 24 | 25 | self.num_classes = self.args.get("num_classes", _NUM_CLASSES) 26 | self.output_dims = (self.num_classes, 1) 27 | self.mapping = list(range(0, self.num_classes)) 28 | 29 | @staticmethod 30 | def add_to_argparse(parser): 31 | BaseDataModule.add_to_argparse(parser) 32 | parser.add_argument("--num_samples", type=int, default=_NUM_SAMPLES) 33 | parser.add_argument("--num_classes", type=int, default=_NUM_CLASSES) 34 | parser.add_argument("--image_height", type=int, default=_IMAGE_LEN) 35 | parser.add_argument("--image_width", type=int, default=_IMAGE_LEN) 36 | return parser 37 | 38 | def setup(self, stage: str = None) -> None: 39 | fake_dataset = torchvision.datasets.FakeData( 40 | size=self.num_samples, 41 | image_size=self.dims, 42 | num_classes=self.output_dims[0], 43 | transform=torchvision.transforms.ToTensor(), 44 | ) 45 | val_size = int(self.num_samples * 0.25) 46 | self.data_train, self.data_val, self.data_test = torch.utils.data.random_split( # type: ignore 47 | dataset=fake_dataset, lengths=[self.num_samples - 2 * val_size, val_size, val_size] 48 | ) 49 | -------------------------------------------------------------------------------- /lab8/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab8/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab8/text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py: -------------------------------------------------------------------------------- 1 | """Run validation test for paragraph_text_recognizer module.""" 2 | import os 3 | import argparse 4 | import time 5 | import unittest 6 | import torch 7 | import pytorch_lightning as pl 8 | from text_recognizer.data import IAMParagraphs 9 | from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer 10 | 11 | 12 | _TEST_CHARACTER_ERROR_RATE = 0.17 13 | 14 | 15 | class TestEvaluateParagraphTextRecognizer(unittest.TestCase): 16 | """Evaluate ParagraphTextRecognizer on the IAMParagraphs test dataset.""" 17 | 18 | @torch.no_grad() 19 | def test_evaluate(self): 20 | dataset = IAMParagraphs(argparse.Namespace(batch_size=16, num_workers=10)) 21 | dataset.prepare_data() 22 | dataset.setup() 23 | 24 | text_recog = ParagraphTextRecognizer() 25 | trainer = pl.Trainer(gpus=1) 26 | 27 | start_time = time.time() 28 | metrics = trainer.test(text_recog.lit_model, datamodule=dataset) 29 | end_time = time.time() 30 | 31 | test_cer = round(metrics[0]["test_cer"], 2) 32 | time_taken = round((end_time - start_time) / 60, 2) 33 | 34 | print(f"Character error rate: {test_cer}, time_taken: {time_taken} m") 35 | self.assertEqual(test_cer, _TEST_CHARACTER_ERROR_RATE) 36 | self.assertLess(time_taken, 45) 37 | -------------------------------------------------------------------------------- /lab8/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | # Hide lines below until Lab 3 4 | from .ctc import CTCLitModel 5 | 6 | # Hide lines above until Lab 3 7 | # Hide lines below until Lab 4 8 | from .transformer import TransformerLitModel 9 | 10 | # Hide lines above until Lab 4 11 | -------------------------------------------------------------------------------- /lab8/text_recognizer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import editdistance 6 | 7 | 8 | class CharacterErrorRate(pl.metrics.Metric): 9 | """Character error rate metric, computed using Levenshtein distance.""" 10 | 11 | def __init__(self, ignore_tokens: Sequence[int], *args): 12 | super().__init__(*args) 13 | self.ignore_tokens = set(ignore_tokens) 14 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") # pylint: disable=not-callable 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # pylint: disable=not-callable 16 | self.error: torch.Tensor 17 | self.total: torch.Tensor 18 | 19 | def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: 20 | N = preds.shape[0] 21 | for ind in range(N): 22 | pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens] 23 | target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens] 24 | distance = editdistance.distance(pred, target) 25 | error = distance / max(len(pred), len(target)) 26 | self.error = self.error + error 27 | self.total = self.total + N 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.error / self.total 31 | 32 | 33 | def test_character_error_rate(): 34 | metric = CharacterErrorRate([0, 1]) 35 | X = torch.tensor( # pylint: disable=not-callable 36 | [ 37 | [0, 2, 2, 3, 3, 1], # error will be 0 38 | [0, 2, 1, 1, 1, 1], # error will be .75 39 | [0, 2, 2, 4, 4, 1], # error will be .5 40 | ] 41 | ) 42 | Y = torch.tensor( # pylint: disable=not-callable 43 | [ 44 | [0, 2, 2, 3, 3, 1], 45 | [0, 2, 2, 3, 3, 1], 46 | [0, 2, 2, 3, 3, 1], 47 | ] 48 | ) 49 | metric(X, Y) 50 | print(metric.compute()) 51 | assert metric.compute() == sum([0, 0.75, 0.5]) / 3 52 | 53 | 54 | if __name__ == "__main__": 55 | test_character_error_rate() 56 | -------------------------------------------------------------------------------- /lab8/text_recognizer/lit_models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | import wandb 4 | except ModuleNotFoundError: 5 | pass 6 | 7 | 8 | from .metrics import CharacterErrorRate 9 | from .base import BaseLitModel 10 | 11 | 12 | class TransformerLitModel(BaseLitModel): # pylint: disable=too-many-ancestors 13 | """ 14 | Generic PyTorch-Lightning class that must be initialized with a PyTorch module. 15 | 16 | The module must take x, y as inputs, and have a special predict() method. 17 | """ 18 | 19 | def __init__(self, model, args=None): 20 | super().__init__(model, args) 21 | 22 | self.mapping = self.model.data_config["mapping"] 23 | inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} 24 | start_index = inverse_mapping[""] 25 | end_index = inverse_mapping[""] 26 | padding_index = inverse_mapping["

"] 27 | 28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=padding_index) 29 | 30 | ignore_tokens = [start_index, end_index, padding_index] 31 | self.val_cer = CharacterErrorRate(ignore_tokens) 32 | self.test_cer = CharacterErrorRate(ignore_tokens) 33 | 34 | def forward(self, x): 35 | return self.model.predict(x) 36 | 37 | def training_step(self, batch, batch_idx): # pylint: disable=unused-argument 38 | x, y = batch 39 | logits = self.model(x, y[:, :-1]) 40 | loss = self.loss_fn(logits, y[:, 1:]) 41 | self.log("train_loss", loss) 42 | return loss 43 | 44 | def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument 45 | x, y = batch 46 | logits = self.model(x, y[:, :-1]) 47 | loss = self.loss_fn(logits, y[:, 1:]) 48 | self.log("val_loss", loss, prog_bar=True) 49 | 50 | pred = self.model.predict(x) 51 | # Hide lines below until Lab 5 52 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 53 | try: 54 | self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 55 | except AttributeError: 56 | pass 57 | # Hide lines above until Lab 5 58 | self.val_cer(pred, y) 59 | self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) 60 | 61 | def test_step(self, batch, batch_idx): # pylint: disable=unused-argument 62 | x, y = batch 63 | pred = self.model.predict(x) 64 | # Hide lines below until Lab 5 65 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 66 | try: 67 | self.logger.experiment.log({"test_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 68 | except AttributeError: 69 | pass 70 | # Hide lines above until Lab 5 71 | self.test_cer(pred, y) 72 | self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) 73 | -------------------------------------------------------------------------------- /lab8/text_recognizer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def first_element(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: 7 | """ 8 | Return indices of first occurence of element in x. If not found, return length of x along dim. 9 | 10 | Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 11 | 12 | Examples 13 | -------- 14 | >>> first_element(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3) 15 | tensor([2, 1, 3]) 16 | """ 17 | nonz = x == element 18 | ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices 19 | ind[ind == 0] = x.shape[dim] 20 | return ind 21 | -------------------------------------------------------------------------------- /lab8/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | # Hide lines below until Lab 3 9 | from .line_cnn_simple import LineCNNSimple 10 | from .line_cnn import LineCNN 11 | from .line_cnn_lstm import LineCNNLSTM 12 | 13 | # Hide lines above until Lab 3 14 | 15 | # Hide lines below until Lab 4 16 | from .line_cnn_transformer import LineCNNTransformer 17 | 18 | # Hide lines above until Lab 4 19 | 20 | # Hide lines below until Lab 7 21 | from .resnet_transformer import ResnetTransformer 22 | 23 | # Hide lines above until Lab 7 24 | -------------------------------------------------------------------------------- /lab8/text_recognizer/models/line_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .line_cnn import LineCNN 7 | 8 | LSTM_DIM = 512 9 | LSTM_LAYERS = 1 10 | LSTM_DROPOUT = 0.2 11 | 12 | 13 | class LineCNNLSTM(nn.Module): 14 | """Process the line through a CNN and process the resulting sequence through LSTM layers.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.data_config = data_config 23 | self.args = vars(args) if args is not None else {} 24 | 25 | num_classes = len(data_config["mapping"]) 26 | lstm_dim = self.args.get("lstm_dim", LSTM_DIM) 27 | lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS) 28 | lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT) 29 | 30 | self.line_cnn = LineCNN(data_config=data_config, args=args) 31 | # LineCNN outputs (B, C, S) log probs, with C == num_classes 32 | 33 | self.lstm = nn.LSTM( 34 | input_size=num_classes, 35 | hidden_size=lstm_dim, 36 | num_layers=lstm_layers, 37 | dropout=lstm_dropout, 38 | bidirectional=True, 39 | ) 40 | self.fc = nn.Linear(lstm_dim, num_classes) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Parameters 45 | ---------- 46 | x 47 | (B, H, W) input image 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 53 | S can be computed from W and CHAR_WIDTH 54 | C is num_classes 55 | """ 56 | x = self.line_cnn(x) # -> (B, C, S) 57 | B, _C, S = x.shape 58 | x = x.permute(2, 0, 1) # -> (S, B, C) 59 | 60 | x, _ = self.lstm(x) # -> (S, B, 2 * H) where H is lstm_dim 61 | 62 | # Sum up both directions of the LSTM: 63 | x = x.view(S, B, 2, -1).sum(dim=2) # -> (S, B, H) 64 | 65 | x = self.fc(x) # -> (S, B, C) 66 | 67 | return x.permute(1, 2, 0) # -> (B, C, S) 68 | 69 | @staticmethod 70 | def add_to_argparse(parser): 71 | LineCNN.add_to_argparse(parser) 72 | parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM) 73 | parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS) 74 | parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT) 75 | return parser 76 | -------------------------------------------------------------------------------- /lab8/text_recognizer/models/line_cnn_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .cnn import CNN, IMAGE_SIZE 9 | 10 | WINDOW_WIDTH = 28 11 | WINDOW_STRIDE = 28 12 | 13 | 14 | class LineCNNSimple(nn.Module): 15 | """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" 16 | 17 | def __init__( 18 | self, 19 | data_config: Dict[str, Any], 20 | args: argparse.Namespace = None, 21 | ) -> None: 22 | super().__init__() 23 | self.args = vars(args) if args is not None else {} 24 | 25 | self.WW = self.args.get("window_width", WINDOW_WIDTH) 26 | self.WS = self.args.get("window_stride", WINDOW_STRIDE) 27 | self.limit_output_length = self.args.get("limit_output_length", False) 28 | 29 | self.num_classes = len(data_config["mapping"]) 30 | self.output_length = data_config["output_dims"][0] 31 | self.cnn = CNN(data_config=data_config, args=args) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Parameters 36 | ---------- 37 | x 38 | (B, C, H, W) input image 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 44 | S can be computed from W and CHAR_WIDTH 45 | C is self.num_classes 46 | """ 47 | B, _C, H, W = x.shape 48 | assert H == IMAGE_SIZE # Make sure we can use our CNN class 49 | 50 | # Compute number of windows 51 | S = math.floor((W - self.WW) / self.WS + 1) 52 | 53 | # NOTE: type_as properly sets device 54 | activations = torch.zeros((B, self.num_classes, S)).type_as(x) 55 | for s in range(S): 56 | start_w = self.WS * s 57 | end_w = start_w + self.WW 58 | window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) 59 | activations[:, :, s] = self.cnn(window) 60 | 61 | if self.limit_output_length: 62 | # S might not match ground truth, so let's only take enough activations as are expected 63 | activations = activations[:, :, : self.output_length] 64 | return activations 65 | 66 | @staticmethod 67 | def add_to_argparse(parser): 68 | CNN.add_to_argparse(parser) 69 | parser.add_argument( 70 | "--window_width", 71 | type=int, 72 | default=WINDOW_WIDTH, 73 | help="Width of the window that will slide over the input image.", 74 | ) 75 | parser.add_argument( 76 | "--window_stride", 77 | type=int, 78 | default=WINDOW_STRIDE, 79 | help="Stride of the window that will slide over the input image.", 80 | ) 81 | parser.add_argument("--limit_output_length", action="store_true", default=False) 82 | return parser 83 | -------------------------------------------------------------------------------- /lab8/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab8/text_recognizer/tests/support/paragraphs/a01-077.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab8/text_recognizer/tests/support/paragraphs/a01-077.png -------------------------------------------------------------------------------- /lab8/text_recognizer/tests/support/paragraphs/a01-087.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab8/text_recognizer/tests/support/paragraphs/a01-087.png -------------------------------------------------------------------------------- /lab8/text_recognizer/tests/support/paragraphs/a01-107.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab8/text_recognizer/tests/support/paragraphs/a01-107.png -------------------------------------------------------------------------------- /lab8/text_recognizer/tests/support/paragraphs/a02-046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab8/text_recognizer/tests/support/paragraphs/a02-046.png -------------------------------------------------------------------------------- /lab8/text_recognizer/tests/test_paragraph_text_recognizer.py: -------------------------------------------------------------------------------- 1 | """Test for paragraph_text_recognizer module.""" 2 | import os 3 | import json 4 | from pathlib import Path 5 | import time 6 | import editdistance 7 | from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer 8 | 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 11 | 12 | 13 | _FILE_DIRNAME = Path(__file__).parents[0].resolve() 14 | _SUPPORT_DIRNAME = _FILE_DIRNAME / "support" / "paragraphs" 15 | 16 | # restricting number of samples to prevent CirleCI running out of time 17 | _NUM_MAX_SAMPLES = 2 if os.environ.get("CIRCLECI", False) else 100 18 | 19 | 20 | def test_paragraph_text_recognizer(): 21 | """Test ParagraphTextRecognizer.""" 22 | support_filenames = list(_SUPPORT_DIRNAME.glob("*.png")) 23 | with open(_SUPPORT_DIRNAME / "data_by_file_id.json", "r") as f: 24 | support_data_by_file_id = json.load(f) 25 | 26 | start_time = time.time() 27 | text_recognizer = ParagraphTextRecognizer() 28 | end_time = time.time() 29 | print(f"Time taken to initialize ParagraphTextRecognizer: {round(end_time - start_time, 2)}s") 30 | 31 | for i, support_filename in enumerate(support_filenames): 32 | if i >= _NUM_MAX_SAMPLES: 33 | break 34 | expected_text = support_data_by_file_id[support_filename.stem]["predicted_text"] 35 | start_time = time.time() 36 | predicted_text = _test_paragraph_text_recognizer(support_filename, expected_text, text_recognizer) 37 | end_time = time.time() 38 | time_taken = round(end_time - start_time, 2) 39 | 40 | cer = _character_error_rate(support_data_by_file_id[support_filename.stem]["ground_truth_text"], predicted_text) 41 | print(f"Character error rate is {round(cer, 3)} for file {support_filename.name} (time taken: {time_taken}s)") 42 | 43 | 44 | def _test_paragraph_text_recognizer(image_filename: Path, expected_text: str, text_recognizer: ParagraphTextRecognizer): 45 | """Test ParagraphTextRecognizer on 1 image.""" 46 | predicted_text = text_recognizer.predict(image_filename) 47 | assert predicted_text == expected_text, f"predicted text does not match expected for {image_filename.name}" 48 | return predicted_text 49 | 50 | 51 | def _character_error_rate(str_a: str, str_b: str) -> float: 52 | """Return character error rate.""" 53 | return editdistance.eval(str_a, str_b) / max(len(str_a), len(str_b)) 54 | -------------------------------------------------------------------------------- /lab8/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | 35 | 36 | def compute_sha256(filename: Union[Path, str]): 37 | """Return SHA256 checksum of a file.""" 38 | with open(filename, "rb") as f: 39 | return hashlib.sha256(f.read()).hexdigest() 40 | 41 | 42 | class TqdmUpTo(tqdm): 43 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 44 | 45 | def update_to(self, blocks=1, bsize=1, tsize=None): 46 | """ 47 | Parameters 48 | ---------- 49 | blocks: int, optional 50 | Number of blocks transferred so far [default: 1]. 51 | bsize: int, optional 52 | Size of each block (in tqdm units) [default: 1]. 53 | tsize: int, optional 54 | Total size (in tqdm units). If [default: None] remains unchanged. 55 | """ 56 | if tsize is not None: 57 | self.total = tsize # pylint: disable=attribute-defined-outside-init 58 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 59 | 60 | 61 | def download_url(url, filename): 62 | """Download a file from url to filename, with a progress bar.""" 63 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 64 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 65 | -------------------------------------------------------------------------------- /lab8/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab8/training/__init__.py -------------------------------------------------------------------------------- /lab8/training/sweeps/emnist_lines2_line_cnn_transformer.yml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - ${interpreter} 4 | - ${program} 5 | - "--wandb" 6 | - ${args} 7 | program: training/run_experiment.py 8 | method: random 9 | metric: 10 | goal: minimize 11 | name: val_loss 12 | early_terminate: 13 | type: hyperband 14 | min_iter: 20 15 | parameters: 16 | conv_dim: 17 | values: [32, 64] 18 | window_width: 19 | values: [8, 16] 20 | window_stride: 21 | value: 8 22 | fc_dim: 23 | values: [512, 1024] 24 | tf_dim: 25 | values: [128, 256] 26 | tf_fc_dim: 27 | values: [256, 1024] 28 | tf_nhead: 29 | values: [4, 8] 30 | tf_layers: 31 | values: [2, 4, 6] 32 | lr: 33 | values: [0.01, 0.001, 0.0003] 34 | num_workers: 35 | value: 20 36 | gpus: 37 | value: -1 38 | data_class: 39 | value: EMNISTLines2 40 | model_class: 41 | value: LineCNNTransformer 42 | loss: 43 | value: transformer 44 | precision: 45 | value: 16 46 | -------------------------------------------------------------------------------- /lab8/training/tests/test_run_experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | FAILURE=false 3 | 4 | python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=32 --fc_dim=16 --loss=cross_entropy --num_workers=4 --max_epochs=4 || FAILURE=true 5 | 6 | if [ "$FAILURE" = true ]; then 7 | echo "Test for run_experiment.py failed" 8 | exit 1 9 | fi 10 | echo "Test for run_experiment.py passed" 11 | exit 0 12 | -------------------------------------------------------------------------------- /lab9/.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .*_cache 3 | data 4 | -------------------------------------------------------------------------------- /lab9/.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | # Without this, pylint will complain that torch does not contain some methods 3 | # https://github.com/pytorch/pytorch/issues/701 4 | generated-members=numpy.*,torch.* 5 | 6 | [MESSAGES CONTROL] 7 | disable= 8 | abstract-method, 9 | arguments-differ, 10 | attribute-defined-outside-init, 11 | duplicate-code, 12 | invalid-name, 13 | fixme, 14 | missing-function-docstring, 15 | missing-module-docstring, 16 | too-few-public-methods 17 | 18 | [FORMAT] 19 | max-line-length = 120 20 | 21 | [DESIGN] 22 | max-attributes = 18 23 | max-args = 8 24 | max-locals = 16 25 | -------------------------------------------------------------------------------- /lab9/api_server/Dockerfile: -------------------------------------------------------------------------------- 1 | # The "buster" flavor of the official docker Python image is based on Debian and includes common packages. 2 | FROM python:3.6-buster 3 | 4 | # Create the working directory 5 | RUN set -ex && mkdir /repo 6 | WORKDIR /repo 7 | 8 | # Install Python dependencies 9 | COPY requirements.txt ./requirements.txt 10 | RUN sed -i 's/cu101/cpu/' requirements.txt 11 | RUN pip install --upgrade pip~=21.0.0 12 | RUN pip install -r requirements.txt 13 | 14 | # Copy only the relevant directories 15 | COPY text_recognizer/ ./text_recognizer 16 | COPY api_server/ ./api 17 | 18 | # Run the web server 19 | EXPOSE 8000 20 | ENV PYTHONPATH /repo 21 | CMD python3 /repo/api/app.py 22 | -------------------------------------------------------------------------------- /lab9/api_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/api_server/__init__.py -------------------------------------------------------------------------------- /lab9/api_server/app.py: -------------------------------------------------------------------------------- 1 | """Flask web server serving text_recognizer predictions.""" 2 | import os 3 | import logging 4 | 5 | from flask import Flask, request, jsonify 6 | from PIL import ImageStat 7 | 8 | from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer 9 | import text_recognizer.util as util 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "" # Do not use GPU 12 | 13 | app = Flask(__name__) # pylint: disable=invalid-name 14 | model = ParagraphTextRecognizer() 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | @app.route("/") 18 | def index(): 19 | """Provide simple health check route.""" 20 | return "Hello, world!" 21 | 22 | 23 | @app.route("/v1/predict", methods=["GET", "POST"]) 24 | def predict(): 25 | """Provide main prediction API route. Responds to both GET and POST requests.""" 26 | image = _load_image() 27 | pred = model.predict(image) 28 | image_stat = ImageStat.Stat(image) 29 | logging.info("METRIC image_mean_intensity {}".format(image_stat.mean[0])) 30 | logging.info("METRIC image_area {}".format(image.size[0] * image.size[1])) 31 | logging.info("METRIC pred_length {}".format(len(pred))) 32 | logging.info("pred {}".format(pred)) 33 | return jsonify({"pred": str(pred)}) 34 | 35 | 36 | def _load_image(): 37 | if request.method == "POST": 38 | data = request.get_json() 39 | if data is None: 40 | return "no json received" 41 | return util.read_b64_image(data["image"], grayscale=True) 42 | if request.method == "GET": 43 | image_url = request.args.get("image_url") 44 | if image_url is None: 45 | return "no image_url defined in query string" 46 | logging.info("url {}".format(image_url)) 47 | return util.read_image_pil(image_url, grayscale=True) 48 | raise ValueError("Unsupported HTTP method") 49 | 50 | 51 | def main(): 52 | """Run the app.""" 53 | app.run(host="0.0.0.0", port=8000, debug=False) # nosec 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /lab9/api_server/tests/test_app.py: -------------------------------------------------------------------------------- 1 | """Tests for web app.""" 2 | import os 3 | from pathlib import Path 4 | from unittest import TestCase 5 | import base64 6 | 7 | from api_server.app import app 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 10 | 11 | REPO_DIRNAME = Path(__file__).parents[2].resolve() 12 | SUPPORT_DIRNAME = REPO_DIRNAME / "text_recognizer" / "tests" / "support" / "paragraphs" 13 | FILENAME = SUPPORT_DIRNAME / "a01-077.png" 14 | EXPECTED_PRED = "And, since this is election year in West\nGermany, Dr. Adenauer is in a tough\nspot. Joyce Egginton cables: President\nKennedy at his Washington Press con-\nference admitted he did not know\nwhether America was lagging behind\nRussia in missile power. He said he\nwas waiting for his senior military\naides to come up with the answer on\nFebruary 20." 15 | 16 | 17 | class TestIntegrations(TestCase): 18 | def setUp(self): 19 | self.app = app.test_client() 20 | 21 | def test_index(self): 22 | response = self.app.get("/") 23 | assert response.get_data().decode() == "Hello, world!" 24 | 25 | def test_predict(self): 26 | with open(FILENAME, "rb") as f: 27 | b64_image = base64.b64encode(f.read()) 28 | response = self.app.post("/v1/predict", json={"image": f"data:image/png;base64,{b64_image.decode()}"}) 29 | json_data = response.get_json() 30 | self.assertEqual(json_data["pred"], EXPECTED_PRED) 31 | -------------------------------------------------------------------------------- /lab9/api_serverless/Dockerfile: -------------------------------------------------------------------------------- 1 | # Starting from an official AWS image 2 | FROM amazon/aws-lambda-python:3.6 3 | 4 | # Install Python dependencies 5 | COPY requirements.txt ./requirements.txt 6 | RUN sed -i 's/cu101/cpu/' requirements.txt 7 | RUN pip install --upgrade pip~=21.0.0 8 | RUN pip install -r requirements.txt 9 | 10 | # Copy only the relevant directories and files 11 | COPY text_recognizer/ ./text_recognizer 12 | COPY api_serverless/app.py ./app.py 13 | 14 | CMD ["app.handler"] 15 | -------------------------------------------------------------------------------- /lab9/api_serverless/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/api_serverless/__init__.py -------------------------------------------------------------------------------- /lab9/api_serverless/app.py: -------------------------------------------------------------------------------- 1 | """AWS Lambda function serving text_recognizer predictions.""" 2 | from PIL import ImageStat 3 | 4 | from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer 5 | import text_recognizer.util as util 6 | 7 | model = ParagraphTextRecognizer() 8 | 9 | 10 | def handler(event, _context): 11 | """Provide main prediction API""" 12 | image = _load_image(event) 13 | pred = model.predict(image) 14 | image_stat = ImageStat.Stat(image) 15 | print("METRIC image_mean_intensity {}".format(image_stat.mean[0])) 16 | print("METRIC image_area {}".format(image.size[0] * image.size[1])) 17 | print("METRIC pred_length {}".format(len(pred))) 18 | print("INFO pred {}".format(pred)) 19 | return {"pred": str(pred)} 20 | 21 | 22 | def _load_image(event): 23 | image_url = event.get("image_url") 24 | if image_url is None: 25 | return "no image_url provided in event" 26 | print("INFO url {}".format(image_url)) 27 | return util.read_image_pil(image_url, grayscale=True) 28 | -------------------------------------------------------------------------------- /lab9/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | -------------------------------------------------------------------------------- /lab9/setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | max-line-length = 120 3 | ignore = E203,E402,E501,W503 4 | 5 | [pydocstyle] 6 | convention = numpy 7 | add-ignore = D100,D102,D103,D104,D105,D200,D205,D400 8 | 9 | [mypy] 10 | ignore_missing_imports = True 11 | 12 | [tool:pytest] 13 | addopts = --doctest-modules 14 | -------------------------------------------------------------------------------- /lab9/tasks/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | echo "safety (failure is tolerated)" 8 | FILE=requirements/prod.txt 9 | if [ -f "$FILE" ]; then 10 | # We're in the main repo 11 | safety check -r requirements/prod.txt -r requirements/dev.txt 12 | else 13 | # We're in the labs repo 14 | safety check -r ../requirements/prod.txt -r ../requirements/dev.txt 15 | fi 16 | 17 | echo "pylint" 18 | pylint text_recognizer training || FAILURE=true 19 | 20 | echo "pycodestyle" 21 | pycodestyle text_recognizer training || FAILURE=true 22 | 23 | echo "pydocstyle" 24 | pydocstyle text_recognizer training || FAILURE=true 25 | 26 | echo "mypy" 27 | mypy text_recognizer training || FAILURE=true 28 | 29 | echo "bandit" 30 | bandit -ll -r {text_recognizer,training} || FAILURE=true 31 | 32 | echo "shellcheck" 33 | find . -name "*.sh" -print0 | xargs -0 shellcheck || FAILURE=true 34 | 35 | if [ "$FAILURE" = true ]; then 36 | echo "Linting failed" 37 | exit 1 38 | fi 39 | echo "Linting passed" 40 | exit 0 41 | -------------------------------------------------------------------------------- /lab9/tasks/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | ./training/tests/test_run_experiment.sh || FAILURE=true 8 | pytest -s . || FAILURE=true 9 | 10 | if [ "$FAILURE" = true ]; then 11 | echo "Tests failed" 12 | exit 1 13 | fi 14 | echo "Tests passed" 15 | exit 0 16 | -------------------------------------------------------------------------------- /lab9/text_recognizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/text_recognizer/__init__.py -------------------------------------------------------------------------------- /lab9/text_recognizer/artifacts/paragraph_text_recognizer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0001, 3 | "gpus": -1, 4 | "loss": "transformer", 5 | "wandb": true, 6 | "fc_dim": 512, 7 | "logger": true, 8 | "tf_dim": 256, 9 | "plugins": "None", 10 | "conv_dim": 32, 11 | "profiler": "None", 12 | "tf_nhead": 4, 13 | "amp_level": "O2", 14 | "benchmark": false, 15 | "max_steps": "None", 16 | "min_steps": "None", 17 | "num_nodes": 1, 18 | "optimizer": "Adam", 19 | "precision": 32, 20 | "test_only": false, 21 | "tf_fc_dim": 256, 22 | "tf_layers": 4, 23 | "tpu_cores": "_gpus_arg_default", 24 | "batch_size": 16, 25 | "data_class": "IAMOriginalAndSyntheticParagraphs", 26 | "max_epochs": "None", 27 | "min_epochs": "None", 28 | "tf_dropout": 0.4, 29 | "accelerator": "ddp", 30 | "amp_backend": "native", 31 | "model_class": "ResnetTransformer", 32 | "num_workers": 24, 33 | "augment_data": "true", 34 | "auto_lr_find": false, 35 | "fast_dev_run": false, 36 | "window_width": 16, 37 | "deterministic": false, 38 | "num_processes": 1, 39 | "window_stride": 8, 40 | "log_gpu_memory": "None", 41 | "sync_batchnorm": false, 42 | "load_checkpoint": "None", 43 | "overfit_batches": 0, 44 | "track_grad_norm": -1, 45 | "weights_summary": "top", 46 | "auto_select_gpus": false, 47 | "default_root_dir": "None", 48 | "one_cycle_max_lr": "None", 49 | "process_position": 0, 50 | "terminate_on_nan": true, 51 | "gradient_clip_val": 0, 52 | "limit_val_batches": 1, 53 | "log_every_n_steps": 50, 54 | "weights_save_path": "None", 55 | "limit_test_batches": 1, 56 | "val_check_interval": 1, 57 | "checkpoint_callback": true, 58 | "distributed_backend": "None", 59 | "enable_pl_optimizer": "None", 60 | "limit_output_length": false, 61 | "limit_train_batches": 1, 62 | "move_metrics_to_cpu": false, 63 | "replace_sampler_ddp": true, 64 | "num_sanity_val_steps": 2, 65 | "truncated_bptt_steps": "None", 66 | "auto_scale_batch_size": false, 67 | "limit_predict_batches": 1, 68 | "one_cycle_total_steps": 100, 69 | "prepare_data_per_node": true, 70 | "stochastic_weight_avg": false, 71 | "automatic_optimization": "None", 72 | "resume_from_checkpoint": "None", 73 | "accumulate_grad_batches": 4, 74 | "check_val_every_n_epoch": 10, 75 | "flush_logs_every_n_steps": 100, 76 | "multiple_trainloader_mode": "max_size_cycle", 77 | "progress_bar_refresh_rate": "None", 78 | "reload_dataloaders_every_epoch": false 79 | } -------------------------------------------------------------------------------- /lab9/text_recognizer/artifacts/paragraph_text_recognizer/model.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ef4b7ca47edaf2e7be1eeb6cccbd1c0df864f43de564ea3a668f5019751dffe6 3 | size 546255010 4 | -------------------------------------------------------------------------------- /lab9/text_recognizer/artifacts/paragraph_text_recognizer/run_command.txt: -------------------------------------------------------------------------------- 1 | python training/run_experiment.py --wandb --gpus=-1 --data_class=IAMOriginalAndSyntheticParagraphs --model_class=ResnetTransformer --loss=transformer --batch_size=16 --check_val_every_n_epoch=10 --terminate_on_nan=1 --num_workers=24 --accelerator=ddp --lr=0.0001 --accumulate_grad_batches=4 -------------------------------------------------------------------------------- /lab9/text_recognizer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import BaseDataset 2 | from .base_data_module import BaseDataModule 3 | from .mnist import MNIST 4 | 5 | # Hide lines below until Lab 2 6 | from .emnist import EMNIST 7 | from .emnist_lines import EMNISTLines 8 | 9 | # Hide lines above until Lab 2 10 | 11 | # Hide lines below until Lab 5 12 | from .emnist_lines2 import EMNISTLines2 13 | from .iam_lines import IAMLines 14 | 15 | # Hide lines above until Lab 5 16 | 17 | # Hide lines below until Lab 7 18 | from .iam_paragraphs import IAMParagraphs 19 | from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs 20 | 21 | # Hide lines above until Lab 7 22 | 23 | # Hide lines below until Lab 8 24 | from .fake_images import FakeImageData 25 | 26 | # Hide lines above until Lab 8 27 | -------------------------------------------------------------------------------- /lab9/text_recognizer/data/emnist_essentials.json: -------------------------------------------------------------------------------- 1 | {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} -------------------------------------------------------------------------------- /lab9/text_recognizer/data/fake_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fake images dataset. 3 | """ 4 | import argparse 5 | import torch 6 | import torchvision 7 | from text_recognizer.data.base_data_module import BaseDataModule 8 | 9 | 10 | _NUM_SAMPLES = 512 11 | _IMAGE_LEN = 28 12 | _NUM_CLASSES = 10 13 | 14 | 15 | class FakeImageData(BaseDataModule): 16 | """ 17 | Fake images dataset. 18 | """ 19 | 20 | def __init__(self, args: argparse.Namespace = None): 21 | super().__init__(args) 22 | self.num_samples = self.args.get("num_samples", _NUM_SAMPLES) 23 | self.dims = (1, self.args.get("image_height", _IMAGE_LEN), self.args.get("image_width", _IMAGE_LEN)) 24 | 25 | self.num_classes = self.args.get("num_classes", _NUM_CLASSES) 26 | self.output_dims = (self.num_classes, 1) 27 | self.mapping = list(range(0, self.num_classes)) 28 | 29 | @staticmethod 30 | def add_to_argparse(parser): 31 | BaseDataModule.add_to_argparse(parser) 32 | parser.add_argument("--num_samples", type=int, default=_NUM_SAMPLES) 33 | parser.add_argument("--num_classes", type=int, default=_NUM_CLASSES) 34 | parser.add_argument("--image_height", type=int, default=_IMAGE_LEN) 35 | parser.add_argument("--image_width", type=int, default=_IMAGE_LEN) 36 | return parser 37 | 38 | def setup(self, stage: str = None) -> None: 39 | fake_dataset = torchvision.datasets.FakeData( 40 | size=self.num_samples, 41 | image_size=self.dims, 42 | num_classes=self.output_dims[0], 43 | transform=torchvision.transforms.ToTensor(), 44 | ) 45 | val_size = int(self.num_samples * 0.25) 46 | self.data_train, self.data_val, self.data_test = torch.utils.data.random_split( # type: ignore 47 | dataset=fake_dataset, lengths=[self.num_samples - 2 * val_size, val_size, val_size] 48 | ) 49 | -------------------------------------------------------------------------------- /lab9/text_recognizer/data/mnist.py: -------------------------------------------------------------------------------- 1 | """MNIST DataModule""" 2 | import argparse 3 | 4 | from torch.utils.data import random_split 5 | from torchvision.datasets import MNIST as TorchMNIST 6 | from torchvision import transforms 7 | 8 | from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info 9 | 10 | DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" 11 | 12 | # NOTE: temp fix until https://github.com/pytorch/vision/issues/1938 is resolved 13 | from six.moves import urllib # pylint: disable=wrong-import-position, wrong-import-order 14 | 15 | opener = urllib.request.build_opener() 16 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 17 | urllib.request.install_opener(opener) 18 | 19 | 20 | class MNIST(BaseDataModule): 21 | """ 22 | MNIST DataModule. 23 | Learn more at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 24 | """ 25 | 26 | def __init__(self, args: argparse.Namespace) -> None: 27 | super().__init__(args) 28 | self.data_dir = DOWNLOADED_DATA_DIRNAME 29 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object. 31 | self.output_dims = (1,) 32 | self.mapping = list(range(10)) 33 | 34 | def prepare_data(self, *args, **kwargs) -> None: 35 | """Download train and test MNIST data from PyTorch canonical source.""" 36 | TorchMNIST(self.data_dir, train=True, download=True) 37 | TorchMNIST(self.data_dir, train=False, download=True) 38 | 39 | def setup(self, stage=None) -> None: 40 | """Split into train, val, test, and set dims.""" 41 | mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) 42 | self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) # type: ignore 43 | self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) 44 | 45 | 46 | if __name__ == "__main__": 47 | load_and_print_info(MNIST) 48 | -------------------------------------------------------------------------------- /lab9/text_recognizer/data/sentence_generator.py: -------------------------------------------------------------------------------- 1 | """SentenceGenerator class and supporting functions.""" 2 | import itertools 3 | import re 4 | import string 5 | from typing import Optional 6 | 7 | import nltk 8 | import numpy as np 9 | 10 | from text_recognizer.data.base_data_module import BaseDataModule 11 | 12 | NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" 13 | 14 | 15 | class SentenceGenerator: 16 | """Generate text sentences using the Brown corpus.""" 17 | 18 | def __init__(self, max_length: Optional[int] = None): 19 | self.text = brown_text() 20 | self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] 21 | self.max_length = max_length 22 | 23 | def generate(self, max_length: Optional[int] = None) -> str: 24 | """ 25 | Sample a string from text of the Brown corpus of length at least one word and at most max_length. 26 | """ 27 | if max_length is None: 28 | max_length = self.max_length 29 | if max_length is None: 30 | raise ValueError("Must provide max_length to this method or when making this object.") 31 | 32 | for _ in range(10): # Try several times to generate before actually erroring 33 | try: 34 | first_ind = np.random.randint(0, len(self.word_start_inds) - 1) 35 | start_ind = self.word_start_inds[first_ind] 36 | end_ind_candidates = [] 37 | for ind in range(first_ind + 1, len(self.word_start_inds)): 38 | if self.word_start_inds[ind] - start_ind > max_length: 39 | break 40 | end_ind_candidates.append(self.word_start_inds[ind]) 41 | end_ind = np.random.choice(end_ind_candidates) 42 | sampled_text = self.text[start_ind:end_ind].strip() 43 | return sampled_text 44 | except Exception: # pylint: disable=broad-except 45 | pass 46 | raise RuntimeError("Was not able to generate a valid string") 47 | 48 | 49 | def brown_text(): 50 | """Return a single string with the Brown corpus with all punctuation stripped.""" 51 | sents = load_nltk_brown_corpus() 52 | text = " ".join(itertools.chain.from_iterable(sents)) 53 | text = text.translate({ord(c): None for c in string.punctuation}) 54 | text = re.sub(" +", " ", text) 55 | return text 56 | 57 | 58 | def load_nltk_brown_corpus(): 59 | """Load the Brown corpus using the NLTK library.""" 60 | nltk.data.path.append(NLTK_DATA_DIRNAME) 61 | try: 62 | nltk.corpus.brown.sents() 63 | except LookupError: 64 | NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 65 | nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) 66 | return nltk.corpus.brown.sents() 67 | -------------------------------------------------------------------------------- /lab9/text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py: -------------------------------------------------------------------------------- 1 | """Run validation test for paragraph_text_recognizer module.""" 2 | import os 3 | import argparse 4 | import time 5 | import unittest 6 | import torch 7 | import pytorch_lightning as pl 8 | from text_recognizer.data import IAMParagraphs 9 | from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer 10 | 11 | 12 | _TEST_CHARACTER_ERROR_RATE = 0.17 13 | 14 | 15 | class TestEvaluateParagraphTextRecognizer(unittest.TestCase): 16 | """Evaluate ParagraphTextRecognizer on the IAMParagraphs test dataset.""" 17 | 18 | @torch.no_grad() 19 | def test_evaluate(self): 20 | dataset = IAMParagraphs(argparse.Namespace(batch_size=16, num_workers=10)) 21 | dataset.prepare_data() 22 | dataset.setup() 23 | 24 | text_recog = ParagraphTextRecognizer() 25 | trainer = pl.Trainer(gpus=1) 26 | 27 | start_time = time.time() 28 | metrics = trainer.test(text_recog.lit_model, datamodule=dataset) 29 | end_time = time.time() 30 | 31 | test_cer = round(metrics[0]["test_cer"], 2) 32 | time_taken = round((end_time - start_time) / 60, 2) 33 | 34 | print(f"Character error rate: {test_cer}, time_taken: {time_taken} m") 35 | self.assertEqual(test_cer, _TEST_CHARACTER_ERROR_RATE) 36 | self.assertLess(time_taken, 45) 37 | -------------------------------------------------------------------------------- /lab9/text_recognizer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLitModel 2 | 3 | # Hide lines below until Lab 3 4 | from .ctc import CTCLitModel 5 | 6 | # Hide lines above until Lab 3 7 | # Hide lines below until Lab 4 8 | from .transformer import TransformerLitModel 9 | 10 | # Hide lines above until Lab 4 11 | -------------------------------------------------------------------------------- /lab9/text_recognizer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import editdistance 6 | 7 | 8 | class CharacterErrorRate(pl.metrics.Metric): 9 | """Character error rate metric, computed using Levenshtein distance.""" 10 | 11 | def __init__(self, ignore_tokens: Sequence[int], *args): 12 | super().__init__(*args) 13 | self.ignore_tokens = set(ignore_tokens) 14 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") # pylint: disable=not-callable 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # pylint: disable=not-callable 16 | self.error: torch.Tensor 17 | self.total: torch.Tensor 18 | 19 | def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: 20 | N = preds.shape[0] 21 | for ind in range(N): 22 | pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens] 23 | target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens] 24 | distance = editdistance.distance(pred, target) 25 | error = distance / max(len(pred), len(target)) 26 | self.error = self.error + error 27 | self.total = self.total + N 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.error / self.total 31 | 32 | 33 | def test_character_error_rate(): 34 | metric = CharacterErrorRate([0, 1]) 35 | X = torch.tensor( # pylint: disable=not-callable 36 | [ 37 | [0, 2, 2, 3, 3, 1], # error will be 0 38 | [0, 2, 1, 1, 1, 1], # error will be .75 39 | [0, 2, 2, 4, 4, 1], # error will be .5 40 | ] 41 | ) 42 | Y = torch.tensor( # pylint: disable=not-callable 43 | [ 44 | [0, 2, 2, 3, 3, 1], 45 | [0, 2, 2, 3, 3, 1], 46 | [0, 2, 2, 3, 3, 1], 47 | ] 48 | ) 49 | metric(X, Y) 50 | print(metric.compute()) 51 | assert metric.compute() == sum([0, 0.75, 0.5]) / 3 52 | 53 | 54 | if __name__ == "__main__": 55 | test_character_error_rate() 56 | -------------------------------------------------------------------------------- /lab9/text_recognizer/lit_models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | import wandb 4 | except ModuleNotFoundError: 5 | pass 6 | 7 | 8 | from .metrics import CharacterErrorRate 9 | from .base import BaseLitModel 10 | 11 | 12 | class TransformerLitModel(BaseLitModel): # pylint: disable=too-many-ancestors 13 | """ 14 | Generic PyTorch-Lightning class that must be initialized with a PyTorch module. 15 | 16 | The module must take x, y as inputs, and have a special predict() method. 17 | """ 18 | 19 | def __init__(self, model, args=None): 20 | super().__init__(model, args) 21 | 22 | self.mapping = self.model.data_config["mapping"] 23 | inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} 24 | start_index = inverse_mapping[""] 25 | end_index = inverse_mapping[""] 26 | padding_index = inverse_mapping["

"] 27 | 28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=padding_index) 29 | 30 | ignore_tokens = [start_index, end_index, padding_index] 31 | self.val_cer = CharacterErrorRate(ignore_tokens) 32 | self.test_cer = CharacterErrorRate(ignore_tokens) 33 | 34 | def forward(self, x): 35 | return self.model.predict(x) 36 | 37 | def training_step(self, batch, batch_idx): # pylint: disable=unused-argument 38 | x, y = batch 39 | logits = self.model(x, y[:, :-1]) 40 | loss = self.loss_fn(logits, y[:, 1:]) 41 | self.log("train_loss", loss) 42 | return loss 43 | 44 | def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument 45 | x, y = batch 46 | logits = self.model(x, y[:, :-1]) 47 | loss = self.loss_fn(logits, y[:, 1:]) 48 | self.log("val_loss", loss, prog_bar=True) 49 | 50 | pred = self.model.predict(x) 51 | # Hide lines below until Lab 5 52 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 53 | try: 54 | self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 55 | except AttributeError: 56 | pass 57 | # Hide lines above until Lab 5 58 | self.val_cer(pred, y) 59 | self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) 60 | 61 | def test_step(self, batch, batch_idx): # pylint: disable=unused-argument 62 | x, y = batch 63 | pred = self.model.predict(x) 64 | # Hide lines below until Lab 5 65 | pred_str = "".join(self.mapping[_] for _ in pred[0].tolist() if _ != 3) 66 | try: 67 | self.logger.experiment.log({"test_pred_examples": [wandb.Image(x[0], caption=pred_str)]}) 68 | except AttributeError: 69 | pass 70 | # Hide lines above until Lab 5 71 | self.test_cer(pred, y) 72 | self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) 73 | -------------------------------------------------------------------------------- /lab9/text_recognizer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def first_element(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: 7 | """ 8 | Return indices of first occurence of element in x. If not found, return length of x along dim. 9 | 10 | Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 11 | 12 | Examples 13 | -------- 14 | >>> first_element(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3) 15 | tensor([2, 1, 3]) 16 | """ 17 | nonz = x == element 18 | ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices 19 | ind[ind == 0] = x.shape[dim] 20 | return ind 21 | -------------------------------------------------------------------------------- /lab9/text_recognizer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | 3 | # Hide lines below until Lab 2 4 | from .cnn import CNN 5 | 6 | # Hide lines above until Lab 2 7 | 8 | # Hide lines below until Lab 3 9 | from .line_cnn_simple import LineCNNSimple 10 | from .line_cnn import LineCNN 11 | from .line_cnn_lstm import LineCNNLSTM 12 | 13 | # Hide lines above until Lab 3 14 | 15 | # Hide lines below until Lab 4 16 | from .line_cnn_transformer import LineCNNTransformer 17 | 18 | # Hide lines above until Lab 4 19 | 20 | # Hide lines below until Lab 7 21 | from .resnet_transformer import ResnetTransformer 22 | 23 | # Hide lines above until Lab 7 24 | -------------------------------------------------------------------------------- /lab9/text_recognizer/models/line_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .line_cnn import LineCNN 7 | 8 | LSTM_DIM = 512 9 | LSTM_LAYERS = 1 10 | LSTM_DROPOUT = 0.2 11 | 12 | 13 | class LineCNNLSTM(nn.Module): 14 | """Process the line through a CNN and process the resulting sequence through LSTM layers.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.data_config = data_config 23 | self.args = vars(args) if args is not None else {} 24 | 25 | num_classes = len(data_config["mapping"]) 26 | lstm_dim = self.args.get("lstm_dim", LSTM_DIM) 27 | lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS) 28 | lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT) 29 | 30 | self.line_cnn = LineCNN(data_config=data_config, args=args) 31 | # LineCNN outputs (B, C, S) log probs, with C == num_classes 32 | 33 | self.lstm = nn.LSTM( 34 | input_size=num_classes, 35 | hidden_size=lstm_dim, 36 | num_layers=lstm_layers, 37 | dropout=lstm_dropout, 38 | bidirectional=True, 39 | ) 40 | self.fc = nn.Linear(lstm_dim, num_classes) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Parameters 45 | ---------- 46 | x 47 | (B, H, W) input image 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 53 | S can be computed from W and CHAR_WIDTH 54 | C is num_classes 55 | """ 56 | x = self.line_cnn(x) # -> (B, C, S) 57 | B, _C, S = x.shape 58 | x = x.permute(2, 0, 1) # -> (S, B, C) 59 | 60 | x, _ = self.lstm(x) # -> (S, B, 2 * H) where H is lstm_dim 61 | 62 | # Sum up both directions of the LSTM: 63 | x = x.view(S, B, 2, -1).sum(dim=2) # -> (S, B, H) 64 | 65 | x = self.fc(x) # -> (S, B, C) 66 | 67 | return x.permute(1, 2, 0) # -> (B, C, S) 68 | 69 | @staticmethod 70 | def add_to_argparse(parser): 71 | LineCNN.add_to_argparse(parser) 72 | parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM) 73 | parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS) 74 | parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT) 75 | return parser 76 | -------------------------------------------------------------------------------- /lab9/text_recognizer/models/line_cnn_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .cnn import CNN, IMAGE_SIZE 9 | 10 | WINDOW_WIDTH = 28 11 | WINDOW_STRIDE = 28 12 | 13 | 14 | class LineCNNSimple(nn.Module): 15 | """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" 16 | 17 | def __init__( 18 | self, 19 | data_config: Dict[str, Any], 20 | args: argparse.Namespace = None, 21 | ) -> None: 22 | super().__init__() 23 | self.args = vars(args) if args is not None else {} 24 | 25 | self.WW = self.args.get("window_width", WINDOW_WIDTH) 26 | self.WS = self.args.get("window_stride", WINDOW_STRIDE) 27 | self.limit_output_length = self.args.get("limit_output_length", False) 28 | 29 | self.num_classes = len(data_config["mapping"]) 30 | self.output_length = data_config["output_dims"][0] 31 | self.cnn = CNN(data_config=data_config, args=args) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Parameters 36 | ---------- 37 | x 38 | (B, C, H, W) input image 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | (B, C, S) logits, where S is the length of the sequence and C is the number of classes 44 | S can be computed from W and CHAR_WIDTH 45 | C is self.num_classes 46 | """ 47 | B, _C, H, W = x.shape 48 | assert H == IMAGE_SIZE # Make sure we can use our CNN class 49 | 50 | # Compute number of windows 51 | S = math.floor((W - self.WW) / self.WS + 1) 52 | 53 | # NOTE: type_as properly sets device 54 | activations = torch.zeros((B, self.num_classes, S)).type_as(x) 55 | for s in range(S): 56 | start_w = self.WS * s 57 | end_w = start_w + self.WW 58 | window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) 59 | activations[:, :, s] = self.cnn(window) 60 | 61 | if self.limit_output_length: 62 | # S might not match ground truth, so let's only take enough activations as are expected 63 | activations = activations[:, :, : self.output_length] 64 | return activations 65 | 66 | @staticmethod 67 | def add_to_argparse(parser): 68 | CNN.add_to_argparse(parser) 69 | parser.add_argument( 70 | "--window_width", 71 | type=int, 72 | default=WINDOW_WIDTH, 73 | help="Width of the window that will slide over the input image.", 74 | ) 75 | parser.add_argument( 76 | "--window_stride", 77 | type=int, 78 | default=WINDOW_STRIDE, 79 | help="Stride of the window that will slide over the input image.", 80 | ) 81 | parser.add_argument("--limit_output_length", action="store_true", default=False) 82 | return parser 83 | -------------------------------------------------------------------------------- /lab9/text_recognizer/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FC1_DIM = 1024 10 | FC2_DIM = 128 11 | 12 | 13 | class MLP(nn.Module): 14 | """Simple MLP suitable for recognizing single characters.""" 15 | 16 | def __init__( 17 | self, 18 | data_config: Dict[str, Any], 19 | args: argparse.Namespace = None, 20 | ) -> None: 21 | super().__init__() 22 | self.args = vars(args) if args is not None else {} 23 | 24 | input_dim = np.prod(data_config["input_dims"]) 25 | num_classes = len(data_config["mapping"]) 26 | 27 | fc1_dim = self.args.get("fc1", FC1_DIM) 28 | fc2_dim = self.args.get("fc2", FC2_DIM) 29 | 30 | self.dropout = nn.Dropout(0.5) 31 | self.fc1 = nn.Linear(input_dim, fc1_dim) 32 | self.fc2 = nn.Linear(fc1_dim, fc2_dim) 33 | self.fc3 = nn.Linear(fc2_dim, num_classes) 34 | 35 | def forward(self, x): 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | x = self.dropout(x) 40 | x = self.fc2(x) 41 | x = F.relu(x) 42 | x = self.dropout(x) 43 | x = self.fc3(x) 44 | return x 45 | 46 | @staticmethod 47 | def add_to_argparse(parser): 48 | parser.add_argument("--fc1", type=int, default=1024) 49 | parser.add_argument("--fc2", type=int, default=128) 50 | return parser 51 | -------------------------------------------------------------------------------- /lab9/text_recognizer/tests/support/paragraphs/a01-077.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/text_recognizer/tests/support/paragraphs/a01-077.png -------------------------------------------------------------------------------- /lab9/text_recognizer/tests/support/paragraphs/a01-087.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/text_recognizer/tests/support/paragraphs/a01-087.png -------------------------------------------------------------------------------- /lab9/text_recognizer/tests/support/paragraphs/a01-107.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/text_recognizer/tests/support/paragraphs/a01-107.png -------------------------------------------------------------------------------- /lab9/text_recognizer/tests/support/paragraphs/a02-046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/text_recognizer/tests/support/paragraphs/a02-046.png -------------------------------------------------------------------------------- /lab9/text_recognizer/tests/test_paragraph_text_recognizer.py: -------------------------------------------------------------------------------- 1 | """Test for paragraph_text_recognizer module.""" 2 | import os 3 | import json 4 | from pathlib import Path 5 | import time 6 | import editdistance 7 | from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer 8 | 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 11 | 12 | 13 | _FILE_DIRNAME = Path(__file__).parents[0].resolve() 14 | _SUPPORT_DIRNAME = _FILE_DIRNAME / "support" / "paragraphs" 15 | 16 | # restricting number of samples to prevent CirleCI running out of time 17 | _NUM_MAX_SAMPLES = 2 if os.environ.get("CIRCLECI", False) else 100 18 | 19 | 20 | def test_paragraph_text_recognizer(): 21 | """Test ParagraphTextRecognizer.""" 22 | support_filenames = list(_SUPPORT_DIRNAME.glob("*.png")) 23 | with open(_SUPPORT_DIRNAME / "data_by_file_id.json", "r") as f: 24 | support_data_by_file_id = json.load(f) 25 | 26 | start_time = time.time() 27 | text_recognizer = ParagraphTextRecognizer() 28 | end_time = time.time() 29 | print(f"Time taken to initialize ParagraphTextRecognizer: {round(end_time - start_time, 2)}s") 30 | 31 | for i, support_filename in enumerate(support_filenames): 32 | if i >= _NUM_MAX_SAMPLES: 33 | break 34 | expected_text = support_data_by_file_id[support_filename.stem]["predicted_text"] 35 | start_time = time.time() 36 | predicted_text = _test_paragraph_text_recognizer(support_filename, expected_text, text_recognizer) 37 | end_time = time.time() 38 | time_taken = round(end_time - start_time, 2) 39 | 40 | cer = _character_error_rate(support_data_by_file_id[support_filename.stem]["ground_truth_text"], predicted_text) 41 | print(f"Character error rate is {round(cer, 3)} for file {support_filename.name} (time taken: {time_taken}s)") 42 | 43 | 44 | def _test_paragraph_text_recognizer(image_filename: Path, expected_text: str, text_recognizer: ParagraphTextRecognizer): 45 | """Test ParagraphTextRecognizer on 1 image.""" 46 | predicted_text = text_recognizer.predict(image_filename) 47 | assert predicted_text == expected_text, f"predicted text does not match expected for {image_filename.name}" 48 | return predicted_text 49 | 50 | 51 | def _character_error_rate(str_a: str, str_b: str) -> float: 52 | """Return character error rate.""" 53 | return editdistance.eval(str_a, str_b) / max(len(str_a), len(str_b)) 54 | -------------------------------------------------------------------------------- /lab9/text_recognizer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for text_recognizer module.""" 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | from urllib.request import urlretrieve 6 | import base64 7 | import hashlib 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import smart_open 13 | 14 | 15 | def to_categorical(y, num_classes): 16 | """1-hot encode a tensor.""" 17 | return np.eye(num_classes, dtype="uint8")[y] 18 | 19 | 20 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 21 | with smart_open.open(image_uri, "rb") as image_file: 22 | return read_image_pil_file(image_file, grayscale) 23 | 24 | 25 | def read_image_pil_file(image_file, grayscale=False) -> Image: 26 | with Image.open(image_file) as image: 27 | if grayscale: 28 | image = image.convert(mode="L") 29 | else: 30 | image = image.convert(mode=image.mode) 31 | return image 32 | 33 | 34 | # Hide lines below until Lab 9 35 | def read_b64_image(b64_string, grayscale=False): # pylint: disable=unused-argument 36 | """Load base64-encoded images.""" 37 | try: 38 | _, b64_data = b64_string.split(",") # pylint: disable=unused-variable 39 | image_file = BytesIO(base64.b64decode(b64_data)) 40 | return read_image_pil_file(image_file, grayscale) 41 | except Exception as exception: 42 | raise ValueError("Could not load image from b64 {}: {}".format(b64_string, exception)) from exception 43 | 44 | 45 | # Hide lines above until Lab 9 46 | 47 | 48 | def compute_sha256(filename: Union[Path, str]): 49 | """Return SHA256 checksum of a file.""" 50 | with open(filename, "rb") as f: 51 | return hashlib.sha256(f.read()).hexdigest() 52 | 53 | 54 | class TqdmUpTo(tqdm): 55 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 56 | 57 | def update_to(self, blocks=1, bsize=1, tsize=None): 58 | """ 59 | Parameters 60 | ---------- 61 | blocks: int, optional 62 | Number of blocks transferred so far [default: 1]. 63 | bsize: int, optional 64 | Size of each block (in tqdm units) [default: 1]. 65 | tsize: int, optional 66 | Total size (in tqdm units). If [default: None] remains unchanged. 67 | """ 68 | if tsize is not None: 69 | self.total = tsize # pylint: disable=attribute-defined-outside-init 70 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 71 | 72 | 73 | def download_url(url, filename): 74 | """Download a file from url to filename, with a progress bar.""" 75 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 76 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 77 | -------------------------------------------------------------------------------- /lab9/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/lab9/training/__init__.py -------------------------------------------------------------------------------- /lab9/training/sweeps/emnist_lines2_line_cnn_transformer.yml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - ${interpreter} 4 | - ${program} 5 | - "--wandb" 6 | - ${args} 7 | program: training/run_experiment.py 8 | method: random 9 | metric: 10 | goal: minimize 11 | name: val_loss 12 | early_terminate: 13 | type: hyperband 14 | min_iter: 20 15 | parameters: 16 | conv_dim: 17 | values: [32, 64] 18 | window_width: 19 | values: [8, 16] 20 | window_stride: 21 | value: 8 22 | fc_dim: 23 | values: [512, 1024] 24 | tf_dim: 25 | values: [128, 256] 26 | tf_fc_dim: 27 | values: [256, 1024] 28 | tf_nhead: 29 | values: [4, 8] 30 | tf_layers: 31 | values: [2, 4, 6] 32 | lr: 33 | values: [0.01, 0.001, 0.0003] 34 | num_workers: 35 | value: 20 36 | gpus: 37 | value: -1 38 | data_class: 39 | value: EMNISTLines2 40 | model_class: 41 | value: LineCNNTransformer 42 | loss: 43 | value: transformer 44 | precision: 45 | value: 16 46 | -------------------------------------------------------------------------------- /lab9/training/tests/test_run_experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | FAILURE=false 3 | 4 | python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=32 --fc_dim=16 --loss=cross_entropy --num_workers=4 --max_epochs=4 || FAILURE=true 5 | 6 | if [ "$FAILURE" = true ]; then 7 | echo "Test for run_experiment.py failed" 8 | exit 1 9 | fi 10 | echo "Test for run_experiment.py passed" 11 | exit 0 12 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Full Stack Deep Learning Spring 2021 Labs 2 | 3 | Welcome! 4 | 5 | As part of Full Stack Deep Learning 2021, we will incrementally develop a complete deep learning codebase to understand the content of handwritten paragraphs. 6 | 7 | We will use the modern stack of PyTorch and PyTorch-Ligtning 8 | 9 | We will use the main workhorses of DL today: CNNs, RNNs, and Transformers 10 | 11 | We will manage our experiments using what we believe to be the best tool for the job: Weights & Biases 12 | 13 | We will set up continuous integration system for our codebase using CircleCI 14 | 15 | We will package up the prediction system as a REST API using FastAPI, and deploy it as a Docker container on AWS Lambda. 16 | 17 | We will set up monitoring that alerts us when the incoming data distribution changes. 18 | 19 | Sequence: 20 | 21 | - [Lab Setup](setup/readme.md): Set up our computing environment. 22 | - [Lab 1: Intro](lab1/readme.md): Formulate problem, structure codebase, train an MLP for MNIST. 23 | - [Lab 2: CNNs](lab2/readme.md): Introduce EMNIST, generate synthetic handwritten lines, and train CNNs. 24 | - [Lab 3: RNNs](lab3/readme.md): Using CNN + LSTM with CTC loss for line text recognition. 25 | - [Lab 4: Transformers](lab4/readme.md): Using Transformers for line text recognition. 26 | - [Lab 5: Experiment Management](lab5/readme.md): Real handwriting data, Weights & Biases, and hyperparameter sweeps. 27 | - [Lab 6: Data Labeling](lab6/readme.md): Label our own handwriting data and properly store it. 28 | - [Lab 7: Paragraph Recognition](lab7/readme.md): Train and evaluate whole-paragraph recognition. 29 | - [Lab 8: Continuous Integration](lab8/readme.md): Add continuous linting and testing of our code. 30 | - [Lab 9: Deployment](lab9/readme.md): Run as a REST API locally, then in Docker, then put in production using AWS Lambda. 31 | - [Lab 10: Monitoring](lab10/readme.md): Set up monitoring that alerts us when the incoming data distribution changes. 32 | -------------------------------------------------------------------------------- /requirements/dev.in: -------------------------------------------------------------------------------- 1 | -c prod.txt 2 | bandit 3 | black 4 | editdistance 5 | itermplot 6 | jupyterlab 7 | matplotlib 8 | mypy 9 | nltk 10 | pycodestyle 11 | pydocstyle 12 | pylint 13 | pytest 14 | pyyaml 15 | tornado 16 | safety 17 | scipy 18 | pillow 19 | wandb 20 | -------------------------------------------------------------------------------- /requirements/prod.in: -------------------------------------------------------------------------------- 1 | boltons 2 | editdistance 3 | flask 4 | h5py 5 | numpy 6 | pytorch-lightning 7 | requests 8 | smart_open 9 | toml 10 | torch<1.8 11 | torchvision<0.9 12 | tqdm 13 | -------------------------------------------------------------------------------- /setup/colab_lab1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/setup/colab_lab1.png -------------------------------------------------------------------------------- /setup/colab_runtime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/setup/colab_runtime.png -------------------------------------------------------------------------------- /setup/colab_vscode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/setup/colab_vscode.png -------------------------------------------------------------------------------- /setup/colab_vscode_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-full-stack/fsdl-text-recognizer-2021-labs/174ebbdc065442175d9457b7a97d6e065f3d9cd0/setup/colab_vscode_2.png --------------------------------------------------------------------------------