├── .gitignore ├── LICENSE ├── README.md ├── conda_env.yml ├── data-exploration.ipynb ├── demo.ipynb ├── exp1.ipynb ├── exp2.ipynb ├── images ├── 6.png ├── leave.png └── varsity.png ├── sys.png ├── task_descriptions.json ├── tests └── test_data.py └── videoannotator ├── __init__.py ├── config.py ├── data.py ├── dp.py ├── experiments.py ├── io.py └── models.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video Annotator 2 | 3 | ## Video understanding 4 | 5 | Here's a quick demo to get started: 6 | ```python 7 | from videoannotator import data, models 8 | import numpy as np 9 | 10 | # pick a label to work with 11 | label = "car-chase" 12 | 13 | # grab the annotated dataset for the label 14 | dataset = data.get_aggregate_labeled_dataset(label=label) 15 | 16 | # train a classifier with the labeled data 17 | clf = models.LogisticRegression(scoring='average_precision') 18 | clf.fit(dataset.x, dataset.y) 19 | 20 | # get the video embedding for your clip 21 | np.random.seed(0) 22 | emb = np.random.rand(1, 512) 23 | 24 | # score is between 0 and 1 25 | # higher score means that the clip is more likely to contain to `label` 26 | score = clf.predict_proba(emb).item() 27 | ``` 28 | 29 | See more details in [this notebook](demo.ipynb). 30 | 31 | ## What is Video Annotator? 32 | 33 | The best source is our paper [here](https://arxiv.org/abs/2402.06560). Here's the abstract: 34 | 35 | > High-quality and consistent annotations are fundamental to the successful development of robust machine learning models. Traditional data annotation methods are resource-intensive and inefficient, often leading to a reliance on third-party annotators who are not the domain experts. Hard samples, which are usually the most informative for model training, tend to be difficult to label accurately and consistently without business context. These can arise unpredictably during the annotation process, requiring a variable number of iterations and rounds of feedback, leading to unforeseen expenses and time commitments to guarantee quality. 36 | > 37 | >We posit that more direct involvement of domain experts, using a human-in-the-loop system, can resolve many of these practical challenges. We propose a novel framework we call Video Annotator (VA) for annotating, managing, and iterating on video classification datasets. Our approach offers a new paradigm for an end-user-centered model development process, enhancing the efficiency, usability, and effectiveness of video classifiers. Uniquely, VA allows for a continuous annotation process, seamlessly integrating data collection and model training. 38 | > 39 | > We leverage the zero-shot capabilities of vision-language foundation models combined with active learning techniques, and demonstrate that VA enables the efficient creation of high-quality models. VA achieves a median 6.8 point improvement in Average Precision relative to the most competitive baseline across a wide-ranging assortment of tasks. We release a dataset with 153k labels across 56 video understanding tasks annotated by three professional video editors using VA, and also release code to replicate our experiments. 40 | 41 | ## System 42 | ![System](sys.png) 43 | 44 | ## Data 45 | 46 | 1. Download the data tar file from [here](https://drive.google.com/file/d/107Dm1UBYQ8UOSs9rcLA1TGExYKbcg7IA/view) and unpack into your folder of choice. 47 | 2. Set `DATA_BASE_PATH` to that folder in `videoannotator/config.py`. E.g. `/root/data/videoannotator/`. 48 | 3. View [this notebook](data-exploration.ipynb) to understand how the data is organized. 49 | 50 | Once unpacked, the data should look like this: 51 | ``` 52 | ├── videoannotator 53 | │ ├── agg 54 | │ │ ├── action.json 55 | │ │ ├── aerial.json 56 | │ │ ├── ... 57 | │ ├── ave 58 | │ │ ├── keys-to-labels.json 59 | │ │ ├── labels-mapped.json 60 | │ │ ├── ... 61 | │ ├── checkpoints 62 | │ │ ├── action.json 63 | │ │ ├── aerial.json 64 | │ │ ├── ... 65 | │ ├── cmp 66 | │ │ ├── action.json 67 | │ │ ├── aerial.json 68 | │ │ ├── ... 69 | │ ├── embeddings.h5 70 | │ ├── queries.json 71 | │ ├── shot-data.csv 72 | │ ├── text-embeddings.json 73 | ``` 74 | 75 | ## Environment setup 76 | ```shell 77 | conda env create -f conda_env.yml 78 | conda activate videoannotator 79 | ``` 80 | 81 | ## Running experiments 82 | - [Experiment 1 notebook](exp1.ipynb) 83 | - [Experiment 2 notebook](exp2.ipynb) 84 | 85 | ## Citation 86 | 87 | If you find this work useful, please consider citing: 88 | 89 | ```bibtex 90 | @misc{ziai2024video, 91 | title={Video Annotator: A framework for efficiently building video classifiers using vision-language models and active learning}, 92 | author={Amir Ziai and Aneesh Vartakavi}, 93 | year={2024}, 94 | eprint={2402.06560}, 95 | archivePrefix={arXiv}, 96 | primaryClass={cs.CV} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | # TODO 2 | name: videoannotator 3 | channels: 4 | - pytorch 5 | - defaults 6 | - conda-forge 7 | - intel 8 | dependencies: 9 | - python==3.7.4 10 | - pip==20.2.4 11 | - scikit-learn==0.24.2 12 | - pandas==1.3.4 13 | - numpy==1.19.5 14 | - pytorch==1.10.0 15 | - toolz==0.10.0 16 | - pip: 17 | - xgboost==1.2.1 18 | - pytorch-metric-learning==0.9.99 19 | - scipy==1.5.4 20 | - mmh3 21 | - h5py 22 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "f9ecf656-36e9-428f-b8a4-020b90f2486e", 7 | "metadata": {}, 8 | "source": [ 9 | "# Video Annotator\n", 10 | "## Video understanding demo\n", 11 | "\n", 12 | "**Note**:\n", 13 | "- read the `data` section of the README first.\n", 14 | "- this notebook assumes that you have already\n", 15 | " - downloaded data\n", 16 | " - placed it in the directory expected by `videoannotator/config.py`" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 18, 22 | "id": "c2f837a9-84d9-4f93-abda-670e59b2a89a", 23 | "metadata": { 24 | "execution": { 25 | "iopub.execute_input": "2024-04-24T16:17:53.477481Z", 26 | "iopub.status.busy": "2024-04-24T16:17:53.476871Z", 27 | "iopub.status.idle": "2024-04-24T16:17:53.480312Z", 28 | "shell.execute_reply": "2024-04-24T16:17:53.479816Z", 29 | "shell.execute_reply.started": "2024-04-24T16:17:53.477452Z" 30 | } 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import json\n", 35 | "import numpy as np\n", 36 | "import pandas as pd\n", 37 | "from IPython.display import HTML\n", 38 | "\n", 39 | "from videoannotator import data, models, io, config as cfg" 40 | ] 41 | }, 42 | { 43 | "attachments": {}, 44 | "cell_type": "markdown", 45 | "id": "22adfc3e-63ad-41b4-a50f-44a9c58a0158", 46 | "metadata": {}, 47 | "source": [ 48 | "## helpers" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 19, 54 | "id": "4c9e40ed-3a7b-41f2-a8d7-c7f365978031", 55 | "metadata": { 56 | "execution": { 57 | "iopub.execute_input": "2024-04-24T16:17:54.134269Z", 58 | "iopub.status.busy": "2024-04-24T16:17:54.133652Z", 59 | "iopub.status.idle": "2024-04-24T16:17:54.139614Z", 60 | "shell.execute_reply": "2024-04-24T16:17:54.139129Z", 61 | "shell.execute_reply.started": "2024-04-24T16:17:54.134241Z" 62 | } 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "def get_model(label: str) -> models.LogisticRegression:\n", 67 | " \"\"\"\n", 68 | " Trains a model given all the aggregated annotations for `label`.\n", 69 | " \"\"\"\n", 70 | " ds = data.get_aggregate_labeled_dataset(label=label)\n", 71 | " clf = models.LogisticRegression(scoring='average_precision')\n", 72 | " clf.fit(ds.x, ds.y)\n", 73 | " return clf\n", 74 | "\n", 75 | "def get_score(label: str, key: str) -> float:\n", 76 | " \"\"\"\n", 77 | " Get classification score for `label` against shot `key`.\n", 78 | " \"\"\"\n", 79 | " if label not in clfs:\n", 80 | " clfs[label] = get_model(label=label)\n", 81 | " emb = data.get_embeddings_dict()[key]\n", 82 | " return clfs[label].predict_proba(emb[None, :]).item()\n", 83 | "\n", 84 | "def get_score_all_labels(key: str) -> dict:\n", 85 | " \"\"\"\n", 86 | " Scores all labels against the shot represented by `key`.\n", 87 | " \"\"\"\n", 88 | " return {label: get_score(label=label, key=key) for label in cfg.LABELS}\n", 89 | "\n", 90 | "def get_color(val: float) -> str:\n", 91 | " \"\"\"Used for table cell formatting.\"\"\"\n", 92 | " if val > 0.5:\n", 93 | " r, g = 0, 128\n", 94 | " a = val\n", 95 | " else:\n", 96 | " r, g = 255, 0\n", 97 | " a = 1 - val\n", 98 | " return f'color: rgba({r}, {g}, 0, {a})'" 99 | ] 100 | }, 101 | { 102 | "attachments": {}, 103 | "cell_type": "markdown", 104 | "id": "b820a5d4-9938-4138-8bea-5c08efa96d2b", 105 | "metadata": {}, 106 | "source": [ 107 | "## load a few models" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 40, 113 | "id": "1e4e3a8a-0e02-49f5-9926-649b51cb0c55", 114 | "metadata": { 115 | "execution": { 116 | "iopub.execute_input": "2024-04-24T16:32:42.113540Z", 117 | "iopub.status.busy": "2024-04-24T16:32:42.112924Z", 118 | "iopub.status.idle": "2024-04-24T16:32:47.059614Z", 119 | "shell.execute_reply": "2024-04-24T16:32:47.058607Z", 120 | "shell.execute_reply.started": "2024-04-24T16:32:42.113512Z" 121 | } 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "labels = ('action', 'establishing-shots', 'day', 'car-chase')\n", 126 | "clfs = {label: get_model(label=label) for label in labels}" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 21, 132 | "id": "27467ee9-092f-4276-a0fc-184ed09908ab", 133 | "metadata": { 134 | "execution": { 135 | "iopub.execute_input": "2024-04-24T16:17:59.821195Z", 136 | "iopub.status.busy": "2024-04-24T16:17:59.820774Z", 137 | "iopub.status.idle": "2024-04-24T16:17:59.828819Z", 138 | "shell.execute_reply": "2024-04-24T16:17:59.828144Z", 139 | "shell.execute_reply.started": "2024-04-24T16:17:59.821153Z" 140 | } 141 | }, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "'action, aerial, alcohol, anger, animal, car-chase, character-focus, closeup, cowboy-shot, day, drama, drugs, dutch-angle, establishing-shots, extreme-close-up, extreme-wide-shot, eye-level, fantasy, fight, golden-hour, gore, group-shot, handheld, happy, high-angle, horror, insert-shot, interior, interview, intimacy, jump-scare, laughter, low-angle, medium, nudity, object, over-the-shoulder-shot, overhead-shot, pan, point-of-view-shot, romance, run, sad, scared, sci-fi, shutter-shot, single-shot, slow-motion, smoking, static-shot, tilt-shot, timelapse, two-shot, violence, wide, zoom'" 147 | ] 148 | }, 149 | "execution_count": 21, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "# the full set of available labels is at cfg.LABELS\n", 156 | "', '.join(sorted(cfg.LABELS))" 157 | ] 158 | }, 159 | { 160 | "attachments": {}, 161 | "cell_type": "markdown", 162 | "id": "16126115-6539-4825-b758-c072ad2c0a60", 163 | "metadata": {}, 164 | "source": [ 165 | "## load a few sample embeddings" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 22, 171 | "id": "fa0bdfeb-3564-48a1-bca3-bd23f2eb3831", 172 | "metadata": { 173 | "execution": { 174 | "iopub.execute_input": "2024-04-24T16:17:59.830507Z", 175 | "iopub.status.busy": "2024-04-24T16:17:59.829997Z", 176 | "iopub.status.idle": "2024-04-24T16:18:00.025539Z", 177 | "shell.execute_reply": "2024-04-24T16:18:00.025006Z", 178 | "shell.execute_reply.started": "2024-04-24T16:17:59.830469Z" 179 | } 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "samples = json.load(open('sample-embeddings.json'))\n", 184 | "samples = {k: np.array(v) for k, v in samples.items()}" 185 | ] 186 | }, 187 | { 188 | "attachments": {}, 189 | "cell_type": "markdown", 190 | "id": "1644163e-f1e7-4771-9025-2bff678820c8", 191 | "metadata": {}, 192 | "source": [ 193 | "## score the sample embeddings" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 41, 199 | "id": "0b0c6702-edca-47e4-93b6-1f5044556f86", 200 | "metadata": { 201 | "execution": { 202 | "iopub.execute_input": "2024-04-24T16:32:47.066975Z", 203 | "iopub.status.busy": "2024-04-24T16:32:47.064680Z", 204 | "iopub.status.idle": "2024-04-24T16:32:47.073915Z", 205 | "shell.execute_reply": "2024-04-24T16:32:47.073259Z", 206 | "shell.execute_reply.started": "2024-04-24T16:32:47.066931Z" 207 | } 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "def _get_record(title, emb):\n", 212 | " d = dict(\n", 213 | " title=title,\n", 214 | " video_thumbnail=f'',\n", 215 | " )\n", 216 | " scores = {\n", 217 | " f'{label}_score': clf.predict_proba(emb[None, :]).item()\n", 218 | " for label, clf in clfs.items()\n", 219 | " }\n", 220 | " return {**d, **scores}" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 42, 226 | "id": "a7c2a338-9d7a-49b8-8f7d-1ab9dea66930", 227 | "metadata": { 228 | "execution": { 229 | "iopub.execute_input": "2024-04-24T16:32:47.076581Z", 230 | "iopub.status.busy": "2024-04-24T16:32:47.076044Z", 231 | "iopub.status.idle": "2024-04-24T16:32:47.296071Z", 232 | "shell.execute_reply": "2024-04-24T16:32:47.295541Z", 233 | "shell.execute_reply.started": "2024-04-24T16:32:47.076544Z" 234 | } 235 | }, 236 | "outputs": [ 237 | { 238 | "data": { 239 | "text/html": [ 240 | "\n", 278 | "\n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | "
  action_scoreestablishing-shots_scoreday_scorecar-chase_score
titlevideo_thumbnail    
Operation Varsity Blues0.180.990.840.01
6 Underground0.730.300.690.78
Leave the world behind0.360.120.110.04
\n" 324 | ], 325 | "text/plain": [ 326 | "" 327 | ] 328 | }, 329 | "execution_count": 42, 330 | "metadata": {}, 331 | "output_type": "execute_result" 332 | } 333 | ], 334 | "source": [ 335 | "pngs = {'Operation Varsity Blues': 'varsity.png', '6 Underground': '6.png', 'Leave the world behind': 'leave.png'}\n", 336 | "df = pd.DataFrame(\n", 337 | " _get_record(title, emb)\n", 338 | " for title, emb in samples.items()\n", 339 | ").set_index(['title', 'video_thumbnail']) \n", 340 | "df = df.style.applymap(get_color).format(lambda x: f'{x:.2f}')\n", 341 | "HTML(df.to_html(escape=False))" 342 | ] 343 | }, 344 | { 345 | "attachments": {}, 346 | "cell_type": "markdown", 347 | "id": "6a3cc399-f546-4197-8cbf-071ce7ed6d3b", 348 | "metadata": {}, 349 | "source": [ 350 | "## bring your own C4C embeddings\n", 351 | "\n", 352 | "Use [Clip4CLIP](https://github.com/ArrowLuo/CLIP4Clip) to extract video embeddings from your own videos and then apply." 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 25, 358 | "id": "3d1b3dc0-6a1e-4a38-9d64-1d8467bd499d", 359 | "metadata": { 360 | "execution": { 361 | "iopub.execute_input": "2024-04-24T16:18:01.144671Z", 362 | "iopub.status.busy": "2024-04-24T16:18:01.144166Z", 363 | "iopub.status.idle": "2024-04-24T16:18:01.147823Z", 364 | "shell.execute_reply": "2024-04-24T16:18:01.147172Z", 365 | "shell.execute_reply.started": "2024-04-24T16:18:01.144641Z" 366 | } 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "np.random.seed(0)\n", 371 | "your_emb = np.random.rand(1, 512) # replace with your own" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 26, 377 | "id": "f41d1b6d-2df3-4df3-a807-4cc3af2c0f4f", 378 | "metadata": { 379 | "execution": { 380 | "iopub.execute_input": "2024-04-24T16:18:01.728981Z", 381 | "iopub.status.busy": "2024-04-24T16:18:01.728377Z", 382 | "iopub.status.idle": "2024-04-24T16:18:02.332514Z", 383 | "shell.execute_reply": "2024-04-24T16:18:02.331604Z", 384 | "shell.execute_reply.started": "2024-04-24T16:18:01.728952Z" 385 | } 386 | }, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "0.10548243605868816" 392 | ] 393 | }, 394 | "execution_count": 26, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "label = 'car-chase' # pick from any model in `cfg.LABELS`\n", 401 | "get_model(label=label).predict_proba(your_emb).item()" 402 | ] 403 | }, 404 | { 405 | "attachments": {}, 406 | "cell_type": "markdown", 407 | "id": "e51aec2e-6a65-4ed8-a46a-5f9f6748e6fd", 408 | "metadata": {}, 409 | "source": [ 410 | "## apply any model to any shot in the corpus from the Condensed Movie Dataset\n", 411 | "More details at [this link](https://www.robots.ox.ac.uk/~vgg/data/condensed-movies/)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 27, 417 | "id": "81762de3-c875-4f01-892b-da9c335ef4fb", 418 | "metadata": { 419 | "execution": { 420 | "iopub.execute_input": "2024-04-24T16:18:03.529245Z", 421 | "iopub.status.busy": "2024-04-24T16:18:03.528611Z", 422 | "iopub.status.idle": "2024-04-24T16:18:04.588503Z", 423 | "shell.execute_reply": "2024-04-24T16:18:04.587890Z", 424 | "shell.execute_reply.started": "2024-04-24T16:18:03.529216Z" 425 | } 426 | }, 427 | "outputs": [ 428 | { 429 | "data": { 430 | "text/html": [ 431 | "
\n", 432 | "\n", 445 | "\n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | "
yt_iddescgenreframe_inframe_outstartendfpscluster_index
key
HnnBvemHrWA|0HnnBvemHrWAFrankie and Johnny - He Just Asked Her OutComedy0820.0000003.42031123.9744278
gOJJm_cSRds|3987gOJJm_cSRdsThe Wild Bunch - Bank ShootoutAction39874007166.302203167.13642523.9744276
mDViU8OSRkA|2222mDViU8OSRkANon-Stop - How's Your Daughter? SceneNaN2222228492.67591795.26183323.9760243
egB-SG97EcI|805egB-SG97EcIKalifornia - I Think I Gotta Kill You SceneCrime80582533.57519034.40935723.9760370
uGsWYV2bWAc|1013uGsWYV2bWAcKill Bill: Vol. 1 - The Bride vs. GogoAction1013104142.25054243.41837523.9760247
\n", 535 | "
" 536 | ], 537 | "text/plain": [ 538 | " yt_id desc \\\n", 539 | "key \n", 540 | "HnnBvemHrWA|0 HnnBvemHrWA Frankie and Johnny - He Just Asked Her Out \n", 541 | "gOJJm_cSRds|3987 gOJJm_cSRds The Wild Bunch - Bank Shootout \n", 542 | "mDViU8OSRkA|2222 mDViU8OSRkA Non-Stop - How's Your Daughter? Scene \n", 543 | "egB-SG97EcI|805 egB-SG97EcI Kalifornia - I Think I Gotta Kill You Scene \n", 544 | "uGsWYV2bWAc|1013 uGsWYV2bWAc Kill Bill: Vol. 1 - The Bride vs. Gogo \n", 545 | "\n", 546 | " genre frame_in frame_out start end \\\n", 547 | "key \n", 548 | "HnnBvemHrWA|0 Comedy 0 82 0.000000 3.420311 \n", 549 | "gOJJm_cSRds|3987 Action 3987 4007 166.302203 167.136425 \n", 550 | "mDViU8OSRkA|2222 NaN 2222 2284 92.675917 95.261833 \n", 551 | "egB-SG97EcI|805 Crime 805 825 33.575190 34.409357 \n", 552 | "uGsWYV2bWAc|1013 Action 1013 1041 42.250542 43.418375 \n", 553 | "\n", 554 | " fps cluster_index \n", 555 | "key \n", 556 | "HnnBvemHrWA|0 23.974427 8 \n", 557 | "gOJJm_cSRds|3987 23.974427 6 \n", 558 | "mDViU8OSRkA|2222 23.976024 3 \n", 559 | "egB-SG97EcI|805 23.976037 0 \n", 560 | "uGsWYV2bWAc|1013 23.976024 7 " 561 | ] 562 | }, 563 | "execution_count": 27, 564 | "metadata": {}, 565 | "output_type": "execute_result" 566 | } 567 | ], 568 | "source": [ 569 | "df_shots = pd.read_csv(io.PATHS_STATIC.shot_data).set_index('key')\n", 570 | "df_shots.head()" 571 | ] 572 | }, 573 | { 574 | "attachments": {}, 575 | "cell_type": "markdown", 576 | "id": "bb5ae078-757a-4c1d-ab21-ddbaab2c03f5", 577 | "metadata": {}, 578 | "source": [ 579 | "### pick a key from the dataframe above" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 28, 585 | "id": "b30b22db-aa12-420c-9e2d-e64c300e5aff", 586 | "metadata": { 587 | "execution": { 588 | "iopub.execute_input": "2024-04-24T16:18:05.572441Z", 589 | "iopub.status.busy": "2024-04-24T16:18:05.571805Z", 590 | "iopub.status.idle": "2024-04-24T16:18:05.575349Z", 591 | "shell.execute_reply": "2024-04-24T16:18:05.574731Z", 592 | "shell.execute_reply.started": "2024-04-24T16:18:05.572412Z" 593 | } 594 | }, 595 | "outputs": [], 596 | "source": [ 597 | "key = 'HnnBvemHrWA|0'" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 29, 603 | "id": "2c547f22-61fa-49a9-ace4-0e73d215f618", 604 | "metadata": { 605 | "execution": { 606 | "iopub.execute_input": "2024-04-24T16:18:06.156529Z", 607 | "iopub.status.busy": "2024-04-24T16:18:06.155931Z", 608 | "iopub.status.idle": "2024-04-24T16:18:06.160842Z", 609 | "shell.execute_reply": "2024-04-24T16:18:06.160237Z", 610 | "shell.execute_reply.started": "2024-04-24T16:18:06.156499Z" 611 | } 612 | }, 613 | "outputs": [ 614 | { 615 | "data": { 616 | "text/plain": [ 617 | "0.15079263042877716" 618 | ] 619 | }, 620 | "execution_count": 29, 621 | "metadata": {}, 622 | "output_type": "execute_result" 623 | } 624 | ], 625 | "source": [ 626 | "get_score(label='action', key=key)" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": 30, 632 | "id": "c71691d3-3421-4b74-ab1e-e37da32d2528", 633 | "metadata": { 634 | "execution": { 635 | "iopub.execute_input": "2024-04-24T16:18:06.725140Z", 636 | "iopub.status.busy": "2024-04-24T16:18:06.724526Z", 637 | "iopub.status.idle": "2024-04-24T16:18:07.366689Z", 638 | "shell.execute_reply": "2024-04-24T16:18:07.365931Z", 639 | "shell.execute_reply.started": "2024-04-24T16:18:06.725110Z" 640 | } 641 | }, 642 | "outputs": [ 643 | { 644 | "data": { 645 | "text/plain": [ 646 | "0.8293545496789174" 647 | ] 648 | }, 649 | "execution_count": 30, 650 | "metadata": {}, 651 | "output_type": "execute_result" 652 | } 653 | ], 654 | "source": [ 655 | "get_score(label='closeup', key=key)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 31, 661 | "id": "2ddb686b-4cd7-44cd-83c4-4a87d9f279b4", 662 | "metadata": { 663 | "execution": { 664 | "iopub.execute_input": "2024-04-24T16:18:07.368947Z", 665 | "iopub.status.busy": "2024-04-24T16:18:07.368171Z", 666 | "iopub.status.idle": "2024-04-24T16:18:07.375109Z", 667 | "shell.execute_reply": "2024-04-24T16:18:07.374429Z", 668 | "shell.execute_reply.started": "2024-04-24T16:18:07.368902Z" 669 | } 670 | }, 671 | "outputs": [ 672 | { 673 | "data": { 674 | "text/plain": [ 675 | "0.02759443746631655" 676 | ] 677 | }, 678 | "execution_count": 31, 679 | "metadata": {}, 680 | "output_type": "execute_result" 681 | } 682 | ], 683 | "source": [ 684 | "get_score(label='establishing-shots', key=key)" 685 | ] 686 | }, 687 | { 688 | "attachments": {}, 689 | "cell_type": "markdown", 690 | "id": "30b13fbd-9f36-4d4c-b5ff-d91eec5f53c3", 691 | "metadata": {}, 692 | "source": [ 693 | "## apply all models to a shot in the corpus" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 32, 699 | "id": "3d9996b5-c655-4fd0-8454-d69489518fbe", 700 | "metadata": { 701 | "execution": { 702 | "iopub.execute_input": "2024-04-24T16:18:08.473189Z", 703 | "iopub.status.busy": "2024-04-24T16:18:08.472662Z", 704 | "iopub.status.idle": "2024-04-24T16:18:08.476199Z", 705 | "shell.execute_reply": "2024-04-24T16:18:08.475710Z", 706 | "shell.execute_reply.started": "2024-04-24T16:18:08.473160Z" 707 | } 708 | }, 709 | "outputs": [], 710 | "source": [ 711 | "def get_link(key: str) -> HTML:\n", 712 | " rec = df_shots.loc[key]\n", 713 | " link = f'https://youtu.be/{rec.yt_id}?t={int(rec.start)}'\n", 714 | " return HTML(f'click here to watch')" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 33, 720 | "id": "1593eec5-fbd5-4eaa-817e-f29b109df8aa", 721 | "metadata": { 722 | "execution": { 723 | "iopub.execute_input": "2024-04-24T16:18:09.089381Z", 724 | "iopub.status.busy": "2024-04-24T16:18:09.088766Z", 725 | "iopub.status.idle": "2024-04-24T16:18:25.917844Z", 726 | "shell.execute_reply": "2024-04-24T16:18:25.917005Z", 727 | "shell.execute_reply.started": "2024-04-24T16:18:09.089350Z" 728 | } 729 | }, 730 | "outputs": [ 731 | { 732 | "data": { 733 | "text/plain": [ 734 | "eye-level 0.89\n", 735 | "interior 0.83\n", 736 | "closeup 0.83\n", 737 | "character-focus 0.75\n", 738 | "single-shot 0.65\n", 739 | "static-shot 0.55\n", 740 | "medium 0.44\n", 741 | "fantasy 0.39\n", 742 | "object 0.39\n", 743 | "low-angle 0.37\n", 744 | "drama 0.34\n", 745 | "violence 0.33\n", 746 | "horror 0.31\n", 747 | "over-the-shoulder-shot 0.31\n", 748 | "day 0.27\n", 749 | "scared 0.22\n", 750 | "smoking 0.22\n", 751 | "anger 0.20\n", 752 | "happy 0.17\n", 753 | "handheld 0.16\n", 754 | "animal 0.15\n", 755 | "action 0.15\n", 756 | "sad 0.12\n", 757 | "tilt-shot 0.11\n", 758 | "pan 0.10\n", 759 | "insert-shot 0.09\n", 760 | "cowboy-shot 0.08\n", 761 | "alcohol 0.07\n", 762 | "zoom 0.07\n", 763 | "intimacy 0.06\n", 764 | "dutch-angle 0.06\n", 765 | "slow-motion 0.05\n", 766 | "run 0.05\n", 767 | "wide 0.05\n", 768 | "fight 0.05\n", 769 | "high-angle 0.05\n", 770 | "two-shot 0.04\n", 771 | "group-shot 0.04\n", 772 | "aerial 0.03\n", 773 | "establishing-shots 0.03\n", 774 | "point-of-view-shot 0.03\n", 775 | "extreme-close-up 0.02\n", 776 | "drugs 0.02\n", 777 | "golden-hour 0.02\n", 778 | "sci-fi 0.02\n", 779 | "romance 0.02\n", 780 | "gore 0.01\n", 781 | "laughter 0.01\n", 782 | "nudity 0.01\n", 783 | "shutter-shot 0.01\n", 784 | "extreme-wide-shot 0.01\n", 785 | "car-chase 0.01\n", 786 | "jump-scare 0.01\n", 787 | "overhead-shot 0.01\n", 788 | "interview 0.01\n", 789 | "timelapse 0.00\n", 790 | "dtype: float64" 791 | ] 792 | }, 793 | "execution_count": 33, 794 | "metadata": {}, 795 | "output_type": "execute_result" 796 | } 797 | ], 798 | "source": [ 799 | "# sort by score descending\n", 800 | "pd.Series(get_score_all_labels(key=key)).sort_values(ascending=False).round(2)" 801 | ] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "execution_count": 34, 806 | "id": "8796347f-834c-4062-95e4-a3f556ac9fd1", 807 | "metadata": { 808 | "execution": { 809 | "iopub.execute_input": "2024-04-24T16:18:25.923165Z", 810 | "iopub.status.busy": "2024-04-24T16:18:25.922349Z", 811 | "iopub.status.idle": "2024-04-24T16:18:26.024421Z", 812 | "shell.execute_reply": "2024-04-24T16:18:26.023892Z", 813 | "shell.execute_reply.started": "2024-04-24T16:18:25.923124Z" 814 | } 815 | }, 816 | "outputs": [ 817 | { 818 | "data": { 819 | "text/html": [ 820 | "click here to watch" 821 | ], 822 | "text/plain": [ 823 | "" 824 | ] 825 | }, 826 | "execution_count": 34, 827 | "metadata": {}, 828 | "output_type": "execute_result" 829 | } 830 | ], 831 | "source": [ 832 | "# get the link to watch\n", 833 | "get_link(key=key)" 834 | ] 835 | } 836 | ], 837 | "metadata": { 838 | "kernelspec": { 839 | "display_name": "Python 3.10 (recommended)", 840 | "language": "python", 841 | "name": "python310" 842 | }, 843 | "language_info": { 844 | "codemirror_mode": { 845 | "name": "ipython", 846 | "version": 3 847 | }, 848 | "file_extension": ".py", 849 | "mimetype": "text/x-python", 850 | "name": "python", 851 | "nbconvert_exporter": "python", 852 | "pygments_lexer": "ipython3", 853 | "version": "3.10.14" 854 | } 855 | }, 856 | "nbformat": 4, 857 | "nbformat_minor": 5 858 | } 859 | -------------------------------------------------------------------------------- /exp2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "33a23872-6b85-40c8-baa8-f2a42702b630", 7 | "metadata": {}, 8 | "source": [ 9 | "# Experiment 2\n", 10 | "Need to run experiment 1 first" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "dbe1fb03", 17 | "metadata": { 18 | "execution": { 19 | "iopub.execute_input": "2024-02-09T06:58:20.320622Z", 20 | "iopub.status.busy": "2024-02-09T06:58:20.320375Z", 21 | "iopub.status.idle": "2024-02-09T06:58:20.721700Z", 22 | "shell.execute_reply": "2024-02-09T06:58:20.720953Z", 23 | "shell.execute_reply.started": "2024-02-09T06:58:20.320598Z" 24 | }, 25 | "tags": [] 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import pathlib\n", 30 | "import functools\n", 31 | "import pandas as pd\n", 32 | "import numpy as np\n", 33 | "import math\n", 34 | "import pickle\n", 35 | "\n", 36 | "from tqdm.auto import tqdm\n", 37 | "\n", 38 | "from videoannotator import data, dp, experiments" 39 | ] 40 | }, 41 | { 42 | "attachments": {}, 43 | "cell_type": "markdown", 44 | "id": "243bc89c-e52c-4c40-a086-9fa790de4f39", 45 | "metadata": {}, 46 | "source": [ 47 | "## helpers" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "id": "fb6ab183-2d5e-4b71-9475-ceb7bbf03cbb", 54 | "metadata": { 55 | "execution": { 56 | "iopub.execute_input": "2024-02-09T06:58:20.724968Z", 57 | "iopub.status.busy": "2024-02-09T06:58:20.724526Z", 58 | "iopub.status.idle": "2024-02-09T06:58:20.734467Z", 59 | "shell.execute_reply": "2024-02-09T06:58:20.733913Z", 60 | "shell.execute_reply.started": "2024-02-09T06:58:20.724943Z" 61 | }, 62 | "tags": [] 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "dps = {\n", 67 | " 'eps0.1': functools.partial(dp.EpsilonGreedyMean, epsilon=0.1),\n", 68 | " 'eps0.25': functools.partial(dp.EpsilonGreedyMean, epsilon=0.25),\n", 69 | " 'eps0.5': functools.partial(dp.EpsilonGreedyMean, epsilon=0.5),\n", 70 | " 'rr': dp.RoundRobin,\n", 71 | " 'greedy': dp.GreedyOracle,\n", 72 | " 'ucb': dp.UCBMean,\n", 73 | " 'ucb.1': functools.partial(dp.UCBMean, c=1e-1),\n", 74 | " # 'ecbwema': dp.UCBEWMA,\n", 75 | " # 'eps0.25ewma': functools.partial(dp.EpsilonGreedyEWMA, epsilon=0.25),\n", 76 | "}\n", 77 | "\n", 78 | "@functools.lru_cache(10_000)\n", 79 | "def run(dp_name, d, label):\n", 80 | " dp = dps[dp_name]\n", 81 | " try:\n", 82 | " return dict(\n", 83 | " dp=dp,\n", 84 | " d=d,\n", 85 | " label=label,\n", 86 | " res=dp(label=label, d=d, _seed_base=1).run(n=1_000 // d),\n", 87 | " )\n", 88 | " except Exception as e:\n", 89 | " raise ValueError(f'Error in dp={dp}, d={d}, label={label}: {e}')\n", 90 | " \n", 91 | "def get_best(r):\n", 92 | " res = r['res']\n", 93 | " return max(\n", 94 | " (\n", 95 | " (res.average_precision.mean, res.average_precision)\n", 96 | " for _, res in res.res_hist\n", 97 | " ),\n", 98 | " key=lambda x: x[0],\n", 99 | " )[1]\n", 100 | "\n", 101 | "def get_algo_detes(r):\n", 102 | " dp = r['dp']\n", 103 | " if dp.__class__ == functools.partial:\n", 104 | " return dict(\n", 105 | " algo=dp.func.__name__,\n", 106 | " kws=dp.keywords,\n", 107 | " args=dp.args,\n", 108 | " )\n", 109 | " else:\n", 110 | " return dict(algo=dp.__name__)\n", 111 | " \n", 112 | "def eps(q):\n", 113 | " if pd.notnull(q) and 'epsilon' in q:\n", 114 | " return q.get('epsilon') == .25\n", 115 | " \n", 116 | " return True" 117 | ] 118 | }, 119 | { 120 | "attachments": {}, 121 | "cell_type": "markdown", 122 | "id": "35983077-22a6-4045-87f4-1ed162ddb21a", 123 | "metadata": {}, 124 | "source": [ 125 | "## run experiment 2" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 3, 131 | "id": "a8e9e3e5-e6be-44a2-9144-6abb528b3128", 132 | "metadata": { 133 | "execution": { 134 | "iopub.execute_input": "2024-02-09T06:58:20.737344Z", 135 | "iopub.status.busy": "2024-02-09T06:58:20.736926Z", 136 | "iopub.status.idle": "2024-02-09T06:58:25.203956Z", 137 | "shell.execute_reply": "2024-02-09T06:58:25.203228Z", 138 | "shell.execute_reply.started": "2024-02-09T06:58:20.737321Z" 139 | }, 140 | "tags": [] 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "# this takes ~20 mins on a machine with 64 CPUs and 64GB RAM\n", 145 | "path_results_pkl = 'exp2_results.pkl'\n", 146 | "if not pathlib.Path(path_results_pkl).exists():\n", 147 | " # takes a few hours\n", 148 | " items = [\n", 149 | " (dp_name, d, label)\n", 150 | " for i, dp_name in enumerate(dps)\n", 151 | " for d in (25,)\n", 152 | " for label in sorted(data.cfg.LABELS)\n", 153 | " ]\n", 154 | " res = [run(*q) for q in tqdm(items)]\n", 155 | " pickle.dump(res, open(path_results_pkl, 'wb'))\n", 156 | "else:\n", 157 | " res = pickle.load(open(path_results_pkl, 'rb'))" 158 | ] 159 | }, 160 | { 161 | "attachments": {}, 162 | "cell_type": "markdown", 163 | "id": "783eb716-826e-449b-9549-ccbf637b2cc7", 164 | "metadata": {}, 165 | "source": [ 166 | "## analyze" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "id": "137666f0", 173 | "metadata": { 174 | "execution": { 175 | "iopub.execute_input": "2024-02-09T06:58:25.211433Z", 176 | "iopub.status.busy": "2024-02-09T06:58:25.211089Z", 177 | "iopub.status.idle": "2024-02-09T06:58:26.310674Z", 178 | "shell.execute_reply": "2024-02-09T06:58:26.309992Z", 179 | "shell.execute_reply.started": "2024-02-09T06:58:25.211408Z" 180 | }, 181 | "tags": [] 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "# this object is computed in experiment 1\n", 186 | "ap_va = pickle.load(open('ap_va_best.pkl', 'rb'))\n", 187 | "ap_va_best = {k: v['best'] for k, v in ap_va.items()}\n", 188 | "ap_va_last = {k: v['last'] for k, v in ap_va.items()}\n", 189 | "\n", 190 | "df_res = pd.DataFrame(\n", 191 | " dict(\n", 192 | " d=r['d'],\n", 193 | " label=r['label'],\n", 194 | " best_idx=r['res'].best_idx,\n", 195 | " best_metric=get_best(r),\n", 196 | " best_metric_mean=get_best(r).mean,\n", 197 | " last_metric=r['res'].last_metric,\n", 198 | " last_metric_mean=r['res'].last_metric.mean,\n", 199 | " res=r['res'],\n", 200 | " after_5=r['res'].res_hist[5][1].average_precision.mean,\n", 201 | " after_10=r['res'].res_hist[10][1].average_precision.mean,\n", 202 | " **get_algo_detes(r),\n", 203 | " )\n", 204 | " for r in res\n", 205 | ")\n", 206 | "df_res = df_res.assign(\n", 207 | " ap_va_best=df_res.label.map(ap_va_best),\n", 208 | " ap_va_last=df_res.label.map(ap_va_last),\n", 209 | ")\n", 210 | "d = df_res[df_res.algo == 'UCBMean'].groupby(['d', 'label']).apply(lambda x: pd.Series(dict(\n", 211 | " best_va=x.ap_va_best.max(),\n", 212 | " best_bandit=x.best_metric_mean.max(),\n", 213 | " last_va=x.ap_va_last.max(),\n", 214 | " last_bandit=x.last_metric_mean.max(),\n", 215 | ")))\n", 216 | "d = d.assign(\n", 217 | " lift_best=d.best_bandit - d.best_va,\n", 218 | " lift_max=d.last_bandit - d.last_va,\n", 219 | ")\n", 220 | "dfr = df_res[df_res.kws.apply(eps)]\n", 221 | "e = dfr[dfr.d == 25].groupby(['algo', 'label']).apply(lambda x: pd.Series(dict(\n", 222 | " best_va=x.ap_va_best.max(),\n", 223 | " best_bandit=x.best_metric_mean.max(),\n", 224 | ")))\n", 225 | "e = e.assign(lift_best=e.best_bandit - e.best_va,)\n", 226 | "algos = ['RoundRobin', 'GreedyOracle', 'EpsilonGreedyMean', 'UCBMean']\n", 227 | "cols = ['min', '25%', '50%', '75%', 'max']" 228 | ] 229 | }, 230 | { 231 | "attachments": {}, 232 | "cell_type": "markdown", 233 | "id": "854e5e4e-0fd0-40db-91e6-e6b766341f87", 234 | "metadata": {}, 235 | "source": [ 236 | "## table 3" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 5, 242 | "id": "fa91f320", 243 | "metadata": { 244 | "execution": { 245 | "iopub.execute_input": "2024-02-09T06:58:26.318997Z", 246 | "iopub.status.busy": "2024-02-09T06:58:26.318663Z", 247 | "iopub.status.idle": "2024-02-09T06:58:26.343227Z", 248 | "shell.execute_reply": "2024-02-09T06:58:26.342682Z", 249 | "shell.execute_reply.started": "2024-02-09T06:58:26.318973Z" 250 | }, 251 | "tags": [] 252 | }, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/html": [ 257 | "
\n", 258 | "\n", 271 | "\n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | "
p10p25p50p75p90
algo
RoundRobin-5.9-0.32.57.415.1
GreedyOracle-5.40.93.99.616.2
EpsilonGreedyMean-6.5-0.62.97.616.7
UCBMean-7.5-0.23.49.215.9
\n", 325 | "
" 326 | ], 327 | "text/plain": [ 328 | " p10 p25 p50 p75 p90\n", 329 | "algo \n", 330 | "RoundRobin -5.9 -0.3 2.5 7.4 15.1\n", 331 | "GreedyOracle -5.4 0.9 3.9 9.6 16.2\n", 332 | "EpsilonGreedyMean -6.5 -0.6 2.9 7.6 16.7\n", 333 | "UCBMean -7.5 -0.2 3.4 9.2 15.9" 334 | ] 335 | }, 336 | "execution_count": 5, 337 | "metadata": {}, 338 | "output_type": "execute_result" 339 | } 340 | ], 341 | "source": [ 342 | "ee = e.reset_index().groupby('algo').lift_best.agg(\n", 343 | " p10=functools.partial(np.percentile, q=10),\n", 344 | " p25=functools.partial(np.percentile, q=25),\n", 345 | " p50=functools.partial(np.percentile, q=50),\n", 346 | " p75=functools.partial(np.percentile, q=75),\n", 347 | " p90=functools.partial(np.percentile, q=90)\n", 348 | ").loc[algos]\n", 349 | "(ee * 100).round(1)" 350 | ] 351 | }, 352 | { 353 | "attachments": {}, 354 | "cell_type": "markdown", 355 | "id": "c0560a99-4f29-4de5-8d42-3b87b632fbd2", 356 | "metadata": {}, 357 | "source": [ 358 | "## cumulative results\n", 359 | "Add up the gain from experiment 1 to this gain and compute the median across labels." 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 6, 365 | "id": "6eac9359-e97c-42c0-9176-e716f5c4246f", 366 | "metadata": { 367 | "execution": { 368 | "iopub.execute_input": "2024-02-09T06:58:26.344418Z", 369 | "iopub.status.busy": "2024-02-09T06:58:26.343981Z", 370 | "iopub.status.idle": "2024-02-09T06:58:26.549934Z", 371 | "shell.execute_reply": "2024-02-09T06:58:26.549376Z", 372 | "shell.execute_reply.started": "2024-02-09T06:58:26.344394Z" 373 | }, 374 | "tags": [] 375 | }, 376 | "outputs": [ 377 | { 378 | "data": { 379 | "text/plain": [ 380 | "8.3" 381 | ] 382 | }, 383 | "execution_count": 6, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "e1 = pickle.load(open('exp1_gain.pkl', 'rb'))\n", 390 | "f = e.reset_index()\n", 391 | "e2 = f[f.algo == 'UCBMean'].set_index('label').lift_best.to_dict()\n", 392 | "comb = pd.Series({\n", 393 | " k: e1[k] + e2[k]\n", 394 | " for k in e1\n", 395 | " if pd.notnull(e1[k]) and pd.notnull(e2[k])\n", 396 | "})\n", 397 | "(np.percentile(comb, q=50) * 100).round(1)" 398 | ] 399 | } 400 | ], 401 | "metadata": { 402 | "hide_input": false, 403 | "kernelspec": { 404 | "display_name": "Python 3 (ipykernel)", 405 | "language": "python", 406 | "name": "python3" 407 | }, 408 | "language_info": { 409 | "codemirror_mode": { 410 | "name": "ipython", 411 | "version": 3 412 | }, 413 | "file_extension": ".py", 414 | "mimetype": "text/x-python", 415 | "name": "python", 416 | "nbconvert_exporter": "python", 417 | "pygments_lexer": "ipython3", 418 | "version": "3.7.12" 419 | } 420 | }, 421 | "nbformat": 4, 422 | "nbformat_minor": 5 423 | } 424 | -------------------------------------------------------------------------------- /images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/images/6.png -------------------------------------------------------------------------------- /images/leave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/images/leave.png -------------------------------------------------------------------------------- /images/varsity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/images/varsity.png -------------------------------------------------------------------------------- /sys.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/sys.png -------------------------------------------------------------------------------- /task_descriptions.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": { 3 | "description": "Climactic moments or exciting, exhilarating or fast-paced shots or sequences", 4 | "keywords": "Explosions, Car Chase, Fighting, Running, Shootout" 5 | }, 6 | "intimacy": { 7 | "description": "Two characters being physically close to each other and showing affection", 8 | "keywords": "Hugging, Kissing, Caressing, Cuddling" 9 | }, 10 | "sci-fi": { 11 | "description": "Imaginary concepts that blend science and fantasy", 12 | "keywords": "Space, Aliens, Monsters, Robots, Futuristic Technology, Lasers" 13 | }, 14 | "character-focus": { 15 | "description": "Characters are clearly in focus and properly displayed", 16 | "keywords": "Profile shots, Facial shots" 17 | }, 18 | "establishing-shots": { 19 | "description": "Extremely wide shot that sets the location or time period", 20 | "keywords": "Landscapes, Buildings" 21 | }, 22 | "nudity": { 23 | "description": "Exposed flesh or socially coveted body parts", 24 | "keywords": "genitals, breast, buttocks" 25 | }, 26 | "gore": { 27 | "description": "Extreme violence against another person", 28 | "keywords": "Blood Spewing, Decapitation, Impact Shots, Stabbing" 29 | }, 30 | "smoking": { 31 | "description": "Characters smoking tobacco or drugs", 32 | "keywords": "Cigarette, Cigars, Vapes, Drugs" 33 | }, 34 | "happy": { 35 | "description": "Characters expressing happiness", 36 | "keywords": "Smiling, Joyful, Ecstatic" 37 | }, 38 | "sad": { 39 | "description": "Characters expressing sadness", 40 | "keywords": "Crying, Tearing, Frowning, Looking Down, Despair" 41 | }, 42 | "horror": { 43 | "description": "Content aimed at eliciting fear, shock or disgust. This applies to story, characters, mood and atmosphere. Elements that suggest that the character is in danger or fearing for their life. Subgenres include psychological horror, supernatural horror, gore, slashers", 44 | "keywords": "Death, Gore, Blood, Violence, Devil, Demons, Paranormal, Monsters, Cults, Serial Killers" 45 | }, 46 | "alcohol": { 47 | "description": "Alcohol consumption is when a character drinks an alcoholic beverage that, depending on the quantity consumed, can cause intoxication or feelings of euphoria, inhibition of judgment, reduced anxiety, and/or sociability. The consumption of alcohol does not have to be shown for this tag to apply, but it should be clearly established that alcohol is the cause of actions/ behaviour depicted", 48 | "keywords": "Alcohol bottles, Beer, Shots, Characters Inebriated, Being Drunk" 49 | }, 50 | "anger": { 51 | "description": "Character expressing hostility, discontent, annoyance. Anger can be expressed in a range of low, medium, high; passive aggression, open aggression and assertive aggression.", 52 | "keywords": "Yelling, Shouting, Rage, Tense, Displeased, Enacting Violence" 53 | }, 54 | "aerial": { 55 | "description": "A shot captured from an elevated vantage point, taken above the subject.", 56 | "keywords": "Drone Shots, Helicopter shots, Bird’s Eye View, Overhead, Camera pointing down" 57 | }, 58 | "eye-level": { 59 | "description": "When the level of the camera is placed at the same height as the eyes of the characters in your frame. An eye level camera angle does not require the viewer to see the eyes of the actor, nor does the actor need to look directly into the camera for a shot to be considered eye level.", 60 | "keywords": "Character Facing, Head on shots, Looking Straight Ahead, Medium Close up" 61 | }, 62 | "medium": { 63 | "description": "Shots depict the subject approximately from the waist up emphasizing both the subject and their surrounding. ", 64 | "keywords": "Interview, Table Talk, Conversation" 65 | }, 66 | "closeup": { 67 | "description": "Shots taken at a close range intended to show greater detail.", 68 | "keywords": "Face, Hands, Objects, Details" 69 | }, 70 | "wide": { 71 | "description": "Wide shots, also known as long shot, show the entire subject and their relation to the surrounding environment.", 72 | "keywords": "Establishing shots, City Scapes, Landscape, Field, Nature" 73 | }, 74 | "zoom": { 75 | "description": "A zoom shot is when the focal length of a camera lens is adjusted to give the illusion of moving closer or further away from the subject. Zoom shots are done with a zoom lens, which have variable focal lengths. The most typical camera zoom types are the \"zoom in\" and \"zoom out.\" But this technique can also be combined with a dolly shot to create a dolly zoom.", 76 | "keywords": "Zoom In / out, Push In / Out, Dolly In / Out, Hitchcock Zoom" 77 | }, 78 | "pan": { 79 | "description": "A camera pan is horizontal movement in which the camera moves right to left from a central axis. Taken from the word \"panorama,\" the physical position of the camera itself does not move. Instead, the panning refers to the swiveling left to right mounted on a fixed location. This differs from dolly shots in which the mounting system itself moves. Of course, a pan can also be combined with different camera movements such as dolly shots or tracking shots.", 80 | "keywords": "Reveal, Turn, Tracking" 81 | }, 82 | "drama": { 83 | "description": "Titles with a serious, sober gravity of overall tone or titles that attempt to present a story that could conceivably happen in real life. A moment of high emotion.", 84 | "keywords": "Emotional, Conflict, Serious, Life, Relationships, Reflection, Realization" 85 | }, 86 | "jump-scare": { 87 | "description": "A technique often used in horror films, intended to scare the audience by surprising them with an abrupt change in image or event usually co-occurring with a loud, jarring sound.", 88 | "keywords": "Screech, Bang, Scream, Shock, Horror" 89 | }, 90 | "car-chase": { 91 | "description": "An exciting action sequence where characters are pursuing each other in a moving vehicle. These scenes usually involve reckless driving, daring stunts, and tension.", 92 | "keywords": "Car Chase, Speeding, Driving, Car Crash, Swerving" 93 | }, 94 | "violence": { 95 | "description": "Violence is any physical act initiated by a perpetrator to harm or damage oneself, another living creature, a personified object or character, or an inanimate object.", 96 | "keywords": "Hitting, Kicking, Throwing, Shooting, Stabbing, Slashing, Blood, Gore, Death, Harm" 97 | }, 98 | "drugs": { 99 | "description": "Drugs that are illegal or legal substances that are used specifically for the intent of altering one’s state of consciousness or physical feeling, becoming intoxicated, or detaching from reality.", 100 | "keywords": "Drugs, Pills, High, Intoxicated, Snort, Swallow, Marijuana, Cocaine, Ecstasy, Needles" 101 | }, 102 | "laughter": { 103 | "description": "Characters laughing. Laughter can apply to a range of emotions. Joyous laugh, nervous laugh, ominous laugh.", 104 | "keywords": "Laughing, Chuckling, Giggling, Ha Ha" 105 | }, 106 | "animal": { 107 | "description": "A mammal, fish, bird, reptile or amphibian", 108 | "keywords": "Dog, Cat, Elephant, Pigeon, Iguana, Turtle, Bear" 109 | }, 110 | "object": { 111 | "description": "A material thing that can be seen or touched. Items that are being used by characters or are the focus of the shot", 112 | "keywords": "Briefcase, Lamp, Phone, Money, Cups" 113 | }, 114 | "interior": { 115 | "description": "Situated inside something enclosed", 116 | "keywords": "Room, Office, Bedroom, Bathroom, Livingroom, Car, Bus" 117 | }, 118 | "day": { 119 | "description": "Settings where there is daylight. Natural light is scene in the setting", 120 | "keywords": "Daylight, Day time, Sun, Afternoon" 121 | }, 122 | "overhead-shot": { 123 | "description": "When the camera is placed directly above the subject. It's somewhere around a 90-degree angle above the scene taking place. Overhead shots are also called a bird view, bird's eye view, or elevated shot. Its placement is somewhat near the action, differing from the aerial shot. An aerial shot would be closer to a satellite or airplane view, closer to a God's eye view, and an overhead shot is closer to a bird's eye view.", 124 | "keywords": "Overhead, Bird’s Eye, God’s View, Above, Looking down, Floor" 125 | }, 126 | "low-angle": { 127 | "description": "A shot from a camera angle positioned anywhere below the eye line, pointing upward. It can even be below the character’s feet", 128 | "keywords": "Low, Looking Up, Sky, High Ground" 129 | }, 130 | "single-shot": { 131 | "description": "One subject or character in focus", 132 | "keywords": "Character, Profile, Hero" 133 | }, 134 | "group-shot": { 135 | "description": "Group of three or more characters within the focus of the frame", 136 | "keywords": "Group, Team, Friends, Gathering" 137 | }, 138 | "over-the-shoulder-shot": { 139 | "description": "Shows the main subject from behind the shoulder of another subject", 140 | "keywords": "Conversation, Approach, Shoulder" 141 | }, 142 | "point-of-view-shot": { 143 | "description": "Seeing from a character’s point of view", 144 | "keywords": "POV, Watching, Seeing, Looking" 145 | }, 146 | "handheld": { 147 | "description": "A shot is taken with the camera being supported only by the operator’s hands and shoulder.", 148 | "keywords": "Camera Footage, Found Footage, Shaky, Unstable Movement, Tense" 149 | }, 150 | "slow-motion": { 151 | "description": "The motion of is slower than normal", 152 | "keywords": "Slow motion, Freeze " 153 | }, 154 | "timelapse": { 155 | "description": "A technique to show a long period of time shown quickly. Usually to capture a progression of action that is normally moving slow. Timelapse is a form of fast motion.", 156 | "keywords": "Fast Forward, Time Span, Fast Movement" 157 | }, 158 | "interview": { 159 | "description": "Subjects breaking the fourth wall and speaking directly to the camera, audience or interviewer.", 160 | "keywords": "Interview, Conversation, Documentary, Explanation" 161 | }, 162 | "romance": { 163 | "description": "Content focused on love, romantic relationships, dating or affairs of the heart that center on falling in love, passion, emotion and the affectionate involvement of the main characters.", 164 | "keywords": "Looks of Affection, Embrace, Kissing, Hugging, Holding Hands" 165 | }, 166 | "fantasy": { 167 | "description": "Titles that contain imaginary and often magical elements which are far removed from reality: wizards, witches, dragons, vampires, elves, trolls, fairies, etc. It may be set in a fictional past of swords, horses, knights and princesses or a grounded reality. Do not confuse this genre with Sci-Fi.", 168 | "keywords": "Magic, Creatures, Spectacle, Mystical" 169 | }, 170 | "fight": { 171 | "description": "Characters physically battling each other using weapons or their body", 172 | "keywords": "Combat, Martial Arts, Punch, Kick, Hit, Weapons" 173 | }, 174 | "run": { 175 | "description": "Characters running", 176 | "keywords": "Run, Sprint, Speed, Fast" 177 | }, 178 | "scared": { 179 | "description": "Characters in fear or frightened. This can be in varying degrees.", 180 | "keywords": "Fear, Frightened, Panicked, Screaming" 181 | }, 182 | "golden-hour": { 183 | "description": "The hour as the sun rises or as the sun sets", 184 | "keywords": "Sunrise, Sunset, Dawn" 185 | }, 186 | "extreme-close-up": { 187 | "description": "An extreme close-up shot frames a subject very closely, often so much so that the outer portions of the subject are cut off by the edges of the frame.", 188 | "keywords": "Eyes, Details, Zoom" 189 | }, 190 | "extreme-wide-shot": { 191 | "description": "A wide shot that covers a wider area. The shot frames the subject from a distance and focuses on its surroundings. Remember: In a wide shot, the character is usually still noticeable in the frame. But in an ELS, it’s not always necessary to see the character. The EWS functions best as an establishing shot, as more of the location is in the frame.", 192 | "keywords": "Far away, Distant, Extreme Long Shot" 193 | }, 194 | "cowboy-shot": { 195 | "description": "A shot framed from the actor’s mid-waist to right above their head. Many times the camera is placed at hip level as well, to give a slightly low angle to the shot. The idea of a cowboy shot is to present the actor in a heroic, confident fashion while also allowing the viewer to see action taking place around the actor’s waistline, often to see guns or swords drawn.", 196 | "keywords": "Medium Full Shot, Waist, Hip" 197 | }, 198 | "high-angle": { 199 | "description": "A high angle shot is when the camera looks down on the character or subject from an elevated perspective. This is usually achieved by placing the camera higher than the subject and then angling it down on them. This can range from a shallow angle just above eye level all the way to directly above the subject.", 200 | "keywords": "Looking Down, Edge, Cliff, Hanging, Characters Looking Up" 201 | }, 202 | "dutch-angle": { 203 | "description": "A type of camera shot that has a noticeable tilt on the camera’s “x-axis.” Directors often use a Dutch angle to signal to the viewer that something is wrong, disorienting, or unsettling.", 204 | "keywords": "Tilt, Disoriented, Strange, Unsettling" 205 | }, 206 | "two-shot": { 207 | "description": "A two shot is a shot in which the frames a view of the subjects. The subjects do not have to be next to each other. There are many which have one subject in the foreground and the other subject in the background. The shot is used to show the emotional reactions between the subjects.", 208 | "keywords": "Two, Pair, Couple, Together" 209 | }, 210 | "insert-shot": { 211 | "description": "An insert shot is any shot that’s sole purpose is to focus the viewer’s attention to a specific detail within a scene. These shots are often placed from the point of view of a character, suggesting that we see what the character sees during a moment of concentration.", 212 | "keywords": "Objects, Close-Up, B-Roll" 213 | }, 214 | "tilt-shot": { 215 | "description": "A tilt shot is shot that uses a cinematographic technique called tilting in which the camera stays fixed but rotates up and down on a vertical plane. Tilting is similar to the motion of an individual raising or lowering their heads. This is different than the panning of a camera which also sits on a fixed but moves left to right.", 216 | "keywords": "Looking Up, Looking Down" 217 | }, 218 | "static-shot": { 219 | "description": "A static shot in film is a shot that is devoid of camera movement. Also known as a locked-off shot, or an immobile shot. The frame can be filled with the movement of vehicles, characters, props, weather, etc, but the frame itself does not move in a static shot.", 220 | "keywords": "Frame, Still" 221 | }, 222 | "shutter-shot": { 223 | "description": "The asset contains duplicate frames, stretching the running time of the sequence and creating a sense of slow motion for the audience.", 224 | "keywords": "slow motion" 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from videoannotator import data 4 | 5 | 6 | class TestLabeledDataset(unittest.TestCase): 7 | def test_valid(self): 8 | lds = data.LabeledDataset( 9 | label="test", pos=frozenset({"x", "y"}), neg=frozenset({"z"}) 10 | ) 11 | assert len(lds) == 3 12 | assert lds.pos_cnt == 2 13 | assert lds.neg_cnt == 1 14 | 15 | def test_invalid_common_keys(self): 16 | with self.assertRaises(ValueError): 17 | data.LabeledDataset( 18 | label="test", pos=frozenset({"x"}), neg=frozenset({"x"}) 19 | ) 20 | 21 | def test_invalid_no_data(self): 22 | with self.assertRaises(ValueError): 23 | data.LabeledDataset(label="test", pos=frozenset(), neg=frozenset()) 24 | -------------------------------------------------------------------------------- /videoannotator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Netflix/videoannotator/526d38dfa4ffcb953423d1f1f583d695904c90a2/videoannotator/__init__.py -------------------------------------------------------------------------------- /videoannotator/config.py: -------------------------------------------------------------------------------- 1 | from fractions import Fraction 2 | 3 | PATH_DATA_BASE = "/root/data/videoannotator" 4 | 5 | SEED = 0 6 | SCORE_THRESHOLD = 0.5 7 | N_SPLITS = 5 8 | N_BOOTSTRAP = 5 9 | N_JOBS = -1 10 | MAX_ITER = 1_000 11 | VALIDATION_FRACTION = Fraction(numerator=2, denominator=10) 12 | MIN_TRAINING_POS_RATE = 0.1 13 | CLUSTER_CNT = 10 14 | ANNOTATOR_CNT = 3 15 | 16 | # AVE experiments 17 | AVE_EXPERIMENTS_NS = (25, 50, 100, 500, 1000) 18 | AVE_EXPERIMENTS_METHODS = ( 19 | "random", 20 | "zero-shot-50-random-50", 21 | "zero-shot", 22 | ) 23 | SCORING = ( 24 | "average_precision", 25 | # "balanced_accuracy", 26 | ) 27 | 28 | LABEL_GROUPS = dict( 29 | motion={ 30 | "zoom", 31 | "pan", 32 | "slow-motion", 33 | "handheld", 34 | # "fast-motion", 35 | "timelapse", 36 | "jump-scare", 37 | }, 38 | genres={ 39 | "action", 40 | "sci-fi", 41 | "horror", 42 | "fantasy", 43 | "drama", 44 | "romance", 45 | }, 46 | emotions={ 47 | "anger", 48 | "sad", 49 | "happy", 50 | "scared", 51 | "laughter", 52 | }, 53 | shot_types={ 54 | "establishing-shots", 55 | "static-shot", 56 | "shutter-shot", 57 | "cowboy-shot", 58 | "extreme-close-up", 59 | "extreme-wide-shot", 60 | "two-shot", 61 | "group-shot", 62 | "aerial", 63 | "eye-level", 64 | "medium", 65 | "closeup", 66 | "wide", 67 | "over-the-shoulder-shot", 68 | "tilt-shot", 69 | "dutch-angle", 70 | "point-of-view-shot", 71 | "high-angle", 72 | "insert-shot", 73 | "low-angle", 74 | "overhead-shot", 75 | "single-shot", 76 | }, 77 | sensitivities={ 78 | "alcohol", 79 | "smoking", 80 | "nudity", 81 | "gore", 82 | "drugs", 83 | "violence", 84 | "intimacy", 85 | }, 86 | events_actions={ 87 | "interview", 88 | "fight", 89 | "car-chase", 90 | "run", 91 | }, 92 | time_location={ 93 | "day", 94 | "golden-hour", 95 | "interior", 96 | }, 97 | focus={ 98 | "character-focus", 99 | "animal", 100 | "object", 101 | }, 102 | ) 103 | 104 | LABELS = frozenset( 105 | ( 106 | "action", 107 | "smoking", 108 | "nudity", 109 | "gore", 110 | "sad", 111 | "happy", 112 | "intimacy", 113 | "establishing-shots", 114 | "character-focus", 115 | "sci-fi", 116 | "horror", 117 | "alcohol", 118 | "anger", 119 | "aerial", 120 | "eye-level", 121 | "medium", 122 | "closeup", 123 | "wide", 124 | "zoom", 125 | "pan", 126 | "slow-motion", 127 | "jump-scare", 128 | "timelapse", 129 | "interior", 130 | "overhead-shot", 131 | # "fast-motion", # very low pos rate 132 | "animal", 133 | "group-shot", 134 | "drugs", 135 | "car-chase", 136 | "laughter", 137 | "over-the-shoulder-shot", 138 | "day", 139 | "object", 140 | "handheld", 141 | "low-angle", 142 | "violence", 143 | "drama", 144 | "point-of-view-shot", 145 | "single-shot", 146 | "romance", 147 | "golden-hour", 148 | "extreme-wide-shot", 149 | "high-angle", 150 | "insert-shot", 151 | "run", 152 | "fantasy", 153 | "static-shot", 154 | "shutter-shot", 155 | "cowboy-shot", 156 | "extreme-close-up", 157 | "two-shot", 158 | "fight", 159 | "interview", 160 | "scared", 161 | "tilt-shot", 162 | "dutch-angle", 163 | ) 164 | ) 165 | LABELS_AVE = frozenset( 166 | ( 167 | "group-shot", 168 | "animal", 169 | "single-shot", 170 | "low-angle", 171 | "eye-level", 172 | "medium", 173 | "extreme-wide-shot", 174 | "tilt-shot", 175 | "high-angle", 176 | "over-the-shoulder-shot", 177 | "zoom", 178 | "closeup", 179 | "wide", 180 | "handheld", 181 | "overhead-shot", 182 | "insert-shot", 183 | "two-shot", 184 | "extreme-close-up", 185 | "aerial", 186 | ) 187 | ) 188 | -------------------------------------------------------------------------------- /videoannotator/data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import fractions 3 | import functools 4 | import json 5 | import mmh3 6 | import logging 7 | import typing as t 8 | from dataclasses import dataclass 9 | 10 | import h5py 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from . import config as cfg, io 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | Key = str 19 | Embedding = np.ndarray 20 | Label = str 21 | 22 | 23 | @functools.lru_cache(maxsize=1) 24 | def get_embeddings_dict() -> t.Dict[Key, Embedding]: 25 | logger.info("Loading embeddings for the first time...") 26 | keys = pd.read_csv(io.PATHS_STATIC.shot_data, usecols=["key"]).key.tolist() 27 | with h5py.File(io.PATHS_STATIC.embeddings, "r") as hf: 28 | embs = hf["embeddings"][:] 29 | d = {key: emb for key, emb in zip(keys, embs)} 30 | logger.info("Finished loading embeddings.") 31 | return d 32 | 33 | 34 | @functools.lru_cache(maxsize=1) 35 | def get_text_embedding_dict() -> t.Dict[Key, Embedding]: 36 | logger.info("Loading text embeddings for the first time...") 37 | data = json.load(open(io.PATHS_STATIC.text_embeddings)) 38 | logger.info("Finished loading text embeddings.") 39 | return data 40 | 41 | 42 | @functools.lru_cache(maxsize=1) 43 | def get_key_to_cluster_idx_dict() -> t.Dict[Key, int]: 44 | df = pd.read_csv(io.PATHS_STATIC.shot_data) 45 | return df.set_index("key").cluster_index.to_dict() 46 | 47 | 48 | @dataclass(frozen=True) 49 | class LabeledDataset: 50 | label: str 51 | pos: t.FrozenSet[Key] 52 | neg: t.FrozenSet[Key] 53 | _cluster_cnt: int = cfg.CLUSTER_CNT 54 | 55 | def __post_init__(self): 56 | if len(self) == 0: 57 | raise ValueError(f"Dataset with label={self.label} has no data.") 58 | common = self.pos & self.neg 59 | if len(common) > 0: 60 | raise ValueError( 61 | f"keys cannot be in both positive and negative sets. these keys are in both: {common}" 62 | ) 63 | elif self.pos_cnt == 0: 64 | logger.warning(f"{self} has no positive annotations.") 65 | elif self.neg_cnt == 0: 66 | logger.warning(f"{self} has no negative annotations.") 67 | 68 | def __len__(self) -> int: 69 | return len(self.pos) + len(self.neg) 70 | 71 | def __repr__(self) -> str: 72 | return ( 73 | f"LabeledDataset(label={self.label}, cnt={len(self)}, pos_cnt={self.pos_cnt}, " 74 | f"neg_cnt={self.neg_cnt}, pos_rate={self.pos_rate:.2f})" 75 | ) 76 | 77 | @property 78 | def keys(self) -> t.List[Key]: 79 | return self.keys_pos + self.keys_neg 80 | 81 | @property 82 | def y(self) -> np.array: 83 | return np.array([True] * self.pos_cnt + [False] * self.neg_cnt) 84 | 85 | @property 86 | @functools.lru_cache() 87 | def x(self) -> np.ndarray: 88 | return np.vstack([get_embeddings_dict()[key] for key in self.keys]) 89 | 90 | @property 91 | @functools.lru_cache() 92 | def data_score(self) -> float: 93 | k2c = get_key_to_cluster_idx_dict() 94 | part_pos = self._data_score_part( 95 | keys=self.keys_pos, k2c=k2c, cluster_cnt=self._cluster_cnt 96 | ) 97 | part_neg = self._data_score_part( 98 | keys=self.keys_neg, k2c=k2c, cluster_cnt=self._cluster_cnt 99 | ) 100 | return (part_pos + part_neg) / (2 * self._cluster_cnt) 101 | 102 | @staticmethod 103 | def _data_score_part( 104 | keys: t.Sequence[Key], k2c: t.Dict[Key, int], cluster_cnt: int 105 | ) -> float: 106 | return ( 107 | pd.Series(map(k2c.get, keys)).value_counts().clip(0, cluster_cnt) 108 | / cluster_cnt 109 | ).sum() 110 | 111 | @property 112 | def keys_pos(self) -> t.List[Key]: 113 | return sorted(self.pos) 114 | 115 | @property 116 | def keys_neg(self) -> t.List[Key]: 117 | return sorted(self.neg) 118 | 119 | @property 120 | def pos_cnt(self) -> int: 121 | return len(self.pos) 122 | 123 | @property 124 | def pos_rate(self) -> float: 125 | return self.pos_cnt / len(self) 126 | 127 | @property 128 | def neg_cnt(self) -> int: 129 | return len(self.neg) 130 | 131 | @staticmethod 132 | def _key_in_val( 133 | key: Key, validation_fraction: fractions.Fraction, seed: int 134 | ) -> bool: 135 | vf = validation_fraction 136 | h = mmh3.hash(key=key, seed=seed) 137 | bucket = h % vf.denominator 138 | return bucket < vf.numerator 139 | 140 | def split( 141 | self, 142 | validation_fraction: fractions.Fraction = cfg.VALIDATION_FRACTION, 143 | keys_to_remove_from_train: t.Set[Key] = frozenset(), 144 | keys_to_remove_from_validation: t.Set[Key] = frozenset(), 145 | seed: int = cfg.SEED, 146 | ) -> t.Tuple["LabeledDataset", "LabeledDataset"]: 147 | data = dict( 148 | train=collections.defaultdict(set), val=collections.defaultdict(set) 149 | ) 150 | for value, keys in ((True, self.pos), (False, self.neg)): 151 | for key in keys: 152 | in_val = self._key_in_val( 153 | key=key, validation_fraction=validation_fraction, seed=seed 154 | ) 155 | if in_val and key not in keys_to_remove_from_validation: 156 | data["val"][value].add(key) 157 | elif not in_val and key not in keys_to_remove_from_train: 158 | data["train"][value].add(key) 159 | 160 | ds_train = LabeledDataset( 161 | label=self.label, 162 | pos=frozenset(data["train"][True]), 163 | neg=frozenset(data["train"][False]), 164 | ) 165 | ds_val = LabeledDataset( 166 | label=self.label, 167 | pos=frozenset(data["val"][True]), 168 | neg=frozenset(data["val"][False]), 169 | ) 170 | if ds_train.pos_rate == 0: 171 | raise ValueError( 172 | f"The training split for label={self.label} with seed={seed} has no positive instances." 173 | ) 174 | if ds_val.pos_rate == 0: 175 | raise ValueError( 176 | f"The validation split for label={self.label} with seed={seed} has no positive instances." 177 | ) 178 | return ds_train, ds_val 179 | 180 | def boostrap_xyk(self, idx: int) -> t.Tuple[np.ndarray, np.ndarray, t.List[Key]]: 181 | if idx == 0: 182 | # return the exact dataset on exactly one of the bootstrap 183 | # this helps with datasets that have a very small number of positives 184 | return self.x, self.y, self.keys 185 | np.random.seed(idx) 186 | n = len(self) 187 | idxs = np.random.choice(range(n), n) 188 | x = self.x[idxs] 189 | y = self.y[idxs] 190 | keys = np.array(self.keys)[idxs] 191 | return x, y, list(keys) 192 | 193 | def remove_keys(self, keys: t.Set[Key]) -> "LabeledDataset": 194 | return LabeledDataset( 195 | label=self.label, 196 | pos=frozenset(self.pos - keys), 197 | neg=frozenset(self.neg - keys), 198 | ) 199 | 200 | def _create_dataset_from_keys(self, keys: t.Collection[Key]) -> "LabeledDataset": 201 | keys_pos_set = set(self.keys_pos) 202 | keys_neg_set = set(self.keys_neg) 203 | pos = frozenset(k for k in keys if k in keys_pos_set) 204 | neg = frozenset(k for k in keys if k in keys_neg_set) 205 | return LabeledDataset( 206 | label=self.label, 207 | pos=pos, 208 | neg=neg, 209 | ) 210 | 211 | def sample( 212 | self, n: int, seed: t.Optional[int] = None, start_idx: int = 0 213 | ) -> "LabeledDataset": 214 | np.random.seed(seed) 215 | keys_perm = np.random.permutation(self.keys) 216 | idx_last = start_idx + n if n is not None else None 217 | keys_to_use = keys_perm[start_idx:idx_last] 218 | return self._create_dataset_from_keys(keys=keys_to_use) 219 | 220 | def sample_stratified( 221 | self, 222 | n: int, 223 | seed: t.Optional[int] = None, 224 | min_training_pos_rate: float = 0.0, 225 | ) -> "LabeledDataset": 226 | n = min(len(self), n) 227 | np.random.seed(seed) 228 | cnt_pos = round(max(self.pos_rate, min_training_pos_rate) * n) 229 | cnt_neg = n - cnt_pos 230 | pos = np.random.choice(self.keys_pos, replace=False, size=cnt_pos) 231 | neg = np.random.choice(self.keys_neg, replace=False, size=cnt_neg) 232 | return self._create_dataset_from_keys(keys=pos + neg) 233 | 234 | def __add__(self, other: "LabeledDataset") -> "LabeledDataset": 235 | if self.label != other.label: 236 | raise ValueError( 237 | f"Cannot add datasets with two different labels: {self.label} and {other.label}" 238 | ) 239 | return LabeledDataset( 240 | label=self.label, 241 | pos=self.pos | other.pos, 242 | neg=self.neg | other.neg, 243 | ) 244 | 245 | def __sub__(self, other: "LabeledDataset") -> "LabeledDataset": 246 | if self.label != other.label: 247 | raise ValueError( 248 | f"Cannot subtract datasets with two different labels: {self.label} and {other.label}" 249 | ) 250 | return LabeledDataset( 251 | label=self.label, 252 | pos=self.pos - other.pos, 253 | neg=self.neg - other.neg, 254 | ) 255 | 256 | def __eq__(self, other: "LabeledDataset") -> bool: 257 | return ( 258 | len(self) == len(other) 259 | and self.label == other.label 260 | and self.pos == other.pos 261 | and self.neg == other.neg 262 | ) 263 | 264 | 265 | @dataclass(frozen=True) 266 | class DatasetOrdered: 267 | label: str 268 | pos: t.Tuple[Key, ...] 269 | neg: t.Tuple[Key, ...] 270 | 271 | def __post_init__(self): 272 | if len(set(self.pos) & set(self.neg)) > 0: 273 | raise ValueError( 274 | f"trying to create an ordered dataset with common pos and neg keys." 275 | ) 276 | 277 | def __add__(self, other: "DatasetOrdered"): 278 | assert other.label == self.label 279 | pos_set = set(self.pos) 280 | neg_set = set(self.neg) 281 | other_pos = tuple(x for x in other.pos if x not in pos_set) 282 | other_neg = tuple(x for x in other.neg if x not in neg_set) 283 | return DatasetOrdered( 284 | label=self.label, pos=self.pos + other_pos, neg=self.neg + other_neg 285 | ) 286 | 287 | @staticmethod 288 | def _sub(a: tuple, b: tuple) -> tuple: 289 | b_set = set(b) 290 | return tuple(x for x in a if x not in b_set) 291 | 292 | def __sub__(self, other: "DatasetOrdered") -> "DatasetOrdered": 293 | assert other.label == self.label 294 | pos = self._sub(a=self.pos, b=other.pos) 295 | neg = self._sub(a=self.neg, b=other.neg) 296 | return DatasetOrdered(label=self.label, pos=pos, neg=neg) 297 | 298 | @property 299 | def lds(self) -> LabeledDataset: 300 | return LabeledDataset( 301 | label=self.label, 302 | pos=frozenset(self.pos), 303 | neg=frozenset(self.neg), 304 | ) 305 | 306 | def __len__(self) -> int: 307 | return len(self.pos) + len(self.neg) 308 | 309 | @classmethod 310 | def from_labeled_dataset( 311 | cls, lds: LabeledDataset, seed: int = cfg.SEED 312 | ) -> "DatasetOrdered": 313 | np.random.seed(seed) 314 | pos = sorted(lds.pos) 315 | neg = sorted(lds.neg) 316 | return cls( 317 | label=lds.label, 318 | pos=tuple(np.random.permutation(pos)), 319 | neg=tuple(np.random.permutation(neg)), 320 | ) 321 | 322 | def get_n(self, n: int) -> "DatasetOrdered": 323 | pos_rate = len(self.pos) / len(self) 324 | cnt_pos = min(len(self.pos), max(round(pos_rate * n), 1)) 325 | cnt_neg = min(len(self.neg), n - cnt_pos) 326 | return DatasetOrdered( 327 | label=self.label, 328 | pos=self.pos[:cnt_pos], 329 | neg=self.neg[:cnt_neg], 330 | ) 331 | 332 | 333 | @dataclass(frozen=True) 334 | class CompDataset: 335 | label: str 336 | keys_to_remove: t.FrozenSet[Key] = frozenset() 337 | 338 | @property 339 | @functools.lru_cache() 340 | def data(self) -> dict: 341 | data_og = json.load(open(io.LabelPaths(label=self.label).path_cmp)) 342 | return { 343 | k: [v for v in vs if v["key"] not in self.keys_to_remove] 344 | for k, vs in data_og.items() 345 | } 346 | 347 | def _get_lds( 348 | self, 349 | which: str, 350 | n: t.Optional[int], 351 | start_idx: int, 352 | keys_to_remove: t.Optional[t.Set[Key]] = None, 353 | fail_if_less_available: bool = False, 354 | ) -> LabeledDataset: 355 | d = self.data[which] 356 | if n is not None and n > len(d): 357 | msg = f"Dataset for label={self.label} has {len(d)} records, but requested {n}" 358 | if fail_if_less_available: 359 | raise ValueError(msg) 360 | else: 361 | logger.warning(msg) 362 | data = d[start_idx : start_idx + n] if n is not None else d[start_idx:] 363 | pos, neg = self.__get_pos_neg(data=data, keys_to_remove=keys_to_remove) 364 | return LabeledDataset(label=self.label, pos=frozenset(pos), neg=frozenset(neg)) 365 | 366 | @staticmethod 367 | def __get_pos_neg(data, keys_to_remove: t.Optional[t.Set[Key]]): 368 | keys_to_remove = keys_to_remove if keys_to_remove is not None else set() 369 | data = [d for d in data if d["key"] not in keys_to_remove] 370 | pos = (d["key"] for d in data if d["value"]) 371 | neg = (d["key"] for d in data if not d["value"]) 372 | return pos, neg 373 | 374 | def random( 375 | self, 376 | n: t.Optional[int] = None, 377 | start_idx: int = 0, 378 | keys_to_remove: t.Optional[t.Set[Key]] = None, 379 | ) -> LabeledDataset: 380 | return self._get_lds( 381 | which="random", n=n, start_idx=start_idx, keys_to_remove=keys_to_remove 382 | ) 383 | 384 | def zero_shot( 385 | self, 386 | n: t.Optional[int] = None, 387 | start_idx: int = 0, 388 | keys_to_remove: t.Optional[t.Set[Key]] = None, 389 | ) -> LabeledDataset: 390 | return self._get_lds( 391 | which="zero_shot", n=n, start_idx=start_idx, keys_to_remove=keys_to_remove 392 | ) 393 | 394 | @property 395 | def cnt_zero_shot(self) -> int: 396 | return len(self.data["zero_shot"]) 397 | 398 | @property 399 | def cnt_random(self) -> int: 400 | return len(self.data["random"]) 401 | 402 | def get_ordered_dataset_zero_shot( 403 | self, 404 | keys_to_remove: t.Set[Key] = frozenset(), 405 | ) -> DatasetOrdered: 406 | return self._get_ordered_dataset( 407 | which="zero_shot", keys_to_remove=keys_to_remove 408 | ) 409 | 410 | def get_ordered_dataset_random( 411 | self, keys_to_remove: t.Set[Key] = frozenset() 412 | ) -> DatasetOrdered: 413 | return self._get_ordered_dataset(which="random", keys_to_remove=keys_to_remove) 414 | 415 | def _get_ordered_dataset( 416 | self, which: str, keys_to_remove: t.Set[Key] 417 | ) -> DatasetOrdered: 418 | data = self.data[which] 419 | pos, neg = self.__get_pos_neg(data=data, keys_to_remove=keys_to_remove) 420 | return DatasetOrdered(label=self.label, pos=tuple(pos), neg=tuple(neg)) 421 | 422 | 423 | def get_aggregate_labeled_dataset(label: str) -> LabeledDataset: 424 | path = io.LabelPaths(label=label).path_agg 425 | data = json.load(open(path)) 426 | ds = collections.defaultdict(set) 427 | for x in data["data"]: 428 | ds[x["agg"]].add(x["key"]) 429 | return LabeledDataset( 430 | label=label, pos=frozenset(ds[True]), neg=frozenset(ds[False]) 431 | ) 432 | 433 | 434 | def get_labeled_dataset_positive_cnt_by_key(label: str) -> t.Dict[Key, int]: 435 | path = io.LabelPaths(label=label).path_agg 436 | data = json.load(open(path)) 437 | return {x["key"]: x["pos_cnt"] for x in data["data"]} 438 | 439 | 440 | def get_labeled_dataset_agreement( 441 | label: str, _annotator_cnt: int = cfg.ANNOTATOR_CNT 442 | ) -> float: 443 | pos_counts = get_labeled_dataset_positive_cnt_by_key(label=label) 444 | den = len(pos_counts) 445 | num = sum(cnt in {0, _annotator_cnt} for _, cnt in pos_counts.items()) 446 | return num / den 447 | 448 | 449 | def get_labeled_dataset_checkpoints(label: str) -> t.List[LabeledDataset]: 450 | path = io.LabelPaths(label=label).path_checkpoints 451 | data = json.load(open(path)) 452 | return [ 453 | LabeledDataset( 454 | label=label, 455 | pos=frozenset(x["pos"]), 456 | neg=frozenset(x["neg"]), 457 | ) 458 | for x in data 459 | ] 460 | 461 | 462 | def get_ave_validation_labeled_dataset(label: str) -> LabeledDataset: 463 | if label not in cfg.LABELS_AVE: 464 | raise ValueError( 465 | f"AVE data for label={label} does not exist. choices are: {cfg.LABELS_AVE}" 466 | ) 467 | path = io.LabelPaths(label=label).path_ave 468 | data = json.load(open(path)) 469 | return LabeledDataset( 470 | label=label, 471 | pos=frozenset(data["pos"]), 472 | neg=frozenset(data["neg"]), 473 | ) 474 | 475 | 476 | @dataclass(frozen=False) 477 | class DatasetManager: 478 | label: str 479 | _seed: int = cfg.SEED 480 | 481 | def __post_init__(self): 482 | cmp = CompDataset(label=self.label) 483 | self._agg = get_aggregate_labeled_dataset(label=self.label) 484 | self._agg_val_keys = set(self._agg.split(seed=self._seed)[1].keys) 485 | self._ordered_datasets: t.Dict[str, t.Optional[DatasetOrdered]] = dict( 486 | agg=self._get_combined_agg_ordered_dataset(), 487 | zero_shot=cmp.get_ordered_dataset_zero_shot( 488 | keys_to_remove=self._agg_val_keys 489 | ), 490 | random=cmp.get_ordered_dataset_random(keys_to_remove=self._agg_val_keys), 491 | ) 492 | self.lds = None 493 | self.lds_hist = [] 494 | 495 | def has_data(self, which: str) -> bool: 496 | return self._ordered_datasets[which] is not None 497 | 498 | def _get_agg_ordered_dataset(self, lds: LabeledDataset) -> DatasetOrdered: 499 | ods = DatasetOrdered.from_labeled_dataset(lds=lds) 500 | keys = ods.pos + ods.neg 501 | pos = tuple( 502 | k for k in keys if k in self._agg.pos and k not in self._agg_val_keys 503 | ) 504 | neg = tuple( 505 | k for k in keys if k in self._agg.neg and k not in self._agg_val_keys 506 | ) 507 | return DatasetOrdered(label=self.label, pos=pos, neg=neg) 508 | 509 | def _get_combined_agg_ordered_dataset(self) -> DatasetOrdered: 510 | lds_list = get_labeled_dataset_checkpoints(label=self.label) 511 | ods = self._get_agg_ordered_dataset(lds=lds_list[0]) 512 | for lds in lds_list[1:]: 513 | if len(lds) > len(ods): 514 | ods_curr = self._get_agg_ordered_dataset(lds=lds) 515 | ods = ods + (ods_curr - ods) 516 | return ods 517 | 518 | def extend(self, which: str, n: int) -> LabeledDataset: 519 | ods = self._ordered_datasets[which] 520 | if ods is not None: 521 | ods_new = ods.get_n(n=n) 522 | if len(ods_new) < n: 523 | logger.warning( 524 | f"label={self.label} {which} returned {len(ods_new)} vs. the requested {n}." 525 | ) 526 | self._ordered_datasets[which] = None 527 | elif len(ods_new) == len(ods): 528 | self._ordered_datasets[which] = None 529 | else: 530 | self._ordered_datasets[which] = ods - ods_new 531 | lds_new = ods_new.lds 532 | 533 | self.lds_hist.append( 534 | dict(which=which, n=n, lds_new_cnt=len(lds_new), lds_new=lds_new) 535 | ) 536 | self.lds = self.lds + lds_new if self.lds is not None else lds_new 537 | for name, ods in self._ordered_datasets.items(): 538 | if ods is not None: 539 | self._ordered_datasets[name] = ( 540 | self._ordered_datasets[name] - ods_new 541 | ) 542 | else: 543 | logger.warning( 544 | f"label={self.label} {which} is already exhausted. returning the last dataset." 545 | ) 546 | return self.lds 547 | -------------------------------------------------------------------------------- /videoannotator/dp.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass 3 | import logging 4 | import typing as t 5 | 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | 9 | from . import config as cfg 10 | from .data import DatasetManager, get_aggregate_labeled_dataset 11 | from .experiments import ExperimentResults, Metric, run_experiment 12 | from .models import LogisticRegression, Model 13 | 14 | Choice = str 15 | 16 | # TODO: add conf 17 | CHOICES: t.Tuple[str, ...] = ("agg", "zero_shot", "random") 18 | D = 50 19 | EPSILON = 0.2 20 | ALPHA = 0.5 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass(frozen=True) 27 | class DecisionProcessResults: 28 | res_hist: t.Tuple[t.Tuple[str, ExperimentResults], ...] 29 | choice_hist: t.Tuple[str, ...] 30 | 31 | @property 32 | def best_metric(self) -> Metric: 33 | return max( 34 | ( 35 | (res.average_precision.mean, res.average_precision) 36 | for _, res in self.res_hist 37 | ), 38 | key=lambda x: x[0], 39 | )[1] 40 | 41 | @property 42 | def last_metric(self) -> Metric: 43 | return self.res_hist[-1][1].average_precision 44 | 45 | @property 46 | def best_idx(self) -> int: 47 | return max( 48 | (res.average_precision.mean, idx) 49 | for idx, (_, res) in enumerate(self.res_hist) 50 | )[1] 51 | 52 | 53 | @dataclass(frozen=False) 54 | class _DecisionProcess: 55 | label: str 56 | choices: t.Tuple[str, ...] = CHOICES 57 | d: int = D 58 | model: Model = LogisticRegression(scoring="average_precision") 59 | _seed_base: int = cfg.SEED 60 | 61 | @staticmethod 62 | def _iter_n(n: int, verbose: bool): 63 | if verbose: 64 | return tqdm(range(n)) 65 | else: 66 | return range(n) 67 | 68 | def run(self, n: int, verbose: bool = False) -> DecisionProcessResults: 69 | self._warmup(verbose=verbose) 70 | for _ in self._iter_n(n=n, verbose=verbose): 71 | choice = self._select_outer() 72 | res = self._act(choice=choice, verbose=verbose) 73 | self._update(choice=choice, res=res) 74 | self._choice_hist.append(choice) 75 | self._res_hist.append((choice, res)) 76 | self.i += 1 77 | return DecisionProcessResults( 78 | res_hist=tuple(self._res_hist), 79 | choice_hist=tuple(self._choice_hist), 80 | ) 81 | 82 | def __post_init__(self): 83 | self.i = 0 84 | self._res_hist = [] 85 | self._choice_hist = [] 86 | self._dsm = DatasetManager(label=self.label, _seed=self._seed_base) 87 | self._choices_exhausted = set() 88 | 89 | @property 90 | def choices_available(self) -> t.Tuple[str, ...]: 91 | return tuple(c for c in self.choices if c not in self._choices_exhausted) 92 | 93 | def _select_outer(self) -> Choice: 94 | if self.i < len(self.choices): 95 | # try each option once first 96 | return self.choices[self.i] 97 | while True: 98 | c = self._select() 99 | if c is None: 100 | raise ValueError( 101 | f"{self.__class__.__name__} select returned option None (not allowed) at i={self.i} with " 102 | f"params: {self.__dict__}" 103 | ) 104 | if self._dsm.has_data(which=c): 105 | return c 106 | else: 107 | self._choices_exhausted.add(c) 108 | if len(self.choices_available) == 0: 109 | raise ValueError( 110 | f"{self.__class__.__name__} ran out of options to select at i={self.i} with " 111 | f"params: {self.__dict__}" 112 | ) 113 | 114 | def _select(self) -> Choice: 115 | pass 116 | 117 | def _warmup(self, verbose: bool) -> None: 118 | pass 119 | 120 | def _act(self, choice: Choice, verbose: bool) -> ExperimentResults: 121 | _, lds_val = get_aggregate_labeled_dataset(label=self.label).split( 122 | seed=self._seed_base 123 | ) 124 | lds_train = self._dsm.extend(which=choice, n=self.d) 125 | if verbose: 126 | logger.info(f"Running experiment with choice={choice}") 127 | return run_experiment( 128 | model=self.model, 129 | lds_train=lds_train, 130 | lds_validation=lds_val, 131 | ) 132 | 133 | def _update(self, choice: Choice, res: ExperimentResults) -> None: 134 | pass 135 | 136 | @property 137 | def _seed(self) -> int: 138 | return self._seed_base + self.i 139 | 140 | @staticmethod 141 | def extract_metric(res: ExperimentResults) -> float: 142 | return res.average_precision.mean 143 | 144 | 145 | @dataclass 146 | class RoundRobin(_DecisionProcess): 147 | def _select(self) -> Choice: 148 | return self.choices_available[self.i % len(self.choices_available)] 149 | 150 | 151 | @dataclass 152 | class Random(_DecisionProcess): 153 | def _select(self) -> Choice: 154 | np.random.seed(self._seed) 155 | return np.random.choice(self.choices_available) 156 | 157 | 158 | @dataclass 159 | class GreedyOracle(_DecisionProcess): 160 | def _select_outer(self) -> Choice: 161 | while True: 162 | c = self._select() 163 | if self._dsm.has_data(which=c): 164 | return c 165 | else: 166 | self._choices_exhausted.add(c) 167 | if len(self.choices_available) == 0: 168 | raise ValueError( 169 | f"GreedyOracle has no more available choices for label={self.label} at i={self.i}" 170 | ) 171 | 172 | def _select(self) -> Choice: 173 | best, best_score = None, float("-inf") 174 | dsm_copy = deepcopy(self._dsm) 175 | for choice in self.choices_available: 176 | try: 177 | res = self._act(choice=choice, verbose=False) 178 | score = self.extract_metric(res=res) 179 | if np.isnan(score): 180 | raise ValueError("Evaluation score is NaN.") 181 | if score > best_score: 182 | best, best_score = choice, score 183 | except Exception as e: 184 | logger.warning( 185 | f"Running GreedyOracle with choice={choice}, i={self.i}, label={self.label} failed: {e}. " 186 | f"Skipping over this choice and setting the metric to zero. " 187 | f"Params: {self.__dict__}" 188 | ) 189 | finally: 190 | self._dsm = deepcopy(dsm_copy) 191 | if best is None: 192 | raise ValueError( 193 | f"All choices failed for GreedyOracle label={self.label}, i={self.i}." 194 | f" Params: {self.__dict__}" 195 | ) 196 | return best 197 | 198 | 199 | @dataclass 200 | class EpsilonGreedyMean(_DecisionProcess): 201 | epsilon: float = EPSILON 202 | 203 | def __post_init__(self): 204 | super().__post_init__() 205 | self._metrics = {c: Metric(values=tuple()) for c in self.choices} 206 | self._metrics_abs = {c: Metric(values=tuple()) for c in self.choices} 207 | 208 | @property 209 | def best_choice(self) -> Choice: 210 | return max( 211 | (m.mean, c) 212 | for c, m in self._metrics.items() 213 | if c in set(self.choices_available) 214 | )[1] 215 | 216 | def _select(self) -> Choice: 217 | np.random.seed(self._seed) 218 | if np.random.rand() < self.epsilon: 219 | return np.random.choice(self.choices_available) 220 | else: 221 | return self.best_choice 222 | 223 | def _update(self, choice: Choice, res: ExperimentResults) -> None: 224 | prev = self._metrics_abs[self._choice_hist[-1]].values[-1] 225 | metric = self.extract_metric(res=res) 226 | self._metrics[choice] = self._metrics[choice].add(metric - prev) 227 | self._metrics_abs[choice] = self._metrics_abs[choice].add(metric) 228 | 229 | def _warmup(self, verbose: bool) -> None: 230 | dsm_copy = deepcopy(self._dsm) 231 | metrics = dict() 232 | for choice in self.choices: 233 | try: 234 | res = self._act(choice=choice, verbose=verbose) 235 | val = self.extract_metric(res=res) 236 | except Exception as e: 237 | val = 0 238 | name = self.__class__.__name__ 239 | logger.warning( 240 | f"Running {name} warmup for choice={choice}, label={self.label} failed: {e}. " 241 | f"Skipping over this choice and setting the metric to zero. " 242 | f"Params: {self.__dict__}" 243 | ) 244 | 245 | self._metrics_abs[choice] = self._metrics_abs[choice].add(val) 246 | metrics[choice] = val 247 | self._dsm = deepcopy(dsm_copy) 248 | 249 | choice_worst_val, choice_worst = min((v, c) for c, v in metrics.items()) 250 | self._choice_hist.append(choice_worst) 251 | for choice in self.choices: 252 | self._metrics[choice] = self._metrics[choice].add( 253 | metrics[choice] - choice_worst_val 254 | ) 255 | 256 | 257 | @dataclass 258 | class EpsilonGreedyLast(EpsilonGreedyMean): 259 | @property 260 | def best_choice(self) -> Choice: 261 | return max( 262 | (m.last, c) 263 | for c, m in self._metrics.items() 264 | if c in set(self.choices_available) 265 | )[1] 266 | 267 | 268 | @dataclass 269 | class EpsilonGreedyEWMA(EpsilonGreedyMean): 270 | alpha: float = ALPHA 271 | 272 | @property 273 | def best_choice(self) -> Choice: 274 | return max( 275 | (m.ewma(alpha=self.alpha), c) 276 | for c, m in self._metrics.items() 277 | if c in set(self.choices_available) 278 | )[1] 279 | 280 | 281 | @dataclass 282 | class UCBMean(EpsilonGreedyMean): 283 | c: float = 1e-2 284 | 285 | """ 286 | c <= 1 for best results?! 287 | """ 288 | 289 | def __post_init__(self): 290 | super().__post_init__() 291 | self._counts = {c: 0 for c in self.choices} 292 | 293 | def _select(self) -> Choice: 294 | t_ = self.i + 1 295 | 296 | best, best_score = None, float("-inf") 297 | for choice, cnt in self._counts.items(): 298 | if choice in set(self.choices_available): 299 | if cnt == 0: 300 | return choice 301 | 302 | # val + c * sqrt(log(t) / cnt) 303 | aug = self._metrics[choice].mean 304 | aug += self.c * np.sqrt(np.log(t_) / cnt) 305 | if aug > best_score: 306 | best = choice 307 | best_score = aug 308 | 309 | return best 310 | 311 | def _update(self, choice: Choice, res: ExperimentResults) -> None: 312 | self._counts[choice] += 1 313 | super()._update(choice=choice, res=res) 314 | 315 | 316 | @dataclass 317 | class UCBEWMA(EpsilonGreedyEWMA, UCBMean): 318 | pass 319 | 320 | 321 | @dataclass 322 | class UCBLast(EpsilonGreedyLast, UCBMean): 323 | pass 324 | 325 | 326 | @dataclass 327 | class Thompson(_DecisionProcess): 328 | pass 329 | 330 | 331 | @dataclass 332 | class Exp3Mean(EpsilonGreedyMean): 333 | gamma: float = 1e-1 334 | _seed: int = cfg.SEED 335 | 336 | def __post_init__(self): 337 | assert 0 <= self.gamma <= 1 338 | super().__post_init__() 339 | self._ws = {c: 1 for c in self.choices_available} 340 | 341 | @property 342 | def _k(self) -> int: 343 | return len(self.choices_available) 344 | 345 | @property 346 | def _ps(self) -> t.Dict[Choice, float]: 347 | w_sum = sum(w for c, w in self._ws.items() if c in self.choices_available) 348 | return { 349 | c: (1 - self.gamma) * self._ws[c] / w_sum + self.gamma / self._k 350 | for c in self.choices_available 351 | } 352 | 353 | def _select(self) -> Choice: 354 | np.random.seed(self._seed) 355 | p = [self._ps[c] for c in self.choices_available] 356 | return np.random.choice(self.choices_available, p=p) 357 | 358 | def _update(self, choice: Choice, res: ExperimentResults) -> None: 359 | super()._update(choice=choice, res=res) 360 | x = self._get_update_metric(choice=choice) / self._ps[choice] 361 | self._ws[choice] *= np.exp(self.gamma * x / self._k) 362 | 363 | def _get_update_metric(self, choice: Choice) -> float: 364 | return self._metrics[choice].mean 365 | 366 | 367 | @dataclass 368 | class Exp3Last(Exp3Mean): 369 | def _get_update_metric(self, choice: Choice) -> float: 370 | return self._metrics[choice].last 371 | 372 | 373 | @dataclass 374 | class Exp3EWMA(Exp3Mean): 375 | alpha: float = ALPHA 376 | 377 | def _get_update_metric(self, choice: Choice) -> float: 378 | return self._metrics[choice].ewma(alpha=self.alpha) 379 | -------------------------------------------------------------------------------- /videoannotator/experiments.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import itertools 3 | import logging 4 | import typing as t 5 | from dataclasses import dataclass 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.metrics import average_precision_score as ap, balanced_accuracy_score as ba 10 | from tqdm.auto import tqdm 11 | 12 | from . import config as cfg, data, models 13 | from .data import LabeledDataset 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass(frozen=True) 20 | class Metric: 21 | values: t.Tuple[float, ...] 22 | 23 | @property 24 | @functools.lru_cache() 25 | def values_finite(self) -> t.List[float]: 26 | if len(self.values) == 0: 27 | return [] 28 | vals = [v for v in self.values if not np.isnan(v)] 29 | if len(vals) == 0: 30 | raise ValueError("Metric only contains NaN.") 31 | if len(vals) < len(self.values): 32 | logger.warning(f"Metric has NaN values: {self.values}") 33 | return vals 34 | 35 | @property 36 | def mean(self) -> float: 37 | return np.mean(self.values_finite).item() if len(self.values_finite) > 0 else 0 38 | 39 | @property 40 | def std(self) -> float: 41 | return np.std(self.values_finite).item() if len(self.values_finite) > 0 else 0 42 | 43 | def __repr__(self): 44 | return f"Metric(mean={self.mean:.2f}, std={self.std:.2f})" 45 | 46 | def add(self, x: float) -> "Metric": 47 | return Metric(values=self.values + (x,)) 48 | 49 | @property 50 | def last(self) -> float: 51 | return self.values_finite[-1] if len(self.values_finite) > 0 else float("-inf") 52 | 53 | def ewma(self, alpha: float) -> float: 54 | assert 0 <= alpha <= 1 55 | if len(self.values_finite) == 0: 56 | return 0 57 | v = 1 - alpha 58 | n = len(self.values_finite) 59 | num = sum(x * (v**i) for i, x in enumerate(self.values_finite[::-1])) 60 | den = sum(v**i for i in range(n)) 61 | return num / den 62 | 63 | 64 | @dataclass(frozen=True) 65 | class ExperimentResults: 66 | lds_train: LabeledDataset 67 | lds_validation: LabeledDataset 68 | average_precision: Metric 69 | balanced_accuracy: Metric 70 | 71 | 72 | def run_experiment( 73 | model: models.Model, 74 | lds_train: LabeledDataset, 75 | lds_validation: LabeledDataset, # TODO: take a list/dict instead? 76 | n_bootstrap: int = cfg.N_BOOTSTRAP, 77 | ) -> ExperimentResults: 78 | common_keys = set(lds_train.keys) & set(lds_validation.keys) 79 | if len(common_keys) > 0: 80 | raise ValueError(f"Train and validation sets have common keys: {common_keys}") 81 | model.fit(lds_train.x, lds_train.y) 82 | aps, bas = [], [] 83 | for idx in range(n_bootstrap): 84 | x, y, _ = lds_validation.boostrap_xyk(idx=idx) 85 | bas.append(ba(y, model.predict(x))) 86 | aps.append(ap(y, model.predict_proba(x))) 87 | return ExperimentResults( 88 | lds_train=lds_train, 89 | lds_validation=lds_validation, 90 | average_precision=Metric(values=tuple(aps)), 91 | balanced_accuracy=Metric(values=tuple(bas)), 92 | ) 93 | 94 | 95 | def _ap_baseline(lds_val: LabeledDataset) -> Metric: 96 | thr = 1_000 97 | p = lds_val.pos_cnt 98 | n = len(lds_val) 99 | if n >= thr: 100 | return Metric(values=(p / n,)) 101 | np.random.seed(0) 102 | y = [True] * p + [False] * (n - p) 103 | rand_scores = np.random.rand(thr, n) 104 | aps = tuple(ap(y, rand_scores[i]) for i in range(thr)) 105 | return Metric(values=aps) 106 | 107 | 108 | @dataclass(frozen=True) 109 | class Experiment: 110 | ns: t.Tuple[int, ...] = cfg.AVE_EXPERIMENTS_NS 111 | methods: t.Tuple[str, ...] = cfg.AVE_EXPERIMENTS_METHODS 112 | seed: int = cfg.SEED 113 | scoring: t.Tuple[str, ...] = cfg.SCORING 114 | 115 | def experiment_active_learning(self, label: str) -> dict: 116 | # TODO: make sure zero-shot / random don't overlap with validation 117 | logger.info(f"Running active learning experiment for label={label}...") 118 | ds_agg = data.get_aggregate_labeled_dataset(label=label) 119 | ds_agg_train, ds_agg_val = ds_agg.split() 120 | # ds_val_lookup = dict(agg=ds_agg_val) 121 | res_final = dict() 122 | res_final["baseline_agg"] = _ap_baseline(lds_val=ds_agg_val) 123 | # TODO: remove disagreements and then run (if any positives left) => may still be misaligned 124 | # if label in cfg.LABELS_AVE: 125 | # ds_ave = data.get_ave_validation_labeled_dataset(label=label) 126 | # ds_val_lookup["ave"] = ds_ave 127 | # logger.info("Zero shot with AVE data...") 128 | # res_final["zero_shot_ave"] = run_experiment( 129 | # model=models.ZeroShotText2Video(label=label), 130 | # lds_train=ds_agg, 131 | # lds_validation=ds_ave, 132 | # ) 133 | # res_final["baseline_ave"] = _ap_baseline(lds_val=ds_ave) 134 | # else: 135 | # logger.info(f"AVE data does NOT exist for label={label}") 136 | logger.info("Zero shot with agg data...") 137 | res_final["zero_shot"] = run_experiment( 138 | model=models.ZeroShotText2Video(label=label), 139 | lds_train=ds_agg_train, 140 | lds_validation=ds_agg_val, 141 | ) 142 | val_keys = set(ds_agg_val.keys) 143 | ds_all = data.get_labeled_dataset_checkpoints(label=label) 144 | logger.info("Comparing...") 145 | res_final["cmp"] = self._run_experiments_comp( 146 | label=label, 147 | lds_validation=ds_agg_val, 148 | ) 149 | res_final["checkpoints"] = [ 150 | run_experiment( 151 | models.LogisticRegression(scoring="average_precision"), 152 | lds_train=ds.remove_keys(keys=val_keys), 153 | lds_validation=ds_agg_val, 154 | ) 155 | for ds in tqdm(ds_all, desc="Processing checkpoints") 156 | ] 157 | return res_final 158 | 159 | def experiment_active_learning_batch( 160 | self, 161 | labels: t.Tuple[str, ...], 162 | ) -> dict: 163 | return { 164 | label: self.experiment_active_learning(label=label) 165 | for label in tqdm(labels) 166 | } 167 | 168 | def experiment_active_learning_aggregate( 169 | self, res_all: t.Dict[str, dict] 170 | ) -> pd.DataFrame: 171 | # TODO: AP vs. n => lines with CI: 1- baseline 2- zs 3- agg 4- 172 | pass 173 | 174 | @staticmethod 175 | def _get_lds_zero_shot(lds: LabeledDataset, n: int) -> LabeledDataset: 176 | zs = models.ZeroShotText2Video(label=lds.label) 177 | keys = zs.top_n_keys(lds=lds, n=n) 178 | return LabeledDataset( 179 | label=lds.label, 180 | pos=frozenset(k for k in keys if k in lds.pos), 181 | neg=frozenset(k for k in keys if k in lds.neg), 182 | ) 183 | 184 | @staticmethod 185 | def _get_lds_train( 186 | label: str, n: int, method: str, keys_to_remove: t.FrozenSet[data.Key] 187 | ) -> LabeledDataset: 188 | ds = data.CompDataset(label=label, keys_to_remove=frozenset(keys_to_remove)) 189 | if method == "random": 190 | return ds.random(n=n) 191 | elif method == "zero-shot-50-random-50": 192 | n_rand = round(n * 0.5) 193 | lds_rand = ds.random(n=n_rand) 194 | lds_zs = ds.zero_shot(n=n - n_rand) 195 | return lds_rand + lds_zs 196 | elif method == "zero-shot-20-random-80": 197 | n_zs = round(n * 0.2) 198 | lds_zs = ds.zero_shot(n=n_zs) 199 | lds_rand = ds.random(n=n - n_zs) 200 | return lds_rand + lds_zs 201 | elif method == "zero-shot": 202 | return ds.zero_shot(n=n) 203 | else: 204 | raise ValueError(f"method = {method} is not a valid choice.") 205 | 206 | @property 207 | def comp_iter(self) -> list: 208 | return list( 209 | itertools.product( 210 | self.ns, self.methods, ("average_precision", "balanced_accuracy") 211 | ) 212 | ) 213 | 214 | def _run_experiments_comp( 215 | self, 216 | label: str, 217 | lds_validation: LabeledDataset, 218 | ) -> dict: 219 | res = dict() 220 | for n, method, scoring in tqdm(self.comp_iter, desc="Running experiments"): 221 | try: 222 | lds_train = self._get_lds_train( 223 | label=label, 224 | n=n, 225 | method=method, 226 | keys_to_remove=frozenset(lds_validation.keys), 227 | ) 228 | res[(n, method, scoring)] = run_experiment( 229 | model=models.LogisticRegression(scoring=scoring), 230 | lds_train=lds_train, 231 | lds_validation=lds_validation, 232 | ) 233 | except Exception as e: 234 | logger.error( 235 | f"Experiment ({n}, {method}, {scoring}) for label={label} failed: {e}" 236 | ) 237 | return res 238 | 239 | # def experiment_ave_self_comparison( 240 | # self, 241 | # label: str, 242 | # ) -> t.Dict[t.Tuple[int, str, str], ExperimentResults]: 243 | # if label not in cfg.LABELS_AVE: 244 | # raise ValueError(f"label={label} does not have AVE data") 245 | # ds_ave = data.get_ave_validation_labeled_dataset(label=label) 246 | # lds_train_base, lds_validation = ds_ave.split() 247 | # return self._run_experiments_comp( 248 | # label=label, 249 | # lds_train_base=lds_train_base, 250 | # lds_validation=lds_validation, 251 | # ) 252 | # 253 | # def experiment_ave_self_comparison_all(self) -> ...: 254 | # return [ 255 | # self.experiment_ave_self_comparison(label=label) for label in cfg.LABELS_AVE 256 | # ] 257 | -------------------------------------------------------------------------------- /videoannotator/io.py: -------------------------------------------------------------------------------- 1 | # TODO: add text embeddings for top 10k+ common words and store in repo for convenience 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | from . import config as cfg 7 | 8 | 9 | def _check_path(func): 10 | def __check(*args, **kwargs): 11 | path = func(*args, **kwargs) 12 | if not path.exists(): 13 | raise FileNotFoundError(f"file {path} does not exist.") 14 | return path 15 | 16 | return __check 17 | 18 | 19 | @dataclass(frozen=True) 20 | class _PathsStatic: 21 | _path_base: str = cfg.PATH_DATA_BASE 22 | 23 | @property 24 | @_check_path 25 | def embeddings(self) -> Path: 26 | return Path(f"{self._path_base}/embeddings.h5") 27 | 28 | @property 29 | @_check_path 30 | def shot_data(self) -> Path: 31 | return Path(f"{self._path_base}/shot-data.csv") 32 | 33 | @property 34 | @_check_path 35 | def text_embeddings(self) -> Path: 36 | return Path(f"{self._path_base}/text-embeddings.json") 37 | 38 | 39 | @dataclass(frozen=True) 40 | class _PathAVE(_PathsStatic): 41 | pass # TODO 42 | 43 | 44 | @dataclass(frozen=True) 45 | class LabelPaths: 46 | label: str 47 | _path_base: str = cfg.PATH_DATA_BASE 48 | 49 | @property 50 | @_check_path 51 | def path_agg(self) -> Path: 52 | return Path(f"{self._path_base}/agg/{self.label}.json") 53 | 54 | @property 55 | @_check_path 56 | def path_checkpoints(self) -> Path: 57 | return Path(f"{self._path_base}/checkpoints/{self.label}.json") 58 | 59 | @property 60 | @_check_path 61 | def path_ave(self) -> Path: 62 | return Path(f"{self._path_base}/ave/validation/{self.label}.json") 63 | 64 | @property 65 | @_check_path 66 | def path_cmp(self) -> Path: 67 | return Path(f"{self._path_base}/cmp/{self.label}.json") 68 | 69 | 70 | PATHS_STATIC = _PathsStatic() 71 | PATHS_AVE = _PathAVE() 72 | -------------------------------------------------------------------------------- /videoannotator/models.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.linear_model import LogisticRegressionCV 7 | from sklearn.metrics.pairwise import cosine_similarity as cos 8 | from sklearn.model_selection import StratifiedKFold 9 | from xgboost import XGBClassifier 10 | 11 | from . import config as cfg, data 12 | 13 | 14 | @dataclass 15 | class _ScikitBinaryClassifierAPI: 16 | def fit(self, x: np.ndarray, y: np.ndarray) -> None: 17 | pass 18 | 19 | def predict_proba(self, x: np.ndarray) -> np.ndarray: 20 | pass 21 | 22 | def predict(self, x: np.ndarray) -> np.ndarray: 23 | pass 24 | 25 | 26 | @dataclass 27 | class ZeroShotText2Video(_ScikitBinaryClassifierAPI): 28 | label: str 29 | 30 | @property 31 | def text_emb(self) -> np.ndarray: 32 | return np.array(data.get_text_embedding_dict()[self.label]) 33 | 34 | def predict_proba(self, x: np.ndarray) -> np.ndarray: 35 | return cos(x, self.text_emb[np.newaxis, :]).squeeze() 36 | 37 | def predict(self, x: np.ndarray) -> np.ndarray: 38 | # doesn't really make sense, but we'll add for compatibility 39 | return np.array([False] * len(x)) 40 | 41 | def top_n_keys(self, lds: data.LabeledDataset, n: int) -> t.List[data.Key]: 42 | scores = self.predict_proba(x=lds.x) 43 | return list(np.array(lds.keys)[scores.argsort()[::-1][:n]]) 44 | 45 | 46 | @dataclass 47 | class LogisticRegression(_ScikitBinaryClassifierAPI): 48 | scoring: str 49 | n_splits: int = cfg.N_SPLITS 50 | n_jobs: int = cfg.N_JOBS 51 | seed: int = cfg.SEED 52 | max_iter: int = cfg.MAX_ITER 53 | 54 | def __post_init__(self): 55 | cv = StratifiedKFold( 56 | n_splits=self.n_splits, shuffle=True, random_state=self.seed 57 | ) 58 | self._model = LogisticRegressionCV( 59 | cv=cv, 60 | scoring=self.scoring, 61 | max_iter=self.max_iter, 62 | n_jobs=self.n_jobs, 63 | ) 64 | 65 | def fit(self, x: np.ndarray, y: np.ndarray) -> None: 66 | self._model.fit(x, y) 67 | 68 | def predict(self, x: np.ndarray) -> np.ndarray: 69 | return self._model.predict(x) 70 | 71 | def predict_proba(self, x: np.ndarray) -> np.ndarray: 72 | return self._model.predict_proba(x)[:, 1] 73 | 74 | 75 | Model = t.Union[ 76 | LogisticRegression, RandomForestClassifier, XGBClassifier, ZeroShotText2Video 77 | ] 78 | --------------------------------------------------------------------------------