9 |
10 |
11 |
--------------------------------------------------------------------------------
/image-to-image-search/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | clip @ git+https://github.com/openai/CLIP.git
3 | faiss-cpu
4 |
5 | # Download data
6 | roboflow
7 |
--------------------------------------------------------------------------------
/leaked_container/.env:
--------------------------------------------------------------------------------
1 | API_KEY=this_is_a_secret_key
2 |
--------------------------------------------------------------------------------
/leaked_container/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11-alpine
2 |
3 | WORKDIR /app
4 |
5 | COPY . .
6 |
7 | RUN pip install --no-cache-dir -r requirements.txt
8 |
9 | COPY .env /app/.env
10 |
11 | RUN export $(cat /app/.env | xargs)
12 |
13 | RUN echo "API_KEY=${API_KEY}" > /tmp/credentials.txt
14 |
15 | # remove .env
16 | RUN rm /app/.env # INSECURE
17 |
18 | CMD ["python", "src/index.py"]
19 |
20 | EXPOSE 3000
21 |
22 |
--------------------------------------------------------------------------------
/leaked_container/requirements.txt:
--------------------------------------------------------------------------------
1 | flask
2 |
--------------------------------------------------------------------------------
/leaked_container/src/index.py:
--------------------------------------------------------------------------------
1 | from flask import Flask
2 | import os
3 |
4 | app = Flask(__name__)
5 |
6 | @app.route('/')
7 | def index():
8 | api_key = os.getenv('API_KEY')
9 | return f"API_KEY is: {api_key}"
10 |
11 | if __name__ == '__main__':
12 | app.run(host='0.0.0.0', port=3000)
13 |
14 |
--------------------------------------------------------------------------------
/rag-foundation/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/rag-foundation/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Pre-commit hooks
2 | default_language_version:
3 | python: python3
4 |
5 | repos:
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v4.5.0
8 | hooks:
9 | - id: check-ast
10 | - id: check-yaml
11 | - id: check-json
12 | - id: check-toml
13 | - id: check-case-conflict
14 | - id: check-docstring-first
15 | # - id: check-added-large-files
16 | - id: trailing-whitespace
17 | - id: detect-aws-credentials
18 | args: ["--allow-missing-credentials"]
19 | - id: detect-private-key
20 | - id: end-of-file-fixer
21 | - id: mixed-line-ending
22 |
23 | - repo: https://github.com/psf/black
24 | rev: 23.12.1
25 | hooks:
26 | - id: black
27 | name: PEP8 formatting
28 | args: [ --skip-string-normalization]
29 |
30 | - repo: https://github.com/PyCQA/isort
31 | rev: 5.13.2
32 | hooks:
33 | - id: isort
34 | name: I-sort imports
35 | args: ["--profile", "black"]
36 |
37 | - repo: https://github.com/PyCQA/flake8
38 | rev: 7.0.0
39 | hooks:
40 | - id: flake8
41 | name: PEP8 checker
42 |
43 | - repo: https://github.com/myint/autoflake
44 | rev: v2.2.1
45 | hooks:
46 | - id: autoflake
47 | args:
48 | [
49 | "--in-place",
50 | # "--remove-unused-variables",
51 | "--remove-all-unused-imports",
52 | "--ignore-init-module-imports",
53 | ]
54 |
--------------------------------------------------------------------------------
/rag-foundation/README.md:
--------------------------------------------------------------------------------
1 | # rag-foundation-exercise
2 |
3 | ## Installation
4 |
5 | **Note:** Prefer `python=3.10.*`
6 |
7 | ### 1. Fork the repo
8 |
9 | ### 2. Set up environment
10 | Assume that the name of your forked repository is also `ai-bootcamp-2024`.
11 |
12 | #### Windows
13 |
14 | - **Open Command Prompt.**
15 | - **Navigate to your project directory:**
16 |
17 | ```sh
18 | cd C:\Path\To\ai-bootcamp-2024
19 | ```
20 |
21 | - **Create a virtual environment using Python 3.10:**
22 |
23 | Check your python version first using `py -0` or `where python`
24 |
25 | ```
26 | python -m venv rag-foundation
27 | or
28 | path/to/python3.10 -m venv rag-foundation
29 | ```
30 |
31 | - **Activate the Virtual Environment:**
32 |
33 | ```sh
34 | rag-foundation\Scripts\activate
35 | ```
36 |
37 | #### Ubuntu/MacOS
38 |
39 | - **Open a terminal.**
40 | - **Create a new Conda environment with Python 3.10:**
41 |
42 | ```sh
43 | conda create --name rag-foundation python=3.10
44 | ```
45 |
46 | - **Activate the Conda Environment:**
47 |
48 | ```sh
49 | conda activate rag-foundation
50 | ```
51 |
52 | ### 3. **Install Required Packages:**
53 |
54 | - Install the required packages from `requirements.txt`:
55 |
56 | ```sh
57 | pip install -r requirements.txt
58 | ```
59 |
60 | ## Homework
61 |
62 | ### 1. **Fill your implementation**
63 |
64 | Search for `"Your code here"` line in the codebase which will lead you to where you should place your code.
65 |
66 | ### 2. **Run script**
67 |
68 | You should read the code in this repository carefully to understand the setup comprehensively.
69 |
70 | You can run the script below to get the results from your pre-built RAG, for example:
71 |
72 | ```sh
73 | python -m scripts.main \
74 | --data_path \
75 | --output_path predictions.jsonl \
76 | --mode \
77 | --force_index \
78 | --retrieval_only True \
79 | --top_k 5
80 | ```
81 |
82 | where some arguments can be:
83 |
84 | - `mode`: `sparse` or `semantic`
85 | - `force_index`: `True` or `False` (True: override the old vectorstore index)
86 | - `retrieval_only`: `True` or `False` (True: just get the retrieval contexts, answers are empty)
87 |
88 | #### NOTE:
89 |
90 | To use LLM generation with RAG pipeline, you can use ChatOpenAI by supplying OPENAI_API_KEY in the enviroment variable (supposed you have one).
91 | If you don't have access to OpenAI API, use Groq free-tier instead:
92 |
93 | - Register an account at https://console.groq.com/keys (free)
94 | - Generate your API key
95 | - Assign env variable: `export GROQ_API_KEY=`
96 | - Run the main script without `--retrieval_only` to use LLM
97 |
98 | ### 3. **Run Evaluation:**
99 | ```sh
100 | python evaluate.py --predictions predictions.jsonl --gold data/qasper-test-v0.3.json --retrieval_only
101 | ```
102 | $\rightarrow$ just evaluate the retrieval contexts.
103 |
104 | ```sh
105 | python evaluate.py --predictions predictions.jsonl --gold data/qasper-test-v0.3.json
106 | ```
107 | $\rightarrow$ evaluate both the retrieval contexts and answers.
108 |
--------------------------------------------------------------------------------
/rag-foundation/data/llama2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/rag-foundation/data/llama2.pdf
--------------------------------------------------------------------------------
/rag-foundation/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Official script for evaluating models built for the Qasper dataset. The script
3 | outputs Answer F1 and Evidence F1 reported in the paper.
4 | """
5 |
6 | import argparse
7 | import json
8 | import re
9 | import string
10 | from collections import Counter
11 |
12 |
13 | def normalize_answer(s):
14 | """
15 | Taken from the official evaluation script for v1.1 of the SQuAD dataset.
16 | Lower text and remove punctuation, articles and extra whitespace.
17 | """
18 |
19 | def remove_articles(text):
20 | return re.sub(r"\b(a|an|the)\b", " ", text)
21 |
22 | def white_space_fix(text):
23 | return " ".join(text.split())
24 |
25 | def remove_punc(text):
26 | exclude = set(string.punctuation)
27 | return "".join(ch for ch in text if ch not in exclude)
28 |
29 | def lower(text):
30 | return text.lower()
31 |
32 | return white_space_fix(remove_articles(remove_punc(lower(s))))
33 |
34 |
35 | def token_f1_score(prediction, ground_truth):
36 | """
37 | Taken from the official evaluation script for v1.1 of the SQuAD dataset.
38 | """
39 | prediction_tokens = normalize_answer(prediction).split()
40 | ground_truth_tokens = normalize_answer(ground_truth).split()
41 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
42 | num_same = sum(common.values())
43 | if num_same == 0:
44 | return 0
45 | precision = 1.0 * num_same / len(prediction_tokens)
46 | recall = 1.0 * num_same / len(ground_truth_tokens)
47 | f1 = (2 * precision * recall) / (precision + recall)
48 | return f1
49 |
50 |
51 | def paragraph_f1_score(prediction, ground_truth):
52 | if not ground_truth and not prediction:
53 | # The question is unanswerable and the prediction is empty.
54 | return 1.0
55 | num_same = len(set(ground_truth).intersection(set(prediction)))
56 | if num_same == 0:
57 | return 0.0
58 | precision = num_same / len(prediction)
59 | recall = num_same / len(ground_truth)
60 | f1 = (2 * precision * recall) / (precision + recall)
61 | return f1
62 |
63 |
64 | def get_answers_and_evidence(data, text_evidence_only):
65 | answers_and_evidence = {}
66 | for paper_data in data.values():
67 | for qa_info in paper_data["qas"]:
68 | question_id = qa_info["question_id"]
69 | references = []
70 | for annotation_info in qa_info["answers"]:
71 | answer_info = annotation_info["answer"]
72 | if answer_info["unanswerable"]:
73 | references.append(
74 | {"answer": "Unanswerable", "evidence": [], "type": "none"}
75 | )
76 | else:
77 | if answer_info["extractive_spans"]:
78 | answer = ", ".join(answer_info["extractive_spans"])
79 | answer_type = "extractive"
80 | elif answer_info["free_form_answer"]:
81 | answer = answer_info["free_form_answer"]
82 | answer_type = "abstractive"
83 | elif answer_info["yes_no"]:
84 | answer = "Yes"
85 | answer_type = "boolean"
86 | elif answer_info["yes_no"] is not None:
87 | answer = "No"
88 | answer_type = "boolean"
89 | else:
90 | raise RuntimeError(
91 | f"Annotation {answer_info['annotation_id']} does not contain an answer"
92 | )
93 | if text_evidence_only:
94 | evidence = [
95 | text
96 | for text in answer_info["evidence"]
97 | if "FLOAT SELECTED" not in text
98 | ]
99 | else:
100 | evidence = answer_info["evidence"]
101 | references.append(
102 | {"answer": answer, "evidence": evidence, "type": answer_type}
103 | )
104 | answers_and_evidence[question_id] = references
105 |
106 | return answers_and_evidence
107 |
108 |
109 | def evaluate(gold, predicted, retrieval_only=False):
110 | max_answer_f1s = []
111 | max_evidence_f1s = []
112 | max_answer_f1s_by_type = {
113 | "extractive": [],
114 | "abstractive": [],
115 | "boolean": [],
116 | "none": [],
117 | }
118 | num_missing_predictions = 0
119 | for question_id, references in gold.items():
120 | if question_id not in predicted:
121 | num_missing_predictions += 1
122 | max_answer_f1s.append(0.0)
123 | max_evidence_f1s.append(0.0)
124 | continue
125 | answer_f1s_and_types = [
126 | (
127 | token_f1_score(predicted[question_id]["answer"], reference["answer"]),
128 | reference["type"],
129 | )
130 | for reference in gold[question_id]
131 | ]
132 | max_answer_f1, answer_type = sorted(
133 | answer_f1s_and_types, key=lambda x: x[0], reverse=True
134 | )[0]
135 | max_answer_f1s.append(max_answer_f1)
136 | max_answer_f1s_by_type[answer_type].append(max_answer_f1)
137 | evidence_f1s = [
138 | paragraph_f1_score(
139 | predicted[question_id]["evidence"], reference["evidence"]
140 | )
141 | for reference in gold[question_id]
142 | ]
143 | max_evidence_f1s.append(max(evidence_f1s))
144 |
145 | mean = lambda x: sum(x) / len(x) if x else 0.0
146 |
147 | if not retrieval_only:
148 | return {
149 | "Answer F1": mean(max_answer_f1s),
150 | "Answer F1 by type": {
151 | key: mean(value) for key, value in max_answer_f1s_by_type.items()
152 | },
153 | "Evidence F1": mean(max_evidence_f1s),
154 | "Missing predictions": num_missing_predictions,
155 | }
156 | else:
157 | return {
158 | "Evidence F1": mean(max_evidence_f1s),
159 | }
160 |
161 |
162 | if __name__ == "__main__":
163 | parser = argparse.ArgumentParser()
164 | parser.add_argument(
165 | "--predictions",
166 | type=str,
167 | required=True,
168 | help="""JSON lines file with each line in format:
169 | {'question_id': str, 'predicted_answer': str, 'predicted_evidence': List[str]}""",
170 | )
171 | parser.add_argument(
172 | "--gold",
173 | type=str,
174 | required=True,
175 | help="Test or dev set from the released dataset",
176 | )
177 | parser.add_argument(
178 | "--retrieval_only",
179 | help="If set, the evaluator will just evaluate the retrieval scores",
180 | action="store_true",
181 | )
182 | parser.add_argument(
183 | "--text_evidence_only",
184 | action="store_true",
185 | help="If set, the evaluator will ignore evidence in figures and tables while reporting evidence f1",
186 | )
187 | args = parser.parse_args()
188 | gold_data = json.load(open(args.gold))
189 | gold_answers_and_evidence = get_answers_and_evidence(
190 | gold_data, args.text_evidence_only
191 | )
192 | predicted_answers_and_evidence = {}
193 | for line in open(args.predictions):
194 | prediction_data = json.loads(line)
195 | predicted_answers_and_evidence[prediction_data["question_id"]] = {
196 | "answer": prediction_data["predicted_answer"],
197 | "evidence": prediction_data["predicted_evidence"],
198 | }
199 | evaluation_output = evaluate(
200 | gold_answers_and_evidence,
201 | predicted_answers_and_evidence,
202 | retrieval_only=args.retrieval_only,
203 | )
204 | print(json.dumps(evaluation_output, indent=2))
205 |
--------------------------------------------------------------------------------
/rag-foundation/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.9.5
2 | aiosignal==1.3.1
3 | annotated-types==0.7.0
4 | anyio==4.4.0
5 | async-timeout==4.0.3
6 | attrs==23.2.0
7 | beautifulsoup4==4.12.3
8 | certifi==2024.7.4
9 | charset-normalizer==3.3.2
10 | click==8.1.7
11 | dataclasses-json==0.6.7
12 | Deprecated==1.2.14
13 | dirtyjson==1.0.8
14 | distro==1.9.0
15 | exceptiongroup==1.2.2
16 | filelock==3.15.4
17 | frozenlist==1.4.1
18 | fsspec==2024.6.1
19 | greenlet==3.0.3
20 | h11==0.14.0
21 | httpcore==1.0.5
22 | httpx==0.27.0
23 | huggingface-hub==0.23.4
24 | idna==3.7
25 | Jinja2==3.1.4
26 | joblib==1.4.2
27 | llama-cloud==0.0.9
28 | llama-index==0.10.55
29 | llama-index-agent-openai==0.2.8
30 | llama-index-cli==0.1.12
31 | llama-index-core==0.10.55
32 | llama-index-embeddings-openai==0.1.10
33 | llama-index-indices-managed-llama-cloud==0.2.5
34 | llama-index-legacy==0.9.48
35 | llama-index-llms-ollama==0.1.5
36 | llama-index-llms-openai==0.1.25
37 | llama-index-multi-modal-llms-openai==0.1.7
38 | llama-index-program-openai==0.1.6
39 | llama-index-question-gen-openai==0.1.3
40 | llama-index-readers-file==0.1.30
41 | llama-index-readers-llama-parse==0.1.6
42 | llama-parse==0.4.7
43 | loguru==0.7.2
44 | MarkupSafe==2.1.5
45 | marshmallow==3.21.3
46 | mpmath==1.3.0
47 | multidict==6.0.5
48 | mypy-extensions==1.0.0
49 | nest-asyncio==1.6.0
50 | networkx==3.3
51 | nltk==3.8.1
52 | numpy==1.26.4
53 | openai==1.35.13
54 | packaging==24.1
55 | pandas==2.2.2
56 | pillow==10.4.0
57 | pydantic==2.8.2
58 | pydantic_core==2.20.1
59 | PyMuPDF==1.24.7
60 | PyMuPDFb==1.24.6
61 | pypdf==4.3.0
62 | python-dateutil==2.9.0.post0
63 | pytz==2024.1
64 | PyYAML==6.0.1
65 | regex==2024.5.15
66 | requests==2.32.3
67 | safetensors==0.4.3
68 | scikit-learn==1.5.1
69 | scipy==1.14.0
70 | sentence-transformers==3.0.1
71 | setuptools==69.5.1
72 | six==1.16.0
73 | sniffio==1.3.1
74 | soupsieve==2.5
75 | SQLAlchemy==2.0.31
76 | striprtf==0.0.26
77 | sympy==1.13.0
78 | tenacity==8.5.0
79 | threadpoolctl==3.5.0
80 | tiktoken==0.7.0
81 | tokenizers==0.19.1
82 | torch==2.3.1
83 | tqdm==4.66.4
84 | transformers==4.42.4
85 | typing-inspect==0.9.0
86 | typing_extensions==4.12.2
87 | tzdata==2024.1
88 | urllib3==2.2.2
89 | wheel==0.43.0
90 | wrapt==1.16.0
91 | yarl==1.9.4
92 | fire
93 | langchain-openai
94 | langchain-groq
95 |
--------------------------------------------------------------------------------
/rag-foundation/sample_predictions.jsonl:
--------------------------------------------------------------------------------
1 | {"question_id": "397a1e851aab41c455c2b284f5e4947500d797f0", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a typing conversation task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents. From Table TABREF19, we observe that MISSA maintains longer conversations (14.9 turns) compared with TransferTransfo (8.5 turns), which means MISSA is better at maintaining the attacker's engagement. MISSA also has a higher task success score (1.294) than TransferTransfo (1.025), which indicates that it elicits information more strategically. In the top two dialogs (A and B) that are shown in Table TABREF21, both attackers were eliciting a credit card number in their first turns. TransferTransfo directly gave away the information, while MISSA replied with a semantically-related question \u201cwhy would you need my credit card number?\" Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor.", "MISSA is based on the generative pre-trained transformer BIBREF32. We use an Adam optimizer with a learning rate of 6.25e-5 and $L2$ weight decay of $0.01$, we set the coefficient of language modeling loss to be 2, the coefficient of intent and slot classifiers to be 1, and the coefficient of next-utterance classifier to be 1. We first pre-train the model on the PERSONA-CHAT dataset. When fine-tuning on the AntiScam and the PersuasionForGood datasets, we use $80\\%$ data for training, $10\\%$ data for validation, and $10\\%$ data for testing. Since the original PersuasionForGood dataset is annotated with intents, we separate the original on-task and off-task intents, which are shown in Table TABREF2. To deal with the words out of the vocabulary, we conduct delexicalization to replace slot values with corresponding slot tokens during the training phase, and replace the slot tokens with pre-defined information during testing.\n\nAn example of human-human chat on AntiScam dataset is shown in Table TABREF25.", "MISSA follows the TransferTransfo framework BIBREF0 with three modifications: (i) We first concurrently predict user's, system's intents and semantic slots; (ii) We then perform conditional generation to improve generated response's coherence. Specifically, we generate responses conditioned on the above intermediate representation (intents and slots); (iii) Finally, we generate multiple responses with the nucleus sampling strategy BIBREF5 and then apply a response filter, which contains a set of pre-defined constraints to select coherent responses. The constraints in the filter can be defined according to specific task requirements or general conversational rules.\n\nTo enrich publicly available non-collaborative task datasets, we collect a new dataset AntiScam, where users defend themselves against attackers trying to collect personal information. As non-collaborative tasks are still relatively new to the study of dialog systems, there are insufficiently many meaningful datasets for evaluation and we hope this provides a valuable example. We evaluate MISSA on the newly collected AntiScam dataset and an existing PersuasionForGood dataset. Both automatic and human evaluations suggest that MISSA outperforms multiple competitive baselines.\n\nIn summary, our contributions include: (i) We design a hierarchical intent annotation scheme and a semantic slot annotation scheme to annotate the non-collaborative dialog dataset, we also propose a carefully-designed AntiScam dataset to facilitate the research of non-collaborative dialog systems. (ii) We propose a model that can be applied to all non-collaborative tasks, outperforming other baselines on two different non-collaborative tasks. (iii) We develop an anti-scam dialog system to occupy attacker's attention and elicit their private information for social good. Furthermore, we also build a persuasion dialog system to persuade people to donate to charities. We release the code and data.\n\nThe interest in non-collaborative tasks has been increasing and there have already been several related datasets. For instance, BIBREF1 wang2019persuasion collected conversations where one participant persuades another to donate to a charity. BIBREF2 he2018decoupling collected negotiation dialogs where buyers and sellers bargain for items for sale on Craigslist. There are many other non-collaborative tasks, such as the turn-taking game BIBREF6, the multi-party game BIBREF7 and item splitting negotiation BIBREF8.", "We posted a role-playing task on the Amazon Mechanical Turk platform and collected a typing conversation dataset named AntiScam. We collected 220 human-human dialogs. The average conversation length is 12.45 turns and the average utterance length is 11.13 words. Only 172 out of 220 users successfully identified their partner as an attacker, suggesting that the attackers are well trained and not too easily identifiable. We recruited two expert annotators who have linguistic training to annotate 3,044 sentences in 100 dialogs, achieving a 0.874 averaged weighted kappa value.\n\nThe PersuasionForGood dataset BIBREF1 was collected from typing conversations on Amazon Mechanical Turk platform. Two workers were randomly paired, one was assigned the role of persuader, the other was persuadee. The goal of the persuader was to persuade the persuadee to donate a portion of task earning to a specific charity. The dataset consists of 1,017 dialogs, where 300 dialogs are annotated with dialog acts. The average conversation length is 10.43, the vocabulary size is 8,141. Since the original PersuasionForGood dataset is annotated with dialog acts, we select the on-task dialog acts as on-task intents shown in Table TABREF2, and categorize the other dialog acts into our pre-defined off-task intents.\n\nThe TransferTransfo framework was proposed to build open domain dialog systems. BIBREF0 wolf2019transfertransfo fine-tuned the generative pre-training model (GPT) BIBREF32 with the PERSONA-CHAT dataset BIBREF33 in a multi-task fashion, where the language model objective is combined with a next-utterance classification task. The language model's objective is to maximize the following likelihood for a given sequence of tokens, $X = \\lbrace x_1,\\dots ,x_n\\rbrace $:\n\nThe authors also trained a classifier to distinguish the correct next-utterance appended to the input human utterances from a set of randomly selected utterance distractors. In addition, they introduced dialog state embeddings to indicate speaker role in the model. The model significantly outperformed previous baselines over both automatic evaluations and human evaluations in social conversations. Since the TransferTransfo framework performs well in open domain, we adapt it for non-collaborative settings.", "We suspect the underlying reason is that there are more possible responses with the same intent in PersuasionForGood than in AntiScam. This also suggests that we should adjust the model structure according to the nature of the dataset.\n\nWe propose a general dialog system pipeline to build non-collaborative dialog systems, including a hierarchical annotation scheme and an end-to-end neural response generation model called MISSA. With the hierarchical annotation scheme, we can distinguish on-task and off-task intents. MISSA takes both on and off-task intents as supervision in its training and thus can deal with diverse user utterances in non-collaborative settings. Moreover, to validate MISSA's performance, we create a non-collaborate dialog dataset that focuses on deterring phone scammers. MISSA outperforms all baseline methods in terms of fluency, coherency, and user engagement on both the newly proposed anti-scam task and an existing persuasion task. However, MISSA still produces responses that are not consistent with their distant conversation history as GPT can only track a limited history span. In future work, we plan to address this issue by developing methods that can effectively track longer dialog context.\n\nThis work was supported by DARPA ASED Program HR001117S0050. The U.S. Government is authorized to reproduce and distribute reprints for governmental purposes not withstanding any copyright notation therein. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies, either expressed or implied, of DARPA or the U.S. Government.\n\nWe randomly pair two workers: one is assigned the role of the attacker to elicit user information, and the other one is assigned the role of an everyday user who aims to protect her/his information and potentially elicit the attacker's information. We give both workers specific personal data. Instructions are shown in Table TABREF24. The \u201cattacker\u201d additionally receives training on how to elicit information from people. Workers cannot see their partners' instructions.\n\nThere are two tasks for the users: firstly, users are required to chat with their partners and determine if they are attackers or not, reporting their decisions at the end of the task. If users think their partners are attackers, they are instructed to prolong the conversation and elicit information from their partners."]}
2 | {"question_id": "cc8b4ed3985f9bfbe1b5d7761b31d9bd6a965444", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a typing conversation task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["The intent predictor achieves a $84\\%$ accuracy and the semantic slot predictor achieves $77\\%$ on the AntiScam dataset. Then we compare the predicted values with human-annotated ground truth in the dataset to compute the response-intent prediction (RIP) and response-slot prediction (RSP).\n\nExtended Response-Intent Prediction (ERIP) $\\&$ Extended Response-Slot Prediction (ERSP) With Response-Intent Prediction, we verify the predicted intents to evaluate the coherence of the dialog. However, the real mapping between human-intent and system-intent is much more complicated as there might be multiple acceptable system-intents for the same human-intent. Therefore, we also design a metric to evaluate if the predicted system-intent is in the set of acceptable intents. Specifically, we estimate the transition probability $p(I_i|I_j)$ by counting the frequency of all the bi-gram human-intent and system-intent pairs in the training data. During the test stage, if the predicted intent matches the ground truth, we set the score as 1, otherwise we set the score as $p(I_{predict}|I_i)$ where $I_i$ is the intent of the input human utterance. We then report the average value of those scores over turns as the final extended response-intent prediction result.\n\nAutomatic metrics only validate the system\u2019s performance on a single dimension at a time. The ultimate holistic evaluation should be conducted by having the trained system interact with human users. Therefore we also conduct human evaluations for the dialog system built on AntiScam. We test our models and baselines with 15 college-student volunteers. Each of them is asked to pretend to be an attacker and interact with all the models for at least three times to avoid randomness. We in total collect 225 number of dialogs. Each time, volunteers are required to use similar sentences and strategies to interact with all five models and score each model based on the metrics listed below at the end of the current round. Each model receives a total of 45 human ratings, and the average score is reported as the final human-evaluation score. In total, we design five different metrics to assess the models' conversational ability whilst interacting with humans. The results are shown in Table TABREF19.\n\nFluency Fluency is used to explore different models' language generation quality.", "Compared with these works, MISSA is end-to-end trainable and thus easier to train and update.\n\nTo decouple syntactic and semantic information in utterances and provide detailed supervision, we design a hierarchical intent annotation scheme for non-collaborative tasks. We first separate on-task and off-task intents. As on-task intents are key actions that can vary among different tasks, we need to specifically define on-task intents for each task. On the other hand, since off-task content is too general to design task-specific intents, we choose common dialog acts as the categories. The advantage of this hierarchical annotation scheme is apparent when starting a new non-collaborative task: we only need to focus on designing the on-task categories and semantic slots which are the same as traditional task-oriented dialog systems. Consequently, we don't have to worry about the off-task annotation design since the off-task category is universal.\n\nIn the intent annotation scheme shown in Table TABREF2, we list the designed intent annotation scheme for the newly collected AntiScam dataset and the PersuasionForGood dataset. We first define on-task intents for the datasets, which are key actions in the task. Since our AntiScam focuses on understanding and reacting towards elicitations, we define elicitation, providing_information and refusal as on-task intents. In the PersuasionForGood dataset, we define nine on-task intents in Table TABREF2 based on the original PersuasionForGood dialog act annotation scheme. All these intents are related to donation actions, which are salient on-task intents in the persuasion task. The off-task intents are the same for both tasks, including six general intents and six additional social intents. General intents are more closely related to the syntactic meaning of the sentence (open_question, yes_no_question, positive_answer, negative_answer, responsive_statement, and nonresponsive_statement) while social intents are common social actions (greeting, closing, apology, thanking,respond_to_thank, and hold).\n\nFor specific tasks, we also design a semantic slot annotation scheme for annotating sentences based on their semantic content. We identify 13 main semantic slots in the anti-scam task, for example, credit card numbers. We present a detailed semantic slot annotation in Table TABREF3. Following BIBREF1, we segment each conversation turn into single sentences and then annotate each sentence rather than turns.", "Therefore, we need to design a system that handles both on-task and off-task information appropriately and in a way that leads back to the system's goal.\n\nTo tackle the issue of incoherent system responses to off-task content, previous studies have built hybrid systems to interleave off-task and on-task content. BIBREF4 used a rule-based dialog manager for on-task content and a neural model for off-task content, and trained a reinforcement learning model to select between these two models based on the dialog context. However, such a method is difficult to train and struggles to generalize beyond the movie promotion task they considered. To tackle these problems, we propose a hierarchical intent annotation scheme that separates on-task and off-task information in order to provide detailed supervision. For on-task information, we directly use task-related intents for representation. Off-task information, on the other hand, is too general to categorize into specific intents, so we choose dialog acts that convey syntax information. These acts, such as \u201copen question\" are general to all tasks.\n\nPrevious studies use template-based methods to maintain sentence coherence. However, rigid templates lead to limited diversity, causing the user losing engagement. On the other hand, language generation models can generate diverse responses but are bad at being coherent. We propose Multiple Intents and Semantic Slots Annotation Neural Network (MISSA) to combine the advantages of both template and generation models and takes advantage from the hierarchical annotation at the same time. MISSA follows the TransferTransfo framework BIBREF0 with three modifications: (i) We first concurrently predict user's, system's intents and semantic slots; (ii) We then perform conditional generation to improve generated response's coherence. Specifically, we generate responses conditioned on the above intermediate representation (intents and slots); (iii) Finally, we generate multiple responses with the nucleus sampling strategy BIBREF5 and then apply a response filter, which contains a set of pre-defined constraints to select coherent responses. The constraints in the filter can be defined according to specific task requirements or general conversational rules.\n\nTo enrich publicly available non-collaborative task datasets, we collect a new dataset AntiScam, where users defend themselves against attackers trying to collect personal information. As non-collaborative tasks are still relatively new to the study of dialog systems, there are insufficiently many meaningful datasets for evaluation and we hope this provides a valuable example.", "BIBREF9 hardy2002multi followed the DAMSL schemeBIBREF10 and annotated a multilingual human-computer dialog corpus with a hierarchical dialog act annotation scheme. BIBREF11 gupta2018semantic used a hierarchical annotation scheme for semantic parsing. Inspired by these studies, our idea is to annotate the intent and semantic slot separately in non-collaborative tasks. We propose a hierarchical intent annotation scheme that can be adopted by all non-collaborative tasks. With this annotation scheme, MISSA is able to quickly build an end-to-end trainable dialog system for any non-collaborative task.\n\nTraditional task-oriented dialog systems BIBREF12 are usually composed of multiple independent modules, for example, natural language understanding, dialog state tracking BIBREF13, BIBREF14, dialog policy manager BIBREF15, and natural language generation BIBREF16. Conversational intent is adopted to capture the meaning of task content in these dialog systems BIBREF2, BIBREF17. In comparison to this work, we use a hierarchical intent scheme that includes off-task and on-task intents to capture utterance meaning. We also train the model in a multi-task fashion to predict decoupled intents and semantic slots. The major defect of a separately trained pipeline is the laborious dialog state design and annotation. In order to mitigate this problem, recent work has explored replacing independent modules with end-to-end neural networks BIBREF18, BIBREF19, BIBREF20. Our model also follows this end-to-end fashion.\n\nOver the last few years, we have witnessed a huge growth in non-task-oriented dialog systems BIBREF21, BIBREF22. Social chatbots such as Gunrock BIBREF23 were able to maintain a conversation for around ten minutes in an open domain. Recent improvements build on top of the transformer and pre-trained language models BIBREF24, BIBREF25, BIBREF26, obtained state-of-the-art results on the Persona-Chat dataset BIBREF0. Pre-trained language models are proposed to build task-oriented dialog systems to drive the progress on leveraging large amounts of available unannotated data. BIBREF27. Similarly, our approach is also built on top of the TransferTransfo framework BIBREF0. BIBREF27 budzianowski2019hello focused on collaborative tasks BIBREF28.", "All these intents are related to donation actions, which are salient on-task intents in the persuasion task. The off-task intents are the same for both tasks, including six general intents and six additional social intents. General intents are more closely related to the syntactic meaning of the sentence (open_question, yes_no_question, positive_answer, negative_answer, responsive_statement, and nonresponsive_statement) while social intents are common social actions (greeting, closing, apology, thanking,respond_to_thank, and hold).\n\nFor specific tasks, we also design a semantic slot annotation scheme for annotating sentences based on their semantic content. We identify 13 main semantic slots in the anti-scam task, for example, credit card numbers. We present a detailed semantic slot annotation in Table TABREF3. Following BIBREF1, we segment each conversation turn into single sentences and then annotate each sentence rather than turns.\n\nWe test our approach on two non-collaborative task datasets: the AntiScam dataset and the PersuasionForGood dataset BIBREF1. Both datasets are collected from the Amazon Mechanical Turk platform in the form of typing conversations and off-task dialog is interleaved in the dialog.\n\nTo enrich available non-collaborative task datasets, we created a corpus of human-human anti-scam dialogs in order to learn human elicitation strategies. We chose a popular Amazon customer service scam scenario to collect dialogs between users and attackers who aim to collect users information. We posted a role-playing task on the Amazon Mechanical Turk platform and collected a typing conversation dataset named AntiScam. We collected 220 human-human dialogs. The average conversation length is 12.45 turns and the average utterance length is 11.13 words. Only 172 out of 220 users successfully identified their partner as an attacker, suggesting that the attackers are well trained and not too easily identifiable. We recruited two expert annotators who have linguistic training to annotate 3,044 sentences in 100 dialogs, achieving a 0.874 averaged weighted kappa value.\n\nThe PersuasionForGood dataset BIBREF1 was collected from typing conversations on Amazon Mechanical Turk platform. Two workers were randomly paired, one was assigned the role of persuader, the other was persuadee. The goal of the persuader was to persuade the persuadee to donate a portion of task earning to a specific charity."]}
3 | {"question_id": "f7662b11e87c1e051e13799413f3db459ac3e19c", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a typing conversation task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["MISSA is based on the generative pre-trained transformer BIBREF32. We use an Adam optimizer with a learning rate of 6.25e-5 and $L2$ weight decay of $0.01$, we set the coefficient of language modeling loss to be 2, the coefficient of intent and slot classifiers to be 1, and the coefficient of next-utterance classifier to be 1. We first pre-train the model on the PERSONA-CHAT dataset. When fine-tuning on the AntiScam and the PersuasionForGood datasets, we use $80\\%$ data for training, $10\\%$ data for validation, and $10\\%$ data for testing. Since the original PersuasionForGood dataset is annotated with intents, we separate the original on-task and off-task intents, which are shown in Table TABREF2. To deal with the words out of the vocabulary, we conduct delexicalization to replace slot values with corresponding slot tokens during the training phase, and replace the slot tokens with pre-defined information during testing.\n\nAn example of human-human chat on AntiScam dataset is shown in Table TABREF25.", "The results are shown in Table TABREF19. We find that MISSA has higher fluency score and coherence score than MISSA-con (4.18 vs 3.78 for fluency, and 3.75 vs 3.68 for coherence), which suggests that conditioning on the system intent to generate responses improves the quality of the generated sentences. Compared with MISSA-sel, MISSA achieves better performance on all the metrics. For example, the engagement score for MISSA is 3.69 while MISSA-sel only has 2.87. This is because the response filter removed all the incoherent responses, which makes the attacker more willing to keep chatting. The ablation study shows both the conditional language generation mechanism and the response filter are essential to MISSA's good performance.\n\nWe also apply our method to the PersuasionForGood dataset. As shown in Table TABREF23, MISSA and its variants outperform the TransferTransfo and the hybrid models on all evaluation metrics. Such good performance indicates MISSA can be easily applied to a different non-collaborative task and achieve good performance. Particularly, MISSA achieves the lowest perplexity, which confirms that using conditional response generation leads to high quality responses. Compared with the result on AntiScam dataset, MISSA-con performs the best in terms of RIP and ERIP. We suspect the underlying reason is that there are more possible responses with the same intent in PersuasionForGood than in AntiScam. This also suggests that we should adjust the model structure according to the nature of the dataset.\n\nWe propose a general dialog system pipeline to build non-collaborative dialog systems, including a hierarchical annotation scheme and an end-to-end neural response generation model called MISSA. With the hierarchical annotation scheme, we can distinguish on-task and off-task intents. MISSA takes both on and off-task intents as supervision in its training and thus can deal with diverse user utterances in non-collaborative settings. Moreover, to validate MISSA's performance, we create a non-collaborate dialog dataset that focuses on deterring phone scammers. MISSA outperforms all baseline methods in terms of fluency, coherency, and user engagement on both the newly proposed anti-scam task and an existing persuasion task.", "Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor. In the second turn, both attackers were explaining the reason of requesting the billing address previously. With semantic slot information, MISSA can easily understand the attacker; but the hybrid model misunderstands that the attacker was talking about the order number, possibly because the token \u201corder\u201d appeared in the attacker's utterance. We suspect that the hybrid model's bad performance on the off-task content leads to its low coherence rating (2.76) and short dialog length (8.2).\n\nTo explore the influence of the intent-based conditional response generation method and the designed response filter, we perform an ablation study. The results are shown in Table TABREF19. We find that MISSA has higher fluency score and coherence score than MISSA-con (4.18 vs 3.78 for fluency, and 3.75 vs 3.68 for coherence), which suggests that conditioning on the system intent to generate responses improves the quality of the generated sentences. Compared with MISSA-sel, MISSA achieves better performance on all the metrics. For example, the engagement score for MISSA is 3.69 while MISSA-sel only has 2.87. This is because the response filter removed all the incoherent responses, which makes the attacker more willing to keep chatting. The ablation study shows both the conditional language generation mechanism and the response filter are essential to MISSA's good performance.\n\nWe also apply our method to the PersuasionForGood dataset.", "MISSA follows the TransferTransfo framework BIBREF0 with three modifications: (i) We first concurrently predict user's, system's intents and semantic slots; (ii) We then perform conditional generation to improve generated response's coherence. Specifically, we generate responses conditioned on the above intermediate representation (intents and slots); (iii) Finally, we generate multiple responses with the nucleus sampling strategy BIBREF5 and then apply a response filter, which contains a set of pre-defined constraints to select coherent responses. The constraints in the filter can be defined according to specific task requirements or general conversational rules.\n\nTo enrich publicly available non-collaborative task datasets, we collect a new dataset AntiScam, where users defend themselves against attackers trying to collect personal information. As non-collaborative tasks are still relatively new to the study of dialog systems, there are insufficiently many meaningful datasets for evaluation and we hope this provides a valuable example. We evaluate MISSA on the newly collected AntiScam dataset and an existing PersuasionForGood dataset. Both automatic and human evaluations suggest that MISSA outperforms multiple competitive baselines.\n\nIn summary, our contributions include: (i) We design a hierarchical intent annotation scheme and a semantic slot annotation scheme to annotate the non-collaborative dialog dataset, we also propose a carefully-designed AntiScam dataset to facilitate the research of non-collaborative dialog systems. (ii) We propose a model that can be applied to all non-collaborative tasks, outperforming other baselines on two different non-collaborative tasks. (iii) We develop an anti-scam dialog system to occupy attacker's attention and elicit their private information for social good. Furthermore, we also build a persuasion dialog system to persuade people to donate to charities. We release the code and data.\n\nThe interest in non-collaborative tasks has been increasing and there have already been several related datasets. For instance, BIBREF1 wang2019persuasion collected conversations where one participant persuades another to donate to a charity. BIBREF2 he2018decoupling collected negotiation dialogs where buyers and sellers bargain for items for sale on Craigslist. There are many other non-collaborative tasks, such as the turn-taking game BIBREF6, the multi-party game BIBREF7 and item splitting negotiation BIBREF8.", "So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents. From Table TABREF19, we observe that MISSA maintains longer conversations (14.9 turns) compared with TransferTransfo (8.5 turns), which means MISSA is better at maintaining the attacker's engagement. MISSA also has a higher task success score (1.294) than TransferTransfo (1.025), which indicates that it elicits information more strategically. In the top two dialogs (A and B) that are shown in Table TABREF21, both attackers were eliciting a credit card number in their first turns. TransferTransfo directly gave away the information, while MISSA replied with a semantically-related question \u201cwhy would you need my credit card number?\" Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor."]}
4 | {"question_id": "b584739622d0c53830e60430b13fd3ae6ff43669", "predicted_answer": "The ANTISCAM dataset consists of 220 human-human dialogs collected from a role-playing task on the Amazon Mechanical Turk platform.", "predicted_evidence": ["The ultimate holistic evaluation should be conducted by having the trained system interact with human users. Therefore we also conduct human evaluations for the dialog system built on AntiScam. We test our models and baselines with 15 college-student volunteers. Each of them is asked to pretend to be an attacker and interact with all the models for at least three times to avoid randomness. We in total collect 225 number of dialogs. Each time, volunteers are required to use similar sentences and strategies to interact with all five models and score each model based on the metrics listed below at the end of the current round. Each model receives a total of 45 human ratings, and the average score is reported as the final human-evaluation score. In total, we design five different metrics to assess the models' conversational ability whilst interacting with humans. The results are shown in Table TABREF19.\n\nFluency Fluency is used to explore different models' language generation quality.\n\nCoherence Different from single sentence's fluency, coherence focuses more on the logical consistency between sentences in each turn.\n\nEngagement In the anti-scam scenario, one of our missions is to keep engaging with the attackers to waste their time. So we directly ask volunteers (attackers) to what extend they would like to continue chatting with the system.\n\nDialog length (Length) Engagement is a subjective metric. Anti-scam system's goal is to engage user in the conversation longer in order to limit their harm to other potential victims. So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents.", "So we count the dialog length as another metric to evaluate system performance.\n\nTask Success Score (TaskSuc) The other goal of the anti-scam system is to elicit attacker's personal information. We count the average type of information (name, address and phone number) that the system obtained from attackers as the task success score.\n\nTable TABREF19 presents the main experiment results on AntiScam dataset, for both automatic evaluation metrics and human evaluation metrics. The experiment results on PersuasionForGood are shown in Table TABREF23. We observe that MISSA outperforms two baseline models (TransferTransfo and hybrid model) on almost all the metrics on both datasets. For further analysis, examples of real dialogs from the human evaluation are presented in Table TABREF21.\n\nCompared to the first TransferTransfo baseline, MISSA outperforms the TransferTransfo baseline on the on-task contents. From Table TABREF19, we observe that MISSA maintains longer conversations (14.9 turns) compared with TransferTransfo (8.5 turns), which means MISSA is better at maintaining the attacker's engagement. MISSA also has a higher task success score (1.294) than TransferTransfo (1.025), which indicates that it elicits information more strategically. In the top two dialogs (A and B) that are shown in Table TABREF21, both attackers were eliciting a credit card number in their first turns. TransferTransfo directly gave away the information, while MISSA replied with a semantically-related question \u201cwhy would you need my credit card number?\" Furthermore, in the next turn, TransferTransfo ignored the context and asked an irrelevant question \u201cwhat is your name?\u201d while MISSA was able to generate the response \u201cwhy can't you use my address?\u201d, which is consistent to the context. We suspect the improved performance of MISSA comes from our proposed annotation scheme: the semantic slot information enables MISSA to keep track of the current entities, and the intent information helps MISSA to maintain coherency and prolong conversations.\n\nCompared to the hybrid model baseline, MISSA performs better on off-task content. As shown in the bottom two dialogs in Table TABREF21, attackers in both dialogs introduced their names in their first utterances. MISSA recognized attacker's name, while the hybrid model did not. We suspect it is because the hybrid model does not have the built-in semantic slot predictor.", "$\\lambda _{LM}$, $\\lambda _{I_h}$, $\\lambda _{S_h}$, $\\lambda _{I_s}$, $\\lambda _{S_s}$, and $\\lambda _{nup}$ are the hyper-parameters that control the relative importance of every loss.\n\nMISSA can generate multiple sentences in a single system turn. Therefore, we perform system generation conditioned on predicted system intents. More specifically, during the training phase, in addition to inserting a special $<$sep$>$ token at the end of each sentence, we also insert the intent of the system response as special tokens at the head of each sentence in the system response. For example, in Figure FIGREF6, we insert a $<$pos_ans$>$ token at the head of $S_t^1$, which is the system response in green. We then use a cross entropy loss function to calculate the loss between the predicted token and the ground truth intent token. During the testing phase, the model first generates a special intent token, then after being conditioned on this intent token, the model keeps generating a sentence until it generates a $<$sep$>$ token. After that, the model continues to generate another intent token and another sentence until it generates an $<$eos$>$ token.\n\nSince we only perform conditional generation, a type of soft constraint on the predicted intent of system response, the system can still generate samples that violate simple conversation regulations, such as eliciting information that has already been provided. These corner cases may lead to fatal results in high-risk tasks, for example, health care and education. To improve the robustness of MISSA and improve its ability to generalize to more tasks, we add a response filtering module after the generation. With the nucleus sampling strategy BIBREF5, MISSA is able to generate multiple diverse candidate responses with different intents and semantic slots. We then adopt a task-specific response filtering policy to choose the best candidate response as the final output. In our anti-scam scenario, we set up a few simple rules to filter out some unreasonable candidates, for instance, eliciting the repeated information. The filtering module is easily adaptable to different domains or specific requirements, which makes our dialog system more controllable.\n\nWe evaluate MISSA on two non-collaborative task datasets.", "The intent predictor achieves a $84\\%$ accuracy and the semantic slot predictor achieves $77\\%$ on the AntiScam dataset. Then we compare the predicted values with human-annotated ground truth in the dataset to compute the response-intent prediction (RIP) and response-slot prediction (RSP).\n\nExtended Response-Intent Prediction (ERIP) $\\&$ Extended Response-Slot Prediction (ERSP) With Response-Intent Prediction, we verify the predicted intents to evaluate the coherence of the dialog. However, the real mapping between human-intent and system-intent is much more complicated as there might be multiple acceptable system-intents for the same human-intent. Therefore, we also design a metric to evaluate if the predicted system-intent is in the set of acceptable intents. Specifically, we estimate the transition probability $p(I_i|I_j)$ by counting the frequency of all the bi-gram human-intent and system-intent pairs in the training data. During the test stage, if the predicted intent matches the ground truth, we set the score as 1, otherwise we set the score as $p(I_{predict}|I_i)$ where $I_i$ is the intent of the input human utterance. We then report the average value of those scores over turns as the final extended response-intent prediction result.\n\nAutomatic metrics only validate the system\u2019s performance on a single dimension at a time. The ultimate holistic evaluation should be conducted by having the trained system interact with human users. Therefore we also conduct human evaluations for the dialog system built on AntiScam. We test our models and baselines with 15 college-student volunteers. Each of them is asked to pretend to be an attacker and interact with all the models for at least three times to avoid randomness. We in total collect 225 number of dialogs. Each time, volunteers are required to use similar sentences and strategies to interact with all five models and score each model based on the metrics listed below at the end of the current round. Each model receives a total of 45 human ratings, and the average score is reported as the final human-evaluation score. In total, we design five different metrics to assess the models' conversational ability whilst interacting with humans. The results are shown in Table TABREF19.\n\nFluency Fluency is used to explore different models' language generation quality.", "We follow the original TransferTransfo design BIBREF0 and train with undelexicalized data.\n\nHybrid Following BIBREF4 yu2017learning, we also build a hybrid dialog system by combining vanilla TransferTransfo and MISSA. Specifically, we first determine if the human utterances are on-task or off-task with human intent classifier. If the classifier decides that the utterance is on-task, we choose the response from MISSA; otherwise, we choose the response from vanilla TransferTransfo baseline.\n\nIn addition, we perform ablation studies on MISSA to show the effects of different components.\n\nMISSA-sel denotes MISSA without response filtering.\n\nMISSA-con denotes MISSA leaving out the intent token at the start of the response generation.\n\nPerplexity Since the canonical measure of a good language model is perplexity, which indicates the error rate of the expected word. We choose perplexity to evaluate the model performance.\n\nResponse-Intent Prediction (RIP) $\\&$ Response-Slot Prediction (RSP) Different from open-domain dialog systems, we care about the intents of the system response in non-collaborative tasks as we hope to know if the system response satisfies user intents. For example, in the anti-scam task, if the attacker elicits information from the system, we need to know if the system refuses or agrees to provide the information. Therefore we care about intent prediction for the generated system response. Since our baselines are more suited for social chat as they cannot produce system intents, we use the system intent and slot classifiers trained in our model to predict their responses' intents and slots. The intent predictor achieves a $84\\%$ accuracy and the semantic slot predictor achieves $77\\%$ on the AntiScam dataset. Then we compare the predicted values with human-annotated ground truth in the dataset to compute the response-intent prediction (RIP) and response-slot prediction (RSP).\n\nExtended Response-Intent Prediction (ERIP) $\\&$ Extended Response-Slot Prediction (ERSP) With Response-Intent Prediction, we verify the predicted intents to evaluate the coherence of the dialog. However, the real mapping between human-intent and system-intent is much more complicated as there might be multiple acceptable system-intents for the same human-intent. Therefore, we also design a metric to evaluate if the predicted system-intent is in the set of acceptable intents."]}
5 |
--------------------------------------------------------------------------------
/rag-foundation/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/rag-foundation/scripts/__init__.py
--------------------------------------------------------------------------------
/rag-foundation/scripts/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import fire
5 | from llama_index.core import Document
6 | from llama_index.core.node_parser import SentenceSplitter
7 | from vector_store.node import TextNode, VectorStoreQueryResult
8 | from vector_store.semantic_vector_store import SemanticVectorStore
9 | from vector_store.sparse_vector_store import SparseVectorStore
10 |
11 |
12 | def prepare_data_nodes(documents: list, chunk_size: int = 200) -> list[TextNode]:
13 | """
14 | Args:
15 | documents: List of documents.
16 | chunk_size: Chunk size for splitting the documents.
17 | Returns:
18 | text_node: List of TextNode objects.
19 | """
20 | # Load data
21 | documents = [Document(text=t) for t in documents]
22 |
23 | # Split the documents into nodes
24 | node_parser = SentenceSplitter(chunk_size=chunk_size)
25 |
26 | # Get the nodes from the documents
27 | nodes = node_parser.get_nodes_from_documents(documents)
28 |
29 | # Prepare the nodes for the vector store
30 | text_node = [
31 | TextNode(id_=str(id_), text=node.text, metadata=node.metadata)
32 | for id_, node in enumerate(nodes)
33 | ]
34 | return text_node
35 |
36 |
37 | def prepare_vector_store(documents: list, mode: str, force_index=False, chunk_size=200):
38 | """
39 | Prepare the vector store with the given documents.
40 | Args:
41 | documents: List of documents to be indexed.
42 | mode: Mode of the vector store. Choose either `sparse` or `semantic`.
43 | force_index: Whether to force indexing the documents.
44 | chunk_size: Chunk size for splitting the documents.
45 | Returns:
46 | vector_store: Vector store object.
47 | """
48 | if mode == "sparse":
49 | vector_store = SparseVectorStore(
50 | persist=True,
51 | saved_file="data/sparse.csv",
52 | metadata_file="data/sparse_metadata.json",
53 | force_index=force_index,
54 | )
55 | elif mode == "semantic":
56 | vector_store = SemanticVectorStore(
57 | persist=True,
58 | saved_file="data/dense.csv",
59 | force_index=force_index,
60 | )
61 | else:
62 | raise ValueError("Invalid mode. Choose either `sparse` or `semantic`.")
63 |
64 | if force_index:
65 | nodes = prepare_data_nodes(documents=documents, chunk_size=chunk_size)
66 | vector_store.add(nodes)
67 |
68 | return vector_store
69 |
70 |
71 | class RAGPipeline:
72 | def __init__(self, vector_store: SemanticVectorStore, prompt_template: str):
73 | self.vector_store = vector_store
74 | self.prompt_template = prompt_template
75 |
76 | # choose your model from groq or openai/azure
77 | self.model = None
78 |
79 | # GROQ
80 | # from langchain_groq import ChatGroq
81 | # self.model = ChatGroq(model="llama3-70b-8192", temperature=0)
82 |
83 | # OpenAI
84 | # from langchain_openai import ChatOpenAI
85 | # self.model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
86 |
87 | def retrieve(self, query: str, top_k: int = 5) -> VectorStoreQueryResult:
88 | query_result = self.vector_store.query(query, top_k=top_k)
89 | return query_result
90 |
91 | def answer(self, query: str, top_k: int = 5) -> tuple[str, list[str]]:
92 | # Generate openai code to answer the query
93 | result = self.retrieve(query, top_k=top_k)
94 | context_list = [node.text for node in result.nodes]
95 | context = "\n\n".join(context_list)
96 |
97 | self.prompt_template = (
98 | f"""Question: {query}\n\nGiven context: {context}\n\nAnswer:"""
99 | )
100 |
101 | if not self.model:
102 | raise ValueError("Model not found. Please initialize the model first.")
103 | try:
104 | response = self.model.invoke(self.prompt_template)
105 | except Exception as e:
106 | raise Exception(f"Error in calling the model: {e}")
107 | return response.content, context_list
108 |
109 |
110 | def main(
111 | data_path: Path = Path("data/qasper-test-v0.3.json"),
112 | output_path: Path = Path("predictions.jsonl"),
113 | mode: str = "sparse",
114 | force_index: bool = False,
115 | print_context: bool = False,
116 | chunk_size: int = 200,
117 | top_k: int = 5,
118 | retrieval_only: bool = False,
119 | ):
120 | # Generate doc string
121 | """
122 | Args:
123 | data_path: Path to the qasper data file.
124 | output_path: Path to save the predictions.
125 | mode: Mode of the vector store. Choose either `sparse` or `semantic`.
126 | force_index: Whether to force indexing the documents.
127 | print_context: Whether to print the context.
128 | chunk_size: Chunk size for splitting the documents.
129 | top_k: Number of top k documents to retrieve.
130 | retrieval_only: Whether to retrieve only.
131 | Returns:
132 | None
133 | """
134 | # Load the data
135 | raw_data = json.load(open(data_path, "r", encoding="utf-8"))
136 |
137 | question_ids, predicted_answers, predicted_evidences = [], [], []
138 |
139 | # NOTE: qasper has many papers, each paper has multiple sections
140 | # we will loop through each paper, gather the full text of each section
141 | # and prepare the documents for the vector store
142 | # and answer the query
143 | for _, values in raw_data.items():
144 | # for each paper in qasper
145 | documents = []
146 |
147 | for section in values["full_text"]:
148 | # for each section in the paper
149 | documents += section["paragraphs"]
150 |
151 | # initialize the vector store
152 | # and rag pipeline
153 | # Remember to force_index=True if you want to override the existing index
154 | vector_store = prepare_vector_store(
155 | documents, mode=mode, force_index=force_index, chunk_size=chunk_size
156 | )
157 |
158 | # NOTE: Should design your own template
159 | prompt_template = """Question: {}\n\nGiven context: {}\n\nAnswer:"""
160 |
161 | rag_pipeline = RAGPipeline(vector_store, prompt_template=prompt_template)
162 |
163 | for q in values["qas"]:
164 | # for each question in the paper
165 | query = q["question"]
166 | question_ids.append(q["question_id"])
167 |
168 | # NOTE: If you just want to retrieve the top_k relevant documents
169 | # set retrieval_only=True
170 | # Otherwise, it will answer the question
171 | if retrieval_only:
172 | result = rag_pipeline.retrieve(query, top_k=top_k)
173 | context_list = [node.text for node in result.nodes]
174 |
175 | if print_context:
176 | for i, context in enumerate(context_list):
177 | print(f"Relevent context {i + 1}:", context)
178 | print("\n\n")
179 |
180 | predicted_evidences.append(context_list)
181 | predicted_answers.append("")
182 |
183 | else:
184 | predicted_answer, context_list = rag_pipeline.answer(query, top_k=top_k)
185 |
186 | # Just In Case. Print out the context list for each question
187 | # if needed.
188 | if print_context:
189 | for i, context in enumerate(context_list):
190 | print(f"Relevent context {i + 1}:", context)
191 | print("\n\n")
192 |
193 | print("LLM Answer")
194 | print(predicted_answer)
195 |
196 | predicted_evidences.append(context_list)
197 | predicted_answers.append(predicted_answer)
198 |
199 | # save the results
200 | with open(output_path, "w") as f:
201 | for question_id, predicted_answer, predicted_evidence in zip(
202 | question_ids, predicted_answers, predicted_evidences
203 | ):
204 | f.write(
205 | json.dumps(
206 | {
207 | "question_id": question_id,
208 | "predicted_answer": predicted_answer,
209 | "predicted_evidence": predicted_evidence,
210 | }
211 | )
212 | )
213 | f.write("\n")
214 |
215 |
216 | if __name__ == "__main__":
217 | fire.Fire(main)
218 |
--------------------------------------------------------------------------------
/rag-foundation/setup.cfg:
--------------------------------------------------------------------------------
1 | # Project-wide configuration file, can be used for package metadata and other tool configurations
2 | # Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments
3 | # Local usage: pip install pre-commit, pre-commit run --all-files
4 |
5 | [isort]
6 | # https://pycqa.github.io/isort/docs/configuration/options.html
7 | line_length =
8 | # see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html
9 | multi_line_output = 0
10 | include_trailing_comma = True
11 |
12 | [black]
13 | line_length = 120
14 |
15 | [flake8]
16 | # https://flake8.pycqa.org/en/latest/user/options.html
17 | max-line-length = 120
18 | verbose = 2
19 | format = pylint
20 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
21 | # see: https://www.flake8rules.com/
22 | select = B, C, E, F, W, T4, B9
23 | ignore = C101, C407, C408, E203, E402, E731, W503
24 | # C101: Coding magic comment not found
25 | # C407: Unnecessary comprehension - can take a generator
26 | # C408: Unnecessary call - rewrite as a literal
27 | # E203 Whitespace before ':'
28 | # E402: module level import not at top of file
29 | # E731: Do not assign a lambda expression, use a def
30 | # W503 Line break occurred before a binary operator
31 | per-file-ignores =
32 | **/__init__.py: F401, F403, F405
33 | # F401: module imported but unused
34 | # F403: ‘from module import *’ used; unable to detect undefined names
35 | # F405: Name may be undefined, or defined from star imports: module
36 | # E501: ignore line length in constants file
37 |
--------------------------------------------------------------------------------
/rag-foundation/vector_store/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/rag-foundation/vector_store/__init__.py
--------------------------------------------------------------------------------
/rag-foundation/vector_store/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import pandas as pd
5 | from loguru import logger
6 | from pydantic import BaseModel, Field
7 |
8 | from .node import BaseNode, TextNode
9 |
10 |
11 | class BaseVectorStore(BaseModel):
12 | """Simple custom Vector Store.
13 |
14 | Stores documents in a simple in-memory dict.
15 | """
16 |
17 | force_index: bool = False
18 | persist: bool = True
19 | node_dict: dict[str, BaseNode] = Field(default_factory=dict)
20 | node_list: list[BaseNode] = Field(default_factory=list)
21 | saved_file: str = "rag-foundation/data/sematic_vectordb_nodes.csv"
22 | csv_file: Path = Path(saved_file)
23 |
24 | class Config:
25 | arbitrary_types_allowed = True
26 |
27 | def __init__(self, **data):
28 | super().__init__(**data)
29 | self.csv_file = Path(self.saved_file)
30 | self._setup_store()
31 |
32 | def _setup_store(self):
33 | if self.persist:
34 | if self.force_index:
35 | self._reset_csv()
36 | self._initialize_csv()
37 | self._load_from_csv()
38 |
39 | def _initialize_csv(self):
40 | """Initialize the CSV file if it doesn't exist."""
41 | if not self.csv_file.exists():
42 | logger.warning(
43 | f"Cannot find CSV file at `{self.saved_file}`, creating a new one..."
44 | )
45 | os.makedirs(self.csv_file.parent, exist_ok=True)
46 | with open(self.csv_file, "w") as f:
47 | f.write("id,text,embedding,metadata\n")
48 |
49 | def _load_from_csv(self):
50 | """Load the node_dict from the CSV file."""
51 | if self.csv_file.exists():
52 | df = pd.read_csv(self.csv_file)
53 | for _, row in df.iterrows():
54 | node_id = row["id"]
55 | text = row["text"]
56 | try:
57 | embedding = eval(row["embedding"])
58 | metadata = eval(row["metadata"])
59 | except TypeError:
60 | embedding = None
61 | metadata = None
62 | self.node_dict[node_id] = TextNode(
63 | id_=str(node_id), text=text, embedding=embedding, metadata=metadata
64 | )
65 |
66 | def _update_csv(self):
67 | """Update the CSV file with the current node_dict if persist is True."""
68 | if self.persist:
69 | data = {"id": [], "text": [], "embedding": [], "metadata": []}
70 | for key, node in self.node_dict.items():
71 | data["id"].append(key)
72 | data["text"].append(node.text)
73 | data["embedding"].append(node.embedding)
74 | data["metadata"].append(node.metadata)
75 | df = pd.DataFrame(data)
76 | df.to_csv(self.csv_file, index=False)
77 | else:
78 | logger.warning("`persist` is set to `False`, not updating CSV file.")
79 |
80 | def _reset_csv(self):
81 | """Reset the CSV file by deleting it if it exists."""
82 | if self.csv_file.exists():
83 | self.csv_file.unlink()
84 |
85 | def get(self):
86 | """Get embedding."""
87 |
88 | def add(self):
89 | """Add nodes to index."""
90 |
91 | def delete(self) -> None:
92 | """Delete nodes using with node_id."""
93 |
94 | def query(self):
95 | """Get nodes for response."""
96 |
--------------------------------------------------------------------------------
/rag-foundation/vector_store/node.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, List, Optional, Sequence
3 |
4 | from pydantic import BaseModel
5 |
6 |
7 | class BaseNode(BaseModel):
8 | id_: str
9 | embedding: Optional[List[float]] = None
10 | metadata: Optional[Dict[str, Any]] = None
11 |
12 |
13 | class TextNode(BaseNode):
14 | text: str | List[str]
15 |
16 |
17 | @dataclass
18 | class VectorStoreQueryResult:
19 | """Vector store query result."""
20 |
21 | nodes: Optional[Sequence[BaseNode]] = None
22 | similarities: Optional[List[float]] = None
23 | ids: Optional[List[str]] = None
24 |
--------------------------------------------------------------------------------
/rag-foundation/vector_store/semantic_vector_store.py:
--------------------------------------------------------------------------------
1 | # autoflake: off
2 | # flake8: noqa: F841
3 | import sys
4 | from typing import Dict, List, cast
5 |
6 | import numpy as np
7 | from loguru import logger
8 | from sentence_transformers import SentenceTransformer
9 |
10 | from .base import BaseVectorStore
11 | from .node import TextNode, VectorStoreQueryResult
12 |
13 | logger.add(
14 | sink=sys.stdout,
15 | colorize=True,
16 | format="{time}{message}",
17 | )
18 |
19 |
20 | class SemanticVectorStore(BaseVectorStore):
21 | """Semantic Vector Store using SentenceTransformer embeddings."""
22 |
23 | saved_file: str = "rag-foundation/data/test_db_00.csv"
24 | embed_model_name: str = "all-MiniLM-L6-v2"
25 | embed_model: SentenceTransformer = SentenceTransformer(embed_model_name)
26 |
27 | def __init__(self, **data):
28 | super().__init__(**data)
29 | self._setup_store()
30 |
31 | def get(self, text_id: str) -> TextNode:
32 | """Get node."""
33 | try:
34 | return self.node_dict[text_id]
35 | except KeyError:
36 | logger.error(f"Node with id `{text_id}` not found.")
37 | return None
38 |
39 | def add(self, nodes: List[TextNode]) -> List[str]:
40 | """Add nodes to index."""
41 | for node in nodes:
42 | if node.embedding is None:
43 | logger.info(
44 | "Found node without embedding, calculating "
45 | f"embedding with model {self.embed_model_name}"
46 | )
47 | node.embedding = self._get_text_embedding(node.text)
48 | self.node_dict[node.id_] = node
49 | self._update_csv() # Update CSV after adding nodes
50 | return [node.id_ for node in nodes]
51 |
52 | def _get_text_embedding(self, text: str) -> List[float]:
53 | """Calculate embedding."""
54 | return self.embed_model.encode(text).tolist()
55 |
56 | def delete(self, node_id: str, **delete_kwargs: Dict) -> None:
57 | """Delete nodes using node_id."""
58 | if node_id in self.node_dict:
59 | del self.node_dict[node_id]
60 | self._update_csv() # Update CSV after deleting nodes
61 | else:
62 | logger.error(f"Node with id `{node_id}` not found.")
63 |
64 | def _calculate_similarity(
65 | self,
66 | query_embedding: List[float],
67 | doc_embeddings: List[List[float]],
68 | doc_ids: List[str],
69 | similarity_top_k: int = 3,
70 | ) -> tuple[List[float], List[str]]:
71 | """Get top nodes by similarity to the query."""
72 | qembed_np = np.array(query_embedding)
73 | dembed_np = np.array(doc_embeddings)
74 |
75 | # calculate the dot product of
76 | # the query embedding with the document embeddings
77 | # HINT: np.dot
78 | "Your code here"
79 | dproduct_arr = None
80 | # calculate the cosine similarity
81 | # by dividing the dot product by the norm
82 | # HINT: np.linalg.norm
83 | "Your code here"
84 | cos_sim_arr = None
85 |
86 | # get the indices of the top k similarities
87 | "Your code here"
88 | similarities = None
89 | node_ids = None
90 |
91 | return similarities, node_ids
92 |
93 | def query(self, query: str, top_k: int = 3) -> VectorStoreQueryResult:
94 | """Query similar nodes."""
95 | query_embedding = cast(List[float], self._get_text_embedding(query))
96 | doc_embeddings = [node.embedding for node in self.node_dict.values()]
97 | doc_ids = list(self.node_dict.keys())
98 | if len(doc_embeddings) == 0:
99 | logger.error("No documents found in the index.")
100 | result_nodes, similarities, node_ids = [], [], []
101 | else:
102 | similarities, node_ids = self._calculate_similarity(
103 | query_embedding, doc_embeddings, doc_ids, top_k
104 | )
105 | result_nodes = [self.node_dict[node_id] for node_id in node_ids]
106 | return VectorStoreQueryResult(
107 | nodes=result_nodes, similarities=similarities, ids=node_ids
108 | )
109 |
110 | def batch_query(
111 | self, query: List[str], top_k: int = 3
112 | ) -> List[VectorStoreQueryResult]:
113 | """Batch query similar nodes."""
114 | return [self.query(q, top_k) for q in query]
115 |
--------------------------------------------------------------------------------
/rag-foundation/vector_store/sparse_vector_store.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: F841
2 | import json
3 | import sys
4 | from multiprocessing import Pool, cpu_count
5 | from pathlib import Path
6 | from typing import ClassVar, Dict, List
7 |
8 | import numpy as np
9 | from loguru import logger
10 | from pydantic import Field
11 | from transformers import AutoTokenizer
12 |
13 | from .base import BaseVectorStore
14 | from .node import TextNode, VectorStoreQueryResult
15 |
16 | logger.add(
17 | sink=sys.stdout,
18 | colorize=True,
19 | format="{time}{message}",
20 | )
21 |
22 | TOKENIZER = AutoTokenizer.from_pretrained(
23 | "google-bert/bert-base-uncased", max_length=200, truncation=True
24 | )
25 |
26 |
27 | class SparseVectorStore(BaseVectorStore):
28 | """VectorStore2 (add/get/delete implemented)."""
29 |
30 | saved_file: str = "rag-foundation/data/test_db_10.csv"
31 | metadata_file: Path = Path("rag-foundation/data/sparse_metadata_tmp.json")
32 | tokenizer: ClassVar[AutoTokenizer] = TOKENIZER
33 | corpus_size: int = Field(default=0, init=False)
34 | avgdl: float = Field(default=0.0, init=False)
35 | doc_freqs: List[Dict[str, int]] = Field(default_factory=list, init=False)
36 | idf: Dict[str, float] = Field(default_factory=dict, init=False)
37 | doc_len: List[int] = Field(default_factory=list, init=False)
38 | nd: int = Field(default=0, init=False)
39 |
40 | # Algorithm specific parameters
41 | k1: float = Field(default=1.2)
42 | b: float = Field(default=0.75)
43 | delta: float = Field(default=0.25)
44 |
45 | def __init__(self, **data):
46 | super().__init__(**data)
47 | if len(self.node_dict) > 0:
48 | self.metadata_file = Path(self.metadata_file)
49 | if self.metadata_file.exists() and not self.force_index:
50 | self._load_from_json()
51 | else:
52 | self._initialize_bm25_assets()
53 |
54 | self.node_list = list(self.node_dict.values())
55 |
56 | def _initialize_bm25_assets(self):
57 | """Initialize BM25 assets from the node dictionary."""
58 | self.corpus_size = 0
59 | self.avgdl = 0
60 | self.doc_freqs = []
61 | self.idf = {}
62 | self.doc_len = []
63 | self.nd = 0
64 |
65 | corpus = self._tokenize_text([node.text for node in self.node_list])
66 | self._initialize(corpus)
67 | content = {
68 | "corpus_size": self.corpus_size,
69 | "avgdl": self.avgdl,
70 | "doc_freqs": self.doc_freqs,
71 | "idf": self.idf,
72 | "doc_len": self.doc_len,
73 | "nd": self.nd,
74 | }
75 | with open(self.metadata_file, "w") as f:
76 | json.dump(content, f)
77 |
78 | def _load_from_json(self):
79 | with open(self.metadata_file, "r") as f:
80 | content = json.load(f)
81 | self.corpus_size = content["corpus_size"]
82 | self.avgdl = content["avgdl"]
83 | self.doc_freqs = content["doc_freqs"]
84 | self.idf = content["idf"]
85 | self.doc_len = content["doc_len"]
86 | self.nd = content["nd"]
87 |
88 | def _initialize(self, corpus: List[List[str]]):
89 | nd = {} # word -> number of documents with word
90 | num_doc = 0
91 | for document in corpus:
92 | self.doc_len.append(len(document))
93 | num_doc += len(document)
94 |
95 | frequencies = {}
96 | for word in document:
97 | if word not in frequencies:
98 | frequencies[word] = 0
99 | frequencies[word] += 1
100 | self.doc_freqs.append(frequencies)
101 |
102 | for word, freq in frequencies.items():
103 | try:
104 | nd[word] += 1
105 | except KeyError:
106 | nd[word] = 1
107 |
108 | self.corpus_size += 1
109 |
110 | self.avgdl = num_doc / self.corpus_size
111 | self.idf = {
112 | word: self._calculate_idf(doc_count, self.corpus_size)
113 | for word, doc_count in nd.items()
114 | }
115 |
116 | def _calculate_idf(self, doc_count: int, corpus_size: int) -> float:
117 | # Calculate the inverse document frequency for a word
118 | # HINT: Use the formula provided in the BM25 algorithm and np.log()
119 | "Your code here"
120 | idf_score = None
121 | return idf_score
122 |
123 | def _tokenize_text(self, corpus: List[str] | str):
124 | if isinstance(corpus, str):
125 | return self.tokenizer.tokenize(corpus)
126 | else:
127 | pool = Pool(cpu_count())
128 | tokenized_corpus = pool.map(self.tokenizer.tokenize, corpus)
129 | return tokenized_corpus
130 |
131 | def add(self, nodes: List[TextNode]) -> List[str]:
132 | """Add nodes to index."""
133 | for node in nodes:
134 | self.node_dict[node.id_] = node
135 | self._update_csv() # Update CSV after adding nodes
136 |
137 | # Reinitialize BM25 assets after adding new nodes
138 | self._initialize_bm25_assets()
139 |
140 | return [node.id_ for node in nodes]
141 |
142 | def get(self, text_id: str) -> TextNode:
143 | """Get node."""
144 | try:
145 | return self.node_dict[text_id]
146 | except KeyError:
147 | logger.error(f"Node with id `{text_id}` not found.")
148 | return None
149 |
150 | def get_scores(self, query: str):
151 | score = np.zeros(self.corpus_size)
152 | tokenized_query = self._tokenize_text(query)
153 | for q in tokenized_query:
154 | # calulate the score for each token in the query
155 | # HINT: use self.doc_freqs, self.idf, self.corpus_size, self.avgdl
156 | "Your code here"
157 | cur_score = None
158 | score += cur_score
159 | return score
160 |
161 | def query(self, query: str, top_k: int = 3) -> VectorStoreQueryResult:
162 | """Query similar nodes.
163 |
164 | Args:
165 | query (str): _description_
166 | top_k (int, optional): _description_. Defaults to 3.
167 |
168 | Returns:
169 | List[TextNode]: _description_
170 | """
171 | scores = self.get_scores(query)
172 | best_ids = np.argsort(scores)[::-1][:top_k]
173 | nodes = [self.node_list[node_id] for node_id in best_ids]
174 | return VectorStoreQueryResult(
175 | nodes=nodes,
176 | similarities=[scores[doc_id] for doc_id in best_ids],
177 | ids=[node.id_ for node in nodes],
178 | )
179 |
180 | def batch_query(
181 | self, query: List[str], top_k: int = 3
182 | ) -> List[VectorStoreQueryResult]:
183 | """Batch query similar nodes.
184 |
185 | Args:
186 | query (List[str]): _description_
187 | top_k (int, optional): _description_. Defaults to 3.
188 |
189 | Returns:
190 | List[VectorStoreQueryResult]: _description_
191 | """
192 | return [self.query(q, top_k) for q in query]
193 |
--------------------------------------------------------------------------------
/streamlit_demo/.gitignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | __pycache__/
3 | .idea/
4 | app_data/
5 |
--------------------------------------------------------------------------------
/streamlit_demo/README.md:
--------------------------------------------------------------------------------
1 | # Streamlit for Object Detection
2 |
3 | ---
4 |
5 | ## Quick Usage
6 |
7 | Install requirements
8 |
9 | ```bash
10 | conda create --name streamlit_demo python=3.11
11 | conda activate streamlit_demo
12 |
13 | pip install -r requirements.txt
14 | ```
15 |
16 | Run app
17 |
18 | ```bash
19 | python launch.py
20 | ```
21 |
22 | ## Screenshots
23 |
24 | 
--------------------------------------------------------------------------------
/streamlit_demo/assets/screenshot_app.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/assets/screenshot_app.png
--------------------------------------------------------------------------------
/streamlit_demo/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | APP_DATA_DIR = Path(__file__).parent / "app_data"
5 | os.makedirs(APP_DATA_DIR, exist_ok=True)
6 |
7 | FEEDBACK_DIR = APP_DATA_DIR / "feedback"
8 | os.makedirs(FEEDBACK_DIR, exist_ok=True)
9 |
10 | FEEDBACK_SQL_PATH = f"sqlite:///{FEEDBACK_DIR / 'feedback.sql'}"
11 |
12 | YOLO_OPTIONS = [
13 | "yolov8s.pt",
14 | "yolov8n.pt"
15 | ]
16 |
17 | YOLO_SUPPORTED_EXTENSIONS = ["jpg", "png", "jpeg"]
18 |
19 | USER_DATA_DIR = APP_DATA_DIR / "user_data" / "images"
20 | os.makedirs(USER_DATA_DIR, exist_ok=True)
21 |
22 | AI_MODEL_CONFIGS = {
23 | "yolov8": {
24 | "model_name": "yolov8s.pt",
25 | "device": "cuda"
26 | }
27 | }
28 | AI_MODEL = "yolov8"
29 |
30 | CLASSES = ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light',
31 | 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow',
32 | 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee',
33 | 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard',
34 | 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple',
35 | 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch',
36 | 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone',
37 | 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear',
38 | 'Hair drier', 'Toothbrush']
39 |
--------------------------------------------------------------------------------
/streamlit_demo/launch.py:
--------------------------------------------------------------------------------
1 | from shared.views import App
2 | from shared.utils.log import custom_logger
3 | from shared.utils.pages import set_page_config
4 |
5 | set_page_config()
6 | custom_logger()
7 |
8 | app = App()
9 | app.view(key="app")
10 |
--------------------------------------------------------------------------------
/streamlit_demo/lessons/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/lessons/__init__.py
--------------------------------------------------------------------------------
/streamlit_demo/lessons/cache_flow.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 |
4 | import streamlit as st
5 | import pandas as pd
6 |
7 | from functools import lru_cache
8 |
9 |
10 | class Model:
11 | def __init__(self, dct):
12 | self.dct = dct
13 |
14 |
15 | def experiment_1():
16 | st.code('''
17 | class Model:
18 | def __init__(self, dct):
19 | self.dct = dct
20 |
21 | # @lru_cache(1)
22 | # @st.cache_data
23 | @st.cache_resource
24 | def load_data(dct: dict) -> Model:
25 | print("I will go sleep for 3s")
26 | time.sleep(3)
27 |
28 | model = Model(dct)
29 | return model
30 |
31 | data_dct = {
32 | 'Column1': [1, 2, 3, 4, 5],
33 | 'Column2': ['A', 'B', 'C', 'D', 'E'],
34 | 'Column3': [10.5, 20.5, 30.5, 40.5, 50.5],
35 | 'Column4': [True, False, True, False, True]
36 | }
37 |
38 | model = load_data(data_dct)
39 | st.json(model.dct)
40 |
41 | model.dct = {}
42 | st.button("Rerun")
43 | ''', language='python')
44 |
45 | # @lru_cache(1)
46 | @st.cache_data
47 | # @st.cache_resource
48 | def load_data(dct: dict) -> Model:
49 | print("I will go sleep for 3s")
50 | time.sleep(3)
51 |
52 | model = Model(dct)
53 | return model
54 |
55 | data_dct = {
56 | 'Column1': [1, 2, 3, 4, 5],
57 | 'Column2': ['A', 'B', 'C', 'D', 'E'],
58 | 'Column3': [10.5, 20.5, 30.5, 40.5, 50.5],
59 | 'Column4': [True, False, True, False, True]
60 | }
61 |
62 | model = load_data(data_dct)
63 | st.json(model.dct)
64 |
65 | btn = st.button("Rerun")
66 | if btn:
67 | model.dct = {}
68 |
69 |
70 | experiment_1()
71 |
--------------------------------------------------------------------------------
/streamlit_demo/lessons/execution_flow.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import streamlit as st
3 |
4 | try:
5 | st.set_page_config(
6 | page_title="Execution Flow",
7 | page_icon="🤖",
8 | layout="wide",
9 | initial_sidebar_state="expanded"
10 | )
11 | finally:
12 | pass
13 |
14 |
15 | def experiment_1():
16 | st.code('''
17 | st.info(str(datetime.datetime.now()))
18 |
19 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1)
20 | print(magic_number)
21 |
22 | btn = st.button("Submit")
23 | input_a = None
24 | if btn:
25 | print("Enter btn function", datetime.datetime.now())
26 | st.toast("Button pressed")
27 | input_a = f"Hello word. Your magic number is: {magic_number}"
28 |
29 | st.info(magic_number)
30 | st.info(input_a)
31 | ''', language="python")
32 |
33 | st.info(str(datetime.datetime.now()))
34 |
35 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1)
36 | print(magic_number)
37 |
38 | btn = st.button("Submit")
39 | input_a = None
40 | if btn:
41 | print("Enter btn function", datetime.datetime.now())
42 | st.toast("Button pressed")
43 | input_a = f"Hello word. Your magic number is: {magic_number}"
44 |
45 | st.info(magic_number)
46 | st.info(input_a)
47 |
48 |
49 | def experiment_2():
50 | st.code('''
51 | st.info(str(datetime.datetime.now()))
52 |
53 | with st.form("form", clear_on_submit=True):
54 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1)
55 | print(magic_number)
56 |
57 | btn = st.form_submit_button("Submit")
58 |
59 | if btn:
60 | print("Enter btn function", datetime.datetime.now())
61 | st.info("Hello World")
62 |
63 | st.info(magic_number)
64 | ''', language="python")
65 | st.info(str(datetime.datetime.now()))
66 |
67 | with st.form("form", clear_on_submit=True):
68 | magic_number = st.slider("Magic number", min_value=0., max_value=1., step=0.1)
69 | print(magic_number)
70 |
71 | btn = st.form_submit_button("Submit")
72 |
73 | if btn:
74 | print("Enter btn function", datetime.datetime.now())
75 | st.info("Hello World")
76 |
77 | st.info(magic_number)
78 |
79 |
80 | def experiment_3():
81 |
82 | pass
83 |
84 |
85 | cols = st.columns(2)
86 |
87 | with cols[0]:
88 | experiment_1()
89 |
90 |
91 | with cols[1]:
92 | experiment_2()
93 |
94 |
--------------------------------------------------------------------------------
/streamlit_demo/lessons/layout.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 |
4 | try:
5 | st.set_page_config(
6 | page_title="Execution Flow",
7 | page_icon="🤖",
8 | layout="wide",
9 | initial_sidebar_state="expanded"
10 | )
11 | finally:
12 | pass
13 |
14 |
15 | def side_bar_view():
16 | st.header("Model Configurations")
17 | st.info("Check out the documentation "
18 | "at [link](https://docs.ultralytics.com/modes/predict/#inference-sources)")
19 |
20 | key = "sidebar"
21 | with st.form(f"{key}_upload", clear_on_submit=True):
22 | upload_image = st.file_uploader(
23 | "Upload Image(s)",
24 | accept_multiple_files=False,
25 | type=["png", "jpg", "jpeg"],
26 | key=f"{key}_upload_images"
27 | )
28 |
29 | col1, col2 = st.columns(2)
30 | with col1:
31 | augment = st.radio(
32 | "Augment",
33 | (True, False),
34 | horizontal=True
35 | )
36 | with col2:
37 | agnostic_nms = st.radio(
38 | "Agnostic NMS",
39 | (True, False),
40 | horizontal=True
41 | )
42 | image_size = st.number_input(
43 | "Image Size",
44 | value=640,
45 | step=32,
46 | min_value=640,
47 | max_value=1280
48 | )
49 | min_iou = st.slider(
50 | "Minimum IOU",
51 | min_value=0.0,
52 | max_value=1.0,
53 | value=0.5,
54 | step=0.01
55 | )
56 | min_confident_score = st.slider(
57 | "Minimum Confidence Score",
58 | min_value=0.0,
59 | max_value=1.0,
60 | value=0.2,
61 | step=0.01
62 | )
63 |
64 | submit_btn = st.form_submit_button(
65 | label="Upload",
66 | type="primary",
67 | use_container_width=True
68 | )
69 |
70 |
71 | def col_1_view():
72 | st.image("m10.jpg")
73 |
74 |
75 | def col_2_view():
76 | dummy_counting_dct = {
77 | "Person": 1
78 | }
79 |
80 | with st.container(border=True):
81 | st.markdown("**Counting**")
82 | st.json(dummy_counting_dct)
83 |
84 | with st.expander(label="Object Detail", expanded=True):
85 | cls = st.selectbox(label="Class", options=["Person", "Animal"], index=0)
86 |
87 | st.markdown(f"Confident score :red[0.92]")
88 |
89 |
90 | with st.sidebar:
91 | side_bar_view()
92 |
93 | image_col, info_col = st.columns([8, 2])
94 |
95 | with image_col:
96 | col_1_view()
97 |
98 | with info_col:
99 | col_2_view()
100 |
--------------------------------------------------------------------------------
/streamlit_demo/lessons/m10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/lessons/m10.jpg
--------------------------------------------------------------------------------
/streamlit_demo/requirements.txt:
--------------------------------------------------------------------------------
1 | loguru==0.7.2
2 | numpy==2.0.1
3 | opencv_python==4.10.0.84
4 | pandas==2.2.2
5 | Pillow==10.4.0
6 | Requests==2.32.3
7 | SQLAlchemy==2.0.31
8 | streamlit==1.36.0
9 | torch==2.0.1
10 | ultralytics==8.2.64
11 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/crud/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/crud/__init__.py
--------------------------------------------------------------------------------
/streamlit_demo/shared/crud/feedbacks.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 |
3 | from shared.models import Feedback
4 | from shared.models.engine import Session
5 |
6 |
7 | class FeedbackCRUD:
8 | def __init__(self, session: Session):
9 | self.session = session
10 |
11 | def create(self, image_path: str, data: dict) -> bool:
12 | existed_feedback = self.get_by_image_path(image_path)
13 | if existed_feedback:
14 | self.delete_by_id(existed_feedback.id)
15 | logger.info(f"Image path: {image_path} exists. Deleted")
16 |
17 | feedback = Feedback(image_path=image_path, data=data)
18 | self.session.add(feedback)
19 | self.session.commit()
20 |
21 | logger.info(f"Added 1 row")
22 | return True
23 |
24 | def delete_by_id(self, feedback_id: int) -> bool:
25 | (
26 | self.session
27 | .query(Feedback)
28 | .filter(Feedback.id == feedback_id)
29 | .delete(synchronize_session=False)
30 | )
31 | return True
32 |
33 | def get_by_image_path(self, image_path: str) -> Feedback | None:
34 | result = (
35 | self.session
36 | .query(Feedback)
37 | .filter(Feedback.image_path == image_path)
38 | .first()
39 | )
40 |
41 | return result
42 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .engine import (
2 | engine
3 | )
4 | from .models import (
5 | Feedback, Base
6 | )
7 |
8 | Base.metadata.create_all(engine)
9 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/models/engine.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import create_engine
2 | from sqlalchemy.orm import sessionmaker
3 |
4 | from constants import FEEDBACK_SQL_PATH
5 |
6 |
7 | engine = create_engine(FEEDBACK_SQL_PATH)
8 | Session = sessionmaker(
9 | bind=engine,
10 | )
11 |
12 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/models/models.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import Column, Integer, JSON, String
2 | from sqlalchemy.ext.declarative import declarative_base
3 |
4 |
5 | Base = declarative_base()
6 |
7 |
8 | class Feedback(Base):
9 | __tablename__ = 'Feedback'
10 | id = Column(Integer, primary_key=True, autoincrement=True)
11 | image_path = Column(String)
12 | data = Column(JSON, default=dict())
13 |
14 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/models_ai/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import streamlit as st
4 |
5 | from .base import BaseAIModel
6 | from .yolov8 import Yolov8
7 |
8 |
9 | @st.cache_resource
10 | def get_ai_model(name: str, model_params: dict) -> BaseAIModel | None:
11 | factory: dict[str, BaseAIModel] = {
12 | "yolov8": Yolov8(**model_params),
13 | }
14 |
15 | return factory.get(name, None)
16 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/models_ai/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 |
4 | import numpy as np
5 |
6 |
7 | class BaseAIModel(ABC):
8 | @abstractmethod
9 | def process(self, image_in: Path | str | np.ndarray, *args, **kwargs) -> Path:
10 | ...
11 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/models_ai/yolov8.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Literal
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from ultralytics.models import YOLO
8 | from ultralytics.engine.results import Results
9 | from loguru import logger
10 |
11 | from .base import BaseAIModel
12 | from shared.schemas import Parameters, ModelOutput
13 |
14 |
15 | class Yolov8(BaseAIModel):
16 | def __init__(self, model_name: str, device: Literal["cpu", "cuda"] = "cuda"):
17 | self._model = YOLO(model_name, task="detect")
18 |
19 | if device in ["cuda"] and torch.cuda.is_available():
20 | self._device = torch.device(device)
21 | else:
22 | self._device = torch.device("cpu")
23 |
24 | self._model.to(self._device)
25 |
26 | @staticmethod
27 | def get_default() -> dict:
28 | return {
29 | "augment": False,
30 | "agnostic_nms": False,
31 | "imgsz": 640,
32 | "iou": 0.5,
33 | "conf": 0.01,
34 | "verbose": False
35 | }
36 |
37 | def process(
38 | self,
39 | image_in: Path | str | np.ndarray,
40 | *args,
41 | **kwargs,
42 | ) -> Path:
43 | if type(image_in) is [str, Path]:
44 | image_in = cv2.imread(image_in, cv2.IMREAD_COLOR)
45 |
46 | default_params: dict = self.get_default()
47 | if kwargs.get("params", None):
48 | params: Parameters = kwargs["params"]
49 |
50 | # Update
51 | default_params["augment"] = params.augment
52 | default_params["agnostic_nms"] = params.agnostic_nms
53 | default_params["imgsz"] = params.image_size
54 | default_params["iou"] = params.min_iou
55 | default_params["conf"] = params.min_confident_score
56 |
57 | logger.debug(f"Run with config: {default_params}")
58 |
59 | results: Results = self._model(image_in, **default_params)
60 | result = results[0].cpu().numpy()
61 |
62 | model_out_params = {
63 | "xyxysc": result.boxes.data
64 | }
65 |
66 | return ModelOutput(**model_out_params)
67 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/schemas.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 |
3 | import numpy as np
4 |
5 |
6 | @dataclass
7 | class Base:
8 | def to_dict(self):
9 | return asdict(self)
10 |
11 |
12 | @dataclass
13 | class Parameters(Base):
14 | augment: bool
15 | agnostic_nms: bool
16 | image_size: int
17 | min_iou: float
18 | min_confident_score: float
19 |
20 |
21 | @dataclass
22 | class ModelInput(Base):
23 | upload_image: str
24 | params: Parameters
25 |
26 |
27 | @dataclass
28 | class ModelOutput(Base):
29 | xyxysc: np.ndarray # x_min, y_min, x_max, y_max, score, class
30 |
31 | def __len__(self):
32 | return len(self.xyxysc)
33 |
34 | def __getitem__(self, item_id: int) -> np.ndarray:
35 | return self.xyxysc[item_id]
36 |
37 | def count(self) -> dict[int, int]:
38 | cls_dict: dict[int, int] = {}
39 | for c in self.xyxysc[:, -1]:
40 | c = int(c)
41 | if c not in cls_dict:
42 | cls_dict[c] = 0
43 | cls_dict[c] += 1
44 |
45 | return cls_dict
46 |
47 | def to_dict(self) -> dict[int, list]:
48 | result_dict: dict[int, list] = {}
49 | for i, elem in enumerate(self.xyxysc):
50 | x_min, y_min, x_max, y_max = map(int, elem[:4])
51 | score = float(elem[-2])
52 | cls = int(elem[-1])
53 |
54 | result_dict[i] = [
55 | x_min, y_min, x_max, y_max, score, cls
56 | ]
57 | return result_dict
58 |
59 |
60 | @dataclass
61 | class EditedOutput(Base):
62 | cls: int
63 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/utils/__init__.py
--------------------------------------------------------------------------------
/streamlit_demo/shared/utils/files.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | from loguru import logger
4 | from PIL import Image
5 |
6 |
7 | def save_uploaded_file(file, dir_out: str) -> str:
8 | """Save uploaded file to local"""
9 | pil_image = Image.open(file)
10 |
11 | path_out = os.path.join(dir_out, file.name)
12 | pil_image.save(
13 | path_out
14 | )
15 |
16 | assert os.path.isfile(path_out)
17 | logger.info(f"Save file at: {path_out}")
18 |
19 | return path_out
20 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/utils/log.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from loguru import logger
4 |
5 |
6 | def custom_logger():
7 | logger.remove()
8 | logger.add(
9 | sys.stderr,
10 | colorize=True,
11 | format="[{time:MM/DD HH:mm:ss}]> {level: ^8}| {message}",
12 | )
13 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/utils/pages.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 |
4 | def set_page_config():
5 | try:
6 | st.set_page_config(
7 | page_title="Object Detection",
8 | page_icon="🤖",
9 | layout="wide",
10 | initial_sidebar_state="expanded"
11 | )
12 | finally:
13 | pass
14 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/__init__.py:
--------------------------------------------------------------------------------
1 | from .app.view import App
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/app/__init__.py
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/app/view.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import streamlit as st
4 |
5 | import constants as c
6 | from shared.crud.feedbacks import FeedbackCRUD
7 | from shared.models.engine import Session
8 | from shared.utils.files import save_uploaded_file
9 | from shared.schemas import ModelInput, ModelOutput, Parameters, EditedOutput
10 | from shared.models_ai import get_ai_model, BaseAIModel
11 | from shared.views.canvas.canvas import st_annotate_tool
12 |
13 |
14 | class BaseView(ABC):
15 | @abstractmethod
16 | def view(self, key: str):
17 | ...
18 |
19 |
20 | class UploadView(BaseView):
21 | def view(self, key: str) -> ModelInput | None:
22 | with st.form(f"{key}_upload", clear_on_submit=True):
23 | upload_image = st.file_uploader(
24 | "Upload Image(s)",
25 | accept_multiple_files=False,
26 | type=c.YOLO_SUPPORTED_EXTENSIONS,
27 | key=f"{key}_upload_images"
28 | )
29 |
30 | col1, col2 = st.columns(2)
31 | with col1:
32 | augment = st.radio(
33 | "Augment",
34 | (True, False),
35 | horizontal=True
36 | )
37 | with col2:
38 | agnostic_nms = st.radio(
39 | "Agnostic NMS",
40 | (True, False),
41 | horizontal=True
42 | )
43 | image_size = st.number_input(
44 | "Image Size",
45 | value=640,
46 | step=32,
47 | min_value=640,
48 | max_value=1280
49 | )
50 | min_iou = st.slider(
51 | "Minimum IOU",
52 | min_value=0.0,
53 | max_value=1.0,
54 | value=0.5,
55 | step=0.01
56 | )
57 | min_confident_score = st.slider(
58 | "Minimum Confidence Score",
59 | min_value=0.0,
60 | max_value=1.0,
61 | value=0.2,
62 | step=0.01
63 | )
64 |
65 | submit_btn = st.form_submit_button(
66 | label="Upload",
67 | type="primary",
68 | use_container_width=True
69 | )
70 |
71 | if submit_btn:
72 | upload_image_path: str = save_uploaded_file(
73 | upload_image,
74 | c.USER_DATA_DIR
75 | )
76 |
77 | input_params = {
78 | "augment": augment,
79 | "agnostic_nms": agnostic_nms,
80 | "image_size": image_size,
81 | "min_iou": min_iou,
82 | "min_confident_score": min_confident_score
83 | }
84 |
85 | return ModelInput(
86 | upload_image=upload_image_path,
87 | params=Parameters(**input_params)
88 | )
89 |
90 | return
91 |
92 |
93 | class ImagePanelView(BaseView):
94 | def view(self, key: str, model_output: ModelOutput, image_path: str):
95 | updated_output, selected_index = st_annotate_tool(
96 | regions=model_output,
97 | background_image=image_path,
98 | key=f"{key}_visual",
99 | canvas_height=900,
100 | canvas_width=900
101 | )
102 |
103 | updated_output: ModelOutput
104 | selected_index: int
105 |
106 | return updated_output, selected_index
107 |
108 |
109 | class InfoPanelView(BaseView):
110 | def view(self, key: str, model_output: ModelOutput, selected_index: int) -> EditedOutput | None:
111 | # Counting bboxes
112 | cls_name_dict: dict[str, int] = {c.CLASSES[k]: v for k, v in model_output.count().items()}
113 |
114 | with st.container(border=True):
115 | st.markdown("**Counting**")
116 | st.json(cls_name_dict)
117 |
118 | # View selected bbox
119 | if 0 <= selected_index < len(model_output.xyxysc):
120 | x_min, y_min, x_max, y_max, score, cls = model_output.xyxysc[selected_index]
121 |
122 | with st.expander(label="Object Detail", expanded=True):
123 | cls = st.selectbox(label="Class", options=c.CLASSES, index=int(cls))
124 |
125 | score_in_str = "%.3f" % score
126 | st.markdown(f"Confident score :red[{score_in_str}]")
127 |
128 | cls_index: int = c.CLASSES.index(cls)
129 |
130 | return EditedOutput(cls=cls_index)
131 |
132 |
133 | class App(BaseView):
134 | def __init__(self):
135 | self._upload_view = UploadView()
136 | self._image_panel_view = ImagePanelView()
137 | self._info_panel_view = InfoPanelView()
138 |
139 | self._ai_model: BaseAIModel = get_ai_model(
140 | c.AI_MODEL,
141 | c.AI_MODEL_CONFIGS[c.AI_MODEL]
142 | )
143 |
144 | self.feedback_crud: FeedbackCRUD = FeedbackCRUD(
145 | session=Session()
146 | )
147 |
148 | @property
149 | def model_input(self) -> ModelInput | None:
150 | return st.session_state.get("model_input", None)
151 |
152 | @model_input.setter
153 | def model_input(self, model_in: ModelInput):
154 | st.session_state["model_input"] = model_in
155 |
156 | @property
157 | def model_output(self):
158 | return st.session_state.get("model_output", None)
159 |
160 | @model_output.setter
161 | def model_output(self, model_output: ModelOutput):
162 | st.session_state["model_output"] = model_output
163 |
164 | def view(self, key: str):
165 | with st.sidebar:
166 | st.header("Model Configurations")
167 | st.info("Check out the documentation "
168 | "at [link](https://docs.ultralytics.com/modes/predict/#inference-sources)")
169 |
170 | model_input: ModelInput | None = self._upload_view.view(key=f"{key}_upload_inputs")
171 | if model_input is not None:
172 | # Run AI model when get new input
173 | with st.spinner("Running AI...."):
174 | model_output: ModelOutput = self._ai_model.process(
175 | image_in=model_input.upload_image,
176 | params=model_input.params
177 | )
178 | st.toast("Finished AI processing", icon="🎉")
179 | self.model_input = model_input
180 | self.model_output = model_output
181 |
182 | if self.model_input is None:
183 | return
184 |
185 | image_col, info_col = st.columns([8, 2])
186 | with image_col:
187 | updated_model_output, selected_index = self._image_panel_view.view(
188 | key=f"{key}_images",
189 | model_output=self.model_output,
190 | image_path=self.model_input.upload_image
191 | )
192 | self.model_output = updated_model_output
193 |
194 | with info_col:
195 | edited_output: EditedOutput | None = self._info_panel_view.view(
196 | key=f"{key}_info",
197 | model_output=self.model_output,
198 | selected_index=selected_index
199 | )
200 |
201 | save = st.button(
202 | "Edit & Save",
203 | key=f"{key}_save_btn",
204 | use_container_width=True,
205 | type="primary"
206 | )
207 |
208 | if save and edited_output and 0 <= selected_index <= len(updated_model_output):
209 | updated_model_output[selected_index][-2] = edited_output.cls
210 | self.model_output = updated_model_output
211 |
212 | self.feedback_crud.create(
213 | image_path=self.model_input.upload_image,
214 | data=self.model_output.to_dict()
215 | )
216 |
217 | st.toast("Saved", icon="🎉")
218 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cinnamon/ai-bootcamp-2024/87e28e5863621aa989a6ef5d88785505acde5345/streamlit_demo/shared/views/canvas/__init__.py
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/canvas.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Literal
4 |
5 | import streamlit as st
6 | import streamlit.components.v1 as components
7 | import streamlit.elements.image as st_image
8 | from PIL import Image
9 |
10 | from .processor import DataProcessor
11 | from shared.schemas import ModelOutput
12 |
13 |
14 | _RELEASE = True # on packaging, pass this to True
15 |
16 |
17 | if not _RELEASE:
18 | _component_func = components.declare_component(
19 | "st_sparrow_labeling",
20 | url="http://localhost:3001",
21 | )
22 | else:
23 | parent_dir = os.path.dirname(os.path.abspath(__file__))
24 | build_dir = os.path.join(parent_dir, "frontend/build")
25 | _component_func = components.declare_component("st_sparrow_labeling", path=build_dir)
26 |
27 |
28 | @lru_cache(1)
29 | def get_background_image_bytes(image_path: str):
30 | background_image = Image.open(image_path)
31 | width, height = background_image.size
32 |
33 | format = st_image._validate_image_format_string(background_image, "PNG")
34 | image_data = _pil_to_bytes(background_image, format)
35 |
36 | return image_data, width
37 |
38 |
39 | def check_image_url(url):
40 | import requests
41 | try:
42 | response = requests.get(url)
43 | # Check if the request was successful
44 | if response.status_code == 200:
45 | return True
46 | else:
47 | return False
48 | except Exception as e:
49 | return False
50 |
51 |
52 | def _pil_to_bytes(
53 | image: st_image.PILImage,
54 | format: st_image.ImageFormat = "JPEG",
55 | quality: int = 100,
56 | ) -> bytes:
57 | import io
58 |
59 | """Convert a PIL image to bytes."""
60 | tmp = io.BytesIO()
61 |
62 | # User must have specified JPEG, so we must convert it
63 | if format == "JPEG" and st_image._image_may_have_alpha_channel(image):
64 | image = image.convert("RGB")
65 |
66 | image.save(tmp, format=format, quality=quality)
67 |
68 | return tmp.getvalue()
69 |
70 |
71 | def st_annotate_tool(
72 | regions: ModelOutput,
73 | fill_color: str = "#eee",
74 | stroke_width: int = 20,
75 | stroke_color: str = "black",
76 | background_image: Image = None,
77 | drawing_mode: Literal["transform", "rect"] = "transform",
78 | point_display_radius: int = 3,
79 | canvas_height: int = 600,
80 | canvas_width: int = 600,
81 | key=None,
82 | ) -> tuple[ModelOutput, int]:
83 | """Create a drawing canvas in Streamlit app. Retrieve the RGBA image data into a 4D numpy array (r, g, b, alpha)
84 | on mouse up event.
85 |
86 | Parameters
87 | ----------
88 | regions: ModelOutput
89 | Output from ai model, list of (x_min, y_min, x_max, y_max, score, cls)
90 | fill_color: str
91 | Color of fill for Rect in CSS color property. Defaults to "#eee".
92 | stroke_width: str
93 | Width of drawing brush in CSS color property. Defaults to 20.
94 | stroke_color: str
95 | Color of drawing brush in hex. Defaults to "black".
96 | background_image: Image
97 | Pillow Image to display behind canvas.
98 | Automatically resized to canvas dimensions.
99 | Being behind the canvas, it is not sent back to Streamlit on mouse event.
100 | drawing_mode: {'freedraw', 'transform', 'line', 'rect', 'circle', 'point', 'polygon'}
101 | Enable free drawing when "freedraw", object manipulation when "transform", "line", "rect", "circle", "point", "polygon".
102 | Defaults to "freedraw".
103 | point_display_radius: int
104 | The radius to use when displaying point objects. Defaults to 3.
105 | canvas_height: int
106 | Height of canvas in pixels. Defaults to 600.
107 | canvas_width: int
108 | Width of canvas in pixels. Defaults to 600.
109 | key: str
110 | An optional string to use as the unique key for the widget.
111 | Assign a key so the component is not remount every time the script is rerun.
112 |
113 | Returns
114 | -------
115 | new_model_output: contains edited bounding boxes
116 | selected_index: select index
117 | """
118 | # Resize background_image to canvas dimensions by default
119 | # Then override background_color
120 | if canvas_height == 0 or canvas_width == 0:
121 | return regions, -1
122 |
123 | background_image_url = None
124 | if background_image:
125 | image_bytes, width = get_background_image_bytes(background_image)
126 |
127 | # Reduce network traffic and cache when switch another configure,
128 | # use streamlit in-mem filemanager to convert image to URL
129 | background_image_url = st_image.image_to_url(
130 | image_bytes, width, True, "RGB", "PNG",
131 | f"drawable-canvas-bg-{background_image}-{key}"
132 | )
133 | background_image_url = st._config.get_option("server.baseUrlPath") + background_image_url
134 |
135 | data_processor = DataProcessor()
136 | canvas_rects = data_processor.prepare_canvas_data(regions)
137 |
138 | component_value = _component_func(
139 | fillColor=fill_color,
140 | strokeWidth=stroke_width,
141 | strokeColor=stroke_color,
142 | backgroundImageURL=background_image_url,
143 | canvasHeight=canvas_height,
144 | canvasWidth=canvas_width,
145 | drawingMode=drawing_mode,
146 | initialDrawing=canvas_rects,
147 | displayRadius=point_display_radius,
148 | key=f"{key}_canvas",
149 | default=None,
150 | realtimeUpdateStreamlit=True,
151 | showingMode="All",
152 | displayToolbar=False
153 | )
154 |
155 | if component_value is None:
156 | return regions, -1
157 |
158 | select_index = component_value.get('selectIndex', -1)
159 | new_model_output, select_index = data_processor.prepare_rect_data(
160 | component_value["raw"],
161 | regions,
162 | select_index
163 | )
164 |
165 | return (
166 | new_model_output,
167 | select_index,
168 | )
169 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/frontend/.env:
--------------------------------------------------------------------------------
1 | # Run the component's dev server on :3001
2 | # (The Streamlit dev server already runs on :3000)
3 | PORT=3001
4 |
5 | # Don't automatically open the web browser on `npm run start`.
6 | BROWSER=none
7 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/frontend/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # package-lock.json
4 |
5 | # dependencies
6 | /node_modules
7 | /.pnp
8 | .pnp.js
9 |
10 | # testing
11 | /coverage
12 |
13 | # production
14 | /build
15 |
16 | # misc
17 | .DS_Store
18 | .env.local
19 | .env.development.local
20 | .env.test.local
21 | .env.production.local
22 |
23 | npm-debug.log*
24 | yarn-debug.log*
25 | yarn-error.log*
26 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/frontend/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "endOfLine": "lf",
3 | "semi": false,
4 | "trailingComma": "es5"
5 | }
6 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "drawable_canvas",
3 | "version": "0.9.3",
4 | "private": true,
5 | "dependencies": {
6 | "apache-arrow": "^0.17.0",
7 | "event-target-shim": "^5.0.1",
8 | "fabric": "4.4.0",
9 | "hoist-non-react-statics": "^3.3.2",
10 | "lodash": "^4.17.20",
11 | "react": "^16.13.1",
12 | "react-dom": "^16.13.1",
13 | "react-scripts": "4.0.3",
14 | "streamlit-component-lib": "^1.3.0",
15 | "typescript": "^4.6.3"
16 | },
17 | "devDependencies": {
18 | "@types/fabric": "^3.6.2",
19 | "@types/hoist-non-react-statics": "^3.3.1",
20 | "@types/jest": "^24.0.0",
21 | "@types/lodash": "^4.14.161",
22 | "@types/node": "^12.0.0",
23 | "@types/react": "^16.9.0",
24 | "@types/react-dom": "^16.9.0"
25 | },
26 | "scripts": {
27 | "start": "react-scripts start",
28 | "build": "react-scripts build",
29 | "test": "react-scripts test",
30 | "eject": "react-scripts eject"
31 | },
32 | "eslintConfig": {
33 | "extends": "react-app"
34 | },
35 | "browserslist": {
36 | "production": [
37 | ">0.2%",
38 | "not dead",
39 | "not op_mini all"
40 | ],
41 | "development": [
42 | "last 1 chrome version",
43 | "last 1 firefox version",
44 | "last 1 safari version"
45 | ]
46 | },
47 | "homepage": "."
48 | }
49 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/frontend/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | Streamlit Component
9 |
10 |
11 |
12 |
13 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/streamlit_demo/shared/views/canvas/frontend/src/DrawableCanvas.tsx:
--------------------------------------------------------------------------------
1 | import React, { useEffect, useState } from "react"
2 | import {
3 | ComponentProps,
4 | Streamlit,
5 | withStreamlitConnection,
6 | } from "streamlit-component-lib"
7 | import { fabric } from "fabric"
8 | import { isEqual } from "lodash"
9 |
10 | import CanvasToolbar from "./components/CanvasToolbar"
11 |
12 | import { useCanvasState } from "./DrawableCanvasState"
13 | import { tools, FabricTool } from "./lib"
14 |
15 | function getStreamlitBaseUrl(): string | null {
16 | const params = new URLSearchParams(window.location.search)
17 | const baseUrl = params.get("streamlitUrl")
18 | if (baseUrl == null) {
19 | return null
20 | }
21 |
22 | try {
23 | return new URL(baseUrl).origin
24 | } catch {
25 | return null
26 | }
27 | }
28 |
29 | interface CustomFabricCanvas extends fabric.Canvas {
30 | isDragging?: boolean;
31 | selection?: boolean;
32 | lastPosX?: number;
33 | lastPosY?: number;
34 |
35 | secondTimeAccess?: boolean;
36 | currentState?: Object;
37 | showingMode?: string;
38 |
39 | }
40 |
41 | /**
42 | * Arguments Streamlit receives from the Python side
43 | */
44 | export interface PythonArgs {
45 | fillColor: string
46 | strokeWidth: number
47 | strokeColor: string
48 | backgroundColor: string
49 | backgroundImageURL: string
50 | realtimeUpdateStreamlit: boolean
51 | canvasWidth: number
52 | canvasHeight: number
53 | drawingMode: string
54 | initialDrawing: Object
55 | displayToolbar: boolean
56 | displayRadius: number
57 | showingMode: string
58 | }
59 |
60 | /**
61 | * Define logic for the canvas area
62 | */
63 | const DrawableCanvas = ({ args }: ComponentProps) => {
64 | const {
65 | canvasWidth,
66 | canvasHeight,
67 | backgroundColor,
68 | backgroundImageURL,
69 | realtimeUpdateStreamlit,
70 | drawingMode,
71 | fillColor,
72 | strokeWidth,
73 | strokeColor,
74 | displayRadius,
75 | initialDrawing,
76 | displayToolbar,
77 | showingMode
78 | }: PythonArgs = args
79 |
80 | /**
81 | * State initialization
82 | */
83 | const [canvas, setCanvas] = useState(new fabric.Canvas("c") as CustomFabricCanvas);
84 | canvas.stopContextMenu = true
85 | canvas.fireRightClick = true
86 |
87 | const [selectedRect, setSelectedRect] = useState(-1)
88 |
89 | const [backgroundCanvas, setBackgroundCanvas] = useState(new fabric.Canvas("c") as CustomFabricCanvas);
90 | const {
91 | canvasState: {
92 | action: { shouldReloadCanvas, forceSendToStreamlit },
93 | currentState,
94 | initialState,
95 | },
96 | saveState,
97 | undo,
98 | redo,
99 | canUndo,
100 | canRedo,
101 | forceStreamlitUpdate,
102 | resetState,
103 | } = useCanvasState()
104 |
105 |
106 | /*
107 | * Load background image from URL
108 | */
109 | // const params = new URLSearchParams(window.location.search);
110 | // const baseUrl = params.get('streamlitUrl')
111 | const baseUrl = getStreamlitBaseUrl() ?? ""
112 | let img = new fabric.Image()
113 |
114 | fabric.Image.fromURL(baseUrl + backgroundImageURL, function(oImg) {
115 | img = oImg
116 | img.selectable = false;
117 | backgroundCanvas.add(img);
118 |
119 | if (img.width == null || img.height == null){
120 | return
121 | }
122 |
123 | // only initialize (image + rects) for canvas 1
124 | const isSecondTimes = (canvas.secondTimeAccess || false)
125 |
126 | /*
127 | * This is the first time UI is created,
128 | * And we try to align the canvas size with image by perform zooming only.
129 | * PS: This happend only for 1st time
130 | */
131 | if (isSecondTimes === false){ // It means this is the first time
132 | console.log("Render Fist Time")
133 | canvas.loadFromJSON(initialDrawing, () => {})
134 |
135 | // initialize zoom
136 | const widthRatio = canvas.getWidth() / img.width;
137 | const heightRatio = canvas.getHeight() / img.height;
138 | const zoom = Math.min(widthRatio, heightRatio)
139 | canvas.setZoom(zoom);
140 | backgroundCanvas.setZoom(zoom)
141 |
142 | canvas.secondTimeAccess = true
143 | canvas.requestRenderAll()
144 | backgroundCanvas.requestRenderAll()
145 |
146 | canvas.currentState = { ...initialDrawing }
147 | canvas.showingMode = showingMode
148 | }
149 |
150 | /*
151 | * User can choose some group of boxes to visualie (keys only, value only, or both)
152 | * Refresh the initial canvas
153 | * The current showingMode is different with the previous one! => Trigger to re-load the initialDrawings!
154 | * [07.10.2023] The below code should be erased. We don't allow to do it anymore because of low performance.
155 | */
156 | if (canvas.showingMode !== showingMode){
157 | canvas.showingMode = showingMode
158 |
159 | if (!isEqual(canvas.currentState, initialDrawing)){
160 | canvas.loadFromJSON(initialDrawing, () => {
161 | canvas.currentState = { ...initialDrawing }
162 |
163 | canvas.renderAll()
164 | })
165 | }
166 | }
167 |
168 | });
169 |
170 | /**
171 | * Initialize canvases on component mount
172 | * NB: Remount component by changing its key instead of defining deps
173 | */
174 | useEffect(() => {
175 | const c = new fabric.Canvas("canvas", {
176 | enableRetinaScaling: false,
177 | })
178 | const imgC = new fabric.Canvas("backgroundimage-canvas", {
179 | enableRetinaScaling: false,
180 | })
181 | setCanvas(c)
182 | setBackgroundCanvas(imgC)
183 | Streamlit.setFrameHeight()
184 | }, [])
185 |
186 |
187 | /**
188 | * If state changed from undo/redo/reset, update user-facing canvas
189 | */
190 | useEffect(() => {
191 | if (shouldReloadCanvas) {
192 | canvas.loadFromJSON(currentState, () => {})
193 | }
194 | }, [canvas, shouldReloadCanvas, currentState])
195 |
196 |
197 | /**
198 | * Update canvas with selected tool
199 | * PS: add initialDrawing in dependency so user drawing update reinits tool
200 | */
201 | useEffect(() => {
202 | // Update canvas events with selected tool
203 | const selectedTool = new tools[drawingMode](canvas) as FabricTool
204 | const cleanupToolEvents = selectedTool.configureCanvas({
205 | fillColor: fillColor,
206 | strokeWidth: strokeWidth,
207 | strokeColor: strokeColor,
208 | displayRadius: displayRadius
209 | })
210 |
211 | /*
212 | * Ensure zoom/pan do not exceed the boundary of canvas.
213 | */
214 | let ensure_boundary: () => void = function (): void {
215 | const T = canvas.viewportTransform;
216 |
217 | if (img.aCoords == null || T == null) return
218 |
219 | const brRaw = img.aCoords.br
220 | const tlRaw = img.aCoords.tl
221 |
222 | const br = fabric.util.transformPoint(brRaw, T);
223 | const tl = fabric.util.transformPoint(tlRaw, T);
224 |
225 | const {
226 | x: left,
227 | y: top
228 | } = tl;
229 |
230 | const {
231 | x: right,
232 | y: bottom
233 | } = br;
234 |
235 | const width = canvas.getWidth()
236 | const height = canvas.getHeight()
237 |
238 | // calculate how far to translate to line up the edge of the object with
239 | // the edge of the canvas
240 | const dLeft = Math.abs(right - width);
241 | const dRight = Math.abs(left);
242 | const dUp = Math.abs(bottom - height);
243 | const dDown = Math.abs(top);
244 | const maxDx = Math.min(dLeft, dRight);
245 | const maxDy = Math.min(dUp, dDown);
246 |
247 | // if the object is larger than the canvas, clamp translation such that
248 | // we don't push the opposite boundary past the edge
249 | const leftIsOver = left < 0;
250 | const rightIsOver = right > width;
251 | const topIsOver = top < 0;
252 | const bottomIsOver = bottom > height;
253 |
254 | const translateLeft = rightIsOver && !leftIsOver;
255 | const translateRight = leftIsOver && !rightIsOver;
256 | const translateUp = bottomIsOver && !topIsOver;
257 | const translateDown = topIsOver && !bottomIsOver;
258 |
259 | const dx = translateLeft ? -maxDx : translateRight ? maxDx : 0;
260 | const dy = translateUp ? -maxDy : translateDown ? maxDy : 0;
261 |
262 | if (dx || dy) {
263 | T[4] += dx;
264 | T[5] += dy;
265 | canvas.requestRenderAll();
266 |
267 | backgroundCanvas.setViewportTransform(T)
268 | backgroundCanvas.requestRenderAll()
269 | }
270 |
271 | };
272 |
273 | /*
274 | * Mouse down event.
275 | * IF user press Alt keyboard, then move => Drag & Drop the image.
276 | */
277 | canvas.on("mouse:down", function (this: CustomFabricCanvas, opt) {
278 | var evt = opt.e as MouseEvent;
279 |
280 | if (evt.altKey === true) {
281 | this.isDragging = true;
282 | this.selection = false;
283 | this.lastPosX = evt.clientX;
284 | this.lastPosY = evt.clientY;
285 |
286 | canvas.setCursor('grab')
287 | // canvas.discardActiveObject();
288 | // canvas.requestRenderAll();
289 |
290 | }
291 |
292 | if (opt.target) {
293 | if (opt.target.type === 'rect') {
294 |
295 | const selectObject = canvas.getActiveObject()
296 | const selectIndex = canvas.getObjects().indexOf(selectObject)
297 |
298 | selectObject.selectionBackgroundColor = 'rgba(63,245,39,0.5)'
299 |
300 | // Return selected object.
301 | setSelectedRect(selectIndex)
302 |
303 | const data = canvas
304 | .getContext()
305 | .canvas.toDataURL()
306 |
307 | Streamlit.setComponentValue({
308 | data: data,
309 | width: canvas.getWidth(),
310 | height: canvas.getHeight(),
311 | raw: canvas.toObject(),
312 | selectIndex: selectIndex
313 | })
314 |
315 | }
316 | } else {
317 | setSelectedRect(-1)
318 | }
319 | })
320 |
321 |
322 | /*
323 | * Mouse move event. Only affect while the alt key is pressed.
324 | */
325 | canvas.on("mouse:move", function (this: CustomFabricCanvas, opt) {
326 | var e = opt.e as MouseEvent
327 |
328 | if (this.isDragging || false) {
329 | canvas.setCursor('grab')
330 | const delta = new fabric.Point( e.movementX, e.movementY )
331 |
332 | canvas.relativePan( delta )
333 | backgroundCanvas.relativePan( delta )
334 |
335 | ensure_boundary()
336 |
337 | e.preventDefault();
338 | e.stopPropagation();
339 |
340 | }
341 | })
342 |
343 | /*
344 | * Mouse wheel event - Scale in/out
345 | */
346 | canvas.on("mouse:wheel", function (this: CustomFabricCanvas, opt) {
347 | var e = opt.e as WheelEvent;
348 | var delta = e.deltaY;
349 | var zoom = canvas.getZoom();
350 | zoom *= 0.999 ** delta;
351 | if (zoom > 10) zoom = 10;
352 | if (zoom < 0.1) zoom = 0.1;
353 | var point = new fabric.Point(e.offsetX, e.offsetY);
354 | canvas.zoomToPoint(point, zoom);
355 | backgroundCanvas.zoomToPoint(point, zoom);
356 |
357 | e.preventDefault();
358 | e.stopPropagation();
359 | })
360 |
361 | canvas.on("mouse:up", (e: any) => {
362 | /*
363 | * There are several events can end with mouse:up:
364 | * 1. [rect] create new object
365 | * 2. [transform] resize selected object
366 | * 3. [transform] choose selected object
367 | * 4. [transform] delete selected object
368 | */
369 |
370 | // saveState(canvas.toJSON());
371 |
372 | var isEqualState = isEqual( canvas.toObject(), canvas.currentState )
373 | if ( (isEqualState === false) && (drawingMode === 'transform') ){
374 | canvas.currentState = { ...canvas.toObject() }
375 |
376 | const selectObject = canvas.getActiveObject()
377 | const selectIndex = canvas.getObjects().indexOf(selectObject)
378 |
379 | const data = canvas
380 | .getContext()
381 | .canvas.toDataURL()
382 |
383 | Streamlit.setComponentValue({
384 | data: data,
385 | width: canvas.getWidth(),
386 | height: canvas.getHeight(),
387 | raw: canvas.toObject(),
388 | selectIndex: selectIndex
389 | })
390 |
391 | }
392 |
393 | // Add your logic here for handling mouse up events
394 | canvas.isDragging = false;
395 | canvas.selection = true;
396 | canvas.setCursor("default")
397 | });
398 |
399 | canvas.on("mouse:dblclick", () => {
400 | if (drawingMode === 'transform') {
401 | const selectObject = canvas.getActiveObject()
402 | const selectIndex = canvas.getObjects().indexOf(selectObject)
403 |
404 | canvas.remove(selectObject)
405 |
406 | const data = canvas
407 | .getContext()
408 | .canvas.toDataURL()
409 |
410 | Streamlit.setComponentValue({
411 | data: data,
412 | width: canvas.getWidth(),
413 | height: canvas.getHeight(),
414 | raw: canvas.toObject(),
415 | selectIndex: selectIndex
416 | })
417 |
418 | }
419 |
420 | })
421 |
422 | // Cleanup tool + send data to Streamlit events
423 | return () => {
424 | cleanupToolEvents()
425 | canvas.off("mouse:down")
426 | canvas.off("mouse:move")
427 | canvas.off("mouse:up")
428 | canvas.off("mouse:wheel")
429 | canvas.off("mouse:dblclick")
430 | backgroundCanvas.off("mouse:down")
431 | backgroundCanvas.off("mouse:move")
432 | backgroundCanvas.off("mouse:up")
433 | backgroundCanvas.off("mouse:wheel")
434 | backgroundCanvas.off("mouse:dblclick")
435 | }
436 | }, [
437 | canvas,
438 | backgroundCanvas,
439 | strokeWidth,
440 | strokeColor,
441 | displayRadius,
442 | fillColor,
443 | drawingMode,
444 | initialDrawing,
445 | saveState,
446 | forceStreamlitUpdate,
447 | img
448 | ])
449 |
450 | /**
451 | * Render canvas w/ toolbar
452 | */
453 | return (
454 |