├── .gitignore
├── LICENSE
├── README.md
├── conda_env.yml
├── data-exploration.ipynb
├── demo.ipynb
├── exp1.ipynb
├── exp2.ipynb
├── images
├── 6.png
├── leave.png
└── varsity.png
├── sys.png
├── task_descriptions.json
├── tests
└── test_data.py
└── videoannotator
├── __init__.py
├── config.py
├── data.py
├── dp.py
├── experiments.py
├── io.py
└── models.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Video Annotator
2 |
3 | ## Video understanding
4 |
5 | Here's a quick demo to get started:
6 | ```python
7 | from videoannotator import data, models
8 | import numpy as np
9 |
10 | # pick a label to work with
11 | label = "car-chase"
12 |
13 | # grab the annotated dataset for the label
14 | dataset = data.get_aggregate_labeled_dataset(label=label)
15 |
16 | # train a classifier with the labeled data
17 | clf = models.LogisticRegression(scoring='average_precision')
18 | clf.fit(dataset.x, dataset.y)
19 |
20 | # get the video embedding for your clip
21 | np.random.seed(0)
22 | emb = np.random.rand(1, 512)
23 |
24 | # score is between 0 and 1
25 | # higher score means that the clip is more likely to contain to `label`
26 | score = clf.predict_proba(emb).item()
27 | ```
28 |
29 | See more details in [this notebook](demo.ipynb).
30 |
31 | ## What is Video Annotator?
32 |
33 | The best source is our paper [here](https://arxiv.org/abs/2402.06560). Here's the abstract:
34 |
35 | > High-quality and consistent annotations are fundamental to the successful development of robust machine learning models. Traditional data annotation methods are resource-intensive and inefficient, often leading to a reliance on third-party annotators who are not the domain experts. Hard samples, which are usually the most informative for model training, tend to be difficult to label accurately and consistently without business context. These can arise unpredictably during the annotation process, requiring a variable number of iterations and rounds of feedback, leading to unforeseen expenses and time commitments to guarantee quality.
36 | >
37 | >We posit that more direct involvement of domain experts, using a human-in-the-loop system, can resolve many of these practical challenges. We propose a novel framework we call Video Annotator (VA) for annotating, managing, and iterating on video classification datasets. Our approach offers a new paradigm for an end-user-centered model development process, enhancing the efficiency, usability, and effectiveness of video classifiers. Uniquely, VA allows for a continuous annotation process, seamlessly integrating data collection and model training.
38 | >
39 | > We leverage the zero-shot capabilities of vision-language foundation models combined with active learning techniques, and demonstrate that VA enables the efficient creation of high-quality models. VA achieves a median 6.8 point improvement in Average Precision relative to the most competitive baseline across a wide-ranging assortment of tasks. We release a dataset with 153k labels across 56 video understanding tasks annotated by three professional video editors using VA, and also release code to replicate our experiments.
40 |
41 | ## System
42 | 
43 |
44 | ## Data
45 |
46 | 1. Download the data tar file from [here](https://drive.google.com/file/d/107Dm1UBYQ8UOSs9rcLA1TGExYKbcg7IA/view) and unpack into your folder of choice.
47 | 2. Set `DATA_BASE_PATH` to that folder in `videoannotator/config.py`. E.g. `/root/data/videoannotator/`.
48 | 3. View [this notebook](data-exploration.ipynb) to understand how the data is organized.
49 |
50 | Once unpacked, the data should look like this:
51 | ```
52 | ├── videoannotator
53 | │ ├── agg
54 | │ │ ├── action.json
55 | │ │ ├── aerial.json
56 | │ │ ├── ...
57 | │ ├── ave
58 | │ │ ├── keys-to-labels.json
59 | │ │ ├── labels-mapped.json
60 | │ │ ├── ...
61 | │ ├── checkpoints
62 | │ │ ├── action.json
63 | │ │ ├── aerial.json
64 | │ │ ├── ...
65 | │ ├── cmp
66 | │ │ ├── action.json
67 | │ │ ├── aerial.json
68 | │ │ ├── ...
69 | │ ├── embeddings.h5
70 | │ ├── queries.json
71 | │ ├── shot-data.csv
72 | │ ├── text-embeddings.json
73 | ```
74 |
75 | ## Environment setup
76 | ```shell
77 | conda env create -f conda_env.yml
78 | conda activate videoannotator
79 | ```
80 |
81 | ## Running experiments
82 | - [Experiment 1 notebook](exp1.ipynb)
83 | - [Experiment 2 notebook](exp2.ipynb)
84 |
85 | ## Citation
86 |
87 | If you find this work useful, please consider citing:
88 |
89 | ```bibtex
90 | @misc{ziai2024video,
91 | title={Video Annotator: A framework for efficiently building video classifiers using vision-language models and active learning},
92 | author={Amir Ziai and Aneesh Vartakavi},
93 | year={2024},
94 | eprint={2402.06560},
95 | archivePrefix={arXiv},
96 | primaryClass={cs.CV}
97 | }
98 | ```
99 |
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | # TODO
2 | name: videoannotator
3 | channels:
4 | - pytorch
5 | - defaults
6 | - conda-forge
7 | - intel
8 | dependencies:
9 | - python==3.7.4
10 | - pip==20.2.4
11 | - scikit-learn==0.24.2
12 | - pandas==1.3.4
13 | - numpy==1.19.5
14 | - pytorch==1.10.0
15 | - toolz==0.10.0
16 | - pip:
17 | - xgboost==1.2.1
18 | - pytorch-metric-learning==0.9.99
19 | - scipy==1.5.4
20 | - mmh3
21 | - h5py
22 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "f9ecf656-36e9-428f-b8a4-020b90f2486e",
7 | "metadata": {},
8 | "source": [
9 | "# Video Annotator\n",
10 | "## Video understanding demo\n",
11 | "\n",
12 | "**Note**:\n",
13 | "- read the `data` section of the README first.\n",
14 | "- this notebook assumes that you have already\n",
15 | " - downloaded data\n",
16 | " - placed it in the directory expected by `videoannotator/config.py`"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 18,
22 | "id": "c2f837a9-84d9-4f93-abda-670e59b2a89a",
23 | "metadata": {
24 | "execution": {
25 | "iopub.execute_input": "2024-04-24T16:17:53.477481Z",
26 | "iopub.status.busy": "2024-04-24T16:17:53.476871Z",
27 | "iopub.status.idle": "2024-04-24T16:17:53.480312Z",
28 | "shell.execute_reply": "2024-04-24T16:17:53.479816Z",
29 | "shell.execute_reply.started": "2024-04-24T16:17:53.477452Z"
30 | }
31 | },
32 | "outputs": [],
33 | "source": [
34 | "import json\n",
35 | "import numpy as np\n",
36 | "import pandas as pd\n",
37 | "from IPython.display import HTML\n",
38 | "\n",
39 | "from videoannotator import data, models, io, config as cfg"
40 | ]
41 | },
42 | {
43 | "attachments": {},
44 | "cell_type": "markdown",
45 | "id": "22adfc3e-63ad-41b4-a50f-44a9c58a0158",
46 | "metadata": {},
47 | "source": [
48 | "## helpers"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 19,
54 | "id": "4c9e40ed-3a7b-41f2-a8d7-c7f365978031",
55 | "metadata": {
56 | "execution": {
57 | "iopub.execute_input": "2024-04-24T16:17:54.134269Z",
58 | "iopub.status.busy": "2024-04-24T16:17:54.133652Z",
59 | "iopub.status.idle": "2024-04-24T16:17:54.139614Z",
60 | "shell.execute_reply": "2024-04-24T16:17:54.139129Z",
61 | "shell.execute_reply.started": "2024-04-24T16:17:54.134241Z"
62 | }
63 | },
64 | "outputs": [],
65 | "source": [
66 | "def get_model(label: str) -> models.LogisticRegression:\n",
67 | " \"\"\"\n",
68 | " Trains a model given all the aggregated annotations for `label`.\n",
69 | " \"\"\"\n",
70 | " ds = data.get_aggregate_labeled_dataset(label=label)\n",
71 | " clf = models.LogisticRegression(scoring='average_precision')\n",
72 | " clf.fit(ds.x, ds.y)\n",
73 | " return clf\n",
74 | "\n",
75 | "def get_score(label: str, key: str) -> float:\n",
76 | " \"\"\"\n",
77 | " Get classification score for `label` against shot `key`.\n",
78 | " \"\"\"\n",
79 | " if label not in clfs:\n",
80 | " clfs[label] = get_model(label=label)\n",
81 | " emb = data.get_embeddings_dict()[key]\n",
82 | " return clfs[label].predict_proba(emb[None, :]).item()\n",
83 | "\n",
84 | "def get_score_all_labels(key: str) -> dict:\n",
85 | " \"\"\"\n",
86 | " Scores all labels against the shot represented by `key`.\n",
87 | " \"\"\"\n",
88 | " return {label: get_score(label=label, key=key) for label in cfg.LABELS}\n",
89 | "\n",
90 | "def get_color(val: float) -> str:\n",
91 | " \"\"\"Used for table cell formatting.\"\"\"\n",
92 | " if val > 0.5:\n",
93 | " r, g = 0, 128\n",
94 | " a = val\n",
95 | " else:\n",
96 | " r, g = 255, 0\n",
97 | " a = 1 - val\n",
98 | " return f'color: rgba({r}, {g}, 0, {a})'"
99 | ]
100 | },
101 | {
102 | "attachments": {},
103 | "cell_type": "markdown",
104 | "id": "b820a5d4-9938-4138-8bea-5c08efa96d2b",
105 | "metadata": {},
106 | "source": [
107 | "## load a few models"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 40,
113 | "id": "1e4e3a8a-0e02-49f5-9926-649b51cb0c55",
114 | "metadata": {
115 | "execution": {
116 | "iopub.execute_input": "2024-04-24T16:32:42.113540Z",
117 | "iopub.status.busy": "2024-04-24T16:32:42.112924Z",
118 | "iopub.status.idle": "2024-04-24T16:32:47.059614Z",
119 | "shell.execute_reply": "2024-04-24T16:32:47.058607Z",
120 | "shell.execute_reply.started": "2024-04-24T16:32:42.113512Z"
121 | }
122 | },
123 | "outputs": [],
124 | "source": [
125 | "labels = ('action', 'establishing-shots', 'day', 'car-chase')\n",
126 | "clfs = {label: get_model(label=label) for label in labels}"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 21,
132 | "id": "27467ee9-092f-4276-a0fc-184ed09908ab",
133 | "metadata": {
134 | "execution": {
135 | "iopub.execute_input": "2024-04-24T16:17:59.821195Z",
136 | "iopub.status.busy": "2024-04-24T16:17:59.820774Z",
137 | "iopub.status.idle": "2024-04-24T16:17:59.828819Z",
138 | "shell.execute_reply": "2024-04-24T16:17:59.828144Z",
139 | "shell.execute_reply.started": "2024-04-24T16:17:59.821153Z"
140 | }
141 | },
142 | "outputs": [
143 | {
144 | "data": {
145 | "text/plain": [
146 | "'action, aerial, alcohol, anger, animal, car-chase, character-focus, closeup, cowboy-shot, day, drama, drugs, dutch-angle, establishing-shots, extreme-close-up, extreme-wide-shot, eye-level, fantasy, fight, golden-hour, gore, group-shot, handheld, happy, high-angle, horror, insert-shot, interior, interview, intimacy, jump-scare, laughter, low-angle, medium, nudity, object, over-the-shoulder-shot, overhead-shot, pan, point-of-view-shot, romance, run, sad, scared, sci-fi, shutter-shot, single-shot, slow-motion, smoking, static-shot, tilt-shot, timelapse, two-shot, violence, wide, zoom'"
147 | ]
148 | },
149 | "execution_count": 21,
150 | "metadata": {},
151 | "output_type": "execute_result"
152 | }
153 | ],
154 | "source": [
155 | "# the full set of available labels is at cfg.LABELS\n",
156 | "', '.join(sorted(cfg.LABELS))"
157 | ]
158 | },
159 | {
160 | "attachments": {},
161 | "cell_type": "markdown",
162 | "id": "16126115-6539-4825-b758-c072ad2c0a60",
163 | "metadata": {},
164 | "source": [
165 | "## load a few sample embeddings"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 22,
171 | "id": "fa0bdfeb-3564-48a1-bca3-bd23f2eb3831",
172 | "metadata": {
173 | "execution": {
174 | "iopub.execute_input": "2024-04-24T16:17:59.830507Z",
175 | "iopub.status.busy": "2024-04-24T16:17:59.829997Z",
176 | "iopub.status.idle": "2024-04-24T16:18:00.025539Z",
177 | "shell.execute_reply": "2024-04-24T16:18:00.025006Z",
178 | "shell.execute_reply.started": "2024-04-24T16:17:59.830469Z"
179 | }
180 | },
181 | "outputs": [],
182 | "source": [
183 | "samples = json.load(open('sample-embeddings.json'))\n",
184 | "samples = {k: np.array(v) for k, v in samples.items()}"
185 | ]
186 | },
187 | {
188 | "attachments": {},
189 | "cell_type": "markdown",
190 | "id": "1644163e-f1e7-4771-9025-2bff678820c8",
191 | "metadata": {},
192 | "source": [
193 | "## score the sample embeddings"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 41,
199 | "id": "0b0c6702-edca-47e4-93b6-1f5044556f86",
200 | "metadata": {
201 | "execution": {
202 | "iopub.execute_input": "2024-04-24T16:32:47.066975Z",
203 | "iopub.status.busy": "2024-04-24T16:32:47.064680Z",
204 | "iopub.status.idle": "2024-04-24T16:32:47.073915Z",
205 | "shell.execute_reply": "2024-04-24T16:32:47.073259Z",
206 | "shell.execute_reply.started": "2024-04-24T16:32:47.066931Z"
207 | }
208 | },
209 | "outputs": [],
210 | "source": [
211 | "def _get_record(title, emb):\n",
212 | " d = dict(\n",
213 | " title=title,\n",
214 | " video_thumbnail=f'
',\n",
215 | " )\n",
216 | " scores = {\n",
217 | " f'{label}_score': clf.predict_proba(emb[None, :]).item()\n",
218 | " for label, clf in clfs.items()\n",
219 | " }\n",
220 | " return {**d, **scores}"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 42,
226 | "id": "a7c2a338-9d7a-49b8-8f7d-1ab9dea66930",
227 | "metadata": {
228 | "execution": {
229 | "iopub.execute_input": "2024-04-24T16:32:47.076581Z",
230 | "iopub.status.busy": "2024-04-24T16:32:47.076044Z",
231 | "iopub.status.idle": "2024-04-24T16:32:47.296071Z",
232 | "shell.execute_reply": "2024-04-24T16:32:47.295541Z",
233 | "shell.execute_reply.started": "2024-04-24T16:32:47.076544Z"
234 | }
235 | },
236 | "outputs": [
237 | {
238 | "data": {
239 | "text/html": [
240 | "\n",
278 | "
\n",
279 | " \n",
280 | " \n",
281 | " | \n",
282 | " | \n",
283 | " action_score | \n",
284 | " establishing-shots_score | \n",
285 | " day_score | \n",
286 | " car-chase_score | \n",
287 | "
\n",
288 | " \n",
289 | " title | \n",
290 | " video_thumbnail | \n",
291 | " | \n",
292 | " | \n",
293 | " | \n",
294 | " | \n",
295 | "
\n",
296 | " \n",
297 | " \n",
298 | " \n",
299 | " Operation Varsity Blues | \n",
300 | "  | \n",
301 | " 0.18 | \n",
302 | " 0.99 | \n",
303 | " 0.84 | \n",
304 | " 0.01 | \n",
305 | "
\n",
306 | " \n",
307 | " 6 Underground | \n",
308 | "  | \n",
309 | " 0.73 | \n",
310 | " 0.30 | \n",
311 | " 0.69 | \n",
312 | " 0.78 | \n",
313 | "
\n",
314 | " \n",
315 | " Leave the world behind | \n",
316 | "  | \n",
317 | " 0.36 | \n",
318 | " 0.12 | \n",
319 | " 0.11 | \n",
320 | " 0.04 | \n",
321 | "
\n",
322 | " \n",
323 | "
\n"
324 | ],
325 | "text/plain": [
326 | ""
327 | ]
328 | },
329 | "execution_count": 42,
330 | "metadata": {},
331 | "output_type": "execute_result"
332 | }
333 | ],
334 | "source": [
335 | "pngs = {'Operation Varsity Blues': 'varsity.png', '6 Underground': '6.png', 'Leave the world behind': 'leave.png'}\n",
336 | "df = pd.DataFrame(\n",
337 | " _get_record(title, emb)\n",
338 | " for title, emb in samples.items()\n",
339 | ").set_index(['title', 'video_thumbnail']) \n",
340 | "df = df.style.applymap(get_color).format(lambda x: f'{x:.2f}')\n",
341 | "HTML(df.to_html(escape=False))"
342 | ]
343 | },
344 | {
345 | "attachments": {},
346 | "cell_type": "markdown",
347 | "id": "6a3cc399-f546-4197-8cbf-071ce7ed6d3b",
348 | "metadata": {},
349 | "source": [
350 | "## bring your own C4C embeddings\n",
351 | "\n",
352 | "Use [Clip4CLIP](https://github.com/ArrowLuo/CLIP4Clip) to extract video embeddings from your own videos and then apply."
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": 25,
358 | "id": "3d1b3dc0-6a1e-4a38-9d64-1d8467bd499d",
359 | "metadata": {
360 | "execution": {
361 | "iopub.execute_input": "2024-04-24T16:18:01.144671Z",
362 | "iopub.status.busy": "2024-04-24T16:18:01.144166Z",
363 | "iopub.status.idle": "2024-04-24T16:18:01.147823Z",
364 | "shell.execute_reply": "2024-04-24T16:18:01.147172Z",
365 | "shell.execute_reply.started": "2024-04-24T16:18:01.144641Z"
366 | }
367 | },
368 | "outputs": [],
369 | "source": [
370 | "np.random.seed(0)\n",
371 | "your_emb = np.random.rand(1, 512) # replace with your own"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 26,
377 | "id": "f41d1b6d-2df3-4df3-a807-4cc3af2c0f4f",
378 | "metadata": {
379 | "execution": {
380 | "iopub.execute_input": "2024-04-24T16:18:01.728981Z",
381 | "iopub.status.busy": "2024-04-24T16:18:01.728377Z",
382 | "iopub.status.idle": "2024-04-24T16:18:02.332514Z",
383 | "shell.execute_reply": "2024-04-24T16:18:02.331604Z",
384 | "shell.execute_reply.started": "2024-04-24T16:18:01.728952Z"
385 | }
386 | },
387 | "outputs": [
388 | {
389 | "data": {
390 | "text/plain": [
391 | "0.10548243605868816"
392 | ]
393 | },
394 | "execution_count": 26,
395 | "metadata": {},
396 | "output_type": "execute_result"
397 | }
398 | ],
399 | "source": [
400 | "label = 'car-chase' # pick from any model in `cfg.LABELS`\n",
401 | "get_model(label=label).predict_proba(your_emb).item()"
402 | ]
403 | },
404 | {
405 | "attachments": {},
406 | "cell_type": "markdown",
407 | "id": "e51aec2e-6a65-4ed8-a46a-5f9f6748e6fd",
408 | "metadata": {},
409 | "source": [
410 | "## apply any model to any shot in the corpus from the Condensed Movie Dataset\n",
411 | "More details at [this link](https://www.robots.ox.ac.uk/~vgg/data/condensed-movies/)"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": 27,
417 | "id": "81762de3-c875-4f01-892b-da9c335ef4fb",
418 | "metadata": {
419 | "execution": {
420 | "iopub.execute_input": "2024-04-24T16:18:03.529245Z",
421 | "iopub.status.busy": "2024-04-24T16:18:03.528611Z",
422 | "iopub.status.idle": "2024-04-24T16:18:04.588503Z",
423 | "shell.execute_reply": "2024-04-24T16:18:04.587890Z",
424 | "shell.execute_reply.started": "2024-04-24T16:18:03.529216Z"
425 | }
426 | },
427 | "outputs": [
428 | {
429 | "data": {
430 | "text/html": [
431 | "\n",
432 | "\n",
445 | "
\n",
446 | " \n",
447 | " \n",
448 | " | \n",
449 | " yt_id | \n",
450 | " desc | \n",
451 | " genre | \n",
452 | " frame_in | \n",
453 | " frame_out | \n",
454 | " start | \n",
455 | " end | \n",
456 | " fps | \n",
457 | " cluster_index | \n",
458 | "
\n",
459 | " \n",
460 | " key | \n",
461 | " | \n",
462 | " | \n",
463 | " | \n",
464 | " | \n",
465 | " | \n",
466 | " | \n",
467 | " | \n",
468 | " | \n",
469 | " | \n",
470 | "
\n",
471 | " \n",
472 | " \n",
473 | " \n",
474 | " HnnBvemHrWA|0 | \n",
475 | " HnnBvemHrWA | \n",
476 | " Frankie and Johnny - He Just Asked Her Out | \n",
477 | " Comedy | \n",
478 | " 0 | \n",
479 | " 82 | \n",
480 | " 0.000000 | \n",
481 | " 3.420311 | \n",
482 | " 23.974427 | \n",
483 | " 8 | \n",
484 | "
\n",
485 | " \n",
486 | " gOJJm_cSRds|3987 | \n",
487 | " gOJJm_cSRds | \n",
488 | " The Wild Bunch - Bank Shootout | \n",
489 | " Action | \n",
490 | " 3987 | \n",
491 | " 4007 | \n",
492 | " 166.302203 | \n",
493 | " 167.136425 | \n",
494 | " 23.974427 | \n",
495 | " 6 | \n",
496 | "
\n",
497 | " \n",
498 | " mDViU8OSRkA|2222 | \n",
499 | " mDViU8OSRkA | \n",
500 | " Non-Stop - How's Your Daughter? Scene | \n",
501 | " NaN | \n",
502 | " 2222 | \n",
503 | " 2284 | \n",
504 | " 92.675917 | \n",
505 | " 95.261833 | \n",
506 | " 23.976024 | \n",
507 | " 3 | \n",
508 | "
\n",
509 | " \n",
510 | " egB-SG97EcI|805 | \n",
511 | " egB-SG97EcI | \n",
512 | " Kalifornia - I Think I Gotta Kill You Scene | \n",
513 | " Crime | \n",
514 | " 805 | \n",
515 | " 825 | \n",
516 | " 33.575190 | \n",
517 | " 34.409357 | \n",
518 | " 23.976037 | \n",
519 | " 0 | \n",
520 | "
\n",
521 | " \n",
522 | " uGsWYV2bWAc|1013 | \n",
523 | " uGsWYV2bWAc | \n",
524 | " Kill Bill: Vol. 1 - The Bride vs. Gogo | \n",
525 | " Action | \n",
526 | " 1013 | \n",
527 | " 1041 | \n",
528 | " 42.250542 | \n",
529 | " 43.418375 | \n",
530 | " 23.976024 | \n",
531 | " 7 | \n",
532 | "
\n",
533 | " \n",
534 | "
\n",
535 | "
"
536 | ],
537 | "text/plain": [
538 | " yt_id desc \\\n",
539 | "key \n",
540 | "HnnBvemHrWA|0 HnnBvemHrWA Frankie and Johnny - He Just Asked Her Out \n",
541 | "gOJJm_cSRds|3987 gOJJm_cSRds The Wild Bunch - Bank Shootout \n",
542 | "mDViU8OSRkA|2222 mDViU8OSRkA Non-Stop - How's Your Daughter? Scene \n",
543 | "egB-SG97EcI|805 egB-SG97EcI Kalifornia - I Think I Gotta Kill You Scene \n",
544 | "uGsWYV2bWAc|1013 uGsWYV2bWAc Kill Bill: Vol. 1 - The Bride vs. Gogo \n",
545 | "\n",
546 | " genre frame_in frame_out start end \\\n",
547 | "key \n",
548 | "HnnBvemHrWA|0 Comedy 0 82 0.000000 3.420311 \n",
549 | "gOJJm_cSRds|3987 Action 3987 4007 166.302203 167.136425 \n",
550 | "mDViU8OSRkA|2222 NaN 2222 2284 92.675917 95.261833 \n",
551 | "egB-SG97EcI|805 Crime 805 825 33.575190 34.409357 \n",
552 | "uGsWYV2bWAc|1013 Action 1013 1041 42.250542 43.418375 \n",
553 | "\n",
554 | " fps cluster_index \n",
555 | "key \n",
556 | "HnnBvemHrWA|0 23.974427 8 \n",
557 | "gOJJm_cSRds|3987 23.974427 6 \n",
558 | "mDViU8OSRkA|2222 23.976024 3 \n",
559 | "egB-SG97EcI|805 23.976037 0 \n",
560 | "uGsWYV2bWAc|1013 23.976024 7 "
561 | ]
562 | },
563 | "execution_count": 27,
564 | "metadata": {},
565 | "output_type": "execute_result"
566 | }
567 | ],
568 | "source": [
569 | "df_shots = pd.read_csv(io.PATHS_STATIC.shot_data).set_index('key')\n",
570 | "df_shots.head()"
571 | ]
572 | },
573 | {
574 | "attachments": {},
575 | "cell_type": "markdown",
576 | "id": "bb5ae078-757a-4c1d-ab21-ddbaab2c03f5",
577 | "metadata": {},
578 | "source": [
579 | "### pick a key from the dataframe above"
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": 28,
585 | "id": "b30b22db-aa12-420c-9e2d-e64c300e5aff",
586 | "metadata": {
587 | "execution": {
588 | "iopub.execute_input": "2024-04-24T16:18:05.572441Z",
589 | "iopub.status.busy": "2024-04-24T16:18:05.571805Z",
590 | "iopub.status.idle": "2024-04-24T16:18:05.575349Z",
591 | "shell.execute_reply": "2024-04-24T16:18:05.574731Z",
592 | "shell.execute_reply.started": "2024-04-24T16:18:05.572412Z"
593 | }
594 | },
595 | "outputs": [],
596 | "source": [
597 | "key = 'HnnBvemHrWA|0'"
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "execution_count": 29,
603 | "id": "2c547f22-61fa-49a9-ace4-0e73d215f618",
604 | "metadata": {
605 | "execution": {
606 | "iopub.execute_input": "2024-04-24T16:18:06.156529Z",
607 | "iopub.status.busy": "2024-04-24T16:18:06.155931Z",
608 | "iopub.status.idle": "2024-04-24T16:18:06.160842Z",
609 | "shell.execute_reply": "2024-04-24T16:18:06.160237Z",
610 | "shell.execute_reply.started": "2024-04-24T16:18:06.156499Z"
611 | }
612 | },
613 | "outputs": [
614 | {
615 | "data": {
616 | "text/plain": [
617 | "0.15079263042877716"
618 | ]
619 | },
620 | "execution_count": 29,
621 | "metadata": {},
622 | "output_type": "execute_result"
623 | }
624 | ],
625 | "source": [
626 | "get_score(label='action', key=key)"
627 | ]
628 | },
629 | {
630 | "cell_type": "code",
631 | "execution_count": 30,
632 | "id": "c71691d3-3421-4b74-ab1e-e37da32d2528",
633 | "metadata": {
634 | "execution": {
635 | "iopub.execute_input": "2024-04-24T16:18:06.725140Z",
636 | "iopub.status.busy": "2024-04-24T16:18:06.724526Z",
637 | "iopub.status.idle": "2024-04-24T16:18:07.366689Z",
638 | "shell.execute_reply": "2024-04-24T16:18:07.365931Z",
639 | "shell.execute_reply.started": "2024-04-24T16:18:06.725110Z"
640 | }
641 | },
642 | "outputs": [
643 | {
644 | "data": {
645 | "text/plain": [
646 | "0.8293545496789174"
647 | ]
648 | },
649 | "execution_count": 30,
650 | "metadata": {},
651 | "output_type": "execute_result"
652 | }
653 | ],
654 | "source": [
655 | "get_score(label='closeup', key=key)"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": 31,
661 | "id": "2ddb686b-4cd7-44cd-83c4-4a87d9f279b4",
662 | "metadata": {
663 | "execution": {
664 | "iopub.execute_input": "2024-04-24T16:18:07.368947Z",
665 | "iopub.status.busy": "2024-04-24T16:18:07.368171Z",
666 | "iopub.status.idle": "2024-04-24T16:18:07.375109Z",
667 | "shell.execute_reply": "2024-04-24T16:18:07.374429Z",
668 | "shell.execute_reply.started": "2024-04-24T16:18:07.368902Z"
669 | }
670 | },
671 | "outputs": [
672 | {
673 | "data": {
674 | "text/plain": [
675 | "0.02759443746631655"
676 | ]
677 | },
678 | "execution_count": 31,
679 | "metadata": {},
680 | "output_type": "execute_result"
681 | }
682 | ],
683 | "source": [
684 | "get_score(label='establishing-shots', key=key)"
685 | ]
686 | },
687 | {
688 | "attachments": {},
689 | "cell_type": "markdown",
690 | "id": "30b13fbd-9f36-4d4c-b5ff-d91eec5f53c3",
691 | "metadata": {},
692 | "source": [
693 | "## apply all models to a shot in the corpus"
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": 32,
699 | "id": "3d9996b5-c655-4fd0-8454-d69489518fbe",
700 | "metadata": {
701 | "execution": {
702 | "iopub.execute_input": "2024-04-24T16:18:08.473189Z",
703 | "iopub.status.busy": "2024-04-24T16:18:08.472662Z",
704 | "iopub.status.idle": "2024-04-24T16:18:08.476199Z",
705 | "shell.execute_reply": "2024-04-24T16:18:08.475710Z",
706 | "shell.execute_reply.started": "2024-04-24T16:18:08.473160Z"
707 | }
708 | },
709 | "outputs": [],
710 | "source": [
711 | "def get_link(key: str) -> HTML:\n",
712 | " rec = df_shots.loc[key]\n",
713 | " link = f'https://youtu.be/{rec.yt_id}?t={int(rec.start)}'\n",
714 | " return HTML(f'click here to watch')"
715 | ]
716 | },
717 | {
718 | "cell_type": "code",
719 | "execution_count": 33,
720 | "id": "1593eec5-fbd5-4eaa-817e-f29b109df8aa",
721 | "metadata": {
722 | "execution": {
723 | "iopub.execute_input": "2024-04-24T16:18:09.089381Z",
724 | "iopub.status.busy": "2024-04-24T16:18:09.088766Z",
725 | "iopub.status.idle": "2024-04-24T16:18:25.917844Z",
726 | "shell.execute_reply": "2024-04-24T16:18:25.917005Z",
727 | "shell.execute_reply.started": "2024-04-24T16:18:09.089350Z"
728 | }
729 | },
730 | "outputs": [
731 | {
732 | "data": {
733 | "text/plain": [
734 | "eye-level 0.89\n",
735 | "interior 0.83\n",
736 | "closeup 0.83\n",
737 | "character-focus 0.75\n",
738 | "single-shot 0.65\n",
739 | "static-shot 0.55\n",
740 | "medium 0.44\n",
741 | "fantasy 0.39\n",
742 | "object 0.39\n",
743 | "low-angle 0.37\n",
744 | "drama 0.34\n",
745 | "violence 0.33\n",
746 | "horror 0.31\n",
747 | "over-the-shoulder-shot 0.31\n",
748 | "day 0.27\n",
749 | "scared 0.22\n",
750 | "smoking 0.22\n",
751 | "anger 0.20\n",
752 | "happy 0.17\n",
753 | "handheld 0.16\n",
754 | "animal 0.15\n",
755 | "action 0.15\n",
756 | "sad 0.12\n",
757 | "tilt-shot 0.11\n",
758 | "pan 0.10\n",
759 | "insert-shot 0.09\n",
760 | "cowboy-shot 0.08\n",
761 | "alcohol 0.07\n",
762 | "zoom 0.07\n",
763 | "intimacy 0.06\n",
764 | "dutch-angle 0.06\n",
765 | "slow-motion 0.05\n",
766 | "run 0.05\n",
767 | "wide 0.05\n",
768 | "fight 0.05\n",
769 | "high-angle 0.05\n",
770 | "two-shot 0.04\n",
771 | "group-shot 0.04\n",
772 | "aerial 0.03\n",
773 | "establishing-shots 0.03\n",
774 | "point-of-view-shot 0.03\n",
775 | "extreme-close-up 0.02\n",
776 | "drugs 0.02\n",
777 | "golden-hour 0.02\n",
778 | "sci-fi 0.02\n",
779 | "romance 0.02\n",
780 | "gore 0.01\n",
781 | "laughter 0.01\n",
782 | "nudity 0.01\n",
783 | "shutter-shot 0.01\n",
784 | "extreme-wide-shot 0.01\n",
785 | "car-chase 0.01\n",
786 | "jump-scare 0.01\n",
787 | "overhead-shot 0.01\n",
788 | "interview 0.01\n",
789 | "timelapse 0.00\n",
790 | "dtype: float64"
791 | ]
792 | },
793 | "execution_count": 33,
794 | "metadata": {},
795 | "output_type": "execute_result"
796 | }
797 | ],
798 | "source": [
799 | "# sort by score descending\n",
800 | "pd.Series(get_score_all_labels(key=key)).sort_values(ascending=False).round(2)"
801 | ]
802 | },
803 | {
804 | "cell_type": "code",
805 | "execution_count": 34,
806 | "id": "8796347f-834c-4062-95e4-a3f556ac9fd1",
807 | "metadata": {
808 | "execution": {
809 | "iopub.execute_input": "2024-04-24T16:18:25.923165Z",
810 | "iopub.status.busy": "2024-04-24T16:18:25.922349Z",
811 | "iopub.status.idle": "2024-04-24T16:18:26.024421Z",
812 | "shell.execute_reply": "2024-04-24T16:18:26.023892Z",
813 | "shell.execute_reply.started": "2024-04-24T16:18:25.923124Z"
814 | }
815 | },
816 | "outputs": [
817 | {
818 | "data": {
819 | "text/html": [
820 | "click here to watch"
821 | ],
822 | "text/plain": [
823 | ""
824 | ]
825 | },
826 | "execution_count": 34,
827 | "metadata": {},
828 | "output_type": "execute_result"
829 | }
830 | ],
831 | "source": [
832 | "# get the link to watch\n",
833 | "get_link(key=key)"
834 | ]
835 | }
836 | ],
837 | "metadata": {
838 | "kernelspec": {
839 | "display_name": "Python 3.10 (recommended)",
840 | "language": "python",
841 | "name": "python310"
842 | },
843 | "language_info": {
844 | "codemirror_mode": {
845 | "name": "ipython",
846 | "version": 3
847 | },
848 | "file_extension": ".py",
849 | "mimetype": "text/x-python",
850 | "name": "python",
851 | "nbconvert_exporter": "python",
852 | "pygments_lexer": "ipython3",
853 | "version": "3.10.14"
854 | }
855 | },
856 | "nbformat": 4,
857 | "nbformat_minor": 5
858 | }
859 |
--------------------------------------------------------------------------------
/exp2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "33a23872-6b85-40c8-baa8-f2a42702b630",
7 | "metadata": {},
8 | "source": [
9 | "# Experiment 2\n",
10 | "Need to run experiment 1 first"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "dbe1fb03",
17 | "metadata": {
18 | "execution": {
19 | "iopub.execute_input": "2024-02-09T06:58:20.320622Z",
20 | "iopub.status.busy": "2024-02-09T06:58:20.320375Z",
21 | "iopub.status.idle": "2024-02-09T06:58:20.721700Z",
22 | "shell.execute_reply": "2024-02-09T06:58:20.720953Z",
23 | "shell.execute_reply.started": "2024-02-09T06:58:20.320598Z"
24 | },
25 | "tags": []
26 | },
27 | "outputs": [],
28 | "source": [
29 | "import pathlib\n",
30 | "import functools\n",
31 | "import pandas as pd\n",
32 | "import numpy as np\n",
33 | "import math\n",
34 | "import pickle\n",
35 | "\n",
36 | "from tqdm.auto import tqdm\n",
37 | "\n",
38 | "from videoannotator import data, dp, experiments"
39 | ]
40 | },
41 | {
42 | "attachments": {},
43 | "cell_type": "markdown",
44 | "id": "243bc89c-e52c-4c40-a086-9fa790de4f39",
45 | "metadata": {},
46 | "source": [
47 | "## helpers"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "id": "fb6ab183-2d5e-4b71-9475-ceb7bbf03cbb",
54 | "metadata": {
55 | "execution": {
56 | "iopub.execute_input": "2024-02-09T06:58:20.724968Z",
57 | "iopub.status.busy": "2024-02-09T06:58:20.724526Z",
58 | "iopub.status.idle": "2024-02-09T06:58:20.734467Z",
59 | "shell.execute_reply": "2024-02-09T06:58:20.733913Z",
60 | "shell.execute_reply.started": "2024-02-09T06:58:20.724943Z"
61 | },
62 | "tags": []
63 | },
64 | "outputs": [],
65 | "source": [
66 | "dps = {\n",
67 | " 'eps0.1': functools.partial(dp.EpsilonGreedyMean, epsilon=0.1),\n",
68 | " 'eps0.25': functools.partial(dp.EpsilonGreedyMean, epsilon=0.25),\n",
69 | " 'eps0.5': functools.partial(dp.EpsilonGreedyMean, epsilon=0.5),\n",
70 | " 'rr': dp.RoundRobin,\n",
71 | " 'greedy': dp.GreedyOracle,\n",
72 | " 'ucb': dp.UCBMean,\n",
73 | " 'ucb.1': functools.partial(dp.UCBMean, c=1e-1),\n",
74 | " # 'ecbwema': dp.UCBEWMA,\n",
75 | " # 'eps0.25ewma': functools.partial(dp.EpsilonGreedyEWMA, epsilon=0.25),\n",
76 | "}\n",
77 | "\n",
78 | "@functools.lru_cache(10_000)\n",
79 | "def run(dp_name, d, label):\n",
80 | " dp = dps[dp_name]\n",
81 | " try:\n",
82 | " return dict(\n",
83 | " dp=dp,\n",
84 | " d=d,\n",
85 | " label=label,\n",
86 | " res=dp(label=label, d=d, _seed_base=1).run(n=1_000 // d),\n",
87 | " )\n",
88 | " except Exception as e:\n",
89 | " raise ValueError(f'Error in dp={dp}, d={d}, label={label}: {e}')\n",
90 | " \n",
91 | "def get_best(r):\n",
92 | " res = r['res']\n",
93 | " return max(\n",
94 | " (\n",
95 | " (res.average_precision.mean, res.average_precision)\n",
96 | " for _, res in res.res_hist\n",
97 | " ),\n",
98 | " key=lambda x: x[0],\n",
99 | " )[1]\n",
100 | "\n",
101 | "def get_algo_detes(r):\n",
102 | " dp = r['dp']\n",
103 | " if dp.__class__ == functools.partial:\n",
104 | " return dict(\n",
105 | " algo=dp.func.__name__,\n",
106 | " kws=dp.keywords,\n",
107 | " args=dp.args,\n",
108 | " )\n",
109 | " else:\n",
110 | " return dict(algo=dp.__name__)\n",
111 | " \n",
112 | "def eps(q):\n",
113 | " if pd.notnull(q) and 'epsilon' in q:\n",
114 | " return q.get('epsilon') == .25\n",
115 | " \n",
116 | " return True"
117 | ]
118 | },
119 | {
120 | "attachments": {},
121 | "cell_type": "markdown",
122 | "id": "35983077-22a6-4045-87f4-1ed162ddb21a",
123 | "metadata": {},
124 | "source": [
125 | "## run experiment 2"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 3,
131 | "id": "a8e9e3e5-e6be-44a2-9144-6abb528b3128",
132 | "metadata": {
133 | "execution": {
134 | "iopub.execute_input": "2024-02-09T06:58:20.737344Z",
135 | "iopub.status.busy": "2024-02-09T06:58:20.736926Z",
136 | "iopub.status.idle": "2024-02-09T06:58:25.203956Z",
137 | "shell.execute_reply": "2024-02-09T06:58:25.203228Z",
138 | "shell.execute_reply.started": "2024-02-09T06:58:20.737321Z"
139 | },
140 | "tags": []
141 | },
142 | "outputs": [],
143 | "source": [
144 | "# this takes ~20 mins on a machine with 64 CPUs and 64GB RAM\n",
145 | "path_results_pkl = 'exp2_results.pkl'\n",
146 | "if not pathlib.Path(path_results_pkl).exists():\n",
147 | " # takes a few hours\n",
148 | " items = [\n",
149 | " (dp_name, d, label)\n",
150 | " for i, dp_name in enumerate(dps)\n",
151 | " for d in (25,)\n",
152 | " for label in sorted(data.cfg.LABELS)\n",
153 | " ]\n",
154 | " res = [run(*q) for q in tqdm(items)]\n",
155 | " pickle.dump(res, open(path_results_pkl, 'wb'))\n",
156 | "else:\n",
157 | " res = pickle.load(open(path_results_pkl, 'rb'))"
158 | ]
159 | },
160 | {
161 | "attachments": {},
162 | "cell_type": "markdown",
163 | "id": "783eb716-826e-449b-9549-ccbf637b2cc7",
164 | "metadata": {},
165 | "source": [
166 | "## analyze"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 4,
172 | "id": "137666f0",
173 | "metadata": {
174 | "execution": {
175 | "iopub.execute_input": "2024-02-09T06:58:25.211433Z",
176 | "iopub.status.busy": "2024-02-09T06:58:25.211089Z",
177 | "iopub.status.idle": "2024-02-09T06:58:26.310674Z",
178 | "shell.execute_reply": "2024-02-09T06:58:26.309992Z",
179 | "shell.execute_reply.started": "2024-02-09T06:58:25.211408Z"
180 | },
181 | "tags": []
182 | },
183 | "outputs": [],
184 | "source": [
185 | "# this object is computed in experiment 1\n",
186 | "ap_va = pickle.load(open('ap_va_best.pkl', 'rb'))\n",
187 | "ap_va_best = {k: v['best'] for k, v in ap_va.items()}\n",
188 | "ap_va_last = {k: v['last'] for k, v in ap_va.items()}\n",
189 | "\n",
190 | "df_res = pd.DataFrame(\n",
191 | " dict(\n",
192 | " d=r['d'],\n",
193 | " label=r['label'],\n",
194 | " best_idx=r['res'].best_idx,\n",
195 | " best_metric=get_best(r),\n",
196 | " best_metric_mean=get_best(r).mean,\n",
197 | " last_metric=r['res'].last_metric,\n",
198 | " last_metric_mean=r['res'].last_metric.mean,\n",
199 | " res=r['res'],\n",
200 | " after_5=r['res'].res_hist[5][1].average_precision.mean,\n",
201 | " after_10=r['res'].res_hist[10][1].average_precision.mean,\n",
202 | " **get_algo_detes(r),\n",
203 | " )\n",
204 | " for r in res\n",
205 | ")\n",
206 | "df_res = df_res.assign(\n",
207 | " ap_va_best=df_res.label.map(ap_va_best),\n",
208 | " ap_va_last=df_res.label.map(ap_va_last),\n",
209 | ")\n",
210 | "d = df_res[df_res.algo == 'UCBMean'].groupby(['d', 'label']).apply(lambda x: pd.Series(dict(\n",
211 | " best_va=x.ap_va_best.max(),\n",
212 | " best_bandit=x.best_metric_mean.max(),\n",
213 | " last_va=x.ap_va_last.max(),\n",
214 | " last_bandit=x.last_metric_mean.max(),\n",
215 | ")))\n",
216 | "d = d.assign(\n",
217 | " lift_best=d.best_bandit - d.best_va,\n",
218 | " lift_max=d.last_bandit - d.last_va,\n",
219 | ")\n",
220 | "dfr = df_res[df_res.kws.apply(eps)]\n",
221 | "e = dfr[dfr.d == 25].groupby(['algo', 'label']).apply(lambda x: pd.Series(dict(\n",
222 | " best_va=x.ap_va_best.max(),\n",
223 | " best_bandit=x.best_metric_mean.max(),\n",
224 | ")))\n",
225 | "e = e.assign(lift_best=e.best_bandit - e.best_va,)\n",
226 | "algos = ['RoundRobin', 'GreedyOracle', 'EpsilonGreedyMean', 'UCBMean']\n",
227 | "cols = ['min', '25%', '50%', '75%', 'max']"
228 | ]
229 | },
230 | {
231 | "attachments": {},
232 | "cell_type": "markdown",
233 | "id": "854e5e4e-0fd0-40db-91e6-e6b766341f87",
234 | "metadata": {},
235 | "source": [
236 | "## table 3"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 5,
242 | "id": "fa91f320",
243 | "metadata": {
244 | "execution": {
245 | "iopub.execute_input": "2024-02-09T06:58:26.318997Z",
246 | "iopub.status.busy": "2024-02-09T06:58:26.318663Z",
247 | "iopub.status.idle": "2024-02-09T06:58:26.343227Z",
248 | "shell.execute_reply": "2024-02-09T06:58:26.342682Z",
249 | "shell.execute_reply.started": "2024-02-09T06:58:26.318973Z"
250 | },
251 | "tags": []
252 | },
253 | "outputs": [
254 | {
255 | "data": {
256 | "text/html": [
257 | "\n",
258 | "\n",
271 | "
\n",
272 | " \n",
273 | " \n",
274 | " | \n",
275 | " p10 | \n",
276 | " p25 | \n",
277 | " p50 | \n",
278 | " p75 | \n",
279 | " p90 | \n",
280 | "
\n",
281 | " \n",
282 | " algo | \n",
283 | " | \n",
284 | " | \n",
285 | " | \n",
286 | " | \n",
287 | " | \n",
288 | "
\n",
289 | " \n",
290 | " \n",
291 | " \n",
292 | " RoundRobin | \n",
293 | " -5.9 | \n",
294 | " -0.3 | \n",
295 | " 2.5 | \n",
296 | " 7.4 | \n",
297 | " 15.1 | \n",
298 | "
\n",
299 | " \n",
300 | " GreedyOracle | \n",
301 | " -5.4 | \n",
302 | " 0.9 | \n",
303 | " 3.9 | \n",
304 | " 9.6 | \n",
305 | " 16.2 | \n",
306 | "
\n",
307 | " \n",
308 | " EpsilonGreedyMean | \n",
309 | " -6.5 | \n",
310 | " -0.6 | \n",
311 | " 2.9 | \n",
312 | " 7.6 | \n",
313 | " 16.7 | \n",
314 | "
\n",
315 | " \n",
316 | " UCBMean | \n",
317 | " -7.5 | \n",
318 | " -0.2 | \n",
319 | " 3.4 | \n",
320 | " 9.2 | \n",
321 | " 15.9 | \n",
322 | "
\n",
323 | " \n",
324 | "
\n",
325 | "
"
326 | ],
327 | "text/plain": [
328 | " p10 p25 p50 p75 p90\n",
329 | "algo \n",
330 | "RoundRobin -5.9 -0.3 2.5 7.4 15.1\n",
331 | "GreedyOracle -5.4 0.9 3.9 9.6 16.2\n",
332 | "EpsilonGreedyMean -6.5 -0.6 2.9 7.6 16.7\n",
333 | "UCBMean -7.5 -0.2 3.4 9.2 15.9"
334 | ]
335 | },
336 | "execution_count": 5,
337 | "metadata": {},
338 | "output_type": "execute_result"
339 | }
340 | ],
341 | "source": [
342 | "ee = e.reset_index().groupby('algo').lift_best.agg(\n",
343 | " p10=functools.partial(np.percentile, q=10),\n",
344 | " p25=functools.partial(np.percentile, q=25),\n",
345 | " p50=functools.partial(np.percentile, q=50),\n",
346 | " p75=functools.partial(np.percentile, q=75),\n",
347 | " p90=functools.partial(np.percentile, q=90)\n",
348 | ").loc[algos]\n",
349 | "(ee * 100).round(1)"
350 | ]
351 | },
352 | {
353 | "attachments": {},
354 | "cell_type": "markdown",
355 | "id": "c0560a99-4f29-4de5-8d42-3b87b632fbd2",
356 | "metadata": {},
357 | "source": [
358 | "## cumulative results\n",
359 | "Add up the gain from experiment 1 to this gain and compute the median across labels."
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": 6,
365 | "id": "6eac9359-e97c-42c0-9176-e716f5c4246f",
366 | "metadata": {
367 | "execution": {
368 | "iopub.execute_input": "2024-02-09T06:58:26.344418Z",
369 | "iopub.status.busy": "2024-02-09T06:58:26.343981Z",
370 | "iopub.status.idle": "2024-02-09T06:58:26.549934Z",
371 | "shell.execute_reply": "2024-02-09T06:58:26.549376Z",
372 | "shell.execute_reply.started": "2024-02-09T06:58:26.344394Z"
373 | },
374 | "tags": []
375 | },
376 | "outputs": [
377 | {
378 | "data": {
379 | "text/plain": [
380 | "8.3"
381 | ]
382 | },
383 | "execution_count": 6,
384 | "metadata": {},
385 | "output_type": "execute_result"
386 | }
387 | ],
388 | "source": [
389 | "e1 = pickle.load(open('exp1_gain.pkl', 'rb'))\n",
390 | "f = e.reset_index()\n",
391 | "e2 = f[f.algo == 'UCBMean'].set_index('label').lift_best.to_dict()\n",
392 | "comb = pd.Series({\n",
393 | " k: e1[k] + e2[k]\n",
394 | " for k in e1\n",
395 | " if pd.notnull(e1[k]) and pd.notnull(e2[k])\n",
396 | "})\n",
397 | "(np.percentile(comb, q=50) * 100).round(1)"
398 | ]
399 | }
400 | ],
401 | "metadata": {
402 | "hide_input": false,
403 | "kernelspec": {
404 | "display_name": "Python 3 (ipykernel)",
405 | "language": "python",
406 | "name": "python3"
407 | },
408 | "language_info": {
409 | "codemirror_mode": {
410 | "name": "ipython",
411 | "version": 3
412 | },
413 | "file_extension": ".py",
414 | "mimetype": "text/x-python",
415 | "name": "python",
416 | "nbconvert_exporter": "python",
417 | "pygments_lexer": "ipython3",
418 | "version": "3.7.12"
419 | }
420 | },
421 | "nbformat": 4,
422 | "nbformat_minor": 5
423 | }
424 |
--------------------------------------------------------------------------------
/images/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/images/6.png
--------------------------------------------------------------------------------
/images/leave.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/images/leave.png
--------------------------------------------------------------------------------
/images/varsity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/images/varsity.png
--------------------------------------------------------------------------------
/sys.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/sys.png
--------------------------------------------------------------------------------
/task_descriptions.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": {
3 | "description": "Climactic moments or exciting, exhilarating or fast-paced shots or sequences",
4 | "keywords": "Explosions, Car Chase, Fighting, Running, Shootout"
5 | },
6 | "intimacy": {
7 | "description": "Two characters being physically close to each other and showing affection",
8 | "keywords": "Hugging, Kissing, Caressing, Cuddling"
9 | },
10 | "sci-fi": {
11 | "description": "Imaginary concepts that blend science and fantasy",
12 | "keywords": "Space, Aliens, Monsters, Robots, Futuristic Technology, Lasers"
13 | },
14 | "character-focus": {
15 | "description": "Characters are clearly in focus and properly displayed",
16 | "keywords": "Profile shots, Facial shots"
17 | },
18 | "establishing-shots": {
19 | "description": "Extremely wide shot that sets the location or time period",
20 | "keywords": "Landscapes, Buildings"
21 | },
22 | "nudity": {
23 | "description": "Exposed flesh or socially coveted body parts",
24 | "keywords": "genitals, breast, buttocks"
25 | },
26 | "gore": {
27 | "description": "Extreme violence against another person",
28 | "keywords": "Blood Spewing, Decapitation, Impact Shots, Stabbing"
29 | },
30 | "smoking": {
31 | "description": "Characters smoking tobacco or drugs",
32 | "keywords": "Cigarette, Cigars, Vapes, Drugs"
33 | },
34 | "happy": {
35 | "description": "Characters expressing happiness",
36 | "keywords": "Smiling, Joyful, Ecstatic"
37 | },
38 | "sad": {
39 | "description": "Characters expressing sadness",
40 | "keywords": "Crying, Tearing, Frowning, Looking Down, Despair"
41 | },
42 | "horror": {
43 | "description": "Content aimed at eliciting fear, shock or disgust. This applies to story, characters, mood and atmosphere. Elements that suggest that the character is in danger or fearing for their life. Subgenres include psychological horror, supernatural horror, gore, slashers",
44 | "keywords": "Death, Gore, Blood, Violence, Devil, Demons, Paranormal, Monsters, Cults, Serial Killers"
45 | },
46 | "alcohol": {
47 | "description": "Alcohol consumption is when a character drinks an alcoholic beverage that, depending on the quantity consumed, can cause intoxication or feelings of euphoria, inhibition of judgment, reduced anxiety, and/or sociability. The consumption of alcohol does not have to be shown for this tag to apply, but it should be clearly established that alcohol is the cause of actions/ behaviour depicted",
48 | "keywords": "Alcohol bottles, Beer, Shots, Characters Inebriated, Being Drunk"
49 | },
50 | "anger": {
51 | "description": "Character expressing hostility, discontent, annoyance. Anger can be expressed in a range of low, medium, high; passive aggression, open aggression and assertive aggression.",
52 | "keywords": "Yelling, Shouting, Rage, Tense, Displeased, Enacting Violence"
53 | },
54 | "aerial": {
55 | "description": "A shot captured from an elevated vantage point, taken above the subject.",
56 | "keywords": "Drone Shots, Helicopter shots, Bird’s Eye View, Overhead, Camera pointing down"
57 | },
58 | "eye-level": {
59 | "description": "When the level of the camera is placed at the same height as the eyes of the characters in your frame. An eye level camera angle does not require the viewer to see the eyes of the actor, nor does the actor need to look directly into the camera for a shot to be considered eye level.",
60 | "keywords": "Character Facing, Head on shots, Looking Straight Ahead, Medium Close up"
61 | },
62 | "medium": {
63 | "description": "Shots depict the subject approximately from the waist up emphasizing both the subject and their surrounding. ",
64 | "keywords": "Interview, Table Talk, Conversation"
65 | },
66 | "closeup": {
67 | "description": "Shots taken at a close range intended to show greater detail.",
68 | "keywords": "Face, Hands, Objects, Details"
69 | },
70 | "wide": {
71 | "description": "Wide shots, also known as long shot, show the entire subject and their relation to the surrounding environment.",
72 | "keywords": "Establishing shots, City Scapes, Landscape, Field, Nature"
73 | },
74 | "zoom": {
75 | "description": "A zoom shot is when the focal length of a camera lens is adjusted to give the illusion of moving closer or further away from the subject. Zoom shots are done with a zoom lens, which have variable focal lengths. The most typical camera zoom types are the \"zoom in\" and \"zoom out.\" But this technique can also be combined with a dolly shot to create a dolly zoom.",
76 | "keywords": "Zoom In / out, Push In / Out, Dolly In / Out, Hitchcock Zoom"
77 | },
78 | "pan": {
79 | "description": "A camera pan is horizontal movement in which the camera moves right to left from a central axis. Taken from the word \"panorama,\" the physical position of the camera itself does not move. Instead, the panning refers to the swiveling left to right mounted on a fixed location. This differs from dolly shots in which the mounting system itself moves. Of course, a pan can also be combined with different camera movements such as dolly shots or tracking shots.",
80 | "keywords": "Reveal, Turn, Tracking"
81 | },
82 | "drama": {
83 | "description": "Titles with a serious, sober gravity of overall tone or titles that attempt to present a story that could conceivably happen in real life. A moment of high emotion.",
84 | "keywords": "Emotional, Conflict, Serious, Life, Relationships, Reflection, Realization"
85 | },
86 | "jump-scare": {
87 | "description": "A technique often used in horror films, intended to scare the audience by surprising them with an abrupt change in image or event usually co-occurring with a loud, jarring sound.",
88 | "keywords": "Screech, Bang, Scream, Shock, Horror"
89 | },
90 | "car-chase": {
91 | "description": "An exciting action sequence where characters are pursuing each other in a moving vehicle. These scenes usually involve reckless driving, daring stunts, and tension.",
92 | "keywords": "Car Chase, Speeding, Driving, Car Crash, Swerving"
93 | },
94 | "violence": {
95 | "description": "Violence is any physical act initiated by a perpetrator to harm or damage oneself, another living creature, a personified object or character, or an inanimate object.",
96 | "keywords": "Hitting, Kicking, Throwing, Shooting, Stabbing, Slashing, Blood, Gore, Death, Harm"
97 | },
98 | "drugs": {
99 | "description": "Drugs that are illegal or legal substances that are used specifically for the intent of altering one’s state of consciousness or physical feeling, becoming intoxicated, or detaching from reality.",
100 | "keywords": "Drugs, Pills, High, Intoxicated, Snort, Swallow, Marijuana, Cocaine, Ecstasy, Needles"
101 | },
102 | "laughter": {
103 | "description": "Characters laughing. Laughter can apply to a range of emotions. Joyous laugh, nervous laugh, ominous laugh.",
104 | "keywords": "Laughing, Chuckling, Giggling, Ha Ha"
105 | },
106 | "animal": {
107 | "description": "A mammal, fish, bird, reptile or amphibian",
108 | "keywords": "Dog, Cat, Elephant, Pigeon, Iguana, Turtle, Bear"
109 | },
110 | "object": {
111 | "description": "A material thing that can be seen or touched. Items that are being used by characters or are the focus of the shot",
112 | "keywords": "Briefcase, Lamp, Phone, Money, Cups"
113 | },
114 | "interior": {
115 | "description": "Situated inside something enclosed",
116 | "keywords": "Room, Office, Bedroom, Bathroom, Livingroom, Car, Bus"
117 | },
118 | "day": {
119 | "description": "Settings where there is daylight. Natural light is scene in the setting",
120 | "keywords": "Daylight, Day time, Sun, Afternoon"
121 | },
122 | "overhead-shot": {
123 | "description": "When the camera is placed directly above the subject. It's somewhere around a 90-degree angle above the scene taking place. Overhead shots are also called a bird view, bird's eye view, or elevated shot. Its placement is somewhat near the action, differing from the aerial shot. An aerial shot would be closer to a satellite or airplane view, closer to a God's eye view, and an overhead shot is closer to a bird's eye view.",
124 | "keywords": "Overhead, Bird’s Eye, God’s View, Above, Looking down, Floor"
125 | },
126 | "low-angle": {
127 | "description": "A shot from a camera angle positioned anywhere below the eye line, pointing upward. It can even be below the character’s feet",
128 | "keywords": "Low, Looking Up, Sky, High Ground"
129 | },
130 | "single-shot": {
131 | "description": "One subject or character in focus",
132 | "keywords": "Character, Profile, Hero"
133 | },
134 | "group-shot": {
135 | "description": "Group of three or more characters within the focus of the frame",
136 | "keywords": "Group, Team, Friends, Gathering"
137 | },
138 | "over-the-shoulder-shot": {
139 | "description": "Shows the main subject from behind the shoulder of another subject",
140 | "keywords": "Conversation, Approach, Shoulder"
141 | },
142 | "point-of-view-shot": {
143 | "description": "Seeing from a character’s point of view",
144 | "keywords": "POV, Watching, Seeing, Looking"
145 | },
146 | "handheld": {
147 | "description": "A shot is taken with the camera being supported only by the operator’s hands and shoulder.",
148 | "keywords": "Camera Footage, Found Footage, Shaky, Unstable Movement, Tense"
149 | },
150 | "slow-motion": {
151 | "description": "The motion of is slower than normal",
152 | "keywords": "Slow motion, Freeze "
153 | },
154 | "timelapse": {
155 | "description": "A technique to show a long period of time shown quickly. Usually to capture a progression of action that is normally moving slow. Timelapse is a form of fast motion.",
156 | "keywords": "Fast Forward, Time Span, Fast Movement"
157 | },
158 | "interview": {
159 | "description": "Subjects breaking the fourth wall and speaking directly to the camera, audience or interviewer.",
160 | "keywords": "Interview, Conversation, Documentary, Explanation"
161 | },
162 | "romance": {
163 | "description": "Content focused on love, romantic relationships, dating or affairs of the heart that center on falling in love, passion, emotion and the affectionate involvement of the main characters.",
164 | "keywords": "Looks of Affection, Embrace, Kissing, Hugging, Holding Hands"
165 | },
166 | "fantasy": {
167 | "description": "Titles that contain imaginary and often magical elements which are far removed from reality: wizards, witches, dragons, vampires, elves, trolls, fairies, etc. It may be set in a fictional past of swords, horses, knights and princesses or a grounded reality. Do not confuse this genre with Sci-Fi.",
168 | "keywords": "Magic, Creatures, Spectacle, Mystical"
169 | },
170 | "fight": {
171 | "description": "Characters physically battling each other using weapons or their body",
172 | "keywords": "Combat, Martial Arts, Punch, Kick, Hit, Weapons"
173 | },
174 | "run": {
175 | "description": "Characters running",
176 | "keywords": "Run, Sprint, Speed, Fast"
177 | },
178 | "scared": {
179 | "description": "Characters in fear or frightened. This can be in varying degrees.",
180 | "keywords": "Fear, Frightened, Panicked, Screaming"
181 | },
182 | "golden-hour": {
183 | "description": "The hour as the sun rises or as the sun sets",
184 | "keywords": "Sunrise, Sunset, Dawn"
185 | },
186 | "extreme-close-up": {
187 | "description": "An extreme close-up shot frames a subject very closely, often so much so that the outer portions of the subject are cut off by the edges of the frame.",
188 | "keywords": "Eyes, Details, Zoom"
189 | },
190 | "extreme-wide-shot": {
191 | "description": "A wide shot that covers a wider area. The shot frames the subject from a distance and focuses on its surroundings. Remember: In a wide shot, the character is usually still noticeable in the frame. But in an ELS, it’s not always necessary to see the character. The EWS functions best as an establishing shot, as more of the location is in the frame.",
192 | "keywords": "Far away, Distant, Extreme Long Shot"
193 | },
194 | "cowboy-shot": {
195 | "description": "A shot framed from the actor’s mid-waist to right above their head. Many times the camera is placed at hip level as well, to give a slightly low angle to the shot. The idea of a cowboy shot is to present the actor in a heroic, confident fashion while also allowing the viewer to see action taking place around the actor’s waistline, often to see guns or swords drawn.",
196 | "keywords": "Medium Full Shot, Waist, Hip"
197 | },
198 | "high-angle": {
199 | "description": "A high angle shot is when the camera looks down on the character or subject from an elevated perspective. This is usually achieved by placing the camera higher than the subject and then angling it down on them. This can range from a shallow angle just above eye level all the way to directly above the subject.",
200 | "keywords": "Looking Down, Edge, Cliff, Hanging, Characters Looking Up"
201 | },
202 | "dutch-angle": {
203 | "description": "A type of camera shot that has a noticeable tilt on the camera’s “x-axis.” Directors often use a Dutch angle to signal to the viewer that something is wrong, disorienting, or unsettling.",
204 | "keywords": "Tilt, Disoriented, Strange, Unsettling"
205 | },
206 | "two-shot": {
207 | "description": "A two shot is a shot in which the frames a view of the subjects. The subjects do not have to be next to each other. There are many which have one subject in the foreground and the other subject in the background. The shot is used to show the emotional reactions between the subjects.",
208 | "keywords": "Two, Pair, Couple, Together"
209 | },
210 | "insert-shot": {
211 | "description": "An insert shot is any shot that’s sole purpose is to focus the viewer’s attention to a specific detail within a scene. These shots are often placed from the point of view of a character, suggesting that we see what the character sees during a moment of concentration.",
212 | "keywords": "Objects, Close-Up, B-Roll"
213 | },
214 | "tilt-shot": {
215 | "description": "A tilt shot is shot that uses a cinematographic technique called tilting in which the camera stays fixed but rotates up and down on a vertical plane. Tilting is similar to the motion of an individual raising or lowering their heads. This is different than the panning of a camera which also sits on a fixed but moves left to right.",
216 | "keywords": "Looking Up, Looking Down"
217 | },
218 | "static-shot": {
219 | "description": "A static shot in film is a shot that is devoid of camera movement. Also known as a locked-off shot, or an immobile shot. The frame can be filled with the movement of vehicles, characters, props, weather, etc, but the frame itself does not move in a static shot.",
220 | "keywords": "Frame, Still"
221 | },
222 | "shutter-shot": {
223 | "description": "The asset contains duplicate frames, stretching the running time of the sequence and creating a sense of slow motion for the audience.",
224 | "keywords": "slow motion"
225 | }
226 | }
227 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from videoannotator import data
4 |
5 |
6 | class TestLabeledDataset(unittest.TestCase):
7 | def test_valid(self):
8 | lds = data.LabeledDataset(
9 | label="test", pos=frozenset({"x", "y"}), neg=frozenset({"z"})
10 | )
11 | assert len(lds) == 3
12 | assert lds.pos_cnt == 2
13 | assert lds.neg_cnt == 1
14 |
15 | def test_invalid_common_keys(self):
16 | with self.assertRaises(ValueError):
17 | data.LabeledDataset(
18 | label="test", pos=frozenset({"x"}), neg=frozenset({"x"})
19 | )
20 |
21 | def test_invalid_no_data(self):
22 | with self.assertRaises(ValueError):
23 | data.LabeledDataset(label="test", pos=frozenset(), neg=frozenset())
24 |
--------------------------------------------------------------------------------
/videoannotator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/videoannotator/__init__.py
--------------------------------------------------------------------------------
/videoannotator/config.py:
--------------------------------------------------------------------------------
1 | from fractions import Fraction
2 |
3 | PATH_DATA_BASE = "/root/data/videoannotator"
4 |
5 | SEED = 0
6 | SCORE_THRESHOLD = 0.5
7 | N_SPLITS = 5
8 | N_BOOTSTRAP = 5
9 | N_JOBS = -1
10 | MAX_ITER = 1_000
11 | VALIDATION_FRACTION = Fraction(numerator=2, denominator=10)
12 | MIN_TRAINING_POS_RATE = 0.1
13 | CLUSTER_CNT = 10
14 | ANNOTATOR_CNT = 3
15 |
16 | # AVE experiments
17 | AVE_EXPERIMENTS_NS = (25, 50, 100, 500, 1000)
18 | AVE_EXPERIMENTS_METHODS = (
19 | "random",
20 | "zero-shot-50-random-50",
21 | "zero-shot",
22 | )
23 | SCORING = (
24 | "average_precision",
25 | # "balanced_accuracy",
26 | )
27 |
28 | LABEL_GROUPS = dict(
29 | motion={
30 | "zoom",
31 | "pan",
32 | "slow-motion",
33 | "handheld",
34 | # "fast-motion",
35 | "timelapse",
36 | "jump-scare",
37 | },
38 | genres={
39 | "action",
40 | "sci-fi",
41 | "horror",
42 | "fantasy",
43 | "drama",
44 | "romance",
45 | },
46 | emotions={
47 | "anger",
48 | "sad",
49 | "happy",
50 | "scared",
51 | "laughter",
52 | },
53 | shot_types={
54 | "establishing-shots",
55 | "static-shot",
56 | "shutter-shot",
57 | "cowboy-shot",
58 | "extreme-close-up",
59 | "extreme-wide-shot",
60 | "two-shot",
61 | "group-shot",
62 | "aerial",
63 | "eye-level",
64 | "medium",
65 | "closeup",
66 | "wide",
67 | "over-the-shoulder-shot",
68 | "tilt-shot",
69 | "dutch-angle",
70 | "point-of-view-shot",
71 | "high-angle",
72 | "insert-shot",
73 | "low-angle",
74 | "overhead-shot",
75 | "single-shot",
76 | },
77 | sensitivities={
78 | "alcohol",
79 | "smoking",
80 | "nudity",
81 | "gore",
82 | "drugs",
83 | "violence",
84 | "intimacy",
85 | },
86 | events_actions={
87 | "interview",
88 | "fight",
89 | "car-chase",
90 | "run",
91 | },
92 | time_location={
93 | "day",
94 | "golden-hour",
95 | "interior",
96 | },
97 | focus={
98 | "character-focus",
99 | "animal",
100 | "object",
101 | },
102 | )
103 |
104 | LABELS = frozenset(
105 | (
106 | "action",
107 | "smoking",
108 | "nudity",
109 | "gore",
110 | "sad",
111 | "happy",
112 | "intimacy",
113 | "establishing-shots",
114 | "character-focus",
115 | "sci-fi",
116 | "horror",
117 | "alcohol",
118 | "anger",
119 | "aerial",
120 | "eye-level",
121 | "medium",
122 | "closeup",
123 | "wide",
124 | "zoom",
125 | "pan",
126 | "slow-motion",
127 | "jump-scare",
128 | "timelapse",
129 | "interior",
130 | "overhead-shot",
131 | # "fast-motion", # very low pos rate
132 | "animal",
133 | "group-shot",
134 | "drugs",
135 | "car-chase",
136 | "laughter",
137 | "over-the-shoulder-shot",
138 | "day",
139 | "object",
140 | "handheld",
141 | "low-angle",
142 | "violence",
143 | "drama",
144 | "point-of-view-shot",
145 | "single-shot",
146 | "romance",
147 | "golden-hour",
148 | "extreme-wide-shot",
149 | "high-angle",
150 | "insert-shot",
151 | "run",
152 | "fantasy",
153 | "static-shot",
154 | "shutter-shot",
155 | "cowboy-shot",
156 | "extreme-close-up",
157 | "two-shot",
158 | "fight",
159 | "interview",
160 | "scared",
161 | "tilt-shot",
162 | "dutch-angle",
163 | )
164 | )
165 | LABELS_AVE = frozenset(
166 | (
167 | "group-shot",
168 | "animal",
169 | "single-shot",
170 | "low-angle",
171 | "eye-level",
172 | "medium",
173 | "extreme-wide-shot",
174 | "tilt-shot",
175 | "high-angle",
176 | "over-the-shoulder-shot",
177 | "zoom",
178 | "closeup",
179 | "wide",
180 | "handheld",
181 | "overhead-shot",
182 | "insert-shot",
183 | "two-shot",
184 | "extreme-close-up",
185 | "aerial",
186 | )
187 | )
188 |
--------------------------------------------------------------------------------
/videoannotator/data.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import fractions
3 | import functools
4 | import json
5 | import mmh3
6 | import logging
7 | import typing as t
8 | from dataclasses import dataclass
9 |
10 | import h5py
11 | import numpy as np
12 | import pandas as pd
13 |
14 | from . import config as cfg, io
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | Key = str
19 | Embedding = np.ndarray
20 | Label = str
21 |
22 |
23 | @functools.lru_cache(maxsize=1)
24 | def get_embeddings_dict() -> t.Dict[Key, Embedding]:
25 | logger.info("Loading embeddings for the first time...")
26 | keys = pd.read_csv(io.PATHS_STATIC.shot_data, usecols=["key"]).key.tolist()
27 | with h5py.File(io.PATHS_STATIC.embeddings, "r") as hf:
28 | embs = hf["embeddings"][:]
29 | d = {key: emb for key, emb in zip(keys, embs)}
30 | logger.info("Finished loading embeddings.")
31 | return d
32 |
33 |
34 | @functools.lru_cache(maxsize=1)
35 | def get_text_embedding_dict() -> t.Dict[Key, Embedding]:
36 | logger.info("Loading text embeddings for the first time...")
37 | data = json.load(open(io.PATHS_STATIC.text_embeddings))
38 | logger.info("Finished loading text embeddings.")
39 | return data
40 |
41 |
42 | @functools.lru_cache(maxsize=1)
43 | def get_key_to_cluster_idx_dict() -> t.Dict[Key, int]:
44 | df = pd.read_csv(io.PATHS_STATIC.shot_data)
45 | return df.set_index("key").cluster_index.to_dict()
46 |
47 |
48 | @dataclass(frozen=True)
49 | class LabeledDataset:
50 | label: str
51 | pos: t.FrozenSet[Key]
52 | neg: t.FrozenSet[Key]
53 | _cluster_cnt: int = cfg.CLUSTER_CNT
54 |
55 | def __post_init__(self):
56 | if len(self) == 0:
57 | raise ValueError(f"Dataset with label={self.label} has no data.")
58 | common = self.pos & self.neg
59 | if len(common) > 0:
60 | raise ValueError(
61 | f"keys cannot be in both positive and negative sets. these keys are in both: {common}"
62 | )
63 | elif self.pos_cnt == 0:
64 | logger.warning(f"{self} has no positive annotations.")
65 | elif self.neg_cnt == 0:
66 | logger.warning(f"{self} has no negative annotations.")
67 |
68 | def __len__(self) -> int:
69 | return len(self.pos) + len(self.neg)
70 |
71 | def __repr__(self) -> str:
72 | return (
73 | f"LabeledDataset(label={self.label}, cnt={len(self)}, pos_cnt={self.pos_cnt}, "
74 | f"neg_cnt={self.neg_cnt}, pos_rate={self.pos_rate:.2f})"
75 | )
76 |
77 | @property
78 | def keys(self) -> t.List[Key]:
79 | return self.keys_pos + self.keys_neg
80 |
81 | @property
82 | def y(self) -> np.array:
83 | return np.array([True] * self.pos_cnt + [False] * self.neg_cnt)
84 |
85 | @property
86 | @functools.lru_cache()
87 | def x(self) -> np.ndarray:
88 | return np.vstack([get_embeddings_dict()[key] for key in self.keys])
89 |
90 | @property
91 | @functools.lru_cache()
92 | def data_score(self) -> float:
93 | k2c = get_key_to_cluster_idx_dict()
94 | part_pos = self._data_score_part(
95 | keys=self.keys_pos, k2c=k2c, cluster_cnt=self._cluster_cnt
96 | )
97 | part_neg = self._data_score_part(
98 | keys=self.keys_neg, k2c=k2c, cluster_cnt=self._cluster_cnt
99 | )
100 | return (part_pos + part_neg) / (2 * self._cluster_cnt)
101 |
102 | @staticmethod
103 | def _data_score_part(
104 | keys: t.Sequence[Key], k2c: t.Dict[Key, int], cluster_cnt: int
105 | ) -> float:
106 | return (
107 | pd.Series(map(k2c.get, keys)).value_counts().clip(0, cluster_cnt)
108 | / cluster_cnt
109 | ).sum()
110 |
111 | @property
112 | def keys_pos(self) -> t.List[Key]:
113 | return sorted(self.pos)
114 |
115 | @property
116 | def keys_neg(self) -> t.List[Key]:
117 | return sorted(self.neg)
118 |
119 | @property
120 | def pos_cnt(self) -> int:
121 | return len(self.pos)
122 |
123 | @property
124 | def pos_rate(self) -> float:
125 | return self.pos_cnt / len(self)
126 |
127 | @property
128 | def neg_cnt(self) -> int:
129 | return len(self.neg)
130 |
131 | @staticmethod
132 | def _key_in_val(
133 | key: Key, validation_fraction: fractions.Fraction, seed: int
134 | ) -> bool:
135 | vf = validation_fraction
136 | h = mmh3.hash(key=key, seed=seed)
137 | bucket = h % vf.denominator
138 | return bucket < vf.numerator
139 |
140 | def split(
141 | self,
142 | validation_fraction: fractions.Fraction = cfg.VALIDATION_FRACTION,
143 | keys_to_remove_from_train: t.Set[Key] = frozenset(),
144 | keys_to_remove_from_validation: t.Set[Key] = frozenset(),
145 | seed: int = cfg.SEED,
146 | ) -> t.Tuple["LabeledDataset", "LabeledDataset"]:
147 | data = dict(
148 | train=collections.defaultdict(set), val=collections.defaultdict(set)
149 | )
150 | for value, keys in ((True, self.pos), (False, self.neg)):
151 | for key in keys:
152 | in_val = self._key_in_val(
153 | key=key, validation_fraction=validation_fraction, seed=seed
154 | )
155 | if in_val and key not in keys_to_remove_from_validation:
156 | data["val"][value].add(key)
157 | elif not in_val and key not in keys_to_remove_from_train:
158 | data["train"][value].add(key)
159 |
160 | ds_train = LabeledDataset(
161 | label=self.label,
162 | pos=frozenset(data["train"][True]),
163 | neg=frozenset(data["train"][False]),
164 | )
165 | ds_val = LabeledDataset(
166 | label=self.label,
167 | pos=frozenset(data["val"][True]),
168 | neg=frozenset(data["val"][False]),
169 | )
170 | if ds_train.pos_rate == 0:
171 | raise ValueError(
172 | f"The training split for label={self.label} with seed={seed} has no positive instances."
173 | )
174 | if ds_val.pos_rate == 0:
175 | raise ValueError(
176 | f"The validation split for label={self.label} with seed={seed} has no positive instances."
177 | )
178 | return ds_train, ds_val
179 |
180 | def boostrap_xyk(self, idx: int) -> t.Tuple[np.ndarray, np.ndarray, t.List[Key]]:
181 | if idx == 0:
182 | # return the exact dataset on exactly one of the bootstrap
183 | # this helps with datasets that have a very small number of positives
184 | return self.x, self.y, self.keys
185 | np.random.seed(idx)
186 | n = len(self)
187 | idxs = np.random.choice(range(n), n)
188 | x = self.x[idxs]
189 | y = self.y[idxs]
190 | keys = np.array(self.keys)[idxs]
191 | return x, y, list(keys)
192 |
193 | def remove_keys(self, keys: t.Set[Key]) -> "LabeledDataset":
194 | return LabeledDataset(
195 | label=self.label,
196 | pos=frozenset(self.pos - keys),
197 | neg=frozenset(self.neg - keys),
198 | )
199 |
200 | def _create_dataset_from_keys(self, keys: t.Collection[Key]) -> "LabeledDataset":
201 | keys_pos_set = set(self.keys_pos)
202 | keys_neg_set = set(self.keys_neg)
203 | pos = frozenset(k for k in keys if k in keys_pos_set)
204 | neg = frozenset(k for k in keys if k in keys_neg_set)
205 | return LabeledDataset(
206 | label=self.label,
207 | pos=pos,
208 | neg=neg,
209 | )
210 |
211 | def sample(
212 | self, n: int, seed: t.Optional[int] = None, start_idx: int = 0
213 | ) -> "LabeledDataset":
214 | np.random.seed(seed)
215 | keys_perm = np.random.permutation(self.keys)
216 | idx_last = start_idx + n if n is not None else None
217 | keys_to_use = keys_perm[start_idx:idx_last]
218 | return self._create_dataset_from_keys(keys=keys_to_use)
219 |
220 | def sample_stratified(
221 | self,
222 | n: int,
223 | seed: t.Optional[int] = None,
224 | min_training_pos_rate: float = 0.0,
225 | ) -> "LabeledDataset":
226 | n = min(len(self), n)
227 | np.random.seed(seed)
228 | cnt_pos = round(max(self.pos_rate, min_training_pos_rate) * n)
229 | cnt_neg = n - cnt_pos
230 | pos = np.random.choice(self.keys_pos, replace=False, size=cnt_pos)
231 | neg = np.random.choice(self.keys_neg, replace=False, size=cnt_neg)
232 | return self._create_dataset_from_keys(keys=pos + neg)
233 |
234 | def __add__(self, other: "LabeledDataset") -> "LabeledDataset":
235 | if self.label != other.label:
236 | raise ValueError(
237 | f"Cannot add datasets with two different labels: {self.label} and {other.label}"
238 | )
239 | return LabeledDataset(
240 | label=self.label,
241 | pos=self.pos | other.pos,
242 | neg=self.neg | other.neg,
243 | )
244 |
245 | def __sub__(self, other: "LabeledDataset") -> "LabeledDataset":
246 | if self.label != other.label:
247 | raise ValueError(
248 | f"Cannot subtract datasets with two different labels: {self.label} and {other.label}"
249 | )
250 | return LabeledDataset(
251 | label=self.label,
252 | pos=self.pos - other.pos,
253 | neg=self.neg - other.neg,
254 | )
255 |
256 | def __eq__(self, other: "LabeledDataset") -> bool:
257 | return (
258 | len(self) == len(other)
259 | and self.label == other.label
260 | and self.pos == other.pos
261 | and self.neg == other.neg
262 | )
263 |
264 |
265 | @dataclass(frozen=True)
266 | class DatasetOrdered:
267 | label: str
268 | pos: t.Tuple[Key, ...]
269 | neg: t.Tuple[Key, ...]
270 |
271 | def __post_init__(self):
272 | if len(set(self.pos) & set(self.neg)) > 0:
273 | raise ValueError(
274 | f"trying to create an ordered dataset with common pos and neg keys."
275 | )
276 |
277 | def __add__(self, other: "DatasetOrdered"):
278 | assert other.label == self.label
279 | pos_set = set(self.pos)
280 | neg_set = set(self.neg)
281 | other_pos = tuple(x for x in other.pos if x not in pos_set)
282 | other_neg = tuple(x for x in other.neg if x not in neg_set)
283 | return DatasetOrdered(
284 | label=self.label, pos=self.pos + other_pos, neg=self.neg + other_neg
285 | )
286 |
287 | @staticmethod
288 | def _sub(a: tuple, b: tuple) -> tuple:
289 | b_set = set(b)
290 | return tuple(x for x in a if x not in b_set)
291 |
292 | def __sub__(self, other: "DatasetOrdered") -> "DatasetOrdered":
293 | assert other.label == self.label
294 | pos = self._sub(a=self.pos, b=other.pos)
295 | neg = self._sub(a=self.neg, b=other.neg)
296 | return DatasetOrdered(label=self.label, pos=pos, neg=neg)
297 |
298 | @property
299 | def lds(self) -> LabeledDataset:
300 | return LabeledDataset(
301 | label=self.label,
302 | pos=frozenset(self.pos),
303 | neg=frozenset(self.neg),
304 | )
305 |
306 | def __len__(self) -> int:
307 | return len(self.pos) + len(self.neg)
308 |
309 | @classmethod
310 | def from_labeled_dataset(
311 | cls, lds: LabeledDataset, seed: int = cfg.SEED
312 | ) -> "DatasetOrdered":
313 | np.random.seed(seed)
314 | pos = sorted(lds.pos)
315 | neg = sorted(lds.neg)
316 | return cls(
317 | label=lds.label,
318 | pos=tuple(np.random.permutation(pos)),
319 | neg=tuple(np.random.permutation(neg)),
320 | )
321 |
322 | def get_n(self, n: int) -> "DatasetOrdered":
323 | pos_rate = len(self.pos) / len(self)
324 | cnt_pos = min(len(self.pos), max(round(pos_rate * n), 1))
325 | cnt_neg = min(len(self.neg), n - cnt_pos)
326 | return DatasetOrdered(
327 | label=self.label,
328 | pos=self.pos[:cnt_pos],
329 | neg=self.neg[:cnt_neg],
330 | )
331 |
332 |
333 | @dataclass(frozen=True)
334 | class CompDataset:
335 | label: str
336 | keys_to_remove: t.FrozenSet[Key] = frozenset()
337 |
338 | @property
339 | @functools.lru_cache()
340 | def data(self) -> dict:
341 | data_og = json.load(open(io.LabelPaths(label=self.label).path_cmp))
342 | return {
343 | k: [v for v in vs if v["key"] not in self.keys_to_remove]
344 | for k, vs in data_og.items()
345 | }
346 |
347 | def _get_lds(
348 | self,
349 | which: str,
350 | n: t.Optional[int],
351 | start_idx: int,
352 | keys_to_remove: t.Optional[t.Set[Key]] = None,
353 | fail_if_less_available: bool = False,
354 | ) -> LabeledDataset:
355 | d = self.data[which]
356 | if n is not None and n > len(d):
357 | msg = f"Dataset for label={self.label} has {len(d)} records, but requested {n}"
358 | if fail_if_less_available:
359 | raise ValueError(msg)
360 | else:
361 | logger.warning(msg)
362 | data = d[start_idx : start_idx + n] if n is not None else d[start_idx:]
363 | pos, neg = self.__get_pos_neg(data=data, keys_to_remove=keys_to_remove)
364 | return LabeledDataset(label=self.label, pos=frozenset(pos), neg=frozenset(neg))
365 |
366 | @staticmethod
367 | def __get_pos_neg(data, keys_to_remove: t.Optional[t.Set[Key]]):
368 | keys_to_remove = keys_to_remove if keys_to_remove is not None else set()
369 | data = [d for d in data if d["key"] not in keys_to_remove]
370 | pos = (d["key"] for d in data if d["value"])
371 | neg = (d["key"] for d in data if not d["value"])
372 | return pos, neg
373 |
374 | def random(
375 | self,
376 | n: t.Optional[int] = None,
377 | start_idx: int = 0,
378 | keys_to_remove: t.Optional[t.Set[Key]] = None,
379 | ) -> LabeledDataset:
380 | return self._get_lds(
381 | which="random", n=n, start_idx=start_idx, keys_to_remove=keys_to_remove
382 | )
383 |
384 | def zero_shot(
385 | self,
386 | n: t.Optional[int] = None,
387 | start_idx: int = 0,
388 | keys_to_remove: t.Optional[t.Set[Key]] = None,
389 | ) -> LabeledDataset:
390 | return self._get_lds(
391 | which="zero_shot", n=n, start_idx=start_idx, keys_to_remove=keys_to_remove
392 | )
393 |
394 | @property
395 | def cnt_zero_shot(self) -> int:
396 | return len(self.data["zero_shot"])
397 |
398 | @property
399 | def cnt_random(self) -> int:
400 | return len(self.data["random"])
401 |
402 | def get_ordered_dataset_zero_shot(
403 | self,
404 | keys_to_remove: t.Set[Key] = frozenset(),
405 | ) -> DatasetOrdered:
406 | return self._get_ordered_dataset(
407 | which="zero_shot", keys_to_remove=keys_to_remove
408 | )
409 |
410 | def get_ordered_dataset_random(
411 | self, keys_to_remove: t.Set[Key] = frozenset()
412 | ) -> DatasetOrdered:
413 | return self._get_ordered_dataset(which="random", keys_to_remove=keys_to_remove)
414 |
415 | def _get_ordered_dataset(
416 | self, which: str, keys_to_remove: t.Set[Key]
417 | ) -> DatasetOrdered:
418 | data = self.data[which]
419 | pos, neg = self.__get_pos_neg(data=data, keys_to_remove=keys_to_remove)
420 | return DatasetOrdered(label=self.label, pos=tuple(pos), neg=tuple(neg))
421 |
422 |
423 | def get_aggregate_labeled_dataset(label: str) -> LabeledDataset:
424 | path = io.LabelPaths(label=label).path_agg
425 | data = json.load(open(path))
426 | ds = collections.defaultdict(set)
427 | for x in data["data"]:
428 | ds[x["agg"]].add(x["key"])
429 | return LabeledDataset(
430 | label=label, pos=frozenset(ds[True]), neg=frozenset(ds[False])
431 | )
432 |
433 |
434 | def get_labeled_dataset_positive_cnt_by_key(label: str) -> t.Dict[Key, int]:
435 | path = io.LabelPaths(label=label).path_agg
436 | data = json.load(open(path))
437 | return {x["key"]: x["pos_cnt"] for x in data["data"]}
438 |
439 |
440 | def get_labeled_dataset_agreement(
441 | label: str, _annotator_cnt: int = cfg.ANNOTATOR_CNT
442 | ) -> float:
443 | pos_counts = get_labeled_dataset_positive_cnt_by_key(label=label)
444 | den = len(pos_counts)
445 | num = sum(cnt in {0, _annotator_cnt} for _, cnt in pos_counts.items())
446 | return num / den
447 |
448 |
449 | def get_labeled_dataset_checkpoints(label: str) -> t.List[LabeledDataset]:
450 | path = io.LabelPaths(label=label).path_checkpoints
451 | data = json.load(open(path))
452 | return [
453 | LabeledDataset(
454 | label=label,
455 | pos=frozenset(x["pos"]),
456 | neg=frozenset(x["neg"]),
457 | )
458 | for x in data
459 | ]
460 |
461 |
462 | def get_ave_validation_labeled_dataset(label: str) -> LabeledDataset:
463 | if label not in cfg.LABELS_AVE:
464 | raise ValueError(
465 | f"AVE data for label={label} does not exist. choices are: {cfg.LABELS_AVE}"
466 | )
467 | path = io.LabelPaths(label=label).path_ave
468 | data = json.load(open(path))
469 | return LabeledDataset(
470 | label=label,
471 | pos=frozenset(data["pos"]),
472 | neg=frozenset(data["neg"]),
473 | )
474 |
475 |
476 | @dataclass(frozen=False)
477 | class DatasetManager:
478 | label: str
479 | _seed: int = cfg.SEED
480 |
481 | def __post_init__(self):
482 | cmp = CompDataset(label=self.label)
483 | self._agg = get_aggregate_labeled_dataset(label=self.label)
484 | self._agg_val_keys = set(self._agg.split(seed=self._seed)[1].keys)
485 | self._ordered_datasets: t.Dict[str, t.Optional[DatasetOrdered]] = dict(
486 | agg=self._get_combined_agg_ordered_dataset(),
487 | zero_shot=cmp.get_ordered_dataset_zero_shot(
488 | keys_to_remove=self._agg_val_keys
489 | ),
490 | random=cmp.get_ordered_dataset_random(keys_to_remove=self._agg_val_keys),
491 | )
492 | self.lds = None
493 | self.lds_hist = []
494 |
495 | def has_data(self, which: str) -> bool:
496 | return self._ordered_datasets[which] is not None
497 |
498 | def _get_agg_ordered_dataset(self, lds: LabeledDataset) -> DatasetOrdered:
499 | ods = DatasetOrdered.from_labeled_dataset(lds=lds)
500 | keys = ods.pos + ods.neg
501 | pos = tuple(
502 | k for k in keys if k in self._agg.pos and k not in self._agg_val_keys
503 | )
504 | neg = tuple(
505 | k for k in keys if k in self._agg.neg and k not in self._agg_val_keys
506 | )
507 | return DatasetOrdered(label=self.label, pos=pos, neg=neg)
508 |
509 | def _get_combined_agg_ordered_dataset(self) -> DatasetOrdered:
510 | lds_list = get_labeled_dataset_checkpoints(label=self.label)
511 | ods = self._get_agg_ordered_dataset(lds=lds_list[0])
512 | for lds in lds_list[1:]:
513 | if len(lds) > len(ods):
514 | ods_curr = self._get_agg_ordered_dataset(lds=lds)
515 | ods = ods + (ods_curr - ods)
516 | return ods
517 |
518 | def extend(self, which: str, n: int) -> LabeledDataset:
519 | ods = self._ordered_datasets[which]
520 | if ods is not None:
521 | ods_new = ods.get_n(n=n)
522 | if len(ods_new) < n:
523 | logger.warning(
524 | f"label={self.label} {which} returned {len(ods_new)} vs. the requested {n}."
525 | )
526 | self._ordered_datasets[which] = None
527 | elif len(ods_new) == len(ods):
528 | self._ordered_datasets[which] = None
529 | else:
530 | self._ordered_datasets[which] = ods - ods_new
531 | lds_new = ods_new.lds
532 |
533 | self.lds_hist.append(
534 | dict(which=which, n=n, lds_new_cnt=len(lds_new), lds_new=lds_new)
535 | )
536 | self.lds = self.lds + lds_new if self.lds is not None else lds_new
537 | for name, ods in self._ordered_datasets.items():
538 | if ods is not None:
539 | self._ordered_datasets[name] = (
540 | self._ordered_datasets[name] - ods_new
541 | )
542 | else:
543 | logger.warning(
544 | f"label={self.label} {which} is already exhausted. returning the last dataset."
545 | )
546 | return self.lds
547 |
--------------------------------------------------------------------------------
/videoannotator/dp.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from dataclasses import dataclass
3 | import logging
4 | import typing as t
5 |
6 | import numpy as np
7 | from tqdm.auto import tqdm
8 |
9 | from . import config as cfg
10 | from .data import DatasetManager, get_aggregate_labeled_dataset
11 | from .experiments import ExperimentResults, Metric, run_experiment
12 | from .models import LogisticRegression, Model
13 |
14 | Choice = str
15 |
16 | # TODO: add conf
17 | CHOICES: t.Tuple[str, ...] = ("agg", "zero_shot", "random")
18 | D = 50
19 | EPSILON = 0.2
20 | ALPHA = 0.5
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | @dataclass(frozen=True)
27 | class DecisionProcessResults:
28 | res_hist: t.Tuple[t.Tuple[str, ExperimentResults], ...]
29 | choice_hist: t.Tuple[str, ...]
30 |
31 | @property
32 | def best_metric(self) -> Metric:
33 | return max(
34 | (
35 | (res.average_precision.mean, res.average_precision)
36 | for _, res in self.res_hist
37 | ),
38 | key=lambda x: x[0],
39 | )[1]
40 |
41 | @property
42 | def last_metric(self) -> Metric:
43 | return self.res_hist[-1][1].average_precision
44 |
45 | @property
46 | def best_idx(self) -> int:
47 | return max(
48 | (res.average_precision.mean, idx)
49 | for idx, (_, res) in enumerate(self.res_hist)
50 | )[1]
51 |
52 |
53 | @dataclass(frozen=False)
54 | class _DecisionProcess:
55 | label: str
56 | choices: t.Tuple[str, ...] = CHOICES
57 | d: int = D
58 | model: Model = LogisticRegression(scoring="average_precision")
59 | _seed_base: int = cfg.SEED
60 |
61 | @staticmethod
62 | def _iter_n(n: int, verbose: bool):
63 | if verbose:
64 | return tqdm(range(n))
65 | else:
66 | return range(n)
67 |
68 | def run(self, n: int, verbose: bool = False) -> DecisionProcessResults:
69 | self._warmup(verbose=verbose)
70 | for _ in self._iter_n(n=n, verbose=verbose):
71 | choice = self._select_outer()
72 | res = self._act(choice=choice, verbose=verbose)
73 | self._update(choice=choice, res=res)
74 | self._choice_hist.append(choice)
75 | self._res_hist.append((choice, res))
76 | self.i += 1
77 | return DecisionProcessResults(
78 | res_hist=tuple(self._res_hist),
79 | choice_hist=tuple(self._choice_hist),
80 | )
81 |
82 | def __post_init__(self):
83 | self.i = 0
84 | self._res_hist = []
85 | self._choice_hist = []
86 | self._dsm = DatasetManager(label=self.label, _seed=self._seed_base)
87 | self._choices_exhausted = set()
88 |
89 | @property
90 | def choices_available(self) -> t.Tuple[str, ...]:
91 | return tuple(c for c in self.choices if c not in self._choices_exhausted)
92 |
93 | def _select_outer(self) -> Choice:
94 | if self.i < len(self.choices):
95 | # try each option once first
96 | return self.choices[self.i]
97 | while True:
98 | c = self._select()
99 | if c is None:
100 | raise ValueError(
101 | f"{self.__class__.__name__} select returned option None (not allowed) at i={self.i} with "
102 | f"params: {self.__dict__}"
103 | )
104 | if self._dsm.has_data(which=c):
105 | return c
106 | else:
107 | self._choices_exhausted.add(c)
108 | if len(self.choices_available) == 0:
109 | raise ValueError(
110 | f"{self.__class__.__name__} ran out of options to select at i={self.i} with "
111 | f"params: {self.__dict__}"
112 | )
113 |
114 | def _select(self) -> Choice:
115 | pass
116 |
117 | def _warmup(self, verbose: bool) -> None:
118 | pass
119 |
120 | def _act(self, choice: Choice, verbose: bool) -> ExperimentResults:
121 | _, lds_val = get_aggregate_labeled_dataset(label=self.label).split(
122 | seed=self._seed_base
123 | )
124 | lds_train = self._dsm.extend(which=choice, n=self.d)
125 | if verbose:
126 | logger.info(f"Running experiment with choice={choice}")
127 | return run_experiment(
128 | model=self.model,
129 | lds_train=lds_train,
130 | lds_validation=lds_val,
131 | )
132 |
133 | def _update(self, choice: Choice, res: ExperimentResults) -> None:
134 | pass
135 |
136 | @property
137 | def _seed(self) -> int:
138 | return self._seed_base + self.i
139 |
140 | @staticmethod
141 | def extract_metric(res: ExperimentResults) -> float:
142 | return res.average_precision.mean
143 |
144 |
145 | @dataclass
146 | class RoundRobin(_DecisionProcess):
147 | def _select(self) -> Choice:
148 | return self.choices_available[self.i % len(self.choices_available)]
149 |
150 |
151 | @dataclass
152 | class Random(_DecisionProcess):
153 | def _select(self) -> Choice:
154 | np.random.seed(self._seed)
155 | return np.random.choice(self.choices_available)
156 |
157 |
158 | @dataclass
159 | class GreedyOracle(_DecisionProcess):
160 | def _select_outer(self) -> Choice:
161 | while True:
162 | c = self._select()
163 | if self._dsm.has_data(which=c):
164 | return c
165 | else:
166 | self._choices_exhausted.add(c)
167 | if len(self.choices_available) == 0:
168 | raise ValueError(
169 | f"GreedyOracle has no more available choices for label={self.label} at i={self.i}"
170 | )
171 |
172 | def _select(self) -> Choice:
173 | best, best_score = None, float("-inf")
174 | dsm_copy = deepcopy(self._dsm)
175 | for choice in self.choices_available:
176 | try:
177 | res = self._act(choice=choice, verbose=False)
178 | score = self.extract_metric(res=res)
179 | if np.isnan(score):
180 | raise ValueError("Evaluation score is NaN.")
181 | if score > best_score:
182 | best, best_score = choice, score
183 | except Exception as e:
184 | logger.warning(
185 | f"Running GreedyOracle with choice={choice}, i={self.i}, label={self.label} failed: {e}. "
186 | f"Skipping over this choice and setting the metric to zero. "
187 | f"Params: {self.__dict__}"
188 | )
189 | finally:
190 | self._dsm = deepcopy(dsm_copy)
191 | if best is None:
192 | raise ValueError(
193 | f"All choices failed for GreedyOracle label={self.label}, i={self.i}."
194 | f" Params: {self.__dict__}"
195 | )
196 | return best
197 |
198 |
199 | @dataclass
200 | class EpsilonGreedyMean(_DecisionProcess):
201 | epsilon: float = EPSILON
202 |
203 | def __post_init__(self):
204 | super().__post_init__()
205 | self._metrics = {c: Metric(values=tuple()) for c in self.choices}
206 | self._metrics_abs = {c: Metric(values=tuple()) for c in self.choices}
207 |
208 | @property
209 | def best_choice(self) -> Choice:
210 | return max(
211 | (m.mean, c)
212 | for c, m in self._metrics.items()
213 | if c in set(self.choices_available)
214 | )[1]
215 |
216 | def _select(self) -> Choice:
217 | np.random.seed(self._seed)
218 | if np.random.rand() < self.epsilon:
219 | return np.random.choice(self.choices_available)
220 | else:
221 | return self.best_choice
222 |
223 | def _update(self, choice: Choice, res: ExperimentResults) -> None:
224 | prev = self._metrics_abs[self._choice_hist[-1]].values[-1]
225 | metric = self.extract_metric(res=res)
226 | self._metrics[choice] = self._metrics[choice].add(metric - prev)
227 | self._metrics_abs[choice] = self._metrics_abs[choice].add(metric)
228 |
229 | def _warmup(self, verbose: bool) -> None:
230 | dsm_copy = deepcopy(self._dsm)
231 | metrics = dict()
232 | for choice in self.choices:
233 | try:
234 | res = self._act(choice=choice, verbose=verbose)
235 | val = self.extract_metric(res=res)
236 | except Exception as e:
237 | val = 0
238 | name = self.__class__.__name__
239 | logger.warning(
240 | f"Running {name} warmup for choice={choice}, label={self.label} failed: {e}. "
241 | f"Skipping over this choice and setting the metric to zero. "
242 | f"Params: {self.__dict__}"
243 | )
244 |
245 | self._metrics_abs[choice] = self._metrics_abs[choice].add(val)
246 | metrics[choice] = val
247 | self._dsm = deepcopy(dsm_copy)
248 |
249 | choice_worst_val, choice_worst = min((v, c) for c, v in metrics.items())
250 | self._choice_hist.append(choice_worst)
251 | for choice in self.choices:
252 | self._metrics[choice] = self._metrics[choice].add(
253 | metrics[choice] - choice_worst_val
254 | )
255 |
256 |
257 | @dataclass
258 | class EpsilonGreedyLast(EpsilonGreedyMean):
259 | @property
260 | def best_choice(self) -> Choice:
261 | return max(
262 | (m.last, c)
263 | for c, m in self._metrics.items()
264 | if c in set(self.choices_available)
265 | )[1]
266 |
267 |
268 | @dataclass
269 | class EpsilonGreedyEWMA(EpsilonGreedyMean):
270 | alpha: float = ALPHA
271 |
272 | @property
273 | def best_choice(self) -> Choice:
274 | return max(
275 | (m.ewma(alpha=self.alpha), c)
276 | for c, m in self._metrics.items()
277 | if c in set(self.choices_available)
278 | )[1]
279 |
280 |
281 | @dataclass
282 | class UCBMean(EpsilonGreedyMean):
283 | c: float = 1e-2
284 |
285 | """
286 | c <= 1 for best results?!
287 | """
288 |
289 | def __post_init__(self):
290 | super().__post_init__()
291 | self._counts = {c: 0 for c in self.choices}
292 |
293 | def _select(self) -> Choice:
294 | t_ = self.i + 1
295 |
296 | best, best_score = None, float("-inf")
297 | for choice, cnt in self._counts.items():
298 | if choice in set(self.choices_available):
299 | if cnt == 0:
300 | return choice
301 |
302 | # val + c * sqrt(log(t) / cnt)
303 | aug = self._metrics[choice].mean
304 | aug += self.c * np.sqrt(np.log(t_) / cnt)
305 | if aug > best_score:
306 | best = choice
307 | best_score = aug
308 |
309 | return best
310 |
311 | def _update(self, choice: Choice, res: ExperimentResults) -> None:
312 | self._counts[choice] += 1
313 | super()._update(choice=choice, res=res)
314 |
315 |
316 | @dataclass
317 | class UCBEWMA(EpsilonGreedyEWMA, UCBMean):
318 | pass
319 |
320 |
321 | @dataclass
322 | class UCBLast(EpsilonGreedyLast, UCBMean):
323 | pass
324 |
325 |
326 | @dataclass
327 | class Thompson(_DecisionProcess):
328 | pass
329 |
330 |
331 | @dataclass
332 | class Exp3Mean(EpsilonGreedyMean):
333 | gamma: float = 1e-1
334 | _seed: int = cfg.SEED
335 |
336 | def __post_init__(self):
337 | assert 0 <= self.gamma <= 1
338 | super().__post_init__()
339 | self._ws = {c: 1 for c in self.choices_available}
340 |
341 | @property
342 | def _k(self) -> int:
343 | return len(self.choices_available)
344 |
345 | @property
346 | def _ps(self) -> t.Dict[Choice, float]:
347 | w_sum = sum(w for c, w in self._ws.items() if c in self.choices_available)
348 | return {
349 | c: (1 - self.gamma) * self._ws[c] / w_sum + self.gamma / self._k
350 | for c in self.choices_available
351 | }
352 |
353 | def _select(self) -> Choice:
354 | np.random.seed(self._seed)
355 | p = [self._ps[c] for c in self.choices_available]
356 | return np.random.choice(self.choices_available, p=p)
357 |
358 | def _update(self, choice: Choice, res: ExperimentResults) -> None:
359 | super()._update(choice=choice, res=res)
360 | x = self._get_update_metric(choice=choice) / self._ps[choice]
361 | self._ws[choice] *= np.exp(self.gamma * x / self._k)
362 |
363 | def _get_update_metric(self, choice: Choice) -> float:
364 | return self._metrics[choice].mean
365 |
366 |
367 | @dataclass
368 | class Exp3Last(Exp3Mean):
369 | def _get_update_metric(self, choice: Choice) -> float:
370 | return self._metrics[choice].last
371 |
372 |
373 | @dataclass
374 | class Exp3EWMA(Exp3Mean):
375 | alpha: float = ALPHA
376 |
377 | def _get_update_metric(self, choice: Choice) -> float:
378 | return self._metrics[choice].ewma(alpha=self.alpha)
379 |
--------------------------------------------------------------------------------
/videoannotator/experiments.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import itertools
3 | import logging
4 | import typing as t
5 | from dataclasses import dataclass
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from sklearn.metrics import average_precision_score as ap, balanced_accuracy_score as ba
10 | from tqdm.auto import tqdm
11 |
12 | from . import config as cfg, data, models
13 | from .data import LabeledDataset
14 |
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | @dataclass(frozen=True)
20 | class Metric:
21 | values: t.Tuple[float, ...]
22 |
23 | @property
24 | @functools.lru_cache()
25 | def values_finite(self) -> t.List[float]:
26 | if len(self.values) == 0:
27 | return []
28 | vals = [v for v in self.values if not np.isnan(v)]
29 | if len(vals) == 0:
30 | raise ValueError("Metric only contains NaN.")
31 | if len(vals) < len(self.values):
32 | logger.warning(f"Metric has NaN values: {self.values}")
33 | return vals
34 |
35 | @property
36 | def mean(self) -> float:
37 | return np.mean(self.values_finite).item() if len(self.values_finite) > 0 else 0
38 |
39 | @property
40 | def std(self) -> float:
41 | return np.std(self.values_finite).item() if len(self.values_finite) > 0 else 0
42 |
43 | def __repr__(self):
44 | return f"Metric(mean={self.mean:.2f}, std={self.std:.2f})"
45 |
46 | def add(self, x: float) -> "Metric":
47 | return Metric(values=self.values + (x,))
48 |
49 | @property
50 | def last(self) -> float:
51 | return self.values_finite[-1] if len(self.values_finite) > 0 else float("-inf")
52 |
53 | def ewma(self, alpha: float) -> float:
54 | assert 0 <= alpha <= 1
55 | if len(self.values_finite) == 0:
56 | return 0
57 | v = 1 - alpha
58 | n = len(self.values_finite)
59 | num = sum(x * (v**i) for i, x in enumerate(self.values_finite[::-1]))
60 | den = sum(v**i for i in range(n))
61 | return num / den
62 |
63 |
64 | @dataclass(frozen=True)
65 | class ExperimentResults:
66 | lds_train: LabeledDataset
67 | lds_validation: LabeledDataset
68 | average_precision: Metric
69 | balanced_accuracy: Metric
70 |
71 |
72 | def run_experiment(
73 | model: models.Model,
74 | lds_train: LabeledDataset,
75 | lds_validation: LabeledDataset, # TODO: take a list/dict instead?
76 | n_bootstrap: int = cfg.N_BOOTSTRAP,
77 | ) -> ExperimentResults:
78 | common_keys = set(lds_train.keys) & set(lds_validation.keys)
79 | if len(common_keys) > 0:
80 | raise ValueError(f"Train and validation sets have common keys: {common_keys}")
81 | model.fit(lds_train.x, lds_train.y)
82 | aps, bas = [], []
83 | for idx in range(n_bootstrap):
84 | x, y, _ = lds_validation.boostrap_xyk(idx=idx)
85 | bas.append(ba(y, model.predict(x)))
86 | aps.append(ap(y, model.predict_proba(x)))
87 | return ExperimentResults(
88 | lds_train=lds_train,
89 | lds_validation=lds_validation,
90 | average_precision=Metric(values=tuple(aps)),
91 | balanced_accuracy=Metric(values=tuple(bas)),
92 | )
93 |
94 |
95 | def _ap_baseline(lds_val: LabeledDataset) -> Metric:
96 | thr = 1_000
97 | p = lds_val.pos_cnt
98 | n = len(lds_val)
99 | if n >= thr:
100 | return Metric(values=(p / n,))
101 | np.random.seed(0)
102 | y = [True] * p + [False] * (n - p)
103 | rand_scores = np.random.rand(thr, n)
104 | aps = tuple(ap(y, rand_scores[i]) for i in range(thr))
105 | return Metric(values=aps)
106 |
107 |
108 | @dataclass(frozen=True)
109 | class Experiment:
110 | ns: t.Tuple[int, ...] = cfg.AVE_EXPERIMENTS_NS
111 | methods: t.Tuple[str, ...] = cfg.AVE_EXPERIMENTS_METHODS
112 | seed: int = cfg.SEED
113 | scoring: t.Tuple[str, ...] = cfg.SCORING
114 |
115 | def experiment_active_learning(self, label: str) -> dict:
116 | # TODO: make sure zero-shot / random don't overlap with validation
117 | logger.info(f"Running active learning experiment for label={label}...")
118 | ds_agg = data.get_aggregate_labeled_dataset(label=label)
119 | ds_agg_train, ds_agg_val = ds_agg.split()
120 | # ds_val_lookup = dict(agg=ds_agg_val)
121 | res_final = dict()
122 | res_final["baseline_agg"] = _ap_baseline(lds_val=ds_agg_val)
123 | # TODO: remove disagreements and then run (if any positives left) => may still be misaligned
124 | # if label in cfg.LABELS_AVE:
125 | # ds_ave = data.get_ave_validation_labeled_dataset(label=label)
126 | # ds_val_lookup["ave"] = ds_ave
127 | # logger.info("Zero shot with AVE data...")
128 | # res_final["zero_shot_ave"] = run_experiment(
129 | # model=models.ZeroShotText2Video(label=label),
130 | # lds_train=ds_agg,
131 | # lds_validation=ds_ave,
132 | # )
133 | # res_final["baseline_ave"] = _ap_baseline(lds_val=ds_ave)
134 | # else:
135 | # logger.info(f"AVE data does NOT exist for label={label}")
136 | logger.info("Zero shot with agg data...")
137 | res_final["zero_shot"] = run_experiment(
138 | model=models.ZeroShotText2Video(label=label),
139 | lds_train=ds_agg_train,
140 | lds_validation=ds_agg_val,
141 | )
142 | val_keys = set(ds_agg_val.keys)
143 | ds_all = data.get_labeled_dataset_checkpoints(label=label)
144 | logger.info("Comparing...")
145 | res_final["cmp"] = self._run_experiments_comp(
146 | label=label,
147 | lds_validation=ds_agg_val,
148 | )
149 | res_final["checkpoints"] = [
150 | run_experiment(
151 | models.LogisticRegression(scoring="average_precision"),
152 | lds_train=ds.remove_keys(keys=val_keys),
153 | lds_validation=ds_agg_val,
154 | )
155 | for ds in tqdm(ds_all, desc="Processing checkpoints")
156 | ]
157 | return res_final
158 |
159 | def experiment_active_learning_batch(
160 | self,
161 | labels: t.Tuple[str, ...],
162 | ) -> dict:
163 | return {
164 | label: self.experiment_active_learning(label=label)
165 | for label in tqdm(labels)
166 | }
167 |
168 | def experiment_active_learning_aggregate(
169 | self, res_all: t.Dict[str, dict]
170 | ) -> pd.DataFrame:
171 | # TODO: AP vs. n => lines with CI: 1- baseline 2- zs 3- agg 4-
172 | pass
173 |
174 | @staticmethod
175 | def _get_lds_zero_shot(lds: LabeledDataset, n: int) -> LabeledDataset:
176 | zs = models.ZeroShotText2Video(label=lds.label)
177 | keys = zs.top_n_keys(lds=lds, n=n)
178 | return LabeledDataset(
179 | label=lds.label,
180 | pos=frozenset(k for k in keys if k in lds.pos),
181 | neg=frozenset(k for k in keys if k in lds.neg),
182 | )
183 |
184 | @staticmethod
185 | def _get_lds_train(
186 | label: str, n: int, method: str, keys_to_remove: t.FrozenSet[data.Key]
187 | ) -> LabeledDataset:
188 | ds = data.CompDataset(label=label, keys_to_remove=frozenset(keys_to_remove))
189 | if method == "random":
190 | return ds.random(n=n)
191 | elif method == "zero-shot-50-random-50":
192 | n_rand = round(n * 0.5)
193 | lds_rand = ds.random(n=n_rand)
194 | lds_zs = ds.zero_shot(n=n - n_rand)
195 | return lds_rand + lds_zs
196 | elif method == "zero-shot-20-random-80":
197 | n_zs = round(n * 0.2)
198 | lds_zs = ds.zero_shot(n=n_zs)
199 | lds_rand = ds.random(n=n - n_zs)
200 | return lds_rand + lds_zs
201 | elif method == "zero-shot":
202 | return ds.zero_shot(n=n)
203 | else:
204 | raise ValueError(f"method = {method} is not a valid choice.")
205 |
206 | @property
207 | def comp_iter(self) -> list:
208 | return list(
209 | itertools.product(
210 | self.ns, self.methods, ("average_precision", "balanced_accuracy")
211 | )
212 | )
213 |
214 | def _run_experiments_comp(
215 | self,
216 | label: str,
217 | lds_validation: LabeledDataset,
218 | ) -> dict:
219 | res = dict()
220 | for n, method, scoring in tqdm(self.comp_iter, desc="Running experiments"):
221 | try:
222 | lds_train = self._get_lds_train(
223 | label=label,
224 | n=n,
225 | method=method,
226 | keys_to_remove=frozenset(lds_validation.keys),
227 | )
228 | res[(n, method, scoring)] = run_experiment(
229 | model=models.LogisticRegression(scoring=scoring),
230 | lds_train=lds_train,
231 | lds_validation=lds_validation,
232 | )
233 | except Exception as e:
234 | logger.error(
235 | f"Experiment ({n}, {method}, {scoring}) for label={label} failed: {e}"
236 | )
237 | return res
238 |
239 | # def experiment_ave_self_comparison(
240 | # self,
241 | # label: str,
242 | # ) -> t.Dict[t.Tuple[int, str, str], ExperimentResults]:
243 | # if label not in cfg.LABELS_AVE:
244 | # raise ValueError(f"label={label} does not have AVE data")
245 | # ds_ave = data.get_ave_validation_labeled_dataset(label=label)
246 | # lds_train_base, lds_validation = ds_ave.split()
247 | # return self._run_experiments_comp(
248 | # label=label,
249 | # lds_train_base=lds_train_base,
250 | # lds_validation=lds_validation,
251 | # )
252 | #
253 | # def experiment_ave_self_comparison_all(self) -> ...:
254 | # return [
255 | # self.experiment_ave_self_comparison(label=label) for label in cfg.LABELS_AVE
256 | # ]
257 |
--------------------------------------------------------------------------------
/videoannotator/io.py:
--------------------------------------------------------------------------------
1 | # TODO: add text embeddings for top 10k+ common words and store in repo for convenience
2 |
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 |
6 | from . import config as cfg
7 |
8 |
9 | def _check_path(func):
10 | def __check(*args, **kwargs):
11 | path = func(*args, **kwargs)
12 | if not path.exists():
13 | raise FileNotFoundError(f"file {path} does not exist.")
14 | return path
15 |
16 | return __check
17 |
18 |
19 | @dataclass(frozen=True)
20 | class _PathsStatic:
21 | _path_base: str = cfg.PATH_DATA_BASE
22 |
23 | @property
24 | @_check_path
25 | def embeddings(self) -> Path:
26 | return Path(f"{self._path_base}/embeddings.h5")
27 |
28 | @property
29 | @_check_path
30 | def shot_data(self) -> Path:
31 | return Path(f"{self._path_base}/shot-data.csv")
32 |
33 | @property
34 | @_check_path
35 | def text_embeddings(self) -> Path:
36 | return Path(f"{self._path_base}/text-embeddings.json")
37 |
38 |
39 | @dataclass(frozen=True)
40 | class _PathAVE(_PathsStatic):
41 | pass # TODO
42 |
43 |
44 | @dataclass(frozen=True)
45 | class LabelPaths:
46 | label: str
47 | _path_base: str = cfg.PATH_DATA_BASE
48 |
49 | @property
50 | @_check_path
51 | def path_agg(self) -> Path:
52 | return Path(f"{self._path_base}/agg/{self.label}.json")
53 |
54 | @property
55 | @_check_path
56 | def path_checkpoints(self) -> Path:
57 | return Path(f"{self._path_base}/checkpoints/{self.label}.json")
58 |
59 | @property
60 | @_check_path
61 | def path_ave(self) -> Path:
62 | return Path(f"{self._path_base}/ave/validation/{self.label}.json")
63 |
64 | @property
65 | @_check_path
66 | def path_cmp(self) -> Path:
67 | return Path(f"{self._path_base}/cmp/{self.label}.json")
68 |
69 |
70 | PATHS_STATIC = _PathsStatic()
71 | PATHS_AVE = _PathAVE()
72 |
--------------------------------------------------------------------------------
/videoannotator/models.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 | from dataclasses import dataclass
3 |
4 | import numpy as np
5 | from sklearn.ensemble import RandomForestClassifier
6 | from sklearn.linear_model import LogisticRegressionCV
7 | from sklearn.metrics.pairwise import cosine_similarity as cos
8 | from sklearn.model_selection import StratifiedKFold
9 | from xgboost import XGBClassifier
10 |
11 | from . import config as cfg, data
12 |
13 |
14 | @dataclass
15 | class _ScikitBinaryClassifierAPI:
16 | def fit(self, x: np.ndarray, y: np.ndarray) -> None:
17 | pass
18 |
19 | def predict_proba(self, x: np.ndarray) -> np.ndarray:
20 | pass
21 |
22 | def predict(self, x: np.ndarray) -> np.ndarray:
23 | pass
24 |
25 |
26 | @dataclass
27 | class ZeroShotText2Video(_ScikitBinaryClassifierAPI):
28 | label: str
29 |
30 | @property
31 | def text_emb(self) -> np.ndarray:
32 | return np.array(data.get_text_embedding_dict()[self.label])
33 |
34 | def predict_proba(self, x: np.ndarray) -> np.ndarray:
35 | return cos(x, self.text_emb[np.newaxis, :]).squeeze()
36 |
37 | def predict(self, x: np.ndarray) -> np.ndarray:
38 | # doesn't really make sense, but we'll add for compatibility
39 | return np.array([False] * len(x))
40 |
41 | def top_n_keys(self, lds: data.LabeledDataset, n: int) -> t.List[data.Key]:
42 | scores = self.predict_proba(x=lds.x)
43 | return list(np.array(lds.keys)[scores.argsort()[::-1][:n]])
44 |
45 |
46 | @dataclass
47 | class LogisticRegression(_ScikitBinaryClassifierAPI):
48 | scoring: str
49 | n_splits: int = cfg.N_SPLITS
50 | n_jobs: int = cfg.N_JOBS
51 | seed: int = cfg.SEED
52 | max_iter: int = cfg.MAX_ITER
53 |
54 | def __post_init__(self):
55 | cv = StratifiedKFold(
56 | n_splits=self.n_splits, shuffle=True, random_state=self.seed
57 | )
58 | self._model = LogisticRegressionCV(
59 | cv=cv,
60 | scoring=self.scoring,
61 | max_iter=self.max_iter,
62 | n_jobs=self.n_jobs,
63 | )
64 |
65 | def fit(self, x: np.ndarray, y: np.ndarray) -> None:
66 | self._model.fit(x, y)
67 |
68 | def predict(self, x: np.ndarray) -> np.ndarray:
69 | return self._model.predict(x)
70 |
71 | def predict_proba(self, x: np.ndarray) -> np.ndarray:
72 | return self._model.predict_proba(x)[:, 1]
73 |
74 |
75 | Model = t.Union[
76 | LogisticRegression, RandomForestClassifier, XGBClassifier, ZeroShotText2Video
77 | ]
78 |
--------------------------------------------------------------------------------