├── benchmark ├── __init__.py ├── data │ ├── __init__.py │ ├── labels.py │ └── embeddings.py ├── evaluation │ ├── __init__.py │ ├── utils.py │ ├── metrics.py │ ├── evaluation.py │ ├── results.py │ ├── visualisations.py │ └── linear_probing.py ├── config.yaml └── main.py ├── examples ├── data │ ├── __init__.py │ ├── submission_utils.py │ └── dataset.py ├── S2_dino_embeddings.py └── baseline_compression_mean.ipynb ├── .gitattributes ├── assets └── NeuCoBench.png ├── data ├── baseline_mean_embeddings_eval.csv ├── baseline_random_embeddings_eval.csv ├── AI4G_intern_squad_eval_final_submission.csv ├── KTH_and_Friends_eval_final_submission.csv ├── 404_Embedding_Not_Found_eval_final_submission.csv └── README.md ├── requirements.txt ├── CITATION.cff ├── .github └── CONTRIBUTING.md ├── .gitignore ├── README.md └── LICENSE /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | data/** filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /assets/NeuCoBench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/embed2scale/NeuCo-Bench/HEAD/assets/NeuCoBench.png -------------------------------------------------------------------------------- /data/baseline_mean_embeddings_eval.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e705368cc9f1a8a73635165f17619eacd56f971bfedcc918fe04068360ef7651 3 | size 90140926 4 | -------------------------------------------------------------------------------- /data/baseline_random_embeddings_eval.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5e7ba9f40846d8675a4f33d0986b847b55dedb2c2cad5c316e417dfcd8531a59 3 | size 163582298 4 | -------------------------------------------------------------------------------- /data/AI4G_intern_squad_eval_final_submission.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c78f9c4f7b794d7da89a315d9a2b1d35b03bb6c8742b6cf52d6a2066e6cfd6f2 3 | size 91817778 4 | -------------------------------------------------------------------------------- /data/KTH_and_Friends_eval_final_submission.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3e800b9e5ca3755a00a934e7262e96b967a66849fe4cf9e7bb8a37dad6caf463 3 | size 61426375 4 | -------------------------------------------------------------------------------- /data/404_Embedding_Not_Found_eval_final_submission.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:df16cd87a96fa550ff647b34a8949388bec07d6b270c8e5ef941725133a1046f 3 | size 133028192 4 | -------------------------------------------------------------------------------- /benchmark/config.yaml: -------------------------------------------------------------------------------- 1 | embedding_dim: 1024 2 | batch_size: 64 3 | epochs: 20 4 | learning_rate: 0.001 5 | k_folds: 40 # Eval: 200 6 | standardize_embeddings: true 7 | normalize_labels: true 8 | enable_plots: false 9 | update_leaderboard: false 10 | task_filter: false #eg. ["biomass_mean", "biomass_std"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.10.1 2 | numcodecs==0.15.1 3 | numpy==2.2.5 4 | pandas==2.2.3 5 | PyYAML==6.0.2 6 | scikit_learn==1.6.1 7 | scipy==1.15.2 8 | timm==1.0.15 9 | torch==2.6.0 10 | torchgeo==0.7.0 11 | torchmetrics==1.7.1 12 | torchvision==0.21.0 13 | tqdm==4.67.1 14 | xarray==2024.3.0 15 | zarr==2.18.0 -------------------------------------------------------------------------------- /benchmark/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import random, numpy as np, torch 2 | 3 | def fix_all_seeds(seed: int = 42): 4 | """ 5 | Fixes all relevant random seeds to ensure reproducible results. 6 | 7 | This sets seeds for Python's built-in `random` module, NumPy, and PyTorch, 8 | and enforces deterministic behavior in PyTorch operations. 9 | 10 | Args: 11 | seed (int): The seed value to use. Default is 42. 12 | """ 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.use_deterministic_algorithms(True) -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Wittmann 5 | given-names: Isabelle 6 | orcid: https://orcid.org/0009-0005-2137-6167 7 | - family-names: Vinge 8 | given-names: Rikard 9 | orcid: https://orcid.org/0000-0002-7306-3403 10 | - family-names: Albrecht 11 | given-names: Conrad M. 12 | orcid: https://orcid.org/0009-0009-2422-7289 13 | - family-names: Schneider 14 | given-names: Jannik 15 | title: "NeuCo-Bench" 16 | version: 1.0 17 | date-released: 2025-05-12 18 | url: https://github.com/embed2scale/benchmark -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # NeuCo-Bench Data Directory 2 | 3 | This folder is managed by Git LFS (cf. `.gitattributes`). 4 | 5 | ## Assets 6 | 7 | ### 2025 CVPR EARTHVISION data challenge 8 | 9 | cf. https://github.com/DLR-MF-DAS/embed2scale-challenge-supplement 10 | 11 | - `404_Embedding_Not_Found_eval_final_submission.csv`: embeddings of winning solution by Microsoft 12 | - `KTH_and_Friends_eval_final_submission.csv`: embeddings of winning solution by KTH: https://github.com/KerekesDavid/embed2scale-solution 13 | - `AI4G_intern_squad_eval_final_submission.csv`: no-training solution by Wherobots/UniBW/Microsoft/Colorado U: https://github.com/isaaccorley/temporal-mosaiks 14 | - `baseline_mean_embeddings_eval.csv`/`baseline_random_embeddings_eval.csv`: E2S baselines, cf. https://github.com/DLR-MF-DAS/embed2scale-challenge-supplement/blob/main/data_loading_submission_demo/baseline_compression_mean.ipynb 15 | -------------------------------------------------------------------------------- /benchmark/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from sklearn.metrics import ( 4 | accuracy_score, 5 | confusion_matrix, 6 | f1_score, 7 | mean_absolute_error, 8 | mean_squared_error, 9 | precision_score, 10 | recall_score, 11 | r2_score, 12 | roc_auc_score, 13 | ) 14 | 15 | def classification_metrics( 16 | y_pred: List[int], 17 | y_true: List[int], 18 | ) -> Dict[str, Any]: 19 | 20 | y_pred_cls = (y_pred >= 0).long() 21 | 22 | metrics: Dict[str, Any] = {} 23 | metrics["f1"] = f1_score(y_true, y_pred_cls, average='binary') 24 | metrics["accuracy"] = accuracy_score(y_true, y_pred_cls) 25 | metrics["precision"] = precision_score( 26 | y_true, y_pred_cls, average='binary', zero_division=0 27 | ) 28 | metrics["recall"] = recall_score( 29 | y_true, y_pred_cls, average='binary', zero_division=0 30 | ) 31 | metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred_cls).tolist() 32 | metrics["roc_auc"] = roc_auc_score(y_true, y_pred) 33 | 34 | return metrics 35 | 36 | 37 | def regression_metrics( 38 | y_pred: List[float], 39 | y_true: List[float], 40 | ) -> Dict[str, float]: 41 | 42 | metrics = { 43 | "r2": r2_score(y_true, y_pred), 44 | "mse": mean_squared_error(y_true, y_pred), 45 | "mae": mean_absolute_error(y_true, y_pred), 46 | } 47 | return metrics 48 | -------------------------------------------------------------------------------- /benchmark/data/labels.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | from pathlib import Path 4 | from typing import Union 5 | import pandas as pd 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | def get_annotations(folder_path: Union[str, Path]) -> pd.DataFrame: 10 | """ 11 | Load annotation entries from all CSV files in a folder into a single DataFrame. 12 | 13 | Each CSV file should be named __.csv and must contain columns 'id' and 'label'. 14 | 15 | Args: 16 | folder_path: Path to the directory containing CSV annotation files. 17 | 18 | Returns: 19 | DataFrame with columns ['id', 'label', 'task_name', 'task_type']. 20 | """ 21 | folder = Path(folder_path) 22 | if not folder.is_dir(): 23 | raise ValueError(f"Provided path is not a directory: {folder}") 24 | 25 | entries = [] 26 | sorted_out = 0 27 | 28 | for csv_path in folder.glob("*.csv"): 29 | task_name, task_type = csv_path.stem.split("__", 1) 30 | with csv_path.open(newline='') as csvfile: 31 | reader = csv.DictReader(csvfile) 32 | for row in reader: 33 | label = row.get('label') 34 | if label is None or label == "": 35 | sorted_out += 1 36 | continue 37 | entries.append({ 38 | 'id': row['id'], 39 | 'label': float(label), 40 | 'task_name': task_name, 41 | 'task_type': task_type, 42 | }) 43 | logger.info("Processed %s: %d valid entries", csv_path.name, len(entries)) 44 | 45 | if sorted_out: 46 | logger.warning("Skipped %d rows due to missing labels", sorted_out) 47 | 48 | df = pd.DataFrame(entries) 49 | logger.info("Loaded total of %d annotation entries", len(df)) 50 | return df 51 | 52 | __all__ = ['get_annotations'] 53 | -------------------------------------------------------------------------------- /benchmark/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | import yaml 5 | 6 | from evaluation.evaluation import evaluate 7 | from evaluation.results import summarize_runs 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--submission_file", type=Path, required=True, help='File containing compressed embeddings to evaluate.') 15 | parser.add_argument("--exclude_file", type=Path, default=None, required=False, help='File containing your compressed embeddings.') 16 | parser.add_argument("--annotation_path", type=Path, required=True, help='Folder containing csv label files per downstream task.') 17 | parser.add_argument("--config", type=Path, default="config.yaml", help='YAML file with cross-validation settings, and logging options. See provided sample config.') 18 | parser.add_argument("--method_name", type=str, required=True, help='Identifier for your compression method—used to tag outputs and leaderboards.') 19 | parser.add_argument("--output_dir", type=Path, default=Path("results/"), help='Directory to save per-task reports, plots, and aggregated results.') 20 | parser.add_argument("--phase", type=str, default="all", help='A label (e.g., “dev”, “eval”) defining a particular benchmark setup. Results for each phase are stored in a separate subfolder under output_dir.') 21 | return parser.parse_args() 22 | 23 | def main(): 24 | args = parse_args() 25 | with args.config.open() as f: 26 | config = yaml.safe_load(f) 27 | 28 | evaluate( 29 | submission_file=args.submission_file, 30 | exclude_file=args.exclude_file, 31 | annotation_path=args.annotation_path, 32 | method_name=args.method_name, 33 | output_dir=args.output_dir, 34 | phase=args.phase, 35 | config=config, 36 | ) 37 | 38 | if config["update_leaderboard"]: 39 | summarize_runs(output_dir=args.output_dir, phase=args.phase) 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /examples/data/submission_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | See: https://github.com/DLR-MF-DAS/embed2scale-challenge-supplement/blob/main/data_loading_submission_demo/demo_load_create_submission.ipynb 3 | """ 4 | import pandas as pd 5 | 6 | 7 | def create_submission_from_dict(emb_dict): 8 | """ 9 | Assume dictionary has format: 10 | {hash-id0: embedding0, hash-id1: embedding1, ...} 11 | """ 12 | df_submission = pd.DataFrame.from_dict(emb_dict, orient='index') 13 | 14 | # Reset index with name 'id' 15 | df_submission.index.name = 'id' 16 | df_submission.reset_index(drop=False, inplace=True) 17 | 18 | return df_submission 19 | 20 | 21 | def test_submission( 22 | path_to_submission: str, 23 | expected_embedding_ids: set, 24 | embedding_dim: int = 1024 25 | ) -> bool: 26 | # Load data 27 | df = pd.read_csv(path_to_submission, header=0) 28 | 29 | # Verify that 'id' is in columns 30 | if 'id' not in df.columns: 31 | raise ValueError("Submission file must contain column 'id'.") 32 | 33 | # Temporarily set index to 'id' 34 | df.set_index('id', inplace=True) 35 | 36 | # Check that all samples are included 37 | submitted_embeddings = set(df.index) 38 | missing = expected_embedding_ids.difference(submitted_embeddings) 39 | if missing: 40 | n_missing = len(missing) 41 | raise ValueError(f"Submission is missing {n_missing} embeddings.") 42 | 43 | # Check that embeddings have the correct length 44 | if df.shape[1] != embedding_dim: 45 | raise ValueError( 46 | f"{embedding_dim} embedding dimensions expected, " 47 | f"but provided embeddings have {df.shape[1]} dimensions." 48 | ) 49 | 50 | # Convert columns to float 51 | try: 52 | for col in df.columns: 53 | df[col] = df[col].astype(float) 54 | except Exception as e: 55 | raise ValueError( 56 | "Failed to convert embedding values to float. " 57 | "Check for invalid characters (e.g., empty strings, letters). " 58 | f"Original error: {e}" 59 | ) 60 | 61 | # Check for any NaNs 62 | if df.isna().any().any(): 63 | raise ValueError("Embeddings contain NaN values.") 64 | 65 | return True -------------------------------------------------------------------------------- /benchmark/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from datetime import datetime 5 | from pathlib import Path 6 | import torch 7 | 8 | from data.embeddings import load_submission 9 | from data.labels import get_annotations 10 | from evaluation.linear_probing import cross_validate 11 | from evaluation.results import save_results, summarize_runs 12 | from evaluation.utils import fix_all_seeds 13 | 14 | # Logging setup 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger("benchmarking") 17 | 18 | def evaluate(submission_file: Path, 19 | annotation_path: Path, 20 | method_name: str, 21 | output_dir: Path, 22 | phase: str, 23 | config: dict, 24 | exclude_file: str = None) -> None: 25 | """Evaluate an set of embeddings on one or multiple tasks. 26 | Results are written to output_dir folder. 27 | 28 | Args: 29 | submission_file: Path to the submission file containing embeddings. 30 | annotation_path: Path to folder containing label files 31 | method_name: Name of experiment. Used to distinguish methods compared in the same experiment. 32 | output_dir: Path to folder to which results are written. 33 | phase: Name of phase. Used to distinguish experiments. 34 | config: Dictionary containing evaluation configurations. 35 | exclude_file: Path to file containing embedding IDs to exclude. If not provided, exclude no embeddings. 36 | """ 37 | fix_all_seeds(seed=42) 38 | device = torch.device(config.get("device", "cuda" if torch.cuda.is_available() else "cpu")) 39 | 40 | # Determine device 41 | device = ( 42 | torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | if config.get("device", "auto") == "auto" 44 | else torch.device(config["device"]) 45 | ) 46 | 47 | # Load data 48 | annotation_df = get_annotations(annotation_path) 49 | submission_df = load_submission( 50 | file_path=submission_file, 51 | valid_ids=set(annotation_df['id']), 52 | expected_dim=config['embedding_dim'], 53 | exclude_file=exclude_file, 54 | standardize=config['standardize_embeddings'] 55 | ) 56 | 57 | merged_df = annotation_df.merge(submission_df, on="id").dropna(subset=["embedding"]) 58 | 59 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 60 | experiment_name = f"{method_name}_{timestamp}" 61 | run_dir = output_dir / phase / experiment_name 62 | run_dir.mkdir(parents=True, exist_ok=True) 63 | 64 | task_results = {} 65 | task_filter = config.get("task_filter") 66 | 67 | for task_name, group in merged_df.groupby("task_name"): 68 | if task_filter is not False and task_name not in task_filter: 69 | logger.info("Skipping task %s due to filter", task_name) 70 | continue 71 | 72 | task_type = group["task_type"].iloc[0] 73 | 74 | if config.get("normalize_labels", True): 75 | group = group.copy() 76 | group["label"] = (group["label"] - group["label"].min()) / (group["label"].max() - group["label"].min()) 77 | 78 | logger.info("Evaluating %s (%s)", task_name, task_type) 79 | 80 | result = cross_validate( 81 | df=group, 82 | task_type=task_type, 83 | task_name=task_name, 84 | device=device, 85 | batch_size=config["batch_size"], 86 | n_splits=config["k_folds"], 87 | epochs=config["epochs"], 88 | embedding_dim=config["embedding_dim"], 89 | learning_rate=config["learning_rate"], 90 | output_dir=run_dir, 91 | filename_prefix=submission_file.stem, 92 | enable_plots=config.get("enable_plots", True), 93 | output_fold_results=config.get("output_fold_results", False), 94 | ) 95 | 96 | task_results[task_name] = result.q_statistic 97 | 98 | save_results(experiment_name=experiment_name, task_results=task_results, output_dir=run_dir, config=config) 99 | logger.info("Finished evaluation.") 100 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for considering contributing to this project! Your efforts help to strengthen the open-source community. We welcome all forms of contributions, including but not limited to the following: 4 | 5 | - Introduction of new downstream tasks and data 6 | - Introduction of new evaluation methods 7 | - Run data challenges 8 | - Documentation updates, bug fixing, and general code improvement 9 | 10 | ## Workflow 11 | 12 | 0. For significant modifications or any bugs spotting, please consider opening an issue for discussion beforehand. 13 | 1. Fork and pull the latest repository (Click the **Fork** button on GitHub). 14 | 2. Clone your fork: 15 | ```sh 16 | git clone https://github.com/your-username/neuco-bench.git 17 | ``` 18 | 3. Navigate into the project directory 19 | ```sh 20 | cd neuco-bench 21 | ``` 22 | 4. Add the upstream repository 23 | ```sh 24 | git remote add upstream https://github.com/neuco-bench.git 25 | ``` 26 | 5. Create a local branch 27 | ```sh 28 | git checkout -b feature-branch 29 | ``` 30 | and use a descriptive branch name related to your changes, e.g. `feature/building-count-downstream-task`. 31 | 6. Make your changes following the [PEP 8](https://peps.python.org/pep-0008) coding standard. 32 | 7. Write clear and concise commit messages: 33 | ```sh 34 | git commit -m 'Short description of changes' -m'and more details' 35 | ``` 36 | 8. Push your branch to your fork 37 | ```sh 38 | git push origin feature-branch 39 | ``` 40 | 9. Go to your fork on GitHub and click **Compare & pull request**. Provide a detailed explanation of your changes and link to relevant issues. 41 | 42 | 43 | # Code of Conduct 44 | As contributor we expect you to follow the [Code of Conduct as specified by the Linux Foundation}(https://docs.linuxfoundation.org/lfx/mentorship/mentor-guide/code-of-conduct). 45 | 46 | ## License 47 | By contributing, you agree that your work will be licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0.html). 48 | 49 | ## Need Help? 50 | If you have any questions, feel free to open an issue or contact the project maintainers. 51 | 52 | 53 | # Code structure 54 | 55 | The main code for NeuCo-Bench reside in `benchmark`. 56 | 57 | ## data 58 | 59 | In `data`, basic modules for loading and processing downstream data and labels are located. 60 | 61 | 1. `embeddings` loads and processes files containing embeddings. 62 | 2. `labels.py` loads and processes label files. 63 | 64 | ## evaluation 65 | 66 | `evaluation` contains code for evaluating embeddings on downstream tasks. 67 | 68 | 1. `evaluation`, main evaluation function 69 | 2. `linear_probing` contains code for repeatedly training Linear Probes for evaluating the embeddings on downstream tasks. 70 | 3. `metrics` contains metrics for various types of downstream tasks. 71 | 4. `results` aggregates results over multiplle downstream tasks and creates a leaderboard. 72 | 5. `visualizations` contains code to visualize intermediate results during training as well as overall results from evaluation. 73 | 74 | ## examples 75 | 76 | `examples` contain examples of embedding creation, evaluation demos, and more. 77 | 78 | 79 | # Adding new features 80 | 81 | ## Adding new downstream data 82 | 83 | NeuCo-Bench is designed to be data agnostic. There are two main methods for adding new downstream tasks. 84 | 85 | 1. **Create your own dataset** either locally or on data repositories, e.g. Huggingface. Use the same structure as the [SSL4EO-S12-downstream](https://huggingface.co/datasets/embed2scale/SSL4EO-S12-downstream) dataset, i.e. one folder `data` containing your data to embed (possibly in sub folders) and one folder `labels` containing one file per task containing a map between the data (`id`) and target (`label`). If you use new fodler structure or data types, this may require implementing new data loading functionality in this code base. 86 | 2. **Extend the [SSL4EO-S12-downstream](https://huggingface.co/datasets/embed2scale/SSL4EO-S12-downstream) dataset** by creating a github issue on this repo, or contacting the admins of the [embed2scale](https://huggingface.co/embed2scale) Huggingface organization. 87 | 88 | ## Adding new evaluation methods 89 | 90 | Please see the **Contributing** section above. 91 | -------------------------------------------------------------------------------- /benchmark/evaluation/results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from pathlib import Path 4 | from datetime import datetime 5 | import pandas as pd 6 | from scipy.stats import rankdata 7 | 8 | def save_results(experiment_name, task_results, output_dir: Path, config: dict = None): 9 | """Save raw results with timestamp and optional config snapshot.""" 10 | 11 | # Cast output_dir to Path 12 | if not isinstance(output_dir, Path): 13 | output_dir = Path(output_dir) 14 | 15 | result_path = output_dir / "results_summary.json" 16 | 17 | metadata = { 18 | "experiment": experiment_name, 19 | "overall_score": np.mean([v for v in task_results.values()]), 20 | "task_results": task_results, 21 | } 22 | if config: 23 | metadata["config"] = config 24 | 25 | with open(result_path, "w") as f: 26 | json.dump(metadata, f, indent=2) 27 | 28 | def aggregate_results(output_dir: Path, phase: str) -> pd.DataFrame: 29 | """Aggregate all runs under a given phase into a DataFrame.""" 30 | # Cast output_dir to Path 31 | if not isinstance(output_dir, Path): 32 | output_dir = Path(output_dir) 33 | 34 | phase_path = output_dir / phase 35 | if not phase_path.exists(): 36 | raise FileNotFoundError(f"No results found under {phase_path}") 37 | 38 | rows = [] 39 | for exp_dir in phase_path.iterdir(): 40 | for json_file in exp_dir.glob("results_summary.json"): 41 | with open(json_file) as f: 42 | data = json.load(f) 43 | row = { 44 | "experiment": data["experiment"], 45 | } 46 | row.update(data["task_results"]) 47 | rows.append(row) 48 | 49 | return pd.DataFrame(rows) 50 | 51 | def compute_leaderboard(df: pd.DataFrame, metric_columns: list[str]) -> pd.DataFrame: 52 | """ 53 | 1. Exclude any runs that have NaN for any metric. 54 | 2. Compute mean score across metrics. 55 | 3. For each metric, rank experiments (highest→1). 56 | 4. Compute weights = stddev per metric / sum(stddevs). 57 | 5. Weighted_score = sum(rank * weight). 58 | 6. aggregated_rank = rank(weighted_score). 59 | """ 60 | df = df.copy() 61 | 62 | complete_mask = df[metric_columns].notna().all(axis=1) 63 | df = df.loc[complete_mask].reset_index(drop=True) 64 | 65 | if df.empty: 66 | return pd.DataFrame( 67 | columns=["experiment", "mean_score", "weighted_score", "aggregated_rank"] + metric_columns 68 | ) 69 | 70 | df["mean_score"] = df[metric_columns].mean(axis=1) 71 | 72 | # Per-metric ranks (highest better → negate) 73 | scores = df[metric_columns].to_numpy() 74 | per_task_ranks = rankdata(-scores, axis=0, method="min") 75 | 76 | # Metric weights from stddev 77 | if len(df) > 1: 78 | stds = np.std(scores, axis=0, ddof=1) 79 | weights = stds / stds.sum() 80 | else: 81 | weights = np.ones(len(metric_columns)) / len(metric_columns) 82 | 83 | df["weighted_score"] = (per_task_ranks * weights).sum(axis=1) 84 | df["aggregated_rank"] = rankdata(df["weighted_score"].values, method="min") 85 | 86 | # Sort best first 87 | cols = ["experiment", "mean_score", "weighted_score", "aggregated_rank"] + metric_columns 88 | return df.sort_values("weighted_score")[cols].reset_index(drop=True) 89 | 90 | def save_leaderboard(df: pd.DataFrame, output_dir: Path, phase: str): 91 | """Save leaderboard summary CSV (including mean_score, weighted_score, aggregated_rank).""" 92 | # Cast output_dir to Path 93 | if not isinstance(output_dir, Path): 94 | output_dir = Path(output_dir) 95 | 96 | leaderboard_path = output_dir / "leaderboards" / phase 97 | leaderboard_path.mkdir(parents=True, exist_ok=True) 98 | 99 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 100 | out_file = leaderboard_path / f"leaderboard_{timestamp}.csv" 101 | df.to_csv(out_file, index=False) 102 | print(f"Leaderboard CSV saved to {out_file}") 103 | 104 | def summarize_runs(output_dir: Path, phase: str): 105 | """Print and save leaderboard from all complete runs in a given phase.""" 106 | df = aggregate_results(output_dir, phase) 107 | if df.empty: 108 | print("No results to summarize.") 109 | return 110 | 111 | metric_columns = [col for col in df.columns if col not in {"experiment", "timestamp"}] 112 | leaderboard = compute_leaderboard(df, metric_columns) 113 | 114 | if leaderboard.empty: 115 | print("No complete runs (all have missing metrics); nothing to rank.") 116 | return 117 | 118 | print("\n=== Leaderboard Summary ===") 119 | display_cols = ["experiment", "mean_score", "weighted_score", "aggregated_rank"] + metric_columns 120 | print(leaderboard[display_cols].to_string(index=False)) 121 | 122 | save_leaderboard(leaderboard, output_dir, phase) 123 | -------------------------------------------------------------------------------- /examples/S2_dino_embeddings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Callable, Tuple 4 | 5 | import timm 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torchgeo.models import ViTSmall16_Weights 12 | from tqdm import tqdm 13 | 14 | from data.submission_utils import create_submission_from_dict, test_submission 15 | from data.dataset import ( 16 | E2SChallengeDataset, 17 | Normalize, 18 | TemporalMean, 19 | InputResizer, 20 | collate_fn, 21 | ) 22 | 23 | # Configurations 24 | MODALITIES = ["s2l1c"] # We are using the Dino ViT model pretrained on SSL4EO L1C data. 25 | INPUT_SIZE = 224 # Resize input images to the expected 224x224 pixels. 26 | EMBEDDING_SIZE = 1024 # Set output embedding size, unused dimensions will be padded with zeros. 27 | METHOD = "avg" # How to reduce the per-patch features into a single vector; Options: ["avg", "cls", "max"]. 28 | DATA_PATH = Path("./data") # Path to the SSL4EO-S12-downstream image directory. 29 | OUTPUT_PATH = Path("./results.csv") # Path to the output CSV file. 30 | 31 | 32 | # Load model backbone from TorchGeo 33 | def load_torchgeo_model( 34 | model_name: str, weights_obj: ViTSmall16_Weights, device: torch.device 35 | ) -> Tuple[nn.Module, Callable]: 36 | """ 37 | Loads a Timm model with specified TorchGeo weights. 38 | """ 39 | in_chans = weights_obj.meta.get("in_chans") 40 | model = timm.create_model(model_name, in_chans=in_chans) 41 | state_dict = weights_obj.get_state_dict(progress=True) 42 | model.load_state_dict(state_dict, strict=False) 43 | model.eval() 44 | model.to(device) 45 | 46 | logging.info( 47 | "Loaded TorchGeo model '%s' with in_chans=%d", 48 | model_name, 49 | in_chans, 50 | ) 51 | extractor = getattr(model, "forward_features", model) 52 | return extractor 53 | 54 | 55 | # Main function to extract embeddings 56 | def main() -> None: 57 | logging.basicConfig( 58 | level=logging.INFO, 59 | format='%(asctime)s [%(levelname)s] %(message)s', 60 | ) 61 | logger = logging.getLogger(__name__) 62 | 63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | 65 | # Load model 66 | weights = ViTSmall16_Weights.SENTINEL2_ALL_DINO 67 | extractor = load_torchgeo_model('vit_small_patch16_224', weights, device) 68 | 69 | transform = transforms.Compose([ 70 | Normalize(), # scale to [0,1] 71 | TemporalMean() # average over 4 seasonal timesteps 72 | ]) 73 | 74 | dataset = E2SChallengeDataset( 75 | DATA_PATH, 76 | modalities=MODALITIES, 77 | seasons=4, 78 | dataset_name='bands', 79 | transform=transform, 80 | concat=True, 81 | output_file_name=True, 82 | ) 83 | 84 | logger.info(f'Dataset length: {len(dataset)} samples') 85 | sample = dataset[0]['data'] 86 | logger.info(f'Sample data shape: {sample.shape}') 87 | 88 | loader = DataLoader( 89 | dataset, 90 | batch_size=1, 91 | num_workers=0, 92 | pin_memory=True, 93 | collate_fn=collate_fn, 94 | ) 95 | 96 | embeddings = {} 97 | resizer = InputResizer(INPUT_SIZE).to(device) 98 | 99 | # Define how to reduce patch embedding to a single vector 100 | reducers = { 101 | 'avg': lambda f: f[:, 1:].mean(dim=1), 102 | 'cls': lambda f: f[:, 0], 103 | 'max': lambda f: f[:, 1:, :].max(dim=1)[0], 104 | } 105 | 106 | # Extract embeddings 107 | for batch in tqdm(loader, desc='Extracting embeddings'): 108 | data = batch['data'].squeeze(0).to(device) 109 | data = resizer(data) 110 | 111 | with torch.no_grad(): 112 | features = extractor(data) 113 | try: 114 | emb_compressed = reducers[METHOD](features) 115 | except KeyError: 116 | raise ValueError(f'Unknown METHOD {METHOD!r}') 117 | emb_flat = emb_compressed.flatten() 118 | 119 | # Pad with zeroes to fixed EMBEDDING_SIZE 120 | if emb_flat.shape[0] < EMBEDDING_SIZE: 121 | emb_flat = F.pad( 122 | emb_flat, 123 | (0, EMBEDDING_SIZE - emb_flat.numel()), 124 | ) 125 | 126 | embeddings[batch["file_name"][0]] = emb_flat.cpu().tolist() 127 | 128 | # Create and save embedding csv file 129 | submission_df = create_submission_from_dict(embeddings) 130 | logger.info(f"Number of embeddings: {len(submission_df)}") 131 | submission_df.to_csv(OUTPUT_PATH, index=False) 132 | logger.info(f"Saved embeddings to {OUTPUT_PATH}") 133 | 134 | # Validate output format 135 | ids = set(embeddings.keys()) 136 | assert test_submission( 137 | OUTPUT_PATH, 138 | ids, 139 | EMBEDDING_SIZE, 140 | ) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() -------------------------------------------------------------------------------- /benchmark/evaluation/visualisations.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence, Optional 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | plt.style.use("ggplot") 8 | 9 | def save_plot( 10 | fig: plt.Figure, 11 | folder: Path, 12 | filename: str, 13 | dpi: int = 300, 14 | ) -> None: 15 | # Ensure target directory exists 16 | folder.mkdir(parents=True, exist_ok=True) 17 | output_path = folder / filename 18 | # Save and close 19 | fig.savefig(output_path, dpi=dpi, bbox_inches="tight") 20 | plt.close(fig) 21 | 22 | def save_loss_curve( 23 | loss_curves: Sequence[Sequence[float]], 24 | output_folder: Path, 25 | task_name: str, 26 | loss_type: str = "train", 27 | ) -> None: 28 | """ 29 | Plot and save one or multiple loss curves (training and/or validation). 30 | 31 | Args: 32 | loss_curves: An iterable of loss sequences, e.g., [train_loss, val_loss]. 33 | output_folder: Directory where the plot will be saved. 34 | task_name: Identifier for the current task. 35 | loss_type: Descriptor for the loss type (e.g., 'train', 'val', or 'combined'). 36 | """ 37 | # Build target directory for loss plots 38 | target_folder = output_folder / task_name / "loss_curves" 39 | fig, ax = plt.subplots(figsize=(6, 4)) 40 | 41 | # Plot each loss curve 42 | for idx, curve in enumerate(loss_curves, start=1): 43 | ax.plot(curve, label=f"Curve {idx}") 44 | 45 | ax.set_title(f"Loss Curves - {task_name} ({loss_type})") 46 | ax.set_xlabel("Epoch") 47 | ax.set_ylabel("Loss") 48 | ax.grid(True) 49 | 50 | # Save figure 51 | filename = f"{task_name}_{loss_type}_loss_curves.png" 52 | save_plot(fig, target_folder, filename) 53 | 54 | def plot_regression_scatter( 55 | y_train_true: Sequence[float], 56 | y_train_pred: Sequence[float], 57 | y_val_true: Sequence[float], 58 | y_val_pred: Sequence[float], 59 | task_name: str, 60 | fold_idx: int, 61 | output_folder: Path, 62 | base_name: str, 63 | ) -> None: 64 | """ 65 | Generate and save a scatter plot comparing true vs. predicted values for regression. 66 | 67 | Args: 68 | y_train_true: True target values for the training set. 69 | y_train_pred: Predicted values for the training set. 70 | y_val_true: True target values for the validation set. 71 | y_val_pred: Predicted values for the validation set. 72 | task_name: Identifier for the current task. 73 | fold_idx: Index of the cross-validation fold (0-based). 74 | output_folder: Root folder for saving results. 75 | base_name: Base filename prefix. 76 | """ 77 | # Build target directory for scatter plots 78 | target_folder = output_folder / task_name / "regression_scatter" 79 | fig, ax = plt.subplots(figsize=(6, 5)) 80 | 81 | # Plot training and validation points 82 | ax.scatter(y_train_true, y_train_pred, alpha=0.3, label="Train") 83 | ax.scatter(y_val_true, y_val_pred, alpha=0.6, label="Validation") 84 | 85 | # Draw diagonal identity line 86 | bounds = [min(min(y_train_true), min(y_val_true)), max(max(y_train_true), max(y_val_true))] 87 | ax.plot(bounds, bounds, linestyle='--', linewidth=3, color='green') 88 | 89 | ax.set_xlabel("True values") 90 | ax.set_ylabel("Predicted values") 91 | ax.set_title(f"Regression Scatter: {task_name}, Fold {fold_idx + 1}") 92 | ax.legend(loc="best") 93 | 94 | # Save figure 95 | filename = f"{base_name}_{task_name}_scatter_fold{fold_idx + 1}.png" 96 | save_plot(fig, target_folder, filename) 97 | 98 | 99 | def plot_confusion_matrix( 100 | cm_array: np.ndarray, 101 | task_name: str, 102 | fold_idx: int, 103 | output_folder: Path, 104 | base_name: str, 105 | ) -> None: 106 | """ 107 | Generate and save a confusion matrix plot for a classification task. 108 | 109 | Args: 110 | cm_array: Square confusion matrix as a NumPy array. 111 | task_name: Identifier for the current task. 112 | fold_idx: Index of the cross-validation fold (0-based). 113 | output_folder: Root folder for saving results. 114 | base_name: Base filename prefix. 115 | """ 116 | # Build target directory 117 | target_folder = output_folder / task_name / "confusion_matrices" 118 | fig, ax = plt.subplots(figsize=(6, 5)) 119 | 120 | # Display confusion matrix 121 | im = ax.imshow(cm_array, interpolation="nearest", cmap=plt.cm.Blues) 122 | ax.set_title(f"Confusion Matrix: {task_name}, Fold {fold_idx + 1}") 123 | fig.colorbar(im, ax=ax) 124 | 125 | # Set tick labels 126 | labels = np.arange(cm_array.shape[0]) 127 | ax.set_xticks(labels) 128 | ax.set_yticks(labels) 129 | 130 | # Annotate each cell 131 | thresh = cm_array.max() / 2.0 132 | for i in labels: 133 | for j in labels: 134 | color = 'white' if cm_array[i, j] > thresh else 'black' 135 | ax.text(j, i, f"{cm_array[i, j]}", ha='center', va='center', color=color) 136 | 137 | ax.set_xlabel("Predicted label") 138 | ax.set_ylabel("True label") 139 | 140 | # Save figure 141 | filename = f"{base_name}_{task_name}_cm_fold{fold_idx + 1}.png" 142 | save_plot(fig, target_folder, filename) 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/run_eval.sh 2 | run_eval.sh 3 | **/run_main.sh 4 | run_main.sh 5 | **/__pycache__/ 6 | **/*.pyc 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # UV 105 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | #uv.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | #poetry.toml 117 | 118 | # pdm 119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 120 | #pdm.lock 121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 122 | # in version control. 123 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 124 | .pdm.toml 125 | .pdm-python 126 | .pdm-build/ 127 | 128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 129 | __pypackages__/ 130 | 131 | # Celery stuff 132 | celerybeat-schedule 133 | celerybeat.pid 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | -------------------------------------------------------------------------------- /benchmark/data/embeddings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | import ast 5 | from pathlib import Path 6 | from typing import Optional, Set 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class EmbeddingDataset(Dataset): 16 | """ 17 | PyTorch Dataset wrapping embeddings and targets. 18 | Expects a DataFrame with 'embedding' and 'label' columns. 19 | """ 20 | def __init__(self, df: pd.DataFrame) -> None: 21 | """PyTorch Dataset wrapping embeddings and targets. 22 | Expects a DataFrame with 'embedding' and 'label' columns. 23 | 24 | Args: 25 | df: Pandas dataframe with columns 'embedding' and 'label'. 26 | """ 27 | self.features = torch.stack( 28 | [torch.tensor(e, dtype=torch.float32) for e in df['embedding']] 29 | ) 30 | self.targets = torch.tensor(df['label'].values, dtype=torch.float32) 31 | 32 | def __len__(self) -> int: 33 | return len(self.targets) 34 | 35 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 36 | return self.features[idx], self.targets[idx] 37 | 38 | 39 | def unify_nans_in_embeddings(vec_str): 40 | """Replace missing (NaN) values in string vec_str with 0.""" 41 | vec_str = vec_str.replace("float('nan')", "'nan'").replace('float("nan")', "'nan'") 42 | vec_str = re.sub(r'\\b(nan)\\b', "'nan'", vec_str, flags=re.IGNORECASE) 43 | 44 | try: 45 | vec = ast.literal_eval(vec_str) 46 | except Exception as e: 47 | raise ValueError(f"Error parsing embedding: {vec_str}, Error: {e}") 48 | 49 | processed_vec = [] 50 | for item in vec: 51 | if isinstance(item, str) and item.lower() == "nan": 52 | processed_vec.append(0.0) 53 | else: 54 | val = float(item) 55 | processed_vec.append(0.0 if np.isnan(val) else val) 56 | 57 | return processed_vec 58 | 59 | 60 | def process_embedding(embedding_str, embedding_dim): 61 | """Preprocess embedding to adhere to basic requirements that the vector should be embedding_dim elements long and contain no missing (NaN) values. 62 | 63 | Args: 64 | embedding_str: Embedding in format as string "[val1, val2, ...]". 65 | embedding_dim: Number of embedding dimensions. 66 | """ 67 | if pd.isna(embedding_str): 68 | return [0.0] * embedding_dim 69 | 70 | vec = unify_nans_in_embeddings(embedding_str) 71 | 72 | if len(vec) != embedding_dim: 73 | raise ValueError(f"Embedding dimension mismatch. Expected {embedding_dim}, got {len(vec)}") 74 | 75 | return vec 76 | 77 | 78 | def load_submission( 79 | file_path: Path, 80 | valid_ids: Set[str], 81 | expected_dim: int, 82 | exclude_file: Optional[Path] = None, 83 | standardize: bool = True, 84 | ) -> pd.DataFrame: 85 | """ 86 | Load and preprocess CSV of embeddings. 87 | 88 | - Filters by IDs in valid_ids and optional exclude list. 89 | - Standardizes embeddings (zero mean, unit variance). 90 | 91 | Returns a DataFrame with 'id' and 'embedding' columns. 92 | 93 | Args: 94 | file_path: Path to file containing embeddings. 95 | valid_ids: Set of valid embedding IDs. 96 | expected_dim: Expected dimension of embeddings, e.g. 1024. 97 | exclude_file: Path to file containing embeddings to exclude from processing. 98 | standardize: Boolean which controls whether embeddings are standardized (default). Standardization is done over the complete embedding_file, not per downstream task. 99 | Returns: 100 | Pandas Dataframe with columns 'id' and 'embedding'. 101 | Raises: 102 | ValueError if embedding_file does not contain column 'id'. 103 | ValueError if embedding_file contains missing (NaN) values. 104 | """ 105 | logger.info("Loading embeddings from %s", file_path) 106 | df = pd.read_csv(file_path) 107 | 108 | if 'id' not in df.columns: 109 | raise ValueError(f"""Submission file must contain column 'id'.""") 110 | 111 | df['id'] = df['id'].str.replace(".zarr.zip", "", regex=False) 112 | df.set_index('id', inplace=True) 113 | 114 | if exclude_file and exclude_file.exists(): 115 | bad_ids = {line.strip() for line in exclude_file.read_text().splitlines() if line.strip()} 116 | logger.info("Excluding %d corrupted IDs", len(bad_ids)) 117 | df = df.drop(index=bad_ids, errors='ignore') 118 | 119 | df = df.loc[df.index.intersection(valid_ids)] 120 | logger.info("Retained %d valid records", len(df)) 121 | 122 | processed_embeddings = [] 123 | for _, row in df.iterrows(): 124 | embedding = process_embedding(str(list(row)), expected_dim) 125 | processed_embeddings.append(embedding) 126 | 127 | embeddings_array = np.array(processed_embeddings) 128 | 129 | if np.isnan(embeddings_array).any(): 130 | raise ValueError("NaN values detected in processed embeddings.") 131 | 132 | if standardize: 133 | mu, sigma = embeddings_array.mean(), embeddings_array.std() 134 | sigma = sigma if sigma != 0 else 1 135 | logger.info("Standardizing embeddings (mean=%.4f, std=%.4f)", mu, sigma) 136 | embeddings_array = (embeddings_array - mu) / sigma 137 | 138 | result = pd.DataFrame({ 139 | 'id': df.index, 140 | 'embedding': list(embeddings_array) 141 | }).reset_index(drop=True) 142 | 143 | return result 144 | 145 | 146 | def parse_annotations(annotation_path: Path) -> pd.DataFrame: 147 | """ 148 | Load annotations JSON and flatten into DataFrame. 149 | """ 150 | raw = json.loads(annotation_path.read_text()) 151 | rows = [] 152 | for composite_key, items in raw.items(): 153 | task, kind = composite_key.split("__") 154 | for rec in items: 155 | rows.append({ 156 | 'id': rec['id'], 157 | 'label': float(rec['label']), 158 | 'task_name': task, 159 | 'task_type': kind, 160 | }) 161 | df = pd.DataFrame(rows) 162 | logger.info("Loaded %d annotation entries", len(df)) 163 | return df 164 | 165 | 166 | __all__ = ['EmbeddingDataset', 'load_submission'] 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuCo-Bench 2 | 3 | **softare licence**: Apache-2.0 4 | 5 | **TL;DR**: *Originally developed to evaluate challenge submissions for the 2025 EARTHVISION Challenge at CVPR ([competition details](https://www.grss-ieee.org/events/earthvision-2025/?tab=challenge)), NeuCo-Bench is now released for local benchmarking and evaluation - additional tech details in [http://arxiv.org/html/2510.17914](http://arxiv.org/html/2510.17914).* 6 | 7 | --- 8 | 9 | NeuCo-Bench is a **benchmarking framework** designed to evaluate how effectively compressed embeddings preserve information for downstream tasks. 10 | 11 | In domains like Earth Observation (EO), pipelines typically handle large volumes of image data used primarily for analytical tasks. Traditional compression techniques focus on pixel-level reconstruction, while Foundation Model (FM) research does not explicitly consider embedding size. NeuCo-Bench addresses this gap by enforcing strict size constraints and evaluating embeddings directly on real-world EO tasks. 12 | 13 | NeuCo-Bench provides an initial set of EO tasks and invites community contributions of additional tasks and datasets from EO and other domains. 14 | 15 |

16 | Framework overview 17 |

18 | 19 | 20 | ## Key Features 21 | 22 | - **Model-agnostic**: Supports evaluation of any fixed-size embedding (e.g. 1024‑dim feature vectors), which enables comparison among compression and representation learning methods. 23 | - **Task-Driven Evaluation**: Utilizes linear probes across diverse EO tasks, including land-cover proportion estimation, cloud detection, and biomass estimation. 24 | - **Metrics**: Incorporates signal-to-noise scores and dynamic rank aggregation to compare methods. 25 | 26 | --- 27 | 28 | ## Quickstart 29 | 30 | ```bash 31 | # start from fresh environment (skip if not needed) 32 | micromamba create -n neuco-bench -c conda-forge python=3.12 33 | micromamba activate neuco-bench 34 | 35 | # clone NeuCo-Bench and install requirements 36 | git clone https://github.com/embed2scale/NeuCo-Bench.git 37 | cd NeuCo-Bench/benchmark 38 | pip install -r ../requirements.txt 39 | 40 | # run standalone NeuCo-Bench evaluation script 41 | python main.py \ 42 | --annotation_path path/to/annotation_folder \ 43 | --submission_file path/to/submission_file.csv \ 44 | --output_dir path/to/results \ 45 | --config path/to/config.yaml \ 46 | --method_name your-method-name \ 47 | --phase phase-name 48 | ``` 49 | 50 | - `--annotation_path` Directory containing CSV label files for each task. 51 | - `--submission_file` CSV file with your embeddings. 52 | - `--output_dir` Destination for per-task reports, plots, and aggregated benchmark results. 53 | - `--config` YAML file specifying cross-validation settings and logging options (see provided sample). 54 | - `--method_name` Identifier for your method used in filenames and leaderboard entries. 55 | - `--phase` Groups evaluation runs under a specified phase name for ranking, creating a subfolder within `output_dir`. 56 | 57 | To disable GPU utilization, run `CUDA_VISIBLE_DEVICES=''` before execution. 58 | 59 | ## Overview 60 | 61 | NeuCo-Bench emphasizes task-oriented semantic evaluation rather than pixel-level reconstruction, measuring how effectively compressed embeddings retain information relevant to EO tasks. 62 | 63 | To evaluate embeddings: 64 | 1. Download the [SSL4EO-S12-downstream dataset](https://huggingface.co/datasets/embed2scale/SSL4EO-S12-downstream) from Hugging Face (see [Data](#data)). 65 | 2. Encode images into fixed-size embeddings, save as CSV (see [Creating Embeddings](#creating-embeddings)). 66 | 3. Run NeuCo-Bench locally to evaluate and aggregate scores, generating a leaderboard (see [Evaluation and Ranking](#evaluation-and-ranking)). 67 | 68 | --- 69 | 70 | ## Data 71 | 72 | The **SSL4EO-S12-downstream** dataset includes: 73 | 74 | - `data/` 75 | Subfolders for modalities (`s1/`, `s2l1c/`, `s2l2a/`) with subsets of 1000 `zarr.zip` files each. 76 | - `labels/` 77 | Annotation files for each downstream task. 78 | 79 | Both `data/` and `labels/` are required. See `examples/data` for a TorchDataset loader; if you experience data-loading errors, verify that `zarr==2.18.0` is used. 80 | 81 | Data format aligns with [SSL4EOS12 v1.1](https://github.com/DLR-MF-DAS/SSL4EO-S12-v1.1), recommended as a pretraining dataset. 82 | 83 | --- 84 | 85 | ## Creating Embeddings 86 | 87 | Generate embeddings and save them as CSV files. Example scripts in `examples/` illustrate the required format and provide two baseline methods: Averaging Baseline (Bilinear interpolation and averaging of the modalities) and downsampled embeddings from a pretrained FM (DINO ViT pretrained on SSL4EO). 88 | 89 | To ensure consistent benchmarking, all methods should use the same embedding dimension. We set the embedding size to 1024 (dimensions) during the 2025 CVPR EARTHVISION data challenge. 90 | As reference, we provide a selection of CSV files from the 2025 CVPR EARTHVISION data challenge in the repo's top-level `data/` directory. More details in `data/README.md`. 91 | In general, the https://github.com/embed2scale/NeuCo-Bench/tree/main/data folder is tracked by Git LFS to keep initial clones of this repo slim. If you like to download the approx. 500 MB of embeddings, utilize: 92 | ```Bash 93 | git lfs install 94 | git pull 95 | ``` 96 | 97 | --- 98 | 99 | ## Evaluation and Ranking 100 | 101 | Run the benchmark on your embeddings with: 102 | 103 | ```bash 104 | python main.py \ 105 | --annotation_path path/to/annotation_folder \ 106 | --submission_file path/to/submission_file.csv \ 107 | --output_dir path/to/results \ 108 | --config path/to/config.yaml \ 109 | --method_name "your-method-name" \ 110 | --phase "phase-name" 111 | ``` 112 | 113 | ### Configuration 114 | 115 | A sample config file (`benchmark/config.yaml`) specifies: 116 | 117 | - `batch_size`, `epochs`, `learning_rate`, `k_folds`: Cross-validation settings. 118 | - `standardize_embeddings`: Standardize embeddings using global mean and std (recommended). 119 | - `normalize_labels`: Normalize target labels to [0,1] (recommended). 120 | - `enable_plots`: Generate per-fold plots (e.g., parity plots for regression). 121 | - `update_leaderboard`: Aggregate and update leaderboard after evaluation. 122 | - `task_filter`: Tasks to evaluate (default: all tasks available in `annotation_path`). 123 | 124 | ### Results 125 | 126 | Results saved under `output_dir//` include: 127 | 128 | - Task-specific metrics and loss curves 129 | - `results_summary.json` with per-task signal-to-noise scores and overall scores 130 | 131 | ### Aggregation 132 | 133 | Aggregate scores for leaderboard by setting `update_leaderboard` to `True` during last evaluation or manually run: 134 | 135 | ```bash 136 | from evaluation.results import summarize_runs 137 | summarize_runs(output_dir=output_dir, phase=phase) 138 | ``` 139 | 140 | --- 141 | 142 | ## Future Work & Contributing 143 | 144 | All downstream tasks and labels are published on Hugging Face. We are planning to extend the framework to further tasks (eg. spatial and temporal downstream tasks). 145 | 146 | We invite the community to collaborate and appreciate contributions, including but not limited to the following: 147 | - Benchmark and contribute new compression techniques 148 | - Incorporate additional downstream task and metrics 149 | - Extension to further input modalities 150 | 151 | Check out [CONTRIBUTING.md](.github/CONTRIBUTING.md). 152 | 153 | ## How to cite 154 | 155 | ```BibTeX 156 | @article{Vinge2025NeuCoBench, 157 | author = {Rikard Vinge and Isabelle Wittmann and Jannik Schneider and Michael Marszalek and Luis Gilch and Thomas Brunschwiler and Conrad M Albrecht}, 158 | title = {NeuCo-Bench: A Novel Benchmark Framework for Neural Embeddings in Earth Observation}, 159 | journal = {arXiv preprint arXiv:2510.17914}, 160 | year = {2025}, 161 | url = {https://arxiv.org/abs/2510.17914}, 162 | doi = {10.48550/arXiv.2510.17914}, 163 | note = {Submitted on 19 Oct 2025}, 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /examples/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | As provided on 3 | https://github.com/DLR-MF-DAS/embed2scale-challenge-supplement/blob/main/data_loading_submission_demo/challenge_dataset.py 4 | """ 5 | import torch 6 | from torch.utils.data import Dataset 7 | import os 8 | import glob 9 | import xarray as xr 10 | import numpy as np 11 | from torch import nn 12 | from typing import List, Tuple 13 | 14 | S2L1C_MEAN = [2607.345, 2393.068, 2320.225, 2373.963, 2562.536, 3110.071, 3392.832, 3321.154, 3583.77, 1838.712, 1021.753, 3205.112, 2545.798] 15 | S2L1C_STD = [786.523, 849.702, 875.318, 1143.578, 1126.248, 1161.98, 1273.505, 1246.79, 1342.755, 576.795, 45.626, 1340.347, 1145.036] 16 | 17 | S2L2A_MEAN = [1793.243, 1924.863, 2184.553, 2340.936, 2671.402, 3240.082, 3468.412, 3563.244, 3627.704, 3711.071, 3416.714, 2849.625] 18 | S2L2A_STD = [1160.144, 1201.092, 1219.943, 1397.225, 1400.035, 1373.136, 1429.17, 1485.025, 1447.836, 1652.703, 1471.002, 1365.307] 19 | 20 | S1GRD_MEAN = [-12.577, -20.265] 21 | S1GRD_STD = [5.179, 5.872] 22 | 23 | 24 | class E2SChallengeDataset(Dataset): 25 | 26 | def __init__(self, 27 | data_path: str = None, 28 | transform = None, 29 | modalities: List[str] = None, 30 | dataset_name: str = 'bands', 31 | seasons: int = 4, 32 | randomize_seasons: bool = False, 33 | concat: bool = True, 34 | output_file_name: bool = False, 35 | shift_s2_channels: bool = True, 36 | ): 37 | """Dataset class for the embed2scale challenge data 38 | 39 | Parameters 40 | ---------- 41 | data_path : str, path-like 42 | Path to challenge data. Assumes that under data_path there are 3 subfolders, named after the modalities. 43 | transform : torch.Compose 44 | Transformations to apply to the data 45 | modalities : list[str] 46 | List of modalities to include. Should correpond to the subfolders under data_path. 47 | dataset_name : str 48 | Name of dataset in zarr archive. Use 'bands' here. Defaults to 'bands'. 49 | seasons : int 50 | Number of seasons to load. Must be integer between 1 and 4. Default is 4. 51 | randomize_seasons : bool 52 | Toggle randomized order of seasons. If True, the order of the seasons will be randomized. Default is False. 53 | concat : bool 54 | Toggle concatenating the modalities along the channel dimension. Default is True. 55 | output_file_name : bool 56 | Toggle output of the file name. 57 | shift_s2_channels : bool 58 | Toggle shifting the S2 channels by 1000 to align to SSL4EO-S12 v1.1. Default is True, where the challenge data S2 channels are 59 | shifted upward 1000 to have the range as SSL4EO-S12 v1.1. The background is that ESA decided 60 | from 2022-01-25 to shift the DN values of S2 by 1000 upward. SSL4EO-S12 v1.1 includes this shift, 61 | while the challenge data does not. 62 | 63 | Returns 64 | ------- 65 | torch.Tensor or dict 66 | If output_file_name=False, outputs a torch.Tensor. 67 | If output_file_name=True, outputs a dictionary with fields 'data' and 'file_name'. 'data' is a torch.Tensor if concat=True and a dict with one field per modality, each containing a torch.Tensor if False. 'file_name' is the id of the loaded file. 68 | """ 69 | 70 | self.data_path = data_path 71 | self.transform = transform 72 | self.modalities = modalities 73 | self.dataset_name = dataset_name 74 | assert isinstance(seasons, int) and (1 <= seasons <= 4), "Number of seasons must be integer between 1 and 4." 75 | 76 | self.seasons = seasons 77 | self.randomize_seasons = randomize_seasons 78 | if not randomize_seasons: 79 | self.possible_seasons = list(range(seasons)) 80 | else: 81 | self.possible_seasons = list(range(4)) 82 | assert len(modalities) > 0, "No modalities provided." 83 | self.concat = concat 84 | self.output_file_name = output_file_name 85 | self.shift_s2_channels = shift_s2_channels 86 | 87 | self.samples = glob.glob(os.path.join(data_path, modalities[0], '*', '*.zarr.zip')) 88 | 89 | def __len__(self): 90 | 91 | return len(self.samples) 92 | 93 | def __getitem__(self, idx): 94 | 95 | sample_path = self.samples[idx] 96 | file_name = os.path.splitext(os.path.basename(sample_path))[0].replace('.zarr', '') 97 | if self.randomize_seasons: 98 | seasons = [self.possible_seasons[ind] for ind in torch.randperm(len(self.possible_seasons)).tolist()[:self.seasons]] 99 | else: 100 | seasons = self.possible_seasons 101 | sample_paths = [sample_path] + [sample_path.replace(self.modalities[0]+'/', modality+'/') for modality in self.modalities[1:]] 102 | data = {} 103 | 104 | for modality, sample_path in zip(self.modalities, sample_paths): 105 | season_index = xr.DataArray(seasons, dims='time') 106 | data[modality] = xr.open_zarr(sample_path).isel(time=season_index)[self.dataset_name].values 107 | 108 | # Add shift to align S2 channels with SSL4EO-S12 v1.1 109 | if self.shift_s2_channels and (modality in ['s2l1c', 's2l2a']): 110 | data[modality] += 1000 111 | 112 | n_bands_per_modality = {m: d.shape[-3] for m, d in data.items()} 113 | start_ind_of_modality = {m: n for m, n in zip(self.modalities, [0] + np.cumsum(list(n_bands_per_modality.values())).tolist())} 114 | 115 | # Concatenate data 116 | data = np.concatenate(list(data.values()), axis=-3) 117 | data = data.astype(np.float32) 118 | data = torch.from_numpy(data) 119 | 120 | # Transform 121 | if self.transform is not None: 122 | data = self.transform(data) 123 | 124 | if not self.concat: 125 | data = {m: data[..., start_ind_of_modality[m]: start_ind_of_modality[m] + n_bands_per_modality[m], :, :] for m in self.modalities} 126 | 127 | if self.output_file_name: 128 | 129 | return {'data': data, 'file_name': file_name} 130 | else: 131 | 132 | return data 133 | 134 | def collate_fn(batch): 135 | if isinstance(batch, dict) or isinstance(batch, torch.Tensor): 136 | # Single sample 137 | return batch 138 | elif isinstance(batch, list) and isinstance(batch[0], torch.Tensor): 139 | # Concatenate tensors along sample dim 140 | return torch.concat(batch, dim=0) 141 | elif isinstance(batch, list) and isinstance(batch[0], dict): 142 | file_names = [sample['file_name'] for sample in batch] 143 | data = [sample['data'] for sample in batch] 144 | if isinstance(data[0], torch.Tensor): 145 | data = torch.concat(data, dim=0) 146 | elif isinstance(data[0], dict): 147 | data = { 148 | m: torch.concat([b[m] for b in data], dim=0) 149 | for m in data[0].keys() 150 | } 151 | return {'data': data, 'file_name': file_names} 152 | 153 | class InputResizer(nn.Module): 154 | """ 155 | Resizes spatial dimensions of input tensor via adaptive average pooling. 156 | """ 157 | 158 | def __init__(self, output_size: Tuple[int, int]): 159 | super().__init__() 160 | self.adaptive_pool = nn.AdaptiveAvgPool2d(output_size) 161 | 162 | def forward(self, x: torch.Tensor) -> torch.Tensor: 163 | return self.adaptive_pool(x) 164 | 165 | class Normalize: 166 | """ 167 | Normalizes image tensor for DINO: scales to [0,1] range by dividing by 10000. 168 | """ 169 | 170 | def __call__(self, img: torch.Tensor) -> torch.Tensor: 171 | img = img.float() / 10000.0 172 | return torch.clamp(img, 0.0, 1.0) 173 | 174 | class TemporalMean(nn.Module): 175 | """ 176 | Averages over the time dimension (first dim). 177 | """ 178 | def forward(self, x: torch.Tensor) -> torch.Tensor: 179 | return x.mean(dim=1, keepdim=True) -------------------------------------------------------------------------------- /benchmark/evaluation/linear_probing.py: -------------------------------------------------------------------------------- 1 | # Training and cross-validation pipeline for linear probes 2 | 3 | import json 4 | import copy 5 | import logging 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader, SubsetRandomSampler 15 | from sklearn.model_selection import ShuffleSplit 16 | from torchmetrics.classification import BinaryF1Score 17 | from torchmetrics.regression import R2Score 18 | 19 | from data.embeddings import EmbeddingDataset 20 | from evaluation.metrics import classification_metrics, regression_metrics 21 | from evaluation.visualisations import save_loss_curve, plot_confusion_matrix, plot_regression_scatter 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | @dataclass 26 | class FoldResult: 27 | train_loss: list[float] 28 | val_loss: list[float] 29 | metric_history: list[float] 30 | best_metric: float 31 | best_model_state: dict 32 | 33 | @dataclass 34 | class TaskResult: 35 | q_statistic: float 36 | mean_score: float 37 | std_dev: float 38 | 39 | # We employ a simple, linear model which can be used for regression and classification tasks. 40 | # New models can be added for new downstream task in the future. 41 | class LinearProbe(nn.Module): 42 | """Linear model for downstream tasks.""" 43 | 44 | def __init__(self, input_dim: int) -> None: 45 | super().__init__() 46 | self.linear = nn.Linear(input_dim, 1) 47 | 48 | def forward(self, features: torch.Tensor) -> torch.Tensor: 49 | return self.linear(features).view(-1) 50 | 51 | class Trainer: 52 | """Train and evaluate a single fold.""" 53 | 54 | def __init__( 55 | self, 56 | model: nn.Module, 57 | task_type: str, 58 | task_name: str, 59 | device: torch.device, 60 | learning_rate: float, 61 | fold_index: int, 62 | output_dir: Path, 63 | filename_prefix: str, 64 | enable_plots: bool, 65 | ) -> None: 66 | """Train and evaluate a single fold. 67 | 68 | Args: 69 | model: torch.Model to use in evaluation. 70 | task_type: Type of task, either "classification" or "regression". 71 | task_name: Name of task. 72 | device: Device, CPU or GPU, to run training and inference. 73 | learning_rate: Learning rate for Linear Probe training. 74 | fold_index: Index of split. Used in plotting. 75 | output_dir: Path to folder to store results in. 76 | filename_prefix: Prefix to add to files of visualizations. 77 | enable_plots: Flag controlling the creation of plots. Set to True to plot results. 78 | """ 79 | self.model = model.to(device) 80 | self.task_type = task_type 81 | self.task_name = task_name 82 | self.device = device 83 | self.fold_index = fold_index 84 | self.output_dir = output_dir 85 | self.filename_prefix = filename_prefix 86 | self.enable_plots = enable_plots 87 | 88 | if task_type in ("classification", "cls"): 89 | self.loss_fn = nn.BCEWithLogitsLoss() 90 | self.metric_fn = BinaryF1Score().to(device) 91 | elif task_type in ("regression", "regr"): 92 | self.loss_fn = nn.MSELoss() 93 | self.metric_fn = R2Score().to(device) 94 | else: 95 | raise ValueError(f"Unsupported task_type: {task_type}") 96 | 97 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate) 98 | 99 | def _step( 100 | self, 101 | batch: tuple[torch.Tensor, torch.Tensor], 102 | training: bool, 103 | ) -> tuple[torch.Tensor, torch.Tensor]: 104 | inputs, targets = (tensor.to(self.device) for tensor in batch) 105 | with torch.set_grad_enabled(training): 106 | preds = self.model(inputs) 107 | loss = self.loss_fn(preds, targets.float()) 108 | if training: 109 | self.optimizer.zero_grad() 110 | loss.backward() 111 | self.optimizer.step() 112 | return loss.detach(), preds.detach() 113 | 114 | def fit( 115 | self, 116 | train_loader: DataLoader, 117 | val_loader: DataLoader, 118 | epochs: int, 119 | ) -> FoldResult: 120 | """Fit model to data and evaluate. 121 | 122 | Args: 123 | train_loader: Data loader for training data. 124 | val_loader: Data loader for validation data. 125 | epochs: Number of epochs to train network. 126 | Returns: 127 | Instance of FoldResult containing the evaluation results. 128 | """ 129 | train_losses, val_losses, metrics = [], [], [] 130 | best_metric = float("-inf") 131 | best_state = None 132 | best_train_preds, best_train_targets = None, None 133 | best_val_preds, best_val_targets = None, None 134 | 135 | for _ in range(epochs): 136 | # Training phase 137 | self.model.train() 138 | epoch_train_losses, train_preds, train_targets = [], [], [] 139 | for batch in train_loader: 140 | loss, preds = self._step(batch, training=True) 141 | epoch_train_losses.append(loss) 142 | train_preds.append(preds) 143 | train_targets.append(batch[1]) 144 | train_losses.append(torch.stack(epoch_train_losses).mean().item()) 145 | 146 | # Validation phase 147 | self.model.eval() 148 | epoch_val_losses, val_preds, val_targets = [], [], [] 149 | with torch.no_grad(): 150 | for batch in val_loader: 151 | loss, preds = self._step(batch, training=False) 152 | epoch_val_losses.append(loss) 153 | val_preds.append(preds) 154 | val_targets.append(batch[1]) 155 | val_losses.append(torch.stack(epoch_val_losses).mean().item()) 156 | 157 | # Compute metric on validation data 158 | metric_value = self.metric_fn( 159 | torch.cat(val_preds), torch.cat(val_targets) 160 | ).item() 161 | metrics.append(metric_value) 162 | 163 | # Update best model 164 | if metric_value > best_metric: 165 | best_metric = metric_value 166 | best_state = copy.deepcopy(self.model.state_dict()) 167 | best_train_preds, best_train_targets = train_preds, train_targets 168 | best_val_preds, best_val_targets = val_preds, val_targets 169 | 170 | # Generate plots if enabled 171 | if self.enable_plots: 172 | if self.task_type in ("classification", "cls"): 173 | cm = classification_metrics( 174 | torch.cat(best_train_preds), torch.cat(best_train_targets) 175 | )["confusion_matrix"] 176 | plot_confusion_matrix( 177 | np.array(cm), self.task_name, self.fold_index, 178 | self.output_dir, self.filename_prefix 179 | ) 180 | else: 181 | plot_regression_scatter( 182 | torch.cat(best_train_targets), torch.cat(best_train_preds), 183 | torch.cat(best_val_targets), torch.cat(best_val_preds), 184 | self.task_name, self.fold_index, 185 | self.output_dir, self.filename_prefix, 186 | ) 187 | 188 | return FoldResult(train_losses, val_losses, metrics, best_metric, best_state) 189 | 190 | def cross_validate( 191 | df: pd.DataFrame, 192 | task_type: str, 193 | task_name: str, 194 | device: torch.device, 195 | batch_size: int, 196 | n_splits: int, 197 | epochs: int, 198 | embedding_dim: int, 199 | learning_rate: float, 200 | output_dir: Path, 201 | filename_prefix: str, 202 | enable_plots: bool, 203 | random_seed: int = 42, 204 | output_fold_results: bool = False, 205 | ) -> TaskResult: 206 | """Perform k-fold cross-validation with a linear probe. 207 | 208 | Args: 209 | df: Dataframe to evaluate. 210 | task_type: Type of task, either "classification" or "regression". 211 | task_name: Name of task. 212 | device: Device, CPU or GPU, to run training and inference. 213 | batch_size: Batch size in Linear Probe training. 214 | n_splits: Number of repetitions of Linear Probe evaluation. Use to gather statistics. 215 | epochs: Number of epochs to use in Linear Probe training. 216 | embedding_dim: Size of embeddings. 217 | learning_rate: Learning rate for Linear Probe training. 218 | output_dir: Path to folder to store results in. 219 | enable_plots: Toggle storing of plots. Set to True to store plots. 220 | random_seed: Integer seed for random number generator. 221 | output_fold_results: Toggle storing performance metric per fold in addition to summary statistics. Default is False, in which case performance per fold is not stored. 222 | Returns: 223 | A TaskResult instance containing the evaluation results. 224 | """ 225 | logger.info("Cross-validation start: %s", task_name) 226 | 227 | dataset = EmbeddingDataset(df) 228 | splitter = ShuffleSplit( 229 | n_splits=n_splits, test_size=0.1, random_state=random_seed 230 | ) 231 | fold_results: list[FoldResult] = [] 232 | 233 | for fold_index, (train_idx, test_idx) in tqdm(enumerate(splitter.split(dataset)), total=n_splits): 234 | #logger.info("Fold %d/%d", fold_index + 1, n_splits) 235 | 236 | train_loader = DataLoader( 237 | dataset, 238 | batch_size=batch_size, 239 | sampler=SubsetRandomSampler(train_idx), 240 | pin_memory=True, 241 | ) 242 | val_loader = DataLoader( 243 | dataset, 244 | batch_size=batch_size, 245 | sampler=SubsetRandomSampler(test_idx), 246 | pin_memory=True, 247 | ) 248 | 249 | model = LinearProbe(embedding_dim) 250 | trainer = Trainer( 251 | model=model, 252 | task_type=task_type, 253 | task_name=task_name, 254 | device=device, 255 | learning_rate=learning_rate, 256 | fold_index=fold_index, 257 | output_dir=output_dir, 258 | filename_prefix=filename_prefix, 259 | enable_plots=enable_plots, 260 | ) 261 | fold_results.append(trainer.fit(train_loader, val_loader, epochs)) 262 | 263 | # Aggregate and save results 264 | train_losses = [fr.train_loss for fr in fold_results] 265 | val_losses = [fr.val_loss for fr in fold_results] 266 | best_metrics = np.array([fr.best_metric for fr in fold_results], dtype=np.float64) 267 | 268 | save_loss_curve(train_losses, output_dir, task_name, loss_type="train") 269 | save_loss_curve(val_losses, output_dir, task_name, loss_type="validation") 270 | 271 | # Compute summary Q-statistic 272 | mean_score = np.nanmean(best_metrics) 273 | std_dev = np.nanstd(best_metrics) 274 | q_stat = mean_score / (0.02 + std_dev) * 2 275 | 276 | result_metrics = { 277 | "q_stat": q_stat, 278 | "mean_score": mean_score, 279 | "std_dev": std_dev, 280 | } 281 | if output_fold_results: 282 | result_metrics["q_t"] = best_metrics.tolist() 283 | 284 | (output_dir / task_name / f"{task_name}_result.json").write_text( 285 | json.dumps(result_metrics, indent=2) 286 | ) 287 | 288 | return TaskResult(q_stat, mean_score, std_dev) 289 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Embed2Scale 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /examples/baseline_compression_mean.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copy of https://github.com/DLR-MF-DAS/embed2scale-challenge-supplement/blob/main/data_loading_submission_demo/baseline_compression_mean.ipynb" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Embed2Scale challenge \"mean\" baseline\n", 15 | "\n", 16 | "This notebook creates baseline embeddings by bilinear interpolation and averaging of the modalities.\n", 17 | "\n", 18 | "We use the E2SChallengeDataset to load the data. The datacubes of the challenge data are of shapes (1, 4, 27, 264, 264), (number of samples, number of timesteps, number of channels, height, width).\n", 19 | "\n", 20 | "The embedding works as follow:\n", 21 | "1. Subsample each channel to 8x8 pixels using bilinear interpolation -> shape (1, 4, 27, 8, 8)\n", 22 | "2. Average B01 through B09 for both S2L1C and L2 L2A along the channel dimension. Average B11 and B12 along the channel dimension. Average S1 channels along the channel dimension. Concatenate the three averages and B10 along channel dimension -> shape (1, 4, 4, 8, 8)\n", 23 | "3. Flatten into 1024 element vector -> shape (1024,)\n", 24 | "\n", 25 | "After embedding, a submission file is created in the expected format for the embed2scale eval.ai challenge. If you use this code, verify that it produces the right number of decimals for your output.\n", 26 | "\n", 27 | "At the end, a function to test that a submission file is readable for evaluation is provided.\n", 28 | "\n", 29 | "Note that parts of this notebook is simplified for demonstration purposes. However, the datasets and dataloaders, as well as the verification of the submission file are intended to be directly usable and true to the data and the expected submission file formats." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "from concurrent.futures import ThreadPoolExecutor\n", 41 | "from scipy.ndimage import zoom\n", 42 | "from torchvision import transforms\n", 43 | "\n", 44 | "from data.dataset import E2SChallengeDataset, S2L1C_MEAN, S2L1C_STD, S2L2A_MEAN, S2L2A_STD, S1GRD_MEAN, S1GRD_STD" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# Configurations" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# Order of modalities.\n", 61 | "# In this demo, modalities are ordered the same as the default order in the SSL4EOS12 dataset class.\n", 62 | "# Modalities are loaded in the order provided here.\n", 63 | "# Change the order based on your needs.\n", 64 | "modalities = ['s2l1c', 's2l2a', 's1']\n", 65 | "\n", 66 | "# Path to challenge data folder, i.e. the folder containing the s1, s2l1c and s2l2a subfolders.\n", 67 | "path_to_data = '/path/to/challenge/data/'\n", 68 | "\n", 69 | "# Path to where the submission file should be written.\n", 70 | "path_to_output_file = 'path/to/output/file.csv'\n", 71 | "\n", 72 | "write_result_to_file = True # Set to True to trigger saving of the csv at the end.\n", 73 | "\n", 74 | "# Create data transformation\n", 75 | "# Get mean and standard deviations for the modalities in the same order as the modalities\n", 76 | "# Note that we will use the `shift_s2_channels` flag in the challenge dataset, and we should \n", 77 | "# therefore use the mean and standard deviation of the SSL4EO-S12 v1.1 dataset.\n", 78 | "mean_data = S2L1C_MEAN + S2L2A_MEAN + S1GRD_MEAN\n", 79 | "std_data = S2L1C_STD + S2L2A_STD + S1GRD_STD\n", 80 | "\n", 81 | "data_transform = transforms.Compose([\n", 82 | " # Add additional transformation here\n", 83 | " transforms.Normalize(mean=mean_data, std=std_data)\n", 84 | "])\n", 85 | "\n", 86 | "# Note that both E2SChallengeDataset and SSL4EOS12Dataset outputs torch tensors, so there is no need to a ToTensor transform." 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "# Load data" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Length of train dataset: 5149\n", 106 | "Modality s2l1c shape: torch.Size([1, 4, 13, 264, 264])\n", 107 | "Modality s2l2a shape: torch.Size([1, 4, 12, 264, 264])\n", 108 | "Modality s1 shape: torch.Size([1, 4, 2, 264, 264])\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "# Concatenate modalities\n", 114 | "# dataloader output is {'data': concatenated_data, 'file_name': file_name}\n", 115 | "# The data has shapes [n_samples, n_seasons, n_channels, height, width] (for concatenated_data [1, 4, 27, 264, 264])\n", 116 | "\n", 117 | "dataset_e2s = E2SChallengeDataset(path_to_data, \n", 118 | " modalities = modalities, \n", 119 | " dataset_name='bands', \n", 120 | " transform=data_transform, \n", 121 | " concat=False,\n", 122 | " output_file_name=True,\n", 123 | " shift_s2_channels=True\n", 124 | " )\n", 125 | "\n", 126 | "# Print dataset length\n", 127 | "print(f\"Length of train dataset: {len(dataset_e2s)}\")\n", 128 | "\n", 129 | "# Print shape of first sample\n", 130 | "for m, d in dataset_e2s[0]['data'].items():\n", 131 | " print(f'Modality {m} shape:', d.shape)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "# Create submission file\n", 139 | "\n", 140 | "In this section, we create a submission by randomly generating embeddings of the correct size.\n", 141 | "Finally, we create a submission file.\n", 142 | "\n", 143 | "We use the E2SChallengeDataset since we can easily get the sample ID (file name) from the this." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 4, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def create_submission_from_dict(emb_dict):\n", 153 | " \"\"\"Assume dictionary has format {hash-id0: embedding0, hash-id1: embedding1, ...}\n", 154 | " \"\"\"\n", 155 | " df_submission = pd.DataFrame.from_dict(emb_dict, orient='index')\n", 156 | " \n", 157 | " # Reset index with name 'id'\n", 158 | " df_submission.index.name = 'id'\n", 159 | " df_submission.reset_index(drop=False, inplace=True)\n", 160 | " \n", 161 | " return df_submission\n", 162 | " " 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "# Compress by bilinear transform and channel averaging\n", 170 | "\n", 171 | "In this section, we create a submission file by processing each sample accordingly:\n", 172 | "1. Subsampling each channel to 8x8 pixels using bilinear interpolation\n", 173 | "2. Average channels B01 to B09 for both L1C and L2A, average B11 and B12, and average S1 channels. Together with B10, this turns into 4 new channels.\n", 174 | "3. Flatten into 1024 element vector.\n", 175 | "\n", 176 | "We use the dataloader based on the E2SChallengeDataset since we can easily get the sample ID (file name) from the dataloader." 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 5, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "# Correlation analysis show that L1C and L2A channels B01 to B09 are correlated, B11 and B12 are correlated, \n", 186 | "# and S1 VV and VH are correlated, so we average these, leaving (together with B10) 4 averaged channels.\n", 187 | "\n", 188 | "def embed(data, file_name, emb_len=1024):\n", 189 | " # Bilinear interpolation of each channel separately.\n", 190 | " rescaled_mod = {m: zoom(d, (1, 1, 1, 8/d.shape[3], 8/d.shape[4]), order=1) for m, d in data.items()}\n", 191 | "\n", 192 | " # Calculate mean of correlated channels.\n", 193 | " b1_b9 = np.mean(np.concatenate((rescaled_mod['s2l1c'][:, :, 0:9, :, :], \n", 194 | " rescaled_mod['s2l2a'][:, :, 0:9, :, :]), axis=2), \n", 195 | " axis=2, keepdims=True)\n", 196 | " b10 = rescaled_mod['s2l1c'][:, :, 9:10, :, :]\n", 197 | " b11_b12 = np.mean(np.concatenate((rescaled_mod['s2l1c'][:, :, 10:, :, :], \n", 198 | " rescaled_mod['s2l2a'][:, :, 10:, :, :]), axis=2), \n", 199 | " axis=2, keepdims=True)\n", 200 | " s1 = np.mean(rescaled_mod['s1'], axis=2, keepdims=True)\n", 201 | "\n", 202 | " # Concatenate aggregated channels\n", 203 | " emb = np.concatenate((b1_b9, b10, b11_b12, s1), axis=2)\n", 204 | "\n", 205 | " # Flatten\n", 206 | " emb = emb.flatten()\n", 207 | "\n", 208 | " return {'file_name': file_name, 'embedding': emb}\n", 209 | "\n", 210 | "\n", 211 | "def mean_embedding_parallel(dataset, n_workers=4, n_samples=None):\n", 212 | " \n", 213 | " # Initialize result embeddings\n", 214 | " embeddings = {}\n", 215 | "\n", 216 | " # Run embedding in parallel\n", 217 | " with ThreadPoolExecutor(max_workers=n_workers) as executor:\n", 218 | " futures = []\n", 219 | " \n", 220 | " for ind, data_file_name in enumerate(dataset):\n", 221 | " data = data_file_name['data']\n", 222 | " # print(data)\n", 223 | " file_name = data_file_name['file_name']\n", 224 | " # Submit the batch for processing\n", 225 | " future = executor.submit(embed, data, file_name)\n", 226 | " futures.append(future)\n", 227 | "\n", 228 | " if (n_samples is not None) and (ind-1 > n_samples):\n", 229 | " break\n", 230 | " \n", 231 | " # Extract results\n", 232 | " for future in futures:\n", 233 | " res = future.result()\n", 234 | " # Compile embeddings\n", 235 | " embeddings[res['file_name']] = res['embedding']\n", 236 | " return embeddings\n" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 6, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "n_workers = 1\n", 246 | "if n_workers != 1:\n", 247 | " # Embed data\n", 248 | " embeddings = mean_embedding_parallel(dataset_e2s, n_workers=n_workers, n_samples=10)\n", 249 | "else:\n", 250 | " embeddings = {}\n", 251 | " for ind, data_file_name in enumerate(dataset_e2s):\n", 252 | " data = data_file_name['data']\n", 253 | " file_name = data_file_name['file_name']\n", 254 | " emb = embed(data, file_name, 1024)\n", 255 | " embeddings[file_name] = emb['embedding']\n", 256 | " " 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 7, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "# Create submission file\n", 266 | "submission_file = create_submission_from_dict(embeddings)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 8, 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "name": "stdout", 276 | "output_type": "stream", 277 | "text": [ 278 | "Number of embeddings: 5149\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "print('Number of embeddings:', len(submission_file))" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 9, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/html": [ 294 | "
\n", 295 | "\n", 308 | "\n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | "
id012345678...1014101510161017101810191020102110221023
0fec24d0cda8793ff55e1059c7b88763fee8d58d3decf78...-0.136362-0.371835-0.404560-0.508725-0.460369-0.422099-0.455290-0.303729-0.100668...0.3952650.7119390.6018920.3839440.8749820.4498060.952038-0.2688830.8845330.243575
167960f4c8870a8aa52f295da0f0fea6d708c3cee2555a4...0.2593050.0836600.0438980.3328390.073862-0.279074-0.1639570.0979870.054454...0.113527-0.575751-0.560006-0.238343-0.913553-0.952944-0.011693-0.6644400.862798-0.504407
29688abfaebaea5dca2ec8bde771a7bf1e2bba8e661b777...0.0688810.1127590.0852320.1193780.012920-0.1038900.0191920.134928-0.128380...0.388513-0.447894-1.262257-1.520254-0.984263-1.121416-0.635569-1.050879-1.350882-0.926634
3fa3ae237ee6e2ee569c20a1e088112cf2105300d9272cc...-1.164994-1.179528-1.185304-1.183173-1.179835-1.183128-1.183431-1.183904-1.148454...-1.174436-1.286493-1.486834-0.8395480.3618050.279468-0.059674-0.799558-0.876158-1.462009
4430590d31e38c5b345a92dc7d9eb8d126c01abced0cf1a...-0.166036-0.311182-0.300327-0.343975-0.384960-0.244595-0.299571-0.286590-0.221417...1.3548960.1188330.7459801.3083910.5399590.5296500.2330030.6463470.7467150.449681
\n", 458 | "

5 rows × 1025 columns

\n", 459 | "
" 460 | ], 461 | "text/plain": [ 462 | " id 0 1 \\\n", 463 | "0 fec24d0cda8793ff55e1059c7b88763fee8d58d3decf78... -0.136362 -0.371835 \n", 464 | "1 67960f4c8870a8aa52f295da0f0fea6d708c3cee2555a4... 0.259305 0.083660 \n", 465 | "2 9688abfaebaea5dca2ec8bde771a7bf1e2bba8e661b777... 0.068881 0.112759 \n", 466 | "3 fa3ae237ee6e2ee569c20a1e088112cf2105300d9272cc... -1.164994 -1.179528 \n", 467 | "4 430590d31e38c5b345a92dc7d9eb8d126c01abced0cf1a... -0.166036 -0.311182 \n", 468 | "\n", 469 | " 2 3 4 5 6 7 8 ... \\\n", 470 | "0 -0.404560 -0.508725 -0.460369 -0.422099 -0.455290 -0.303729 -0.100668 ... \n", 471 | "1 0.043898 0.332839 0.073862 -0.279074 -0.163957 0.097987 0.054454 ... \n", 472 | "2 0.085232 0.119378 0.012920 -0.103890 0.019192 0.134928 -0.128380 ... \n", 473 | "3 -1.185304 -1.183173 -1.179835 -1.183128 -1.183431 -1.183904 -1.148454 ... \n", 474 | "4 -0.300327 -0.343975 -0.384960 -0.244595 -0.299571 -0.286590 -0.221417 ... \n", 475 | "\n", 476 | " 1014 1015 1016 1017 1018 1019 1020 \\\n", 477 | "0 0.395265 0.711939 0.601892 0.383944 0.874982 0.449806 0.952038 \n", 478 | "1 0.113527 -0.575751 -0.560006 -0.238343 -0.913553 -0.952944 -0.011693 \n", 479 | "2 0.388513 -0.447894 -1.262257 -1.520254 -0.984263 -1.121416 -0.635569 \n", 480 | "3 -1.174436 -1.286493 -1.486834 -0.839548 0.361805 0.279468 -0.059674 \n", 481 | "4 1.354896 0.118833 0.745980 1.308391 0.539959 0.529650 0.233003 \n", 482 | "\n", 483 | " 1021 1022 1023 \n", 484 | "0 -0.268883 0.884533 0.243575 \n", 485 | "1 -0.664440 0.862798 -0.504407 \n", 486 | "2 -1.050879 -1.350882 -0.926634 \n", 487 | "3 -0.799558 -0.876158 -1.462009 \n", 488 | "4 0.646347 0.746715 0.449681 \n", 489 | "\n", 490 | "[5 rows x 1025 columns]" 491 | ] 492 | }, 493 | "execution_count": 9, 494 | "metadata": {}, 495 | "output_type": "execute_result" 496 | } 497 | ], 498 | "source": [ 499 | "submission_file.head()" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 10, 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "# Write submission\n", 509 | "if write_result_to_file:\n", 510 | " submission_file.to_csv(path_to_output_file, index=False)" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "# Verify submission file integrity\n", 518 | "\n", 519 | "Below we provide a snippet from a function which will read your embeddingsand test for the same errors that the evaluation will check for. The function is similar to how the submission files are loaded.\n", 520 | "\n", 521 | "The intention of this function is to help to verify that a submission has the right structure and contents, check for missing embeddings or NaN values, prior to submission.\n", 522 | "\n", 523 | "The function is intended to be a support. Successfully completing this function does not guarantee fault-free submission file, but is an indication that the most common errors are not present." 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 11, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "def test_submission(path_to_submission: str, \n", 533 | " expected_embedding_ids: set, \n", 534 | " embedding_dim: int = 1024):\n", 535 | " # Load data\n", 536 | " df = pd.read_csv(path_to_submission, header=0)\n", 537 | "\n", 538 | " # Verify that id is in columns\n", 539 | " if 'id' not in df.columns:\n", 540 | " raise ValueError(f\"\"\"Submission file must contain column 'id'.\"\"\")\n", 541 | "\n", 542 | " # Temporarily set index to 'id'\n", 543 | " df.set_index('id', inplace=True)\n", 544 | "\n", 545 | " # Check that all samples are included\n", 546 | " submitted_embeddings = set(df.index.to_list())\n", 547 | " n_missing_embeddings = len(expected_embedding_ids.difference(submitted_embeddings))\n", 548 | " if n_missing_embeddings > 0:\n", 549 | " raise ValueError(f\"\"\"Submission is missing {n_missing_embeddings} embeddings.\"\"\")\n", 550 | " \n", 551 | " # Check that embeddings have the correct length\n", 552 | " if len(df.columns) != embedding_dim:\n", 553 | " raise ValueError(f\"\"\"{embedding_dim} embedding dimensions, but provided embeddings have {len(df.columns)} dimensions.\"\"\")\n", 554 | "\n", 555 | " # Convert columns to float\n", 556 | " try:\n", 557 | " for col in df.columns:\n", 558 | " df[col] = df[col].astype(float)\n", 559 | " except Exception as e:\n", 560 | " raise ValueError(f\"\"\"Failed to convert embedding values to float.\n", 561 | " Check embeddings for any not-allowed character, for example empty strings, letters, etc.\n", 562 | " Original error message: {e}\"\"\")\n", 563 | "\n", 564 | " # Check if any NaNs \n", 565 | " if df.isna().any().any():\n", 566 | " raise ValueError(f\"\"\"Embeddings contain NaN values.\"\"\")\n", 567 | "\n", 568 | " # Successful completion of the function\n", 569 | " return True" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": 12, 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "# We use the created embeddings as the list of all samples.\n", 579 | "# This can be done since we are sure to have fully looped through the dataset.\n", 580 | "# A better way would be to find all the IDs in the challenge data separately, e.g. from the dataloader.\n", 581 | "embedding_ids = set(embeddings.keys())\n", 582 | "embedding_dim = 1024\n", 583 | "\n", 584 | "# Test submission\n", 585 | "assert test_submission(path_to_output_file, embedding_ids, embedding_dim)" 586 | ] 587 | } 588 | ], 589 | "metadata": { 590 | "kernelspec": { 591 | "display_name": "Python 3 (ipykernel)", 592 | "language": "python", 593 | "name": "python3" 594 | }, 595 | "language_info": { 596 | "codemirror_mode": { 597 | "name": "ipython", 598 | "version": 3 599 | }, 600 | "file_extension": ".py", 601 | "mimetype": "text/x-python", 602 | "name": "python", 603 | "nbconvert_exporter": "python", 604 | "pygments_lexer": "ipython3", 605 | "version": "3.11.0" 606 | } 607 | }, 608 | "nbformat": 4, 609 | "nbformat_minor": 4 610 | } 611 | --------------------------------------------------------------------------------