├── .gitignore ├── Dockerfile ├── Dockerfile_mlcube ├── README.md ├── baselines ├── fpscv │ └── run.py ├── modified_uncertainty_sampling │ ├── get_tag.py │ └── run.py ├── other_experiments │ └── run.ipynb └── pseudo_label_generation │ ├── dataperf_vision_experiments.py │ └── utils.py ├── constants.py ├── data ├── embeddings │ └── .gitignore ├── examples │ ├── alpha_example_set_Cupcake.csv │ ├── alpha_example_set_Hawk.csv │ └── alpha_example_set_Sushi.csv ├── results │ ├── .gitignore │ └── result_for_random_500.json ├── test_sets │ └── .gitignore └── train_sets │ ├── .gitignore │ └── random_500.csv ├── docker-compose.yaml ├── download_data.py ├── eval.py ├── main.py ├── mlcube.py ├── mlcube.yaml ├── requirements.txt ├── run_evaluate.sh ├── selection.py ├── task_setup.yaml ├── utils.py └── workspace └── parameters.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | workdir/ 2 | .vscode 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # Scratch space 134 | scratch/ 135 | 136 | # MLCube 137 | workspace/* 138 | !workspace/parameters.yaml -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | RUN apt update -y 4 | RUN apt install default-jdk -y 5 | 6 | COPY requirements.txt /app/requirements.txt 7 | RUN pip install -r /app/requirements.txt 8 | 9 | COPY *.py /app/ 10 | COPY *.yaml /app/ 11 | 12 | WORKDIR /app/ 13 | ENTRYPOINT python3 main.py --docker_flag True 14 | -------------------------------------------------------------------------------- /Dockerfile_mlcube: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN apt-get update && apt-get install -y \ 4 | software-properties-common && apt install default-jdk -y 5 | RUN apt-get install -y python3.8 python3-pip 6 | 7 | COPY requirements.txt /app/requirements.txt 8 | RUN pip3 install --upgrade pip && pip3 install -r /app/requirements.txt 9 | 10 | COPY *.py /app/ 11 | COPY *.yaml /app/ 12 | COPY *.sh /app/ 13 | 14 | WORKDIR /app/ 15 | RUN chmod +x /app/run_evaluate.sh 16 | ENTRYPOINT ["python3", "/app/mlcube.py"] 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `dataperf-vision-selection`: A Data-Centric Visual Benchmark for Training Data Selection 2 | 3 | ### **Current version:** beta 4 | 5 | This github repo serves as the starting point for offline evaluation of submissions for the training data selection visual benchmark. The offline evaluation can be run on both your local environment as well as a containerized image for reproducibility of score results. 6 | 7 | For a detailed summary of the a benchmark, refer to the provided [documentation](https://www.dataperf.org/training-set-selection-vision). 8 | 9 | *Note that permission is required to view the benchmark documentation and download the required resources. Please contact dataperf@coactive.ai to request access.* 10 | 11 | ## Requirements 12 | 13 | ### Download resources 14 | 15 | The following resources will need to be downloaded locally in order to run offline evaluation: 16 | 17 | - Embeddings for candidate pool of training images (.parquet file) 18 | - Test sets for each classification task (.parquet files) 19 | 20 | These resources can be downloaded in a .zip file at the following url 21 | 22 | ``` 23 | https://drive.google.com/drive/folders/181uI-7NFJwK3IOPy2kOYVIQS4vZOC02A?usp=sharing 24 | ``` 25 | 26 | ### Install dependencies 27 | 28 | For running as a containerized image: 29 | 30 | - `docker` for building the containerized image 31 | - `docker-compose` for running the scoring service with the appropriate resources 32 | 33 | Installation instructions can be found at the following links: [Docker](https://docs.docker.com/get-docker/), [Docker compose](https://docs.docker.com/compose/install/) 34 | 35 | For running locally: 36 | 37 | - Python (>= 3.7) 38 | - An [appropriate version of Java](https://spark.apache.org/docs/latest/) for your version of `python` and `pyspark` 39 | 40 | The current version of this repo has only been tested locally on python 3.9 and java openjdk-11. 41 | 42 | ## Installation 43 | 44 | Clone this repo to your local machine 45 | 46 | ``` 47 | git clone git@github.com:CoactiveAI/dataperf-vision-selection.git 48 | ``` 49 | 50 | If you want to run the offline evaluation in your local environment, install the required python packages 51 | 52 | ``` 53 | pip install -r dataperf-vision-selection/requirements.txt 54 | ``` 55 | 56 | A template filesystem with the following structure is provided in the repo. Move the embeddings file and the tests sets to the appropriate folders in this template filesystem 57 | 58 | ``` 59 | unzip dataperf-vision-selection-resources.zip 60 | mv dataperf-vision-selection-resources/embeddings/* dataperf-vision-selection/data/embeddings/ 61 | mv dataperf-vision-selection-resources/examples/* dataperf-vision-selection/data/examples/ 62 | mv dataperf-vision-selection-resources/test_sets/* dataperf-vision-selection/data/test_sets/ 63 | mv dataperf-vision-selection-resources/train_sets/* dataperf-vision-selection/data/train_sets/ 64 | mv dataperf-vision-selection-resources/results/* dataperf-vision-selection/data/results/ 65 | ``` 66 | 67 | The resulting filesystem in the repo should look as follows 68 | 69 | ``` 70 | |____data 71 | | |____embeddings 72 | | | |____train_emb_256_dataperf.parquet 73 | | |____examples 74 | | | |____alpha_example_set_Hawk.csv 75 | | | |____alpha_example_set_Cupcake.csv 76 | | | |____alpha_example_set_Sushi.csv 77 | | |____test_sets 78 | | | |____alpha_test_set_Hawk_256.parquet 79 | | | |____alpha_test_set_Cupcake_256.parquet 80 | | | |____alpha_test_set_Sushi_256.parquet 81 | | |____train_sets 82 | | | |____random_500.csv 83 | | |____results 84 | | | |____result_for_random_500.json 85 | ``` 86 | 87 | With the resources in place, you can now test that the system is functioning as expected. 88 | 89 | To test the containerized offline evaluation, run 90 | 91 | ``` 92 | cd dataperf-vision-selection 93 | docker-compose up 94 | ``` 95 | 96 | Similarly, to test the local python offline evaluation, run 97 | 98 | ``` 99 | cd dataperf-vision-selection 100 | python3 main.py 101 | ``` 102 | 103 | Either test will run the offline evaluation using the setup specified in `task_setup.yaml`, which will utilize a training set of randomly sampled and labeled data points (`data/train_sets/random_500.csv`) to generate a score results file in `data/results/` with a unique UTC timestamp 104 | 105 | ``` 106 | |____data 107 | | |____results 108 | | | |____result_for_random_500.json 109 | | | |____result_UTC-2022-03-31-20-19-24.json 110 | ``` 111 | 112 | The generated scores in this new results file should be identical to those in `data/results_for_random_500.json`. 113 | 114 | # MLCube execution 115 | 116 | The [MLCube](https://github.com/mlcommons/mlcube) implementation allows us to execute the project using the following steps. 117 | 118 | ## Project setup 119 | 120 | ```bash 121 | # Create Python environment and install MLCube Docker runner 122 | virtualenv -p python3 ./env && source ./env/bin/activate && pip install mlcube-docker 123 | 124 | # Fetch the vision selection repo 125 | git clone https://github.com/CoactiveAI/dataperf-vision-selection && cd ./dataperf-vision-selection 126 | ``` 127 | 128 | ## Tasks execution 129 | 130 | ```bash 131 | # Download and extract dataset 132 | mlcube run --task=download -Pdocker.build_strategy=always 133 | 134 | # Run evaluation 135 | mlcube run --task=evaluate -Pdocker.build_strategy=always 136 | ``` 137 | 138 | ## Execute complete pipeline 139 | 140 | ```bash 141 | # Run all steps 142 | mlcube run --task=download,evaluate -Pdocker.build_strategy=always 143 | ``` 144 | 145 | # Guidelines v0.5 146 | 147 | For the v0.5 of this benchmark we will support offline and online evaluation for the open division. 148 | 149 | ## Open Division: Creating a submission 150 | 151 | A valid submission for the open division includes the following: 152 | 153 | - A description of the data selection algorithm/strategy used 154 | - A training set for each classification task as specified below 155 | - (Optional) A script of the algorithm/strategy used 156 | 157 | Each training set file must be a .csv file containing two columns: `ImageID` (the unique identifier for the image) and `Confidence` (the binary label, either a `0` or `1`). The `ImageID`s in the training set files must be limited to the provided candidate pool of training images (i.e. `ImageID`s in the downloaded embeddings file). 158 | 159 | The included training set file serves as a template of a single training set: 160 | 161 | ``` 162 | cat dataperf-vision-selection/data/train_sets/random_500.csv 163 | 164 | ImageID,Confidence 165 | 0002643773a76876,0 166 | 0016a0f096337445,0 167 | 0036043ce525479b,1 168 | 00526f123f84db2f,1 169 | 0080db2599d54447,1 170 | 00978577e9fdd967,1 171 | ... 172 | ``` 173 | 174 | ## Open Division: Offline evaluation 175 | 176 | The configuration for the offline evaluation is specified in `task_setup.yaml` file. For simplicity, the repo comes pre-configured such that for offline evaluation you can simply: 177 | 178 | 1. Copy your training sets to the template filesystem 179 | 2. Modify the config file to specify the training set for each task 180 | 3. Run offline evaluation 181 | 4. See results in stdout and results file in `data/results/` 182 | 183 | For example 184 | 185 | ``` 186 | # 1. Copy training sets for each task 187 | cd dataperf-vision-selection 188 | cp /path/to/your/training/sets/Cupcake.csv data/train_sets/ 189 | cp /path/to/your/training/sets/Hawk.csv data/train_sets/ 190 | cp /path/to/your/training/sets/Sushi.csv data/train_sets/ 191 | 192 | # 2. task_setup.yaml: modify the training set relative path for each classification task 193 | Cupcake: ['train_sets/Cupcake.csv', 'test_sets/alpha_test_set_Cupcake_256.parquet'] 194 | Hawk: ['train_sets/Hawk.csv', 'test_sets/alpha_test_set_Hawk_256.parquet'] 195 | Sushi: ['train_sets/Sushi.csv', 'test_sets/alpha_test_set_Sushi_256.parquet'] 196 | 197 | # 3a. Run offline evaluation (docker) 198 | docker-compose up --build --force-recreate 199 | 200 | # 3b. Run offline evaluation (local python) 201 | python3 main.py 202 | 203 | # 4. See results (file will have save timestamp in name) 204 | cat data/results/result_UTC-2022-03-31-20-19-24.json 205 | 206 | { 207 | "Cupcake": { 208 | "accuracy": 0.5401459854014599, 209 | "recall": 0.463768115942029, 210 | "precision": 0.5517241379310345, 211 | "f1": 0.5039370078740157 212 | }, 213 | "Hawk": { 214 | "accuracy": 0.296551724137931, 215 | "recall": 0.16831683168316833, 216 | "precision": 0.4857142857142857, 217 | "f1": 0.25000000000000006 218 | }, 219 | "Sushi": { 220 | "accuracy": 0.5185185185185185, 221 | "recall": 0.6261682242990654, 222 | "precision": 0.638095238095238, 223 | "f1": 0.6320754716981132 224 | } 225 | } 226 | ``` 227 | 228 | Though we recommend working as described above, you can specify a custom task setup .yaml file and/or data folder if needed. 229 | 230 | For the containerized offline evaluation, modify the following files and run as follows 231 | 232 | ``` 233 | # docker-compose.yaml: modify the volume source 234 | volumes: 235 | - path/to/your/data/folder:/app/data 236 | 237 | # Dockerfile: modify the COPY *.yaml command and specify the new file in the entrypoint 238 | COPY path/to/your/custom_task_setup.yaml /app/ 239 | ... 240 | ENTRYPOINT python3 main.py --docker_flag True --setup_yaml_path 'custom_task_setup.yaml' 241 | 242 | # Run and force rebuild 243 | docker-compose up --build --force-recreate 244 | ``` 245 | 246 | For the local python offline evaluation, modify the following files and run as follows 247 | 248 | ``` 249 | # path/to/your/custom_task_setup.yaml: modify data_dir 250 | data_dir: 'path/to/your/data/folder' 251 | 252 | # Run and specify custom .yaml file 253 | python3 main.py --setup_yaml_path 'path/to/your/custom_task_setup.yaml' 254 | ``` 255 | 256 | *Note: when specifying a data folder, ensure all relative paths in the task setup .yaml file are valid* 257 | 258 | ## Open Division: Online evaluation 259 | 260 | To submit your final submission, we will utilize Dynabench as our online evaluation system. Please submit to the [Vision Dataperf task](https://dynabench.org/tasks/vision-selection). 261 | 262 | 263 | ## Closed Division: Creating a submission 264 | 265 | TBD. 266 | 267 | ## Closed Division: Offline evaluation of a submission 268 | 269 | TBD. 270 | 271 | # Baselines 272 | 273 | In the `baselines/` directory, we include the winning submissions for the beta version of this challenge, which also act as our baseline. The methods are: 274 | 1. FPSCV (Farthest Point Sampling Cross-Validation) by Paolo Climenco: This method selects negative examples by attempting to sample the feature search space through iterative maximum l2 distances, afterwards 275 | returning the best coreset under nested cross-validation.Official repository for this submission can be found at [Submission-Dataperf-Vision-Challenge](https://github.com/PaClimaco/Submissions-Dataperf-Vision-Challenge/), which also includes a detailed report on this method 276 | 2. Pseudo Label Generation by Danilo Brajovic: This method trains multiple neural networks and classical models on a subset of data to classify the remainder of points and uses the best-performing model for coreset proposal under multiple sampling experiments 277 | 3. Modified Uncertainty Sampling by Steve Mussmann: This method trains a binary classifier on noisy positive labels from OpenImages and uses this classifier to assign positive and negative image pools, with the coreset randomly sampled from both pools. 278 | Rafael Mosquera experiemnted with some classical and neural network based models, which we have included under `other_experiments/`. 279 | 280 | ``` 281 | # Run the baseline 282 | python3 baselines/fpscv/run.py 283 | python3 baselines/pseudo_label_generation/run.py 284 | python3 baselines/modified_uncertainty_sampling/run.py 285 | ``` -------------------------------------------------------------------------------- /baselines/fpscv/run.py: -------------------------------------------------------------------------------- 1 | # Author: Paolo Climaco 2 | # Author email: climaco@ins.uni-bonn.de 3 | 4 | import numpy as np 5 | from operator import itemgetter 6 | import torch 7 | from dgl.geometry import farthest_point_sampler 8 | from tqdm import tqdm 9 | import pandas as pd 10 | import numpy as np 11 | import shutil 12 | from typing import Dict, List 13 | import sklearn 14 | import sklearn.model_selection 15 | import sklearn.ensemble 16 | import sklearn.linear_model 17 | import tqdm 18 | import os 19 | import csv 20 | import json 21 | from sklearn.metrics import f1_score 22 | import warnings 23 | import requests 24 | 25 | with warnings.catch_warnings(): 26 | warnings.simplefilter("ignore") 27 | 28 | # ------------------------------------------ 29 | # PATHS 30 | # ----------------------------------------- 31 | 32 | embeddings_path = "../../data/embeddings/train_emb_256_dataperf.parquet" 33 | examples_path = "../../data/examples/" 34 | output_path = "output_repo/" # Change 35 | machine_human_labels_path = os.path.join( 36 | output_path, "oidv6-train-annotations-human-imagelabels.csv" 37 | ) # DO NOT CHANGE 38 | 39 | # Create output directory if does not exist 40 | if not os.path.exists(output_path): 41 | os.makedirs(output_path) 42 | 43 | 44 | # ------------------------------------------ 45 | # FUNCTIONS 46 | # ----------------------------------------- 47 | def download_csv_from_url(url, local_filename): 48 | try: 49 | response = requests.get(url, stream=True) 50 | response.raise_for_status() 51 | with open(local_filename, "wb") as f: 52 | for chunk in response.iter_content(chunk_size=8192): 53 | f.write(chunk) 54 | print(f"File downloaded and saved as: {local_filename}") 55 | except requests.exceptions.HTTPError as e: 56 | print(f"HTTP Error: {e}") 57 | except requests.exceptions.RequestException as e: 58 | print(f"Request Exception: {e}") 59 | 60 | 61 | def read_parquet(path=str): 62 | return pd.read_parquet(path, engine="pyarrow", use_threads=True) 63 | 64 | 65 | def get_labelled_samples(task=None, mhl=None, embeddings=None, examples_path=None): 66 | # load examples 67 | examples_df = pd.read_csv( 68 | os.path.join(examples_path, f"alpha_example_set_{task}.csv") 69 | ) 70 | # get label name examples 71 | LabelName = examples_df["LabelName"][0] 72 | 73 | mydf = mhl[mhl["ImageID"].isin(embeddings["ImageID"])] 74 | labelled_samples_df = mydf[mydf["LabelName"] == LabelName] 75 | labelled_samples_pd = pd.DataFrame( 76 | embeddings[embeddings["ImageID"].isin(labelled_samples_df["ImageID"])] 77 | ) 78 | labelled_samples_pd.loc[:, "Confidence"] = np.asarray( 79 | [ 80 | labelled_samples_df[labelled_samples_df["ImageID"] == image][ 81 | "Confidence" 82 | ].item() 83 | for image in labelled_samples_pd["ImageID"] 84 | ] 85 | ) 86 | labelled_samples_vect = np.stack(labelled_samples_pd["embedding"].to_numpy()) 87 | # return only those with confidence 1 88 | return ( 89 | labelled_samples_vect[labelled_samples_pd["Confidence"] == 1], 90 | labelled_samples_pd[labelled_samples_pd["Confidence"] == 1], 91 | ) 92 | 93 | 94 | def csv_to_json(csv, json_file_path): 95 | data = [] 96 | 97 | data = csv 98 | dictio = {} 99 | for i, c in zip(data["ImageID"], data["Confidence"]): 100 | dictio[i] = int(c) 101 | 102 | # Write JSON data to a file 103 | with open(json_file_path, "w") as json_file: 104 | json_file.write(json.dumps(dictio, indent=4)) 105 | 106 | 107 | # ------------------------------------------ 108 | # DOWNLOAD EMBEDDINGS 109 | # ----------------------------------------- 110 | print("Loading embeddings") 111 | embeddings = read_parquet(embeddings_path) 112 | embeddings_vect = np.stack(embeddings["embedding"].to_numpy()) 113 | print("Embeddings loaded") 114 | 115 | 116 | # ------------------------------------------ 117 | # LOAD MACHINE-HUMAN LABELS 118 | # ----------------------------------------- 119 | 120 | try: 121 | mhl = pd.read_csv(machine_human_labels_path) 122 | except FileNotFoundError: 123 | print("loading human-verified labels...") 124 | url = "https://storage.googleapis.com/openimages/v6/oidv6-train-annotations-human-imagelabels.csv" 125 | download_csv_from_url(url, machine_human_labels_path) 126 | mhl = pd.read_csv(machine_human_labels_path) 127 | 128 | 129 | # ------------------------------------------ 130 | # FARTHEST POINT SAMPLING: it takes around 17 minutes but it is a one time effort. SAVE AND LOAD 131 | # ----------------------------------------- 132 | 133 | try: 134 | idx_fps = np.load( 135 | os.path.join(output_path, "FPS_CrossV_init_0.npy") 136 | ) # np.load('/local/hdd/climaco/dataperf-vision-selection/data/selections_folder/FPS_CrossV_init_0.npy') 137 | except FileNotFoundError: 138 | print("FPS running...it takes a few minutes but it is a one-time effort") 139 | t_emb = torch.from_numpy(embeddings_vect).unsqueeze_(0) 140 | idx_fps = farthest_point_sampler(t_emb, 1000, start_idx=0).numpy()[0].tolist() 141 | np.save(os.path.join(output_path, "FPS_CrossV_init_0.npy"), idx_fps) 142 | idx_fps = np.asarray(idx_fps) 143 | 144 | 145 | # ------------------------------------------ 146 | # SELECTION ALGORITHM 147 | # ------------------------------------------ 148 | 149 | 150 | def Selection( 151 | task="None", 152 | n_confidence1=5, 153 | n_confidence0=40, 154 | n_folds=5, 155 | mhl=mhl, 156 | examples_path=examples_path, 157 | embeddings=embeddings, 158 | output_path=output_path, 159 | idx_fps=idx_fps, 160 | ): 161 | """ 162 | Parameters 163 | ---------- 164 | 165 | task : string 166 | The name of the target class, e.g. Hawk, Cupcake or Sushi. 167 | 168 | n_confidence1 : int 169 | Number of points in the final selected set with confidence 1 170 | 171 | n_confidence0 : int 172 | Number of points in the final selected set with confidence 0 173 | 174 | n_folds : int 175 | Number of folds in the nested cross validation 176 | 177 | mhl : pandas data frame 178 | Dataframe containing machine-generated labels and human annotations of images in oidv6 training set 179 | 180 | examples_path : string 181 | Path the the classes examples provided within the challenge 182 | 183 | embeddings : pandas data frame 184 | Data frame containing the available pool of data points provided within the challenge 185 | 186 | output_path : string 187 | Path to folder where algorithm's output is stored 188 | 189 | idx_fps : ndarray 190 | 191 | Returns 192 | ---------- 193 | csv and .json files containing selected images IDs and associated confidence (0 or 1) 194 | 195 | 196 | """ 197 | 198 | # get elements confidence 1 199 | confidence1_samples_vect, confidence1_samples_df = get_labelled_samples( 200 | task, mhl, embeddings, examples_path 201 | ) 202 | 203 | # Idxs selected elements with FPS sapling algorithm 204 | 205 | # Candidate with confidence 0 and candidates with confiddence 1 206 | confidence1_samples = confidence1_samples_vect 207 | confidence0_samples = embeddings_vect[idx_fps] 208 | confidence1_labels = np.ones(confidence1_samples.shape[0]) 209 | confidence0_labels = np.zeros(confidence0_samples.shape[0]) 210 | 211 | # Cross validation 212 | best_score = 0 213 | best_confidence1_train_ixs = None 214 | best_confidence0_train_ixs = None 215 | confidence0_class_size = n_confidence0 216 | confidence1_class_size = n_confidence1 217 | 218 | crossfold_confidence1 = sklearn.model_selection.StratifiedShuffleSplit( 219 | n_splits=n_folds, 220 | train_size=confidence1_class_size, 221 | random_state=23, 222 | ) 223 | 224 | for confidence1_train_ixs, confidence1_val_ixs in tqdm.tqdm( 225 | crossfold_confidence1.split(confidence1_samples, confidence1_labels), 226 | desc="k-fold cross validation", 227 | total=n_folds, 228 | ): 229 | crossfold_confidence0 = sklearn.model_selection.StratifiedShuffleSplit( 230 | n_splits=n_folds, 231 | train_size=confidence0_class_size, 232 | random_state=23, 233 | ) 234 | for confidence0_train_ixs, confidence0_val_ixs in crossfold_confidence0.split( 235 | confidence0_samples, confidence0_labels 236 | ): 237 | train_Xs = np.vstack( 238 | [ 239 | confidence1_samples[confidence1_train_ixs], 240 | confidence0_samples[confidence0_train_ixs], 241 | ] 242 | ) 243 | train_ys = np.concatenate( 244 | [ 245 | confidence1_labels[confidence1_train_ixs], 246 | confidence0_labels[confidence0_train_ixs], 247 | ] 248 | ) 249 | 250 | clf = sklearn.ensemble.VotingClassifier( 251 | estimators=[ 252 | ("lr", sklearn.linear_model.LogisticRegression()), 253 | ], 254 | voting="soft", 255 | weights=None, 256 | n_jobs=-1, 257 | ) 258 | clf.fit(train_Xs, train_ys) 259 | 260 | val_Xs = np.vstack( 261 | [ 262 | confidence1_samples[confidence1_val_ixs], 263 | confidence0_samples[confidence0_val_ixs], 264 | ] 265 | ) 266 | val_ys = np.concatenate( 267 | [ 268 | confidence1_labels[confidence1_val_ixs], 269 | confidence0_labels[confidence0_val_ixs], 270 | ] 271 | ) 272 | 273 | pred_Ys = clf.predict(val_Xs) 274 | 275 | score = f1_score(val_ys, pred_Ys, labels=[0, 1], average="binary") 276 | if score > best_score: 277 | best_score = score 278 | best_confidence1_train_ixs = confidence1_train_ixs 279 | best_confidence0_train_ixs = confidence0_train_ixs 280 | 281 | # store selected elements in csv 282 | submission_confidence1 = pd.DataFrame( 283 | confidence1_samples_df.iloc[best_confidence1_train_ixs]["ImageID"] 284 | ) 285 | submission_confidence1.loc[:, "Confidence"] = np.ones( 286 | len(submission_confidence1), dtype=int 287 | ).tolist() 288 | 289 | submission_confidence0 = pd.DataFrame( 290 | embeddings.iloc[idx_fps[best_confidence0_train_ixs]]["ImageID"] 291 | ) 292 | submission_confidence0.loc[:, "Confidence"] = np.zeros( 293 | len(submission_confidence0), dtype=int 294 | ).tolist() 295 | 296 | submission = pd.concat([submission_confidence1, submission_confidence0]) 297 | 298 | # seve selected elements in csv and json format 299 | submission.to_csv(os.path.join(output_path, f"{task}_fpscv_ft2.csv"), index=False) 300 | csv_to_json(submission, os.path.join(output_path, f"{task}_fpscv_ft2.json")) 301 | 302 | 303 | # ------------------------------------------ 304 | # RUN SELECTION ALGORITHM 305 | # ------------------------------------------ 306 | 307 | # fpscv average F1 score 78.10% 308 | """ 309 | print('Hawk') 310 | Selection('Hawk') 311 | 312 | print('Cupcake') 313 | Selection('Cupcake') 314 | 315 | print('Sushi') 316 | Selection('Sushi') 317 | """ 318 | # fpscv_ft average F1 score 79.95% 319 | """ 320 | print('Hawk') 321 | Selection('Hawk', n_confidence1=5, n_confidence0=20) 322 | 323 | print('Cupcake') 324 | Selection('Cupcake', n_confidence1=5, n_confidence0=40) 325 | 326 | print('Sushi') 327 | Selection('Sushi', n_confidence1=10, n_confidence0=20) 328 | """ 329 | 330 | # fpscv_ft2 average F1 score 81% 331 | print("Hawk") 332 | Selection("Hawk", n_confidence1=5, n_confidence0=90, n_folds=4) 333 | 334 | print("Cupcake") 335 | Selection("Cupcake", n_confidence1=5, n_confidence0=40) 336 | 337 | print("Sushi") 338 | Selection("Sushi", n_confidence1=10, n_confidence0=20) 339 | 340 | 341 | # ---------------------------------------------- 342 | # how to improve fpscv algorithm 343 | # ---------------------------------------------- 344 | 345 | """ 346 | Notice that the fpscv, fpscv_ft and fpscv_ft2 are implemented using the same algorithm that has been initialized 347 | considering different values for the parameters n_confidence1, n_confidence0, and n_folds. Different choices of these 348 | parameters lead to different results. 349 | 350 | 351 | The main issue of the developed approach is its sensibility to the choices of the parameters 352 | n_confidence1, n_confidence0 and n_folds. Such choices have been made heuristically. Can we do better than heuristic? 353 | Developing a principled approach for the optimization of the mentioned parameters may lead to a substantial improvement in terms 354 | of the F1 score on each of the considered tasks. 355 | 356 | Please if you do develop such a principled approach do not hesitate to contact me. I will be curious to hear about it. 357 | 358 | """ 359 | -------------------------------------------------------------------------------- /baselines/modified_uncertainty_sampling/get_tag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import requests 4 | 5 | if False: 6 | target_tag = "/m/03p1r4" 7 | output_file = "cupcake.txt" 8 | if False: 9 | target_tag = "/m/0fp7c" 10 | output_file = "hawk.txt" 11 | if True: 12 | target_tag = "/m/07030" 13 | output_file = "sushi_human.txt" 14 | 15 | 16 | # ------------------------------------------ 17 | # FUNCTIONS 18 | # ----------------------------------------- 19 | def download_csv_from_url(url, local_filename): 20 | try: 21 | response = requests.get(url, stream=True) 22 | response.raise_for_status() 23 | with open(local_filename, "wb") as f: 24 | for chunk in response.iter_content(chunk_size=8192): 25 | f.write(chunk) 26 | print(f"File downloaded and saved as: {local_filename}") 27 | except requests.exceptions.HTTPError as e: 28 | print(f"HTTP Error: {e}") 29 | except requests.exceptions.RequestException as e: 30 | print(f"Request Exception: {e}") 31 | 32 | 33 | # ------------------------------------------ 34 | # LOAD MACHINE-HUMAN LABELS 35 | # ----------------------------------------- 36 | 37 | output_path = "output_repo/" # Change 38 | machine_human_labels_path = os.path.join( 39 | output_path, "oidv6-train-annotations-human-imagelabels.csv" 40 | ) # DO NOT CHANGE 41 | try: 42 | mhl = pd.read_csv(machine_human_labels_path) 43 | except FileNotFoundError: 44 | print("loading human-verified labels...") 45 | url = "https://storage.googleapis.com/openimages/v6/oidv6-train-annotations-human-imagelabels.csv" 46 | download_csv_from_url(url, machine_human_labels_path) 47 | mhl = pd.read_csv(machine_human_labels_path) 48 | 49 | 50 | human_labels = [] 51 | 52 | for line in open("annotations-human.csv", "r").readlines(): 53 | image_id, _, tag, value = line.strip().split(",") 54 | 55 | if tag == target_tag and float(value) >= 0.5: 56 | human_labels.append(image_id) 57 | 58 | 59 | machine_labels = [] 60 | for line in open("annotations-machine.csv", "r").readlines(): 61 | image_id, _, tag, value = line.strip().split(",") 62 | 63 | if tag == target_tag and float(value) >= 0.5: 64 | machine_labels.append(image_id) 65 | 66 | 67 | print(len(human_labels)) 68 | print(len(machine_labels)) 69 | 70 | print(len(set(human_labels).intersection(set(machine_labels)))) 71 | 72 | if True: 73 | f = open(output_file, "w") 74 | for image_id in set(human_labels): 75 | f.write(image_id + "\n") 76 | f.close() 77 | 78 | if False: 79 | f = open(output_file, "w") 80 | for image_id in set(human_labels + machine_labels): 81 | f.write(image_id + "\n") 82 | f.close() 83 | -------------------------------------------------------------------------------- /baselines/modified_uncertainty_sampling/run.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import OED 4 | from sklearn.linear_model import LogisticRegression 5 | 6 | task = "hawk" 7 | strategy = "diverse" 8 | 9 | 10 | with open("data/embedding_image_ids.pkl","rb") as f: 11 | image_ids = pickle.load(f) 12 | print(len(image_ids)) 13 | 14 | imid_to_idx = {} 15 | for i in range(len(image_ids)): 16 | imid_to_idx[image_ids[i]] = i 17 | 18 | with open("data/embeddings.npy","rb") as f: 19 | embeddings = np.load(f) 20 | print(embeddings.shape) 21 | 22 | embeddings = embeddings - np.mean(embeddings, axis=0) 23 | 24 | #cov = embeddings.T @ embeddings / embeddings.shape[0] 25 | #print(np.linalg.eigvalsh(cov)) 26 | 27 | with open("data/"+task+".txt","r") as f: 28 | positives = [] 29 | for line in f.readlines(): 30 | positives.append(line.strip()) 31 | print(len(positives)) 32 | 33 | embedded_positives = list(set(positives).intersection(set(image_ids))) 34 | print(len(embedded_positives)) 35 | 36 | 37 | 38 | n, d = embeddings.shape 39 | 40 | labels = np.zeros((n,)) 41 | positive_idx = [] 42 | for pos in embedded_positives: 43 | positive_idx.append(imid_to_idx[pos]) 44 | labels[positive_idx] = 1 45 | 46 | model = LogisticRegression(C = 10**6, max_iter = 1000) 47 | 48 | print("starting fit") 49 | if False: 50 | model.fit(embeddings,labels) 51 | with open("data/model_"+task+".pkl","wb") as f: 52 | pickle.dump(model,f) 53 | else: 54 | with open("data/model_"+task+".pkl","rb") as f: 55 | model = pickle.load(f) 56 | print("finished fit") 57 | fit_labels = model.predict(embeddings) 58 | 59 | TP = np.sum( fit_labels[positive_idx] == 1) 60 | FN = np.sum( fit_labels[positive_idx] != 1) 61 | FP = np.sum( np.logical_and(fit_labels == 1, labels == 0) ) 62 | 63 | print("True positives: {}\nFalse negatives: {}\nFalse positives: {}".format(TP,FN,FP)) 64 | 65 | 66 | probs = model.predict_proba(embeddings)[:,1] 67 | 68 | pp = n//100 69 | 70 | if task == "sushi": 71 | topp = np.argsort(probs)[-5*pp:-2*pp] 72 | if task == "cupcake": 73 | topp = np.argsort(probs)[-5*pp:-2*pp] 74 | if task == "hawk": 75 | topp = np.argsort(probs)[-5*pp:-2*pp] 76 | 77 | #if task == "sushi": 78 | # close_idx = np.logical_and(probs >= 10**-6,probs < 10**-4).nonzero()[0] 79 | #elif task == "cupcake": 80 | # close_idx = np.logical_and(probs >= 10**-6,probs < 10**-5).nonzero()[0] 81 | #elif task == "hawk": 82 | # close_idx = np.logical_and(probs >= 10**-5,probs < 10**-4).nonzero()[0] 83 | if task == "cupcake": 84 | true_idx = np.logical_and(probs >= 0.5, labels==1).nonzero()[0] 85 | if task == "sushi": 86 | true_idx = (labels==1).nonzero()[0] 87 | if task == "hawk": 88 | true_idx = np.logical_and(probs >= 0.5, labels==1).nonzero()[0] 89 | 90 | close_idx = list(set(topp).difference(set(true_idx))) 91 | 92 | print(len(close_idx)) 93 | print(len(true_idx)) 94 | 95 | if strategy == "even": 96 | with open("data/train_sets/"+task+"_even.csv","w") as f: 97 | f.write("ImageID,Confidence\n") 98 | 99 | true_labels = min(500, len(true_idx)) 100 | 101 | for idx in np.random.choice(true_idx,size=(true_labels,),replace=False): 102 | f.write("{},1\n".format(image_ids[idx])) 103 | 104 | for idx in np.random.choice(close_idx,size=(1000-true_labels,),replace=False): 105 | f.write("{},0\n".format(image_ids[idx])) 106 | 107 | if strategy == "diverse": 108 | with open("data/train_sets/"+task+"_27.csv","w") as f: 109 | f.write("ImageID,Confidence\n") 110 | 111 | if task == "cupcake": 112 | true_labels = 500 113 | if task == "sushi": 114 | true_labels = 500 115 | if task == "hawk": 116 | true_labels = 333 117 | 118 | 119 | true_embeddings = embeddings[true_idx,:] 120 | design_points = OED.random(true_embeddings,true_labels) 121 | selected_idx = [true_idx[point] for point in design_points] 122 | 123 | for idx in selected_idx: 124 | f.write("{},1\n".format(image_ids[idx])) 125 | 126 | 127 | close_embeddings = embeddings[close_idx,:] 128 | design_points = OED.random(close_embeddings,1000-true_labels) 129 | selected_idx = [close_idx[point] for point in design_points] 130 | 131 | for idx in selected_idx: 132 | f.write("{},0\n".format(image_ids[idx])) 133 | 134 | 135 | 136 | 137 | print("done!") 138 | -------------------------------------------------------------------------------- /baselines/other_experiments/run.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "d0ef2ba7-b0cd-436a-bdde-716e47f1c67f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "d7a5c8b4-6f1a-4dfe-8a79-e66ea6ddd3aa", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "path = '/Users/rafael/Downloads/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet/part-00000-c487d15f-d29d-4c93-96c2-4caf4c12cd59-c000.snappy.parquet'" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "id": "f9c17b76-ce54-43a0-b673-5b8ee297a5d9", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "df = pd.read_parquet(path)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "id": "6edd744e-4da3-4aad-8432-88383fd19e4f", 38 | "metadata": { 39 | "tags": [] 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "for i in range(10):\n", 44 | " parquet_num = str(i).zfill(5)\n", 45 | " path = f'/Users/rafael/Downloads/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet/part-{parquet_num}-c487d15f-d29d-4c93-96c2-4caf4c12cd59-c000.snappy.parquet'\n", 46 | " df2 = pd.read_parquet(path)\n", 47 | " df = df.append(df2)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 20, 53 | "id": "948a5074-99c6-46fc-aed3-f09d2ed26b67", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/html": [ 59 | "
\n", 60 | "\n", 73 | "\n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | "
ImageIDembedding
00001771c86cadcd6[0.7272707919841341, 1.336483982316891, -0.164...
10003db14e5cdc1c4[1.7126857273368188, -0.11685025791386651, -1....
200080e248756d4dd[0.8326104343431097, 0.4016596514351839, 1.017...
300094d307650046a[0.3379979616695768, -0.3716915576725939, -0.9...
4000b76a9b80ba43a[1.0386764605271153, 1.0388317175923953, 0.425...
.........
16677ffee8a942d5f46f5[-0.4032568158266959, -1.1216809493878481, 0.3...
16678fff5531bf50e14ff[-1.3310225644952542, 0.6147678624111613, -0.7...
16679fff8991fe2ddeb7e[-1.085723378698474, -0.998268968641006, -0.29...
16680fffd06d7764cfc52[-1.8009028307864514, 1.2498299457095756, -0.3...
16681fffd81ba57715c4c[-0.12192140594927368, -1.9914731328587736, -0...
\n", 139 | "

182547 rows × 2 columns

\n", 140 | "
" 141 | ], 142 | "text/plain": [ 143 | " ImageID embedding\n", 144 | "0 0001771c86cadcd6 [0.7272707919841341, 1.336483982316891, -0.164...\n", 145 | "1 0003db14e5cdc1c4 [1.7126857273368188, -0.11685025791386651, -1....\n", 146 | "2 00080e248756d4dd [0.8326104343431097, 0.4016596514351839, 1.017...\n", 147 | "3 00094d307650046a [0.3379979616695768, -0.3716915576725939, -0.9...\n", 148 | "4 000b76a9b80ba43a [1.0386764605271153, 1.0388317175923953, 0.425...\n", 149 | "... ... ...\n", 150 | "16677 ffee8a942d5f46f5 [-0.4032568158266959, -1.1216809493878481, 0.3...\n", 151 | "16678 fff5531bf50e14ff [-1.3310225644952542, 0.6147678624111613, -0.7...\n", 152 | "16679 fff8991fe2ddeb7e [-1.085723378698474, -0.998268968641006, -0.29...\n", 153 | "16680 fffd06d7764cfc52 [-1.8009028307864514, 1.2498299457095756, -0.3...\n", 154 | "16681 fffd81ba57715c4c [-0.12192140594927368, -1.9914731328587736, -0...\n", 155 | "\n", 156 | "[182547 rows x 2 columns]" 157 | ] 158 | }, 159 | "execution_count": 20, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "df" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 15, 171 | "id": "46290559-0c60-412d-986c-ac65395e7bea", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "import boto3" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 16, 181 | "id": "fd00414c-2760-4316-ab94-9fe7a3459b58", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "session = boto3.Session(profile_name = 'MLCommons', region_name = \"us-west-1\")\n", 186 | "s3_client = session.client('s3')" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 45, 192 | "id": "c4bccb48-9779-4fa6-9422-4c4843ec62b5", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "available_ids = []" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 46, 202 | "id": "31d644ea-9d26-48e3-82c4-f8d4c9b8fae8", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "for i in s3_client.list_objects_v2(\n", 207 | " Bucket = 'vision-dataperf',\n", 208 | " Prefix = 'public_train_dataset_256embeddings_dynabench_formatted/train')['Contents']:\n", 209 | " #StartAfter = 'public_train_dataset_256embeddings_dynabench_formatted/train0053e81841897c73.npy' )['Contents']:\n", 210 | " id = i['Key']\n", 211 | " id = id.strip('public_train_dataset_256embeddings_dynabench_formatted/train')\n", 212 | " id = id.strip('.npy')\n", 213 | " if df.loc[df['ImageID'] == id].empty:\n", 214 | " continue\n", 215 | " else:\n", 216 | " available_ids.append(id)\n", 217 | " #print(df.loc[df['ImageID'] == id])" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 89, 223 | "id": "537cf93d-e270-4fed-b97a-ad10ce24610a", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "n = 1000" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 90, 233 | "id": "66b2c742-5f77-4844-b62f-ce9d01053e19", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "hawks = {}\n", 238 | "for i in random.sample(available_ids,n):\n", 239 | " if (bool(random.getrandbits(1))):\n", 240 | " hawks[i] = 1\n", 241 | " else:\n", 242 | " hawks[i] = 0\n", 243 | "hawks = {'Hawks': hawks}\n", 244 | "with open(\"samples/Hawks.json\", \"w\") as outfile:\n", 245 | " json.dump(hawks, outfile)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 91, 251 | "id": "b2f959f3-b6de-4f24-9213-31cee1684ee9", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "sushi = {}\n", 256 | "for i in random.sample(available_ids,n):\n", 257 | " if (bool(random.getrandbits(1))):\n", 258 | " sushi[i] = 1\n", 259 | " else:\n", 260 | " sushi[i] = 0\n", 261 | "sushi = {'Sushi': sushi}\n", 262 | "with open(\"samples/Sushi.json\", \"w\") as outfile:\n", 263 | " json.dump(sushi, outfile)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 92, 269 | "id": "1cbe05e9-434a-4078-a7bf-5726e570f9b2", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "cupcake = {}\n", 274 | "for i in random.sample(available_ids,n):\n", 275 | " if (bool(random.getrandbits(1))):\n", 276 | " cupcake[i] = 1\n", 277 | " else:\n", 278 | " cupcake[i] = 0\n", 279 | "cupcake = {'Cupcake': cupcake}\n", 280 | "with open(\"samples/Cupcake.json\", \"w\") as outfile:\n", 281 | " json.dump(cupcake, outfile)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 6, 287 | "id": "5ff94e29-d09d-443c-834f-cedac83461a9", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "import pickle" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 205, 297 | "id": "60b3acda-59e3-41c1-bfc8-06940db5bff8", 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "with open('outfile', 'wb') as fp:\n", 302 | " pickle.dump(set(available_ids), fp)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 53, 308 | "id": "437d2f25-61c7-49eb-95c5-cbda837a3676", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "with open ('outfile', 'rb') as fp:\n", 313 | " available_ids = list(pickle.load(fp))" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 55, 319 | "id": "4b3b482a-7e51-471a-a45d-0b065e567e38", 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "import random\n", 324 | "import json" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 56, 330 | "id": "c6729787-ff3e-48d0-b174-afd9e63db753", 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "n = 1000" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 57, 340 | "id": "e9493083-f67d-4902-b97a-aacbe08d5e8a", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "birds = {}\n", 345 | "for i in random.sample(available_ids,n):\n", 346 | " if (bool(random.getrandbits(1))):\n", 347 | " birds[i] = 1\n", 348 | " else:\n", 349 | " birds[i] = 0\n", 350 | "birds = {'Birds': birds}\n", 351 | "with open(\"samples/Birds.json\", \"w\") as outfile:\n", 352 | " json.dump(birds, outfile)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 58, 358 | "id": "582e0970-faaf-4a08-9c60-b0b7e9e97f7d", 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "canoe = {}\n", 363 | "for i in random.sample(available_ids,n):\n", 364 | " if (bool(random.getrandbits(1))):\n", 365 | " canoe[i] = 1\n", 366 | " else:\n", 367 | " canoe[i] = 0\n", 368 | "canoe = {'Canoe': canoe}\n", 369 | "with open(\"samples/Canoe.json\", \"w\") as outfile:\n", 370 | " json.dump(canoe, outfile)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 60, 376 | "id": "bd9de99e-213e-426f-b9e6-9e05f9ac4fe4", 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "croissant = {}\n", 381 | "for i in random.sample(available_ids,n):\n", 382 | " if (bool(random.getrandbits(1))):\n", 383 | " croissant[i] = 1\n", 384 | " else:\n", 385 | " croissant[i] = 0\n", 386 | "croissant = {'Croissant': croissant}\n", 387 | "with open(\"samples/Croissant.json\", \"w\") as outfile:\n", 388 | " json.dump(croissant, outfile)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 61, 394 | "id": "f79f8173-ff0c-46ef-93fb-17ca85bea5a1", 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "muffin = {}\n", 399 | "for i in random.sample(available_ids,n):\n", 400 | " if (bool(random.getrandbits(1))):\n", 401 | " muffin[i] = 1\n", 402 | " else:\n", 403 | " muffin[i] = 0\n", 404 | "muffin = {'Muffin': muffin}\n", 405 | "with open(\"samples/Muffin.json\", \"w\") as outfile:\n", 406 | " json.dump(muffin, outfile)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 62, 412 | "id": "f14f619b-e331-44e4-887f-73ec8afdc865", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "pizza = {}\n", 417 | "for i in random.sample(available_ids,n):\n", 418 | " if (bool(random.getrandbits(1))):\n", 419 | " pizza[i] = 1\n", 420 | " else:\n", 421 | " pizza[i] = 0\n", 422 | "pizza = {'Pizza': pizza}\n", 423 | "with open(\"samples/Pizza.json\", \"w\") as outfile:\n", 424 | " json.dump(pizza, outfile)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 136, 430 | "id": "c45054cd-6ff7-46b0-8674-9115f8715286", 431 | "metadata": {}, 432 | "outputs": [ 433 | { 434 | "name": "stdout", 435 | "output_type": "stream", 436 | "text": [ 437 | "50 50\n" 438 | ] 439 | } 440 | ], 441 | "source": [ 442 | "count = 0\n", 443 | "for id in list(croissant[\"Croissant\"].keys()):\n", 444 | "\n", 445 | " if df.loc[df['ImageID'] == id].empty:\n", 446 | " continue\n", 447 | " else:\n", 448 | " count += 1\n", 449 | "print(count, len(croissant[\"Croissant\"]))" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 138, 455 | "id": "5e9ddcf9-fab1-4680-9baf-092dd0c017c1", 456 | "metadata": {}, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "text/html": [ 461 | "
\n", 462 | "\n", 475 | "\n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | "
ImageIDembedding
0000335f65ee227cf[-1.383943737943668, -0.9189053270236168, -0.1...
\n", 491 | "
" 492 | ], 493 | "text/plain": [ 494 | " ImageID embedding\n", 495 | "0 000335f65ee227cf [-1.383943737943668, -0.9189053270236168, -0.1..." 496 | ] 497 | }, 498 | "execution_count": 138, 499 | "metadata": {}, 500 | "output_type": "execute_result" 501 | } 502 | ], 503 | "source": [ 504 | "id = '000335f65ee227cf'\n", 505 | "df.loc[df['ImageID'] == id]" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "id": "cadedc40-4a3f-435e-9672-7c722a7fc193", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "parquet_paths = ['/Users/rafael/Downloads/part-00000-fed1d8f9-f6ab-4f54-88aa-9908abd59bec-c000.snappy.parquet','/Users/rafael/Downloads/part-00001-fed1d8f9-f6ab-4f54-88aa-9908abd59bec-c000.snappy.parquet']\n", 516 | "df = pd.read_parquet(parquet_paths[0])\n", 517 | "df2 = pd.read_parquet(parquet_paths[1])\n", 518 | "df = df.append(df2)\n", 519 | "df" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "id": "36563be4-a891-4ae9-91f1-922049477bd6", 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 15, 533 | "id": "b8ca8457-5bcc-4121-9cce-637749b8c4cc", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "muffin2 = {}\n", 538 | "for i in range(51):\n", 539 | " if i%2 == 1:\n", 540 | " muffin2[df['ImageID'].values[i]] = 1\n", 541 | " else:\n", 542 | " muffin2[df['ImageID'].values[i]] = 0\n", 543 | "muffin2 = {'Muffin': muffin2}\n", 544 | "with open(\"Muffin2.json\", \"w\") as outfile:\n", 545 | " json.dump(muffin2, outfile)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 157, 551 | "id": "ded50481-c85d-4b4e-b3f3-049c5b691b31", 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "data": { 556 | "text/plain": [ 557 | "{'dataperf_f1': 54.4,\n", 558 | " 'perf': 54.4,\n", 559 | " 'perf_std': 0.0,\n", 560 | " 'perf_by_tag': [{'tag': 'bird',\n", 561 | " 'pretty_perf': '54.4 %',\n", 562 | " 'perf': 54.4,\n", 563 | " 'perf_std': 0.0,\n", 564 | " 'perf_dict': {'dataperf_f1': 54.4}}]}" 565 | ] 566 | }, 567 | "execution_count": 157, 568 | "metadata": {}, 569 | "output_type": "execute_result" 570 | } 571 | ], 572 | "source": [ 573 | "{\"dataperf_f1\": 54.4, \"perf\": 54.4, \"perf_std\": 0.0, \"perf_by_tag\": [{\"tag\": \"bird\", \"pretty_perf\": \"54.4 %\", \"perf\": 54.4, \"perf_std\": 0.0, \"perf_dict\": {\"dataperf_f1\": 54.4}}]}" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 296, 579 | "id": "706ead65-2a25-4d71-8daf-e2ac15f23e22", 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "df_new = pd.read_parquet('/Users/rafael/Downloads/dataperf-vision-selection-resources/test_sets/test-hawk.parquet')" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 297, 589 | "id": "d726fce3-438b-433a-9d39-02d4390ea718", 590 | "metadata": {}, 591 | "outputs": [ 592 | { 593 | "data": { 594 | "text/plain": [ 595 | "['target_label', 'Embedding', 'ImageID']" 596 | ] 597 | }, 598 | "execution_count": 297, 599 | "metadata": {}, 600 | "output_type": "execute_result" 601 | } 602 | ], 603 | "source": [ 604 | "['target_label','Embedding', 'ImageID']" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 298, 610 | "id": "d138dd58-87cf-41bd-b704-84eac08770b8", 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "df_new = df_new.rename({'ImageID':'ImageID', 'Confidence': 'target_label', 'embedding': 'Embedding'}, axis=1)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 299, 620 | "id": "dc671ecf-2c3d-46d7-8929-6be6fb432b4e", 621 | "metadata": {}, 622 | "outputs": [], 623 | "source": [ 624 | "df_new.to_parquet(path = '/Users/rafael/Downloads/dataperf-vision-selection-resources/test_sets/test-hawk.parquet')" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 300, 630 | "id": "b06f4993-7d45-4a37-a021-6a0e0ae092ef", 631 | "metadata": {}, 632 | "outputs": [ 633 | { 634 | "data": { 635 | "text/html": [ 636 | "
\n", 637 | "\n", 650 | "\n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | "
ImageIDLabelNametarget_labelOriginalURLDisplayNameEmbedding
0e60d148d428e96cd/m/0fp7c0https://c6.staticflickr.com/6/5560/14764818909...Hawk[-0.0038521280919204775, 0.06254096050041212, ...
1bc2260ec5b70b7e5/m/0fp7c0https://farm6.staticflickr.com/3657/3607619201...Hawk[0.19757792792581547, -0.0628706741084504, -1....
2be7aae0d065c11a0/m/0fp7c1https://farm1.staticflickr.com/7253/6996634521...Hawk[0.17347788607282416, -0.37831432327491554, -1...
325880c4f38d01747/m/0fp7c1https://farm8.staticflickr.com/94/257466316_26...Hawk[0.38980357065885335, -0.9966469970143762, -1....
4904078b0695ffe0f/m/0fp7c1https://c8.staticflickr.com/8/7173/6663562541_...Hawk[0.39844726358909366, -0.2868099119322688, -1....
.....................
1407ead8a05548b7dd8/m/0fp7c1https://c1.staticflickr.com/7/6127/6003146852_...Hawk[0.7204737072851969, -0.5799575879679919, -1.7...
141b6030979ffbf6c00/m/0fp7c0https://farm1.staticflickr.com/192/503607905_4...Hawk[0.7399902660512262, -0.34497502398512464, -0....
1426f66217f8ffdbeea/m/0fp7c1https://c3.staticflickr.com/4/3723/13695947494...Hawk[0.5977249684386043, -0.5686997837439453, -0.6...
143b2692b73994377ce/m/0fp7c0https://c7.staticflickr.com/9/8112/8539254331_...Hawk[0.6778727524730443, -0.8354259564874617, -0.7...
144d84b26c3865bec24/m/0fp7c1https://farm5.staticflickr.com/3893/1498221092...Hawk[0.5801936116195033, -0.9496025449658846, -0.8...
\n", 764 | "

145 rows × 6 columns

\n", 765 | "
" 766 | ], 767 | "text/plain": [ 768 | " ImageID LabelName target_label \\\n", 769 | "0 e60d148d428e96cd /m/0fp7c 0 \n", 770 | "1 bc2260ec5b70b7e5 /m/0fp7c 0 \n", 771 | "2 be7aae0d065c11a0 /m/0fp7c 1 \n", 772 | "3 25880c4f38d01747 /m/0fp7c 1 \n", 773 | "4 904078b0695ffe0f /m/0fp7c 1 \n", 774 | ".. ... ... ... \n", 775 | "140 7ead8a05548b7dd8 /m/0fp7c 1 \n", 776 | "141 b6030979ffbf6c00 /m/0fp7c 0 \n", 777 | "142 6f66217f8ffdbeea /m/0fp7c 1 \n", 778 | "143 b2692b73994377ce /m/0fp7c 0 \n", 779 | "144 d84b26c3865bec24 /m/0fp7c 1 \n", 780 | "\n", 781 | " OriginalURL DisplayName \\\n", 782 | "0 https://c6.staticflickr.com/6/5560/14764818909... Hawk \n", 783 | "1 https://farm6.staticflickr.com/3657/3607619201... Hawk \n", 784 | "2 https://farm1.staticflickr.com/7253/6996634521... Hawk \n", 785 | "3 https://farm8.staticflickr.com/94/257466316_26... Hawk \n", 786 | "4 https://c8.staticflickr.com/8/7173/6663562541_... Hawk \n", 787 | ".. ... ... \n", 788 | "140 https://c1.staticflickr.com/7/6127/6003146852_... Hawk \n", 789 | "141 https://farm1.staticflickr.com/192/503607905_4... Hawk \n", 790 | "142 https://c3.staticflickr.com/4/3723/13695947494... Hawk \n", 791 | "143 https://c7.staticflickr.com/9/8112/8539254331_... Hawk \n", 792 | "144 https://farm5.staticflickr.com/3893/1498221092... Hawk \n", 793 | "\n", 794 | " Embedding \n", 795 | "0 [-0.0038521280919204775, 0.06254096050041212, ... \n", 796 | "1 [0.19757792792581547, -0.0628706741084504, -1.... \n", 797 | "2 [0.17347788607282416, -0.37831432327491554, -1... \n", 798 | "3 [0.38980357065885335, -0.9966469970143762, -1.... \n", 799 | "4 [0.39844726358909366, -0.2868099119322688, -1.... \n", 800 | ".. ... \n", 801 | "140 [0.7204737072851969, -0.5799575879679919, -1.7... \n", 802 | "141 [0.7399902660512262, -0.34497502398512464, -0.... \n", 803 | "142 [0.5977249684386043, -0.5686997837439453, -0.6... \n", 804 | "143 [0.6778727524730443, -0.8354259564874617, -0.7... \n", 805 | "144 [0.5801936116195033, -0.9496025449658846, -0.8... \n", 806 | "\n", 807 | "[145 rows x 6 columns]" 808 | ] 809 | }, 810 | "execution_count": 300, 811 | "metadata": {}, 812 | "output_type": "execute_result" 813 | } 814 | ], 815 | "source": [ 816 | "df_2 = pd.read_parquet('/Users/rafael/Downloads/dataperf-vision-selection-resources/test_sets/test-hawk.parquet')\n", 817 | "df_2" 818 | ] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "execution_count": 307, 823 | "id": "22d9d865-98af-4c30-8384-917dea73d680", 824 | "metadata": {}, 825 | "outputs": [ 826 | { 827 | "data": { 828 | "text/plain": [ 829 | "256" 830 | ] 831 | }, 832 | "execution_count": 307, 833 | "metadata": {}, 834 | "output_type": "execute_result" 835 | } 836 | ], 837 | "source": [ 838 | "len(df_2['Embedding'][0])" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 235, 844 | "id": "565d7ba7-f8ac-4bbb-9ee6-ef02df8ac110", 845 | "metadata": {}, 846 | "outputs": [], 847 | "source": [ 848 | "df_sushi = pd.read_csv('/Users/rafael/Downloads/dataperf-vision-selection-resources/train_sets/random_500.csv')" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 260, 854 | "id": "4dc48edc-adf0-4956-9047-5017fc6f4b39", 855 | "metadata": {}, 856 | "outputs": [], 857 | "source": [ 858 | "sushi_dict = {'Sushi': {}}" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": 264, 864 | "id": "bbf40598-c4d7-45fa-8636-9b4ca469b5e2", 865 | "metadata": {}, 866 | "outputs": [], 867 | "source": [ 868 | "labels = list(df_sushi['Confidence'].values)" 869 | ] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "execution_count": 265, 874 | "id": "c7954136-f869-441c-9767-f5b99c8dac9f", 875 | "metadata": {}, 876 | "outputs": [], 877 | "source": [ 878 | "ids = list(df_sushi['ImageID'].values)" 879 | ] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "execution_count": 274, 884 | "id": "c52e39cb-fbfd-4602-b045-a4cb717c026b", 885 | "metadata": {}, 886 | "outputs": [], 887 | "source": [ 888 | "intra_dict = {}\n", 889 | "for i in range(500):\n", 890 | " intra_dict[f'{ids[i]}'] = float(labels[i])\n", 891 | " " 892 | ] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "execution_count": 275, 897 | "id": "17c639f1-930c-4c73-8679-e46ab00553f9", 898 | "metadata": {}, 899 | "outputs": [], 900 | "source": [ 901 | "sushi_dict['Sushi'] = intra_dict" 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": 277, 907 | "id": "8523dd33-e4ef-446a-a429-c9674e0d980e", 908 | "metadata": {}, 909 | "outputs": [], 910 | "source": [ 911 | "with open(\"Sushi.json\", \"w\") as outfile:\n", 912 | " json.dump(sushi_dict, outfile)" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": 316, 918 | "id": "26c87642-9d8f-456f-844f-d24bee7ff103", 919 | "metadata": {}, 920 | "outputs": [], 921 | "source": [ 922 | "embedding_dict = {}" 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": 394, 928 | "id": "5f8b25a9-32da-4eb1-9364-8f923797747c", 929 | "metadata": {}, 930 | "outputs": [], 931 | "source": [ 932 | "df_new = pd.read_parquet('/Users/rafael/Downloads/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet/part-00071-c487d15f-d29d-4c93-96c2-4caf4c12cd59-c000.snappy.parquet')" 933 | ] 934 | }, 935 | { 936 | "cell_type": "code", 937 | "execution_count": 395, 938 | "id": "92678025-174a-40e8-b1aa-bf925a2173d4", 939 | "metadata": {}, 940 | "outputs": [ 941 | { 942 | "data": { 943 | "text/html": [ 944 | "
\n", 945 | "\n", 958 | "\n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | "
ImageIDembedding
0000506542b3a9dd7[-1.2127153543616391, -0.10673216263881578, 0....
100056c76aa3dd52a[0.7178898916090762, 0.826355497831251, 0.7906...
2000692d348b0f181[0.6396630836552186, 0.898348981106668, 0.4226...
3000cabb2eb284c8c[1.1945492605513477, 0.8496085515788221, -0.67...
40011cf9a929a4e19[0.5329256287734359, 1.8642507257449732, 0.614...
.........
16466ffe1a2e32b45d1dc[-0.5439635132177799, -1.0043923233747394, 0.9...
16467ffe46a38144178dd[-0.4237458416684124, 0.010014072469498137, -0...
16468fff3c2eb1c9863c7[-0.4050140839736052, 0.9564130760744758, 1.04...
16469fff41ccc90430cc0[-0.4732436203454635, 0.21041616335678826, 0.9...
16470ffffdc1308c30d53[-0.5498107897396933, 0.9266576856606651, 0.31...
\n", 1024 | "

16471 rows × 2 columns

\n", 1025 | "
" 1026 | ], 1027 | "text/plain": [ 1028 | " ImageID embedding\n", 1029 | "0 000506542b3a9dd7 [-1.2127153543616391, -0.10673216263881578, 0....\n", 1030 | "1 00056c76aa3dd52a [0.7178898916090762, 0.826355497831251, 0.7906...\n", 1031 | "2 000692d348b0f181 [0.6396630836552186, 0.898348981106668, 0.4226...\n", 1032 | "3 000cabb2eb284c8c [1.1945492605513477, 0.8496085515788221, -0.67...\n", 1033 | "4 0011cf9a929a4e19 [0.5329256287734359, 1.8642507257449732, 0.614...\n", 1034 | "... ... ...\n", 1035 | "16466 ffe1a2e32b45d1dc [-0.5439635132177799, -1.0043923233747394, 0.9...\n", 1036 | "16467 ffe46a38144178dd [-0.4237458416684124, 0.010014072469498137, -0...\n", 1037 | "16468 fff3c2eb1c9863c7 [-0.4050140839736052, 0.9564130760744758, 1.04...\n", 1038 | "16469 fff41ccc90430cc0 [-0.4732436203454635, 0.21041616335678826, 0.9...\n", 1039 | "16470 ffffdc1308c30d53 [-0.5498107897396933, 0.9266576856606651, 0.31...\n", 1040 | "\n", 1041 | "[16471 rows x 2 columns]" 1042 | ] 1043 | }, 1044 | "execution_count": 395, 1045 | "metadata": {}, 1046 | "output_type": "execute_result" 1047 | } 1048 | ], 1049 | "source": [ 1050 | "df_new" 1051 | ] 1052 | }, 1053 | { 1054 | "cell_type": "code", 1055 | "execution_count": 317, 1056 | "id": "69da299b-8829-49ec-a421-5e642cdce133", 1057 | "metadata": {}, 1058 | "outputs": [], 1059 | "source": [ 1060 | "embedding_list = (df_2['Embedding'].to_list())" 1061 | ] 1062 | }, 1063 | { 1064 | "cell_type": "code", 1065 | "execution_count": 319, 1066 | "id": "6d8d94c9-c246-472e-8c54-1c0d168d652a", 1067 | "metadata": {}, 1068 | "outputs": [], 1069 | "source": [ 1070 | "id_names = (df_2['ImageID'].to_list())" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "code", 1075 | "execution_count": 326, 1076 | "id": "3cab8316-bf74-40ae-a7d8-d2826ca70160", 1077 | "metadata": {}, 1078 | "outputs": [], 1079 | "source": [ 1080 | "for i in range(len(embedding_list)):\n", 1081 | " embedding_dict[f'train{id_names[i]}.npy'] = embedding_list[i]" 1082 | ] 1083 | }, 1084 | { 1085 | "cell_type": "code", 1086 | "execution_count": 330, 1087 | "id": "361e7bb5-6a3c-479e-9ab8-9c5f014f569a", 1088 | "metadata": {}, 1089 | "outputs": [], 1090 | "source": [ 1091 | "import boto3" 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "code", 1096 | "execution_count": 361, 1097 | "id": "9150bbd6-d0b6-4530-ba00-1153da0be83a", 1098 | "metadata": {}, 1099 | "outputs": [], 1100 | "source": [ 1101 | "new_emb['train37510ffe72539c23.npy'] = embedding_dict['train37510ffe72539c23.npy']" 1102 | ] 1103 | }, 1104 | { 1105 | "cell_type": "code", 1106 | "execution_count": 376, 1107 | "id": "57d0a835-82df-4583-8d9b-63a115551e19", 1108 | "metadata": {}, 1109 | "outputs": [ 1110 | { 1111 | "data": { 1112 | "text/plain": [ 1113 | "numpy.ndarray" 1114 | ] 1115 | }, 1116 | "execution_count": 376, 1117 | "metadata": {}, 1118 | "output_type": "execute_result" 1119 | } 1120 | ], 1121 | "source": [ 1122 | "type(new_emb['train37510ffe72539c23.npy'])" 1123 | ] 1124 | }, 1125 | { 1126 | "cell_type": "code", 1127 | "execution_count": null, 1128 | "id": "0b72dd39-41b8-4d5c-8a3a-a8f2b4da48e2", 1129 | "metadata": {}, 1130 | "outputs": [], 1131 | "source": [ 1132 | "ftmp = tempfile.NamedTemporaryFile(delete=False)\n", 1133 | "fname = ftmp.name + \".npy\"\n", 1134 | "inp = np.random.rand(100)\n", 1135 | "np.save(fname, inp, allow_pickle=False)\n", 1136 | "out = np.load(fname, allow_pickle=False)" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "code", 1141 | "execution_count": 427, 1142 | "id": "58916299-b236-49a5-9168-d7d86f12da55", 1143 | "metadata": {}, 1144 | "outputs": [ 1145 | { 1146 | "name": "stdout", 1147 | "output_type": "stream", 1148 | "text": [ 1149 | "[ 3.89803571e-01 -9.96646997e-01 -1.81296639e+00 -2.11422996e-01\n", 1150 | " 1.11237384e+00 5.50917144e-01 -2.23019595e-01 1.21128170e+00\n", 1151 | " -6.06581952e-01 -1.22170200e-01 1.91518837e-02 2.59825023e+00\n", 1152 | " -1.46337609e-01 1.14671604e+00 -1.04500072e+00 -1.09172317e+00\n", 1153 | " 1.28390928e+00 -6.02485930e-01 -1.38633525e+00 -5.86013834e-01\n", 1154 | " 1.40964381e+00 2.83743064e+00 1.24784790e+00 -9.70432275e-01\n", 1155 | " -7.10446552e-02 -1.33872242e+00 4.68498640e-02 -2.96933628e-01\n", 1156 | " 5.03643044e-01 1.77543576e+00 -7.49830448e-01 -2.92437834e-01\n", 1157 | " -4.35559568e-02 -6.97971199e-01 -1.54200581e+00 1.00275982e+00\n", 1158 | " -1.58844006e-01 -1.10935264e+00 -7.16555608e-01 -3.23272959e-01\n", 1159 | " -1.20135634e+00 4.06764028e-01 -7.16184317e-01 2.85940422e-01\n", 1160 | " 1.42247711e-01 4.45880963e-01 2.65902923e+00 6.63447124e-01\n", 1161 | " 2.96592291e-02 1.96929396e-01 1.19715826e+00 -1.85437922e+00\n", 1162 | " -3.22858810e+00 -5.52233908e-01 -2.28968643e+00 1.89881467e+00\n", 1163 | " -2.69826089e-01 -1.00715497e+00 -4.31742139e-01 8.82787116e-01\n", 1164 | " -2.07622814e-01 2.40897019e+00 -2.04291252e+00 3.29187109e-01\n", 1165 | " 3.30926936e-01 5.40466498e-01 5.21807885e-01 -1.01563803e+00\n", 1166 | " 9.14845711e-01 -4.10491149e-01 6.25429816e-01 1.61488801e+00\n", 1167 | " 2.50226381e+00 6.01955673e-02 -7.91900018e-01 -1.76873686e-01\n", 1168 | " -2.28771317e+00 6.39635918e-01 1.16536545e+00 2.11140401e+00\n", 1169 | " 8.89996511e-02 1.33888259e-01 -9.99618052e-01 3.47843187e+00\n", 1170 | " 8.78792862e-01 1.11952807e-01 1.72209416e+00 6.02893362e-01\n", 1171 | " -8.45727799e-02 3.66698320e-01 7.13103677e-01 2.46110374e+00\n", 1172 | " 7.51076212e-01 2.95859024e+00 8.80468588e-01 -2.62122633e-01\n", 1173 | " -1.19645276e+00 6.12483290e-01 -2.70173794e+00 4.63667627e-01\n", 1174 | " -1.07226326e+00 1.20642437e-02 2.24335401e+00 2.29066151e+00\n", 1175 | " -2.63806008e-01 -3.71698775e-01 9.11383716e-01 -5.66627229e-02\n", 1176 | " -8.54939232e-01 -2.07023799e-01 1.79320635e+00 -8.95001470e-01\n", 1177 | " -1.75708440e+00 6.49242586e-01 -5.17078360e-01 4.03644993e-01\n", 1178 | " -5.89727554e-01 -1.64884495e+00 2.47910865e+00 9.47566147e-01\n", 1179 | " -2.24377446e-01 -2.23720885e-03 2.60619694e+00 1.59679699e-01\n", 1180 | " -3.93530441e-01 -1.51868802e+00 3.48610412e-01 8.99090255e-02\n", 1181 | " 1.15730898e+00 1.10894727e-01 -4.43198262e-02 -4.42949224e-01\n", 1182 | " -7.60400772e-01 -4.39890787e-01 2.51312671e-01 -8.79998981e-01\n", 1183 | " 1.20019254e+00 1.53487517e+00 1.66548900e-02 1.77565735e+00\n", 1184 | " 2.38695643e+00 -2.68797990e-01 1.22994696e+00 1.05242160e-01\n", 1185 | " -2.24284337e-01 -1.74365392e-01 -1.29437818e+00 1.29453605e+00\n", 1186 | " 1.01599930e+00 1.82553687e+00 -9.77754401e-01 7.55106541e-01\n", 1187 | " 4.22414110e-01 6.97701425e-01 2.74640753e+00 2.72244357e+00\n", 1188 | " 7.00360122e-01 -6.34062527e-01 1.53573036e+00 1.03219454e+00\n", 1189 | " 1.37059934e-01 -3.00395361e-01 9.46676673e-01 1.42820776e+00\n", 1190 | " -1.83891427e+00 1.60916704e-01 6.71404128e-01 -7.54523616e-02\n", 1191 | " 5.98854364e-01 -7.71613639e-01 -6.55620813e-01 -1.41284540e+00\n", 1192 | " -8.83293019e-02 3.67647495e+00 6.14790081e-01 -3.52725149e-01\n", 1193 | " 2.49560669e+00 3.17937151e+00 1.41134017e-01 4.70473738e-01\n", 1194 | " 1.56911217e-01 -8.69234430e-01 5.51784682e-01 -1.12584407e+00\n", 1195 | " -2.36649778e+00 -4.99193641e-01 -1.07530067e+00 -3.71462427e-01\n", 1196 | " 4.84255341e-01 -1.89550505e+00 1.51478485e+00 1.84459741e+00\n", 1197 | " 2.19528381e-01 -2.81967307e-01 -3.24608813e-01 1.88795676e+00\n", 1198 | " 1.27942137e+00 -1.37998819e-01 -2.05380238e-01 4.26948425e-01\n", 1199 | " -3.89293401e-01 4.01363498e-01 1.98159891e+00 -1.57827198e-01\n", 1200 | " -5.85166423e-01 -7.13019953e-01 -6.38005962e-01 1.43934620e+00\n", 1201 | " 7.15919559e-01 -5.31886315e-01 3.68193605e-01 -1.90642042e+00\n", 1202 | " 8.62036109e-01 -6.52963357e-01 1.07742582e-01 2.81807473e+00\n", 1203 | " 1.23933507e+00 2.42569902e+00 2.17956315e+00 -9.52152608e-02\n", 1204 | " 9.07720855e-01 1.01096203e+00 -3.24792922e-01 2.39230798e+00\n", 1205 | " -6.16348712e-01 2.36699218e+00 2.11694442e+00 5.38469330e-01\n", 1206 | " -3.01368494e-01 3.72789505e-01 1.87059031e+00 2.23637478e-03\n", 1207 | " 5.00828638e-01 6.83015162e-01 -2.53589557e-01 1.37718434e+00\n", 1208 | " 1.67791787e+00 4.18906651e-01 -1.38630361e+00 -8.20388707e-01\n", 1209 | " -1.77622874e+00 -2.51923289e+00 1.99978151e+00 -5.16371872e-01\n", 1210 | " 1.42943314e-01 -8.87400970e-01 8.69817718e-02 -1.54702236e-02\n", 1211 | " -2.41089198e-01 -1.48360916e+00 1.91162173e+00 1.27717093e+00\n", 1212 | " 2.13009238e-01 -5.98617984e-01 -6.32361384e-01 1.94918398e+00]\n", 1213 | "[ 1.73477886e-01 -3.78314323e-01 -1.15957538e+00 -4.55457117e-01\n", 1214 | " -1.24829119e-01 -6.82458939e-01 2.47177816e-01 6.75238190e-01\n", 1215 | " 6.32115197e-01 -6.86811856e-01 1.31945937e-01 2.90450028e+00\n", 1216 | " -7.81145498e-02 7.61810122e-01 3.53271576e-01 -1.93634324e+00\n", 1217 | " 4.34813137e-01 7.93498748e-01 -9.99374515e-01 3.57337146e-01\n", 1218 | " 7.67517595e-01 3.23606027e+00 8.09872843e-01 -1.87386090e+00\n", 1219 | " 4.02630880e-01 -1.97024956e+00 -2.02397129e+00 6.84022534e-01\n", 1220 | " -6.97145592e-01 2.57936995e+00 -1.79964854e+00 -8.49397060e-01\n", 1221 | " 6.13147645e-01 8.10354836e-01 -3.39851667e+00 1.22421018e+00\n", 1222 | " -8.60824905e-01 -1.60722203e+00 3.59145122e-01 2.01454811e+00\n", 1223 | " -1.61948743e-01 -1.00224593e+00 5.73652303e-01 6.30755277e-01\n", 1224 | " 4.14952294e-01 2.16821687e+00 2.53577484e+00 1.82889149e+00\n", 1225 | " 1.75935907e-01 -6.57751382e-02 1.10174426e+00 -2.06796838e+00\n", 1226 | " -2.94581082e+00 -1.28551166e+00 -1.51854207e+00 1.67130640e+00\n", 1227 | " -1.03049907e+00 -5.07271596e-01 -6.97514998e-01 -1.84129927e-01\n", 1228 | " -7.49230290e-01 2.12516077e+00 -5.71402462e-01 7.41986199e-01\n", 1229 | " -2.57257231e-01 9.47876568e-01 2.24098504e+00 -5.17886593e-01\n", 1230 | " 2.59016190e+00 2.18350262e+00 -2.46278041e+00 2.97659203e+00\n", 1231 | " 1.79582747e-01 2.31392899e+00 7.35421377e-01 -1.36229226e-01\n", 1232 | " -1.11759859e+00 4.11274083e-01 2.02971659e+00 2.34953279e+00\n", 1233 | " 1.39825615e-01 1.29747795e+00 -1.92871190e+00 7.62746025e-01\n", 1234 | " 1.78325199e+00 -7.74352229e-01 1.96503445e+00 -8.22589279e-01\n", 1235 | " 4.76711337e-01 1.97345695e-01 2.41591692e-04 3.03953992e+00\n", 1236 | " -3.64576371e-01 -5.85988048e-01 2.30729064e-02 2.40199793e+00\n", 1237 | " 4.04458971e-01 4.62537789e-01 -9.44197385e-01 -3.94466042e-01\n", 1238 | " -7.71524621e-01 -1.44357413e+00 3.80624581e-01 1.02348017e+00\n", 1239 | " -2.51731718e-01 -1.82386160e+00 -1.98006934e-01 1.71773232e+00\n", 1240 | " 3.39659050e-01 -1.61678386e-01 8.64877151e-01 5.27221507e-01\n", 1241 | " -1.33709202e+00 -5.17925673e-01 5.45324675e-01 -3.39876880e-01\n", 1242 | " 6.28603177e-01 1.05901119e+00 3.21904938e+00 5.65501247e-01\n", 1243 | " -1.36764496e+00 -1.05371234e+00 1.36910906e+00 2.13052453e+00\n", 1244 | " -1.53023197e+00 -7.13439829e-02 1.77023797e+00 3.76738890e-01\n", 1245 | " 5.80164396e-01 1.32691407e+00 6.60211046e-01 -5.98368078e-01\n", 1246 | " 6.98442518e-01 -2.54519675e+00 1.57567109e+00 4.16967821e-01\n", 1247 | " 4.76946804e-01 -9.59922631e-01 9.60794297e-02 1.55166016e-01\n", 1248 | " 1.81177522e+00 -1.52597836e+00 2.82341435e+00 1.21299257e-01\n", 1249 | " -8.44809950e-01 7.03000928e-02 1.53342508e-01 3.38649085e+00\n", 1250 | " 5.65857183e-01 3.49518456e+00 -1.98264737e+00 8.48937940e-01\n", 1251 | " -4.80279170e-01 8.66791453e-01 7.33455740e-01 3.36394531e+00\n", 1252 | " -7.11448994e-02 -1.67163688e+00 1.89634404e+00 5.24291752e-01\n", 1253 | " -2.32818349e+00 -9.27940928e-01 -2.84750277e-01 2.78600716e+00\n", 1254 | " -3.27495724e+00 1.23622920e+00 -4.39654288e-01 1.33676926e+00\n", 1255 | " -3.71218248e-02 -4.67805272e-01 -1.19710997e+00 -1.01201737e+00\n", 1256 | " 4.00629907e-01 2.37262771e+00 1.93417388e+00 -3.06882494e-01\n", 1257 | " 1.60856514e+00 2.23441980e+00 1.54979445e+00 1.71643700e+00\n", 1258 | " 1.36638736e-01 1.06117105e+00 -1.02864472e+00 -2.34444189e+00\n", 1259 | " -2.94609160e+00 -1.58258268e+00 2.89173790e+00 1.16163355e+00\n", 1260 | " 2.59235939e+00 -2.02021420e+00 -7.73131893e-01 1.29344678e+00\n", 1261 | " 3.99642409e-01 2.74359725e-01 -4.29611048e+00 3.21451459e+00\n", 1262 | " 8.61943075e-01 -8.29561193e-02 3.62781221e-01 -1.28975848e+00\n", 1263 | " 1.45744159e+00 -5.07338791e-02 2.80161335e-01 -1.88915363e+00\n", 1264 | " -5.48913797e-02 1.59280869e+00 -2.64753155e+00 3.05718540e-02\n", 1265 | " 8.57470089e-01 3.38945004e-01 -4.86244452e-01 3.47816665e-01\n", 1266 | " -6.40024050e-02 -1.89914441e-01 -1.09029680e+00 1.28981721e+00\n", 1267 | " 1.57027400e+00 1.02784636e+00 3.18322695e+00 -6.64335702e-01\n", 1268 | " 1.79662418e+00 3.36934653e-01 7.86187765e-01 3.36239847e+00\n", 1269 | " -6.71641773e-01 4.16299206e-01 1.67452233e+00 -2.92389720e-01\n", 1270 | " -9.64364373e-01 3.33029921e-01 6.98441369e-01 3.99958218e-01\n", 1271 | " 9.19708275e-02 -8.20748303e-01 3.45360249e-01 1.13682329e+00\n", 1272 | " 6.91683315e-01 -7.31566063e-01 8.30114479e-01 -2.31014512e+00\n", 1273 | " -3.03351615e+00 -7.30246093e-01 -2.97940655e+00 2.73203994e-01\n", 1274 | " -4.12306385e-01 -5.42596202e-01 -6.75748803e-01 -2.75624418e-02\n", 1275 | " 1.45179840e+00 -6.41767557e-01 4.80260974e-01 1.02999982e+00\n", 1276 | " 1.62948661e+00 -7.74186094e-01 -6.38427721e-01 1.36770599e+00]\n", 1277 | "[ 0.74329027 -0.3775754 -1.55002959 -0.08654639 -0.84612519 -0.66517953\n", 1278 | " -0.02449258 0.36550824 -0.09603422 -0.42939465 0.05933956 3.72453531\n", 1279 | " -0.60181018 -0.18968162 -0.33691938 -1.21067376 1.69810332 -0.68387043\n", 1280 | " -2.19420451 0.30623149 0.04384658 3.88534118 2.54776959 -0.48658014\n", 1281 | " 0.48451308 -0.31657583 -2.93257781 -0.09047868 0.27749259 2.14582287\n", 1282 | " -0.7742917 0.61943849 1.17077945 0.65668756 -1.49204705 0.82001543\n", 1283 | " -1.14004735 -1.53238315 -0.9232495 1.49888674 -1.35190679 -1.06768895\n", 1284 | " -0.98028353 0.26184235 0.33910083 -0.6371447 0.83411936 0.6429185\n", 1285 | " 0.25254067 1.61010048 1.70969211 -2.14806023 -1.63120278 0.22985059\n", 1286 | " -1.83635132 0.81144822 -0.85230331 -0.01214619 -1.0530559 0.78959619\n", 1287 | " -1.29366791 2.32742711 -0.47322131 0.18036712 0.2442932 0.57961274\n", 1288 | " 1.83769874 -1.92854169 -0.32842854 -0.12183421 1.06949212 -1.1666125\n", 1289 | " -0.18457919 1.42867819 0.22614148 1.24973795 -0.27109001 -1.30841724\n", 1290 | " -0.1323186 1.50524927 -0.06151336 0.36280611 -0.33027768 -0.68321206\n", 1291 | " 1.62078746 -1.69239535 0.69330277 2.77379231 -0.89682064 -0.79466231\n", 1292 | " 0.19352146 1.983096 0.32072105 1.41947621 0.45555747 0.90832593\n", 1293 | " -2.51358088 -0.24226676 -1.49537912 -0.2563347 -1.98138404 -0.37085698\n", 1294 | " 1.62641746 1.06244157 -0.82419462 -0.57390422 0.9138385 -1.00704416\n", 1295 | " -0.23448064 -1.37492495 -0.25682802 0.11005032 0.23172159 -1.56330803\n", 1296 | " 0.91391364 -0.34495631 0.50074237 -0.4066586 -0.00549562 3.24345052\n", 1297 | " 0.75700746 -1.69821946 1.77941427 -0.28645015 -0.57915812 1.20586644\n", 1298 | " 0.36845146 0.58970282 0.22338647 0.20552488 0.51302908 1.04770178\n", 1299 | " -0.68046897 0.52678917 -0.6986149 -1.40170292 0.61719286 -0.47468235\n", 1300 | " -0.56355147 -1.18205737 -0.18334463 1.71284032 2.02351803 0.82235456\n", 1301 | " -0.17395866 0.0966321 -0.06158389 0.85559931 -0.34557713 1.28382154\n", 1302 | " 1.02731127 -0.92987219 -0.21203948 0.63788485 0.84467619 1.97563281\n", 1303 | " -0.50550261 -1.95678332 1.12111351 -1.63963881 1.04433216 -1.59185571\n", 1304 | " 1.44506895 2.29460523 1.94414325 -0.33587755 -0.36165364 0.26807456\n", 1305 | " 0.69051003 0.19675298 1.47340625 0.2810714 0.32018736 0.70501098\n", 1306 | " 0.5535334 2.53173343 0.11880176 2.23190851 0.55329793 1.51165165\n", 1307 | " 2.05233722 0.0862875 -0.68701175 -1.41732675 -1.63925079 -1.93015384\n", 1308 | " -0.08638934 2.0572556 -0.34858196 0.83103614 -1.45979835 -1.14297843\n", 1309 | " 2.16423315 -0.06061224 0.13932088 0.5792231 1.30258981 -0.31806488\n", 1310 | " -0.71589099 -0.29994465 -1.2526677 -1.78460434 1.16588397 -0.328064\n", 1311 | " -0.74002878 0.30247325 0.3442312 -0.28380328 -0.62916205 -0.75991837\n", 1312 | " 1.97814817 0.51295979 0.36566438 -0.25536378 -0.27424215 0.04552544\n", 1313 | " 0.17500421 1.00346983 1.35499899 -0.38314174 -0.50808836 0.24078918\n", 1314 | " 0.60978008 1.91112286 0.82011784 -0.2615277 -0.26694167 -0.67379871\n", 1315 | " -2.74933261 1.37684731 1.99775643 0.10294325 0.5447847 -2.57597607\n", 1316 | " 2.63937386 0.68193744 0.08253477 0.39086499 -0.66228937 -2.01208531\n", 1317 | " -1.61198843 -0.99157155 0.74521818 -0.22949644 -0.66489836 1.76523025\n", 1318 | " 0.11179357 1.58020681 0.23154623 -0.7978364 0.02799542 0.51259032\n", 1319 | " 0.64821907 -1.05749894 -0.83042711 -1.11948847]\n" 1320 | ] 1321 | } 1322 | ], 1323 | "source": [ 1324 | "s3_client = session.client('s3')\n", 1325 | "for i in new_emb:\n", 1326 | " with tempfile.NamedTemporaryFile() as tf:\n", 1327 | " fname = tf.name + \".npy\"\n", 1328 | " np.save(fname, new_emb[i], allow_pickle=False)\n", 1329 | " print(np.load(fname, allow_pickle = False))\n", 1330 | " #s3_client.upload_file(Filename = tf.name, Bucket= 'vision-dataperf', Key = f'public_train_dataset_256_embeddings_dynabench_formatted/{i}')" 1331 | ] 1332 | } 1333 | ], 1334 | "metadata": { 1335 | "kernelspec": { 1336 | "display_name": "Python 3", 1337 | "language": "python", 1338 | "name": "python3" 1339 | }, 1340 | "language_info": { 1341 | "codemirror_mode": { 1342 | "name": "ipython", 1343 | "version": 3 1344 | }, 1345 | "file_extension": ".py", 1346 | "mimetype": "text/x-python", 1347 | "name": "python", 1348 | "nbconvert_exporter": "python", 1349 | "pygments_lexer": "ipython3", 1350 | "version": "3.9.5" 1351 | } 1352 | }, 1353 | "nbformat": 4, 1354 | "nbformat_minor": 5 1355 | } 1356 | -------------------------------------------------------------------------------- /baselines/pseudo_label_generation/dataperf_vision_experiments.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.neural_network import MLPClassifier 4 | 5 | import sys 6 | sys.path.append("dataperf-vision-selection/") 7 | from main import run_tasks 8 | from utils import SubmissionCreator, EnsembleOOD 9 | 10 | 11 | # Load the data 12 | data = pd.read_parquet('./dataperf-vision-selection/data/embeddings/train_emb_256_dataperf.parquet', engine='pyarrow') 13 | data_np = np.vstack(data['embedding'].values) 14 | data_ids = data["ImageID"] 15 | 16 | base_folder = './dataperf-vision-selection/data/' 17 | 18 | data_cupcake = pd.read_csv(base_folder + 'examples/alpha_example_set_Cupcake.csv') 19 | cupcake_ids = data_cupcake['ImageID'].values 20 | cupcake_idx = np.where([id_ in cupcake_ids for id_ in data_ids]) 21 | cupcake_np = data_np[cupcake_idx] 22 | print(cupcake_np.shape) 23 | 24 | data_hawk = pd.read_csv(base_folder + 'examples/alpha_example_set_Hawk.csv') 25 | hawk_ids = data_hawk['ImageID'].values 26 | hawk_idx = np.where([id_ in hawk_ids for id_ in data_ids]) 27 | hawk_np = data_np[hawk_idx] 28 | print(hawk_np.shape) 29 | 30 | data_sushi = pd.read_csv(base_folder + 'examples/alpha_example_set_Sushi.csv') 31 | sushi_ids = data_sushi['ImageID'].values 32 | sushi_idx = np.where([id_ in sushi_ids for id_ in data_ids]) 33 | sushi_np = data_np[sushi_idx] 34 | print(sushi_np.shape) 35 | 36 | # Experiment 1: baseline, the provided examples 37 | example_cupcake_ids = data_cupcake["ImageID"].values 38 | example_sushi_ids = data_sushi["ImageID"].values 39 | example_hawk_ids = data_hawk["ImageID"].values 40 | 41 | sc = SubmissionCreator(example_cupcake_ids, example_hawk_ids, example_sushi_ids) 42 | 43 | sc.write(base_folder + 'train_sets/cupcake.csv', 0, 44 | base_folder + 'train_sets/cupcake_baseline.json') 45 | 46 | sc.write(base_folder + 'train_sets/hawk.csv', 1, 47 | base_folder + 'train_sets/hawk_baseline.jso') 48 | 49 | sc.write(base_folder + 'train_sets/sushi.csv', 2, 50 | base_folder + 'train_sets/sushi_baseline.json') 51 | 52 | run_tasks("task_setup_new.yaml") 53 | 54 | # Experiment 2: Random selection of 333 images per class with pseudo-labels from sklearn classifier 55 | 56 | data_train = np.vstack([cupcake_np, hawk_np, sushi_np]) 57 | labels_train = np.hstack([np.ones(20) * 0, np.ones(20) * 1, np.ones(20) * 2]) 58 | 59 | mlp = MLPClassifier(25) 60 | mlp.fit(data_train, labels_train) 61 | mlp.score(data_train, labels_train) 62 | 63 | pseudo_labels = mlp.predict(data_np) 64 | 65 | sushi_ids = np.random.choice(data_ids[pseudo_labels == 2], 333, replace=False) 66 | hawk_ids = np.random.choice(data_ids[pseudo_labels == 1], 333, replace=False) 67 | cupcake_ids = np.random.choice(data_ids[pseudo_labels == 0], 333, replace=False) 68 | 69 | sc = SubmissionCreator(cupcake_ids, hawk_ids, sushi_ids) 70 | 71 | sc.write(base_folder + "train_sets/cupcake.csv", 0, base_folder + "train_sets/cupcake_nn.json") 72 | sc.write(base_folder + "train_sets/hawk.csv", 1, base_folder + "train_sets/hawk_nn.json") 73 | sc.write(base_folder + "train_sets/sushi.csv", 2, base_folder + "train_sets/sushi_nn.json") 74 | 75 | # Experiment 3: Select samples based on uncertainty 76 | ood = EnsembleOOD(mlp, data_train, labels_train, n_models=50) 77 | pseudo_labels = ood.predict(data_np) 78 | uncertainty = ood.predict_value(data_np, None) 79 | prob = uncertainty / uncertainty.sum() 80 | 81 | sushi_ids = np.random.choice(data_ids[pseudo_labels == 2], 333, replace=False, 82 | p=prob[pseudo_labels == 2] / np.sum(prob[pseudo_labels == 2])) 83 | hawk_ids = np.random.choice(data_ids[pseudo_labels == 1], 333, replace=False, 84 | p=prob[pseudo_labels == 1] / np.sum(prob[pseudo_labels == 1])) 85 | cupcake_ids = np.random.choice(data_ids[pseudo_labels == 0], 333, replace=False, 86 | p=prob[pseudo_labels == 0] / np.sum(prob[pseudo_labels == 0])) 87 | 88 | sc = SubmissionCreator(cupcake_ids, hawk_ids, sushi_ids) 89 | sc.write(base_folder + "train_sets/cupcake.csv", 0, base_folder + "train_sets/cupcake_ood.json") 90 | sc.write(base_folder + "train_sets/hawk.csv", 1, base_folder + "train_sets/hawk_ood.json") 91 | sc.write(base_folder + "train_sets/sushi.csv", 2, base_folder + "train_sets/sushi_ood.json") 92 | -------------------------------------------------------------------------------- /baselines/pseudo_label_generation/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | import sklearn 5 | 6 | 7 | def make_json(train_file, file_name, data_class_name): 8 | """""" 9 | df = pd.read_csv(train_file) 10 | dct_ = {data_class_name: {}} 11 | dct = dct_[data_class_name] 12 | 13 | for index, row in df.iterrows(): 14 | dct.update({row["ImageID"]: row["Confidence"]}) 15 | 16 | json_string = json.dumps(dct_) 17 | json_file = open(file_name, "w") 18 | json_file.write(json_string) 19 | json_file.close() 20 | 21 | 22 | class SubmissionCreator: 23 | 24 | def __init__(self, cupcake_ids, hawk_ids, sushi_ids): 25 | """""" 26 | self.cupcake = cupcake_ids 27 | self.hawk = hawk_ids 28 | self.sushi = sushi_ids 29 | 30 | self.names = ["Cupcake", "Hawk", "Sushi"] 31 | 32 | def write(self, file_name, target_class, file_name_json=""): 33 | """""" 34 | data = np.hstack((self.cupcake, self.hawk, self.sushi)) 35 | labels = np.zeros(data.shape[0]) 36 | 37 | if target_class == 0: 38 | labels[:self.cupcake.shape[0]] = 1 39 | 40 | if target_class == 1: 41 | labels[self.cupcake.shape[0]:self.cupcake.shape[0] + self.hawk.shape[0]] = 1 42 | 43 | if target_class == 2: 44 | labels[-self.sushi.shape[0]:] = 1 45 | 46 | df = pd.DataFrame(data={"ImageID": data, "Confidence": labels.astype(int)}) 47 | df.to_csv(file_name, index=False) 48 | 49 | if file_name_json != "": 50 | make_json(file_name, file_name_json, self.names[target_class]) 51 | 52 | 53 | class EnsembleOOD: 54 | 55 | def __init__(self, model, X, y, n_models=10): # noqa 56 | """""" 57 | self.model = model 58 | self.num_models = n_models 59 | 60 | self.models = [sklearn.base.clone(self.model).fit(X, y) for _ in range(self.num_models)] 61 | 62 | def predict_value(self, X, _y): # noqa 63 | predictions = np.array([model.predict(X) for model in self.models]) 64 | std = np.std(predictions, axis=0) 65 | return std 66 | 67 | def predict(self, X): # noqa 68 | """Compute predictions with majority voting for the ensemble.""" 69 | predictions = np.array([model.predict(X) for model in self.models]) 70 | mean = np.mean(predictions, axis=0) 71 | return np.round(mean) 72 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # For setup file 2 | DEFAULT_SETUP_YAML_PATH = 'task_setup.yaml' 3 | SETUP_YAML_LOCAL_DATA_DIR_KEY = 'data_dir' 4 | SETUP_YAML_DOCKER_DATA_DIR_KEY = 'docker_data_dir' 5 | SETUP_YAML_DIM_KEY = 'dim' 6 | SETUP_YAML_TASKS_KEY = 'eval_tasks' 7 | SETUP_YAML_EMB_KEY = 'emb_file' 8 | SETUP_YAML_SPARK_MEM_KEY = 'spark_driver_memory' 9 | SETUP_YAML_RESULTS_KEY = 'results_dir' 10 | 11 | # Column names in train, test and embedding files 12 | EMB_COL = 'embedding' 13 | ID_COL = 'ImageID' 14 | LABEL_COL = 'Confidence' 15 | 16 | # For results 17 | RESULT_FILE_PREFIX = 'result' 18 | 19 | # General 20 | RANDOM_SEED = 0 21 | -------------------------------------------------------------------------------- /data/embeddings/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /data/examples/alpha_example_set_Cupcake.csv: -------------------------------------------------------------------------------- 1 | LabelName,ImageID,Confidence,OriginalURL,DisplayName 2 | /m/03p1r4,04e3345ac6812367,1,https://farm8.staticflickr.com/4041/4703281077_377b2381b6_o.jpg,Cupcake 3 | /m/03p1r4,283b1f143f257cb2,1,https://farm3.staticflickr.com/2640/3934727884_2d6c0e840e_o.jpg,Cupcake 4 | /m/03p1r4,0048aff8e997b7bb,1,https://c4.staticflickr.com/4/3452/3959385380_1a4c1a0c72_o.jpg,Cupcake 5 | /m/03p1r4,0e595758eca6bc7b,1,https://c4.staticflickr.com/4/3466/3870361733_458cfb9811_o.jpg,Cupcake 6 | /m/03p1r4,03822e0571376f3a,1,https://c3.staticflickr.com/3/2572/3735793704_ea2426c4a2_o.jpg,Cupcake 7 | /m/03p1r4,088cef39f4d9600b,1,https://c4.staticflickr.com/4/3435/3828923954_c945cb47eb_o.jpg,Cupcake 8 | /m/03p1r4,0c7713ac6a33580d,1,https://c6.staticflickr.com/9/8524/8680026031_c90b37a52f_o.jpg,Cupcake 9 | /m/03p1r4,8661675b69bcd451,1,https://farm2.staticflickr.com/2939/14553454122_a8ed4a39a2_o.jpg,Cupcake 10 | /m/03p1r4,19d61a65a14fc338,1,https://c1.staticflickr.com/5/4014/4303193380_ac6921e649_o.jpg,Cupcake 11 | /m/03p1r4,127064a5f9cbc7fb,1,https://c4.staticflickr.com/8/7047/6818757202_1ebd77a426_o.jpg,Cupcake 12 | /m/03p1r4,4f28a31461581f24,1,https://farm5.staticflickr.com/8086/8543665572_cfcd413833_o.jpg,Cupcake 13 | /m/03p1r4,0cf0c2409697ac19,1,https://c2.staticflickr.com/2/1196/5145082590_9d4c4cc564_o.jpg,Cupcake 14 | /m/03p1r4,0285bbd389b7b61b,1,https://c4.staticflickr.com/3/2756/4391995615_f318b2488f_o.jpg,Cupcake 15 | /m/03p1r4,063eae052bb7d0db,1,https://c4.staticflickr.com/8/7347/12443516385_0a515fd7cc_o.jpg,Cupcake 16 | /m/03p1r4,086dd93ba3986bf4,1,https://farm5.staticflickr.com/3129/2876175548_0e861137d6_o.jpg,Cupcake 17 | /m/03p1r4,0548f98e537f1e6e,1,https://c7.staticflickr.com/4/3531/3826774014_d251f5b27d_o.jpg,Cupcake 18 | /m/03p1r4,1cc97730294411bc,1,https://c5.staticflickr.com/9/8679/16330521258_331b7f6c3b_o.jpg,Cupcake 19 | /m/03p1r4,61c83243b4dfd6a3,1,https://farm4.staticflickr.com/4026/4342942316_a19f9f3407_o.jpg,Cupcake 20 | /m/03p1r4,0574b4df4b28a75c,1,https://c5.staticflickr.com/3/2488/4193123064_6cbecf1b86_o.jpg,Cupcake 21 | /m/03p1r4,0dd9e1b7cdbe2b95,1,https://farm6.staticflickr.com/3725/13572801914_2a15942d70_o.jpg,Cupcake 22 | -------------------------------------------------------------------------------- /data/examples/alpha_example_set_Hawk.csv: -------------------------------------------------------------------------------- 1 | LabelName,ImageID,Confidence,OriginalURL,DisplayName 2 | /m/0fp7c,58359817daa6a0ff,1,https://c8.staticflickr.com/7/6119/6406452785_710d4a404e_o.jpg,Hawk 3 | /m/0fp7c,920f445b8580ae55,1,https://farm6.staticflickr.com/44/162525080_c5b31676f0_o.jpg,Hawk 4 | /m/0fp7c,0991a6b74935b991,1,https://farm2.staticflickr.com/7410/8994260910_3f184016c0_o.jpg,Hawk 5 | /m/0fp7c,0bb6ced3dbfd0028,1,https://c1.staticflickr.com/9/8293/7625588028_85c718fde3_o.jpg,Hawk 6 | /m/0fp7c,03225433b89f66f0,1,https://farm1.staticflickr.com/8354/8345997333_7376bf043e_o.jpg,Hawk 7 | /m/0fp7c,055323b81ac99c57,1,https://farm8.staticflickr.com/4067/4275906471_0f02561237_o.jpg,Hawk 8 | /m/0fp7c,0b8cbed949c1348c,1,https://c5.staticflickr.com/7/6211/6299251927_032762f446_o.jpg,Hawk 9 | /m/0fp7c,162987bf2dedd64c,1,https://c5.staticflickr.com/6/5529/9822190545_9f171a476c_o.jpg,Hawk 10 | /m/0fp7c,19b6f08a80e6cf6f,1,https://c8.staticflickr.com/3/2143/2486820398_8cd643f0f2_o.jpg,Hawk 11 | /m/0fp7c,01fe78acc4b6cff2,1,https://farm3.staticflickr.com/8432/7811850262_4c0a670124_o.jpg,Hawk 12 | /m/0fp7c,03d1abe089763813,1,https://farm4.staticflickr.com/5524/11614684325_1788fcd159_o.jpg,Hawk 13 | /m/0fp7c,18bdd57f2c3649e9,1,https://c2.staticflickr.com/1/57/156768050_71668cd26b_o.jpg,Hawk 14 | /m/0fp7c,0ed6cf921a06ffc3,1,https://farm4.staticflickr.com/7543/15837042400_210de31bdb_o.jpg,Hawk 15 | /m/0fp7c,3a17d7c310845da4,1,https://farm5.staticflickr.com/7329/15847340443_7fd343f22e_o.jpg,Hawk 16 | /m/0fp7c,1e5f5e1c51e433a7,1,https://farm8.staticflickr.com/2844/12019952706_01921a24a1_o.jpg,Hawk 17 | /m/0fp7c,33d07430c9e1c1c0,1,https://farm1.staticflickr.com/2414/1676415495_80b38ec6ba_o.jpg,Hawk 18 | /m/0fp7c,25438ec6557bbeef,1,https://c8.staticflickr.com/9/8576/16505512336_22c1400200_o.jpg,Hawk 19 | /m/0fp7c,0454fcd2f6aa87b6,1,https://farm7.staticflickr.com/3141/2675478886_a9bf3f97b9_o.jpg,Hawk 20 | /m/0fp7c,172b142392991d65,1,https://farm6.staticflickr.com/3085/3171199512_f938a191fe_o.jpg,Hawk 21 | /m/0fp7c,0fb32f8309f9e657,1,https://farm8.staticflickr.com/1269/4679488219_3a3c9a63cb_o.jpg,Hawk 22 | -------------------------------------------------------------------------------- /data/examples/alpha_example_set_Sushi.csv: -------------------------------------------------------------------------------- 1 | LabelName,ImageID,Confidence,OriginalURL,DisplayName 2 | /m/07030,07b185567034bbb1,1,https://farm2.staticflickr.com/3778/11680681144_8a85e2cb49_o.jpg,Sushi 3 | /m/07030,e12b26befd92e2a9,1,https://c2.staticflickr.com/8/7525/16106981981_5a0ab10fd4_o.jpg,Sushi 4 | /m/07030,41d62e93c8333940,1,https://farm1.staticflickr.com/5070/5644991286_af30efcbd9_o.jpg,Sushi 5 | /m/07030,1e9f698a78a97100,1,https://farm1.staticflickr.com/2031/2439673474_e991aca2d3_o.jpg,Sushi 6 | /m/07030,65776b7c8603ce56,1,https://farm7.staticflickr.com/5174/5596202858_1aa5c16ee9_o.jpg,Sushi 7 | /m/07030,24571f400a6277b8,1,https://c1.staticflickr.com/3/2595/3951286521_3f32cd30ab_o.jpg,Sushi 8 | /m/07030,881c906e26ec47bd,1,https://c6.staticflickr.com/8/7308/16292070217_acf77c41e3_o.jpg,Sushi 9 | /m/07030,bc15b56a519d9eba,1,https://c3.staticflickr.com/2/1403/655402748_5e25e53ba8_o.jpg,Sushi 10 | /m/07030,0bbe9d3461da8005,1,https://farm6.staticflickr.com/8608/16205382902_fb42692876_o.jpg,Sushi 11 | /m/07030,6973329af45abd17,1,https://farm8.staticflickr.com/8434/7830740592_f6d86ba192_o.jpg,Sushi 12 | /m/07030,7d7e2bbd64b31a36,1,https://farm7.staticflickr.com/3132/3147501854_839c8e5c7d_o.jpg,Sushi 13 | /m/07030,3274aab7ff102683,1,https://farm4.staticflickr.com/3932/15360031479_2ea6152bbb_o.jpg,Sushi 14 | /m/07030,2a711993ced55009,1,https://c2.staticflickr.com/1/137/334330355_15cede6219_o.jpg,Sushi 15 | /m/07030,292d254b27992f39,1,https://c2.staticflickr.com/6/5456/9009681840_d763a0a7db_o.jpg,Sushi 16 | /m/07030,58d70c954b297c6a,1,https://farm4.staticflickr.com/7404/8929944658_2653ca4be0_o.jpg,Sushi 17 | /m/07030,07382aad542bc263,1,https://farm4.staticflickr.com/76/390941710_8e658151ab_o.jpg,Sushi 18 | /m/07030,0be303aafbfe95cb,1,https://c4.staticflickr.com/6/5263/5622576361_f85a9975ac_o.jpg,Sushi 19 | /m/07030,021a6b70ecd95934,1,https://farm5.staticflickr.com/22/28920062_b95d2d605a_o.jpg,Sushi 20 | /m/07030,0f259cd129018e90,1,https://farm1.staticflickr.com/8572/15945575187_17c0a01a22_o.jpg,Sushi 21 | /m/07030,64275fd17d9c11d9,1,https://c2.staticflickr.com/3/2056/2150626616_2cfe4a423d_o.jpg,Sushi 22 | -------------------------------------------------------------------------------- /data/results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.result_for_random_500.json 4 | !.gitignore -------------------------------------------------------------------------------- /data/results/result_for_random_500.json: -------------------------------------------------------------------------------- 1 | { 2 | "Cupcake": { 3 | "accuracy": 0.5401459854014599, 4 | "recall": 0.463768115942029, 5 | "precision": 0.5517241379310345, 6 | "f1": 0.5039370078740157 7 | }, 8 | "Hawk": { 9 | "accuracy": 0.296551724137931, 10 | "recall": 0.16831683168316833, 11 | "precision": 0.4857142857142857, 12 | "f1": 0.25000000000000006 13 | }, 14 | "Sushi": { 15 | "accuracy": 0.5185185185185185, 16 | "recall": 0.6261682242990654, 17 | "precision": 0.638095238095238, 18 | "f1": 0.6320754716981132 19 | } 20 | } -------------------------------------------------------------------------------- /data/test_sets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /data/train_sets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.random_500.csv 4 | !.gitignore -------------------------------------------------------------------------------- /data/train_sets/random_500.csv: -------------------------------------------------------------------------------- 1 | ImageID,Confidence 2 | 0002643773a76876,0 3 | 0016a0f096337445,0 4 | 0036043ce525479b,1 5 | 00526f123f84db2f,1 6 | 0080db2599d54447,1 7 | 00978577e9fdd967,1 8 | 00999e0a1c66fa9e,0 9 | 00ae6f48022f527e,1 10 | 00ccbda21e2cc314,1 11 | 00fe9ad027307049,0 12 | 014277a08124380f,0 13 | 017b8e86d7a13543,0 14 | 01a1201b26a7fdc3,1 15 | 01a22546102bd9e8,0 16 | 01b0b1eb40d21471,1 17 | 01b66b3974dbeae9,0 18 | 01f39a574ee5fe0b,1 19 | 020861abeb63ced7,0 20 | 0210ebe5c31c9e4a,0 21 | 021b127ef86e33bb,0 22 | 0238a7e4aeb6e22f,1 23 | 027e69f951f3e970,0 24 | 027f4a5d585c6f54,0 25 | 027f6b1b3c35c224,1 26 | 02837ef702581393,1 27 | 02a24c3e6c5978f2,0 28 | 02c9db8dbfa39515,0 29 | 02e9d387fe04ebd5,0 30 | 02ede7ccd20f04ad,1 31 | 0322f572caa700d6,1 32 | 03283ae91d87a97b,0 33 | 033a3f1bd43e3112,1 34 | 033e8f0c9c33f138,0 35 | 036ed5ce09c82389,0 36 | 03bac6cfe5a8dc39,1 37 | 03bc484df524e1dd,0 38 | 03dfc8b7e008218d,1 39 | 03e30b9bfb3fff4d,1 40 | 03e96e9c6306f122,1 41 | 0408b3cd3dbd9bc2,1 42 | 041610af4aba2faf,0 43 | 042e922e3eb7036d,1 44 | 04652580e6ca3182,1 45 | 047067c3af79c3ca,1 46 | 04915967aaf19826,0 47 | 04957b97d746fc3e,1 48 | 04c6b3e741cacab3,1 49 | 050ab2cd054846c9,1 50 | 051540940bf53eac,1 51 | 052fe61dc71b0eba,1 52 | 0538e05f269f934c,0 53 | 05856fdeff423cb0,1 54 | 0589bffe7dff2c8a,0 55 | 0593a46ed8ff0f99,0 56 | 059d11c42caa06e5,0 57 | 059fc3db254afaa9,1 58 | 05c5e3c2b0e14842,0 59 | 05f9dafe672ca3b1,0 60 | 062ce07c09a59d1c,1 61 | 0655152dba1f84cc,0 62 | 072c4fb6f5651a1b,0 63 | 073fe85b4abbaa83,0 64 | 074f6faf7c4f2424,0 65 | 07826e354f362710,0 66 | 079c907f111f0b68,0 67 | 07b0459fc902c4b3,1 68 | 07e670e9df51fae3,0 69 | 07ed8d27cb1aebc9,0 70 | 07f525146d75da12,1 71 | 08085c0be08cbe27,1 72 | 083afb44856c10b4,1 73 | 0858a52cc1a80652,0 74 | 085baf695c29d1f3,0 75 | 08946585b8ea5eeb,0 76 | 08b7a60c169c8a68,1 77 | 08cf26629ef459fa,1 78 | 08d52981eea69dc5,0 79 | 08ecb8adf2030622,0 80 | 08f9c841c2a39bf5,0 81 | 09370e426b4f2478,1 82 | 095eab7353db3e6d,1 83 | 096dda4366d118ba,1 84 | 09b6e6848368539e,1 85 | 0a0bdb0cebfa67ee,0 86 | 0a3340f3516f2479,1 87 | 0a4e63f6c188485e,1 88 | 0a684b8fc669b812,0 89 | 0a69b418ad5dbdb0,0 90 | 0a78171e2fe9b152,1 91 | 0a983561e47a4398,1 92 | 0a9d308e5fa151bb,0 93 | 0acbb703c2d13cb2,0 94 | 0ad93624486df1a7,0 95 | 0ad977c36c026730,1 96 | 0aeb361cac11225d,0 97 | 0af2834572e5edfb,0 98 | 0b19e9430e1684ae,0 99 | 0b2373e4fcd56ab0,0 100 | 0b35185242de8b20,0 101 | 0b363dcb204ee6fa,0 102 | 0b44fcd48ce7a66b,1 103 | 0b8bc31e82e8ac51,0 104 | 0bb27693b12008d6,1 105 | 0c5d8bfa724854ed,0 106 | 0c6908ccfc636745,1 107 | 0c8bc6311a5f6b28,1 108 | 0cbfea27c75b7854,1 109 | 0d61a766a5e052c1,1 110 | 0d648bbe5b8a9600,0 111 | 0d7e09d990723b07,0 112 | 0d8787e0f401938b,1 113 | 0dc3994cc6311deb,1 114 | 0dd3211c2ad294a3,1 115 | 0ded1bb5d7b9d152,1 116 | 0e26edc5bc0c8311,0 117 | 0e2ddf22657cb827,1 118 | 0e39471801ad2a43,1 119 | 0e56d521cd207050,1 120 | 0e60973b10e64411,1 121 | 0e69efd1f7b44905,0 122 | 0ecc00179ef54ca5,1 123 | 0ef89705e4158105,1 124 | 0f153d6353f417d4,1 125 | 0f3381b38fcd5cdc,1 126 | 0f5834a4f70f7b20,1 127 | 0f872e1abf7ac6da,0 128 | 0fcfc644a29aa10b,1 129 | 0fe4c64f8692ad84,0 130 | 1000a5246fa70d6d,1 131 | 10103bc8d81860d7,1 132 | 10188cd6c38e5ab8,1 133 | 101921f7b6569691,0 134 | 103a0ee96b890b7a,1 135 | 10407470ddbe7f33,0 136 | 1060886ac7527790,0 137 | 1092892be3a1d74a,1 138 | 10bc4f9d4f9cf0c8,0 139 | 10cc91f3caccdaf5,0 140 | 10ed93f362a259b0,1 141 | 1113516133edbabc,0 142 | 11435a877f12aed5,0 143 | 1169b3404fe2843b,0 144 | 1170571276b6bcad,1 145 | 1174fafefae0428b,0 146 | 1182bd717c802dcc,1 147 | 1186520afb0493a8,0 148 | 119b9d88404d8d48,1 149 | 11bb6b1fb5702099,0 150 | 11bf96bd0e9882db,0 151 | 11da6f070d4771ef,0 152 | 120022a6a24e90cd,1 153 | 12043f1344f75048,1 154 | 123847b6982c7e1e,1 155 | 1246d649fd5e931d,1 156 | 12498a6ca767492e,0 157 | 124cb0d63a578f9e,0 158 | 124e720a8b9a19db,1 159 | 1263ad7c3ee86f27,0 160 | 129686341e997b40,1 161 | 12d2ce2b79ab9573,0 162 | 131407e04be18ace,1 163 | 1320bbabc132dd2f,0 164 | 134b1c18e6af837f,1 165 | 134edfc0e31dd303,0 166 | 135dc87be1b86802,1 167 | 135f588d040ea7fc,0 168 | 1375d5331cf28f83,0 169 | 1391bcfe081f19ad,1 170 | 13c35bcd1a6dd96b,1 171 | 13c98b0387aa0869,1 172 | 13f865503304fe7f,0 173 | 1419b0ba4824bd3a,0 174 | 141de6f0bb105207,0 175 | 1473761fb505e238,0 176 | 14cba893e4652d12,1 177 | 14f955e1ccb4ce9c,1 178 | 150c090d425a726a,0 179 | 1559343cf042ee0a,0 180 | 160d65e19cd0cdee,0 181 | 1612223f74fb4336,1 182 | 16366d31cf297c34,1 183 | 165098cf74ff01fa,1 184 | 168aa9f930331bde,1 185 | 16bda96251f1b2eb,0 186 | 16fae7ccc1c8d87d,1 187 | 17237768756d0617,0 188 | 1729bf4a981ba70e,0 189 | 17878e5abe11fb8b,0 190 | 178c462adad58992,1 191 | 179540365378298e,0 192 | 17ddd6c38fe6a11f,0 193 | 17f6bd54f41ecd24,1 194 | 1829c7528398489f,1 195 | 183368f31e6fa221,0 196 | 184c8e263fe0137d,1 197 | 1856d5dd7a355dfb,1 198 | 187187dd5c763ad3,0 199 | 187e144c1e8f8e41,1 200 | 18bc27b7dcf34a71,0 201 | 18bd73f2e03b0435,1 202 | 18f26cea382ed7b8,1 203 | 18f3454604de81c5,1 204 | 18f4cac0fa308c26,1 205 | 19153d748d8d5a1c,0 206 | 196665cfc99666e8,0 207 | 19c860012d34b7d5,1 208 | 19e8a28edca6a66b,0 209 | 1a17b0157e89480b,1 210 | 1a26647e9bd3a9db,1 211 | 1a3c2aca3ac642f8,0 212 | 1a446826a6216c9f,0 213 | 1a5fafb0a5ebbe5a,1 214 | 1a8111e0b7aadc91,0 215 | 1aa8314c51ba10ed,0 216 | 1aab8e89f1004b4b,0 217 | 1b62252ca65a12ff,1 218 | 1b643fda9337f548,1 219 | 1b977827d49f7759,0 220 | 1bb4b842b7353905,0 221 | 1bc488ddb14b84ac,0 222 | 1bdc4fb329706237,0 223 | 1c13f85e453a54e4,1 224 | 1c24bbdbcadbe3f4,1 225 | 1ca2cec096be16d4,0 226 | 1cdab95f1e784df0,1 227 | 1d0e19f531cb4898,1 228 | 1d19070c56721c8b,1 229 | 1d2f81eb5d0d5211,1 230 | 1d42dfa41fbc0a34,1 231 | 1d5aebd5afbc6f2b,1 232 | 1d5d4c3437b466b2,0 233 | 1d8d014496cb4824,0 234 | 1da26a35f3e8abf4,1 235 | 1dd47f1748a2bcf0,0 236 | 1de37429bf3cb6e8,1 237 | 1dec3dded7a4dae4,1 238 | 1e244b7e9f0b8fd9,0 239 | 1e28a5c7598ce0f9,0 240 | 1e41948a34480596,0 241 | 1e436799490e7eeb,0 242 | 1e595e59b37e4a27,0 243 | 1e9a4a1d0fc4ca29,0 244 | 1e9b191d90c2266e,1 245 | 1f192b722b0ad354,1 246 | 1f1ceaeded18db96,0 247 | 1f208689b80cee08,1 248 | 1f462ce1ca995b8c,1 249 | 1f6a53c6b5721fea,0 250 | 1fafa775240417ec,0 251 | 1fb250a9ef9d5a0c,0 252 | 1fe24d5640ce14b1,0 253 | 200b7847da358a2d,0 254 | 200d3d72a4523c40,0 255 | 2011183f59037ef0,0 256 | 201547df732c4156,1 257 | 2048581ccb95d815,1 258 | 2066a81e97489b4c,1 259 | 20759e94df7e69d9,0 260 | 20b6978d63ce149c,0 261 | 20d90ec2e11fc5b4,0 262 | 20ec4206c96cbd3d,0 263 | 2195261d7bcc65c2,0 264 | 21fecdabcc50d165,0 265 | 22640048749f1c61,0 266 | 226c6ad77da16dce,0 267 | 227edbb5aeeae5e9,1 268 | 22c1fba27a65f29c,1 269 | 23071dc96012937c,0 270 | 2315fe86de57a0f8,1 271 | 233404499f9c93c2,1 272 | 23754137155e09b8,0 273 | 2387aace9f6caa87,1 274 | 23f193b64cb7b6fe,1 275 | 243efee4fa9871de,0 276 | 244dd102dd6515e1,1 277 | 245d7557726b022c,1 278 | 245ebc538d8f9229,1 279 | 24c96ef04bf38053,1 280 | 24d6cb92e781a022,1 281 | 25473b5dff92c79d,0 282 | 269317351e751f49,1 283 | 26941883c642ca70,0 284 | 26adf38aadb3aecb,1 285 | 26e9798a6cf03d7c,1 286 | 26fd27d06044746e,0 287 | 272b18a95bd6a03a,0 288 | 274bc1c79f1b98e9,0 289 | 281cd9a58dfcd265,1 290 | 28b7d8656314a9d5,0 291 | 28cbd682e022cf51,1 292 | 292db09aebe28d57,0 293 | 29a3465795676349,0 294 | 2a0193b50f7e409a,0 295 | 2a0bec44f6933b30,1 296 | 2a43c69ad62d823f,0 297 | 2a43e6c4d1122207,1 298 | 2a47907f0d80ad46,0 299 | 2a8422b853757ddd,1 300 | 2b030fd83d070666,1 301 | 2b13140e618e8146,1 302 | 2b304285abaaa6c5,0 303 | 2b49e3793542e110,1 304 | 2b62224b42e0bb1f,1 305 | 2b853fea1d218995,1 306 | 2ba2b02bf30d4f12,0 307 | 2bc36d77f0c12ea6,0 308 | 2bf49d8e608783db,1 309 | 2c0c5ce55d3b9632,1 310 | 2c1ff4a2ec540cdc,1 311 | 2c32a3dd8b886ccc,0 312 | 2c3c25c576b1de15,1 313 | 2c54593283207177,0 314 | 2c70b971603d4411,0 315 | 2c7b082ff81eaf95,0 316 | 2c7d448252768342,0 317 | 2c9da9d471aaebc9,0 318 | 2cb0911be5234a78,1 319 | 2cb3413bd5efa7b7,1 320 | 2cc4a0f1e4d2c9c9,1 321 | 2d1df68ce872bb9f,0 322 | 2d27bbfd78f9f925,1 323 | 2d5759f12fa9ae85,1 324 | 2d734b2824575942,0 325 | 2dbbb4254b5aed3a,0 326 | 2df2b67715389791,1 327 | 2e0068ac814f792d,0 328 | 2e0f9b2466164356,1 329 | 2e1db05316089325,1 330 | 2e807a42b7b37883,0 331 | 2efa2ff77d256312,0 332 | 2f11455e210fdedf,1 333 | 2f179ca1a5f7039b,0 334 | 2f1adc343f8598d5,0 335 | 2f33584c5ffe72d5,0 336 | 2f5c0482b9732b25,1 337 | 2f73b6bff89bf15a,0 338 | 2f8909ce4e9dae11,0 339 | 2f98d2a51394bee2,1 340 | 2f9f04706938f8af,1 341 | 2fd542f6ac430a94,0 342 | 2fec7a7a1436d1c2,0 343 | 2feec97810f64d63,0 344 | 2ffa3756e321ac79,1 345 | 3019fdd7751a0295,0 346 | 30852efeea1cf314,0 347 | 308a9e78d915d9e9,0 348 | 30b1c6c5eca7f9b8,1 349 | 30ba496d7af5d9a0,1 350 | 30f937df9fc3dd31,0 351 | 3109e63fb93c6e65,1 352 | 31103e5a4701c98c,1 353 | 3127bbd0befaf782,0 354 | 3134bc984a2e3d5c,1 355 | 313977213b0d3c67,1 356 | 3144d031e1d9924d,1 357 | 3170776ff4406e6c,0 358 | 317179c2bc38d651,1 359 | 317678b136554c09,1 360 | 3191765aeac56c0d,1 361 | 31cbcdd5de59681d,0 362 | 31d209e52e881592,0 363 | 31e24288dc274d86,0 364 | 31fddabdb730b698,1 365 | 32099764d1d36634,0 366 | 324b83570c210bd6,1 367 | 325060723d3e21f0,1 368 | 3274a8a2ca494416,1 369 | 328c9475df30f96c,0 370 | 32d021d6e16faffe,0 371 | 32f4f6c0f2da7a97,1 372 | 333e36d6df1426f4,1 373 | 335f9a9f2b05e506,1 374 | 3371da82a80bf4cf,0 375 | 337fdbb659255d17,0 376 | 33e2f04f26101766,0 377 | 33fb3b378f677fc0,1 378 | 34833559bf9ac10b,1 379 | 34bcf62c7a6d20d5,1 380 | 34c0fffcfd317ad0,1 381 | 34e9a871c25e5254,1 382 | 34ec2c8d87192de0,1 383 | 35055b02a0cbaf87,1 384 | 3520d48978e4b1a7,1 385 | 35318d0d30b6bc7f,0 386 | 359ad26db2bdfad1,1 387 | 359bb1dd6df593b2,1 388 | 35b3d53c2e85deb0,0 389 | 35ed6e0029a50970,1 390 | 361f25fe375924eb,1 391 | 364b25a08774360a,0 392 | 364d10980606eb0c,1 393 | 36717da18ce3beab,0 394 | 367edbcf140248f0,0 395 | 36a5d8b4ddebccd2,0 396 | 36b91e58ae716df1,1 397 | 37106bd1e4921449,1 398 | 371d84b0b24527ce,1 399 | 3741db2dfe9ba01c,0 400 | 37af8f2f0515a0de,0 401 | 37b428bbf7cdac16,1 402 | 37d3f43cd3283a65,1 403 | 37f2aec9140b6146,0 404 | 38298c9c9c07b331,0 405 | 384e15769e0ece9d,1 406 | 3877b5b64d41f204,1 407 | 38833fd2cdcef112,0 408 | 38abb45d15a3a84e,0 409 | 38de74053a55c22a,1 410 | 38de92959fa7efcc,0 411 | 391d57896fe7fbad,1 412 | 394d6fc65e3d4f8e,1 413 | 395140593803e508,1 414 | 395d8800157e415b,1 415 | 3970ddb1df7b91b6,1 416 | 397168f7bf3f6653,0 417 | 39794020134fa81e,0 418 | 39e5af85b70fd17e,1 419 | 3a0c7d2da3eb33d4,1 420 | 3a32ebdfdeebc0bb,1 421 | 3a452f4a1cf72231,0 422 | 3acfd3080478564b,0 423 | 3ad343021062e934,0 424 | 3ae6bbb6c71a4e8f,0 425 | 3b028d879278e08c,1 426 | 3b23f46222944cbc,1 427 | 3b93b36bb3df7daf,0 428 | 3be6c6fe9a104131,0 429 | 3bfccb17fff4a25d,0 430 | 3c1e13485fff69fe,1 431 | 3c275a4741f26ff0,1 432 | 3c34214430a3ae97,0 433 | 3c3f0ec0de3fc5ce,1 434 | 3c4450673e1186b5,1 435 | 3c5840e6dd4a9c95,1 436 | 3c6949c3c02f692d,0 437 | 3c858852b73e4e07,1 438 | 3c8d0cdca76aef58,0 439 | 3c9fec64e9e3bb3a,1 440 | 3cbe1a201cc01b30,0 441 | 3d2bef3d8c660d37,1 442 | 3d43ac868b2b469f,1 443 | 3d530359cc8d7722,0 444 | 3d6bbe368fe9c8ce,1 445 | 3d8e69b99c177e6f,0 446 | 3de9a9ce79b2086c,1 447 | 3e07e3bb4bfbacac,1 448 | 3e2ec5f3ffb39ad3,0 449 | 3e46802f0df37952,1 450 | 3ea731bd8cfd98a2,0 451 | 3eb6b4a4ff8f031b,0 452 | 3ec5f97d27abaec6,1 453 | 3f20024b1aa83fe8,1 454 | 3f24d6a573aaca34,1 455 | 3f3aeca80046b05b,1 456 | 3f7bd2c21a3b77df,1 457 | 3fd36c3be1cbb924,1 458 | 3fef955bb604f22f,1 459 | 40242c87b6e22bba,0 460 | 406b8c6a6cefce38,0 461 | 409bfcca8d4b2be5,0 462 | 40b9d8ec678cb0ed,0 463 | 40de594d9ed169b3,0 464 | 4108e999cdb3626f,1 465 | 41168302ba3b3ae6,0 466 | 416fe6a79d6f94ef,1 467 | 417be12f43852235,0 468 | 418d578372036682,0 469 | 41a138381273629a,0 470 | 41a17dfb2113619c,0 471 | 41d325841762cd99,1 472 | 4210237a8adfa3fa,1 473 | 42470b2b40cb0e08,1 474 | 4261645bcc1dd73c,0 475 | 428a82a713007299,0 476 | 42fa06ecb1d8541d,0 477 | 42fed3d5dee37fae,1 478 | 430670ee980faa9c,0 479 | 430c1a398dfcd405,0 480 | 436cdf23b0ba4d76,0 481 | 43712821cf93e649,0 482 | 43b342b0679aff9b,0 483 | 43d50a312e08b997,0 484 | 43dc00dc34fd00f9,1 485 | 43fc149f47813d4a,0 486 | 4437e11d92557b30,0 487 | 44880d5ba63aefb0,0 488 | 44b22675dae9cfa6,0 489 | 44fdc19754799879,1 490 | 450d0d858061f3ae,1 491 | 454d4cb8233fe870,1 492 | 4550a814fd718224,1 493 | 4591afc0400e567c,1 494 | 4599350cccbb98e1,0 495 | 45b82f7e20e6614f,1 496 | 461871a1e09b1d0b,0 497 | 4632286207e40352,1 498 | 4643592415251268,0 499 | 4645bbfb33cd424f,1 500 | 4651a988c0b7c7ba,0 501 | 46a1b7b5b72e2863,1 502 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | dataperf-visual-submission: 3 | container_name: dataperf-visual-submission 4 | build: . 5 | volumes: 6 | - ./data:/app/data -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | """Download samples and eval data""" 2 | import argparse 3 | import os 4 | import yaml 5 | import gdown 6 | import zipfile 7 | from tqdm import tqdm 8 | 9 | 10 | def download_file(url, folder_path, file_name): 11 | """Download file from Google Drive""" 12 | output_path = os.path.join(folder_path, file_name) 13 | gdown.download(url, output_path, quiet=False, fuzzy=True) 14 | 15 | 16 | def main(): 17 | """Main function that perform the download""" 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--parameters_file", 22 | type=str, 23 | required=True, 24 | help="File containing parameters for the download", 25 | ) 26 | parser.add_argument( 27 | "--output_path", type=str, required=True, help="Path where data will be stored", 28 | ) 29 | args = parser.parse_args() 30 | 31 | with open(args.parameters_file, "r") as f: 32 | params = yaml.full_load(f) 33 | 34 | output_path = args.output_path 35 | dataset_url = params["dataset_url"] 36 | file_name = "dataperf-vision-selection-resources.zip" 37 | 38 | download_file(dataset_url, output_path, file_name) 39 | 40 | with zipfile.ZipFile(os.path.join(output_path, file_name)) as zf: 41 | for member in tqdm(zf.infolist(), desc="Extracting "): 42 | try: 43 | zf.extract(member, output_path) 44 | except zipfile.error as e: 45 | print(e) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | from pyspark.sql import DataFrame 4 | from sklearn.base import BaseEstimator 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.metrics import accuracy_score, recall_score,\ 7 | precision_score, f1_score 8 | 9 | import constants as c 10 | 11 | 12 | def get_trained_classifier( 13 | df: DataFrame, 14 | clf: Optional[BaseEstimator] = LogisticRegression( 15 | random_state=c.RANDOM_SEED)) -> BaseEstimator: 16 | 17 | df = df.select(c.LABEL_COL, c.EMB_COL).toPandas() 18 | X = df[c.EMB_COL].values.tolist() 19 | y = df[c.LABEL_COL].values.tolist() 20 | 21 | return clf.fit(X, y) 22 | 23 | 24 | def score_classifier(df: DataFrame, clf: BaseEstimator) -> Dict[str, float]: 25 | df = df.select(c.LABEL_COL, c.EMB_COL).toPandas() 26 | X = df[c.EMB_COL].values.tolist() 27 | y = df[c.LABEL_COL].values.tolist() 28 | y_pred = clf.predict(X).tolist() 29 | 30 | scores = {} 31 | scores['accuracy'] = accuracy_score(y, y_pred) 32 | scores['recall'] = recall_score(y, y_pred, pos_label='1') 33 | scores['precision'] = precision_score(y, y_pred, pos_label='1') 34 | scores['f1'] = f1_score(y, y_pred, pos_label='1') 35 | return scores 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fire 4 | 5 | import constants as c 6 | import utils as utils 7 | import eval as eval 8 | 9 | 10 | def run_tasks( 11 | setup_yaml_path: str = c.DEFAULT_SETUP_YAML_PATH, 12 | docker_flag: bool = False) -> None: 13 | """Runs visual benchmark tasks based on config yaml file. 14 | 15 | Args: 16 | setup_yaml_path (str, optional): Path for config file. Defaults 17 | to path in constants.DEFAULT_SETUP_YAML_PATH. 18 | docker_flag (bool, optional): True when running in container 19 | """ 20 | task_setup = utils.load_yaml(setup_yaml_path) 21 | data_dir_key = c.SETUP_YAML_DOCKER_DATA_DIR_KEY if docker_flag \ 22 | else c.SETUP_YAML_LOCAL_DATA_DIR_KEY 23 | data_dir = task_setup[data_dir_key] 24 | dim = task_setup[c.SETUP_YAML_DIM_KEY] 25 | emb_path = os.path.join(data_dir, task_setup[c.SETUP_YAML_EMB_KEY]) 26 | 27 | ss = utils.get_spark_session(task_setup[c.SETUP_YAML_SPARK_MEM_KEY]) 28 | 29 | print('Loading embeddings\n') 30 | emb_df = utils.load_emb_df(ss=ss, path=emb_path, dim=dim) 31 | 32 | task_paths = { 33 | task: task_setup[task] for task in task_setup[c.SETUP_YAML_TASKS_KEY]} 34 | task_scores = {} 35 | for task, paths in task_paths.items(): 36 | print(f'Evaluating task: {task}') 37 | train_path, test_path = [os.path.join(data_dir, p) for p in paths] 38 | 39 | print(f'Loading training data for {task}...') 40 | train_df = utils.load_train_df(ss=ss, path=train_path) 41 | train_df = utils.add_emb_col(df=train_df, emb_df=emb_df) 42 | 43 | print(f'Loading test data for {task}...') 44 | test_df = utils.load_test_df(ss=ss, path=test_path, dim=dim) 45 | 46 | print(f'Training classifier for {task}...') 47 | clf = eval.get_trained_classifier(df=train_df) 48 | 49 | print(f'Scoring trained classifier for {task}...\n') 50 | task_scores[task] = eval.score_classifier(df=test_df, clf=clf) 51 | 52 | save_dir = os.path.join(data_dir, task_setup[c.SETUP_YAML_RESULTS_KEY]) 53 | utils.save_results(data=task_scores, save_dir=save_dir, verbose=True) 54 | 55 | 56 | if __name__ == "__main__": 57 | fire.Fire(run_tasks) 58 | -------------------------------------------------------------------------------- /mlcube.py: -------------------------------------------------------------------------------- 1 | """MLCube handler file""" 2 | import os 3 | import subprocess 4 | 5 | import typer 6 | import yaml 7 | 8 | from selection import Predictor 9 | 10 | typer_app = typer.Typer() 11 | 12 | 13 | class DownloadTask: 14 | """Download samples and eval data""" 15 | 16 | @staticmethod 17 | def run(parameters_file: str, output_path: str) -> None: 18 | 19 | cmd = "python3 download_data.py" 20 | cmd += f" --parameters_file={parameters_file} --output_path={output_path}" 21 | splitted_cmd = cmd.split() 22 | 23 | process = subprocess.Popen(splitted_cmd, cwd=".") 24 | process.wait() 25 | 26 | 27 | class SelectTask: 28 | """Run select algorithm""" 29 | 30 | @staticmethod 31 | def run( 32 | input_path: str, embeddings_path: str, parameters_file: str, output_path: str 33 | ) -> None: 34 | 35 | with open(parameters_file, "r") as f: 36 | params = yaml.full_load(f) 37 | 38 | print("Creating predictor") 39 | predictor = Predictor(embeddings_path) 40 | predictor.closest_and_furthest( 41 | input_path, output_path, params["n_closest"], params["n_random"] 42 | ) 43 | 44 | 45 | class EvaluateTask: 46 | """Execute evaluation script""" 47 | 48 | @staticmethod 49 | def run(eval_path: str, log_path: str) -> None: 50 | 51 | env = os.environ.copy() 52 | env.update({"eval_path": eval_path, "log_path": log_path}) 53 | 54 | process = subprocess.Popen("./run_evaluate.sh", cwd=".", env=env) 55 | process.wait() 56 | 57 | 58 | @typer_app.command("download") 59 | def download( 60 | parameters_file: str = typer.Option(..., "--parameters_file"), 61 | output_path: str = typer.Option(..., "--output_path"), 62 | ): 63 | DownloadTask.run(parameters_file, output_path) 64 | 65 | 66 | @typer_app.command("select") 67 | def select( 68 | input_path: str = typer.Option(..., "--input_path"), 69 | embeddings_path: str = typer.Option(..., "--embeddings_path"), 70 | parameters_file: str = typer.Option(..., "--parameters_file"), 71 | output_path: str = typer.Option(..., "--output_path"), 72 | ): 73 | SelectTask.run(input_path, embeddings_path, parameters_file, output_path) 74 | 75 | 76 | @typer_app.command("evaluate") 77 | def evaluate( 78 | eval_path: str = typer.Option(..., "--eval_path"), 79 | log_path: str = typer.Option(..., "--log_path"), 80 | ): 81 | EvaluateTask.run(eval_path, log_path) 82 | 83 | 84 | if __name__ == "__main__": 85 | typer_app() 86 | -------------------------------------------------------------------------------- /mlcube.yaml: -------------------------------------------------------------------------------- 1 | name: MLCommons DataPerf Vision Example 2 | description: MLCommons DataPerf integration with MLCube 3 | authors: 4 | - { name: "MLCommons Best Practices Working Group" } 5 | 6 | platform: 7 | accelerator_count: 0 8 | 9 | docker: 10 | # Image name. 11 | image: mlcommons/dataperf_vision:0.0.1 12 | # Docker build context relative to $MLCUBE_ROOT. Default is `build`. 13 | build_context: "." 14 | # Docker file name within docker build context, default is `Dockerfile`. 15 | build_file: "Dockerfile_mlcube" 16 | 17 | tasks: 18 | download: 19 | # Download data 20 | parameters: 21 | inputs: { parameters_file: { type: file, default: parameters.yaml } } 22 | outputs: { output_path: data/ } 23 | 24 | select: 25 | # Run selection script 26 | parameters: 27 | inputs: { 28 | input_path: { type: file, default: data/dataperf-vision-selection-resources/train_sets/random_500.csv }, 29 | embeddings_path: { type: file, default: data/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet}, 30 | parameters_file: { type: file, default: parameters.yaml } 31 | } 32 | outputs: { output_path: { type: file, default: data/selection_output.csv } } 33 | 34 | evaluate: 35 | # Perfom evaluation 36 | parameters: 37 | inputs: { eval_path: data/dataperf-vision-selection-resources/} 38 | outputs: { log_path: { type: file, default: log.txt } } 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyspark==3.2.1 2 | pyyaml==6.0 3 | fire==0.4.0 4 | scikit-learn==0.24.0 5 | pandas==1.1.5 6 | typer==0.6.1 7 | gdown==4.5.1 8 | tqdm==4.64.0 9 | pyarrow 10 | -------------------------------------------------------------------------------- /run_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ln -sf ${eval_path} ./data 4 | python3 main.py 2>&1 | tee ${log_path} 5 | -------------------------------------------------------------------------------- /selection.py: -------------------------------------------------------------------------------- 1 | """Selection script 2 | Replace this file with your own logic""" 3 | import pandas as pd 4 | import numpy as np 5 | import shutil 6 | 7 | 8 | class Predictor: 9 | def __init__(self, embeddings_path: str): 10 | print("Loading embeddings") 11 | self.embeddings = self.read_parquet(embeddings_path) 12 | self.embeddings_vect = np.stack(self.embeddings["embedding"].to_numpy()) 13 | print("Embeddings loaded") 14 | 15 | def read_parquet(self, path: str): 16 | return pd.read_parquet(path, engine="pyarrow", use_threads=True) 17 | 18 | def similarity(self, centroid, n_closest, n_random): 19 | random_indexes = np.random.choice( 20 | len(self.embeddings_vect), size=n_random, replace=False 21 | ) 22 | distances = np.linalg.norm( 23 | self.embeddings_vect[random_indexes] - centroid, axis=1 24 | ) 25 | result = np.argsort(distances)[:n_closest] 26 | return result 27 | 28 | def get_embeddings_labeled_data(self, input_path): 29 | df = pd.read_csv(input_path) 30 | return self.embeddings.loc[self.embeddings["ImageID"].isin(df["ImageID"])] 31 | 32 | def calculate_centroid(self, df): 33 | return df["embedding"].apply(lambda x: np.array(x)).values.mean() 34 | 35 | def closest_and_furthest( 36 | self, input_path, output_path, n_closest=100, n_random=1000 37 | ): 38 | df = self.get_embeddings_labeled_data(input_path) 39 | print("Embeddings loaded") 40 | print("Getting centroids") 41 | centroid = self.calculate_centroid(df) 42 | print("Running similairty") 43 | similarity = self.similarity(centroid, n_closest, n_random) 44 | print("Saving submission") 45 | submission = pd.DataFrame(self.embeddings.iloc[similarity]["ImageID"]) 46 | submission["Confidence"] = 1 47 | submission.to_csv(output_path, index=False) 48 | -------------------------------------------------------------------------------- /task_setup.yaml: -------------------------------------------------------------------------------- 1 | # Memory used by Spark driver (recommend > 6g) 2 | spark_driver_memory: '6g' 3 | 4 | # Path for data directory. Note that all other paths in this setup 5 | # file will be relative to this one (e.g. if data_dir='some/path' 6 | # and emb_file='to/emb.parquet', the embeddings will be loaded from 7 | # 'some/path/to/emb.parquet') 8 | data_dir: 'data' 9 | 10 | # Relative path for embedding file 11 | emb_file: 'embeddings/train_emb_256_dataperf.parquet' 12 | 13 | # Paths to training and test files for each tasks. Task names must match names 14 | # in eval_tasks above. See example below: 15 | # task_name: ['relative/path/to/training_set.csv', 'relative/path/to/test_set.csv'] 16 | Cupcake: ['train_sets/Cupcake.csv', 'test_sets/alpha_test_set_Cupcake_256.parquet'] 17 | Hawk: ['train_sets/Hawk.csv', 'test_sets/alpha_test_set_Hawk_256.parquet'] 18 | Sushi: ['train_sets/Sushi.csv', 'test_sets/alpha_test_set_Sushi_256.parquet'] 19 | 20 | # Relative path for results file 21 | results_dir: 'results' 22 | 23 | # Parameters specific to alpha below (likely not useful to modify) 24 | 25 | # Embedding dimensionality 26 | dim: 256 27 | 28 | # List the names of tasks to evaluate 29 | eval_tasks: ['Cupcake','Hawk','Sushi'] 30 | 31 | # Path for data directory in docker image. 32 | # DO NOT MODIFY 33 | docker_data_dir: 'data' 34 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datetime import datetime 4 | 5 | import yaml 6 | from pyspark.sql import SparkSession, DataFrame 7 | 8 | import constants as c 9 | 10 | 11 | def get_spark_session(spark_driver_memory: str) -> SparkSession: 12 | return SparkSession.builder\ 13 | .config('spark.driver.memory', spark_driver_memory)\ 14 | .getOrCreate() 15 | 16 | 17 | def get_emb_dim(df: DataFrame) -> int: 18 | return len(df.select(c.EMB_COL).take(1)[0][0]) 19 | 20 | 21 | def load_yaml(path: str) -> dict: 22 | with open(path, 'r') as stream: 23 | try: 24 | yaml_dict = yaml.safe_load(stream) 25 | except yaml.YAMLError as exc: 26 | print(exc) 27 | return yaml_dict 28 | 29 | 30 | def load_emb_df(ss: SparkSession, path: str, dim: int) -> DataFrame: 31 | df = ss.read.parquet(path) 32 | 33 | for col in [c.EMB_COL, c.ID_COL]: 34 | assert col in df.columns, \ 35 | f'Embedding file does not have "{col}" column' 36 | 37 | actual_dim = get_emb_dim(df) 38 | assert actual_dim == dim, \ 39 | f'Embedding file dim={actual_dim}, but setup file specifies dim={dim}' 40 | 41 | return df 42 | 43 | 44 | def load_train_df(ss: SparkSession, path: str) -> DataFrame: 45 | df = ss.read.option('header', True).csv(path) 46 | 47 | for col in [c.LABEL_COL, c.ID_COL]: 48 | assert col in df.columns, \ 49 | f'{path}: Train file does not have "{col}" column' 50 | 51 | return df 52 | 53 | 54 | def add_emb_col(df: DataFrame, emb_df: DataFrame) -> DataFrame: 55 | emb_df = emb_df.select(c.ID_COL, c.EMB_COL) 56 | return df.join(emb_df, c.ID_COL) 57 | 58 | 59 | def load_test_df(ss: SparkSession, path: str, dim: int) -> DataFrame: 60 | df = ss.read.parquet(path) 61 | 62 | for col in [c.LABEL_COL, c.ID_COL]: 63 | assert col in df.columns, \ 64 | f'{path}: Train file does not have "{col}" column' 65 | 66 | actual_dim = get_emb_dim(df) 67 | assert actual_dim == dim, \ 68 | f'Test file dim={actual_dim}, but setup file specifies dim={dim}' 69 | 70 | return df 71 | 72 | 73 | def save_results(data: dict, save_dir: str, verbose=False) -> None: 74 | dt = datetime.utcnow().strftime("UTC-%Y-%m-%d-%H-%M-%S") 75 | filename = f'{c.RESULT_FILE_PREFIX}_{dt}.json' 76 | path = os.path.join(save_dir, filename) 77 | with open(path, 'w') as f: 78 | json.dump(data, f, indent=4) 79 | 80 | if verbose: 81 | print(f'Results saved in {path}') 82 | print(json.dumps(data, indent=4)) 83 | -------------------------------------------------------------------------------- /workspace/parameters.yaml: -------------------------------------------------------------------------------- 1 | dataset_url: "https://drive.google.com/drive/folders/181uI-7NFJwK3IOPy2kOYVIQS4vZOC02A?usp=sharing" 2 | n_random: 500 3 | n_closest: 50 --------------------------------------------------------------------------------