├── .gitignore
├── Dockerfile
├── Dockerfile_mlcube
├── README.md
├── baselines
├── fpscv
│ └── run.py
├── modified_uncertainty_sampling
│ ├── get_tag.py
│ └── run.py
├── other_experiments
│ └── run.ipynb
└── pseudo_label_generation
│ ├── dataperf_vision_experiments.py
│ └── utils.py
├── constants.py
├── data
├── embeddings
│ └── .gitignore
├── examples
│ ├── alpha_example_set_Cupcake.csv
│ ├── alpha_example_set_Hawk.csv
│ └── alpha_example_set_Sushi.csv
├── results
│ ├── .gitignore
│ └── result_for_random_500.json
├── test_sets
│ └── .gitignore
└── train_sets
│ ├── .gitignore
│ └── random_500.csv
├── docker-compose.yaml
├── download_data.py
├── eval.py
├── main.py
├── mlcube.py
├── mlcube.yaml
├── requirements.txt
├── run_evaluate.sh
├── selection.py
├── task_setup.yaml
├── utils.py
└── workspace
└── parameters.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | workdir/
2 | .vscode
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # Scratch space
134 | scratch/
135 |
136 | # MLCube
137 | workspace/*
138 | !workspace/parameters.yaml
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.9-slim
2 |
3 | RUN apt update -y
4 | RUN apt install default-jdk -y
5 |
6 | COPY requirements.txt /app/requirements.txt
7 | RUN pip install -r /app/requirements.txt
8 |
9 | COPY *.py /app/
10 | COPY *.yaml /app/
11 |
12 | WORKDIR /app/
13 | ENTRYPOINT python3 main.py --docker_flag True
14 |
--------------------------------------------------------------------------------
/Dockerfile_mlcube:
--------------------------------------------------------------------------------
1 | FROM ubuntu:18.04
2 |
3 | RUN apt-get update && apt-get install -y \
4 | software-properties-common && apt install default-jdk -y
5 | RUN apt-get install -y python3.8 python3-pip
6 |
7 | COPY requirements.txt /app/requirements.txt
8 | RUN pip3 install --upgrade pip && pip3 install -r /app/requirements.txt
9 |
10 | COPY *.py /app/
11 | COPY *.yaml /app/
12 | COPY *.sh /app/
13 |
14 | WORKDIR /app/
15 | RUN chmod +x /app/run_evaluate.sh
16 | ENTRYPOINT ["python3", "/app/mlcube.py"]
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # `dataperf-vision-selection`: A Data-Centric Visual Benchmark for Training Data Selection
2 |
3 | ### **Current version:** beta
4 |
5 | This github repo serves as the starting point for offline evaluation of submissions for the training data selection visual benchmark. The offline evaluation can be run on both your local environment as well as a containerized image for reproducibility of score results.
6 |
7 | For a detailed summary of the a benchmark, refer to the provided [documentation](https://www.dataperf.org/training-set-selection-vision).
8 |
9 | *Note that permission is required to view the benchmark documentation and download the required resources. Please contact dataperf@coactive.ai to request access.*
10 |
11 | ## Requirements
12 |
13 | ### Download resources
14 |
15 | The following resources will need to be downloaded locally in order to run offline evaluation:
16 |
17 | - Embeddings for candidate pool of training images (.parquet file)
18 | - Test sets for each classification task (.parquet files)
19 |
20 | These resources can be downloaded in a .zip file at the following url
21 |
22 | ```
23 | https://drive.google.com/drive/folders/181uI-7NFJwK3IOPy2kOYVIQS4vZOC02A?usp=sharing
24 | ```
25 |
26 | ### Install dependencies
27 |
28 | For running as a containerized image:
29 |
30 | - `docker` for building the containerized image
31 | - `docker-compose` for running the scoring service with the appropriate resources
32 |
33 | Installation instructions can be found at the following links: [Docker](https://docs.docker.com/get-docker/), [Docker compose](https://docs.docker.com/compose/install/)
34 |
35 | For running locally:
36 |
37 | - Python (>= 3.7)
38 | - An [appropriate version of Java](https://spark.apache.org/docs/latest/) for your version of `python` and `pyspark`
39 |
40 | The current version of this repo has only been tested locally on python 3.9 and java openjdk-11.
41 |
42 | ## Installation
43 |
44 | Clone this repo to your local machine
45 |
46 | ```
47 | git clone git@github.com:CoactiveAI/dataperf-vision-selection.git
48 | ```
49 |
50 | If you want to run the offline evaluation in your local environment, install the required python packages
51 |
52 | ```
53 | pip install -r dataperf-vision-selection/requirements.txt
54 | ```
55 |
56 | A template filesystem with the following structure is provided in the repo. Move the embeddings file and the tests sets to the appropriate folders in this template filesystem
57 |
58 | ```
59 | unzip dataperf-vision-selection-resources.zip
60 | mv dataperf-vision-selection-resources/embeddings/* dataperf-vision-selection/data/embeddings/
61 | mv dataperf-vision-selection-resources/examples/* dataperf-vision-selection/data/examples/
62 | mv dataperf-vision-selection-resources/test_sets/* dataperf-vision-selection/data/test_sets/
63 | mv dataperf-vision-selection-resources/train_sets/* dataperf-vision-selection/data/train_sets/
64 | mv dataperf-vision-selection-resources/results/* dataperf-vision-selection/data/results/
65 | ```
66 |
67 | The resulting filesystem in the repo should look as follows
68 |
69 | ```
70 | |____data
71 | | |____embeddings
72 | | | |____train_emb_256_dataperf.parquet
73 | | |____examples
74 | | | |____alpha_example_set_Hawk.csv
75 | | | |____alpha_example_set_Cupcake.csv
76 | | | |____alpha_example_set_Sushi.csv
77 | | |____test_sets
78 | | | |____alpha_test_set_Hawk_256.parquet
79 | | | |____alpha_test_set_Cupcake_256.parquet
80 | | | |____alpha_test_set_Sushi_256.parquet
81 | | |____train_sets
82 | | | |____random_500.csv
83 | | |____results
84 | | | |____result_for_random_500.json
85 | ```
86 |
87 | With the resources in place, you can now test that the system is functioning as expected.
88 |
89 | To test the containerized offline evaluation, run
90 |
91 | ```
92 | cd dataperf-vision-selection
93 | docker-compose up
94 | ```
95 |
96 | Similarly, to test the local python offline evaluation, run
97 |
98 | ```
99 | cd dataperf-vision-selection
100 | python3 main.py
101 | ```
102 |
103 | Either test will run the offline evaluation using the setup specified in `task_setup.yaml`, which will utilize a training set of randomly sampled and labeled data points (`data/train_sets/random_500.csv`) to generate a score results file in `data/results/` with a unique UTC timestamp
104 |
105 | ```
106 | |____data
107 | | |____results
108 | | | |____result_for_random_500.json
109 | | | |____result_UTC-2022-03-31-20-19-24.json
110 | ```
111 |
112 | The generated scores in this new results file should be identical to those in `data/results_for_random_500.json`.
113 |
114 | # MLCube execution
115 |
116 | The [MLCube](https://github.com/mlcommons/mlcube) implementation allows us to execute the project using the following steps.
117 |
118 | ## Project setup
119 |
120 | ```bash
121 | # Create Python environment and install MLCube Docker runner
122 | virtualenv -p python3 ./env && source ./env/bin/activate && pip install mlcube-docker
123 |
124 | # Fetch the vision selection repo
125 | git clone https://github.com/CoactiveAI/dataperf-vision-selection && cd ./dataperf-vision-selection
126 | ```
127 |
128 | ## Tasks execution
129 |
130 | ```bash
131 | # Download and extract dataset
132 | mlcube run --task=download -Pdocker.build_strategy=always
133 |
134 | # Run evaluation
135 | mlcube run --task=evaluate -Pdocker.build_strategy=always
136 | ```
137 |
138 | ## Execute complete pipeline
139 |
140 | ```bash
141 | # Run all steps
142 | mlcube run --task=download,evaluate -Pdocker.build_strategy=always
143 | ```
144 |
145 | # Guidelines v0.5
146 |
147 | For the v0.5 of this benchmark we will support offline and online evaluation for the open division.
148 |
149 | ## Open Division: Creating a submission
150 |
151 | A valid submission for the open division includes the following:
152 |
153 | - A description of the data selection algorithm/strategy used
154 | - A training set for each classification task as specified below
155 | - (Optional) A script of the algorithm/strategy used
156 |
157 | Each training set file must be a .csv file containing two columns: `ImageID` (the unique identifier for the image) and `Confidence` (the binary label, either a `0` or `1`). The `ImageID`s in the training set files must be limited to the provided candidate pool of training images (i.e. `ImageID`s in the downloaded embeddings file).
158 |
159 | The included training set file serves as a template of a single training set:
160 |
161 | ```
162 | cat dataperf-vision-selection/data/train_sets/random_500.csv
163 |
164 | ImageID,Confidence
165 | 0002643773a76876,0
166 | 0016a0f096337445,0
167 | 0036043ce525479b,1
168 | 00526f123f84db2f,1
169 | 0080db2599d54447,1
170 | 00978577e9fdd967,1
171 | ...
172 | ```
173 |
174 | ## Open Division: Offline evaluation
175 |
176 | The configuration for the offline evaluation is specified in `task_setup.yaml` file. For simplicity, the repo comes pre-configured such that for offline evaluation you can simply:
177 |
178 | 1. Copy your training sets to the template filesystem
179 | 2. Modify the config file to specify the training set for each task
180 | 3. Run offline evaluation
181 | 4. See results in stdout and results file in `data/results/`
182 |
183 | For example
184 |
185 | ```
186 | # 1. Copy training sets for each task
187 | cd dataperf-vision-selection
188 | cp /path/to/your/training/sets/Cupcake.csv data/train_sets/
189 | cp /path/to/your/training/sets/Hawk.csv data/train_sets/
190 | cp /path/to/your/training/sets/Sushi.csv data/train_sets/
191 |
192 | # 2. task_setup.yaml: modify the training set relative path for each classification task
193 | Cupcake: ['train_sets/Cupcake.csv', 'test_sets/alpha_test_set_Cupcake_256.parquet']
194 | Hawk: ['train_sets/Hawk.csv', 'test_sets/alpha_test_set_Hawk_256.parquet']
195 | Sushi: ['train_sets/Sushi.csv', 'test_sets/alpha_test_set_Sushi_256.parquet']
196 |
197 | # 3a. Run offline evaluation (docker)
198 | docker-compose up --build --force-recreate
199 |
200 | # 3b. Run offline evaluation (local python)
201 | python3 main.py
202 |
203 | # 4. See results (file will have save timestamp in name)
204 | cat data/results/result_UTC-2022-03-31-20-19-24.json
205 |
206 | {
207 | "Cupcake": {
208 | "accuracy": 0.5401459854014599,
209 | "recall": 0.463768115942029,
210 | "precision": 0.5517241379310345,
211 | "f1": 0.5039370078740157
212 | },
213 | "Hawk": {
214 | "accuracy": 0.296551724137931,
215 | "recall": 0.16831683168316833,
216 | "precision": 0.4857142857142857,
217 | "f1": 0.25000000000000006
218 | },
219 | "Sushi": {
220 | "accuracy": 0.5185185185185185,
221 | "recall": 0.6261682242990654,
222 | "precision": 0.638095238095238,
223 | "f1": 0.6320754716981132
224 | }
225 | }
226 | ```
227 |
228 | Though we recommend working as described above, you can specify a custom task setup .yaml file and/or data folder if needed.
229 |
230 | For the containerized offline evaluation, modify the following files and run as follows
231 |
232 | ```
233 | # docker-compose.yaml: modify the volume source
234 | volumes:
235 | - path/to/your/data/folder:/app/data
236 |
237 | # Dockerfile: modify the COPY *.yaml command and specify the new file in the entrypoint
238 | COPY path/to/your/custom_task_setup.yaml /app/
239 | ...
240 | ENTRYPOINT python3 main.py --docker_flag True --setup_yaml_path 'custom_task_setup.yaml'
241 |
242 | # Run and force rebuild
243 | docker-compose up --build --force-recreate
244 | ```
245 |
246 | For the local python offline evaluation, modify the following files and run as follows
247 |
248 | ```
249 | # path/to/your/custom_task_setup.yaml: modify data_dir
250 | data_dir: 'path/to/your/data/folder'
251 |
252 | # Run and specify custom .yaml file
253 | python3 main.py --setup_yaml_path 'path/to/your/custom_task_setup.yaml'
254 | ```
255 |
256 | *Note: when specifying a data folder, ensure all relative paths in the task setup .yaml file are valid*
257 |
258 | ## Open Division: Online evaluation
259 |
260 | To submit your final submission, we will utilize Dynabench as our online evaluation system. Please submit to the [Vision Dataperf task](https://dynabench.org/tasks/vision-selection).
261 |
262 |
263 | ## Closed Division: Creating a submission
264 |
265 | TBD.
266 |
267 | ## Closed Division: Offline evaluation of a submission
268 |
269 | TBD.
270 |
271 | # Baselines
272 |
273 | In the `baselines/` directory, we include the winning submissions for the beta version of this challenge, which also act as our baseline. The methods are:
274 | 1. FPSCV (Farthest Point Sampling Cross-Validation) by Paolo Climenco: This method selects negative examples by attempting to sample the feature search space through iterative maximum l2 distances, afterwards
275 | returning the best coreset under nested cross-validation.Official repository for this submission can be found at [Submission-Dataperf-Vision-Challenge](https://github.com/PaClimaco/Submissions-Dataperf-Vision-Challenge/), which also includes a detailed report on this method
276 | 2. Pseudo Label Generation by Danilo Brajovic: This method trains multiple neural networks and classical models on a subset of data to classify the remainder of points and uses the best-performing model for coreset proposal under multiple sampling experiments
277 | 3. Modified Uncertainty Sampling by Steve Mussmann: This method trains a binary classifier on noisy positive labels from OpenImages and uses this classifier to assign positive and negative image pools, with the coreset randomly sampled from both pools.
278 | Rafael Mosquera experiemnted with some classical and neural network based models, which we have included under `other_experiments/`.
279 |
280 | ```
281 | # Run the baseline
282 | python3 baselines/fpscv/run.py
283 | python3 baselines/pseudo_label_generation/run.py
284 | python3 baselines/modified_uncertainty_sampling/run.py
285 | ```
--------------------------------------------------------------------------------
/baselines/fpscv/run.py:
--------------------------------------------------------------------------------
1 | # Author: Paolo Climaco
2 | # Author email: climaco@ins.uni-bonn.de
3 |
4 | import numpy as np
5 | from operator import itemgetter
6 | import torch
7 | from dgl.geometry import farthest_point_sampler
8 | from tqdm import tqdm
9 | import pandas as pd
10 | import numpy as np
11 | import shutil
12 | from typing import Dict, List
13 | import sklearn
14 | import sklearn.model_selection
15 | import sklearn.ensemble
16 | import sklearn.linear_model
17 | import tqdm
18 | import os
19 | import csv
20 | import json
21 | from sklearn.metrics import f1_score
22 | import warnings
23 | import requests
24 |
25 | with warnings.catch_warnings():
26 | warnings.simplefilter("ignore")
27 |
28 | # ------------------------------------------
29 | # PATHS
30 | # -----------------------------------------
31 |
32 | embeddings_path = "../../data/embeddings/train_emb_256_dataperf.parquet"
33 | examples_path = "../../data/examples/"
34 | output_path = "output_repo/" # Change
35 | machine_human_labels_path = os.path.join(
36 | output_path, "oidv6-train-annotations-human-imagelabels.csv"
37 | ) # DO NOT CHANGE
38 |
39 | # Create output directory if does not exist
40 | if not os.path.exists(output_path):
41 | os.makedirs(output_path)
42 |
43 |
44 | # ------------------------------------------
45 | # FUNCTIONS
46 | # -----------------------------------------
47 | def download_csv_from_url(url, local_filename):
48 | try:
49 | response = requests.get(url, stream=True)
50 | response.raise_for_status()
51 | with open(local_filename, "wb") as f:
52 | for chunk in response.iter_content(chunk_size=8192):
53 | f.write(chunk)
54 | print(f"File downloaded and saved as: {local_filename}")
55 | except requests.exceptions.HTTPError as e:
56 | print(f"HTTP Error: {e}")
57 | except requests.exceptions.RequestException as e:
58 | print(f"Request Exception: {e}")
59 |
60 |
61 | def read_parquet(path=str):
62 | return pd.read_parquet(path, engine="pyarrow", use_threads=True)
63 |
64 |
65 | def get_labelled_samples(task=None, mhl=None, embeddings=None, examples_path=None):
66 | # load examples
67 | examples_df = pd.read_csv(
68 | os.path.join(examples_path, f"alpha_example_set_{task}.csv")
69 | )
70 | # get label name examples
71 | LabelName = examples_df["LabelName"][0]
72 |
73 | mydf = mhl[mhl["ImageID"].isin(embeddings["ImageID"])]
74 | labelled_samples_df = mydf[mydf["LabelName"] == LabelName]
75 | labelled_samples_pd = pd.DataFrame(
76 | embeddings[embeddings["ImageID"].isin(labelled_samples_df["ImageID"])]
77 | )
78 | labelled_samples_pd.loc[:, "Confidence"] = np.asarray(
79 | [
80 | labelled_samples_df[labelled_samples_df["ImageID"] == image][
81 | "Confidence"
82 | ].item()
83 | for image in labelled_samples_pd["ImageID"]
84 | ]
85 | )
86 | labelled_samples_vect = np.stack(labelled_samples_pd["embedding"].to_numpy())
87 | # return only those with confidence 1
88 | return (
89 | labelled_samples_vect[labelled_samples_pd["Confidence"] == 1],
90 | labelled_samples_pd[labelled_samples_pd["Confidence"] == 1],
91 | )
92 |
93 |
94 | def csv_to_json(csv, json_file_path):
95 | data = []
96 |
97 | data = csv
98 | dictio = {}
99 | for i, c in zip(data["ImageID"], data["Confidence"]):
100 | dictio[i] = int(c)
101 |
102 | # Write JSON data to a file
103 | with open(json_file_path, "w") as json_file:
104 | json_file.write(json.dumps(dictio, indent=4))
105 |
106 |
107 | # ------------------------------------------
108 | # DOWNLOAD EMBEDDINGS
109 | # -----------------------------------------
110 | print("Loading embeddings")
111 | embeddings = read_parquet(embeddings_path)
112 | embeddings_vect = np.stack(embeddings["embedding"].to_numpy())
113 | print("Embeddings loaded")
114 |
115 |
116 | # ------------------------------------------
117 | # LOAD MACHINE-HUMAN LABELS
118 | # -----------------------------------------
119 |
120 | try:
121 | mhl = pd.read_csv(machine_human_labels_path)
122 | except FileNotFoundError:
123 | print("loading human-verified labels...")
124 | url = "https://storage.googleapis.com/openimages/v6/oidv6-train-annotations-human-imagelabels.csv"
125 | download_csv_from_url(url, machine_human_labels_path)
126 | mhl = pd.read_csv(machine_human_labels_path)
127 |
128 |
129 | # ------------------------------------------
130 | # FARTHEST POINT SAMPLING: it takes around 17 minutes but it is a one time effort. SAVE AND LOAD
131 | # -----------------------------------------
132 |
133 | try:
134 | idx_fps = np.load(
135 | os.path.join(output_path, "FPS_CrossV_init_0.npy")
136 | ) # np.load('/local/hdd/climaco/dataperf-vision-selection/data/selections_folder/FPS_CrossV_init_0.npy')
137 | except FileNotFoundError:
138 | print("FPS running...it takes a few minutes but it is a one-time effort")
139 | t_emb = torch.from_numpy(embeddings_vect).unsqueeze_(0)
140 | idx_fps = farthest_point_sampler(t_emb, 1000, start_idx=0).numpy()[0].tolist()
141 | np.save(os.path.join(output_path, "FPS_CrossV_init_0.npy"), idx_fps)
142 | idx_fps = np.asarray(idx_fps)
143 |
144 |
145 | # ------------------------------------------
146 | # SELECTION ALGORITHM
147 | # ------------------------------------------
148 |
149 |
150 | def Selection(
151 | task="None",
152 | n_confidence1=5,
153 | n_confidence0=40,
154 | n_folds=5,
155 | mhl=mhl,
156 | examples_path=examples_path,
157 | embeddings=embeddings,
158 | output_path=output_path,
159 | idx_fps=idx_fps,
160 | ):
161 | """
162 | Parameters
163 | ----------
164 |
165 | task : string
166 | The name of the target class, e.g. Hawk, Cupcake or Sushi.
167 |
168 | n_confidence1 : int
169 | Number of points in the final selected set with confidence 1
170 |
171 | n_confidence0 : int
172 | Number of points in the final selected set with confidence 0
173 |
174 | n_folds : int
175 | Number of folds in the nested cross validation
176 |
177 | mhl : pandas data frame
178 | Dataframe containing machine-generated labels and human annotations of images in oidv6 training set
179 |
180 | examples_path : string
181 | Path the the classes examples provided within the challenge
182 |
183 | embeddings : pandas data frame
184 | Data frame containing the available pool of data points provided within the challenge
185 |
186 | output_path : string
187 | Path to folder where algorithm's output is stored
188 |
189 | idx_fps : ndarray
190 |
191 | Returns
192 | ----------
193 | csv and .json files containing selected images IDs and associated confidence (0 or 1)
194 |
195 |
196 | """
197 |
198 | # get elements confidence 1
199 | confidence1_samples_vect, confidence1_samples_df = get_labelled_samples(
200 | task, mhl, embeddings, examples_path
201 | )
202 |
203 | # Idxs selected elements with FPS sapling algorithm
204 |
205 | # Candidate with confidence 0 and candidates with confiddence 1
206 | confidence1_samples = confidence1_samples_vect
207 | confidence0_samples = embeddings_vect[idx_fps]
208 | confidence1_labels = np.ones(confidence1_samples.shape[0])
209 | confidence0_labels = np.zeros(confidence0_samples.shape[0])
210 |
211 | # Cross validation
212 | best_score = 0
213 | best_confidence1_train_ixs = None
214 | best_confidence0_train_ixs = None
215 | confidence0_class_size = n_confidence0
216 | confidence1_class_size = n_confidence1
217 |
218 | crossfold_confidence1 = sklearn.model_selection.StratifiedShuffleSplit(
219 | n_splits=n_folds,
220 | train_size=confidence1_class_size,
221 | random_state=23,
222 | )
223 |
224 | for confidence1_train_ixs, confidence1_val_ixs in tqdm.tqdm(
225 | crossfold_confidence1.split(confidence1_samples, confidence1_labels),
226 | desc="k-fold cross validation",
227 | total=n_folds,
228 | ):
229 | crossfold_confidence0 = sklearn.model_selection.StratifiedShuffleSplit(
230 | n_splits=n_folds,
231 | train_size=confidence0_class_size,
232 | random_state=23,
233 | )
234 | for confidence0_train_ixs, confidence0_val_ixs in crossfold_confidence0.split(
235 | confidence0_samples, confidence0_labels
236 | ):
237 | train_Xs = np.vstack(
238 | [
239 | confidence1_samples[confidence1_train_ixs],
240 | confidence0_samples[confidence0_train_ixs],
241 | ]
242 | )
243 | train_ys = np.concatenate(
244 | [
245 | confidence1_labels[confidence1_train_ixs],
246 | confidence0_labels[confidence0_train_ixs],
247 | ]
248 | )
249 |
250 | clf = sklearn.ensemble.VotingClassifier(
251 | estimators=[
252 | ("lr", sklearn.linear_model.LogisticRegression()),
253 | ],
254 | voting="soft",
255 | weights=None,
256 | n_jobs=-1,
257 | )
258 | clf.fit(train_Xs, train_ys)
259 |
260 | val_Xs = np.vstack(
261 | [
262 | confidence1_samples[confidence1_val_ixs],
263 | confidence0_samples[confidence0_val_ixs],
264 | ]
265 | )
266 | val_ys = np.concatenate(
267 | [
268 | confidence1_labels[confidence1_val_ixs],
269 | confidence0_labels[confidence0_val_ixs],
270 | ]
271 | )
272 |
273 | pred_Ys = clf.predict(val_Xs)
274 |
275 | score = f1_score(val_ys, pred_Ys, labels=[0, 1], average="binary")
276 | if score > best_score:
277 | best_score = score
278 | best_confidence1_train_ixs = confidence1_train_ixs
279 | best_confidence0_train_ixs = confidence0_train_ixs
280 |
281 | # store selected elements in csv
282 | submission_confidence1 = pd.DataFrame(
283 | confidence1_samples_df.iloc[best_confidence1_train_ixs]["ImageID"]
284 | )
285 | submission_confidence1.loc[:, "Confidence"] = np.ones(
286 | len(submission_confidence1), dtype=int
287 | ).tolist()
288 |
289 | submission_confidence0 = pd.DataFrame(
290 | embeddings.iloc[idx_fps[best_confidence0_train_ixs]]["ImageID"]
291 | )
292 | submission_confidence0.loc[:, "Confidence"] = np.zeros(
293 | len(submission_confidence0), dtype=int
294 | ).tolist()
295 |
296 | submission = pd.concat([submission_confidence1, submission_confidence0])
297 |
298 | # seve selected elements in csv and json format
299 | submission.to_csv(os.path.join(output_path, f"{task}_fpscv_ft2.csv"), index=False)
300 | csv_to_json(submission, os.path.join(output_path, f"{task}_fpscv_ft2.json"))
301 |
302 |
303 | # ------------------------------------------
304 | # RUN SELECTION ALGORITHM
305 | # ------------------------------------------
306 |
307 | # fpscv average F1 score 78.10%
308 | """
309 | print('Hawk')
310 | Selection('Hawk')
311 |
312 | print('Cupcake')
313 | Selection('Cupcake')
314 |
315 | print('Sushi')
316 | Selection('Sushi')
317 | """
318 | # fpscv_ft average F1 score 79.95%
319 | """
320 | print('Hawk')
321 | Selection('Hawk', n_confidence1=5, n_confidence0=20)
322 |
323 | print('Cupcake')
324 | Selection('Cupcake', n_confidence1=5, n_confidence0=40)
325 |
326 | print('Sushi')
327 | Selection('Sushi', n_confidence1=10, n_confidence0=20)
328 | """
329 |
330 | # fpscv_ft2 average F1 score 81%
331 | print("Hawk")
332 | Selection("Hawk", n_confidence1=5, n_confidence0=90, n_folds=4)
333 |
334 | print("Cupcake")
335 | Selection("Cupcake", n_confidence1=5, n_confidence0=40)
336 |
337 | print("Sushi")
338 | Selection("Sushi", n_confidence1=10, n_confidence0=20)
339 |
340 |
341 | # ----------------------------------------------
342 | # how to improve fpscv algorithm
343 | # ----------------------------------------------
344 |
345 | """
346 | Notice that the fpscv, fpscv_ft and fpscv_ft2 are implemented using the same algorithm that has been initialized
347 | considering different values for the parameters n_confidence1, n_confidence0, and n_folds. Different choices of these
348 | parameters lead to different results.
349 |
350 |
351 | The main issue of the developed approach is its sensibility to the choices of the parameters
352 | n_confidence1, n_confidence0 and n_folds. Such choices have been made heuristically. Can we do better than heuristic?
353 | Developing a principled approach for the optimization of the mentioned parameters may lead to a substantial improvement in terms
354 | of the F1 score on each of the considered tasks.
355 |
356 | Please if you do develop such a principled approach do not hesitate to contact me. I will be curious to hear about it.
357 |
358 | """
359 |
--------------------------------------------------------------------------------
/baselines/modified_uncertainty_sampling/get_tag.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import requests
4 |
5 | if False:
6 | target_tag = "/m/03p1r4"
7 | output_file = "cupcake.txt"
8 | if False:
9 | target_tag = "/m/0fp7c"
10 | output_file = "hawk.txt"
11 | if True:
12 | target_tag = "/m/07030"
13 | output_file = "sushi_human.txt"
14 |
15 |
16 | # ------------------------------------------
17 | # FUNCTIONS
18 | # -----------------------------------------
19 | def download_csv_from_url(url, local_filename):
20 | try:
21 | response = requests.get(url, stream=True)
22 | response.raise_for_status()
23 | with open(local_filename, "wb") as f:
24 | for chunk in response.iter_content(chunk_size=8192):
25 | f.write(chunk)
26 | print(f"File downloaded and saved as: {local_filename}")
27 | except requests.exceptions.HTTPError as e:
28 | print(f"HTTP Error: {e}")
29 | except requests.exceptions.RequestException as e:
30 | print(f"Request Exception: {e}")
31 |
32 |
33 | # ------------------------------------------
34 | # LOAD MACHINE-HUMAN LABELS
35 | # -----------------------------------------
36 |
37 | output_path = "output_repo/" # Change
38 | machine_human_labels_path = os.path.join(
39 | output_path, "oidv6-train-annotations-human-imagelabels.csv"
40 | ) # DO NOT CHANGE
41 | try:
42 | mhl = pd.read_csv(machine_human_labels_path)
43 | except FileNotFoundError:
44 | print("loading human-verified labels...")
45 | url = "https://storage.googleapis.com/openimages/v6/oidv6-train-annotations-human-imagelabels.csv"
46 | download_csv_from_url(url, machine_human_labels_path)
47 | mhl = pd.read_csv(machine_human_labels_path)
48 |
49 |
50 | human_labels = []
51 |
52 | for line in open("annotations-human.csv", "r").readlines():
53 | image_id, _, tag, value = line.strip().split(",")
54 |
55 | if tag == target_tag and float(value) >= 0.5:
56 | human_labels.append(image_id)
57 |
58 |
59 | machine_labels = []
60 | for line in open("annotations-machine.csv", "r").readlines():
61 | image_id, _, tag, value = line.strip().split(",")
62 |
63 | if tag == target_tag and float(value) >= 0.5:
64 | machine_labels.append(image_id)
65 |
66 |
67 | print(len(human_labels))
68 | print(len(machine_labels))
69 |
70 | print(len(set(human_labels).intersection(set(machine_labels))))
71 |
72 | if True:
73 | f = open(output_file, "w")
74 | for image_id in set(human_labels):
75 | f.write(image_id + "\n")
76 | f.close()
77 |
78 | if False:
79 | f = open(output_file, "w")
80 | for image_id in set(human_labels + machine_labels):
81 | f.write(image_id + "\n")
82 | f.close()
83 |
--------------------------------------------------------------------------------
/baselines/modified_uncertainty_sampling/run.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | import OED
4 | from sklearn.linear_model import LogisticRegression
5 |
6 | task = "hawk"
7 | strategy = "diverse"
8 |
9 |
10 | with open("data/embedding_image_ids.pkl","rb") as f:
11 | image_ids = pickle.load(f)
12 | print(len(image_ids))
13 |
14 | imid_to_idx = {}
15 | for i in range(len(image_ids)):
16 | imid_to_idx[image_ids[i]] = i
17 |
18 | with open("data/embeddings.npy","rb") as f:
19 | embeddings = np.load(f)
20 | print(embeddings.shape)
21 |
22 | embeddings = embeddings - np.mean(embeddings, axis=0)
23 |
24 | #cov = embeddings.T @ embeddings / embeddings.shape[0]
25 | #print(np.linalg.eigvalsh(cov))
26 |
27 | with open("data/"+task+".txt","r") as f:
28 | positives = []
29 | for line in f.readlines():
30 | positives.append(line.strip())
31 | print(len(positives))
32 |
33 | embedded_positives = list(set(positives).intersection(set(image_ids)))
34 | print(len(embedded_positives))
35 |
36 |
37 |
38 | n, d = embeddings.shape
39 |
40 | labels = np.zeros((n,))
41 | positive_idx = []
42 | for pos in embedded_positives:
43 | positive_idx.append(imid_to_idx[pos])
44 | labels[positive_idx] = 1
45 |
46 | model = LogisticRegression(C = 10**6, max_iter = 1000)
47 |
48 | print("starting fit")
49 | if False:
50 | model.fit(embeddings,labels)
51 | with open("data/model_"+task+".pkl","wb") as f:
52 | pickle.dump(model,f)
53 | else:
54 | with open("data/model_"+task+".pkl","rb") as f:
55 | model = pickle.load(f)
56 | print("finished fit")
57 | fit_labels = model.predict(embeddings)
58 |
59 | TP = np.sum( fit_labels[positive_idx] == 1)
60 | FN = np.sum( fit_labels[positive_idx] != 1)
61 | FP = np.sum( np.logical_and(fit_labels == 1, labels == 0) )
62 |
63 | print("True positives: {}\nFalse negatives: {}\nFalse positives: {}".format(TP,FN,FP))
64 |
65 |
66 | probs = model.predict_proba(embeddings)[:,1]
67 |
68 | pp = n//100
69 |
70 | if task == "sushi":
71 | topp = np.argsort(probs)[-5*pp:-2*pp]
72 | if task == "cupcake":
73 | topp = np.argsort(probs)[-5*pp:-2*pp]
74 | if task == "hawk":
75 | topp = np.argsort(probs)[-5*pp:-2*pp]
76 |
77 | #if task == "sushi":
78 | # close_idx = np.logical_and(probs >= 10**-6,probs < 10**-4).nonzero()[0]
79 | #elif task == "cupcake":
80 | # close_idx = np.logical_and(probs >= 10**-6,probs < 10**-5).nonzero()[0]
81 | #elif task == "hawk":
82 | # close_idx = np.logical_and(probs >= 10**-5,probs < 10**-4).nonzero()[0]
83 | if task == "cupcake":
84 | true_idx = np.logical_and(probs >= 0.5, labels==1).nonzero()[0]
85 | if task == "sushi":
86 | true_idx = (labels==1).nonzero()[0]
87 | if task == "hawk":
88 | true_idx = np.logical_and(probs >= 0.5, labels==1).nonzero()[0]
89 |
90 | close_idx = list(set(topp).difference(set(true_idx)))
91 |
92 | print(len(close_idx))
93 | print(len(true_idx))
94 |
95 | if strategy == "even":
96 | with open("data/train_sets/"+task+"_even.csv","w") as f:
97 | f.write("ImageID,Confidence\n")
98 |
99 | true_labels = min(500, len(true_idx))
100 |
101 | for idx in np.random.choice(true_idx,size=(true_labels,),replace=False):
102 | f.write("{},1\n".format(image_ids[idx]))
103 |
104 | for idx in np.random.choice(close_idx,size=(1000-true_labels,),replace=False):
105 | f.write("{},0\n".format(image_ids[idx]))
106 |
107 | if strategy == "diverse":
108 | with open("data/train_sets/"+task+"_27.csv","w") as f:
109 | f.write("ImageID,Confidence\n")
110 |
111 | if task == "cupcake":
112 | true_labels = 500
113 | if task == "sushi":
114 | true_labels = 500
115 | if task == "hawk":
116 | true_labels = 333
117 |
118 |
119 | true_embeddings = embeddings[true_idx,:]
120 | design_points = OED.random(true_embeddings,true_labels)
121 | selected_idx = [true_idx[point] for point in design_points]
122 |
123 | for idx in selected_idx:
124 | f.write("{},1\n".format(image_ids[idx]))
125 |
126 |
127 | close_embeddings = embeddings[close_idx,:]
128 | design_points = OED.random(close_embeddings,1000-true_labels)
129 | selected_idx = [close_idx[point] for point in design_points]
130 |
131 | for idx in selected_idx:
132 | f.write("{},0\n".format(image_ids[idx]))
133 |
134 |
135 |
136 |
137 | print("done!")
138 |
--------------------------------------------------------------------------------
/baselines/other_experiments/run.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "d0ef2ba7-b0cd-436a-bdde-716e47f1c67f",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import pandas as pd\n",
11 | "import numpy as np"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "id": "d7a5c8b4-6f1a-4dfe-8a79-e66ea6ddd3aa",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "path = '/Users/rafael/Downloads/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet/part-00000-c487d15f-d29d-4c93-96c2-4caf4c12cd59-c000.snappy.parquet'"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 3,
27 | "id": "f9c17b76-ce54-43a0-b673-5b8ee297a5d9",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "df = pd.read_parquet(path)"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 4,
37 | "id": "6edd744e-4da3-4aad-8432-88383fd19e4f",
38 | "metadata": {
39 | "tags": []
40 | },
41 | "outputs": [],
42 | "source": [
43 | "for i in range(10):\n",
44 | " parquet_num = str(i).zfill(5)\n",
45 | " path = f'/Users/rafael/Downloads/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet/part-{parquet_num}-c487d15f-d29d-4c93-96c2-4caf4c12cd59-c000.snappy.parquet'\n",
46 | " df2 = pd.read_parquet(path)\n",
47 | " df = df.append(df2)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 20,
53 | "id": "948a5074-99c6-46fc-aed3-f09d2ed26b67",
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "data": {
58 | "text/html": [
59 | "
\n",
60 | "\n",
73 | "
\n",
74 | " \n",
75 | " \n",
76 | " | \n",
77 | " ImageID | \n",
78 | " embedding | \n",
79 | "
\n",
80 | " \n",
81 | " \n",
82 | " \n",
83 | " 0 | \n",
84 | " 0001771c86cadcd6 | \n",
85 | " [0.7272707919841341, 1.336483982316891, -0.164... | \n",
86 | "
\n",
87 | " \n",
88 | " 1 | \n",
89 | " 0003db14e5cdc1c4 | \n",
90 | " [1.7126857273368188, -0.11685025791386651, -1.... | \n",
91 | "
\n",
92 | " \n",
93 | " 2 | \n",
94 | " 00080e248756d4dd | \n",
95 | " [0.8326104343431097, 0.4016596514351839, 1.017... | \n",
96 | "
\n",
97 | " \n",
98 | " 3 | \n",
99 | " 00094d307650046a | \n",
100 | " [0.3379979616695768, -0.3716915576725939, -0.9... | \n",
101 | "
\n",
102 | " \n",
103 | " 4 | \n",
104 | " 000b76a9b80ba43a | \n",
105 | " [1.0386764605271153, 1.0388317175923953, 0.425... | \n",
106 | "
\n",
107 | " \n",
108 | " ... | \n",
109 | " ... | \n",
110 | " ... | \n",
111 | "
\n",
112 | " \n",
113 | " 16677 | \n",
114 | " ffee8a942d5f46f5 | \n",
115 | " [-0.4032568158266959, -1.1216809493878481, 0.3... | \n",
116 | "
\n",
117 | " \n",
118 | " 16678 | \n",
119 | " fff5531bf50e14ff | \n",
120 | " [-1.3310225644952542, 0.6147678624111613, -0.7... | \n",
121 | "
\n",
122 | " \n",
123 | " 16679 | \n",
124 | " fff8991fe2ddeb7e | \n",
125 | " [-1.085723378698474, -0.998268968641006, -0.29... | \n",
126 | "
\n",
127 | " \n",
128 | " 16680 | \n",
129 | " fffd06d7764cfc52 | \n",
130 | " [-1.8009028307864514, 1.2498299457095756, -0.3... | \n",
131 | "
\n",
132 | " \n",
133 | " 16681 | \n",
134 | " fffd81ba57715c4c | \n",
135 | " [-0.12192140594927368, -1.9914731328587736, -0... | \n",
136 | "
\n",
137 | " \n",
138 | "
\n",
139 | "
182547 rows × 2 columns
\n",
140 | "
"
141 | ],
142 | "text/plain": [
143 | " ImageID embedding\n",
144 | "0 0001771c86cadcd6 [0.7272707919841341, 1.336483982316891, -0.164...\n",
145 | "1 0003db14e5cdc1c4 [1.7126857273368188, -0.11685025791386651, -1....\n",
146 | "2 00080e248756d4dd [0.8326104343431097, 0.4016596514351839, 1.017...\n",
147 | "3 00094d307650046a [0.3379979616695768, -0.3716915576725939, -0.9...\n",
148 | "4 000b76a9b80ba43a [1.0386764605271153, 1.0388317175923953, 0.425...\n",
149 | "... ... ...\n",
150 | "16677 ffee8a942d5f46f5 [-0.4032568158266959, -1.1216809493878481, 0.3...\n",
151 | "16678 fff5531bf50e14ff [-1.3310225644952542, 0.6147678624111613, -0.7...\n",
152 | "16679 fff8991fe2ddeb7e [-1.085723378698474, -0.998268968641006, -0.29...\n",
153 | "16680 fffd06d7764cfc52 [-1.8009028307864514, 1.2498299457095756, -0.3...\n",
154 | "16681 fffd81ba57715c4c [-0.12192140594927368, -1.9914731328587736, -0...\n",
155 | "\n",
156 | "[182547 rows x 2 columns]"
157 | ]
158 | },
159 | "execution_count": 20,
160 | "metadata": {},
161 | "output_type": "execute_result"
162 | }
163 | ],
164 | "source": [
165 | "df"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 15,
171 | "id": "46290559-0c60-412d-986c-ac65395e7bea",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "import boto3"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 16,
181 | "id": "fd00414c-2760-4316-ab94-9fe7a3459b58",
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "session = boto3.Session(profile_name = 'MLCommons', region_name = \"us-west-1\")\n",
186 | "s3_client = session.client('s3')"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": 45,
192 | "id": "c4bccb48-9779-4fa6-9422-4c4843ec62b5",
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "available_ids = []"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 46,
202 | "id": "31d644ea-9d26-48e3-82c4-f8d4c9b8fae8",
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "for i in s3_client.list_objects_v2(\n",
207 | " Bucket = 'vision-dataperf',\n",
208 | " Prefix = 'public_train_dataset_256embeddings_dynabench_formatted/train')['Contents']:\n",
209 | " #StartAfter = 'public_train_dataset_256embeddings_dynabench_formatted/train0053e81841897c73.npy' )['Contents']:\n",
210 | " id = i['Key']\n",
211 | " id = id.strip('public_train_dataset_256embeddings_dynabench_formatted/train')\n",
212 | " id = id.strip('.npy')\n",
213 | " if df.loc[df['ImageID'] == id].empty:\n",
214 | " continue\n",
215 | " else:\n",
216 | " available_ids.append(id)\n",
217 | " #print(df.loc[df['ImageID'] == id])"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 89,
223 | "id": "537cf93d-e270-4fed-b97a-ad10ce24610a",
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "n = 1000"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 90,
233 | "id": "66b2c742-5f77-4844-b62f-ce9d01053e19",
234 | "metadata": {},
235 | "outputs": [],
236 | "source": [
237 | "hawks = {}\n",
238 | "for i in random.sample(available_ids,n):\n",
239 | " if (bool(random.getrandbits(1))):\n",
240 | " hawks[i] = 1\n",
241 | " else:\n",
242 | " hawks[i] = 0\n",
243 | "hawks = {'Hawks': hawks}\n",
244 | "with open(\"samples/Hawks.json\", \"w\") as outfile:\n",
245 | " json.dump(hawks, outfile)"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 91,
251 | "id": "b2f959f3-b6de-4f24-9213-31cee1684ee9",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "sushi = {}\n",
256 | "for i in random.sample(available_ids,n):\n",
257 | " if (bool(random.getrandbits(1))):\n",
258 | " sushi[i] = 1\n",
259 | " else:\n",
260 | " sushi[i] = 0\n",
261 | "sushi = {'Sushi': sushi}\n",
262 | "with open(\"samples/Sushi.json\", \"w\") as outfile:\n",
263 | " json.dump(sushi, outfile)"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 92,
269 | "id": "1cbe05e9-434a-4078-a7bf-5726e570f9b2",
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "cupcake = {}\n",
274 | "for i in random.sample(available_ids,n):\n",
275 | " if (bool(random.getrandbits(1))):\n",
276 | " cupcake[i] = 1\n",
277 | " else:\n",
278 | " cupcake[i] = 0\n",
279 | "cupcake = {'Cupcake': cupcake}\n",
280 | "with open(\"samples/Cupcake.json\", \"w\") as outfile:\n",
281 | " json.dump(cupcake, outfile)"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 6,
287 | "id": "5ff94e29-d09d-443c-834f-cedac83461a9",
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "import pickle"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 205,
297 | "id": "60b3acda-59e3-41c1-bfc8-06940db5bff8",
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "with open('outfile', 'wb') as fp:\n",
302 | " pickle.dump(set(available_ids), fp)"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 53,
308 | "id": "437d2f25-61c7-49eb-95c5-cbda837a3676",
309 | "metadata": {},
310 | "outputs": [],
311 | "source": [
312 | "with open ('outfile', 'rb') as fp:\n",
313 | " available_ids = list(pickle.load(fp))"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 55,
319 | "id": "4b3b482a-7e51-471a-a45d-0b065e567e38",
320 | "metadata": {},
321 | "outputs": [],
322 | "source": [
323 | "import random\n",
324 | "import json"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 56,
330 | "id": "c6729787-ff3e-48d0-b174-afd9e63db753",
331 | "metadata": {},
332 | "outputs": [],
333 | "source": [
334 | "n = 1000"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 57,
340 | "id": "e9493083-f67d-4902-b97a-aacbe08d5e8a",
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "birds = {}\n",
345 | "for i in random.sample(available_ids,n):\n",
346 | " if (bool(random.getrandbits(1))):\n",
347 | " birds[i] = 1\n",
348 | " else:\n",
349 | " birds[i] = 0\n",
350 | "birds = {'Birds': birds}\n",
351 | "with open(\"samples/Birds.json\", \"w\") as outfile:\n",
352 | " json.dump(birds, outfile)"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": 58,
358 | "id": "582e0970-faaf-4a08-9c60-b0b7e9e97f7d",
359 | "metadata": {},
360 | "outputs": [],
361 | "source": [
362 | "canoe = {}\n",
363 | "for i in random.sample(available_ids,n):\n",
364 | " if (bool(random.getrandbits(1))):\n",
365 | " canoe[i] = 1\n",
366 | " else:\n",
367 | " canoe[i] = 0\n",
368 | "canoe = {'Canoe': canoe}\n",
369 | "with open(\"samples/Canoe.json\", \"w\") as outfile:\n",
370 | " json.dump(canoe, outfile)"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": 60,
376 | "id": "bd9de99e-213e-426f-b9e6-9e05f9ac4fe4",
377 | "metadata": {},
378 | "outputs": [],
379 | "source": [
380 | "croissant = {}\n",
381 | "for i in random.sample(available_ids,n):\n",
382 | " if (bool(random.getrandbits(1))):\n",
383 | " croissant[i] = 1\n",
384 | " else:\n",
385 | " croissant[i] = 0\n",
386 | "croissant = {'Croissant': croissant}\n",
387 | "with open(\"samples/Croissant.json\", \"w\") as outfile:\n",
388 | " json.dump(croissant, outfile)"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 61,
394 | "id": "f79f8173-ff0c-46ef-93fb-17ca85bea5a1",
395 | "metadata": {},
396 | "outputs": [],
397 | "source": [
398 | "muffin = {}\n",
399 | "for i in random.sample(available_ids,n):\n",
400 | " if (bool(random.getrandbits(1))):\n",
401 | " muffin[i] = 1\n",
402 | " else:\n",
403 | " muffin[i] = 0\n",
404 | "muffin = {'Muffin': muffin}\n",
405 | "with open(\"samples/Muffin.json\", \"w\") as outfile:\n",
406 | " json.dump(muffin, outfile)"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 62,
412 | "id": "f14f619b-e331-44e4-887f-73ec8afdc865",
413 | "metadata": {},
414 | "outputs": [],
415 | "source": [
416 | "pizza = {}\n",
417 | "for i in random.sample(available_ids,n):\n",
418 | " if (bool(random.getrandbits(1))):\n",
419 | " pizza[i] = 1\n",
420 | " else:\n",
421 | " pizza[i] = 0\n",
422 | "pizza = {'Pizza': pizza}\n",
423 | "with open(\"samples/Pizza.json\", \"w\") as outfile:\n",
424 | " json.dump(pizza, outfile)"
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 136,
430 | "id": "c45054cd-6ff7-46b0-8674-9115f8715286",
431 | "metadata": {},
432 | "outputs": [
433 | {
434 | "name": "stdout",
435 | "output_type": "stream",
436 | "text": [
437 | "50 50\n"
438 | ]
439 | }
440 | ],
441 | "source": [
442 | "count = 0\n",
443 | "for id in list(croissant[\"Croissant\"].keys()):\n",
444 | "\n",
445 | " if df.loc[df['ImageID'] == id].empty:\n",
446 | " continue\n",
447 | " else:\n",
448 | " count += 1\n",
449 | "print(count, len(croissant[\"Croissant\"]))"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": 138,
455 | "id": "5e9ddcf9-fab1-4680-9baf-092dd0c017c1",
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/html": [
461 | "\n",
462 | "\n",
475 | "
\n",
476 | " \n",
477 | " \n",
478 | " | \n",
479 | " ImageID | \n",
480 | " embedding | \n",
481 | "
\n",
482 | " \n",
483 | " \n",
484 | " \n",
485 | " 0 | \n",
486 | " 000335f65ee227cf | \n",
487 | " [-1.383943737943668, -0.9189053270236168, -0.1... | \n",
488 | "
\n",
489 | " \n",
490 | "
\n",
491 | "
"
492 | ],
493 | "text/plain": [
494 | " ImageID embedding\n",
495 | "0 000335f65ee227cf [-1.383943737943668, -0.9189053270236168, -0.1..."
496 | ]
497 | },
498 | "execution_count": 138,
499 | "metadata": {},
500 | "output_type": "execute_result"
501 | }
502 | ],
503 | "source": [
504 | "id = '000335f65ee227cf'\n",
505 | "df.loc[df['ImageID'] == id]"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": null,
511 | "id": "cadedc40-4a3f-435e-9672-7c722a7fc193",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": [
515 | "parquet_paths = ['/Users/rafael/Downloads/part-00000-fed1d8f9-f6ab-4f54-88aa-9908abd59bec-c000.snappy.parquet','/Users/rafael/Downloads/part-00001-fed1d8f9-f6ab-4f54-88aa-9908abd59bec-c000.snappy.parquet']\n",
516 | "df = pd.read_parquet(parquet_paths[0])\n",
517 | "df2 = pd.read_parquet(parquet_paths[1])\n",
518 | "df = df.append(df2)\n",
519 | "df"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": null,
525 | "id": "36563be4-a891-4ae9-91f1-922049477bd6",
526 | "metadata": {},
527 | "outputs": [],
528 | "source": []
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": 15,
533 | "id": "b8ca8457-5bcc-4121-9cce-637749b8c4cc",
534 | "metadata": {},
535 | "outputs": [],
536 | "source": [
537 | "muffin2 = {}\n",
538 | "for i in range(51):\n",
539 | " if i%2 == 1:\n",
540 | " muffin2[df['ImageID'].values[i]] = 1\n",
541 | " else:\n",
542 | " muffin2[df['ImageID'].values[i]] = 0\n",
543 | "muffin2 = {'Muffin': muffin2}\n",
544 | "with open(\"Muffin2.json\", \"w\") as outfile:\n",
545 | " json.dump(muffin2, outfile)"
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": 157,
551 | "id": "ded50481-c85d-4b4e-b3f3-049c5b691b31",
552 | "metadata": {},
553 | "outputs": [
554 | {
555 | "data": {
556 | "text/plain": [
557 | "{'dataperf_f1': 54.4,\n",
558 | " 'perf': 54.4,\n",
559 | " 'perf_std': 0.0,\n",
560 | " 'perf_by_tag': [{'tag': 'bird',\n",
561 | " 'pretty_perf': '54.4 %',\n",
562 | " 'perf': 54.4,\n",
563 | " 'perf_std': 0.0,\n",
564 | " 'perf_dict': {'dataperf_f1': 54.4}}]}"
565 | ]
566 | },
567 | "execution_count": 157,
568 | "metadata": {},
569 | "output_type": "execute_result"
570 | }
571 | ],
572 | "source": [
573 | "{\"dataperf_f1\": 54.4, \"perf\": 54.4, \"perf_std\": 0.0, \"perf_by_tag\": [{\"tag\": \"bird\", \"pretty_perf\": \"54.4 %\", \"perf\": 54.4, \"perf_std\": 0.0, \"perf_dict\": {\"dataperf_f1\": 54.4}}]}"
574 | ]
575 | },
576 | {
577 | "cell_type": "code",
578 | "execution_count": 296,
579 | "id": "706ead65-2a25-4d71-8daf-e2ac15f23e22",
580 | "metadata": {},
581 | "outputs": [],
582 | "source": [
583 | "df_new = pd.read_parquet('/Users/rafael/Downloads/dataperf-vision-selection-resources/test_sets/test-hawk.parquet')"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": 297,
589 | "id": "d726fce3-438b-433a-9d39-02d4390ea718",
590 | "metadata": {},
591 | "outputs": [
592 | {
593 | "data": {
594 | "text/plain": [
595 | "['target_label', 'Embedding', 'ImageID']"
596 | ]
597 | },
598 | "execution_count": 297,
599 | "metadata": {},
600 | "output_type": "execute_result"
601 | }
602 | ],
603 | "source": [
604 | "['target_label','Embedding', 'ImageID']"
605 | ]
606 | },
607 | {
608 | "cell_type": "code",
609 | "execution_count": 298,
610 | "id": "d138dd58-87cf-41bd-b704-84eac08770b8",
611 | "metadata": {},
612 | "outputs": [],
613 | "source": [
614 | "df_new = df_new.rename({'ImageID':'ImageID', 'Confidence': 'target_label', 'embedding': 'Embedding'}, axis=1)"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "execution_count": 299,
620 | "id": "dc671ecf-2c3d-46d7-8929-6be6fb432b4e",
621 | "metadata": {},
622 | "outputs": [],
623 | "source": [
624 | "df_new.to_parquet(path = '/Users/rafael/Downloads/dataperf-vision-selection-resources/test_sets/test-hawk.parquet')"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": 300,
630 | "id": "b06f4993-7d45-4a37-a021-6a0e0ae092ef",
631 | "metadata": {},
632 | "outputs": [
633 | {
634 | "data": {
635 | "text/html": [
636 | "\n",
637 | "\n",
650 | "
\n",
651 | " \n",
652 | " \n",
653 | " | \n",
654 | " ImageID | \n",
655 | " LabelName | \n",
656 | " target_label | \n",
657 | " OriginalURL | \n",
658 | " DisplayName | \n",
659 | " Embedding | \n",
660 | "
\n",
661 | " \n",
662 | " \n",
663 | " \n",
664 | " 0 | \n",
665 | " e60d148d428e96cd | \n",
666 | " /m/0fp7c | \n",
667 | " 0 | \n",
668 | " https://c6.staticflickr.com/6/5560/14764818909... | \n",
669 | " Hawk | \n",
670 | " [-0.0038521280919204775, 0.06254096050041212, ... | \n",
671 | "
\n",
672 | " \n",
673 | " 1 | \n",
674 | " bc2260ec5b70b7e5 | \n",
675 | " /m/0fp7c | \n",
676 | " 0 | \n",
677 | " https://farm6.staticflickr.com/3657/3607619201... | \n",
678 | " Hawk | \n",
679 | " [0.19757792792581547, -0.0628706741084504, -1.... | \n",
680 | "
\n",
681 | " \n",
682 | " 2 | \n",
683 | " be7aae0d065c11a0 | \n",
684 | " /m/0fp7c | \n",
685 | " 1 | \n",
686 | " https://farm1.staticflickr.com/7253/6996634521... | \n",
687 | " Hawk | \n",
688 | " [0.17347788607282416, -0.37831432327491554, -1... | \n",
689 | "
\n",
690 | " \n",
691 | " 3 | \n",
692 | " 25880c4f38d01747 | \n",
693 | " /m/0fp7c | \n",
694 | " 1 | \n",
695 | " https://farm8.staticflickr.com/94/257466316_26... | \n",
696 | " Hawk | \n",
697 | " [0.38980357065885335, -0.9966469970143762, -1.... | \n",
698 | "
\n",
699 | " \n",
700 | " 4 | \n",
701 | " 904078b0695ffe0f | \n",
702 | " /m/0fp7c | \n",
703 | " 1 | \n",
704 | " https://c8.staticflickr.com/8/7173/6663562541_... | \n",
705 | " Hawk | \n",
706 | " [0.39844726358909366, -0.2868099119322688, -1.... | \n",
707 | "
\n",
708 | " \n",
709 | " ... | \n",
710 | " ... | \n",
711 | " ... | \n",
712 | " ... | \n",
713 | " ... | \n",
714 | " ... | \n",
715 | " ... | \n",
716 | "
\n",
717 | " \n",
718 | " 140 | \n",
719 | " 7ead8a05548b7dd8 | \n",
720 | " /m/0fp7c | \n",
721 | " 1 | \n",
722 | " https://c1.staticflickr.com/7/6127/6003146852_... | \n",
723 | " Hawk | \n",
724 | " [0.7204737072851969, -0.5799575879679919, -1.7... | \n",
725 | "
\n",
726 | " \n",
727 | " 141 | \n",
728 | " b6030979ffbf6c00 | \n",
729 | " /m/0fp7c | \n",
730 | " 0 | \n",
731 | " https://farm1.staticflickr.com/192/503607905_4... | \n",
732 | " Hawk | \n",
733 | " [0.7399902660512262, -0.34497502398512464, -0.... | \n",
734 | "
\n",
735 | " \n",
736 | " 142 | \n",
737 | " 6f66217f8ffdbeea | \n",
738 | " /m/0fp7c | \n",
739 | " 1 | \n",
740 | " https://c3.staticflickr.com/4/3723/13695947494... | \n",
741 | " Hawk | \n",
742 | " [0.5977249684386043, -0.5686997837439453, -0.6... | \n",
743 | "
\n",
744 | " \n",
745 | " 143 | \n",
746 | " b2692b73994377ce | \n",
747 | " /m/0fp7c | \n",
748 | " 0 | \n",
749 | " https://c7.staticflickr.com/9/8112/8539254331_... | \n",
750 | " Hawk | \n",
751 | " [0.6778727524730443, -0.8354259564874617, -0.7... | \n",
752 | "
\n",
753 | " \n",
754 | " 144 | \n",
755 | " d84b26c3865bec24 | \n",
756 | " /m/0fp7c | \n",
757 | " 1 | \n",
758 | " https://farm5.staticflickr.com/3893/1498221092... | \n",
759 | " Hawk | \n",
760 | " [0.5801936116195033, -0.9496025449658846, -0.8... | \n",
761 | "
\n",
762 | " \n",
763 | "
\n",
764 | "
145 rows × 6 columns
\n",
765 | "
"
766 | ],
767 | "text/plain": [
768 | " ImageID LabelName target_label \\\n",
769 | "0 e60d148d428e96cd /m/0fp7c 0 \n",
770 | "1 bc2260ec5b70b7e5 /m/0fp7c 0 \n",
771 | "2 be7aae0d065c11a0 /m/0fp7c 1 \n",
772 | "3 25880c4f38d01747 /m/0fp7c 1 \n",
773 | "4 904078b0695ffe0f /m/0fp7c 1 \n",
774 | ".. ... ... ... \n",
775 | "140 7ead8a05548b7dd8 /m/0fp7c 1 \n",
776 | "141 b6030979ffbf6c00 /m/0fp7c 0 \n",
777 | "142 6f66217f8ffdbeea /m/0fp7c 1 \n",
778 | "143 b2692b73994377ce /m/0fp7c 0 \n",
779 | "144 d84b26c3865bec24 /m/0fp7c 1 \n",
780 | "\n",
781 | " OriginalURL DisplayName \\\n",
782 | "0 https://c6.staticflickr.com/6/5560/14764818909... Hawk \n",
783 | "1 https://farm6.staticflickr.com/3657/3607619201... Hawk \n",
784 | "2 https://farm1.staticflickr.com/7253/6996634521... Hawk \n",
785 | "3 https://farm8.staticflickr.com/94/257466316_26... Hawk \n",
786 | "4 https://c8.staticflickr.com/8/7173/6663562541_... Hawk \n",
787 | ".. ... ... \n",
788 | "140 https://c1.staticflickr.com/7/6127/6003146852_... Hawk \n",
789 | "141 https://farm1.staticflickr.com/192/503607905_4... Hawk \n",
790 | "142 https://c3.staticflickr.com/4/3723/13695947494... Hawk \n",
791 | "143 https://c7.staticflickr.com/9/8112/8539254331_... Hawk \n",
792 | "144 https://farm5.staticflickr.com/3893/1498221092... Hawk \n",
793 | "\n",
794 | " Embedding \n",
795 | "0 [-0.0038521280919204775, 0.06254096050041212, ... \n",
796 | "1 [0.19757792792581547, -0.0628706741084504, -1.... \n",
797 | "2 [0.17347788607282416, -0.37831432327491554, -1... \n",
798 | "3 [0.38980357065885335, -0.9966469970143762, -1.... \n",
799 | "4 [0.39844726358909366, -0.2868099119322688, -1.... \n",
800 | ".. ... \n",
801 | "140 [0.7204737072851969, -0.5799575879679919, -1.7... \n",
802 | "141 [0.7399902660512262, -0.34497502398512464, -0.... \n",
803 | "142 [0.5977249684386043, -0.5686997837439453, -0.6... \n",
804 | "143 [0.6778727524730443, -0.8354259564874617, -0.7... \n",
805 | "144 [0.5801936116195033, -0.9496025449658846, -0.8... \n",
806 | "\n",
807 | "[145 rows x 6 columns]"
808 | ]
809 | },
810 | "execution_count": 300,
811 | "metadata": {},
812 | "output_type": "execute_result"
813 | }
814 | ],
815 | "source": [
816 | "df_2 = pd.read_parquet('/Users/rafael/Downloads/dataperf-vision-selection-resources/test_sets/test-hawk.parquet')\n",
817 | "df_2"
818 | ]
819 | },
820 | {
821 | "cell_type": "code",
822 | "execution_count": 307,
823 | "id": "22d9d865-98af-4c30-8384-917dea73d680",
824 | "metadata": {},
825 | "outputs": [
826 | {
827 | "data": {
828 | "text/plain": [
829 | "256"
830 | ]
831 | },
832 | "execution_count": 307,
833 | "metadata": {},
834 | "output_type": "execute_result"
835 | }
836 | ],
837 | "source": [
838 | "len(df_2['Embedding'][0])"
839 | ]
840 | },
841 | {
842 | "cell_type": "code",
843 | "execution_count": 235,
844 | "id": "565d7ba7-f8ac-4bbb-9ee6-ef02df8ac110",
845 | "metadata": {},
846 | "outputs": [],
847 | "source": [
848 | "df_sushi = pd.read_csv('/Users/rafael/Downloads/dataperf-vision-selection-resources/train_sets/random_500.csv')"
849 | ]
850 | },
851 | {
852 | "cell_type": "code",
853 | "execution_count": 260,
854 | "id": "4dc48edc-adf0-4956-9047-5017fc6f4b39",
855 | "metadata": {},
856 | "outputs": [],
857 | "source": [
858 | "sushi_dict = {'Sushi': {}}"
859 | ]
860 | },
861 | {
862 | "cell_type": "code",
863 | "execution_count": 264,
864 | "id": "bbf40598-c4d7-45fa-8636-9b4ca469b5e2",
865 | "metadata": {},
866 | "outputs": [],
867 | "source": [
868 | "labels = list(df_sushi['Confidence'].values)"
869 | ]
870 | },
871 | {
872 | "cell_type": "code",
873 | "execution_count": 265,
874 | "id": "c7954136-f869-441c-9767-f5b99c8dac9f",
875 | "metadata": {},
876 | "outputs": [],
877 | "source": [
878 | "ids = list(df_sushi['ImageID'].values)"
879 | ]
880 | },
881 | {
882 | "cell_type": "code",
883 | "execution_count": 274,
884 | "id": "c52e39cb-fbfd-4602-b045-a4cb717c026b",
885 | "metadata": {},
886 | "outputs": [],
887 | "source": [
888 | "intra_dict = {}\n",
889 | "for i in range(500):\n",
890 | " intra_dict[f'{ids[i]}'] = float(labels[i])\n",
891 | " "
892 | ]
893 | },
894 | {
895 | "cell_type": "code",
896 | "execution_count": 275,
897 | "id": "17c639f1-930c-4c73-8679-e46ab00553f9",
898 | "metadata": {},
899 | "outputs": [],
900 | "source": [
901 | "sushi_dict['Sushi'] = intra_dict"
902 | ]
903 | },
904 | {
905 | "cell_type": "code",
906 | "execution_count": 277,
907 | "id": "8523dd33-e4ef-446a-a429-c9674e0d980e",
908 | "metadata": {},
909 | "outputs": [],
910 | "source": [
911 | "with open(\"Sushi.json\", \"w\") as outfile:\n",
912 | " json.dump(sushi_dict, outfile)"
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "execution_count": 316,
918 | "id": "26c87642-9d8f-456f-844f-d24bee7ff103",
919 | "metadata": {},
920 | "outputs": [],
921 | "source": [
922 | "embedding_dict = {}"
923 | ]
924 | },
925 | {
926 | "cell_type": "code",
927 | "execution_count": 394,
928 | "id": "5f8b25a9-32da-4eb1-9364-8f923797747c",
929 | "metadata": {},
930 | "outputs": [],
931 | "source": [
932 | "df_new = pd.read_parquet('/Users/rafael/Downloads/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet/part-00071-c487d15f-d29d-4c93-96c2-4caf4c12cd59-c000.snappy.parquet')"
933 | ]
934 | },
935 | {
936 | "cell_type": "code",
937 | "execution_count": 395,
938 | "id": "92678025-174a-40e8-b1aa-bf925a2173d4",
939 | "metadata": {},
940 | "outputs": [
941 | {
942 | "data": {
943 | "text/html": [
944 | "\n",
945 | "\n",
958 | "
\n",
959 | " \n",
960 | " \n",
961 | " | \n",
962 | " ImageID | \n",
963 | " embedding | \n",
964 | "
\n",
965 | " \n",
966 | " \n",
967 | " \n",
968 | " 0 | \n",
969 | " 000506542b3a9dd7 | \n",
970 | " [-1.2127153543616391, -0.10673216263881578, 0.... | \n",
971 | "
\n",
972 | " \n",
973 | " 1 | \n",
974 | " 00056c76aa3dd52a | \n",
975 | " [0.7178898916090762, 0.826355497831251, 0.7906... | \n",
976 | "
\n",
977 | " \n",
978 | " 2 | \n",
979 | " 000692d348b0f181 | \n",
980 | " [0.6396630836552186, 0.898348981106668, 0.4226... | \n",
981 | "
\n",
982 | " \n",
983 | " 3 | \n",
984 | " 000cabb2eb284c8c | \n",
985 | " [1.1945492605513477, 0.8496085515788221, -0.67... | \n",
986 | "
\n",
987 | " \n",
988 | " 4 | \n",
989 | " 0011cf9a929a4e19 | \n",
990 | " [0.5329256287734359, 1.8642507257449732, 0.614... | \n",
991 | "
\n",
992 | " \n",
993 | " ... | \n",
994 | " ... | \n",
995 | " ... | \n",
996 | "
\n",
997 | " \n",
998 | " 16466 | \n",
999 | " ffe1a2e32b45d1dc | \n",
1000 | " [-0.5439635132177799, -1.0043923233747394, 0.9... | \n",
1001 | "
\n",
1002 | " \n",
1003 | " 16467 | \n",
1004 | " ffe46a38144178dd | \n",
1005 | " [-0.4237458416684124, 0.010014072469498137, -0... | \n",
1006 | "
\n",
1007 | " \n",
1008 | " 16468 | \n",
1009 | " fff3c2eb1c9863c7 | \n",
1010 | " [-0.4050140839736052, 0.9564130760744758, 1.04... | \n",
1011 | "
\n",
1012 | " \n",
1013 | " 16469 | \n",
1014 | " fff41ccc90430cc0 | \n",
1015 | " [-0.4732436203454635, 0.21041616335678826, 0.9... | \n",
1016 | "
\n",
1017 | " \n",
1018 | " 16470 | \n",
1019 | " ffffdc1308c30d53 | \n",
1020 | " [-0.5498107897396933, 0.9266576856606651, 0.31... | \n",
1021 | "
\n",
1022 | " \n",
1023 | "
\n",
1024 | "
16471 rows × 2 columns
\n",
1025 | "
"
1026 | ],
1027 | "text/plain": [
1028 | " ImageID embedding\n",
1029 | "0 000506542b3a9dd7 [-1.2127153543616391, -0.10673216263881578, 0....\n",
1030 | "1 00056c76aa3dd52a [0.7178898916090762, 0.826355497831251, 0.7906...\n",
1031 | "2 000692d348b0f181 [0.6396630836552186, 0.898348981106668, 0.4226...\n",
1032 | "3 000cabb2eb284c8c [1.1945492605513477, 0.8496085515788221, -0.67...\n",
1033 | "4 0011cf9a929a4e19 [0.5329256287734359, 1.8642507257449732, 0.614...\n",
1034 | "... ... ...\n",
1035 | "16466 ffe1a2e32b45d1dc [-0.5439635132177799, -1.0043923233747394, 0.9...\n",
1036 | "16467 ffe46a38144178dd [-0.4237458416684124, 0.010014072469498137, -0...\n",
1037 | "16468 fff3c2eb1c9863c7 [-0.4050140839736052, 0.9564130760744758, 1.04...\n",
1038 | "16469 fff41ccc90430cc0 [-0.4732436203454635, 0.21041616335678826, 0.9...\n",
1039 | "16470 ffffdc1308c30d53 [-0.5498107897396933, 0.9266576856606651, 0.31...\n",
1040 | "\n",
1041 | "[16471 rows x 2 columns]"
1042 | ]
1043 | },
1044 | "execution_count": 395,
1045 | "metadata": {},
1046 | "output_type": "execute_result"
1047 | }
1048 | ],
1049 | "source": [
1050 | "df_new"
1051 | ]
1052 | },
1053 | {
1054 | "cell_type": "code",
1055 | "execution_count": 317,
1056 | "id": "69da299b-8829-49ec-a421-5e642cdce133",
1057 | "metadata": {},
1058 | "outputs": [],
1059 | "source": [
1060 | "embedding_list = (df_2['Embedding'].to_list())"
1061 | ]
1062 | },
1063 | {
1064 | "cell_type": "code",
1065 | "execution_count": 319,
1066 | "id": "6d8d94c9-c246-472e-8c54-1c0d168d652a",
1067 | "metadata": {},
1068 | "outputs": [],
1069 | "source": [
1070 | "id_names = (df_2['ImageID'].to_list())"
1071 | ]
1072 | },
1073 | {
1074 | "cell_type": "code",
1075 | "execution_count": 326,
1076 | "id": "3cab8316-bf74-40ae-a7d8-d2826ca70160",
1077 | "metadata": {},
1078 | "outputs": [],
1079 | "source": [
1080 | "for i in range(len(embedding_list)):\n",
1081 | " embedding_dict[f'train{id_names[i]}.npy'] = embedding_list[i]"
1082 | ]
1083 | },
1084 | {
1085 | "cell_type": "code",
1086 | "execution_count": 330,
1087 | "id": "361e7bb5-6a3c-479e-9ab8-9c5f014f569a",
1088 | "metadata": {},
1089 | "outputs": [],
1090 | "source": [
1091 | "import boto3"
1092 | ]
1093 | },
1094 | {
1095 | "cell_type": "code",
1096 | "execution_count": 361,
1097 | "id": "9150bbd6-d0b6-4530-ba00-1153da0be83a",
1098 | "metadata": {},
1099 | "outputs": [],
1100 | "source": [
1101 | "new_emb['train37510ffe72539c23.npy'] = embedding_dict['train37510ffe72539c23.npy']"
1102 | ]
1103 | },
1104 | {
1105 | "cell_type": "code",
1106 | "execution_count": 376,
1107 | "id": "57d0a835-82df-4583-8d9b-63a115551e19",
1108 | "metadata": {},
1109 | "outputs": [
1110 | {
1111 | "data": {
1112 | "text/plain": [
1113 | "numpy.ndarray"
1114 | ]
1115 | },
1116 | "execution_count": 376,
1117 | "metadata": {},
1118 | "output_type": "execute_result"
1119 | }
1120 | ],
1121 | "source": [
1122 | "type(new_emb['train37510ffe72539c23.npy'])"
1123 | ]
1124 | },
1125 | {
1126 | "cell_type": "code",
1127 | "execution_count": null,
1128 | "id": "0b72dd39-41b8-4d5c-8a3a-a8f2b4da48e2",
1129 | "metadata": {},
1130 | "outputs": [],
1131 | "source": [
1132 | "ftmp = tempfile.NamedTemporaryFile(delete=False)\n",
1133 | "fname = ftmp.name + \".npy\"\n",
1134 | "inp = np.random.rand(100)\n",
1135 | "np.save(fname, inp, allow_pickle=False)\n",
1136 | "out = np.load(fname, allow_pickle=False)"
1137 | ]
1138 | },
1139 | {
1140 | "cell_type": "code",
1141 | "execution_count": 427,
1142 | "id": "58916299-b236-49a5-9168-d7d86f12da55",
1143 | "metadata": {},
1144 | "outputs": [
1145 | {
1146 | "name": "stdout",
1147 | "output_type": "stream",
1148 | "text": [
1149 | "[ 3.89803571e-01 -9.96646997e-01 -1.81296639e+00 -2.11422996e-01\n",
1150 | " 1.11237384e+00 5.50917144e-01 -2.23019595e-01 1.21128170e+00\n",
1151 | " -6.06581952e-01 -1.22170200e-01 1.91518837e-02 2.59825023e+00\n",
1152 | " -1.46337609e-01 1.14671604e+00 -1.04500072e+00 -1.09172317e+00\n",
1153 | " 1.28390928e+00 -6.02485930e-01 -1.38633525e+00 -5.86013834e-01\n",
1154 | " 1.40964381e+00 2.83743064e+00 1.24784790e+00 -9.70432275e-01\n",
1155 | " -7.10446552e-02 -1.33872242e+00 4.68498640e-02 -2.96933628e-01\n",
1156 | " 5.03643044e-01 1.77543576e+00 -7.49830448e-01 -2.92437834e-01\n",
1157 | " -4.35559568e-02 -6.97971199e-01 -1.54200581e+00 1.00275982e+00\n",
1158 | " -1.58844006e-01 -1.10935264e+00 -7.16555608e-01 -3.23272959e-01\n",
1159 | " -1.20135634e+00 4.06764028e-01 -7.16184317e-01 2.85940422e-01\n",
1160 | " 1.42247711e-01 4.45880963e-01 2.65902923e+00 6.63447124e-01\n",
1161 | " 2.96592291e-02 1.96929396e-01 1.19715826e+00 -1.85437922e+00\n",
1162 | " -3.22858810e+00 -5.52233908e-01 -2.28968643e+00 1.89881467e+00\n",
1163 | " -2.69826089e-01 -1.00715497e+00 -4.31742139e-01 8.82787116e-01\n",
1164 | " -2.07622814e-01 2.40897019e+00 -2.04291252e+00 3.29187109e-01\n",
1165 | " 3.30926936e-01 5.40466498e-01 5.21807885e-01 -1.01563803e+00\n",
1166 | " 9.14845711e-01 -4.10491149e-01 6.25429816e-01 1.61488801e+00\n",
1167 | " 2.50226381e+00 6.01955673e-02 -7.91900018e-01 -1.76873686e-01\n",
1168 | " -2.28771317e+00 6.39635918e-01 1.16536545e+00 2.11140401e+00\n",
1169 | " 8.89996511e-02 1.33888259e-01 -9.99618052e-01 3.47843187e+00\n",
1170 | " 8.78792862e-01 1.11952807e-01 1.72209416e+00 6.02893362e-01\n",
1171 | " -8.45727799e-02 3.66698320e-01 7.13103677e-01 2.46110374e+00\n",
1172 | " 7.51076212e-01 2.95859024e+00 8.80468588e-01 -2.62122633e-01\n",
1173 | " -1.19645276e+00 6.12483290e-01 -2.70173794e+00 4.63667627e-01\n",
1174 | " -1.07226326e+00 1.20642437e-02 2.24335401e+00 2.29066151e+00\n",
1175 | " -2.63806008e-01 -3.71698775e-01 9.11383716e-01 -5.66627229e-02\n",
1176 | " -8.54939232e-01 -2.07023799e-01 1.79320635e+00 -8.95001470e-01\n",
1177 | " -1.75708440e+00 6.49242586e-01 -5.17078360e-01 4.03644993e-01\n",
1178 | " -5.89727554e-01 -1.64884495e+00 2.47910865e+00 9.47566147e-01\n",
1179 | " -2.24377446e-01 -2.23720885e-03 2.60619694e+00 1.59679699e-01\n",
1180 | " -3.93530441e-01 -1.51868802e+00 3.48610412e-01 8.99090255e-02\n",
1181 | " 1.15730898e+00 1.10894727e-01 -4.43198262e-02 -4.42949224e-01\n",
1182 | " -7.60400772e-01 -4.39890787e-01 2.51312671e-01 -8.79998981e-01\n",
1183 | " 1.20019254e+00 1.53487517e+00 1.66548900e-02 1.77565735e+00\n",
1184 | " 2.38695643e+00 -2.68797990e-01 1.22994696e+00 1.05242160e-01\n",
1185 | " -2.24284337e-01 -1.74365392e-01 -1.29437818e+00 1.29453605e+00\n",
1186 | " 1.01599930e+00 1.82553687e+00 -9.77754401e-01 7.55106541e-01\n",
1187 | " 4.22414110e-01 6.97701425e-01 2.74640753e+00 2.72244357e+00\n",
1188 | " 7.00360122e-01 -6.34062527e-01 1.53573036e+00 1.03219454e+00\n",
1189 | " 1.37059934e-01 -3.00395361e-01 9.46676673e-01 1.42820776e+00\n",
1190 | " -1.83891427e+00 1.60916704e-01 6.71404128e-01 -7.54523616e-02\n",
1191 | " 5.98854364e-01 -7.71613639e-01 -6.55620813e-01 -1.41284540e+00\n",
1192 | " -8.83293019e-02 3.67647495e+00 6.14790081e-01 -3.52725149e-01\n",
1193 | " 2.49560669e+00 3.17937151e+00 1.41134017e-01 4.70473738e-01\n",
1194 | " 1.56911217e-01 -8.69234430e-01 5.51784682e-01 -1.12584407e+00\n",
1195 | " -2.36649778e+00 -4.99193641e-01 -1.07530067e+00 -3.71462427e-01\n",
1196 | " 4.84255341e-01 -1.89550505e+00 1.51478485e+00 1.84459741e+00\n",
1197 | " 2.19528381e-01 -2.81967307e-01 -3.24608813e-01 1.88795676e+00\n",
1198 | " 1.27942137e+00 -1.37998819e-01 -2.05380238e-01 4.26948425e-01\n",
1199 | " -3.89293401e-01 4.01363498e-01 1.98159891e+00 -1.57827198e-01\n",
1200 | " -5.85166423e-01 -7.13019953e-01 -6.38005962e-01 1.43934620e+00\n",
1201 | " 7.15919559e-01 -5.31886315e-01 3.68193605e-01 -1.90642042e+00\n",
1202 | " 8.62036109e-01 -6.52963357e-01 1.07742582e-01 2.81807473e+00\n",
1203 | " 1.23933507e+00 2.42569902e+00 2.17956315e+00 -9.52152608e-02\n",
1204 | " 9.07720855e-01 1.01096203e+00 -3.24792922e-01 2.39230798e+00\n",
1205 | " -6.16348712e-01 2.36699218e+00 2.11694442e+00 5.38469330e-01\n",
1206 | " -3.01368494e-01 3.72789505e-01 1.87059031e+00 2.23637478e-03\n",
1207 | " 5.00828638e-01 6.83015162e-01 -2.53589557e-01 1.37718434e+00\n",
1208 | " 1.67791787e+00 4.18906651e-01 -1.38630361e+00 -8.20388707e-01\n",
1209 | " -1.77622874e+00 -2.51923289e+00 1.99978151e+00 -5.16371872e-01\n",
1210 | " 1.42943314e-01 -8.87400970e-01 8.69817718e-02 -1.54702236e-02\n",
1211 | " -2.41089198e-01 -1.48360916e+00 1.91162173e+00 1.27717093e+00\n",
1212 | " 2.13009238e-01 -5.98617984e-01 -6.32361384e-01 1.94918398e+00]\n",
1213 | "[ 1.73477886e-01 -3.78314323e-01 -1.15957538e+00 -4.55457117e-01\n",
1214 | " -1.24829119e-01 -6.82458939e-01 2.47177816e-01 6.75238190e-01\n",
1215 | " 6.32115197e-01 -6.86811856e-01 1.31945937e-01 2.90450028e+00\n",
1216 | " -7.81145498e-02 7.61810122e-01 3.53271576e-01 -1.93634324e+00\n",
1217 | " 4.34813137e-01 7.93498748e-01 -9.99374515e-01 3.57337146e-01\n",
1218 | " 7.67517595e-01 3.23606027e+00 8.09872843e-01 -1.87386090e+00\n",
1219 | " 4.02630880e-01 -1.97024956e+00 -2.02397129e+00 6.84022534e-01\n",
1220 | " -6.97145592e-01 2.57936995e+00 -1.79964854e+00 -8.49397060e-01\n",
1221 | " 6.13147645e-01 8.10354836e-01 -3.39851667e+00 1.22421018e+00\n",
1222 | " -8.60824905e-01 -1.60722203e+00 3.59145122e-01 2.01454811e+00\n",
1223 | " -1.61948743e-01 -1.00224593e+00 5.73652303e-01 6.30755277e-01\n",
1224 | " 4.14952294e-01 2.16821687e+00 2.53577484e+00 1.82889149e+00\n",
1225 | " 1.75935907e-01 -6.57751382e-02 1.10174426e+00 -2.06796838e+00\n",
1226 | " -2.94581082e+00 -1.28551166e+00 -1.51854207e+00 1.67130640e+00\n",
1227 | " -1.03049907e+00 -5.07271596e-01 -6.97514998e-01 -1.84129927e-01\n",
1228 | " -7.49230290e-01 2.12516077e+00 -5.71402462e-01 7.41986199e-01\n",
1229 | " -2.57257231e-01 9.47876568e-01 2.24098504e+00 -5.17886593e-01\n",
1230 | " 2.59016190e+00 2.18350262e+00 -2.46278041e+00 2.97659203e+00\n",
1231 | " 1.79582747e-01 2.31392899e+00 7.35421377e-01 -1.36229226e-01\n",
1232 | " -1.11759859e+00 4.11274083e-01 2.02971659e+00 2.34953279e+00\n",
1233 | " 1.39825615e-01 1.29747795e+00 -1.92871190e+00 7.62746025e-01\n",
1234 | " 1.78325199e+00 -7.74352229e-01 1.96503445e+00 -8.22589279e-01\n",
1235 | " 4.76711337e-01 1.97345695e-01 2.41591692e-04 3.03953992e+00\n",
1236 | " -3.64576371e-01 -5.85988048e-01 2.30729064e-02 2.40199793e+00\n",
1237 | " 4.04458971e-01 4.62537789e-01 -9.44197385e-01 -3.94466042e-01\n",
1238 | " -7.71524621e-01 -1.44357413e+00 3.80624581e-01 1.02348017e+00\n",
1239 | " -2.51731718e-01 -1.82386160e+00 -1.98006934e-01 1.71773232e+00\n",
1240 | " 3.39659050e-01 -1.61678386e-01 8.64877151e-01 5.27221507e-01\n",
1241 | " -1.33709202e+00 -5.17925673e-01 5.45324675e-01 -3.39876880e-01\n",
1242 | " 6.28603177e-01 1.05901119e+00 3.21904938e+00 5.65501247e-01\n",
1243 | " -1.36764496e+00 -1.05371234e+00 1.36910906e+00 2.13052453e+00\n",
1244 | " -1.53023197e+00 -7.13439829e-02 1.77023797e+00 3.76738890e-01\n",
1245 | " 5.80164396e-01 1.32691407e+00 6.60211046e-01 -5.98368078e-01\n",
1246 | " 6.98442518e-01 -2.54519675e+00 1.57567109e+00 4.16967821e-01\n",
1247 | " 4.76946804e-01 -9.59922631e-01 9.60794297e-02 1.55166016e-01\n",
1248 | " 1.81177522e+00 -1.52597836e+00 2.82341435e+00 1.21299257e-01\n",
1249 | " -8.44809950e-01 7.03000928e-02 1.53342508e-01 3.38649085e+00\n",
1250 | " 5.65857183e-01 3.49518456e+00 -1.98264737e+00 8.48937940e-01\n",
1251 | " -4.80279170e-01 8.66791453e-01 7.33455740e-01 3.36394531e+00\n",
1252 | " -7.11448994e-02 -1.67163688e+00 1.89634404e+00 5.24291752e-01\n",
1253 | " -2.32818349e+00 -9.27940928e-01 -2.84750277e-01 2.78600716e+00\n",
1254 | " -3.27495724e+00 1.23622920e+00 -4.39654288e-01 1.33676926e+00\n",
1255 | " -3.71218248e-02 -4.67805272e-01 -1.19710997e+00 -1.01201737e+00\n",
1256 | " 4.00629907e-01 2.37262771e+00 1.93417388e+00 -3.06882494e-01\n",
1257 | " 1.60856514e+00 2.23441980e+00 1.54979445e+00 1.71643700e+00\n",
1258 | " 1.36638736e-01 1.06117105e+00 -1.02864472e+00 -2.34444189e+00\n",
1259 | " -2.94609160e+00 -1.58258268e+00 2.89173790e+00 1.16163355e+00\n",
1260 | " 2.59235939e+00 -2.02021420e+00 -7.73131893e-01 1.29344678e+00\n",
1261 | " 3.99642409e-01 2.74359725e-01 -4.29611048e+00 3.21451459e+00\n",
1262 | " 8.61943075e-01 -8.29561193e-02 3.62781221e-01 -1.28975848e+00\n",
1263 | " 1.45744159e+00 -5.07338791e-02 2.80161335e-01 -1.88915363e+00\n",
1264 | " -5.48913797e-02 1.59280869e+00 -2.64753155e+00 3.05718540e-02\n",
1265 | " 8.57470089e-01 3.38945004e-01 -4.86244452e-01 3.47816665e-01\n",
1266 | " -6.40024050e-02 -1.89914441e-01 -1.09029680e+00 1.28981721e+00\n",
1267 | " 1.57027400e+00 1.02784636e+00 3.18322695e+00 -6.64335702e-01\n",
1268 | " 1.79662418e+00 3.36934653e-01 7.86187765e-01 3.36239847e+00\n",
1269 | " -6.71641773e-01 4.16299206e-01 1.67452233e+00 -2.92389720e-01\n",
1270 | " -9.64364373e-01 3.33029921e-01 6.98441369e-01 3.99958218e-01\n",
1271 | " 9.19708275e-02 -8.20748303e-01 3.45360249e-01 1.13682329e+00\n",
1272 | " 6.91683315e-01 -7.31566063e-01 8.30114479e-01 -2.31014512e+00\n",
1273 | " -3.03351615e+00 -7.30246093e-01 -2.97940655e+00 2.73203994e-01\n",
1274 | " -4.12306385e-01 -5.42596202e-01 -6.75748803e-01 -2.75624418e-02\n",
1275 | " 1.45179840e+00 -6.41767557e-01 4.80260974e-01 1.02999982e+00\n",
1276 | " 1.62948661e+00 -7.74186094e-01 -6.38427721e-01 1.36770599e+00]\n",
1277 | "[ 0.74329027 -0.3775754 -1.55002959 -0.08654639 -0.84612519 -0.66517953\n",
1278 | " -0.02449258 0.36550824 -0.09603422 -0.42939465 0.05933956 3.72453531\n",
1279 | " -0.60181018 -0.18968162 -0.33691938 -1.21067376 1.69810332 -0.68387043\n",
1280 | " -2.19420451 0.30623149 0.04384658 3.88534118 2.54776959 -0.48658014\n",
1281 | " 0.48451308 -0.31657583 -2.93257781 -0.09047868 0.27749259 2.14582287\n",
1282 | " -0.7742917 0.61943849 1.17077945 0.65668756 -1.49204705 0.82001543\n",
1283 | " -1.14004735 -1.53238315 -0.9232495 1.49888674 -1.35190679 -1.06768895\n",
1284 | " -0.98028353 0.26184235 0.33910083 -0.6371447 0.83411936 0.6429185\n",
1285 | " 0.25254067 1.61010048 1.70969211 -2.14806023 -1.63120278 0.22985059\n",
1286 | " -1.83635132 0.81144822 -0.85230331 -0.01214619 -1.0530559 0.78959619\n",
1287 | " -1.29366791 2.32742711 -0.47322131 0.18036712 0.2442932 0.57961274\n",
1288 | " 1.83769874 -1.92854169 -0.32842854 -0.12183421 1.06949212 -1.1666125\n",
1289 | " -0.18457919 1.42867819 0.22614148 1.24973795 -0.27109001 -1.30841724\n",
1290 | " -0.1323186 1.50524927 -0.06151336 0.36280611 -0.33027768 -0.68321206\n",
1291 | " 1.62078746 -1.69239535 0.69330277 2.77379231 -0.89682064 -0.79466231\n",
1292 | " 0.19352146 1.983096 0.32072105 1.41947621 0.45555747 0.90832593\n",
1293 | " -2.51358088 -0.24226676 -1.49537912 -0.2563347 -1.98138404 -0.37085698\n",
1294 | " 1.62641746 1.06244157 -0.82419462 -0.57390422 0.9138385 -1.00704416\n",
1295 | " -0.23448064 -1.37492495 -0.25682802 0.11005032 0.23172159 -1.56330803\n",
1296 | " 0.91391364 -0.34495631 0.50074237 -0.4066586 -0.00549562 3.24345052\n",
1297 | " 0.75700746 -1.69821946 1.77941427 -0.28645015 -0.57915812 1.20586644\n",
1298 | " 0.36845146 0.58970282 0.22338647 0.20552488 0.51302908 1.04770178\n",
1299 | " -0.68046897 0.52678917 -0.6986149 -1.40170292 0.61719286 -0.47468235\n",
1300 | " -0.56355147 -1.18205737 -0.18334463 1.71284032 2.02351803 0.82235456\n",
1301 | " -0.17395866 0.0966321 -0.06158389 0.85559931 -0.34557713 1.28382154\n",
1302 | " 1.02731127 -0.92987219 -0.21203948 0.63788485 0.84467619 1.97563281\n",
1303 | " -0.50550261 -1.95678332 1.12111351 -1.63963881 1.04433216 -1.59185571\n",
1304 | " 1.44506895 2.29460523 1.94414325 -0.33587755 -0.36165364 0.26807456\n",
1305 | " 0.69051003 0.19675298 1.47340625 0.2810714 0.32018736 0.70501098\n",
1306 | " 0.5535334 2.53173343 0.11880176 2.23190851 0.55329793 1.51165165\n",
1307 | " 2.05233722 0.0862875 -0.68701175 -1.41732675 -1.63925079 -1.93015384\n",
1308 | " -0.08638934 2.0572556 -0.34858196 0.83103614 -1.45979835 -1.14297843\n",
1309 | " 2.16423315 -0.06061224 0.13932088 0.5792231 1.30258981 -0.31806488\n",
1310 | " -0.71589099 -0.29994465 -1.2526677 -1.78460434 1.16588397 -0.328064\n",
1311 | " -0.74002878 0.30247325 0.3442312 -0.28380328 -0.62916205 -0.75991837\n",
1312 | " 1.97814817 0.51295979 0.36566438 -0.25536378 -0.27424215 0.04552544\n",
1313 | " 0.17500421 1.00346983 1.35499899 -0.38314174 -0.50808836 0.24078918\n",
1314 | " 0.60978008 1.91112286 0.82011784 -0.2615277 -0.26694167 -0.67379871\n",
1315 | " -2.74933261 1.37684731 1.99775643 0.10294325 0.5447847 -2.57597607\n",
1316 | " 2.63937386 0.68193744 0.08253477 0.39086499 -0.66228937 -2.01208531\n",
1317 | " -1.61198843 -0.99157155 0.74521818 -0.22949644 -0.66489836 1.76523025\n",
1318 | " 0.11179357 1.58020681 0.23154623 -0.7978364 0.02799542 0.51259032\n",
1319 | " 0.64821907 -1.05749894 -0.83042711 -1.11948847]\n"
1320 | ]
1321 | }
1322 | ],
1323 | "source": [
1324 | "s3_client = session.client('s3')\n",
1325 | "for i in new_emb:\n",
1326 | " with tempfile.NamedTemporaryFile() as tf:\n",
1327 | " fname = tf.name + \".npy\"\n",
1328 | " np.save(fname, new_emb[i], allow_pickle=False)\n",
1329 | " print(np.load(fname, allow_pickle = False))\n",
1330 | " #s3_client.upload_file(Filename = tf.name, Bucket= 'vision-dataperf', Key = f'public_train_dataset_256_embeddings_dynabench_formatted/{i}')"
1331 | ]
1332 | }
1333 | ],
1334 | "metadata": {
1335 | "kernelspec": {
1336 | "display_name": "Python 3",
1337 | "language": "python",
1338 | "name": "python3"
1339 | },
1340 | "language_info": {
1341 | "codemirror_mode": {
1342 | "name": "ipython",
1343 | "version": 3
1344 | },
1345 | "file_extension": ".py",
1346 | "mimetype": "text/x-python",
1347 | "name": "python",
1348 | "nbconvert_exporter": "python",
1349 | "pygments_lexer": "ipython3",
1350 | "version": "3.9.5"
1351 | }
1352 | },
1353 | "nbformat": 4,
1354 | "nbformat_minor": 5
1355 | }
1356 |
--------------------------------------------------------------------------------
/baselines/pseudo_label_generation/dataperf_vision_experiments.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from sklearn.neural_network import MLPClassifier
4 |
5 | import sys
6 | sys.path.append("dataperf-vision-selection/")
7 | from main import run_tasks
8 | from utils import SubmissionCreator, EnsembleOOD
9 |
10 |
11 | # Load the data
12 | data = pd.read_parquet('./dataperf-vision-selection/data/embeddings/train_emb_256_dataperf.parquet', engine='pyarrow')
13 | data_np = np.vstack(data['embedding'].values)
14 | data_ids = data["ImageID"]
15 |
16 | base_folder = './dataperf-vision-selection/data/'
17 |
18 | data_cupcake = pd.read_csv(base_folder + 'examples/alpha_example_set_Cupcake.csv')
19 | cupcake_ids = data_cupcake['ImageID'].values
20 | cupcake_idx = np.where([id_ in cupcake_ids for id_ in data_ids])
21 | cupcake_np = data_np[cupcake_idx]
22 | print(cupcake_np.shape)
23 |
24 | data_hawk = pd.read_csv(base_folder + 'examples/alpha_example_set_Hawk.csv')
25 | hawk_ids = data_hawk['ImageID'].values
26 | hawk_idx = np.where([id_ in hawk_ids for id_ in data_ids])
27 | hawk_np = data_np[hawk_idx]
28 | print(hawk_np.shape)
29 |
30 | data_sushi = pd.read_csv(base_folder + 'examples/alpha_example_set_Sushi.csv')
31 | sushi_ids = data_sushi['ImageID'].values
32 | sushi_idx = np.where([id_ in sushi_ids for id_ in data_ids])
33 | sushi_np = data_np[sushi_idx]
34 | print(sushi_np.shape)
35 |
36 | # Experiment 1: baseline, the provided examples
37 | example_cupcake_ids = data_cupcake["ImageID"].values
38 | example_sushi_ids = data_sushi["ImageID"].values
39 | example_hawk_ids = data_hawk["ImageID"].values
40 |
41 | sc = SubmissionCreator(example_cupcake_ids, example_hawk_ids, example_sushi_ids)
42 |
43 | sc.write(base_folder + 'train_sets/cupcake.csv', 0,
44 | base_folder + 'train_sets/cupcake_baseline.json')
45 |
46 | sc.write(base_folder + 'train_sets/hawk.csv', 1,
47 | base_folder + 'train_sets/hawk_baseline.jso')
48 |
49 | sc.write(base_folder + 'train_sets/sushi.csv', 2,
50 | base_folder + 'train_sets/sushi_baseline.json')
51 |
52 | run_tasks("task_setup_new.yaml")
53 |
54 | # Experiment 2: Random selection of 333 images per class with pseudo-labels from sklearn classifier
55 |
56 | data_train = np.vstack([cupcake_np, hawk_np, sushi_np])
57 | labels_train = np.hstack([np.ones(20) * 0, np.ones(20) * 1, np.ones(20) * 2])
58 |
59 | mlp = MLPClassifier(25)
60 | mlp.fit(data_train, labels_train)
61 | mlp.score(data_train, labels_train)
62 |
63 | pseudo_labels = mlp.predict(data_np)
64 |
65 | sushi_ids = np.random.choice(data_ids[pseudo_labels == 2], 333, replace=False)
66 | hawk_ids = np.random.choice(data_ids[pseudo_labels == 1], 333, replace=False)
67 | cupcake_ids = np.random.choice(data_ids[pseudo_labels == 0], 333, replace=False)
68 |
69 | sc = SubmissionCreator(cupcake_ids, hawk_ids, sushi_ids)
70 |
71 | sc.write(base_folder + "train_sets/cupcake.csv", 0, base_folder + "train_sets/cupcake_nn.json")
72 | sc.write(base_folder + "train_sets/hawk.csv", 1, base_folder + "train_sets/hawk_nn.json")
73 | sc.write(base_folder + "train_sets/sushi.csv", 2, base_folder + "train_sets/sushi_nn.json")
74 |
75 | # Experiment 3: Select samples based on uncertainty
76 | ood = EnsembleOOD(mlp, data_train, labels_train, n_models=50)
77 | pseudo_labels = ood.predict(data_np)
78 | uncertainty = ood.predict_value(data_np, None)
79 | prob = uncertainty / uncertainty.sum()
80 |
81 | sushi_ids = np.random.choice(data_ids[pseudo_labels == 2], 333, replace=False,
82 | p=prob[pseudo_labels == 2] / np.sum(prob[pseudo_labels == 2]))
83 | hawk_ids = np.random.choice(data_ids[pseudo_labels == 1], 333, replace=False,
84 | p=prob[pseudo_labels == 1] / np.sum(prob[pseudo_labels == 1]))
85 | cupcake_ids = np.random.choice(data_ids[pseudo_labels == 0], 333, replace=False,
86 | p=prob[pseudo_labels == 0] / np.sum(prob[pseudo_labels == 0]))
87 |
88 | sc = SubmissionCreator(cupcake_ids, hawk_ids, sushi_ids)
89 | sc.write(base_folder + "train_sets/cupcake.csv", 0, base_folder + "train_sets/cupcake_ood.json")
90 | sc.write(base_folder + "train_sets/hawk.csv", 1, base_folder + "train_sets/hawk_ood.json")
91 | sc.write(base_folder + "train_sets/sushi.csv", 2, base_folder + "train_sets/sushi_ood.json")
92 |
--------------------------------------------------------------------------------
/baselines/pseudo_label_generation/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import pandas as pd
4 | import sklearn
5 |
6 |
7 | def make_json(train_file, file_name, data_class_name):
8 | """"""
9 | df = pd.read_csv(train_file)
10 | dct_ = {data_class_name: {}}
11 | dct = dct_[data_class_name]
12 |
13 | for index, row in df.iterrows():
14 | dct.update({row["ImageID"]: row["Confidence"]})
15 |
16 | json_string = json.dumps(dct_)
17 | json_file = open(file_name, "w")
18 | json_file.write(json_string)
19 | json_file.close()
20 |
21 |
22 | class SubmissionCreator:
23 |
24 | def __init__(self, cupcake_ids, hawk_ids, sushi_ids):
25 | """"""
26 | self.cupcake = cupcake_ids
27 | self.hawk = hawk_ids
28 | self.sushi = sushi_ids
29 |
30 | self.names = ["Cupcake", "Hawk", "Sushi"]
31 |
32 | def write(self, file_name, target_class, file_name_json=""):
33 | """"""
34 | data = np.hstack((self.cupcake, self.hawk, self.sushi))
35 | labels = np.zeros(data.shape[0])
36 |
37 | if target_class == 0:
38 | labels[:self.cupcake.shape[0]] = 1
39 |
40 | if target_class == 1:
41 | labels[self.cupcake.shape[0]:self.cupcake.shape[0] + self.hawk.shape[0]] = 1
42 |
43 | if target_class == 2:
44 | labels[-self.sushi.shape[0]:] = 1
45 |
46 | df = pd.DataFrame(data={"ImageID": data, "Confidence": labels.astype(int)})
47 | df.to_csv(file_name, index=False)
48 |
49 | if file_name_json != "":
50 | make_json(file_name, file_name_json, self.names[target_class])
51 |
52 |
53 | class EnsembleOOD:
54 |
55 | def __init__(self, model, X, y, n_models=10): # noqa
56 | """"""
57 | self.model = model
58 | self.num_models = n_models
59 |
60 | self.models = [sklearn.base.clone(self.model).fit(X, y) for _ in range(self.num_models)]
61 |
62 | def predict_value(self, X, _y): # noqa
63 | predictions = np.array([model.predict(X) for model in self.models])
64 | std = np.std(predictions, axis=0)
65 | return std
66 |
67 | def predict(self, X): # noqa
68 | """Compute predictions with majority voting for the ensemble."""
69 | predictions = np.array([model.predict(X) for model in self.models])
70 | mean = np.mean(predictions, axis=0)
71 | return np.round(mean)
72 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | # For setup file
2 | DEFAULT_SETUP_YAML_PATH = 'task_setup.yaml'
3 | SETUP_YAML_LOCAL_DATA_DIR_KEY = 'data_dir'
4 | SETUP_YAML_DOCKER_DATA_DIR_KEY = 'docker_data_dir'
5 | SETUP_YAML_DIM_KEY = 'dim'
6 | SETUP_YAML_TASKS_KEY = 'eval_tasks'
7 | SETUP_YAML_EMB_KEY = 'emb_file'
8 | SETUP_YAML_SPARK_MEM_KEY = 'spark_driver_memory'
9 | SETUP_YAML_RESULTS_KEY = 'results_dir'
10 |
11 | # Column names in train, test and embedding files
12 | EMB_COL = 'embedding'
13 | ID_COL = 'ImageID'
14 | LABEL_COL = 'Confidence'
15 |
16 | # For results
17 | RESULT_FILE_PREFIX = 'result'
18 |
19 | # General
20 | RANDOM_SEED = 0
21 |
--------------------------------------------------------------------------------
/data/embeddings/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
--------------------------------------------------------------------------------
/data/examples/alpha_example_set_Cupcake.csv:
--------------------------------------------------------------------------------
1 | LabelName,ImageID,Confidence,OriginalURL,DisplayName
2 | /m/03p1r4,04e3345ac6812367,1,https://farm8.staticflickr.com/4041/4703281077_377b2381b6_o.jpg,Cupcake
3 | /m/03p1r4,283b1f143f257cb2,1,https://farm3.staticflickr.com/2640/3934727884_2d6c0e840e_o.jpg,Cupcake
4 | /m/03p1r4,0048aff8e997b7bb,1,https://c4.staticflickr.com/4/3452/3959385380_1a4c1a0c72_o.jpg,Cupcake
5 | /m/03p1r4,0e595758eca6bc7b,1,https://c4.staticflickr.com/4/3466/3870361733_458cfb9811_o.jpg,Cupcake
6 | /m/03p1r4,03822e0571376f3a,1,https://c3.staticflickr.com/3/2572/3735793704_ea2426c4a2_o.jpg,Cupcake
7 | /m/03p1r4,088cef39f4d9600b,1,https://c4.staticflickr.com/4/3435/3828923954_c945cb47eb_o.jpg,Cupcake
8 | /m/03p1r4,0c7713ac6a33580d,1,https://c6.staticflickr.com/9/8524/8680026031_c90b37a52f_o.jpg,Cupcake
9 | /m/03p1r4,8661675b69bcd451,1,https://farm2.staticflickr.com/2939/14553454122_a8ed4a39a2_o.jpg,Cupcake
10 | /m/03p1r4,19d61a65a14fc338,1,https://c1.staticflickr.com/5/4014/4303193380_ac6921e649_o.jpg,Cupcake
11 | /m/03p1r4,127064a5f9cbc7fb,1,https://c4.staticflickr.com/8/7047/6818757202_1ebd77a426_o.jpg,Cupcake
12 | /m/03p1r4,4f28a31461581f24,1,https://farm5.staticflickr.com/8086/8543665572_cfcd413833_o.jpg,Cupcake
13 | /m/03p1r4,0cf0c2409697ac19,1,https://c2.staticflickr.com/2/1196/5145082590_9d4c4cc564_o.jpg,Cupcake
14 | /m/03p1r4,0285bbd389b7b61b,1,https://c4.staticflickr.com/3/2756/4391995615_f318b2488f_o.jpg,Cupcake
15 | /m/03p1r4,063eae052bb7d0db,1,https://c4.staticflickr.com/8/7347/12443516385_0a515fd7cc_o.jpg,Cupcake
16 | /m/03p1r4,086dd93ba3986bf4,1,https://farm5.staticflickr.com/3129/2876175548_0e861137d6_o.jpg,Cupcake
17 | /m/03p1r4,0548f98e537f1e6e,1,https://c7.staticflickr.com/4/3531/3826774014_d251f5b27d_o.jpg,Cupcake
18 | /m/03p1r4,1cc97730294411bc,1,https://c5.staticflickr.com/9/8679/16330521258_331b7f6c3b_o.jpg,Cupcake
19 | /m/03p1r4,61c83243b4dfd6a3,1,https://farm4.staticflickr.com/4026/4342942316_a19f9f3407_o.jpg,Cupcake
20 | /m/03p1r4,0574b4df4b28a75c,1,https://c5.staticflickr.com/3/2488/4193123064_6cbecf1b86_o.jpg,Cupcake
21 | /m/03p1r4,0dd9e1b7cdbe2b95,1,https://farm6.staticflickr.com/3725/13572801914_2a15942d70_o.jpg,Cupcake
22 |
--------------------------------------------------------------------------------
/data/examples/alpha_example_set_Hawk.csv:
--------------------------------------------------------------------------------
1 | LabelName,ImageID,Confidence,OriginalURL,DisplayName
2 | /m/0fp7c,58359817daa6a0ff,1,https://c8.staticflickr.com/7/6119/6406452785_710d4a404e_o.jpg,Hawk
3 | /m/0fp7c,920f445b8580ae55,1,https://farm6.staticflickr.com/44/162525080_c5b31676f0_o.jpg,Hawk
4 | /m/0fp7c,0991a6b74935b991,1,https://farm2.staticflickr.com/7410/8994260910_3f184016c0_o.jpg,Hawk
5 | /m/0fp7c,0bb6ced3dbfd0028,1,https://c1.staticflickr.com/9/8293/7625588028_85c718fde3_o.jpg,Hawk
6 | /m/0fp7c,03225433b89f66f0,1,https://farm1.staticflickr.com/8354/8345997333_7376bf043e_o.jpg,Hawk
7 | /m/0fp7c,055323b81ac99c57,1,https://farm8.staticflickr.com/4067/4275906471_0f02561237_o.jpg,Hawk
8 | /m/0fp7c,0b8cbed949c1348c,1,https://c5.staticflickr.com/7/6211/6299251927_032762f446_o.jpg,Hawk
9 | /m/0fp7c,162987bf2dedd64c,1,https://c5.staticflickr.com/6/5529/9822190545_9f171a476c_o.jpg,Hawk
10 | /m/0fp7c,19b6f08a80e6cf6f,1,https://c8.staticflickr.com/3/2143/2486820398_8cd643f0f2_o.jpg,Hawk
11 | /m/0fp7c,01fe78acc4b6cff2,1,https://farm3.staticflickr.com/8432/7811850262_4c0a670124_o.jpg,Hawk
12 | /m/0fp7c,03d1abe089763813,1,https://farm4.staticflickr.com/5524/11614684325_1788fcd159_o.jpg,Hawk
13 | /m/0fp7c,18bdd57f2c3649e9,1,https://c2.staticflickr.com/1/57/156768050_71668cd26b_o.jpg,Hawk
14 | /m/0fp7c,0ed6cf921a06ffc3,1,https://farm4.staticflickr.com/7543/15837042400_210de31bdb_o.jpg,Hawk
15 | /m/0fp7c,3a17d7c310845da4,1,https://farm5.staticflickr.com/7329/15847340443_7fd343f22e_o.jpg,Hawk
16 | /m/0fp7c,1e5f5e1c51e433a7,1,https://farm8.staticflickr.com/2844/12019952706_01921a24a1_o.jpg,Hawk
17 | /m/0fp7c,33d07430c9e1c1c0,1,https://farm1.staticflickr.com/2414/1676415495_80b38ec6ba_o.jpg,Hawk
18 | /m/0fp7c,25438ec6557bbeef,1,https://c8.staticflickr.com/9/8576/16505512336_22c1400200_o.jpg,Hawk
19 | /m/0fp7c,0454fcd2f6aa87b6,1,https://farm7.staticflickr.com/3141/2675478886_a9bf3f97b9_o.jpg,Hawk
20 | /m/0fp7c,172b142392991d65,1,https://farm6.staticflickr.com/3085/3171199512_f938a191fe_o.jpg,Hawk
21 | /m/0fp7c,0fb32f8309f9e657,1,https://farm8.staticflickr.com/1269/4679488219_3a3c9a63cb_o.jpg,Hawk
22 |
--------------------------------------------------------------------------------
/data/examples/alpha_example_set_Sushi.csv:
--------------------------------------------------------------------------------
1 | LabelName,ImageID,Confidence,OriginalURL,DisplayName
2 | /m/07030,07b185567034bbb1,1,https://farm2.staticflickr.com/3778/11680681144_8a85e2cb49_o.jpg,Sushi
3 | /m/07030,e12b26befd92e2a9,1,https://c2.staticflickr.com/8/7525/16106981981_5a0ab10fd4_o.jpg,Sushi
4 | /m/07030,41d62e93c8333940,1,https://farm1.staticflickr.com/5070/5644991286_af30efcbd9_o.jpg,Sushi
5 | /m/07030,1e9f698a78a97100,1,https://farm1.staticflickr.com/2031/2439673474_e991aca2d3_o.jpg,Sushi
6 | /m/07030,65776b7c8603ce56,1,https://farm7.staticflickr.com/5174/5596202858_1aa5c16ee9_o.jpg,Sushi
7 | /m/07030,24571f400a6277b8,1,https://c1.staticflickr.com/3/2595/3951286521_3f32cd30ab_o.jpg,Sushi
8 | /m/07030,881c906e26ec47bd,1,https://c6.staticflickr.com/8/7308/16292070217_acf77c41e3_o.jpg,Sushi
9 | /m/07030,bc15b56a519d9eba,1,https://c3.staticflickr.com/2/1403/655402748_5e25e53ba8_o.jpg,Sushi
10 | /m/07030,0bbe9d3461da8005,1,https://farm6.staticflickr.com/8608/16205382902_fb42692876_o.jpg,Sushi
11 | /m/07030,6973329af45abd17,1,https://farm8.staticflickr.com/8434/7830740592_f6d86ba192_o.jpg,Sushi
12 | /m/07030,7d7e2bbd64b31a36,1,https://farm7.staticflickr.com/3132/3147501854_839c8e5c7d_o.jpg,Sushi
13 | /m/07030,3274aab7ff102683,1,https://farm4.staticflickr.com/3932/15360031479_2ea6152bbb_o.jpg,Sushi
14 | /m/07030,2a711993ced55009,1,https://c2.staticflickr.com/1/137/334330355_15cede6219_o.jpg,Sushi
15 | /m/07030,292d254b27992f39,1,https://c2.staticflickr.com/6/5456/9009681840_d763a0a7db_o.jpg,Sushi
16 | /m/07030,58d70c954b297c6a,1,https://farm4.staticflickr.com/7404/8929944658_2653ca4be0_o.jpg,Sushi
17 | /m/07030,07382aad542bc263,1,https://farm4.staticflickr.com/76/390941710_8e658151ab_o.jpg,Sushi
18 | /m/07030,0be303aafbfe95cb,1,https://c4.staticflickr.com/6/5263/5622576361_f85a9975ac_o.jpg,Sushi
19 | /m/07030,021a6b70ecd95934,1,https://farm5.staticflickr.com/22/28920062_b95d2d605a_o.jpg,Sushi
20 | /m/07030,0f259cd129018e90,1,https://farm1.staticflickr.com/8572/15945575187_17c0a01a22_o.jpg,Sushi
21 | /m/07030,64275fd17d9c11d9,1,https://c2.staticflickr.com/3/2056/2150626616_2cfe4a423d_o.jpg,Sushi
22 |
--------------------------------------------------------------------------------
/data/results/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.result_for_random_500.json
4 | !.gitignore
--------------------------------------------------------------------------------
/data/results/result_for_random_500.json:
--------------------------------------------------------------------------------
1 | {
2 | "Cupcake": {
3 | "accuracy": 0.5401459854014599,
4 | "recall": 0.463768115942029,
5 | "precision": 0.5517241379310345,
6 | "f1": 0.5039370078740157
7 | },
8 | "Hawk": {
9 | "accuracy": 0.296551724137931,
10 | "recall": 0.16831683168316833,
11 | "precision": 0.4857142857142857,
12 | "f1": 0.25000000000000006
13 | },
14 | "Sushi": {
15 | "accuracy": 0.5185185185185185,
16 | "recall": 0.6261682242990654,
17 | "precision": 0.638095238095238,
18 | "f1": 0.6320754716981132
19 | }
20 | }
--------------------------------------------------------------------------------
/data/test_sets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
--------------------------------------------------------------------------------
/data/train_sets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.random_500.csv
4 | !.gitignore
--------------------------------------------------------------------------------
/data/train_sets/random_500.csv:
--------------------------------------------------------------------------------
1 | ImageID,Confidence
2 | 0002643773a76876,0
3 | 0016a0f096337445,0
4 | 0036043ce525479b,1
5 | 00526f123f84db2f,1
6 | 0080db2599d54447,1
7 | 00978577e9fdd967,1
8 | 00999e0a1c66fa9e,0
9 | 00ae6f48022f527e,1
10 | 00ccbda21e2cc314,1
11 | 00fe9ad027307049,0
12 | 014277a08124380f,0
13 | 017b8e86d7a13543,0
14 | 01a1201b26a7fdc3,1
15 | 01a22546102bd9e8,0
16 | 01b0b1eb40d21471,1
17 | 01b66b3974dbeae9,0
18 | 01f39a574ee5fe0b,1
19 | 020861abeb63ced7,0
20 | 0210ebe5c31c9e4a,0
21 | 021b127ef86e33bb,0
22 | 0238a7e4aeb6e22f,1
23 | 027e69f951f3e970,0
24 | 027f4a5d585c6f54,0
25 | 027f6b1b3c35c224,1
26 | 02837ef702581393,1
27 | 02a24c3e6c5978f2,0
28 | 02c9db8dbfa39515,0
29 | 02e9d387fe04ebd5,0
30 | 02ede7ccd20f04ad,1
31 | 0322f572caa700d6,1
32 | 03283ae91d87a97b,0
33 | 033a3f1bd43e3112,1
34 | 033e8f0c9c33f138,0
35 | 036ed5ce09c82389,0
36 | 03bac6cfe5a8dc39,1
37 | 03bc484df524e1dd,0
38 | 03dfc8b7e008218d,1
39 | 03e30b9bfb3fff4d,1
40 | 03e96e9c6306f122,1
41 | 0408b3cd3dbd9bc2,1
42 | 041610af4aba2faf,0
43 | 042e922e3eb7036d,1
44 | 04652580e6ca3182,1
45 | 047067c3af79c3ca,1
46 | 04915967aaf19826,0
47 | 04957b97d746fc3e,1
48 | 04c6b3e741cacab3,1
49 | 050ab2cd054846c9,1
50 | 051540940bf53eac,1
51 | 052fe61dc71b0eba,1
52 | 0538e05f269f934c,0
53 | 05856fdeff423cb0,1
54 | 0589bffe7dff2c8a,0
55 | 0593a46ed8ff0f99,0
56 | 059d11c42caa06e5,0
57 | 059fc3db254afaa9,1
58 | 05c5e3c2b0e14842,0
59 | 05f9dafe672ca3b1,0
60 | 062ce07c09a59d1c,1
61 | 0655152dba1f84cc,0
62 | 072c4fb6f5651a1b,0
63 | 073fe85b4abbaa83,0
64 | 074f6faf7c4f2424,0
65 | 07826e354f362710,0
66 | 079c907f111f0b68,0
67 | 07b0459fc902c4b3,1
68 | 07e670e9df51fae3,0
69 | 07ed8d27cb1aebc9,0
70 | 07f525146d75da12,1
71 | 08085c0be08cbe27,1
72 | 083afb44856c10b4,1
73 | 0858a52cc1a80652,0
74 | 085baf695c29d1f3,0
75 | 08946585b8ea5eeb,0
76 | 08b7a60c169c8a68,1
77 | 08cf26629ef459fa,1
78 | 08d52981eea69dc5,0
79 | 08ecb8adf2030622,0
80 | 08f9c841c2a39bf5,0
81 | 09370e426b4f2478,1
82 | 095eab7353db3e6d,1
83 | 096dda4366d118ba,1
84 | 09b6e6848368539e,1
85 | 0a0bdb0cebfa67ee,0
86 | 0a3340f3516f2479,1
87 | 0a4e63f6c188485e,1
88 | 0a684b8fc669b812,0
89 | 0a69b418ad5dbdb0,0
90 | 0a78171e2fe9b152,1
91 | 0a983561e47a4398,1
92 | 0a9d308e5fa151bb,0
93 | 0acbb703c2d13cb2,0
94 | 0ad93624486df1a7,0
95 | 0ad977c36c026730,1
96 | 0aeb361cac11225d,0
97 | 0af2834572e5edfb,0
98 | 0b19e9430e1684ae,0
99 | 0b2373e4fcd56ab0,0
100 | 0b35185242de8b20,0
101 | 0b363dcb204ee6fa,0
102 | 0b44fcd48ce7a66b,1
103 | 0b8bc31e82e8ac51,0
104 | 0bb27693b12008d6,1
105 | 0c5d8bfa724854ed,0
106 | 0c6908ccfc636745,1
107 | 0c8bc6311a5f6b28,1
108 | 0cbfea27c75b7854,1
109 | 0d61a766a5e052c1,1
110 | 0d648bbe5b8a9600,0
111 | 0d7e09d990723b07,0
112 | 0d8787e0f401938b,1
113 | 0dc3994cc6311deb,1
114 | 0dd3211c2ad294a3,1
115 | 0ded1bb5d7b9d152,1
116 | 0e26edc5bc0c8311,0
117 | 0e2ddf22657cb827,1
118 | 0e39471801ad2a43,1
119 | 0e56d521cd207050,1
120 | 0e60973b10e64411,1
121 | 0e69efd1f7b44905,0
122 | 0ecc00179ef54ca5,1
123 | 0ef89705e4158105,1
124 | 0f153d6353f417d4,1
125 | 0f3381b38fcd5cdc,1
126 | 0f5834a4f70f7b20,1
127 | 0f872e1abf7ac6da,0
128 | 0fcfc644a29aa10b,1
129 | 0fe4c64f8692ad84,0
130 | 1000a5246fa70d6d,1
131 | 10103bc8d81860d7,1
132 | 10188cd6c38e5ab8,1
133 | 101921f7b6569691,0
134 | 103a0ee96b890b7a,1
135 | 10407470ddbe7f33,0
136 | 1060886ac7527790,0
137 | 1092892be3a1d74a,1
138 | 10bc4f9d4f9cf0c8,0
139 | 10cc91f3caccdaf5,0
140 | 10ed93f362a259b0,1
141 | 1113516133edbabc,0
142 | 11435a877f12aed5,0
143 | 1169b3404fe2843b,0
144 | 1170571276b6bcad,1
145 | 1174fafefae0428b,0
146 | 1182bd717c802dcc,1
147 | 1186520afb0493a8,0
148 | 119b9d88404d8d48,1
149 | 11bb6b1fb5702099,0
150 | 11bf96bd0e9882db,0
151 | 11da6f070d4771ef,0
152 | 120022a6a24e90cd,1
153 | 12043f1344f75048,1
154 | 123847b6982c7e1e,1
155 | 1246d649fd5e931d,1
156 | 12498a6ca767492e,0
157 | 124cb0d63a578f9e,0
158 | 124e720a8b9a19db,1
159 | 1263ad7c3ee86f27,0
160 | 129686341e997b40,1
161 | 12d2ce2b79ab9573,0
162 | 131407e04be18ace,1
163 | 1320bbabc132dd2f,0
164 | 134b1c18e6af837f,1
165 | 134edfc0e31dd303,0
166 | 135dc87be1b86802,1
167 | 135f588d040ea7fc,0
168 | 1375d5331cf28f83,0
169 | 1391bcfe081f19ad,1
170 | 13c35bcd1a6dd96b,1
171 | 13c98b0387aa0869,1
172 | 13f865503304fe7f,0
173 | 1419b0ba4824bd3a,0
174 | 141de6f0bb105207,0
175 | 1473761fb505e238,0
176 | 14cba893e4652d12,1
177 | 14f955e1ccb4ce9c,1
178 | 150c090d425a726a,0
179 | 1559343cf042ee0a,0
180 | 160d65e19cd0cdee,0
181 | 1612223f74fb4336,1
182 | 16366d31cf297c34,1
183 | 165098cf74ff01fa,1
184 | 168aa9f930331bde,1
185 | 16bda96251f1b2eb,0
186 | 16fae7ccc1c8d87d,1
187 | 17237768756d0617,0
188 | 1729bf4a981ba70e,0
189 | 17878e5abe11fb8b,0
190 | 178c462adad58992,1
191 | 179540365378298e,0
192 | 17ddd6c38fe6a11f,0
193 | 17f6bd54f41ecd24,1
194 | 1829c7528398489f,1
195 | 183368f31e6fa221,0
196 | 184c8e263fe0137d,1
197 | 1856d5dd7a355dfb,1
198 | 187187dd5c763ad3,0
199 | 187e144c1e8f8e41,1
200 | 18bc27b7dcf34a71,0
201 | 18bd73f2e03b0435,1
202 | 18f26cea382ed7b8,1
203 | 18f3454604de81c5,1
204 | 18f4cac0fa308c26,1
205 | 19153d748d8d5a1c,0
206 | 196665cfc99666e8,0
207 | 19c860012d34b7d5,1
208 | 19e8a28edca6a66b,0
209 | 1a17b0157e89480b,1
210 | 1a26647e9bd3a9db,1
211 | 1a3c2aca3ac642f8,0
212 | 1a446826a6216c9f,0
213 | 1a5fafb0a5ebbe5a,1
214 | 1a8111e0b7aadc91,0
215 | 1aa8314c51ba10ed,0
216 | 1aab8e89f1004b4b,0
217 | 1b62252ca65a12ff,1
218 | 1b643fda9337f548,1
219 | 1b977827d49f7759,0
220 | 1bb4b842b7353905,0
221 | 1bc488ddb14b84ac,0
222 | 1bdc4fb329706237,0
223 | 1c13f85e453a54e4,1
224 | 1c24bbdbcadbe3f4,1
225 | 1ca2cec096be16d4,0
226 | 1cdab95f1e784df0,1
227 | 1d0e19f531cb4898,1
228 | 1d19070c56721c8b,1
229 | 1d2f81eb5d0d5211,1
230 | 1d42dfa41fbc0a34,1
231 | 1d5aebd5afbc6f2b,1
232 | 1d5d4c3437b466b2,0
233 | 1d8d014496cb4824,0
234 | 1da26a35f3e8abf4,1
235 | 1dd47f1748a2bcf0,0
236 | 1de37429bf3cb6e8,1
237 | 1dec3dded7a4dae4,1
238 | 1e244b7e9f0b8fd9,0
239 | 1e28a5c7598ce0f9,0
240 | 1e41948a34480596,0
241 | 1e436799490e7eeb,0
242 | 1e595e59b37e4a27,0
243 | 1e9a4a1d0fc4ca29,0
244 | 1e9b191d90c2266e,1
245 | 1f192b722b0ad354,1
246 | 1f1ceaeded18db96,0
247 | 1f208689b80cee08,1
248 | 1f462ce1ca995b8c,1
249 | 1f6a53c6b5721fea,0
250 | 1fafa775240417ec,0
251 | 1fb250a9ef9d5a0c,0
252 | 1fe24d5640ce14b1,0
253 | 200b7847da358a2d,0
254 | 200d3d72a4523c40,0
255 | 2011183f59037ef0,0
256 | 201547df732c4156,1
257 | 2048581ccb95d815,1
258 | 2066a81e97489b4c,1
259 | 20759e94df7e69d9,0
260 | 20b6978d63ce149c,0
261 | 20d90ec2e11fc5b4,0
262 | 20ec4206c96cbd3d,0
263 | 2195261d7bcc65c2,0
264 | 21fecdabcc50d165,0
265 | 22640048749f1c61,0
266 | 226c6ad77da16dce,0
267 | 227edbb5aeeae5e9,1
268 | 22c1fba27a65f29c,1
269 | 23071dc96012937c,0
270 | 2315fe86de57a0f8,1
271 | 233404499f9c93c2,1
272 | 23754137155e09b8,0
273 | 2387aace9f6caa87,1
274 | 23f193b64cb7b6fe,1
275 | 243efee4fa9871de,0
276 | 244dd102dd6515e1,1
277 | 245d7557726b022c,1
278 | 245ebc538d8f9229,1
279 | 24c96ef04bf38053,1
280 | 24d6cb92e781a022,1
281 | 25473b5dff92c79d,0
282 | 269317351e751f49,1
283 | 26941883c642ca70,0
284 | 26adf38aadb3aecb,1
285 | 26e9798a6cf03d7c,1
286 | 26fd27d06044746e,0
287 | 272b18a95bd6a03a,0
288 | 274bc1c79f1b98e9,0
289 | 281cd9a58dfcd265,1
290 | 28b7d8656314a9d5,0
291 | 28cbd682e022cf51,1
292 | 292db09aebe28d57,0
293 | 29a3465795676349,0
294 | 2a0193b50f7e409a,0
295 | 2a0bec44f6933b30,1
296 | 2a43c69ad62d823f,0
297 | 2a43e6c4d1122207,1
298 | 2a47907f0d80ad46,0
299 | 2a8422b853757ddd,1
300 | 2b030fd83d070666,1
301 | 2b13140e618e8146,1
302 | 2b304285abaaa6c5,0
303 | 2b49e3793542e110,1
304 | 2b62224b42e0bb1f,1
305 | 2b853fea1d218995,1
306 | 2ba2b02bf30d4f12,0
307 | 2bc36d77f0c12ea6,0
308 | 2bf49d8e608783db,1
309 | 2c0c5ce55d3b9632,1
310 | 2c1ff4a2ec540cdc,1
311 | 2c32a3dd8b886ccc,0
312 | 2c3c25c576b1de15,1
313 | 2c54593283207177,0
314 | 2c70b971603d4411,0
315 | 2c7b082ff81eaf95,0
316 | 2c7d448252768342,0
317 | 2c9da9d471aaebc9,0
318 | 2cb0911be5234a78,1
319 | 2cb3413bd5efa7b7,1
320 | 2cc4a0f1e4d2c9c9,1
321 | 2d1df68ce872bb9f,0
322 | 2d27bbfd78f9f925,1
323 | 2d5759f12fa9ae85,1
324 | 2d734b2824575942,0
325 | 2dbbb4254b5aed3a,0
326 | 2df2b67715389791,1
327 | 2e0068ac814f792d,0
328 | 2e0f9b2466164356,1
329 | 2e1db05316089325,1
330 | 2e807a42b7b37883,0
331 | 2efa2ff77d256312,0
332 | 2f11455e210fdedf,1
333 | 2f179ca1a5f7039b,0
334 | 2f1adc343f8598d5,0
335 | 2f33584c5ffe72d5,0
336 | 2f5c0482b9732b25,1
337 | 2f73b6bff89bf15a,0
338 | 2f8909ce4e9dae11,0
339 | 2f98d2a51394bee2,1
340 | 2f9f04706938f8af,1
341 | 2fd542f6ac430a94,0
342 | 2fec7a7a1436d1c2,0
343 | 2feec97810f64d63,0
344 | 2ffa3756e321ac79,1
345 | 3019fdd7751a0295,0
346 | 30852efeea1cf314,0
347 | 308a9e78d915d9e9,0
348 | 30b1c6c5eca7f9b8,1
349 | 30ba496d7af5d9a0,1
350 | 30f937df9fc3dd31,0
351 | 3109e63fb93c6e65,1
352 | 31103e5a4701c98c,1
353 | 3127bbd0befaf782,0
354 | 3134bc984a2e3d5c,1
355 | 313977213b0d3c67,1
356 | 3144d031e1d9924d,1
357 | 3170776ff4406e6c,0
358 | 317179c2bc38d651,1
359 | 317678b136554c09,1
360 | 3191765aeac56c0d,1
361 | 31cbcdd5de59681d,0
362 | 31d209e52e881592,0
363 | 31e24288dc274d86,0
364 | 31fddabdb730b698,1
365 | 32099764d1d36634,0
366 | 324b83570c210bd6,1
367 | 325060723d3e21f0,1
368 | 3274a8a2ca494416,1
369 | 328c9475df30f96c,0
370 | 32d021d6e16faffe,0
371 | 32f4f6c0f2da7a97,1
372 | 333e36d6df1426f4,1
373 | 335f9a9f2b05e506,1
374 | 3371da82a80bf4cf,0
375 | 337fdbb659255d17,0
376 | 33e2f04f26101766,0
377 | 33fb3b378f677fc0,1
378 | 34833559bf9ac10b,1
379 | 34bcf62c7a6d20d5,1
380 | 34c0fffcfd317ad0,1
381 | 34e9a871c25e5254,1
382 | 34ec2c8d87192de0,1
383 | 35055b02a0cbaf87,1
384 | 3520d48978e4b1a7,1
385 | 35318d0d30b6bc7f,0
386 | 359ad26db2bdfad1,1
387 | 359bb1dd6df593b2,1
388 | 35b3d53c2e85deb0,0
389 | 35ed6e0029a50970,1
390 | 361f25fe375924eb,1
391 | 364b25a08774360a,0
392 | 364d10980606eb0c,1
393 | 36717da18ce3beab,0
394 | 367edbcf140248f0,0
395 | 36a5d8b4ddebccd2,0
396 | 36b91e58ae716df1,1
397 | 37106bd1e4921449,1
398 | 371d84b0b24527ce,1
399 | 3741db2dfe9ba01c,0
400 | 37af8f2f0515a0de,0
401 | 37b428bbf7cdac16,1
402 | 37d3f43cd3283a65,1
403 | 37f2aec9140b6146,0
404 | 38298c9c9c07b331,0
405 | 384e15769e0ece9d,1
406 | 3877b5b64d41f204,1
407 | 38833fd2cdcef112,0
408 | 38abb45d15a3a84e,0
409 | 38de74053a55c22a,1
410 | 38de92959fa7efcc,0
411 | 391d57896fe7fbad,1
412 | 394d6fc65e3d4f8e,1
413 | 395140593803e508,1
414 | 395d8800157e415b,1
415 | 3970ddb1df7b91b6,1
416 | 397168f7bf3f6653,0
417 | 39794020134fa81e,0
418 | 39e5af85b70fd17e,1
419 | 3a0c7d2da3eb33d4,1
420 | 3a32ebdfdeebc0bb,1
421 | 3a452f4a1cf72231,0
422 | 3acfd3080478564b,0
423 | 3ad343021062e934,0
424 | 3ae6bbb6c71a4e8f,0
425 | 3b028d879278e08c,1
426 | 3b23f46222944cbc,1
427 | 3b93b36bb3df7daf,0
428 | 3be6c6fe9a104131,0
429 | 3bfccb17fff4a25d,0
430 | 3c1e13485fff69fe,1
431 | 3c275a4741f26ff0,1
432 | 3c34214430a3ae97,0
433 | 3c3f0ec0de3fc5ce,1
434 | 3c4450673e1186b5,1
435 | 3c5840e6dd4a9c95,1
436 | 3c6949c3c02f692d,0
437 | 3c858852b73e4e07,1
438 | 3c8d0cdca76aef58,0
439 | 3c9fec64e9e3bb3a,1
440 | 3cbe1a201cc01b30,0
441 | 3d2bef3d8c660d37,1
442 | 3d43ac868b2b469f,1
443 | 3d530359cc8d7722,0
444 | 3d6bbe368fe9c8ce,1
445 | 3d8e69b99c177e6f,0
446 | 3de9a9ce79b2086c,1
447 | 3e07e3bb4bfbacac,1
448 | 3e2ec5f3ffb39ad3,0
449 | 3e46802f0df37952,1
450 | 3ea731bd8cfd98a2,0
451 | 3eb6b4a4ff8f031b,0
452 | 3ec5f97d27abaec6,1
453 | 3f20024b1aa83fe8,1
454 | 3f24d6a573aaca34,1
455 | 3f3aeca80046b05b,1
456 | 3f7bd2c21a3b77df,1
457 | 3fd36c3be1cbb924,1
458 | 3fef955bb604f22f,1
459 | 40242c87b6e22bba,0
460 | 406b8c6a6cefce38,0
461 | 409bfcca8d4b2be5,0
462 | 40b9d8ec678cb0ed,0
463 | 40de594d9ed169b3,0
464 | 4108e999cdb3626f,1
465 | 41168302ba3b3ae6,0
466 | 416fe6a79d6f94ef,1
467 | 417be12f43852235,0
468 | 418d578372036682,0
469 | 41a138381273629a,0
470 | 41a17dfb2113619c,0
471 | 41d325841762cd99,1
472 | 4210237a8adfa3fa,1
473 | 42470b2b40cb0e08,1
474 | 4261645bcc1dd73c,0
475 | 428a82a713007299,0
476 | 42fa06ecb1d8541d,0
477 | 42fed3d5dee37fae,1
478 | 430670ee980faa9c,0
479 | 430c1a398dfcd405,0
480 | 436cdf23b0ba4d76,0
481 | 43712821cf93e649,0
482 | 43b342b0679aff9b,0
483 | 43d50a312e08b997,0
484 | 43dc00dc34fd00f9,1
485 | 43fc149f47813d4a,0
486 | 4437e11d92557b30,0
487 | 44880d5ba63aefb0,0
488 | 44b22675dae9cfa6,0
489 | 44fdc19754799879,1
490 | 450d0d858061f3ae,1
491 | 454d4cb8233fe870,1
492 | 4550a814fd718224,1
493 | 4591afc0400e567c,1
494 | 4599350cccbb98e1,0
495 | 45b82f7e20e6614f,1
496 | 461871a1e09b1d0b,0
497 | 4632286207e40352,1
498 | 4643592415251268,0
499 | 4645bbfb33cd424f,1
500 | 4651a988c0b7c7ba,0
501 | 46a1b7b5b72e2863,1
502 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | dataperf-visual-submission:
3 | container_name: dataperf-visual-submission
4 | build: .
5 | volumes:
6 | - ./data:/app/data
--------------------------------------------------------------------------------
/download_data.py:
--------------------------------------------------------------------------------
1 | """Download samples and eval data"""
2 | import argparse
3 | import os
4 | import yaml
5 | import gdown
6 | import zipfile
7 | from tqdm import tqdm
8 |
9 |
10 | def download_file(url, folder_path, file_name):
11 | """Download file from Google Drive"""
12 | output_path = os.path.join(folder_path, file_name)
13 | gdown.download(url, output_path, quiet=False, fuzzy=True)
14 |
15 |
16 | def main():
17 | """Main function that perform the download"""
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument(
21 | "--parameters_file",
22 | type=str,
23 | required=True,
24 | help="File containing parameters for the download",
25 | )
26 | parser.add_argument(
27 | "--output_path", type=str, required=True, help="Path where data will be stored",
28 | )
29 | args = parser.parse_args()
30 |
31 | with open(args.parameters_file, "r") as f:
32 | params = yaml.full_load(f)
33 |
34 | output_path = args.output_path
35 | dataset_url = params["dataset_url"]
36 | file_name = "dataperf-vision-selection-resources.zip"
37 |
38 | download_file(dataset_url, output_path, file_name)
39 |
40 | with zipfile.ZipFile(os.path.join(output_path, file_name)) as zf:
41 | for member in tqdm(zf.infolist(), desc="Extracting "):
42 | try:
43 | zf.extract(member, output_path)
44 | except zipfile.error as e:
45 | print(e)
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 |
3 | from pyspark.sql import DataFrame
4 | from sklearn.base import BaseEstimator
5 | from sklearn.linear_model import LogisticRegression
6 | from sklearn.metrics import accuracy_score, recall_score,\
7 | precision_score, f1_score
8 |
9 | import constants as c
10 |
11 |
12 | def get_trained_classifier(
13 | df: DataFrame,
14 | clf: Optional[BaseEstimator] = LogisticRegression(
15 | random_state=c.RANDOM_SEED)) -> BaseEstimator:
16 |
17 | df = df.select(c.LABEL_COL, c.EMB_COL).toPandas()
18 | X = df[c.EMB_COL].values.tolist()
19 | y = df[c.LABEL_COL].values.tolist()
20 |
21 | return clf.fit(X, y)
22 |
23 |
24 | def score_classifier(df: DataFrame, clf: BaseEstimator) -> Dict[str, float]:
25 | df = df.select(c.LABEL_COL, c.EMB_COL).toPandas()
26 | X = df[c.EMB_COL].values.tolist()
27 | y = df[c.LABEL_COL].values.tolist()
28 | y_pred = clf.predict(X).tolist()
29 |
30 | scores = {}
31 | scores['accuracy'] = accuracy_score(y, y_pred)
32 | scores['recall'] = recall_score(y, y_pred, pos_label='1')
33 | scores['precision'] = precision_score(y, y_pred, pos_label='1')
34 | scores['f1'] = f1_score(y, y_pred, pos_label='1')
35 | return scores
36 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import fire
4 |
5 | import constants as c
6 | import utils as utils
7 | import eval as eval
8 |
9 |
10 | def run_tasks(
11 | setup_yaml_path: str = c.DEFAULT_SETUP_YAML_PATH,
12 | docker_flag: bool = False) -> None:
13 | """Runs visual benchmark tasks based on config yaml file.
14 |
15 | Args:
16 | setup_yaml_path (str, optional): Path for config file. Defaults
17 | to path in constants.DEFAULT_SETUP_YAML_PATH.
18 | docker_flag (bool, optional): True when running in container
19 | """
20 | task_setup = utils.load_yaml(setup_yaml_path)
21 | data_dir_key = c.SETUP_YAML_DOCKER_DATA_DIR_KEY if docker_flag \
22 | else c.SETUP_YAML_LOCAL_DATA_DIR_KEY
23 | data_dir = task_setup[data_dir_key]
24 | dim = task_setup[c.SETUP_YAML_DIM_KEY]
25 | emb_path = os.path.join(data_dir, task_setup[c.SETUP_YAML_EMB_KEY])
26 |
27 | ss = utils.get_spark_session(task_setup[c.SETUP_YAML_SPARK_MEM_KEY])
28 |
29 | print('Loading embeddings\n')
30 | emb_df = utils.load_emb_df(ss=ss, path=emb_path, dim=dim)
31 |
32 | task_paths = {
33 | task: task_setup[task] for task in task_setup[c.SETUP_YAML_TASKS_KEY]}
34 | task_scores = {}
35 | for task, paths in task_paths.items():
36 | print(f'Evaluating task: {task}')
37 | train_path, test_path = [os.path.join(data_dir, p) for p in paths]
38 |
39 | print(f'Loading training data for {task}...')
40 | train_df = utils.load_train_df(ss=ss, path=train_path)
41 | train_df = utils.add_emb_col(df=train_df, emb_df=emb_df)
42 |
43 | print(f'Loading test data for {task}...')
44 | test_df = utils.load_test_df(ss=ss, path=test_path, dim=dim)
45 |
46 | print(f'Training classifier for {task}...')
47 | clf = eval.get_trained_classifier(df=train_df)
48 |
49 | print(f'Scoring trained classifier for {task}...\n')
50 | task_scores[task] = eval.score_classifier(df=test_df, clf=clf)
51 |
52 | save_dir = os.path.join(data_dir, task_setup[c.SETUP_YAML_RESULTS_KEY])
53 | utils.save_results(data=task_scores, save_dir=save_dir, verbose=True)
54 |
55 |
56 | if __name__ == "__main__":
57 | fire.Fire(run_tasks)
58 |
--------------------------------------------------------------------------------
/mlcube.py:
--------------------------------------------------------------------------------
1 | """MLCube handler file"""
2 | import os
3 | import subprocess
4 |
5 | import typer
6 | import yaml
7 |
8 | from selection import Predictor
9 |
10 | typer_app = typer.Typer()
11 |
12 |
13 | class DownloadTask:
14 | """Download samples and eval data"""
15 |
16 | @staticmethod
17 | def run(parameters_file: str, output_path: str) -> None:
18 |
19 | cmd = "python3 download_data.py"
20 | cmd += f" --parameters_file={parameters_file} --output_path={output_path}"
21 | splitted_cmd = cmd.split()
22 |
23 | process = subprocess.Popen(splitted_cmd, cwd=".")
24 | process.wait()
25 |
26 |
27 | class SelectTask:
28 | """Run select algorithm"""
29 |
30 | @staticmethod
31 | def run(
32 | input_path: str, embeddings_path: str, parameters_file: str, output_path: str
33 | ) -> None:
34 |
35 | with open(parameters_file, "r") as f:
36 | params = yaml.full_load(f)
37 |
38 | print("Creating predictor")
39 | predictor = Predictor(embeddings_path)
40 | predictor.closest_and_furthest(
41 | input_path, output_path, params["n_closest"], params["n_random"]
42 | )
43 |
44 |
45 | class EvaluateTask:
46 | """Execute evaluation script"""
47 |
48 | @staticmethod
49 | def run(eval_path: str, log_path: str) -> None:
50 |
51 | env = os.environ.copy()
52 | env.update({"eval_path": eval_path, "log_path": log_path})
53 |
54 | process = subprocess.Popen("./run_evaluate.sh", cwd=".", env=env)
55 | process.wait()
56 |
57 |
58 | @typer_app.command("download")
59 | def download(
60 | parameters_file: str = typer.Option(..., "--parameters_file"),
61 | output_path: str = typer.Option(..., "--output_path"),
62 | ):
63 | DownloadTask.run(parameters_file, output_path)
64 |
65 |
66 | @typer_app.command("select")
67 | def select(
68 | input_path: str = typer.Option(..., "--input_path"),
69 | embeddings_path: str = typer.Option(..., "--embeddings_path"),
70 | parameters_file: str = typer.Option(..., "--parameters_file"),
71 | output_path: str = typer.Option(..., "--output_path"),
72 | ):
73 | SelectTask.run(input_path, embeddings_path, parameters_file, output_path)
74 |
75 |
76 | @typer_app.command("evaluate")
77 | def evaluate(
78 | eval_path: str = typer.Option(..., "--eval_path"),
79 | log_path: str = typer.Option(..., "--log_path"),
80 | ):
81 | EvaluateTask.run(eval_path, log_path)
82 |
83 |
84 | if __name__ == "__main__":
85 | typer_app()
86 |
--------------------------------------------------------------------------------
/mlcube.yaml:
--------------------------------------------------------------------------------
1 | name: MLCommons DataPerf Vision Example
2 | description: MLCommons DataPerf integration with MLCube
3 | authors:
4 | - { name: "MLCommons Best Practices Working Group" }
5 |
6 | platform:
7 | accelerator_count: 0
8 |
9 | docker:
10 | # Image name.
11 | image: mlcommons/dataperf_vision:0.0.1
12 | # Docker build context relative to $MLCUBE_ROOT. Default is `build`.
13 | build_context: "."
14 | # Docker file name within docker build context, default is `Dockerfile`.
15 | build_file: "Dockerfile_mlcube"
16 |
17 | tasks:
18 | download:
19 | # Download data
20 | parameters:
21 | inputs: { parameters_file: { type: file, default: parameters.yaml } }
22 | outputs: { output_path: data/ }
23 |
24 | select:
25 | # Run selection script
26 | parameters:
27 | inputs: {
28 | input_path: { type: file, default: data/dataperf-vision-selection-resources/train_sets/random_500.csv },
29 | embeddings_path: { type: file, default: data/dataperf-vision-selection-resources/embeddings/train_emb_256_dataperf.parquet},
30 | parameters_file: { type: file, default: parameters.yaml }
31 | }
32 | outputs: { output_path: { type: file, default: data/selection_output.csv } }
33 |
34 | evaluate:
35 | # Perfom evaluation
36 | parameters:
37 | inputs: { eval_path: data/dataperf-vision-selection-resources/}
38 | outputs: { log_path: { type: file, default: log.txt } }
39 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyspark==3.2.1
2 | pyyaml==6.0
3 | fire==0.4.0
4 | scikit-learn==0.24.0
5 | pandas==1.1.5
6 | typer==0.6.1
7 | gdown==4.5.1
8 | tqdm==4.64.0
9 | pyarrow
10 |
--------------------------------------------------------------------------------
/run_evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ln -sf ${eval_path} ./data
4 | python3 main.py 2>&1 | tee ${log_path}
5 |
--------------------------------------------------------------------------------
/selection.py:
--------------------------------------------------------------------------------
1 | """Selection script
2 | Replace this file with your own logic"""
3 | import pandas as pd
4 | import numpy as np
5 | import shutil
6 |
7 |
8 | class Predictor:
9 | def __init__(self, embeddings_path: str):
10 | print("Loading embeddings")
11 | self.embeddings = self.read_parquet(embeddings_path)
12 | self.embeddings_vect = np.stack(self.embeddings["embedding"].to_numpy())
13 | print("Embeddings loaded")
14 |
15 | def read_parquet(self, path: str):
16 | return pd.read_parquet(path, engine="pyarrow", use_threads=True)
17 |
18 | def similarity(self, centroid, n_closest, n_random):
19 | random_indexes = np.random.choice(
20 | len(self.embeddings_vect), size=n_random, replace=False
21 | )
22 | distances = np.linalg.norm(
23 | self.embeddings_vect[random_indexes] - centroid, axis=1
24 | )
25 | result = np.argsort(distances)[:n_closest]
26 | return result
27 |
28 | def get_embeddings_labeled_data(self, input_path):
29 | df = pd.read_csv(input_path)
30 | return self.embeddings.loc[self.embeddings["ImageID"].isin(df["ImageID"])]
31 |
32 | def calculate_centroid(self, df):
33 | return df["embedding"].apply(lambda x: np.array(x)).values.mean()
34 |
35 | def closest_and_furthest(
36 | self, input_path, output_path, n_closest=100, n_random=1000
37 | ):
38 | df = self.get_embeddings_labeled_data(input_path)
39 | print("Embeddings loaded")
40 | print("Getting centroids")
41 | centroid = self.calculate_centroid(df)
42 | print("Running similairty")
43 | similarity = self.similarity(centroid, n_closest, n_random)
44 | print("Saving submission")
45 | submission = pd.DataFrame(self.embeddings.iloc[similarity]["ImageID"])
46 | submission["Confidence"] = 1
47 | submission.to_csv(output_path, index=False)
48 |
--------------------------------------------------------------------------------
/task_setup.yaml:
--------------------------------------------------------------------------------
1 | # Memory used by Spark driver (recommend > 6g)
2 | spark_driver_memory: '6g'
3 |
4 | # Path for data directory. Note that all other paths in this setup
5 | # file will be relative to this one (e.g. if data_dir='some/path'
6 | # and emb_file='to/emb.parquet', the embeddings will be loaded from
7 | # 'some/path/to/emb.parquet')
8 | data_dir: 'data'
9 |
10 | # Relative path for embedding file
11 | emb_file: 'embeddings/train_emb_256_dataperf.parquet'
12 |
13 | # Paths to training and test files for each tasks. Task names must match names
14 | # in eval_tasks above. See example below:
15 | # task_name: ['relative/path/to/training_set.csv', 'relative/path/to/test_set.csv']
16 | Cupcake: ['train_sets/Cupcake.csv', 'test_sets/alpha_test_set_Cupcake_256.parquet']
17 | Hawk: ['train_sets/Hawk.csv', 'test_sets/alpha_test_set_Hawk_256.parquet']
18 | Sushi: ['train_sets/Sushi.csv', 'test_sets/alpha_test_set_Sushi_256.parquet']
19 |
20 | # Relative path for results file
21 | results_dir: 'results'
22 |
23 | # Parameters specific to alpha below (likely not useful to modify)
24 |
25 | # Embedding dimensionality
26 | dim: 256
27 |
28 | # List the names of tasks to evaluate
29 | eval_tasks: ['Cupcake','Hawk','Sushi']
30 |
31 | # Path for data directory in docker image.
32 | # DO NOT MODIFY
33 | docker_data_dir: 'data'
34 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from datetime import datetime
4 |
5 | import yaml
6 | from pyspark.sql import SparkSession, DataFrame
7 |
8 | import constants as c
9 |
10 |
11 | def get_spark_session(spark_driver_memory: str) -> SparkSession:
12 | return SparkSession.builder\
13 | .config('spark.driver.memory', spark_driver_memory)\
14 | .getOrCreate()
15 |
16 |
17 | def get_emb_dim(df: DataFrame) -> int:
18 | return len(df.select(c.EMB_COL).take(1)[0][0])
19 |
20 |
21 | def load_yaml(path: str) -> dict:
22 | with open(path, 'r') as stream:
23 | try:
24 | yaml_dict = yaml.safe_load(stream)
25 | except yaml.YAMLError as exc:
26 | print(exc)
27 | return yaml_dict
28 |
29 |
30 | def load_emb_df(ss: SparkSession, path: str, dim: int) -> DataFrame:
31 | df = ss.read.parquet(path)
32 |
33 | for col in [c.EMB_COL, c.ID_COL]:
34 | assert col in df.columns, \
35 | f'Embedding file does not have "{col}" column'
36 |
37 | actual_dim = get_emb_dim(df)
38 | assert actual_dim == dim, \
39 | f'Embedding file dim={actual_dim}, but setup file specifies dim={dim}'
40 |
41 | return df
42 |
43 |
44 | def load_train_df(ss: SparkSession, path: str) -> DataFrame:
45 | df = ss.read.option('header', True).csv(path)
46 |
47 | for col in [c.LABEL_COL, c.ID_COL]:
48 | assert col in df.columns, \
49 | f'{path}: Train file does not have "{col}" column'
50 |
51 | return df
52 |
53 |
54 | def add_emb_col(df: DataFrame, emb_df: DataFrame) -> DataFrame:
55 | emb_df = emb_df.select(c.ID_COL, c.EMB_COL)
56 | return df.join(emb_df, c.ID_COL)
57 |
58 |
59 | def load_test_df(ss: SparkSession, path: str, dim: int) -> DataFrame:
60 | df = ss.read.parquet(path)
61 |
62 | for col in [c.LABEL_COL, c.ID_COL]:
63 | assert col in df.columns, \
64 | f'{path}: Train file does not have "{col}" column'
65 |
66 | actual_dim = get_emb_dim(df)
67 | assert actual_dim == dim, \
68 | f'Test file dim={actual_dim}, but setup file specifies dim={dim}'
69 |
70 | return df
71 |
72 |
73 | def save_results(data: dict, save_dir: str, verbose=False) -> None:
74 | dt = datetime.utcnow().strftime("UTC-%Y-%m-%d-%H-%M-%S")
75 | filename = f'{c.RESULT_FILE_PREFIX}_{dt}.json'
76 | path = os.path.join(save_dir, filename)
77 | with open(path, 'w') as f:
78 | json.dump(data, f, indent=4)
79 |
80 | if verbose:
81 | print(f'Results saved in {path}')
82 | print(json.dumps(data, indent=4))
83 |
--------------------------------------------------------------------------------
/workspace/parameters.yaml:
--------------------------------------------------------------------------------
1 | dataset_url: "https://drive.google.com/drive/folders/181uI-7NFJwK3IOPy2kOYVIQS4vZOC02A?usp=sharing"
2 | n_random: 500
3 | n_closest: 50
--------------------------------------------------------------------------------