├── .gitignore ├── README.md ├── assets └── fig1.png ├── configs ├── README.md ├── run.yaml └── sweep.yaml ├── pyproject.toml └── scOT ├── __init__.py ├── inference.py ├── metrics.py ├── model.py ├── problems ├── __init__.py ├── base.py ├── elliptic │ ├── __init__.py │ ├── helmholtz.py │ └── poisson.py ├── fluids │ ├── __init__.py │ ├── compressible.py │ ├── incompressible.py │ └── normalization_constants.py ├── reaction_diffusion │ ├── __init__.py │ └── allen_cahn.py └── wave │ ├── __init__.py │ └── acoustic.py ├── train.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Poseidon: Efficient Foundation Models for PDEs 2 | 3 | This is the source code for the paper [*Poseidon: Efficient Foundation Models for PDEs*](https://arxiv.org/abs/2405.19101). It also acts as a package if you want to use the models in your code. 4 | 5 | ![Poseidon](assets/fig1.png) 6 | 7 | Find pretrained models and pretraining dataset in our collection on the [🤗 Hub – Pretrained Models and Pretraining Datasets](https://huggingface.co/collections/camlab-ethz/poseidon-664fa125729c53d8607e209a). All datasets corresponding to downstream tasks can be downloaded from the respective collection on the [🤗 Hub – Downstream Tasks](https://huggingface.co/collections/camlab-ethz/poseidon-downstream-tasks-664fa237cd6b0c097971ef14) as well. To use them, follow the respective sections below. 8 | 9 | ## Usage 10 | 11 | ### Installation & Requirements 12 | 13 | To get all requirements and install the package, run (inside this folder), after getting this repository: 14 | 15 | ```bash 16 | pip install -e . 17 | ``` 18 | 19 | We recommend running the above command in a [virtual environment](https://docs.python.org/3/library/venv.html). 20 | 21 | After installation, you can import the models and use the training and inference scripts from everywhere on your system. 22 | 23 | ### Using the models in your own code 24 | 25 | To use the (pretrained) models in your own code, you can use the following code snippet (after installing): 26 | 27 | ```python 28 | from scOT.model import ScOT 29 | 30 | model = ScOT.from_pretrained("camlab-ethz/Poseidon-") 31 | ``` 32 | 33 | This will load the pretrained model from the 🤗 Hub. `` has to be replaced by `T`, `B`, or `L`, for the respective pretrained model. You can also load a model from a local path by providing the path to the `from_pretrained` method. 34 | 35 | To finetune and replace embeddings and recovery parameters, load the model as follows: 36 | 37 | ```python 38 | from scOT.model import ScOT 39 | 40 | model = ScOT.from_pretrained("camlab-ethz/Poseidon-", config=model_config, ignore_mismatched_sizes=True) 41 | ``` 42 | 43 | Here, `model_config` is a `ScOTConfig` with the correct input/output dimensions. We also refer to [the training/finetuning script](scOT/train.py), see below on usage, which might be easier. 44 | 45 | ### Training & Finetuning 46 | 47 | The easiest way to finetune **Poseidon** on your own dataset is by plugging in your own dataset and running the provided training script as follows: 48 | 49 | ```bash 50 | accelerate launch scOT/train.py \ 51 | --config \ 52 | --wandb_run_name \ 53 | --wandb_project_name \ 54 | --checkpoint_path \ 55 | --data_path \ 56 | --finetune_from \ 57 | --replace_embedding_recovery 58 | ``` 59 | 60 | For more arguments and options, see the help message of the script: 61 | 62 | ```bash 63 | accelerate launch scOT/train.py --help 64 | ``` 65 | 66 | Since the code is built on top of [🤗 Accelerate](https://huggingface.co/docs/accelerate/en/index), you should run `accelerate config` first. 67 | 68 | We also make heavy use of [Weights and Biases](wandb.com) to log and organise all our runs. The code might run without it (by setting `WANDB_MODE=disabled`), but we don't give any guarantees as this probably breaks the folder structure. 69 | 70 | Most of the actual training configuration is set in a YAML config file (see for all arguments to set for a single W&B [run](configs/run.yaml) or a W&B [sweep](configs/sweep.yaml) (multiple runs, see the [W&B documentation](https://docs.wandb.ai/guides/sweeps) on how to start a sweep)). The config file is passed to the training script via the `--config` argument. 71 | 72 | We do our pretrainings with the same script. 73 | 74 | ### Inference/Testing 75 | 76 | To evaluate a model on a dataset, you can use the inference script, for all possible arguments see the help message: 77 | 78 | ```bash 79 | python -m scOT.inference --help 80 | ``` 81 | 82 | ## Datasets 83 | 84 | We provide all datasets used in the paper on the 🤗 Hub. You can download them from the respective collections: 85 | - [🤗 Hub – Pretraining Datasets](https://huggingface.co/collections/camlab-ethz/poseidon-664fa125729c53d8607e209a) 86 | - [🤗 Hub – Downstream Tasks](https://huggingface.co/collections/camlab-ethz/poseidon-downstream-tasks-664fa237cd6b0c097971ef14) 87 | 88 | ### Naming convention in the code 89 | 90 | In the code, we refer to the datasets by a different identifier than on the 🤗 Hub, see the following table for a mapping: 91 | 92 | | Code Identifier | 🤗 Hub/Paper Identifier | 93 | | ----------------|------------------------- | 94 | |fluids.incompressible.Sines| NS-Sines| 95 | |fluids.incompressible.Gaussians| NS-Gauss| 96 | |fluids.compressible.Riemann|CE-RP| 97 | |fluids.compressible.RiemannCurved|CE-CRP| 98 | |fluids.compressible.KelvinHelmholtz|CE-KH| 99 | |fluids.compressible.Gaussians|CE-Gauss| 100 | |fluids.incompressible.PiecewiseConstants|NS-PwC| 101 | |fluids.incompressible.VortexSheet|NS-SVS| 102 | |fluids.incompressible.BrownianBridge|NS-BB| 103 | |fluids.incompressible.ShearLayer|NS-SL| 104 | |fluids.incomressible.PiecewiseConstants.tracer|NS-Tracer-PwC| 105 | |fluids.incompressible.forcing.KolmogorovFlow|FNS-KF| 106 | |fluids.compressible.RiemannKelvinHelmholtz|CE-RPUI| 107 | |fluids.compressible.RichtmyerMeshkov|CE-RM| 108 | |fluids.compressible.gravity.RayleighTaylor|GCE-RT| 109 | |wave.Layer|Wave-Layer| 110 | |wave.Gaussians|Wave-Gauss| 111 | |reaction_diffusion.AllenCahn|ACE| 112 | |fluids.compressible.steady.Airfoil(.time)|SE-AF| 113 | |elliptic.poisson.Gaussians(.time)|Poisson-Gauss| 114 | |elliptic.Helmholtz(.time)|Helmholtz| 115 | 116 | Adding the suffix `.time` to the dataset identifier will load the dataset as time-dependent dataset, i.e. as a long-time limit – use that suffix for finetuning on time-independent datasets. 117 | 118 | ### Download & Assembly 119 | 120 | Download all the datasets used in our paper from the 🤗 Hub. You may want to use the CLI provided by the [Hub Python Library](https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-download): 121 | 122 | ```bash 123 | huggingface-cli download camlab-ethz/ --repo-type dataset --local-dir 124 | ``` 125 | 126 | This will download a specific dataset to the specified `LOCAL DIRECTORY`. After download, you need to assemble the datasets to the format expected by the code; for that, we refer to the README in the respective dataset repository. After assembly, remove the chunked dataset files, as they are not needed for training, and place the assembled dataset at the path you specify as `--data_path` for the training/inference script. You may also specify the 🤗 Hub cache location by specifying the environment variable `HF_HOME` as this is where the download will be performed to. 127 | 128 | ### Adding your own dataset 129 | 130 | We encourage adding your own datasets. For that, you can subclass from [BaseDataset and BaseTimeDataset](scOT/problems/base.py) and add it to the `get_dataset` selector method. You can then use the dataset in the training script by specifying the dataset identifier in the config file. 131 | 132 | For subclassing, we refer to the docstrings in the base classes and the existing datasets in the [problems](scOT/problems) folder. 133 | 134 | ## Pretrained models 135 | 136 | Pretrained models are available on the 🤗 Hub, see the [Poseidon collection](https://huggingface.co/collections/camlab-ethz/poseidon-664fa125729c53d8607e209a) for all models. You can download them via the 🤗 Hub API or by using the `from_pretrained` method, see above. 137 | 138 | ## Citation 139 | 140 | If you use our models, code, or datasets, please consider citing our paper: 141 | 142 | ```bibtex 143 | @misc{herde2024poseidon, 144 | title={Poseidon: Efficient Foundation Models for PDEs}, 145 | author={Maximilian Herde and Bogdan Raonić and Tobias Rohner and Roger Käppeli and Roberto Molinaro and Emmanuel de Bézenac and Siddhartha Mishra}, 146 | year={2024}, 147 | eprint={2405.19101}, 148 | archivePrefix={arXiv}, 149 | primaryClass={cs.LG} 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/assets/fig1.png -------------------------------------------------------------------------------- /configs/README.md: -------------------------------------------------------------------------------- 1 | # Configuration Files 2 | 3 | We give two sample configuration files. One for a single finetuning run and one for a finetuning sweep. Both finetune the Poseidon-B model on Wave-Layer. -------------------------------------------------------------------------------- /configs/run.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: "wave.Layer" 3 | num_trajectories: 4 | value: 128 5 | model_name: 6 | value: "B" 7 | lr: 8 | value: 0.00005 9 | lr_embedding_recovery: 10 | value: 0.0005 11 | lr_time_embedding: 12 | value: 0.0005 13 | weight_decay: 14 | value: 0.000001 15 | lr_scheduler: 16 | value: "cosine" 17 | warmup_ratio: 18 | value: 0.0 19 | early_stopping_patience: 20 | value: 200 21 | num_epochs: 22 | value: 200 23 | batch_size: 24 | value: 40 25 | max_grad_norm: 26 | value: 5.0 -------------------------------------------------------------------------------- /configs/sweep.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | entity: 3 | program: scOT/train.py 4 | method: grid 5 | metric: 6 | name: "eval/loss" 7 | goal: minimize 8 | command: 9 | - "HDF5_USE_FILE_LOCKING=FALSE" 10 | - "accelerate" 11 | - "launch" 12 | - ${program} 13 | - "--disable_tqdm" 14 | - "--json-config" 15 | - "--finetune_from" 16 | - "camlab-ethz/Poseidon-B" 17 | - "--replace_embedding_recovery" 18 | - "--config" 19 | - ${args_json} 20 | parameters: 21 | dataset: 22 | value: "wave.Layer" 23 | num_trajectories: 24 | values: 25 | - 1 26 | - 2 27 | - 4 28 | - 8 29 | - 16 30 | - 32 31 | - 64 32 | - 128 33 | - 256 34 | - 512 35 | - 1024 36 | model_name: 37 | value: "B" 38 | lr: 39 | value: 0.00005 40 | lr_embedding_recovery: 41 | value: 0.0005 42 | lr_time_embedding: 43 | value: 0.0005 44 | weight_decay: 45 | value: 0.000001 46 | lr_scheduler: 47 | value: "cosine" 48 | warmup_ratio: 49 | value: 0.0 50 | early_stopping_patience: 51 | value: 200 52 | num_epochs: 53 | value: 200 54 | batch_size: 55 | value: 40 56 | max_grad_norm: 57 | value: 5.0 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "scOT" 3 | version = "1.0.0" 4 | description = "Foundation models for PDEs based on a scalable Operator Transformer" 5 | dependencies = [ 6 | "torch == 2.0.1", 7 | "torchvision == 0.15.2", 8 | "numpy", 9 | "transformers == 4.29.2", 10 | "matplotlib", 11 | "accelerate == 0.31.0", 12 | "wandb == 0.14.2", 13 | "h5py", 14 | "pandas", 15 | "pyyaml", 16 | ] 17 | 18 | [build-system] 19 | build-backend = "flit_core.buildapi" 20 | requires = ["flit_core >=3.2,<4"] 21 | -------------------------------------------------------------------------------- /scOT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/scOT/__init__.py -------------------------------------------------------------------------------- /scOT/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this script for inference/testing a scOT model. 3 | The script can be used in different modes: 4 | - save_samples: Save samples from a model. 5 | - save_samples_sweep: Save samples from a sweep. 6 | - eval: Evaluate a model on the test set. 7 | - eval_sweep: Evaluate a sweep on the test set. 8 | - eval_accumulation_error: Evaluate the accumulation error of a model. 9 | - eval_resolutions: Evaluate a model on different resolutions. 10 | 11 | See the --help page for more information. 12 | """ 13 | 14 | import argparse 15 | import torch 16 | import numpy as np 17 | import random 18 | import psutil 19 | import os 20 | import pandas as pd 21 | import wandb 22 | from transformers.trainer_utils import EvalPrediction 23 | from scOT.model import ScOT 24 | from scOT.trainer import TrainingArguments, Trainer 25 | from scOT.problems.base import get_dataset, BaseTimeDataset 26 | from scOT.metrics import relative_lp_error, lp_error 27 | 28 | 29 | SEED = 0 30 | torch.manual_seed(SEED) 31 | np.random.seed(SEED) 32 | random.seed(SEED) 33 | 34 | 35 | def get_trainer( 36 | model_path, 37 | batch_size, 38 | dataset, 39 | full_data=False, 40 | output_all_steps=False, 41 | workers=-1, 42 | ): 43 | """ 44 | Get a trainer for the model (actually just using the interface for inference). 45 | 46 | Args: 47 | model_path: str 48 | Path to the model. 49 | batch_size: int 50 | Batch size for evaluation. 51 | dataset: BaseTimeDataset 52 | Test set. 53 | full_data: bool 54 | Whether to save the full data distribution. 55 | output_all_steps: bool 56 | Whether to output all preliminary steps in autoregressive rollout. 57 | workers: int 58 | Number of workers for evaluation. If -1 will use all available cores. 59 | """ 60 | num_cpu_cores = len(psutil.Process().cpu_affinity()) 61 | if workers == -1: 62 | workers = num_cpu_cores 63 | if workers > num_cpu_cores: 64 | workers = num_cpu_cores 65 | assert workers > 0 66 | 67 | model = ScOT.from_pretrained(model_path) 68 | args = TrainingArguments( 69 | output_dir=".", 70 | per_device_eval_batch_size=batch_size, 71 | eval_accumulation_steps=16, 72 | dataloader_num_workers=workers, 73 | ) 74 | time_involved = isinstance(dataset, BaseTimeDataset) 75 | 76 | def compute_metrics(eval_preds): 77 | if time_involved and output_all_steps: 78 | return {} 79 | channel_list = dataset.channel_slice_list 80 | 81 | def get_relative_statistics(errors): 82 | median_error = np.median(errors, axis=0) 83 | mean_error = np.mean(errors, axis=0) 84 | std_error = np.std(errors, axis=0) 85 | min_error = np.min(errors, axis=0) 86 | max_error = np.max(errors, axis=0) 87 | return { 88 | "median_relative_l1_error": median_error, 89 | "mean_relative_l1_error": mean_error, 90 | "std_relative_l1_error": std_error, 91 | "min_relative_l1_error": min_error, 92 | "max_relative_l1_error": max_error, 93 | } 94 | 95 | def get_statistics(errors): 96 | median_error = np.median(errors, axis=0) 97 | mean_error = np.mean(errors, axis=0) 98 | std_error = np.std(errors, axis=0) 99 | min_error = np.min(errors, axis=0) 100 | max_error = np.max(errors, axis=0) 101 | return { 102 | "median_l1_error": median_error, 103 | "mean_l1_error": mean_error, 104 | "std_l1_error": std_error, 105 | "min_l1_error": min_error, 106 | "max_l1_error": max_error, 107 | } 108 | 109 | relative_errors = [ 110 | relative_lp_error( 111 | eval_preds.predictions[:, channel_list[i] : channel_list[i + 1]], 112 | eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]], 113 | p=1, 114 | return_percent=True, 115 | ) 116 | for i in range(len(channel_list) - 1) 117 | ] 118 | 119 | errors = [ 120 | lp_error( 121 | eval_preds.predictions[:, channel_list[i] : channel_list[i + 1]], 122 | eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]], 123 | p=1, 124 | ) 125 | for i in range(len(channel_list) - 1) 126 | ] 127 | 128 | relative_error_statistics = [ 129 | get_relative_statistics(relative_errors[i]) 130 | for i in range(len(channel_list) - 1) 131 | ] 132 | 133 | error_statistics = [ 134 | get_statistics(errors[i]) for i in range(len(channel_list) - 1) 135 | ] 136 | 137 | if dataset.output_dim == 1: 138 | relative_error_statistics = relative_error_statistics[0] 139 | error_statistics = error_statistics[0] 140 | if full_data: 141 | relative_error_statistics["relative_full_data"] = relative_errors[ 142 | 0 143 | ].tolist() 144 | error_statistics["full_data"] = errors[0].tolist() 145 | return {**relative_error_statistics, **error_statistics} 146 | else: 147 | mean_over_relative_means = np.mean( 148 | np.array( 149 | [ 150 | stats["mean_relative_l1_error"] 151 | for stats in relative_error_statistics 152 | ] 153 | ), 154 | axis=0, 155 | ) 156 | mean_over_relative_medians = np.mean( 157 | np.array( 158 | [ 159 | stats["median_relative_l1_error"] 160 | for stats in relative_error_statistics 161 | ] 162 | ), 163 | axis=0, 164 | ) 165 | mean_over_means = np.mean( 166 | np.array([stats["mean_l1_error"] for stats in error_statistics]), axis=0 167 | ) 168 | mean_over_medians = np.mean( 169 | np.array([stats["median_l1_error"] for stats in error_statistics]), 170 | axis=0, 171 | ) 172 | 173 | error_statistics_ = { 174 | "mean_relative_l1_error": mean_over_relative_means, 175 | "mean_over_median_relative_l1_error": mean_over_relative_medians, 176 | "mean_l1_error": mean_over_means, 177 | "mean_over_median_l1_error": mean_over_medians, 178 | } 179 | #!! The above is different from train and finetune (here mean_relative_l1_error is mean over medians instead of mean over means) 180 | for i, stats in enumerate(relative_error_statistics): 181 | for key, value in stats.items(): 182 | error_statistics_[ 183 | dataset.printable_channel_description[i] + "/" + key 184 | ] = value 185 | if full_data: 186 | error_statistics_[ 187 | dataset.printable_channel_description[i] 188 | + "/" 189 | + "relative_full_data" 190 | ] = relative_errors[i].tolist() 191 | for i, stats in enumerate(error_statistics): 192 | for key, value in stats.items(): 193 | error_statistics_[ 194 | dataset.printable_channel_description[i] + "/" + key 195 | ] = value 196 | if full_data: 197 | error_statistics_[ 198 | dataset.printable_channel_description[i] + "/" + "full_data" 199 | ] = errors[i].tolist() 200 | return error_statistics_ 201 | 202 | trainer = Trainer( 203 | model=model, 204 | args=args, 205 | compute_metrics=compute_metrics, 206 | ) 207 | return trainer 208 | 209 | 210 | def rollout(trainer, dataset, ar_steps=1, output_all_steps=False): 211 | """ 212 | Do a rollout of the model. 213 | 214 | Args: 215 | trainer: Trainer 216 | Trainer for the model. 217 | dataset: BaseTimeDataset 218 | Test set. 219 | ar_steps: int or list 220 | Number of autoregressive steps to take. A single int n is interpreted as taking n homogeneous steps, a list of ints [j_0, j_1, ...] is interpreted as taking a step of size j_i. 221 | output_all_steps: bool 222 | Whether to output all preliminary steps in autoregressive rollout. 223 | """ 224 | time_involved = isinstance(dataset, BaseTimeDataset) 225 | if time_involved and ar_steps != 1: 226 | trainer.set_ar_steps(ar_steps, output_all_steps=output_all_steps) 227 | else: 228 | trainer.set_ar_steps(ar_steps=1, output_all_steps=False) 229 | 230 | prediction = trainer.predict(dataset, metric_key_prefix="") 231 | 232 | try: 233 | return prediction.predictions, prediction.label_ids, prediction.metrics 234 | except: 235 | return prediction.predictions 236 | 237 | 238 | def get_test_set( 239 | dataset, data_path, initial_time=None, final_time=None, dataset_kwargs={} 240 | ): 241 | """ 242 | Get a test set (input at initial_time, output at final_time). 243 | 244 | Args: 245 | dataset: str 246 | Dataset name. 247 | data_path: str 248 | Path to data. 249 | initial_time: int 250 | Initial time step to start from. 251 | final_time: int 252 | Final time step to end at. 253 | dataset_kwargs: dict 254 | Additional arguments for dataset as in scOT.problems.base.get_dataset. 255 | """ 256 | if initial_time is not None and final_time is not None: 257 | dataset_kwargs = { 258 | **dataset_kwargs, 259 | "fix_input_to_time_step": initial_time, 260 | "time_step_size": final_time - initial_time, 261 | "max_num_time_steps": 1, 262 | } 263 | dataset = get_dataset( 264 | dataset=dataset, 265 | which="test", 266 | num_trajectories=1, 267 | data_path=data_path, 268 | move_to_local_scratch=None, 269 | **dataset_kwargs, 270 | ) 271 | return dataset 272 | 273 | 274 | def get_first_n_inputs(dataset, n): 275 | """ 276 | Helper to get the first n inputs of a dataset. 277 | """ 278 | inputs = [] 279 | for i in range(n): 280 | inputs.append(dataset[i]["pixel_values"]) 281 | return torch.stack(inputs) 282 | 283 | 284 | def get_trajectories( 285 | dataset, data_path, ar_steps, initial_time, final_time, dataset_kwargs 286 | ): 287 | """ 288 | Get full trajectories in a dataset. Helper for accumulation error evaluation. 289 | 290 | Args: 291 | dataset: str 292 | Dataset name. 293 | data_path: str 294 | Path to data. 295 | ar_steps: int or list 296 | Number of autoregressive steps to take. A single int n is interpreted as taking n homogeneous steps, a list of ints [j_0, j_1, ...] is interpreted as taking a step of size j_i. 297 | initial_time: int 298 | Initial time step to start from. 299 | final_time: int 300 | Final time step to end at. 301 | dataset_kwargs: dict 302 | Additional arguments for dataset as in scOT.problems.base.get_dataset. 303 | """ 304 | trajectories = [] 305 | if isinstance(ar_steps, int): 306 | delta = (final_time - initial_time) // ar_steps 307 | for i in range(ar_steps): 308 | dataset_ = get_test_set( 309 | dataset, 310 | data_path, 311 | initial_time + i * delta, 312 | initial_time + (i + 1) * delta, 313 | dataset_kwargs, 314 | ) 315 | traj_ = [] 316 | for j in range(len(dataset_)): 317 | traj_.append(dataset_[j]["labels"]) 318 | trajectories.append(torch.stack(traj_)) 319 | else: 320 | running_time = initial_time 321 | for i in ar_steps: 322 | dataset_ = get_test_set( 323 | dataset, data_path, running_time, running_time + i, dataset_kwargs 324 | ) 325 | running_time += i 326 | traj_ = [] 327 | for j in range(len(dataset_)): 328 | traj_.append(dataset_[j]["labels"]) 329 | trajectories.append(torch.stack(traj_)) 330 | return torch.stack(trajectories, dim=1) 331 | 332 | 333 | def remove_underscore_dict(d): 334 | return {key[1:] if key.startswith("_") else key: value for key, value in d.items()} 335 | 336 | 337 | if __name__ == "__main__": 338 | parser = argparse.ArgumentParser( 339 | description="Do different evaluations for a model, see --mode." 340 | ) 341 | parser.add_argument( 342 | "--model_path", 343 | type=str, 344 | required=False, 345 | help="Model path. Not required when mode==eval_sweep or save_samples_sweep.", 346 | ) 347 | parser.add_argument( 348 | "--file", 349 | type=str, 350 | required=True, 351 | help="File to load/write to. May also be a directory to save samples.", 352 | ) 353 | parser.add_argument( 354 | "--data_path", 355 | type=str, 356 | required=True, 357 | help="Path to data.", 358 | ) 359 | parser.add_argument( 360 | "--dataset", 361 | type=str, 362 | help="Which test set to load. Not required if mode==eval_sweep or save_samples_sweep.", 363 | ) 364 | parser.add_argument( 365 | "--batch_size", 366 | type=int, 367 | default=64, 368 | help="Batch size for evaluation.", 369 | ) 370 | parser.add_argument( 371 | "--full_data", 372 | action="store_true", 373 | help="Whether to save full data distributions.", 374 | ) 375 | parser.add_argument( 376 | "--initial_time", 377 | type=int, 378 | default=None, 379 | help="Initial time step to start from.", 380 | ) 381 | parser.add_argument( 382 | "--final_time", 383 | type=int, 384 | default=None, 385 | help="Final time step to end at.", 386 | ) 387 | parser.add_argument( 388 | "--ar_steps", 389 | type=int, 390 | nargs="+", 391 | default=[1], 392 | help="Number of autoregressive steps to take. A single int n is interpreted as taking n homogeneous steps, a list of ints [j_0, j_1, ...] is interpreted as taking a step of size j_i.", 393 | ) 394 | parser.add_argument( 395 | "--mode", 396 | type=str, 397 | choices=[ 398 | "save_samples", 399 | "save_samples_sweep", 400 | "eval", 401 | "eval_sweep", 402 | "eval_accumulation_error", 403 | "eval_resolutions", 404 | ], 405 | default="eval", 406 | help="Mode to run. Can be either save_samples to save n samples, save_samples_sweep, eval (to evaluate a single model), eval_sweep (to evaluate all models in a wandb sweep), eval_accumulation_error (to evaluate a model's accumulation error), eval_resolutions (to evaluate a model on different resolutions).", 407 | ) 408 | parser.add_argument( 409 | "--save_n_samples", 410 | type=int, 411 | default=1, 412 | help="Number of samples to save. Only required for mode==save_samples or save_samples_sweep.", 413 | ) 414 | parser.add_argument( 415 | "--resolutions", 416 | type=int, 417 | nargs="+", 418 | help="List of resolutions to evaluate. Only required for mode==eval_resolutions.", 419 | ) 420 | parser.add_argument( 421 | "--wandb_project", 422 | type=str, 423 | default="scOT", 424 | help="Wandb project name. Required if mode==eval_sweep or save_samples_sweep.", 425 | ) 426 | parser.add_argument( 427 | "--wandb_entity", 428 | type=str, 429 | required=False, 430 | help="Wandb entity name. Required if mode==eval_sweep or save_samples_sweep.", 431 | ) 432 | parser.add_argument( 433 | "--wandb_sweep_id", 434 | type=str, 435 | default=None, 436 | help="Wandb sweep id. Required if mode==eval_sweep or save_samples_sweep.", 437 | ) 438 | parser.add_argument( 439 | "--ckpt_dir", 440 | type=str, 441 | required=True, 442 | help="Base checkpoint directory. Required if mode==eval_sweep or save_samples_sweep.", 443 | ) 444 | parser.add_argument( 445 | "--exclude_dataset", 446 | type=str, 447 | nargs="+", 448 | default=[], 449 | help="Datasets to exclude from evaluation. Only relevant when mode==eval_sweep or save_samples_sweep.", 450 | ) 451 | parser.add_argument( 452 | "--exclusively_evaluate_dataset", 453 | type=str, 454 | nargs="+", 455 | default=[], 456 | help="Datasets to exclusively evaluate. Only relevant when mode==eval_sweep or save_samples_sweep.", 457 | ) 458 | parser.add_argument( 459 | "--just_velocities", 460 | action="store_true", 461 | help="Use just velocities in incompressible flow data.", 462 | ) 463 | parser.add_argument( 464 | "--allow_failed", 465 | action="store_true", 466 | help="Allow failed runs to be taken into account with eval_sweep.", 467 | ) 468 | parser.add_argument( 469 | "--append_time", 470 | action="store_true", 471 | help="Append .time to dataset name for evaluation.", 472 | ) 473 | parser.add_argument( 474 | "--num_trajectories", 475 | type=int, 476 | default=128, 477 | help="Filter runs for number of training trajectories. Only relevant if mode==eval_sweep or save_samples_sweep.", 478 | ) 479 | params = parser.parse_args() 480 | if len(params.ar_steps) == 1: 481 | params.ar_steps = params.ar_steps[0] 482 | ar_steps = params.ar_steps 483 | else: 484 | ar_steps = params.ar_steps 485 | params.ar_steps = [ 486 | step / (params.final_time - params.initial_time) for step in params.ar_steps 487 | ] 488 | dataset_kwargs = {} 489 | if params.just_velocities: 490 | dataset_kwargs["just_velocities"] = True 491 | if params.mode == "save_samples": 492 | dataset = get_test_set( 493 | params.dataset, 494 | params.data_path, 495 | params.initial_time, 496 | params.final_time, 497 | dataset_kwargs, 498 | ) 499 | trainer = get_trainer(params.model_path, params.batch_size, dataset) 500 | inputs = get_first_n_inputs(dataset, params.save_n_samples) 501 | outputs, labels, _ = rollout(trainer, dataset, ar_steps=params.ar_steps) 502 | np.save( 503 | params.file + "/" + params.dataset.replace(".", "-") + "/" + "inputs.npy", 504 | inputs.cpu().numpy(), 505 | ) 506 | np.save( 507 | params.file + "/" + params.dataset.replace(".", "-") + "/" + "labels.npy", 508 | labels[: params.save_n_samples], 509 | ) 510 | np.save( 511 | params.file + "/" + params.dataset.replace(".", "-") + "/" + "outputs.npy", 512 | outputs[: params.save_n_samples], 513 | ) 514 | elif params.mode == "save_samples_sweep": 515 | api = wandb.Api() 516 | sweep = api.sweep( 517 | params.wandb_entity 518 | + "/" 519 | + params.wandb_project 520 | + "/" 521 | + params.wandb_sweep_id 522 | ) 523 | for run in sweep.runs: 524 | if run.state == "finished" or ( 525 | params.allow_failed and run.state == "failed" 526 | ): 527 | dset_name = run.config["dataset"] 528 | if run.config["num_trajectories"] != params.num_trajectories: 529 | continue 530 | if dset_name in params.exclude_dataset: 531 | continue 532 | if ( 533 | len(params.exclusively_evaluate_dataset) > 0 534 | and dset_name not in params.exclusively_evaluate_dataset 535 | ): 536 | continue 537 | num_trajectories = run.config["num_trajectories"] 538 | ckpt_dir = ( 539 | params.ckpt_dir 540 | + "/" 541 | + params.wandb_project 542 | + "/" 543 | + params.wandb_sweep_id 544 | + "/" 545 | + run.name 546 | ) 547 | items = os.listdir(ckpt_dir) 548 | dirs = [ 549 | item 550 | for item in items 551 | if os.path.isdir(os.path.join(ckpt_dir, item)) 552 | ] 553 | if len(dirs) > 1: 554 | print( 555 | "WARNING: more than one checkpoint in run directory " + ckpt_dir 556 | ) 557 | print("choosing " + dirs[0]) 558 | model_path = os.path.join(ckpt_dir, dirs[0]) 559 | dataset = get_test_set( 560 | dset_name, 561 | params.data_path, 562 | params.initial_time, 563 | params.final_time, 564 | dataset_kwargs, 565 | ) 566 | trainer = get_trainer(model_path, params.batch_size, dataset) 567 | inputs = get_first_n_inputs(dataset, params.save_n_samples) 568 | outputs, labels, _ = rollout(trainer, dataset, ar_steps=params.ar_steps) 569 | if not os.path.exists(params.file + "/" + dset_name.replace(".", "-")): 570 | os.makedirs(params.file + "/" + dset_name.replace(".", "-")) 571 | if not os.path.exists( 572 | params.file 573 | + "/" 574 | + dset_name.replace(".", "-") 575 | + "/" 576 | + str(num_trajectories) 577 | ): 578 | os.makedirs( 579 | params.file 580 | + "/" 581 | + dset_name.replace(".", "-") 582 | + "/" 583 | + str(num_trajectories) 584 | ) 585 | np.save( 586 | params.file 587 | + "/" 588 | + dset_name.replace(".", "-") 589 | + "/" 590 | + str(num_trajectories) 591 | + "/inputs.npy", 592 | inputs.cpu().numpy(), 593 | ) 594 | np.save( 595 | params.file 596 | + "/" 597 | + dset_name.replace(".", "-") 598 | + "/" 599 | + str(num_trajectories) 600 | + "/labels.npy", 601 | labels[: params.save_n_samples], 602 | ) 603 | np.save( 604 | params.file 605 | + "/" 606 | + dset_name.replace(".", "-") 607 | + "/" 608 | + str(num_trajectories) 609 | + "/" 610 | + "outputs.npy", 611 | outputs[: params.save_n_samples], 612 | ) 613 | else: 614 | if params.mode == "eval": 615 | dataset = get_test_set( 616 | params.dataset, 617 | params.data_path, 618 | params.initial_time, 619 | params.final_time, 620 | dataset_kwargs, 621 | ) 622 | trainer = get_trainer( 623 | params.model_path, 624 | params.batch_size, 625 | dataset, 626 | full_data=params.full_data, 627 | ) 628 | _, _, metrics = rollout( 629 | trainer, 630 | dataset, 631 | ar_steps=params.ar_steps, 632 | output_all_steps=False, 633 | ) 634 | data = { 635 | "dataset": params.dataset, 636 | "initial_time": params.initial_time, 637 | "final_time": params.final_time, 638 | "ar_steps": ar_steps, 639 | **metrics, 640 | } 641 | data = [remove_underscore_dict(data)] 642 | elif params.mode == "eval_sweep": 643 | api = wandb.Api() 644 | sweep = api.sweep( 645 | params.wandb_entity 646 | + "/" 647 | + params.wandb_project 648 | + "/" 649 | + params.wandb_sweep_id 650 | ) 651 | data = [] 652 | for run in sweep.runs: 653 | if run.state == "finished" or ( 654 | params.allow_failed and run.state == "failed" 655 | ): 656 | dset_name = ( 657 | run.config["dataset"] 658 | if not params.append_time 659 | else run.config["dataset"] + ".time" 660 | ) 661 | if dset_name in params.exclude_dataset: 662 | continue 663 | if ( 664 | len(params.exclusively_evaluate_dataset) > 0 665 | and dset_name not in params.exclusively_evaluate_dataset 666 | ): 667 | continue 668 | num_trajectories = run.config["num_trajectories"] 669 | ckpt_dir = ( 670 | params.ckpt_dir 671 | + "/" 672 | + params.wandb_project 673 | + "/" 674 | + params.wandb_sweep_id 675 | + "/" 676 | + run.name 677 | ) 678 | items = os.listdir(ckpt_dir) 679 | dirs = [ 680 | item 681 | for item in items 682 | if os.path.isdir(os.path.join(ckpt_dir, item)) 683 | ] 684 | if len(dirs) > 1: 685 | print( 686 | "WARNING: more than one checkpoint in run directory " 687 | + ckpt_dir 688 | ) 689 | print("choosing " + dirs[0]) 690 | continue 691 | if len(dirs) == 0: 692 | continue 693 | model_path = os.path.join(ckpt_dir, dirs[0]) 694 | dataset = get_test_set( 695 | dset_name, 696 | params.data_path, 697 | params.initial_time, 698 | params.final_time, 699 | dataset_kwargs, 700 | ) 701 | trainer = get_trainer( 702 | model_path, 703 | params.batch_size, 704 | dataset, 705 | full_data=params.full_data, 706 | ) 707 | _, _, metrics = rollout( 708 | trainer, 709 | dataset, 710 | ar_steps=params.ar_steps, 711 | output_all_steps=False, 712 | ) 713 | data.append( 714 | remove_underscore_dict( 715 | { 716 | "dataset": dset_name, 717 | "num_trajectories": num_trajectories, 718 | "initial_time": params.initial_time, 719 | "final_time": params.final_time, 720 | "ar_steps": ar_steps, 721 | **metrics, 722 | } 723 | ) 724 | ) 725 | elif params.mode == "eval_accumulation_error": 726 | dataset = get_test_set( 727 | params.dataset, 728 | params.data_path, 729 | params.initial_time, 730 | params.final_time, 731 | dataset_kwargs, 732 | ) 733 | trainer = get_trainer( 734 | params.model_path, 735 | params.batch_size, 736 | dataset, 737 | output_all_steps=True, 738 | full_data=params.full_data, 739 | ) 740 | predictions, _, _ = rollout( 741 | trainer, 742 | dataset, 743 | ar_steps=params.ar_steps, 744 | output_all_steps=True, 745 | ) 746 | labels = get_trajectories( 747 | params.dataset, 748 | params.data_path, 749 | params.ar_steps, 750 | params.initial_time, 751 | params.final_time, 752 | dataset_kwargs, 753 | ) 754 | 755 | def compute_metrics(eval_preds): 756 | channel_list = dataset.channel_slice_list 757 | 758 | def get_relative_statistics(errors): 759 | median_error = np.median(errors, axis=0) 760 | mean_error = np.mean(errors, axis=0) 761 | std_error = np.std(errors, axis=0) 762 | min_error = np.min(errors, axis=0) 763 | max_error = np.max(errors, axis=0) 764 | return { 765 | "median_relative_l1_error": median_error, 766 | "mean_relative_l1_error": mean_error, 767 | "std_relative_l1_error": std_error, 768 | "min_relative_l1_error": min_error, 769 | "max_relative_l1_error": max_error, 770 | } 771 | 772 | def get_statistics(errors): 773 | median_error = np.median(errors, axis=0) 774 | mean_error = np.mean(errors, axis=0) 775 | std_error = np.std(errors, axis=0) 776 | min_error = np.min(errors, axis=0) 777 | max_error = np.max(errors, axis=0) 778 | return { 779 | "median_l1_error": median_error, 780 | "mean_l1_error": mean_error, 781 | "std_l1_error": std_error, 782 | "min_l1_error": min_error, 783 | "max_l1_error": max_error, 784 | } 785 | 786 | relative_errors = [ 787 | relative_lp_error( 788 | eval_preds.predictions[ 789 | :, channel_list[i] : channel_list[i + 1] 790 | ], 791 | eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]], 792 | p=1, 793 | return_percent=True, 794 | ) 795 | for i in range(len(channel_list) - 1) 796 | ] 797 | 798 | errors = [ 799 | lp_error( 800 | eval_preds.predictions[ 801 | :, channel_list[i] : channel_list[i + 1] 802 | ], 803 | eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]], 804 | p=1, 805 | ) 806 | for i in range(len(channel_list) - 1) 807 | ] 808 | 809 | relative_error_statistics = [ 810 | get_relative_statistics(relative_errors[i]) 811 | for i in range(len(channel_list) - 1) 812 | ] 813 | 814 | error_statistics = [ 815 | get_statistics(errors[i]) for i in range(len(channel_list) - 1) 816 | ] 817 | 818 | if dataset.output_dim == 1: 819 | relative_error_statistics = relative_error_statistics[0] 820 | error_statistics = error_statistics[0] 821 | if params.full_data: 822 | relative_error_statistics["relative_full_data"] = ( 823 | relative_errors[0].tolist() 824 | ) 825 | error_statistics["full_data"] = errors[0].tolist() 826 | return {**relative_error_statistics, **error_statistics} 827 | else: 828 | mean_over_relative_means = np.mean( 829 | np.array( 830 | [ 831 | stats["mean_relative_l1_error"] 832 | for stats in relative_error_statistics 833 | ] 834 | ), 835 | axis=0, 836 | ) 837 | mean_over_relative_medians = np.mean( 838 | np.array( 839 | [ 840 | stats["median_relative_l1_error"] 841 | for stats in relative_error_statistics 842 | ] 843 | ), 844 | axis=0, 845 | ) 846 | mean_over_means = np.mean( 847 | np.array( 848 | [stats["mean_l1_error"] for stats in error_statistics] 849 | ), 850 | axis=0, 851 | ) 852 | mean_over_medians = np.mean( 853 | np.array( 854 | [stats["median_l1_error"] for stats in error_statistics] 855 | ), 856 | axis=0, 857 | ) 858 | 859 | error_statistics_ = { 860 | "mean_relative_l1_error": mean_over_relative_means, 861 | "mean_over_median_relative_l1_error": mean_over_relative_medians, 862 | "mean_l1_error": mean_over_means, 863 | "mean_over_median_l1_error": mean_over_medians, 864 | } 865 | #!! The above is different from train and finetune (here mean_relative_l1_error is mean over medians instead of mean over means) 866 | for i, stats in enumerate(relative_error_statistics): 867 | for key, value in stats.items(): 868 | error_statistics_[ 869 | dataset.printable_channel_description[i] + "/" + key 870 | ] = value 871 | if params.full_data: 872 | error_statistics_[ 873 | dataset.printable_channel_description[i] 874 | + "/" 875 | + "relative_full_data" 876 | ] = relative_errors[i].tolist() 877 | for i, stats in enumerate(error_statistics): 878 | for key, value in stats.items(): 879 | error_statistics_[ 880 | dataset.printable_channel_description[i] + "/" + key 881 | ] = value 882 | if params.full_data: 883 | error_statistics_[ 884 | dataset.printable_channel_description[i] 885 | + "/" 886 | + "full_data" 887 | ] = errors[i].tolist() 888 | return error_statistics_ 889 | 890 | data = [] 891 | for step in range(predictions.shape[1]): 892 | metrics = compute_metrics( 893 | EvalPrediction(predictions[:, step], labels[:, step].cpu().numpy()) 894 | ) 895 | if isinstance(params.ar_steps, int): 896 | delta = (params.final_time - params.initial_time) // params.ar_steps 897 | else: 898 | delta = params.ar_steps[step] 899 | data.append( 900 | remove_underscore_dict( 901 | { 902 | "dataset": params.dataset, 903 | "initial_time": params.initial_time + step * delta, 904 | "final_time": params.initial_time + (step + 1) * delta, 905 | **metrics, 906 | } 907 | ) 908 | ) 909 | elif params.mode == "eval_resolutions": 910 | data = [] 911 | for resolution in params.resolutions: 912 | dataset_kwargs = {"resolution": resolution} 913 | dataset = get_test_set( 914 | params.dataset, 915 | params.data_path, 916 | params.initial_time, 917 | params.final_time, 918 | dataset_kwargs, 919 | ) 920 | trainer = get_trainer( 921 | params.model_path, 922 | params.batch_size, 923 | dataset, 924 | full_data=params.full_data, 925 | ) 926 | _, _, metrics = rollout( 927 | trainer, 928 | dataset, 929 | ar_steps=params.ar_steps, 930 | output_all_steps=False, 931 | ) 932 | data.append( 933 | remove_underscore_dict( 934 | { 935 | "dataset": params.dataset, 936 | "initial_time": params.initial_time, 937 | "final_time": params.final_time, 938 | "ar_steps": ar_steps, 939 | "resolution": resolution, 940 | **metrics, 941 | } 942 | ) 943 | ) 944 | 945 | if os.path.exists(params.file): 946 | df = pd.read_csv(params.file) 947 | else: 948 | df = pd.DataFrame() 949 | df = pd.concat([df, pd.DataFrame(data)], ignore_index=True) 950 | df.to_csv(params.file, index=False) 951 | -------------------------------------------------------------------------------- /scOT/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def lp_error(preds: np.ndarray, targets: np.ndarray, p=1): 5 | num_samples, num_channels, _, _ = preds.shape 6 | preds = preds.reshape(num_samples, num_channels, -1) 7 | targets = targets.reshape(num_samples, num_channels, -1) 8 | errors = np.sum(np.abs(preds - targets) ** p, axis=-1) 9 | return np.sum(errors, axis=-1) ** (1 / p) 10 | 11 | 12 | def relative_lp_error( 13 | preds: np.ndarray, 14 | targets: np.ndarray, 15 | p=1, 16 | return_percent=True, 17 | ): 18 | num_samples, num_channels, _, _ = preds.shape 19 | preds = preds.reshape(num_samples, num_channels, -1) 20 | targets = targets.reshape(num_samples, num_channels, -1) 21 | errors = np.sum(np.abs(preds - targets) ** p, axis=-1) 22 | normalization_factor = np.sum(np.abs(targets) ** p, axis=-1) 23 | 24 | # catch 0 division 25 | normalization_factor = np.sum(normalization_factor, axis=-1) 26 | normalization_factor = np.where( 27 | normalization_factor == 0, 1e-10, normalization_factor 28 | ) 29 | 30 | errors = (np.sum(errors, axis=-1) / normalization_factor) ** (1 / p) 31 | 32 | if return_percent: 33 | errors *= 100 34 | 35 | return errors 36 | 37 | 38 | def mean_relative_lp_error( 39 | preds: np.ndarray, 40 | targets: np.ndarray, 41 | p=1, 42 | return_percent=True, 43 | ): 44 | errors = relative_lp_error(preds, targets, p, return_percent) 45 | return np.mean(errors, axis=0) 46 | 47 | 48 | def median_relative_lp_error( 49 | preds: np.ndarray, 50 | targets: np.ndarray, 51 | p=1, 52 | return_percent=True, 53 | ): 54 | errors = relative_lp_error(preds, targets, p, return_percent) 55 | return np.median(errors, axis=0) 56 | -------------------------------------------------------------------------------- /scOT/problems/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/scOT/problems/__init__.py -------------------------------------------------------------------------------- /scOT/problems/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the dataset selector get_dataset, as well as the base 3 | classes for all datasets. 4 | """ 5 | 6 | from torch.utils.data import Dataset, ConcatDataset 7 | from typing import Optional, List, Dict 8 | from abc import ABC 9 | import re 10 | import os 11 | import shutil 12 | from accelerate.utils import broadcast_object_list 13 | 14 | 15 | def get_dataset(dataset, **kwargs): 16 | """ 17 | Get a dataset by name. 18 | If you enter a list of str, will return a ConcatDataset of the datasets. 19 | 20 | Available choices are: 21 | - fluids.incompressible.BrownianBridge(.tracer) 22 | - fluids.incompressible.Gaussians(.tracer) 23 | - fluids.incompressible.ShearLayer 24 | - fluids.incompressible.Sines(.tracer) 25 | - fluids.incompressible.PiecewiseConstants(.tracer) 26 | - fluids.incompressible.VortexSheet(.tracer) 27 | - fluids.incompressible.forcing.KolmogorovFlow 28 | - fluids.compressible.gravity.RayleighTaylor(.tracer) 29 | - fluids.compressible.RiemannKelvinHelmholtz 30 | - fluids.compressible.RiemannCurved 31 | - fluids.compressible.Riemann 32 | - fluids.compressible.KelvinHelmholtz 33 | - fluids.compressible.Gaussians 34 | - fluids.compressible.RichtmyerMeshkov(.tracer) 35 | - fluids.compressible.steady.Airfoil(.time) 36 | - elliptic.poisson.Gaussians(.time) 37 | - elliptic.Helmholtz(.time) 38 | - wave.Layer 39 | - wave.Gaussians 40 | - reaction_diffusion.AllenCahn 41 | 42 | Adding .out at the end of the str, returns a dataset with more time steps. 43 | **kwargs overwrite the default settings. 44 | .time is a time-wrapped time-independent dataset. 45 | """ 46 | if isinstance(dataset, list): 47 | return ConcatDataset([get_dataset(d, **kwargs) for d in dataset]) 48 | if "fluids" in dataset: 49 | if "fluids.incompressible" in dataset: 50 | if "BrownianBridge" in dataset: 51 | from .fluids.incompressible import BrownianBridge as dset 52 | elif "Gaussians" in dataset: 53 | from .fluids.incompressible import Gaussians as dset 54 | elif "ShearLayer" in dataset: 55 | from .fluids.incompressible import ShearLayer as dset 56 | elif "Sines" in dataset: 57 | from .fluids.incompressible import Sines as dset 58 | elif "PiecewiseConstants" in dataset: 59 | from .fluids.incompressible import PiecewiseConstants as dset 60 | elif "VortexSheet" in dataset: 61 | from .fluids.incompressible import VortexSheet as dset 62 | elif "forcing" in dataset: 63 | if "KolmogorovFlow" in dataset: 64 | from .fluids.incompressible import KolmogorovFlow as dset 65 | else: 66 | raise ValueError(f"Unknown dataset {dataset}") 67 | else: 68 | raise ValueError(f"Unknown dataset {dataset}") 69 | elif "fluids.compressible" in dataset: 70 | if "gravity" in dataset: 71 | if "RayleighTaylor" in dataset: 72 | from .fluids.compressible import RayleighTaylor as dset 73 | 74 | if "out" in dataset: 75 | default_time_settings = { 76 | "max_num_time_steps": 10, 77 | "time_step_size": 1, 78 | } 79 | else: 80 | default_time_settings = { 81 | "max_num_time_steps": 7, 82 | "time_step_size": 1, 83 | } 84 | kwargs = {**default_time_settings, **kwargs} 85 | elif "Blast" in dataset: 86 | from .fluids.compressible import Blast as dset 87 | elif "RiemannKelvinHelmholtz" in dataset: 88 | from .fluids.compressible import RiemannKelvinHelmholtz as dset 89 | elif "RiemannCurved" in dataset: 90 | from .fluids.compressible import RiemannCurved as dset 91 | elif "Riemann" in dataset: 92 | from .fluids.compressible import Riemann as dset 93 | elif "KelvinHelmholtz" in dataset: 94 | from .fluids.compressible import KelvinHelmholtz as dset 95 | elif "Gaussians" in dataset: 96 | from .fluids.compressible import Gaussians as dset 97 | elif "RichtmyerMeshkov" in dataset: 98 | from .fluids.compressible import RichtmyerMeshkov as dset 99 | elif "steady" in dataset: 100 | if "steady.Airfoil" in dataset: 101 | from .fluids.compressible import Airfoil as dset 102 | 103 | if "out" in dataset: 104 | raise ValueError(f"Unknown dataset {dataset}") 105 | else: 106 | raise ValueError(f"Unknown dataset {dataset}") 107 | else: 108 | raise ValueError(f"Unknown dataset {dataset}") 109 | else: 110 | raise ValueError(f"Unknown dataset {dataset}") 111 | if "out" in dataset: 112 | default_time_settings = {"max_num_time_steps": 10, "time_step_size": 2} 113 | else: 114 | default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2} 115 | if "tracer" in dataset: 116 | tracer = True 117 | else: 118 | tracer = False 119 | if not "steady" in dataset: 120 | kwargs = {"tracer": tracer, **default_time_settings, **kwargs} 121 | elif "elliptic" in dataset: 122 | if ".out" in dataset: 123 | raise NotImplementedError(f"Unknown dataset {dataset}") 124 | if "elliptic.poisson" in dataset: 125 | if "Gaussians" in dataset: 126 | from .elliptic.poisson import Gaussians as dset 127 | else: 128 | raise ValueError(f"Unknown dataset {dataset}") 129 | elif "elliptic.Helmholtz" in dataset: 130 | from .elliptic.helmholtz import Helmholtz as dset 131 | else: 132 | raise ValueError(f"Unknown dataset {dataset}") 133 | elif "wave" in dataset: 134 | if "wave.Layer" in dataset: 135 | if "out" in dataset: 136 | default_time_settings = {"max_num_time_steps": 10, "time_step_size": 2} 137 | else: 138 | default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2} 139 | kwargs = {**default_time_settings, **kwargs} 140 | from .wave.acoustic import Layer as dset 141 | elif "wave.Gaussians" in dataset: 142 | if "out" in dataset: 143 | raise ValueError(f"Unknown dataset {dataset}") 144 | else: 145 | default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2} 146 | kwargs = {**default_time_settings, **kwargs} 147 | from .wave.acoustic import Gaussians as dset 148 | else: 149 | raise ValueError(f"Unknown dataset {dataset}") 150 | elif "reaction_diffusion" in dataset: 151 | if "reaction_diffusion.AllenCahn" in dataset: 152 | if "out" in dataset: 153 | default_time_settings = {"max_num_time_steps": 9, "time_step_size": 2} 154 | else: 155 | default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2} 156 | kwargs = {**default_time_settings, **kwargs} 157 | from .reaction_diffusion.allen_cahn import AllenCahn as dset 158 | else: 159 | raise ValueError(f"Unknown dataset {dataset}") 160 | 161 | return dset(**kwargs) if ".time" not in dataset else TimeWrapper(dset(**kwargs)) 162 | 163 | 164 | class BaseDataset(Dataset, ABC): 165 | """A base class for all datasets. Can be directly derived from if you have a steady/non-time dependent problem.""" 166 | 167 | def __init__( 168 | self, 169 | which: Optional[str] = None, 170 | num_trajectories: Optional[int] = None, 171 | data_path: Optional[str] = "./data", 172 | move_to_local_scratch: Optional[str] = None, 173 | ) -> None: 174 | """ 175 | Args: 176 | which: Which dataset to use, i.e. train, val, or test. 177 | resolution: The resolution of the dataset. 178 | num_trajectories: The number of trajectories to use for training. 179 | data_path: The path to the data files. 180 | move_to_local_scratch: If not None, move the data to this directory at dataset initialization and use it from there. 181 | """ 182 | assert which in ["train", "val", "test"] 183 | assert num_trajectories is not None and ( 184 | num_trajectories > 0 or num_trajectories in [-1, -2, -8] 185 | ) 186 | 187 | self.num_trajectories = num_trajectories 188 | self.data_path = data_path 189 | self.which = which 190 | self.move_to_local_scratch = move_to_local_scratch 191 | 192 | def _move_to_local_scratch(self, file_path): 193 | if self.move_to_local_scratch is not None: 194 | data_dir = os.path.join(self.data_path, file_path) 195 | file = file_path.split("/")[-1] 196 | scratch_dir = self.move_to_local_scratch 197 | dest_dir = os.path.join(scratch_dir, file) 198 | RANK = int(os.environ.get("LOCAL_RANK", -1)) 199 | if not os.path.exists(dest_dir) and (RANK == 0 or RANK == -1): 200 | print(f"Start copying {file} to {dest_dir}...") 201 | shutil.copy(data_dir, dest_dir) 202 | print("Finished data copy.") 203 | # idk how to do the barrier differently 204 | ls = broadcast_object_list([dest_dir], from_process=0) 205 | dest_dir = ls[0] 206 | return dest_dir 207 | else: 208 | return file_path 209 | 210 | def post_init(self) -> None: 211 | """ 212 | Call after self.N_max, self.N_val, self.N_test, as well as the file_paths and normalization constants are set. 213 | """ 214 | assert ( 215 | self.N_max is not None 216 | and self.N_max > 0 217 | and self.N_max >= self.N_val + self.N_test 218 | ) 219 | if self.num_trajectories == -1: 220 | self.num_trajectories = self.N_max - self.N_val - self.N_test 221 | elif self.num_trajectories == -2: 222 | self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 2 223 | elif self.num_trajectories == -8: 224 | self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 8 225 | assert self.num_trajectories + self.N_val + self.N_test <= self.N_max 226 | assert self.N_val is not None and self.N_val > 0 227 | assert self.N_test is not None and self.N_test > 0 228 | if self.which == "train": 229 | self.length = self.num_trajectories 230 | self.start = 0 231 | elif self.which == "val": 232 | self.length = self.N_val 233 | self.start = self.N_max - self.N_val - self.N_test 234 | else: 235 | self.length = self.N_test 236 | self.start = self.N_max - self.N_test 237 | 238 | self.output_dim = self.label_description.count(",") + 1 239 | descriptors, channel_slice_list = self.get_channel_lists(self.label_description) 240 | self.printable_channel_description = descriptors 241 | self.channel_slice_list = channel_slice_list 242 | 243 | def __len__(self) -> int: 244 | """ 245 | Returns: overall length of dataset. 246 | """ 247 | return self.length 248 | 249 | def __getitem__(self, idx) -> Dict: 250 | """ 251 | Get an item. OVERWRITE! 252 | 253 | Args: 254 | idx: The index of the sample to get. 255 | 256 | Returns: 257 | A dict of key-value pairs of data. 258 | """ 259 | pass 260 | 261 | @staticmethod 262 | def get_channel_lists(label_description): 263 | matches = re.findall(r"\[([^\[\]]+)\]", label_description) 264 | channel_slice_list = [0] # use as channel_slice_list[i]:channel_slice_list[i+1] 265 | beautiful_descriptors = [] 266 | for match in matches: 267 | channel_slice_list.append(channel_slice_list[-1] + 1 + match.count(",")) 268 | splt = match.split(",") 269 | if len(splt) > 1: 270 | beautiful_descriptors.append("".join(splt)) 271 | else: 272 | beautiful_descriptors.append(match) 273 | return beautiful_descriptors, channel_slice_list 274 | 275 | 276 | class BaseTimeDataset(BaseDataset, ABC): 277 | """A base class for time dependent problems. Inherit time-dependent problems from here.""" 278 | 279 | def __init__( 280 | self, 281 | *args, 282 | max_num_time_steps: Optional[int] = None, 283 | time_step_size: Optional[int] = None, 284 | fix_input_to_time_step: Optional[int] = None, 285 | allowed_time_transitions: Optional[List[int]] = None, 286 | **kwargs, 287 | ) -> None: 288 | """ 289 | Args: 290 | max_num_time_steps: The maximum number of time steps to use. 291 | time_step_size: The size of the time step. 292 | fix_input_to_time_step: If not None, fix the input to this time step. 293 | allowed_time_transitions: If not None, only allow these time transitions (time steps). 294 | """ 295 | assert max_num_time_steps is not None and max_num_time_steps > 0 296 | assert time_step_size is not None and time_step_size > 0 297 | assert fix_input_to_time_step is None or fix_input_to_time_step >= 0 298 | 299 | super().__init__(*args, **kwargs) 300 | self.max_num_time_steps = max_num_time_steps 301 | self.time_step_size = time_step_size 302 | self.fix_input_to_time_step = fix_input_to_time_step 303 | self.allowed_time_transitions = allowed_time_transitions 304 | 305 | def _idx_map(self, idx): 306 | i = idx // self.multiplier 307 | _idx = idx - i * self.multiplier 308 | 309 | if self.fix_input_to_time_step is None: 310 | t1, t2 = self.time_indices[_idx] 311 | assert t2 >= t1 312 | t = t2 - t1 313 | else: 314 | t1 = self.fix_input_to_time_step 315 | t2 = self.time_step_size * (_idx + 1) + self.fix_input_to_time_step 316 | t = t2 - t1 317 | return i, t, t1, t2 318 | 319 | def post_init(self) -> None: 320 | """ 321 | Call after self.N_max, self.N_val, self.N_test, as well as the file_paths and normalization constants are set. 322 | self.max_time_step must have already been set. 323 | """ 324 | assert ( 325 | self.N_max is not None 326 | and self.N_max > 0 327 | and self.N_max >= self.N_val + self.N_test 328 | ) 329 | if self.num_trajectories == -1: 330 | self.num_trajectories = self.N_max - self.N_val - self.N_test 331 | elif self.num_trajectories == -2: 332 | self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 2 333 | elif self.num_trajectories == -8: 334 | self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 8 335 | assert self.num_trajectories + self.N_val + self.N_test <= self.N_max 336 | assert self.N_val is not None and self.N_val > 0 337 | assert self.N_test is not None and self.N_test > 0 338 | assert self.max_num_time_steps is not None and self.max_num_time_steps > 0 339 | 340 | if self.fix_input_to_time_step is not None: 341 | self.multiplier = self.max_num_time_steps 342 | else: 343 | self.time_indices = [] 344 | for i in range(self.max_num_time_steps + 1): 345 | for j in range(i, self.max_num_time_steps + 1): 346 | if ( 347 | self.allowed_time_transitions is not None 348 | and (j - i) not in self.allowed_time_transitions 349 | ): 350 | continue 351 | self.time_indices.append( 352 | (self.time_step_size * i, self.time_step_size * j) 353 | ) 354 | self.multiplier = len(self.time_indices) 355 | 356 | if self.which == "train": 357 | self.length = self.num_trajectories * self.multiplier 358 | self.start = 0 359 | elif self.which == "val": 360 | self.length = self.N_val * self.multiplier 361 | self.start = self.N_max - self.N_val - self.N_test 362 | else: 363 | self.length = self.N_test * self.multiplier 364 | self.start = self.N_max - self.N_test 365 | 366 | self.output_dim = self.label_description.count(",") + 1 367 | descriptors, channel_slice_list = self.get_channel_lists(self.label_description) 368 | self.printable_channel_description = descriptors 369 | self.channel_slice_list = channel_slice_list 370 | 371 | 372 | class TimeWrapper(BaseTimeDataset): 373 | """For time-independent problems to be plugged into time-dependent models.""" 374 | 375 | def __init__(self, dataset): 376 | super().__init__( 377 | dataset.which, 378 | dataset.num_trajectories, 379 | dataset.data_path, 380 | None, 381 | max_num_time_steps=1, 382 | time_step_size=1, 383 | ) 384 | self.dataset = dataset 385 | self.resolution = dataset.resolution 386 | self.input_dim = dataset.input_dim 387 | self.output_dim = dataset.output_dim 388 | self.channel_slice_list = dataset.channel_slice_list 389 | self.printable_channel_description = dataset.printable_channel_description 390 | 391 | def __len__(self): 392 | return len(self.dataset) 393 | 394 | def __getitem__(self, idx): 395 | return {**self.dataset[idx], "time": 1.0} 396 | -------------------------------------------------------------------------------- /scOT/problems/elliptic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/scOT/problems/elliptic/__init__.py -------------------------------------------------------------------------------- /scOT/problems/elliptic/helmholtz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import h5py 4 | import numpy as np 5 | from scOT.problems.base import BaseDataset 6 | 7 | 8 | class Helmholtz(BaseDataset): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | self.N_max = 19675 13 | self.N_val = 128 14 | self.N_test = 512 15 | self.resolution = 128 16 | 17 | self.file_path = os.path.join( 18 | self.data_path, 19 | "Helmholtz.h5", 20 | ) 21 | self.file_path = self._move_to_local_scratch(self.file_path) 22 | self.reader = h5py.File(self.file_path, "r") 23 | self.mean = 0.11523915668552 24 | self.std = 0.8279975746000605 25 | 26 | self.input_dim = 2 27 | self.label_description = "[u]" 28 | 29 | self.post_init() 30 | 31 | def __getitem__(self, idx): 32 | inputs = ( 33 | torch.from_numpy(self.reader["Sample_" + str(idx + self.start)]["a"][:]) 34 | .type(torch.float32) 35 | .reshape(1, self.resolution, self.resolution) 36 | ) 37 | inputs = inputs - 1 38 | b = float(np.array(self.reader["Sample_" + str(idx + self.start)]["bc"])) 39 | bc = b * torch.ones_like(inputs) 40 | inputs = torch.cat((inputs, bc), dim=0) 41 | 42 | labels = ( 43 | torch.from_numpy(self.reader["Sample_" + str(idx + self.start)]["u"][:]) 44 | .type(torch.float32) 45 | .reshape(1, self.resolution, self.resolution) 46 | ) 47 | labels = (labels - self.mean) / self.std 48 | 49 | return {"pixel_values": inputs, "labels": labels} 50 | -------------------------------------------------------------------------------- /scOT/problems/elliptic/poisson.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import h5py 4 | from scOT.problems.base import BaseDataset 5 | 6 | CONSTANTS = { 7 | "mean_source": 0.014822142414492256, 8 | "std_source": 4.755138816607612, 9 | "mean_solution": 0.0005603458434937093, 10 | "std_solution": 0.02401226126952699, 11 | } 12 | 13 | 14 | class Gaussians(BaseDataset): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.N_max = 20000 18 | self.N_val = 120 19 | self.N_test = 240 20 | self.resolution = 128 21 | 22 | self.file_path = os.path.join(self.data_path, "Poisson-Gauss.nc") 23 | self.file_path = self._move_to_local_scratch(self.file_path) 24 | self.reader = h5py.File(self.file_path, "r") 25 | self.constants = CONSTANTS 26 | 27 | self.input_dim = 1 28 | self.label_description = "[u]" 29 | 30 | self.post_init() 31 | 32 | def __getitem__(self, idx): 33 | inputs = ( 34 | torch.from_numpy(self.reader["source"][idx + self.start]) 35 | .type(torch.float32) 36 | .reshape(1, self.resolution, self.resolution) 37 | ) 38 | 39 | labels = ( 40 | torch.from_numpy(self.reader["solution"][idx + self.start]) 41 | .type(torch.float32) 42 | .reshape(1, self.resolution, self.resolution) 43 | ) 44 | 45 | inputs = (inputs - self.constants["mean_source"]) / self.constants["std_source"] 46 | labels = (labels - self.constants["mean_solution"]) / self.constants[ 47 | "std_solution" 48 | ] 49 | 50 | return {"pixel_values": inputs, "labels": labels} 51 | -------------------------------------------------------------------------------- /scOT/problems/fluids/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/scOT/problems/fluids/__init__.py -------------------------------------------------------------------------------- /scOT/problems/fluids/compressible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | import copy 4 | from scOT.problems.base import BaseTimeDataset, BaseDataset 5 | from scOT.problems.fluids.normalization_constants import CONSTANTS 6 | 7 | 8 | class Airfoil(BaseDataset): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | self.N_max = 10869 13 | self.N_val = 120 14 | self.N_test = 240 15 | self.resolution = 128 16 | 17 | data_path = self.data_path + "/SE-AF.nc" 18 | data_path = self._move_to_local_scratch(data_path) 19 | self.reader = h5py.File(data_path, "r") 20 | 21 | self.constants = { 22 | "mean": 0.92984116, 23 | "std": 0.10864315, 24 | } 25 | 26 | self.input_dim = 1 27 | self.label_description = "[rho]" 28 | 29 | self.post_init() 30 | 31 | def __getitem__(self, idx): 32 | i = idx 33 | inputs = ( 34 | torch.from_numpy(self.reader["solution"][i + self.start, 0]) 35 | .type(torch.float32) 36 | .reshape(1, self.resolution, self.resolution) 37 | ) 38 | labels = ( 39 | torch.from_numpy(self.reader["solution"][i + self.start, 1]) 40 | .type(torch.float32) 41 | .reshape(1, self.resolution, self.resolution) 42 | ) 43 | 44 | labels = (labels - self.constants["mean"]) / self.constants["std"] 45 | 46 | pixel_mask = inputs == 1 47 | labels[pixel_mask] = 1 48 | 49 | return { 50 | "pixel_values": inputs, 51 | "labels": labels, 52 | "pixel_mask": pixel_mask, 53 | } 54 | 55 | 56 | class RichtmyerMeshkov(BaseTimeDataset): 57 | def __init__(self, *args, tracer=False, **kwargs): 58 | super().__init__(*args, **kwargs) 59 | assert self.max_num_time_steps * self.time_step_size <= 20 60 | 61 | self.N_max = 1260 62 | self.N_val = 100 63 | self.N_test = 130 64 | self.resolution = 128 65 | 66 | data_path = self.data_path + "/CE-RM.nc" 67 | data_path = self._move_to_local_scratch(data_path) 68 | self.reader = h5py.File(data_path, "r") 69 | 70 | self.constants = { 71 | "mean": torch.tensor([1.1964245, -7.164812e-06, 2.8968952e-06, 1.5648036]) 72 | .unsqueeze(1) 73 | .unsqueeze(1), 74 | "std": torch.tensor([0.5543239, 0.24304213, 0.2430597, 0.89639103]) 75 | .unsqueeze(1) 76 | .unsqueeze(1), 77 | "time": 20.0, 78 | } 79 | 80 | self.input_dim = 4 81 | self.label_description = "[rho],[u,v],[p]" 82 | 83 | self.pixel_mask = torch.tensor([False, False, False, False]) 84 | 85 | self.post_init() 86 | 87 | def __getitem__(self, idx): 88 | i, t, t1, t2 = self._idx_map(idx) 89 | time = t / self.constants["time"] 90 | 91 | inputs = ( 92 | torch.from_numpy(self.reader["solution"][i + self.start, t1, 0:4]) 93 | .type(torch.float32) 94 | .reshape(4, self.resolution, self.resolution) 95 | ) 96 | 97 | label = ( 98 | torch.from_numpy(self.reader["solution"][i + self.start, t2, 0:4]) 99 | .type(torch.float32) 100 | .reshape(4, self.resolution, self.resolution) 101 | ) 102 | 103 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 104 | label = (label - self.constants["mean"]) / self.constants["std"] 105 | 106 | return { 107 | "pixel_values": inputs, 108 | "labels": label, 109 | "time": time, 110 | "pixel_mask": self.pixel_mask, 111 | } 112 | 113 | 114 | class RayleighTaylor(BaseTimeDataset): 115 | def __init__(self, *args, tracer=False, **kwargs): 116 | super().__init__(*args, **kwargs) 117 | assert self.max_num_time_steps * self.time_step_size <= 10 118 | 119 | self.N_max = 1260 120 | self.N_val = 100 121 | self.N_test = 130 122 | self.resolution = 128 123 | 124 | data_path = self.data_path + "/GCE-RT.nc" 125 | data_path = self._move_to_local_scratch(data_path) 126 | self.reader = h5py.File(data_path, "r") 127 | 128 | self.constants = { 129 | "mean": torch.tensor( 130 | [0.8970493, 4.0316996e-13, -1.3858967e-13, 0.7133829, -1.7055787] 131 | ) 132 | .unsqueeze(1) 133 | .unsqueeze(1), 134 | "std": torch.tensor( 135 | [0.12857835, 0.014896976, 0.014896975, 0.21293919, 0.40131348] 136 | ) 137 | .unsqueeze(1) 138 | .unsqueeze(1), 139 | "time": 10.0, 140 | } 141 | 142 | self.input_dim = 5 143 | self.label_description = "[rho],[u,v],[p],[g]" 144 | 145 | self.pixel_mask = torch.tensor([False, False, False, False, False]) 146 | 147 | self.post_init() 148 | 149 | def __getitem__(self, idx): 150 | i, t, t1, t2 = self._idx_map(idx) 151 | time = t / self.constants["time"] 152 | 153 | inputs = ( 154 | torch.from_numpy(self.reader["solution"][i + self.start, t1, 0:4]) 155 | .type(torch.float32) 156 | .reshape(4, self.resolution, self.resolution) 157 | ) 158 | label = ( 159 | torch.from_numpy(self.reader["solution"][i + self.start, t2, 0:4]) 160 | .type(torch.float32) 161 | .reshape(4, self.resolution, self.resolution) 162 | ) 163 | 164 | g_1 = ( 165 | torch.from_numpy(self.reader["solution"][i + self.start, t1, 5:6]) 166 | .type(torch.float32) 167 | .reshape(1, self.resolution, self.resolution) 168 | ) 169 | g_2 = ( 170 | torch.from_numpy(self.reader["solution"][i + self.start, t2, 5:6]) 171 | .type(torch.float32) 172 | .reshape(1, self.resolution, self.resolution) 173 | ) 174 | 175 | inputs = (inputs - self.constants["mean"][:4]) / self.constants["std"][:4] 176 | g_1 = (g_1 - self.constants["mean"][4]) / self.constants["std"][4] 177 | g_2 = (g_2 - self.constants["mean"][4]) / self.constants["std"][4] 178 | label = (label - self.constants["mean"][:4]) / self.constants["std"][:4] 179 | 180 | inputs = torch.cat([inputs, g_1], dim=0) 181 | label = torch.cat([label, g_2], dim=0) 182 | 183 | return { 184 | "pixel_values": inputs, 185 | "labels": label, 186 | "time": time, 187 | "pixel_mask": self.pixel_mask, 188 | } 189 | 190 | 191 | class CompressibleBase(BaseTimeDataset): 192 | def __init__(self, file_path, *args, tracer=False, **kwargs): 193 | super().__init__(*args, **kwargs) 194 | assert self.max_num_time_steps * self.time_step_size <= 20 195 | 196 | self.N_max = 10000 197 | self.N_val = 120 198 | self.N_test = 240 199 | self.resolution = 128 200 | self.tracer = tracer 201 | 202 | data_path = self.data_path + file_path 203 | data_path = self._move_to_local_scratch(data_path) 204 | self.reader = h5py.File(data_path, "r") 205 | 206 | self.constants = copy.deepcopy(CONSTANTS) 207 | 208 | self.input_dim = 4 if not tracer else 5 209 | self.label_description = ( 210 | "[rho],[u,v],[p]" if not tracer else "[rho],[u,v],[p],[tracer]" 211 | ) 212 | 213 | self.pixel_mask = ( 214 | torch.tensor([False, False, False, False]) 215 | if not tracer 216 | else torch.tensor([False, False, False, False, False]) 217 | ) 218 | 219 | self.post_init() 220 | 221 | def __getitem__(self, idx): 222 | i, t, t1, t2 = self._idx_map(idx) 223 | time = t / self.constants["time"] 224 | 225 | inputs = ( 226 | torch.from_numpy(self.reader["data"][i + self.start, t1, 0:4]) 227 | .type(torch.float32) 228 | .reshape(4, self.resolution, self.resolution) 229 | ) 230 | label = ( 231 | torch.from_numpy(self.reader["data"][i + self.start, t2, 0:4]) 232 | .type(torch.float32) 233 | .reshape(4, self.resolution, self.resolution) 234 | ) 235 | 236 | inputs[3] = inputs[3] - self.mean_pressure 237 | label[3] = label[3] - self.mean_pressure 238 | 239 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 240 | label = (label - self.constants["mean"]) / self.constants["std"] 241 | 242 | if self.tracer: 243 | input_tracer = ( 244 | torch.from_numpy(self.reader["data"][i + self.start, t1, 4:5]) 245 | .type(torch.float32) 246 | .reshape(1, self.resolution, self.resolution) 247 | ) 248 | output_tracer = ( 249 | torch.from_numpy(self.reader["data"][i + self.start, t2, 4:5]) 250 | .type(torch.float32) 251 | .reshape(1, self.resolution, self.resolution) 252 | ) 253 | inputs = torch.cat([inputs, input_tracer], dim=0) 254 | label = torch.cat([label, output_tracer], dim=0) 255 | 256 | return { 257 | "pixel_values": inputs, 258 | "labels": label, 259 | "time": time, 260 | "pixel_mask": self.pixel_mask, 261 | } 262 | 263 | 264 | class Gaussians(CompressibleBase): 265 | def __init__(self, *args, tracer=False, **kwargs): 266 | self.mean_pressure = 2.513 267 | file_path = "/CE-Gauss.nc" 268 | if tracer: 269 | raise NotImplementedError("Tracer not implemented for Gaussians") 270 | super().__init__(file_path, *args, tracer=tracer, **kwargs) 271 | 272 | 273 | class KelvinHelmholtz(CompressibleBase): 274 | def __init__(self, *args, tracer=False, **kwargs): 275 | self.mean_pressure = 1.0 276 | file_path = "/CE-KH.nc" 277 | if tracer: 278 | raise NotImplementedError("Tracer not implemented for KelvinHelmholtz") 279 | super().__init__(file_path, *args, tracer=tracer, **kwargs) 280 | 281 | 282 | class Riemann(CompressibleBase): 283 | def __init__(self, *args, tracer=False, **kwargs): 284 | self.mean_pressure = 0.215 285 | file_path = "/CE-RP.nc" 286 | if tracer: 287 | raise NotImplementedError("Tracer not implemented for Riemann") 288 | super().__init__(file_path, *args, tracer=tracer, **kwargs) 289 | 290 | 291 | class RiemannCurved(CompressibleBase): 292 | def __init__(self, *args, tracer=False, **kwargs): 293 | self.mean_pressure = 0.553 294 | file_path = "/CE-CRP.nc" 295 | if tracer: 296 | raise NotImplementedError("Tracer not implemented for RiemannCurved") 297 | super().__init__(file_path, *args, tracer=tracer, **kwargs) 298 | 299 | 300 | class RiemannKelvinHelmholtz(CompressibleBase): 301 | def __init__(self, *args, tracer=False, **kwargs): 302 | self.mean_pressure = 1.33 303 | file_path = "/CE-RPUI.nc" 304 | if tracer: 305 | raise NotImplementedError( 306 | "Tracer not implemented for RiemannKelvinHelmholtz" 307 | ) 308 | super().__init__(file_path, *args, tracer=tracer, **kwargs) 309 | -------------------------------------------------------------------------------- /scOT/problems/fluids/incompressible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | import numpy as np 4 | import copy 5 | from scOT.problems.base import BaseTimeDataset 6 | from scOT.problems.fluids.normalization_constants import CONSTANTS 7 | 8 | 9 | class IncompressibleBase(BaseTimeDataset): 10 | def __init__( 11 | self, 12 | N_max, 13 | file_path, 14 | *args, 15 | tracer=False, 16 | just_velocities=False, 17 | transpose=False, 18 | resolution=None, 19 | **kwargs 20 | ): 21 | """ 22 | just_velocities: If True, only the velocities are used as input and output. 23 | transpose: If True, the input and output are transposed. 24 | """ 25 | super().__init__(*args, **kwargs) 26 | assert self.max_num_time_steps * self.time_step_size <= 20 27 | 28 | self.N_max = N_max 29 | self.N_val = 120 30 | self.N_test = 240 31 | self.resolution = 128 32 | self.tracer = tracer 33 | self.just_velocities = just_velocities 34 | self.transpose = transpose 35 | 36 | data_path = self.data_path + file_path 37 | data_path = self._move_to_local_scratch(data_path) 38 | self.reader = h5py.File(data_path, "r") 39 | 40 | self.constants = copy.deepcopy(CONSTANTS) 41 | if just_velocities: 42 | self.constants["mean"] = self.constants["mean"][1:3] 43 | self.constants["std"] = self.constants["std"][1:3] 44 | 45 | self.density = torch.ones(1, self.resolution, self.resolution) 46 | self.pressure = torch.zeros(1, self.resolution, self.resolution) 47 | 48 | self.input_dim = 4 if not tracer else 5 49 | if just_velocities: 50 | self.input_dim -= 2 51 | self.label_description = "[u,v]" 52 | if not self.just_velocities: 53 | self.label_description = "[rho],[u,v],[p]" 54 | if tracer: 55 | self.label_description += ",[tracer]" 56 | 57 | self.pixel_mask = torch.tensor([False, False]) 58 | if not self.just_velocities: 59 | self.pixel_mask = torch.tensor([False, False, False, True]) 60 | if tracer: 61 | self.pixel_mask = torch.cat( 62 | [self.pixel_mask, torch.tensor([False])], 63 | dim=0, 64 | ) 65 | 66 | if resolution is None: 67 | self.res = None 68 | else: 69 | if resolution > 128: 70 | raise ValueError("Resolution must be <= 128") 71 | self.res = resolution 72 | 73 | self.post_init() 74 | 75 | def _downsample(self, image, target_size): 76 | image = image.unsqueeze(0) 77 | image_size = image.shape[-2] 78 | freqs = torch.fft.fftfreq(image_size, d=1 / image_size) 79 | sel = torch.logical_and(freqs >= -target_size / 2, freqs <= target_size / 2 - 1) 80 | image_hat = torch.fft.fft2(image, norm="forward") 81 | image_hat = image_hat[:, :, sel, :][:, :, :, sel] 82 | image = torch.fft.ifft2(image_hat, norm="forward").real 83 | return image.squeeze(0) 84 | 85 | def __getitem__(self, idx): 86 | i, t, t1, t2 = self._idx_map(idx) 87 | time = t / self.constants["time"] 88 | 89 | inputs_v = ( 90 | torch.from_numpy(self.reader["velocity"][i + self.start, t1, 0:2]) 91 | .type(torch.float32) 92 | .reshape(2, self.resolution, self.resolution) 93 | ) 94 | label_v = ( 95 | torch.from_numpy(self.reader["velocity"][i + self.start, t2, 0:2]) 96 | .type(torch.float32) 97 | .reshape(2, self.resolution, self.resolution) 98 | ) 99 | if self.transpose: 100 | inputs_v = inputs_v.transpose(-2, -1) 101 | label_v = label_v.transpose(-2, -1) 102 | 103 | if not self.just_velocities: 104 | inputs = torch.cat([self.density, inputs_v, self.pressure], dim=0) 105 | label = torch.cat([self.density, label_v, self.pressure], dim=0) 106 | else: 107 | inputs = inputs_v 108 | label = label_v 109 | 110 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 111 | label = (label - self.constants["mean"]) / self.constants["std"] 112 | 113 | if self.tracer: 114 | input_tracer = ( 115 | torch.from_numpy(self.reader["velocity"][i + self.start, t1, 2:3]) 116 | .type(torch.float32) 117 | .reshape(1, self.resolution, self.resolution) 118 | ) 119 | output_tracer = ( 120 | torch.from_numpy(self.reader["velocity"][i + self.start, t2, 2:3]) 121 | .type(torch.float32) 122 | .reshape(1, self.resolution, self.resolution) 123 | ) 124 | if self.transpose: 125 | input_tracer = input_tracer.transpose(-2, -1) 126 | output_tracer = output_tracer.transpose(-2, -1) 127 | input_tracer = ( 128 | input_tracer - self.constants["tracer_mean"] 129 | ) / self.constants["tracer_std"] 130 | output_tracer = ( 131 | output_tracer - self.constants["tracer_mean"] 132 | ) / self.constants["tracer_std"] 133 | 134 | inputs = torch.cat([inputs, input_tracer], dim=0) 135 | label = torch.cat([label, output_tracer], dim=0) 136 | 137 | if self.res is not None: 138 | inputs = self._downsample(inputs, self.res) 139 | label = self._downsample(label, self.res) 140 | 141 | return { 142 | "pixel_values": inputs, 143 | "labels": label, 144 | "time": time, 145 | "pixel_mask": self.pixel_mask, 146 | } 147 | 148 | 149 | class KolmogorovFlow(BaseTimeDataset): 150 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 151 | super().__init__(*args, **kwargs) 152 | assert self.max_num_time_steps * self.time_step_size <= 20 153 | 154 | assert tracer == False 155 | 156 | self.N_max = 20000 157 | self.N_val = 120 158 | self.N_test = 240 159 | self.resolution = 128 160 | self.just_velocities = just_velocities 161 | 162 | data_path = self.data_path + "/FNS-KF.nc" 163 | data_path = self._move_to_local_scratch(data_path) 164 | self.reader = h5py.File(data_path, "r") 165 | 166 | self.constants = copy.deepcopy(CONSTANTS) 167 | self.constants["mean"][1] = -2.2424793e-13 168 | self.constants["mean"][2] = 4.1510376e-12 169 | self.constants["std"][1] = 0.22017328 170 | self.constants["std"][2] = 0.22078253 171 | if just_velocities: 172 | self.constants["mean"] = self.constants["mean"][1:3] 173 | self.constants["std"] = self.constants["std"][1:3] 174 | 175 | self.density = torch.ones(1, self.resolution, self.resolution) 176 | self.pressure = torch.zeros(1, self.resolution, self.resolution) 177 | X, Y = torch.meshgrid( 178 | torch.linspace(0, 1, self.resolution), 179 | torch.linspace(0, 1, self.resolution), 180 | indexing="ij", 181 | ) 182 | f = lambda x, y: 0.1 * torch.sin(2.0 * np.pi * (x + y)) 183 | self.forcing = f(X, Y).unsqueeze(0) 184 | self.constants["mean_forcing"] = -1.2996679288335145e-09 185 | self.constants["std_forcing"] = 0.0707106739282608 186 | self.forcing = (self.forcing - self.constants["mean_forcing"]) / self.constants[ 187 | "std_forcing" 188 | ] 189 | 190 | self.input_dim = 5 if not tracer else 6 191 | if just_velocities: 192 | self.input_dim -= 2 193 | self.label_description = "[u,v],[g]" 194 | if not self.just_velocities: 195 | self.label_description = "[rho],[u,v],[p],[g]" 196 | if tracer: 197 | self.label_description += ",[tracer]" 198 | 199 | self.pixel_mask = torch.tensor([False, False, False]) 200 | if not self.just_velocities: 201 | self.pixel_mask = torch.tensor([False, False, False, True, False]) 202 | if tracer: 203 | self.pixel_mask = torch.cat( 204 | [self.pixel_mask, torch.tensor([False])], 205 | dim=0, 206 | ) 207 | 208 | self.post_init() 209 | 210 | def __getitem__(self, idx): 211 | i, t, t1, t2 = self._idx_map(idx) 212 | time = t / self.constants["time"] 213 | 214 | inputs_v = ( 215 | torch.from_numpy(self.reader["solution"][i + self.start, t1, 0:2]) 216 | .type(torch.float32) 217 | .reshape(2, self.resolution, self.resolution) 218 | ) 219 | label_v = ( 220 | torch.from_numpy(self.reader["solution"][i + self.start, t2, 0:2]) 221 | .type(torch.float32) 222 | .reshape(2, self.resolution, self.resolution) 223 | ) 224 | 225 | if not self.just_velocities: 226 | inputs = torch.cat([self.density, inputs_v, self.pressure], dim=0) 227 | label = torch.cat([self.density, label_v, self.pressure], dim=0) 228 | else: 229 | inputs = inputs_v 230 | label = label_v 231 | 232 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 233 | label = (label - self.constants["mean"]) / self.constants["std"] 234 | 235 | inputs = torch.cat([inputs, self.forcing], dim=0) 236 | label = torch.cat([label, self.forcing], dim=0) 237 | 238 | return { 239 | "pixel_values": inputs, 240 | "labels": label, 241 | "time": time, 242 | "pixel_mask": self.pixel_mask, 243 | } 244 | 245 | 246 | class BrownianBridge(IncompressibleBase): 247 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 248 | if tracer: 249 | raise ValueError("BrownianBridge does not have a tracer") 250 | file_path = "/NS-BB.nc" 251 | super().__init__( 252 | 20000, 253 | file_path, 254 | *args, 255 | tracer=False, 256 | just_velocities=just_velocities, 257 | **kwargs 258 | ) 259 | 260 | 261 | class PiecewiseConstants(IncompressibleBase): 262 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 263 | file_path = "/NS-PwC.nc" 264 | super().__init__( 265 | 20000, 266 | file_path, 267 | *args, 268 | tracer=tracer, 269 | just_velocities=just_velocities, 270 | **kwargs 271 | ) 272 | 273 | 274 | class Gaussians(IncompressibleBase): 275 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 276 | if tracer: 277 | raise ValueError("Gaussians does not have a tracer") 278 | file_path = "/NS-Gauss.nc" 279 | super().__init__( 280 | 20000, 281 | file_path, 282 | *args, 283 | tracer=False, 284 | just_velocities=just_velocities, 285 | **kwargs 286 | ) 287 | 288 | 289 | class ShearLayer(IncompressibleBase): 290 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 291 | if tracer: 292 | raise ValueError("Shear layer does not have a tracer") 293 | super().__init__( 294 | 40000, 295 | "/NS-SL.nc", 296 | *args, 297 | transpose=True, 298 | tracer=False, 299 | just_velocities=just_velocities, 300 | **kwargs 301 | ) 302 | 303 | 304 | class VortexSheet(IncompressibleBase): 305 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 306 | if tracer: 307 | raise ValueError("VortexSheet does not have a tracer") 308 | file_path = "/NS-SVS.nc" 309 | super().__init__( 310 | 20000, 311 | file_path, 312 | *args, 313 | tracer=False, 314 | just_velocities=just_velocities, 315 | **kwargs 316 | ) 317 | 318 | 319 | class Sines(IncompressibleBase): 320 | def __init__(self, *args, tracer=False, just_velocities=False, **kwargs): 321 | if tracer: 322 | raise ValueError("Sines does not have a tracer") 323 | file_path = "/NS-Sines.nc" 324 | super().__init__( 325 | 20000, 326 | file_path, 327 | *args, 328 | tracer=False, 329 | just_velocities=just_velocities, 330 | **kwargs 331 | ) 332 | -------------------------------------------------------------------------------- /scOT/problems/fluids/normalization_constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | CONSTANTS = { 4 | "mean": torch.tensor([0.80, 0.0, 0.0, 0.0]).unsqueeze(1).unsqueeze(1), 5 | "std": torch.tensor([0.31, 0.391, 0.356, 0.185]).unsqueeze(1).unsqueeze(1), 6 | "time": 20.0, 7 | "tracer_mean": 0.19586183, 8 | "tracer_std": 0.37, 9 | } 10 | -------------------------------------------------------------------------------- /scOT/problems/reaction_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/scOT/problems/reaction_diffusion/__init__.py -------------------------------------------------------------------------------- /scOT/problems/reaction_diffusion/allen_cahn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | from scOT.problems.base import BaseTimeDataset 4 | 5 | 6 | class AllenCahn(BaseTimeDataset): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | assert self.max_num_time_steps * self.time_step_size <= 19 10 | 11 | self.N_max = 15000 12 | self.N_val = 60 13 | self.N_test = 240 14 | self.resolution = 128 15 | 16 | data_path = self.data_path + "/ACE.nc" 17 | data_path = self._move_to_local_scratch(data_path) 18 | self.reader = h5py.File(data_path, "r") 19 | 20 | self.constants = { 21 | "mean": 0.002484262, 22 | "std": 0.65351176, 23 | "time": 19.0, 24 | } 25 | 26 | self.input_dim = 1 27 | self.label_description = "[u]" 28 | 29 | self.post_init() 30 | 31 | def __getitem__(self, idx): 32 | i, t, t1, t2 = self._idx_map(idx) 33 | time = t / self.constants["time"] 34 | 35 | inputs = ( 36 | torch.from_numpy(self.reader["solution"][i + self.start, t1]) 37 | .type(torch.float32) 38 | .reshape(1, self.resolution, self.resolution) 39 | ) 40 | labels = ( 41 | torch.from_numpy(self.reader["solution"][i + self.start, t2]) 42 | .type(torch.float32) 43 | .reshape(1, self.resolution, self.resolution) 44 | ) 45 | 46 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 47 | labels = (labels - self.constants["mean"]) / self.constants["std"] 48 | 49 | return { 50 | "pixel_values": inputs, 51 | "labels": labels, 52 | "time": time, 53 | } 54 | -------------------------------------------------------------------------------- /scOT/problems/wave/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-ethz/poseidon/b8fa28f59bd7f7673323f28d11a12c6f3a215c61/scOT/problems/wave/__init__.py -------------------------------------------------------------------------------- /scOT/problems/wave/acoustic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | from scOT.problems.base import BaseTimeDataset 4 | 5 | 6 | class Layer(BaseTimeDataset): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | assert self.max_num_time_steps * self.time_step_size <= 20 10 | 11 | self.N_max = 10512 12 | self.N_val = 60 13 | self.N_test = 240 14 | self.resolution = 128 15 | 16 | data_path = self.data_path + "/Wave-Layer.nc" 17 | data_path = self._move_to_local_scratch(data_path) 18 | self.reader = h5py.File(data_path, "r") 19 | 20 | self.constants = { 21 | "mean": 0.03467443221585092, 22 | "std": 0.10442421752963911, 23 | "mean_c": 3498.5644380917424, 24 | "std_c": 647.843958567462, 25 | "time": 20.0, 26 | } 27 | 28 | self.input_dim = 2 29 | self.label_description = "[u],[c]" 30 | 31 | self.post_init() 32 | 33 | def __getitem__(self, idx): 34 | i, t, t1, t2 = self._idx_map(idx) 35 | time = t / self.constants["time"] 36 | 37 | inputs = ( 38 | torch.from_numpy(self.reader["solution"][i + self.start, t1]) 39 | .type(torch.float32) 40 | .reshape(1, self.resolution, self.resolution) 41 | ) 42 | inputs_c = ( 43 | torch.from_numpy(self.reader["c"][i + self.start]) 44 | .type(torch.float32) 45 | .reshape(1, self.resolution, self.resolution) 46 | ) 47 | labels = ( 48 | torch.from_numpy(self.reader["solution"][i + self.start, t2]) 49 | .type(torch.float32) 50 | .reshape(1, self.resolution, self.resolution) 51 | ) 52 | 53 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 54 | inputs_c = (inputs_c - self.constants["mean_c"]) / self.constants["std_c"] 55 | labels = (labels - self.constants["mean"]) / self.constants["std"] 56 | 57 | inputs = torch.cat([inputs, inputs_c], dim=0) 58 | labels = torch.cat([labels, inputs_c], dim=0) 59 | 60 | return { 61 | "pixel_values": inputs, 62 | "labels": labels, 63 | "time": time, 64 | } 65 | 66 | 67 | class Gaussians(BaseTimeDataset): 68 | def __init__(self, *args, **kwargs): 69 | super().__init__(*args, **kwargs) 70 | assert self.max_num_time_steps * self.time_step_size <= 15 71 | 72 | self.N_max = 10512 73 | self.N_val = 60 74 | self.N_test = 240 75 | self.resolution = 128 76 | 77 | data_path = self.data_path + "/Wave-Gauss.nc" 78 | data_path = self._move_to_local_scratch(data_path) 79 | self.reader = h5py.File(data_path, "r") 80 | 81 | self.constants = { 82 | "mean": 0.0334376316, 83 | "std": 0.1171879068, 84 | "mean_c": 2618.4593933, 85 | "std_c": 601.51658913, 86 | "time": 15.0, 87 | } 88 | 89 | self.input_dim = 2 90 | self.label_description = "[u],[c]" 91 | 92 | self.post_init() 93 | 94 | def __getitem__(self, idx): 95 | i, t, t1, t2 = self._idx_map(idx) 96 | time = t / self.constants["time"] 97 | 98 | inputs = ( 99 | torch.from_numpy(self.reader["solution"][i + self.start, t1]) 100 | .type(torch.float32) 101 | .reshape(1, self.resolution, self.resolution) 102 | ) 103 | inputs_c = ( 104 | torch.from_numpy(self.reader["c"][i + self.start]) 105 | .type(torch.float32) 106 | .reshape(1, self.resolution, self.resolution) 107 | ) 108 | labels = ( 109 | torch.from_numpy(self.reader["solution"][i + self.start, t2]) 110 | .type(torch.float32) 111 | .reshape(1, self.resolution, self.resolution) 112 | ) 113 | 114 | inputs = (inputs - self.constants["mean"]) / self.constants["std"] 115 | inputs_c = (inputs_c - self.constants["mean_c"]) / self.constants["std_c"] 116 | labels = (labels - self.constants["mean"]) / self.constants["std"] 117 | 118 | inputs = torch.cat([inputs, inputs_c], dim=0) 119 | labels = torch.cat([labels, inputs_c], dim=0) 120 | 121 | return { 122 | "pixel_values": inputs, 123 | "labels": labels, 124 | "time": time, 125 | } 126 | -------------------------------------------------------------------------------- /scOT/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script trains a scOT or pretrains Poseidon on a PDE dataset. 3 | Can be also used for finetuning Poseidon. 4 | Can be used in a single config or sweep setup. 5 | """ 6 | 7 | import argparse 8 | import torch 9 | import wandb 10 | import numpy as np 11 | import random 12 | import json 13 | import psutil 14 | import os 15 | 16 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 17 | import yaml 18 | import matplotlib.pyplot as plt 19 | import transformers 20 | from accelerate.utils import broadcast_object_list 21 | from scOT.trainer import TrainingArguments, Trainer 22 | from transformers import EarlyStoppingCallback 23 | from scOT.model import ScOT, ScOTConfig 24 | from mpl_toolkits.axes_grid1 import ImageGrid 25 | from scOT.problems.base import get_dataset, BaseTimeDataset 26 | from scOT.utils import get_num_parameters, read_cli, get_num_parameters_no_embed 27 | from scOT.metrics import relative_lp_error 28 | 29 | SEED = 0 30 | torch.manual_seed(SEED) 31 | np.random.seed(SEED) 32 | random.seed(SEED) 33 | 34 | 35 | MODEL_MAP = { 36 | "T": { 37 | "num_heads": [3, 6, 12, 24], 38 | "skip_connections": [2, 2, 2, 0], 39 | "window_size": 16, 40 | "patch_size": 4, 41 | "mlp_ratio": 4.0, 42 | "depths": [4, 4, 4, 4], 43 | "embed_dim": 48, 44 | }, 45 | "S": { 46 | "num_heads": [3, 6, 12, 24], 47 | "skip_connections": [2, 2, 2, 0], 48 | "window_size": 16, 49 | "patch_size": 4, 50 | "mlp_ratio": 4.0, 51 | "depths": [8, 8, 8, 8], 52 | "embed_dim": 48, 53 | }, 54 | "B": { 55 | "num_heads": [3, 6, 12, 24], 56 | "skip_connections": [2, 2, 2, 0], 57 | "window_size": 16, 58 | "patch_size": 4, 59 | "mlp_ratio": 4.0, 60 | "depths": [8, 8, 8, 8], 61 | "embed_dim": 96, 62 | }, 63 | "L": { 64 | "num_heads": [3, 6, 12, 24], 65 | "skip_connections": [2, 2, 2, 0], 66 | "window_size": 16, 67 | "patch_size": 4, 68 | "mlp_ratio": 4.0, 69 | "depths": [8, 8, 8, 8], 70 | "embed_dim": 192, 71 | }, 72 | } 73 | 74 | 75 | def create_predictions_plot(predictions, labels, wandb_prefix): 76 | assert predictions.shape[0] >= 4 77 | 78 | indices = random.sample(range(predictions.shape[0]), 4) 79 | 80 | predictions = predictions[indices] 81 | labels = labels[indices] 82 | 83 | fig = plt.figure() 84 | grid = ImageGrid( 85 | fig, 111, nrows_ncols=(predictions.shape[1] + labels.shape[1], 4), axes_pad=0.1 86 | ) 87 | 88 | vmax, vmin = max(predictions.max(), labels.max()), min( 89 | predictions.min(), labels.min() 90 | ) 91 | 92 | for _i, ax in enumerate(grid): 93 | i = _i // 4 94 | j = _i % 4 95 | 96 | if i % 2 == 0: 97 | ax.imshow( 98 | predictions[j, i // 2, :, :], 99 | cmap="gist_ncar", 100 | origin="lower", 101 | vmin=vmin, 102 | vmax=vmax, 103 | ) 104 | else: 105 | ax.imshow( 106 | labels[j, i // 2, :, :], 107 | cmap="gist_ncar", 108 | origin="lower", 109 | vmin=vmin, 110 | vmax=vmax, 111 | ) 112 | 113 | ax.set_xticks([]) 114 | ax.set_yticks([]) 115 | 116 | wandb.log({wandb_prefix + "/predictions": wandb.Image(fig)}) 117 | plt.close() 118 | 119 | 120 | def setup(params, model_map=True): 121 | config = None 122 | RANK = int(os.environ.get("LOCAL_RANK", -1)) 123 | CPU_CORES = len(psutil.Process().cpu_affinity()) 124 | CPU_CORES = min(CPU_CORES, 16) 125 | print(f"Detected {CPU_CORES} CPU cores, will use {CPU_CORES} workers.") 126 | if params.disable_tqdm: 127 | transformers.utils.logging.disable_progress_bar() 128 | if params.json_config: 129 | config = json.loads(params.config) 130 | else: 131 | config = params.config 132 | 133 | if RANK == 0 or RANK == -1: 134 | run = wandb.init( 135 | project=params.wandb_project_name, name=params.wandb_run_name, config=config 136 | ) 137 | config = wandb.config 138 | else: 139 | 140 | def clean_yaml(config): 141 | d = {} 142 | for key, inner_dict in config.items(): 143 | d[key] = inner_dict["value"] 144 | return d 145 | 146 | if not params.json_config: 147 | with open(params.config, "r") as s: 148 | config = yaml.safe_load(s) 149 | config = clean_yaml(config) 150 | run = None 151 | 152 | ckpt_dir = "./" 153 | if RANK == 0 or RANK == -1: 154 | if run.sweep_id is not None: 155 | ckpt_dir = ( 156 | params.checkpoint_path 157 | + "/" 158 | + run.project 159 | + "/" 160 | + run.sweep_id 161 | + "/" 162 | + run.name 163 | ) 164 | else: 165 | ckpt_dir = params.checkpoint_path + "/" + run.project + "/" + run.name 166 | if (RANK == 0 or RANK == -1) and not os.path.exists(ckpt_dir): 167 | os.makedirs(ckpt_dir) 168 | ls = broadcast_object_list([ckpt_dir], from_process=0) 169 | ckpt_dir = ls[0] 170 | 171 | if model_map and ( 172 | type(config["model_name"]) == str and config["model_name"] in MODEL_MAP.keys() 173 | ): 174 | config = {**config, **MODEL_MAP[config["model_name"]]} 175 | if RANK == 0 or RANK == -1: 176 | wandb.config.update(MODEL_MAP[config["model_name"]], allow_val_change=True) 177 | 178 | return run, config, ckpt_dir, RANK, CPU_CORES 179 | 180 | 181 | if __name__ == "__main__": 182 | parser = argparse.ArgumentParser(description="Train scOT or pretrain Poseidon.") 183 | parser.add_argument("--resume_training", action="store_true") 184 | parser.add_argument( 185 | "--finetune_from", 186 | type=str, 187 | default=None, 188 | help="Set this to a str pointing to a HF Hub model checkpoint or a directory with a scOT checkpoint if you want to finetune.", 189 | ) 190 | parser.add_argument( 191 | "--replace_embedding_recovery", 192 | action="store_true", 193 | help="Set this if you have to replace the embeddings and recovery layers because you are not just using the density, velocity and pressure channels. Only relevant for finetuning.", 194 | ) 195 | params = read_cli(parser).parse_args() 196 | run, config, ckpt_dir, RANK, CPU_CORES = setup(params) 197 | 198 | train_eval_set_kwargs = ( 199 | {"just_velocities": True} 200 | if ("incompressible" in config["dataset"]) and params.just_velocities 201 | else {} 202 | ) 203 | if params.move_data is not None: 204 | train_eval_set_kwargs["move_to_local_scratch"] = params.move_data 205 | if params.max_num_train_time_steps is not None: 206 | train_eval_set_kwargs["max_num_time_steps"] = params.max_num_train_time_steps 207 | if params.train_time_step_size is not None: 208 | train_eval_set_kwargs["time_step_size"] = params.train_time_step_size 209 | if params.train_small_time_transition: 210 | train_eval_set_kwargs["allowed_time_transitions"] = [1] 211 | train_dataset = get_dataset( 212 | dataset=config["dataset"], 213 | which="train", 214 | num_trajectories=config["num_trajectories"], 215 | data_path=params.data_path, 216 | **train_eval_set_kwargs, 217 | ) 218 | eval_dataset = get_dataset( 219 | dataset=config["dataset"], 220 | which="val", 221 | num_trajectories=config["num_trajectories"], 222 | data_path=params.data_path, 223 | **train_eval_set_kwargs, 224 | ) 225 | 226 | config["effective_train_set_size"] = len(train_dataset) 227 | time_involved = isinstance(train_dataset, BaseTimeDataset) or ( 228 | isinstance(train_dataset, torch.utils.data.ConcatDataset) 229 | and isinstance(train_dataset.datasets[0], BaseTimeDataset) 230 | ) 231 | 232 | if not isinstance(train_dataset, torch.utils.data.ConcatDataset): 233 | resolution = train_dataset.resolution 234 | input_dim = train_dataset.input_dim 235 | output_dim = train_dataset.output_dim 236 | channel_slice_list = train_dataset.channel_slice_list 237 | printable_channel_description = train_dataset.printable_channel_description 238 | else: 239 | resolution = train_dataset.datasets[0].resolution 240 | input_dim = train_dataset.datasets[0].input_dim 241 | output_dim = train_dataset.datasets[0].output_dim 242 | channel_slice_list = train_dataset.datasets[0].channel_slice_list 243 | printable_channel_description = train_dataset.datasets[ 244 | 0 245 | ].printable_channel_description 246 | 247 | model_config = ( 248 | ScOTConfig( 249 | image_size=resolution, 250 | patch_size=config["patch_size"], 251 | num_channels=input_dim, 252 | num_out_channels=output_dim, 253 | embed_dim=config["embed_dim"], 254 | depths=config["depths"], 255 | num_heads=config["num_heads"], 256 | skip_connections=config["skip_connections"], 257 | window_size=config["window_size"], 258 | mlp_ratio=config["mlp_ratio"], 259 | qkv_bias=True, 260 | hidden_dropout_prob=0.0, # default 261 | attention_probs_dropout_prob=0.0, # default 262 | drop_path_rate=0.0, 263 | hidden_act="gelu", 264 | use_absolute_embeddings=False, 265 | initializer_range=0.02, 266 | layer_norm_eps=1e-5, 267 | p=1, 268 | channel_slice_list_normalized_loss=channel_slice_list, 269 | residual_model="convnext", 270 | use_conditioning=time_involved, 271 | learn_residual=False, 272 | ) 273 | if params.finetune_from is None or params.replace_embedding_recovery 274 | else None 275 | ) 276 | 277 | train_config = TrainingArguments( 278 | output_dir=ckpt_dir, 279 | overwrite_output_dir=True, #! OVERWRITE THIS DIRECTORY IN CASE, also for resuming training 280 | evaluation_strategy="epoch", 281 | per_device_train_batch_size=config["batch_size"], 282 | per_device_eval_batch_size=config["batch_size"], 283 | eval_accumulation_steps=16, 284 | max_grad_norm=config["max_grad_norm"], 285 | num_train_epochs=config["num_epochs"], 286 | optim="adamw_torch", 287 | learning_rate=config["lr"], 288 | learning_rate_embedding_recovery=( 289 | None 290 | if (params.finetune_from is None or "lr_embedding_recovery" not in config) 291 | else config["lr_embedding_recovery"] 292 | ), 293 | learning_rate_time_embedding=( 294 | None 295 | if (params.finetune_from is None or "lr_time_embedding" not in config) 296 | else config["lr_time_embedding"] 297 | ), 298 | weight_decay=config["weight_decay"], 299 | adam_beta1=0.9, # default 300 | adam_beta2=0.999, # default 301 | adam_epsilon=1e-8, # default 302 | lr_scheduler_type=config["lr_scheduler"], 303 | warmup_ratio=config["warmup_ratio"], 304 | log_level="passive", 305 | logging_strategy="steps", 306 | logging_steps=5, 307 | logging_nan_inf_filter=False, 308 | save_strategy="epoch", 309 | save_total_limit=1, 310 | seed=SEED, 311 | fp16=False, 312 | dataloader_num_workers=CPU_CORES, 313 | load_best_model_at_end=True, 314 | metric_for_best_model="loss", 315 | greater_is_better=False, 316 | dataloader_pin_memory=True, 317 | gradient_checkpointing=False, 318 | auto_find_batch_size=False, 319 | full_determinism=False, 320 | torch_compile=False, 321 | report_to="wandb", 322 | run_name=params.wandb_run_name, 323 | ) 324 | 325 | early_stopping = EarlyStoppingCallback( 326 | early_stopping_patience=config["early_stopping_patience"], 327 | early_stopping_threshold=0.0, # set no threshold for now 328 | ) 329 | 330 | if params.finetune_from is not None: 331 | model = ScOT.from_pretrained( 332 | params.finetune_from, config=model_config, ignore_mismatched_sizes=True 333 | ) 334 | else: 335 | model = ScOT(model_config) 336 | num_params = get_num_parameters(model) 337 | config["num_params"] = num_params 338 | num_params_no_embed = get_num_parameters_no_embed(model) 339 | config["num_params_wout_embed"] = num_params_no_embed 340 | if RANK == 0 or RANK == -1: 341 | print(f"Model size: {num_params}") 342 | print(f"Model size without embeddings: {num_params_no_embed}") 343 | 344 | def compute_metrics(eval_preds): 345 | channel_list = channel_slice_list 346 | 347 | def get_statistics(errors): 348 | median_error = np.median(errors, axis=0) 349 | mean_error = np.mean(errors, axis=0) 350 | std_error = np.std(errors, axis=0) 351 | min_error = np.min(errors, axis=0) 352 | max_error = np.max(errors, axis=0) 353 | return { 354 | "median_relative_l1_error": median_error, 355 | "mean_relative_l1_error": mean_error, 356 | "std_relative_l1_error": std_error, 357 | "min_relative_l1_error": min_error, 358 | "max_relative_l1_error": max_error, 359 | } 360 | 361 | error_statistics = [ 362 | get_statistics( 363 | relative_lp_error( 364 | eval_preds.predictions[:, channel_list[i] : channel_list[i + 1]], 365 | eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]], 366 | p=1, 367 | return_percent=True, 368 | ) 369 | ) 370 | for i in range(len(channel_list) - 1) 371 | ] 372 | 373 | if output_dim == 1: 374 | error_statistics = error_statistics[0] 375 | return error_statistics 376 | else: 377 | mean_over_means = np.mean( 378 | np.array( 379 | [stats["mean_relative_l1_error"] for stats in error_statistics] 380 | ), 381 | axis=0, 382 | ) 383 | mean_over_medians = np.mean( 384 | np.array( 385 | [stats["median_relative_l1_error"] for stats in error_statistics] 386 | ), 387 | axis=0, 388 | ) 389 | error_statistics_ = { 390 | "mean_relative_l1_error": mean_over_means, 391 | "mean_over_median_relative_l1_error": mean_over_medians, 392 | } 393 | for i, stats in enumerate(error_statistics): 394 | for key, value in stats.items(): 395 | error_statistics_[printable_channel_description[i] + "/" + key] = ( 396 | value 397 | ) 398 | return error_statistics_ 399 | 400 | trainer = Trainer( 401 | model=model, 402 | args=train_config, 403 | train_dataset=train_dataset, 404 | eval_dataset=eval_dataset, 405 | compute_metrics=compute_metrics, 406 | callbacks=[early_stopping], 407 | ) 408 | 409 | trainer.train(resume_from_checkpoint=params.resume_training) 410 | trainer.save_model(train_config.output_dir) 411 | 412 | if (RANK == 0 or RANK == -1) and params.push_to_hf_hub is not None: 413 | model.push_to_hub(params.push_to_hf_hub) 414 | 415 | do_test = ( 416 | True 417 | if params.max_num_train_time_steps is None 418 | and params.train_time_step_size is None 419 | and not params.train_small_time_transition 420 | and not ".time" in config["dataset"] 421 | else False 422 | ) 423 | if do_test: 424 | print("Testing...") 425 | test_set_kwargs = ( 426 | {"just_velocities": True} 427 | if ("incompressible" in config["dataset"]) and params.just_velocities 428 | else {} 429 | ) 430 | out_test_set_kwargs = ( 431 | {"just_velocities": True} 432 | if ("incompressible" in config["dataset"]) and params.just_velocities 433 | else {} 434 | ) 435 | if params.move_data is not None: 436 | test_set_kwargs["move_to_local_scratch"] = params.move_data 437 | out_test_set_kwargs["move_to_local_scratch"] = params.move_data 438 | if time_involved: 439 | test_set_kwargs = { 440 | **test_set_kwargs, 441 | "max_num_time_steps": 1, 442 | "time_step_size": 14, 443 | "allowed_time_transitions": [1], 444 | } 445 | out_test_set_kwargs = { 446 | **out_test_set_kwargs, 447 | "max_num_time_steps": 1, 448 | "time_step_size": 20, 449 | "allowed_time_transitions": [1], 450 | } 451 | if "RayleighTaylor" in config["dataset"]: 452 | test_set_kwargs = { 453 | **test_set_kwargs, 454 | "max_num_time_steps": 1, 455 | "time_step_size": 7, 456 | "allowed_time_transitions": [1], 457 | } 458 | out_test_set_kwargs = { 459 | **out_test_set_kwargs, 460 | "max_num_time_steps": 1, 461 | "time_step_size": 10, 462 | "allowed_time_transitions": [1], 463 | } 464 | 465 | test_dataset = get_dataset( 466 | dataset=config["dataset"], 467 | which="test", 468 | num_trajectories=config["num_trajectories"], 469 | data_path=params.data_path, 470 | **test_set_kwargs, 471 | ) 472 | try: 473 | out_dist_test_dataset = get_dataset( 474 | dataset=config["dataset"] + ".out", 475 | which="test", 476 | num_trajectories=config["num_trajectories"], 477 | data_path=params.data_path, 478 | **out_test_set_kwargs, 479 | ) 480 | except: 481 | out_dist_test_dataset = None 482 | predictions = trainer.predict(test_dataset, metric_key_prefix="") 483 | if RANK == 0 or RANK == -1: 484 | metrics = {} 485 | for key, value in predictions.metrics.items(): 486 | metrics["test/" + key[1:]] = value 487 | wandb.log(metrics) 488 | create_predictions_plot( 489 | predictions.predictions, 490 | predictions.label_ids, 491 | wandb_prefix="test", 492 | ) 493 | 494 | # evaluate on out-of-distribution test set 495 | if out_dist_test_dataset is not None: 496 | predictions = trainer.predict(out_dist_test_dataset, metric_key_prefix="") 497 | if RANK == 0 or RANK == -1: 498 | metrics = {} 499 | for key, value in predictions.metrics.items(): 500 | metrics["test_out_dist/" + key[1:]] = value 501 | wandb.log(metrics) 502 | create_predictions_plot( 503 | predictions.predictions, 504 | predictions.label_ids, 505 | wandb_prefix="test_out_dist", 506 | ) 507 | 508 | if time_involved and (test_set_kwargs["time_step_size"] // 2 > 0): 509 | trainer.set_ar_steps(test_set_kwargs["time_step_size"] // 2) 510 | predictions = trainer.predict(test_dataset, metric_key_prefix="") 511 | if RANK == 0 or RANK == -1: 512 | metrics = {} 513 | for key, value in predictions.metrics.items(): 514 | metrics["test/ar/" + key[1:]] = value 515 | wandb.log(metrics) 516 | create_predictions_plot( 517 | predictions.predictions, 518 | predictions.label_ids, 519 | wandb_prefix="test/ar", 520 | ) 521 | 522 | # evaluate on out-of-distribution test set 523 | if out_dist_test_dataset is not None: 524 | trainer.set_ar_steps(out_test_set_kwargs["time_step_size"] // 2) 525 | predictions = trainer.predict( 526 | out_dist_test_dataset, metric_key_prefix="" 527 | ) 528 | if RANK == 0 or RANK == -1: 529 | metrics = {} 530 | for key, value in predictions.metrics.items(): 531 | metrics["test_out_dist/ar/" + key[1:]] = value 532 | wandb.log(metrics) 533 | create_predictions_plot( 534 | predictions.predictions, 535 | predictions.label_ids, 536 | wandb_prefix="test_out_dist/ar", 537 | ) 538 | -------------------------------------------------------------------------------- /scOT/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Our version of the Huggingface Trainer class. 3 | It adds learning_rate_time_embedding, learning_rate_embedding_recovery as 4 | additional learning rates and groups parameters for the optimizer. 5 | It also allows for autoregressive rollouts by using 6 | trainer.set_ar_steps(AR_STEPS) where AR_STEPS is either a an integer for a 7 | homogeneous rollout of AR_STEPS steps or a list of integers for a heterogeneous 8 | rollout where each element is the timestep. 9 | If, additionally, output_all_steps is also set, the predict function will 10 | output all intermediate steps as well. 11 | 12 | We sublass a Huggingface Trainer to allow for autoregressive rollouts and multiple parameter groups in the optimizer. 13 | It is specifically subclassed for our purpose. 14 | 15 | A lot of code is copied over because only slight changes have been made. 16 | 17 | The original code of Huggingface Transformers is distributed under the Apache 2.0 license. See below: 18 | 19 | Copyright 2018- The Hugging Face team. All rights reserved. 20 | 21 | Apache License 22 | Version 2.0, January 2004 23 | http://www.apache.org/licenses/ 24 | 25 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 26 | 27 | 1. Definitions. 28 | 29 | "License" shall mean the terms and conditions for use, reproduction, 30 | and distribution as defined by Sections 1 through 9 of this document. 31 | 32 | "Licensor" shall mean the copyright owner or entity authorized by 33 | the copyright owner that is granting the License. 34 | 35 | "Legal Entity" shall mean the union of the acting entity and all 36 | other entities that control, are controlled by, or are under common 37 | control with that entity. For the purposes of this definition, 38 | "control" means (i) the power, direct or indirect, to cause the 39 | direction or management of such entity, whether by contract or 40 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 41 | outstanding shares, or (iii) beneficial ownership of such entity. 42 | 43 | "You" (or "Your") shall mean an individual or Legal Entity 44 | exercising permissions granted by this License. 45 | 46 | "Source" form shall mean the preferred form for making modifications, 47 | including but not limited to software source code, documentation 48 | source, and configuration files. 49 | 50 | "Object" form shall mean any form resulting from mechanical 51 | transformation or translation of a Source form, including but 52 | not limited to compiled object code, generated documentation, 53 | and conversions to other media types. 54 | 55 | "Work" shall mean the work of authorship, whether in Source or 56 | Object form, made available under the License, as indicated by a 57 | copyright notice that is included in or attached to the work 58 | (an example is provided in the Appendix below). 59 | 60 | "Derivative Works" shall mean any work, whether in Source or Object 61 | form, that is based on (or derived from) the Work and for which the 62 | editorial revisions, annotations, elaborations, or other modifications 63 | represent, as a whole, an original work of authorship. For the purposes 64 | of this License, Derivative Works shall not include works that remain 65 | separable from, or merely link (or bind by name) to the interfaces of, 66 | the Work and Derivative Works thereof. 67 | 68 | "Contribution" shall mean any work of authorship, including 69 | the original version of the Work and any modifications or additions 70 | to that Work or Derivative Works thereof, that is intentionally 71 | submitted to Licensor for inclusion in the Work by the copyright owner 72 | or by an individual or Legal Entity authorized to submit on behalf of 73 | the copyright owner. For the purposes of this definition, "submitted" 74 | means any form of electronic, verbal, or written communication sent 75 | to the Licensor or its representatives, including but not limited to 76 | communication on electronic mailing lists, source code control systems, 77 | and issue tracking systems that are managed by, or on behalf of, the 78 | Licensor for the purpose of discussing and improving the Work, but 79 | excluding communication that is conspicuously marked or otherwise 80 | designated in writing by the copyright owner as "Not a Contribution." 81 | 82 | "Contributor" shall mean Licensor and any individual or Legal Entity 83 | on behalf of whom a Contribution has been received by Licensor and 84 | subsequently incorporated within the Work. 85 | 86 | 2. Grant of Copyright License. Subject to the terms and conditions of 87 | this License, each Contributor hereby grants to You a perpetual, 88 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 89 | copyright license to reproduce, prepare Derivative Works of, 90 | publicly display, publicly perform, sublicense, and distribute the 91 | Work and such Derivative Works in Source or Object form. 92 | 93 | 3. Grant of Patent License. Subject to the terms and conditions of 94 | this License, each Contributor hereby grants to You a perpetual, 95 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 96 | (except as stated in this section) patent license to make, have made, 97 | use, offer to sell, sell, import, and otherwise transfer the Work, 98 | where such license applies only to those patent claims licensable 99 | by such Contributor that are necessarily infringed by their 100 | Contribution(s) alone or by combination of their Contribution(s) 101 | with the Work to which such Contribution(s) was submitted. If You 102 | institute patent litigation against any entity (including a 103 | cross-claim or counterclaim in a lawsuit) alleging that the Work 104 | or a Contribution incorporated within the Work constitutes direct 105 | or contributory patent infringement, then any patent licenses 106 | granted to You under this License for that Work shall terminate 107 | as of the date such litigation is filed. 108 | 109 | 4. Redistribution. You may reproduce and distribute copies of the 110 | Work or Derivative Works thereof in any medium, with or without 111 | modifications, and in Source or Object form, provided that You 112 | meet the following conditions: 113 | 114 | (a) You must give any other recipients of the Work or 115 | Derivative Works a copy of this License; and 116 | 117 | (b) You must cause any modified files to carry prominent notices 118 | stating that You changed the files; and 119 | 120 | (c) You must retain, in the Source form of any Derivative Works 121 | that You distribute, all copyright, patent, trademark, and 122 | attribution notices from the Source form of the Work, 123 | excluding those notices that do not pertain to any part of 124 | the Derivative Works; and 125 | 126 | (d) If the Work includes a "NOTICE" text file as part of its 127 | distribution, then any Derivative Works that You distribute must 128 | include a readable copy of the attribution notices contained 129 | within such NOTICE file, excluding those notices that do not 130 | pertain to any part of the Derivative Works, in at least one 131 | of the following places: within a NOTICE text file distributed 132 | as part of the Derivative Works; within the Source form or 133 | documentation, if provided along with the Derivative Works; or, 134 | within a display generated by the Derivative Works, if and 135 | wherever such third-party notices normally appear. The contents 136 | of the NOTICE file are for informational purposes only and 137 | do not modify the License. You may add Your own attribution 138 | notices within Derivative Works that You distribute, alongside 139 | or as an addendum to the NOTICE text from the Work, provided 140 | that such additional attribution notices cannot be construed 141 | as modifying the License. 142 | 143 | You may add Your own copyright statement to Your modifications and 144 | may provide additional or different license terms and conditions 145 | for use, reproduction, or distribution of Your modifications, or 146 | for any such Derivative Works as a whole, provided Your use, 147 | reproduction, and distribution of the Work otherwise complies with 148 | the conditions stated in this License. 149 | 150 | 5. Submission of Contributions. Unless You explicitly state otherwise, 151 | any Contribution intentionally submitted for inclusion in the Work 152 | by You to the Licensor shall be under the terms and conditions of 153 | this License, without any additional terms or conditions. 154 | Notwithstanding the above, nothing herein shall supersede or modify 155 | the terms of any separate license agreement you may have executed 156 | with Licensor regarding such Contributions. 157 | 158 | 6. Trademarks. This License does not grant permission to use the trade 159 | names, trademarks, service marks, or product names of the Licensor, 160 | except as required for reasonable and customary use in describing the 161 | origin of the Work and reproducing the content of the NOTICE file. 162 | 163 | 7. Disclaimer of Warranty. Unless required by applicable law or 164 | agreed to in writing, Licensor provides the Work (and each 165 | Contributor provides its Contributions) on an "AS IS" BASIS, 166 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 167 | implied, including, without limitation, any warranties or conditions 168 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 169 | PARTICULAR PURPOSE. You are solely responsible for determining the 170 | appropriateness of using or redistributing the Work and assume any 171 | risks associated with Your exercise of permissions under this License. 172 | 173 | 8. Limitation of Liability. In no event and under no legal theory, 174 | whether in tort (including negligence), contract, or otherwise, 175 | unless required by applicable law (such as deliberate and grossly 176 | negligent acts) or agreed to in writing, shall any Contributor be 177 | liable to You for damages, including any direct, indirect, special, 178 | incidental, or consequential damages of any character arising as a 179 | result of this License or out of the use or inability to use the 180 | Work (including but not limited to damages for loss of goodwill, 181 | work stoppage, computer failure or malfunction, or any and all 182 | other commercial damages or losses), even if such Contributor 183 | has been advised of the possibility of such damages. 184 | 185 | 9. Accepting Warranty or Additional Liability. While redistributing 186 | the Work or Derivative Works thereof, You may choose to offer, 187 | and charge a fee for, acceptance of support, warranty, indemnity, 188 | or other liability obligations and/or rights consistent with this 189 | License. However, in accepting such obligations, You may act only 190 | on Your own behalf and on Your sole responsibility, not on behalf 191 | of any other Contributor, and only if You agree to indemnify, 192 | defend, and hold each Contributor harmless for any liability 193 | incurred by, or claims asserted against, such Contributor by reason 194 | of your accepting any such warranty or additional liability. 195 | 196 | END OF TERMS AND CONDITIONS 197 | 198 | APPENDIX: How to apply the Apache License to your work. 199 | 200 | To apply the Apache License to your work, attach the following 201 | boilerplate notice, with the fields enclosed by brackets "[]" 202 | replaced with your own identifying information. (Don't include 203 | the brackets!) The text should be enclosed in the appropriate 204 | comment syntax for the file format. We also recommend that a 205 | file or class name and description of purpose be included on the 206 | same "printed page" as the copyright notice for easier 207 | identification within third-party archives. 208 | 209 | Copyright [yyyy] [name of copyright owner] 210 | 211 | Licensed under the Apache License, Version 2.0 (the "License"); 212 | you may not use this file except in compliance with the License. 213 | You may obtain a copy of the License at 214 | 215 | http://www.apache.org/licenses/LICENSE-2.0 216 | 217 | Unless required by applicable law or agreed to in writing, software 218 | distributed under the License is distributed on an "AS IS" BASIS, 219 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 220 | See the License for the specific language governing permissions and 221 | limitations under the License. 222 | """ 223 | 224 | import torch 225 | from torch import nn 226 | from typing import List, Optional, Dict, Tuple, Union, Any 227 | from transformers.trainer import * 228 | from transformers import Trainer as Trainer_ 229 | from transformers import TrainingArguments as TrainingArguments_ 230 | from scOT.model import LayerNorm, ConditionalLayerNorm 231 | from dataclasses import dataclass, field 232 | 233 | 234 | @dataclass 235 | class TrainingArguments(TrainingArguments_): 236 | learning_rate_embedding_recovery: Optional[float] = field( 237 | default=None, 238 | metadata={ 239 | "help": "The initial learning rate for the embedding/recovery. When not provided, falls back to `learning_rate`." 240 | }, 241 | ) 242 | 243 | learning_rate_time_embedding: Optional[float] = field( 244 | default=None, 245 | metadata={ 246 | "help": "The initial learning rate for the time embedding. When not provided, falls back to `learning_rate`. Only used when embedding and recovery are also fine-tuned with different lr." 247 | }, 248 | ) 249 | 250 | def set_training( 251 | self, 252 | *args, 253 | learning_rate_embedding_recovery: Optional[float] = None, 254 | learning_rate_time_embedding: Optional[float] = None, 255 | **kwargs, 256 | ): 257 | self = super().set_training(*args, **kwargs) 258 | self.learning_rate_embedding_recovery = learning_rate_embedding_recovery 259 | self.learning_rate_time_embedding = learning_rate_time_embedding 260 | return self 261 | 262 | def set_optimizer( 263 | self, 264 | *args, 265 | learning_rate_embedding_recovery: Optional[float] = None, 266 | learning_rate_time_embedding: Optional[float] = None, 267 | **kwargs, 268 | ): 269 | self = super().set_optimizer(*args, **kwargs) 270 | self.learning_rate_embedding_recovery = learning_rate_embedding_recovery 271 | self.learning_rate_time_embedding = learning_rate_time_embedding 272 | return self 273 | 274 | 275 | class Trainer(Trainer_): 276 | def __init__(self, *args, **kwargs): 277 | super().__init__(*args, **kwargs) 278 | self.ar_steps = None 279 | self.output_all_steps = False 280 | 281 | def get_decay_parameter_names(self, model) -> List[str]: 282 | ALL_LAYERNORM_LAYERS = [torch.nn.LayerNorm, LayerNorm, ConditionalLayerNorm] 283 | decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) 284 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 285 | return decay_parameters 286 | 287 | def get_conditional_norm_params(self, model): 288 | params = [] 289 | for name, module in model.named_modules(): 290 | if isinstance(module, ConditionalLayerNorm): 291 | for param_name, _ in module.named_parameters(): 292 | params.append(f"{name}.{param_name}") 293 | return params 294 | 295 | def create_optimizer(self): 296 | """This is the same as in the standard trainer, except param groups""" 297 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 298 | if self.optimizer is None: 299 | decay_parameters = self.get_decay_parameter_names(self.model) 300 | if self.args.learning_rate_embedding_recovery is not None: 301 | if self.args.learning_rate_time_embedding is not None: 302 | time_embedding_params = self.get_conditional_norm_params(self.model) 303 | params = { 304 | "standard": [], 305 | "no_weight_decay": [], 306 | "embeddings": [], 307 | "time_embedding": [], 308 | } 309 | for n, p in opt_model.named_parameters(): 310 | if ( 311 | "embeddings" in n or "patch_recovery" in n 312 | ) and p.requires_grad: 313 | params["embeddings"].append(p) 314 | elif n in decay_parameters and p.requires_grad: 315 | params["standard"].append(p) 316 | elif p.requires_grad: 317 | if n in time_embedding_params: 318 | params["time_embedding"].append(p) 319 | else: 320 | params["no_weight_decay"].append(p) 321 | optimizer_grouped_parameters = [ 322 | { 323 | "params": params["standard"], 324 | "weight_decay": self.args.weight_decay, 325 | }, 326 | { 327 | "params": params["no_weight_decay"], 328 | "weight_decay": 0.0, 329 | }, 330 | { 331 | "params": params["embeddings"], 332 | "lr": self.args.learning_rate_embedding_recovery, 333 | "weight_decay": self.args.weight_decay, 334 | }, 335 | { 336 | "params": params["time_embedding"], 337 | "lr": self.args.learning_rate_time_embedding, 338 | "weight_decay": 0.0, 339 | }, 340 | ] 341 | else: 342 | params = {"standard": [], "no_weight_decay": [], "embeddings": []} 343 | for n, p in opt_model.named_parameters(): 344 | if ( 345 | "embeddings" in n or "patch_recovery" in n 346 | ) and p.requires_grad: 347 | params["embeddings"].append(p) 348 | elif n in decay_parameters and p.requires_grad: 349 | params["standard"].append(p) 350 | elif p.requires_grad: 351 | params["no_weight_decay"].append(p) 352 | optimizer_grouped_parameters = [ 353 | { 354 | "params": params["standard"], 355 | "weight_decay": self.args.weight_decay, 356 | }, 357 | { 358 | "params": params["no_weight_decay"], 359 | "weight_decay": 0.0, 360 | }, 361 | { 362 | "params": params["embeddings"], 363 | "lr": self.args.learning_rate_embedding_recovery, 364 | "weight_decay": self.args.weight_decay, 365 | }, 366 | ] 367 | elif self.args.learning_rate_time_embedding is not None: 368 | time_embedding_params = self.get_conditional_norm_params(self.model) 369 | params = {"standard": [], "no_weight_decay": [], "time_embedding": []} 370 | for n, p in opt_model.named_parameters(): 371 | if n in decay_parameters and p.requires_grad: 372 | params["standard"].append(p) 373 | elif p.requires_grad: 374 | if n in time_embedding_params: 375 | params["time_embedding"].append(p) 376 | else: 377 | params["no_weight_decay"].append(p) 378 | optimizer_grouped_parameters = [ 379 | { 380 | "params": params["standard"], 381 | "weight_decay": self.args.weight_decay, 382 | }, 383 | { 384 | "params": params["no_weight_decay"], 385 | "weight_decay": 0.0, 386 | }, 387 | { 388 | "params": params["time_embedding"], 389 | "lr": self.args.learning_rate_time_embedding, 390 | "weight_decay": 0.0, 391 | }, 392 | ] 393 | else: 394 | optimizer_grouped_parameters = [ 395 | { 396 | "params": [ 397 | p 398 | for n, p in opt_model.named_parameters() 399 | if (n in decay_parameters and p.requires_grad) 400 | ], 401 | "weight_decay": self.args.weight_decay, 402 | }, 403 | { 404 | "params": [ 405 | p 406 | for n, p in opt_model.named_parameters() 407 | if (n not in decay_parameters and p.requires_grad) 408 | ], 409 | "weight_decay": 0.0, 410 | }, 411 | ] 412 | 413 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( 414 | self.args 415 | ) 416 | 417 | self.optimizer = optimizer_cls( 418 | optimizer_grouped_parameters, **optimizer_kwargs 419 | ) 420 | if optimizer_cls.__name__ == "Adam8bit": 421 | import bitsandbytes 422 | 423 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 424 | 425 | skipped = 0 426 | for module in opt_model.modules(): 427 | if isinstance(module, nn.Embedding): 428 | skipped += sum( 429 | { 430 | p.data_ptr(): p.numel() for p in module.parameters() 431 | }.values() 432 | ) 433 | print(f"skipped {module}: {skipped/2**20}M params") 434 | manager.register_module_override( 435 | module, "weight", {"optim_bits": 32} 436 | ) 437 | logger.debug( 438 | f"bitsandbytes: will optimize {module} in fp32" 439 | ) 440 | print(f"skipped: {skipped/2**20}M params") 441 | 442 | if is_sagemaker_mp_enabled(): 443 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 444 | 445 | return self.optimizer 446 | 447 | def set_ar_steps(self, ar_steps=None, output_all_steps=False): 448 | self.ar_steps = ar_steps 449 | if self.ar_steps is not None and output_all_steps: 450 | self.output_all_steps = True 451 | 452 | def _model_forward(self, model, inputs): 453 | if self.ar_steps is not None and model.config.use_conditioning: 454 | channel_difference = ( 455 | model.config.num_channels > model.config.num_out_channels 456 | ) 457 | # TODO: if outputs is not a dataclass this will break 458 | if isinstance(self.ar_steps, int): 459 | inputs = {**inputs, **{"time": inputs["time"] / self.ar_steps}} 460 | if self.output_all_steps: 461 | loss_ = [] 462 | outputs_ = [] 463 | hidden_states_ = [] 464 | attentions_ = [] 465 | reshaped_hidden_states_ = [] 466 | else: 467 | loss = 0 468 | for i in range(self.ar_steps): 469 | outputs = model(**inputs) 470 | if self.output_all_steps: 471 | outputs_.append(outputs.output.detach()) 472 | if outputs.hidden_states is not None: 473 | hidden_states_.append(outputs.hidden_states) 474 | if outputs.attentions is not None: 475 | attentions_.append(outputs.attentions) 476 | if outputs.reshaped_hidden_states is not None: 477 | reshaped_hidden_states_.append( 478 | outputs.reshaped_hidden_states 479 | ) 480 | if outputs.loss is not None: 481 | loss_.append(outputs.loss) 482 | else: 483 | if outputs.loss is not None: 484 | loss += outputs.loss 485 | inputs = { 486 | **inputs, 487 | **{ 488 | "pixel_values": ( 489 | outputs.output.detach() 490 | if not channel_difference 491 | else torch.cat( 492 | [ 493 | outputs.output.detach(), 494 | inputs["pixel_values"][ 495 | :, 496 | model.config.num_out_channels :, 497 | ], 498 | ], 499 | dim=1, 500 | ) 501 | ) 502 | }, 503 | } 504 | if self.output_all_steps: 505 | outputs.output = torch.stack(outputs_, dim=1) 506 | if len(loss_) > 0: 507 | outputs.loss = torch.stack(loss_, dim=0) 508 | if len(hidden_states_) > 0: 509 | outputs.hidden_states = [ 510 | torch.stack(hs, dim=1) for hs in zip(*hidden_states_) 511 | ] 512 | if len(attentions_) > 0: 513 | outputs.attentions = [ 514 | torch.stack(att, dim=1) for att in zip(*attentions_) 515 | ] 516 | if len(reshaped_hidden_states_) > 0: 517 | outputs.reshaped_hidden_states = [ 518 | torch.stack(rhs, dim=1) 519 | for rhs in zip(*reshaped_hidden_states_) 520 | ] 521 | else: 522 | loss /= self.ar_steps 523 | outputs.loss = loss 524 | elif isinstance(self.ar_steps, list): 525 | if self.output_all_steps: 526 | loss_ = [] 527 | outputs_ = [] 528 | hidden_states_ = [] 529 | attentions_ = [] 530 | reshaped_hidden_states_ = [] 531 | else: 532 | loss = 0 533 | lead_time = inputs["time"] 534 | for i in self.ar_steps: 535 | inputs = { 536 | **inputs, 537 | **{"time": lead_time * i}, 538 | } 539 | outputs = model(**inputs) 540 | if self.output_all_steps: 541 | outputs_.append(outputs.output.detach()) 542 | if self.output_all_steps: 543 | outputs_.append(outputs.output.detach()) 544 | if outputs.hidden_states is not None: 545 | hidden_states_.append(outputs.hidden_states) 546 | if outputs.attentions is not None: 547 | attentions_.append(outputs.attentions) 548 | if outputs.reshaped_hidden_states is not None: 549 | reshaped_hidden_states_.append( 550 | outputs.reshaped_hidden_states 551 | ) 552 | if outputs.loss is not None: 553 | loss_.append(outputs.loss) 554 | else: 555 | if outputs.loss is not None: 556 | loss += outputs.loss 557 | inputs = { 558 | **inputs, 559 | **{ 560 | "pixel_values": ( 561 | outputs.output.detach() 562 | if not channel_difference 563 | else torch.cat( 564 | [ 565 | outputs.output.detach(), 566 | inputs["pixel_values"][ 567 | :, 568 | model.config.num_out_channels :, 569 | ], 570 | ], 571 | dim=1, 572 | ) 573 | ) 574 | }, 575 | } 576 | if self.output_all_steps: 577 | outputs.output = torch.stack(outputs_, dim=1) 578 | if len(loss_) > 0: 579 | outputs.loss = torch.stack(loss_, dim=1) 580 | if len(hidden_states_) > 0: 581 | outputs.hidden_states = [ 582 | torch.stack(hs, dim=1) for hs in zip(*hidden_states_) 583 | ] 584 | if len(attentions_) > 0: 585 | outputs.attentions = [ 586 | torch.stack(att, dim=1) for att in zip(*attentions_) 587 | ] 588 | if len(reshaped_hidden_states_) > 0: 589 | outputs.reshaped_hidden_states = [ 590 | torch.stack(rhs, dim=1) 591 | for rhs in zip(*reshaped_hidden_states_) 592 | ] 593 | else: 594 | loss /= len(self.ar_steps) 595 | outputs.loss = loss 596 | else: 597 | raise ValueError( 598 | "num_ar_steps must be an integer or a list of integers." 599 | ) 600 | else: 601 | outputs = model(**inputs) 602 | 603 | return outputs 604 | 605 | def compute_loss(self, model, inputs, return_outputs=False): 606 | if self.label_smoother is not None and "labels" in inputs: 607 | labels = inputs.pop("labels") 608 | else: 609 | labels = None 610 | outputs = self._model_forward(model, inputs) 611 | # Save past state if it exists 612 | # TODO: this needs to be fixed and made cleaner later. 613 | if self.args.past_index >= 0: 614 | self._past = outputs[self.args.past_index] 615 | 616 | if labels is not None: 617 | unwrapped_model = unwrap_model(model) 618 | if _is_peft_model(unwrapped_model): 619 | model_name = unwrapped_model.base_model.model._get_name() 620 | else: 621 | model_name = unwrapped_model._get_name() 622 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 623 | loss = self.label_smoother(outputs, labels, shift_labels=True) 624 | else: 625 | loss = self.label_smoother(outputs, labels) 626 | else: 627 | if isinstance(outputs, dict) and "loss" not in outputs: 628 | raise ValueError( 629 | "The model did not return a loss from the inputs, only the following keys: " 630 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 631 | ) 632 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 633 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 634 | 635 | return (loss, outputs) if return_outputs else loss 636 | 637 | def prediction_step( 638 | self, 639 | model: nn.Module, 640 | inputs: Dict[str, Union[torch.Tensor, Any]], 641 | prediction_loss_only: bool, 642 | ignore_keys: Optional[List[str]] = None, 643 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 644 | """ 645 | Perform an evaluation step on `model` using `inputs`. 646 | 647 | Subclass and override to inject custom behavior. 648 | 649 | Args: 650 | model (`nn.Module`): 651 | The model to evaluate. 652 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 653 | The inputs and targets of the model. 654 | 655 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 656 | argument `labels`. Check your model's documentation for all accepted arguments. 657 | prediction_loss_only (`bool`): 658 | Whether or not to return the loss only. 659 | ignore_keys (`List[str]`, *optional*): 660 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 661 | gathering predictions. 662 | 663 | Return: 664 | Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, 665 | logits and labels (each being optional). 666 | """ 667 | has_labels = ( 668 | False 669 | if len(self.label_names) == 0 670 | else all(inputs.get(k) is not None for k in self.label_names) 671 | ) 672 | # For CLIP-like models capable of returning loss values. 673 | # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` 674 | # is `True` in `model.forward`. 675 | return_loss = inputs.get("return_loss", None) 676 | if return_loss is None: 677 | return_loss = self.can_return_loss 678 | loss_without_labels = ( 679 | True if len(self.label_names) == 0 and return_loss else False 680 | ) 681 | 682 | inputs = self._prepare_inputs(inputs) 683 | if ignore_keys is None: 684 | if hasattr(self.model, "config"): 685 | ignore_keys = getattr( 686 | self.model.config, "keys_to_ignore_at_inference", [] 687 | ) 688 | else: 689 | ignore_keys = [] 690 | 691 | # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. 692 | if has_labels or loss_without_labels: 693 | labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) 694 | if len(labels) == 1: 695 | labels = labels[0] 696 | else: 697 | labels = None 698 | 699 | with torch.no_grad(): 700 | if is_sagemaker_mp_enabled(): 701 | raw_outputs = smp_forward_only(model, inputs) 702 | if has_labels or loss_without_labels: 703 | if isinstance(raw_outputs, dict): 704 | loss_mb = raw_outputs["loss"] 705 | logits_mb = tuple( 706 | v 707 | for k, v in raw_outputs.items() 708 | if k not in ignore_keys + ["loss"] 709 | ) 710 | else: 711 | loss_mb = raw_outputs[0] 712 | logits_mb = raw_outputs[1:] 713 | 714 | loss = loss_mb.reduce_mean().detach().cpu() 715 | logits = smp_nested_concat(logits_mb) 716 | else: 717 | loss = None 718 | if isinstance(raw_outputs, dict): 719 | logits_mb = tuple( 720 | v for k, v in raw_outputs.items() if k not in ignore_keys 721 | ) 722 | else: 723 | logits_mb = raw_outputs 724 | logits = smp_nested_concat(logits_mb) 725 | else: 726 | if has_labels or loss_without_labels: 727 | with self.compute_loss_context_manager(): 728 | loss, outputs = self.compute_loss( 729 | model, inputs, return_outputs=True 730 | ) 731 | loss = loss.mean().detach() 732 | 733 | if isinstance(outputs, dict): 734 | logits = tuple( 735 | v 736 | for k, v in outputs.items() 737 | if k not in ignore_keys + ["loss"] 738 | ) 739 | else: 740 | logits = outputs[1:] 741 | else: 742 | loss = None 743 | with self.compute_loss_context_manager(): 744 | outputs = self._model_forward(model, inputs) 745 | if isinstance(outputs, dict): 746 | logits = tuple( 747 | v for k, v in outputs.items() if k not in ignore_keys 748 | ) 749 | else: 750 | logits = outputs 751 | # TODO: this needs to be fixed and made cleaner later. 752 | if self.args.past_index >= 0: 753 | self._past = outputs[self.args.past_index - 1] 754 | 755 | if prediction_loss_only: 756 | return (loss, None, None) 757 | 758 | logits = nested_detach(logits) 759 | if len(logits) == 1: 760 | logits = logits[0] 761 | 762 | return (loss, logits, labels) 763 | -------------------------------------------------------------------------------- /scOT/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | 4 | def read_cli(parser): 5 | """Reads command line arguments.""" 6 | 7 | parser.add_argument( 8 | "--config", 9 | type=str, 10 | required=True, 11 | help="Path to config file or JSON string", 12 | ) 13 | parser.add_argument( 14 | "--json_config", 15 | action="store_true", 16 | help="Whether the config is a JSON string", 17 | ) 18 | parser.add_argument( 19 | "--wandb_run_name", 20 | type=str, 21 | required=False, 22 | default=None, 23 | help="Name of the run in wandb", 24 | ) 25 | parser.add_argument( 26 | "--wandb_project_name", 27 | type=str, 28 | default="scOT", 29 | help="Name of the wandb project", 30 | ) 31 | parser.add_argument( 32 | "--max_num_train_time_steps", 33 | type=int, 34 | default=None, 35 | help="Maximum number of time steps to use for training and validation.", 36 | ) 37 | parser.add_argument( 38 | "--train_time_step_size", 39 | type=int, 40 | default=None, 41 | help="Time step size to use for training and validation.", 42 | ) 43 | parser.add_argument( 44 | "--train_small_time_transition", 45 | action="store_true", 46 | help="Whether to train only for next step prediction.", 47 | ) 48 | parser.add_argument( 49 | "--data_path", 50 | type=str, 51 | required=True, 52 | help="Base path to data.", 53 | ) 54 | parser.add_argument( 55 | "--checkpoint_path", 56 | type=str, 57 | required=True, 58 | help="Path to checkpoint directory. Will be prepended by wandb project and run name.", 59 | ) 60 | parser.add_argument( 61 | "--disable_tqdm", 62 | action="store_true", 63 | help="Whether to disable tqdm progress bar", 64 | ) 65 | parser.add_argument( 66 | "--push_to_hf_hub", 67 | type=str, 68 | default=None, 69 | help="Whether to push the model to Huggingface Hub. Specify the model repository name.", 70 | ) 71 | parser.add_argument( 72 | "--just_velocities", 73 | action="store_true", 74 | help="Whether to only use velocities as input. Only relevant for incompressible flow datasets.", 75 | ) 76 | parser.add_argument( 77 | "--move_data", 78 | type=str, 79 | default=None, 80 | help="If set, moves the data to this directory and trains from there.", 81 | ) 82 | return parser 83 | 84 | 85 | def get_num_parameters(model): 86 | """Returns the number of trainable parameters in a model.""" 87 | 88 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 89 | 90 | 91 | def get_num_parameters_no_embed(model): 92 | """Returns the number of trainable parameters in a scOT model without embedding and recovery.""" 93 | out = 0 94 | for name, p in model.named_parameters(): 95 | if not ("embeddings" in name or "patch_recovery" in name) and p.requires_grad: 96 | out += p.numel() 97 | return out 98 | --------------------------------------------------------------------------------