├── .gitignore ├── data_utils ├── .gitignore ├── data_utils │ ├── studies │ │ ├── __init__.py │ │ └── algonauts2025.py │ ├── features │ │ ├── __init__.py │ │ ├── neuro.py │ │ ├── subject.py │ │ ├── text.py │ │ ├── audio.py │ │ └── video.py │ ├── __init__.py │ ├── infra │ │ └── __init__.py │ ├── download.py │ ├── splitting.py │ ├── helpers.py │ ├── utils.py │ ├── dataloader.py │ ├── base.py │ ├── data.py │ ├── segments.py │ └── events.py └── pyproject.toml ├── modeling_utils ├── modeling_utils │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── transformer.py │ │ ├── common.py │ │ └── fmri_mlp.py │ ├── losses │ │ ├── __init__.py │ │ ├── losses.py │ │ └── base.py │ ├── metrics │ │ ├── __init__.py │ │ ├── base.py │ │ └── metrics.py │ ├── optimizers │ │ ├── __init__.py │ │ └── base.py │ └── utils.py └── pyproject.toml ├── algonauts2025 ├── grids │ ├── test_run.py │ ├── run_grid.py │ ├── run_ensemble.py │ ├── defaults.py │ └── average_submissions.py ├── callbacks.py ├── pl_module.py ├── model.py └── main.py ├── CONTRIBUTING.md ├── README.md └── CODE_OF_CONDUCT.md /.gitignore: -------------------------------------------------------------------------------- 1 | data_utils/build/ 2 | modeling_utils/build/ 3 | **/*.egg-info 4 | **/*.pyc -------------------------------------------------------------------------------- /data_utils/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .vscode 3 | */__pycache__ 4 | *.pyc 5 | .cache/ 6 | -------------------------------------------------------------------------------- /data_utils/data_utils/studies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .losses import LossConfig 7 | from .metrics import MetricConfig 8 | from .optimizers import OptimizerConfig 9 | -------------------------------------------------------------------------------- /modeling_utils/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "modeling_utils" 3 | version = "0.1.0" 4 | dependencies = [ 5 | "pandas>=2.0.1", 6 | "numpy>=2.1", 7 | "scikit-learn>=0.21.2", 8 | "torch>=2.5.1", 9 | "pytest>=7.4.0", 10 | "pydantic>=2.5.0", 11 | "torchmetrics>=1.1.2", 12 | "x_transformers>=1.27.20", 13 | "lightning>=2.0.8", 14 | "wandb>=0.15.11", 15 | ] -------------------------------------------------------------------------------- /data_utils/data_utils/features/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .audio import * 7 | from .neuro import * 8 | from .subject import * 9 | from .text import * 10 | from .video import VJEPA2 11 | -------------------------------------------------------------------------------- /data_utils/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "data_utils" 3 | version = "0.1.0" 4 | dependencies = [ 5 | "pandas>=2.2.2", 6 | "numpy>=2.1", 7 | "pyarrow>=17.0.0", 8 | "mne>=1.4.0", 9 | "pybv>=0.7.6", 10 | "mne_bids>=0.16", 11 | "scikit-learn>=0.21.2", 12 | "torch>=2.5.1", 13 | "pytest>=7.4.0", 14 | "nibabel>=5.1.0", 15 | "tqdm>=4.65.0", 16 | "exca>=0.4.5", 17 | ] 18 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | import warnings 8 | 9 | import pydantic 10 | 11 | from ..utils import all_subclasses 12 | from .transformer import TransformerEncoderConfig 13 | -------------------------------------------------------------------------------- /data_utils/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .base import CACHE_FOLDER as CACHE_FOLDER 7 | from .data import BaseData as BaseData 8 | from .dataloader import CollateSegments as CollateSegments 9 | from .dataloader import SegmentDataset as SegmentDataset 10 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | 8 | import pydantic 9 | 10 | from ..utils import all_subclasses 11 | from .base import BaseLossConfig 12 | 13 | LossConfig = BaseLossConfig 14 | 15 | 16 | def update_config_loss() -> None: 17 | global LossConfig 18 | 19 | from .base import BaseLossConfig 20 | 21 | LossConfig = tp.Annotated[ 22 | tp.Union[tuple(all_subclasses(BaseLossConfig))], 23 | pydantic.Field(discriminator="name"), 24 | ] 25 | 26 | 27 | update_config_loss() 28 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | 8 | import pydantic 9 | 10 | from ..utils import all_subclasses 11 | from .base import BaseMetricConfig 12 | 13 | MetricConfig = BaseMetricConfig 14 | 15 | 16 | def update_config_metric() -> None: 17 | global MetricConfig 18 | 19 | from .base import BaseMetricConfig 20 | 21 | MetricConfig = tp.Annotated[ 22 | tp.Union[tuple(all_subclasses(BaseMetricConfig))], 23 | pydantic.Field(discriminator="name"), 24 | ] 25 | 26 | 27 | update_config_metric() 28 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | 8 | import pydantic 9 | 10 | from ..utils import all_subclasses 11 | from .base import BaseOptimizerConfig, LightningOptimizerConfig, TorchLRSchedulerConfig 12 | 13 | OptimizerConfig = BaseOptimizerConfig 14 | 15 | 16 | def update_config_optimizer() -> None: 17 | global OptimizerConfig 18 | 19 | from .base import BaseOptimizerConfig 20 | 21 | OptimizerConfig = tp.Annotated[ 22 | tp.Union[tuple(all_subclasses(BaseOptimizerConfig))], 23 | pydantic.Field(discriminator="name"), 24 | ] 25 | 26 | 27 | update_config_optimizer() 28 | -------------------------------------------------------------------------------- /data_utils/data_utils/infra/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import exca 7 | 8 | if not hasattr(exca, "__version__"): 9 | raise RuntimeError("exca must be updated to version 0.2.0 or newer.") 10 | 11 | from exca import ConfDict as ConfDict 12 | from exca import MapInfra as MapInfra 13 | from exca import TaskInfra as TaskInfra 14 | from exca import helpers as helpers 15 | from exca.base import DEFAULT_CHECK_SKIPS 16 | from exca.cachedict import CacheDict as CacheDict 17 | 18 | 19 | def _skip_new_event_types(key, val, prev): 20 | if "event_types" in key and not prev: 21 | return True 22 | return False 23 | 24 | 25 | DEFAULT_CHECK_SKIPS.append(_skip_new_event_types) 26 | -------------------------------------------------------------------------------- /algonauts2025/grids/test_run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import logging 9 | import os 10 | 11 | from exca import ConfDict 12 | 13 | from ..main import Experiment # type: ignore 14 | from .defaults import default_config 15 | 16 | logging.getLogger("exca").setLevel(logging.DEBUG) 17 | 18 | 19 | update = { 20 | "save_checkpoints": False, 21 | "n_epochs": 6, 22 | "infra.cluster": None, 23 | "infra.gpus_per_node": 1, 24 | "infra.mode": "force", 25 | "data.num_workers": 0, 26 | "data.study.query": "subject_timeline_index<10", 27 | "data.study.cache_all_timelines": False, 28 | } 29 | 30 | 31 | def test_run(config: dict) -> None: 32 | task = Experiment(**config) 33 | task.infra.clear_job() 34 | trainer = task.run() 35 | 36 | 37 | if __name__ == "__main__": 38 | updated_config = ConfDict(default_config) 39 | updated_config.update(update) 40 | folder = os.path.join(updated_config["infra"]["folder"], "test") 41 | updated_config["infra"]["folder"] = folder 42 | if os.path.exists(folder): 43 | import shutil 44 | 45 | shutil.rmtree(folder) 46 | test_run(updated_config) 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to algonauts-2025 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to algonauts-2025, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /algonauts2025/grids/run_grid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from data_utils.infra import ConfDict 8 | from modeling_utils.utils import run_grid 9 | 10 | from ..main import Experiment # type: ignore 11 | from .defaults import PROJECT_NAME, SAVEDIR, default_config 12 | 13 | GRID_NAME = "grid" 14 | 15 | update = { 16 | "infra": { 17 | "cluster": "auto", 18 | "folder": SAVEDIR, 19 | "slurm_partition": "partition", 20 | "job_name": PROJECT_NAME, 21 | }, 22 | "wandb_config.group": GRID_NAME, 23 | "save_checkpoints": False, 24 | } 25 | 26 | grid = { 27 | "data.layers": [ 28 | [0, 0.5, 1], 29 | [0.5, 0.75, 1.0], 30 | [0.5, 1.], 31 | [0, 0.2, 0.4, 0.6, 0.8, 1.0], 32 | ], 33 | "seed": list(range(5)), 34 | } 35 | 36 | 37 | if __name__ == "__main__": 38 | updated_config = ConfDict(default_config) 39 | updated_config.update(update) 40 | 41 | out = run_grid( 42 | Experiment, 43 | GRID_NAME, 44 | updated_config, 45 | grid, 46 | job_name_keys=["wandb_config.name", "infra.job_name"], 47 | combinatorial=True, 48 | overwrite=False, 49 | dry_run=False, 50 | infra_mode="force", 51 | ) 52 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/losses/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class PearsonLoss(nn.Module): 12 | def __init__(self, reduction: str = "mean", dim: int = 1): 13 | super(PearsonLoss, self).__init__() 14 | self.reduction = reduction 15 | self.dim = dim 16 | 17 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 18 | 19 | x = x.transpose(0, self.dim) 20 | y = y.transpose(0, self.dim) 21 | x = x.reshape(x.shape[0], -1) 22 | y = y.reshape(y.shape[0], -1) 23 | 24 | x_mean = torch.mean(x, dim=1, keepdim=True) 25 | y_mean = torch.mean(y, dim=1, keepdim=True) 26 | x = x - x_mean 27 | y = y - y_mean 28 | 29 | cov = torch.sum(x * y, dim=1) 30 | 31 | x_std = torch.sqrt(torch.sum(x**2, dim=1)) 32 | y_std = torch.sqrt(torch.sum(y**2, dim=1)) 33 | 34 | pcc = cov / ((x_std * y_std) + 1e-8) 35 | 36 | loss = 1 - pcc 37 | if self.reduction == "mean": 38 | return loss.mean() 39 | elif self.reduction == "sum": 40 | return loss.sum() 41 | else: 42 | raise ValueError(f"Invalid reduction: {self.reduction}") 43 | -------------------------------------------------------------------------------- /algonauts2025/grids/run_ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from data_utils.infra import ConfDict 7 | from modeling_utils.utils import run_grid 8 | 9 | from ..main import Experiment # type: ignore 10 | from .defaults import PROJECT_NAME, SAVEDIR, default_config 11 | 12 | GRID_NAME = "model_soup" 13 | 14 | update = { 15 | "infra": { 16 | "cluster": "auto", 17 | "folder": SAVEDIR, 18 | "slurm_partition": "partition", 19 | "job_name": PROJECT_NAME, 20 | }, 21 | "wandb_config.group": GRID_NAME, 22 | "save_checkpoints": False, 23 | "seed": None, 24 | "patience": None, 25 | } 26 | 27 | grid = { 28 | "data.layers": [[0, 0.5, 1], [0.5, 0.75, 1.0], [0.5, 1.], [0, 0.2, 0.4, 0.6, 0.8, 1.]], 29 | "loss.name": ["MSELoss", "PearsonLoss", "SmoothL1Loss", "HuberLoss"], 30 | "data.layer_aggregation": [None, "group_mean"], 31 | "brain_model_config.subject_embedding": [True, False], 32 | "brain_model_config.layer_aggregation": ["cat", "mean"], 33 | "brain_model_config.feature_aggregation": ["cat", "sum"], 34 | "brain_model_config.modality_dropout": [0.0, 0.2, 0.4], 35 | } 36 | 37 | 38 | if __name__ == "__main__": 39 | updated_config = ConfDict(default_config) 40 | updated_config.update(update) 41 | 42 | out = run_grid( 43 | Experiment, 44 | GRID_NAME, 45 | updated_config, 46 | grid, 47 | job_name_keys=["wandb_config.name", "infra.job_name"], 48 | combinatorial=True, 49 | n_randomly_sampled=1000, 50 | overwrite=False, 51 | dry_run=False, 52 | infra_mode="force", 53 | ) 54 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/losses/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | from inspect import isclass 9 | 10 | import pydantic 11 | from torch import nn 12 | from torch.nn.modules.loss import _Loss 13 | 14 | from data_utils.infra import helpers 15 | from modeling_utils.utils import all_subclasses, convert_to_pydantic 16 | 17 | from . import losses 18 | 19 | custom_losses = [ 20 | obj for obj in losses.__dict__.values() if isclass(obj) and issubclass(obj, nn.Module) 21 | ] 22 | 23 | TORCHLOSS_NAMES = [loss_class.__name__ for loss_class in all_subclasses(_Loss)] 24 | 25 | 26 | class BaseLossConfig(pydantic.BaseModel): 27 | 28 | model_config = pydantic.ConfigDict(extra="forbid") 29 | name: str 30 | 31 | def build(self) -> nn.Module: 32 | raise NotImplementedError 33 | 34 | 35 | for loss_class in custom_losses: 36 | loss_class_name = loss_class.__name__ 37 | config_cls = convert_to_pydantic( 38 | loss_class, loss_class_name, parent_class=BaseLossConfig 39 | ) 40 | locals()[f"{loss_class_name}Config"] = config_cls 41 | 42 | 43 | class TorchLossConfig(BaseLossConfig): 44 | name: tp.Literal[tuple(TORCHLOSS_NAMES)] 45 | 46 | kwargs: dict[str, tp.Any] = {} 47 | 48 | def model_post_init(self, log__: tp.Any) -> None: 49 | super().model_post_init(log__) 50 | 51 | helpers.validate_kwargs(getattr(nn, self.name), self.kwargs) 52 | 53 | def build(self, **kwargs: tp.Any) -> nn.Module: 54 | if overlap := set(self.kwargs) & set(kwargs): 55 | raise ValueError( 56 | f"Build kwargs overlap with config kwargs for keys: {overlap}." 57 | ) 58 | kwargs = self.kwargs | kwargs 59 | return getattr(nn, self.name)(**kwargs) 60 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import typing as tp 9 | 10 | import pydantic 11 | import torch.nn as nn 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class TransformerEncoderConfig(pydantic.BaseModel): 17 | model_config = pydantic.ConfigDict(extra="forbid") 18 | name: tp.Literal["TransformerEncoder"] = "TransformerEncoder" 19 | heads: int = 8 20 | depth: int = 12 21 | 22 | cross_attend: bool = False 23 | causal: bool = False 24 | attn_flash: bool = False 25 | attn_dropout: float = 0.1 26 | 27 | ff_mult: int = 4 28 | 29 | ff_dropout: float = 0.0 30 | 31 | use_scalenorm: bool = True 32 | use_rmsnorm: bool = False 33 | 34 | rel_pos_bias: bool = False 35 | alibi_pos_bias: bool = False 36 | rotary_pos_emb: bool = True 37 | rotary_xpos: bool = False 38 | 39 | residual_attn: bool = False 40 | scale_residual: bool = True 41 | layer_dropout: float = 0.0 42 | 43 | def build(self, dim: int) -> nn.Module: 44 | from x_transformers import Decoder, Encoder 45 | 46 | if dim % self.heads != 0: 47 | raise ValueError( 48 | f"dim ({dim}) must be divisible by the number of heads ({self.heads})" 49 | ) 50 | if dim < 256: 51 | raise ValueError( 52 | f"dim ({dim}) is less than 256, which causes weird bug in x-transformers" 53 | ) 54 | kwargs = self.model_dump() 55 | kwargs["attn_dim_head"] = dim // self.heads 56 | del kwargs["name"] 57 | del kwargs["causal"] 58 | if self.causal: 59 | return Decoder(dim=dim, **kwargs) 60 | else: 61 | return Encoder(dim=dim, **kwargs) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [TRIBE: TRImodal Brain Encoder for whole-brain fMRI response prediction](https://www.arxiv.org/abs/2507.22229) 2 | 3 | This repository can be used for training and evaluating encoding models to predict fMRI brain responses to naturalistic video stimuli. 4 | 5 | ## Create the environment 6 | 7 | **1.** Create a conda environment for running and evaluating the model: 8 | 9 | ```bash 10 | export ENVNAME=algonauts-2025 11 | conda create -n $ENVNAME python=3.12 ipython -y 12 | conda activate $ENVNAME 13 | 14 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 15 | 16 | git clone https://github.com/facebookresearch/algonauts-2025.git 17 | cd algonauts-2025/data_utils 18 | pip install -e . 19 | cd ../modeling_utils 20 | pip install -e . 21 | 22 | pip install transformers moviepy spacy nilearn Levenshtein "huggingface_hub[cli]" julius 23 | ``` 24 | 25 | **2.** Get access to the [LLAMA3.2-3B repository on HuggingFace](https://huggingface.co/meta-llama/Llama-3.2-3B). First, run: 26 | 27 | ```bash 28 | huggingface-cli login 29 | ``` 30 | 31 | Then, create a `read` [token](https://huggingface.co/settings/tokens) and copy when prompted. 32 | 33 | 34 | **3.** Set paths to the Algonauts dataset, where you want to save your results, and what Slurm partition to use. This can be done by setting corresponding values in `algonauts2025/grids/defaults.py`, or alternatively, by adding the following to your shell’s startup file (e.g., `.bashrc`, `.zshrc`, etc.). 35 | 36 | ```bash 37 | export SAVEPATH="/your/save/directory" 38 | export DATAPATH="/path/to/algonauts/dataset" 39 | export SLURM_PARTITION="your-slurm-partition" 40 | ``` 41 | 42 | ## Run a test training locally 43 | 44 | ``` 45 | python -m algonauts2025.grids.test_run 46 | ``` 47 | 48 | ## Run a grid search on Slurm 49 | 50 | ``` 51 | python -m algonauts2025.grids.run_grid 52 | ``` 53 | 54 | ## Train an ensemble of models 55 | 56 | ``` 57 | python -m algonauts2025.grids.run_ensemble 58 | ``` 59 | 60 | Training and results can be monitored using [Weights & Biases](https://docs.wandb.ai/quickstart). See the config key `wandb_config`. 61 | 62 | 63 | ## License 64 | 65 | This repository is CC-BY-NC licensed, as found in the LICENSE file. 66 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/optimizers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | import pydantic 10 | import torch 11 | from torch import optim 12 | from torch.optim.lr_scheduler import LRScheduler 13 | from torch.optim.optimizer import Optimizer 14 | 15 | from data_utils.infra import helpers 16 | from modeling_utils.utils import all_subclasses 17 | 18 | TORCH_OPTIMIZER_NAMES = [ 19 | cls.__name__ for cls in all_subclasses(Optimizer) if cls.__name__ != "NewCls" 20 | ] 21 | TORCH_LR_SCHEDULER_NAMES = [cls.__name__ for cls in all_subclasses(LRScheduler)] 22 | 23 | 24 | class BaseOptimizerConfig(pydantic.BaseModel): 25 | 26 | model_config = pydantic.ConfigDict(extra="forbid") 27 | name: str 28 | 29 | def build(self, params: tp.Iterable[torch.Tensor]) -> Optimizer: 30 | raise NotImplementedError 31 | 32 | 33 | class TorchOptimizerConfig(BaseOptimizerConfig): 34 | name: tp.Literal[tuple(TORCH_OPTIMIZER_NAMES)] 35 | 36 | lr: float 37 | kwargs: dict[str, tp.Any] = {} 38 | 39 | def model_post_init(self, log__: tp.Any) -> None: 40 | super().model_post_init(log__) 41 | assert ( 42 | "lr" not in self.kwargs 43 | ), "lr should be defined as a base parameter instead of within kwargs." 44 | 45 | helpers.validate_kwargs(getattr(optim, self.name), self.kwargs | {"params": None}) 46 | 47 | def build(self, params: tp.Iterable[torch.Tensor]) -> Optimizer: 48 | return getattr(optim, self.name)(params, lr=self.lr, **self.kwargs) 49 | 50 | 51 | class BaseLRSchedulerConfig(pydantic.BaseModel): 52 | 53 | model_config = pydantic.ConfigDict(extra="forbid") 54 | name: str 55 | 56 | def build(self, optimizer: Optimizer) -> LRScheduler: 57 | raise NotImplementedError 58 | 59 | 60 | class TorchLRSchedulerConfig(BaseLRSchedulerConfig): 61 | name: tp.Literal[tuple(TORCH_LR_SCHEDULER_NAMES)] 62 | 63 | kwargs: dict[str, tp.Any] = {} 64 | 65 | def model_post_init(self, log__: tp.Any) -> None: 66 | super().model_post_init(log__) 67 | 68 | helpers.validate_kwargs( 69 | getattr(optim.lr_scheduler, self.name), self.kwargs | {"optimizer": None} 70 | ) 71 | 72 | def build(self, optimizer: Optimizer, **build_kwargs: tp.Any) -> LRScheduler: 73 | return getattr(optim.lr_scheduler, self.name)( 74 | optimizer, **(self.kwargs | build_kwargs) 75 | ) 76 | 77 | 78 | class LightningOptimizerConfig(pydantic.BaseModel): 79 | 80 | model_config = pydantic.ConfigDict(extra="forbid") 81 | name: tp.Literal["LightningOptimizer"] = "LightningOptimizer" 82 | optimizer: TorchOptimizerConfig 83 | scheduler: TorchLRSchedulerConfig | None = None 84 | interval: tp.Literal["step", "epoch"] = "step" 85 | 86 | def build( 87 | self, 88 | params: tp.Iterable[torch.Tensor], 89 | **scheduler_build_kwargs: tp.Any, 90 | ) -> dict[str, tp.Any]: 91 | out = {"optimizer": self.optimizer.build(params)} 92 | if self.scheduler is not None: 93 | scheduler = self.scheduler.build(out["optimizer"], **scheduler_build_kwargs) 94 | out["lr_scheduler"] = {"scheduler": scheduler, "interval": self.interval} 95 | 96 | return out 97 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /data_utils/data_utils/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import subprocess 8 | import typing as tp 9 | from glob import glob 10 | from pathlib import Path 11 | 12 | import pydantic 13 | from tqdm import tqdm 14 | 15 | class Wildcard(pydantic.BaseModel): 16 | folder: str 17 | 18 | 19 | class Datalad(pydantic.BaseModel): 20 | dset_dir: str | Path 21 | folder: str = "download" 22 | 23 | _dl_dir: Path = pydantic.PrivateAttr() 24 | 25 | def model_post_init(self, log__: tp.Any) -> None: 26 | super().model_post_init(log__) 27 | 28 | dset_dir = Path(self.dset_dir).resolve() 29 | if not dset_dir.parent.exists(): 30 | raise ValueError(f"Parent folder must exist for {dset_dir}") 31 | dset_dir.mkdir(exist_ok=True) 32 | self._dl_dir = dset_dir / self.folder 33 | self._dl_dir.mkdir(exist_ok=True, parents=True) 34 | 35 | def get_success_file(self) -> Path: 36 | cls_name = self.__class__.__name__.lower() 37 | return self._dl_dir / f"{cls_name}_Algonauts2025_success_download.txt" 38 | 39 | @tp.final 40 | def download(self, overwrite: bool = False) -> None: 41 | if self.get_success_file().exists() and not overwrite: 42 | return 43 | print(f"Downloading Algonauts2025 to {self._dl_dir}...") 44 | self._download() 45 | self.get_success_file().write_text("success") 46 | 47 | folders: list[str | Wildcard] = [] 48 | 49 | def install_requirements(cls) -> None: 50 | subprocess.run( 51 | [ 52 | "datalad-installer", 53 | "datalad", 54 | "git-annex", 55 | ] 56 | ) 57 | 58 | @pydantic.computed_field 59 | @property 60 | def repo_name(self) -> str: 61 | 62 | repo_name = Path( 63 | "https://github.com/courtois-neuromod/algonauts_2025.competitors.git", 64 | ).name 65 | if Path(repo_name).suffix == ".git": 66 | repo_name = repo_name[:-4] 67 | return repo_name 68 | 69 | def _datalad(self, cmd: str, path: Path | str) -> None: 70 | 71 | proc = subprocess.run( 72 | cmd, cwd=str(path), capture_output=True, text=True, shell=True 73 | ) 74 | if "install(error)" in proc.stdout: 75 | logging.warning("Potential error in datalad clone:\n> %s", proc.stdout) 76 | if proc.stderr: 77 | 78 | logging.warning("Potential error in datalad clone:\n> %s", proc.stderr) 79 | 80 | def _dl_item(self, cur_path: Path | str) -> None: 81 | cmd = f'datalad get "{cur_path}"' 82 | self._datalad(cmd, self._dl_dir / self.repo_name) 83 | 84 | def _download(self) -> None: 85 | 86 | self._datalad( 87 | "datalad clone https://github.com/courtois-neuromod/algonauts_2025.competitors.git", 88 | self._dl_dir, 89 | ) 90 | 91 | folders = self.folders if self.folders else [Wildcard(folder="*")] 92 | 93 | all_folders: list[Path] = [] 94 | for folder in folders: 95 | if isinstance(folder, Wildcard): 96 | all_folders += [ 97 | Path(str(p)) 98 | for p in glob(str(self._dl_dir / self.repo_name / folder.folder)) 99 | ] 100 | else: 101 | all_folders += [self._dl_dir / self.repo_name / folder] 102 | 103 | print(f"Loading {len(all_folders)} folders: ", all_folders) 104 | 105 | for item in tqdm(all_folders, desc="Downloading Algonauts2025", ncols=100): 106 | if not item.is_dir(): 107 | continue 108 | self._dl_item(item) 109 | 110 | print("\nDownloaded Dataset") 111 | -------------------------------------------------------------------------------- /data_utils/data_utils/splitting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import hashlib 7 | import random 8 | import typing as tp 9 | from dataclasses import dataclass 10 | 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from . import events as event_module 15 | 16 | 17 | @dataclass 18 | class DeterministicSplitter: 19 | ratios: tp.Dict[str, float] 20 | seed: float = 0.0 21 | 22 | def __post_init__(self) -> None: 23 | 24 | assert all(ratio > 0 for ratio in self.ratios.values()) 25 | assert np.allclose( 26 | sum(self.ratios.values()), 1.0 27 | ), f"the sum of ratios must be equal to 1. got {self.ratios}" 28 | 29 | def __call__(self, uid: str) -> str: 30 | hashed = int(hashlib.sha256(uid.encode()).hexdigest(), 16) 31 | rng = random.Random(hashed + self.seed) 32 | score = rng.random() 33 | 34 | cdf = np.cumsum(list(self.ratios.values())) 35 | names = list(self.ratios.keys()) 36 | 37 | for idx, cdf_val in enumerate(cdf): 38 | if score < cdf_val: 39 | return names[idx] 40 | raise ValueError 41 | 42 | 43 | def chunk_events( 44 | events: pd.DataFrame, 45 | event_type_to_chunk: tp.Literal["Sound", "Video"], 46 | event_type_to_use: str | None = None, 47 | min_duration: float | None = None, 48 | max_duration: float = np.inf, 49 | ): 50 | 51 | added_events: tp.List[tp.Dict] = [] 52 | dropped_rows: tp.List[int] = [] 53 | ns_event_type_to_chunk = getattr(event_module, event_type_to_chunk) 54 | assert hasattr( 55 | ns_event_type_to_chunk, "_split" 56 | ), f"Event type {event_type_to_chunk} is not splittable" 57 | if event_type_to_use is not None: 58 | assert "split" in events.columns, "Events must have a split column" 59 | 60 | for _, df in events.groupby("timeline"): 61 | df.sort_values("start", inplace=True) 62 | if event_type_to_use is None: 63 | 64 | timepoints: list[float] = np.arange( 65 | df.start.min(), df.stop.max(), max_duration 66 | ).tolist() 67 | if min_duration is not None: 68 | if df.stop.max() - timepoints[-1] < min_duration: 69 | timepoints = timepoints[:-1] 70 | else: 71 | 72 | timepoints = [] 73 | events_to_use = df.loc[events.type == event_type_to_use].copy() 74 | previous = events_to_use.copy().shift(1) 75 | split_change = events_to_use.split.astype(str) != previous.split.astype(str) 76 | events_to_use["section"] = np.cumsum(split_change.values) 77 | 78 | for _, section in events_to_use.groupby("section"): 79 | start, end = ( 80 | section.iloc[0].start, 81 | section.iloc[-1].start + section.iloc[-1].duration, 82 | ) 83 | timepoints.extend(np.arange(start, end, max_duration)) 84 | 85 | events_to_chunk = df.loc[events.type == event_type_to_chunk] 86 | dropped_rows.extend(events_to_chunk.index) 87 | for row in events_to_chunk.itertuples(): 88 | event_to_chunk = ns_event_type_to_chunk.from_dict(row) 89 | new_events = event_to_chunk._split( 90 | [t - event_to_chunk.start for t in timepoints], min_duration 91 | ) 92 | 93 | for new_event in new_events: 94 | new_event_dict = new_event.to_dict() 95 | 96 | for k, v in row._asdict().items(): 97 | 98 | if k not in new_event_dict: 99 | new_event_dict[k] = v 100 | added_events.append(new_event_dict) 101 | 102 | out_events = events.copy() 103 | out_events.drop(dropped_rows, inplace=True) 104 | out_events = pd.concat([out_events, pd.DataFrame(added_events)]) 105 | out_events.reset_index(drop=True, inplace=True) 106 | return out_events 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /data_utils/data_utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import concurrent.futures 7 | import logging 8 | import typing as tp 9 | 10 | import pandas as pd 11 | 12 | from data_utils import events, segments 13 | 14 | logger = logging.getLogger(__name__) 15 | TypesParam = str | tp.Sequence[str] | tp.Type[events.Event] | events.EventTypesHelper 16 | 17 | 18 | def extract_events(obj: tp.Any, types: TypesParam | None = None) -> list[events.Event]: 19 | 20 | helper: events.EventTypesHelper | None = None 21 | if isinstance(types, events.EventTypesHelper): 22 | helper = types 23 | elif types is not None: 24 | helper = events.EventTypesHelper(types) 25 | 26 | if isinstance(obj, (list, tuple)): 27 | if not obj: 28 | return [] 29 | if isinstance(obj[0], events.Event): 30 | if helper is not None: 31 | obj = [e for e in obj if isinstance(e, helper.classes)] 32 | return obj 33 | if isinstance(obj, pd.DataFrame): 34 | if helper is not None: 35 | obj = obj.loc[obj.type.isin(helper.names), :] 36 | unknown = set(obj.type) - set(events.Event._CLASSES) 37 | if unknown: 38 | logger.warning("Ignoring unknown event types: %s", unknown) 39 | obj = obj.loc[~obj.type.isin(unknown), :] 40 | 41 | num = len(obj) 42 | iterable = (obj.iloc[k, :] for k in range(num)) if num <= 2 else obj.itertuples() 43 | out = [events.Event.from_dict(r) for r in iterable] 44 | for i, e in zip(obj.index, out): 45 | e._index = i 46 | 47 | return out 48 | if isinstance(obj, events.Event): 49 | obj = [obj] 50 | elif isinstance(obj, dict): 51 | obj = [events.Event.from_dict(obj)] 52 | if not isinstance(obj, (list, tuple)): 53 | raise NotImplementedError(f"Conversion of {type(obj)} is not supported") 54 | if not obj: 55 | return [] 56 | if isinstance(obj[0], segments.Segment): 57 | event_dict = dict() 58 | for segment in obj: 59 | event_dict.update({id(e): e for e in segment.ns_events}) 60 | obj = list(event_dict.values()) 61 | if not isinstance(obj[0], events.Event): 62 | raise NotImplementedError(f"Unexpected list of {type(obj[0])} is not supported") 63 | return extract_events(obj, types=helper) 64 | 65 | 66 | def prepare_features( 67 | features: list[tp.Any] | dict[str, tp.Any], 68 | events: pd.DataFrame | tp.Sequence[events.Event] | tp.Sequence[segments.Segment], 69 | ) -> None: 70 | 71 | events = extract_events(events) 72 | 73 | feature_list = list(features.values()) if isinstance(features, dict) else features 74 | features_using_slurm = [ 75 | feature 76 | for feature in feature_list 77 | if hasattr(feature, "infra") 78 | and getattr(feature.infra, "cluster", None) == "slurm" 79 | ] 80 | other_features = [ 81 | feature for feature in feature_list if feature not in features_using_slurm 82 | ] 83 | slurm_names = ", ".join( 84 | feature.__class__.__name__ for feature in features_using_slurm 85 | ) 86 | with concurrent.futures.ThreadPoolExecutor() as executor: 87 | futures = [] 88 | for feature in features_using_slurm: 89 | futures.append(executor.submit(feature.prepare, events)) 90 | 91 | futures[-1].__dict__["_name"] = feature.__class__.__name__ 92 | if features_using_slurm: 93 | logger.info( 94 | f"Started parallel preparation of features {slurm_names} on slurm" 95 | ) 96 | for feature in other_features: 97 | logger.info(f"Preparing feature: {feature.__class__.__name__}") 98 | feature.prepare(events) 99 | for future in concurrent.futures.as_completed(futures): 100 | try: 101 | future.result() 102 | 103 | except Exception as e: 104 | name = future.__dict__.get("_name", "UNKNOWN") 105 | logger.warning("Error occurred while preparing feature %s: %s", name, e) 106 | raise 107 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/metrics/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import typing as tp 9 | from inspect import isclass 10 | 11 | import pandas as pd 12 | import pydantic 13 | import torch 14 | import torch.nn as nn 15 | import torchmetrics 16 | from data_utils.infra import helpers 17 | from modeling_utils.metrics import metrics 18 | from modeling_utils.utils import all_subclasses, convert_to_pydantic 19 | from torchmetrics import Metric 20 | 21 | custom_metrics = [ 22 | obj for obj in metrics.__dict__.values() if isclass(obj) and issubclass(obj, Metric) 23 | ] 24 | 25 | 26 | class MultidimPearsonCorrCoef(torchmetrics.PearsonCorrCoef): 27 | 28 | def compute(self): 29 | return super().compute().mean() 30 | 31 | 32 | TORCHMETRICS_NAMES = { 33 | metric_class.__name__: metric_class 34 | for metric_class in all_subclasses(Metric) 35 | if metric_class not in custom_metrics 36 | } 37 | 38 | 39 | class GroupedMetric(Metric): 40 | def __init__(self, metric_name: str, kwargs: dict[str, tp.Any] | None = None) -> None: 41 | super().__init__() 42 | if kwargs is None: 43 | kwargs = {} 44 | if metric_name in TORCHMETRICS_NAMES: 45 | self.base_metric_cls = TORCHMETRICS_NAMES[metric_name] 46 | else: 47 | assert hasattr(metrics, metric_name), f"Metric {metric_name} not found" 48 | self.base_metric_cls = getattr(metrics, metric_name) 49 | self.metric_kwargs = kwargs 50 | self.metrics = torch.nn.ModuleDict() 51 | 52 | def update( 53 | self, 54 | preds: torch.Tensor, 55 | target: torch.Tensor, 56 | groups: tp.Optional[torch.Tensor] = None, 57 | ) -> None: 58 | 59 | if groups is None: 60 | groups = torch.zeros(preds.shape[0]) 61 | else: 62 | groups = groups.flatten() 63 | assert ( 64 | len(groups) == preds.shape[0] 65 | ), f"Groups must be the same shape as preds/target, got {groups.shape} and {preds.shape}" 66 | 67 | groups_df = pd.DataFrame({"label": groups.tolist()}) 68 | for group_id, group in groups_df.groupby("label", sort=False): 69 | mask = group.index.to_numpy() 70 | group_preds = preds[mask] 71 | group_target = target[mask] 72 | 73 | group_key = str(group_id) 74 | if group_key not in self.metrics: 75 | self.metrics[group_key] = self.base_metric_cls(**self.metric_kwargs) 76 | self.metrics[group_key] = self.metrics[group_key].to(preds.device) 77 | 78 | self.metrics[group_key].update(group_preds, group_target) 79 | 80 | def compute(self) -> dict[str, float]: 81 | 82 | return { 83 | group_id: metric.compute().item() for group_id, metric in self.metrics.items() 84 | } 85 | 86 | def reset(self) -> None: 87 | for metric in self.metrics.values(): 88 | metric.reset() 89 | 90 | def __repr__(self) -> str: 91 | return f"GroupedMetric({self.base_metric_cls.__name__})" 92 | 93 | 94 | class BaseMetricConfig(pydantic.BaseModel): 95 | 96 | model_config = pydantic.ConfigDict(extra="forbid") 97 | 98 | log_name: str 99 | name: str 100 | 101 | def build(self) -> nn.Module: 102 | raise NotImplementedError 103 | 104 | 105 | for metric_class in custom_metrics + [GroupedMetric]: 106 | metric_class_name = metric_class.__name__ 107 | config_cls = convert_to_pydantic( 108 | metric_class, 109 | metric_class_name, 110 | parent_class=BaseMetricConfig, 111 | exclude_from_build=["log_name"], 112 | ) 113 | locals()[f"{metric_class_name}Config"] = config_cls 114 | 115 | 116 | class TorchMetricConfig(BaseMetricConfig): 117 | name: tp.Literal[tuple(TORCHMETRICS_NAMES.keys())] 118 | 119 | kwargs: dict[str, tp.Any] = {} 120 | 121 | def model_post_init(self, log__: tp.Any) -> None: 122 | super().model_post_init(log__) 123 | 124 | helpers.validate_kwargs(TORCHMETRICS_NAMES[self.name], self.kwargs) 125 | 126 | def build(self) -> nn.Module: 127 | return TORCHMETRICS_NAMES[self.name](**self.kwargs) 128 | -------------------------------------------------------------------------------- /algonauts2025/callbacks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | from data_utils.segments import SegmentCreator, _prepare_strided_windows 11 | from lightning.pytorch import Callback 12 | 13 | SUBJECT_MAPPINGS = {0: 1, 1: 2, 2: 3, 3: 5} 14 | 15 | 16 | class JitterWindows(Callback): 17 | def __init__( 18 | self, 19 | start_jitter_amount: float = 0.0, 20 | duration_jitter_amount: float = 0.0, 21 | ): 22 | self.start_jitter_amount = start_jitter_amount 23 | self.duration_jitter_amount = duration_jitter_amount 24 | 25 | def on_train_epoch_start(self, trainer, pl_module): 26 | start_jitter = (np.random.rand() * 2 - 1) * self.start_jitter_amount 27 | duration_jitter = (np.random.rand() * 2 - 1) * self.duration_jitter_amount 28 | segments = trainer.train_dataloader.dataset.segments 29 | new_segments = [] 30 | creators = SegmentCreator.from_obj(segments) 31 | for creator in creators.values(): 32 | starts, durations = _prepare_strided_windows( 33 | creator.starts.min() - 4.47 + start_jitter, 34 | creator.stops.max() - 4.47 + start_jitter, 35 | 149, 36 | 149, 37 | drop_incomplete=False, 38 | ) 39 | for start_, duration_ in zip(starts, durations): 40 | seg = creator.select(start=start_, duration=duration_) 41 | seg._trigger = start_ 42 | new_segments.append(seg) 43 | assert len(segments) == len(new_segments) 44 | trainer.train_dataloader.dataset.segments = new_segments 45 | 46 | 47 | class Benchmark(Callback): 48 | 49 | def __init__(self, root_data_dir): 50 | self.root_data_dir = Path(root_data_dir) 51 | self.submission_dict = {} 52 | 53 | def on_test_epoch_start(self, trainer, pl_module): 54 | self.submission_dict = {} 55 | 56 | def on_test_batch_end( 57 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0 58 | ): 59 | y_pred, _ = outputs # we do not have the ground truth 60 | overlap_trs = 0.0 61 | for i, segment in enumerate(batch.segments): 62 | subject = segment.events.subject.unique()[0] 63 | chunk = segment.events.chunk.unique()[0] 64 | pred = y_pred[i].cpu().numpy() # 1000, T 65 | pred = pred.T 66 | subject = subject.split("/")[1] 67 | chunk = "s07" + chunk.split(":")[1] 68 | if not subject in self.submission_dict: 69 | self.submission_dict[subject] = {} 70 | if not chunk in self.submission_dict[subject]: 71 | self.submission_dict[subject][chunk] = [] 72 | else: 73 | pred = pred[overlap_trs:] # remove the overlap except on the first chunk 74 | self.submission_dict[subject][chunk].append(pred) 75 | 76 | def on_test_epoch_end(self, trainer, pl_module): 77 | 78 | for subject in self.submission_dict.keys(): 79 | samples_file = ( 80 | self.root_data_dir 81 | / f"algonauts_2025.competitors/fmri/{subject}/target_sample_number/{subject}_friends-s7_fmri_samples.npy" 82 | ) 83 | target_sample_number = np.load(samples_file, allow_pickle=True).item() 84 | for chunk, sample_number in target_sample_number.items(): 85 | result = np.concatenate(self.submission_dict[subject][chunk], axis=0) 86 | if len(result) < sample_number: 87 | raise ValueError( 88 | f"Warning: {len(result)} predictions for {chunk} but expected at least {sample_number}" 89 | ) 90 | self.submission_dict[subject][chunk] = result[:sample_number] 91 | 92 | # save 93 | submission_name = "submission.npy" 94 | submission_path = Path(trainer.logger.save_dir) / submission_name 95 | np.save(submission_path, self.submission_dict) 96 | import zipfile 97 | 98 | try: 99 | with zipfile.ZipFile(submission_path.with_suffix(".zip"), "w") as zipf: 100 | zipf.write(submission_path, arcname=submission_name) 101 | print(f"Saved submission to {submission_path.with_suffix('.zip')}") 102 | except: 103 | print(f"Failed to save submission to {submission_path.with_suffix('.zip')}") 104 | -------------------------------------------------------------------------------- /algonauts2025/grids/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | PROJECT_NAME = "algonauts-2025" 11 | 12 | 13 | SLURM_PARTITION = "partition" 14 | DATADIR = "save_dir" 15 | BASEDIR = os.path.expandvars("save_dir") 16 | 17 | CACHEDIR = os.path.join(BASEDIR, "cache", PROJECT_NAME) 18 | SAVEDIR = os.path.join(BASEDIR, "results", PROJECT_NAME) 19 | 20 | for path in [CACHEDIR, SAVEDIR, DATADIR]: 21 | Path(path).mkdir(parents=True, exist_ok=True) 22 | 23 | text_feature = { 24 | "name": "LLAMA3p2", 25 | } 26 | video_feature = { 27 | "name": "VJEPA2", 28 | } 29 | audio_feature = { 30 | "name": "Wav2VecBert", 31 | } 32 | neuro_feature = { 33 | "name": "Fmri", 34 | } 35 | for feature in [ 36 | text_feature, 37 | video_feature, 38 | audio_feature, 39 | neuro_feature, 40 | ]: 41 | feature["infra"] = { 42 | "folder": CACHEDIR, 43 | "keep_in_ram": True, 44 | "mode": "cached", 45 | "version": "final", 46 | } 47 | 48 | default_config = { 49 | "infra": { 50 | "cluster": "slurm", # Run example locally 51 | "folder": SAVEDIR, 52 | }, 53 | "data": { 54 | "num_workers": 20, 55 | "study": { 56 | "path": Path(DATADIR) / "algonauts2025", 57 | "query": None, 58 | "infra": { 59 | "folder": CACHEDIR, 60 | }, 61 | "enhancers": { 62 | "addtext": {"name": "AddText"}, 63 | "addsentence": { 64 | "name": "AddSentenceToWords", 65 | "max_unmatched_ratio": 0.05, 66 | }, 67 | "addcontext": { 68 | "name": "AddContextToWords", 69 | "sentence_only": False, 70 | "max_context_len": 1024, 71 | }, 72 | "removemissing": {"name": "RemoveMissing"}, 73 | "extractaudio": {"name": "ExtractAudioFromVideo"}, 74 | "chunkevents": { 75 | "name": "ChunkEvents", 76 | "event_type_to_chunk": "Sound", 77 | "max_duration": 60, 78 | "min_duration": 30, 79 | }, 80 | }, 81 | }, 82 | "neuro": neuro_feature, 83 | "text_feature": text_feature, 84 | "video_feature": video_feature, 85 | "audio_feature": audio_feature, 86 | "layers": [0.5, 0.75, 1.0], 87 | "layer_aggregation": "group_mean", 88 | }, 89 | "wandb_config": { 90 | "log_model": False, 91 | "project": "algonauts-2025", 92 | "group": "default", 93 | "host": None, 94 | }, 95 | "brain_model_config": { 96 | "name": "FmriEncoder", 97 | "modality_dropout": 0.3, 98 | "feature_aggregation": "cat", 99 | "layer_aggregation": "cat", 100 | "subject_embedding": False, 101 | }, 102 | "metrics": [ 103 | { 104 | "log_name": "pearson", 105 | "name": "MultidimPearsonCorrCoef", 106 | "kwargs": {"num_outputs": 1000}, 107 | }, 108 | { 109 | "log_name": "subj_pearson", 110 | "name": "GroupedMetric", 111 | "metric_name": "MultidimPearsonCorrCoef", 112 | "kwargs": {"num_outputs": 1000}, 113 | }, 114 | { 115 | "log_name": "retrieval_top1", 116 | "name": "TopkAcc", 117 | "topk": 1, 118 | }, 119 | ], 120 | "loss": {"name": "MSELoss"}, 121 | "optim": { 122 | "optimizer": { 123 | "name": "Adam", 124 | "lr": 1e-4, 125 | "kwargs": { 126 | "weight_decay": 0.0, 127 | }, 128 | }, 129 | "scheduler": { 130 | "name": "OneCycleLR", 131 | "kwargs": { 132 | "max_lr": 1e-4, 133 | "pct_start": 0.1, 134 | }, 135 | }, 136 | }, 137 | "n_epochs": 15, 138 | "limit_train_batches": None, 139 | "patience": None, 140 | "enable_progress_bar": True, 141 | "log_every_n_steps": 5, 142 | "fast_dev_run": False, 143 | "seed": 33, 144 | } 145 | 146 | 147 | if __name__ == "__main__": 148 | from ..main import Experiment 149 | 150 | exp = Experiment( 151 | **default_config, 152 | ) 153 | 154 | exp.infra.clear_job() 155 | out = exp.run() 156 | print(out) 157 | -------------------------------------------------------------------------------- /algonauts2025/pl_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import typing as tp 9 | from pathlib import Path 10 | 11 | import lightning.pytorch as pl 12 | from data_utils.dataloader import SegmentData 13 | from einops import rearrange 14 | from modeling_utils.optimizers import OptimizerConfig 15 | from torch import nn 16 | from torchmetrics import Metric 17 | 18 | 19 | class BrainModule(pl.LightningModule): 20 | 21 | def __init__( 22 | self, 23 | model: nn.Module, 24 | loss: nn.Module, 25 | optim_config: OptimizerConfig, 26 | metrics: dict[str, Metric], 27 | max_epochs: int = 100, 28 | checkpoint_path: Path | None = None, 29 | config: dict[str, tp.Any] | None = None, 30 | ) -> None: 31 | super().__init__() 32 | self.model = model 33 | self.checkpoint_path = checkpoint_path 34 | self.config = config 35 | 36 | # Optimizer 37 | self.optim_config = optim_config 38 | self.max_epochs = max_epochs 39 | 40 | self.loss = loss 41 | self.metrics = metrics 42 | 43 | def forward(self, batch): 44 | return self.model(batch) 45 | 46 | def _run_step(self, batch: SegmentData, batch_idx, step_name): 47 | y_true = batch.data["fmri"] # B, D, T 48 | y_pred = self.forward(batch) # B, D, T 49 | if step_name == "val": 50 | y_true = y_true[:, :, 0:] 51 | y_pred = y_pred[:, :, 0:] 52 | subject_ids_flat = batch.data["subject_id"].repeat_interleave(y_pred.shape[2], 0) 53 | 54 | y_pred_flat = rearrange(y_pred, "b d t -> (b t) d") 55 | y_true_flat = rearrange(y_true, "b d t -> (b t) d") 56 | loss = self.loss(y_pred_flat, y_true_flat) 57 | log_kwargs = { 58 | "on_step": True if step_name == "train" else False, 59 | "on_epoch": True, 60 | "logger": True, 61 | "prog_bar": True, 62 | "batch_size": y_pred.shape[0], 63 | } 64 | 65 | self.log( 66 | f"{step_name}/loss", 67 | loss, 68 | **log_kwargs, 69 | ) 70 | 71 | # Compute metrics 72 | for metric_name, metric in self.metrics.items(): 73 | if metric_name.startswith(step_name): 74 | if "grouped" in metric.__class__.__name__.lower(): 75 | metric.update(y_pred_flat, y_true_flat, groups=subject_ids_flat) 76 | else: 77 | if "retrieval" in metric_name: 78 | metric.update(y_pred.mean(dim=-1), y_true.mean(dim=-1)) 79 | else: 80 | metric.update(y_pred_flat, y_true_flat) 81 | self.log( 82 | metric_name, 83 | metric, 84 | **log_kwargs, 85 | ) 86 | return loss, y_pred.detach().cpu(), y_true.detach().cpu() 87 | 88 | def on_val_or_test_epoch_end(self, step_name: str) -> None: 89 | for metric_name, metric in self.metrics.items(): 90 | if metric_name.startswith(step_name): 91 | if "grouped" in metric.__class__.__name__.lower(): 92 | metric_dict = { 93 | metric_name + "/" + k: v for k, v in metric.compute().items() 94 | } 95 | self.log_dict(metric_dict) 96 | 97 | def on_validation_epoch_end(self) -> None: 98 | self.on_val_or_test_epoch_end("val") 99 | return super().on_validation_epoch_end() 100 | 101 | def on_test_epoch_end(self) -> None: 102 | self.on_val_or_test_epoch_end("test") 103 | return super().on_test_epoch_end() 104 | 105 | def training_step(self, batch: SegmentData, batch_idx): 106 | loss, _, _ = self._run_step(batch, batch_idx, step_name="train") 107 | return loss 108 | 109 | def validation_step(self, batch: SegmentData, batch_idx): 110 | _, y_pred, y_true = self._run_step(batch, batch_idx, step_name="val") 111 | return y_pred, y_true 112 | 113 | def test_step(self, batch: SegmentData, batch_idx): 114 | _, y_pred, y_true = self._run_step(batch, batch_idx, step_name="test") 115 | return y_pred, y_true 116 | 117 | def configure_optimizers(self): 118 | optim_config = self.optim_config.copy() 119 | unfrozen_params = [p for p in self.parameters() if p.requires_grad] 120 | optimizer = optim_config.build( 121 | unfrozen_params, total_steps=self.trainer.estimated_stepping_batches 122 | ) 123 | return optimizer 124 | -------------------------------------------------------------------------------- /data_utils/data_utils/features/neuro.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import typing as tp 8 | import warnings 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import pydantic 13 | import torch 14 | from tqdm import tqdm 15 | 16 | import data_utils as du 17 | from data_utils.base import Frequency, TimedArray 18 | from data_utils.events import Event, EventTypesHelper 19 | from data_utils.infra import MapInfra 20 | from data_utils.segments import Segment 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class Fmri(pydantic.BaseModel): 26 | _event_types_helper: EventTypesHelper 27 | _missing_default: torch.Tensor | None = None 28 | model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") 29 | name: tp.Literal["Fmri"] = "Fmri" 30 | infra: MapInfra = MapInfra() 31 | 32 | 33 | @classmethod 34 | def __pydantic_init_subclass__(cls, **kwargs: tp.Any) -> None: 35 | super().__pydantic_init_subclass__(**kwargs) 36 | 37 | super().__init_subclass__() 38 | 39 | def model_post_init(self, log__: tp.Any) -> None: 40 | super().model_post_init(log__) 41 | self._event_types_helper = EventTypesHelper("Fmri") 42 | 43 | def prepare( 44 | self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] 45 | ) -> None: 46 | from data_utils import helpers 47 | 48 | events = helpers.extract_events(obj, types=self._event_types_helper) 49 | 50 | self._get_data(events) 51 | if events: 52 | 53 | self( 54 | events[0], 55 | start=events[0].start, 56 | duration=0.001, 57 | trigger=events[0].to_dict(), 58 | ) 59 | 60 | def __call__( 61 | self, 62 | events: tp.Any, 63 | start: float, 64 | duration: float, 65 | trigger: float | dict[str, tp.Any] | None = None, 66 | ) -> torch.Tensor: 67 | _input_events = events 68 | 69 | from data_utils import helpers 70 | 71 | assert duration >= 0.0, f"{duration} must be >= 0." 72 | event_types = self._event_types_helper.classes 73 | name = self.__class__.__name__ 74 | 75 | events = helpers.extract_events(events, types=self._event_types_helper) 76 | 77 | if not events and self._missing_default is not None: 78 | default = self._missing_default 79 | freq = Frequency(1/1.49) 80 | if freq: 81 | n_times = max(1, freq.to_ind(duration)) 82 | reps = [1 for _ in range(default.ndim)] + [n_times] 83 | default = default.unsqueeze(-1).repeat(reps) 84 | return default 85 | 86 | 87 | events = events[:1] 88 | tarrays = list( 89 | self._get_timed_arrays(events=events, start=start, duration=duration) 90 | ) 91 | 92 | time_info: dict[str, tp.Any] = { 93 | "start": start, 94 | "frequency": 1/1.49, 95 | "duration": duration, 96 | } 97 | out = TimedArray(aggregation="sum", **time_info) 98 | for ta in tarrays: 99 | out += ta 100 | tensor = torch.from_numpy(out.data) 101 | if not tensor.ndim: 102 | tensor = tensor.unsqueeze(0) 103 | 104 | if self._missing_default is None: 105 | 106 | shape = tuple(tensor.shape[: -1]) 107 | self._missing_default = torch.zeros(*shape, dtype=tensor.dtype) 108 | return tensor 109 | 110 | def _exclude_from_cache_uid(self) -> list[str]: 111 | return [ 112 | "offset", 113 | ] 114 | 115 | def _preprocess_event(self, event: du.events.Fmri) -> np.ndarray: 116 | rec = event.read() 117 | data = rec.get_fdata() 118 | 119 | import nilearn.signal 120 | 121 | data = data.T 122 | 123 | shape = data.shape 124 | data = nilearn.signal.clean( 125 | data.reshape(shape[0], -1), 126 | detrend=False, 127 | high_pass=None, 128 | t_r=1.49, 129 | standardize="zscore_sample", 130 | ) 131 | data = data.reshape(shape).T 132 | return data.astype(np.float32) 133 | 134 | @infra.apply( 135 | item_uid=lambda e: str(e.filepath), 136 | exclude_from_cache_uid=_exclude_from_cache_uid, 137 | cache_type="NumpyMemmapArray", 138 | ) 139 | def _get_data(self, events: tp.List[du.events.Fmri]) -> tp.Iterable[np.ndarray]: 140 | for event in tqdm(events, disable=len(events) < 2, desc="Computing fmri data"): 141 | yield self._preprocess_event(event) 142 | 143 | def _get_timed_arrays( 144 | self, events: list[du.events.Fmri], start: float, duration: float 145 | ) -> tp.Iterable[TimedArray]: 146 | freq = events[0].frequency 147 | for event, data in zip(events, self._get_data(events)): 148 | yield TimedArray( 149 | data=data, 150 | frequency=freq, 151 | start=event.start - 4.47, 152 | duration=event.duration, 153 | ) 154 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/models/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | 8 | import pydantic 9 | import torch 10 | from torch import nn 11 | from torchvision.ops import MLP 12 | 13 | 14 | class SubjectLayers(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: int, 20 | n_subjects: int, 21 | bias: bool = False, 22 | init_id: bool = False, 23 | average_subjects: bool = False, 24 | ): 25 | super().__init__() 26 | self.weights = nn.Parameter(torch.empty(n_subjects, in_channels, out_channels)) 27 | self.bias = nn.Parameter(torch.empty(n_subjects, out_channels)) if bias else None 28 | if init_id: 29 | if in_channels != out_channels: 30 | raise ValueError( 31 | "in_channels and out_channels must be the same for identity initialization." 32 | ) 33 | self.weights.data[:] = torch.eye(in_channels)[None] 34 | if self.bias is not None: 35 | self.bias.data[:] = 0 36 | else: 37 | self.weights.data.normal_() 38 | if self.bias is not None: 39 | self.bias.data.normal_() 40 | self.weights.data *= 1 / in_channels**0.5 41 | if self.bias is not None: 42 | self.bias.data *= 1 / in_channels**0.5 43 | self.average_subjects = average_subjects 44 | 45 | def forward( 46 | self, 47 | x: torch.Tensor, 48 | subjects: torch.Tensor, 49 | ) -> torch.Tensor: 50 | 51 | B, C, T = x.shape 52 | N, C, D = self.weights.shape 53 | assert ( 54 | subjects.max() < N 55 | ), "Subject index higher than number of subjects used to initialize the weights." 56 | if self.average_subjects: 57 | weights = self.weights.mean(dim=0).expand(B, C, D) 58 | if self.bias is not None: 59 | bias = self.bias.mean(dim=0).view(1, D, 1).expand(B, D, 1) 60 | else: 61 | weights = self.weights.index_select(0, subjects.flatten()) 62 | if self.bias is not None: 63 | bias = self.bias.index_select(0, subjects.flatten()).view(B, D, 1) 64 | out = torch.einsum("bct,bcd->bdt", x, weights) 65 | if self.bias is not None: 66 | out += bias 67 | return out 68 | 69 | def __repr__(self): 70 | S, C, D = self.weights.shape 71 | return f"SubjectLayers({C}, {D}, {S})" 72 | 73 | 74 | class LayerScale(nn.Module): 75 | 76 | def __init__(self, channels: int, init: float = 0.1, boost: float = 5.0): 77 | super().__init__() 78 | self.scale = nn.Parameter(torch.zeros(channels)) 79 | self.scale.data[:] = init / boost 80 | self.boost = boost 81 | 82 | def forward(self, x): 83 | return (self.boost * self.scale[:, None]) * x 84 | 85 | 86 | class MlpConfig(pydantic.BaseModel): 87 | 88 | model_config = pydantic.ConfigDict(extra="forbid") 89 | name: tp.Literal["Mlp"] = "Mlp" 90 | 91 | input_size: int | None = None 92 | hidden_sizes: list[int] | None = None 93 | 94 | norm_layer: tp.Literal["layer", "batch", "instance", "unit", None] = None 95 | activation_layer: tp.Literal["relu", "gelu", "elu", "prelu", None] = "relu" 96 | 97 | bias: bool = True 98 | dropout: float = 0.0 99 | 100 | @staticmethod 101 | def _get_norm_layer(kind: str | None) -> tp.Type[nn.Module] | None: 102 | return { 103 | "batch": nn.BatchNorm1d, 104 | "layer": nn.LayerNorm, 105 | "instance": nn.InstanceNorm1d, 106 | None: None, 107 | }[kind] 108 | 109 | @staticmethod 110 | def _get_activation_layer(kind: str | None) -> tp.Type[nn.Module]: 111 | return { 112 | "gelu": nn.GELU, 113 | "relu": nn.ReLU, 114 | "elu": nn.ELU, 115 | "prelu": nn.PReLU, 116 | None: nn.Identity, 117 | }[kind] 118 | 119 | def build( 120 | self, input_size: int | None = None, output_size: int | None = None 121 | ) -> nn.Sequential | nn.Identity: 122 | input_size = self.input_size if input_size is None else input_size 123 | assert input_size is not None, "input_size cannot be None." 124 | if not self.hidden_sizes: 125 | assert ( 126 | output_size is not None 127 | ), "output_size cannot be None if hidden_sizes is empty." 128 | return nn.Linear(input_size, output_size) 129 | 130 | hidden_sizes = self.hidden_sizes 131 | if output_size is not None: 132 | hidden_sizes.append(output_size) 133 | 134 | return MLP( 135 | in_channels=input_size, 136 | hidden_channels=hidden_sizes, 137 | norm_layer=self._get_norm_layer(self.norm_layer), 138 | activation_layer=self._get_activation_layer(self.activation_layer), 139 | bias=self.bias, 140 | dropout=self.dropout, 141 | ) 142 | 143 | 144 | class Mean(nn.Module): 145 | def __init__(self, dim: int, keepdim: bool = False): 146 | super().__init__() 147 | self.dim = dim 148 | self.keepdim = keepdim 149 | 150 | def forward(self, x: torch.Tensor) -> torch.Tensor: 151 | return x.mean(dim=self.dim, keepdim=self.keepdim) 152 | -------------------------------------------------------------------------------- /data_utils/data_utils/features/subject.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import typing as tp 8 | import warnings 9 | 10 | import pandas as pd 11 | import pydantic 12 | import torch 13 | 14 | import data_utils as du 15 | from data_utils.base import Frequency as Frequency 16 | from data_utils.base import TimedArray as TimedArray 17 | from data_utils.events import Event, EventTypesHelper 18 | from data_utils.segments import Segment 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class SubjectEncoder(pydantic.BaseModel): 24 | _event_types_helper: EventTypesHelper 25 | _missing_default: torch.Tensor | None = None 26 | 27 | def model_post_init(self, log__: tp.Any) -> None: 28 | super().model_post_init(log__) 29 | self._event_types_helper = EventTypesHelper("Event") 30 | 31 | def _get_data(self, events: list[Event]) -> tp.Iterable[tp.Any]: 32 | for _ in events: 33 | yield None 34 | 35 | def _get_timed_arrays( 36 | self, events: list[Event], start: float, duration: float 37 | ) -> tp.Iterable[TimedArray]: 38 | raise NotImplementedError 39 | 40 | def __call__( 41 | self, 42 | events: tp.Any, 43 | start: float, 44 | duration: float, 45 | trigger: float | dict[str, tp.Any] | None = None, 46 | ) -> torch.Tensor: 47 | _input_events = events 48 | 49 | from data_utils import helpers 50 | 51 | assert duration >= 0.0, f"{duration} must be >= 0." 52 | event_types = self._event_types_helper.classes 53 | name = self.__class__.__name__ 54 | events = helpers.extract_events(events, types=self._event_types_helper) 55 | 56 | if not events and self._missing_default is not None: 57 | default = self._missing_default 58 | freq = Frequency(0.0) 59 | if freq: 60 | n_times = max(1, freq.to_ind(duration)) 61 | reps = [1 for _ in range(default.ndim)] + [n_times] 62 | default = default.unsqueeze(-1).repeat(reps) 63 | return default 64 | 65 | if not events: 66 | found_types = {type(e) for e in _input_events} 67 | msg = f"No {event_types} found in segment for feature {name} " 68 | msg += f"(types found: {found_types} in {_input_events}) " 69 | msg += "and feature shape not populated " 70 | msg += '(you may need to call "prepare" on the feature).' 71 | raise ValueError(msg) 72 | 73 | events = events[:1] 74 | tarrays = list( 75 | self._get_timed_arrays(events=events, start=start, duration=duration) 76 | ) 77 | 78 | 79 | time_info: dict[str, tp.Any] = { 80 | "start": start, 81 | "frequency": 0.0, 82 | "duration": duration, 83 | } 84 | aggreg = "sum" 85 | out = TimedArray(aggregation=aggreg, **time_info) 86 | for ta in tarrays: 87 | out += ta 88 | tensor = torch.from_numpy(out.data) 89 | if not tensor.ndim: 90 | tensor = tensor.unsqueeze(0) 91 | 92 | if self._missing_default is None: 93 | 94 | shape = tuple(tensor.shape[: -1 ]) 95 | self._missing_default = torch.zeros(*shape, dtype=tensor.dtype) 96 | return tensor 97 | 98 | 99 | 100 | def _get_timed_arrays( 101 | self, events: list[Event], start: float, duration: float 102 | ) -> tp.Iterable[TimedArray]: 103 | for event in events: 104 | embedding = self.get_static(event) 105 | ta = TimedArray( 106 | frequency=0, 107 | duration=event.duration, 108 | start=event.start, 109 | data=embedding.numpy(), 110 | ) 111 | yield ta 112 | 113 | name: tp.Literal["SubjectEncoder"] = "SubjectEncoder" 114 | 115 | _label_to_ind: dict[str, int] = {} 116 | 117 | def _extract_event_field(self, event: du.events.Event) -> str: 118 | if hasattr(event, "subject"): 119 | return getattr(event, "subject") 120 | else: 121 | return event.extra["subject"] 122 | 123 | def prepare( 124 | self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] 125 | ) -> None: 126 | from data_utils import helpers 127 | 128 | events = helpers.extract_events(obj, types=self._event_types_helper) 129 | field = "subject" 130 | if not all(hasattr(e, field) or field in e.extra for e in events): 131 | msg = f"Field {field} not found in events for {self.__class__.__name__}" 132 | raise TypeError(msg) 133 | labels = set(self._extract_event_field(e) for e in events) 134 | if len(labels) < 2: 135 | logger.warning( 136 | f"SubjectEncoder has only found one label: {labels}. " 137 | "This was probably not intended." 138 | ) 139 | self._label_to_ind = {label: i for i, label in enumerate(sorted(labels))} 140 | if events: 141 | self(events[0], events[0].start, duration=0.001, trigger=events[0].to_dict()) 142 | 143 | def get_static(self, event: du.events.Event) -> torch.Tensor: 144 | if not self._label_to_ind: 145 | msg = "Must call subject_encoder.prepare(events) before using the feature." 146 | raise ValueError(msg) 147 | inds = [self._label_to_ind[self._extract_event_field(event)]] 148 | label = torch.tensor(inds, dtype=torch.long) 149 | return label 150 | -------------------------------------------------------------------------------- /data_utils/data_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import contextlib 7 | import dataclasses 8 | import functools 9 | import hashlib 10 | import os 11 | import re 12 | import typing as tp 13 | import warnings 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | 18 | 19 | def all_subclasses(cls): 20 | 21 | subs = set(cls.__subclasses__()) 22 | return subs | {s for c in subs for s in all_subclasses(c)} 23 | 24 | 25 | def match_list(A, B, on_replace="delete"): 26 | 27 | from Levenshtein import editops 28 | 29 | if not isinstance(A, str): 30 | unique = np.unique(np.r_[A, B]) 31 | label_encoder = dict((k, v) for v, k in enumerate(unique)) 32 | 33 | def int_to_unicode(array: np.ndarray) -> str: 34 | return "".join([str(chr(label_encoder[ii])) for ii in array]) 35 | 36 | A = int_to_unicode(A) 37 | B = int_to_unicode(B) 38 | 39 | changes = editops(A, B) 40 | B_sel = np.arange(len(B)).astype(float) 41 | A_sel = np.arange(len(A)).astype(float) 42 | for type_, val_a, val_b in changes: 43 | if type_ == "insert": 44 | B_sel[val_b] = np.nan 45 | elif type_ == "delete": 46 | A_sel[val_a] = np.nan 47 | elif on_replace == "delete": 48 | 49 | A_sel[val_a] = np.nan 50 | B_sel[val_b] = np.nan 51 | elif on_replace == "keep": 52 | 53 | pass 54 | else: 55 | raise NotImplementedError 56 | B_sel = B_sel[np.where(~np.isnan(B_sel))] 57 | A_sel = A_sel[np.where(~np.isnan(A_sel))] 58 | assert len(B_sel) == len(A_sel) 59 | return A_sel.astype(int), B_sel.astype(int) 60 | 61 | 62 | ISSUED_WARNINGS = set() 63 | 64 | 65 | def warn_once(message: str) -> None: 66 | if message not in ISSUED_WARNINGS: 67 | warnings.warn(message) 68 | ISSUED_WARNINGS.add(message) 69 | 70 | 71 | def compress_string(file_) -> str: 72 | def hash_(s: str) -> str: 73 | return hashlib.sha256(s.encode()).hexdigest()[:10] 74 | 75 | file_ = str(file_) 76 | fname = Path(file_).name 77 | 78 | pattern = r"[^a-zA-Z0-9.\-_]" 79 | valid = re.sub(pattern, "", fname) 80 | 81 | if len(fname) > 70: 82 | valid = "_".join([valid[:20], hash_(fname), valid[-20:]]) 83 | 84 | folder = str(Path(file_).parent) 85 | if folder != "." or valid != fname: 86 | valid = f"{hash_(file_)}_{valid}" 87 | 88 | return valid 89 | 90 | 91 | @contextlib.contextmanager 92 | def ignore_all() -> tp.Iterator[None]: 93 | with open(os.devnull, "w", encoding="utf8") as fnull: 94 | with contextlib.redirect_stdout(fnull), contextlib.redirect_stderr(fnull): 95 | with warnings.catch_warnings(): 96 | warnings.simplefilter("ignore") 97 | yield 98 | 99 | 100 | @contextlib.contextmanager 101 | def success_writer( 102 | fname: str | Path, suffix: str = "_success.txt", success_msg: str = "done" 103 | ): 104 | 105 | success_fname = Path(str(Path(fname).with_suffix("")) + suffix) 106 | file_exists = success_fname.exists() 107 | yield file_exists 108 | if not file_exists: 109 | with open(success_fname, "w") as f: 110 | f.write(success_msg) 111 | 112 | 113 | class NoApproximateMatch(ValueError): 114 | 115 | def __init__(self, msg: str, matches: tp.Any) -> None: 116 | super().__init__(msg) 117 | self.matches = matches 118 | 119 | 120 | @dataclasses.dataclass 121 | class Tolerance: 122 | 123 | abs_tol: float 124 | rel_tol: float 125 | 126 | def __call__(self, value1: float, value2: float) -> bool: 127 | diff = abs(value1 - value2) 128 | tol = max(self.abs_tol, self.rel_tol * min(abs(value1), abs(value2))) 129 | return diff <= tol 130 | 131 | 132 | @dataclasses.dataclass 133 | class Sequence: 134 | 135 | sequence: tp.Sequence[float] 136 | 137 | current: int 138 | 139 | matches: tp.List[int] 140 | 141 | def valid_index(self, shift: int = 0) -> bool: 142 | return self.current + shift < len(self.sequence) 143 | 144 | def diff(self, shift: int = 0) -> float: 145 | return self.sequence[self.current + shift] - self.last_value 146 | 147 | @property 148 | def last_value(self) -> float: 149 | return self.sequence[self.matches[-1]] 150 | 151 | def diff_to(self, ind: int) -> np.ndarray: 152 | r = self.matches[-1] 153 | sub = self.sequence[r : r + ind] if ind > 0 else self.sequence[r + ind : r] 154 | return np.array(sub) - self.last_value 155 | 156 | 157 | def get_spacy_model(*, model: str = "", language: str = "") -> tp.Any: 158 | 159 | if language and model: 160 | msg = f"Language and model cannot be specified at the same time, got {language=} and {model=}" 161 | raise ValueError(msg) 162 | if not language and not model: 163 | language = "english" 164 | 165 | if language: 166 | defaults = dict( 167 | english="en_core_web_lg", 168 | french="fr_core_news_lg", 169 | spanish="es_core_news_lg", 170 | chinese="zh_core_web_lg", 171 | ) 172 | if language not in defaults: 173 | raise ValueError(f"Language {language!r} not available: {defaults}") 174 | model = defaults[language] 175 | return _get_model(model) 176 | 177 | 178 | @functools.lru_cache(maxsize=3) 179 | def _get_model(model: str) -> tp.Any: 180 | import spacy 181 | 182 | if not spacy.util.is_package(model): 183 | import spacy.cli 184 | 185 | spacy.cli.download(model) 186 | 187 | nlp = spacy.load(model) 188 | return nlp 189 | -------------------------------------------------------------------------------- /algonauts2025/grids/average_submissions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import zipfile 9 | from collections import defaultdict 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | from pathlib import Path 12 | from .defaults import SAVEDIR 13 | 14 | import numpy as np 15 | import pandas as pd 16 | from tqdm import tqdm 17 | import torch 18 | 19 | def select_diverse_subset(C, k): 20 | """ 21 | Greedy algorithm to select k most diverse indices. 22 | """ 23 | import numpy as np 24 | n = C.shape[0] 25 | selected = [int(np.argmin(np.sum(np.abs(C), axis=0)))] # Start with the least correlated overall 26 | 27 | while len(selected) < k: 28 | candidates = list(set(range(n)) - set(selected)) 29 | scores = [] 30 | for c in candidates: 31 | total_corr = sum(abs(C[c, s]) for s in selected) 32 | scores.append((c, total_corr)) 33 | best = min(scores, key=lambda x: x[1])[0] 34 | selected.append(best) 35 | 36 | return selected 37 | 38 | def get_k_most_diverse_indices(predictions, k): 39 | """ 40 | Get the indices of the k most diverse predictors 41 | """ 42 | preds = [] 43 | for sub in predictions[0].keys(): 44 | for chunk in tqdm(predictions[0][sub].keys(), desc="Gathering predictions for diverse subset estimation"): 45 | preds.append(np.array([data[sub][chunk] for data in predictions])) 46 | break # only the first subject 47 | preds = np.concatenate(preds, axis=1) 48 | preds = preds.reshape(preds.shape[0], -1) 49 | assert preds.shape[0] == len(predictions) 50 | 51 | corr_matrix = np.corrcoef(preds) 52 | indices = select_diverse_subset(corr_matrix, k) 53 | return np.array(indices) 54 | 55 | def average_submissions(grid_path: Path, weigh_by_score: bool = False, per_voxel_weights: bool = False, temperature: float = 1.0, max_runs: int | None = None, k_most_diverse: int | None = None): 56 | """ 57 | Average the submissions of a grid. 58 | """ 59 | checkpoint_paths = [] 60 | # find all folders in grid_path and get the best.ckpt file 61 | print("Found the following submissions:") 62 | for folder in os.listdir(grid_path): 63 | if max_runs is not None and len(checkpoint_paths) == max_runs: 64 | break 65 | if os.path.isdir(os.path.join(grid_path, folder)): 66 | submission_path = os.path.join(grid_path, folder, "submission.zip") 67 | to_remove = os.path.join(grid_path, folder, "submission.npy") 68 | if os.path.exists(submission_path): 69 | checkpoint_paths.append(submission_path) 70 | print(submission_path) 71 | if os.path.exists(to_remove): 72 | os.remove(to_remove) 73 | print(f"Found {len(checkpoint_paths)} submissions") 74 | 75 | predictions = [] 76 | scores = [] 77 | pearsons = [] 78 | 79 | def load_submission(path): 80 | try: 81 | submission = np.load(path, allow_pickle=True)["submission"].item() 82 | except: 83 | print(f"Error loading submission from {path}") 84 | return None 85 | metrics = pd.read_csv(path.replace("submission.zip", "metrics.csv")) 86 | if os.path.exists(path.replace("submission.zip", "pearson.npy")): 87 | pearson = np.load(path.replace("submission.zip", "pearson.npy")) 88 | else: 89 | pearson = None 90 | return submission, metrics, pearson 91 | 92 | with ThreadPoolExecutor(max_workers=10) as executor: 93 | future_to_path = {executor.submit(load_submission, path): path for path in checkpoint_paths} 94 | 95 | for future in tqdm(as_completed(future_to_path), total=len(checkpoint_paths), desc="Loading submissions"): 96 | output = future.result() 97 | if output is None: 98 | continue 99 | predictions.append(output[0]) 100 | scores.append(output[1]) 101 | pearsons.append(output[2]) 102 | 103 | if k_most_diverse is not None: 104 | indices = get_k_most_diverse_indices(predictions, k_most_diverse) 105 | predictions, scores = [predictions[i] for i in indices], [scores[i] for i in indices] 106 | 107 | if per_voxel_weights: 108 | pearsons = torch.Tensor(pearsons) / temperature 109 | weights = pearsons.softmax(dim=1) # n_submissions x n_voxels 110 | weights = np.array(weights.unsqueeze(1)) 111 | else: 112 | scores = np.array([score["val/pearson"].item() for score in scores]) 113 | weights = np.exp(scores / temperature) / np.sum(np.exp(scores / temperature)) 114 | weights = weights[:, None, None] 115 | print(weights.min(), weights.max()) 116 | 117 | averaged_predictions = defaultdict(dict) 118 | for sub in tqdm(predictions[0].keys(), desc="Averaging submissions"): 119 | for chunk in predictions[0][sub].keys(): 120 | preds = np.array([data[sub][chunk] for data in predictions]) # n_submissions x n_timepoints x n_voxels 121 | if weigh_by_score: 122 | avg_preds = np.sum(preds * weights, axis=0) 123 | else: 124 | avg_preds = np.mean(preds, axis=0) 125 | averaged_predictions[sub][chunk] = avg_preds 126 | 127 | submission_path = grid_path / "submission.npy" 128 | np.save(submission_path, averaged_predictions) 129 | with zipfile.ZipFile(submission_path.with_suffix(".zip"), "w") as zipf: 130 | zipf.write(submission_path, arcname = submission_path.name) 131 | print(f"Saved average submission to {submission_path.with_suffix('.zip')}") 132 | 133 | if __name__ == "__main__": 134 | grid_path = Path(SAVEDIR) / "model_soup" 135 | weigh_by_score = True 136 | per_voxel_weights = True 137 | temperature = 0.3 138 | max_runs = None 139 | k_most_diverse = None 140 | average_submissions(grid_path=grid_path, weigh_by_score=weigh_by_score, per_voxel_weights=per_voxel_weights, temperature=temperature, max_runs=max_runs, k_most_diverse=k_most_diverse) -------------------------------------------------------------------------------- /data_utils/data_utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import collections 7 | import dataclasses 8 | import logging 9 | import typing as tp 10 | import warnings 11 | 12 | import torch 13 | 14 | import data_utils as du 15 | 16 | from .base import Frequency 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class CollateSegments: 22 | 23 | def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: 24 | raise RuntimeError("CollateSegments is deprecated in favor of SegmentDataset") 25 | 26 | 27 | @dataclasses.dataclass 28 | class SegmentData: 29 | 30 | data: tp.Dict[str, torch.Tensor] 31 | segments: tp.List[du.segments.Segment] 32 | 33 | def __post_init__(self) -> None: 34 | if not isinstance(self.data, dict): 35 | raise TypeError(f"'features' need to be a dict, got: {self.features}") 36 | if not self.data: 37 | raise ValueError(f"No data in {self}") 38 | if not isinstance(self.segments, list): 39 | raise TypeError(f"'segments' needs to be a list, got {self.segments}") 40 | 41 | batch_size = next(iter(self.data.values())).shape[0] 42 | if len(self.segments) != batch_size: 43 | raise RuntimeError( 44 | f"Incoherent batch size {batch_size} for {len(self.segments)} segments in {self}" 45 | ) 46 | 47 | def to(self, device: str) -> "SegmentData": 48 | 49 | out = {name: d.to(device) for name, d in self.data.items()} 50 | return SegmentData(data=out, segments=self.segments) 51 | 52 | def __getitem__(self, key: str) -> None: 53 | raise RuntimeError("New SegmentData batch is not a dict, use batch.data instead") 54 | 55 | 56 | def validate_features(features: tp.Mapping[str, tp.Any]) -> tp.Mapping[str, tp.Any]: 57 | 58 | if not features: 59 | return {} 60 | 61 | if not isinstance(features, collections.abc.Mapping): 62 | raise ValueError(f"Only dict of features are supported, got {type(features)}") 63 | 64 | return features 65 | 66 | 67 | def get_pad_lengths( 68 | feats: tp.Mapping[str, tp.Any], 69 | pad_duration: float | None, 70 | ) -> tp.Dict[str, int]: 71 | 72 | pad_lengths: tp.Dict[str, int] = {} 73 | if pad_duration is None: 74 | return pad_lengths 75 | for name, f in feats.items(): 76 | if isinstance( 77 | f, 78 | du.features.text.LLAMA3p2 79 | | du.features.audio.Wav2VecBert 80 | | du.features.neuro.Fmri 81 | | du.features.video.VJEPA2 82 | | du.features.SubjectEncoder, 83 | ): 84 | freq = Frequency(f.frequency) 85 | pad_lengths[name] = freq.to_ind(pad_duration) 86 | return pad_lengths 87 | 88 | 89 | def _pad_to(tensor: torch.Tensor, pad_len: int | None): 90 | 91 | if pad_len is None: 92 | return tensor 93 | if pad_len < tensor.shape[-1]: 94 | msg = "Pad duration is shorter than segment duration, cropping." 95 | warnings.warn(msg, UserWarning) 96 | return tensor[:, :pad_len] 97 | else: 98 | return torch.nn.functional.pad(tensor, (0, pad_len - tensor.shape[-1])) 99 | 100 | 101 | def _apply_feature(segment: du.segments.Segment, feature: tp.Any) -> torch.Tensor: 102 | 103 | return feature( 104 | segment.ns_events, 105 | start=segment.start, 106 | duration=segment.duration, 107 | trigger=segment._trigger, 108 | ) 109 | 110 | 111 | class SegmentDataset(torch.utils.data.Dataset[SegmentData]): 112 | 113 | def __init__( 114 | self, 115 | features: tp.Mapping[str, tp.Any], 116 | segments: tp.Sequence[du.segments.Segment], 117 | pad_duration: float | None = None, 118 | ) -> None: 119 | self.features = validate_features(features) 120 | self.segments = segments 121 | self._pad_lengths = get_pad_lengths(self.features, pad_duration) 122 | 123 | def collate_fn(self, batches: tp.List[SegmentData]) -> SegmentData: 124 | 125 | if not batches: 126 | return SegmentData(data={}, segments=[]) 127 | if len(batches) == 1: 128 | return batches[0] 129 | if not batches[0].data: 130 | raise ValueError(f"No feature in first batch: {batches[0]}") 131 | 132 | features = {} 133 | for name in batches[0].data: 134 | data = [b.data[name] for b in batches] 135 | try: 136 | features[name] = torch.cat(data, axis=0) 137 | 138 | except Exception: 139 | string = f"Failed to collate data with shapes {[d.shape for d in data]}\n" 140 | string += "Do you need specifying padding in SegmentDataset?" 141 | logger.warning(string) 142 | raise 143 | segments = [s for b in batches for s in b.segments] 144 | return SegmentData(data=features, segments=segments) 145 | 146 | def __len__(self) -> int: 147 | return len(self.segments) 148 | 149 | def __getitem__(self, idx: int) -> SegmentData: 150 | seg = self.segments[idx] 151 | out: tp.Dict[str, torch.Tensor] = {} 152 | for name, feats in self.features.items(): 153 | 154 | data = _apply_feature(seg, feats) 155 | 156 | data = _pad_to(data, self._pad_lengths.get(name, None)) 157 | 158 | out[name] = data[None, ...] 159 | 160 | return SegmentData(data=out, segments=[seg]) 161 | 162 | def build_dataloader(self, **kwargs: tp.Any) -> torch.utils.data.DataLoader: 163 | 164 | return torch.utils.data.DataLoader(self, collate_fn=self.collate_fn, **kwargs) 165 | 166 | def as_one_batch(self, num_workers: int = 0) -> SegmentData: 167 | 168 | num_workers = min(num_workers, len(self)) 169 | batch_size = len(self) 170 | if num_workers > 1: 171 | batch_size = max(1, len(self) // (3 * num_workers)) 172 | if num_workers == 1: 173 | num_workers = 0 174 | 175 | loader = self.build_dataloader( 176 | num_workers=num_workers, 177 | batch_size=batch_size, 178 | shuffle=False, 179 | ) 180 | return self.collate_fn(list(loader)) 181 | -------------------------------------------------------------------------------- /algonauts2025/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | import numpy as np 10 | import pydantic 11 | import torch 12 | from data_utils.dataloader import SegmentData 13 | from einops import rearrange 14 | from modeling_utils.models.common import MlpConfig, SubjectLayers 15 | from modeling_utils.models.transformer import TransformerEncoderConfig 16 | from torch import nn 17 | 18 | 19 | class FmriEncoderConfig(pydantic.BaseModel): 20 | model_config = pydantic.ConfigDict(extra="forbid") 21 | name: tp.Literal["FmriEncoder"] = "FmriEncoder" 22 | n_subjects: int | None = None 23 | feature_aggregation: tp.Literal["sum", "cat"] = "cat" 24 | layer_aggregation: tp.Literal["mean", "cat"] = "cat" 25 | subject_embedding: bool = False 26 | modality_dropout: float = 0.0 27 | 28 | def build( 29 | self, feature_dims: dict[int], n_outputs: int, n_output_timesteps: int 30 | ) -> nn.Module: 31 | return FmriEncoder( 32 | feature_dims, 33 | n_outputs, 34 | n_output_timesteps, 35 | config=self, 36 | ) 37 | 38 | 39 | class FmriEncoder(nn.Module): 40 | def __init__( 41 | self, 42 | feature_dims: dict[str, tuple[int, int]], 43 | n_outputs: int, 44 | n_output_timesteps: int, 45 | config: FmriEncoderConfig, 46 | ): 47 | super().__init__() 48 | self.config = config 49 | self.feature_dims = feature_dims 50 | self.n_outputs = n_outputs 51 | self.projectors = nn.ModuleDict() 52 | self.pooler = nn.AdaptiveAvgPool1d(n_output_timesteps) 53 | hidden = 3072 54 | for modality, tup in feature_dims.items(): 55 | if tup is None: 56 | print( 57 | f"Warning: {modality} has no feature dimensions. Skipping projector." 58 | ) 59 | continue 60 | else: 61 | num_layers, feature_dim = tup 62 | input_dim = ( 63 | feature_dim * num_layers 64 | if config.layer_aggregation == "cat" 65 | else feature_dim 66 | ) 67 | output_dim = ( 68 | hidden // len(feature_dims) 69 | if config.feature_aggregation == "cat" 70 | else hidden 71 | ) 72 | self.projectors[modality] = MlpConfig( 73 | norm_layer="layer", activation_layer="gelu", dropout=0.0 74 | ).build(input_dim, output_dim) 75 | input_dim = ( 76 | (hidden // len(feature_dims)) * len(feature_dims) 77 | if config.feature_aggregation == "cat" 78 | else hidden 79 | ) 80 | self.combiner = nn.Identity() 81 | self.predictor = SubjectLayers( 82 | in_channels=hidden, 83 | out_channels=n_outputs, 84 | n_subjects=config.n_subjects, 85 | average_subjects=False, 86 | bias=True, 87 | ) 88 | self.time_pos_embed = nn.Parameter(torch.randn(1, 1024, hidden)) 89 | if config.subject_embedding: 90 | self.subject_embed = nn.Embedding(config.n_subjects, hidden) 91 | self.encoder = TransformerEncoderConfig( 92 | attn_dropout=0.0, ff_dropout=0.0, layer_dropout=0.0, depth=8 93 | ).build(dim=hidden) 94 | 95 | def forward(self, batch: SegmentData, pool_outputs: bool = True) -> torch.Tensor: 96 | x = self.aggregate_features(batch) # B, T, H 97 | subject_id = batch.data.get("subject_id", None) 98 | x = self.transformer_forward(x, subject_id) 99 | x = x.transpose(1, 2) # B, H, T 100 | x = self.predictor(x, subject_id) # B, O, T 101 | if pool_outputs: 102 | out = self.pooler(x) # B, O, T' 103 | else: 104 | out = x 105 | return out 106 | 107 | def aggregate_features(self, batch): 108 | tensors = [] 109 | # get B, T 110 | for modality in batch.data.keys(): 111 | if modality in self.feature_dims: 112 | break 113 | x = batch.data[modality] 114 | B, T = x.shape[0], x.shape[-1] 115 | # select the modalities to dropout, keep at least one modality 116 | modalities_to_dropout = [] 117 | for modality in self.feature_dims.keys(): 118 | if torch.rand(1).item() < self.config.modality_dropout and self.training: 119 | modalities_to_dropout.append(modality) 120 | if len(modalities_to_dropout) == len(self.feature_dims): 121 | modalities_to_dropout = np.random.choice( 122 | modalities_to_dropout, len(modalities_to_dropout) - 1, replace=False 123 | ) 124 | for modality in self.feature_dims.keys(): 125 | if modality not in self.projectors: 126 | data = torch.zeros(B, T, 3072 // len(self.feature_dims)).to(x.device) 127 | else: 128 | data = batch.data[modality] # B, L, H, T 129 | data = data.to(torch.float32) 130 | if data.ndim == 3: 131 | data = data.unsqueeze(1) 132 | # mean over layers 133 | if self.config.layer_aggregation == "mean": 134 | data = data.mean(dim=1) 135 | elif self.config.layer_aggregation == "cat": 136 | data = rearrange(data, "b l d t -> b (l d) t") 137 | data = data.transpose(1, 2) 138 | assert data.ndim == 3 # B, T, D 139 | data = self.projectors[modality](data) # B, T, H 140 | if modality in modalities_to_dropout: 141 | data = torch.zeros_like(data) 142 | tensors.append(data) 143 | if self.config.feature_aggregation == "cat": 144 | out = torch.cat(tensors, dim=-1) 145 | elif self.config.feature_aggregation == "sum": 146 | out = sum(tensors) 147 | return out 148 | 149 | def transformer_forward(self, x, subject_id=None): 150 | x = self.combiner(x) 151 | if hasattr(self, "time_pos_embed"): 152 | x = x + self.time_pos_embed[:, : x.size(1)] 153 | if hasattr(self, "subject_embed"): 154 | x = x + self.subject_embed(subject_id) 155 | x = self.encoder(x) 156 | return x 157 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import annotations 8 | 9 | import inspect 10 | import random 11 | import shutil 12 | import typing as tp 13 | from itertools import product 14 | from pathlib import Path 15 | 16 | import pydantic 17 | import wandb 18 | from data_utils.infra import ConfDict, TaskInfra 19 | from pydantic import BaseModel, Field, create_model 20 | 21 | 22 | def convert_to_pydantic( 23 | class_to_convert: type, 24 | name: str, 25 | parent_class: tp.Any = None, 26 | exclude_from_build: list[str] | None = None, 27 | ) -> BaseModel: 28 | 29 | init = class_to_convert.__init__ 30 | 31 | sig = inspect.signature(init) 32 | empty = inspect.Parameter.empty 33 | 34 | fields = { 35 | k: ( 36 | v.annotation if v.annotation != empty else tp.Any, 37 | v.default if v.default != empty else ..., 38 | ) 39 | for k, v in sig.parameters.items() 40 | if k != "self" and not k.startswith("_") 41 | } 42 | 43 | assert "name" not in sig.parameters.items() 44 | 45 | Builder = create_model( 46 | name, 47 | name=(tp.Literal[name], Field(default=name)), 48 | __base__=parent_class, 49 | **fields, 50 | ) 51 | Builder._cls = class_to_convert 52 | 53 | if exclude_from_build is None: 54 | exclude_from_build = [] 55 | 56 | def build_method(instance: BaseModel): 57 | params = dict( 58 | (field, getattr(instance, field)) 59 | for field in type(instance).model_fields 60 | if (field != "name" and field not in exclude_from_build) 61 | ) 62 | return instance._cls(**params) 63 | 64 | setattr(Builder, "build", build_method) 65 | 66 | return Builder 67 | 68 | 69 | def all_subclasses(cls): 70 | 71 | return set(cls.__subclasses__()).union( 72 | [s for c in cls.__subclasses__() for s in all_subclasses(c)] 73 | ) 74 | 75 | 76 | def run_grid( 77 | exp_cls, 78 | exp_name: str, 79 | base_config: dict[str, tp.Any], 80 | grid: dict[str, list], 81 | n_randomly_sampled: int | None = None, 82 | job_name_keys: list[str] | None = None, 83 | combinatorial: bool = False, 84 | overwrite: bool = False, 85 | dry_run: bool = False, 86 | infra_mode: str = "retry", 87 | ) -> list[ConfDict]: 88 | 89 | job_array_kwargs = {} 90 | if dry_run: 91 | from importlib.metadata import version 92 | 93 | from pkg_resources import parse_version 94 | 95 | if parse_version(version("exca")) < parse_version("0.4.5"): 96 | raise ImportError("`dry_run` requires `exca>=0.4.5` to be installed.") 97 | job_array_kwargs["allow_empty"] = True 98 | 99 | base_config = base_config 100 | base_config["infra.job_name"] = exp_name 101 | base_folder = Path(base_config["infra"]["folder"]) 102 | assert all([isinstance(v, list) for v in grid.values()]), "Grid values must be lists." 103 | 104 | task = exp_cls( 105 | **base_config, 106 | ) 107 | 108 | if combinatorial: 109 | grid_product = list(dict(zip(grid.keys(), v)) for v in product(*grid.values())) 110 | else: 111 | grid_product = [ 112 | {param: value} for param, values in grid.items() for value in values 113 | ] 114 | 115 | if n_randomly_sampled is not None: 116 | assert n_randomly_sampled <= len( 117 | grid_product 118 | ), "n_randomly_sampled must be less than the number of grid products" 119 | grid_product = random.sample(grid_product, n_randomly_sampled) 120 | 121 | print(f"Launching {len(grid_product)} tasks") 122 | 123 | out_configs = [] 124 | tmp = task.infra.clone_obj(**{"infra.mode": infra_mode}) 125 | with tmp.infra.job_array(**job_array_kwargs) as tasks: 126 | for params in grid_product: 127 | job_name = ConfDict(params).to_uid() 128 | 129 | config = ConfDict(base_config) 130 | config.update(params) 131 | 132 | folder = base_folder / exp_name / job_name 133 | if folder.exists(): 134 | 135 | print(f"{folder} already exists.") 136 | if overwrite and not dry_run: 137 | 138 | print(f"Folder {folder} already exists. Overwrite? (y/n)") 139 | answer = input() 140 | if answer.lower() != "y": 141 | print("Skipping.") 142 | continue 143 | print(f"Deleting {folder}.") 144 | shutil.rmtree(folder) 145 | folder.mkdir() 146 | 147 | config["infra.folder"] = str(folder) 148 | if job_name_keys is not None: 149 | for key in job_name_keys: 150 | config.update({key: str(job_name)}) 151 | 152 | if not dry_run: 153 | task_ = exp_cls(**config) 154 | tasks.append(task_) 155 | 156 | out_configs.append(config) 157 | 158 | print("Done.") 159 | 160 | return out_configs 161 | 162 | 163 | class WandbLoggerConfig(pydantic.BaseModel): 164 | 165 | model_config = pydantic.ConfigDict(extra="forbid") 166 | 167 | offline: bool = False 168 | host: str | None = None 169 | name: str | None = None 170 | group: str | None = None 171 | entity: str | None = None 172 | 173 | version: str | None = None 174 | dir: Path | None = None 175 | id: str | None = None 176 | anonymous: bool | None = None 177 | project: str | None = None 178 | log_model: str | bool = False 179 | experiment: tp.Any | None = None 180 | prefix: str = "" 181 | 182 | def build( 183 | self, 184 | save_dir: str | Path, 185 | xp_config: dict | pydantic.BaseModel | None = None, 186 | id: str | None = None, 187 | ) -> tp.Any: 188 | if self.offline: 189 | login_kwargs = {"key": "X" * 40} 190 | else: 191 | login_kwargs = {"host": self.host} 192 | 193 | wandb.login(**login_kwargs) 194 | 195 | from lightning.pytorch.loggers import WandbLogger 196 | 197 | if isinstance(xp_config, pydantic.BaseModel): 198 | xp_config = xp_config.model_dump() 199 | config = self.model_dump() 200 | if id is not None: 201 | config["id"] = id 202 | del config["host"] 203 | logger = WandbLogger(**config, save_dir=save_dir, config=xp_config) 204 | try: 205 | logger.experiment.config["_dummy"] = None 206 | 207 | except TypeError: 208 | pass 209 | 210 | return logger 211 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/models/fmri_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import typing as tp 8 | from functools import partial 9 | 10 | import pydantic 11 | import torch 12 | from torch import nn 13 | from torchvision.ops import MLP 14 | 15 | from .common import Mean, MlpConfig, SubjectLayers 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class FmriMlpConfig(pydantic.BaseModel): 21 | model_config = pydantic.ConfigDict(extra="forbid") 22 | name: tp.Literal["FmriMlp"] = "FmriMlp" 23 | 24 | hidden: int = 4096 25 | n_blocks: int = 4 26 | norm_type: str = "ln" 27 | act_first: bool = False 28 | 29 | n_repetition_times: int = 1 30 | time_agg: tp.Literal["in_mean", "in_linear", "out_mean", "out_linear"] = "out_linear" 31 | 32 | use_tr_embeds: bool = False 33 | tr_embed_dim: int = 16 34 | use_tr_layer: bool = False 35 | 36 | out_dim: int | None = None 37 | 38 | subject_layers: bool = False 39 | n_subjects: int = 20 40 | subject_layers_dim: tp.Literal["input", "hidden"] = "hidden" 41 | subject_layers_id: bool = False 42 | 43 | output_head_config: MlpConfig | dict[str, MlpConfig] | None = None 44 | 45 | def build(self, n_in_channels: int, n_outputs: int | None) -> nn.Module: 46 | out_dim = self.out_dim if n_outputs is None else n_outputs 47 | if out_dim is None: 48 | raise ValueError("One of n_outputs or config.out_dim must be set.") 49 | 50 | return FmriMlp( 51 | in_dim=n_in_channels, 52 | out_dim=out_dim, 53 | config=self, 54 | ) 55 | 56 | 57 | class FmriMlp(nn.Module): 58 | 59 | def __init__( 60 | self, 61 | in_dim: int, 62 | out_dim: int, 63 | config: FmriMlpConfig | None = None, 64 | ): 65 | super().__init__() 66 | config = config if config is not None else FmriMlpConfig() 67 | 68 | self.in_time_agg: None | nn.Module = None 69 | self.out_time_agg: None | nn.Module = None 70 | self.n_repetition_times = config.n_repetition_times 71 | if config.time_agg == "in_mean": 72 | self.in_time_agg = Mean(dim=2, keepdim=True) 73 | self.n_repetition_times = 1 74 | elif config.time_agg == "in_linear": 75 | self.in_time_agg = nn.LazyLinear(1) 76 | self.n_repetition_times = 1 77 | elif config.time_agg == "out_mean": 78 | self.out_time_agg = Mean(dim=2) 79 | elif config.time_agg == "out_linear": 80 | self.out_time_agg = nn.LazyLinear(1) 81 | 82 | norm_func = ( 83 | partial(nn.BatchNorm1d, num_features=config.hidden) 84 | if config.norm_type == "bn" 85 | else partial(nn.LayerNorm, normalized_shape=config.hidden) 86 | ) 87 | act_fn = partial(nn.ReLU, inplace=True) if config.norm_type == "bn" else nn.GELU 88 | act_and_norm = (act_fn, norm_func) if config.act_first else (norm_func, act_fn) 89 | 90 | self.subject_layers = None 91 | if config.subject_layers: 92 | dim = {"hidden": config.hidden, "input": in_dim}[config.subject_layers_dim] 93 | self.subject_layers = SubjectLayers( 94 | in_dim, 95 | dim, 96 | config.n_subjects, 97 | config.subject_layers_id, 98 | mode="for_loop", 99 | ) 100 | in_dim = dim 101 | 102 | self.tr_embeddings = None 103 | if config.use_tr_embeds: 104 | self.tr_embeddings = nn.Embedding( 105 | self.n_repetition_times, config.tr_embed_dim 106 | ) 107 | in_dim += config.tr_embed_dim 108 | 109 | self.lin0: nn.Conv1d | nn.Linear 110 | if config.use_tr_layer: 111 | self.lin0 = nn.Conv1d( 112 | in_channels=self.n_repetition_times, 113 | out_channels=self.n_repetition_times * config.hidden, 114 | kernel_size=in_dim, 115 | groups=self.n_repetition_times, 116 | bias=True, 117 | ) 118 | else: 119 | self.lin0 = nn.Linear(in_dim, config.hidden) 120 | self.post_lin0 = nn.Sequential( 121 | *[item() for item in act_and_norm], nn.Dropout(0.5) 122 | ) 123 | 124 | self.mlp = nn.ModuleList( 125 | [ 126 | nn.Sequential( 127 | nn.Linear(config.hidden, config.hidden), 128 | *[item() for item in act_and_norm], 129 | nn.Dropout(0.15), 130 | ) 131 | for _ in range(config.n_blocks) 132 | ] 133 | ) 134 | self.lin1 = nn.Linear(config.hidden, out_dim, bias=True) 135 | self.n_blocks = config.n_blocks 136 | 137 | self.output_head: None | MLP | dict[str, MLP] = None 138 | if config.output_head_config is not None: 139 | if isinstance(config.output_head_config, MlpConfig): 140 | self.output_head = config.output_head_config.build(input_size=out_dim) 141 | elif isinstance(config.output_head_config, dict): 142 | self.output_head = nn.ModuleDict() 143 | for name, head_config in config.output_head_config.items(): 144 | self.output_head[name] = head_config.build( 145 | input_size=out_dim, 146 | ) 147 | 148 | def forward( 149 | self, 150 | x: torch.Tensor, 151 | subject_ids: torch.Tensor | None = None, 152 | channel_positions: torch.Tensor | None = None, 153 | ) -> torch.Tensor | dict[str, torch.Tensor]: 154 | x = x.reshape(x.shape[0], -1, x.shape[-1]) 155 | 156 | if self.in_time_agg is not None: 157 | x = self.in_time_agg(x) 158 | 159 | B, _, T = x.shape 160 | assert ( 161 | T == self.n_repetition_times 162 | ), f"Mismatch between expected and provided number TRs: {T} != {self.n_repetition_times}" 163 | 164 | if self.subject_layers is not None: 165 | x = self.subject_layers(x, subject_ids) 166 | x = x.permute(0, 2, 1) 167 | 168 | if self.tr_embeddings is not None: 169 | embeds = self.tr_embeddings(torch.arange(T, device=x.device)) 170 | embeds = torch.tile(embeds, dims=(B, 1, 1)) 171 | x = torch.cat([x, embeds], dim=2) 172 | 173 | x = self.lin0(x).reshape(B, T, -1) 174 | 175 | x = self.post_lin0(x) 176 | 177 | residual = x 178 | for res_block in range(self.n_blocks): 179 | x = self.mlp[res_block](x) 180 | x += residual 181 | residual = x 182 | 183 | x = x.permute(0, 2, 1) 184 | 185 | if self.out_time_agg is not None: 186 | x = self.out_time_agg(x) 187 | 188 | x = x.flatten(1) 189 | 190 | x = self.lin1(x) 191 | 192 | if isinstance(self.output_head, MLP): 193 | out = self.output_head(x) 194 | elif isinstance(self.output_head, nn.ModuleDict): 195 | out = {name: head(x) for name, head in self.output_head.items()} 196 | else: 197 | out = x 198 | 199 | return out 200 | -------------------------------------------------------------------------------- /modeling_utils/modeling_utils/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import typing as tp 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | import torchmetrics 14 | 15 | 16 | class OnlinePearsonCorr(torchmetrics.regression.PearsonCorrCoef): 17 | 18 | def __init__( 19 | self, 20 | dim: int, 21 | reduction: tp.Literal["mean", "sum", "none"] | None = "mean", 22 | ): 23 | 24 | super().__init__() 25 | self.dim = dim 26 | self.reduction = reduction 27 | self._initialized = False 28 | 29 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 30 | 31 | if self.dim == 1: 32 | preds = preds.T 33 | target = target.T 34 | 35 | if not self._initialized: 36 | 37 | self.num_outputs = preds.shape[1] 38 | state_names = ["mean_x", "mean_y", "var_x", "var_y", "corr_xy", "n_total"] 39 | for state_name in state_names: 40 | self.add_state( 41 | state_name, 42 | default=torch.zeros(self.num_outputs).to(self.device), 43 | dist_reduce_fx=None, 44 | ) 45 | self._initialized = True 46 | 47 | super().update(preds, target) 48 | 49 | def compute(self): 50 | 51 | corrcoef = super().compute() 52 | 53 | if self.reduction == "mean": 54 | return torch.mean(corrcoef) 55 | elif self.reduction == "sum": 56 | return torch.sum(corrcoef) 57 | else: 58 | 59 | return corrcoef 60 | 61 | def reset(self) -> None: 62 | self._initialized = False 63 | super().reset() 64 | 65 | 66 | class Rank(torchmetrics.Metric): 67 | 68 | is_differentiable: bool = False 69 | higher_is_better: bool = False 70 | full_state_update: bool = True 71 | 72 | def __init__( 73 | self, 74 | reduction: tp.Literal["mean", "median", "std"] = "median", 75 | relative: bool = False, 76 | ): 77 | super().__init__() 78 | 79 | self.reduction = reduction 80 | self.relative = relative 81 | self.add_state( 82 | "ranks", 83 | default=torch.Tensor([]), 84 | dist_reduce_fx="cat", 85 | ) 86 | self.rank_count: torch.Tensor 87 | 88 | @classmethod 89 | def _compute_sim(cls, x, y, norm_kind="y", eps=1e-15): 90 | if norm_kind is None: 91 | eq, inv_norms = "b", torch.ones(x.shape[0]) 92 | elif norm_kind == "x": 93 | eq, inv_norms = "b", 1 / (eps + x.norm(dim=(1), p=2)) 94 | elif norm_kind == "y": 95 | eq, inv_norms = "o", 1 / (eps + y.norm(dim=(1), p=2)) 96 | elif norm_kind == "xy": 97 | eq = "bo" 98 | inv_norms = 1 / ( 99 | eps + torch.outer(x.norm(dim=(1), p=2), y.norm(dim=(1), p=2)) 100 | ) 101 | else: 102 | raise ValueError(f"norm must be None, x, y or xy, got {norm_kind}.") 103 | 104 | return torch.einsum(f"bc,oc,{eq}->bo", x, y, inv_norms) 105 | 106 | def _compute_ranks( 107 | self, 108 | x: torch.Tensor, 109 | y: torch.Tensor, 110 | x_labels: None | list[str] = None, 111 | y_labels: None | list[str] = None, 112 | ) -> torch.Tensor: 113 | scores = self._compute_sim(x, y) 114 | 115 | if x_labels is not None and y_labels is not None: 116 | 117 | true_inds = torch.tensor( 118 | [y_labels.index(x) for x in x_labels], 119 | dtype=torch.long, 120 | device=scores.device, 121 | )[:, None] 122 | true_scores = torch.take_along_dim(scores, true_inds, dim=1) 123 | else: 124 | 125 | assert x_labels is None and y_labels is None 126 | assert x.shape[0] == y.shape[0] 127 | true_scores = torch.diag(scores)[:, None] 128 | 129 | ranks_gt = (scores > true_scores).nansum(axis=1) 130 | ranks_ge = (scores >= true_scores).nansum(axis=1) - 1 131 | ranks = (ranks_gt + ranks_ge) / 2 132 | ranks[ranks < 0] = len(scores) // 2 133 | 134 | if self.relative: 135 | ranks /= len(y) 136 | 137 | return ranks 138 | 139 | @torch.inference_mode() 140 | def update( 141 | self, 142 | x: torch.Tensor, 143 | y: torch.Tensor, 144 | x_labels: None | list[str] = None, 145 | y_labels: None | list[str] = None, 146 | ) -> None: 147 | 148 | ranks = self._compute_ranks(x, y, x_labels, y_labels) 149 | self.ranks = torch.cat([self.ranks, ranks]) 150 | 151 | def compute(self) -> torch.Tensor: 152 | agg_func: tp.Callable 153 | if self.reduction == "mean": 154 | agg_func = torch.mean 155 | elif self.reduction == "median": 156 | agg_func = torch.median 157 | elif self.reduction == "std": 158 | agg_func = torch.std 159 | else: 160 | raise ValueError( 161 | f'Unknown aggregation {self.reduction} for computing metric. Available aggregations are: "mean", "median" or "std".' 162 | ) 163 | return agg_func(self.ranks) 164 | 165 | def _compute_macro_average( 166 | self, ranks: torch.Tensor, labels: list[str] 167 | ) -> tp.Dict[str, float]: 168 | 169 | assert len(ranks) == len(labels) 170 | groups = defaultdict(list) 171 | agg_func = np.mean if self.reduction == "mean" else np.median 172 | for i, label in enumerate(labels): 173 | groups[label].append(ranks[i]) 174 | return {label: agg_func(ranks) for label, ranks in groups.items()} 175 | 176 | @classmethod 177 | def _compute_topk_scores( 178 | cls, 179 | x: torch.Tensor, 180 | y: torch.Tensor, 181 | y_labels: list[str], 182 | k: int = 5, 183 | ) -> tp.Tuple[list[list[str]], list[list[float]]]: 184 | 185 | scores = cls._compute_sim(x, y) 186 | topk_inds = torch.argsort(scores, dim=1, descending=True)[:, :k] 187 | topk_labels = [[y_labels[ind] for ind in inds] for inds in topk_inds] 188 | scores = [ 189 | [scores[i, ind].item() for ind in inds] for i, inds in enumerate(topk_inds) 190 | ] 191 | return topk_labels, scores 192 | 193 | 194 | class TopkAcc(Rank): 195 | 196 | is_differentiable: bool = False 197 | higher_is_better: bool = True 198 | full_state_update: bool = True 199 | 200 | def __init__(self, topk: int = 5): 201 | super().__init__(relative=False) 202 | self.topk = topk 203 | 204 | def _compute_macro_average( 205 | self, ranks: torch.Tensor, labels: list[str] 206 | ) -> tp.Dict[str, float]: 207 | 208 | groups = defaultdict(list) 209 | for i, label in enumerate(labels): 210 | groups[label].append(ranks[i]) 211 | return { 212 | label: float(np.mean([r < self.topk for r in ranks])) 213 | for label, ranks in groups.items() 214 | } 215 | 216 | def compute(self) -> torch.Tensor: 217 | ranks = self.ranks 218 | return (ranks < self.topk).float().mean() 219 | -------------------------------------------------------------------------------- /data_utils/data_utils/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import typing as tp 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import pydantic 12 | import yaml 13 | from typing_extensions import Annotated 14 | 15 | PathLike = str | Path 16 | 17 | 18 | logger = logging.getLogger("data_utils") 19 | _handler = logging.StreamHandler() 20 | _formatter = logging.Formatter( 21 | "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S" 22 | ) 23 | _handler.setFormatter(_formatter) 24 | logger.addHandler(_handler) 25 | logger.setLevel(logging.INFO) 26 | 27 | 28 | def _int_cast(v: tp.Any) -> tp.Any: 29 | 30 | if isinstance(v, int): 31 | return str(v) 32 | return v 33 | 34 | 35 | StrCast = Annotated[str, pydantic.BeforeValidator(_int_cast)] 36 | CACHE_FOLDER = Path.home() / ".cache/data_utils/" 37 | CACHE_FOLDER.mkdir(parents=True, exist_ok=True) 38 | 39 | 40 | class Frequency(float): 41 | 42 | @tp.overload 43 | def to_ind(self, seconds: float) -> int: ... 44 | 45 | @tp.overload 46 | def to_ind(self, seconds: np.ndarray) -> np.ndarray: ... 47 | 48 | def to_ind(self, seconds: tp.Any) -> tp.Any: 49 | 50 | if isinstance(seconds, np.ndarray): 51 | return np.round(seconds * self).astype(int) 52 | return int(round(seconds * self)) 53 | 54 | @tp.overload 55 | def to_sec(self, index: int) -> float: ... 56 | 57 | @tp.overload 58 | def to_sec(self, index: np.ndarray) -> np.ndarray: ... 59 | 60 | def to_sec(self, index: tp.Any) -> tp.Any: 61 | 62 | return index / self 63 | 64 | @staticmethod 65 | def _yaml_representer(dumper, data): 66 | 67 | return dumper.represent_scalar("tag:yaml.org,2002:float", str(float(data))) 68 | 69 | 70 | class TimedArray: 71 | def __init__( 72 | self, 73 | *, 74 | frequency: float, 75 | start: float, 76 | data: np.ndarray | None = None, 77 | duration: float | None = None, 78 | aggregation: str = "sum", 79 | ) -> None: 80 | 81 | self.frequency = Frequency(frequency) 82 | self.start = start 83 | self.aggregation = aggregation 84 | exp_size = 0 85 | if duration is not None and duration < 0: 86 | raise ValueError(f"duration should be None or >=0, got {duration}") 87 | 88 | if data is None: 89 | if duration is None: 90 | raise ValueError("Missing data or duration") 91 | 92 | if not frequency: 93 | data = np.zeros((0,)) 94 | else: 95 | exp_size = max(1, self.frequency.to_ind(duration)) 96 | data = np.zeros((0, exp_size)) 97 | self.data = data 98 | if frequency and duration is not None: 99 | exp_size = max(1, self.frequency.to_ind(duration)) 100 | if not self.data.shape[-1]: 101 | msg = "Last dimension is empty but frequency is not null " 102 | msg += f"(shape={self.data.shape})" 103 | raise ValueError(msg) 104 | if abs(data.shape[-1] - exp_size) > 2: 105 | msg = f"Data has incorrect (last) dimension {data.shape} for duration " 106 | msg += f"{duration} and frequency {frequency} (expected {exp_size})" 107 | raise ValueError(msg) 108 | if frequency: 109 | self.duration = self.frequency.to_sec(data.shape[-1]) 110 | elif duration is None: 111 | raise ValueError(f"duration must be provided if {frequency=}") 112 | else: 113 | self.duration = duration 114 | 115 | self._overlapping_data_count: None | np.ndarray = None 116 | if aggregation == "average": 117 | num = self.data.shape[-1] if self.frequency else 1 118 | self._overlapping_data_count = np.zeros(num, dtype=int) 119 | elif aggregation != "sum": 120 | raise ValueError(f"Unknown {aggregation=}") 121 | 122 | def __repr__(self) -> str: 123 | cls = self.__class__.__name__ 124 | fields = "frequency,start,duration,aggregation,data".split(",") 125 | string = ",".join(f"{f}={getattr(self, f)}" for f in fields) 126 | return f"{cls}({string})" 127 | 128 | def __iadd__(self, other: "TimedArray") -> "TimedArray": 129 | if other.frequency and self.frequency != other.frequency: 130 | diff = abs(self.frequency - other.frequency) 131 | if diff * max(self.duration, other.duration) >= 0.5: 132 | 133 | msg = f"Cannot add with different (non-0) frequencies ({other.frequency} and {self.frequency})" 134 | raise ValueError(msg) 135 | if not self.data.size: 136 | 137 | last = -1 if other.frequency else None 138 | shape = other.data.shape[:last] 139 | if self.frequency: 140 | shape += (self.data.shape[-1],) 141 | self.data = np.zeros(shape, dtype=other.data.dtype) 142 | if self.frequency: 143 | slices = [ 144 | sa1._overlap_slice(sa2.start, sa2.duration) 145 | for sa1, sa2 in [(self, other), (other, self)] 146 | ] 147 | if slices[0] is None or slices[1] is None: 148 | return self 149 | 150 | self_slice = slices[0][-1] 151 | other_slice = slices[1][-1] 152 | else: 153 | self_slice = None 154 | other_slice = None 155 | if self._overlapping_data_count is None: 156 | 157 | self.data[..., self_slice] += other.data[..., other_slice] 158 | else: 159 | 160 | counts = self._overlapping_data_count[..., self_slice] 161 | upd = counts / (1.0 + counts) 162 | self.data[..., self_slice] *= upd 163 | self.data[..., self_slice] += (1 - upd) * other.data[..., other_slice] 164 | counts += 1 165 | return self 166 | 167 | def _overlap_slice( 168 | self, start: float, duration: float 169 | ) -> tuple[float, float, slice | None] | None: 170 | if duration < 0: 171 | raise ValueError(f"duration should be >=0, got {duration=}") 172 | overlap_start = max(start, self.start) 173 | overlap_stop = min(start + duration, self.start + self.duration) 174 | if overlap_stop < overlap_start: 175 | return None 176 | 177 | if overlap_stop == overlap_start and self.duration and duration: 178 | return None 179 | 180 | if not self.frequency: 181 | return overlap_start, overlap_stop - overlap_start, None 182 | start_ind = self.frequency.to_ind(overlap_start - self.start) 183 | duration_ind = self.frequency.to_ind(overlap_stop - overlap_start) 184 | 185 | if duration_ind <= 0: 186 | 187 | duration_ind = 1 188 | 189 | tps = self.data.shape[-1] 190 | if start_ind > tps - duration_ind: 191 | start_ind = tps - duration_ind 192 | if start_ind < 0: 193 | raise RuntimeError(f"Fail for {start=} {duration=} on {self}") 194 | start = self.frequency.to_sec(start_ind) + self.start 195 | duration = self.frequency.to_sec(duration_ind) 196 | 197 | out = start, duration, slice(start_ind, start_ind + duration_ind) 198 | return out 199 | 200 | def overlap(self, start: float, duration: float) -> tp.Optional["TimedArray"]: 201 | 202 | out = self._overlap_slice(start, duration) 203 | if out is None: 204 | return None 205 | ostart, oduration, sl = out 206 | return TimedArray( 207 | frequency=self.frequency, 208 | start=ostart, 209 | duration=oduration, 210 | data=self.data[..., sl], 211 | ) 212 | 213 | 214 | yaml.representer.SafeRepresenter.add_representer(Frequency, Frequency._yaml_representer) 215 | -------------------------------------------------------------------------------- /data_utils/data_utils/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import importlib 7 | import logging 8 | import shutil 9 | import tempfile 10 | import typing as tp 11 | from collections import OrderedDict 12 | from pathlib import Path 13 | 14 | import exca 15 | import pandas as pd 16 | import pydantic 17 | 18 | from .base import PathLike, StrCast 19 | from .enhancers import Enhancer 20 | from .infra import CacheDict, MapInfra 21 | from .segments import validate_events 22 | from .utils import compress_string 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def _check_folder_path(path: PathLike, name: str) -> Path: 28 | 29 | path = Path(path) 30 | if not path.parent.exists(): 31 | raise RuntimeError(f"Parent folder {path.parent} of {name} must exist first.") 32 | path.mkdir(exist_ok=True) 33 | return path 34 | 35 | 36 | TIMELINES: tp.Dict[str, "BaseData"] = {} 37 | 38 | 39 | class BaseData(pydantic.BaseModel): 40 | 41 | subject: StrCast 42 | path: PathLike 43 | timeline: str = "" 44 | 45 | version: tp.ClassVar[str] = "v2" 46 | study: tp.ClassVar[str] 47 | device: tp.ClassVar[str] = "" 48 | 49 | @tp.final 50 | @classmethod 51 | def iter_timelines(cls, path: PathLike) -> tp.Iterator["BaseData"]: 52 | path = _check_folder_path(path, name="path") 53 | study = "Algonauts2025" 54 | if path.name.lower() != study.lower(): 55 | 56 | for name in (study, study.lower(), study.lower().replace("bold", "")): 57 | if (path / name).exists(): 58 | path = path / name 59 | logger.debug("Updating study path to %s", path) 60 | break 61 | found = False 62 | for tl in cls._iter_timelines(path): 63 | found = True 64 | yield tl 65 | if not found: 66 | raise RuntimeError(f"No timeline found for {study} in {path}") 67 | 68 | def __init_subclass__(cls) -> None: 69 | super().__init_subclass__() 70 | 71 | def model_post_init(self, log__: tp.Any) -> None: 72 | super().model_post_init(log__) 73 | 74 | if not self.timeline: 75 | excludes = "path", "timeline" 76 | timeline = "Algonauts2025" 77 | for name, arg in type(self).model_fields.items(): 78 | if name in excludes or arg.init_var is False: 79 | continue 80 | value = getattr(self, name) 81 | timeline += f"_{name}-{str(value)}" 82 | self.timeline = compress_string(timeline) 83 | 84 | TIMELINES[self.timeline] = self 85 | 86 | @tp.final 87 | def load(self) -> pd.DataFrame: 88 | 89 | events = self._load_events() 90 | 91 | for col in ["subject", "timeline"]: 92 | if col in events: 93 | raise ValueError(f"Column {col} already exists in the events dataframe") 94 | events[col] = getattr(self, col) 95 | events["study"] = "Algonauts2025" 96 | 97 | events = validate_events(events) 98 | return events 99 | 100 | 101 | class StudyLoader(pydantic.BaseModel): 102 | 103 | path: PathLike 104 | query: str | None = None 105 | 106 | enhancers: tp.List[Enhancer] | OrderedDict[str, Enhancer] = [] 107 | infra: MapInfra = MapInfra(cluster="processpool") 108 | _build_infra: MapInfra = MapInfra() 109 | _timelines: tp.List[BaseData] | None = None 110 | 111 | def _exclude_from_cls_uid(self) -> tp.List[str]: 112 | return ["path"] 113 | 114 | def model_post_init(self, log__: tp.Any) -> None: 115 | if isinstance(self.enhancers, dict): 116 | version = exca.__version__ 117 | if tuple(int(n) for n in version.split(".")) < (0, 4, 2): 118 | msg = f"study_loader.enhancers cannot be a dict with exca<0.4.2 ({version=})" 119 | raise RuntimeError(msg) 120 | try: 121 | with tempfile.TemporaryDirectory() as tmp: 122 | _ = CacheDict(folder=tmp, cache_type="ParquetPandasDataFrame") 123 | except ValueError as e: 124 | raise RuntimeError('Run "pip install pyarrow" to enable study cache') from e 125 | 126 | study = self.study() 127 | 128 | name = self.__class__.__name__ + ",{version}" 129 | i = self.infra 130 | 131 | if i.cluster is not None and "pool" in i.cluster: 132 | if "max_jobs" not in i.model_fields_set: 133 | i.max_jobs = None 134 | 135 | i.version = type(self).model_fields["infra"].default.version + f"-{study.version}" 136 | folder_name = f"{name},Algonauts2025" 137 | i._uid_string = folder_name + "/{method},{uid}" 138 | 139 | names = ["folder", "version", "_uid_string", "mode"] 140 | self._build_infra._update({x: getattr(i, x) for x in names}) 141 | 142 | if self.infra.mode == "force" and self.infra.folder is not None: 143 | folder = Path(self.infra.folder) / folder_name 144 | if folder.exists(): 145 | shutil.rmtree(folder) 146 | 147 | def study(self) -> tp.Type[BaseData]: 148 | 149 | return getattr( 150 | importlib.import_module("data_utils.studies.algonauts2025"), "Algonauts2025" 151 | ) 152 | 153 | def iter_timelines(self) -> tp.Iterator[BaseData]: 154 | 155 | if self._timelines is None: 156 | self._timelines = list(self.study().iter_timelines(self.path)) 157 | else: 158 | for tl in self._timelines: 159 | TIMELINES[tl.timeline] = tl 160 | 161 | return iter(self._timelines) 162 | 163 | def study_summary(self, apply_query: bool = True) -> pd.DataFrame: 164 | 165 | out = pd.DataFrame([dict(tl) for tl in self.iter_timelines()]) 166 | out["subject"] = out.subject.apply(lambda x: f"Algonauts2025/{x}") 167 | if any(n in out.columns for n in ["subject_index", "timeline_index"]): 168 | msg = "Study dataframes are not allowed to have subject_index nor timeline_index" 169 | msg += f" in their column, found columns: {list(out.columns)}" 170 | raise RuntimeError(msg) 171 | groups = out.groupby("subject") 172 | out.loc[:, "subject_index"] = groups.ngroup() 173 | out.loc[:, "subject_timeline_index"] = groups.cumcount() 174 | out.loc[:, "timeline_index"] = out.index 175 | 176 | if apply_query and self.query is not None: 177 | out = out.query(self.query) 178 | return out 179 | 180 | def build(self) -> pd.DataFrame: 181 | 182 | for _ in self.iter_timelines(): 183 | pass 184 | query = self.query 185 | out = list(self._build([query]))[0] 186 | 187 | return out 188 | 189 | @infra.apply( 190 | item_uid=lambda item: item.timeline, 191 | exclude_from_cache_uid=("enhancers", "query"), 192 | ) 193 | def _load_timelines( 194 | self, timelines: tp.Iterable[BaseData] 195 | ) -> tp.Iterator[pd.DataFrame]: 196 | 197 | for tl in timelines: 198 | TIMELINES[tl.timeline] = tl 199 | 200 | out = tl.load() 201 | out.subject = f"Algonauts2025/{tl.subject}" 202 | yield out 203 | 204 | @_build_infra.apply( 205 | item_uid=str, 206 | exclude_from_cache_uid=("query",), 207 | cache_type="ParquetPandasDataFrame", 208 | ) 209 | def _build(self, queries: tp.Iterable[str | None]) -> tp.Iterator[pd.DataFrame]: 210 | 211 | timelines = list(self.iter_timelines()) 212 | summary: pd.DataFrame | None = None 213 | for query in queries: 214 | sub = timelines 215 | if query is not None: 216 | if summary is None: 217 | summary = self.study_summary(apply_query=False) 218 | selected = summary.query(query) 219 | sub = [timelines[i] for i in selected.index] 220 | if not sub: 221 | msg = f"No timeline found for Algonauts2025 with {query=}" 222 | raise RuntimeError(msg) 223 | events = pd.concat(list(self._load_timelines(sub))).reset_index(drop=True) 224 | if isinstance(self.enhancers, dict): 225 | enhancers = list(self.enhancers.values()) 226 | else: 227 | enhancers = list(self.enhancers) 228 | for enhancer in enhancers: 229 | events = enhancer(events) 230 | events = validate_events(events) 231 | yield events 232 | -------------------------------------------------------------------------------- /data_utils/data_utils/studies/algonauts2025.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | from itertools import product 8 | from pathlib import Path 9 | 10 | import h5py 11 | import nibabel 12 | import numpy as np 13 | import pandas as pd 14 | 15 | from data_utils.data import BaseData 16 | from data_utils.download import Datalad 17 | 18 | 19 | class Algonauts2025(BaseData): 20 | task: tp.Literal["friends", "movie10"] 21 | movie: str 22 | chunk: str 23 | run: int = 0 24 | 25 | device: tp.ClassVar[str] = "Fmri" 26 | 27 | @classmethod 28 | def _download(cls, path: Path) -> None: 29 | Datalad( 30 | dset_dir=path, 31 | ).download() 32 | 33 | @classmethod 34 | def _iter_timelines( 35 | cls, 36 | path: str | Path, 37 | ): 38 | for subject in ["sub-01", "sub-02", "sub-03", "sub-05"]: 39 | for task in ["friends", "movie10"]: 40 | if task == "friends": 41 | season_episode_chunk = range(1, 8), range(1, 26), "abcd" 42 | for season, episode, chunk in product(*season_episode_chunk): 43 | timeline = cls( 44 | path=str(path), 45 | subject=subject, 46 | task=task, 47 | movie=str(season), 48 | chunk=f"e{episode:02d}{chunk}", 49 | ) 50 | stim_path = timeline._get_transcript_filepath() 51 | if ( 52 | (season == 5 and episode == 20 and chunk == "a") 53 | or (season == 4 and episode == 1 and chunk == "a") 54 | or (season == 6 and episode == 3 and chunk == "a") 55 | or (season == 4 and episode == 13 and chunk == "b") 56 | or (season == 4 and episode == 1 and chunk == "b") 57 | ): 58 | continue 59 | if stim_path.exists(): 60 | yield timeline 61 | elif task == "movie10": 62 | movie_chunk_run = ( 63 | ["bourne", "wolf", "life", "figures"], 64 | range(1, 18), 65 | [1, 2], 66 | ) 67 | for movie, chunk, run in product(*movie_chunk_run): 68 | 69 | if movie in ["bourne", "wolf"] and run == 2: 70 | continue 71 | timeline = cls( 72 | path=str(path), 73 | subject=subject, 74 | task=task, 75 | movie=movie, 76 | chunk=str(chunk), 77 | run=run, 78 | ) 79 | stim_path = timeline._get_transcript_filepath() 80 | if stim_path.exists(): 81 | yield timeline 82 | 83 | def _get_transcript_filepath(self): 84 | folder = ( 85 | Path(self.path) 86 | / "download" 87 | / "algonauts_2025.competitors" 88 | / "stimuli" 89 | / "transcripts" 90 | / f"{self.task}" 91 | ) 92 | if self.task == "friends": 93 | stim_path = ( 94 | folder 95 | / f"s{self.movie}" 96 | / f"friends_s{int(self.movie):02d}{self.chunk}.tsv" 97 | ) 98 | elif self.task == "movie10": 99 | stim_path = ( 100 | folder 101 | / f"{self.movie}" 102 | / f"movie10_{self.movie}{int(self.chunk):02d}.tsv" 103 | ) 104 | return stim_path 105 | 106 | def _get_movie_filepath(self) -> Path: 107 | folder = ( 108 | Path(self.path) 109 | / "download" 110 | / "algonauts_2025.competitors" 111 | / "stimuli" 112 | / "movies" 113 | / f"{self.task}" 114 | ) 115 | if self.task == "friends": 116 | stim_path = ( 117 | folder 118 | / f"s{self.movie}" 119 | / f"friends_s{int(self.movie):02d}{self.chunk}.mkv" 120 | ) 121 | elif self.task == "movie10": 122 | stim_path = ( 123 | folder / f"{self.movie}" / f"{self.movie}{int(self.chunk):02d}.mkv" 124 | ) 125 | return stim_path 126 | 127 | def _get_fmri_filepath(self) -> Path: 128 | folder = Path(self.path) / "download" / "algonauts_2025.competitors" / "fmri" 129 | subj_dir = folder / self.subject / "func" 130 | file_stem = f"{self.subject}_task-{self.task}_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net" 131 | if self.task == "friends": 132 | fmri_file = subj_dir / f"{file_stem}_desc-s123456_bold.h5" 133 | else: 134 | fmri_file = subj_dir / f"{file_stem}_bold.h5" 135 | return fmri_file 136 | 137 | def _load_fmri(self, timeline: str) -> nibabel.Nifti2Image: 138 | fmri_file = self._get_fmri_filepath() 139 | fmri = h5py.File(fmri_file, "r") 140 | if self.task == "friends": 141 | key = f"{int(self.movie):02d}{self.chunk}" 142 | else: 143 | key = f"{self.movie}{int(self.chunk):02d}" 144 | if self.movie in ["life", "figures"]: 145 | key += f"_run-{self.run}" 146 | selected_key = [key_ for key_ in fmri.keys() if key in key_] 147 | if len(selected_key) != 1: 148 | print(key, selected_key, list(fmri.keys())) 149 | raise ValueError(f"Multiple or no keys found, {key}, {list(fmri.keys())}") 150 | fmri = fmri[selected_key[0]] 151 | data = fmri[:].astype(np.float32) 152 | obj = nibabel.Nifti2Image(data.T, affine=np.eye(4)) 153 | return obj 154 | 155 | def _get_split(self) -> str: 156 | 157 | if self.task == "friends": 158 | if int(self.movie) in range(1, 7): 159 | return "train" 160 | elif int(self.movie) == 7: 161 | return "test" 162 | else: 163 | return "train" 164 | 165 | def _load_events(self) -> pd.DataFrame: 166 | 167 | all_events = [] 168 | if not (self.task == "friends" and self.movie == "7"): 169 | uri = f"method:_load_fmri?timeline={self.timeline}" 170 | fmri = self._load_fmri(timeline="") 171 | fmri_duration = fmri.shape[-1] * 1.49 172 | fmri_event = dict( 173 | type="Fmri", 174 | filepath=uri, 175 | start=0, 176 | frequency=1 / 1.49, 177 | duration=fmri_duration, 178 | ) 179 | all_events.append(fmri_event) 180 | 181 | movie_filepath = self._get_movie_filepath() 182 | movie_event = dict(type="Video", filepath=movie_filepath, start=0) 183 | all_events.append(movie_event) 184 | 185 | transcript_path = self._get_transcript_filepath() 186 | transcript_df = pd.read_csv(transcript_path, sep="\t") 187 | word_events = [] 188 | for _, row in transcript_df.iterrows(): 189 | words = eval(row["words_per_tr"]) 190 | starts = eval(row["onsets_per_tr"]) 191 | durations = eval(row["durations_per_tr"]) 192 | for word, start, duration in zip(words, starts, durations): 193 | event = dict( 194 | type="Word", 195 | text=word, 196 | start=start, 197 | duration=duration, 198 | stop=start + duration, 199 | language="english", 200 | ) 201 | word_events.append(event) 202 | if word_events: 203 | word_df = pd.DataFrame(word_events) 204 | text = " ".join(word_df["text"].tolist()) 205 | text_event = dict( 206 | type="Text", 207 | text=text, 208 | start=word_df["start"].min(), 209 | duration=word_df["stop"].max() - word_df["start"].min(), 210 | stop=word_df["stop"].max(), 211 | language="english", 212 | ) 213 | all_events.append(text_event) 214 | all_events.extend(word_events) 215 | 216 | events_df = pd.DataFrame(all_events) 217 | events_df["split"] = self._get_split() 218 | events_df["movie"] = "movie:" + str(self.movie) 219 | events_df["chunk"] = "chunk:" + str(self.chunk) 220 | return events_df 221 | -------------------------------------------------------------------------------- /data_utils/data_utils/segments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import collections 7 | import dataclasses 8 | import logging 9 | import typing as tp 10 | import warnings 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import tqdm 15 | 16 | from .events import Event 17 | from .utils import warn_once 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | @dataclasses.dataclass 23 | class Segment: 24 | 25 | start: float 26 | duration: float 27 | _index: np.ndarray 28 | 29 | ns_events: tp.List[Event] = dataclasses.field(default_factory=list) 30 | _trigger: float | tp.Dict[str, tp.Any] | None = None 31 | 32 | @property 33 | def events(self) -> pd.DataFrame: 34 | 35 | if not self.ns_events: 36 | raise RuntimeError(f"ns_events was not populated in {self}") 37 | if len(self.ns_events) != len(self._index): 38 | msg = f"Cannot recreate events dataframe as some rows were not actual Event\n(on segment={self})" 39 | raise RuntimeError(msg) 40 | return pd.DataFrame(index=self._index, data=[e.to_dict() for e in self.ns_events]) 41 | 42 | def subsegment(self, start: float, duration: float) -> "Segment": 43 | 44 | assert ( 45 | start >= 0 46 | ), "Start is relative to the segment start and must be non-negative" 47 | new_start = self.start + start 48 | new_duration = duration 49 | new_index, new_ns_events = [], [] 50 | for i, e in enumerate(self.ns_events): 51 | if e.start <= new_start + new_duration and e.start + e.duration >= new_start: 52 | new_index.append(self._index[i]) 53 | new_ns_events.append(e) 54 | new_index = np.array(new_index) 55 | return Segment( 56 | start=new_start, 57 | duration=new_duration, 58 | _index=new_index, 59 | ns_events=new_ns_events, 60 | _trigger=self._trigger, 61 | ) 62 | 63 | @property 64 | def event_list(self) -> list[Event]: 65 | raise RuntimeError( 66 | "segment.event_list is deprecated in favor of segment.ns_events" 67 | ) 68 | 69 | @property 70 | def stop(self) -> float: 71 | return self.start + self.duration 72 | 73 | def _to_feature(self) -> dict[str, tp.Any]: 74 | 75 | return { 76 | "start": self.start, 77 | "duration": self.duration, 78 | "events": self.ns_events, 79 | "trigger": self._trigger, 80 | } 81 | 82 | 83 | def _validate_event(event: pd.Series) -> dict[str, tp.Any]: 84 | 85 | event_type = event["type"] 86 | lower = {x.lower() for x in Event._CLASSES} 87 | if event_type in Event._CLASSES: 88 | event_class = Event._CLASSES[event_type] 89 | event_obj = event_class.from_dict(event).to_dict() 90 | 91 | event_dict = {**event, **event_obj} 92 | elif event_type in lower: 93 | raise ValueError(f"Legacy uncapitalized event {event}") 94 | else: 95 | warn_once( 96 | f'Unexpected type "{event["type"]}". Support for new event ' 97 | "types can be added by creating new `Event` classes in " 98 | "`data_utils.events`." 99 | ) 100 | event_dict = {**event} 101 | 102 | return event_dict 103 | 104 | 105 | def validate_events(events: pd.DataFrame) -> pd.DataFrame: 106 | 107 | if events.empty: 108 | return events.copy() 109 | msg = 'events DataFrame must have a "type" column with strings' 110 | if "type" not in events.keys(): 111 | raise ValueError(msg) 112 | types = events["type"].unique() 113 | if not all(isinstance(typ, str) for typ in types): 114 | raise ValueError(msg) 115 | 116 | df = pd.DataFrame( 117 | events.apply(_validate_event, axis=1).tolist(), 118 | index=events.index, 119 | ) 120 | 121 | null = df.loc[df.duration <= 0, :] 122 | if not null.empty: 123 | types = null["type"].unique() 124 | msg = f"Found {len(null)} event(s) with null duration (types: {types})" 125 | warnings.warn(msg) 126 | 127 | dfs = [] 128 | for _, sub in df.groupby(by="timeline", sort=False): 129 | dfs.append( 130 | sub.sort_values( 131 | by=["start", "duration"], ascending=[True, False], ignore_index=True 132 | ) 133 | ) 134 | important = ["type", "start", "duration", "timeline"] 135 | df = pd.concat(dfs, ignore_index=True) 136 | 137 | columns = important + [c for c in df.columns if c not in important] 138 | df = df.loc[:, columns] 139 | 140 | df = df.assign(stop=lambda x: x.start + x.duration) 141 | return df 142 | 143 | 144 | def _prepare_strided_windows( 145 | start: float, 146 | stop: float, 147 | stride: float, 148 | duration: float, 149 | drop_incomplete: bool = True, 150 | ) -> tuple[np.ndarray, np.ndarray]: 151 | 152 | eps = 1e-8 153 | if drop_incomplete: 154 | stop -= duration 155 | starts = np.arange(start, stop + eps, stride) 156 | durations = np.full_like(starts, fill_value=duration) 157 | return starts, durations 158 | 159 | 160 | def iter_segments( 161 | events: pd.DataFrame, 162 | ) -> tp.Iterator[Segment]: 163 | 164 | starts: tp.Any 165 | durations: tp.Any 166 | creators = SegmentCreator.from_obj(events) 167 | 168 | for creator in creators.values(): 169 | starts, durations = _prepare_strided_windows( 170 | creator.starts.min() - 4.47, 171 | creator.stops.max() - 4.47, 172 | 149.0, 173 | 149.0, 174 | drop_incomplete=False, 175 | ) 176 | for start_, duration_ in zip(starts, durations): 177 | seg = creator.select(start=start_, duration=duration_) 178 | seg._trigger = start_ 179 | yield seg 180 | return 181 | 182 | 183 | def list_segments( 184 | events: pd.DataFrame, 185 | ) -> list[Segment]: 186 | return list(iter_segments(**locals())) 187 | 188 | 189 | def find_enclosed(df: pd.DataFrame, start: float, duration: float) -> pd.Series: 190 | estart = np.array(df.start) 191 | estop = estart + np.array(df.duration) 192 | is_enclosed = np.logical_and(estart >= start, estop <= start + duration) 193 | return pd.Series(df.index[is_enclosed]) 194 | 195 | 196 | def find_overlap( 197 | events: pd.DataFrame, 198 | idx: int | pd.Series | None = None, 199 | *, 200 | start: float = 0.0, 201 | duration: float | np.ndarray | None = None, 202 | ) -> pd.Series: 203 | 204 | if idx is None: 205 | 206 | assert duration is not None 207 | assert events.timeline.nunique() == 1 208 | has_overlap = (events.start >= start) & (events.start < start + duration) 209 | has_overlap |= (events.start + events.duration > start) & ( 210 | events.start + events.duration <= start + duration 211 | ) 212 | has_overlap |= (events.start <= start) & ( 213 | events.start + events.duration >= start + duration 214 | ) 215 | 216 | out = events.index[has_overlap] 217 | return pd.Series(out) 218 | else: 219 | sel: list[int] = [] 220 | for segment in iter_segments( 221 | events, 222 | idx=idx, 223 | start=start, 224 | duration=duration, 225 | stride=None, 226 | ): 227 | sel.extend(segment._index.tolist()) 228 | 229 | return pd.Series(sel) 230 | 231 | 232 | class SegmentCreator: 233 | 234 | def __init__(self, events: list[Event]) -> None: 235 | timelines = {e.timeline for e in events} 236 | if len(timelines) > 1: 237 | name = self.__class__.__name__ 238 | msg = f"Cannot create {name} on several timelines, got {timelines}" 239 | raise ValueError(msg) 240 | self.events = np.array(events) 241 | self.starts = np.array([e.start for e in events]) 242 | self.indices = np.array([e._index for e in events]) 243 | self.stops = np.array([e.duration for e in events]) + self.starts 244 | 245 | @classmethod 246 | def from_obj(cls, obj: tp.Any) -> dict[str, "SegmentCreator"]: 247 | 248 | from data_utils import helpers 249 | 250 | timeline_events: dict[str, list[Event]] = collections.defaultdict(list) 251 | for e in helpers.extract_events(obj): 252 | timeline_events[e.timeline].append(e) 253 | timelines = list(timeline_events) 254 | if isinstance(obj, pd.DataFrame): 255 | 256 | timelines = list(obj.timeline.unique()) 257 | return {tl: cls(timeline_events[tl]) for tl in timelines} 258 | 259 | def select(self, start: float, duration: float) -> Segment: 260 | 261 | select = self.starts < start + duration 262 | select &= self.stops > start 263 | events = list(self.events[select]) 264 | index = self.indices[select] 265 | return Segment(ns_events=events, start=start, duration=duration, _index=index) 266 | -------------------------------------------------------------------------------- /data_utils/data_utils/features/text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import typing as tp 8 | import warnings 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import pydantic 13 | import torch 14 | from exca import MapInfra 15 | from exca.utils import environment_variables 16 | from torch import nn 17 | from torch.utils.data import DataLoader, Dataset 18 | from tqdm import tqdm 19 | 20 | import data_utils as du 21 | from data_utils.base import Frequency as Frequency 22 | from data_utils.base import TimedArray 23 | from data_utils.events import Event, EventTypesHelper 24 | from data_utils.segments import Segment 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class TextDataset(Dataset): 30 | 31 | def __init__(self, events: tp.List[du.events.Word]): 32 | self.events = events 33 | 34 | def __len__(self): 35 | return len(self.events) 36 | 37 | def __getitem__(self, idx): 38 | sel = self.events[idx] 39 | return sel.text, sel.context 40 | 41 | 42 | class LLAMA3p2(pydantic.BaseModel): 43 | _event_types_helper: EventTypesHelper 44 | _missing_default: torch.Tensor | None = None 45 | layers: list[float] = [0.5, 0.75, 1.0] 46 | layer_aggregation: tp.Literal["group_mean"] | None = "group_mean" 47 | 48 | name: tp.Literal["LLAMA3p2"] = "LLAMA3p2" 49 | 50 | model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") 51 | infra: MapInfra = MapInfra() 52 | 53 | _model: nn.Module = pydantic.PrivateAttr() 54 | _tokenizer: nn.Module = pydantic.PrivateAttr() 55 | _pad_id: int = pydantic.PrivateAttr() 56 | 57 | def model_post_init(self, log__: tp.Any) -> None: 58 | super().model_post_init(log__) 59 | self._event_types_helper = EventTypesHelper("Word") 60 | if self.device == "auto": 61 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 62 | 63 | def prepare( 64 | self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] 65 | ) -> None: 66 | from data_utils import helpers 67 | 68 | events = helpers.extract_events(obj, types=self._event_types_helper) 69 | 70 | self._get_data(events) 71 | if events: 72 | 73 | self( 74 | events[0], 75 | start=events[0].start, 76 | duration=0.001, 77 | trigger=events[0].to_dict(), 78 | ) 79 | 80 | def __call__( 81 | self, 82 | events: tp.Any, 83 | start: float, 84 | duration: float, 85 | trigger: float | dict[str, tp.Any] | None = None, 86 | ) -> torch.Tensor: 87 | _input_events = events 88 | 89 | from data_utils import helpers 90 | 91 | assert duration >= 0.0, f"{duration} must be >= 0." 92 | events = helpers.extract_events(events, types=self._event_types_helper) 93 | 94 | if not events and self._missing_default is not None: 95 | default = self._missing_default 96 | freq = Frequency(2.0) 97 | if freq: 98 | n_times = max(1, freq.to_ind(duration)) 99 | reps = [1 for _ in range(default.ndim)] + [n_times] 100 | default = default.unsqueeze(-1).repeat(reps) 101 | return default 102 | 103 | tarrays = list( 104 | self._get_timed_arrays(events=events, start=start, duration=duration) 105 | ) 106 | 107 | time_info: dict[str, tp.Any] = { 108 | "start": start, 109 | "frequency": 2.0, 110 | "duration": duration, 111 | } 112 | 113 | out = TimedArray(aggregation="sum", **time_info) 114 | for ta in tarrays: 115 | out += ta 116 | tensor = torch.from_numpy(out.data) 117 | if not tensor.ndim: 118 | tensor = tensor.unsqueeze(0) 119 | 120 | if self._missing_default is None: 121 | 122 | shape = tuple(tensor.shape[:-1]) 123 | self._missing_default = torch.zeros(*shape, dtype=tensor.dtype) 124 | return tensor 125 | 126 | 127 | device: tp.Literal["auto", "cpu", "cuda"] = "auto" 128 | 129 | def _aggregate_layers(self, latents: np.ndarray) -> np.ndarray: 130 | layer_indices = np.unique( 131 | [int(i * (latents.shape[0] - 1)) for i in self.layers] 132 | ).tolist() 133 | 134 | if len(layer_indices) == 1: 135 | if self.layer_aggregation is None: 136 | return latents[layer_indices[0]][None, :] 137 | else: 138 | return latents[layer_indices[0]] 139 | else: 140 | if self.layer_aggregation == "group_mean": 141 | groups = [] 142 | layer_indices[-1] += 1 143 | for l1, l2 in zip(layer_indices[:-1], layer_indices[1:]): 144 | groups.append(latents[l1:l2].mean(0)) 145 | return np.stack(groups) 146 | elif self.layer_aggregation is None: 147 | return latents[layer_indices] 148 | else: 149 | raise ValueError(f"Unknown layer aggregation: {self.layer_aggregation}") 150 | 151 | 152 | 153 | @classmethod 154 | def _exclude_from_cls_uid(cls) -> tp.List[str]: 155 | return ["device"] 156 | 157 | def _exclude_from_cache_uid(self) -> tp.List[str]: 158 | return ["device"] + ["layers", "layer_aggregation"] 159 | 160 | @property 161 | def model(self) -> nn.Module: 162 | if not hasattr(self, "_model"): 163 | from transformers import AutoModel, AutoTokenizer 164 | 165 | kwargs: dict[str, tp.Any] = {} 166 | self._tokenizer = AutoTokenizer.from_pretrained( 167 | "meta-llama/Llama-3.2-3B", truncation_side="left", **kwargs 168 | ) 169 | Model = AutoModel 170 | 171 | if self.device == "accelerate": 172 | kwargs = {"device_map": "auto", "torch_dtype": torch.float16} 173 | self._model = Model.from_pretrained("meta-llama/Llama-3.2-3B", **kwargs) 174 | if self.device != "accelerate": 175 | self._model.to(self.device) 176 | self._model.eval() 177 | 178 | if self._tokenizer.pad_token is None: 179 | 180 | self._tokenizer.pad_token = self._tokenizer.eos_token 181 | self._pad_id = self.tokenizer.eos_token_id 182 | 183 | return self._model 184 | 185 | @property 186 | def tokenizer(self) -> nn.Module: 187 | self.model 188 | return self._tokenizer 189 | 190 | def _get_timed_arrays( 191 | self, events: list[du.events.Word], start: float, duration: float 192 | ) -> tp.Iterable[TimedArray]: 193 | 194 | for event, latent in zip(events, self._get_data(events)): 195 | latent = self._aggregate_layers(latent) 196 | ta = TimedArray( 197 | frequency=0, 198 | duration=event.duration, 199 | start=event.start, 200 | data=latent, 201 | ) 202 | yield ta 203 | 204 | @infra.apply( 205 | item_uid=lambda event: f"{event.text}_{event.context}", 206 | exclude_from_cache_uid="method:_exclude_from_cache_uid", 207 | cache_type="MemmapArrayFile", 208 | ) 209 | def _get_data(self, events: tp.List[du.events.Word]) -> tp.Iterator[np.ndarray]: 210 | dataset = TextDataset(events) 211 | dloader = DataLoader(dataset, batch_size=8, shuffle=False) 212 | 213 | if len(dloader) > 1: 214 | dloader = tqdm(dloader, desc="Computing word embeddings") 215 | 216 | device = "auto" if self.device == "accelerate" else self.device 217 | if device == "auto": 218 | device = "cuda" if torch.cuda.is_available() else "cpu" 219 | for target_words, context in dloader: 220 | 221 | with environment_variables(TOKENIZERS_PARALLELISM="false"): 222 | text = context 223 | if isinstance(text, tuple): 224 | 225 | text = list(text) 226 | inputs = self.tokenizer( 227 | text, 228 | add_special_tokens=False, 229 | return_tensors="pt", 230 | padding=True, 231 | truncation=True, 232 | ).to(device) 233 | with torch.no_grad(): 234 | outputs = self.model(**inputs, output_hidden_states=True) 235 | if "hidden_states" in outputs: 236 | states = outputs.hidden_states 237 | else: 238 | 239 | states = outputs.encoder_hidden_states + outputs.decoder_hidden_states 240 | hidden_states = torch.stack([layer.cpu() for layer in states]) 241 | n_layers, n_batch, n_tokens, n_dims = hidden_states.shape 242 | 243 | for i, target_word in enumerate(target_words): 244 | 245 | hidden_state = hidden_states[:, i] 246 | 247 | n_pads = sum(inputs["input_ids"][i].cpu().numpy() == self._pad_id) 248 | 249 | if n_pads: 250 | hidden_state = hidden_state[:, :-n_pads] 251 | 252 | word_state = hidden_state[:, -len(target_word) :] 253 | 254 | word_state = word_state.mean(axis=1) 255 | out = word_state.cpu().numpy() 256 | yield out 257 | -------------------------------------------------------------------------------- /data_utils/data_utils/features/audio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import typing as tp 9 | import warnings 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import pydantic 14 | import torch 15 | from torch import nn 16 | from torch.nn import functional as F 17 | 18 | import data_utils as du 19 | from data_utils.base import Frequency, TimedArray 20 | from data_utils.events import Event, EventTypesHelper 21 | from data_utils.infra import MapInfra 22 | from data_utils.segments import Segment 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class Wav2VecBert(pydantic.BaseModel): 28 | 29 | name: tp.Literal["Wav2VecBert"] = "Wav2VecBert" 30 | 31 | 32 | device: tp.Literal["auto", "cpu", "cuda"] = "auto" 33 | layer_aggregation: tp.Literal["group_mean"] | None = "group_mean" 34 | 35 | _model: nn.Module 36 | _feature_extractor: nn.Module 37 | 38 | infra: MapInfra = MapInfra() 39 | _event_types_helper: EventTypesHelper 40 | _missing_default: torch.Tensor | None = None 41 | layers: list[float] = [0.5, 0.75, 1.0] 42 | model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") 43 | 44 | def _get_sound_model(self) -> torch.nn.Module: 45 | from transformers import Wav2Vec2BertModel 46 | 47 | _model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") 48 | _model.to(self.device) 49 | _model.eval() 50 | return _model 51 | 52 | 53 | def model_post_init(self, log__: tp.Any) -> None: 54 | super().model_post_init(log__) 55 | self._event_types_helper = EventTypesHelper("Sound") 56 | if self.device == "auto": 57 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 58 | 59 | def prepare( 60 | self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] 61 | ) -> None: 62 | from data_utils import helpers 63 | 64 | events = helpers.extract_events(obj, types=self._event_types_helper) 65 | 66 | self._get_data(events) 67 | if events: 68 | 69 | self( 70 | events[0], 71 | start=events[0].start, 72 | duration=0.001, 73 | trigger=events[0].to_dict(), 74 | ) 75 | 76 | def __call__( 77 | self, 78 | events: tp.Any, 79 | start: float, 80 | duration: float, 81 | trigger: float | dict[str, tp.Any] | None = None, 82 | ) -> torch.Tensor: 83 | _input_events = events 84 | 85 | from data_utils import helpers 86 | 87 | events = helpers.extract_events(events, types=self._event_types_helper) 88 | 89 | if not events and self._missing_default is not None: 90 | default = self._missing_default 91 | freq = Frequency(2.0) 92 | if freq: 93 | n_times = max(1, freq.to_ind(duration)) 94 | reps = [1 for _ in range(default.ndim)] + [n_times] 95 | default = default.unsqueeze(-1).repeat(reps) 96 | return default 97 | 98 | 99 | 100 | tarrays = list( 101 | self._get_timed_arrays(events=events, start=start, duration=duration) 102 | ) 103 | 104 | time_info: dict[str, tp.Any] = { 105 | "start": start, 106 | "frequency": 2.0, 107 | "duration": duration, 108 | } 109 | out = TimedArray(aggregation="sum", **time_info) 110 | for ta in tarrays: 111 | out += ta 112 | tensor = torch.from_numpy(out.data) 113 | if not tensor.ndim: 114 | tensor = tensor.unsqueeze(0) 115 | 116 | if self._missing_default is None: 117 | 118 | shape = tuple(tensor.shape[:-1]) 119 | self._missing_default = torch.zeros(*shape, dtype=tensor.dtype) 120 | return tensor 121 | 122 | 123 | def _preprocess_wav(self, wav: torch.Tensor) -> torch.Tensor: 124 | wav = torch.mean(wav, dim=1) 125 | 126 | wav = (wav - wav.mean()) / (1e-8 + wav.std()) 127 | return wav 128 | 129 | def _resample_wav( 130 | self, wav: torch.Tensor, old_frequency: float, new_frequency: float 131 | ) -> torch.Tensor: 132 | old_frequency, new_frequency = int(old_frequency), int(new_frequency) 133 | import julius 134 | 135 | wav = julius.resample.ResampleFrac(old_sr=old_frequency, new_sr=new_frequency)( 136 | wav.T 137 | ).T 138 | return wav 139 | 140 | @infra.apply( 141 | item_uid=lambda event: f"{event.filepath}_{event.offset:.2f}_{event.duration:.2f}", 142 | exclude_from_cache_uid="method:_exclude_from_cache_uid", 143 | cache_type="MemmapArrayFile", 144 | ) 145 | def _get_data(self, events: list[du.events.Event]) -> tp.Iterator[np.ndarray]: 146 | if len(events) > 1: 147 | from tqdm import tqdm 148 | 149 | events = tqdm(events, desc="Computing audio embeddings") 150 | 151 | for event in events: 152 | if isinstance(event, du.events.Sound): 153 | wav = event.read() 154 | sfreq = event.frequency 155 | elif isinstance(event, du.events.Video): 156 | audio = event.read().audio 157 | wav = torch.tensor(audio.to_soundarray(), dtype=torch.float32) 158 | sfreq = audio.fps 159 | wav = self._resample_wav(wav, sfreq, self._input_frequency) 160 | wav = self._preprocess_wav(wav) 161 | latents = self._process_wav(wav) 162 | 163 | timepoints = Frequency(2.0).to_ind(event.duration) 164 | 165 | if abs(timepoints - latents.shape[-1]) > 0: 166 | if len(latents.shape) == 2: 167 | 168 | latents = F.interpolate(latents[None], timepoints)[0] 169 | else: 170 | 171 | latents = F.interpolate(latents, timepoints) 172 | yield latents.numpy() 173 | 174 | def _aggregate_layers(self, latents: np.ndarray) -> np.ndarray: 175 | layer_indices = np.unique( 176 | [int(i * (latents.shape[0] - 1)) for i in self.layers] 177 | ).tolist() 178 | 179 | if len(layer_indices) == 1: 180 | if self.layer_aggregation is None: 181 | return latents[layer_indices[0]][None, :] 182 | else: 183 | return latents[layer_indices[0]] 184 | else: 185 | if self.layer_aggregation == "group_mean": 186 | groups = [] 187 | layer_indices[-1] += 1 188 | for l1, l2 in zip(layer_indices[:-1], layer_indices[1:]): 189 | groups.append(latents[l1:l2].mean(0)) 190 | return np.stack(groups) 191 | elif self.layer_aggregation is None: 192 | return latents[layer_indices] 193 | else: 194 | raise ValueError(f"Unknown layer aggregation: {self.layer_aggregation}") 195 | 196 | @property 197 | def _input_frequency(self) -> float: 198 | return getattr(self.feature_extractor, "sampling_rate", 16_000) 199 | 200 | @classmethod 201 | def _exclude_from_cls_uid(cls) -> list[str]: 202 | return ["device"] 203 | 204 | def _exclude_from_cache_uid(self) -> list[str]: 205 | return ["device"] + ["layers", "layer_aggregation"] 206 | 207 | @property 208 | def feature_extractor(self) -> nn.Module: 209 | if not hasattr(self, "_feature_extractor"): 210 | self._feature_extractor = self._get_feature_extractor() 211 | return self._feature_extractor 212 | 213 | @property 214 | def model(self) -> nn.Module: 215 | if not hasattr(self, "_model"): 216 | self._model = self._get_sound_model() 217 | return self._model 218 | 219 | def _get_feature_extractor(self) -> torch.nn.Module: 220 | from transformers import AutoFeatureExtractor 221 | 222 | return AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") 223 | 224 | def _get_features(self, wav): 225 | out = self._feature_extractor( 226 | wav, 227 | return_tensors="pt", 228 | sampling_rate=self.feature_extractor.sampling_rate, 229 | do_normalize=True, 230 | ) 231 | try: 232 | return out["input_features"] 233 | except KeyError: 234 | return out["input_values"] 235 | 236 | def _get_timed_arrays( 237 | self, events: list[du.events.Event], start: float, duration: float 238 | ) -> tp.Iterable[TimedArray]: 239 | freq = 2.0 240 | for latent, event in zip(self._get_data(events), events): 241 | if freq is None: 242 | 243 | freq = latent.shape[-1] / event.duration 244 | 245 | tdata = TimedArray(data=latent, start=event.start, frequency=freq) 246 | sub = tdata.overlap(start=start, duration=duration) 247 | if sub is None: 248 | 249 | sub = tdata.overlap(start=tdata.start, duration=0) 250 | sub.data = self._aggregate_layers(sub.data) 251 | yield sub 252 | 253 | def _process_wav(self, wav: torch.Tensor) -> torch.Tensor: 254 | features = self._get_features(wav) 255 | with torch.no_grad(): 256 | outputs = self.model(features.to(self.device), output_hidden_states=True) 257 | out: tp.Any = outputs.get("hidden_states") 258 | if isinstance(out, tuple): 259 | out = torch.stack(out) 260 | 261 | out = out.squeeze(1).detach().cpu().clone().transpose(-1, -2).numpy() 262 | 263 | return torch.Tensor(out) 264 | -------------------------------------------------------------------------------- /data_utils/data_utils/features/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import typing as tp 8 | import warnings 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import pydantic 13 | import torch 14 | import torch.nn as nn 15 | from exca import MapInfra 16 | from tqdm import tqdm 17 | 18 | import data_utils.events as evts 19 | from data_utils.base import Frequency, TimedArray 20 | from data_utils.events import Event, EventTypesHelper 21 | from data_utils.segments import Segment 22 | from data_utils.utils import ignore_all 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def _fix_pixel_values(inputs: dict[str, tp.Any]) -> None: 28 | if "pixel_values" in inputs: 29 | nans = inputs["pixel_values"].isnan() 30 | if nans.any(): 31 | inputs["pixel_values"][nans] = 0 32 | inputs["pixel_values"] = inputs["pixel_values"].float() 33 | 34 | 35 | class _VideoImage(evts.Image): 36 | 37 | start: float = 0.0 38 | timeline: str = "fake" 39 | duration: float = 1.0 40 | video: tp.Any 41 | time: float = 0.0 42 | filepath: str = "" 43 | 44 | def model_post_init(self, log__: tp.Any) -> None: 45 | self.filepath = f"{self.video.filename}:{self.time:.3f}" 46 | super().model_post_init(log__) 47 | 48 | def _read(self) -> tp.Any: 49 | import PIL 50 | 51 | with ignore_all(): 52 | img = self.video.get_frame(self.time) 53 | return PIL.Image.fromarray(img.astype("uint8")) 54 | 55 | 56 | class VJEPA2(pydantic.BaseModel): 57 | _event_types_helper: EventTypesHelper 58 | _missing_default: torch.Tensor | None = None 59 | model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") 60 | layers: list[float] = [0.5, 0.75, 1.0] 61 | layer_aggregation: tp.Literal["group_mean"] | None = "group_mean" 62 | 63 | name: tp.Literal["VJEPA2"] = "VJEPA2" 64 | device: tp.Literal["auto", "cpu", "cuda"] = "auto" 65 | 66 | _model: nn.Module = pydantic.PrivateAttr() 67 | 68 | infra: MapInfra = MapInfra() 69 | 70 | @classmethod 71 | def __pydantic_init_subclass__(cls, **kwargs: tp.Any) -> None: 72 | super().__pydantic_init_subclass__(**kwargs) 73 | 74 | super().__init_subclass__() 75 | 76 | def model_post_init(self, log__: tp.Any) -> None: 77 | super().model_post_init(log__) 78 | self._event_types_helper = EventTypesHelper("Video") 79 | if self.device == "auto": 80 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 81 | 82 | def prepare( 83 | self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] 84 | ) -> None: 85 | from data_utils import helpers 86 | 87 | events = helpers.extract_events(obj, types=self._event_types_helper) 88 | self._get_data(events) 89 | if events: 90 | 91 | self( 92 | events[0], 93 | start=events[0].start, 94 | duration=0.001, 95 | trigger=events[0].to_dict(), 96 | ) 97 | 98 | def __call__( 99 | self, 100 | events: tp.Any, 101 | start: float, 102 | duration: float, 103 | trigger: float | dict[str, tp.Any] | None = None, 104 | ) -> torch.Tensor: 105 | _input_events = events 106 | 107 | from data_utils import helpers 108 | 109 | assert duration >= 0.0, f"{duration} must be >= 0." 110 | event_types = self._event_types_helper.classes 111 | name = self.__class__.__name__ 112 | events = helpers.extract_events(events, types=self._event_types_helper) 113 | 114 | if not events and self._missing_default is not None: 115 | default = self._missing_default 116 | freq = Frequency(2.0) 117 | if freq: 118 | n_times = max(1, freq.to_ind(duration)) 119 | reps = [1 for _ in range(default.ndim)] + [n_times] 120 | default = default.unsqueeze(-1).repeat(reps) 121 | return default 122 | 123 | 124 | 125 | tarrays = list( 126 | self._get_timed_arrays(events=events, start=start, duration=duration) 127 | ) 128 | 129 | time_info: dict[str, tp.Any] = { 130 | "start": start, 131 | "frequency": 2.0, 132 | "duration": duration, 133 | } 134 | out = TimedArray(aggregation="sum", **time_info) 135 | for ta in tarrays: 136 | out += ta 137 | tensor = torch.from_numpy(out.data) 138 | if not tensor.ndim: 139 | tensor = tensor.unsqueeze(0) 140 | 141 | if self._missing_default is None: 142 | 143 | shape = tuple(tensor.shape[:-1]) 144 | self._missing_default = torch.zeros(*shape, dtype=tensor.dtype) 145 | return tensor 146 | 147 | def _aggregate_layers(self, latents: np.ndarray) -> np.ndarray: 148 | layer_indices = np.unique( 149 | [int(i * (latents.shape[0] - 1)) for i in self.layers] 150 | ).tolist() 151 | 152 | if len(layer_indices) == 1: 153 | if self.layer_aggregation is None: 154 | return latents[layer_indices[0]][None, :] 155 | else: 156 | return latents[layer_indices[0]] 157 | else: 158 | if self.layer_aggregation == "group_mean": 159 | groups = [] 160 | layer_indices[-1] += 1 161 | for l1, l2 in zip(layer_indices[:-1], layer_indices[1:]): 162 | groups.append(latents[l1:l2].mean(0)) 163 | return np.stack(groups) 164 | elif self.layer_aggregation is None: 165 | return latents[layer_indices] 166 | else: 167 | raise ValueError(f"Unknown layer aggregation: {self.layer_aggregation}") 168 | 169 | def _exclude_from_cache_uid(self) -> list[str]: 170 | return ["device"] + ["layers", "layer_aggregation"] 171 | 172 | def _get_timed_arrays( 173 | self, events: list[evts.Video], start: float, duration: float 174 | ) -> tp.Iterable[TimedArray]: 175 | for event, latent in zip(events, self._get_data(events)): 176 | freq = 2.0 177 | ta: TimedArray = TimedArray( 178 | data=latent, 179 | frequency=freq, 180 | start=event.start, 181 | duration=event.duration, 182 | ) 183 | 184 | sub = ta.overlap(start=start, duration=duration) 185 | if sub is None: 186 | 187 | sub = ta.overlap(start=ta.start, duration=0) 188 | sub.data = self._aggregate_layers(sub.data) 189 | yield sub 190 | 191 | @infra.apply( 192 | item_uid=lambda event: f"{event.filepath}_{event.offset:.2f}_{event.duration:.2f}", 193 | exclude_from_cache_uid="method:_exclude_from_cache_uid", 194 | cache_type="MemmapArrayFile", 195 | ) 196 | def _get_data(self, events: tp.List[evts.Video]) -> tp.Iterator[np.ndarray]: 197 | logging.getLogger("data_utils").setLevel(logging.DEBUG) 198 | 199 | model = VideoModel() 200 | if model.model.device.type == "cpu": 201 | model.model.to(self.device) 202 | 203 | subtimes = list( 204 | k / model.num_frames * 4.0 for k in reversed(range(model.num_frames)) 205 | ) 206 | 207 | for event in events: 208 | video = event.read() 209 | expect_frames = Frequency(2.0).to_ind(event.duration) 210 | logger.debug( 211 | "Loaded Video (duration %ss at %sfps, shape %s):\n%s", 212 | video.duration, 213 | video.fps, 214 | tuple(video.size), 215 | event.filepath, 216 | ) 217 | 218 | times = np.linspace(0, video.duration, expect_frames + 1)[1:] 219 | 220 | output = np.array([]) 221 | 222 | for k, t in tqdm(enumerate(times), total=len(times), desc="Encoding video"): 223 | ims = [_VideoImage(video=video, time=max(0, t - t2)) for t2 in subtimes] 224 | data = np.array([np.array(i.read()) for i in ims]) 225 | t_embd = model.predict_hidden_states(data) 226 | t_embd = t_embd[0] 227 | 228 | embd = t_embd.mean(axis=1).cpu().numpy() 229 | if not output.size: 230 | output = np.zeros((len(times),) + embd.shape) 231 | logger.debug("Created Tensor with size %s", output.shape) 232 | output[k] = embd 233 | video.close() 234 | 235 | output = output.transpose(list(range(1, output.ndim)) + [0]) 236 | yield output 237 | 238 | 239 | class VideoModel: 240 | def __init__( 241 | self, 242 | ) -> None: 243 | super().__init__() 244 | from transformers import AutoModel as Model 245 | from transformers import AutoVideoProcessor as Processor 246 | 247 | self.model = Model.from_pretrained( 248 | "facebook/vjepa2-vitg-fpc64-256", output_hidden_states=True 249 | ) 250 | self.model.eval() 251 | 252 | self.processor = Processor.from_pretrained( 253 | "facebook/vjepa2-vitg-fpc64-256", do_rescale=True 254 | ) 255 | self.num_frames = 64 256 | 257 | def predict(self, images: np.ndarray) -> tp.Any: 258 | kwargs: dict[str, tp.Any] = {"text": "", "return_tensors": "pt"} 259 | field = "videos" 260 | del kwargs["text"] 261 | kwargs[field] = list(images) 262 | inputs = self.processor(**kwargs) 263 | 264 | _fix_pixel_values(inputs) 265 | inputs = inputs.to(self.model.device) 266 | with torch.inference_mode(): 267 | pred = self.model(**inputs) 268 | return pred 269 | 270 | def predict_hidden_states(self, images: np.ndarray) -> torch.Tensor: 271 | pred = self.predict(images) 272 | states = pred.hidden_states 273 | out = torch.cat([x.unsqueeze(1) for x in states], axis=1) 274 | return out 275 | -------------------------------------------------------------------------------- /data_utils/data_utils/events.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import functools 7 | import inspect 8 | import logging 9 | import typing as tp 10 | import urllib 11 | from abc import abstractmethod 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import pydantic 17 | 18 | from .base import Frequency, StrCast 19 | from .utils import ignore_all, warn_once 20 | 21 | E = tp.TypeVar("E", bound="Event") 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class Event(pydantic.BaseModel): 26 | 27 | start: float 28 | timeline: str 29 | duration: pydantic.NonNegativeFloat = 0.0 30 | extra: dict[str, tp.Any] = {} 31 | type: tp.ClassVar[str] = "Event" 32 | _CLASSES: tp.ClassVar[dict[str, tp.Type["Event"]]] = {} 33 | _index: int | None = None 34 | 35 | def __init_subclass__(cls) -> None: 36 | super().__init_subclass__() 37 | 38 | cls.type = cls.__name__ 39 | Event._CLASSES[cls.__name__] = cls 40 | 41 | def model_post_init(self, log__: tp.Any) -> None: 42 | super().model_post_init(log__) 43 | if pd.isna(self.start): 44 | raise ValueError(f"Start time needs to be provided for {self!r}") 45 | 46 | @classmethod 47 | def from_dict(cls: tp.Type[E], row: tp.Any) -> E: 48 | 49 | index: int | None = None 50 | if hasattr(row, "_asdict"): 51 | index = getattr(row, "Index", None) 52 | row = row._asdict() 53 | 54 | cls_ = cls._CLASSES[row["type"]] 55 | if not issubclass(cls_, cls): 56 | raise TypeError(f"{cls_} is not a subclass of {cls}") 57 | fs = set(cls_.model_fields) 58 | 59 | kwargs: dict[str, tp.Any] = {} 60 | extra = {} 61 | for k, v in row.items(): 62 | if pd.isna(v): 63 | 64 | continue 65 | if k in fs: 66 | kwargs[k] = v 67 | elif k != "type": 68 | if k.startswith("extra__"): 69 | 70 | k = k[7:] 71 | extra[k] = v 72 | kwargs.setdefault("extra", {}).update(extra) 73 | 74 | try: 75 | out = cls_(**kwargs) 76 | except Exception as e: 77 | logger.warning( 78 | "Event.from_dict parsing failed for input %s\nmapped to %s\n with error: %s)", 79 | row.to_string() if hasattr(row, "to_string") else row, 80 | kwargs, 81 | e, 82 | ) 83 | raise 84 | out._index = index 85 | return out 86 | 87 | def to_dict(self) -> dict[str, tp.Any]: 88 | 89 | out = dict(self.extra) 90 | out["type"] = self.type 91 | 92 | tag = "extra" 93 | fields = {x: str(y) if isinstance(y, Path) else y for x, y in self if x != tag} 94 | out.update(fields) 95 | return out 96 | 97 | @property 98 | def stop(self) -> float: 99 | return self.start + self.duration 100 | 101 | def __str__(self) -> str: 102 | core_fields = {k: v for k, v in self if k != "extra"} 103 | return ", ".join([f"{k}={v}" for k, v in core_fields.items()]) 104 | 105 | 106 | Event._CLASSES["Event"] = Event 107 | 108 | 109 | class EventTypesHelper: 110 | 111 | def __init__(self, event_types: str | tp.Type[Event] | tp.Sequence[str]) -> None: 112 | self.specified = event_types 113 | if inspect.isclass(event_types): 114 | self.classes: tp.Tuple[tp.Type[Event], ...] = (event_types,) 115 | else: 116 | if isinstance(event_types, str): 117 | event_types = (event_types,) 118 | try: 119 | self.classes = tuple(Event._CLASSES[x] for x in event_types) 120 | 121 | except KeyError as e: 122 | avail = list(Event._CLASSES) 123 | msg = f"{event_types} is an invalid event name, use one of {avail}" 124 | raise ValueError(msg) from e 125 | items = Event._CLASSES.items() 126 | self.names = [x for x, y in items if issubclass(y, self.classes)] 127 | 128 | 129 | class BaseDataEvent(Event): 130 | 131 | filepath: Path | str = "" 132 | frequency: float = 0 133 | _read_method: tp.Any = None 134 | 135 | def model_post_init(self, log__: tp.Any) -> None: 136 | super().model_post_init(log__) 137 | if not self.filepath: 138 | raise ValueError("A filepath must be provided") 139 | 140 | self._set_read_method() 141 | fp = str(self.filepath) 142 | self.filepath = fp 143 | if ":" not in str(fp): 144 | 145 | if not Path(fp).exists(): 146 | warn_once(f"file missing: {fp}") 147 | 148 | def _set_read_method(self) -> None: 149 | try: 150 | if getattr(self, "_read_method", None) is not None: 151 | return 152 | except TypeError: 153 | 154 | pass 155 | 156 | tag = "method:" 157 | fp = str(self.filepath) 158 | if not fp.startswith(tag): 159 | self._read_method = self._read 160 | return 161 | 162 | from .data import TIMELINES 163 | 164 | components = urllib.parse.urlparse(fp) 165 | assert components.netloc == "" 166 | assert components.params == "" 167 | assert components.fragment == "" 168 | 169 | inst = TIMELINES[self.timeline] 170 | kwargs = dict(urllib.parse.parse_qsl(components.query, strict_parsing=True)) 171 | self._read_method = functools.partial(getattr(inst, components.path), **kwargs) 172 | 173 | def __hash__(self) -> int: 174 | return hash(self.to_dict()) 175 | 176 | def __eq__(self, other: tp.Any) -> bool: 177 | if isinstance(other, self.__class__): 178 | return self.__hash__() == other.__hash__() 179 | return False 180 | 181 | def read(self) -> tp.Any: 182 | self._set_read_method() 183 | return self._read_method() 184 | 185 | @abstractmethod 186 | def _read(self) -> tp.Any: 187 | return 188 | 189 | def _missing_duration_or_frequency(self) -> bool: 190 | return any(not x or pd.isna(x) for x in [self.duration, self.frequency]) 191 | 192 | 193 | class BaseSplittableEvent(BaseDataEvent): 194 | 195 | offset: pydantic.NonNegativeFloat = 0.0 196 | 197 | def _split( 198 | self, timepoints: tp.List[float], min_duration: float | None = None 199 | ) -> tp.Sequence["BaseSplittableEvent"]: 200 | 201 | timepoints = [t for t in timepoints if 0 < t < self.duration] 202 | timepoints = sorted(set(timepoints)) 203 | if min_duration: 204 | delta_before = np.diff(timepoints, prepend=0) 205 | delta_after = np.diff(timepoints, append=self.duration) 206 | timepoints = [ 207 | t 208 | for t, db, da in zip(timepoints, delta_before, delta_after) 209 | if db >= min_duration and da >= min_duration 210 | ] 211 | timepoints.append(self.duration) 212 | 213 | start = 0.0 214 | data = dict(self) 215 | cls = self.__class__ 216 | events = [] 217 | for stop in list(timepoints): 218 | if start >= stop: 219 | raise ValueError( 220 | f"Timepoints should be strictly increasing (got {start} and {stop})" 221 | ) 222 | data.update( 223 | start=self.start + start, 224 | duration=stop - start, 225 | offset=self.offset + start, 226 | ) 227 | events.append(cls(**data)) 228 | start = stop 229 | return events 230 | 231 | 232 | class Image(BaseDataEvent): 233 | 234 | caption: str = "" 235 | 236 | def _read(self) -> tp.Any: 237 | 238 | import PIL.Image 239 | 240 | return PIL.Image.open(self.filepath).convert("RGB") 241 | 242 | def model_post_init(self, log__: tp.Any) -> None: 243 | super().model_post_init(log__) 244 | if self.duration <= 0: 245 | logger.info("Image event has null duration and will be ignored.") 246 | 247 | 248 | class Sound(BaseSplittableEvent): 249 | 250 | 251 | def model_post_init(self, log__: tp.Any) -> None: 252 | 253 | if not Path(self.filepath).exists(): 254 | raise ValueError(f"Sound filepath does not exist: {self.filepath}") 255 | if self._missing_duration_or_frequency(): 256 | import soundfile 257 | 258 | info = soundfile.info(str(self.filepath)) 259 | self.frequency = Frequency(info.samplerate) 260 | self.duration = info.duration 261 | super().model_post_init(log__) 262 | 263 | def _read(self) -> tp.Any: 264 | import soundfile 265 | import torch 266 | 267 | sr = Frequency(self.frequency) 268 | offset = sr.to_ind(self.offset) 269 | num = sr.to_ind(self.duration) 270 | fp = str(self.filepath) 271 | wav = soundfile.read(fp, start=offset, frames=num)[0] 272 | out = torch.Tensor(wav) 273 | if out.ndim == 1: 274 | out = out[:, None] 275 | return out 276 | 277 | 278 | class Video(BaseSplittableEvent): 279 | 280 | def model_post_init(self, log__: tp.Any) -> None: 281 | 282 | if not Path(self.filepath).exists(): 283 | raise ValueError(f"Missing video file {self.filepath}") 284 | if self._missing_duration_or_frequency(): 285 | from moviepy import VideoFileClip 286 | 287 | with ignore_all(): 288 | video = VideoFileClip(str(self.filepath)) 289 | self.frequency = Frequency(video.fps) 290 | self.duration = video.duration 291 | video.close() 292 | super().model_post_init(log__) 293 | 294 | def _read(self) -> None: 295 | from moviepy import VideoFileClip 296 | 297 | with ignore_all(): 298 | clip = VideoFileClip(str(self.filepath)) 299 | start, end = self.offset, self.offset + self.duration 300 | assert end <= clip.duration 301 | clip = clip.subclipped(start, end) 302 | return clip 303 | 304 | 305 | class BaseText(Event): 306 | 307 | language: str = "" 308 | text: str = pydantic.Field("", min_length=1) 309 | context: str = "" 310 | 311 | 312 | class Text(BaseText): 313 | pass 314 | 315 | 316 | class Sentence(BaseText): 317 | pass 318 | 319 | 320 | class Word(BaseText): 321 | sentence: str = "" 322 | 323 | sentence_char: int | None = None 324 | 325 | 326 | class Phoneme(BaseText): 327 | pass 328 | 329 | 330 | class Fmri(BaseDataEvent): 331 | subject: StrCast = "" 332 | 333 | def model_post_init(self, log__: tp.Any) -> None: 334 | self.subject = str(self.subject) 335 | 336 | if self._missing_duration_or_frequency(): 337 | raise ValueError( 338 | "Duration and frequency must be provided for Fmri event: " 339 | "Don't rely on get_zooms as the header is sometimes unreliable.\n" 340 | f"Got: {self}" 341 | ) 342 | if not self.subject: 343 | raise ValueError("Missing 'subject' field") 344 | super().model_post_init(log__) 345 | 346 | def _read(self) -> tp.Any: 347 | import nibabel 348 | 349 | nii_img = nibabel.load(self.filepath, mmap=True) 350 | if nii_img.ndim not in (4, 2): 351 | 352 | msg = f"{self.filepath} should be 2D or 4D with time the last dim." 353 | raise ValueError(msg) 354 | return nii_img 355 | -------------------------------------------------------------------------------- /algonauts2025/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import typing as tp 10 | from pathlib import Path 11 | 12 | import lightning.pytorch as pl 13 | import numpy as np 14 | import pandas as pd 15 | import pydantic 16 | import torch 17 | import yaml 18 | from exca import ConfDict, TaskInfra 19 | from lightning.pytorch.callbacks import ( 20 | EarlyStopping, 21 | LearningRateMonitor, 22 | ModelCheckpoint, 23 | StochasticWeightAveraging, 24 | ) 25 | from lightning.pytorch.loggers import WandbLogger 26 | from torch import nn 27 | from torch.utils.data import DataLoader 28 | 29 | import data_utils as du 30 | from data_utils.data import StudyLoader 31 | from data_utils.events import EventTypesHelper 32 | from data_utils.features.audio import Wav2VecBert 33 | from data_utils.features.neuro import Fmri 34 | from data_utils.features.text import LLAMA3p2 35 | from data_utils.features.video import VJEPA2 36 | from data_utils.helpers import prepare_features 37 | from data_utils.splitting import DeterministicSplitter 38 | from modeling_utils.losses import LossConfig 39 | from modeling_utils.metrics import MetricConfig 40 | from modeling_utils.optimizers.base import LightningOptimizerConfig 41 | from modeling_utils.utils import WandbLoggerConfig 42 | from tqdm import tqdm 43 | from einops import rearrange 44 | from scipy.stats import pearsonr 45 | 46 | from .callbacks import Benchmark, JitterWindows 47 | from .model import FmriEncoder, FmriEncoderConfig 48 | from .pl_module import BrainModule 49 | 50 | dummy = FmriEncoder 51 | 52 | 53 | # Configure logger 54 | LOGGER = logging.getLogger(__name__) 55 | _handler = logging.StreamHandler() 56 | _formatter = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s", "%H:%M:%S") 57 | _handler.setFormatter(_formatter) 58 | if not LOGGER.handlers: 59 | LOGGER.addHandler(_handler) 60 | LOGGER.setLevel(logging.INFO) 61 | 62 | 63 | class Data(pydantic.BaseModel): 64 | """Handles configuration and creation of DataLoaders from dataset and features.""" 65 | 66 | model_config = pydantic.ConfigDict(extra="forbid") 67 | 68 | study: StudyLoader 69 | neuro: Fmri 70 | text_feature: LLAMA3p2 | None = None 71 | audio_feature: Wav2VecBert | None = None 72 | video_feature: VJEPA2 | None = None 73 | layers: list[float] | None = None 74 | layer_aggregation: tp.Literal["group_mean"] | None = None 75 | num_workers: int = 0 76 | 77 | def model_post_init(self, __context): 78 | super().model_post_init(__context) 79 | for modality in ["text", "audio", "video"]: 80 | feature = getattr(self, f"{modality}_feature") 81 | if self.layers is not None: 82 | setattr(feature, "layers", self.layers) 83 | if self.layer_aggregation is not None: 84 | setattr(feature, "layer_aggregation", self.layer_aggregation) 85 | 86 | def get_events(self) -> pd.DataFrame: 87 | events = self.study.build() 88 | 89 | if "split" not in events.columns: 90 | events["split"] = "train" 91 | 92 | train_sel = events.split == "train" 93 | splitter = DeterministicSplitter(ratios={"train": 1 - 0.1, "val": 0.1}) 94 | values = events.loc[train_sel]["chunk"].unique() 95 | splits = [splitter(value) for value in values] 96 | if splits and "val" not in splits: 97 | splits[-1] = "val" # need at least one val split 98 | events.loc[train_sel, "split"] = events.loc[train_sel]["chunk"].map( 99 | dict(zip(values, splits)) 100 | ) 101 | # check that all rows have split assigned 102 | unassigned_events = events[events.split.isna()] 103 | if len(unassigned_events) > 0: 104 | msg = f"The following events do not have a split assigned: {unassigned_events.type.unique()}" 105 | if any( 106 | [ 107 | name.capitalize() in unassigned_events.type.unique() 108 | for name in ["Fmri", "text", "audio", "video"] 109 | ] 110 | ): 111 | raise ValueError(msg) 112 | else: 113 | LOGGER.warning(msg) 114 | 115 | cols = ["index", "subject", "timeline"] 116 | if "movie" in events.columns: 117 | cols.append("movie") 118 | if "chunk" in events.columns: 119 | cols.append("chunk") 120 | event_summary = events.reset_index().groupby(["split", "type"])[cols].nunique() 121 | LOGGER.info("Event summary: \n%s", event_summary) 122 | return events 123 | 124 | def get_loaders( 125 | self, 126 | events: pd.DataFrame | None = None, 127 | split_to_build: ( 128 | tp.Literal["train", "val", "test", "all"] 129 | | list[tp.Literal["train", "val", "test", "all"]] 130 | | None 131 | ) = None, 132 | ) -> tuple[dict[str, DataLoader], int]: 133 | 134 | if events is None: 135 | events = self.get_events() 136 | features = {} 137 | for modality in ["text", "audio", "video"]: 138 | features[modality] = getattr(self, f"{modality}_feature") 139 | if "Fmri" in events.type.unique(): 140 | features["fmri"] = self.neuro 141 | subject_id = du.features.SubjectEncoder() 142 | features["subject_id"] = subject_id 143 | 144 | features_to_type = { 145 | "text": "Word", 146 | "audio": "Sound", 147 | "video": "Video", 148 | "fmri": "Fmri", 149 | "subject_id": "Event", 150 | } 151 | 152 | features_to_remove = set() 153 | for feature_name, feature in features.items(): 154 | event_types = EventTypesHelper(features_to_type[feature_name]).names 155 | if not any( 156 | [event_type in events.type.unique() for event_type in event_types] 157 | ): 158 | features_to_remove.add(feature_name) 159 | for feature_name in features_to_remove: 160 | del features[feature_name] 161 | LOGGER.warning( 162 | "Removing feature %s as there are no corresponding events", feature_name 163 | ) 164 | prepare_features(features, events) 165 | 166 | # Prepare dataloaders 167 | loaders = {} 168 | if isinstance(split_to_build, list): 169 | splits = split_to_build 170 | elif split_to_build is None: 171 | splits = ["train", "val", "test"] 172 | else: 173 | splits = [split_to_build] 174 | for split in splits: 175 | LOGGER.info("Building dataloader for split %s", split) 176 | if split == "all": 177 | sel = [True] * len(events) 178 | shuffle = False 179 | else: 180 | sel = events.split == split 181 | shuffle = { 182 | "train": "train" in events.split.unique(), 183 | "val": "val" in events.split.unique(), 184 | "test": False, 185 | }[split] 186 | segments = du.segments.list_segments( 187 | events[sel], 188 | ) 189 | if len(sel) == 0: 190 | LOGGER.warning("No events found for split %s", split) 191 | continue 192 | dataset = du.SegmentDataset( 193 | features=features, 194 | segments=segments, 195 | ) 196 | dataloader = dataset.build_dataloader( 197 | shuffle=shuffle, 198 | num_workers=self.num_workers, 199 | batch_size=16, 200 | ) 201 | loaders[split] = dataloader 202 | 203 | return loaders 204 | 205 | 206 | class Experiment(pydantic.BaseModel): 207 | 208 | model_config = pydantic.ConfigDict(extra="forbid") 209 | 210 | data: Data 211 | # Reproducibility 212 | seed: int | None = 33 213 | # Model 214 | brain_model_config: FmriEncoderConfig 215 | # Loss 216 | loss: LossConfig 217 | # Optimization 218 | optim: LightningOptimizerConfig 219 | # Metrics 220 | metrics: list[MetricConfig] 221 | monitor: str = "val/pearson" 222 | # Weights & Biases 223 | wandb_config: WandbLoggerConfig | None = None 224 | # Hardware 225 | accelerator: str = "gpu" 226 | # Optim 227 | n_epochs: int = 10 228 | patience: int | None = None 229 | limit_train_batches: int | None = None 230 | # Others 231 | enable_progress_bar: bool = True 232 | log_every_n_steps: int | None = None 233 | fast_dev_run: bool = False 234 | save_checkpoints: bool = True 235 | # Eval 236 | checkpoint_path: str | None = None 237 | test_only: bool = False 238 | 239 | # Internal properties 240 | _trainer: pl.Trainer | None = None 241 | _brain_module: BrainModule | None = None 242 | _logger: WandbLogger | None = None 243 | 244 | # Others 245 | infra: TaskInfra = TaskInfra(version="1") 246 | 247 | def model_post_init(self, __context: tp.Any) -> None: 248 | super().model_post_init(__context) 249 | if self.infra.folder is None: 250 | msg = "infra.folder needs to be specified to save the results." 251 | raise ValueError(msg) 252 | # Update Trainer parameters based on infra 253 | self.infra.tasks_per_node = self.infra.gpus_per_node 254 | self.infra.slurm_use_srun = True if self.infra.gpus_per_node > 1 else False 255 | if self.infra.gpus_per_node > 1: 256 | self.metrics = [ 257 | m for m in self.metrics if m.name not in ["TopkAcc"] 258 | ] # FIXME: TopkAcc is not supported in DDP 259 | 260 | if self.brain_model_config.n_subjects is None: 261 | self.brain_model_config.n_subjects = ( 262 | self.data.study.study_summary().subject.nunique() 263 | ) 264 | 265 | def _get_checkpoint_path(self) -> Path | None: 266 | if self.checkpoint_path: 267 | assert Path( 268 | self.checkpoint_path 269 | ).exists(), f"Checkpoint path {self.checkpoint_path} does not exist." 270 | checkpoint_path = Path(self.checkpoint_path) 271 | else: 272 | checkpoint_path = Path(self.infra.folder) / "last.ckpt" 273 | if not checkpoint_path.exists(): 274 | checkpoint_path = None 275 | return checkpoint_path 276 | 277 | def _init_module(self, model: nn.Module) -> pl.LightningModule: 278 | # Setup torch-lightning module 279 | checkpoint_path = self._get_checkpoint_path() 280 | if checkpoint_path is not None: 281 | LOGGER.info("Loading model from %s", checkpoint_path) 282 | init_fn = BrainModule.load_from_checkpoint 283 | init_kwargs = {"checkpoint_path": checkpoint_path, "strict": False} 284 | else: 285 | init_fn = BrainModule 286 | init_kwargs = {} 287 | 288 | metrics = { 289 | split + "/" + metric.log_name: metric.build() 290 | for metric in self.metrics 291 | for split in ["val", "test"] 292 | } 293 | metrics = nn.ModuleDict(metrics) 294 | pl_module = init_fn( 295 | model=model, 296 | loss=self.loss.build(), 297 | optim_config=self.optim, 298 | metrics=metrics, 299 | max_epochs=self.n_epochs, 300 | config=ConfDict(self.model_dump()), 301 | **init_kwargs, 302 | ) 303 | 304 | return pl_module 305 | 306 | def _setup_trainer(self, train_loader: DataLoader) -> pl.Trainer: 307 | root_data_dir = Path(self.data.study.path) / "algonauts2025" / "download" 308 | # Initialize brain model 309 | batch = next(iter(train_loader)) 310 | feature_dims = {} 311 | for modality in ["text", "audio", "video"]: 312 | if modality in batch.data: # B, L, D, T 313 | if batch.data[modality].ndim == 4: 314 | feature_dims[modality] = ( 315 | batch.data[modality].shape[1], 316 | batch.data[modality].shape[2], 317 | ) 318 | elif batch.data[modality].ndim == 3: 319 | feature_dims[modality] = ( 320 | 1, 321 | batch.data[modality].shape[1], 322 | ) 323 | else: 324 | raise ValueError( 325 | f"Unexpected number of dimensions for modality {modality}: {batch.data[modality].ndim}" 326 | ) 327 | else: 328 | feature_dims[modality] = None 329 | if "fmri" in batch.data: 330 | fmri = batch.data["fmri"] 331 | n_outputs = fmri.shape[1] 332 | for metric in self.metrics: 333 | if hasattr(metric, "kwargs") and "num_outputs" in metric.kwargs: 334 | metric.kwargs["num_outputs"] = n_outputs 335 | else: 336 | n_outputs = 1000 337 | n_output_timesteps = 100 338 | brain_model = self.brain_model_config.build( 339 | feature_dims=feature_dims, 340 | n_outputs=n_outputs, 341 | n_output_timesteps=n_output_timesteps, 342 | ) 343 | # print(brain_model) 344 | 345 | LOGGER.info("Feature dims: %s", feature_dims) 346 | input_data = brain_model.aggregate_features(batch) 347 | LOGGER.info("Input shapes: %s", input_data.shape) 348 | LOGGER.info("Target shapes: %s", n_outputs) 349 | _ = brain_model(batch) 350 | total_params = sum(p.numel() for p in brain_model.parameters()) 351 | LOGGER.info(f"Total parameters: {total_params}") 352 | self._brain_module = self._init_module(brain_model) 353 | if self.monitor == "val/pearson": 354 | mode = "max" 355 | else: 356 | mode = "min" 357 | callbacks = [ 358 | LearningRateMonitor(logging_interval="epoch"), 359 | JitterWindows(start_jitter_amount=10.0), 360 | ] 361 | if self.patience is not None: 362 | callbacks.append( 363 | EarlyStopping(monitor=self.monitor, mode=mode, patience=self.patience) 364 | ) 365 | annealing_epochs = int(self.n_epochs * (1 - 0.6)) 366 | callbacks.append( 367 | StochasticWeightAveraging( 368 | swa_epoch_start=0.6, 369 | annealing_epochs=annealing_epochs, 370 | swa_lrs=1e-5, 371 | annealing_strategy="cos", 372 | ) 373 | ) 374 | if self.save_checkpoints: 375 | callbacks.append( 376 | ModelCheckpoint( 377 | save_last=True, 378 | save_top_k=1, 379 | dirpath=self.infra.folder, 380 | filename="best", 381 | monitor=self.monitor, 382 | mode=mode, 383 | save_on_train_epoch_end=True, 384 | ) 385 | ) 386 | callbacks.append(Benchmark(root_data_dir)) 387 | 388 | trainer = pl.Trainer( 389 | strategy=( 390 | "auto" 391 | if self.infra.gpus_per_node == 1 392 | else "ddp_find_unused_parameters_true" 393 | ), 394 | devices=self.infra.gpus_per_node, 395 | accelerator=self.accelerator, 396 | max_epochs=self.n_epochs, 397 | limit_train_batches=self.limit_train_batches, 398 | enable_progress_bar=self.enable_progress_bar, 399 | log_every_n_steps=self.log_every_n_steps, 400 | fast_dev_run=self.fast_dev_run, 401 | callbacks=callbacks, 402 | logger=self._logger, 403 | enable_checkpointing=self.save_checkpoints, 404 | ) 405 | self._trainer = trainer 406 | return trainer 407 | 408 | def fit(self, train_loader: DataLoader, valid_loader: DataLoader) -> None: 409 | self._trainer.fit( 410 | model=self._brain_module, 411 | train_dataloaders=train_loader, 412 | val_dataloaders=valid_loader, 413 | ckpt_path=self._get_checkpoint_path(), 414 | ) 415 | 416 | def test(self, test_loader: DataLoader) -> None: 417 | if self.infra.gpus_per_node > 1: 418 | LOGGER.info( 419 | "Destroying DDP process group to enable testing on single device." 420 | ) 421 | torch.distributed.destroy_process_group() 422 | if not self._trainer.is_global_zero: 423 | return 424 | if self.checkpoint_path: 425 | ckpt_path = self.checkpoint_path 426 | else: 427 | ckpt_path = None 428 | self._trainer.test( 429 | self._brain_module, 430 | dataloaders=test_loader, 431 | ckpt_path=ckpt_path, 432 | ) 433 | 434 | def setup_run(self): 435 | 436 | if self.infra.cluster and self.infra.status() != "not submitted": 437 | for out_type in ["stdout", "stderr"]: 438 | old_path = Path(getattr(self.infra.job().paths, out_type)) 439 | new_path = Path(self.infra.folder) / f"log.{out_type}" 440 | try: 441 | if new_path.exists(): 442 | os.remove(new_path) 443 | os.symlink( 444 | old_path, 445 | new_path, 446 | ) 447 | except: 448 | pass 449 | config_path = Path(self.infra.folder) / "config.yaml" 450 | os.makedirs(self.infra.folder, exist_ok=True) 451 | with open(config_path, "w") as outfile: 452 | yaml.dump( 453 | self.model_dump(), 454 | outfile, 455 | indent=4, 456 | default_flow_style=False, 457 | sort_keys=False, 458 | ) 459 | def compute_multidim_pearson(self, loader: DataLoader) -> torch.Tensor: 460 | preds, trues = [], [] 461 | model = self._brain_module 462 | model.eval() 463 | model.to("cuda") 464 | with torch.inference_mode(): 465 | for batch in tqdm(loader, desc="Computing multidim pearson"): 466 | batch = batch.to("cuda") 467 | y_true = batch.data["fmri"].squeeze(-1) 468 | trues.append(y_true.detach().cpu().numpy()) 469 | y_pred = self._brain_module(batch) 470 | preds.append(y_pred.detach().cpu().numpy()) 471 | preds, trues = np.concatenate(preds), np.concatenate(trues) 472 | preds = rearrange(preds, "b d t -> (b t) d") 473 | trues = rearrange(trues, "b d t -> (b t) d") 474 | pearson = np.zeros((trues.shape[1]), dtype=np.float32) 475 | for p in range(len(pearson)): 476 | pearson[p] = pearsonr(trues[:, p], preds[:, p])[0] 477 | return pearson 478 | 479 | @infra.apply 480 | def run(self): 481 | self.setup_run() 482 | self._logger = ( 483 | self.wandb_config.build( 484 | save_dir=self.infra.folder, 485 | xp_config=self.model_dump(), 486 | id=f"{self.wandb_config.group}-{self.infra.uid().split('-')[-1]}", 487 | ) 488 | if self.wandb_config 489 | else None 490 | ) 491 | 492 | if self.seed is not None: 493 | pl.seed_everything(self.seed, workers=True) 494 | np.random.seed(self.seed) 495 | torch.manual_seed(self.seed) 496 | 497 | loaders = self.data.get_loaders(split_to_build="test" if self.test_only else None) 498 | self._setup_trainer(next(iter(loaders.values()))) 499 | 500 | if not self.test_only: 501 | self.fit(loaders["train"], loaders["val"]) 502 | self._trainer.validate(self._brain_module, loaders["val"]) 503 | 504 | metrics = self._trainer.callback_metrics 505 | metrics_df = pd.DataFrame([{k: v.item() for k, v in metrics.items()}]) 506 | metrics_df.to_csv(Path(self.infra.folder) / "metrics.csv", index=False) 507 | 508 | pearson = self.compute_multidim_pearson(loaders["val"]) 509 | np.save(Path(self.infra.folder) / "pearson.npy", pearson) 510 | 511 | self.test(loaders["test"]) 512 | --------------------------------------------------------------------------------