├── .gitignore ├── README.md ├── api-demo ├── .dockerignore ├── .env ├── Dockerfile ├── app.py ├── assets │ ├── Cinnamon_Bootcamp_1.png │ ├── Cinnamon_Bootcamp_2.png │ ├── cat.jpg │ ├── cat_liquid.jpg │ ├── cat_liquid2.jpg │ ├── cat_liquid3.jpg │ ├── football-match.jpg │ └── lion.jpg ├── docker-compose.yaml ├── fastapi_backend.py ├── gradio_frontend.py ├── requirements.txt ├── static │ └── styles.css └── templates │ └── index.html ├── image-to-image-search ├── demo.ipynb └── requirements.txt ├── leaked_container ├── .env ├── Dockerfile ├── requirements.txt └── src │ └── index.py ├── rag-foundation ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── data │ ├── dense.csv │ ├── llama2.pdf │ ├── sparse.csv │ └── sparse_metadata.json ├── evaluate.py ├── qasper-test-v0.3.json ├── requirements.txt ├── sample_predictions.jsonl ├── scripts │ ├── __init__.py │ └── main.py ├── setup.cfg └── vector_store │ ├── __init__.py │ ├── base.py │ ├── node.py │ ├── semantic_vector_store.py │ └── sparse_vector_store.py └── streamlit_demo ├── .gitignore ├── README.md ├── assets └── screenshot_app.png ├── constants.py ├── launch.py ├── lessons ├── __init__.py ├── cache_flow.py ├── execution_flow.py ├── layout.py └── m10.jpg ├── requirements.txt └── shared ├── __init__.py ├── crud ├── __init__.py └── feedbacks.py ├── models ├── __init__.py ├── engine.py └── models.py ├── models_ai ├── __init__.py ├── base.py └── yolov8.py ├── schemas.py ├── utils ├── __init__.py ├── files.py ├── log.py └── pages.py └── views ├── __init__.py ├── app ├── __init__.py └── view.py └── canvas ├── __init__.py ├── canvas.py ├── frontend ├── .env ├── .gitignore ├── .prettierrc ├── package-lock.json ├── package.json ├── public │ └── index.html ├── src │ ├── DrawableCanvas.tsx │ ├── DrawableCanvasState.tsx │ ├── components │ │ ├── CanvasToolbar.module.css │ │ ├── CanvasToolbar.tsx │ │ └── UpdateStreamlit.tsx │ ├── img │ │ ├── bin.png │ │ ├── download.png │ │ └── undo.png │ ├── index.css │ ├── index.tsx │ └── react-app-env.d.ts └── tsconfig.json └── processor.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | # PyCharm 155 | .idea/ 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cinnamon AI Bootcamp 2024 2 | Lecture materials for Cinnamon AI Bootcamp. 3 | 4 | For Cinnamon AI Bootcamp 2024 use only, DO NOT DISTRIBUTE. 5 | -------------------------------------------------------------------------------- /api-demo/.dockerignore: -------------------------------------------------------------------------------- 1 | 2 | # Git 3 | .git 4 | .gitignore 5 | 6 | # CI 7 | .codeclimate.yml 8 | .travis.yml 9 | .taskcluster.yml 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | */__pycache__/ 14 | */*/__pycache__/ 15 | */*/*/__pycache__/ 16 | */*/*.py[cod] 17 | */*/*/*.py[cod] 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | *.pyo 39 | *.pyd 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Virtual environment 73 | .venv/ 74 | venv/ 75 | 76 | # PyCharm 77 | .idea 78 | 79 | # Python mode for VIM 80 | .ropeproject 81 | */.ropeproject 82 | */*/.ropeproject 83 | */*/*/.ropeproject 84 | 85 | # Vim swap files 86 | *.swp 87 | */*.swp 88 | */*/*.swp 89 | */*/*/*.swp 90 | 91 | # Data 92 | data/ 93 | 94 | # Notebook 95 | .ipynb_checkpoints/ 96 | *.ipynb 97 | -------------------------------------------------------------------------------- /api-demo/.env: -------------------------------------------------------------------------------- 1 | # Development settings 2 | FASTAPI_HOST="0.0.0.0" 3 | FASTAPI_PORT=8000 4 | 5 | GRADIO_HOST="0.0.0.0" 6 | GRADIO_PORT=8080 7 | 8 | AWS_ACCOUNT_ID=******* 9 | AWS_REGION=us-west-2 10 | PREDICT_ENDPOINT="http://127.0.0.1:${FASTAPI_PORT}/predict/" 11 | DOCKER_REGISTRY="${AWS_ACCOUNT_ID}.dkr.ecr.${AWS_REGION}.amazonaws.com" 12 | -------------------------------------------------------------------------------- /api-demo/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | RUN mkdir /workspace 4 | COPY . /workspace/ 5 | 6 | WORKDIR /workspace 7 | RUN pip install -r requirements.txt 8 | 9 | ENV HOME=/workspace 10 | 11 | CMD [ "python" , "app.py" ] 12 | -------------------------------------------------------------------------------- /api-demo/app.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | 3 | from fastapi_backend import main as run_backend 4 | from gradio_frontend import main as run_frontend 5 | 6 | 7 | if __name__ == "__main__": 8 | backend_process = Process(target=run_backend) 9 | backend_process.start() 10 | 11 | frontend_process = Process(target=run_frontend) 12 | frontend_process.start() 13 | 14 | backend_process.join() 15 | frontend_process.join() 16 | -------------------------------------------------------------------------------- /api-demo/assets/Cinnamon_Bootcamp_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/Cinnamon_Bootcamp_1.png -------------------------------------------------------------------------------- /api-demo/assets/Cinnamon_Bootcamp_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/Cinnamon_Bootcamp_2.png -------------------------------------------------------------------------------- /api-demo/assets/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/cat.jpg -------------------------------------------------------------------------------- /api-demo/assets/cat_liquid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/cat_liquid.jpg -------------------------------------------------------------------------------- /api-demo/assets/cat_liquid2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/cat_liquid2.jpg -------------------------------------------------------------------------------- /api-demo/assets/cat_liquid3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/cat_liquid3.jpg -------------------------------------------------------------------------------- /api-demo/assets/football-match.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/football-match.jpg -------------------------------------------------------------------------------- /api-demo/assets/lion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/api-demo/assets/lion.jpg -------------------------------------------------------------------------------- /api-demo/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | 3 | services: 4 | classification_app: 5 | container_name: classification_demo 6 | image: "${DOCKER_REGISTRY}/classification_demo:0.0.2" 7 | build: 8 | context: . 9 | dockerfile: Dockerfile 10 | restart: always 11 | ports: 12 | - "8000:8000" 13 | - "8080:8080" 14 | env_file: 15 | - .env 16 | -------------------------------------------------------------------------------- /api-demo/fastapi_backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | import torch 5 | import uvicorn 6 | from PIL import Image 7 | from dotenv import load_dotenv 8 | from fastapi import FastAPI, UploadFile, Request 9 | from fastapi.staticfiles import StaticFiles 10 | from fastapi.templating import Jinja2Templates 11 | from transformers import ViTImageProcessor, ViTForImageClassification 12 | 13 | load_dotenv() 14 | app = FastAPI(title="fastapi-classification-demo") 15 | app.mount("/static", StaticFiles(directory="static"), name="static") 16 | app.mount("/assets", StaticFiles(directory="assets"), name="assets") 17 | templates = Jinja2Templates(directory="templates") 18 | 19 | 20 | def predict_imagenet_confidences(image: Union[Image.Image, str]) -> dict: 21 | """[A normal python function] 22 | Receive an image and predict confidences for ImageNet classes. 23 | 24 | Args: 25 | image (Union[Image.Image, str]): Image to predict confidences for. 26 | 27 | Returns: 28 | dict: Dictionary of 1000 classes in ImageNet and their confidence scores (float). 29 | """ 30 | if isinstance(image, str): 31 | image = Image.open(image) 32 | # Get the model and processor 33 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") 34 | model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") 35 | 36 | inputs = processor(images=image, return_tensors="pt") 37 | outputs = model(**inputs) 38 | # Get confidence scores for all 1000 classes 39 | logits = outputs.logits 40 | confidences_id = torch.nn.functional.softmax(logits[0], dim=0) 41 | confidences_labels = { 42 | model.config.id2label[i]: float(confidences_id[i]) for i in range(1000) 43 | } 44 | 45 | return confidences_labels 46 | 47 | 48 | @app.get("/") 49 | def home(request: Request): 50 | return templates.TemplateResponse( 51 | "index.html", {"request": request, "name": "Class of Cinnamon AI Bootcamp 2023"} 52 | ) 53 | 54 | 55 | @app.post("/predict/") 56 | async def predict(file: UploadFile): 57 | """[FastAPI endpoint] 58 | Predict confidences for ImageNet classes from an uploaded image. 59 | 60 | Args: 61 | file (UploadFile): Uploaded image file. 62 | 63 | Returns: 64 | dict: Dictionary of 1000 classes in ImageNet and their confidence scores (float). 65 | """ 66 | file_obj = file.file 67 | image = Image.open(file_obj) 68 | confidences = predict_imagenet_confidences(image) 69 | return confidences 70 | 71 | 72 | def main(): 73 | # Run web server with uvicorn 74 | uvicorn.run( 75 | "fastapi_backend:app", 76 | host=os.getenv("FASTAPI_HOST", "127.0.0.1"), 77 | port=int(os.getenv("FASTAPI_PORT", 8000)), 78 | # reload=True, # Uncomment this for debug 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /api-demo/gradio_frontend.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import gradio as gr 5 | import requests 6 | from PIL import Image 7 | from dotenv import load_dotenv 8 | 9 | load_dotenv() 10 | 11 | 12 | def predict_imagenet_confidences_via_request(image: Image.Image) -> dict: 13 | """[Send a POST request to the fastapi backend] 14 | Receive an image and predict confidences for ImageNet classes. 15 | 16 | Args: 17 | image: Image to predict confidences for. 18 | 19 | Returns: 20 | dict: Dictionary of 1000 classes in ImageNet and their confidence scores (float). 21 | """ 22 | # Get the prediction endpoint 23 | url = os.getenv("PREDICT_ENDPOINT", "http://127.0.0.1:8000/predict/") 24 | 25 | # Convert PIL Image to bytes to send via requests POST 26 | img_byte_arr = io.BytesIO() 27 | image.save(img_byte_arr, format="PNG") 28 | img_byte_arr = img_byte_arr.getvalue() 29 | 30 | # Send POST request to predict endpoint 31 | files = {"file": img_byte_arr} 32 | response = requests.post(url, files=files) 33 | return response.json() 34 | 35 | 36 | def main(): 37 | # Gradio front-end interface 38 | gr_interface = gr.Interface( 39 | fn=predict_imagenet_confidences_via_request, 40 | inputs=gr.Image(type="pil"), 41 | outputs=gr.Label(num_top_classes=5), 42 | examples=["assets/cat.jpg", "assets/lion.jpg"], 43 | ) 44 | 45 | # Launch the web server 46 | gr_interface.launch( 47 | server_name=os.getenv("GRADIO_HOST", "127.0.0.1"), 48 | server_port=int(os.getenv("GRADIO_PORT", 8080)), 49 | ) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /api-demo/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiofiles==23.1.0 3 | aiohttp==3.8.4 4 | aiohttp-retry==2.8.3 5 | aioitertools==0.11.0 6 | aiosignal==1.3.1 7 | altair==5.0.1 8 | amqp==5.1.1 9 | annotated-types==0.5.0 10 | antlr4-python3-runtime==4.9.3 11 | anyio==3.7.1 12 | appdirs==1.4.4 13 | argon2-cffi==21.3.0 14 | argon2-cffi-bindings==21.2.0 15 | arrow==1.2.3 16 | asttokens==2.2.1 17 | async-lru==2.0.3 18 | async-timeout==4.0.2 19 | asyncssh==2.13.1 20 | atpublic==3.1.1 21 | attrs==23.1.0 22 | Babel==2.12.1 23 | backcall==0.2.0 24 | beautifulsoup4==4.12.2 25 | bleach==6.0.0 26 | blessed==1.20.0 27 | cachetools==4.2.4 28 | certifi==2023.7.22 29 | cffi==1.15.1 30 | charset-normalizer==3.2.0 31 | click==8.1.6 32 | click-didyoumean==0.3.0 33 | click-plugins==1.1.1 34 | click-repl==0.2.0 35 | cmake==3.27.0 36 | comm==0.1.3 37 | contourpy==1.1.0 38 | cycler==0.11.0 39 | debugpy==1.6.7 40 | decorator==5.1.1 41 | defusedxml==0.7.1 42 | dictdiffer==0.9.0 43 | dulwich==0.21.5 44 | dvc-http==2.30.2 45 | exceptiongroup==1.1.2 46 | executing==1.2.0 47 | fastapi==0.100.0 48 | fastjsonschema==2.18.0 49 | ffmpy==0.3.1 50 | filelock==3.12.0 51 | flatten-dict==0.4.2 52 | fonttools==4.41.1 53 | fqdn==1.5.1 54 | frozenlist==1.3.3 55 | fsspec==2023.6.0 56 | funcy==2.0 57 | future==0.18.3 58 | gitdb==4.0.10 59 | GitPython==3.1.31 60 | google-auth==2.22.0 61 | google-auth-oauthlib==1.0.0 62 | gpustat==1.1 63 | gradio==3.38.0 64 | gradio_client==0.2.10 65 | grpcio==1.56.2 66 | h11==0.14.0 67 | httpcore==0.17.3 68 | httptools==0.6.0 69 | httpx==0.24.1 70 | huggingface-hub==0.16.4 71 | hydra-core==1.3.2 72 | idna==3.4 73 | importlib-metadata==6.8.0 74 | importlib-resources==5.12.0 75 | ipykernel==6.25.0 76 | ipython==8.12.2 77 | isoduration==20.11.0 78 | iterative-telemetry==0.0.8 79 | jedi==0.18.2 80 | Jinja2==3.1.2 81 | json5==0.9.14 82 | jsonpointer==2.4 83 | jsonschema==4.18.4 84 | jsonschema-specifications==2023.7.1 85 | jupyter-events==0.6.3 86 | jupyter-lsp==2.2.0 87 | jupyter_client==8.3.0 88 | jupyter_core==5.3.1 89 | jupyter_server==2.7.0 90 | jupyter_server_terminals==0.4.4 91 | jupyterlab==4.0.3 92 | jupyterlab-pygments==0.2.2 93 | jupyterlab_server==2.24.0 94 | kiwisolver==1.4.4 95 | linkify-it-py==2.0.2 96 | lit==16.0.6 97 | Markdown==3.4.3 98 | markdown-it-py==2.2.0 99 | MarkupSafe==2.1.3 100 | matplotlib==3.7.2 101 | matplotlib-inline==0.1.6 102 | mdit-py-plugins==0.3.3 103 | mdurl==0.1.2 104 | mistune==3.0.1 105 | mpmath==1.3.0 106 | multidict==6.0.4 107 | nanotime==0.5.2 108 | nbclient==0.8.0 109 | nbconvert==7.7.3 110 | nbformat==5.9.1 111 | nest-asyncio==1.5.6 112 | networkx==3.1 113 | notebook_shim==0.2.3 114 | numpy==1.23.4 115 | nvidia-cublas-cu11==11.10.3.66 116 | nvidia-cuda-cupti-cu11==11.7.101 117 | nvidia-cuda-nvrtc-cu11==11.7.99 118 | nvidia-cuda-runtime-cu11==11.7.99 119 | nvidia-cudnn-cu11==8.5.0.96 120 | nvidia-cufft-cu11==10.9.0.58 121 | nvidia-curand-cu11==10.2.10.91 122 | nvidia-cusolver-cu11==11.4.0.1 123 | nvidia-cusparse-cu11==11.7.4.91 124 | nvidia-ml-py==11.525.112 125 | nvidia-nccl-cu11==2.14.3 126 | nvidia-nvtx-cu11==11.7.91 127 | oauthlib==3.2.2 128 | omegaconf==2.3.0 129 | orjson==3.9.2 130 | overrides==7.3.1 131 | packaging==23.1 132 | pandas==2.0.3 133 | pandocfilters==1.5.0 134 | parso==0.8.3 135 | pexpect==4.8.0 136 | pickleshare==0.7.5 137 | Pillow==10.0.0 138 | pipdeptree==2.9.3 139 | pkgutil_resolve_name==1.3.10 140 | platformdirs==3.9.1 141 | prometheus-client==0.17.1 142 | prompt-toolkit==3.0.38 143 | protobuf==3.20.0 144 | psutil==5.9.5 145 | ptyprocess==0.7.0 146 | pure-eval==0.2.2 147 | pyasn1==0.5.0 148 | pyasn1-modules==0.3.0 149 | pycparser==2.21 150 | pydantic==2.0.3 151 | pydantic_core==2.3.0 152 | pydot==1.4.2 153 | pydub==0.25.1 154 | Pygments==2.15.1 155 | pyparsing==3.0.9 156 | python-dateutil==2.8.2 157 | python-dotenv==1.0.0 158 | python-json-logger==2.0.7 159 | python-multipart==0.0.6 160 | pytz==2023.3 161 | PyYAML==6.0.1 162 | pyzmq==25.1.0 163 | referencing==0.30.0 164 | regex==2023.6.3 165 | requests==2.31.0 166 | requests-oauthlib==1.3.1 167 | rfc3339-validator==0.1.4 168 | rfc3986-validator==0.1.1 169 | rich==13.3.5 170 | rpds-py==0.9.2 171 | rsa==4.9 172 | ruamel.yaml.clib==0.2.7 173 | s3transfer==0.6.1 174 | safetensors==0.3.1 175 | semantic-version==2.10.0 176 | Send2Trash==1.8.2 177 | shortuuid==1.0.11 178 | shtab==1.6.1 179 | six==1.16.0 180 | smmap==5.0.0 181 | sniffio==1.3.0 182 | soupsieve==2.4.1 183 | stack-data==0.6.2 184 | starlette==0.27.0 185 | sympy==1.12 186 | tabulate==0.9.0 187 | tensorboard==2.13.0 188 | tensorboard-data-server==0.7.1 189 | tensorboard-plugin-wit==1.8.1 190 | terminado==0.17.1 191 | tinycss2==1.2.1 192 | tokenizers==0.13.3 193 | tomli==2.0.1 194 | tomlkit==0.11.8 195 | toolz==0.12.0 196 | torch==2.0.1 197 | torchvision==0.15.2 198 | tornado==6.3.2 199 | tqdm==4.65.0 200 | traitlets==5.9.0 201 | transformers==4.31.0 202 | triton==2.0.0 203 | typing_extensions==4.7.1 204 | tzdata==2023.3 205 | uc-micro-py==1.0.2 206 | uri-template==1.3.0 207 | urllib3==1.26.16 208 | uvicorn==0.23.1 209 | uvloop==0.17.0 210 | vine==5.0.0 211 | voluptuous==0.13.1 212 | watchfiles==0.19.0 213 | wcwidth==0.2.6 214 | webcolors==1.13 215 | webencodings==0.5.1 216 | websocket-client==1.6.1 217 | websockets==11.0.3 218 | Werkzeug==2.3.6 219 | wrapt==1.15.0 220 | yarl==1.9.2 221 | zc.lockfile==3.0.post1 222 | zipp==3.15.0 223 | -------------------------------------------------------------------------------- /api-demo/static/styles.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | color: gray; 3 | text-align: center; 4 | } 5 | 6 | .center { 7 | display: block; 8 | margin-left: auto; 9 | margin-right: auto; 10 | width: 50%; 11 | } 12 | -------------------------------------------------------------------------------- /api-demo/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Item Details 4 | 5 | 6 | 7 | Logo 8 |

Hello, {{ name }}

9 | 10 | 11 | -------------------------------------------------------------------------------- /image-to-image-search/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | clip @ git+https://github.com/openai/CLIP.git 3 | faiss-cpu 4 | 5 | # Download data 6 | roboflow 7 | -------------------------------------------------------------------------------- /leaked_container/.env: -------------------------------------------------------------------------------- 1 | API_KEY=this_is_a_secret_key 2 | -------------------------------------------------------------------------------- /leaked_container/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | WORKDIR /app 4 | 5 | COPY . . 6 | 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | 9 | COPY .env /app/.env 10 | 11 | RUN export $(cat /app/.env | xargs) 12 | 13 | RUN echo "API_KEY=${API_KEY}" > /tmp/credentials.txt 14 | 15 | # remove .env 16 | RUN rm /app/.env # INSECURE 17 | 18 | CMD ["python", "src/index.py"] 19 | 20 | EXPOSE 3000 21 | 22 | -------------------------------------------------------------------------------- /leaked_container/requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | -------------------------------------------------------------------------------- /leaked_container/src/index.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | import os 3 | 4 | app = Flask(__name__) 5 | 6 | @app.route('/') 7 | def index(): 8 | api_key = os.getenv('API_KEY') 9 | return f"API_KEY is: {api_key}" 10 | 11 | if __name__ == '__main__': 12 | app.run(host='0.0.0.0', port=3000) 13 | 14 | -------------------------------------------------------------------------------- /rag-foundation/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /rag-foundation/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Pre-commit hooks 2 | default_language_version: 3 | python: python3 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.5.0 8 | hooks: 9 | - id: check-ast 10 | - id: check-yaml 11 | - id: check-json 12 | - id: check-toml 13 | - id: check-case-conflict 14 | - id: check-docstring-first 15 | # - id: check-added-large-files 16 | - id: trailing-whitespace 17 | - id: detect-aws-credentials 18 | args: ["--allow-missing-credentials"] 19 | - id: detect-private-key 20 | - id: end-of-file-fixer 21 | - id: mixed-line-ending 22 | 23 | - repo: https://github.com/psf/black 24 | rev: 23.12.1 25 | hooks: 26 | - id: black 27 | name: PEP8 formatting 28 | args: [ --skip-string-normalization] 29 | 30 | - repo: https://github.com/PyCQA/isort 31 | rev: 5.13.2 32 | hooks: 33 | - id: isort 34 | name: I-sort imports 35 | args: ["--profile", "black"] 36 | 37 | - repo: https://github.com/PyCQA/flake8 38 | rev: 7.0.0 39 | hooks: 40 | - id: flake8 41 | name: PEP8 checker 42 | 43 | - repo: https://github.com/myint/autoflake 44 | rev: v2.2.1 45 | hooks: 46 | - id: autoflake 47 | args: 48 | [ 49 | "--in-place", 50 | # "--remove-unused-variables", 51 | "--remove-all-unused-imports", 52 | "--ignore-init-module-imports", 53 | ] 54 | -------------------------------------------------------------------------------- /rag-foundation/README.md: -------------------------------------------------------------------------------- 1 | # rag-foundation-exercise 2 | 3 | ## Installation 4 | 5 | **Note:** Prefer `python=3.10.*` 6 | 7 | ### 1. Fork the repo 8 | 9 | ### 2. Set up environment 10 | Assume that the name of your forked repository is also `ai-bootcamp-2024`. 11 | 12 | #### Windows 13 | 14 | - **Open Command Prompt.** 15 | - **Navigate to your project directory:** 16 | 17 | ```sh 18 | cd C:\Path\To\ai-bootcamp-2024 19 | ``` 20 | 21 | - **Create a virtual environment using Python 3.10:** 22 | 23 | Check your python version first using `py -0` or `where python` 24 | 25 | ``` 26 | python -m venv rag-foundation 27 | or 28 | path/to/python3.10 -m venv rag-foundation 29 | ``` 30 | 31 | - **Activate the Virtual Environment:** 32 | 33 | ```sh 34 | rag-foundation\Scripts\activate 35 | ``` 36 | 37 | #### Ubuntu/MacOS 38 | 39 | - **Open a terminal.** 40 | - **Create a new Conda environment with Python 3.10:** 41 | 42 | ```sh 43 | conda create --name rag-foundation python=3.10 44 | ``` 45 | 46 | - **Activate the Conda Environment:** 47 | 48 | ```sh 49 | conda activate rag-foundation 50 | ``` 51 | 52 | ### 3. **Install Required Packages:** 53 | 54 | - Install the required packages from `requirements.txt`: 55 | 56 | ```sh 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ## Homework 61 | 62 | ### 1. **Fill your implementation** 63 | 64 | Search for `"Your code here"` line in the codebase which will lead you to where you should place your code. 65 | 66 | ### 2. **Run script** 67 | 68 | You should read the code in this repository carefully to understand the setup comprehensively. 69 | 70 | You can run the script below to get the results from your pre-built RAG, for example: 71 | 72 | ```sh 73 | python -m scripts.main \ 74 | --data_path \ 75 | --output_path predictions.jsonl \ 76 | --mode \ 77 | --force_index \ 78 | --retrieval_only True \ 79 | --top_k 5 80 | ``` 81 | 82 | where some arguments can be: 83 | 84 | - `mode`: `sparse` or `semantic` 85 | - `force_index`: `True` or `False` (True: override the old vectorstore index) 86 | - `retrieval_only`: `True` or `False` (True: just get the retrieval contexts, answers are empty) 87 | 88 | #### NOTE: 89 | 90 | To use LLM generation with RAG pipeline, you can use ChatOpenAI by supplying OPENAI_API_KEY in the enviroment variable (supposed you have one). 91 | If you don't have access to OpenAI API, use Groq free-tier instead: 92 | 93 | - Register an account at https://console.groq.com/keys (free) 94 | - Generate your API key 95 | - Assign env variable: `export GROQ_API_KEY=` 96 | - Run the main script without `--retrieval_only` to use LLM 97 | 98 | ### 3. **Run Evaluation:** 99 | ```sh 100 | python evaluate.py --predictions predictions.jsonl --gold data/qasper-test-v0.3.json --retrieval_only 101 | ``` 102 | $\rightarrow$ just evaluate the retrieval contexts. 103 | 104 | ```sh 105 | python evaluate.py --predictions predictions.jsonl --gold data/qasper-test-v0.3.json 106 | ``` 107 | $\rightarrow$ evaluate both the retrieval contexts and answers. 108 | -------------------------------------------------------------------------------- /rag-foundation/data/llama2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/rag-foundation/data/llama2.pdf -------------------------------------------------------------------------------- /rag-foundation/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Official script for evaluating models built for the Qasper dataset. The script 3 | outputs Answer F1 and Evidence F1 reported in the paper. 4 | """ 5 | 6 | import argparse 7 | import json 8 | import re 9 | import string 10 | from collections import Counter 11 | 12 | 13 | def normalize_answer(s): 14 | """ 15 | Taken from the official evaluation script for v1.1 of the SQuAD dataset. 16 | Lower text and remove punctuation, articles and extra whitespace. 17 | """ 18 | 19 | def remove_articles(text): 20 | return re.sub(r"\b(a|an|the)\b", " ", text) 21 | 22 | def white_space_fix(text): 23 | return " ".join(text.split()) 24 | 25 | def remove_punc(text): 26 | exclude = set(string.punctuation) 27 | return "".join(ch for ch in text if ch not in exclude) 28 | 29 | def lower(text): 30 | return text.lower() 31 | 32 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 33 | 34 | 35 | def token_f1_score(prediction, ground_truth): 36 | """ 37 | Taken from the official evaluation script for v1.1 of the SQuAD dataset. 38 | """ 39 | prediction_tokens = normalize_answer(prediction).split() 40 | ground_truth_tokens = normalize_answer(ground_truth).split() 41 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 42 | num_same = sum(common.values()) 43 | if num_same == 0: 44 | return 0 45 | precision = 1.0 * num_same / len(prediction_tokens) 46 | recall = 1.0 * num_same / len(ground_truth_tokens) 47 | f1 = (2 * precision * recall) / (precision + recall) 48 | return f1 49 | 50 | 51 | def paragraph_f1_score(prediction, ground_truth): 52 | if not ground_truth and not prediction: 53 | # The question is unanswerable and the prediction is empty. 54 | return 1.0 55 | num_same = len(set(ground_truth).intersection(set(prediction))) 56 | if num_same == 0: 57 | return 0.0 58 | precision = num_same / len(prediction) 59 | recall = num_same / len(ground_truth) 60 | f1 = (2 * precision * recall) / (precision + recall) 61 | return f1 62 | 63 | 64 | def get_answers_and_evidence(data, text_evidence_only): 65 | answers_and_evidence = {} 66 | for paper_data in data.values(): 67 | for qa_info in paper_data["qas"]: 68 | question_id = qa_info["question_id"] 69 | references = [] 70 | for annotation_info in qa_info["answers"]: 71 | answer_info = annotation_info["answer"] 72 | if answer_info["unanswerable"]: 73 | references.append( 74 | {"answer": "Unanswerable", "evidence": [], "type": "none"} 75 | ) 76 | else: 77 | if answer_info["extractive_spans"]: 78 | answer = ", ".join(answer_info["extractive_spans"]) 79 | answer_type = "extractive" 80 | elif answer_info["free_form_answer"]: 81 | answer = answer_info["free_form_answer"] 82 | answer_type = "abstractive" 83 | elif answer_info["yes_no"]: 84 | answer = "Yes" 85 | answer_type = "boolean" 86 | elif answer_info["yes_no"] is not None: 87 | answer = "No" 88 | answer_type = "boolean" 89 | else: 90 | raise RuntimeError( 91 | f"Annotation {answer_info['annotation_id']} does not contain an answer" 92 | ) 93 | if text_evidence_only: 94 | evidence = [ 95 | text 96 | for text in answer_info["evidence"] 97 | if "FLOAT SELECTED" not in text 98 | ] 99 | else: 100 | evidence = answer_info["evidence"] 101 | references.append( 102 | {"answer": answer, "evidence": evidence, "type": answer_type} 103 | ) 104 | answers_and_evidence[question_id] = references 105 | 106 | return answers_and_evidence 107 | 108 | 109 | def evaluate(gold, predicted, retrieval_only=False): 110 | max_answer_f1s = [] 111 | max_evidence_f1s = [] 112 | max_answer_f1s_by_type = { 113 | "extractive": [], 114 | "abstractive": [], 115 | "boolean": [], 116 | "none": [], 117 | } 118 | num_missing_predictions = 0 119 | for question_id, references in gold.items(): 120 | if question_id not in predicted: 121 | num_missing_predictions += 1 122 | max_answer_f1s.append(0.0) 123 | max_evidence_f1s.append(0.0) 124 | continue 125 | answer_f1s_and_types = [ 126 | ( 127 | token_f1_score(predicted[question_id]["answer"], reference["answer"]), 128 | reference["type"], 129 | ) 130 | for reference in gold[question_id] 131 | ] 132 | max_answer_f1, answer_type = sorted( 133 | answer_f1s_and_types, key=lambda x: x[0], reverse=True 134 | )[0] 135 | max_answer_f1s.append(max_answer_f1) 136 | max_answer_f1s_by_type[answer_type].append(max_answer_f1) 137 | evidence_f1s = [ 138 | paragraph_f1_score( 139 | predicted[question_id]["evidence"], reference["evidence"] 140 | ) 141 | for reference in gold[question_id] 142 | ] 143 | max_evidence_f1s.append(max(evidence_f1s)) 144 | 145 | mean = lambda x: sum(x) / len(x) if x else 0.0 146 | 147 | if not retrieval_only: 148 | return { 149 | "Answer F1": mean(max_answer_f1s), 150 | "Answer F1 by type": { 151 | key: mean(value) for key, value in max_answer_f1s_by_type.items() 152 | }, 153 | "Evidence F1": mean(max_evidence_f1s), 154 | "Missing predictions": num_missing_predictions, 155 | } 156 | else: 157 | return { 158 | "Evidence F1": mean(max_evidence_f1s), 159 | } 160 | 161 | 162 | if __name__ == "__main__": 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument( 165 | "--predictions", 166 | type=str, 167 | required=True, 168 | help="""JSON lines file with each line in format: 169 | {'question_id': str, 'predicted_answer': str, 'predicted_evidence': List[str]}""", 170 | ) 171 | parser.add_argument( 172 | "--gold", 173 | type=str, 174 | required=True, 175 | help="Test or dev set from the released dataset", 176 | ) 177 | parser.add_argument( 178 | "--retrieval_only", 179 | help="If set, the evaluator will just evaluate the retrieval scores", 180 | action="store_true", 181 | ) 182 | parser.add_argument( 183 | "--text_evidence_only", 184 | action="store_true", 185 | help="If set, the evaluator will ignore evidence in figures and tables while reporting evidence f1", 186 | ) 187 | args = parser.parse_args() 188 | gold_data = json.load(open(args.gold)) 189 | gold_answers_and_evidence = get_answers_and_evidence( 190 | gold_data, args.text_evidence_only 191 | ) 192 | predicted_answers_and_evidence = {} 193 | for line in open(args.predictions): 194 | prediction_data = json.loads(line) 195 | predicted_answers_and_evidence[prediction_data["question_id"]] = { 196 | "answer": prediction_data["predicted_answer"], 197 | "evidence": prediction_data["predicted_evidence"], 198 | } 199 | evaluation_output = evaluate( 200 | gold_answers_and_evidence, 201 | predicted_answers_and_evidence, 202 | retrieval_only=args.retrieval_only, 203 | ) 204 | print(json.dumps(evaluation_output, indent=2)) 205 | -------------------------------------------------------------------------------- /rag-foundation/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.5 2 | aiosignal==1.3.1 3 | annotated-types==0.7.0 4 | anyio==4.4.0 5 | async-timeout==4.0.3 6 | attrs==23.2.0 7 | beautifulsoup4==4.12.3 8 | certifi==2024.7.4 9 | charset-normalizer==3.3.2 10 | click==8.1.7 11 | dataclasses-json==0.6.7 12 | Deprecated==1.2.14 13 | dirtyjson==1.0.8 14 | distro==1.9.0 15 | exceptiongroup==1.2.2 16 | filelock==3.15.4 17 | frozenlist==1.4.1 18 | fsspec==2024.6.1 19 | greenlet==3.0.3 20 | h11==0.14.0 21 | httpcore==1.0.5 22 | httpx==0.27.0 23 | huggingface-hub==0.23.4 24 | idna==3.7 25 | Jinja2==3.1.4 26 | joblib==1.4.2 27 | llama-cloud==0.0.9 28 | llama-index==0.10.55 29 | llama-index-agent-openai==0.2.8 30 | llama-index-cli==0.1.12 31 | llama-index-core==0.10.55 32 | llama-index-embeddings-openai==0.1.10 33 | llama-index-indices-managed-llama-cloud==0.2.5 34 | llama-index-legacy==0.9.48 35 | llama-index-llms-ollama==0.1.5 36 | llama-index-llms-openai==0.1.25 37 | llama-index-multi-modal-llms-openai==0.1.7 38 | llama-index-program-openai==0.1.6 39 | llama-index-question-gen-openai==0.1.3 40 | llama-index-readers-file==0.1.30 41 | llama-index-readers-llama-parse==0.1.6 42 | llama-parse==0.4.7 43 | loguru==0.7.2 44 | MarkupSafe==2.1.5 45 | marshmallow==3.21.3 46 | mpmath==1.3.0 47 | multidict==6.0.5 48 | mypy-extensions==1.0.0 49 | nest-asyncio==1.6.0 50 | networkx==3.3 51 | nltk==3.8.1 52 | numpy==1.26.4 53 | openai==1.35.13 54 | packaging==24.1 55 | pandas==2.2.2 56 | pillow==10.4.0 57 | pydantic==2.8.2 58 | pydantic_core==2.20.1 59 | PyMuPDF==1.24.7 60 | PyMuPDFb==1.24.6 61 | pypdf==4.3.0 62 | python-dateutil==2.9.0.post0 63 | pytz==2024.1 64 | PyYAML==6.0.1 65 | regex==2024.5.15 66 | requests==2.32.3 67 | safetensors==0.4.3 68 | scikit-learn==1.5.1 69 | scipy==1.14.0 70 | sentence-transformers==3.0.1 71 | setuptools==69.5.1 72 | six==1.16.0 73 | sniffio==1.3.1 74 | soupsieve==2.5 75 | SQLAlchemy==2.0.31 76 | striprtf==0.0.26 77 | sympy==1.13.0 78 | tenacity==8.5.0 79 | threadpoolctl==3.5.0 80 | tiktoken==0.7.0 81 | tokenizers==0.19.1 82 | torch==2.3.1 83 | tqdm==4.66.4 84 | transformers==4.42.4 85 | typing-inspect==0.9.0 86 | typing_extensions==4.12.2 87 | tzdata==2024.1 88 | urllib3==2.2.2 89 | wheel==0.43.0 90 | wrapt==1.16.0 91 | yarl==1.9.4 92 | fire 93 | langchain-openai 94 | langchain-groq 95 | -------------------------------------------------------------------------------- /rag-foundation/sample_predictions.jsonl: -------------------------------------------------------------------------------- 1 | {"question_id": "397a1e851aab41c455c2b284f5e4947500d797f0", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a typing conversation task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents. From Table TABREF19, we observe that MISSA maintains longer conversations (14.9 turns) compared with TransferTransfo (8.5 turns), which means MISSA is better at maintaining the attacker's engagement. MISSA also has a higher task success score (1.294) than TransferTransfo (1.025), which indicates that it elicits information more strategically. In the top two dialogs (A and B) that are shown in Table TABREF21, both attackers were eliciting a credit card number in their first turns. TransferTransfo directly gave away the information, while MISSA replied with a semantically-related question \u201cwhy would you need my credit card number?\" Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor.", "MISSA is based on the generative pre-trained transformer BIBREF32. We use an Adam optimizer with a learning rate of 6.25e-5 and $L2$ weight decay of $0.01$, we set the coefficient of language modeling loss to be 2, the coefficient of intent and slot classifiers to be 1, and the coefficient of next-utterance classifier to be 1. We first pre-train the model on the PERSONA-CHAT dataset. When fine-tuning on the AntiScam and the PersuasionForGood datasets, we use $80\\%$ data for training, $10\\%$ data for validation, and $10\\%$ data for testing. Since the original PersuasionForGood dataset is annotated with intents, we separate the original on-task and off-task intents, which are shown in Table TABREF2. To deal with the words out of the vocabulary, we conduct delexicalization to replace slot values with corresponding slot tokens during the training phase, and replace the slot tokens with pre-defined information during testing.\n\nAn example of human-human chat on AntiScam dataset is shown in Table TABREF25.", "MISSA follows the TransferTransfo framework BIBREF0 with three modifications: (i) We first concurrently predict user's, system's intents and semantic slots; (ii) We then perform conditional generation to improve generated response's coherence. Specifically, we generate responses conditioned on the above intermediate representation (intents and slots); (iii) Finally, we generate multiple responses with the nucleus sampling strategy BIBREF5 and then apply a response filter, which contains a set of pre-defined constraints to select coherent responses. The constraints in the filter can be defined according to specific task requirements or general conversational rules.\n\nTo enrich publicly available non-collaborative task datasets, we collect a new dataset AntiScam, where users defend themselves against attackers trying to collect personal information. As non-collaborative tasks are still relatively new to the study of dialog systems, there are insufficiently many meaningful datasets for evaluation and we hope this provides a valuable example. We evaluate MISSA on the newly collected AntiScam dataset and an existing PersuasionForGood dataset. Both automatic and human evaluations suggest that MISSA outperforms multiple competitive baselines.\n\nIn summary, our contributions include: (i) We design a hierarchical intent annotation scheme and a semantic slot annotation scheme to annotate the non-collaborative dialog dataset, we also propose a carefully-designed AntiScam dataset to facilitate the research of non-collaborative dialog systems. (ii) We propose a model that can be applied to all non-collaborative tasks, outperforming other baselines on two different non-collaborative tasks. (iii) We develop an anti-scam dialog system to occupy attacker's attention and elicit their private information for social good. Furthermore, we also build a persuasion dialog system to persuade people to donate to charities. We release the code and data.\n\nThe interest in non-collaborative tasks has been increasing and there have already been several related datasets. For instance, BIBREF1 wang2019persuasion collected conversations where one participant persuades another to donate to a charity. BIBREF2 he2018decoupling collected negotiation dialogs where buyers and sellers bargain for items for sale on Craigslist. There are many other non-collaborative tasks, such as the turn-taking game BIBREF6, the multi-party game BIBREF7 and item splitting negotiation BIBREF8.", "We posted a role-playing task on the Amazon Mechanical Turk platform and collected a typing conversation dataset named AntiScam. We collected 220 human-human dialogs. The average conversation length is 12.45 turns and the average utterance length is 11.13 words. Only 172 out of 220 users successfully identified their partner as an attacker, suggesting that the attackers are well trained and not too easily identifiable. We recruited two expert annotators who have linguistic training to annotate 3,044 sentences in 100 dialogs, achieving a 0.874 averaged weighted kappa value.\n\nThe PersuasionForGood dataset BIBREF1 was collected from typing conversations on Amazon Mechanical Turk platform. Two workers were randomly paired, one was assigned the role of persuader, the other was persuadee. The goal of the persuader was to persuade the persuadee to donate a portion of task earning to a specific charity. The dataset consists of 1,017 dialogs, where 300 dialogs are annotated with dialog acts. The average conversation length is 10.43, the vocabulary size is 8,141. Since the original PersuasionForGood dataset is annotated with dialog acts, we select the on-task dialog acts as on-task intents shown in Table TABREF2, and categorize the other dialog acts into our pre-defined off-task intents.\n\nThe TransferTransfo framework was proposed to build open domain dialog systems. BIBREF0 wolf2019transfertransfo fine-tuned the generative pre-training model (GPT) BIBREF32 with the PERSONA-CHAT dataset BIBREF33 in a multi-task fashion, where the language model objective is combined with a next-utterance classification task. The language model's objective is to maximize the following likelihood for a given sequence of tokens, $X = \\lbrace x_1,\\dots ,x_n\\rbrace $:\n\nThe authors also trained a classifier to distinguish the correct next-utterance appended to the input human utterances from a set of randomly selected utterance distractors. In addition, they introduced dialog state embeddings to indicate speaker role in the model. The model significantly outperformed previous baselines over both automatic evaluations and human evaluations in social conversations. Since the TransferTransfo framework performs well in open domain, we adapt it for non-collaborative settings.", "We suspect the underlying reason is that there are more possible responses with the same intent in PersuasionForGood than in AntiScam. This also suggests that we should adjust the model structure according to the nature of the dataset.\n\nWe propose a general dialog system pipeline to build non-collaborative dialog systems, including a hierarchical annotation scheme and an end-to-end neural response generation model called MISSA. With the hierarchical annotation scheme, we can distinguish on-task and off-task intents. MISSA takes both on and off-task intents as supervision in its training and thus can deal with diverse user utterances in non-collaborative settings. Moreover, to validate MISSA's performance, we create a non-collaborate dialog dataset that focuses on deterring phone scammers. MISSA outperforms all baseline methods in terms of fluency, coherency, and user engagement on both the newly proposed anti-scam task and an existing persuasion task. However, MISSA still produces responses that are not consistent with their distant conversation history as GPT can only track a limited history span. In future work, we plan to address this issue by developing methods that can effectively track longer dialog context.\n\nThis work was supported by DARPA ASED Program HR001117S0050. The U.S. Government is authorized to reproduce and distribute reprints for governmental purposes not withstanding any copyright notation therein. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies, either expressed or implied, of DARPA or the U.S. Government.\n\nWe randomly pair two workers: one is assigned the role of the attacker to elicit user information, and the other one is assigned the role of an everyday user who aims to protect her/his information and potentially elicit the attacker's information. We give both workers specific personal data. Instructions are shown in Table TABREF24. The \u201cattacker\u201d additionally receives training on how to elicit information from people. Workers cannot see their partners' instructions.\n\nThere are two tasks for the users: firstly, users are required to chat with their partners and determine if they are attackers or not, reporting their decisions at the end of the task. If users think their partners are attackers, they are instructed to prolong the conversation and elicit information from their partners."]} 2 | {"question_id": "cc8b4ed3985f9bfbe1b5d7761b31d9bd6a965444", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a typing conversation task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["The intent predictor achieves a $84\\%$ accuracy and the semantic slot predictor achieves $77\\%$ on the AntiScam dataset. Then we compare the predicted values with human-annotated ground truth in the dataset to compute the response-intent prediction (RIP) and response-slot prediction (RSP).\n\nExtended Response-Intent Prediction (ERIP) $\\&$ Extended Response-Slot Prediction (ERSP) With Response-Intent Prediction, we verify the predicted intents to evaluate the coherence of the dialog. However, the real mapping between human-intent and system-intent is much more complicated as there might be multiple acceptable system-intents for the same human-intent. Therefore, we also design a metric to evaluate if the predicted system-intent is in the set of acceptable intents. Specifically, we estimate the transition probability $p(I_i|I_j)$ by counting the frequency of all the bi-gram human-intent and system-intent pairs in the training data. During the test stage, if the predicted intent matches the ground truth, we set the score as 1, otherwise we set the score as $p(I_{predict}|I_i)$ where $I_i$ is the intent of the input human utterance. We then report the average value of those scores over turns as the final extended response-intent prediction result.\n\nAutomatic metrics only validate the system\u2019s performance on a single dimension at a time. The ultimate holistic evaluation should be conducted by having the trained system interact with human users. Therefore we also conduct human evaluations for the dialog system built on AntiScam. We test our models and baselines with 15 college-student volunteers. Each of them is asked to pretend to be an attacker and interact with all the models for at least three times to avoid randomness. We in total collect 225 number of dialogs. Each time, volunteers are required to use similar sentences and strategies to interact with all five models and score each model based on the metrics listed below at the end of the current round. Each model receives a total of 45 human ratings, and the average score is reported as the final human-evaluation score. In total, we design five different metrics to assess the models' conversational ability whilst interacting with humans. The results are shown in Table TABREF19.\n\nFluency Fluency is used to explore different models' language generation quality.", "Compared with these works, MISSA is end-to-end trainable and thus easier to train and update.\n\nTo decouple syntactic and semantic information in utterances and provide detailed supervision, we design a hierarchical intent annotation scheme for non-collaborative tasks. We first separate on-task and off-task intents. As on-task intents are key actions that can vary among different tasks, we need to specifically define on-task intents for each task. On the other hand, since off-task content is too general to design task-specific intents, we choose common dialog acts as the categories. The advantage of this hierarchical annotation scheme is apparent when starting a new non-collaborative task: we only need to focus on designing the on-task categories and semantic slots which are the same as traditional task-oriented dialog systems. Consequently, we don't have to worry about the off-task annotation design since the off-task category is universal.\n\nIn the intent annotation scheme shown in Table TABREF2, we list the designed intent annotation scheme for the newly collected AntiScam dataset and the PersuasionForGood dataset. We first define on-task intents for the datasets, which are key actions in the task. Since our AntiScam focuses on understanding and reacting towards elicitations, we define elicitation, providing_information and refusal as on-task intents. In the PersuasionForGood dataset, we define nine on-task intents in Table TABREF2 based on the original PersuasionForGood dialog act annotation scheme. All these intents are related to donation actions, which are salient on-task intents in the persuasion task. The off-task intents are the same for both tasks, including six general intents and six additional social intents. General intents are more closely related to the syntactic meaning of the sentence (open_question, yes_no_question, positive_answer, negative_answer, responsive_statement, and nonresponsive_statement) while social intents are common social actions (greeting, closing, apology, thanking,respond_to_thank, and hold).\n\nFor specific tasks, we also design a semantic slot annotation scheme for annotating sentences based on their semantic content. We identify 13 main semantic slots in the anti-scam task, for example, credit card numbers. We present a detailed semantic slot annotation in Table TABREF3. Following BIBREF1, we segment each conversation turn into single sentences and then annotate each sentence rather than turns.", "Therefore, we need to design a system that handles both on-task and off-task information appropriately and in a way that leads back to the system's goal.\n\nTo tackle the issue of incoherent system responses to off-task content, previous studies have built hybrid systems to interleave off-task and on-task content. BIBREF4 used a rule-based dialog manager for on-task content and a neural model for off-task content, and trained a reinforcement learning model to select between these two models based on the dialog context. However, such a method is difficult to train and struggles to generalize beyond the movie promotion task they considered. To tackle these problems, we propose a hierarchical intent annotation scheme that separates on-task and off-task information in order to provide detailed supervision. For on-task information, we directly use task-related intents for representation. Off-task information, on the other hand, is too general to categorize into specific intents, so we choose dialog acts that convey syntax information. These acts, such as \u201copen question\" are general to all tasks.\n\nPrevious studies use template-based methods to maintain sentence coherence. However, rigid templates lead to limited diversity, causing the user losing engagement. On the other hand, language generation models can generate diverse responses but are bad at being coherent. We propose Multiple Intents and Semantic Slots Annotation Neural Network (MISSA) to combine the advantages of both template and generation models and takes advantage from the hierarchical annotation at the same time. MISSA follows the TransferTransfo framework BIBREF0 with three modifications: (i) We first concurrently predict user's, system's intents and semantic slots; (ii) We then perform conditional generation to improve generated response's coherence. Specifically, we generate responses conditioned on the above intermediate representation (intents and slots); (iii) Finally, we generate multiple responses with the nucleus sampling strategy BIBREF5 and then apply a response filter, which contains a set of pre-defined constraints to select coherent responses. The constraints in the filter can be defined according to specific task requirements or general conversational rules.\n\nTo enrich publicly available non-collaborative task datasets, we collect a new dataset AntiScam, where users defend themselves against attackers trying to collect personal information. As non-collaborative tasks are still relatively new to the study of dialog systems, there are insufficiently many meaningful datasets for evaluation and we hope this provides a valuable example.", "BIBREF9 hardy2002multi followed the DAMSL schemeBIBREF10 and annotated a multilingual human-computer dialog corpus with a hierarchical dialog act annotation scheme. BIBREF11 gupta2018semantic used a hierarchical annotation scheme for semantic parsing. Inspired by these studies, our idea is to annotate the intent and semantic slot separately in non-collaborative tasks. We propose a hierarchical intent annotation scheme that can be adopted by all non-collaborative tasks. With this annotation scheme, MISSA is able to quickly build an end-to-end trainable dialog system for any non-collaborative task.\n\nTraditional task-oriented dialog systems BIBREF12 are usually composed of multiple independent modules, for example, natural language understanding, dialog state tracking BIBREF13, BIBREF14, dialog policy manager BIBREF15, and natural language generation BIBREF16. Conversational intent is adopted to capture the meaning of task content in these dialog systems BIBREF2, BIBREF17. In comparison to this work, we use a hierarchical intent scheme that includes off-task and on-task intents to capture utterance meaning. We also train the model in a multi-task fashion to predict decoupled intents and semantic slots. The major defect of a separately trained pipeline is the laborious dialog state design and annotation. In order to mitigate this problem, recent work has explored replacing independent modules with end-to-end neural networks BIBREF18, BIBREF19, BIBREF20. Our model also follows this end-to-end fashion.\n\nOver the last few years, we have witnessed a huge growth in non-task-oriented dialog systems BIBREF21, BIBREF22. Social chatbots such as Gunrock BIBREF23 were able to maintain a conversation for around ten minutes in an open domain. Recent improvements build on top of the transformer and pre-trained language models BIBREF24, BIBREF25, BIBREF26, obtained state-of-the-art results on the Persona-Chat dataset BIBREF0. Pre-trained language models are proposed to build task-oriented dialog systems to drive the progress on leveraging large amounts of available unannotated data. BIBREF27. Similarly, our approach is also built on top of the TransferTransfo framework BIBREF0. BIBREF27 budzianowski2019hello focused on collaborative tasks BIBREF28.", "All these intents are related to donation actions, which are salient on-task intents in the persuasion task. The off-task intents are the same for both tasks, including six general intents and six additional social intents. General intents are more closely related to the syntactic meaning of the sentence (open_question, yes_no_question, positive_answer, negative_answer, responsive_statement, and nonresponsive_statement) while social intents are common social actions (greeting, closing, apology, thanking,respond_to_thank, and hold).\n\nFor specific tasks, we also design a semantic slot annotation scheme for annotating sentences based on their semantic content. We identify 13 main semantic slots in the anti-scam task, for example, credit card numbers. We present a detailed semantic slot annotation in Table TABREF3. Following BIBREF1, we segment each conversation turn into single sentences and then annotate each sentence rather than turns.\n\nWe test our approach on two non-collaborative task datasets: the AntiScam dataset and the PersuasionForGood dataset BIBREF1. Both datasets are collected from the Amazon Mechanical Turk platform in the form of typing conversations and off-task dialog is interleaved in the dialog.\n\nTo enrich available non-collaborative task datasets, we created a corpus of human-human anti-scam dialogs in order to learn human elicitation strategies. We chose a popular Amazon customer service scam scenario to collect dialogs between users and attackers who aim to collect users information. We posted a role-playing task on the Amazon Mechanical Turk platform and collected a typing conversation dataset named AntiScam. We collected 220 human-human dialogs. The average conversation length is 12.45 turns and the average utterance length is 11.13 words. Only 172 out of 220 users successfully identified their partner as an attacker, suggesting that the attackers are well trained and not too easily identifiable. We recruited two expert annotators who have linguistic training to annotate 3,044 sentences in 100 dialogs, achieving a 0.874 averaged weighted kappa value.\n\nThe PersuasionForGood dataset BIBREF1 was collected from typing conversations on Amazon Mechanical Turk platform. Two workers were randomly paired, one was assigned the role of persuader, the other was persuadee. The goal of the persuader was to persuade the persuadee to donate a portion of task earning to a specific charity."]} 3 | {"question_id": "f7662b11e87c1e051e13799413f3db459ac3e19c", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a typing conversation task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["MISSA is based on the generative pre-trained transformer BIBREF32. We use an Adam optimizer with a learning rate of 6.25e-5 and $L2$ weight decay of $0.01$, we set the coefficient of language modeling loss to be 2, the coefficient of intent and slot classifiers to be 1, and the coefficient of next-utterance classifier to be 1. We first pre-train the model on the PERSONA-CHAT dataset. When fine-tuning on the AntiScam and the PersuasionForGood datasets, we use $80\\%$ data for training, $10\\%$ data for validation, and $10\\%$ data for testing. Since the original PersuasionForGood dataset is annotated with intents, we separate the original on-task and off-task intents, which are shown in Table TABREF2. To deal with the words out of the vocabulary, we conduct delexicalization to replace slot values with corresponding slot tokens during the training phase, and replace the slot tokens with pre-defined information during testing.\n\nAn example of human-human chat on AntiScam dataset is shown in Table TABREF25.", "The results are shown in Table TABREF19. We find that MISSA has higher fluency score and coherence score than MISSA-con (4.18 vs 3.78 for fluency, and 3.75 vs 3.68 for coherence), which suggests that conditioning on the system intent to generate responses improves the quality of the generated sentences. Compared with MISSA-sel, MISSA achieves better performance on all the metrics. For example, the engagement score for MISSA is 3.69 while MISSA-sel only has 2.87. This is because the response filter removed all the incoherent responses, which makes the attacker more willing to keep chatting. The ablation study shows both the conditional language generation mechanism and the response filter are essential to MISSA's good performance.\n\nWe also apply our method to the PersuasionForGood dataset. As shown in Table TABREF23, MISSA and its variants outperform the TransferTransfo and the hybrid models on all evaluation metrics. Such good performance indicates MISSA can be easily applied to a different non-collaborative task and achieve good performance. Particularly, MISSA achieves the lowest perplexity, which confirms that using conditional response generation leads to high quality responses. Compared with the result on AntiScam dataset, MISSA-con performs the best in terms of RIP and ERIP. We suspect the underlying reason is that there are more possible responses with the same intent in PersuasionForGood than in AntiScam. This also suggests that we should adjust the model structure according to the nature of the dataset.\n\nWe propose a general dialog system pipeline to build non-collaborative dialog systems, including a hierarchical annotation scheme and an end-to-end neural response generation model called MISSA. With the hierarchical annotation scheme, we can distinguish on-task and off-task intents. MISSA takes both on and off-task intents as supervision in its training and thus can deal with diverse user utterances in non-collaborative settings. Moreover, to validate MISSA's performance, we create a non-collaborate dialog dataset that focuses on deterring phone scammers. MISSA outperforms all baseline methods in terms of fluency, coherency, and user engagement on both the newly proposed anti-scam task and an existing persuasion task.", "Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor. In the second turn, both attackers were explaining the reason of requesting the billing address previously. With semantic slot information, MISSA can easily understand the attacker; but the hybrid model misunderstands that the attacker was talking about the order number, possibly because the token \u201corder\u201d appeared in the attacker's utterance. We suspect that the hybrid model's bad performance on the off-task content leads to its low coherence rating (2.76) and short dialog length (8.2).\n\nTo explore the influence of the intent-based conditional response generation method and the designed response filter, we perform an ablation study. The results are shown in Table TABREF19. We find that MISSA has higher fluency score and coherence score than MISSA-con (4.18 vs 3.78 for fluency, and 3.75 vs 3.68 for coherence), which suggests that conditioning on the system intent to generate responses improves the quality of the generated sentences. Compared with MISSA-sel, MISSA achieves better performance on all the metrics. For example, the engagement score for MISSA is 3.69 while MISSA-sel only has 2.87. This is because the response filter removed all the incoherent responses, which makes the attacker more willing to keep chatting. The ablation study shows both the conditional language generation mechanism and the response filter are essential to MISSA's good performance.\n\nWe also apply our method to the PersuasionForGood dataset.", "MISSA follows the TransferTransfo framework BIBREF0 with three modifications: (i) We first concurrently predict user's, system's intents and semantic slots; (ii) We then perform conditional generation to improve generated response's coherence. Specifically, we generate responses conditioned on the above intermediate representation (intents and slots); (iii) Finally, we generate multiple responses with the nucleus sampling strategy BIBREF5 and then apply a response filter, which contains a set of pre-defined constraints to select coherent responses. The constraints in the filter can be defined according to specific task requirements or general conversational rules.\n\nTo enrich publicly available non-collaborative task datasets, we collect a new dataset AntiScam, where users defend themselves against attackers trying to collect personal information. As non-collaborative tasks are still relatively new to the study of dialog systems, there are insufficiently many meaningful datasets for evaluation and we hope this provides a valuable example. We evaluate MISSA on the newly collected AntiScam dataset and an existing PersuasionForGood dataset. Both automatic and human evaluations suggest that MISSA outperforms multiple competitive baselines.\n\nIn summary, our contributions include: (i) We design a hierarchical intent annotation scheme and a semantic slot annotation scheme to annotate the non-collaborative dialog dataset, we also propose a carefully-designed AntiScam dataset to facilitate the research of non-collaborative dialog systems. (ii) We propose a model that can be applied to all non-collaborative tasks, outperforming other baselines on two different non-collaborative tasks. (iii) We develop an anti-scam dialog system to occupy attacker's attention and elicit their private information for social good. Furthermore, we also build a persuasion dialog system to persuade people to donate to charities. We release the code and data.\n\nThe interest in non-collaborative tasks has been increasing and there have already been several related datasets. For instance, BIBREF1 wang2019persuasion collected conversations where one participant persuades another to donate to a charity. BIBREF2 he2018decoupling collected negotiation dialogs where buyers and sellers bargain for items for sale on Craigslist. There are many other non-collaborative tasks, such as the turn-taking game BIBREF6, the multi-party game BIBREF7 and item splitting negotiation BIBREF8.", "So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents. From Table TABREF19, we observe that MISSA maintains longer conversations (14.9 turns) compared with TransferTransfo (8.5 turns), which means MISSA is better at maintaining the attacker's engagement. MISSA also has a higher task success score (1.294) than TransferTransfo (1.025), which indicates that it elicits information more strategically. In the top two dialogs (A and B) that are shown in Table TABREF21, both attackers were eliciting a credit card number in their first turns. TransferTransfo directly gave away the information, while MISSA replied with a semantically-related question \u201cwhy would you need my credit card number?\" Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor."]} 4 | {"question_id": "b584739622d0c53830e60430b13fd3ae6ff43669", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a role-playing task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["The ultimate holistic evaluation should be conducted by having the trained system interact with human users. Therefore we also conduct human evaluations for the dialog system built on AntiScam. We test our models and baselines with 15 college-student volunteers. Each of them is asked to pretend to be an attacker and interact with all the models for at least three times to avoid randomness. We in total collect 225 number of dialogs. Each time, volunteers are required to use similar sentences and strategies to interact with all five models and score each model based on the metrics listed below at the end of the current round. Each model receives a total of 45 human ratings, and the average score is reported as the final human-evaluation score. In total, we design five different metrics to assess the models' conversational ability whilst interacting with humans. The results are shown in Table TABREF19.\n\nFluency Fluency is used to explore different models' language generation quality.\n\nCoherence Different from single sentence's fluency, coherence focuses more on the logical consistency between sentences in each turn.\n\nEngagement In the anti-scam scenario, one of our missions is to keep engaging with the attackers to waste their time. So we directly ask volunteers (attackers) to what extend they would like to continue chatting with the system.\n\nDialog length (Length) Engagement is a subjective metric. Anti-scam system's goal is to engage user in the conversation longer in order to limit their harm to other potential victims. So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents.", "So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents. From Table TABREF19, we observe that MISSA maintains longer conversations (14.9 turns) compared with TransferTransfo (8.5 turns), which means MISSA is better at maintaining the attacker's engagement. MISSA also has a higher task success score (1.294) than TransferTransfo (1.025), which indicates that it elicits information more strategically. In the top two dialogs (A and B) that are shown in Table TABREF21, both attackers were eliciting a credit card number in their first turns. TransferTransfo directly gave away the information, while MISSA replied with a semantically-related question \u201cwhy would you need my credit card number?\" Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor.", "$\\lambda _{LM}$, $\\lambda _{I_h}$, $\\lambda _{S_h}$, $\\lambda _{I_s}$, $\\lambda _{S_s}$, and $\\lambda _{nup}$ are the hyper-parameters that control the relative importance of every loss.\n\nMISSA can generate multiple sentences in a single system turn. Therefore, we perform system generation conditioned on predicted system intents. More specifically, during the training phase, in addition to inserting a special $<$sep$>$ token at the end of each sentence, we also insert the intent of the system response as special tokens at the head of each sentence in the system response. For example, in Figure FIGREF6, we insert a $<$pos_ans$>$ token at the head of $S_t^1$, which is the system response in green. We then use a cross entropy loss function to calculate the loss between the predicted token and the ground truth intent token. During the testing phase, the model first generates a special intent token, then after being conditioned on this intent token, the model keeps generating a sentence until it generates a $<$sep$>$ token. After that, the model continues to generate another intent token and another sentence until it generates an $<$eos$>$ token.\n\nSince we only perform conditional generation, a type of soft constraint on the predicted intent of system response, the system can still generate samples that violate simple conversation regulations, such as eliciting information that has already been provided. These corner cases may lead to fatal results in high-risk tasks, for example, health care and education. To improve the robustness of MISSA and improve its ability to generalize to more tasks, we add a response filtering module after the generation. With the nucleus sampling strategy BIBREF5, MISSA is able to generate multiple diverse candidate responses with different intents and semantic slots. We then adopt a task-specific response filtering policy to choose the best candidate response as the final output. In our anti-scam scenario, we set up a few simple rules to filter out some unreasonable candidates, for instance, eliciting the repeated information. The filtering module is easily adaptable to different domains or specific requirements, which makes our dialog system more controllable.\n\nWe evaluate MISSA on two non-collaborative task datasets.", "The intent predictor achieves a $84\\%$ accuracy and the semantic slot predictor achieves $77\\%$ on the AntiScam dataset. Then we compare the predicted values with human-annotated ground truth in the dataset to compute the response-intent prediction (RIP) and response-slot prediction (RSP).\n\nExtended Response-Intent Prediction (ERIP) $\\&$ Extended Response-Slot Prediction (ERSP) With Response-Intent Prediction, we verify the predicted intents to evaluate the coherence of the dialog. However, the real mapping between human-intent and system-intent is much more complicated as there might be multiple acceptable system-intents for the same human-intent. Therefore, we also design a metric to evaluate if the predicted system-intent is in the set of acceptable intents. Specifically, we estimate the transition probability $p(I_i|I_j)$ by counting the frequency of all the bi-gram human-intent and system-intent pairs in the training data. During the test stage, if the predicted intent matches the ground truth, we set the score as 1, otherwise we set the score as $p(I_{predict}|I_i)$ where $I_i$ is the intent of the input human utterance. We then report the average value of those scores over turns as the final extended response-intent prediction result.\n\nAutomatic metrics only validate the system\u2019s performance on a single dimension at a time. The ultimate holistic evaluation should be conducted by having the trained system interact with human users. Therefore we also conduct human evaluations for the dialog system built on AntiScam. We test our models and baselines with 15 college-student volunteers. Each of them is asked to pretend to be an attacker and interact with all the models for at least three times to avoid randomness. We in total collect 225 number of dialogs. Each time, volunteers are required to use similar sentences and strategies to interact with all five models and score each model based on the metrics listed below at the end of the current round. Each model receives a total of 45 human ratings, and the average score is reported as the final human-evaluation score. In total, we design five different metrics to assess the models' conversational ability whilst interacting with humans. The results are shown in Table TABREF19.\n\nFluency Fluency is used to explore different models' language generation quality.", "We follow the original TransferTransfo design BIBREF0 and train with undelexicalized data.\n\nHybrid Following BIBREF4 yu2017learning, we also build a hybrid dialog system by combining vanilla TransferTransfo and MISSA. Specifically, we first determine if the human utterances are on-task or off-task with human intent classifier. If the classifier decides that the utterance is on-task, we choose the response from MISSA; otherwise, we choose the response from vanilla TransferTransfo baseline.\n\nIn addition, we perform ablation studies on MISSA to show the effects of different components.\n\nMISSA-sel denotes MISSA without response filtering.\n\nMISSA-con denotes MISSA leaving out the intent token at the start of the response generation.\n\nPerplexity Since the canonical measure of a good language model is perplexity, which indicates the error rate of the expected word. We choose perplexity to evaluate the model performance.\n\nResponse-Intent Prediction (RIP) $\\&$ Response-Slot Prediction (RSP) Different from open-domain dialog systems, we care about the intents of the system response in non-collaborative tasks as we hope to know if the system response satisfies user intents. For example, in the anti-scam task, if the attacker elicits information from the system, we need to know if the system refuses or agrees to provide the information. Therefore we care about intent prediction for the generated system response. Since our baselines are more suited for social chat as they cannot produce system intents, we use the system intent and slot classifiers trained in our model to predict their responses' intents and slots. The intent predictor achieves a $84\\%$ accuracy and the semantic slot predictor achieves $77\\%$ on the AntiScam dataset. Then we compare the predicted values with human-annotated ground truth in the dataset to compute the response-intent prediction (RIP) and response-slot prediction (RSP).\n\nExtended Response-Intent Prediction (ERIP) $\\&$ Extended Response-Slot Prediction (ERSP) With Response-Intent Prediction, we verify the predicted intents to evaluate the coherence of the dialog. However, the real mapping between human-intent and system-intent is much more complicated as there might be multiple acceptable system-intents for the same human-intent. Therefore, we also design a metric to evaluate if the predicted system-intent is in the set of acceptable intents."]} 5 | -------------------------------------------------------------------------------- /rag-foundation/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/rag-foundation/scripts/__init__.py -------------------------------------------------------------------------------- /rag-foundation/scripts/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import fire 5 | from llama_index.core import Document 6 | from llama_index.core.node_parser import SentenceSplitter 7 | from vector_store.node import TextNode, VectorStoreQueryResult 8 | from vector_store.semantic_vector_store import SemanticVectorStore 9 | from vector_store.sparse_vector_store import SparseVectorStore 10 | 11 | 12 | def prepare_data_nodes(documents: list, chunk_size: int = 200) -> list[TextNode]: 13 | """ 14 | Args: 15 | documents: List of documents. 16 | chunk_size: Chunk size for splitting the documents. 17 | Returns: 18 | text_node: List of TextNode objects. 19 | """ 20 | # Load data 21 | documents = [Document(text=t) for t in documents] 22 | 23 | # Split the documents into nodes 24 | node_parser = SentenceSplitter(chunk_size=chunk_size) 25 | 26 | # Get the nodes from the documents 27 | nodes = node_parser.get_nodes_from_documents(documents) 28 | 29 | # Prepare the nodes for the vector store 30 | text_node = [ 31 | TextNode(id_=str(id_), text=node.text, metadata=node.metadata) 32 | for id_, node in enumerate(nodes) 33 | ] 34 | return text_node 35 | 36 | 37 | def prepare_vector_store(documents: list, mode: str, force_index=False, chunk_size=200): 38 | """ 39 | Prepare the vector store with the given documents. 40 | Args: 41 | documents: List of documents to be indexed. 42 | mode: Mode of the vector store. Choose either `sparse` or `semantic`. 43 | force_index: Whether to force indexing the documents. 44 | chunk_size: Chunk size for splitting the documents. 45 | Returns: 46 | vector_store: Vector store object. 47 | """ 48 | if mode == "sparse": 49 | vector_store = SparseVectorStore( 50 | persist=True, 51 | saved_file="data/sparse.csv", 52 | metadata_file="data/sparse_metadata.json", 53 | force_index=force_index, 54 | ) 55 | elif mode == "semantic": 56 | vector_store = SemanticVectorStore( 57 | persist=True, 58 | saved_file="data/dense.csv", 59 | force_index=force_index, 60 | ) 61 | else: 62 | raise ValueError("Invalid mode. Choose either `sparse` or `semantic`.") 63 | 64 | if force_index: 65 | nodes = prepare_data_nodes(documents=documents, chunk_size=chunk_size) 66 | vector_store.add(nodes) 67 | 68 | return vector_store 69 | 70 | 71 | class RAGPipeline: 72 | def __init__(self, vector_store: SemanticVectorStore, prompt_template: str): 73 | self.vector_store = vector_store 74 | self.prompt_template = prompt_template 75 | 76 | # choose your model from groq or openai/azure 77 | self.model = None 78 | 79 | # GROQ 80 | # from langchain_groq import ChatGroq 81 | # self.model = ChatGroq(model="llama3-70b-8192", temperature=0) 82 | 83 | # OpenAI 84 | # from langchain_openai import ChatOpenAI 85 | # self.model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) 86 | 87 | def retrieve(self, query: str, top_k: int = 5) -> VectorStoreQueryResult: 88 | query_result = self.vector_store.query(query, top_k=top_k) 89 | return query_result 90 | 91 | def answer(self, query: str, top_k: int = 5) -> tuple[str, list[str]]: 92 | # Generate openai code to answer the query 93 | result = self.retrieve(query, top_k=top_k) 94 | context_list = [node.text for node in result.nodes] 95 | context = "\n\n".join(context_list) 96 | 97 | self.prompt_template = ( 98 | f"""Question: {query}\n\nGiven context: {context}\n\nAnswer:""" 99 | ) 100 | 101 | if not self.model: 102 | raise ValueError("Model not found. Please initialize the model first.") 103 | try: 104 | response = self.model.invoke(self.prompt_template) 105 | except Exception as e: 106 | raise Exception(f"Error in calling the model: {e}") 107 | return response.content, context_list 108 | 109 | 110 | def main( 111 | data_path: Path = Path("data/qasper-test-v0.3.json"), 112 | output_path: Path = Path("predictions.jsonl"), 113 | mode: str = "sparse", 114 | force_index: bool = False, 115 | print_context: bool = False, 116 | chunk_size: int = 200, 117 | top_k: int = 5, 118 | retrieval_only: bool = False, 119 | ): 120 | # Generate doc string 121 | """ 122 | Args: 123 | data_path: Path to the qasper data file. 124 | output_path: Path to save the predictions. 125 | mode: Mode of the vector store. Choose either `sparse` or `semantic`. 126 | force_index: Whether to force indexing the documents. 127 | print_context: Whether to print the context. 128 | chunk_size: Chunk size for splitting the documents. 129 | top_k: Number of top k documents to retrieve. 130 | retrieval_only: Whether to retrieve only. 131 | Returns: 132 | None 133 | """ 134 | # Load the data 135 | raw_data = json.load(open(data_path, "r", encoding="utf-8")) 136 | 137 | question_ids, predicted_answers, predicted_evidences = [], [], [] 138 | 139 | # NOTE: qasper has many papers, each paper has multiple sections 140 | # we will loop through each paper, gather the full text of each section 141 | # and prepare the documents for the vector store 142 | # and answer the query 143 | for _, values in raw_data.items(): 144 | # for each paper in qasper 145 | documents = [] 146 | 147 | for section in values["full_text"]: 148 | # for each section in the paper 149 | documents += section["paragraphs"] 150 | 151 | # initialize the vector store 152 | # and rag pipeline 153 | # Remember to force_index=True if you want to override the existing index 154 | vector_store = prepare_vector_store( 155 | documents, mode=mode, force_index=force_index, chunk_size=chunk_size 156 | ) 157 | 158 | # NOTE: Should design your own template 159 | prompt_template = """Question: {}\n\nGiven context: {}\n\nAnswer:""" 160 | 161 | rag_pipeline = RAGPipeline(vector_store, prompt_template=prompt_template) 162 | 163 | for q in values["qas"]: 164 | # for each question in the paper 165 | query = q["question"] 166 | question_ids.append(q["question_id"]) 167 | 168 | # NOTE: If you just want to retrieve the top_k relevant documents 169 | # set retrieval_only=True 170 | # Otherwise, it will answer the question 171 | if retrieval_only: 172 | result = rag_pipeline.retrieve(query, top_k=top_k) 173 | context_list = [node.text for node in result.nodes] 174 | 175 | if print_context: 176 | for i, context in enumerate(context_list): 177 | print(f"Relevent context {i + 1}:", context) 178 | print("\n\n") 179 | 180 | predicted_evidences.append(context_list) 181 | predicted_answers.append("") 182 | 183 | else: 184 | predicted_answer, context_list = rag_pipeline.answer(query, top_k=top_k) 185 | 186 | # Just In Case. Print out the context list for each question 187 | # if needed. 188 | if print_context: 189 | for i, context in enumerate(context_list): 190 | print(f"Relevent context {i + 1}:", context) 191 | print("\n\n") 192 | 193 | print("LLM Answer") 194 | print(predicted_answer) 195 | 196 | predicted_evidences.append(context_list) 197 | predicted_answers.append(predicted_answer) 198 | 199 | # save the results 200 | with open(output_path, "w") as f: 201 | for question_id, predicted_answer, predicted_evidence in zip( 202 | question_ids, predicted_answers, predicted_evidences 203 | ): 204 | f.write( 205 | json.dumps( 206 | { 207 | "question_id": question_id, 208 | "predicted_answer": predicted_answer, 209 | "predicted_evidence": predicted_evidence, 210 | } 211 | ) 212 | ) 213 | f.write("\n") 214 | 215 | 216 | if __name__ == "__main__": 217 | fire.Fire(main) 218 | -------------------------------------------------------------------------------- /rag-foundation/setup.cfg: -------------------------------------------------------------------------------- 1 | # Project-wide configuration file, can be used for package metadata and other tool configurations 2 | # Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments 3 | # Local usage: pip install pre-commit, pre-commit run --all-files 4 | 5 | [isort] 6 | # https://pycqa.github.io/isort/docs/configuration/options.html 7 | line_length = 8 | # see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html 9 | multi_line_output = 0 10 | include_trailing_comma = True 11 | 12 | [black] 13 | line_length = 120 14 | 15 | [flake8] 16 | # https://flake8.pycqa.org/en/latest/user/options.html 17 | max-line-length = 120 18 | verbose = 2 19 | format = pylint 20 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 21 | # see: https://www.flake8rules.com/ 22 | select = B, C, E, F, W, T4, B9 23 | ignore = C101, C407, C408, E203, E402, E731, W503 24 | # C101: Coding magic comment not found 25 | # C407: Unnecessary comprehension - can take a generator 26 | # C408: Unnecessary call - rewrite as a literal 27 | # E203 Whitespace before ':' 28 | # E402: module level import not at top of file 29 | # E731: Do not assign a lambda expression, use a def 30 | # W503 Line break occurred before a binary operator 31 | per-file-ignores = 32 | **/__init__.py: F401, F403, F405 33 | # F401: module imported but unused 34 | # F403: ‘from module import *’ used; unable to detect undefined names 35 | # F405: Name may be undefined, or defined from star imports: module 36 | # E501: ignore line length in constants file 37 | -------------------------------------------------------------------------------- /rag-foundation/vector_store/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/rag-foundation/vector_store/__init__.py -------------------------------------------------------------------------------- /rag-foundation/vector_store/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | from loguru import logger 6 | from pydantic import BaseModel, Field 7 | 8 | from .node import BaseNode, TextNode 9 | 10 | 11 | class BaseVectorStore(BaseModel): 12 | """Simple custom Vector Store. 13 | 14 | Stores documents in a simple in-memory dict. 15 | """ 16 | 17 | force_index: bool = False 18 | persist: bool = True 19 | node_dict: dict[str, BaseNode] = Field(default_factory=dict) 20 | node_list: list[BaseNode] = Field(default_factory=list) 21 | saved_file: str = "rag-foundation/data/sematic_vectordb_nodes.csv" 22 | csv_file: Path = Path(saved_file) 23 | 24 | class Config: 25 | arbitrary_types_allowed = True 26 | 27 | def __init__(self, **data): 28 | super().__init__(**data) 29 | self.csv_file = Path(self.saved_file) 30 | self._setup_store() 31 | 32 | def _setup_store(self): 33 | if self.persist: 34 | if self.force_index: 35 | self._reset_csv() 36 | self._initialize_csv() 37 | self._load_from_csv() 38 | 39 | def _initialize_csv(self): 40 | """Initialize the CSV file if it doesn't exist.""" 41 | if not self.csv_file.exists(): 42 | logger.warning( 43 | f"Cannot find CSV file at `{self.saved_file}`, creating a new one..." 44 | ) 45 | os.makedirs(self.csv_file.parent, exist_ok=True) 46 | with open(self.csv_file, "w") as f: 47 | f.write("id,text,embedding,metadata\n") 48 | 49 | def _load_from_csv(self): 50 | """Load the node_dict from the CSV file.""" 51 | if self.csv_file.exists(): 52 | df = pd.read_csv(self.csv_file) 53 | for _, row in df.iterrows(): 54 | node_id = row["id"] 55 | text = row["text"] 56 | try: 57 | embedding = eval(row["embedding"]) 58 | metadata = eval(row["metadata"]) 59 | except TypeError: 60 | embedding = None 61 | metadata = None 62 | self.node_dict[node_id] = TextNode( 63 | id_=str(node_id), text=text, embedding=embedding, metadata=metadata 64 | ) 65 | 66 | def _update_csv(self): 67 | """Update the CSV file with the current node_dict if persist is True.""" 68 | if self.persist: 69 | data = {"id": [], "text": [], "embedding": [], "metadata": []} 70 | for key, node in self.node_dict.items(): 71 | data["id"].append(key) 72 | data["text"].append(node.text) 73 | data["embedding"].append(node.embedding) 74 | data["metadata"].append(node.metadata) 75 | df = pd.DataFrame(data) 76 | df.to_csv(self.csv_file, index=False) 77 | else: 78 | logger.warning("`persist` is set to `False`, not updating CSV file.") 79 | 80 | def _reset_csv(self): 81 | """Reset the CSV file by deleting it if it exists.""" 82 | if self.csv_file.exists(): 83 | self.csv_file.unlink() 84 | 85 | def get(self): 86 | """Get embedding.""" 87 | 88 | def add(self): 89 | """Add nodes to index.""" 90 | 91 | def delete(self) -> None: 92 | """Delete nodes using with node_id.""" 93 | 94 | def query(self): 95 | """Get nodes for response.""" 96 | -------------------------------------------------------------------------------- /rag-foundation/vector_store/node.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Sequence 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class BaseNode(BaseModel): 8 | id_: str 9 | embedding: Optional[List[float]] = None 10 | metadata: Optional[Dict[str, Any]] = None 11 | 12 | 13 | class TextNode(BaseNode): 14 | text: str | List[str] 15 | 16 | 17 | @dataclass 18 | class VectorStoreQueryResult: 19 | """Vector store query result.""" 20 | 21 | nodes: Optional[Sequence[BaseNode]] = None 22 | similarities: Optional[List[float]] = None 23 | ids: Optional[List[str]] = None 24 | -------------------------------------------------------------------------------- /rag-foundation/vector_store/semantic_vector_store.py: -------------------------------------------------------------------------------- 1 | # autoflake: off 2 | # flake8: noqa: F841 3 | import sys 4 | from typing import Dict, List, cast 5 | 6 | import numpy as np 7 | from loguru import logger 8 | from sentence_transformers import SentenceTransformer 9 | 10 | from .base import BaseVectorStore 11 | from .node import TextNode, VectorStoreQueryResult 12 | 13 | logger.add( 14 | sink=sys.stdout, 15 | colorize=True, 16 | format="{time} {message}", 17 | ) 18 | 19 | 20 | class SemanticVectorStore(BaseVectorStore): 21 | """Semantic Vector Store using SentenceTransformer embeddings.""" 22 | 23 | saved_file: str = "rag-foundation/data/test_db_00.csv" 24 | embed_model_name: str = "all-MiniLM-L6-v2" 25 | embed_model: SentenceTransformer = SentenceTransformer(embed_model_name) 26 | 27 | def __init__(self, **data): 28 | super().__init__(**data) 29 | self._setup_store() 30 | 31 | def get(self, text_id: str) -> TextNode: 32 | """Get node.""" 33 | try: 34 | return self.node_dict[text_id] 35 | except KeyError: 36 | logger.error(f"Node with id `{text_id}` not found.") 37 | return None 38 | 39 | def add(self, nodes: List[TextNode]) -> List[str]: 40 | """Add nodes to index.""" 41 | for node in nodes: 42 | if node.embedding is None: 43 | logger.info( 44 | "Found node without embedding, calculating " 45 | f"embedding with model {self.embed_model_name}" 46 | ) 47 | node.embedding = self._get_text_embedding(node.text) 48 | self.node_dict[node.id_] = node 49 | self._update_csv() # Update CSV after adding nodes 50 | return [node.id_ for node in nodes] 51 | 52 | def _get_text_embedding(self, text: str) -> List[float]: 53 | """Calculate embedding.""" 54 | return self.embed_model.encode(text).tolist() 55 | 56 | def delete(self, node_id: str, **delete_kwargs: Dict) -> None: 57 | """Delete nodes using node_id.""" 58 | if node_id in self.node_dict: 59 | del self.node_dict[node_id] 60 | self._update_csv() # Update CSV after deleting nodes 61 | else: 62 | logger.error(f"Node with id `{node_id}` not found.") 63 | 64 | def _calculate_similarity( 65 | self, 66 | query_embedding: List[float], 67 | doc_embeddings: List[List[float]], 68 | doc_ids: List[str], 69 | similarity_top_k: int = 3, 70 | ) -> tuple[List[float], List[str]]: 71 | """Get top nodes by similarity to the query.""" 72 | qembed_np = np.array(query_embedding) 73 | dembed_np = np.array(doc_embeddings) 74 | 75 | # calculate the dot product of 76 | # the query embedding with the document embeddings 77 | # HINT: np.dot 78 | "Your code here" 79 | dproduct_arr = None 80 | # calculate the cosine similarity 81 | # by dividing the dot product by the norm 82 | # HINT: np.linalg.norm 83 | "Your code here" 84 | cos_sim_arr = None 85 | 86 | # get the indices of the top k similarities 87 | "Your code here" 88 | similarities = None 89 | node_ids = None 90 | 91 | return similarities, node_ids 92 | 93 | def query(self, query: str, top_k: int = 3) -> VectorStoreQueryResult: 94 | """Query similar nodes.""" 95 | query_embedding = cast(List[float], self._get_text_embedding(query)) 96 | doc_embeddings = [node.embedding for node in self.node_dict.values()] 97 | doc_ids = list(self.node_dict.keys()) 98 | if len(doc_embeddings) == 0: 99 | logger.error("No documents found in the index.") 100 | result_nodes, similarities, node_ids = [], [], [] 101 | else: 102 | similarities, node_ids = self._calculate_similarity( 103 | query_embedding, doc_embeddings, doc_ids, top_k 104 | ) 105 | result_nodes = [self.node_dict[node_id] for node_id in node_ids] 106 | return VectorStoreQueryResult( 107 | nodes=result_nodes, similarities=similarities, ids=node_ids 108 | ) 109 | 110 | def batch_query( 111 | self, query: List[str], top_k: int = 3 112 | ) -> List[VectorStoreQueryResult]: 113 | """Batch query similar nodes.""" 114 | return [self.query(q, top_k) for q in query] 115 | -------------------------------------------------------------------------------- /rag-foundation/vector_store/sparse_vector_store.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F841 2 | import json 3 | import sys 4 | from multiprocessing import Pool, cpu_count 5 | from pathlib import Path 6 | from typing import ClassVar, Dict, List 7 | 8 | import numpy as np 9 | from loguru import logger 10 | from pydantic import Field 11 | from transformers import AutoTokenizer 12 | 13 | from .base import BaseVectorStore 14 | from .node import TextNode, VectorStoreQueryResult 15 | 16 | logger.add( 17 | sink=sys.stdout, 18 | colorize=True, 19 | format="{time} {message}", 20 | ) 21 | 22 | TOKENIZER = AutoTokenizer.from_pretrained( 23 | "google-bert/bert-base-uncased", max_length=200, truncation=True 24 | ) 25 | 26 | 27 | class SparseVectorStore(BaseVectorStore): 28 | """VectorStore2 (add/get/delete implemented).""" 29 | 30 | saved_file: str = "rag-foundation/data/test_db_10.csv" 31 | metadata_file: Path = Path("rag-foundation/data/sparse_metadata_tmp.json") 32 | tokenizer: ClassVar[AutoTokenizer] = TOKENIZER 33 | corpus_size: int = Field(default=0, init=False) 34 | avgdl: float = Field(default=0.0, init=False) 35 | doc_freqs: List[Dict[str, int]] = Field(default_factory=list, init=False) 36 | idf: Dict[str, float] = Field(default_factory=dict, init=False) 37 | doc_len: List[int] = Field(default_factory=list, init=False) 38 | nd: int = Field(default=0, init=False) 39 | 40 | # Algorithm specific parameters 41 | k1: float = Field(default=1.2) 42 | b: float = Field(default=0.75) 43 | delta: float = Field(default=0.25) 44 | 45 | def __init__(self, **data): 46 | super().__init__(**data) 47 | if len(self.node_dict) > 0: 48 | self.metadata_file = Path(self.metadata_file) 49 | if self.metadata_file.exists() and not self.force_index: 50 | self._load_from_json() 51 | else: 52 | self._initialize_bm25_assets() 53 | 54 | self.node_list = list(self.node_dict.values()) 55 | 56 | def _initialize_bm25_assets(self): 57 | """Initialize BM25 assets from the node dictionary.""" 58 | self.corpus_size = 0 59 | self.avgdl = 0 60 | self.doc_freqs = [] 61 | self.idf = {} 62 | self.doc_len = [] 63 | self.nd = 0 64 | 65 | corpus = self._tokenize_text([node.text for node in self.node_list]) 66 | self._initialize(corpus) 67 | content = { 68 | "corpus_size": self.corpus_size, 69 | "avgdl": self.avgdl, 70 | "doc_freqs": self.doc_freqs, 71 | "idf": self.idf, 72 | "doc_len": self.doc_len, 73 | "nd": self.nd, 74 | } 75 | with open(self.metadata_file, "w") as f: 76 | json.dump(content, f) 77 | 78 | def _load_from_json(self): 79 | with open(self.metadata_file, "r") as f: 80 | content = json.load(f) 81 | self.corpus_size = content["corpus_size"] 82 | self.avgdl = content["avgdl"] 83 | self.doc_freqs = content["doc_freqs"] 84 | self.idf = content["idf"] 85 | self.doc_len = content["doc_len"] 86 | self.nd = content["nd"] 87 | 88 | def _initialize(self, corpus: List[List[str]]): 89 | nd = {} # word -> number of documents with word 90 | num_doc = 0 91 | for document in corpus: 92 | self.doc_len.append(len(document)) 93 | num_doc += len(document) 94 | 95 | frequencies = {} 96 | for word in document: 97 | if word not in frequencies: 98 | frequencies[word] = 0 99 | frequencies[word] += 1 100 | self.doc_freqs.append(frequencies) 101 | 102 | for word, freq in frequencies.items(): 103 | try: 104 | nd[word] += 1 105 | except KeyError: 106 | nd[word] = 1 107 | 108 | self.corpus_size += 1 109 | 110 | self.avgdl = num_doc / self.corpus_size 111 | self.idf = { 112 | word: self._calculate_idf(doc_count, self.corpus_size) 113 | for word, doc_count in nd.items() 114 | } 115 | 116 | def _calculate_idf(self, doc_count: int, corpus_size: int) -> float: 117 | # Calculate the inverse document frequency for a word 118 | # HINT: Use the formula provided in the BM25 algorithm and np.log() 119 | "Your code here" 120 | idf_score = None 121 | return idf_score 122 | 123 | def _tokenize_text(self, corpus: List[str] | str): 124 | if isinstance(corpus, str): 125 | return self.tokenizer.tokenize(corpus) 126 | else: 127 | pool = Pool(cpu_count()) 128 | tokenized_corpus = pool.map(self.tokenizer.tokenize, corpus) 129 | return tokenized_corpus 130 | 131 | def add(self, nodes: List[TextNode]) -> List[str]: 132 | """Add nodes to index.""" 133 | for node in nodes: 134 | self.node_dict[node.id_] = node 135 | self._update_csv() # Update CSV after adding nodes 136 | 137 | # Reinitialize BM25 assets after adding new nodes 138 | self._initialize_bm25_assets() 139 | 140 | return [node.id_ for node in nodes] 141 | 142 | def get(self, text_id: str) -> TextNode: 143 | """Get node.""" 144 | try: 145 | return self.node_dict[text_id] 146 | except KeyError: 147 | logger.error(f"Node with id `{text_id}` not found.") 148 | return None 149 | 150 | def get_scores(self, query: str): 151 | score = np.zeros(self.corpus_size) 152 | tokenized_query = self._tokenize_text(query) 153 | for q in tokenized_query: 154 | # calulate the score for each token in the query 155 | # HINT: use self.doc_freqs, self.idf, self.corpus_size, self.avgdl 156 | "Your code here" 157 | cur_score = None 158 | score += cur_score 159 | return score 160 | 161 | def query(self, query: str, top_k: int = 3) -> VectorStoreQueryResult: 162 | """Query similar nodes. 163 | 164 | Args: 165 | query (str): _description_ 166 | top_k (int, optional): _description_. Defaults to 3. 167 | 168 | Returns: 169 | List[TextNode]: _description_ 170 | """ 171 | scores = self.get_scores(query) 172 | best_ids = np.argsort(scores)[::-1][:top_k] 173 | nodes = [self.node_list[node_id] for node_id in best_ids] 174 | return VectorStoreQueryResult( 175 | nodes=nodes, 176 | similarities=[scores[doc_id] for doc_id in best_ids], 177 | ids=[node.id_ for node in nodes], 178 | ) 179 | 180 | def batch_query( 181 | self, query: List[str], top_k: int = 3 182 | ) -> List[VectorStoreQueryResult]: 183 | """Batch query similar nodes. 184 | 185 | Args: 186 | query (List[str]): _description_ 187 | top_k (int, optional): _description_. Defaults to 3. 188 | 189 | Returns: 190 | List[VectorStoreQueryResult]: _description_ 191 | """ 192 | return [self.query(q, top_k) for q in query] 193 | -------------------------------------------------------------------------------- /streamlit_demo/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | __pycache__/ 3 | .idea/ 4 | app_data/ 5 | -------------------------------------------------------------------------------- /streamlit_demo/README.md: -------------------------------------------------------------------------------- 1 | # Streamlit for Object Detection 2 | 3 | --- 4 | 5 | ## Quick Usage 6 | 7 | Install requirements 8 | 9 | ```bash 10 | conda create --name streamlit_demo python=3.11 11 | conda activate streamlit_demo 12 | 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | Run app 17 | 18 | ```bash 19 | python launch.py 20 | ``` 21 | 22 | ## Screenshots 23 | 24 | ![Screen_shot](./assets/screenshot_app.png) -------------------------------------------------------------------------------- /streamlit_demo/assets/screenshot_app.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/assets/screenshot_app.png -------------------------------------------------------------------------------- /streamlit_demo/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | APP_DATA_DIR = Path(__file__).parent / "app_data" 5 | os.makedirs(APP_DATA_DIR, exist_ok=True) 6 | 7 | FEEDBACK_DIR = APP_DATA_DIR / "feedback" 8 | os.makedirs(FEEDBACK_DIR, exist_ok=True) 9 | 10 | FEEDBACK_SQL_PATH = f"sqlite:///{FEEDBACK_DIR / 'feedback.sql'}" 11 | 12 | YOLO_OPTIONS = [ 13 | "yolov8s.pt", 14 | "yolov8n.pt" 15 | ] 16 | 17 | YOLO_SUPPORTED_EXTENSIONS = ["jpg", "png", "jpeg"] 18 | 19 | USER_DATA_DIR = APP_DATA_DIR / "user_data" / "images" 20 | os.makedirs(USER_DATA_DIR, exist_ok=True) 21 | 22 | AI_MODEL_CONFIGS = { 23 | "yolov8": { 24 | "model_name": "yolov8s.pt", 25 | "device": "cuda" 26 | } 27 | } 28 | AI_MODEL = "yolov8" 29 | 30 | CLASSES = ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 31 | 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 32 | 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 33 | 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 34 | 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 35 | 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 36 | 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 37 | 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 38 | 'Hair drier', 'Toothbrush'] 39 | -------------------------------------------------------------------------------- /streamlit_demo/launch.py: -------------------------------------------------------------------------------- 1 | from shared.views import App 2 | from shared.utils.log import custom_logger 3 | from shared.utils.pages import set_page_config 4 | 5 | set_page_config() 6 | custom_logger() 7 | 8 | app = App() 9 | app.view(key="app") 10 | -------------------------------------------------------------------------------- /streamlit_demo/lessons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/lessons/__init__.py -------------------------------------------------------------------------------- /streamlit_demo/lessons/cache_flow.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import streamlit as st 5 | import pandas as pd 6 | 7 | from functools import lru_cache 8 | 9 | 10 | class Model: 11 | def __init__(self, dct): 12 | self.dct = dct 13 | 14 | 15 | def experiment_1(): 16 | st.code(''' 17 | class Model: 18 | def __init__(self, dct): 19 | self.dct = dct 20 | 21 | # @lru_cache(1) 22 | # @st.cache_data 23 | @st.cache_resource 24 | def load_data(dct: dict) -> Model: 25 | print("I will go sleep for 3s") 26 | time.sleep(3) 27 | 28 | model = Model(dct) 29 | return model 30 | 31 | data_dct = { 32 | 'Column1': [1, 2, 3, 4, 5], 33 | 'Column2': ['A', 'B', 'C', 'D', 'E'], 34 | 'Column3': [10.5, 20.5, 30.5, 40.5, 50.5], 35 | 'Column4': [True, False, True, False, True] 36 | } 37 | 38 | model = load_data(data_dct) 39 | st.json(model.dct) 40 | 41 | model.dct = {} 42 | st.button("Rerun") 43 | ''', language='python') 44 | 45 | # @lru_cache(1) 46 | @st.cache_data 47 | # @st.cache_resource 48 | def load_data(dct: dict) -> Model: 49 | print("I will go sleep for 3s") 50 | time.sleep(3) 51 | 52 | model = Model(dct) 53 | return model 54 | 55 | data_dct = { 56 | 'Column1': [1, 2, 3, 4, 5], 57 | 'Column2': ['A', 'B', 'C', 'D', 'E'], 58 | 'Column3': [10.5, 20.5, 30.5, 40.5, 50.5], 59 | 'Column4': [True, False, True, False, True] 60 | } 61 | 62 | model = load_data(data_dct) 63 | st.json(model.dct) 64 | 65 | btn = st.button("Rerun") 66 | if btn: 67 | model.dct = {} 68 | 69 | 70 | experiment_1() 71 | -------------------------------------------------------------------------------- /streamlit_demo/lessons/execution_flow.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import streamlit as st 3 | 4 | try: 5 | st.set_page_config( 6 | page_title="Execution Flow", 7 | page_icon="🤖", 8 | layout="wide", 9 | initial_sidebar_state="expanded" 10 | ) 11 | finally: 12 | pass 13 | 14 | 15 | def experiment_1(): 16 | st.code(''' 17 | st.info(str(datetime.datetime.now())) 18 | 19 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1) 20 | print(magic_number) 21 | 22 | btn = st.button("Submit") 23 | input_a = None 24 | if btn: 25 | print("Enter btn function", datetime.datetime.now()) 26 | st.toast("Button pressed") 27 | input_a = f"Hello word. Your magic number is: {magic_number}" 28 | 29 | st.info(magic_number) 30 | st.info(input_a) 31 | ''', language="python") 32 | 33 | st.info(str(datetime.datetime.now())) 34 | 35 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1) 36 | print(magic_number) 37 | 38 | btn = st.button("Submit") 39 | input_a = None 40 | if btn: 41 | print("Enter btn function", datetime.datetime.now()) 42 | st.toast("Button pressed") 43 | input_a = f"Hello word. Your magic number is: {magic_number}" 44 | 45 | st.info(magic_number) 46 | st.info(input_a) 47 | 48 | 49 | def experiment_2(): 50 | st.code(''' 51 | st.info(str(datetime.datetime.now())) 52 | 53 | with st.form("form", clear_on_submit=True): 54 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1) 55 | print(magic_number) 56 | 57 | btn = st.form_submit_button("Submit") 58 | 59 | if btn: 60 | print("Enter btn function", datetime.datetime.now()) 61 | st.info("Hello World") 62 | 63 | st.info(magic_number) 64 | ''', language="python") 65 | st.info(str(datetime.datetime.now())) 66 | 67 | with st.form("form", clear_on_submit=True): 68 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1) 69 | print(magic_number) 70 | 71 | btn = st.form_submit_button("Submit") 72 | 73 | if btn: 74 | print("Enter btn function", datetime.datetime.now()) 75 | st.info("Hello World") 76 | 77 | st.info(magic_number) 78 | 79 | 80 | def experiment_3(): 81 | 82 | pass 83 | 84 | 85 | cols = st.columns(2) 86 | 87 | with cols[0]: 88 | experiment_1() 89 | 90 | 91 | with cols[1]: 92 | experiment_2() 93 | 94 | -------------------------------------------------------------------------------- /streamlit_demo/lessons/layout.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | try: 5 | st.set_page_config( 6 | page_title="Execution Flow", 7 | page_icon="🤖", 8 | layout="wide", 9 | initial_sidebar_state="expanded" 10 | ) 11 | finally: 12 | pass 13 | 14 | 15 | def side_bar_view(): 16 | st.header("Model Configurations") 17 | st.info("Check out the documentation " 18 | "at [link](https://docs.ultralytics.com/modes/predict/#inference-sources)") 19 | 20 | key = "sidebar" 21 | with st.form(f"{key}_upload", clear_on_submit=True): 22 | upload_image = st.file_uploader( 23 | "Upload Image(s)", 24 | accept_multiple_files=False, 25 | type=["png", "jpg", "jpeg"], 26 | key=f"{key}_upload_images" 27 | ) 28 | 29 | col1, col2 = st.columns(2) 30 | with col1: 31 | augment = st.radio( 32 | "Augment", 33 | (True, False), 34 | horizontal=True 35 | ) 36 | with col2: 37 | agnostic_nms = st.radio( 38 | "Agnostic NMS", 39 | (True, False), 40 | horizontal=True 41 | ) 42 | image_size = st.number_input( 43 | "Image Size", 44 | value=640, 45 | step=32, 46 | min_value=640, 47 | max_value=1280 48 | ) 49 | min_iou = st.slider( 50 | "Minimum IOU", 51 | min_value=0.0, 52 | max_value=1.0, 53 | value=0.5, 54 | step=0.01 55 | ) 56 | min_confident_score = st.slider( 57 | "Minimum Confidence Score", 58 | min_value=0.0, 59 | max_value=1.0, 60 | value=0.2, 61 | step=0.01 62 | ) 63 | 64 | submit_btn = st.form_submit_button( 65 | label="Upload", 66 | type="primary", 67 | use_container_width=True 68 | ) 69 | 70 | 71 | def col_1_view(): 72 | st.image("m10.jpg") 73 | 74 | 75 | def col_2_view(): 76 | dummy_counting_dct = { 77 | "Person": 1 78 | } 79 | 80 | with st.container(border=True): 81 | st.markdown("**Counting**") 82 | st.json(dummy_counting_dct) 83 | 84 | with st.expander(label="Object Detail", expanded=True): 85 | cls = st.selectbox(label="Class", options=["Person", "Animal"], index=0) 86 | 87 | st.markdown(f"Confident score :red[0.92]") 88 | 89 | 90 | with st.sidebar: 91 | side_bar_view() 92 | 93 | image_col, info_col = st.columns([8, 2]) 94 | 95 | with image_col: 96 | col_1_view() 97 | 98 | with info_col: 99 | col_2_view() 100 | -------------------------------------------------------------------------------- /streamlit_demo/lessons/m10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/lessons/m10.jpg -------------------------------------------------------------------------------- /streamlit_demo/requirements.txt: -------------------------------------------------------------------------------- 1 | loguru==0.7.2 2 | numpy==2.0.1 3 | opencv_python==4.10.0.84 4 | pandas==2.2.2 5 | Pillow==10.4.0 6 | Requests==2.32.3 7 | SQLAlchemy==2.0.31 8 | streamlit==1.36.0 9 | torch==2.0.1 10 | ultralytics==8.2.64 11 | -------------------------------------------------------------------------------- /streamlit_demo/shared/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /streamlit_demo/shared/crud/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/crud/__init__.py -------------------------------------------------------------------------------- /streamlit_demo/shared/crud/feedbacks.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | from shared.models import Feedback 4 | from shared.models.engine import Session 5 | 6 | 7 | class FeedbackCRUD: 8 | def __init__(self, session: Session): 9 | self.session = session 10 | 11 | def create(self, image_path: str, data: dict) -> bool: 12 | existed_feedback = self.get_by_image_path(image_path) 13 | if existed_feedback: 14 | self.delete_by_id(existed_feedback.id) 15 | logger.info(f"Image path: {image_path} exists. Deleted") 16 | 17 | feedback = Feedback(image_path=image_path, data=data) 18 | self.session.add(feedback) 19 | self.session.commit() 20 | 21 | logger.info(f"Added 1 row") 22 | return True 23 | 24 | def delete_by_id(self, feedback_id: int) -> bool: 25 | ( 26 | self.session 27 | .query(Feedback) 28 | .filter(Feedback.id == feedback_id) 29 | .delete(synchronize_session=False) 30 | ) 31 | return True 32 | 33 | def get_by_image_path(self, image_path: str) -> Feedback | None: 34 | result = ( 35 | self.session 36 | .query(Feedback) 37 | .filter(Feedback.image_path == image_path) 38 | .first() 39 | ) 40 | 41 | return result 42 | -------------------------------------------------------------------------------- /streamlit_demo/shared/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import ( 2 | engine 3 | ) 4 | from .models import ( 5 | Feedback, Base 6 | ) 7 | 8 | Base.metadata.create_all(engine) 9 | -------------------------------------------------------------------------------- /streamlit_demo/shared/models/engine.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.orm import sessionmaker 3 | 4 | from constants import FEEDBACK_SQL_PATH 5 | 6 | 7 | engine = create_engine(FEEDBACK_SQL_PATH) 8 | Session = sessionmaker( 9 | bind=engine, 10 | ) 11 | 12 | -------------------------------------------------------------------------------- /streamlit_demo/shared/models/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, JSON, String 2 | from sqlalchemy.ext.declarative import declarative_base 3 | 4 | 5 | Base = declarative_base() 6 | 7 | 8 | class Feedback(Base): 9 | __tablename__ = 'Feedback' 10 | id = Column(Integer, primary_key=True, autoincrement=True) 11 | image_path = Column(String) 12 | data = Column(JSON, default=dict()) 13 | 14 | -------------------------------------------------------------------------------- /streamlit_demo/shared/models_ai/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import streamlit as st 4 | 5 | from .base import BaseAIModel 6 | from .yolov8 import Yolov8 7 | 8 | 9 | @st.cache_resource 10 | def get_ai_model(name: str, model_params: dict) -> BaseAIModel | None: 11 | factory: dict[str, BaseAIModel] = { 12 | "yolov8": Yolov8(**model_params), 13 | } 14 | 15 | return factory.get(name, None) 16 | -------------------------------------------------------------------------------- /streamlit_demo/shared/models_ai/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | 7 | class BaseAIModel(ABC): 8 | @abstractmethod 9 | def process(self, image_in: Path | str | np.ndarray, *args, **kwargs) -> Path: 10 | ... 11 | -------------------------------------------------------------------------------- /streamlit_demo/shared/models_ai/yolov8.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from ultralytics.models import YOLO 8 | from ultralytics.engine.results import Results 9 | from loguru import logger 10 | 11 | from .base import BaseAIModel 12 | from shared.schemas import Parameters, ModelOutput 13 | 14 | 15 | class Yolov8(BaseAIModel): 16 | def __init__(self, model_name: str, device: Literal["cpu", "cuda"] = "cuda"): 17 | self._model = YOLO(model_name, task="detect") 18 | 19 | if device in ["cuda"] and torch.cuda.is_available(): 20 | self._device = torch.device(device) 21 | else: 22 | self._device = torch.device("cpu") 23 | 24 | self._model.to(self._device) 25 | 26 | @staticmethod 27 | def get_default() -> dict: 28 | return { 29 | "augment": False, 30 | "agnostic_nms": False, 31 | "imgsz": 640, 32 | "iou": 0.5, 33 | "conf": 0.01, 34 | "verbose": False 35 | } 36 | 37 | def process( 38 | self, 39 | image_in: Path | str | np.ndarray, 40 | *args, 41 | **kwargs, 42 | ) -> Path: 43 | if type(image_in) is [str, Path]: 44 | image_in = cv2.imread(image_in, cv2.IMREAD_COLOR) 45 | 46 | default_params: dict = self.get_default() 47 | if kwargs.get("params", None): 48 | params: Parameters = kwargs["params"] 49 | 50 | # Update 51 | default_params["augment"] = params.augment 52 | default_params["agnostic_nms"] = params.agnostic_nms 53 | default_params["imgsz"] = params.image_size 54 | default_params["iou"] = params.min_iou 55 | default_params["conf"] = params.min_confident_score 56 | 57 | logger.debug(f"Run with config: {default_params}") 58 | 59 | results: Results = self._model(image_in, **default_params) 60 | result = results[0].cpu().numpy() 61 | 62 | model_out_params = { 63 | "xyxysc": result.boxes.data 64 | } 65 | 66 | return ModelOutput(**model_out_params) 67 | -------------------------------------------------------------------------------- /streamlit_demo/shared/schemas.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | 3 | import numpy as np 4 | 5 | 6 | @dataclass 7 | class Base: 8 | def to_dict(self): 9 | return asdict(self) 10 | 11 | 12 | @dataclass 13 | class Parameters(Base): 14 | augment: bool 15 | agnostic_nms: bool 16 | image_size: int 17 | min_iou: float 18 | min_confident_score: float 19 | 20 | 21 | @dataclass 22 | class ModelInput(Base): 23 | upload_image: str 24 | params: Parameters 25 | 26 | 27 | @dataclass 28 | class ModelOutput(Base): 29 | xyxysc: np.ndarray # x_min, y_min, x_max, y_max, score, class 30 | 31 | def __len__(self): 32 | return len(self.xyxysc) 33 | 34 | def __getitem__(self, item_id: int) -> np.ndarray: 35 | return self.xyxysc[item_id] 36 | 37 | def count(self) -> dict[int, int]: 38 | cls_dict: dict[int, int] = {} 39 | for c in self.xyxysc[:, -1]: 40 | c = int(c) 41 | if c not in cls_dict: 42 | cls_dict[c] = 0 43 | cls_dict[c] += 1 44 | 45 | return cls_dict 46 | 47 | def to_dict(self) -> dict[int, list]: 48 | result_dict: dict[int, list] = {} 49 | for i, elem in enumerate(self.xyxysc): 50 | x_min, y_min, x_max, y_max = map(int, elem[:4]) 51 | score = float(elem[-2]) 52 | cls = int(elem[-1]) 53 | 54 | result_dict[i] = [ 55 | x_min, y_min, x_max, y_max, score, cls 56 | ] 57 | return result_dict 58 | 59 | 60 | @dataclass 61 | class EditedOutput(Base): 62 | cls: int 63 | -------------------------------------------------------------------------------- /streamlit_demo/shared/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/utils/__init__.py -------------------------------------------------------------------------------- /streamlit_demo/shared/utils/files.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from loguru import logger 4 | from PIL import Image 5 | 6 | 7 | def save_uploaded_file(file, dir_out: str) -> str: 8 | """Save uploaded file to local""" 9 | pil_image = Image.open(file) 10 | 11 | path_out = os.path.join(dir_out, file.name) 12 | pil_image.save( 13 | path_out 14 | ) 15 | 16 | assert os.path.isfile(path_out) 17 | logger.info(f"Save file at: {path_out}") 18 | 19 | return path_out 20 | -------------------------------------------------------------------------------- /streamlit_demo/shared/utils/log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from loguru import logger 4 | 5 | 6 | def custom_logger(): 7 | logger.remove() 8 | logger.add( 9 | sys.stderr, 10 | colorize=True, 11 | format="[{time:MM/DD HH:mm:ss}] {level: ^8}| {message}", 12 | ) 13 | -------------------------------------------------------------------------------- /streamlit_demo/shared/utils/pages.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | def set_page_config(): 5 | try: 6 | st.set_page_config( 7 | page_title="Object Detection", 8 | page_icon="🤖", 9 | layout="wide", 10 | initial_sidebar_state="expanded" 11 | ) 12 | finally: 13 | pass 14 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/__init__.py: -------------------------------------------------------------------------------- 1 | from .app.view import App -------------------------------------------------------------------------------- /streamlit_demo/shared/views/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/app/__init__.py -------------------------------------------------------------------------------- /streamlit_demo/shared/views/app/view.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import streamlit as st 4 | 5 | import constants as c 6 | from shared.crud.feedbacks import FeedbackCRUD 7 | from shared.models.engine import Session 8 | from shared.utils.files import save_uploaded_file 9 | from shared.schemas import ModelInput, ModelOutput, Parameters, EditedOutput 10 | from shared.models_ai import get_ai_model, BaseAIModel 11 | from shared.views.canvas.canvas import st_annotate_tool 12 | 13 | 14 | class BaseView(ABC): 15 | @abstractmethod 16 | def view(self, key: str): 17 | ... 18 | 19 | 20 | class UploadView(BaseView): 21 | def view(self, key: str) -> ModelInput | None: 22 | with st.form(f"{key}_upload", clear_on_submit=True): 23 | upload_image = st.file_uploader( 24 | "Upload Image(s)", 25 | accept_multiple_files=False, 26 | type=c.YOLO_SUPPORTED_EXTENSIONS, 27 | key=f"{key}_upload_images" 28 | ) 29 | 30 | col1, col2 = st.columns(2) 31 | with col1: 32 | augment = st.radio( 33 | "Augment", 34 | (True, False), 35 | horizontal=True 36 | ) 37 | with col2: 38 | agnostic_nms = st.radio( 39 | "Agnostic NMS", 40 | (True, False), 41 | horizontal=True 42 | ) 43 | image_size = st.number_input( 44 | "Image Size", 45 | value=640, 46 | step=32, 47 | min_value=640, 48 | max_value=1280 49 | ) 50 | min_iou = st.slider( 51 | "Minimum IOU", 52 | min_value=0.0, 53 | max_value=1.0, 54 | value=0.5, 55 | step=0.01 56 | ) 57 | min_confident_score = st.slider( 58 | "Minimum Confidence Score", 59 | min_value=0.0, 60 | max_value=1.0, 61 | value=0.2, 62 | step=0.01 63 | ) 64 | 65 | submit_btn = st.form_submit_button( 66 | label="Upload", 67 | type="primary", 68 | use_container_width=True 69 | ) 70 | 71 | if submit_btn: 72 | upload_image_path: str = save_uploaded_file( 73 | upload_image, 74 | c.USER_DATA_DIR 75 | ) 76 | 77 | input_params = { 78 | "augment": augment, 79 | "agnostic_nms": agnostic_nms, 80 | "image_size": image_size, 81 | "min_iou": min_iou, 82 | "min_confident_score": min_confident_score 83 | } 84 | 85 | return ModelInput( 86 | upload_image=upload_image_path, 87 | params=Parameters(**input_params) 88 | ) 89 | 90 | return 91 | 92 | 93 | class ImagePanelView(BaseView): 94 | def view(self, key: str, model_output: ModelOutput, image_path: str): 95 | updated_output, selected_index = st_annotate_tool( 96 | regions=model_output, 97 | background_image=image_path, 98 | key=f"{key}_visual", 99 | canvas_height=900, 100 | canvas_width=900 101 | ) 102 | 103 | updated_output: ModelOutput 104 | selected_index: int 105 | 106 | return updated_output, selected_index 107 | 108 | 109 | class InfoPanelView(BaseView): 110 | def view(self, key: str, model_output: ModelOutput, selected_index: int) -> EditedOutput | None: 111 | # Counting bboxes 112 | cls_name_dict: dict[str, int] = {c.CLASSES[k]: v for k, v in model_output.count().items()} 113 | 114 | with st.container(border=True): 115 | st.markdown("**Counting**") 116 | st.json(cls_name_dict) 117 | 118 | # View selected bbox 119 | if 0 <= selected_index < len(model_output.xyxysc): 120 | x_min, y_min, x_max, y_max, score, cls = model_output.xyxysc[selected_index] 121 | 122 | with st.expander(label="Object Detail", expanded=True): 123 | cls = st.selectbox(label="Class", options=c.CLASSES, index=int(cls)) 124 | 125 | score_in_str = "%.3f" % score 126 | st.markdown(f"Confident score :red[{score_in_str}]") 127 | 128 | cls_index: int = c.CLASSES.index(cls) 129 | 130 | return EditedOutput(cls=cls_index) 131 | 132 | 133 | class App(BaseView): 134 | def __init__(self): 135 | self._upload_view = UploadView() 136 | self._image_panel_view = ImagePanelView() 137 | self._info_panel_view = InfoPanelView() 138 | 139 | self._ai_model: BaseAIModel = get_ai_model( 140 | c.AI_MODEL, 141 | c.AI_MODEL_CONFIGS[c.AI_MODEL] 142 | ) 143 | 144 | self.feedback_crud: FeedbackCRUD = FeedbackCRUD( 145 | session=Session() 146 | ) 147 | 148 | @property 149 | def model_input(self) -> ModelInput | None: 150 | return st.session_state.get("model_input", None) 151 | 152 | @model_input.setter 153 | def model_input(self, model_in: ModelInput): 154 | st.session_state["model_input"] = model_in 155 | 156 | @property 157 | def model_output(self): 158 | return st.session_state.get("model_output", None) 159 | 160 | @model_output.setter 161 | def model_output(self, model_output: ModelOutput): 162 | st.session_state["model_output"] = model_output 163 | 164 | def view(self, key: str): 165 | with st.sidebar: 166 | st.header("Model Configurations") 167 | st.info("Check out the documentation " 168 | "at [link](https://docs.ultralytics.com/modes/predict/#inference-sources)") 169 | 170 | model_input: ModelInput | None = self._upload_view.view(key=f"{key}_upload_inputs") 171 | if model_input is not None: 172 | # Run AI model when get new input 173 | with st.spinner("Running AI...."): 174 | model_output: ModelOutput = self._ai_model.process( 175 | image_in=model_input.upload_image, 176 | params=model_input.params 177 | ) 178 | st.toast("Finished AI processing", icon="🎉") 179 | self.model_input = model_input 180 | self.model_output = model_output 181 | 182 | if self.model_input is None: 183 | return 184 | 185 | image_col, info_col = st.columns([8, 2]) 186 | with image_col: 187 | updated_model_output, selected_index = self._image_panel_view.view( 188 | key=f"{key}_images", 189 | model_output=self.model_output, 190 | image_path=self.model_input.upload_image 191 | ) 192 | self.model_output = updated_model_output 193 | 194 | with info_col: 195 | edited_output: EditedOutput | None = self._info_panel_view.view( 196 | key=f"{key}_info", 197 | model_output=self.model_output, 198 | selected_index=selected_index 199 | ) 200 | 201 | save = st.button( 202 | "Edit & Save", 203 | key=f"{key}_save_btn", 204 | use_container_width=True, 205 | type="primary" 206 | ) 207 | 208 | if save and edited_output and 0 <= selected_index <= len(updated_model_output): 209 | updated_model_output[selected_index][-2] = edited_output.cls 210 | self.model_output = updated_model_output 211 | 212 | self.feedback_crud.create( 213 | image_path=self.model_input.upload_image, 214 | data=self.model_output.to_dict() 215 | ) 216 | 217 | st.toast("Saved", icon="🎉") 218 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/canvas/__init__.py -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/canvas.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Literal 4 | 5 | import streamlit as st 6 | import streamlit.components.v1 as components 7 | import streamlit.elements.image as st_image 8 | from PIL import Image 9 | 10 | from .processor import DataProcessor 11 | from shared.schemas import ModelOutput 12 | 13 | 14 | _RELEASE = True # on packaging, pass this to True 15 | 16 | 17 | if not _RELEASE: 18 | _component_func = components.declare_component( 19 | "st_sparrow_labeling", 20 | url="http://localhost:3001", 21 | ) 22 | else: 23 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 24 | build_dir = os.path.join(parent_dir, "frontend/build") 25 | _component_func = components.declare_component("st_sparrow_labeling", path=build_dir) 26 | 27 | 28 | @lru_cache(1) 29 | def get_background_image_bytes(image_path: str): 30 | background_image = Image.open(image_path) 31 | width, height = background_image.size 32 | 33 | format = st_image._validate_image_format_string(background_image, "PNG") 34 | image_data = _pil_to_bytes(background_image, format) 35 | 36 | return image_data, width 37 | 38 | 39 | def check_image_url(url): 40 | import requests 41 | try: 42 | response = requests.get(url) 43 | # Check if the request was successful 44 | if response.status_code == 200: 45 | return True 46 | else: 47 | return False 48 | except Exception as e: 49 | return False 50 | 51 | 52 | def _pil_to_bytes( 53 | image: st_image.PILImage, 54 | format: st_image.ImageFormat = "JPEG", 55 | quality: int = 100, 56 | ) -> bytes: 57 | import io 58 | 59 | """Convert a PIL image to bytes.""" 60 | tmp = io.BytesIO() 61 | 62 | # User must have specified JPEG, so we must convert it 63 | if format == "JPEG" and st_image._image_may_have_alpha_channel(image): 64 | image = image.convert("RGB") 65 | 66 | image.save(tmp, format=format, quality=quality) 67 | 68 | return tmp.getvalue() 69 | 70 | 71 | def st_annotate_tool( 72 | regions: ModelOutput, 73 | fill_color: str = "#eee", 74 | stroke_width: int = 20, 75 | stroke_color: str = "black", 76 | background_image: Image = None, 77 | drawing_mode: Literal["transform", "rect"] = "transform", 78 | point_display_radius: int = 3, 79 | canvas_height: int = 600, 80 | canvas_width: int = 600, 81 | key=None, 82 | ) -> tuple[ModelOutput, int]: 83 | """Create a drawing canvas in Streamlit app. Retrieve the RGBA image data into a 4D numpy array (r, g, b, alpha) 84 | on mouse up event. 85 | 86 | Parameters 87 | ---------- 88 | regions: ModelOutput 89 | Output from ai model, list of (x_min, y_min, x_max, y_max, score, cls) 90 | fill_color: str 91 | Color of fill for Rect in CSS color property. Defaults to "#eee". 92 | stroke_width: str 93 | Width of drawing brush in CSS color property. Defaults to 20. 94 | stroke_color: str 95 | Color of drawing brush in hex. Defaults to "black". 96 | background_image: Image 97 | Pillow Image to display behind canvas. 98 | Automatically resized to canvas dimensions. 99 | Being behind the canvas, it is not sent back to Streamlit on mouse event. 100 | drawing_mode: {'freedraw', 'transform', 'line', 'rect', 'circle', 'point', 'polygon'} 101 | Enable free drawing when "freedraw", object manipulation when "transform", "line", "rect", "circle", "point", "polygon". 102 | Defaults to "freedraw". 103 | point_display_radius: int 104 | The radius to use when displaying point objects. Defaults to 3. 105 | canvas_height: int 106 | Height of canvas in pixels. Defaults to 600. 107 | canvas_width: int 108 | Width of canvas in pixels. Defaults to 600. 109 | key: str 110 | An optional string to use as the unique key for the widget. 111 | Assign a key so the component is not remount every time the script is rerun. 112 | 113 | Returns 114 | ------- 115 | new_model_output: contains edited bounding boxes 116 | selected_index: select index 117 | """ 118 | # Resize background_image to canvas dimensions by default 119 | # Then override background_color 120 | if canvas_height == 0 or canvas_width == 0: 121 | return regions, -1 122 | 123 | background_image_url = None 124 | if background_image: 125 | image_bytes, width = get_background_image_bytes(background_image) 126 | 127 | # Reduce network traffic and cache when switch another configure, 128 | # use streamlit in-mem filemanager to convert image to URL 129 | background_image_url = st_image.image_to_url( 130 | image_bytes, width, True, "RGB", "PNG", 131 | f"drawable-canvas-bg-{background_image}-{key}" 132 | ) 133 | background_image_url = st._config.get_option("server.baseUrlPath") + background_image_url 134 | 135 | data_processor = DataProcessor() 136 | canvas_rects = data_processor.prepare_canvas_data(regions) 137 | 138 | component_value = _component_func( 139 | fillColor=fill_color, 140 | strokeWidth=stroke_width, 141 | strokeColor=stroke_color, 142 | backgroundImageURL=background_image_url, 143 | canvasHeight=canvas_height, 144 | canvasWidth=canvas_width, 145 | drawingMode=drawing_mode, 146 | initialDrawing=canvas_rects, 147 | displayRadius=point_display_radius, 148 | key=f"{key}_canvas", 149 | default=None, 150 | realtimeUpdateStreamlit=True, 151 | showingMode="All", 152 | displayToolbar=False 153 | ) 154 | 155 | if component_value is None: 156 | return regions, -1 157 | 158 | select_index = component_value.get('selectIndex', -1) 159 | new_model_output, select_index = data_processor.prepare_rect_data( 160 | component_value["raw"], 161 | regions, 162 | select_index 163 | ) 164 | 165 | return ( 166 | new_model_output, 167 | select_index, 168 | ) 169 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/.env: -------------------------------------------------------------------------------- 1 | # Run the component's dev server on :3001 2 | # (The Streamlit dev server already runs on :3000) 3 | PORT=3001 4 | 5 | # Don't automatically open the web browser on `npm run start`. 6 | BROWSER=none 7 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # package-lock.json 4 | 5 | # dependencies 6 | /node_modules 7 | /.pnp 8 | .pnp.js 9 | 10 | # testing 11 | /coverage 12 | 13 | # production 14 | /build 15 | 16 | # misc 17 | .DS_Store 18 | .env.local 19 | .env.development.local 20 | .env.test.local 21 | .env.production.local 22 | 23 | npm-debug.log* 24 | yarn-debug.log* 25 | yarn-error.log* 26 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "endOfLine": "lf", 3 | "semi": false, 4 | "trailingComma": "es5" 5 | } 6 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "drawable_canvas", 3 | "version": "0.9.3", 4 | "private": true, 5 | "dependencies": { 6 | "apache-arrow": "^0.17.0", 7 | "event-target-shim": "^5.0.1", 8 | "fabric": "4.4.0", 9 | "hoist-non-react-statics": "^3.3.2", 10 | "lodash": "^4.17.20", 11 | "react": "^16.13.1", 12 | "react-dom": "^16.13.1", 13 | "react-scripts": "4.0.3", 14 | "streamlit-component-lib": "^1.3.0", 15 | "typescript": "^4.6.3" 16 | }, 17 | "devDependencies": { 18 | "@types/fabric": "^3.6.2", 19 | "@types/hoist-non-react-statics": "^3.3.1", 20 | "@types/jest": "^24.0.0", 21 | "@types/lodash": "^4.14.161", 22 | "@types/node": "^12.0.0", 23 | "@types/react": "^16.9.0", 24 | "@types/react-dom": "^16.9.0" 25 | }, 26 | "scripts": { 27 | "start": "react-scripts start", 28 | "build": "react-scripts build", 29 | "test": "react-scripts test", 30 | "eject": "react-scripts eject" 31 | }, 32 | "eslintConfig": { 33 | "extends": "react-app" 34 | }, 35 | "browserslist": { 36 | "production": [ 37 | ">0.2%", 38 | "not dead", 39 | "not op_mini all" 40 | ], 41 | "development": [ 42 | "last 1 chrome version", 43 | "last 1 firefox version", 44 | "last 1 safari version" 45 | ] 46 | }, 47 | "homepage": "." 48 | } 49 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Streamlit Component 9 | 10 | 11 | 12 |
13 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/DrawableCanvas.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from "react" 2 | import { 3 | ComponentProps, 4 | Streamlit, 5 | withStreamlitConnection, 6 | } from "streamlit-component-lib" 7 | import { fabric } from "fabric" 8 | import { isEqual } from "lodash" 9 | 10 | import CanvasToolbar from "./components/CanvasToolbar" 11 | 12 | import { useCanvasState } from "./DrawableCanvasState" 13 | import { tools, FabricTool } from "./lib" 14 | 15 | function getStreamlitBaseUrl(): string | null { 16 | const params = new URLSearchParams(window.location.search) 17 | const baseUrl = params.get("streamlitUrl") 18 | if (baseUrl == null) { 19 | return null 20 | } 21 | 22 | try { 23 | return new URL(baseUrl).origin 24 | } catch { 25 | return null 26 | } 27 | } 28 | 29 | interface CustomFabricCanvas extends fabric.Canvas { 30 | isDragging?: boolean; 31 | selection?: boolean; 32 | lastPosX?: number; 33 | lastPosY?: number; 34 | 35 | secondTimeAccess?: boolean; 36 | currentState?: Object; 37 | showingMode?: string; 38 | 39 | } 40 | 41 | /** 42 | * Arguments Streamlit receives from the Python side 43 | */ 44 | export interface PythonArgs { 45 | fillColor: string 46 | strokeWidth: number 47 | strokeColor: string 48 | backgroundColor: string 49 | backgroundImageURL: string 50 | realtimeUpdateStreamlit: boolean 51 | canvasWidth: number 52 | canvasHeight: number 53 | drawingMode: string 54 | initialDrawing: Object 55 | displayToolbar: boolean 56 | displayRadius: number 57 | showingMode: string 58 | } 59 | 60 | /** 61 | * Define logic for the canvas area 62 | */ 63 | const DrawableCanvas = ({ args }: ComponentProps) => { 64 | const { 65 | canvasWidth, 66 | canvasHeight, 67 | backgroundColor, 68 | backgroundImageURL, 69 | realtimeUpdateStreamlit, 70 | drawingMode, 71 | fillColor, 72 | strokeWidth, 73 | strokeColor, 74 | displayRadius, 75 | initialDrawing, 76 | displayToolbar, 77 | showingMode 78 | }: PythonArgs = args 79 | 80 | /** 81 | * State initialization 82 | */ 83 | const [canvas, setCanvas] = useState(new fabric.Canvas("c") as CustomFabricCanvas); 84 | canvas.stopContextMenu = true 85 | canvas.fireRightClick = true 86 | 87 | const [selectedRect, setSelectedRect] = useState(-1) 88 | 89 | const [backgroundCanvas, setBackgroundCanvas] = useState(new fabric.Canvas("c") as CustomFabricCanvas); 90 | const { 91 | canvasState: { 92 | action: { shouldReloadCanvas, forceSendToStreamlit }, 93 | currentState, 94 | initialState, 95 | }, 96 | saveState, 97 | undo, 98 | redo, 99 | canUndo, 100 | canRedo, 101 | forceStreamlitUpdate, 102 | resetState, 103 | } = useCanvasState() 104 | 105 | 106 | /* 107 | * Load background image from URL 108 | */ 109 | // const params = new URLSearchParams(window.location.search); 110 | // const baseUrl = params.get('streamlitUrl') 111 | const baseUrl = getStreamlitBaseUrl() ?? "" 112 | let img = new fabric.Image() 113 | 114 | fabric.Image.fromURL(baseUrl + backgroundImageURL, function(oImg) { 115 | img = oImg 116 | img.selectable = false; 117 | backgroundCanvas.add(img); 118 | 119 | if (img.width == null || img.height == null){ 120 | return 121 | } 122 | 123 | // only initialize (image + rects) for canvas 1 124 | const isSecondTimes = (canvas.secondTimeAccess || false) 125 | 126 | /* 127 | * This is the first time UI is created, 128 | * And we try to align the canvas size with image by perform zooming only. 129 | * PS: This happend only for 1st time 130 | */ 131 | if (isSecondTimes === false){ // It means this is the first time 132 | console.log("Render Fist Time") 133 | canvas.loadFromJSON(initialDrawing, () => {}) 134 | 135 | // initialize zoom 136 | const widthRatio = canvas.getWidth() / img.width; 137 | const heightRatio = canvas.getHeight() / img.height; 138 | const zoom = Math.min(widthRatio, heightRatio) 139 | canvas.setZoom(zoom); 140 | backgroundCanvas.setZoom(zoom) 141 | 142 | canvas.secondTimeAccess = true 143 | canvas.requestRenderAll() 144 | backgroundCanvas.requestRenderAll() 145 | 146 | canvas.currentState = { ...initialDrawing } 147 | canvas.showingMode = showingMode 148 | } 149 | 150 | /* 151 | * User can choose some group of boxes to visualie (keys only, value only, or both) 152 | * Refresh the initial canvas 153 | * The current showingMode is different with the previous one! => Trigger to re-load the initialDrawings! 154 | * [07.10.2023] The below code should be erased. We don't allow to do it anymore because of low performance. 155 | */ 156 | if (canvas.showingMode !== showingMode){ 157 | canvas.showingMode = showingMode 158 | 159 | if (!isEqual(canvas.currentState, initialDrawing)){ 160 | canvas.loadFromJSON(initialDrawing, () => { 161 | canvas.currentState = { ...initialDrawing } 162 | 163 | canvas.renderAll() 164 | }) 165 | } 166 | } 167 | 168 | }); 169 | 170 | /** 171 | * Initialize canvases on component mount 172 | * NB: Remount component by changing its key instead of defining deps 173 | */ 174 | useEffect(() => { 175 | const c = new fabric.Canvas("canvas", { 176 | enableRetinaScaling: false, 177 | }) 178 | const imgC = new fabric.Canvas("backgroundimage-canvas", { 179 | enableRetinaScaling: false, 180 | }) 181 | setCanvas(c) 182 | setBackgroundCanvas(imgC) 183 | Streamlit.setFrameHeight() 184 | }, []) 185 | 186 | 187 | /** 188 | * If state changed from undo/redo/reset, update user-facing canvas 189 | */ 190 | useEffect(() => { 191 | if (shouldReloadCanvas) { 192 | canvas.loadFromJSON(currentState, () => {}) 193 | } 194 | }, [canvas, shouldReloadCanvas, currentState]) 195 | 196 | 197 | /** 198 | * Update canvas with selected tool 199 | * PS: add initialDrawing in dependency so user drawing update reinits tool 200 | */ 201 | useEffect(() => { 202 | // Update canvas events with selected tool 203 | const selectedTool = new tools[drawingMode](canvas) as FabricTool 204 | const cleanupToolEvents = selectedTool.configureCanvas({ 205 | fillColor: fillColor, 206 | strokeWidth: strokeWidth, 207 | strokeColor: strokeColor, 208 | displayRadius: displayRadius 209 | }) 210 | 211 | /* 212 | * Ensure zoom/pan do not exceed the boundary of canvas. 213 | */ 214 | let ensure_boundary: () => void = function (): void { 215 | const T = canvas.viewportTransform; 216 | 217 | if (img.aCoords == null || T == null) return 218 | 219 | const brRaw = img.aCoords.br 220 | const tlRaw = img.aCoords.tl 221 | 222 | const br = fabric.util.transformPoint(brRaw, T); 223 | const tl = fabric.util.transformPoint(tlRaw, T); 224 | 225 | const { 226 | x: left, 227 | y: top 228 | } = tl; 229 | 230 | const { 231 | x: right, 232 | y: bottom 233 | } = br; 234 | 235 | const width = canvas.getWidth() 236 | const height = canvas.getHeight() 237 | 238 | // calculate how far to translate to line up the edge of the object with 239 | // the edge of the canvas 240 | const dLeft = Math.abs(right - width); 241 | const dRight = Math.abs(left); 242 | const dUp = Math.abs(bottom - height); 243 | const dDown = Math.abs(top); 244 | const maxDx = Math.min(dLeft, dRight); 245 | const maxDy = Math.min(dUp, dDown); 246 | 247 | // if the object is larger than the canvas, clamp translation such that 248 | // we don't push the opposite boundary past the edge 249 | const leftIsOver = left < 0; 250 | const rightIsOver = right > width; 251 | const topIsOver = top < 0; 252 | const bottomIsOver = bottom > height; 253 | 254 | const translateLeft = rightIsOver && !leftIsOver; 255 | const translateRight = leftIsOver && !rightIsOver; 256 | const translateUp = bottomIsOver && !topIsOver; 257 | const translateDown = topIsOver && !bottomIsOver; 258 | 259 | const dx = translateLeft ? -maxDx : translateRight ? maxDx : 0; 260 | const dy = translateUp ? -maxDy : translateDown ? maxDy : 0; 261 | 262 | if (dx || dy) { 263 | T[4] += dx; 264 | T[5] += dy; 265 | canvas.requestRenderAll(); 266 | 267 | backgroundCanvas.setViewportTransform(T) 268 | backgroundCanvas.requestRenderAll() 269 | } 270 | 271 | }; 272 | 273 | /* 274 | * Mouse down event. 275 | * IF user press Alt keyboard, then move => Drag & Drop the image. 276 | */ 277 | canvas.on("mouse:down", function (this: CustomFabricCanvas, opt) { 278 | var evt = opt.e as MouseEvent; 279 | 280 | if (evt.altKey === true) { 281 | this.isDragging = true; 282 | this.selection = false; 283 | this.lastPosX = evt.clientX; 284 | this.lastPosY = evt.clientY; 285 | 286 | canvas.setCursor('grab') 287 | // canvas.discardActiveObject(); 288 | // canvas.requestRenderAll(); 289 | 290 | } 291 | 292 | if (opt.target) { 293 | if (opt.target.type === 'rect') { 294 | 295 | const selectObject = canvas.getActiveObject() 296 | const selectIndex = canvas.getObjects().indexOf(selectObject) 297 | 298 | selectObject.selectionBackgroundColor = 'rgba(63,245,39,0.5)' 299 | 300 | // Return selected object. 301 | setSelectedRect(selectIndex) 302 | 303 | const data = canvas 304 | .getContext() 305 | .canvas.toDataURL() 306 | 307 | Streamlit.setComponentValue({ 308 | data: data, 309 | width: canvas.getWidth(), 310 | height: canvas.getHeight(), 311 | raw: canvas.toObject(), 312 | selectIndex: selectIndex 313 | }) 314 | 315 | } 316 | } else { 317 | setSelectedRect(-1) 318 | } 319 | }) 320 | 321 | 322 | /* 323 | * Mouse move event. Only affect while the alt key is pressed. 324 | */ 325 | canvas.on("mouse:move", function (this: CustomFabricCanvas, opt) { 326 | var e = opt.e as MouseEvent 327 | 328 | if (this.isDragging || false) { 329 | canvas.setCursor('grab') 330 | const delta = new fabric.Point( e.movementX, e.movementY ) 331 | 332 | canvas.relativePan( delta ) 333 | backgroundCanvas.relativePan( delta ) 334 | 335 | ensure_boundary() 336 | 337 | e.preventDefault(); 338 | e.stopPropagation(); 339 | 340 | } 341 | }) 342 | 343 | /* 344 | * Mouse wheel event - Scale in/out 345 | */ 346 | canvas.on("mouse:wheel", function (this: CustomFabricCanvas, opt) { 347 | var e = opt.e as WheelEvent; 348 | var delta = e.deltaY; 349 | var zoom = canvas.getZoom(); 350 | zoom *= 0.999 ** delta; 351 | if (zoom > 10) zoom = 10; 352 | if (zoom < 0.1) zoom = 0.1; 353 | var point = new fabric.Point(e.offsetX, e.offsetY); 354 | canvas.zoomToPoint(point, zoom); 355 | backgroundCanvas.zoomToPoint(point, zoom); 356 | 357 | e.preventDefault(); 358 | e.stopPropagation(); 359 | }) 360 | 361 | canvas.on("mouse:up", (e: any) => { 362 | /* 363 | * There are several events can end with mouse:up: 364 | * 1. [rect] create new object 365 | * 2. [transform] resize selected object 366 | * 3. [transform] choose selected object 367 | * 4. [transform] delete selected object 368 | */ 369 | 370 | // saveState(canvas.toJSON()); 371 | 372 | var isEqualState = isEqual( canvas.toObject(), canvas.currentState ) 373 | if ( (isEqualState === false) && (drawingMode === 'transform') ){ 374 | canvas.currentState = { ...canvas.toObject() } 375 | 376 | const selectObject = canvas.getActiveObject() 377 | const selectIndex = canvas.getObjects().indexOf(selectObject) 378 | 379 | const data = canvas 380 | .getContext() 381 | .canvas.toDataURL() 382 | 383 | Streamlit.setComponentValue({ 384 | data: data, 385 | width: canvas.getWidth(), 386 | height: canvas.getHeight(), 387 | raw: canvas.toObject(), 388 | selectIndex: selectIndex 389 | }) 390 | 391 | } 392 | 393 | // Add your logic here for handling mouse up events 394 | canvas.isDragging = false; 395 | canvas.selection = true; 396 | canvas.setCursor("default") 397 | }); 398 | 399 | canvas.on("mouse:dblclick", () => { 400 | if (drawingMode === 'transform') { 401 | const selectObject = canvas.getActiveObject() 402 | const selectIndex = canvas.getObjects().indexOf(selectObject) 403 | 404 | canvas.remove(selectObject) 405 | 406 | const data = canvas 407 | .getContext() 408 | .canvas.toDataURL() 409 | 410 | Streamlit.setComponentValue({ 411 | data: data, 412 | width: canvas.getWidth(), 413 | height: canvas.getHeight(), 414 | raw: canvas.toObject(), 415 | selectIndex: selectIndex 416 | }) 417 | 418 | } 419 | 420 | }) 421 | 422 | // Cleanup tool + send data to Streamlit events 423 | return () => { 424 | cleanupToolEvents() 425 | canvas.off("mouse:down") 426 | canvas.off("mouse:move") 427 | canvas.off("mouse:up") 428 | canvas.off("mouse:wheel") 429 | canvas.off("mouse:dblclick") 430 | backgroundCanvas.off("mouse:down") 431 | backgroundCanvas.off("mouse:move") 432 | backgroundCanvas.off("mouse:up") 433 | backgroundCanvas.off("mouse:wheel") 434 | backgroundCanvas.off("mouse:dblclick") 435 | } 436 | }, [ 437 | canvas, 438 | backgroundCanvas, 439 | strokeWidth, 440 | strokeColor, 441 | displayRadius, 442 | fillColor, 443 | drawingMode, 444 | initialDrawing, 445 | saveState, 446 | forceStreamlitUpdate, 447 | img 448 | ]) 449 | 450 | /** 451 | * Render canvas w/ toolbar 452 | */ 453 | return ( 454 |
455 |
464 | {/**/} 473 | 474 |
475 |
483 | 488 |
489 |
497 | 503 |
504 | {displayToolbar && ( 505 | { 514 | resetState(initialState) 515 | }} 516 | /> 517 | )} 518 |
519 | ) 520 | } 521 | 522 | export default withStreamlitConnection(DrawableCanvas) -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/DrawableCanvasState.tsx: -------------------------------------------------------------------------------- 1 | import React, { 2 | createContext, 3 | useReducer, 4 | useContext, 5 | useCallback, 6 | } from "react" 7 | import { isEmpty, isEqual } from "lodash" 8 | 9 | const HISTORY_MAX_COUNT = 100 10 | 11 | interface CanvasHistory { 12 | undoStack: Object[] // store previous canvas states 13 | redoStack: Object[] // store undone canvas states 14 | } 15 | 16 | interface CanvasAction { 17 | shouldReloadCanvas: boolean // reload currentState into app canvas, on undo/redo 18 | forceSendToStreamlit: boolean // send currentState back to Streamlit 19 | } 20 | 21 | const NO_ACTION: CanvasAction = { 22 | shouldReloadCanvas: false, 23 | forceSendToStreamlit: false, 24 | } 25 | 26 | const RELOAD_CANVAS: CanvasAction = { 27 | shouldReloadCanvas: true, 28 | forceSendToStreamlit: false, 29 | } 30 | 31 | const SEND_TO_STREAMLIT: CanvasAction = { 32 | shouldReloadCanvas: false, 33 | forceSendToStreamlit: true, 34 | } 35 | 36 | const RELOAD_AND_SEND_TO_STREAMLIT: CanvasAction = { 37 | shouldReloadCanvas: true, 38 | forceSendToStreamlit: true, 39 | } 40 | 41 | interface CanvasState { 42 | history: CanvasHistory 43 | action: CanvasAction 44 | initialState: Object // first currentState for app 45 | currentState: Object // current canvas state as canvas.toJSON() 46 | } 47 | 48 | interface Action { 49 | type: "save" | "undo" | "redo" | "reset" | "forceSendToStreamlit" 50 | state?: Object 51 | } 52 | 53 | /** 54 | * Reducer takes 5 actions: save, undo, redo, reset, forceSendToStreamlit 55 | * 56 | * On reset, clear everything, set initial and current state to cleared canvas 57 | * 58 | * On save: 59 | * - First, if there is no initial state, set it to current 60 | * Since we don't reset history on component initialization 61 | * As backgroundColor/image are applied after component init 62 | * and wouldn't be stored in initial state 63 | * - If the sent state is same as current state, then nothing has changed so don't save 64 | * - Clear redo stack 65 | * - Push current state to undo stack, delete oldest if necessary 66 | * - Set new current state 67 | * 68 | * On undo: 69 | * - Push state to redoStack if it's not the initial 70 | * - Pop state from undoStack into current state 71 | * 72 | * On redo: 73 | * - Pop state from redoStack into current state 74 | * 75 | * For undo/redo/reset, set shouldReloadCanvas to inject currentState into user facing canvas 76 | */ 77 | const canvasStateReducer = ( 78 | state: CanvasState, 79 | action: Action 80 | ): CanvasState => { 81 | switch (action.type) { 82 | case "save": 83 | if (!action.state) throw new Error("No action state to save") 84 | else if (isEmpty(state.currentState)) { 85 | return { 86 | history: { 87 | undoStack: [], 88 | redoStack: [], 89 | }, 90 | action: { ...NO_ACTION }, 91 | initialState: action.state, 92 | currentState: action.state, 93 | } 94 | } else if (isEqual(action.state, state.currentState)) 95 | return { 96 | history: { ...state.history }, 97 | action: { ...NO_ACTION }, 98 | initialState: state.initialState, 99 | currentState: state.currentState, 100 | } 101 | else { 102 | const undoOverHistoryMaxCount = 103 | state.history.undoStack.length >= HISTORY_MAX_COUNT 104 | return { 105 | history: { 106 | undoStack: [ 107 | ...state.history.undoStack.slice(undoOverHistoryMaxCount ? 1 : 0), 108 | state.currentState, 109 | ], 110 | redoStack: [], 111 | }, 112 | action: { ...NO_ACTION }, 113 | initialState: 114 | state.initialState == null 115 | ? state.currentState 116 | : state.initialState, 117 | currentState: action.state, 118 | } 119 | } 120 | case "undo": 121 | if ( 122 | isEmpty(state.currentState) || 123 | isEqual(state.initialState, state.currentState) 124 | ) { 125 | return { 126 | history: { ...state.history }, 127 | action: { ...NO_ACTION }, 128 | initialState: state.initialState, 129 | currentState: state.currentState, 130 | } 131 | } else { 132 | const isUndoEmpty = state.history.undoStack.length === 0 133 | return { 134 | history: { 135 | undoStack: state.history.undoStack.slice(0, -1), 136 | redoStack: [...state.history.redoStack, state.currentState], 137 | }, 138 | action: { ...RELOAD_CANVAS }, 139 | initialState: state.initialState, 140 | currentState: isUndoEmpty 141 | ? state.currentState 142 | : state.history.undoStack[state.history.undoStack.length - 1], 143 | } 144 | } 145 | case "redo": 146 | if (state.history.redoStack.length > 0) { 147 | // TODO: test currentState empty too ? 148 | return { 149 | history: { 150 | undoStack: [...state.history.undoStack, state.currentState], 151 | redoStack: state.history.redoStack.slice(0, -1), 152 | }, 153 | action: { ...RELOAD_CANVAS }, 154 | initialState: state.initialState, 155 | currentState: 156 | state.history.redoStack[state.history.redoStack.length - 1], 157 | } 158 | } else { 159 | return { 160 | history: { ...state.history }, 161 | action: { ...NO_ACTION }, 162 | initialState: state.initialState, 163 | currentState: state.currentState, 164 | } 165 | } 166 | case "reset": 167 | if (!action.state) throw new Error("No action state to store in reset") 168 | return { 169 | history: { 170 | undoStack: [], 171 | redoStack: [], 172 | }, 173 | action: { ...RELOAD_AND_SEND_TO_STREAMLIT }, 174 | initialState: action.state, 175 | currentState: action.state, 176 | } 177 | case "forceSendToStreamlit": 178 | return { 179 | history: { ...state.history }, 180 | action: { ...SEND_TO_STREAMLIT }, 181 | initialState: state.initialState, 182 | currentState: state.currentState, 183 | } 184 | default: 185 | throw new Error("TS should protect from this") 186 | } 187 | } 188 | 189 | const initialState: CanvasState = { 190 | history: { 191 | undoStack: [], 192 | redoStack: [], 193 | }, 194 | action: { 195 | forceSendToStreamlit: false, 196 | shouldReloadCanvas: false, 197 | }, 198 | initialState: {}, 199 | currentState: {}, 200 | } 201 | 202 | interface CanvasStateContextProps { 203 | canvasState: CanvasState 204 | saveState: (state: Object) => void 205 | undo: () => void 206 | redo: () => void 207 | forceStreamlitUpdate: () => void 208 | canUndo: boolean 209 | canRedo: boolean 210 | resetState: (state: Object) => void 211 | } 212 | 213 | const CanvasStateContext = createContext( 214 | {} as CanvasStateContextProps 215 | ) 216 | 217 | export const CanvasStateProvider = ({ 218 | children, 219 | }: React.PropsWithChildren<{}>) => { 220 | const [canvasState, dispatch] = useReducer(canvasStateReducer, initialState) 221 | 222 | // Setup our callback functions 223 | // We memoize with useCallback to prevent unnecessary re-renders 224 | const saveState = useCallback( 225 | (state) => dispatch({ type: "save", state: state }), 226 | [dispatch] 227 | ) 228 | const undo = useCallback(() => dispatch({ type: "undo" }), [dispatch]) 229 | const redo = useCallback(() => dispatch({ type: "redo" }), [dispatch]) 230 | const forceStreamlitUpdate = useCallback( 231 | () => dispatch({ type: "forceSendToStreamlit" }), 232 | [dispatch] 233 | ) 234 | const resetState = useCallback( 235 | (state) => dispatch({ type: "reset", state: state }), 236 | [dispatch] 237 | ) 238 | 239 | const canUndo = canvasState.history.undoStack.length !== 0 240 | const canRedo = canvasState.history.redoStack.length !== 0 241 | 242 | return ( 243 | 255 | {children} 256 | 257 | ) 258 | } 259 | 260 | /** 261 | * Hook to get data out of context 262 | */ 263 | export const useCanvasState = () => { 264 | return useContext(CanvasStateContext) 265 | } 266 | 267 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/components/CanvasToolbar.module.css: -------------------------------------------------------------------------------- 1 | .enabled { 2 | cursor: pointer; 3 | background: none; 4 | } 5 | 6 | .disabled { 7 | cursor: not-allowed; 8 | filter: invert(95%) sepia(10%) saturate(657%) hue-rotate(184deg) 9 | brightness(92%) contrast(95%); 10 | } 11 | 12 | .enabled:hover { 13 | filter: invert(41%) sepia(62%) saturate(7158%) hue-rotate(344deg) 14 | brightness(101%) contrast(108%); 15 | } 16 | 17 | .invertx { 18 | transform: scaleX(-1); 19 | } 20 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/components/CanvasToolbar.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | 3 | import styles from "./CanvasToolbar.module.css" 4 | 5 | import bin from "../img/bin.png" 6 | import undo from "../img/undo.png" 7 | import download from "../img/download.png" 8 | 9 | interface SquareIconProps { 10 | imgUrl: string 11 | altText: string 12 | invertX?: boolean 13 | size: number 14 | enabled: boolean 15 | clickCallback: () => void 16 | } 17 | 18 | const SquareIcon = ({ 19 | imgUrl, 20 | altText, 21 | invertX, 22 | size, 23 | enabled, 24 | clickCallback, 25 | }: SquareIconProps) => ( 26 | {altText} 39 | ) 40 | SquareIcon.defaultProps = { 41 | invertX: false, 42 | } 43 | 44 | interface CanvasToolbarProps { 45 | topPosition: number 46 | leftPosition: number 47 | canUndo: boolean 48 | canRedo: boolean 49 | downloadCallback: () => void 50 | undoCallback: () => void 51 | redoCallback: () => void 52 | resetCallback: () => void 53 | } 54 | 55 | const CanvasToolbar = ({ 56 | topPosition, 57 | leftPosition, 58 | canUndo, 59 | canRedo, 60 | downloadCallback, 61 | undoCallback, 62 | redoCallback, 63 | resetCallback, 64 | }: CanvasToolbarProps) => { 65 | const GAP_BETWEEN_ICONS = 4 66 | const ICON_SIZE = 24 67 | 68 | const iconElements = [ 69 | { 70 | imgUrl: download, 71 | altText: "Send to Streamlit", 72 | invertX: false, 73 | enabled: true, 74 | clickCallback: downloadCallback, 75 | }, 76 | { 77 | imgUrl: undo, 78 | altText: "Undo", 79 | invertX: true, 80 | enabled: canUndo, 81 | clickCallback: canUndo ? undoCallback : () => {}, 82 | }, 83 | { 84 | imgUrl: undo, 85 | altText: "Redo", 86 | invertX: false, 87 | enabled: canRedo, 88 | clickCallback: canRedo ? redoCallback : () => {}, 89 | }, 90 | { 91 | imgUrl: bin, 92 | altText: "Reset canvas & history", 93 | invertX: false, 94 | enabled: true, 95 | clickCallback: resetCallback, 96 | }, 97 | ] 98 | 99 | return ( 100 |
110 | {iconElements.map((e) => ( 111 | 120 | ))} 121 |
122 | ) 123 | } 124 | 125 | export default CanvasToolbar 126 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/components/UpdateStreamlit.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from "react" 2 | import { Streamlit } from "streamlit-component-lib" 3 | import { fabric } from "fabric" 4 | 5 | const DELAY_DEBOUNCE = 0 6 | 7 | /** 8 | * Download image and JSON data from canvas to send back to Streamlit 9 | */ 10 | const sendDataToStreamlit = (canvas: fabric.Canvas, selectedRect: number): void => { 11 | const data = canvas 12 | .getContext() 13 | .canvas.toDataURL() 14 | Streamlit.setComponentValue({ 15 | data: data, 16 | width: canvas.getWidth(), 17 | height: canvas.getHeight(), 18 | raw: canvas.toObject(), 19 | selectIndex: selectedRect 20 | }) 21 | } 22 | 23 | /** 24 | * This hook allows you to debounce any fast changing value. 25 | * The debounced value will only reflect the latest value when the useDebounce hook has not been called for the specified time period. 26 | * When used in conjunction with useEffect, you can easily ensure that expensive operations like API calls are not executed too frequently. 27 | * https://usehooks.com/useDebounce/ 28 | * @param value value to debounce 29 | * @param delay delay of debounce in ms 30 | */ 31 | const useDebounce = (value: any, delay: number) => { 32 | const [debouncedValue, setDebouncedValue] = useState(value) 33 | 34 | useEffect( 35 | () => { 36 | // Update debounced value after delay 37 | const handler = setTimeout(() => { 38 | setDebouncedValue(value) 39 | }, delay) 40 | 41 | // Cancel the timeout if value changes (also on delay change or unmount) 42 | // This is how we prevent debounced value from updating if value is changed ... 43 | // .. within the delay period. Timeout gets cleared and restarted. 44 | return () => { 45 | clearTimeout(handler) 46 | } 47 | }, 48 | [value, delay] // Only re-call effect if value or delay changes 49 | ) 50 | return debouncedValue 51 | } 52 | 53 | interface UpdateStreamlitProps { 54 | shouldSendToStreamlit: boolean 55 | stateToSendToStreamlit: Object 56 | canvasWidth: number 57 | canvasHeight: number 58 | selectedRect: number 59 | } 60 | 61 | /** 62 | * Canvas whose sole purpose is to draw current state 63 | * to send image data to Streamlit. 64 | * Put it in the background or make it invisible! 65 | */ 66 | const UpdateStreamlit = (props: UpdateStreamlitProps) => { 67 | const [stCanvas, setStCanvas] = useState(new fabric.Canvas("")) 68 | 69 | // Debounce fast changing canvas states 70 | // Especially when drawing lines and circles which continuously render while drawing 71 | const debouncedStateToSend = useDebounce( 72 | props.stateToSendToStreamlit, 73 | DELAY_DEBOUNCE 74 | ) 75 | 76 | // Initialize canvas 77 | useEffect(() => { 78 | const stC = new fabric.Canvas("canvas-to-streamlit", { 79 | enableRetinaScaling: false, 80 | }) 81 | setStCanvas(stC) 82 | }, []) 83 | 84 | // Load state to canvas, then send content to Streamlit 85 | useEffect(() => { 86 | if (debouncedStateToSend && props.shouldSendToStreamlit) { 87 | stCanvas.loadFromJSON(debouncedStateToSend, () => { 88 | sendDataToStreamlit(stCanvas, props.selectedRect) 89 | }) 90 | } 91 | }, [stCanvas, props.shouldSendToStreamlit, props.selectedRect, debouncedStateToSend]) 92 | 93 | return ( 94 | 99 | ) 100 | } 101 | 102 | export default UpdateStreamlit 103 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/img/bin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/canvas/frontend/src/img/bin.png -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/img/download.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/canvas/frontend/src/img/download.png -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/img/undo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/canvas/frontend/src/img/undo.png -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/index.css: -------------------------------------------------------------------------------- 1 | :root { 2 | box-sizing: border-box; 3 | } 4 | 5 | *, 6 | ::before, 7 | ::after { 8 | box-sizing: inherit; 9 | } 10 | 11 | body { 12 | margin: 0; 13 | background: transparent; 14 | } 15 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/index.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | import ReactDOM from "react-dom" 3 | import DrawableCanvas from "./DrawableCanvas" 4 | import { CanvasStateProvider } from "./DrawableCanvasState" 5 | 6 | import "./index.css" 7 | 8 | ReactDOM.render( 9 | 10 | 11 | 12 | 13 | , 14 | document.getElementById("root") 15 | ) 16 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/src/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "lib": [ 5 | "dom", 6 | "dom.iterable", 7 | "esnext" 8 | ], 9 | "allowJs": true, 10 | "skipLibCheck": true, 11 | "esModuleInterop": true, 12 | "allowSyntheticDefaultImports": true, 13 | "strict": true, 14 | "forceConsistentCasingInFileNames": true, 15 | "module": "esnext", 16 | "moduleResolution": "node", 17 | "resolveJsonModule": true, 18 | "isolatedModules": true, 19 | "noEmit": true, 20 | "jsx": "react-jsx", 21 | "noFallthroughCasesInSwitch": true 22 | }, 23 | "include": [ 24 | "src" 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /streamlit_demo/shared/views/canvas/processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | from loguru import logger 5 | 6 | from shared.schemas import ModelOutput 7 | 8 | 9 | class DataProcessor: 10 | def __init__(self, filled_color: str = "rgba(0, 151, 255, 0.25)"): 11 | self._filled_color = filled_color 12 | 13 | def prepare_canvas_data( 14 | self, 15 | data: ModelOutput, 16 | ): 17 | canvas_rects = [] 18 | 19 | for i, box in enumerate(data.xyxysc): 20 | box: np.ndarray 21 | 22 | canvas_rect = self.construct_canvas_group( 23 | box[:4].astype(int), 24 | True, 25 | self._filled_color 26 | ) 27 | canvas_rects += [canvas_rect] 28 | 29 | payload = {"version": "4.4.0", "objects": canvas_rects} 30 | return payload 31 | 32 | @staticmethod 33 | def get_location_from_canvas_rect(canvas_rect: dict) -> tuple: 34 | x2 = round(canvas_rect['left'] + (canvas_rect['width'] * canvas_rect['scaleX'])) 35 | y2 = round(canvas_rect['top'] + (canvas_rect['height'] * canvas_rect['scaleY'])) 36 | x1 = round(canvas_rect['left']) 37 | y1 = round(canvas_rect['top']) 38 | 39 | return x1, y1, x2, y2 40 | 41 | @staticmethod 42 | def construct_canvas_group( 43 | box: np.ndarray, 44 | visibility: bool, 45 | filled_color: str 46 | ): 47 | x_min, y_min, x_max, y_max = map(int, box) 48 | 49 | canvas_rect = { 50 | "type": "rect", 51 | "version": "4.4.0", 52 | "originX": "left", 53 | "originY": "top", 54 | "left": x_min, 55 | "top": y_min, 56 | "width": x_max - x_min, 57 | "height": y_max - y_min, 58 | "fill": filled_color, 59 | "stroke": "rgba(0, 50, 255, 0.7)", 60 | "strokeWidth": 2, 61 | "strokeDashArray": None, 62 | "strokeLineCap": "butt", 63 | "strokeDashOffset": 0, 64 | "strokeLineJoin": "miter", 65 | "strokeUniform": True, 66 | "strokeMiterLimit": 4, 67 | "scaleX": 1, 68 | "scaleY": 1, 69 | "angle": 0, 70 | "flipX": False, 71 | "flipY": False, 72 | "opacity": 1, 73 | "shadow": None, 74 | "visible": visibility, 75 | "backgroundColor": "", 76 | "fillRule": "nonzero", 77 | "paintFirst": "fill", 78 | "globalCompositeOperation": "source-over", 79 | "skewX": 0, 80 | "skewY": 0, 81 | "rx": 0, 82 | "ry": 0, 83 | } 84 | 85 | return canvas_rect 86 | 87 | def prepare_rect_data( 88 | self, 89 | canvas_data, 90 | regions_in: ModelOutput, 91 | select_index: int = -1 92 | ): 93 | regions = [] 94 | n_in = len(regions_in.xyxysc) 95 | n_out = len(canvas_data["objects"]) 96 | 97 | if n_in <= n_out: 98 | # For adding & modify 99 | for i, canvas_rect in enumerate(canvas_data["objects"]): 100 | x_min, y_min, x_max, y_max = self.get_location_from_canvas_rect(canvas_rect) 101 | if i < n_in: 102 | # modifying: update location 103 | old_region = regions_in.xyxysc[i] 104 | old_region[:4] = [x_min, y_min, x_max, y_max] 105 | regions += [old_region] 106 | else: 107 | # adding 108 | region = np.array([x_min, y_min, x_max, y_max, 0.0, -1]) 109 | regions += [region] 110 | elif n_in > n_out: 111 | """ 112 | For deleting 113 | """ 114 | regions = [r for i, r in enumerate(regions_in.xyxysc) if i != select_index] 115 | select_index = -1 116 | 117 | xyxysc = np.array(regions) 118 | 119 | return ModelOutput(xyxysc=xyxysc), select_index 120 | --------------------------------------------------------------------------------