├── app_gradio ├── __init__.py ├── logo.jpeg ├── README.md ├── util.py ├── Dockerfile ├── tests │ └── test_app.py └── app.py ├── training ├── __init__.py ├── sweep │ ├── process_setup_output.py │ ├── sweep_setup.py │ ├── simple-overfit-sweep.yaml │ ├── sweep.sh │ └── main-sweep.yaml ├── tests │ ├── test_run_experiment.sh │ ├── test_memorize_ddp.sh │ ├── test_memorize_caption.sh │ └── test_model_development.sh ├── util.py ├── test_model.py ├── cleanup_artifacts.py ├── run_experiment.py └── stage_model.py ├── .aws └── config ├── question_answer ├── evaluation │ ├── best_pica_f1.txt │ └── evaluate_pica.py ├── lit_models │ ├── __init__.py │ ├── metrics.py │ ├── util.py │ └── gpt2.py ├── tests │ ├── support │ │ ├── questions │ │ │ ├── question2.txt │ │ │ ├── question3.txt │ │ │ ├── question1.txt │ │ │ └── question.txt │ │ ├── images │ │ │ ├── img.jpg │ │ │ ├── img1.jpg │ │ │ ├── img2.jpg │ │ │ └── img3.jpg │ │ └── data_by_file_id.json │ ├── test_callback_utils.py │ ├── test_answer.py │ └── test_data.py ├── __init__.py ├── models │ ├── __init__.py │ └── vit2gpt2.py ├── metadata │ ├── __init__.py │ ├── shared.py │ └── pica.py ├── callbacks │ ├── __init__.py │ ├── util.py │ ├── optim.py │ ├── model.py │ └── imtotext.py ├── artifacts │ └── run_command.txt ├── data │ ├── __init__.py │ ├── util.py │ ├── base_data_module.py │ └── pica.py ├── stems │ ├── webcam.py │ └── image.py ├── util.py └── answer.py ├── .dockerignore ├── .env.template ├── api_serverless ├── __init__.py ├── Dockerfile └── api.py ├── assets ├── demo.png └── inference_pipeline.png ├── .devcontainer ├── gpu │ └── devcontainer.json ├── devcontainer.json └── gpu-from-scratch │ ├── devcontainer.json │ └── dev-gpu.Dockerfile ├── deploy ├── aws_login.sh ├── cont_deploy.sh ├── aws_lambda.py └── aws_lambda.ipynb ├── tasks ├── test.sh ├── integration_test.sh └── unit_test.sh ├── environment.yml ├── requirements ├── dev-lint.in ├── prod.in ├── dev.in ├── prod.txt └── dev.txt ├── .github └── workflows │ ├── pre-commit.yml │ ├── docker.yml │ ├── overfit.yml │ └── test.yml ├── pyproject.toml ├── load_test ├── locust_http_user.py └── locust.ipynb ├── LICENSE ├── .gitignore ├── Makefile ├── .pre-commit-config.yaml ├── .flake8 ├── monitoring └── monitor.ipynb ├── CONTRIBUTING.md └── README.md /app_gradio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.aws/config: -------------------------------------------------------------------------------- 1 | [default] 2 | region = us-west-1 3 | -------------------------------------------------------------------------------- /question_answer/evaluation/best_pica_f1.txt: -------------------------------------------------------------------------------- 1 | 0.9892508056428697 -------------------------------------------------------------------------------- /question_answer/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt2 import GPT2 2 | -------------------------------------------------------------------------------- /question_answer/tests/support/questions/question2.txt: -------------------------------------------------------------------------------- 1 | Am I a person? -------------------------------------------------------------------------------- /question_answer/tests/support/questions/question3.txt: -------------------------------------------------------------------------------- 1 | What am I chewing? -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .*_cache 3 | data 4 | **/logs 5 | **/lightning_logs 6 | -------------------------------------------------------------------------------- /.env.template: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | BACKEND_URL= 3 | TOKENIZERS_PARALLELISM=false -------------------------------------------------------------------------------- /question_answer/tests/support/questions/question1.txt: -------------------------------------------------------------------------------- 1 | What color is my shirt? -------------------------------------------------------------------------------- /question_answer/tests/support/questions/question.txt: -------------------------------------------------------------------------------- 1 | How many beds are in the room? -------------------------------------------------------------------------------- /api_serverless/__init__.py: -------------------------------------------------------------------------------- 1 | """Cloud function-backed API for question answering.""" 2 | -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/assets/demo.png -------------------------------------------------------------------------------- /question_answer/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules for creating and running a question answerer.""" 2 | -------------------------------------------------------------------------------- /app_gradio/logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/app_gradio/logo.jpeg -------------------------------------------------------------------------------- /assets/inference_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/assets/inference_pipeline.png -------------------------------------------------------------------------------- /question_answer/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models for image captioning.""" 2 | 3 | from .vit2gpt2 import ViT2GPT2 4 | -------------------------------------------------------------------------------- /question_answer/metadata/__init__.py: -------------------------------------------------------------------------------- 1 | """Python definitions of metadata for datasets in question_answer.data.""" 2 | -------------------------------------------------------------------------------- /question_answer/tests/support/images/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/question_answer/tests/support/images/img.jpg -------------------------------------------------------------------------------- /question_answer/tests/support/images/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/question_answer/tests/support/images/img1.jpg -------------------------------------------------------------------------------- /question_answer/tests/support/images/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/question_answer/tests/support/images/img2.jpg -------------------------------------------------------------------------------- /question_answer/tests/support/images/img3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhinh/admirer/HEAD/question_answer/tests/support/images/img3.jpg -------------------------------------------------------------------------------- /question_answer/metadata/shared.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | DATA_DIRNAME = Path(__file__).resolve().parents[2] / "data" 4 | DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" 5 | -------------------------------------------------------------------------------- /question_answer/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ModelSizeLogger 2 | from .optim import LearningRateMonitor 3 | 4 | from . import imtotext 5 | from .imtotext import ImageToTextTableLogger as ImageToTextLogger 6 | from .imtotext import ImageToTextPrintLogger 7 | -------------------------------------------------------------------------------- /.devcontainer/gpu/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "GPU Development Environment", 3 | 4 | "image": "admirer/development:latest", 5 | 6 | "runArgs": [ 7 | "--gpus", 8 | "all" 9 | ], 10 | 11 | "extensions": [ 12 | "ms-python.python" 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /question_answer/artifacts/run_command.txt: -------------------------------------------------------------------------------- 1 | python training/run_experiment.py --wandb --gpus=-1 --data_class=PICa --model_class=ViT2GPT2 --batch_size=16 --check_val_every_n_epoch=10 --terminate_on_nan=1 --num_workers=24 --accelerator=ddp --lr=0.0001 --accumulate_grad_batches=4 2 | -------------------------------------------------------------------------------- /deploy/aws_login.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | AWS_REGION=$(aws configure get region) 3 | AWS_ACCOUNT_ID=$(aws sts get-caller-identity --query Account | sed 's/"//g') 4 | aws --region "$AWS_REGION" ecr get-login-password | docker login --username AWS --password-stdin "${AWS_ACCOUNT_ID}.dkr.ecr.${AWS_REGION}.amazonaws.com" -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "GPU Development Environment", 3 | 4 | "image": "admirer/development:latest", 5 | 6 | "extensions": [ 7 | "ms-python.python" 8 | ], 9 | 10 | "postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}" 11 | } 12 | -------------------------------------------------------------------------------- /tasks/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | ./tasks/unit_test.sh || FAILURE=true 8 | ./tasks/integration_test.sh || FAILURE=true 9 | 10 | if [ "$FAILURE" = true ]; then 11 | echo "Tests failed" 12 | exit 1 13 | fi 14 | echo "Tests passed" 15 | exit 0 16 | -------------------------------------------------------------------------------- /.devcontainer/gpu-from-scratch/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "GPU Development Environment", 3 | 4 | "build": { 5 | "dockerfile": "dev-gpu.Dockerfile" 6 | }, 7 | 8 | "runArgs": [ 9 | "--gpus", 10 | "all" 11 | ], 12 | 13 | "extensions": [ 14 | "ms-python.python" 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /tasks/integration_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | ./training/tests/test_model_development.sh || FAILURE=true 8 | 9 | if [ "$FAILURE" = true ]; then 10 | echo "Integration tests failed" 11 | exit 1 12 | fi 13 | echo "Integration tests passed" 14 | exit 0 15 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: admirer 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.7 # versioned to match Google Colab # version also pinned in Dockerfile 7 | - cudatoolkit=11.3.1 8 | - cudnn=8.3.2 9 | - pip=21.1.3 # versioned to match Google Colab # version also pinned in Dockerfile 10 | -------------------------------------------------------------------------------- /requirements/dev-lint.in: -------------------------------------------------------------------------------- 1 | -c prod.txt 2 | -c dev.txt 3 | bandit 4 | black 5 | darglint 6 | flake8<4 7 | flake8-annotations<2 8 | flake8-bandit 9 | flake8-bugbear 10 | flake8-black 11 | flake8-docstrings 12 | flake8-import-order 13 | mypy==0.960 14 | # mypy version also pinned in .pre-commit-config.yaml 15 | safety 16 | shellcheck-py 17 | types-toml 18 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | # allows this Action to be triggered manually 7 | workflow_dispatch: 8 | 9 | jobs: 10 | pre-commit: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v3 15 | - uses: pre-commit/action@v3.0.0 16 | -------------------------------------------------------------------------------- /app_gradio/README.md: -------------------------------------------------------------------------------- 1 | ## Outside Knowledge Visual Question Answering 2 | 3 | For more on how this application works, 4 | [check out the GitHub repo](https://github.com/andrewhinh/admirer). 5 | 6 | 7 | ### Flagging 8 | 9 | If the model outputs in the top-right are wrong in some way, 10 | let us know by clicking the "flagging" buttons underneath. 11 | 12 | We'll analyze the results and use them to improve the model! 13 | -------------------------------------------------------------------------------- /app_gradio/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for app_gradio module.""" 2 | import base64 3 | from io import BytesIO 4 | 5 | 6 | def encode_b64_image(image, format="png"): 7 | """Encode a PIL image as a base64 string.""" 8 | _buffer = BytesIO() # bytes that live in memory 9 | image.save(_buffer, format=format) # but which we write to like a file 10 | encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8") 11 | return encoded_image 12 | -------------------------------------------------------------------------------- /question_answer/callbacks/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.WARNING) 4 | 5 | 6 | def check_and_warn(logger, attribute, feature): 7 | if not hasattr(logger, attribute): 8 | warn_no_attribute(feature, attribute) 9 | return True 10 | 11 | 12 | def warn_no_attribute(blocked_feature, missing_attribute): 13 | logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.") 14 | -------------------------------------------------------------------------------- /question_answer/callbacks/optim.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | KEY = "optimizer" 4 | 5 | 6 | class LearningRateMonitor(pl.callbacks.LearningRateMonitor): 7 | """Extends Lightning's LearningRateMonitor with a prefix. 8 | 9 | Logs the learning rate during training. See the docs for 10 | pl.callbacks.LearningRateMonitor for details. 11 | """ 12 | 13 | def _add_prefix(self, *args, **kwargs) -> str: 14 | return f"{KEY}/" + super()._add_prefix(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /training/sweep/process_setup_output.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def process(url): 5 | parts = url.split("/") 6 | print(parts[3], parts[-1]) # Entity and Sweep ID 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(add_help=False) 11 | parser.add_argument( 12 | "--url", 13 | help="The project to log the sweep results to.", 14 | ) 15 | args = parser.parse_args() 16 | 17 | process(args.url) 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /question_answer/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Module containing submodules for each dataset. 2 | 3 | Each dataset is defined as a class in that submodule. 4 | 5 | The datasets should have a .config method that returns 6 | any configuration information needed by the model. 7 | 8 | Most datasets define their constants in a submodule 9 | of the metadata module that is parallel to this one in the 10 | hierarchy. 11 | """ 12 | from .util import BaseDataset 13 | from .base_data_module import BaseDataModule 14 | 15 | from .pica import PICa 16 | -------------------------------------------------------------------------------- /question_answer/tests/support/data_by_file_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "img": { 3 | "ground_truth_text": "two", 4 | "predicted_text": "one" 5 | }, 6 | "img1": { 7 | "ground_truth_text": "music", 8 | "predicted_text": "?" 9 | }, 10 | "img2": { 11 | "ground_truth_text": "a water bottle", 12 | "predicted_text": "cigarette and remote" 13 | }, 14 | "img3": { 15 | "ground_truth_text": "guitar songs", 16 | "predicted_text": "i like play guitar" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /training/tests/test_run_experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | echo "running fast_dev_run test of real model class on real data" 8 | python training/run_experiment.py --data_class=PICa --model_class=ViT2GPT2 \ 9 | --batch_size 2 --lr 0.0001 \ 10 | --fast_dev_run --num_sanity_val_steps 0 \ 11 | --num_workers 1 || FAILURE=true 12 | 13 | if [ "$FAILURE" = true ]; then 14 | echo "Test for run_experiment.py failed" 15 | exit 1 16 | fi 17 | echo "Tests for run_experiment.py passed" 18 | exit 0 19 | -------------------------------------------------------------------------------- /requirements/prod.in: -------------------------------------------------------------------------------- 1 | h5py 2 | importlib-metadata>=4.4 3 | mkl-service==2.4.0 4 | numpy 5 | pyngrok 6 | requests 7 | smart_open[s3] 8 | tqdm 9 | # versioned for stability 10 | gradio==3.0.21 11 | # versioned to match Google Colab up to minor 12 | Jinja2>=2.11,<2.12 13 | pillow<7.2 14 | torch>=1.12,<1.13 15 | torchvision>=0.13.0 16 | # versioned to avoid breaking change in minor version update 17 | markupsafe<2.1 18 | # for models 19 | openai==1.1.1 20 | transformers 21 | timm 22 | onnxruntime==1.12.1 23 | python-dotenv 24 | # for continual learning 25 | wandb==0.12.17 -------------------------------------------------------------------------------- /requirements/dev.in: -------------------------------------------------------------------------------- 1 | -c prod.txt 2 | boltons 3 | coverage[toml] 4 | defusedxml 5 | great-expectations 6 | itermplot 7 | ipywidgets 8 | language_tool_python 9 | matplotlib 10 | notebook 11 | nltk 12 | pre-commit 13 | pytest 14 | pytest-cov 15 | scipy 16 | toml 17 | zenml 18 | # versioned to give pip hints 19 | coverage[toml]==6.4 20 | pytest==7.1.1 21 | pytest-cov==3.0.0 22 | # versioned to match Google Colab 23 | seaborn>=0.11,<0.12 24 | # tornado>=5.1,<5.2 # Doesn't work with notebook 25 | # versioned to improve stability 26 | pytorch-lightning==1.6.3 27 | torchmetrics<0.8 -------------------------------------------------------------------------------- /question_answer/evaluation/evaluate_pica.py: -------------------------------------------------------------------------------- 1 | """Run validation test for question_answer module.""" 2 | import unittest 3 | 4 | from question_answer.answer import Pipeline 5 | 6 | 7 | class TestEvaluatePICa(unittest.TestCase): 8 | """Evaluate Caption on the additionally-added PICa examples.""" 9 | 10 | def test_evaluate(self): 11 | pipeline = Pipeline() 12 | return pipeline.evaluate() 13 | 14 | 15 | def main(): 16 | testcase = TestEvaluatePICa() 17 | return testcase.test_evaluate() 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /tasks/unit_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | export WANDB_PROJECT="admirer" 8 | 9 | # unit tests check whether current best model is working, so we stage it 10 | python ./training/stage_model.py --fetch --from_project "$WANDB_PROJECT" || FAILURE=true 11 | # pytest configuration in pyproject.toml 12 | python -m pytest || FAILURE=true 13 | 14 | ./training/tests/test_run_experiment.sh || FAILURE=true 15 | 16 | if [ "$FAILURE" = true ]; then 17 | echo "Unit tests failed" 18 | exit 1 19 | fi 20 | echo "Unit tests passed" 21 | exit 0 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.flake8] # configured in .flake8 2 | [tool.darglint] # configured in .flake8 3 | 4 | [tool.black] 5 | line-length = 120 6 | target-version = ["py37"] 7 | 8 | [tool.mypy] 9 | ignore_missing_imports = true 10 | exclude = ["training/logs"] 11 | 12 | [tool.pytest.ini_options] 13 | markers = [ 14 | "slow: marks a test as slow (deselect with '-m \"not slow\"']", 15 | "data: marks a test as dependent on a data download (deselect with '-m \"not data\"')" 16 | ] 17 | addopts = "--cov training --cov text_recognizer --cov-branch --doctest-modules --ignore training/logs -m 'not data' --ignore-glob **/bootstrap.py" 18 | -------------------------------------------------------------------------------- /api_serverless/Dockerfile: -------------------------------------------------------------------------------- 1 | # Starting from an official AWS image 2 | # Keep any dependencies and versions in this file aligned with the environment.yml and Makefile 3 | FROM public.ecr.aws/lambda/python:3.7 4 | 5 | # Install Python dependencies 6 | COPY requirements/prod.txt ./requirements.txt 7 | RUN pip install --upgrade pip==21.1.3 8 | RUN pip install -r requirements.txt 9 | 10 | # Copy only the relevant directories and files 11 | # note that we use a .dockerignore file to avoid copying logs etc. 12 | COPY question_answer/ ./question_answer 13 | COPY api_serverless/api.py ./api.py 14 | COPY .env .env 15 | 16 | CMD ["api.handler"] 17 | -------------------------------------------------------------------------------- /load_test/locust_http_user.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from locust import constant, HttpUser, task 4 | 5 | 6 | image_url = "question_answer/tests/support/images/img.jpg" 7 | question = "What color is my hair" 8 | 9 | 10 | class AdmirerUser(HttpUser): 11 | """ 12 | Simulated AWS Lambda User 13 | """ 14 | 15 | wait_time = constant(1) 16 | headers = {"Content-type": "application/json"} 17 | payload = json.dumps({"image_url": image_url, "question": question}) 18 | 19 | @task 20 | def predict(self): 21 | response = self.client.post("/", data=self.payload, headers=self.headers) 22 | pred = response.json()["pred"] 23 | -------------------------------------------------------------------------------- /training/sweep/sweep_setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import wandb 5 | import yaml 6 | 7 | wb_api = wandb.Api() 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(add_help=False) 12 | parser.add_argument( 13 | "--project", 14 | help="The project to log the sweep results to.", 15 | ) 16 | parser.add_argument( 17 | "--config", 18 | help="The configuration path to set up the sweep with.", 19 | ) 20 | args = parser.parse_args() 21 | 22 | config = yaml.safe_load(Path(args.config).read_text()) 23 | sweep_id = wandb.sweep(sweep=config, project=args.project) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: docker 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | # rebuild whenever the .devcontainer or the requirements change 8 | paths: 9 | - '.devcontainer/gpu-from-scratch/**' 10 | - 'requirements/**' 11 | # allows workflows to be triggered manually 12 | workflow_dispatch: 13 | 14 | jobs: 15 | docker: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - 19 | name: Login to DockerHub 20 | uses: docker/login-action@v2 21 | with: 22 | username: admirer 23 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 24 | - 25 | name: Build and push 26 | uses: docker/build-push-action@v3 27 | with: 28 | push: true 29 | tags: admirer/development:latest 30 | file: .devcontainer/gpu-from-scratch/dev-gpu.Dockerfile 31 | -------------------------------------------------------------------------------- /question_answer/metadata/pica.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import question_answer.metadata.shared as shared 3 | 4 | 5 | ARTIFACT_PATH = Path(__file__).resolve().parents[2] / "question_answer" / "artifacts" / "answer" 6 | RAW_DATA_DIRNAME = ARTIFACT_PATH / "coco_annotations" 7 | PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "pica" 8 | 9 | NUM_ORIGINAL_EXAMPLES = 9009 10 | NUM_ADDED_EXAMPLES = 1236 11 | NUM_TEST_EXAMPLES = 36 12 | 13 | TRAIN_VAL_SPLIT = 0.9 14 | NUM_TRAINVAL = NUM_ADDED_EXAMPLES - NUM_TEST_EXAMPLES 15 | NUM_TRAIN_EXAMPLES = NUM_TRAINVAL * TRAIN_VAL_SPLIT 16 | NUM_VAL_EXAMPLES = NUM_TRAINVAL * (1 - TRAIN_VAL_SPLIT) 17 | 18 | IMAGE_HEIGHT, IMAGE_WIDTH = 224, 224 # Originally = 600, 800 19 | IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) 20 | 21 | MAX_LABEL_LENGTH = 50 22 | 23 | DIMS = (3, IMAGE_HEIGHT, IMAGE_WIDTH) 24 | OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) 25 | -------------------------------------------------------------------------------- /training/util.py: -------------------------------------------------------------------------------- 1 | """Utilities for model development scripts: training and staging.""" 2 | import argparse 3 | import importlib 4 | 5 | DATA_CLASS_MODULE = "question_answer.data" 6 | MODEL_CLASS_MODULE = "question_answer.models" 7 | 8 | 9 | def import_class(module_and_class_name: str) -> type: 10 | """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" 11 | module_name, class_name = module_and_class_name.rsplit(".", 1) 12 | module = importlib.import_module(module_name) 13 | class_ = getattr(module, class_name) 14 | return class_ 15 | 16 | 17 | def setup_data_and_model_from_args(args: argparse.Namespace): 18 | data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") 19 | model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") 20 | 21 | data = data_class(args) 22 | model = model_class(args=args) 23 | 24 | return data, model 25 | -------------------------------------------------------------------------------- /deploy/cont_deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | while : 3 | do 4 | git fetch 5 | if [ "$(git rev-list HEAD...origin/main --count)" != 0 ]; then 6 | echo "Checking if backend needs to be updated..." 7 | git pull 8 | 9 | python3 ./training/stage_model.py --fetch --from_project admirer 10 | 11 | CURRENT_F1_PATH=./question_answer/evaluation/best_pica_f1.txt 12 | CURRENT_F1_SCORE=$(< "$CURRENT_F1_PATH") 13 | NEW_F1=$(python3 ./question_answer/evaluation/evaluate_pica.py) 14 | if [ "$NEW_F1" \> "$CURRENT_F1_SCORE" ]; then 15 | echo "Updating backend..." 16 | 17 | echo "$NEW_F1" >| "$CURRENT_F1_PATH" 18 | 19 | . ./deploy/aws_login.sh 20 | 21 | python3 deploy/aws_lambda.py 22 | else 23 | echo "No improvement -> no updates made" 24 | 25 | rm -rf ./question_answer/artifacts/answer 26 | fi 27 | fi 28 | done 29 | exit 0 -------------------------------------------------------------------------------- /app_gradio/Dockerfile: -------------------------------------------------------------------------------- 1 | # The "buster" flavor of the official docker Python image is based on Debian and includes common packages. 2 | # Keep any dependencies and versions in this file aligned with the environment.yml and Makefile 3 | FROM python:3.7-buster 4 | 5 | # Create the working directory 6 | # set -x prints commands and set -e causes us to stop on errors 7 | RUN set -ex && mkdir /repo 8 | WORKDIR /repo 9 | 10 | # Install Python dependencies 11 | COPY requirements/prod.txt ./requirements.txt 12 | RUN pip install --upgrade pip==21.1.3 13 | RUN pip install -r requirements.txt 14 | ENV PYTHONPATH ".:" 15 | 16 | # Copy only the relevant directories 17 | # note that we use a .dockerignore file to avoid copying logs etc. 18 | COPY question_answer/ ./question_answer 19 | COPY app_gradio/ ./app_gradio 20 | COPY ./.aws/config ./.aws/config 21 | COPY .env .env 22 | 23 | # Use docker run -it --rm -p$PORT:11700 to run the web server and listen on host $PORT 24 | # add --help top see help for the Python script 25 | ENTRYPOINT ["python3", "app_gradio/app.py"] -------------------------------------------------------------------------------- /training/sweep/simple-overfit-sweep.yaml: -------------------------------------------------------------------------------- 1 | # first we specify what we're sweeping 2 | # we specify a program to run 3 | program: training/run_experiment.py 4 | # we optionally specify how to run it, including setting default arguments 5 | command: 6 | - ${env} 7 | - ${interpreter} 8 | - ${program} 9 | - "--wandb" 10 | - "--overfit_batches" 11 | - "1" 12 | - "--log_every_n_steps" 13 | - "25" 14 | - "--max_epochs" 15 | - "100" 16 | - "--limit_test_batches" 17 | - "0" 18 | - ${args} # these arguments come from the sweep parameters below 19 | 20 | # and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it 21 | method: random # generally, random searches perform well, can also be "grid" or "bayes" 22 | metric: 23 | name: train/loss 24 | goal: minimize 25 | parameters: 26 | # we can also fix some values, just like we set default arguments 27 | gpus: 28 | value: 1 29 | model_class: 30 | value: ViT2GPT2 31 | data_class: 32 | value: PICa -------------------------------------------------------------------------------- /training/tests/test_memorize_ddp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | # constants and CLI args set by aiming for <10 min test on 8xV100 8 | MAX_EPOCHS="${1:-600}" 9 | CRITERION="${2:-0.1}" 10 | 11 | echo "running with configuration tuned on 8xV100" 12 | echo "- note that num_workers > 1 speeds up training but results in multiprocessing errors in terminal" 13 | python ./training/run_experiment.py \ 14 | --data_class=PICa --model_class=ViT2GPT2 \ 15 | --limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \ 16 | --augment_data false --tf_dropout 0.0 \ 17 | --gpus 8 --precision 16 --strategy=ddp_find_unused_parameters_false --num_workers 1 --batch_size 16 --lr 0.0001 \ 18 | --log_every_n_steps 50 --max_epochs "$MAX_EPOCHS" --wandb || FAILURE=true 19 | 20 | python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true 21 | 22 | if [ "$FAILURE" = true ]; then 23 | echo "Overfitting test failed" 24 | exit 1 25 | fi 26 | echo "Overfitting test passed" 27 | exit 0 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrew Hinh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /question_answer/lit_models/metrics.py: -------------------------------------------------------------------------------- 1 | """Special-purpose metrics for tracking our model performance.""" 2 | from pathlib import Path 3 | 4 | import torchmetrics 5 | 6 | BERT_SCORE_PATH = ( 7 | Path(__file__).resolve().parents[2] / "question_answer" / "artifacts" / "answer" / "transformers" / "bert_score" 8 | ) 9 | 10 | 11 | class BertF1Score(torchmetrics.text.bert.BERTScore): 12 | """Character error rate metric, allowing for tokens to be ignored.""" 13 | 14 | def __init__(self, model_type=BERT_SCORE_PATH): 15 | super().__init__(model_type) 16 | 17 | def __call__(self, preds, targets): 18 | f1s = super().__call__(preds, targets)["f1"] 19 | return sum(f1s) / len(f1s) 20 | 21 | 22 | def test_bert_f1_score(): 23 | bert_f1 = BertF1Score() 24 | preds = ["hello there", "general kenobi"] 25 | target = ["hello there", "master kenobi"] 26 | f1 = bert_f1(preds, target) 27 | ex_f1s = [0.9999998807907104, 0.9960542917251587] # On main page of torchmetrics page for BERTScore 28 | assert f1 == sum(ex_f1s) / len(ex_f1s) 29 | 30 | 31 | if __name__ == "__main__": 32 | test_bert_f1_score() 33 | -------------------------------------------------------------------------------- /question_answer/stems/webcam.py: -------------------------------------------------------------------------------- 1 | """PICa Stem class.""" 2 | import torchvision.transforms as transforms 3 | 4 | import question_answer.metadata.pica as metadata 5 | from question_answer.stems.image import ImageStem 6 | 7 | 8 | IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH 9 | IMAGE_SHAPE = metadata.IMAGE_SHAPE 10 | MEAN, STD = 0.5, 0.5 11 | 12 | 13 | class WebcamStem(ImageStem): 14 | """A stem for handling webcam screenshots.""" 15 | 16 | def __init__( 17 | self, 18 | augment=False, 19 | ): 20 | super().__init__() 21 | 22 | if not augment: 23 | self.pil_transforms = transforms.Compose([transforms.Resize(IMAGE_SHAPE)]) 24 | else: 25 | # IMAGE_SHAPE is (600, 800) 26 | self.pil_transforms = transforms.Compose( 27 | [ 28 | transforms.Resize(IMAGE_SHAPE), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.RandomApply([transforms.RandomRotation(degrees=20)], p=0.1), 31 | ] 32 | ) 33 | self.torch_transforms = transforms.Compose([transforms.Normalize(mean=MEAN, std=STD)]) 34 | -------------------------------------------------------------------------------- /question_answer/stems/image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | 5 | class ImageStem: 6 | """A stem for models operating on images. 7 | 8 | Images are presumed to be provided as PIL images, 9 | as is standard for torchvision Datasets. 10 | 11 | Transforms are split into two categories: 12 | pil_transforms, which take in and return PIL images, and 13 | torch_transforms, which take in and return Torch tensors. 14 | 15 | By default, these two transforms are both identities. 16 | In between, the images are mapped to tensors. 17 | 18 | The torch_transforms are wrapped in a torch.nn.Sequential 19 | and so are compatible with torchscript if the underyling 20 | Modules are compatible. 21 | """ 22 | 23 | def __init__(self): 24 | self.pil_transforms = transforms.Compose([]) 25 | self.pil_to_tensor = transforms.ToTensor() 26 | self.torch_transforms = torch.nn.Sequential() 27 | 28 | def __call__(self, img): 29 | img = self.pil_transforms(img) 30 | img = self.pil_to_tensor(img) 31 | 32 | with torch.no_grad(): 33 | img = self.torch_transforms(img) 34 | 35 | return img 36 | -------------------------------------------------------------------------------- /training/tests/test_memorize_caption.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | # tests whether we can achieve a criterion loss 6 | # on a single batch within a certain number of epochs 7 | 8 | FAILURE=false 9 | 10 | # constants and CLI args set by aiming for <5 min test on commodity GPU, 11 | # including data download step 12 | MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash 13 | CRITERION="${2:-1.0}" 14 | 15 | # train on GPU if it's available 16 | GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))') 17 | 18 | python ./training/run_experiment.py \ 19 | --data_class=PICa --model_class=ViT2GPT2 \ 20 | --limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \ 21 | --augment_data false --tf_dropout 0.0 \ 22 | --gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \ 23 | --log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true 24 | 25 | python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true 26 | 27 | if [ "$FAILURE" = true ]; then 28 | echo "Memorization test failed at loss criterion $CRITERION" 29 | exit 1 30 | fi 31 | echo "Memorization test passed at loss criterion $CRITERION" 32 | exit 0 33 | -------------------------------------------------------------------------------- /.github/workflows/overfit.yml: -------------------------------------------------------------------------------- 1 | name: overfit 2 | 3 | on: 4 | # once GPU runners are available, this workflow can be run 5 | # schedule: 6 | # daily, not on hour start, see https://crontab.guru/ 7 | # - cron: '17 1 * * *' 8 | # allows manual triggering of this workflow 9 | workflow_dispatch: 10 | 11 | jobs: 12 | 13 | overfit: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.7 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.7" 22 | - name: Full Python environment cacheing 23 | # see AI2 blogpost for details: https://blog.allenai.org/python-caching-in-github-actions-e9452698e98d 24 | uses: actions/cache@v2 25 | with: 26 | path: ${{ env.pythonLocation }} 27 | key: v1-${{ env.pythonLocation }}-${{ hashFiles('requirements/dev.txt') }}-${{ hashFiles('requirements/prod.txt') }} 28 | - name: Install dependencies with pip 29 | run: | 30 | pip install --quiet -r requirements/prod.txt -r requirements/dev.txt 31 | - name: Run overfitting test 32 | run: | 33 | ./training/tests/test_memorize_caption.sh 34 | env: 35 | PYTHONPATH: . 36 | WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} 37 | WANDB_PROJECT: admirer-ci 38 | -------------------------------------------------------------------------------- /app_gradio/tests/test_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import requests 5 | 6 | from app_gradio import app 7 | from question_answer import util 8 | 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 11 | 12 | 13 | TEST_IMAGE = "question_answer/tests/support/images/img.jpg" 14 | TEST_QUESTION = "question_answer/tests/support/questions/question.txt" 15 | if os.path.exists(TEST_QUESTION): 16 | with open(TEST_QUESTION, "r") as f: 17 | TEST_QUESTION = f.readline() 18 | 19 | 20 | def test_local_run(): 21 | """A quick test to make sure we can build the app and ping the API locally.""" 22 | backend = app.PredictorBackend() 23 | frontend = app.make_frontend(fn=backend.run) 24 | 25 | # run the UI without blocking 26 | frontend.launch(share=False, prevent_thread_lock=True) 27 | local_url = frontend.local_url 28 | get_response = requests.get(local_url) 29 | assert get_response.status_code == 200 30 | 31 | image_b64 = util.encode_b64_image(util.read_image_pil(TEST_IMAGE)) 32 | 33 | local_api = f"{local_url}api/predict" 34 | headers = {"Content-Type": "application/json"} 35 | payload = json.dumps({"data": ["data:image/png;base64," + image_b64, "data:question/str;str," + TEST_QUESTION]}) 36 | post_response = requests.post(local_api, data=payload, headers=headers) 37 | assert "error" not in post_response.json() 38 | assert "data" in post_response.json() 39 | -------------------------------------------------------------------------------- /training/sweep/sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Getting entity and sweep ID" 3 | 4 | 5 | # Setting default values 6 | DEFAULT_PROJECT="admirer-training" 7 | DEFAULT_SWEEP_CONFIG="training/sweep/main-sweep.yaml" 8 | 9 | 10 | # Getting arguments through flags 11 | while getopts p:c: flag 12 | do 13 | case "${flag}" in 14 | p) project=${OPTARG};; 15 | c) config=${OPTARG};; 16 | *);; 17 | esac 18 | done 19 | 20 | # Setting project and config values 21 | if [ -z "${project}" ]; then 22 | PROJECT=$DEFAULT_PROJECT 23 | else 24 | PROJECT=${project} 25 | fi 26 | 27 | if [ -z "${config}" ]; then 28 | SWEEP_CONFIG=$DEFAULT_SWEEP_CONFIG 29 | else 30 | SWEEP_CONFIG=${config} 31 | fi 32 | 33 | 34 | # Getting entity and sweep ID 35 | OUTPUT=$(python training/sweep/sweep_setup.py --project "$PROJECT" --config "$SWEEP_CONFIG") 36 | OUTPUT=$(echo "$OUTPUT" | cut -d' ' -f3 | sed -n '2p') 37 | OUTPUT=$(python training/sweep/process_setup_output.py --url "$OUTPUT") 38 | ENTITY="$(echo "$OUTPUT" | cut -d' ' -f1)" 39 | SWEEP_ID="$(echo "$OUTPUT" | cut -d' ' -f2)" 40 | 41 | # Exporting variables for get access in tmux 42 | export PROJECT="$PROJECT" 43 | export ENTITY="$ENTITY" 44 | export SWEEP_ID="$SWEEP_ID" 45 | 46 | # Start a tmux and for every GPU, change GPU_IDX accordingly, create a new window, and run: 47 | # CUDA_VISIBLE_DEVICES=GPU_IDX wandb agent --project ${PROJECT} --entity ${ENTITY} ${SWEEP_ID} 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/downloaded 3 | data/processed 4 | data/interim 5 | 6 | # Editors 7 | .vscode 8 | *.sw? 9 | *~ 10 | 11 | # Node 12 | node_modules 13 | 14 | # Python 15 | __pycache__ 16 | .pytest_cache 17 | 18 | # notebooks 19 | .ipynb_checkpoints 20 | *.nbconvert*.ipynb 21 | .notebook_test.sh 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # logging 42 | wandb 43 | *.pt 44 | *.ckpt 45 | lightning_logs/ 46 | logs 47 | */training/logs 48 | */training/*sweep.yaml 49 | flagged 50 | 51 | # Misc 52 | .aws/credentials 53 | .DS_Store 54 | .env 55 | .mypy_cache 56 | .coverage* 57 | # /requirements.txt 58 | requirements/dev-lint.txt 59 | bootstrap.py 60 | **/fixme.py 61 | 62 | #ADDED################################### 63 | # Training/context data 64 | /data/downloaded/ 65 | 66 | # Label Environment 67 | /data_manage/label-env/ 68 | 69 | # Model Checkpoint 70 | /training/model.pth 71 | 72 | # Production files 73 | /question_answer/artifacts/answer/coco_annotations/ 74 | /question_answer/artifacts/answer/coco_clip_new/ 75 | /question_answer/artifacts/answer/onnx/ 76 | /question_answer/artifacts/answer/transformers/ 77 | 78 | # OpenAI API Key 79 | .env 80 | 81 | # ZenML local config 82 | .zen/ 83 | 84 | #ADDED################################### -------------------------------------------------------------------------------- /question_answer/tests/test_callback_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for the text_recognizer.callbacks.util module.""" 2 | import random 3 | import string 4 | import tempfile 5 | 6 | import pytorch_lightning as pl 7 | 8 | from question_answer.callbacks.util import check_and_warn 9 | 10 | 11 | def test_check_and_warn_simple(): 12 | """Test the success and failure in the case of a simple class we control.""" 13 | 14 | class Foo: 15 | pass # a class with no special attributes 16 | 17 | letters = string.ascii_lowercase 18 | random_attribute = "".join(random.choices(letters, k=10)) 19 | assert check_and_warn(Foo(), random_attribute, "random feature") 20 | assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects") 21 | 22 | 23 | def test_check_and_warn_tblogger(): 24 | """Test that we return a truthy value when trying to log tables with TensorBoard. 25 | 26 | We added check_and_warn in order to prevent a crash if this happens. 27 | """ 28 | tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory()) 29 | assert check_and_warn(tblogger, "log_table", "tables") 30 | 31 | 32 | def test_check_and_warn_wandblogger(): 33 | """Test that we return a falsy value when we try to log tables with W&B. 34 | 35 | In adding check_and_warn, we don't want to block the feature in the happy path. 36 | """ 37 | wandblogger = pl.loggers.WandbLogger(anonymous=True) 38 | assert not check_and_warn(wandblogger, "log_table", "tables") 39 | -------------------------------------------------------------------------------- /training/sweep/main-sweep.yaml: -------------------------------------------------------------------------------- 1 | # first we specify what we're sweeping 2 | # we specify a program to run 3 | program: training/run_experiment.py 4 | # we optionally specify how to run it, including setting default arguments 5 | command: 6 | - ${env} 7 | - ${interpreter} 8 | - ${program} 9 | - "--wandb" 10 | - "--limit_test_batches" # Experiment setup 11 | - "0" 12 | - "--log_every_n_steps" 13 | - "25" 14 | - "--max_epochs" 15 | - "10" 16 | - "--augment_data" # First things to check if training errors occur 17 | - "True" 18 | - "--num_workers" 19 | - "1" 20 | - ${args} # these arguments come from the sweep parameters below 21 | 22 | # and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it 23 | method: random # generally, random searches perform well, can also be "grid" or "bayes" 24 | metric: 25 | name: validation/loss 26 | goal: minimize 27 | parameters: 28 | batch_size: 29 | values: [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] 30 | one_cycle_max_lr: 31 | max: 1.0 32 | min: 0.000001 33 | top_k: 34 | max: 1000 35 | min: 1 36 | top_p: 37 | max: 1.00 38 | min: 0.01 39 | max_label_length: 40 | max: 100 41 | min: 1 42 | # we can also fix some values, just like we set default arguments 43 | gpus: 44 | value: 1 45 | model_class: 46 | value: ViT2GPT2 47 | data_class: 48 | value: PICa -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Arcane incantation to print all the other targets, from https://stackoverflow.com/a/26339924 2 | help: 3 | @$(MAKE) -pRrq -f $(lastword $(MAKEFILE_LIST)) : 2>/dev/null | awk -v RS= -F: '/^# File/,/^# Finished Make data base/ {if ($$1 !~ "^[#.]") {print $$1}}' | sort | egrep -v -e '^[^[:alnum:]]' -e '^$@$$' 4 | 5 | # Install exact Python and CUDA versions 6 | conda-update: 7 | conda env update --prune -f environment.yml 8 | echo "!!!RUN THE conda activate COMMAND ABOVE RIGHT NOW!!!" 9 | 10 | # Compile and install exact pip packages 11 | pip-tools: 12 | pip install pip-tools==6.3.1 setuptools==59.5.0 13 | pip-compile requirements/prod.in && pip-compile requirements/dev.in 14 | pip-sync requirements/prod.txt requirements/dev.txt 15 | 16 | # Compile and install the requirements for local linting (optional) 17 | pip-tools-lint: 18 | pip install pip-tools==6.3.1 setuptools==59.5.0 19 | pip-compile requirements/prod.in && pip-compile requirements/dev.in && pip-compile requirements/dev-lint.in 20 | pip-sync requirements/prod.txt requirements/dev.txt requirements/dev-lint.txt 21 | 22 | # Bump versions of transitive dependencies 23 | pip-tools-upgrade: 24 | pip install pip-tools==6.3.1 setuptools==59.5.0 25 | pip-compile --upgrade requirements/prod.in && pip-compile --upgrade requirements/dev.in && pip-compile --upgrade requirements/dev-lint.in 26 | pip-sync requirements/prod.txt requirements/dev.txt requirements/dev-lint.txt 27 | 28 | # Example training command 29 | train-pica-vit2gpt2-ddp: 30 | python training/run_experiment.py --max_epochs=10 --gpus=-1 --accelerator=ddp --num_workers=20 --data_class=PICa --model_class=ViT2GPT2 31 | 32 | # Lint 33 | lint: 34 | tasks/lint.sh -------------------------------------------------------------------------------- /.devcontainer/gpu-from-scratch/dev-gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | # use nvidia cuda/cudnn image with miniconda on top 2 | FROM gpuci/miniconda-cuda:11.3-devel-ubuntu18.04 3 | 4 | # update GPG key and install linux development CLI tools 5 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \ 6 | && apt update \ 7 | && apt install -y \ 8 | git \ 9 | make \ 10 | sed \ 11 | tmux \ 12 | vim \ 13 | wget 14 | 15 | # allow history search in terminal 16 | RUN echo "\"\e[A\": history-search-backward" > $HOME/.inputrc && echo "\"\e[B\": history-search-forward" $HOME/.inputrc 17 | 18 | # move into the root user's home directory 19 | WORKDIR /root 20 | 21 | # install core Python environment and system packages 22 | COPY ./Makefile ./environment.yml ./ 23 | RUN make conda-update 24 | 25 | # switch to a login shell after cleaning up config: 26 | # removing error-causing line in /root/.profile, see https://www.educative.io/answers/error-mesg-ttyname-failed-inappropriate-ioctl-for-device 27 | # removing environment-setting in /root/.bashrc 28 | RUN sed -i "s/mesg n || true/tty -s \&\& mesg n/" $HOME/.profile 29 | RUN sed -i "s/conda activate base//" $HOME/.bashrc 30 | SHELL ["conda", "run", "--no-capture-output", "-n", "admirer", "/bin/bash", "-c"] 31 | 32 | # install the core requirements, then remove build files 33 | COPY ./requirements ./requirements 34 | RUN make pip-tools && rm -rf ./Makefile ./requirements ./environment.yml 35 | 36 | # add current dir to PYTHONPATH so libraries are importable 37 | ENV PYTHONPATH=.:$PYTHONPATH 38 | 39 | # run all commands inside the conda environment 40 | ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "admirer", "/bin/bash"] 41 | -------------------------------------------------------------------------------- /training/tests/test_model_development.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -uo pipefail 3 | set +e 4 | 5 | FAILURE=false 6 | 7 | CI="${CI:-false}" 8 | if [ "$CI" = false ]; then 9 | export WANDB_PROJECT="admirer-testing" 10 | else 11 | export WANDB_PROJECT="admirer-testing-ci" 12 | fi 13 | 14 | echo "training smaller version of real model class on real data" 15 | python training/run_experiment.py --data_class=PICa --model_class=ViT2GPT2 \ 16 | --batch_size 2 --lr 0.0001 \ 17 | --limit_train_batches 1 --limit_val_batches 1 --limit_test_batches 1 --num_sanity_val_steps 0 \ 18 | --num_workers 1 --wandb || FAILURE=true 19 | 20 | TRAIN_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//") 21 | 22 | echo "staging trained model from run $TRAIN_RUN" 23 | python training/stage_model.py --run "$TRAIN_RUN" --staged_model_name test-dummy --ckpt_alias latest --to_project "$WANDB_PROJECT" --from_project "$WANDB_PROJECT" || FAILURE=true 24 | 25 | echo "fetching staged model" 26 | python training/stage_model.py --fetch --staged_model_name test-dummy --from_project "$WANDB_PROJECT" || FAILURE=true 27 | STAGE_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//") 28 | 29 | if [ "$FAILURE" = true ]; then 30 | echo "Model development test failed" 31 | echo "cleaning up local files" 32 | rm -rf question_answer/artifacts/test-dummy 33 | echo "leaving remote files in place" 34 | exit 1 35 | fi 36 | echo "cleaning up local and remote files" 37 | rm -rf question_answer/artifacts/test-dummy 38 | python training/cleanup_artifacts.py --project "$WANDB_PROJECT" --run_ids "$TRAIN_RUN" "$STAGE_RUN" --all -v 39 | # note: if $TRAIN_RUN and $STAGE_RUN are not set, this will fail. 40 | # that's good because it avoids all artifacts from the project being deleted due to the --all. 41 | echo "Model development test passed" 42 | exit 0 43 | -------------------------------------------------------------------------------- /question_answer/models/vit2gpt2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from transformers import EncoderDecoderModel, GPT2Tokenizer 5 | import torch.nn as nn 6 | 7 | SAVE_PATH = Path(__file__).resolve().parents[2] / "question_answer" / "artifacts" / "answer" / "transformers" 8 | VIT_MODEL = "google/vit-base-patch16-224-in21k" 9 | DISTIL_GPT2 = "distilgpt2" 10 | 11 | 12 | class ViT2GPT2(nn.Module): 13 | """Pass an image through a ViT and decode the resulting embedding with GPT-2.""" 14 | 15 | def __init__( 16 | self, 17 | args: argparse.Namespace = None, 18 | ) -> None: 19 | super().__init__() 20 | # Arguments 21 | self.args = vars(args) if args is not None else {} 22 | self.encoder_path = self.args.get("encoder_path", SAVE_PATH / VIT_MODEL) 23 | self.decoder_and_tokenizer_path = self.args.get("decoder_and_tokenizer_path", SAVE_PATH / DISTIL_GPT2) 24 | 25 | # model 26 | self.vit2gpt2 = EncoderDecoderModel.from_encoder_decoder_pretrained( 27 | self.encoder_path, self.decoder_and_tokenizer_path 28 | ) 29 | 30 | # tokenizer 31 | # make sure GPT2 appends EOS in begin and end 32 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 33 | outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] 34 | return outputs 35 | 36 | GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens 37 | gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.decoder_and_tokenizer_path) 38 | # set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id 39 | gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token 40 | self.gpt2_tokenizer = gpt2_tokenizer 41 | 42 | @staticmethod 43 | def add_to_argparse(parser): 44 | parser.add_argument("--encoder_path", type=Path, default=SAVE_PATH / VIT_MODEL) 45 | parser.add_argument("--decoder_and_tokenizer_path", type=Path, default=SAVE_PATH / DISTIL_GPT2) 46 | return parser 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # a set of useful Python-based pre-commit hooks 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.1.0 5 | hooks: 6 | # list of definitions and supported hooks: https://pre-commit.com/hooks.html 7 | - id: trailing-whitespace # removes any whitespace at the ends of lines 8 | - id: check-toml # check toml syntax by loading all toml files 9 | - id: check-yaml # check yaml syntax by loading all yaml files 10 | - id: check-json # check-json syntax by loading all json files 11 | - id: check-merge-conflict # check for files with merge conflict strings 12 | args: ["--assume-in-merge"] # and run this check even when not explicitly in a merge 13 | - id: check-added-large-files # check that no "large" files have been added 14 | args: ["--maxkb=10240"] # where large means 10MB+, as in Hugging Face's git server 15 | - id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.) 16 | - id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.) 17 | 18 | # black python autoformatting 19 | - repo: https://github.com/psf/black 20 | rev: 22.3.0 21 | hooks: 22 | - id: black 23 | # additional configuration of black in pyproject.toml 24 | 25 | # flake8 python linter with all the fixins 26 | - repo: https://github.com/PyCQA/flake8 27 | rev: 3.9.2 28 | hooks: 29 | - id: flake8 30 | additional_dependencies: 31 | [ 32 | flake8-annotations, 33 | flake8-bandit, 34 | flake8-bugbear, 35 | flake8-black, 36 | flake8-docstrings, 37 | flake8-import-order, 38 | darglint, 39 | mypy==0.960, 40 | pycodestyle, 41 | pydocstyle, 42 | ] 43 | args: ["--config", ".flake8"] 44 | # additional configuration of flake8 and extensions in .flake8 45 | 46 | # shellcheck-py for linting shell files 47 | - repo: https://github.com/shellcheck-py/shellcheck-py 48 | rev: v0.8.0.4 49 | hooks: 50 | - id: shellcheck 51 | -------------------------------------------------------------------------------- /question_answer/data/util.py: -------------------------------------------------------------------------------- 1 | """Base Dataset class.""" 2 | from typing import Any, Callable, Dict, Sequence, Tuple, Union 3 | 4 | from PIL import Image 5 | import torch 6 | 7 | 8 | SequenceOrTensor = Union[Sequence, torch.Tensor] 9 | 10 | 11 | class BaseDataset(torch.utils.data.Dataset): 12 | """Base Dataset class that simply processes data and targets through optional transforms. 13 | 14 | Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset 15 | 16 | Parameters 17 | ---------- 18 | data 19 | commonly these are torch tensors, numpy arrays, or PIL Images 20 | targets 21 | commonly these are torch tensors or numpy arrays 22 | transform 23 | function that takes a datum and returns the same 24 | target_transform 25 | function that takes a target and returns the same 26 | """ 27 | 28 | def __init__( 29 | self, 30 | data: SequenceOrTensor, 31 | targets: SequenceOrTensor, 32 | transform: Callable = None, 33 | target_transform: Callable = None, 34 | ) -> None: 35 | if len(data) != len(targets): 36 | raise ValueError("Data and targets must be of equal length") 37 | super().__init__() 38 | self.data = data 39 | self.targets = targets 40 | self.transform = transform 41 | self.target_transform = target_transform 42 | 43 | def __len__(self) -> int: 44 | """Return length of the dataset.""" 45 | return len(self.data) 46 | 47 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 48 | """ 49 | Return a datum and its target, after processing by transforms. 50 | 51 | Parameters 52 | ---------- 53 | index 54 | 55 | Returns 56 | ------- 57 | (datum, target) 58 | """ 59 | datum, target = self.data[index], self.targets[index] 60 | 61 | if self.transform is not None: 62 | datum = self.transform(datum) 63 | 64 | if self.target_transform is not None: 65 | target = self.target_transform(target) 66 | 67 | return datum, target 68 | -------------------------------------------------------------------------------- /deploy/aws_lambda.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import os 3 | import subprocess 4 | 5 | 6 | def main(): 7 | # Specify ECR repo name 8 | os.environ["LAMBDA_NAME"] = "admirer-backend" 9 | 10 | # Get image URI 11 | proc = subprocess.run( 12 | [ 13 | "aws", 14 | "sts", 15 | "get-caller-identity", 16 | "--query", 17 | "Account", 18 | ], 19 | stdout=subprocess.PIPE, 20 | text=True, 21 | ) 22 | aws_account_id = proc.stdout 23 | proc = subprocess.run( 24 | [ 25 | "aws", 26 | "configure", 27 | "get", 28 | "region", 29 | ], 30 | stdout=subprocess.PIPE, 31 | text=True, 32 | ) 33 | aws_region = proc.stdout 34 | os.environ["AWS_REGION"] = aws_region.strip("\n") 35 | os.environ["AWS_ACCOUNT_ID"] = aws_account_id.replace('"', "").strip("\n") 36 | os.environ["ECR_URI"] = ".".join( 37 | [os.environ["AWS_ACCOUNT_ID"], "dkr", "ecr", os.environ["AWS_REGION"], "amazonaws.com"] 38 | ) 39 | os.environ["IMAGE_URI"] = "/".join([os.environ["ECR_URI"], os.environ["LAMBDA_NAME"]]) 40 | 41 | # Build container image 42 | subprocess.run( 43 | [ 44 | "docker", 45 | "build", 46 | "--no-cache", 47 | "-t", 48 | os.environ["LAMBDA_NAME"], 49 | ".", 50 | "--file", 51 | "./api_serverless/Dockerfile", 52 | ] 53 | ) 54 | 55 | # Upload to the container registry 56 | subprocess.run(["docker", "tag", os.environ["LAMBDA_NAME"] + ":latest", os.environ["IMAGE_URI"] + ":latest"]) 57 | subprocess.run(["docker", "push", os.environ["IMAGE_URI"] + ":latest"]) 58 | 59 | # Update the AWS Lambda function accordingly 60 | proc = subprocess.run( 61 | [ 62 | "aws", 63 | "lambda", 64 | "update-function-code", 65 | "--function-name", 66 | os.environ["LAMBDA_NAME"], 67 | "--image-uri", 68 | os.environ["IMAGE_URI"] + ":latest", 69 | ], 70 | ) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = ANN,B,B9,BLK,C,D,E,F,I,S,W 3 | # only check selected error codes 4 | max-complexity = 12 5 | # C9 - flake8 McCabe Complexity checker -- threshold 6 | max-line-length = 120 7 | # E501 - flake8 -- line length too long, actually handled by black 8 | extend-ignore = 9 | # E W - flake8 PEP style check 10 | E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks 11 | # S - flake8-bandit safety check 12 | S101,S311,S105, # assert removed in bytecode, pRNG not secure, hardcoded password 13 | # ANN - flake8-annotations type annotation check 14 | ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some 15 | # D1 - flake8-docstrings docstring style check 16 | D100,D102,D103,D104,D105, # missing docstrings 17 | # D2 D4 - flake8-docstrings docstring style check 18 | D200,D205,D400,D401, # whitespace issues and first line content 19 | # DAR - flake8-darglint docstring correctness check 20 | DAR103, # mismatched or missing type in docstring 21 | application-import-names = app_gradio,question_answer,tests,training 22 | # flake8-import-order: which names are first party? 23 | import-order-style = google 24 | # flake8-import-order: which import order style guide do we use? 25 | docstring-convention = numpy 26 | # flake8-docstrings: which docstring style guide do we use? 27 | strictness = short 28 | # darglint: how "strict" are we with docstring completeness? 29 | docstring-style = numpy 30 | # darglint: which docstring style guide do we use? 31 | suppress-none-returning = true 32 | # flake8-annotations: do we allow un-annotated Nones in returns? 33 | mypy-init-return = true 34 | # flake8-annotations: do we allow init to have no return annotation? 35 | per-file-ignores = 36 | # list of case-by-case ignores, see files for details 37 | */__init__.py:F401,I 38 | */data/*.py:DAR 39 | data/*.py:F,I 40 | 41 | *question_answer/util.py:DAR101,F401 42 | *app_gradio/app.py:I202 43 | 44 | # Added 45 | *question_answer/*:I100,I202,I201,DAR101,F401,DAR201,DAR002 46 | *question_answer/answer.py:C901 47 | *training/run_experiment.py:I100,I202 48 | *training/sweep/sweep_setup.py:F841 49 | *load_test/locust_http_user.py:F841 50 | *deploy/aws_lambda.py:S404,S607,S602,S603 51 | *app_gradio/app.py:S113 52 | -------------------------------------------------------------------------------- /monitoring/monitor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Model Monitoring with Gradio\n", 8 | "- Run after local Gradio app with flagging enabled is run and a result has been flagged" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "from pathlib import Path\n", 18 | "from IPython.display import display\n", 19 | "\n", 20 | "import pandas as pd\n", 21 | "import sys\n", 22 | "sys.path.append(\"../\")\n", 23 | "\n", 24 | "from question_answer.util import read_image_pil" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "log_path = Path(\"../flagged\") / \"log.csv\"\n", 34 | "\n", 35 | "flagged_df = None\n", 36 | "if log_path.exists():\n", 37 | " flagged_df = pd.read_csv(log_path, quotechar=\"'\") # quoting can be painful for natural text data\n", 38 | " flagged_df = flagged_df.dropna(subset=[\"Webcam Image\"]) # drop any flags without an image\n", 39 | "\n", 40 | "flagged_df" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "if flagged_df is not None:\n", 50 | " row = flagged_df.iloc[-1]\n", 51 | " print(row[\"output\"])\n", 52 | " display(read_image_pil(Path(\"../flagged\") / row[\"Webcam Image\"]))" 53 | ] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "Python 3.7.13 64-bit ('admirer')", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.7.13" 73 | }, 74 | "orig_nbformat": 4, 75 | "vscode": { 76 | "interpreter": { 77 | "hash": "4c4de3d17692a4fce36158e1e6b4cc65d2c1c1dbb8a445fcd77e7a07c1299f79" 78 | } 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 2 83 | } 84 | -------------------------------------------------------------------------------- /question_answer/lit_models/util.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from typing import List, Optional, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def top_k_top_p_filtering( 9 | next_token_logits: torch.FloatTensor, 10 | top_k: Optional[float] = None, 11 | top_p: Optional[float] = None, 12 | device: Union[str, torch.device] = "cpu", 13 | ) -> torch.FloatTensor: 14 | if top_k is None: 15 | top_k = next_token_logits.shape[-1] 16 | if top_p is None: 17 | top_p = 1.0 18 | 19 | p, largest_p_idx = F.softmax(next_token_logits, dim=-1).topk(top_k, dim=-1) 20 | cumulative_p = p.cumsum(dim=-1) 21 | threshold_repeated = top_p + torch.zeros((len(p), 1)).to(device) 22 | idx = torch.searchsorted(cumulative_p, threshold_repeated).clip(max=top_k - 1).squeeze() 23 | cutoffs = cumulative_p[torch.arange(len(cumulative_p)), idx] 24 | censored_p = (cumulative_p <= cutoffs[:, None]) * p 25 | renormalized_p = censored_p / censored_p.sum(dim=-1, keepdims=True) 26 | 27 | final_p = torch.zeros_like(next_token_logits) 28 | row_idx = torch.arange(len(p)).unsqueeze(1).repeat(1, top_k).to(device) 29 | final_p[row_idx, largest_p_idx] = renormalized_p.to(final_p.dtype) 30 | 31 | return final_p 32 | 33 | 34 | def generate_sentence_from_image( 35 | model, encoder_outputs, tokenizer, max_text_length: int, device, top_k: int, top_p: int 36 | ) -> List[str]: 37 | generated_so_far = torch.LongTensor([[tokenizer.bos_token_id]] * len(encoder_outputs.last_hidden_state)).to(device) 38 | with torch.no_grad(): 39 | for _ in tqdm(range(max_text_length)): 40 | attention_mask = torch.ones_like(generated_so_far) 41 | decoder_out = model( 42 | decoder_input_ids=generated_so_far, 43 | decoder_attention_mask=attention_mask, 44 | encoder_outputs=encoder_outputs, 45 | ) 46 | 47 | next_token_logits = decoder_out["logits"][:, -1, :] 48 | filtered_p = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p, device=device) 49 | next_token = torch.multinomial(filtered_p, num_samples=1) 50 | generated_so_far = torch.cat((generated_so_far, next_token), dim=1) 51 | 52 | return [tokenizer.decode(coded_sentence) for coded_sentence in generated_so_far] 53 | -------------------------------------------------------------------------------- /question_answer/tests/test_answer.py: -------------------------------------------------------------------------------- 1 | """Test for answer module.""" 2 | import json 3 | import os 4 | from pathlib import Path 5 | import time 6 | 7 | from question_answer.answer import Pipeline 8 | from question_answer.lit_models.metrics import BertF1Score 9 | 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 12 | 13 | 14 | _FILE_DIRNAME = Path(__file__).parents[0].resolve() 15 | _SUPPORT_DIRNAME = _FILE_DIRNAME / "support" 16 | _IMAGES_SUPPORT_DIRNAME = _SUPPORT_DIRNAME / "images" 17 | _QUESTIONS_SUPPORT_DIRNAME = _SUPPORT_DIRNAME / "questions" 18 | 19 | # restricting number of samples to prevent CirleCI running out of time 20 | _NUM_MAX_SAMPLES = 2 if os.environ.get("CIRCLECI", False) else 100 21 | 22 | 23 | def test_answer(): 24 | """Test Pipeline.""" 25 | support_images = list(_IMAGES_SUPPORT_DIRNAME.glob("*.png")) 26 | support_questions = list(_QUESTIONS_SUPPORT_DIRNAME.glob("*.txt")) 27 | with open(_SUPPORT_DIRNAME / "data_by_file_id.json", "r") as f: 28 | support_data_by_file_id = json.load(f) 29 | 30 | start_time = time.time() 31 | pipeline = Pipeline() 32 | end_time = time.time() 33 | print(f"Time taken to initialize Pipeline: {round(end_time - start_time, 2)}s") 34 | 35 | for i, (support_image, support_question) in enumerate(zip(support_images, support_questions)): 36 | if i >= _NUM_MAX_SAMPLES: 37 | break 38 | expected_text = support_data_by_file_id[support_image.stem]["predicted_text"] 39 | start_time = time.time() 40 | predicted_text = _test_answer(support_image, support_question, expected_text, pipeline) 41 | end_time = time.time() 42 | time_taken = round(end_time - start_time, 2) 43 | 44 | ground_truth = support_data_by_file_id[support_image.stem]["ground_truth_text"] 45 | f1 = BertF1Score()(predicted_text, ground_truth).item() 46 | print( 47 | f"Bert F1 score is {round(f1, 3)} for files {support_image.name} and {support_question.name} (time taken: {time_taken}s)" 48 | ) 49 | 50 | 51 | def _test_answer(image_filename: Path, expected_text: str, pipeline: Pipeline): 52 | """Test ParagraphTextRecognizer on 1 image.""" 53 | predicted_text = pipeline.predict(image_filename) 54 | assert predicted_text == expected_text, f"predicted text does not match expected for {image_filename.name}" 55 | return predicted_text 56 | -------------------------------------------------------------------------------- /api_serverless/api.py: -------------------------------------------------------------------------------- 1 | """AWS Lambda function serving question_answer predictions.""" 2 | import json 3 | 4 | from PIL import ImageStat 5 | 6 | from question_answer.answer import Pipeline 7 | import question_answer.util as util 8 | 9 | model = Pipeline() 10 | 11 | 12 | def handler(event, _context): 13 | """Provide main prediction API.""" 14 | print("INFO loading image") 15 | image = _load_image(event) 16 | if image is None: 17 | return {"statusCode": 400, "message": "neither image_url nor image found in event"} 18 | question = _load_question(event) 19 | if question is None: 20 | return {"statusCode": 400, "message": "neither question_url nor question found in event"} 21 | print("INFO image loaded") 22 | print("INFO starting inference") 23 | pred = model.predict(image, question) 24 | print("INFO inference complete") 25 | image_stat = ImageStat.Stat(image) 26 | print("METRIC image_mean_intensity {}".format(image_stat.mean[0])) 27 | print("METRIC image_area {}".format(image.size[0] * image.size[1])) 28 | print("METRIC pred_length {}".format(len(pred))) 29 | print("INFO pred {}".format(pred)) 30 | return {"pred": str(pred)} 31 | 32 | 33 | def _load_image(event): 34 | event = _from_string(event) 35 | event = _from_string(event.get("body", event)) 36 | image_url = event.get("image_url") 37 | if image_url is not None: 38 | print("INFO url {}".format(image_url)) 39 | return util.read_image_pil(image_url) 40 | else: 41 | image = event.get("image") 42 | if image is not None: 43 | print("INFO reading image from event") 44 | return util.read_b64_image(image) 45 | else: 46 | return None 47 | 48 | 49 | def _load_question(event): 50 | event = _from_string(event) 51 | event = _from_string(event.get("body", event)) 52 | question_url = event.get("question_url") 53 | if question_url is not None: 54 | print("INFO url {}".format(question_url)) 55 | with open(question_url, "r") as f: 56 | question = f.readline() 57 | return question 58 | else: 59 | question = event.get("question") 60 | if question is not None: 61 | print("INFO reading question from event") 62 | return question 63 | else: 64 | return None 65 | 66 | 67 | def _from_string(event): 68 | if isinstance(event, str): 69 | return json.loads(event) 70 | else: 71 | return event 72 | -------------------------------------------------------------------------------- /question_answer/tests/test_data.py: -------------------------------------------------------------------------------- 1 | """Test submodules of the data module.""" 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from question_answer.data import pica 9 | from question_answer.metadata.pica import TRAIN_VAL_SPLIT 10 | 11 | 12 | @pytest.mark.data 13 | class TestDataset: 14 | """Tests downloading and setup of a dataset.""" 15 | 16 | 17 | pica_dir = pica.PROCESSED_DATA_DIRNAME 18 | 19 | 20 | @pytest.fixture(scope="module") 21 | def pica_dataset(): 22 | _remove_if_exist(pica_dir) 23 | dataset = pica.PICa() 24 | dataset.prepare_data() 25 | return dataset 26 | 27 | 28 | def _exist(dir): 29 | return all(os.path.exists(dir)) 30 | 31 | 32 | def _remove_if_exist(dir): 33 | shutil.rmtree(dir, ignore_errors=True) 34 | 35 | 36 | class TestPICa(TestDataset): 37 | """Tests downloading and properties of the dataset.""" 38 | 39 | dir = pica_dir 40 | 41 | def test_prepare_data(self, pica_dataset): 42 | """Tests whether the prepare_data method has produced the expected directories.""" 43 | assert _exist(self.dir) 44 | 45 | def test_setup(self, pica_dataset): 46 | """Tests features of the fully set up dataset.""" 47 | dataset = pica_dataset 48 | dataset.setup() 49 | assert all(map(lambda s: hasattr(dataset, s), ["x_trainval", "y_trainval", "x_test", "y_test"])) 50 | splits = [dataset.x_trainval, dataset.y_trainval, dataset.x_test, dataset.y_test] 51 | assert all(map(lambda attr: type(attr) == np.ndarray, splits)) 52 | observed_train_frac = len(dataset.data_train) / (len(dataset.data_train) + len(dataset.data_val)) 53 | assert np.isclose(observed_train_frac, TRAIN_VAL_SPLIT) 54 | assert dataset.input_dims[-2:] == dataset.x_trainval[0].shape # ToTensor() adds a dimension 55 | assert len(dataset.output_dims) == len(dataset.y_trainval.shape) # == 1 56 | 57 | def test_paired(self, pica_dataset): 58 | """Tests that we retrieve the same number of captions and screenshots.""" 59 | for id in pica_dataset.all_ids: 60 | assert len(pica_dataset.caption_by_id[id]) == len(pica_dataset.screenshot_url_by_id[id]) 61 | 62 | def test_data_splits(self, pica_dataset): 63 | """Fails when any identifiers are shared between training, test, or validation.""" 64 | assert not set(pica_dataset.train_ids) & set(pica_dataset.validation_ids) 65 | assert not set(pica_dataset.train_ids) & set(pica_dataset.test_ids) 66 | assert not set(pica_dataset.validation_ids) & set(pica_dataset.test_ids) 67 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | # allows manual triggering of this workflow 8 | workflow_dispatch: 9 | 10 | jobs: 11 | 12 | unit-tests: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ["3.7"] 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Full Python environment cacheing 25 | # see AI2 blogpost for details: https://blog.allenai.org/python-caching-in-github-actions-e9452698e98d 26 | uses: actions/cache@v2 27 | with: 28 | path: ${{ env.pythonLocation }} 29 | key: v1-${{ env.pythonLocation }}-${{ hashFiles('requirements/dev.txt') }}-${{ hashFiles('requirements/prod.txt') }} 30 | - name: Install dependencies with pip 31 | run: | 32 | pip install --quiet -r requirements/prod.txt -r requirements/dev.txt 33 | - name: Run unit tests 34 | run: | 35 | ./tasks/unit_test.sh 36 | env: 37 | PYTHONPATH: . 38 | WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} 39 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 40 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 41 | 42 | integration-tests: 43 | runs-on: ubuntu-latest 44 | strategy: 45 | matrix: 46 | python-version: ["3.7"] 47 | 48 | steps: 49 | - uses: actions/checkout@v3 50 | - name: Set up Python ${{ matrix.python-version }} 51 | uses: actions/setup-python@v4 52 | with: 53 | python-version: ${{ matrix.python-version }} 54 | - name: Full Python environment cacheing 55 | # see AI2 blogpost for details: https://blog.allenai.org/python-caching-in-github-actions-e9452698e98d 56 | uses: actions/cache@v2 57 | with: 58 | path: ${{ env.pythonLocation }} 59 | key: v1-${{ env.pythonLocation }}-${{ hashFiles('requirements/dev.txt') }}-${{ hashFiles('requirements/prod.txt') }} 60 | - name: Install dependencies with pip 61 | run: | 62 | pip install --quiet -r requirements/prod.txt -r requirements/dev.txt 63 | - name: Run integration tests 64 | run: | 65 | ./tasks/integration_test.sh 66 | env: 67 | PYTHONPATH: . 68 | WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} 69 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 70 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 71 | -------------------------------------------------------------------------------- /question_answer/callbacks/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import tempfile 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 7 | import torch 8 | 9 | from .util import check_and_warn, logging 10 | 11 | try: 12 | import torchviz 13 | 14 | has_torchviz = True 15 | except ImportError: 16 | has_torchviz = False 17 | 18 | 19 | class ModelSizeLogger(pl.Callback): 20 | """Logs information about model size (in parameters and on disk).""" 21 | 22 | def __init__(self, print_size=True): 23 | super().__init__() 24 | self.print_size = print_size 25 | 26 | @rank_zero_only 27 | def on_fit_start(self, trainer, module): 28 | self._run(trainer, module) 29 | 30 | def _run(self, trainer, module): 31 | metrics = {} 32 | metrics["mb_disk"] = self.get_model_disksize(module) 33 | metrics["nparams"] = count_params(module) 34 | 35 | if self.print_size: 36 | print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB") 37 | 38 | metrics = {f"size/{key}": value for key, value in metrics.items()} 39 | 40 | trainer.logger.log_metrics(metrics, step=-1) 41 | 42 | @staticmethod 43 | def get_model_disksize(module): 44 | """Determine the model's size on disk by saving it to disk.""" 45 | with tempfile.NamedTemporaryFile() as f: 46 | torch.save(module.state_dict(), f) 47 | size_mb = os.path.getsize(f.name) / 1e6 48 | return size_mb 49 | 50 | 51 | class GraphLogger(pl.Callback): 52 | """Logs a compute graph as an image.""" 53 | 54 | def __init__(self, output_key="logits"): 55 | super().__init__() 56 | self.graph_logged = False 57 | self.output_key = output_key 58 | if not has_torchviz: 59 | raise ImportError("GraphLogCallback requires torchviz." "") 60 | 61 | @rank_zero_only 62 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx): 63 | if not self.graph_logged: 64 | try: 65 | outputs = outputs[0][0]["extra"] 66 | self.log_graph(trainer, module, outputs[self.output_key]) 67 | except KeyError: 68 | logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}") 69 | self.graph_logged = True 70 | 71 | @staticmethod 72 | def log_graph(trainer, module, outputs): 73 | if check_and_warn(trainer.logger, "log_image", "graph"): 74 | return 75 | params_dict = dict(list(module.named_parameters())) 76 | graph = torchviz.make_dot(outputs, params=params_dict) 77 | graph.format = "png" 78 | fname = Path(trainer.logger.experiment.dir) / "graph" 79 | graph.render(fname) 80 | fname = str(fname.with_suffix("." + graph.format)) 81 | trainer.logger.log_image(key="graph", images=[fname]) 82 | 83 | 84 | def count_params(module): 85 | """Counts the number of parameters in a Torch Module.""" 86 | return sum(p.numel() for p in module.parameters()) 87 | -------------------------------------------------------------------------------- /training/test_model.py: -------------------------------------------------------------------------------- 1 | """Experiment-running framework.""" 2 | import argparse 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | 8 | from question_answer import callbacks as cb 9 | from question_answer import lit_models 10 | from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args 11 | 12 | 13 | # In order to ensure reproducible experiments, we must set random seeds. 14 | np.random.seed(42) 15 | torch.manual_seed(42) 16 | 17 | 18 | def _setup_parser(): 19 | """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" 20 | parser = argparse.ArgumentParser(add_help=False) 21 | 22 | # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision 23 | trainer_parser = pl.Trainer.add_argparse_args(parser) 24 | trainer_parser._action_groups[1].title = "Trainer Args" 25 | parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) 26 | parser.set_defaults(max_epochs=1) 27 | 28 | # Basic arguments 29 | parser.add_argument( 30 | "--data_class", 31 | type=str, 32 | default="PICa", 33 | help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", 34 | ) 35 | parser.add_argument( 36 | "--model_class", 37 | type=str, 38 | default="ViT2GPT2", 39 | help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", 40 | ) 41 | parser.add_argument("--load_checkpoint", type=str, default=None, help="Loads a model from the provided path.") 42 | parser.add_argument( 43 | "--stop_early", 44 | type=int, 45 | default=0, 46 | help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." 47 | + " Default is 0.", 48 | ) 49 | 50 | # Get the data and model classes, so that we can add their specific arguments 51 | temp_args, _ = parser.parse_known_args() 52 | data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") 53 | model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") 54 | 55 | # Get data, model, and LitModel specific arguments 56 | data_group = parser.add_argument_group("Data Args") 57 | data_class.add_to_argparse(data_group) 58 | 59 | model_group = parser.add_argument_group("Model Args") 60 | model_class.add_to_argparse(model_group) 61 | 62 | lit_model_group = parser.add_argument_group("LitModel Args") 63 | lit_models.GPT2.add_to_argparse(lit_model_group) 64 | 65 | parser.add_argument("--help", "-h", action="help") 66 | return parser 67 | 68 | 69 | def main(): 70 | """ 71 | Test a GPT2-decoder model on the PICa test dataset. 72 | """ 73 | parser = _setup_parser() 74 | args = parser.parse_args() 75 | data, model = setup_data_and_model_from_args(args) 76 | assert args.load_checkpoint, "Need to provide a model checkpoint to test." 77 | 78 | lit_model_class = lit_models.GPT2 79 | lit_model = lit_model_class.load_from_checkpoint( 80 | args.load_checkpoint, args=args, model=model.vit2gpt2, tokenizer=model.gpt2_tokenizer 81 | ) 82 | 83 | callbacks = [cb.ImageToTextPrintLogger()] 84 | 85 | trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks) 86 | profiler = pl.profiler.PassThroughProfiler() 87 | trainer.profiler = profiler 88 | 89 | trainer.test(lit_model, datamodule=data) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /load_test/locust.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Load Testing with Locust" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import IPython.display\n", 24 | "import pandas as pd" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "!pip install -q locust" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Running the load test" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "!locust --locustfile=locust_http_user.py \\\n", 50 | " --headless \\\n", 51 | " --users=10 \\\n", 52 | " --spawn-rate=1 \\\n", 53 | " --run-time=2m \\\n", 54 | " --host=https://joiajq6syp65ueonto4mswttzu0apfbi.lambda-url.us-west-1.on.aws \\\n", 55 | " --html=locust_report.html \\\n", 56 | " --csv=locust_report" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Viewing the results" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "!ls -lh locust_report*" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "IPython.display.HTML(\"locust_report.html\")" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## Analyzing load test data programmatically" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "csv_path = \"locust_report_stats_history.csv\"\n", 98 | "results = pd.read_csv(csv_path)\n", 99 | "results[\"Timestamp\"] = pd.to_datetime(results[\"Timestamp\"], unit=\"s\")\n", 100 | "results.tail()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "request_columns = [\"Total Request Count\", \"Total Failure Count\", \"User Count\"]\n", 110 | "results.plot(x=\"Timestamp\", y=request_columns, subplots=True, sharey=True);" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "response_columns = [\"Total Average Response Time\", \"Total Max Response Time\"]\n", 120 | "results.plot(x=\"Timestamp\", y=response_columns);" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "results.groupby(\"Total Median Response Time\").describe()" 130 | ] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "Python 3.7.13 64-bit ('admirer')", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "name": "python", 141 | "version": "3.7.13" 142 | }, 143 | "orig_nbformat": 4, 144 | "vscode": { 145 | "interpreter": { 146 | "hash": "4c4de3d17692a4fce36158e1e6b4cc65d2c1c1dbb8a445fcd77e7a07c1299f79" 147 | } 148 | } 149 | }, 150 | "nbformat": 4, 151 | "nbformat_minor": 2 152 | } 153 | -------------------------------------------------------------------------------- /question_answer/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for question_answer module.""" 2 | import base64 3 | import contextlib 4 | import hashlib 5 | from io import BytesIO 6 | import os 7 | from pathlib import Path 8 | from typing import Union 9 | from urllib.request import urlretrieve 10 | 11 | import numpy as np 12 | from PIL import Image 13 | import smart_open 14 | from tqdm import tqdm 15 | 16 | 17 | def to_categorical(y, num_classes): 18 | """1-hot encode a tensor.""" 19 | return np.eye(num_classes, dtype="uint8")[y] 20 | 21 | 22 | def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: 23 | with smart_open.open(image_uri, "rb") as image_file: 24 | return read_image_pil_file(image_file, grayscale) 25 | 26 | 27 | def read_image_pil_file(image_file, grayscale=False) -> Image: 28 | with Image.open(image_file) as image: 29 | if grayscale: 30 | image = image.convert(mode="L") 31 | else: 32 | image = image.convert(mode=image.mode) 33 | return image 34 | 35 | 36 | @contextlib.contextmanager 37 | def temporary_working_directory(working_dir: Union[str, Path]): 38 | """Temporarily switches to a directory, then returns to the original directory on exit.""" 39 | curdir = os.getcwd() 40 | os.chdir(working_dir) 41 | try: 42 | yield 43 | finally: 44 | os.chdir(curdir) 45 | 46 | 47 | # Hide lines below until Lab 08 48 | def read_b64_image(b64_string, grayscale=False): 49 | """Load base64-encoded images.""" 50 | try: 51 | image_file = read_b64_string(b64_string) 52 | return read_image_pil_file(image_file, grayscale) 53 | except Exception as exception: 54 | raise ValueError("Could not load image from b64 {}: {}".format(b64_string, exception)) from exception 55 | 56 | 57 | def read_b64_string(b64_string, return_data_type=False): 58 | """Read a base64-encoded string into an in-memory file-like object.""" 59 | data_header, b64_data = split_and_validate_b64_string(b64_string) 60 | b64_buffer = BytesIO(base64.b64decode(b64_data)) 61 | if return_data_type: 62 | return get_b64_filetype(data_header), b64_buffer 63 | else: 64 | return b64_buffer 65 | 66 | 67 | def get_b64_filetype(data_header): 68 | """Retrieves the filetype information from the data type header of a base64-encoded object.""" 69 | _, file_type = data_header.split("/") 70 | return file_type 71 | 72 | 73 | def split_and_validate_b64_string(b64_string): 74 | """Return the data_type and data of a b64 string, with validation.""" 75 | header, data = b64_string.split(",", 1) 76 | assert header.startswith("data:") 77 | assert header.endswith(";base64") 78 | data_type = header.split(";")[0].split(":")[1] 79 | return data_type, data 80 | 81 | 82 | # Hide lines above until Lab 08 83 | 84 | 85 | def encode_b64_image(image, format="png"): 86 | """Encode a PIL image as a base64 string.""" 87 | _buffer = BytesIO() # bytes that live in memory 88 | image.save(_buffer, format=format) # but which we write to like a file 89 | encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8") 90 | return encoded_image 91 | 92 | 93 | def compute_sha256(filename: Union[Path, str]): 94 | """Return SHA256 checksum of a file.""" 95 | with open(filename, "rb") as f: 96 | return hashlib.sha256(f.read()).hexdigest() 97 | 98 | 99 | class TqdmUpTo(tqdm): 100 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 101 | 102 | def update_to(self, blocks=1, bsize=1, tsize=None): 103 | """ 104 | Parameters 105 | ---------- 106 | blocks: int, optional 107 | Number of blocks transferred so far [default: 1]. 108 | bsize: int, optional 109 | Size of each block (in tqdm units) [default: 1]. 110 | tsize: int, optional 111 | Total size (in tqdm units). If [default: None] remains unchanged. 112 | """ 113 | if tsize is not None: 114 | self.total = tsize 115 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 116 | 117 | 118 | def download_url(url, filename): 119 | """Download a file from url to filename, with a progress bar.""" 120 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 121 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 122 | -------------------------------------------------------------------------------- /question_answer/data/base_data_module.py: -------------------------------------------------------------------------------- 1 | """Base DataModule class.""" 2 | import argparse 3 | import os 4 | from pathlib import Path 5 | from typing import Collection, Dict, Optional, Tuple, Union 6 | 7 | import pytorch_lightning as pl 8 | import torch 9 | from torch.utils.data import ConcatDataset, DataLoader 10 | 11 | from question_answer.data.util import BaseDataset 12 | import question_answer.metadata.shared as metadata 13 | 14 | 15 | def load_and_print_info(data_module_class) -> None: 16 | """Load Dataset and print info.""" 17 | parser = argparse.ArgumentParser() 18 | data_module_class.add_to_argparse(parser) 19 | args = parser.parse_args() 20 | dataset = data_module_class(args) 21 | dataset.prepare_data() 22 | dataset.setup() 23 | print(dataset) 24 | 25 | 26 | BATCH_SIZE = 8 27 | NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) 28 | NUM_AVAIL_GPUS = torch.cuda.device_count() 29 | 30 | # sensible multiprocessing defaults: at most one worker per CPU 31 | DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS 32 | # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU 33 | DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS 34 | 35 | 36 | class BaseDataModule(pl.LightningDataModule): 37 | """Base for all of our LightningDataModules. 38 | 39 | Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html 40 | """ 41 | 42 | def __init__(self, args: argparse.Namespace = None) -> None: 43 | super().__init__() 44 | self.args = vars(args) if args is not None else {} 45 | self.batch_size = self.args.get("batch_size", BATCH_SIZE) 46 | self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) 47 | 48 | self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) 49 | 50 | # Make sure to set the variables below in subclasses 51 | self.input_dims: Tuple[int, ...] 52 | self.output_dims: Tuple[int, ...] 53 | self.data_train: Union[BaseDataset, ConcatDataset] 54 | self.data_val: Union[BaseDataset, ConcatDataset] 55 | self.data_test: Union[BaseDataset, ConcatDataset] 56 | 57 | @classmethod 58 | def data_dirname(cls): 59 | return metadata.DATA_DIRNAME 60 | 61 | @staticmethod 62 | def add_to_argparse(parser): 63 | parser.add_argument( 64 | "--batch_size", 65 | type=int, 66 | default=BATCH_SIZE, 67 | help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", 68 | ) 69 | parser.add_argument( 70 | "--num_workers", 71 | type=int, 72 | default=DEFAULT_NUM_WORKERS, 73 | help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", 74 | ) 75 | return parser 76 | 77 | def config(self): 78 | """Return important settings of the dataset, which will be passed to instantiate models.""" 79 | return {"input_dims": self.input_dims, "output_dims": self.output_dims} 80 | 81 | def prepare_data(self, *args, **kwargs) -> None: 82 | """Take the first steps to prepare data for use. 83 | 84 | Use this method to do things that might write to disk or that need to be done only from a single GPU 85 | in distributed settings (so don't set state `self.x = y`). 86 | """ 87 | 88 | def setup(self, stage: Optional[str] = None) -> None: 89 | """Perform final setup to prepare data for consumption by DataLoader. 90 | 91 | Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. 92 | Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. 93 | """ 94 | 95 | def train_dataloader(self): 96 | return DataLoader( 97 | self.data_train, 98 | shuffle=True, 99 | batch_size=self.batch_size, 100 | num_workers=self.num_workers, 101 | pin_memory=self.on_gpu, 102 | ) 103 | 104 | def val_dataloader(self): 105 | return DataLoader( 106 | self.data_val, 107 | shuffle=False, 108 | batch_size=self.batch_size, 109 | num_workers=self.num_workers, 110 | pin_memory=self.on_gpu, 111 | ) 112 | 113 | def test_dataloader(self): 114 | return DataLoader( 115 | self.data_test, 116 | shuffle=False, 117 | batch_size=self.batch_size, 118 | num_workers=self.num_workers, 119 | pin_memory=self.on_gpu, 120 | ) 121 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for investing your time in contributing to my project! Any contribution you make will be reflected on [docs.github.com](https://docs.github.com/en) :sparkles:. 4 | 5 | In this guide you will get an overview of the contribution workflow from opening an issue, creating a PR, reviewing, and merging the PR. 6 | 7 | ## New contributor guide 8 | 9 | To get an overview of the project, read the [README](README.md). Here are some resources to help you get started with open source contributions: 10 | 11 | - [Finding ways to contribute to open source on GitHub](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github) 12 | - [Set up Git](https://docs.github.com/en/get-started/quickstart/set-up-git) 13 | - [GitHub flow](https://docs.github.com/en/get-started/quickstart/github-flow) 14 | - [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests) 15 | 16 | ## Getting started 17 | 18 | ### Notion 19 | 20 | Head [here](https://stump-molecule-6c9.notion.site/Project-Home-04728cd6ba2042c59e535979733065cd) to see the current set of features being worked on. To gain edit access, contact Andrew Hinh @ ajhinh@gmail.com. 21 | 22 | ### Issues 23 | 24 | #### Create a new issue 25 | 26 | If you spot a problem with the docs, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). If a related issue doesn't exist, you can open a new issue using a relevant [issue form](https://github.com/andrewhinh/admirer/issues/new). 27 | 28 | #### Solve an issue 29 | 30 | Scan through our [existing issues](https://github.com/andrewhinh/admirer/issues) to find one that interests you. You can narrow down the search using `labels` as filters. As a general rule, we don’t assign issues to anyone. If you find an issue to work on, you are welcome to open a PR with a fix. 31 | 32 | ### Make Changes 33 | 34 | #### Make changes in the UI 35 | 36 | Click **Make a contribution** at the bottom of any docs page to make small changes such as a typo, sentence fix, or a broken link. This takes you to the `.md` file where you can make your changes and [create a pull request](#pull-request) for a review. 37 | 38 | #### Make changes in a codespace 39 | 40 | For more information about using a codespace for working on GitHub documentation, see "[Working in a codespace](https://github.com/github/docs/blob/main/contributing/codespace.md)." 41 | 42 | #### Make changes locally 43 | 44 | 1. Fork the repository. 45 | 46 | - Using GitHub Desktop: 47 | - [Getting started with GitHub Desktop](https://docs.github.com/en/desktop/installing-and-configuring-github-desktop/getting-started-with-github-desktop) will guide you through setting up Desktop. 48 | - Once Desktop is set up, you can use it to [fork the repo](https://docs.github.com/en/desktop/contributing-and-collaborating-using-github-desktop/cloning-and-forking-repositories-from-github-desktop)! 49 | 50 | - Using the command line: 51 | - [Fork the repo](https://docs.github.com/en/github/getting-started-with-github/fork-a-repo#fork-an-example-repository) so that you can make your changes without affecting the original project until you're ready to merge them. 52 | 53 | 2. Create a working branch and start with your changes! 54 | 55 | ### Commit your update 56 | 57 | Commit the changes once you are happy with them. Don't forget to self-review to speed up the review process:zap:. 58 | 59 | ### Pull Request 60 | 61 | When you're finished with the changes, create a pull request, also known as a PR. 62 | 63 | - Fill the "Ready for review" template so that we can review your PR. This template helps reviewers understand your changes as well as the purpose of your pull request. 64 | - Don't forget to [link PR to issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) if you are solving one. 65 | - Enable the checkbox to [allow maintainer edits](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/allowing-changes-to-a-pull-request-branch-created-from-a-fork) so the branch can be updated for a merge. 66 | Once you submit your PR, a Docs team member will review your proposal. We may ask questions or request additional information. 67 | - We may ask for changes to be made before a PR can be merged, either using [suggested changes](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/incorporating-feedback-in-your-pull-request) or pull request comments. You can apply suggested changes directly through the UI. You can make any other changes in your fork, then commit them to your branch. 68 | - As you update your PR and apply changes, mark each conversation as [resolved](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/commenting-on-a-pull-request#resolving-conversations). 69 | - If you run into any merge issues, checkout this [git tutorial](https://github.com/skills/resolve-merge-conflicts) to help you resolve merge conflicts and other issues. 70 | 71 | ### Your PR is merged 72 | 73 | Congratulations :tada::tada: Once your PR is merged, your contributions will be publicly visible on the [GitHub docs](https://docs.github.com/en). 74 | -------------------------------------------------------------------------------- /question_answer/lit_models/gpt2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import transforms 7 | from typing import List, Tuple 8 | import wandb 9 | 10 | 11 | import question_answer.metadata.pica as metadata 12 | 13 | 14 | OPTIMIZER = "Adam" 15 | LR = 1e-4 16 | ONE_CYCLE_TOTAL_STEPS = 100 17 | 18 | TOP_K = 1000 19 | TOP_P = 0.95 20 | MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH 21 | LABEL_MASK = -100 22 | 23 | 24 | class GPT2(pl.LightningModule): 25 | """ 26 | GPT2 PyTorch-Lightning class. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | model: nn.Module, 32 | tokenizer, 33 | args: argparse.Namespace = None, 34 | ): 35 | super().__init__() 36 | self.model = model 37 | self.tokenizer = tokenizer 38 | self.args = vars(args) if args is not None else {} 39 | 40 | optimizer = self.args.get("optimizer", OPTIMIZER) 41 | self.optimizer_class = getattr(torch.optim, optimizer) 42 | self.lr = self.args.get("lr", LR) 43 | self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) 44 | self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) 45 | 46 | self.top_k = self.args.get("top_k", TOP_K) 47 | self.top_p = self.args.get("top_p", TOP_P) 48 | self.max_label_length = self.args.get("max_label_length", MAX_LABEL_LENGTH) 49 | self.label_mask = self.args.get("label_mask", LABEL_MASK) 50 | 51 | model.eval() 52 | for p in model.parameters(): 53 | p.requires_grad = False 54 | 55 | # only allow training of cross attention parameters 56 | for layer in model.decoder.transformer.h: 57 | layer.crossattention.train() 58 | for p in layer.crossattention.parameters(): 59 | p.requires_grad = True 60 | layer.ln_cross_attn.train() 61 | for p in layer.ln_cross_attn.parameters(): 62 | p.requires_grad = True 63 | 64 | @staticmethod 65 | def add_to_argparse(parser): 66 | parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") 67 | parser.add_argument("--lr", type=float, default=LR) 68 | parser.add_argument("--one_cycle_max_lr", type=float, default=None) 69 | parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) 70 | 71 | parser.add_argument("--top_k", type=int, default=TOP_K) 72 | parser.add_argument("--top_p", type=float, default=TOP_P) 73 | parser.add_argument("--max_label_length", type=int, default=MAX_LABEL_LENGTH) 74 | parser.add_argument("--label_mask", type=float, default=LABEL_MASK) 75 | return parser 76 | 77 | def configure_optimizers(self): 78 | optimizer = self.optimizer_class(self.parameters(), lr=self.lr) 79 | if self.one_cycle_max_lr is None: 80 | return optimizer 81 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 82 | optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps 83 | ) 84 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} 85 | 86 | def common_step(self, batch: Tuple[torch.FloatTensor, List[str]]) -> torch.FloatTensor: 87 | images, captions = batch 88 | tokenized_captions = { 89 | k: v.to(self.device) 90 | for k, v in self.tokenizer( 91 | captions, 92 | max_length=self.max_label_length, 93 | truncation=True, 94 | padding=True, 95 | return_tensors="pt", 96 | ).items() 97 | } 98 | labels = tokenized_captions["input_ids"].clone() 99 | labels[tokenized_captions["attention_mask"] == 0] = self.label_mask 100 | encoder_outputs = self.model.encoder(pixel_values=images) 101 | outputs = self.model( 102 | encoder_outputs=encoder_outputs, 103 | decoder_input_ids=tokenized_captions["input_ids"], 104 | decoder_attention_mask=tokenized_captions["attention_mask"], 105 | labels=labels, 106 | return_dict=True, 107 | ) 108 | 109 | return outputs["loss"] 110 | 111 | def training_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int) -> torch.FloatTensor: 112 | loss = self.common_step(batch) 113 | self.log("train/loss", loss) 114 | 115 | return loss 116 | 117 | def validation_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int): 118 | loss = self.common_step(batch) 119 | self.log("validation/loss", loss, prog_bar=True, sync_dist=True) 120 | 121 | def test_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int): 122 | loss = self.common_step(batch) 123 | self.log("test/loss", loss, on_step=False, on_epoch=True) 124 | 125 | def on_after_backward(self): 126 | if self.trainer.global_step % 50 == 0: # don't make the tf file huge 127 | for name, param in self.model.named_parameters(): 128 | if "weight" in name and "norm" not in name and param.requires_grad: 129 | self.logger.experiment.log({f"{name}_grad": wandb.Histogram(param.grad.detach().cpu())}) 130 | self.logger.experiment.log({f"{name}": wandb.Histogram(param.detach().cpu())}) 131 | -------------------------------------------------------------------------------- /question_answer/callbacks/imtotext.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.utilities import rank_zero_only 3 | 4 | try: 5 | import wandb 6 | 7 | has_wandb = True 8 | except ImportError: 9 | has_wandb = False 10 | 11 | from .util import check_and_warn 12 | from question_answer.lit_models.util import generate_sentence_from_image 13 | from torchvision import transforms 14 | 15 | 16 | descale = transforms.Compose( 17 | [ 18 | transforms.Normalize(mean=[0.0, 0.0, 0.0], std=1 / 0.5), 19 | transforms.Normalize(mean=-0.5, std=[1.0, 1.0, 1.0]), 20 | ] 21 | ) 22 | 23 | 24 | class ImageToTextPrintLogger(pl.Callback): 25 | """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" 26 | 27 | def __init__(self, max_images_to_log=32, on_train=True): 28 | super().__init__() 29 | self.max_images_to_log = min(max(max_images_to_log, 1), 32) 30 | self.on_train = on_train 31 | 32 | @rank_zero_only 33 | def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): 34 | self._log_image_text_table(trainer, output, batch, "test/predictions") 35 | 36 | def _log_image_text_table(self, trainer, output, batch, key): 37 | images, actual_sentences = batch 38 | trainer = trainer.model # For easy access to the model 39 | if hasattr(trainer, "module"): # For DDP 40 | trainer = trainer.module.module 41 | encoder_outputs = trainer.model.encoder(pixel_values=images.to(trainer.device)) 42 | generated_sentences = generate_sentence_from_image( 43 | trainer.model, 44 | encoder_outputs, 45 | trainer.tokenizer, 46 | trainer.max_label_length, 47 | trainer.device, 48 | trainer.top_k, 49 | trainer.top_p, 50 | ) 51 | print("Actual: ", actual_sentences) 52 | print("Generated: ", generated_sentences) 53 | 54 | 55 | class ImageToTextTableLogger(pl.Callback): 56 | """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" 57 | 58 | def __init__(self, max_images_to_log=32, on_train=True): 59 | super().__init__() 60 | self.max_images_to_log = min(max(max_images_to_log, 1), 32) 61 | self.on_train = on_train 62 | 63 | @rank_zero_only 64 | def on_train_batch_end(self, trainer, module, output, batch, batch_idx): 65 | if self.on_train: 66 | if check_and_warn(trainer.logger, "log_table", "image-to-text table"): 67 | return 68 | else: 69 | self._log_image_text_table(trainer, output, batch, "train/predictions") 70 | 71 | @rank_zero_only 72 | def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): 73 | if check_and_warn(trainer.logger, "log_table", "image-to-text table"): 74 | return 75 | else: 76 | self._log_image_text_table(trainer, output, batch, "validation/predictions") 77 | 78 | def _log_image_text_table(self, trainer, output, batch, key): 79 | images, actual_sentences = batch 80 | trainer = trainer.model # For easy access to the model 81 | if hasattr(trainer, "module"): # For DDP 82 | trainer = trainer.module.module 83 | encoder_outputs = trainer.model.encoder(pixel_values=images.to(trainer.device)) 84 | generated_sentences = generate_sentence_from_image( 85 | trainer.model, 86 | encoder_outputs, 87 | trainer.tokenizer, 88 | trainer.max_label_length, 89 | trainer.device, 90 | trainer.top_k, 91 | trainer.top_p, 92 | ) 93 | images = [wandb.Image(transforms.ToPILImage()(descale(image))) for image in images] 94 | data = list(map(list, zip(images, actual_sentences, generated_sentences))) 95 | columns = ["Images", "Actual Sentence", "Generated Sentence"] 96 | table = wandb.Table(data=data, columns=columns) 97 | trainer.logger.experiment.log({f"epoch {trainer.current_epoch} results": table}) 98 | 99 | 100 | class ImageToTextCaptionLogger(pl.Callback): 101 | """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" 102 | 103 | def __init__(self, max_images_to_log=32, on_train=True): 104 | super().__init__() 105 | self.max_images_to_log = min(max(max_images_to_log, 1), 32) 106 | self.on_train = on_train 107 | self._required_keys = ["gt_strs", "pred_strs"] 108 | 109 | @rank_zero_only 110 | def on_train_batch_end(self, trainer, module, output, batch, batch_idx): 111 | if self.has_metrics(output): 112 | if check_and_warn(trainer.logger, "log_image", "image-to-text"): 113 | return 114 | else: 115 | self._log_image_text_caption(trainer, output, batch, "train/predictions") 116 | 117 | @rank_zero_only 118 | def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): 119 | if self.has_metrics(output): 120 | if check_and_warn(trainer.logger, "log_image", "image-to-text"): 121 | return 122 | else: 123 | self._log_image_text_caption(trainer, output, batch, "validation/predictions") 124 | 125 | @rank_zero_only 126 | def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): 127 | if self.has_metrics(output): 128 | if check_and_warn(trainer.logger, "log_image", "image-to-text"): 129 | return 130 | else: 131 | self._log_image_text_caption(trainer, output, batch, "test/predictions") 132 | 133 | def _log_image_text_caption(self, trainer, output, batch, key): 134 | xs, _ = batch 135 | gt_strs = output["gt_strs"] 136 | pred_strs = output["pred_strs"] 137 | 138 | mx = self.max_images_to_log 139 | xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx] 140 | 141 | trainer.logger.log_image(key, xs, caption=pred_strs) 142 | 143 | def has_metrics(self, output): 144 | return all(key in output.keys() for key in self._required_keys) 145 | -------------------------------------------------------------------------------- /training/cleanup_artifacts.py: -------------------------------------------------------------------------------- 1 | """Removes artifacts from projects and runs. 2 | 3 | Artifacts are binary files that we want to track 4 | and version but don't want to include in git, 5 | generally because they are too large, 6 | because they don't have meaningful diffs, 7 | or because they change more quickly than code. 8 | 9 | During development, we often generate artifacts 10 | that we don't really need, e.g. model weights for 11 | an overfitting test run. Space on artifact storage 12 | is generally very large, but it is limited, 13 | so we should occasionally delete unneeded artifacts 14 | to reclaim some of that space. 15 | 16 | For usage help, run 17 | python training/cleanup_artifacts.py --help 18 | """ 19 | import argparse 20 | 21 | import wandb 22 | 23 | 24 | api = wandb.Api() 25 | 26 | DEFAULT_PROJECT = "admirer-training" 27 | DEFAULT_ENTITY = api.default_entity 28 | 29 | 30 | def _setup_parser(): 31 | parser = argparse.ArgumentParser(description=__doc__) 32 | parser.add_argument( 33 | "--project", 34 | type=str, 35 | default=DEFAULT_PROJECT, 36 | help=f"The project from which to remove artifacts. Default is {DEFAULT_PROJECT}", 37 | ) 38 | parser.add_argument( 39 | "--run_ids", 40 | type=str, 41 | default=None, 42 | nargs="*", 43 | help="One or more run IDs from which to remove artifacts. Default is None.", 44 | ) 45 | parser.add_argument( 46 | "--run_name_res", 47 | type=str, 48 | default=None, 49 | nargs="*", 50 | help="One or more regular expressions to use to select runs (by display name) from which to remove artifacts. See wandb.Api.runs documentation for details on the syntax. Beware that this is a footgun and consider using interactively with --dryrun and -v. Default is None.", 51 | metavar="RUN_NAME_REGEX", 52 | ) 53 | 54 | flags = parser.add_mutually_exclusive_group() 55 | flags.add_argument("--all", action="store_true", help="Delete all artifacts from selected runs.") 56 | flags.add_argument( 57 | "--no-alias", action="store_true", help="Delete all artifacts without an alias from selected runs." 58 | ) 59 | flags.add_argument( 60 | "--aliases", 61 | type=str, 62 | nargs="*", 63 | help="Delete artifacts that have any of the aliases from the provided list from selected runs.", 64 | ) 65 | 66 | parser.add_argument( 67 | "-v", 68 | action="store_true", 69 | dest="verbose", 70 | help="Display information about targeted entities, projects, runs, and artifacts.", 71 | ) 72 | parser.add_argument( 73 | "--dryrun", 74 | action="store_true", 75 | help="Select artifacts without deleting them and display which artifacts were selected.", 76 | ) 77 | return parser 78 | 79 | 80 | def main(args): 81 | project_path = f"{DEFAULT_ENTITY}/{args.project}" 82 | 83 | runs = _get_runs(project_path, args.run_ids, args.run_name_res, verbose=args.verbose) 84 | artifact_selector = _get_selector_from(args) 85 | protect_aliases = args.no_alias # avoid deletion of any aliased artifacts 86 | 87 | for run in runs: 88 | clean_run_artifacts( 89 | run, selector=artifact_selector, protect_aliases=protect_aliases, verbose=args.verbose, dryrun=args.dryrun 90 | ) 91 | 92 | 93 | def clean_run_artifacts(run, selector, protect_aliases=True, verbose=False, dryrun=True): 94 | artifacts = run.logged_artifacts() 95 | for artifact in artifacts: 96 | if selector(artifact): 97 | remove_artifact(artifact, protect_aliases=protect_aliases, verbose=verbose, dryrun=dryrun) 98 | 99 | 100 | def remove_artifact(artifact, protect_aliases, verbose=False, dryrun=True): 101 | project, entity, id = artifact.project, artifact.entity, artifact.id 102 | type, aliases = artifact.type, artifact.aliases 103 | if verbose or dryrun: 104 | print(f"selecting for deletion artifact {project}/{entity}/{id} of type {type} with aliases {aliases}") 105 | if not dryrun: 106 | artifact.delete(delete_aliases=not protect_aliases) 107 | 108 | 109 | def _get_runs(project_path, run_ids=None, run_name_res=None, verbose=False): 110 | if run_ids is None: 111 | run_ids = [] 112 | 113 | if run_name_res is None: 114 | run_name_res = [] 115 | 116 | runs = [] 117 | for run_id in run_ids: 118 | runs.append(_get_run_by_id(project_path, run_id, verbose=verbose)) 119 | 120 | for run_name_re in run_name_res: 121 | runs += _get_runs_by_name_re(project_path, run_name_re, verbose=verbose) 122 | 123 | return runs 124 | 125 | 126 | def _get_run_by_id(project_path, run_id, verbose=False): 127 | path = f"{project_path}/{run_id}" 128 | run = api.run(path) 129 | if verbose: 130 | print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}") 131 | return run 132 | 133 | 134 | def _get_runs_by_name_re(project_path, run_name_re, verbose=False): 135 | matching_runs = api.runs(path=project_path, filters={"display_name": {"$regex": run_name_re}}) 136 | 137 | if verbose: 138 | for run in matching_runs: 139 | print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}") 140 | 141 | return matching_runs 142 | 143 | 144 | def _get_selector_from(args, verbose=False): 145 | if args.all: 146 | if verbose: 147 | print("removing all artifacts from matching runs") 148 | return lambda _: True 149 | 150 | if args.no_alias: 151 | if verbose: 152 | print("removing all artifacts with no aliases from matching runs") 153 | return lambda artifact: artifact.aliases == [] 154 | 155 | if args.aliases: 156 | if verbose: 157 | print(f"removing all artifacts with any of {args.aliases} in aliases from matching runs") 158 | return lambda artifact: any(alias in artifact.aliases for alias in args.aliases) 159 | 160 | if verbose: 161 | print("removing no artifacts matching runs") 162 | return lambda _: False 163 | 164 | 165 | if __name__ == "__main__": 166 | parser = _setup_parser() 167 | args = parser.parse_args() 168 | main(args) 169 | -------------------------------------------------------------------------------- /app_gradio/app.py: -------------------------------------------------------------------------------- 1 | """Provide a webcam screenshot and a burning question and get back a lover-like accurate answer!""" 2 | import json 3 | import logging 4 | import os 5 | from pathlib import Path 6 | from typing import Callable 7 | 8 | from dotenv import load_dotenv 9 | import gradio as gr 10 | from PIL import ImageStat 11 | from PIL.Image import Image 12 | import requests 13 | from util import encode_b64_image 14 | 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU 17 | 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | load_dotenv() # load environment variables from a .env file if it exists 21 | BACKEND_URL = os.getenv("BACKEND_URL") # URL of a backend to which to send image data 22 | 23 | APP_DIR = Path(__file__).resolve().parent # what is the directory for this application? 24 | FAVICON = APP_DIR / "logo.jpeg" # path to a small image for display in browser tab and social media 25 | README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown 26 | 27 | DEFAULT_PORT = 11700 28 | 29 | 30 | def main(): 31 | predictor = PredictorBackend(use_url=True) 32 | frontend = make_frontend(predictor.run, flagging=True) 33 | frontend.launch( 34 | server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104 35 | favicon_path=FAVICON, # what icon should we display in the address bar? 36 | ) 37 | 38 | 39 | def make_frontend(fn: Callable[[Image, str], str], flagging: bool = False): 40 | """Creates a gradio.Interface frontend for an image + text to text function.""" 41 | img_examples_dir = Path("question_answer") / "tests" / "support" / "images" 42 | img_example_fnames = [elem for elem in os.listdir(img_examples_dir) if elem.endswith(".jpg")] 43 | img_example_paths = [img_examples_dir / fname for fname in img_example_fnames] 44 | img_example_paths = sorted(img_example_paths) 45 | 46 | question_examples_dir = Path("question_answer") / "tests" / "support" / "questions" 47 | question_example_fnames = [elem for elem in os.listdir(question_examples_dir) if elem.endswith(".txt")] 48 | question_example_paths = [question_examples_dir / fname for fname in question_example_fnames] 49 | question_example_paths = sorted(question_example_paths) 50 | 51 | questions = [] 52 | for path in question_example_paths: 53 | with open(path, "r") as f: 54 | questions.append(f.readline()) 55 | 56 | examples = [[str(img_path), question] for img_path, question in zip(img_example_paths, questions)] 57 | 58 | allow_flagging = "never" 59 | if flagging: # logging user feedback to a local CSV file 60 | allow_flagging = "manual" 61 | flagging_callback = gr.CSVLogger() 62 | flagging_dir = "flagged" 63 | else: 64 | flagging_callback, flagging_dir = None, None 65 | 66 | readme = _load_readme(with_logging=allow_flagging == "manual") 67 | 68 | # build a basic browser interface to a Python function 69 | frontend = gr.Interface( 70 | fn=fn, # which Python function are we interacting with? 71 | outputs=gr.components.Textbox(), # what output widgets does it need? the default text widget 72 | # what input widgets does it need? we configure an image widget 73 | inputs=[ 74 | gr.components.Image(type="pil", label="Webcam Image", source="webcam"), 75 | gr.components.Textbox(label="Question"), 76 | ], 77 | title="Admirer", # what should we display at the top of the page? 78 | thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media? 79 | description=__doc__, # what should we display just above the interface? 80 | article=readme, # what long-form content should we display below the interface? 81 | examples=examples, # which potential inputs should we provide? 82 | cache_examples=False, # should we cache those inputs for faster inference? slows down start 83 | allow_flagging=allow_flagging, # should we show users the option to "flag" outputs? 84 | flagging_options=["incorrect", "offensive", "other"], # what options do users have for feedback? 85 | flagging_callback=flagging_callback, 86 | flagging_dir=flagging_dir, 87 | ) 88 | 89 | return frontend 90 | 91 | 92 | class PredictorBackend: 93 | """Interface to a backend that serves predictions. 94 | 95 | To communicate with a backend accessible via a URL, provide the url kwarg. 96 | 97 | Otherwise, runs a predictor locally. 98 | """ 99 | 100 | def __init__(self, use_url): 101 | if use_url: 102 | self.url = BACKEND_URL 103 | self._predict = self._predict_from_endpoint 104 | # Uncomment the following lines to run the predictor locally 105 | # else: 106 | # from question_answer.answer import Pipeline 107 | 108 | # model = Pipeline() 109 | # self._predict = model.predict 110 | 111 | def run(self, image, question): 112 | pred, metrics = self._predict_with_metrics(image, question) 113 | self._log_inference(pred, metrics) 114 | return pred 115 | 116 | def _predict_with_metrics(self, image, question): 117 | pred = self._predict(image, question) 118 | 119 | stats = ImageStat.Stat(image) 120 | metrics = { 121 | "image_mean_intensity": stats.mean, 122 | "image_median": stats.median, 123 | "image_extrema": stats.extrema, 124 | "image_area": image.size[0] * image.size[1], 125 | "pred_length": len(pred), 126 | } 127 | return pred, metrics 128 | 129 | def _predict_from_endpoint(self, image, question): 130 | """Send an image and question to an endpoint that accepts JSON and return the predicted text. 131 | 132 | The endpoint should expect a base64 representation of the image, encoded as a string, 133 | under the key "image" and a str representation of the question. It should return the predicted text under the key "pred". 134 | 135 | Parameters 136 | ---------- 137 | image 138 | A PIL image of handwritten text to be converted into a string 139 | 140 | question 141 | A string containing the user's question 142 | 143 | Returns 144 | ------- 145 | pred 146 | A string containing the predictor's guess of the text in the image. 147 | """ 148 | encoded_image = encode_b64_image(image) 149 | 150 | headers = {"Content-type": "application/json"} 151 | payload = json.dumps( 152 | {"image": "data:image/jpg;base64," + encoded_image, "question": "data:question/str;str," + question} 153 | ) 154 | 155 | response = requests.post(self.url, data=payload, headers=headers) 156 | print(response.json()) 157 | pred = response.json()["pred"] 158 | 159 | return pred 160 | 161 | def _log_inference(self, pred, metrics): 162 | for key, value in metrics.items(): 163 | logging.info(f"METRIC {key} {value}") 164 | logging.info(f"PRED >begin\n{pred}\nPRED >end") 165 | 166 | 167 | def _load_readme(with_logging=False): 168 | with open(README) as f: 169 | lines = f.readlines() 170 | if not with_logging: 171 | lines = lines[: lines.index("\n")] 172 | 173 | readme = "".join(lines) 174 | return readme 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /training/run_experiment.py: -------------------------------------------------------------------------------- 1 | """Experiment-running framework.""" 2 | import argparse 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only 8 | import torch 9 | 10 | from question_answer import callbacks as cb 11 | from question_answer import lit_models 12 | from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args 13 | 14 | 15 | # In order to ensure reproducible experiments, we must set random seeds. 16 | np.random.seed(42) 17 | torch.manual_seed(42) 18 | 19 | 20 | def _setup_parser(): 21 | """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" 22 | parser = argparse.ArgumentParser(add_help=False) 23 | 24 | # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision 25 | trainer_parser = pl.Trainer.add_argparse_args(parser) 26 | trainer_parser._action_groups[1].title = "Trainer Args" 27 | parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) 28 | parser.set_defaults(max_epochs=1) 29 | 30 | # Basic arguments 31 | parser.add_argument( 32 | "--wandb", 33 | action="store_true", 34 | default=False, 35 | help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.", 36 | ) 37 | parser.add_argument( 38 | "--profile", 39 | action="store_true", 40 | default=False, 41 | help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.", 42 | ) 43 | parser.add_argument( 44 | "--data_class", 45 | type=str, 46 | default="PICa", 47 | help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", 48 | ) 49 | parser.add_argument( 50 | "--model_class", 51 | type=str, 52 | default="ViT2GPT2", 53 | help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", 54 | ) 55 | parser.add_argument( 56 | "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." 57 | ) 58 | parser.add_argument( 59 | "--stop_early", 60 | type=int, 61 | default=0, 62 | help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." 63 | + " Default is 0.", 64 | ) 65 | 66 | # Get the data and model classes, so that we can add their specific arguments 67 | temp_args, _ = parser.parse_known_args() 68 | data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") 69 | model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") 70 | 71 | # Get data, model, and LitModel specific arguments 72 | data_group = parser.add_argument_group("Data Args") 73 | data_class.add_to_argparse(data_group) 74 | 75 | model_group = parser.add_argument_group("Model Args") 76 | model_class.add_to_argparse(model_group) 77 | 78 | lit_model_group = parser.add_argument_group("LitModel Args") 79 | lit_models.GPT2.add_to_argparse(lit_model_group) 80 | 81 | parser.add_argument("--help", "-h", action="help") 82 | return parser 83 | 84 | 85 | @rank_zero_only 86 | def _ensure_logging_dir(experiment_dir): 87 | """Create the logging directory via the rank-zero process, if necessary.""" 88 | Path(experiment_dir).mkdir(parents=True, exist_ok=True) 89 | 90 | 91 | def main(): 92 | """ 93 | Run an experiment. 94 | 95 | Sample command: 96 | ``` 97 | python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST 98 | ``` 99 | 100 | For basic help documentation, run the command 101 | ``` 102 | python training/run_experiment.py --help 103 | ``` 104 | 105 | The available command line args differ depending on some of the arguments, including --model_class and --data_class. 106 | 107 | To see which command line args are available and read their documentation, provide values for those arguments 108 | before invoking --help, like so: 109 | ``` 110 | python training/run_experiment.py --model_class=MLP --data_class=MNIST --help 111 | """ 112 | parser = _setup_parser() 113 | args = parser.parse_args() 114 | data, model = setup_data_and_model_from_args(args) 115 | 116 | lit_model_class = lit_models.GPT2 117 | 118 | if args.load_checkpoint is not None: 119 | lit_model = lit_model_class.load_from_checkpoint( 120 | args.load_checkpoint, args=args, model=model.vit2gpt2, tokenizer=model.gpt2_tokenizer 121 | ) 122 | else: 123 | lit_model = lit_model_class(model=model.vit2gpt2, tokenizer=model.gpt2_tokenizer, args=args) 124 | 125 | log_dir = Path("training") / "logs" 126 | _ensure_logging_dir(log_dir) 127 | logger = pl.loggers.TensorBoardLogger(log_dir) 128 | experiment_dir = logger.log_dir 129 | 130 | goldstar_metric = "validation/loss" 131 | filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" 132 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 133 | save_top_k=5, 134 | filename=filename_format, 135 | monitor=goldstar_metric, 136 | mode="min", 137 | auto_insert_metric_name=False, 138 | dirpath=experiment_dir, 139 | every_n_epochs=args.check_val_every_n_epoch, 140 | ) 141 | 142 | summary_callback = pl.callbacks.ModelSummary(max_depth=2) 143 | 144 | callbacks = [summary_callback, checkpoint_callback] 145 | if args.wandb: 146 | logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train") 147 | logger.watch(model, log_freq=max(100, args.log_every_n_steps)) 148 | logger.log_hyperparams(vars(args)) 149 | experiment_dir = logger.experiment.dir 150 | callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()] 151 | if args.stop_early: 152 | early_stopping_callback = pl.callbacks.EarlyStopping( 153 | monitor="validation/loss", mode="min", patience=args.stop_early 154 | ) 155 | callbacks.append(early_stopping_callback) 156 | 157 | if args.wandb: 158 | callbacks.append(cb.ImageToTextLogger()) 159 | 160 | trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) 161 | if args.profile: 162 | sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0) 163 | profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir) 164 | profiler.STEP_FUNCTIONS = {"training_step"} # only profile training 165 | else: 166 | profiler = pl.profiler.PassThroughProfiler() 167 | 168 | trainer.profiler = profiler 169 | 170 | trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate 171 | 172 | trainer.fit(lit_model, datamodule=data) 173 | 174 | trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing 175 | 176 | best_model_path = checkpoint_callback.best_model_path 177 | if best_model_path: 178 | rank_zero_info(f"Best model saved at: {best_model_path}") 179 | if args.wandb: 180 | rank_zero_info("Best model also uploaded to W&B ") 181 | trainer.test(datamodule=data, ckpt_path=best_model_path) 182 | else: 183 | trainer.test(lit_model, datamodule=data) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /requirements/prod.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.7 3 | # To update, run: 4 | # 5 | # pip-compile requirements/prod.in 6 | # 7 | aiohttp==3.8.1 8 | # via gradio 9 | aiosignal==1.2.0 10 | # via aiohttp 11 | analytics-python==1.4.0 12 | # via gradio 13 | anyio==3.6.1 14 | # via 15 | # httpcore 16 | # openai 17 | # starlette 18 | asgiref==3.5.2 19 | # via uvicorn 20 | async-timeout==4.0.2 21 | # via aiohttp 22 | asynctest==0.13.0 23 | # via aiohttp 24 | attrs==20.3.0 25 | # via aiohttp 26 | backoff==1.10.0 27 | # via analytics-python 28 | bcrypt==3.2.2 29 | # via paramiko 30 | boto3==1.24.40 31 | # via smart-open 32 | botocore==1.27.40 33 | # via 34 | # boto3 35 | # s3transfer 36 | cached-property==1.5.2 37 | # via h5py 38 | certifi==2021.10.8 39 | # via 40 | # httpcore 41 | # httpx 42 | # requests 43 | # sentry-sdk 44 | cffi==1.15.0 45 | # via 46 | # bcrypt 47 | # cryptography 48 | # pynacl 49 | charset-normalizer==2.0.12 50 | # via 51 | # aiohttp 52 | # requests 53 | click==8.1.2 54 | # via 55 | # uvicorn 56 | # wandb 57 | coloredlogs==15.0.1 58 | # via onnxruntime 59 | cryptography==37.0.2 60 | # via paramiko 61 | cycler==0.11.0 62 | # via matplotlib 63 | distro==1.8.0 64 | # via openai 65 | docker-pycreds==0.4.0 66 | # via wandb 67 | fastapi==0.78.0 68 | # via gradio 69 | ffmpy==0.3.0 70 | # via gradio 71 | filelock==3.8.0 72 | # via 73 | # huggingface-hub 74 | # transformers 75 | flatbuffers==22.9.24 76 | # via onnxruntime 77 | fonttools==4.33.3 78 | # via matplotlib 79 | frozenlist==1.3.0 80 | # via 81 | # aiohttp 82 | # aiosignal 83 | fsspec==2022.5.0 84 | # via gradio 85 | gitdb==4.0.9 86 | # via gitpython 87 | gitpython==3.1.29 88 | # via wandb 89 | gradio==3.0.21 90 | # via -r requirements/prod.in 91 | h11==0.12.0 92 | # via 93 | # httpcore 94 | # uvicorn 95 | h5py==3.6.0 96 | # via -r requirements/prod.in 97 | httpcore==0.15.0 98 | # via httpx 99 | httpx==0.23.0 100 | # via 101 | # gradio 102 | # openai 103 | huggingface-hub==0.10.0 104 | # via 105 | # timm 106 | # transformers 107 | humanfriendly==10.0 108 | # via coloredlogs 109 | idna==3.3 110 | # via 111 | # anyio 112 | # requests 113 | # rfc3986 114 | # yarl 115 | importlib-metadata==4.11.3 116 | # via 117 | # -r requirements/prod.in 118 | # click 119 | # huggingface-hub 120 | # transformers 121 | intel-openmp==2022.2.0 122 | # via mkl 123 | jinja2==2.11.3 124 | # via 125 | # -r requirements/prod.in 126 | # gradio 127 | jmespath==1.0.1 128 | # via 129 | # boto3 130 | # botocore 131 | kiwisolver==1.4.2 132 | # via matplotlib 133 | linkify-it-py==1.0.3 134 | # via markdown-it-py 135 | markdown-it-py[linkify,plugins]==2.1.0 136 | # via 137 | # gradio 138 | # mdit-py-plugins 139 | markupsafe==1.1.1 140 | # via 141 | # -r requirements/prod.in 142 | # jinja2 143 | matplotlib==3.5.2 144 | # via gradio 145 | mdit-py-plugins==0.3.0 146 | # via markdown-it-py 147 | mdurl==0.1.1 148 | # via markdown-it-py 149 | mkl==2022.2.0 150 | # via mkl-service 151 | mkl-service==2.4.0 152 | # via -r requirements/prod.in 153 | monotonic==1.6 154 | # via analytics-python 155 | mpmath==1.2.1 156 | # via sympy 157 | multidict==6.0.2 158 | # via 159 | # aiohttp 160 | # yarl 161 | numpy==1.21.6 162 | # via 163 | # -r requirements/prod.in 164 | # gradio 165 | # h5py 166 | # matplotlib 167 | # onnxruntime 168 | # pandas 169 | # torchvision 170 | # transformers 171 | onnxruntime==1.12.1 172 | # via -r requirements/prod.in 173 | openai==1.1.1 174 | # via -r requirements/prod.in 175 | orjson==3.7.2 176 | # via gradio 177 | packaging==20.9 178 | # via 179 | # huggingface-hub 180 | # matplotlib 181 | # onnxruntime 182 | # transformers 183 | pandas==1.3.5 184 | # via gradio 185 | paramiko==2.11.0 186 | # via gradio 187 | pathtools==0.1.2 188 | # via wandb 189 | pillow==7.1.2 190 | # via 191 | # -r requirements/prod.in 192 | # gradio 193 | # matplotlib 194 | # torchvision 195 | promise==2.3 196 | # via wandb 197 | protobuf==3.20.3 198 | # via 199 | # onnxruntime 200 | # wandb 201 | psutil==5.9.3 202 | # via wandb 203 | pycparser==2.21 204 | # via cffi 205 | pycryptodome==3.14.1 206 | # via gradio 207 | pydantic==1.9.1 208 | # via 209 | # fastapi 210 | # gradio 211 | # openai 212 | pydub==0.25.1 213 | # via gradio 214 | pynacl==1.5.0 215 | # via paramiko 216 | pyngrok==5.1.0 217 | # via -r requirements/prod.in 218 | pyparsing==2.4.2 219 | # via 220 | # matplotlib 221 | # packaging 222 | python-dateutil==2.8.2 223 | # via 224 | # analytics-python 225 | # botocore 226 | # matplotlib 227 | # pandas 228 | # wandb 229 | python-dotenv==0.21.0 230 | # via -r requirements/prod.in 231 | python-multipart==0.0.5 232 | # via gradio 233 | pytz==2022.1 234 | # via pandas 235 | pyyaml==5.4.1 236 | # via 237 | # huggingface-hub 238 | # pyngrok 239 | # timm 240 | # transformers 241 | # wandb 242 | regex==2022.9.13 243 | # via transformers 244 | requests==2.27.1 245 | # via 246 | # -r requirements/prod.in 247 | # analytics-python 248 | # gradio 249 | # huggingface-hub 250 | # torchvision 251 | # transformers 252 | # wandb 253 | rfc3986[idna2008]==1.5.0 254 | # via httpx 255 | s3transfer==0.6.0 256 | # via boto3 257 | sentry-sdk==1.10.1 258 | # via wandb 259 | setproctitle==1.3.2 260 | # via wandb 261 | shortuuid==1.0.9 262 | # via wandb 263 | six==1.16.0 264 | # via 265 | # analytics-python 266 | # docker-pycreds 267 | # mkl-service 268 | # paramiko 269 | # promise 270 | # python-dateutil 271 | # python-multipart 272 | # wandb 273 | smart-open[s3]==5.2.1 274 | # via -r requirements/prod.in 275 | smmap==5.0.0 276 | # via gitdb 277 | sniffio==1.2.0 278 | # via 279 | # anyio 280 | # httpcore 281 | # httpx 282 | starlette==0.19.1 283 | # via fastapi 284 | sympy==1.10.1 285 | # via onnxruntime 286 | tbb==2021.7.0 287 | # via mkl 288 | timm==0.6.11 289 | # via -r requirements/prod.in 290 | tokenizers==0.12.1 291 | # via transformers 292 | torch==1.12.0 293 | # via 294 | # -r requirements/prod.in 295 | # timm 296 | # torchvision 297 | torchvision==0.13.0 298 | # via 299 | # -r requirements/prod.in 300 | # timm 301 | tqdm==4.64.0 302 | # via 303 | # -r requirements/prod.in 304 | # huggingface-hub 305 | # openai 306 | # transformers 307 | transformers==4.22.2 308 | # via -r requirements/prod.in 309 | typing-extensions==4.7.1 310 | # via 311 | # aiohttp 312 | # anyio 313 | # asgiref 314 | # async-timeout 315 | # gitpython 316 | # huggingface-hub 317 | # importlib-metadata 318 | # kiwisolver 319 | # markdown-it-py 320 | # openai 321 | # pydantic 322 | # starlette 323 | # torch 324 | # torchvision 325 | # uvicorn 326 | # yarl 327 | uc-micro-py==1.0.1 328 | # via linkify-it-py 329 | urllib3==1.26.12 330 | # via 331 | # botocore 332 | # requests 333 | # sentry-sdk 334 | uvicorn==0.17.6 335 | # via gradio 336 | wandb==0.12.17 337 | # via -r requirements/prod.in 338 | yarl==1.7.2 339 | # via aiohttp 340 | zipp==3.8.0 341 | # via importlib-metadata 342 | 343 | # The following packages are considered to be unsafe in a requirements file: 344 | # setuptools 345 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # admirer 2 | 3 | ![demo](./assets/demo.png) 4 | 5 | ## Contents 6 | 7 | - [admirer](#admirer) 8 | - [Contents](#contents) 9 | - [The Transformees](#the-transformees) 10 | - [Description](#description) 11 | - [Inference Pipeline](#inference-pipeline) 12 | - [Usage](#usage) 13 | - [Production](#production) 14 | - [Development](#development) 15 | - [Contributing](#contributing) 16 | - [Setup](#setup) 17 | - [Repository Structure](#repository-structure) 18 | - [Testing](#testing) 19 | - [Code Style](#code-style) 20 | - [Credit](#credit) 21 | 22 | ## The Transformees 23 | 24 | 1. [Andrew Hinh](https://github.com/andrewhinh) 25 | 2. [Aleks Hiidenhovi](https://github.com/alekshiidenhovi) 26 | 27 | ## Description 28 | 29 | A website that uses webcam feeds to answer open-ended questions requiring outside knowledge. For more info, check out the ZenML [blog post](https://bit.ly/3BZ4YpB). 30 | 31 | ### Inference Pipeline 32 | 33 | ![inference_pipeline](./assets/inference_pipeline.png) 34 | 35 | The visual question-answering pipeline is inspired by the paper from Microsoft linked in the [credit section](#credit). In short, we prompt GPT-3 with a generated image caption and object tag list, the question-answer pair, and context examples that demonstrate the task at hand in a few-shot learning method, achieving a [BERTScore](http://bit.ly/3tM1mmc) computed F1 score of around .989 on the test set. 36 | 37 | ### Usage 38 | 39 | As a direct consequence of not feeding the image data directly to GPT-3, the best queries involve asking descriptive, counting, or similar questions about one or more objects visible in the background. For example, if there are two people in the image, one wearing a hat and the other wearing glasses, questions that would work well could include the following: 40 | 41 | - "How many people are in the room?" 42 | - "What color is the hat in the picture?" 43 | - "How many people are wearing glasses?" 44 | 45 | ## Production 46 | 47 | To setup the production server for the website, we: 48 | 49 | 1. Create an AWS Lambda function for the backend: 50 | 51 | ```bash 52 | . deploy/aws_login.sh 53 | python deploy/aws_lambda.py 54 | ``` 55 | 56 | 2. Implement continual development by updating the AWS Lambda backend whenever a commit is pushed to the repo and the BERTScore computed F1 score of the pipeline has improved: 57 | 58 | ```bash 59 | . deploy/cont_deploy.sh 60 | ``` 61 | 62 | ## Development 63 | 64 | ### Contributing 65 | 66 | To contribute, check out the [guide](./CONTRIBUTING.md). 67 | 68 | ### Setup 69 | 70 | 1. Install conda if necessary: 71 | 72 | ```bash 73 | # Install conda: https://conda.io/projects/conda/en/latest/user-guide/install/index.html#regular-installation 74 | # If on Windows, install chocolately: https://chocolatey.org/install. Then, run: 75 | # choco install make 76 | ``` 77 | 78 | 2. Create the conda environment locally: 79 | 80 | ```bash 81 | cd admirer 82 | make conda-update 83 | conda activate admirer 84 | make pip-tools 85 | export PYTHONPATH=. 86 | echo "export PYTHONPATH=.:$PYTHONPATH" >> ~/.bashrc 87 | ``` 88 | 89 | 3. Install pre-commit: 90 | 91 | ```bash 92 | pre-commit install 93 | ``` 94 | 95 | 4. Sign up for an OpenAI account and get an API key [here](https://beta.openai.com/account/api-keys). 96 | 5. Populate a `.env` file with your key and the backend URL in the format of `.env.template`, and reactivate the environment. 97 | 6. Sign up for a Weights and Biases account [here](https://wandb.ai/signup) and download the CLIP ONNX file locally: 98 | 99 | ```bash 100 | wandb login 101 | python ./training/stage_model.py --fetch --from_project admirer 102 | ``` 103 | 104 | 7. (Optional) Sign up for an AWS account [here](https://us-west-2.console.aws.amazon.com/ecr/create-repository?region=us-west-2) and set up your AWS credentials locally, referring to [this](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html#cli-configure-quickstart-config) as needed: 105 | 106 | ```bash 107 | aws configure 108 | ``` 109 | 110 | If the instructions aren't working for you, head to [this Google Colab](https://colab.research.google.com/drive/1Z34DLHJm1i1e1tnknICujfZC6IaToU3k?usp=sharing), make a copy of it, and run the cells there to get an environment set up. 111 | 112 | ### Repository Structure 113 | 114 | The repo is separated into main folders that each describe a part of the ML-project lifecycle, some of which contain interactive notebooks, and supporting files and folders that store configurations and workflow scripts: 115 | 116 | ```bash 117 | . 118 | ├── api_serverless # the backend handler code using AWS Lambda. 119 | ├── app_gradio # the frontend code using Gradio. 120 | ├── deploy # the AWS Lambda backend setup and continuous deployment code. 121 | ├── data_manage # the data management code using AWS S3 for training data and ZenML log storage, boto3 for data exploration, and ZenML + Great Expectations for data validation. 122 | ├── load_test # the load testing code using Locust. 123 | ├── monitoring # the model monitoring code using Gradio's flagging feature. 124 | ├── question_answer # the inference code. 125 | ├── tasks # the pipeline testing code. 126 | ├── training # the model development code using PyTorch, PyTorch Lightning, and Weights and Biases. 127 | ``` 128 | 129 | ### Testing 130 | 131 | From the main directory, there are various ways to test the pipeline: 132 | 133 | - To start a W&B hyperparameter optimization sweep for the caption model (on one GPU): 134 | 135 | ```bash 136 | . ./training/sweep/sweep.sh 137 | CUDA_VISIBLE_DEVICES=0 wandb agent --project ${PROJECT} --entity ${ENTITY} ${SWEEP_ID} 138 | ``` 139 | 140 | - To train the caption model (add `--strategy ddp_find_unused_parameters_false` for multi-GPU machines; takes ~7.5 hrs on an 8xA100 Lambda Labs instance): 141 | 142 | ```bash 143 | python ./training/run_experiment.py \ 144 | --data_class PICa --model_class ViT2GPT2 --gpus "-1" \ 145 | --wandb --log_every_n_steps 25 --max_epochs 300 \ 146 | --augment_data True --num_workers "$(nproc)" \ 147 | --batch_size 2 --one_cycle_max_lr 0.01 --top_k 780 --top_p 0.65 --max_label_length 50 148 | ``` 149 | 150 | - To test the caption model (best model can be downloaded from [here](https://wandb.ai/admirer/admirer-training/artifacts/model/model-2vgqajre/v4/files)): 151 | 152 | ```bash 153 | python ./training/test_model.py \ 154 | --data_class PICa --model_class ViT2GPT2 \ 155 | --num_workers "$(nproc)" --load_checkpoint training/model.pth 156 | ``` 157 | 158 | - To start the app locally (uncomment code in PredictorBackend.__init__ and set use_url=False to use the local model instead of the API): 159 | 160 | ```bash 161 | python app_gradio/app.py 162 | ``` 163 | 164 | - To test the Gradio frontend by launching and pinging the frontend locally: 165 | 166 | ```bash 167 | python -c "from app_gradio.tests.test_app import test_local_run; test_local_run()" 168 | ``` 169 | 170 | - To test the caption model's ability to memorize a single batch: 171 | 172 | ```bash 173 | . ./training/tests/test_memorize_caption.sh 174 | ``` 175 | 176 | - To run integration tests for the model pipeline: 177 | 178 | ```bash 179 | . ./tasks/integration_test.sh 180 | ``` 181 | 182 | - To run unit tests for the model pipeline: 183 | 184 | ```bash 185 | . ./tasks/unit_test.sh 186 | ``` 187 | 188 | - To test the whole model pipeline: 189 | 190 | ```bash 191 | . ./tasks/test.sh 192 | ``` 193 | 194 | ### Code Style 195 | 196 | - To lint your code: 197 | 198 | ```bash 199 | pre-commit run --all-files 200 | ``` 201 | 202 | ## Credit 203 | 204 | - GI4E for their [database](https://www.unavarra.es/gi4e/databases/gi4e/?languageId=1) and [Scale AI](https://scale.com/) for their annotations. 205 | - Facebook for their [image segmentation model](https://huggingface.co/facebook/detr-resnet-50-panoptic). 206 | - NLP Connect for their [base image caption model](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning) and Sachin Abeywardana for his [fine-tuning code](https://sachinruk.github.io/blog/pytorch/huggingface/2021/12/28/vit-to-gpt2-encoder-decoder-model.html). 207 | - OpenAI for their [CLIP text and image encoder code](https://huggingface.co/openai/clip-vit-base-patch16) and [GPT-3 API](https://openai.com/api/). 208 | - Microsoft for their [visual question answering code](https://github.com/microsoft/PICa). 209 | -------------------------------------------------------------------------------- /question_answer/data/pica.py: -------------------------------------------------------------------------------- 1 | """PICa Dataset class.""" 2 | import argparse 3 | import json 4 | from pathlib import Path 5 | import requests 6 | from io import BytesIO 7 | from typing import Callable, Dict, Optional, Sequence, Tuple 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from PIL import Image 12 | from pytorch_lightning.utilities.rank_zero import rank_zero_info 13 | import torch 14 | 15 | from question_answer import util 16 | from question_answer.data.base_data_module import BaseDataModule, load_and_print_info 17 | from question_answer.data.util import BaseDataset 18 | import question_answer.metadata.pica as metadata 19 | from question_answer.stems.webcam import WebcamStem 20 | 21 | 22 | IMAGE_SHAPE = metadata.IMAGE_SHAPE 23 | RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME 24 | PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME 25 | 26 | NUM_ADDED_EXAMPLES = metadata.NUM_ADDED_EXAMPLES 27 | NUM_TRAINVAL = metadata.NUM_TRAINVAL 28 | NUM_VAL_EXAMPLES = metadata.NUM_VAL_EXAMPLES 29 | 30 | # In order to ensure reproducible experiments, we must set random seeds. 31 | np.random.seed(42) 32 | torch.manual_seed(42) 33 | 34 | 35 | class PICa(BaseDataModule): 36 | """PICa webcam screenshots + annotations dataset.""" 37 | 38 | def __init__(self, args: argparse.Namespace = None): 39 | super().__init__(args) 40 | self.augment = self.args.get("augment_data", "true").lower() == "true" 41 | 42 | self.input_dims = metadata.DIMS # We assert that this is correct in setup() 43 | self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() 44 | 45 | self.transform = WebcamStem() 46 | self.trainval_transform = WebcamStem(augment=self.augment) 47 | 48 | self.data_file = RAW_DATA_DIRNAME / "admirer-pica.json" 49 | 50 | self.test_ids = self.calc_test_ids() 51 | self.validation_ids = self.calc_validation_ids() 52 | self.train_ids = self.calc_train_ids() 53 | self.all_ids = self.train_ids + self.validation_ids + self.test_ids 54 | 55 | @staticmethod 56 | def add_to_argparse(parser): 57 | BaseDataModule.add_to_argparse(parser) 58 | parser.add_argument("--augment_data", type=str, default="true") 59 | return parser 60 | 61 | def prepare_data(self, *args, **kwargs) -> None: 62 | if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): 63 | return 64 | rank_zero_info("PICa.prepare_data: Logging dataset info to a json file...") 65 | 66 | properties = {} 67 | for split in ["train", "val", "test"]: 68 | screenshots, captions = get_screenshots_and_captions(self, split=split) 69 | save_screenshots_and_captions(screenshots=screenshots, captions=captions, split=split) 70 | 71 | properties.update( 72 | { 73 | id_: { 74 | "image_shape": screenshots[id_].size[::-1], 75 | "num_words": _num_words(caption), 76 | } 77 | for id_, caption in captions.items() 78 | } 79 | ) 80 | 81 | with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: 82 | json.dump(properties, f, indent=4) 83 | 84 | def load_image(self, id: str) -> Image.Image: 85 | """Load and return an image of a webcam screenshot.""" 86 | url = self.screenshot_url_by_id(id) 87 | response = requests.get(url) 88 | return util.read_image_pil_file(BytesIO(response.content)) 89 | 90 | def setup(self, stage: str = None) -> None: 91 | def _load_dataset(split: str, transform: Callable) -> BaseDataset: 92 | screenshots, captions = load_processed_crops_and_labels(split) 93 | return BaseDataset(screenshots, captions, transform=transform) 94 | 95 | rank_zero_info(f"PICa.setup({stage}): Loading PICa webcam screenshots and captions...") 96 | validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) 97 | 98 | if stage == "fit" or stage is None: 99 | self.data_train = _load_dataset(split="train", transform=self.trainval_transform) 100 | self.data_val = _load_dataset(split="val", transform=self.transform) 101 | 102 | if stage == "test" or stage is None: 103 | self.data_test = _load_dataset(split="test", transform=self.transform) 104 | 105 | def __repr__(self) -> str: 106 | """Print info about the dataset.""" 107 | basic = "PICa Dataset\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" 108 | if self.data_train is None and self.data_val is None and self.data_test is None: 109 | return basic 110 | 111 | x, y = next(iter(self.train_dataloader())) 112 | xt, yt = next(iter(self.test_dataloader())) 113 | data = ( 114 | f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" 115 | f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" 116 | f"Train Batch y stats: {(y)}\n" 117 | f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" 118 | f"Test Batch y stats: {(yt)}\n" 119 | ) 120 | return basic + data 121 | 122 | def ids_by_split(self, split): 123 | return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids}[split] 124 | 125 | def calc_train_ids(self): 126 | """A list of screenshot IDs which are in the training set.""" 127 | return list(set(range(0, NUM_ADDED_EXAMPLES)) - (set(self.test_ids) | set(self.validation_ids))) 128 | 129 | def calc_validation_ids(self): 130 | """A list of screenshot IDs which are in the validation set.""" 131 | ids = [] 132 | while len(ids) < NUM_VAL_EXAMPLES: 133 | id = np.random.randint(low=0, high=int(NUM_TRAINVAL)) 134 | if id in ids: 135 | continue 136 | else: 137 | ids.append(id) 138 | return ids 139 | 140 | def calc_test_ids(self): 141 | """A list of screenshot IDs which are in the test set.""" 142 | return list(range(NUM_TRAINVAL, NUM_ADDED_EXAMPLES)) 143 | 144 | def screenshot_url_by_id(self, id): 145 | """A dict mapping a screenshot id to its filename.""" 146 | df = pd.read_json(self.data_file) 147 | return df.loc[id, "webcam"] 148 | 149 | def caption_by_id(self, id): 150 | """A dict mapping a screenshot id to its caption.""" 151 | df = pd.read_json(self.data_file) 152 | return df.loc[id, "caption"] 153 | 154 | 155 | def validate_input_and_output_dimensions( 156 | input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] 157 | ) -> None: 158 | """Validate input and output dimensions against the properties of the dataset.""" 159 | properties = get_dataset_properties() 160 | 161 | max_image_shape = properties["image_shape"]["max"] 162 | assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] 163 | 164 | # Add 2 because of start and end tokens 165 | assert output_dims is not None and output_dims[0] >= properties["num_words"]["max"] + 2 166 | 167 | 168 | def get_screenshots_and_captions(dataset: PICa, split: str) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: 169 | """Create screenshots + captions for a given split, with resizing.""" 170 | screenshots = {} 171 | captions = {} 172 | ids = dataset.ids_by_split(split) 173 | for id in ids: 174 | image = dataset.load_image(id) 175 | screenshots[id] = image.resize(IMAGE_SHAPE) 176 | captions[id] = dataset.caption_by_id(id) 177 | assert len(screenshots) == len(captions) 178 | return screenshots, captions 179 | 180 | 181 | def save_screenshots_and_captions(screenshots: Dict[str, Image.Image], captions: Dict[str, str], split: str): 182 | """Save crops, labels and shapes of crops of a split.""" 183 | (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) 184 | 185 | with open(_captions_filename(split), "w") as f: 186 | json.dump(captions, f, indent=4) 187 | 188 | for id_, crop in screenshots.items(): 189 | crop.save(_screenshot_filename(id_, split)) 190 | 191 | 192 | def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: 193 | """Load processed crops and labels for given split.""" 194 | with open(_captions_filename(split), "r") as f: 195 | labels = json.load(f) 196 | 197 | sorted_ids = sorted(labels.keys()) 198 | ordered_screenshots = [] 199 | ordered_captions = [] 200 | for id_ in sorted_ids: 201 | image = Image.open(_screenshot_filename(id_, split)) 202 | ordered_screenshots.append(image.convert(mode=image.mode)) 203 | ordered_captions.append(labels[id_]) 204 | 205 | assert len(ordered_screenshots) == len(ordered_captions) 206 | return ordered_screenshots, ordered_captions 207 | 208 | 209 | def get_dataset_properties() -> dict: 210 | """Return properties describing the overall dataset.""" 211 | with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: 212 | properties = json.load(f) 213 | 214 | def _get_property_values(key: str) -> list: 215 | return [_[key] for _ in properties.values()] 216 | 217 | image_shapes = np.array(_get_property_values("image_shape")) 218 | aspect_ratios = image_shapes[:, 1] / image_shapes[:, 0] 219 | return { 220 | "num_words": {"min": min(_get_property_values("num_words")), "max": max(_get_property_values("num_words"))}, 221 | "image_shape": {"min": image_shapes.min(axis=0), "max": image_shapes.max(axis=0)}, 222 | "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, 223 | } 224 | 225 | 226 | def _captions_filename(split: str) -> Path: 227 | """Return filename of processed labels.""" 228 | return PROCESSED_DATA_DIRNAME / split / "_captions.json" 229 | 230 | 231 | def _screenshot_filename(id_: str, split: str) -> Path: 232 | """Return filename of processed crop.""" 233 | return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" 234 | 235 | 236 | def _num_words(caption: str) -> int: 237 | """Return number of words in caption.""" 238 | word_list = caption.split() 239 | return len(word_list) 240 | 241 | 242 | if __name__ == "__main__": 243 | load_and_print_info(PICa) 244 | -------------------------------------------------------------------------------- /training/stage_model.py: -------------------------------------------------------------------------------- 1 | """Stages a model for use in production. 2 | 3 | If based on a checkpoint, the model is saved locally and uploaded to W&B. 4 | 5 | If based on a model that is already uploaded, the model file is downloaded locally. 6 | 7 | For details on how the W&B artifacts backing the checkpoints and models are handled, 8 | see the documenation for stage_model.find_artifact. 9 | """ 10 | # Imports 11 | import argparse 12 | from pathlib import Path 13 | import tempfile 14 | 15 | from dotenv import load_dotenv 16 | import torch 17 | import wandb 18 | from wandb import Artifact 19 | from wandb.sdk.wandb_run import Run 20 | 21 | from question_answer.lit_models import GPT2 22 | from training.util import setup_data_and_model_from_args 23 | 24 | 25 | # Variables 26 | # these names are all set by the pl.loggers.WandbLogger 27 | MODEL_CHECKPOINT_TYPE = "model" 28 | BEST_CHECKPOINT_ALIAS = "best" 29 | MODEL_CHECKPOINT_PATH = "model.ckpt" 30 | LOG_DIR = Path("training") / "logs" 31 | 32 | STAGED_MODEL_TYPE = "prod-ready" 33 | STAGED_MODEL_FILENAME = "model.pt" # standard nomenclature; pytorch_model.bin is also used 34 | 35 | PROJECT_ROOT = Path(__file__).resolve().parents[1] 36 | DIRECTORY = Path("question_answer/") 37 | 38 | LITMODEL_CLASS = GPT2 39 | 40 | api = wandb.Api() 41 | 42 | DEFAULT_ENTITY = api.default_entity 43 | DEFAULT_FROM_PROJECT = "admirer-training" 44 | DEFAULT_TO_PROJECT = "admirer-training" 45 | DEFAULT_STAGED_MODEL_NAME = "answer" 46 | 47 | PROD_STAGING_ROOT = PROJECT_ROOT / DIRECTORY / Path("artifacts") 48 | PROD_PATHS = ["coco_annotations", "coco_clip_new", "transformers", "onnx"] 49 | 50 | load_dotenv() 51 | 52 | 53 | def main(args): 54 | prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name 55 | prod_staging_directory.mkdir(exist_ok=True, parents=True) 56 | # if we're just fetching an already compiled model 57 | if args.fetch: 58 | staged_files = f"{DEFAULT_ENTITY}/{args.from_project}/{args.staged_model_name}:latest" 59 | artifact = download_artifact(staged_files, prod_staging_directory) 60 | print_info(artifact) 61 | return # and we're done 62 | 63 | # otherwise, we'll need to download the weights, compile the model, and save it 64 | with wandb.init( 65 | job_type="stage", project=args.to_project, dir=LOG_DIR 66 | ): # log staging to W&B so prod and training are connected 67 | # find the model checkpoint and retrieve its artifact name and an api handle 68 | if args.run: 69 | ckpt_at, ckpt_api = find_artifact( 70 | project=args.from_project, type=MODEL_CHECKPOINT_TYPE, alias=args.ckpt_alias, run=args.run 71 | ) 72 | 73 | # get the run that produced that checkpoint 74 | logging_run = get_logging_run(ckpt_api) 75 | print_info(ckpt_api, logging_run) 76 | metadata = get_checkpoint_metadata(logging_run, ckpt_api) 77 | 78 | with tempfile.TemporaryDirectory() as tmp_dir: 79 | # download the checkpoint to a temporary directory 80 | download_artifact(ckpt_at, tmp_dir) 81 | # reload the model from that checkpoint 82 | model = load_model_from_checkpoint(metadata, directory=tmp_dir) 83 | # save the model to .pt in the staging directory 84 | save_model_to_pt(model, directory=prod_staging_directory / "transformers" / "trained_caption") 85 | 86 | # create an artifact for the staged, deployable model 87 | staged_at = wandb.Artifact(args.staged_model_name, type=STAGED_MODEL_TYPE) 88 | # upload the staged model so it can be downloaded elsewhere 89 | upload_staged_model(staged_at, from_directory=prod_staging_directory) 90 | 91 | 92 | def find_artifact(project: str, type: str, alias: str, run=None): 93 | """Finds the artifact of a given type with a given alias under the entity and project. 94 | 95 | Parameters 96 | ---------- 97 | project 98 | The project to find the artifact from. 99 | type 100 | The name of the type of the artifact. 101 | alias : str 102 | The alias for this artifact. This alias must be unique within the 103 | provided type for the run, if provided, or for the project, 104 | if the run is not provided. 105 | run : str 106 | Optionally, the run in which the artifact is located. 107 | 108 | Returns 109 | ------- 110 | Tuple[path, artifact] 111 | An identifying path and an API handle for a matching artifact. 112 | """ 113 | if run is not None: 114 | path = _find_artifact_run(project=project, type=type, run=run, alias=alias) 115 | else: 116 | path = _find_artifact_project(project=project, type=type, alias=alias) 117 | return path, api.artifact(path) 118 | 119 | 120 | def get_logging_run(artifact: Artifact) -> Run: 121 | """Get the W&B run that logged the artifact""" 122 | api_run = artifact.logged_by() 123 | return api_run 124 | 125 | 126 | def print_info(artifact: Artifact, run=None) -> None: 127 | """Prints info about the artifact and the run""" 128 | run = get_logging_run(artifact) 129 | 130 | full_artifact_name = f"{artifact.entity}/{artifact.project}/{artifact.name}" 131 | print(f"Using artifact {full_artifact_name}") 132 | artifact_url_prefix = f"https://wandb.ai/{artifact.entity}/{artifact.project}/artifacts/{artifact.type}" 133 | artifact_url_suffix = f"{artifact.name.replace(':', '/')}" 134 | print(f"View at URL: {artifact_url_prefix}/{artifact_url_suffix}") 135 | 136 | print(f"Logged by {run.name} -- {run.project}/{run.entity}/{run.id}") 137 | print(f"View at URL: {run.url}") 138 | 139 | 140 | def get_checkpoint_metadata(run, checkpoint): 141 | config = run.config 142 | out = {"config": config} 143 | try: 144 | ckpt_filename = checkpoint.metadata["original_filename"] 145 | out["original_filename"] = ckpt_filename 146 | metric_key = checkpoint.metadata["ModelCheckpoint"]["monitor"] 147 | metric_score = checkpoint.metadata["score"] 148 | out[metric_key] = metric_score 149 | except KeyError: 150 | pass 151 | return out 152 | 153 | 154 | def download_artifact(artifact_path: str, target_directory: Path) -> Artifact: 155 | """Downloads the artifact at artifact_path to the target directory.""" 156 | if wandb.run is not None: # if we are inside a W&B run, track that we used this artifact 157 | artifact: Artifact = wandb.use_artifact(artifact_path) 158 | else: # otherwise, just download the artifact via the API 159 | artifact: Artifact = api.artifact(artifact_path) 160 | artifact.download(root=target_directory) 161 | 162 | return artifact 163 | 164 | 165 | def load_model_from_checkpoint(ckpt_metadata, directory): 166 | config = ckpt_metadata["config"] 167 | args = argparse.Namespace(**config) 168 | 169 | _, model = setup_data_and_model_from_args(args) 170 | 171 | # load LightningModule from checkpoint 172 | pth = Path(directory) / MODEL_CHECKPOINT_PATH 173 | lit_model = LITMODEL_CLASS.load_from_checkpoint( 174 | checkpoint_path=pth, args=args, model=model.vit2gpt2, tokenizer=model.gpt2_tokenizer, strict=False 175 | ) 176 | lit_model.eval() 177 | 178 | return lit_model 179 | 180 | 181 | def save_model_to_pt(model, directory): 182 | path = Path(directory) / STAGED_MODEL_FILENAME 183 | torch.save(model.state_dict(), path) 184 | 185 | 186 | def upload_staged_model(staged_at: Artifact, from_directory: Path) -> None: 187 | """Uploads a staged arfifact to W&B""" 188 | staged_at.add_dir(from_directory) 189 | wandb.log_artifact(staged_at) 190 | 191 | 192 | def _find_artifact_run(project, type, run, alias): 193 | run_name = f"{DEFAULT_ENTITY}/{project}/{run}" 194 | api_run = api.run(run_name) 195 | artifacts = api_run.logged_artifacts() 196 | 197 | match = [art for art in artifacts if alias in art.aliases and art.type == type] 198 | if not match: 199 | raise ValueError(f"No artifact with alias {alias} found at {run_name} of type {type}") 200 | if len(match) > 1: 201 | raise ValueError(f"Multiple artifacts ({len(match)}) with alias {alias} found at {run_name} of type {type}") 202 | return f"{DEFAULT_ENTITY}/{project}/{match[0].name}" 203 | 204 | 205 | def _find_artifact_project(project, type, alias): 206 | project_name = f"{DEFAULT_ENTITY}/{project}" 207 | api_project = api.project(project, entity=DEFAULT_ENTITY) 208 | api_artifact_types = api_project.artifacts_types() 209 | # loop through all artifact types in this project 210 | for artifact_type in api_artifact_types: 211 | if artifact_type.name != type: 212 | continue # skipping those that don't match type 213 | collections = artifact_type.collections() 214 | # loop through all artifacts and their versions 215 | for collection in collections: 216 | versions = collection.versions() 217 | for version in versions: 218 | if alias in version.aliases: # looking for the first one that matches the alias 219 | return f"{project_name}/{version.name}" 220 | raise ValueError(f"Artifact with alias {alias} not found in type {type} in {project_name}") 221 | raise ValueError(f"Artifact type {type} not found. {project_name} could be private or not exist.") 222 | 223 | 224 | def _setup_parser(): 225 | parser = argparse.ArgumentParser(description=__doc__) 226 | parser.add_argument( 227 | "--fetch", 228 | action="store_true", 229 | help=f"If provided, download the latest version of artifact files to {PROD_STAGING_ROOT}.", 230 | ) 231 | parser.add_argument( 232 | "--from_project", 233 | type=str, 234 | default=DEFAULT_FROM_PROJECT, 235 | help=f"Project from which to download the checkpoint. Default is {DEFAULT_FROM_PROJECT}", 236 | ) 237 | parser.add_argument( 238 | "--to_project", 239 | type=str, 240 | default=DEFAULT_TO_PROJECT, 241 | help=f"Project to which to upload the compiled model. Default is {DEFAULT_TO_PROJECT}.", 242 | ) 243 | parser.add_argument( 244 | "--run", 245 | type=str, 246 | default=None, 247 | help=f"Optionally, the name of a run to check for an artifact of type {MODEL_CHECKPOINT_TYPE} that has the provided CKPT_ALIAS. Default is None.", 248 | ) 249 | parser.add_argument( 250 | "--ckpt_alias", 251 | type=str, 252 | default=BEST_CHECKPOINT_ALIAS, 253 | help=f"Alias that identifies which model checkpoint should be staged.The artifact's alias can be set manually or programmatically elsewhere. Default is '{BEST_CHECKPOINT_ALIAS}'.", 254 | ) 255 | parser.add_argument( 256 | "--staged_model_name", 257 | type=str, 258 | default=DEFAULT_STAGED_MODEL_NAME, 259 | help=f"Name to give the staged model artifact. Default is '{DEFAULT_STAGED_MODEL_NAME}'.", 260 | ) 261 | return parser 262 | 263 | 264 | if __name__ == "__main__": 265 | parser = _setup_parser() 266 | args = parser.parse_args() 267 | main(args) 268 | -------------------------------------------------------------------------------- /deploy/aws_lambda.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Serverless Backend Setup using AWS Lambda\n", 8 | "- Commented cells are meant to be run in the terminal separately" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Imports" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "%cd .." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import json\n", 34 | "import os\n", 35 | "import requests\n", 36 | "\n", 37 | "from app_gradio import app\n", 38 | "from question_answer.answer import Pipeline" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "## Build container image" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "os.environ[\"LAMBDA_NAME\"] = \"admirer-backend\"" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "!docker build -t $LAMBDA_NAME . --file ./api_serverless/Dockerfile #--no-cache" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# export LAMBDA_NAME=admirer-backend\n", 73 | "# docker run -p 9000:8080 $LAMBDA_NAME\\:latest" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "!curl -XPOST \\\n", 83 | " \"http://localhost:9000/2015-03-31/functions/function/invocations\" \\\n", 84 | " -d '{\"image_url\": \"./question_answer/tests/support/images/img.jpg\", \"question\": \"What color is my hair\"}'" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## Upload to the container registry" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# aws configure" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "aws_account_id, = !aws sts get-caller-identity \\\n", 110 | " --query \"Account\"\n", 111 | "aws_region, = !aws configure get region \n", 112 | "\n", 113 | "os.environ[\"AWS_REGION\"] = aws_region\n", 114 | "os.environ[\"AWS_ACCOUNT_ID\"] = aws_account_id.strip('\"')\n", 115 | "\n", 116 | "!echo $AWS_ACCOUNT_ID\n", 117 | "!echo $AWS_REGION" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "os.environ[\"ECR_URI\"] = \".\".join(\n", 127 | " [os.environ[\"AWS_ACCOUNT_ID\"], \"dkr\", \"ecr\", os.environ[\"AWS_REGION\"], \"amazonaws.com\"])\n", 128 | "\n", 129 | "!echo $ECR_URI" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "!aws ecr get-login-password --region $AWS_REGION \\\n", 139 | " | docker login --username AWS --password-stdin $ECR_URI" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "!aws ecr create-repository \\\n", 149 | " --repository-name $LAMBDA_NAME \\\n", 150 | " --image-scanning-configuration scanOnPush=true --image-tag-mutability MUTABLE \\\n", 151 | " | jq -C" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "os.environ[\"IMAGE_URI\"] = \"/\".join([os.environ[\"ECR_URI\"], os.environ[\"LAMBDA_NAME\"]])" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "!docker tag $LAMBDA_NAME\\:latest $IMAGE_URI\\:latest" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "!docker push $IMAGE_URI\\:latest" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "## Create a Lambda function" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "os.environ[\"LAMBDA_ROLE_NAME\"] = \"lambda-role\"" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "!aws iam create-role \\\n", 204 | " --role-name $LAMBDA_ROLE_NAME \\\n", 205 | " --assume-role-policy-document '{\"Version\": \"2012-10-17\", \"Statement\": [{\"Effect\": \"Allow\", \"Principal\": {\"Service\": \"lambda.amazonaws.com\"}, \"Action\": \"sts:AssumeRole\"}]}' \\\n", 206 | " | jq -C" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "lambda_role_arn, = !aws iam get-role --role-name $LAMBDA_ROLE_NAME --output json | jq -r '.Role.Arn'\n", 216 | "lambda_role_arn = lambda_role_arn.strip('\"')\n", 217 | "\n", 218 | "os.environ[\"LAMBDA_ROLE_ARN\"] = lambda_role_arn\n", 219 | "!echo $LAMBDA_ROLE_ARN" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "# allow this IAM role to execute Lambdas\n", 229 | "!aws iam attach-role-policy \\\n", 230 | " --role-name $LAMBDA_ROLE_NAME \\\n", 231 | " --policy-arn arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "# allow this IAM role to write to logs -- required and also important for debugging Lambdas\n", 241 | "!aws iam attach-role-policy \\\n", 242 | " --role-name $LAMBDA_ROLE_NAME \\\n", 243 | " --policy-arn arn:aws:iam::aws:policy/AWSXRayDaemonWriteAccess" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "!aws lambda create-function \\\n", 253 | " --function-name $LAMBDA_NAME \\\n", 254 | " --region $AWS_REGION \\\n", 255 | " --package-type Image \\\n", 256 | " --code ImageUri=$IMAGE_URI:latest \\\n", 257 | " --role $LAMBDA_ROLE_ARN | jq -C" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "!aws lambda update-function-configuration \\\n", 267 | " --function-name $LAMBDA_NAME \\\n", 268 | " --region $AWS_REGION \\\n", 269 | " --timeout 60 \\\n", 270 | " --memory-size 10240 | jq -C" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "!aws lambda invoke \\\n", 280 | " --function-name $LAMBDA_NAME \\\n", 281 | " --invocation-type RequestResponse \\\n", 282 | " --payload '{\"image_url\": \"./question_answer/tests/support/images/img.jpg\", \"question\": \"What color is my hair\"}' \\\n", 283 | " --cli-binary-format raw-in-base64-out lambda.out | jq -C\n", 284 | "\n", 285 | "!cat lambda.out" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "## Add an HTTP endpoint with a URL" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "!aws lambda create-function-url-config \\\n", 302 | " --function-name $LAMBDA_NAME \\\n", 303 | " --auth-type NONE \\\n", 304 | " --cors '{\"AllowOrigins\": [\"*\"], \"AllowCredentials\": false}' \\\n", 305 | " | jq -C" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "# Careful here!!!\n", 315 | "# \"\"\"\n", 316 | "!aws lambda add-permission \\\n", 317 | " --function-name $LAMBDA_NAME \\\n", 318 | " --action lambda:invokeFunctionUrl \\\n", 319 | " --statement-id \"open-access\" \\\n", 320 | " --principal \"*\" \\\n", 321 | " --function-url-auth-type NONE | jq -C\n", 322 | "# \"\"\"" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "lambda_url, = !aws lambda get-function-url-config --function-name $LAMBDA_NAME | jq .FunctionUrl\n", 332 | "lambda_url = lambda_url.strip('\"')\n", 333 | "\n", 334 | "lambda_url" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "image_url = \"./question_answer/tests/support/images/img.jpg\"\n", 344 | "question_path = \"./question_answer/tests/support/questions/question.txt\"\n", 345 | "with open(question_path, \"r\") as f: question = f.readline()\n", 346 | "\n", 347 | "headers = {\"Content-type\": \"application/json\"}\n", 348 | "payload = json.dumps({\"image_url\": image_url, \"question\": question})\n", 349 | "\n", 350 | "response = requests.post(\n", 351 | " lambda_url, data=payload, headers=headers)\n", 352 | "pred = response.json()[\"pred\"]\n", 353 | "\n", 354 | "print(pred)" 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "## Connect AWS with Gradio" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "serverless_backend = app.PredictorBackend(url=lambda_url)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "frontend_serverless_backend = app.make_frontend(serverless_backend.run, flagging=True)\n", 380 | "frontend_serverless_backend.launch(share=True)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "frontend_serverless_backend.close()" 390 | ] 391 | } 392 | ], 393 | "metadata": { 394 | "kernelspec": { 395 | "display_name": "Python 3.7.13 64-bit ('admirer')", 396 | "language": "python", 397 | "name": "python3" 398 | }, 399 | "language_info": { 400 | "codemirror_mode": { 401 | "name": "ipython", 402 | "version": 3 403 | }, 404 | "file_extension": ".py", 405 | "mimetype": "text/x-python", 406 | "name": "python", 407 | "nbconvert_exporter": "python", 408 | "pygments_lexer": "ipython3", 409 | "version": "3.7.13" 410 | }, 411 | "orig_nbformat": 4, 412 | "vscode": { 413 | "interpreter": { 414 | "hash": "4c4de3d17692a4fce36158e1e6b4cc65d2c1c1dbb8a445fcd77e7a07c1299f79" 415 | } 416 | } 417 | }, 418 | "nbformat": 4, 419 | "nbformat_minor": 2 420 | } 421 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.7 3 | # To update, run: 4 | # 5 | # pip-compile requirements/dev.in 6 | # 7 | absl-py==1.0.0 8 | # via 9 | # ml-metadata 10 | # ml-pipelines-sdk 11 | # tensorboard 12 | aiohttp==3.8.1 13 | # via 14 | # -c requirements/prod.txt 15 | # fsspec 16 | aiosignal==1.2.0 17 | # via 18 | # -c requirements/prod.txt 19 | # aiohttp 20 | altair==4.2.0 21 | # via great-expectations 22 | analytics-python==1.4.0 23 | # via 24 | # -c requirements/prod.txt 25 | # zenml 26 | anyio==3.6.1 27 | # via 28 | # -c requirements/prod.txt 29 | # jupyter-server 30 | argon2-cffi==21.3.0 31 | # via 32 | # jupyter-server 33 | # nbclassic 34 | # notebook 35 | argon2-cffi-bindings==21.2.0 36 | # via argon2-cffi 37 | async-timeout==4.0.2 38 | # via 39 | # -c requirements/prod.txt 40 | # aiohttp 41 | asynctest==0.13.0 42 | # via 43 | # -c requirements/prod.txt 44 | # aiohttp 45 | attrs==20.3.0 46 | # via 47 | # -c requirements/prod.txt 48 | # aiohttp 49 | # jsonschema 50 | # ml-metadata 51 | # pytest 52 | backcall==0.2.0 53 | # via ipython 54 | backoff==1.10.0 55 | # via 56 | # -c requirements/prod.txt 57 | # analytics-python 58 | backports.zoneinfo==0.2.1 59 | # via 60 | # pytz-deprecation-shim 61 | # tzlocal 62 | bcrypt==3.2.2 63 | # via 64 | # -c requirements/prod.txt 65 | # passlib 66 | beautifulsoup4==4.11.1 67 | # via nbconvert 68 | bleach==5.0.0 69 | # via nbconvert 70 | boltons==21.0.0 71 | # via -r requirements/dev.in 72 | cachetools==4.2.4 73 | # via google-auth 74 | certifi==2021.10.8 75 | # via 76 | # -c requirements/prod.txt 77 | # requests 78 | cffi==1.15.0 79 | # via 80 | # -c requirements/prod.txt 81 | # argon2-cffi-bindings 82 | # bcrypt 83 | # cryptography 84 | cfgv==3.3.1 85 | # via pre-commit 86 | charset-normalizer==2.0.12 87 | # via 88 | # -c requirements/prod.txt 89 | # aiohttp 90 | # requests 91 | click==8.1.2 92 | # via 93 | # -c requirements/prod.txt 94 | # click-params 95 | # great-expectations 96 | # nltk 97 | # zenml 98 | click-params==0.3.0 99 | # via zenml 100 | colorama==0.4.6 101 | # via great-expectations 102 | commonmark==0.9.1 103 | # via rich 104 | coverage[toml]==6.4 105 | # via 106 | # -r requirements/dev.in 107 | # pytest-cov 108 | cryptography==37.0.2 109 | # via 110 | # -c requirements/prod.txt 111 | # great-expectations 112 | cycler==0.11.0 113 | # via 114 | # -c requirements/prod.txt 115 | # matplotlib 116 | decorator==5.1.1 117 | # via 118 | # ipython 119 | # validators 120 | defusedxml==0.7.1 121 | # via 122 | # -r requirements/dev.in 123 | # nbconvert 124 | distlib==0.3.4 125 | # via virtualenv 126 | distro==1.8.0 127 | # via 128 | # -c requirements/prod.txt 129 | # zenml 130 | docker==4.4.4 131 | # via ml-pipelines-sdk 132 | entrypoints==0.4 133 | # via 134 | # altair 135 | # jupyter-client 136 | # nbconvert 137 | fasteners==0.18 138 | # via google-apitools 139 | fastjsonschema==2.15.3 140 | # via nbformat 141 | filelock==3.8.0 142 | # via 143 | # -c requirements/prod.txt 144 | # virtualenv 145 | fonttools==4.33.3 146 | # via 147 | # -c requirements/prod.txt 148 | # matplotlib 149 | frozenlist==1.3.0 150 | # via 151 | # -c requirements/prod.txt 152 | # aiohttp 153 | # aiosignal 154 | fsspec[http]==2022.5.0 155 | # via 156 | # -c requirements/prod.txt 157 | # pytorch-lightning 158 | gitdb==4.0.9 159 | # via 160 | # -c requirements/prod.txt 161 | # gitpython 162 | gitpython==3.1.29 163 | # via 164 | # -c requirements/prod.txt 165 | # zenml 166 | google-api-core==2.10.2 167 | # via google-api-python-client 168 | google-api-python-client==1.12.11 169 | # via ml-pipelines-sdk 170 | google-apitools==0.5.32 171 | # via ml-pipelines-sdk 172 | google-auth==2.6.5 173 | # via 174 | # google-api-core 175 | # google-api-python-client 176 | # google-auth-httplib2 177 | # google-auth-oauthlib 178 | # tensorboard 179 | google-auth-httplib2==0.1.0 180 | # via google-api-python-client 181 | google-auth-oauthlib==0.4.6 182 | # via tensorboard 183 | googleapis-common-protos==1.56.4 184 | # via google-api-core 185 | great-expectations==0.15.28 186 | # via -r requirements/dev.in 187 | greenlet==1.1.3.post0 188 | # via sqlalchemy 189 | grpcio==1.44.0 190 | # via 191 | # ml-metadata 192 | # tensorboard 193 | httplib2==0.19.1 194 | # via 195 | # google-api-python-client 196 | # google-apitools 197 | # google-auth-httplib2 198 | # oauth2client 199 | # zenml 200 | identify==2.5.1 201 | # via pre-commit 202 | idna==3.3 203 | # via 204 | # -c requirements/prod.txt 205 | # anyio 206 | # requests 207 | # yarl 208 | importlib-metadata==4.11.3 209 | # via 210 | # -c requirements/prod.txt 211 | # click 212 | # great-expectations 213 | # jsonschema 214 | # markdown 215 | # pluggy 216 | # pre-commit 217 | # pytest 218 | # sqlalchemy 219 | # virtualenv 220 | importlib-resources==5.9.0 221 | # via jsonschema 222 | iniconfig==1.1.1 223 | # via pytest 224 | ipykernel==4.10.1 225 | # via 226 | # ipywidgets 227 | # nbclassic 228 | # notebook 229 | ipython==7.32.0 230 | # via 231 | # great-expectations 232 | # ipykernel 233 | # ipywidgets 234 | ipython-genutils==0.2.0 235 | # via 236 | # ipywidgets 237 | # nbclassic 238 | # notebook 239 | ipywidgets==7.7.1 240 | # via 241 | # -r requirements/dev.in 242 | # great-expectations 243 | # rich 244 | itermplot==0.331 245 | # via -r requirements/dev.in 246 | jedi==0.18.1 247 | # via ipython 248 | jinja2==2.11.3 249 | # via 250 | # -c requirements/prod.txt 251 | # altair 252 | # great-expectations 253 | # jupyter-server 254 | # ml-pipelines-sdk 255 | # nbclassic 256 | # nbconvert 257 | # notebook 258 | joblib==1.2.0 259 | # via nltk 260 | jsonpatch==1.32 261 | # via great-expectations 262 | jsonpointer==2.3 263 | # via jsonpatch 264 | jsonschema==4.4.0 265 | # via 266 | # altair 267 | # great-expectations 268 | # nbformat 269 | jupyter-client==7.4.4 270 | # via 271 | # ipykernel 272 | # jupyter-server 273 | # nbclassic 274 | # nbclient 275 | # notebook 276 | jupyter-core==4.9.2 277 | # via 278 | # jupyter-client 279 | # jupyter-server 280 | # nbclassic 281 | # nbconvert 282 | # nbformat 283 | # notebook 284 | jupyter-server==1.21.0 285 | # via 286 | # nbclassic 287 | # notebook-shim 288 | jupyterlab-pygments==0.2.2 289 | # via nbconvert 290 | jupyterlab-widgets==1.1.1 291 | # via ipywidgets 292 | kiwisolver==1.4.2 293 | # via 294 | # -c requirements/prod.txt 295 | # matplotlib 296 | language-tool-python==2.7.1 297 | # via -r requirements/dev.in 298 | makefun==1.15.0 299 | # via great-expectations 300 | markdown==3.3.6 301 | # via tensorboard 302 | markupsafe==1.1.1 303 | # via 304 | # -c requirements/prod.txt 305 | # jinja2 306 | # zenml 307 | marshmallow==3.18.0 308 | # via great-expectations 309 | matplotlib==3.5.2 310 | # via 311 | # -c requirements/prod.txt 312 | # -r requirements/dev.in 313 | # itermplot 314 | # seaborn 315 | matplotlib-inline==0.1.3 316 | # via ipython 317 | mistune==0.8.4 318 | # via 319 | # great-expectations 320 | # nbconvert 321 | ml-metadata==1.8.0 322 | # via ml-pipelines-sdk 323 | ml-pipelines-sdk==1.8.0 324 | # via zenml 325 | monotonic==1.6 326 | # via 327 | # -c requirements/prod.txt 328 | # analytics-python 329 | multidict==6.0.2 330 | # via 331 | # -c requirements/prod.txt 332 | # aiohttp 333 | # yarl 334 | nbclassic==0.4.5 335 | # via notebook 336 | nbclient==0.5.13 337 | # via nbconvert 338 | nbconvert==6.4.4 339 | # via 340 | # jupyter-server 341 | # nbclassic 342 | # notebook 343 | # zenml 344 | nbformat==5.3.0 345 | # via 346 | # great-expectations 347 | # jupyter-server 348 | # nbclassic 349 | # nbclient 350 | # nbconvert 351 | # notebook 352 | nest-asyncio==1.5.6 353 | # via 354 | # jupyter-client 355 | # nbclassic 356 | # nbclient 357 | # notebook 358 | nltk==3.7 359 | # via -r requirements/dev.in 360 | nodeenv==1.6.0 361 | # via pre-commit 362 | notebook==6.5.1 363 | # via 364 | # -r requirements/dev.in 365 | # great-expectations 366 | # widgetsnbextension 367 | notebook-shim==0.2.0 368 | # via nbclassic 369 | numpy==1.21.6 370 | # via 371 | # -c requirements/prod.txt 372 | # altair 373 | # great-expectations 374 | # itermplot 375 | # matplotlib 376 | # pandas 377 | # pyarrow 378 | # pytorch-lightning 379 | # scipy 380 | # seaborn 381 | # tensorboard 382 | # torchmetrics 383 | oauth2client==4.1.3 384 | # via google-apitools 385 | oauthlib==3.2.0 386 | # via requests-oauthlib 387 | packaging==20.9 388 | # via 389 | # -c requirements/prod.txt 390 | # great-expectations 391 | # jupyter-server 392 | # marshmallow 393 | # matplotlib 394 | # ml-pipelines-sdk 395 | # pytest 396 | # pytorch-lightning 397 | # torchmetrics 398 | pandas==1.3.5 399 | # via 400 | # -c requirements/prod.txt 401 | # altair 402 | # great-expectations 403 | # seaborn 404 | # zenml 405 | pandocfilters==1.5.0 406 | # via nbconvert 407 | parso==0.8.3 408 | # via jedi 409 | passlib[bcrypt]==1.7.4 410 | # via zenml 411 | pexpect==4.8.0 412 | # via ipython 413 | pickleshare==0.7.5 414 | # via ipython 415 | pillow==7.1.2 416 | # via 417 | # -c requirements/prod.txt 418 | # matplotlib 419 | platformdirs==2.5.2 420 | # via virtualenv 421 | pluggy==1.0.0 422 | # via pytest 423 | portpicker==1.5.2 424 | # via ml-pipelines-sdk 425 | pre-commit==2.19.0 426 | # via -r requirements/dev.in 427 | prometheus-client==0.15.0 428 | # via 429 | # jupyter-server 430 | # nbclassic 431 | # notebook 432 | prompt-toolkit==3.0.29 433 | # via ipython 434 | protobuf==3.20.3 435 | # via 436 | # -c requirements/prod.txt 437 | # google-api-core 438 | # googleapis-common-protos 439 | # ml-metadata 440 | # ml-pipelines-sdk 441 | # tensorboard 442 | psutil==5.9.3 443 | # via 444 | # -c requirements/prod.txt 445 | # portpicker 446 | ptyprocess==0.7.0 447 | # via 448 | # pexpect 449 | # terminado 450 | py==1.11.0 451 | # via pytest 452 | pyarrow==7.0.0 453 | # via zenml 454 | pyasn1==0.4.8 455 | # via 456 | # oauth2client 457 | # pyasn1-modules 458 | # rsa 459 | pyasn1-modules==0.2.8 460 | # via 461 | # google-auth 462 | # oauth2client 463 | pycparser==2.21 464 | # via 465 | # -c requirements/prod.txt 466 | # cffi 467 | pydantic==1.9.1 468 | # via 469 | # -c requirements/prod.txt 470 | # sqlmodel 471 | # zenml 472 | pydeprecate==0.3.2 473 | # via 474 | # pytorch-lightning 475 | # torchmetrics 476 | pygments==2.11.2 477 | # via 478 | # ipython 479 | # nbconvert 480 | # rich 481 | pyparsing==2.4.2 482 | # via 483 | # -c requirements/prod.txt 484 | # great-expectations 485 | # httplib2 486 | # matplotlib 487 | # packaging 488 | # zenml 489 | pyrsistent==0.18.1 490 | # via jsonschema 491 | pytest==7.1.1 492 | # via 493 | # -r requirements/dev.in 494 | # pytest-cov 495 | pytest-cov==3.0.0 496 | # via -r requirements/dev.in 497 | python-dateutil==2.8.2 498 | # via 499 | # -c requirements/prod.txt 500 | # analytics-python 501 | # great-expectations 502 | # jupyter-client 503 | # matplotlib 504 | # pandas 505 | # zenml 506 | python-terraform==0.10.1 507 | # via zenml 508 | pytorch-lightning==1.6.3 509 | # via -r requirements/dev.in 510 | pytz==2022.1 511 | # via 512 | # -c requirements/prod.txt 513 | # great-expectations 514 | # pandas 515 | pytz-deprecation-shim==0.1.0.post0 516 | # via tzlocal 517 | pyyaml==5.4.1 518 | # via 519 | # -c requirements/prod.txt 520 | # pre-commit 521 | # pytorch-lightning 522 | # zenml 523 | pyzmq==24.0.1 524 | # via 525 | # jupyter-client 526 | # jupyter-server 527 | # nbclassic 528 | # notebook 529 | regex==2022.9.13 530 | # via 531 | # -c requirements/prod.txt 532 | # nltk 533 | requests==2.27.1 534 | # via 535 | # -c requirements/prod.txt 536 | # analytics-python 537 | # docker 538 | # fsspec 539 | # google-api-core 540 | # great-expectations 541 | # language-tool-python 542 | # requests-oauthlib 543 | # tensorboard 544 | requests-oauthlib==1.3.1 545 | # via google-auth-oauthlib 546 | rich[jupyter]==12.6.0 547 | # via zenml 548 | rsa==4.8 549 | # via 550 | # google-auth 551 | # oauth2client 552 | ruamel.yaml==0.17.17 553 | # via great-expectations 554 | ruamel.yaml.clib==0.2.7 555 | # via ruamel.yaml 556 | scipy==1.7.3 557 | # via 558 | # -r requirements/dev.in 559 | # great-expectations 560 | # seaborn 561 | seaborn==0.11.2 562 | # via -r requirements/dev.in 563 | send2trash==1.8.0 564 | # via 565 | # jupyter-server 566 | # nbclassic 567 | # notebook 568 | six==1.16.0 569 | # via 570 | # -c requirements/prod.txt 571 | # absl-py 572 | # analytics-python 573 | # bleach 574 | # docker 575 | # google-api-python-client 576 | # google-apitools 577 | # google-auth 578 | # google-auth-httplib2 579 | # grpcio 580 | # itermplot 581 | # ml-metadata 582 | # oauth2client 583 | # python-dateutil 584 | # validators 585 | # virtualenv 586 | smmap==5.0.0 587 | # via 588 | # -c requirements/prod.txt 589 | # gitdb 590 | sniffio==1.2.0 591 | # via 592 | # -c requirements/prod.txt 593 | # anyio 594 | soupsieve==2.3.2.post1 595 | # via beautifulsoup4 596 | sqlalchemy==1.4.41 597 | # via sqlmodel 598 | sqlalchemy2-stubs==0.0.2a29 599 | # via sqlmodel 600 | sqlmodel==0.0.8 601 | # via zenml 602 | tensorboard==2.8.0 603 | # via pytorch-lightning 604 | tensorboard-data-server==0.6.1 605 | # via tensorboard 606 | tensorboard-plugin-wit==1.8.1 607 | # via tensorboard 608 | termcolor==2.0.1 609 | # via great-expectations 610 | terminado==0.13.3 611 | # via 612 | # jupyter-server 613 | # nbclassic 614 | # notebook 615 | testpath==0.6.0 616 | # via nbconvert 617 | toml==0.10.2 618 | # via 619 | # -r requirements/dev.in 620 | # pre-commit 621 | tomli==2.0.1 622 | # via 623 | # coverage 624 | # pytest 625 | toolz==0.12.0 626 | # via altair 627 | torch==1.12.0 628 | # via 629 | # -c requirements/prod.txt 630 | # pytorch-lightning 631 | # torchmetrics 632 | torchmetrics==0.7.3 633 | # via 634 | # -r requirements/dev.in 635 | # pytorch-lightning 636 | tornado==6.2 637 | # via 638 | # ipykernel 639 | # jupyter-client 640 | # jupyter-server 641 | # nbclassic 642 | # notebook 643 | # terminado 644 | tqdm==4.64.0 645 | # via 646 | # -c requirements/prod.txt 647 | # great-expectations 648 | # language-tool-python 649 | # nltk 650 | # pytorch-lightning 651 | traitlets==5.1.1 652 | # via 653 | # ipykernel 654 | # ipython 655 | # ipywidgets 656 | # jupyter-client 657 | # jupyter-core 658 | # jupyter-server 659 | # matplotlib-inline 660 | # nbclassic 661 | # nbclient 662 | # nbconvert 663 | # nbformat 664 | # notebook 665 | typing-extensions==4.7.1 666 | # via 667 | # -c requirements/prod.txt 668 | # aiohttp 669 | # anyio 670 | # argon2-cffi 671 | # async-timeout 672 | # gitpython 673 | # great-expectations 674 | # importlib-metadata 675 | # jsonschema 676 | # kiwisolver 677 | # pydantic 678 | # pytorch-lightning 679 | # rich 680 | # sqlalchemy2-stubs 681 | # torch 682 | # yarl 683 | tzdata==2022.5 684 | # via pytz-deprecation-shim 685 | tzlocal==4.2 686 | # via great-expectations 687 | uritemplate==3.0.1 688 | # via google-api-python-client 689 | urllib3==1.26.12 690 | # via 691 | # -c requirements/prod.txt 692 | # great-expectations 693 | # requests 694 | validators==0.18.2 695 | # via click-params 696 | virtualenv==20.14.1 697 | # via pre-commit 698 | wcwidth==0.2.5 699 | # via prompt-toolkit 700 | webencodings==0.5.1 701 | # via bleach 702 | websocket-client==1.4.1 703 | # via 704 | # docker 705 | # jupyter-server 706 | werkzeug==2.1.2 707 | # via tensorboard 708 | wheel==0.37.1 709 | # via tensorboard 710 | widgetsnbextension==3.6.1 711 | # via ipywidgets 712 | yarl==1.7.2 713 | # via 714 | # -c requirements/prod.txt 715 | # aiohttp 716 | zenml==0.20.5 717 | # via -r requirements/dev.in 718 | zipp==3.8.0 719 | # via 720 | # -c requirements/prod.txt 721 | # importlib-metadata 722 | # importlib-resources 723 | 724 | # The following packages are considered to be unsafe in a requirements file: 725 | # setuptools 726 | -------------------------------------------------------------------------------- /question_answer/answer.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | from collections import defaultdict 4 | import json 5 | import os 6 | from pathlib import Path 7 | import random 8 | from typing import Any, Dict, List, Tuple, Union 9 | 10 | from dotenv import load_dotenv 11 | import numpy as np 12 | from onnxruntime import InferenceSession 13 | from openai import OpenAI 14 | from PIL import Image 15 | import torch 16 | from transformers import ( 17 | AutoTokenizer, 18 | CLIPProcessor, 19 | DetrFeatureExtractor, 20 | DetrForSegmentation, 21 | pipeline, 22 | VisionEncoderDecoderModel, 23 | ViTFeatureExtractor, 24 | ) 25 | 26 | import question_answer.metadata.pica as metadata 27 | 28 | # Loading env variables 29 | load_dotenv() 30 | 31 | # Variables 32 | # OpenAI params 33 | CLIENT = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 34 | MODEL = "gpt-3.5-turbo-1106" 35 | 36 | # Artifact path 37 | artifact_path = Path(__file__).resolve().parent / "artifacts" / "answer" 38 | 39 | # PICa formatting/config 40 | img_id = 100 # Random idx for inference 41 | question_id = 1005 # Random idx for inference 42 | # Significant driver of performance with little extra cost 43 | # PICa paper's max = 16, but can set higher if model's speed + context size can handle it 44 | n_shot = 16 45 | coco_path = artifact_path / "coco_annotations" 46 | similarity_path = artifact_path / "coco_clip_new" 47 | 48 | # Model setup 49 | transformers_path = artifact_path / "transformers" 50 | onnx_path = artifact_path / "onnx" 51 | 52 | # Segmentation model config 53 | tag_model = transformers_path / "facebook" / "detr-resnet-50-panoptic" 54 | max_length = 16 55 | num_beams = 4 56 | 57 | # Caption model config 58 | caption_model = transformers_path / "nlpconnect" / "vit-gpt2-image-captioning" 59 | 60 | # CLIP Encoders config 61 | clip_processor = transformers_path / "openai" / "clip-vit-base-patch16" 62 | clip_onnx = onnx_path / "clip.onnx" 63 | 64 | # Dataset variables 65 | NUM_ORIGINAL_EXAMPLES = metadata.NUM_ORIGINAL_EXAMPLES 66 | NUM_ADDED_EXAMPLES = metadata.NUM_ADDED_EXAMPLES 67 | NUM_TEST_EXAMPLES = metadata.NUM_TEST_EXAMPLES 68 | 69 | 70 | # Helper/main classes 71 | class PICa_OKVQA: 72 | """ 73 | Question Answering Class 74 | """ 75 | 76 | def __init__( 77 | self, 78 | caption_info: Dict[Any, Any] = None, 79 | tag_info: Dict[Any, Any] = None, 80 | questions: Dict[str, List[Dict[str, str]]] = None, 81 | context_idxs: Dict[str, str] = None, 82 | question_features: np.ndarray = None, 83 | image_features: np.ndarray = None, 84 | evaluate: bool = False, 85 | ): 86 | self.evaluate = evaluate 87 | ( 88 | self.traincontext_caption_dict, 89 | self.traincontext_answer_dict, 90 | self.traincontext_question_dict, 91 | ) = self.load_anno( 92 | "%s/captions_train2014.json" % coco_path, 93 | "%s/mscoco_train2014_annotations.json" % coco_path, 94 | "%s/OpenEnded_mscoco_train2014_questions.json" % coco_path, 95 | ) 96 | ( 97 | self.traincontext_caption_dict, 98 | _, 99 | self.traincontext_answer_dict, 100 | self.traincontext_question_dict, 101 | ) = self.add_anno( 102 | "%s/admirer-pica.json" % coco_path, 103 | self.traincontext_caption_dict, 104 | self.traincontext_answer_dict, 105 | self.traincontext_question_dict, 106 | ) 107 | if evaluate: 108 | ( 109 | self.testcontext_caption_dict, 110 | self.testcontext_tags_dict, 111 | self.testcontext_answer_dict, 112 | self.testcontext_question_dict, 113 | ) = self.add_anno( 114 | "%s/admirer-pica.json" % coco_path, 115 | evaluate=evaluate, 116 | ) 117 | # load cached image representation (Coco caption & Tags) 118 | self.inputtext_dict = self.load_cachetext(self.testcontext_caption_dict, self.testcontext_tags_dict) 119 | self.load_similarity(evaluate=evaluate) 120 | question_dict_keys = list(self.testcontext_question_dict.keys()) 121 | image_ids, question_ids = [key.split("<->")[0] for key in question_dict_keys], [ 122 | key.split("<->")[1] for key in question_dict_keys 123 | ] 124 | list_questions = list(self.testcontext_question_dict.values()) 125 | self.questions = { 126 | "questions": [ 127 | {"image_id": image_id, "question": question_str, "question_id": quest_id} 128 | for image_id, question_str, quest_id in zip(image_ids, list_questions, question_ids) 129 | ] 130 | } 131 | else: 132 | # load cached image representation (Coco caption & Tags) 133 | self.inputtext_dict = self.load_cachetext(caption_info, tag_info) 134 | _ = self.load_similarity(context_idxs, question_features, image_features) 135 | self.questions = questions 136 | 137 | self.train_keys = list(self.traincontext_answer_dict.keys()) 138 | 139 | def answer_gen(self): 140 | _, _, question_dict = self.load_anno(questions=self.questions) 141 | 142 | if self.evaluate: 143 | pred_answers = [] 144 | gt_answers = [] 145 | 146 | keys = list(question_dict.keys()) 147 | for key in keys: 148 | img_key = int(key.split("<->")[0]) 149 | question, caption = ( 150 | question_dict[key], 151 | self.inputtext_dict[img_key], 152 | ) 153 | 154 | context_key_list = self.get_context_keys( 155 | key, 156 | n_shot, 157 | ) 158 | 159 | # prompt format following OpenAI QA API 160 | messages = [] 161 | system_message = { 162 | "role": "system", 163 | "content": str( 164 | "You are given {n_shot} examples of image content, a question about the image, and an answer. " 165 | + "Given a new set of content and question, " 166 | + "you are tasked with coming up with an answer in a similar way to the examples. " 167 | + "If the content is not enough to answer the question, " 168 | + "make up an answer structured as:" 169 | + "\n" 170 | + "1) an acknowledgment of not knowing the correct answer to the question," 171 | + "\n" 172 | + "2) a comedic reply using what you can from the content." 173 | + "\n" 174 | + "For example, if the question is 'What is the color of the user's shirt?', " 175 | + "and the context is 'The user is wearing a shirt with a picture of a cat on it', " 176 | + "a good answer could be 'I don't know, but I think the cat is cute!'" 177 | ), 178 | } 179 | messages.append(system_message) 180 | for ni in range(n_shot): 181 | if context_key_list is None: 182 | context_key = self.train_keys[random.randint(0, len(self.train_keys) - 1)] 183 | else: 184 | context_key = context_key_list[ni] 185 | img_context_key = int(context_key.split("<->")[0]) 186 | while True: # make sure get context with valid question and answer 187 | if ( 188 | len(self.traincontext_question_dict[context_key]) != 0 189 | and len(self.traincontext_answer_dict[context_key][0]) != 0 190 | ): 191 | break 192 | context_key = self.train_keys[random.randint(0, len(self.train_keys) - 1)] 193 | caption = self.traincontext_caption_dict[img_context_key] 194 | question = self.traincontext_question_dict[context_key] 195 | answer = self.traincontext_answer_dict[context_key] 196 | if type(caption) == list: 197 | caption = caption[0] # sometimes annotators messed up 198 | if type(question) == list: 199 | question = question[0] 200 | if type(answer) == list: 201 | answer = answer[0] 202 | user_message = { 203 | "role": "user", 204 | "content": str( 205 | "Image content: " + caption + "\n" + "Question: " + question + "\n" + "Answer: " + answer 206 | ), 207 | } 208 | messages.append(user_message) 209 | current_user_message = { 210 | "role": "user", 211 | "content": str("Image content: " + caption + "\n" + "Question: " + question + "\n" + "Answer: "), 212 | } 213 | messages.append(current_user_message) 214 | try: 215 | response = CLIENT.chat.completions.create( 216 | model=MODEL, 217 | messages=messages, 218 | ) 219 | except Exception as e: 220 | print(e) 221 | exit(0) 222 | 223 | pred_answer = response.choices[0].message.content 224 | 225 | if self.evaluate: 226 | answer = self.testcontext_answer_dict[key] 227 | pred_answers.append(pred_answer) 228 | gt_answers.append(answer) 229 | else: 230 | return pred_answer 231 | 232 | from question_answer.lit_models.metrics import BertF1Score 233 | 234 | return BertF1Score()(pred_answers, gt_answers) 235 | 236 | def get_context_keys(self, key: str, n: int) -> List[str]: 237 | """Get context keys based on similarity scores""" 238 | # combined with Q-similairty (image+question) 239 | lineid = self.valkey2idx[key] 240 | 241 | # Removing validation key from train similarity arrays if needed 242 | temp_train_feature = None 243 | temp_image_train_feature = None 244 | temp_train_idx = None 245 | 246 | for idx in range(NUM_ORIGINAL_EXAMPLES, NUM_ORIGINAL_EXAMPLES + NUM_ADDED_EXAMPLES): 247 | question_feature_equal = np.array_equal(self.val_feature[lineid], self.train_feature[idx]) 248 | image_feature_equal = np.array_equal(self.val_feature[lineid], self.image_train_feature[idx]) 249 | if question_feature_equal and image_feature_equal: 250 | mask = np.ones(len(self.train_feature), dtype=bool) 251 | mask[[idx]] = False 252 | temp_train_feature = self.train_feature[mask] 253 | temp_image_train_feature = self.image_train_feature[mask] 254 | temp_train_idx = self.train_idx.pop(str(idx)) 255 | break 256 | 257 | removed = temp_train_feature is not None and temp_image_train_feature is not None and temp_train_idx is not None 258 | if removed: 259 | question_similarity: np.ndarray = np.matmul(temp_train_feature, self.val_feature[lineid, :]) 260 | # end of Q-similairty 261 | similarity: np.ndarray = question_similarity + np.matmul( 262 | temp_image_train_feature, self.image_val_feature[lineid, :] 263 | ) 264 | else: 265 | question_similarity: np.ndarray = np.matmul(self.train_feature, self.val_feature[lineid, :]) 266 | # end of Q-similairty 267 | similarity: np.ndarray = question_similarity + np.matmul( 268 | self.image_train_feature, self.image_val_feature[lineid, :] 269 | ) 270 | 271 | index: np.ndarray = similarity.argsort()[-n:][::-1] 272 | return [self.train_idx[str(x)] for x in index] 273 | 274 | def load_similarity( 275 | self, 276 | context_idxs: Dict[str, str] = None, 277 | question_features: np.ndarray = None, 278 | image_features: np.ndarray = None, 279 | evaluate=False, 280 | ): 281 | # Add question train feature, image train feature, and train idx 282 | self.train_feature = np.load("%s/coco_clip_vitb16_train2014_okvqa_question.npy" % similarity_path) 283 | self.train_idx: Dict[str, str] = json.load( 284 | open( 285 | "%s/okvqa_qa_line2sample_idx_train2014.json" % similarity_path, 286 | "r", 287 | ) 288 | ) 289 | self.image_train_feature = np.load( 290 | "%s/coco_clip_vitb16_train2014_okvqa_convertedidx_image.npy" % similarity_path 291 | ) 292 | 293 | if evaluate: 294 | context_idxs = dict(list(self.train_idx.items())[NUM_ORIGINAL_EXAMPLES:]) 295 | new_keys = [str(idx) for idx in range(len(context_idxs))] 296 | context_idxs = dict(zip(new_keys, list(context_idxs.values()))) 297 | self.val_feature = self.train_feature[-NUM_ADDED_EXAMPLES:, :] 298 | self.image_val_feature = self.image_train_feature[-NUM_ADDED_EXAMPLES:, :] 299 | else: 300 | self.val_feature = question_features 301 | self.image_val_feature = image_features 302 | 303 | val_idx = context_idxs 304 | self.valkey2idx: Dict[str, int] = {} 305 | for ii in val_idx: 306 | self.valkey2idx[val_idx[ii]] = int(ii) 307 | 308 | def load_tags( 309 | self, 310 | tag_info: Dict[Any, List[str]], 311 | ) -> Dict[int, str]: 312 | """Loads tags for an image""" 313 | tags_dict = {} 314 | image_ids, list_tags = list(tag_info.keys()), list(tag_info.values()) 315 | # Concatenate tags into one string 316 | list_str_tags = [tags for tags in list_tags] 317 | for id in range(len(image_ids)): 318 | tags_dict[image_ids[id]] = list_str_tags[id] 319 | return tags_dict 320 | 321 | def load_cachetext( 322 | self, 323 | caption_info: Dict[Any, List[str]], 324 | tag_info: Dict[Any, List[str]], 325 | ): 326 | """Loads and adds cachetect to the caption""" 327 | tags_dict = self.load_tags(tag_info) 328 | caption_dict = {} 329 | image_ids, captions = list(caption_info.keys()), list(caption_info.values()) 330 | for id in range(len(image_ids)): 331 | caption_dict[image_ids[id]] = captions[id] + ". " + list(tags_dict.values())[id] 332 | return caption_dict 333 | 334 | def load_anno( 335 | self, 336 | coco_caption_file: Path = None, 337 | answer_anno_file: Path = None, 338 | question_anno_file: Path = None, 339 | questions: Dict[str, List[Dict[str, str]]] = None, 340 | ) -> Tuple[Dict[int, List[str]], Dict[str, List[str]], Dict[str, str]]: 341 | """Loads annotation from a caption file""" 342 | # Define default dictionaries 343 | caption_dict: defaultdict[int, List[str]] = defaultdict(list) 344 | answer_dict: defaultdict[str, List[str]] = defaultdict(list) 345 | question_dict: defaultdict[str, str] = defaultdict(list) 346 | 347 | # Create caption dictionary 348 | if coco_caption_file is not None: 349 | coco_caption = json.load(open(coco_caption_file, "r")) 350 | if isinstance(coco_caption, dict): 351 | coco_caption: List[Dict[str, Union[str, int]]] = coco_caption["annotations"] 352 | for sample in coco_caption: 353 | caption_dict[sample["image_id"]].append(sample["caption"]) # int -> sample[image_id] 354 | 355 | # Create answer dictionary 356 | if answer_anno_file is not None: 357 | answer_data = json.load(open(answer_anno_file, "r")) 358 | answer_annotations: List[Dict[str, Any]] = answer_data["annotations"] 359 | for sample in answer_annotations: 360 | id = str(sample["image_id"]) + "<->" + str(sample["question_id"]) 361 | if id not in answer_dict: 362 | answer_dict[id] = [x["answer"] for x in sample["answers"]] 363 | 364 | # Create question dictionary 365 | if question_anno_file is not None: 366 | question_data = json.load(open(question_anno_file, "r")) 367 | else: 368 | question_data = questions 369 | 370 | question_annotations: List[Dict[str, Union[str, int]]] = question_data["questions"] 371 | for sample in question_annotations: 372 | id = str(sample["image_id"]) + "<->" + str(sample["question_id"]) 373 | if id not in question_dict: 374 | question_dict[id] = sample["question"] 375 | 376 | return dict(caption_dict), dict(answer_dict), dict(question_dict) 377 | 378 | def add_anno( 379 | self, 380 | add: Path, 381 | context_caption_dict: Dict[int, List[str]] = None, 382 | context_answer_dict: Dict[str, List[str]] = None, 383 | context_question_dict: Dict[str, str] = None, 384 | evaluate=False, 385 | ): 386 | """Load/add extra annotations to the annotations dictionaries""" 387 | add_dict = json.load(open(add, "r")) 388 | 389 | context_tag_dict = {} 390 | 391 | caption_add = dict(zip(list(add_dict["image_id"].values()), list(add_dict["caption"].values()))) 392 | tags_add = dict(zip(list(add_dict["image_id"].values()), list(add_dict["tags"].values()))) 393 | combine_ids = [ 394 | str(image_id) + "<->" + str(question_id) 395 | for image_id, question_id in zip( 396 | list(add_dict["image_id"].values()), list(add_dict["question_id"].values()) 397 | ) 398 | ] 399 | answer_add = dict(zip(combine_ids, list(add_dict["answer"].values()))) 400 | question_add = dict(zip(combine_ids, list(add_dict["question"].values()))) 401 | 402 | if evaluate: 403 | context_caption_dict = {} 404 | context_answer_dict = {} 405 | context_question_dict = {} 406 | context_caption_dict.update(caption_add) 407 | context_tag_dict.update(tags_add) 408 | context_answer_dict.update(answer_add) 409 | context_question_dict.update(question_add) 410 | 411 | if evaluate: 412 | context_caption_dict = dict(list(context_caption_dict.items())[-NUM_TEST_EXAMPLES:]) 413 | context_tag_dict = dict(list(context_tag_dict.items())[-NUM_TEST_EXAMPLES:]) 414 | context_answer_dict = dict(list(context_answer_dict.items())[-NUM_TEST_EXAMPLES:]) 415 | context_question_dict = dict(list(context_question_dict.items())[-NUM_TEST_EXAMPLES:]) 416 | 417 | return context_caption_dict, context_tag_dict, context_answer_dict, context_question_dict 418 | 419 | 420 | class Pipeline: 421 | """ 422 | Main inference class 423 | """ 424 | 425 | def __init__(self): 426 | # Tagging model setup 427 | segment_model = DetrForSegmentation.from_pretrained(tag_model, use_pretrained_backbone=False) 428 | self.segment = pipeline( 429 | "image-segmentation", model=segment_model, feature_extractor=DetrFeatureExtractor.from_pretrained(tag_model) 430 | ) 431 | self.tags = [] 432 | 433 | # Caption model setup 434 | self.caption_model = VisionEncoderDecoderModel.from_pretrained(caption_model) 435 | self.caption_feature_extractor = ViTFeatureExtractor.from_pretrained(caption_model) 436 | self.caption_tokenizer = AutoTokenizer.from_pretrained(caption_model) 437 | self.device = torch.device("cpu") # torch.device("cuda" if torch.cuda.is_available() else "cpu") 438 | 439 | # CLIP Setup 440 | self.clip_session = InferenceSession(str(clip_onnx)) 441 | self.clip_processor = CLIPProcessor.from_pretrained(clip_processor) 442 | 443 | def predict_caption(self, image): 444 | pixel_values = self.caption_feature_extractor(images=[image], return_tensors="pt").pixel_values 445 | pixel_values = pixel_values.to(self.device) 446 | 447 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams} 448 | output_ids = self.caption_model.generate(pixel_values, **gen_kwargs) 449 | 450 | preds = self.caption_tokenizer.batch_decode(output_ids, skip_special_tokens=True) 451 | preds = [pred.strip() for pred in preds] 452 | return preds[0] 453 | 454 | def predict(self, image: Union[str, Path, Image.Image], question: Union[str, Path]) -> str: 455 | if not isinstance(image, Image.Image): 456 | image_pil = Image.open(image) 457 | if image_pil.mode != "RGB": 458 | image_pil = image_pil.convert(mode="RGB") 459 | else: 460 | image_pil = image 461 | if isinstance(question, Path) | os.path.exists(question): 462 | with open(question, "r") as f: 463 | question_str = f.readline() 464 | else: 465 | question_str = question 466 | 467 | # Generating image tag(s) 468 | for dic in self.segment(image_pil): 469 | self.tags.append(dic["label"]) 470 | if not self.tags: 471 | self.tags.append("") 472 | tag_info: Dict[int, List[str]] = {img_id: ", ".join(self.tags)} 473 | 474 | # Generating image caption 475 | caption = self.predict_caption(image_pil) 476 | if not caption: 477 | caption = "" 478 | caption_info: Dict[int, str] = {img_id: caption} 479 | 480 | # Generating image/question features 481 | inputs = self.clip_processor(text=[question_str], images=image_pil, return_tensors="np", padding=True) 482 | # for i in session.get_outputs(): print(i.name) 483 | outputs = self.clip_session.run( 484 | output_names=["logits_per_image", "logits_per_text", "text_embeds", "image_embeds"], input_feed=dict(inputs) 485 | ) 486 | 487 | # Generating context idxs 488 | context_idxs: Dict[str, str] = {"0": str(img_id) + "<->" + str(question_id)} 489 | 490 | # Answering question 491 | questions = {"questions": [{"image_id": img_id, "question": question_str, "question_id": question_id}]} 492 | okvqa = PICa_OKVQA( 493 | caption_info, tag_info, questions, context_idxs, outputs[2], outputs[3] 494 | ) # Have to initialize here because necessary objects need to be generated 495 | answer = okvqa.answer_gen() 496 | # rationale = okvqa.rationale(answer) 497 | 498 | return answer # + " because " + rationale 499 | 500 | def evaluate(self): 501 | okvqa = PICa_OKVQA( 502 | evaluate=True, 503 | ) 504 | acc = okvqa.answer_gen() 505 | print(acc) 506 | return acc 507 | 508 | 509 | # Running model 510 | def main(): 511 | parser = argparse.ArgumentParser() 512 | 513 | # Inputs 514 | parser.add_argument("--image", type=str, required=True) 515 | parser.add_argument("--question", type=str, required=True) 516 | 517 | args = parser.parse_args() 518 | 519 | # Answering question 520 | pipeline = Pipeline() 521 | pred_str = pipeline.predict(args.image, args.question) 522 | 523 | print(pred_str) 524 | 525 | 526 | if __name__ == "__main__": 527 | main() 528 | --------------------------------------------------------------------------------